Skip to main content

netvsp/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! The user-mode netvsp VMBus device implementation.
5
6#![expect(missing_docs)]
7#![forbid(unsafe_code)]
8
9mod buffers;
10mod protocol;
11pub mod resolver;
12mod rndisprot;
13mod rx_bufs;
14mod saved_state;
15mod test;
16
17use crate::buffers::GuestBuffers;
18use crate::protocol::VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES;
19use crate::protocol::Version;
20use crate::rndisprot::NDIS_HASH_FUNCTION_MASK;
21use crate::rndisprot::NDIS_RSS_PARAM_FLAG_DISABLE_RSS;
22use async_trait::async_trait;
23pub use buffers::BufferPool;
24use buffers::sub_allocation_size_for_mtu;
25use futures::FutureExt;
26use futures::StreamExt;
27use futures::channel::mpsc;
28use futures::channel::mpsc::TrySendError;
29use futures_concurrency::future::Race;
30use guestmem::AccessError;
31use guestmem::GuestMemory;
32use guestmem::GuestMemoryError;
33use guestmem::MemoryRead;
34use guestmem::MemoryWrite;
35use guestmem::ranges::PagedRange;
36use guestmem::ranges::PagedRanges;
37use guestmem::ranges::PagedRangesReader;
38use guid::Guid;
39use hvdef::hypercall::HvGuestOsId;
40use hvdef::hypercall::HvGuestOsMicrosoft;
41use hvdef::hypercall::HvGuestOsMicrosoftIds;
42use hvdef::hypercall::HvGuestOsOpenSourceType;
43use inspect::Inspect;
44use inspect::InspectMut;
45use inspect::SensitivityLevel;
46use inspect_counters::Counter;
47use inspect_counters::Histogram;
48use mesh::rpc::Rpc;
49use net_backend::Endpoint;
50use net_backend::EndpointAction;
51use net_backend::QueueConfig;
52use net_backend::RxId;
53use net_backend::TxError;
54use net_backend::TxId;
55use net_backend::TxSegment;
56use net_backend_resources::mac_address::MacAddress;
57use pal_async::timer::Instant;
58use pal_async::timer::PolledTimer;
59use ring::gparange::MultiPagedRangeIter;
60use rx_bufs::RxBuffers;
61use rx_bufs::SubAllocationInUse;
62use std::collections::VecDeque;
63use std::fmt::Debug;
64use std::future::pending;
65use std::mem::offset_of;
66use std::ops::Range;
67use std::sync::Arc;
68use std::sync::atomic::AtomicUsize;
69use std::sync::atomic::Ordering;
70use std::task::Poll;
71use std::time::Duration;
72use task_control::AsyncRun;
73use task_control::InspectTaskMut;
74use task_control::StopTask;
75use task_control::TaskControl;
76use thiserror::Error;
77use tracing::Instrument;
78use vmbus_async::queue;
79use vmbus_async::queue::ExternalDataError;
80use vmbus_async::queue::IncomingPacket;
81use vmbus_async::queue::Queue;
82use vmbus_channel::bus::OfferParams;
83use vmbus_channel::bus::OpenRequest;
84use vmbus_channel::channel::ChannelControl;
85use vmbus_channel::channel::ChannelOpenError;
86use vmbus_channel::channel::ChannelRestoreError;
87use vmbus_channel::channel::DeviceResources;
88use vmbus_channel::channel::RestoreControl;
89use vmbus_channel::channel::SaveRestoreVmbusDevice;
90use vmbus_channel::channel::VmbusDevice;
91use vmbus_channel::gpadl::GpadlId;
92use vmbus_channel::gpadl::GpadlMapView;
93use vmbus_channel::gpadl::GpadlView;
94use vmbus_channel::gpadl::UnknownGpadlId;
95use vmbus_channel::gpadl_ring::GpadlRingMem;
96use vmbus_channel::gpadl_ring::gpadl_channel;
97use vmbus_ring as ring;
98use vmbus_ring::OutgoingPacketType;
99use vmbus_ring::RingMem;
100use vmbus_ring::gparange::MultiPagedRangeBuf;
101use vmcore::save_restore::RestoreError;
102use vmcore::save_restore::SaveError;
103use vmcore::save_restore::SavedStateBlob;
104use vmcore::vm_task::VmTaskDriver;
105use vmcore::vm_task::VmTaskDriverSource;
106use zerocopy::FromBytes;
107use zerocopy::FromZeros;
108use zerocopy::Immutable;
109use zerocopy::IntoBytes;
110use zerocopy::KnownLayout;
111
112// The minimum ring space required to handle a control message. Most control messages only need to send a completion
113// packet, but also need room for an additional SEND_VF_ASSOCIATION message.
114const MIN_CONTROL_RING_SIZE: usize = 144;
115
116// The minimum ring space required to handle external state changes. Worst case requires a completion message plus two
117// additional inband messages (SWITCH_DATA_PATH and SEND_VF_ASSOCIATION)
118const MIN_STATE_CHANGE_RING_SIZE: usize = 196;
119
120// Assign the VF_ASSOCIATION message a specific transaction ID so that the completion packet can be identified easily.
121const VF_ASSOCIATION_TRANSACTION_ID: u64 = 0x8000000000000000;
122// Assign the SWITCH_DATA_PATH message a specific transaction ID so that the completion packet can be identified easily.
123const SWITCH_DATA_PATH_TRANSACTION_ID: u64 = 0x8000000000000001;
124
125const NETVSP_MAX_SUBCHANNELS_PER_VNIC: u16 = 64;
126
127// Arbitrary delay before adding the device to the guest. Older Linux
128// clients can race when initializing the synthetic nic: the network
129// negotiation is done first and then the device is asynchronously queued to
130// receive a name (e.g. eth0). If the AN device is offered too quickly, it
131// could get the "eth0" name. In provisioning scenarios, the scripts will make
132// assumptions about which interface should be used, with eth0 being the
133// default.
134#[cfg(not(test))]
135const VF_DEVICE_DELAY: Duration = Duration::from_secs(1);
136#[cfg(test)]
137const VF_DEVICE_DELAY: Duration = Duration::from_millis(100);
138
139// Linux guests are known to not act on link state change notifications if
140// they happen in quick succession.
141#[cfg(not(test))]
142const LINK_DELAY_DURATION: Duration = Duration::from_secs(5);
143#[cfg(test)]
144const LINK_DELAY_DURATION: Duration = Duration::from_millis(333);
145
146#[derive(Default, PartialEq)]
147struct CoordinatorMessageUpdateType {
148    /// Update guest VF state based on current availability and the guest VF state tracked by the primary channel.
149    /// This includes adding the guest VF device and switching the data path.
150    guest_vf_state: bool,
151    /// Update the receive filter for all channels.
152    filter_state: bool,
153}
154
155#[derive(PartialEq)]
156enum CoordinatorMessage {
157    /// Update network state.
158    Update(CoordinatorMessageUpdateType),
159    /// Restart endpoints and resume processing. This will also attempt to set VF and data path state to match current
160    /// expectations.
161    /// Identifies the channel that requested the restart. 0 = primary; >0 = sub-channel
162    Restart { channel_idx: u16 },
163    /// Start a timer.
164    StartTimer(Instant),
165}
166
167struct Worker<T: RingMem> {
168    channel_idx: u16,
169    target_vp: Option<u32>,
170    mem: GuestMemory,
171    channel: NetChannel<T>,
172    state: WorkerState,
173    coordinator_send: mpsc::Sender<CoordinatorMessage>,
174}
175
176struct NetQueue {
177    driver: VmTaskDriver,
178    queue_state: Option<QueueState>,
179}
180
181impl<T: RingMem + 'static + Sync> InspectTaskMut<Worker<T>> for NetQueue {
182    fn inspect_mut(&mut self, req: inspect::Request<'_>, worker: Option<&mut Worker<T>>) {
183        if worker.is_none() && self.queue_state.is_none() {
184            req.ignore();
185            return;
186        }
187
188        let mut resp = req.respond();
189        resp.field("driver", &self.driver);
190        if let Some(worker) = worker {
191            resp.field(
192                "protocol_state",
193                match &worker.state {
194                    WorkerState::Init(None) => "version",
195                    WorkerState::Init(Some(_)) => "init",
196                    WorkerState::Ready(_) => "ready",
197                    WorkerState::WaitingForCoordinator(_) => "waiting for coordinator",
198                },
199            )
200            .field("ring", &worker.channel.queue)
201            .field(
202                "can_use_ring_size_optimization",
203                worker.channel.can_use_ring_size_opt,
204            );
205
206            if let Some(state) = worker.state.ready() {
207                resp.field(
208                    "outstanding_tx_packets",
209                    state.state.pending_tx_packets.len() - state.state.free_tx_packets.len(),
210                )
211                .field("pending_rx_packets", state.state.pending_rx_packets.len())
212                .field(
213                    "pending_tx_completions",
214                    state.state.pending_tx_completions.len(),
215                )
216                .field("free_tx_packets", state.state.free_tx_packets.len())
217                .merge(&state.state.stats);
218            }
219
220            resp.field("packet_filter", worker.channel.packet_filter);
221        }
222
223        if let Some(queue_state) = &mut self.queue_state {
224            resp.field_mut("queue", &mut queue_state.queue)
225                .field("rx_buffers", queue_state.rx_buffer_range.id_range.len())
226                .field(
227                    "rx_buffers_start",
228                    queue_state.rx_buffer_range.id_range.start,
229                );
230        }
231    }
232}
233
234enum WorkerState {
235    Init(Option<InitState>),
236    Ready(ReadyState),
237    WaitingForCoordinator(Option<ReadyState>),
238}
239
240impl WorkerState {
241    fn ready(&self) -> Option<&ReadyState> {
242        match self {
243            Self::Ready(state) | Self::WaitingForCoordinator(Some(state)) => Some(state),
244            _ => None,
245        }
246    }
247
248    fn ready_mut(&mut self) -> Option<&mut ReadyState> {
249        match self {
250            Self::Ready(state) | Self::WaitingForCoordinator(Some(state)) => Some(state),
251            _ => None,
252        }
253    }
254}
255
256struct InitState {
257    version: Version,
258    ndis_config: Option<NdisConfig>,
259    ndis_version: Option<NdisVersion>,
260    recv_buffer: Option<ReceiveBuffer>,
261    send_buffer: Option<SendBuffer>,
262}
263
264#[derive(Copy, Clone, Debug, Inspect)]
265struct NdisVersion {
266    #[inspect(hex)]
267    major: u32,
268    #[inspect(hex)]
269    minor: u32,
270}
271
272#[derive(Copy, Clone, Debug, Inspect)]
273struct NdisConfig {
274    #[inspect(safe)]
275    mtu: u32,
276    #[inspect(safe)]
277    capabilities: protocol::NdisConfigCapabilities,
278}
279
280struct ReadyState {
281    buffers: Arc<ChannelBuffers>,
282    state: ActiveState,
283    data: ProcessingData,
284}
285
286impl ReadyState {
287    /// Any in-flight TX packets submitted to the old endpoint queues will
288    /// never be completed, so this method:
289    /// 1. Queues completions for all outstanding sends so the guest gets
290    ///    responses and the associated transmit slots are reclaimed, ensuring
291    ///    the worker is ready to poll post restart of the queues.
292    /// 2. Clears leftover transmit segments that referenced the old endpoint
293    ///    queue so they are not submitted to the new one.
294    ///
295    fn reset_tx_after_endpoint_stop(&mut self) {
296        let state = &mut self.state;
297
298        // Queue completions for in-flight TX packets that were lost when the
299        // endpoint stopped. They will get picked up when the worker restarts.
300        let pending_tx = state
301            .pending_tx_packets
302            .iter_mut()
303            .enumerate()
304            .filter_map(|(id, inflight)| {
305                if inflight.pending_packet_count > 0 {
306                    inflight.pending_packet_count = 0;
307                    Some(PendingTxCompletion {
308                        transaction_id: inflight.transaction_id,
309                        tx_id: Some(TxId(id as u32)),
310                        status: protocol::Status::SUCCESS,
311                    })
312                } else {
313                    None
314                }
315            })
316            .collect::<Vec<_>>();
317        state.pending_tx_completions.extend(pending_tx);
318
319        // Clear leftover TX segments from the previous endpoint queue;
320        // they cannot be submitted to the new queue.
321        self.data.tx_segments.clear();
322        self.data.tx_segments_sent = 0;
323    }
324}
325
326/// Represents a virtual function (VF) device used to expose accelerated
327/// networking to the guest.
328#[async_trait]
329pub trait VirtualFunction: Sync + Send {
330    /// Unique ID of the device. Used by the client to associate a device with
331    /// its synthetic counterpart. A value of None signifies that the VF is not
332    /// currently available for use.
333    async fn id(&self) -> Option<u32>;
334    /// Dynamically expose the device in the guest.
335    async fn guest_ready_for_device(&mut self);
336    /// Returns when there is a change in VF availability. The Rpc result will
337    ///  indicate if the change was successfully handled.
338    async fn wait_for_state_change(&mut self) -> Rpc<(), ()>;
339}
340
341struct Adapter {
342    driver: VmTaskDriver,
343    mac_address: MacAddress,
344    max_queues: u16,
345    indirection_table_size: u16,
346    offload_support: OffloadConfig,
347    ring_size_limit: AtomicUsize,
348    free_tx_packet_threshold: usize,
349    tx_fast_completions: bool,
350    adapter_index: u32,
351    get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
352    num_sub_channels_opened: AtomicUsize,
353    link_speed: u64,
354}
355
356struct QueueState {
357    queue: Box<dyn net_backend::Queue>,
358    pool: BufferPool,
359    rx_buffer_range: RxBufferRange,
360    target_vp_set: bool,
361}
362
363struct RxBufferRange {
364    id_range: Range<u32>,
365    remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
366    remote_ranges: Arc<RxBufferRanges>,
367}
368
369impl RxBufferRange {
370    fn new(
371        ranges: Arc<RxBufferRanges>,
372        id_range: Range<u32>,
373        remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
374    ) -> Self {
375        Self {
376            id_range,
377            remote_buffer_id_recv,
378            remote_ranges: ranges,
379        }
380    }
381
382    fn send_if_remote(&self, id: u32) -> bool {
383        // Only queue 0 should get reserved buffer IDs. Otherwise check if the
384        // ID is owned by the current range.
385        if id < RX_RESERVED_CONTROL_BUFFERS || self.id_range.contains(&id) {
386            false
387        } else {
388            let i = (id - RX_RESERVED_CONTROL_BUFFERS) / self.remote_ranges.buffers_per_queue;
389            // The total number of receive buffers may not evenly divide among
390            // the active queues. Any extra buffers are given to the last
391            // queue, so redirect any larger values there.
392            let i = (i as usize).min(self.remote_ranges.buffer_id_send.len() - 1);
393            let _ = self.remote_ranges.buffer_id_send[i].unbounded_send(id);
394            true
395        }
396    }
397}
398
399#[derive(Debug, Error)]
400enum RxBufferConfigError {
401    #[error("queue_count must be at least 1")]
402    ZeroQueueCount,
403    #[error("buffer_count must be >= RX_RESERVED_CONTROL_BUFFERS")]
404    InsufficientBuffers,
405    #[error("buffers_per_queue must not be 0")]
406    ZeroBuffersPerQueue,
407}
408
409struct RxBufferRanges {
410    buffers_per_queue: u32,
411    buffer_id_send: Vec<mpsc::UnboundedSender<u32>>,
412}
413
414impl RxBufferRanges {
415    /// Validates that the given parameters produce a valid RX buffer configuration.
416    fn validate_params(buffer_count: u32, queue_count: u32) -> Result<u32, WorkerError> {
417        if queue_count == 0 {
418            return Err(WorkerError::InvalidRxBufferConfig(
419                RxBufferConfigError::ZeroQueueCount,
420            ));
421        }
422        if buffer_count < RX_RESERVED_CONTROL_BUFFERS {
423            return Err(WorkerError::InvalidRxBufferConfig(
424                RxBufferConfigError::InsufficientBuffers,
425            ));
426        }
427        let buffers_per_queue = (buffer_count - RX_RESERVED_CONTROL_BUFFERS) / queue_count;
428        if buffers_per_queue == 0 {
429            return Err(WorkerError::InvalidRxBufferConfig(
430                RxBufferConfigError::ZeroBuffersPerQueue,
431            ));
432        }
433        Ok(buffers_per_queue)
434    }
435
436    fn new(
437        buffer_count: u32,
438        queue_count: u32,
439    ) -> Result<(Self, Vec<mpsc::UnboundedReceiver<u32>>), WorkerError> {
440        let buffers_per_queue = Self::validate_params(buffer_count, queue_count)?;
441        #[expect(clippy::disallowed_methods)] // TODO
442        let (send, recv): (Vec<_>, Vec<_>) = (0..queue_count).map(|_| mpsc::unbounded()).unzip();
443        Ok((
444            Self {
445                buffers_per_queue,
446                buffer_id_send: send,
447            },
448            recv,
449        ))
450    }
451}
452
453struct RssState {
454    key: [u8; 40],
455    indirection_table: Vec<u16>,
456}
457
458/// The internal channel state.
459struct NetChannel<T: RingMem> {
460    adapter: Arc<Adapter>,
461    queue: Queue<T>,
462    gpadl_map: GpadlMapView,
463    packet_size: PacketSize,
464    pending_send_size: usize,
465    restart: Option<CoordinatorMessage>,
466    can_use_ring_size_opt: bool,
467    packet_filter: u32,
468}
469
470// Use an enum to give the compiler more visibility into the packet size.
471#[derive(Debug, Copy, Clone)]
472enum PacketSize {
473    /// [`protocol::PACKET_SIZE_V1`]
474    V1,
475    /// [`protocol::PACKET_SIZE_V61`]
476    V61,
477}
478
479/// Buffers used during packet processing.
480struct ProcessingData {
481    tx_segments: Vec<TxSegment>,
482    tx_segments_sent: usize,
483    tx_done: Box<[TxId]>,
484    rx_ready: Box<[RxId]>,
485    rx_done: Vec<RxId>,
486    transfer_pages: Vec<ring::TransferPageRange>,
487    external_data: MultiPagedRangeBuf,
488}
489
490impl ProcessingData {
491    fn new() -> Self {
492        Self {
493            tx_segments: Vec::new(),
494            tx_segments_sent: 0,
495            tx_done: vec![TxId(0); 8192].into(),
496            rx_ready: vec![RxId(0); RX_BATCH_SIZE].into(),
497            rx_done: Vec::with_capacity(RX_BATCH_SIZE),
498            transfer_pages: Vec::with_capacity(RX_BATCH_SIZE),
499            external_data: MultiPagedRangeBuf::new(),
500        }
501    }
502}
503
504/// Buffers used during channel processing. Separated out from the mutable state
505/// to allow multiple concurrent references.
506#[derive(Debug, Inspect)]
507struct ChannelBuffers {
508    version: Version,
509    #[inspect(skip)]
510    mem: GuestMemory,
511    #[inspect(skip)]
512    recv_buffer: ReceiveBuffer,
513    #[inspect(skip)]
514    send_buffer: Option<SendBuffer>,
515    ndis_version: NdisVersion,
516    #[inspect(safe)]
517    ndis_config: NdisConfig,
518}
519
520/// An ID assigned to a control message. This is also its receive buffer index.
521#[derive(Copy, Clone, Debug)]
522struct ControlMessageId(u32);
523
524/// Mutable state for a channel that has finished negotiation.
525struct ActiveState {
526    primary: Option<PrimaryChannelState>,
527
528    pending_tx_packets: Vec<PendingTxPacket>,
529    free_tx_packets: Vec<TxId>,
530    pending_tx_completions: VecDeque<PendingTxCompletion>,
531    pending_rx_packets: VecDeque<RxId>,
532
533    rx_bufs: RxBuffers,
534
535    stats: QueueStats,
536}
537
538#[derive(Inspect, Default)]
539struct QueueStats {
540    tx_stalled: Counter,
541    rx_dropped_ring_full: Counter,
542    rx_dropped_filtered: Counter,
543    spurious_wakes: Counter,
544    rx_packets: Counter,
545    tx_packets: Counter,
546    tx_lso_packets: Counter,
547    tx_checksum_packets: Counter,
548    tx_vlan_packets: Counter,
549    rx_vlan_packets: Counter,
550    tx_invalid_lso_packets: Counter,
551    tx_packets_per_wake: Histogram<10>,
552    rx_packets_per_wake: Histogram<10>,
553}
554
555#[derive(Debug)]
556struct PendingTxCompletion {
557    transaction_id: u64,
558    tx_id: Option<TxId>,
559    status: protocol::Status,
560}
561
562#[derive(Clone, Copy)]
563enum PrimaryChannelGuestVfState {
564    /// No state has been assigned yet
565    Initializing,
566    /// State is being restored from a save
567    Restoring(saved_state::GuestVfState),
568    /// No VF available for the guest
569    Unavailable,
570    /// A VF was previously available to the guest, but is no longer available.
571    UnavailableFromAvailable,
572    /// A VF was previously available, but it is no longer available.
573    UnavailableFromDataPathSwitchPending { to_guest: bool, id: Option<u64> },
574    /// A VF was previously available, but it is no longer available.
575    UnavailableFromDataPathSwitched,
576    /// A VF is available for the guest.
577    Available { vfid: u32 },
578    /// A VF is available for the guest and has been advertised to the guest.
579    AvailableAdvertised,
580    /// A VF is ready for guest use.
581    Ready,
582    /// A VF is ready in the guest and guest has requested a data path switch.
583    DataPathSwitchPending {
584        to_guest: bool,
585        id: Option<u64>,
586        result: Option<bool>,
587    },
588    /// A VF is ready in the guest and is currently acting as the data path.
589    DataPathSwitched,
590    /// A VF is ready in the guest and was acting as the data path, but an external
591    /// state change has moved it back to synthetic.
592    DataPathSynthetic,
593}
594
595impl std::fmt::Display for PrimaryChannelGuestVfState {
596    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
597        match self {
598            PrimaryChannelGuestVfState::Initializing => write!(f, "initializing"),
599            PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
600                write!(f, "restoring")
601            }
602            PrimaryChannelGuestVfState::Restoring(
603                saved_state::GuestVfState::AvailableAdvertised,
604            ) => write!(f, "restoring from guest notified of vfid"),
605            PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
606                write!(f, "restoring from vf present")
607            }
608            PrimaryChannelGuestVfState::Restoring(
609                saved_state::GuestVfState::DataPathSwitchPending {
610                    to_guest, result, ..
611                },
612            ) => {
613                write!(
614                    f,
615                    "restoring from client requested data path switch: to {} {}",
616                    if *to_guest { "guest" } else { "synthetic" },
617                    if let Some(result) = result {
618                        if *result { "succeeded\"" } else { "failed\"" }
619                    } else {
620                        "in progress\""
621                    }
622                )
623            }
624            PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::DataPathSwitched) => {
625                write!(f, "restoring from data path in guest")
626            }
627            PrimaryChannelGuestVfState::Unavailable => write!(f, "unavailable"),
628            PrimaryChannelGuestVfState::UnavailableFromAvailable => {
629                write!(f, "\"unavailable (previously available)\"")
630            }
631            PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. } => {
632                write!(f, "unavailable (previously switching data path)")
633            }
634            PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
635                write!(f, "\"unavailable (previously using guest VF)\"")
636            }
637            PrimaryChannelGuestVfState::Available { vfid } => write!(f, "available vfid: {}", vfid),
638            PrimaryChannelGuestVfState::AvailableAdvertised => {
639                write!(f, "\"available, guest notified\"")
640            }
641            PrimaryChannelGuestVfState::Ready => write!(f, "\"available and present in guest\""),
642            PrimaryChannelGuestVfState::DataPathSwitchPending {
643                to_guest, result, ..
644            } => {
645                write!(
646                    f,
647                    "\"switching to {} {}",
648                    if *to_guest { "guest" } else { "synthetic" },
649                    if let Some(result) = result {
650                        if *result { "succeeded\"" } else { "failed\"" }
651                    } else {
652                        "in progress\""
653                    }
654                )
655            }
656            PrimaryChannelGuestVfState::DataPathSwitched => {
657                write!(f, "\"available and data path switched\"")
658            }
659            PrimaryChannelGuestVfState::DataPathSynthetic => write!(
660                f,
661                "\"available but data path switched back to synthetic due to external state change\""
662            ),
663        }
664    }
665}
666
667impl Inspect for PrimaryChannelGuestVfState {
668    fn inspect(&self, req: inspect::Request<'_>) {
669        req.value(self.to_string());
670    }
671}
672
673#[derive(Debug)]
674enum PendingLinkAction {
675    Default,
676    Active(bool),
677    Delay(bool),
678}
679
680struct PrimaryChannelState {
681    guest_vf_state: PrimaryChannelGuestVfState,
682    is_data_path_switched: Option<bool>,
683    control_messages: VecDeque<ControlMessage>,
684    control_messages_len: usize,
685    free_control_buffers: Vec<ControlMessageId>,
686    rss_state: Option<RssState>,
687    requested_num_queues: u16,
688    rndis_state: RndisState,
689    offload_config: OffloadConfig,
690    pending_offload_change: bool,
691    tx_spread_sent: bool,
692    guest_link_up: bool,
693    pending_link_action: PendingLinkAction,
694    /// The serial number sent in the most recent VF association message, so
695    /// that the matching disassociation message uses the same serial number.
696    advertised_vf_serial_number: Option<u32>,
697}
698
699impl Inspect for PrimaryChannelState {
700    fn inspect(&self, req: inspect::Request<'_>) {
701        req.respond()
702            .sensitivity_field(
703                "guest_vf_state",
704                SensitivityLevel::Safe,
705                self.guest_vf_state,
706            )
707            .sensitivity_field(
708                "data_path_switched",
709                SensitivityLevel::Safe,
710                self.is_data_path_switched,
711            )
712            .sensitivity_field(
713                "pending_control_messages",
714                SensitivityLevel::Safe,
715                self.control_messages.len(),
716            )
717            .sensitivity_field(
718                "free_control_message_buffers",
719                SensitivityLevel::Safe,
720                self.free_control_buffers.len(),
721            )
722            .sensitivity_field(
723                "pending_offload_change",
724                SensitivityLevel::Safe,
725                self.pending_offload_change,
726            )
727            .sensitivity_field("rndis_state", SensitivityLevel::Safe, self.rndis_state)
728            .sensitivity_field(
729                "offload_config",
730                SensitivityLevel::Safe,
731                &self.offload_config,
732            )
733            .sensitivity_field(
734                "tx_spread_sent",
735                SensitivityLevel::Safe,
736                self.tx_spread_sent,
737            )
738            .sensitivity_field("guest_link_up", SensitivityLevel::Safe, self.guest_link_up)
739            .sensitivity_field(
740                "pending_link_action",
741                SensitivityLevel::Safe,
742                match &self.pending_link_action {
743                    PendingLinkAction::Active(up) => format!("Active({:x?})", up),
744                    PendingLinkAction::Delay(up) => format!("Delay({:x?})", up),
745                    PendingLinkAction::Default => "None".to_string(),
746                },
747            );
748    }
749}
750
751#[derive(Debug, Inspect, Clone)]
752struct OffloadConfig {
753    #[inspect(safe)]
754    checksum_tx: ChecksumOffloadConfig,
755    #[inspect(safe)]
756    checksum_rx: ChecksumOffloadConfig,
757    #[inspect(safe)]
758    lso4: bool,
759    #[inspect(safe)]
760    lso6: bool,
761}
762
763impl OffloadConfig {
764    fn mask_to_supported(&mut self, supported: &OffloadConfig) {
765        self.checksum_tx.mask_to_supported(&supported.checksum_tx);
766        self.checksum_rx.mask_to_supported(&supported.checksum_rx);
767        self.lso4 &= supported.lso4;
768        self.lso6 &= supported.lso6;
769    }
770}
771
772#[derive(Debug, Inspect, Clone)]
773struct ChecksumOffloadConfig {
774    #[inspect(safe)]
775    ipv4_header: bool,
776    #[inspect(safe)]
777    tcp4: bool,
778    #[inspect(safe)]
779    udp4: bool,
780    #[inspect(safe)]
781    tcp6: bool,
782    #[inspect(safe)]
783    udp6: bool,
784}
785
786impl ChecksumOffloadConfig {
787    fn mask_to_supported(&mut self, supported: &ChecksumOffloadConfig) {
788        self.ipv4_header &= supported.ipv4_header;
789        self.tcp4 &= supported.tcp4;
790        self.udp4 &= supported.udp4;
791        self.tcp6 &= supported.tcp6;
792        self.udp6 &= supported.udp6;
793    }
794
795    fn flags(
796        &self,
797    ) -> (
798        rndisprot::Ipv4ChecksumOffload,
799        rndisprot::Ipv6ChecksumOffload,
800    ) {
801        let on = rndisprot::NDIS_OFFLOAD_SUPPORTED;
802        let mut v4 = rndisprot::Ipv4ChecksumOffload::new();
803        let mut v6 = rndisprot::Ipv6ChecksumOffload::new();
804        if self.ipv4_header {
805            v4.set_ip_options_supported(on);
806            v4.set_ip_checksum(on);
807        }
808        if self.tcp4 {
809            v4.set_ip_options_supported(on);
810            v4.set_tcp_options_supported(on);
811            v4.set_tcp_checksum(on);
812        }
813        if self.tcp6 {
814            v6.set_ip_extension_headers_supported(on);
815            v6.set_tcp_options_supported(on);
816            v6.set_tcp_checksum(on);
817        }
818        if self.udp4 {
819            v4.set_ip_options_supported(on);
820            v4.set_udp_checksum(on);
821        }
822        if self.udp6 {
823            v6.set_ip_extension_headers_supported(on);
824            v6.set_udp_checksum(on);
825        }
826        (v4, v6)
827    }
828}
829
830impl OffloadConfig {
831    fn ndis_offload(&self) -> rndisprot::NdisOffload {
832        let checksum = {
833            let (ipv4_tx_flags, ipv6_tx_flags) = self.checksum_tx.flags();
834            let (ipv4_rx_flags, ipv6_rx_flags) = self.checksum_rx.flags();
835            rndisprot::TcpIpChecksumOffload {
836                ipv4_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
837                ipv4_tx_flags,
838                ipv4_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
839                ipv4_rx_flags,
840                ipv6_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
841                ipv6_tx_flags,
842                ipv6_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
843                ipv6_rx_flags,
844            }
845        };
846
847        let lso_v2 = {
848            let mut lso = rndisprot::TcpLargeSendOffloadV2::new_zeroed();
849            if self.lso4 {
850                lso.ipv4_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
851                lso.ipv4_max_offload_size = rndisprot::LSO_MAX_OFFLOAD_SIZE;
852                lso.ipv4_min_segment_count = rndisprot::LSO_MIN_SEGMENT_COUNT;
853            }
854            if self.lso6 {
855                lso.ipv6_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
856                lso.ipv6_max_offload_size = rndisprot::LSO_MAX_OFFLOAD_SIZE;
857                lso.ipv6_min_segment_count = rndisprot::LSO_MIN_SEGMENT_COUNT;
858                lso.ipv6_flags = rndisprot::Ipv6LsoFlags::new()
859                    .with_ip_extension_headers_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED)
860                    .with_tcp_options_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED);
861            }
862            lso
863        };
864
865        rndisprot::NdisOffload {
866            header: rndisprot::NdisObjectHeader {
867                object_type: rndisprot::NdisObjectType::OFFLOAD,
868                revision: 3,
869                size: rndisprot::NDIS_SIZEOF_NDIS_OFFLOAD_REVISION_3 as u16,
870            },
871            checksum,
872            lso_v2,
873            ..FromZeros::new_zeroed()
874        }
875    }
876}
877
878#[derive(Debug, Inspect, PartialEq, Eq, Copy, Clone)]
879pub enum RndisState {
880    Initializing,
881    Operational,
882    Halted,
883}
884
885impl PrimaryChannelState {
886    fn new(offload_config: OffloadConfig) -> Self {
887        Self {
888            guest_vf_state: PrimaryChannelGuestVfState::Initializing,
889            is_data_path_switched: None,
890            control_messages: VecDeque::new(),
891            control_messages_len: 0,
892            free_control_buffers: (0..RX_RESERVED_CONTROL_BUFFERS)
893                .map(ControlMessageId)
894                .collect(),
895            rss_state: None,
896            requested_num_queues: 1,
897            rndis_state: RndisState::Initializing,
898            pending_offload_change: false,
899            offload_config,
900            tx_spread_sent: false,
901            guest_link_up: true,
902            pending_link_action: PendingLinkAction::Default,
903            advertised_vf_serial_number: None,
904        }
905    }
906
907    fn restore(
908        guest_vf_state: &saved_state::GuestVfState,
909        rndis_state: &saved_state::RndisState,
910        offload_config: &saved_state::OffloadConfig,
911        pending_offload_change: bool,
912        advertised_vf_serial_number: Option<u32>,
913        num_queues: u16,
914        indirection_table_size: u16,
915        rx_bufs: &RxBuffers,
916        control_messages: Vec<saved_state::IncomingControlMessage>,
917        rss_state: Option<saved_state::RssState>,
918        tx_spread_sent: bool,
919        guest_link_down: bool,
920        pending_link_action: Option<bool>,
921    ) -> Result<Self, NetRestoreError> {
922        // Restore control messages.
923        let control_messages_len = control_messages.iter().map(|msg| msg.data.len()).sum();
924
925        let control_messages = control_messages
926            .into_iter()
927            .map(|msg| ControlMessage {
928                message_type: msg.message_type,
929                data: msg.data.into(),
930            })
931            .collect();
932
933        // Compute the free control buffers.
934        let free_control_buffers = (0..RX_RESERVED_CONTROL_BUFFERS)
935            .filter_map(|id| rx_bufs.is_free(id).then_some(ControlMessageId(id)))
936            .collect();
937
938        let rss_state = rss_state
939            .map(|mut rss| {
940                if rss.indirection_table.len() > indirection_table_size as usize {
941                    // Dynamic reduction of indirection table can cause unexpected and hard to investigate issues
942                    // with performance and processor overloading.
943                    return Err(NetRestoreError::ReducedIndirectionTableSize);
944                }
945                if rss.indirection_table.len() < indirection_table_size as usize {
946                    tracing::warn!(
947                        saved_indirection_table_size = rss.indirection_table.len(),
948                        adapter_indirection_table_size = indirection_table_size,
949                        "increasing indirection table size",
950                    );
951                    // Dynamic increase of indirection table is done by duplicating the existing entries until
952                    // the desired size is reached.
953                    let table_clone = rss.indirection_table.clone();
954                    let num_to_add = indirection_table_size as usize - rss.indirection_table.len();
955                    rss.indirection_table
956                        .extend(table_clone.iter().cycle().take(num_to_add));
957                }
958                Ok(RssState {
959                    key: rss
960                        .key
961                        .try_into()
962                        .map_err(|_| NetRestoreError::InvalidRssKeySize)?,
963                    indirection_table: rss.indirection_table,
964                })
965            })
966            .transpose()?;
967
968        let rndis_state = match rndis_state {
969            saved_state::RndisState::Initializing => RndisState::Initializing,
970            saved_state::RndisState::Operational => RndisState::Operational,
971            saved_state::RndisState::Halted => RndisState::Halted,
972        };
973
974        let guest_vf_state = PrimaryChannelGuestVfState::Restoring(*guest_vf_state);
975        let offload_config = OffloadConfig {
976            checksum_tx: ChecksumOffloadConfig {
977                ipv4_header: offload_config.checksum_tx.ipv4_header,
978                tcp4: offload_config.checksum_tx.tcp4,
979                udp4: offload_config.checksum_tx.udp4,
980                tcp6: offload_config.checksum_tx.tcp6,
981                udp6: offload_config.checksum_tx.udp6,
982            },
983            checksum_rx: ChecksumOffloadConfig {
984                ipv4_header: offload_config.checksum_rx.ipv4_header,
985                tcp4: offload_config.checksum_rx.tcp4,
986                udp4: offload_config.checksum_rx.udp4,
987                tcp6: offload_config.checksum_rx.tcp6,
988                udp6: offload_config.checksum_rx.udp6,
989            },
990            lso4: offload_config.lso4,
991            lso6: offload_config.lso6,
992        };
993
994        let pending_link_action = if let Some(pending) = pending_link_action {
995            PendingLinkAction::Active(pending)
996        } else {
997            PendingLinkAction::Default
998        };
999
1000        Ok(Self {
1001            guest_vf_state,
1002            is_data_path_switched: None,
1003            control_messages,
1004            control_messages_len,
1005            free_control_buffers,
1006            rss_state,
1007            requested_num_queues: num_queues,
1008            rndis_state,
1009            pending_offload_change,
1010            offload_config,
1011            tx_spread_sent,
1012            guest_link_up: !guest_link_down,
1013            pending_link_action,
1014            advertised_vf_serial_number,
1015        })
1016    }
1017}
1018
1019struct ControlMessage {
1020    message_type: u32,
1021    data: Box<[u8]>,
1022}
1023
1024const TX_PACKET_QUOTA: usize = 1024;
1025
1026impl ActiveState {
1027    fn new(primary: Option<PrimaryChannelState>, recv_buffer_count: u32) -> Self {
1028        Self {
1029            primary,
1030            pending_tx_packets: vec![Default::default(); TX_PACKET_QUOTA],
1031            free_tx_packets: (0..TX_PACKET_QUOTA as u32).rev().map(TxId).collect(),
1032            pending_tx_completions: VecDeque::new(),
1033            pending_rx_packets: VecDeque::new(),
1034            rx_bufs: RxBuffers::new(recv_buffer_count),
1035            stats: Default::default(),
1036        }
1037    }
1038
1039    fn restore(
1040        channel: &saved_state::Channel,
1041        recv_buffer_count: u32,
1042    ) -> Result<Self, NetRestoreError> {
1043        let mut active = Self::new(None, recv_buffer_count);
1044        let saved_state::Channel {
1045            pending_tx_completions,
1046            in_use_rx,
1047        } = channel;
1048        for rx in in_use_rx {
1049            active
1050                .rx_bufs
1051                .allocate(rx.buffers.as_slice().iter().copied())?;
1052        }
1053        for &transaction_id in pending_tx_completions {
1054            // Consume tx quota if any is available. If not, still
1055            // allow the restore since tx quota might change from
1056            // release to release.
1057            let tx_id = active.free_tx_packets.pop();
1058            if let Some(id) = tx_id {
1059                // This shouldn't be referenced, but set it in case it is in the future.
1060                active.pending_tx_packets[id.0 as usize].transaction_id = transaction_id;
1061            }
1062            // Save/Restore does not preserve the status of pending tx completions,
1063            // completing any pending completions with 'success' to avoid making changes to saved_state.
1064            active
1065                .pending_tx_completions
1066                .push_back(PendingTxCompletion {
1067                    transaction_id,
1068                    tx_id,
1069                    status: protocol::Status::SUCCESS,
1070                });
1071        }
1072        Ok(active)
1073    }
1074}
1075
1076/// The state for an rndis tx packet that's currently pending in the backend
1077/// endpoint.
1078#[derive(Default, Clone)]
1079struct PendingTxPacket {
1080    pending_packet_count: usize,
1081    transaction_id: u64,
1082}
1083
1084/// The maximum batch size.
1085///
1086/// TODO: An even larger value is supported when RSC is enabled, so look into
1087/// this.
1088const RX_BATCH_SIZE: usize = 375;
1089
1090/// The number of receive buffers to reserve for control message responses.
1091const RX_RESERVED_CONTROL_BUFFERS: u32 = 16;
1092
1093/// A network adapter.
1094pub struct Nic {
1095    instance_id: Guid,
1096    resources: DeviceResources,
1097    coordinator: TaskControl<CoordinatorState, Coordinator>,
1098    coordinator_send: Option<mpsc::Sender<CoordinatorMessage>>,
1099    adapter: Arc<Adapter>,
1100    driver_source: VmTaskDriverSource,
1101}
1102
1103pub struct NicBuilder {
1104    virtual_function: Option<Box<dyn VirtualFunction>>,
1105    limit_ring_buffer: bool,
1106    max_queues: u16,
1107    get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
1108}
1109
1110impl NicBuilder {
1111    pub fn limit_ring_buffer(mut self, limit: bool) -> Self {
1112        self.limit_ring_buffer = limit;
1113        self
1114    }
1115
1116    pub fn max_queues(mut self, max_queues: u16) -> Self {
1117        self.max_queues = max_queues;
1118        self
1119    }
1120
1121    pub fn virtual_function(mut self, virtual_function: Box<dyn VirtualFunction>) -> Self {
1122        self.virtual_function = Some(virtual_function);
1123        self
1124    }
1125
1126    pub fn get_guest_os_id(mut self, os_type: Box<dyn Fn() -> HvGuestOsId + Send + Sync>) -> Self {
1127        self.get_guest_os_id = Some(os_type);
1128        self
1129    }
1130
1131    /// Creates a new NIC.
1132    pub fn build(
1133        self,
1134        driver_source: &VmTaskDriverSource,
1135        instance_id: Guid,
1136        endpoint: Box<dyn Endpoint>,
1137        mac_address: MacAddress,
1138        adapter_index: u32,
1139    ) -> Nic {
1140        let multiqueue = endpoint.multiqueue_support();
1141
1142        let max_queues = self.max_queues.clamp(
1143            1,
1144            multiqueue.max_queues.min(NETVSP_MAX_SUBCHANNELS_PER_VNIC),
1145        );
1146
1147        // If requested, limit the effective size of the outgoing ring buffer.
1148        // In a configuration where the NIC is processed synchronously, this
1149        // will ensure that we don't process incoming rx packets and tx packet
1150        // completions until the guest has processed the data it already has.
1151        let ring_size_limit = if self.limit_ring_buffer { 1024 } else { 0 };
1152
1153        // If the endpoint completes tx packets quickly, then avoid polling the
1154        // incoming ring (and thus avoid arming the signal from the guest) as
1155        // long as there are any tx packets in flight. This can significantly
1156        // reduce the signal rate from the guest, improving batching.
1157        let free_tx_packet_threshold = if endpoint.tx_fast_completions() {
1158            TX_PACKET_QUOTA
1159        } else {
1160            // Avoid getting into a situation where there is always barely
1161            // enough quota.
1162            TX_PACKET_QUOTA / 4
1163        };
1164
1165        let tx_offloads = endpoint.tx_offload_support();
1166
1167        // Always claim support for rx offloads since we can mark any given
1168        // packet as having unknown checksum state.
1169        let offload_support = OffloadConfig {
1170            checksum_rx: ChecksumOffloadConfig {
1171                ipv4_header: true,
1172                tcp4: true,
1173                udp4: true,
1174                tcp6: true,
1175                udp6: true,
1176            },
1177            checksum_tx: ChecksumOffloadConfig {
1178                ipv4_header: tx_offloads.ipv4_header,
1179                tcp4: tx_offloads.tcp,
1180                tcp6: tx_offloads.tcp,
1181                udp4: tx_offloads.udp,
1182                udp6: tx_offloads.udp,
1183            },
1184            // LSOv4 requires both TSO and IPv4 header checksum support,
1185            // because the TAP/virtio GSO engine needs a valid IPv4 header
1186            // checksum that NDIS LSO packets don't provide.
1187            lso4: tx_offloads.tso && tx_offloads.ipv4_header,
1188            lso6: tx_offloads.tso,
1189        };
1190
1191        let driver = driver_source.simple();
1192        let adapter = Arc::new(Adapter {
1193            driver,
1194            mac_address,
1195            max_queues,
1196            indirection_table_size: multiqueue.indirection_table_size,
1197            offload_support,
1198            free_tx_packet_threshold,
1199            ring_size_limit: ring_size_limit.into(),
1200            tx_fast_completions: endpoint.tx_fast_completions(),
1201            adapter_index,
1202            get_guest_os_id: self.get_guest_os_id,
1203            num_sub_channels_opened: AtomicUsize::new(0),
1204            link_speed: endpoint.link_speed(),
1205        });
1206
1207        let coordinator = TaskControl::new(CoordinatorState {
1208            endpoint,
1209            adapter: adapter.clone(),
1210            virtual_function: self.virtual_function,
1211            pending_vf_state: CoordinatorStatePendingVfState::Ready,
1212        });
1213
1214        Nic {
1215            instance_id,
1216            resources: Default::default(),
1217            coordinator,
1218            coordinator_send: None,
1219            adapter,
1220            driver_source: driver_source.clone(),
1221        }
1222    }
1223}
1224
1225fn can_use_ring_opt<T: RingMem>(queue: &mut Queue<T>, guest_os_id: Option<HvGuestOsId>) -> bool {
1226    let Some(guest_os_id) = guest_os_id else {
1227        // guest os id not available.
1228        return false;
1229    };
1230
1231    if !queue.split().0.supports_pending_send_size() {
1232        // guest does not support pending send size.
1233        return false;
1234    }
1235
1236    let Some(open_source_os) = guest_os_id.open_source() else {
1237        // guest os is proprietary (ex: Windows)
1238        return true;
1239    };
1240
1241    match HvGuestOsOpenSourceType(open_source_os.os_type()) {
1242        // Although FreeBSD indicates support for `pending send size`, it doesn't
1243        // implement it correctly. This was fixed in FreeBSD version `1400097`.
1244        HvGuestOsOpenSourceType::FREEBSD => open_source_os.version() >= 1400097,
1245        // Linux kernels prior to 3.11 have issues with pending send size optimization
1246        // which can affect certain Asynchronous I/O (AIO) network operations.
1247        // Disable ring size optimization for these older kernels to avoid flow control issues.
1248        HvGuestOsOpenSourceType::LINUX => {
1249            // Linux version is encoded as: ((major << 16) | (minor << 8) | patch)
1250            // Linux 3.11.0 = (3 << 16) | (11 << 8) | 0 = 199424
1251            open_source_os.version() >= 199424
1252        }
1253        _ => true,
1254    }
1255}
1256
1257impl Nic {
1258    pub fn builder() -> NicBuilder {
1259        NicBuilder {
1260            virtual_function: None,
1261            limit_ring_buffer: false,
1262            max_queues: !0,
1263            get_guest_os_id: None,
1264        }
1265    }
1266
1267    pub fn shutdown(self) -> (Box<dyn Endpoint>, MacAddress) {
1268        let (state, _) = self.coordinator.into_inner();
1269        (state.endpoint, self.adapter.mac_address)
1270    }
1271}
1272
1273impl InspectMut for Nic {
1274    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
1275        self.coordinator.inspect_mut(req);
1276    }
1277}
1278
1279#[async_trait]
1280impl VmbusDevice for Nic {
1281    fn offer(&self) -> OfferParams {
1282        OfferParams {
1283            interface_name: "net".to_owned(),
1284            instance_id: self.instance_id,
1285            interface_id: Guid {
1286                data1: 0xf8615163,
1287                data2: 0xdf3e,
1288                data3: 0x46c5,
1289                data4: [0x91, 0x3f, 0xf2, 0xd2, 0xf9, 0x65, 0xed, 0xe],
1290            },
1291            subchannel_index: 0,
1292            mnf_interrupt_latency: Some(Duration::from_micros(100)),
1293            ..Default::default()
1294        }
1295    }
1296
1297    fn max_subchannels(&self) -> u16 {
1298        self.adapter.max_queues
1299    }
1300
1301    fn install(&mut self, resources: DeviceResources) {
1302        self.resources = resources;
1303    }
1304
1305    async fn open(
1306        &mut self,
1307        channel_idx: u16,
1308        open_request: &OpenRequest,
1309    ) -> Result<(), ChannelOpenError> {
1310        // Start the coordinator task if this is the primary channel.
1311        let state = if channel_idx == 0 {
1312            self.insert_coordinator(1, None);
1313            WorkerState::Init(None)
1314        } else {
1315            self.coordinator.stop().await;
1316            // Get the buffers created when the primary channel was opened.
1317            let buffers = self.coordinator.state().unwrap().buffers.clone().unwrap();
1318            WorkerState::Ready(ReadyState {
1319                state: ActiveState::new(None, buffers.recv_buffer.count),
1320                buffers,
1321                data: ProcessingData::new(),
1322            })
1323        };
1324
1325        let num_opened = self
1326            .adapter
1327            .num_sub_channels_opened
1328            .fetch_add(1, Ordering::SeqCst);
1329        let r = self.insert_worker(channel_idx, open_request, state, true);
1330        if channel_idx != 0
1331            && num_opened + 1 == self.coordinator.state_mut().unwrap().num_queues as usize
1332        {
1333            let coordinator = &mut self.coordinator.state_mut().unwrap();
1334            coordinator.workers[0].stop().await;
1335        }
1336
1337        if r.is_err() && channel_idx == 0 {
1338            self.coordinator.remove();
1339        } else {
1340            // The coordinator will restart any stopped workers.
1341            self.coordinator.start();
1342        }
1343        r?;
1344        Ok(())
1345    }
1346
1347    async fn close(&mut self, channel_idx: u16) {
1348        if !self.coordinator.has_state() {
1349            tracing::error!(
1350                channel_idx,
1351                instance_id = %self.instance_id,
1352                "Close called while vmbus channel is already closed"
1353            );
1354            return;
1355        }
1356
1357        // Stop the coordinator to get access to the workers.
1358        let restart = self.coordinator.stop().await;
1359
1360        // Stop and remove the channel worker.
1361        {
1362            let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1363            worker.stop().await;
1364            if worker.has_state() {
1365                worker.remove();
1366            }
1367        }
1368
1369        self.adapter
1370            .num_sub_channels_opened
1371            .fetch_sub(1, Ordering::SeqCst);
1372        // Disable the endpoint.
1373        if channel_idx == 0 {
1374            for worker in &mut self.coordinator.state_mut().unwrap().workers {
1375                worker.task_mut().queue_state = None;
1376            }
1377
1378            // Note that this await is not restartable.
1379            self.coordinator
1380                .task_mut()
1381                .endpoint
1382                .stop()
1383                .instrument(tracing::info_span!(
1384                    "stopping coordinator endpoint",
1385                    instance_id = %self.instance_id,
1386                ))
1387                .await;
1388
1389            // Keep any VF's added to the guest. This is required to keep guest compat as
1390            // some apps (such as DPDK) relies on the VF sticking around even after vmbus
1391            // channel is closed.
1392            // The coordinator's job is done.
1393            self.coordinator.remove();
1394        } else {
1395            // Restart the coordinator.
1396            if restart {
1397                self.coordinator.start();
1398            }
1399        }
1400    }
1401
1402    async fn retarget_vp(&mut self, channel_idx: u16, target_vp: u32) {
1403        if !self.coordinator.has_state() {
1404            return;
1405        }
1406
1407        // Stop the coordinator and worker associated with this channel.
1408        let coordinator_running = self.coordinator.stop().await;
1409        let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1410        worker.stop().await;
1411        let (net_queue, worker_state) = worker.get_mut();
1412
1413        // Update the target VP on the driver.
1414        net_queue.driver.retarget_vp(target_vp);
1415
1416        if let Some(worker_state) = worker_state {
1417            // Update the target VP in the worker state.
1418            worker_state.target_vp = Some(target_vp);
1419            if let Some(queue_state) = &mut net_queue.queue_state {
1420                // Tell the worker to re-set the target VP on next run.
1421                queue_state.target_vp_set = false;
1422            }
1423        }
1424
1425        // The coordinator will restart any stopped workers.
1426        if coordinator_running {
1427            self.coordinator.start();
1428        }
1429    }
1430
1431    fn start(&mut self) {
1432        if !self.coordinator.is_running() {
1433            self.coordinator.start();
1434        }
1435    }
1436
1437    async fn stop(&mut self) {
1438        self.coordinator.stop().await;
1439        if let Some(coordinator) = self.coordinator.state_mut() {
1440            coordinator.stop_workers().await;
1441        }
1442    }
1443
1444    fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
1445        Some(self)
1446    }
1447}
1448
1449#[async_trait]
1450impl SaveRestoreVmbusDevice for Nic {
1451    async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
1452        let state = self.saved_state();
1453        Ok(SavedStateBlob::new(state))
1454    }
1455
1456    async fn restore(
1457        &mut self,
1458        control: RestoreControl<'_>,
1459        state: SavedStateBlob,
1460    ) -> Result<(), RestoreError> {
1461        let state: saved_state::SavedState = state.parse()?;
1462        if let Err(err) = self.restore_state(control, state).await {
1463            tracing::error!(
1464                error = &err as &dyn std::error::Error,
1465                instance_id = %self.instance_id,
1466                "Failed restoring network vmbus state"
1467            );
1468            Err(err.into())
1469        } else {
1470            Ok(())
1471        }
1472    }
1473}
1474
1475impl Nic {
1476    /// Allocates and inserts a worker.
1477    ///
1478    /// The coordinator must be stopped.
1479    fn insert_worker(
1480        &mut self,
1481        channel_idx: u16,
1482        open_request: &OpenRequest,
1483        state: WorkerState,
1484        start: bool,
1485    ) -> Result<(), OpenError> {
1486        let coordinator = self.coordinator.state_mut().unwrap();
1487
1488        // Retarget the driver now that the channel is open.
1489        // N.B. VMBus doesn't provide a target VP if the channel is not using interrupts. Run on VP
1490        //      0 in that case.
1491        let driver = coordinator.workers[channel_idx as usize]
1492            .task()
1493            .driver
1494            .clone();
1495        driver.retarget_vp(open_request.open_data.target_vp.unwrap_or_default());
1496
1497        let raw = gpadl_channel(&driver, &self.resources, open_request, channel_idx)
1498            .map_err(OpenError::Ring)?;
1499        let mut queue = Queue::new(raw).map_err(OpenError::Queue)?;
1500        let guest_os_id = self.adapter.get_guest_os_id.as_ref().map(|f| f());
1501        let can_use_ring_size_opt = can_use_ring_opt(&mut queue, guest_os_id);
1502        let worker = Worker {
1503            channel_idx,
1504            target_vp: open_request.open_data.target_vp,
1505            mem: self
1506                .resources
1507                .offer_resources
1508                .guest_memory(open_request)
1509                .clone(),
1510            channel: NetChannel {
1511                adapter: self.adapter.clone(),
1512                queue,
1513                gpadl_map: self.resources.gpadl_map.clone(),
1514                packet_size: PacketSize::V1,
1515                pending_send_size: 0,
1516                restart: None,
1517                can_use_ring_size_opt,
1518                packet_filter: coordinator.active_packet_filter,
1519            },
1520            state,
1521            coordinator_send: self.coordinator_send.clone().unwrap(),
1522        };
1523        let instance_id = self.instance_id;
1524        let worker_task = &mut coordinator.workers[channel_idx as usize];
1525        worker_task.insert(
1526            driver,
1527            format!("netvsp-{}-{}", instance_id, channel_idx),
1528            worker,
1529        );
1530        if start {
1531            worker_task.start();
1532        }
1533        Ok(())
1534    }
1535}
1536
1537struct RestoreCoordinatorState {
1538    active_packet_filter: u32,
1539}
1540
1541impl Nic {
1542    /// If `restoring`, then restart the queues as soon as the coordinator starts.
1543    fn insert_coordinator(&mut self, num_queues: u16, restoring: Option<RestoreCoordinatorState>) {
1544        let mut driver_builder = self.driver_source.builder();
1545        // Target each driver to VP 0 initially. This will be updated when the
1546        // channel is opened.
1547        driver_builder.target_vp(0);
1548        // If tx completions arrive quickly, then just do tx processing
1549        // on whatever processor the guest happens to signal from.
1550        // Subsequent transmits will be pulled from the completion
1551        // processor.
1552        driver_builder.run_on_target(!self.adapter.tx_fast_completions);
1553
1554        #[expect(clippy::disallowed_methods)] // TODO
1555        let (send, recv) = mpsc::channel(1);
1556        self.coordinator_send = Some(send);
1557        self.coordinator.insert(
1558            &self.adapter.driver,
1559            format!("netvsp-{}-coordinator", self.instance_id),
1560            Coordinator {
1561                recv,
1562                channel_control: self.resources.channel_control.clone(),
1563                restart: restoring.is_some(),
1564                workers: (0..self.adapter.max_queues)
1565                    .map(|i| {
1566                        TaskControl::new(NetQueue {
1567                            queue_state: None,
1568                            driver: driver_builder
1569                                .build(format!("netvsp-{}-{}", self.instance_id, i)),
1570                        })
1571                    })
1572                    .collect(),
1573                buffers: None,
1574                num_queues,
1575                active_packet_filter: restoring
1576                    .map(|r| r.active_packet_filter)
1577                    .unwrap_or(rndisprot::NDIS_PACKET_TYPE_NONE),
1578                sleep_deadline: None,
1579            },
1580        );
1581    }
1582}
1583
1584#[derive(Debug, Error)]
1585enum NetRestoreError {
1586    #[error("unsupported protocol version {0:#x}")]
1587    UnsupportedVersion(u32),
1588    #[error("send/receive buffer invalid gpadl ID")]
1589    UnknownGpadlId(#[from] UnknownGpadlId),
1590    #[error("failed to restore channels")]
1591    Channel(#[source] ChannelRestoreError),
1592    #[error(transparent)]
1593    ReceiveBuffer(#[from] BufferError),
1594    #[error(transparent)]
1595    SuballocationMisconfigured(#[from] SubAllocationInUse),
1596    #[error(transparent)]
1597    Open(#[from] OpenError),
1598    #[error("invalid rss key size")]
1599    InvalidRssKeySize,
1600    #[error("reduced indirection table size")]
1601    ReducedIndirectionTableSize,
1602}
1603
1604impl From<NetRestoreError> for RestoreError {
1605    fn from(err: NetRestoreError) -> Self {
1606        RestoreError::InvalidSavedState(anyhow::Error::new(err))
1607    }
1608}
1609
1610impl Nic {
1611    async fn restore_state(
1612        &mut self,
1613        mut control: RestoreControl<'_>,
1614        state: saved_state::SavedState,
1615    ) -> Result<(), NetRestoreError> {
1616        let mut saved_packet_filter = 0u32;
1617        if let Some(state) = state.open {
1618            let open = match &state.primary {
1619                saved_state::Primary::Version => vec![true],
1620                saved_state::Primary::Init(_) => vec![true],
1621                saved_state::Primary::Ready(ready) => {
1622                    ready.channels.iter().map(|x| x.is_some()).collect()
1623                }
1624            };
1625
1626            let mut states: Vec<_> = open.iter().map(|_| None).collect();
1627
1628            // N.B. This will restore the vmbus view of open channels, so any
1629            //      failures after this point could result in inconsistent
1630            //      state (vmbus believes the channel is open/active). There
1631            //      are a number of failure paths after this point because this
1632            //      call also restores vmbus device state, like the GPADL map.
1633            let requests = control
1634                .restore(&open)
1635                .await
1636                .map_err(NetRestoreError::Channel)?;
1637
1638            match state.primary {
1639                saved_state::Primary::Version => {
1640                    states[0] = Some(WorkerState::Init(None));
1641                }
1642                saved_state::Primary::Init(init) => {
1643                    let version = check_version(init.version)
1644                        .ok_or(NetRestoreError::UnsupportedVersion(init.version))?;
1645
1646                    let recv_buffer = init
1647                        .receive_buffer
1648                        .map(|recv_buffer| {
1649                            ReceiveBuffer::new(
1650                                &self.resources.gpadl_map,
1651                                recv_buffer.gpadl_id,
1652                                recv_buffer.id,
1653                                recv_buffer.sub_allocation_size,
1654                            )
1655                        })
1656                        .transpose()?;
1657
1658                    let send_buffer = init
1659                        .send_buffer
1660                        .map(|send_buffer| {
1661                            SendBuffer::new(&self.resources.gpadl_map, send_buffer.gpadl_id)
1662                        })
1663                        .transpose()?;
1664
1665                    let state = InitState {
1666                        version,
1667                        ndis_config: init.ndis_config.map(
1668                            |saved_state::NdisConfig { mtu, capabilities }| NdisConfig {
1669                                mtu,
1670                                capabilities: capabilities.into(),
1671                            },
1672                        ),
1673                        ndis_version: init.ndis_version.map(
1674                            |saved_state::NdisVersion { major, minor }| NdisVersion {
1675                                major,
1676                                minor,
1677                            },
1678                        ),
1679                        recv_buffer,
1680                        send_buffer,
1681                    };
1682                    states[0] = Some(WorkerState::Init(Some(state)));
1683                }
1684                saved_state::Primary::Ready(ready) => {
1685                    let saved_state::ReadyPrimary {
1686                        version,
1687                        receive_buffer,
1688                        send_buffer,
1689                        mut control_messages,
1690                        mut rss_state,
1691                        channels,
1692                        ndis_version,
1693                        ndis_config,
1694                        rndis_state,
1695                        guest_vf_state,
1696                        offload_config,
1697                        pending_offload_change,
1698                        tx_spread_sent,
1699                        guest_link_down,
1700                        pending_link_action,
1701                        packet_filter,
1702                        advertised_vf_serial_number,
1703                    } = ready;
1704
1705                    // If saved state does not have a packet filter set, default to directed, multicast, and broadcast.
1706                    saved_packet_filter = packet_filter.unwrap_or(rndisprot::NPROTO_PACKET_FILTER);
1707
1708                    let version = check_version(version)
1709                        .ok_or(NetRestoreError::UnsupportedVersion(version))?;
1710
1711                    let request = requests[0].as_ref().unwrap();
1712                    let buffers = Arc::new(ChannelBuffers {
1713                        version,
1714                        mem: self.resources.offer_resources.guest_memory(request).clone(),
1715                        recv_buffer: ReceiveBuffer::new(
1716                            &self.resources.gpadl_map,
1717                            receive_buffer.gpadl_id,
1718                            receive_buffer.id,
1719                            receive_buffer.sub_allocation_size,
1720                        )?,
1721                        send_buffer: {
1722                            if let Some(send_buffer) = send_buffer {
1723                                Some(SendBuffer::new(
1724                                    &self.resources.gpadl_map,
1725                                    send_buffer.gpadl_id,
1726                                )?)
1727                            } else {
1728                                None
1729                            }
1730                        },
1731                        ndis_version: {
1732                            let saved_state::NdisVersion { major, minor } = ndis_version;
1733                            NdisVersion { major, minor }
1734                        },
1735                        ndis_config: {
1736                            let saved_state::NdisConfig { mtu, capabilities } = ndis_config;
1737                            NdisConfig {
1738                                mtu,
1739                                capabilities: capabilities.into(),
1740                            }
1741                        },
1742                    });
1743
1744                    for (channel_idx, channel) in channels.iter().enumerate() {
1745                        let channel = if let Some(channel) = channel {
1746                            channel
1747                        } else {
1748                            continue;
1749                        };
1750
1751                        let mut active = ActiveState::restore(channel, buffers.recv_buffer.count)?;
1752
1753                        // Restore primary channel state.
1754                        if channel_idx == 0 {
1755                            let primary = PrimaryChannelState::restore(
1756                                &guest_vf_state,
1757                                &rndis_state,
1758                                &offload_config,
1759                                pending_offload_change,
1760                                advertised_vf_serial_number,
1761                                channels.len() as u16,
1762                                self.adapter.indirection_table_size,
1763                                &active.rx_bufs,
1764                                std::mem::take(&mut control_messages),
1765                                rss_state.take(),
1766                                tx_spread_sent,
1767                                guest_link_down,
1768                                pending_link_action,
1769                            )?;
1770                            active.primary = Some(primary);
1771                        }
1772
1773                        states[channel_idx] = Some(WorkerState::Ready(ReadyState {
1774                            buffers: buffers.clone(),
1775                            state: active,
1776                            data: ProcessingData::new(),
1777                        }));
1778                    }
1779                }
1780            }
1781
1782            // Insert the coordinator and mark that it should try to start the
1783            // network endpoint when it starts running.
1784            self.insert_coordinator(
1785                states.len() as u16,
1786                Some(RestoreCoordinatorState {
1787                    active_packet_filter: saved_packet_filter,
1788                }),
1789            );
1790
1791            for (channel_idx, (state, request)) in states.into_iter().zip(requests).enumerate() {
1792                if let Some(state) = state {
1793                    self.insert_worker(channel_idx as u16, &request.unwrap(), state, false)?;
1794                }
1795            }
1796        } else {
1797            control
1798                .restore(&[false])
1799                .await
1800                .map_err(NetRestoreError::Channel)?;
1801        }
1802        Ok(())
1803    }
1804
1805    fn saved_state(&self) -> saved_state::SavedState {
1806        let open = if let Some(coordinator) = self.coordinator.state() {
1807            let primary = coordinator.workers[0].state().unwrap();
1808            let primary = match &primary.state {
1809                WorkerState::Init(None) => saved_state::Primary::Version,
1810                WorkerState::Init(Some(init)) => {
1811                    saved_state::Primary::Init(saved_state::InitPrimary {
1812                        version: init.version as u32,
1813                        ndis_config: init.ndis_config.map(|NdisConfig { mtu, capabilities }| {
1814                            saved_state::NdisConfig {
1815                                mtu,
1816                                capabilities: capabilities.into(),
1817                            }
1818                        }),
1819                        ndis_version: init.ndis_version.map(|NdisVersion { major, minor }| {
1820                            saved_state::NdisVersion { major, minor }
1821                        }),
1822                        receive_buffer: init.recv_buffer.as_ref().map(|x| x.saved_state()),
1823                        send_buffer: init.send_buffer.as_ref().map(|x| saved_state::SendBuffer {
1824                            gpadl_id: x.gpadl.id(),
1825                        }),
1826                    })
1827                }
1828                WorkerState::WaitingForCoordinator(Some(ready)) | WorkerState::Ready(ready) => {
1829                    let primary = ready.state.primary.as_ref().unwrap();
1830
1831                    let rndis_state = match primary.rndis_state {
1832                        RndisState::Initializing => saved_state::RndisState::Initializing,
1833                        RndisState::Operational => saved_state::RndisState::Operational,
1834                        RndisState::Halted => saved_state::RndisState::Halted,
1835                    };
1836
1837                    let offload_config = saved_state::OffloadConfig {
1838                        checksum_tx: saved_state::ChecksumOffloadConfig {
1839                            ipv4_header: primary.offload_config.checksum_tx.ipv4_header,
1840                            tcp4: primary.offload_config.checksum_tx.tcp4,
1841                            udp4: primary.offload_config.checksum_tx.udp4,
1842                            tcp6: primary.offload_config.checksum_tx.tcp6,
1843                            udp6: primary.offload_config.checksum_tx.udp6,
1844                        },
1845                        checksum_rx: saved_state::ChecksumOffloadConfig {
1846                            ipv4_header: primary.offload_config.checksum_rx.ipv4_header,
1847                            tcp4: primary.offload_config.checksum_rx.tcp4,
1848                            udp4: primary.offload_config.checksum_rx.udp4,
1849                            tcp6: primary.offload_config.checksum_rx.tcp6,
1850                            udp6: primary.offload_config.checksum_rx.udp6,
1851                        },
1852                        lso4: primary.offload_config.lso4,
1853                        lso6: primary.offload_config.lso6,
1854                    };
1855
1856                    let control_messages = primary
1857                        .control_messages
1858                        .iter()
1859                        .map(|message| saved_state::IncomingControlMessage {
1860                            message_type: message.message_type,
1861                            data: message.data.to_vec(),
1862                        })
1863                        .collect();
1864
1865                    let rss_state = primary.rss_state.as_ref().map(|rss| saved_state::RssState {
1866                        key: rss.key.into(),
1867                        indirection_table: rss.indirection_table.clone(),
1868                    });
1869
1870                    let pending_link_action = match primary.pending_link_action {
1871                        PendingLinkAction::Default => None,
1872                        PendingLinkAction::Active(action) | PendingLinkAction::Delay(action) => {
1873                            Some(action)
1874                        }
1875                    };
1876
1877                    let channels = coordinator.workers[..coordinator.num_queues as usize]
1878                        .iter()
1879                        .map(|worker| {
1880                            worker.state().map(|worker| {
1881                                if let Some(ready) = worker.state.ready() {
1882                                    // In flight tx will be considered as dropped packets through save/restore, but need
1883                                    // to complete the requests back to the guest.
1884                                    let pending_tx_completions = ready
1885                                        .state
1886                                        .pending_tx_completions
1887                                        .iter()
1888                                        .map(|pending| pending.transaction_id)
1889                                        .chain(ready.state.pending_tx_packets.iter().filter_map(
1890                                            |inflight| {
1891                                                (inflight.pending_packet_count > 0)
1892                                                    .then_some(inflight.transaction_id)
1893                                            },
1894                                        ))
1895                                        .collect();
1896
1897                                    saved_state::Channel {
1898                                        pending_tx_completions,
1899                                        in_use_rx: {
1900                                            ready
1901                                                .state
1902                                                .rx_bufs
1903                                                .allocated()
1904                                                .map(|id| saved_state::Rx {
1905                                                    buffers: id.collect(),
1906                                                })
1907                                                .collect()
1908                                        },
1909                                    }
1910                                } else {
1911                                    saved_state::Channel {
1912                                        pending_tx_completions: Vec::new(),
1913                                        in_use_rx: Vec::new(),
1914                                    }
1915                                }
1916                            })
1917                        })
1918                        .collect();
1919
1920                    let guest_vf_state = match primary.guest_vf_state {
1921                        PrimaryChannelGuestVfState::Initializing
1922                        | PrimaryChannelGuestVfState::Unavailable
1923                        | PrimaryChannelGuestVfState::Available { .. } => {
1924                            saved_state::GuestVfState::NoState
1925                        }
1926                        PrimaryChannelGuestVfState::UnavailableFromAvailable
1927                        | PrimaryChannelGuestVfState::AvailableAdvertised => {
1928                            saved_state::GuestVfState::AvailableAdvertised
1929                        }
1930                        PrimaryChannelGuestVfState::Ready => saved_state::GuestVfState::Ready,
1931                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
1932                            to_guest,
1933                            id,
1934                        } => saved_state::GuestVfState::DataPathSwitchPending {
1935                            to_guest,
1936                            id,
1937                            result: None,
1938                        },
1939                        PrimaryChannelGuestVfState::DataPathSwitchPending {
1940                            to_guest,
1941                            id,
1942                            result,
1943                        } => saved_state::GuestVfState::DataPathSwitchPending {
1944                            to_guest,
1945                            id,
1946                            result,
1947                        },
1948                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
1949                        | PrimaryChannelGuestVfState::DataPathSwitched
1950                        | PrimaryChannelGuestVfState::DataPathSynthetic => {
1951                            saved_state::GuestVfState::DataPathSwitched
1952                        }
1953                        PrimaryChannelGuestVfState::Restoring(saved_state) => saved_state,
1954                    };
1955
1956                    let worker_0_packet_filter = coordinator.workers[0]
1957                        .state()
1958                        .unwrap()
1959                        .channel
1960                        .packet_filter;
1961                    saved_state::Primary::Ready(saved_state::ReadyPrimary {
1962                        version: ready.buffers.version as u32,
1963                        receive_buffer: ready.buffers.recv_buffer.saved_state(),
1964                        send_buffer: ready.buffers.send_buffer.as_ref().map(|sb| {
1965                            saved_state::SendBuffer {
1966                                gpadl_id: sb.gpadl.id(),
1967                            }
1968                        }),
1969                        rndis_state,
1970                        guest_vf_state,
1971                        offload_config,
1972                        pending_offload_change: primary.pending_offload_change,
1973                        control_messages,
1974                        rss_state,
1975                        channels,
1976                        ndis_config: {
1977                            let NdisConfig { mtu, capabilities } = ready.buffers.ndis_config;
1978                            saved_state::NdisConfig {
1979                                mtu,
1980                                capabilities: capabilities.into(),
1981                            }
1982                        },
1983                        ndis_version: {
1984                            let NdisVersion { major, minor } = ready.buffers.ndis_version;
1985                            saved_state::NdisVersion { major, minor }
1986                        },
1987                        tx_spread_sent: primary.tx_spread_sent,
1988                        guest_link_down: !primary.guest_link_up,
1989                        pending_link_action,
1990                        packet_filter: Some(worker_0_packet_filter),
1991                        advertised_vf_serial_number: primary.advertised_vf_serial_number,
1992                    })
1993                }
1994                WorkerState::WaitingForCoordinator(None) => {
1995                    unreachable!("valid ready state")
1996                }
1997            };
1998
1999            let state = saved_state::OpenState { primary };
2000            Some(state)
2001        } else {
2002            None
2003        };
2004
2005        saved_state::SavedState { open }
2006    }
2007}
2008
2009#[derive(Debug, Error)]
2010enum WorkerError {
2011    #[error("packet error")]
2012    Packet(#[source] PacketError),
2013    #[error("unexpected packet order: {0}")]
2014    UnexpectedPacketOrder(#[source] PacketOrderError),
2015    #[error("unknown rndis message type: {0}")]
2016    UnknownRndisMessageType(u32),
2017    #[error("junk after rndis packet message: {0:#x}")]
2018    NonRndisPacketAfterPacket(u32),
2019    #[error("memory access error")]
2020    Access(#[from] AccessError),
2021    #[error("rndis message too small")]
2022    RndisMessageTooSmall,
2023    #[error("unsupported rndis behavior")]
2024    UnsupportedRndisBehavior,
2025    #[error("vmbus queue error")]
2026    Queue(#[from] queue::Error),
2027    #[error("too many control messages")]
2028    TooManyControlMessages,
2029    #[error("invalid rndis packet completion")]
2030    InvalidRndisPacketCompletion,
2031    #[error("missing transaction id")]
2032    MissingTransactionId,
2033    #[error("invalid gpadl")]
2034    InvalidGpadl(#[source] guestmem::InvalidGpn),
2035    #[error("guest buffers error")]
2036    GuestBuffers(#[source] buffers::GuestBuffersError),
2037    #[error("gpa direct error")]
2038    GpaDirectError(#[source] GuestMemoryError),
2039    #[error("endpoint")]
2040    Endpoint(#[source] anyhow::Error),
2041    #[error("message not supported on sub channel: {0}")]
2042    NotSupportedOnSubChannel(u32),
2043    #[error("the ring buffer ran out of space, which should not be possible")]
2044    OutOfSpace,
2045    #[error("send/receive buffer error")]
2046    Buffer(#[from] BufferError),
2047    #[error("invalid rndis state")]
2048    InvalidRndisState,
2049    #[error("rndis message type not implemented")]
2050    RndisMessageTypeNotImplemented,
2051    #[error("invalid TCP header offset {0}")]
2052    InvalidTcpHeaderOffset(u16),
2053    #[error("cancelled")]
2054    Cancelled(task_control::Cancelled),
2055    #[error("tearing down because send/receive buffer is revoked")]
2056    BufferRevoked,
2057    #[error("endpoint requires queue restart: {0}")]
2058    EndpointRequiresQueueRestart(#[source] anyhow::Error),
2059    #[error("Failed to send message to coordinator")]
2060    CoordinatorMessageSendFailed(#[source] TrySendError<CoordinatorMessage>),
2061    #[error("invalid rx buffer configuration: {0}")]
2062    InvalidRxBufferConfig(#[source] RxBufferConfigError),
2063}
2064
2065impl From<task_control::Cancelled> for WorkerError {
2066    fn from(value: task_control::Cancelled) -> Self {
2067        Self::Cancelled(value)
2068    }
2069}
2070
2071#[derive(Debug, Error)]
2072enum OpenError {
2073    #[error("error establishing ring buffer")]
2074    Ring(#[source] vmbus_channel::gpadl_ring::Error),
2075    #[error("error establishing vmbus queue")]
2076    Queue(#[source] queue::Error),
2077}
2078
2079#[derive(Debug, Error)]
2080enum PacketError {
2081    #[error("UnknownType {0}")]
2082    UnknownType(u32),
2083    #[error("Access")]
2084    Access(#[source] AccessError),
2085    #[error("ExternalData")]
2086    ExternalData(#[source] ExternalDataError),
2087    #[error("InvalidSendBufferIndex")]
2088    InvalidSendBufferIndex,
2089}
2090
2091#[derive(Debug, Error)]
2092enum PacketOrderError {
2093    #[error("Invalid PacketData")]
2094    InvalidPacketData,
2095    #[error("Unexpected RndisPacket")]
2096    UnexpectedRndisPacket,
2097    #[error("SendNdisVersion already exists")]
2098    SendNdisVersionExists,
2099    #[error("SendNdisConfig already exists")]
2100    SendNdisConfigExists,
2101    #[error("SendReceiveBuffer already exists")]
2102    SendReceiveBufferExists,
2103    #[error("SendReceiveBuffer missing MTU")]
2104    SendReceiveBufferMissingMTU,
2105    #[error("SendSendBuffer already exists")]
2106    SendSendBufferExists,
2107    #[error("SwitchDataPathCompletion during PrimaryChannelState")]
2108    SwitchDataPathCompletionPrimaryChannelState,
2109}
2110
2111#[derive(Debug)]
2112enum PacketData {
2113    Init(protocol::MessageInit),
2114    SendNdisVersion(protocol::Message1SendNdisVersion),
2115    SendReceiveBuffer(protocol::Message1SendReceiveBuffer),
2116    SendSendBuffer(protocol::Message1SendSendBuffer),
2117    RevokeReceiveBuffer(protocol::Message1RevokeReceiveBuffer),
2118    RevokeSendBuffer(protocol::Message1RevokeSendBuffer),
2119    RndisPacket(protocol::Message1SendRndisPacket),
2120    RndisPacketComplete(protocol::Message1SendRndisPacketComplete),
2121    SendNdisConfig(protocol::Message2SendNdisConfig),
2122    SwitchDataPath(protocol::Message4SwitchDataPath),
2123    OidQueryEx(protocol::Message5OidQueryEx),
2124    SubChannelRequest(protocol::Message5SubchannelRequest),
2125    SendVfAssociationCompletion,
2126    SwitchDataPathCompletion,
2127}
2128
2129#[derive(Debug)]
2130struct Packet<'a> {
2131    data: PacketData,
2132    transaction_id: Option<u64>,
2133    external_data: &'a MultiPagedRangeBuf,
2134}
2135
2136type PacketReader<'a> = PagedRangesReader<'a, MultiPagedRangeIter<'a>>;
2137
2138impl Packet<'_> {
2139    fn rndis_reader<'a>(&'a self, mem: &'a GuestMemory) -> PacketReader<'a> {
2140        PagedRanges::new(self.external_data.iter()).reader(mem)
2141    }
2142}
2143
2144fn read_packet_data<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
2145    reader: &mut impl MemoryRead,
2146) -> Result<T, PacketError> {
2147    reader.read_plain().map_err(PacketError::Access)
2148}
2149
2150fn parse_packet<'a, T: RingMem>(
2151    packet_ref: &queue::PacketRef<'_, T>,
2152    send_buffer: Option<&SendBuffer>,
2153    version: Option<Version>,
2154    external_data: &'a mut MultiPagedRangeBuf,
2155) -> Result<Packet<'a>, PacketError> {
2156    external_data.clear();
2157    let packet = match packet_ref.as_ref() {
2158        IncomingPacket::Data(data) => data,
2159        IncomingPacket::Completion(completion) => {
2160            let data = if completion.transaction_id() == VF_ASSOCIATION_TRANSACTION_ID {
2161                PacketData::SendVfAssociationCompletion
2162            } else if completion.transaction_id() == SWITCH_DATA_PATH_TRANSACTION_ID {
2163                PacketData::SwitchDataPathCompletion
2164            } else {
2165                let mut reader = completion.reader();
2166                let header: protocol::MessageHeader =
2167                    reader.read_plain().map_err(PacketError::Access)?;
2168                match header.message_type {
2169                    protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE => {
2170                        PacketData::RndisPacketComplete(read_packet_data(&mut reader)?)
2171                    }
2172                    typ => return Err(PacketError::UnknownType(typ)),
2173                }
2174            };
2175            return Ok(Packet {
2176                data,
2177                transaction_id: Some(completion.transaction_id()),
2178                external_data,
2179            });
2180        }
2181    };
2182
2183    let mut reader = packet.reader();
2184    let header: protocol::MessageHeader = reader.read_plain().map_err(PacketError::Access)?;
2185    let data = match header.message_type {
2186        protocol::MESSAGE_TYPE_INIT => PacketData::Init(read_packet_data(&mut reader)?),
2187        protocol::MESSAGE1_TYPE_SEND_NDIS_VERSION if version >= Some(Version::V1) => {
2188            PacketData::SendNdisVersion(read_packet_data(&mut reader)?)
2189        }
2190        protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2191            PacketData::SendReceiveBuffer(read_packet_data(&mut reader)?)
2192        }
2193        protocol::MESSAGE1_TYPE_REVOKE_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2194            PacketData::RevokeReceiveBuffer(read_packet_data(&mut reader)?)
2195        }
2196        protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER if version >= Some(Version::V1) => {
2197            PacketData::SendSendBuffer(read_packet_data(&mut reader)?)
2198        }
2199        protocol::MESSAGE1_TYPE_REVOKE_SEND_BUFFER if version >= Some(Version::V1) => {
2200            PacketData::RevokeSendBuffer(read_packet_data(&mut reader)?)
2201        }
2202        protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET if version >= Some(Version::V1) => {
2203            let message: protocol::Message1SendRndisPacket = read_packet_data(&mut reader)?;
2204            if message.send_buffer_section_index != 0xffffffff {
2205                let send_buffer_suballocation = send_buffer
2206                    .ok_or(PacketError::InvalidSendBufferIndex)?
2207                    .gpadl
2208                    .first()
2209                    .unwrap()
2210                    .try_subrange(
2211                        message.send_buffer_section_index as usize * 6144,
2212                        message.send_buffer_section_size as usize,
2213                    )
2214                    .ok_or(PacketError::InvalidSendBufferIndex)?;
2215
2216                external_data.push_range(send_buffer_suballocation);
2217            }
2218            PacketData::RndisPacket(message)
2219        }
2220        protocol::MESSAGE2_TYPE_SEND_NDIS_CONFIG if version >= Some(Version::V2) => {
2221            PacketData::SendNdisConfig(read_packet_data(&mut reader)?)
2222        }
2223        protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH if version >= Some(Version::V4) => {
2224            PacketData::SwitchDataPath(read_packet_data(&mut reader)?)
2225        }
2226        protocol::MESSAGE5_TYPE_OID_QUERY_EX if version >= Some(Version::V5) => {
2227            PacketData::OidQueryEx(read_packet_data(&mut reader)?)
2228        }
2229        protocol::MESSAGE5_TYPE_SUB_CHANNEL if version >= Some(Version::V5) => {
2230            PacketData::SubChannelRequest(read_packet_data(&mut reader)?)
2231        }
2232        typ => return Err(PacketError::UnknownType(typ)),
2233    };
2234    packet
2235        .read_external_ranges(external_data)
2236        .map_err(PacketError::ExternalData)?;
2237    Ok(Packet {
2238        data,
2239        transaction_id: packet.transaction_id(),
2240        external_data,
2241    })
2242}
2243
2244#[derive(Debug, Copy, Clone)]
2245struct NvspMessage {
2246    buf: [u64; protocol::PACKET_SIZE_V61 / 8],
2247    size: PacketSize,
2248}
2249
2250impl NvspMessage {
2251    fn new<P: IntoBytes + Immutable + KnownLayout>(
2252        size: PacketSize,
2253        message_type: u32,
2254        data: P,
2255    ) -> Self {
2256        // Assert at compile time that the packet will fit in the message
2257        // buffer. Note that we are checking against the v1 message size here.
2258        // It's possible this is a v6.1+ message, in which case we could compare
2259        // against the larger size. So far this has not been necessary. If
2260        // needed, make a `new_v61` method that does the more relaxed check,
2261        // rater than weakening this one.
2262        const {
2263            assert!(
2264                size_of::<P>() <= protocol::PACKET_SIZE_V1 - size_of::<protocol::MessageHeader>(),
2265                "packet might not fit in message"
2266            )
2267        };
2268        let mut message = NvspMessage {
2269            buf: [0; protocol::PACKET_SIZE_V61 / 8],
2270            size,
2271        };
2272        let header = protocol::MessageHeader { message_type };
2273        header.write_to_prefix(message.buf.as_mut_bytes()).unwrap();
2274        data.write_to_prefix(
2275            &mut message.buf.as_mut_bytes()[size_of::<protocol::MessageHeader>()..],
2276        )
2277        .unwrap();
2278        message
2279    }
2280
2281    fn aligned_payload(&self) -> &[u64] {
2282        // Note that vmbus packets are always 8-byte multiples, so round the
2283        // protocol package size up.
2284        let len = match self.size {
2285            PacketSize::V1 => const { protocol::PACKET_SIZE_V1.div_ceil(8) },
2286            PacketSize::V61 => const { protocol::PACKET_SIZE_V61.div_ceil(8) },
2287        };
2288        &self.buf[..len]
2289    }
2290}
2291
2292impl<T: RingMem> NetChannel<T> {
2293    fn message<P: IntoBytes + Immutable + KnownLayout>(
2294        &self,
2295        message_type: u32,
2296        data: P,
2297    ) -> NvspMessage {
2298        NvspMessage::new(self.packet_size, message_type, data)
2299    }
2300
2301    fn send_completion(
2302        &mut self,
2303        transaction_id: Option<u64>,
2304        message: Option<&NvspMessage>,
2305    ) -> Result<(), WorkerError> {
2306        match transaction_id {
2307            None => Ok(()),
2308            Some(transaction_id) => Ok(self
2309                .queue
2310                .split()
2311                .1
2312                .batched()
2313                .try_write_aligned(
2314                    transaction_id,
2315                    OutgoingPacketType::Completion,
2316                    message.map_or(&[], |m| m.aligned_payload()),
2317                )
2318                .map_err(|err| match err {
2319                    queue::TryWriteError::Full(_) => WorkerError::OutOfSpace,
2320                    queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2321                })?),
2322        }
2323    }
2324}
2325
2326static SUPPORTED_VERSIONS: &[Version] = &[
2327    Version::V1,
2328    Version::V2,
2329    Version::V4,
2330    Version::V5,
2331    Version::V6,
2332    Version::V61,
2333];
2334
2335fn check_version(requested_version: u32) -> Option<Version> {
2336    SUPPORTED_VERSIONS
2337        .iter()
2338        .find(|version| **version as u32 == requested_version)
2339        .copied()
2340}
2341
2342#[derive(Debug)]
2343struct ReceiveBuffer {
2344    gpadl: GpadlView,
2345    id: u16,
2346    count: u32,
2347    sub_allocation_size: u32,
2348}
2349
2350#[derive(Debug, Error)]
2351enum BufferError {
2352    #[error("unsupported suballocation size {0}")]
2353    UnsupportedSuballocationSize(u32),
2354    #[error("unaligned gpadl")]
2355    UnalignedGpadl,
2356    #[error("unknown gpadl ID")]
2357    UnknownGpadlId(#[from] UnknownGpadlId),
2358}
2359
2360impl ReceiveBuffer {
2361    fn new(
2362        gpadl_map: &GpadlMapView,
2363        gpadl_id: GpadlId,
2364        id: u16,
2365        sub_allocation_size: u32,
2366    ) -> Result<Self, BufferError> {
2367        if sub_allocation_size < sub_allocation_size_for_mtu(DEFAULT_MTU) {
2368            return Err(BufferError::UnsupportedSuballocationSize(
2369                sub_allocation_size,
2370            ));
2371        }
2372        let gpadl = gpadl_map.map(gpadl_id)?;
2373        let range = gpadl
2374            .contiguous_aligned()
2375            .ok_or(BufferError::UnalignedGpadl)?;
2376        let num_sub_allocations = range.len() as u32 / sub_allocation_size;
2377        if num_sub_allocations == 0 {
2378            return Err(BufferError::UnsupportedSuballocationSize(
2379                sub_allocation_size,
2380            ));
2381        }
2382        let recv_buffer = Self {
2383            gpadl,
2384            id,
2385            count: num_sub_allocations,
2386            sub_allocation_size,
2387        };
2388        Ok(recv_buffer)
2389    }
2390
2391    fn range(&self, index: u32) -> PagedRange<'_> {
2392        self.gpadl.first().unwrap().subrange(
2393            (index * self.sub_allocation_size) as usize,
2394            self.sub_allocation_size as usize,
2395        )
2396    }
2397
2398    fn transfer_page_range(&self, index: u32, len: usize) -> ring::TransferPageRange {
2399        assert!(len <= self.sub_allocation_size as usize);
2400        ring::TransferPageRange {
2401            byte_offset: index * self.sub_allocation_size,
2402            byte_count: len as u32,
2403        }
2404    }
2405
2406    fn saved_state(&self) -> saved_state::ReceiveBuffer {
2407        saved_state::ReceiveBuffer {
2408            gpadl_id: self.gpadl.id(),
2409            id: self.id,
2410            sub_allocation_size: self.sub_allocation_size,
2411        }
2412    }
2413}
2414
2415#[derive(Debug)]
2416struct SendBuffer {
2417    gpadl: GpadlView,
2418}
2419
2420impl SendBuffer {
2421    fn new(gpadl_map: &GpadlMapView, gpadl_id: GpadlId) -> Result<Self, BufferError> {
2422        let gpadl = gpadl_map.map(gpadl_id)?;
2423        gpadl
2424            .contiguous_aligned()
2425            .ok_or(BufferError::UnalignedGpadl)?;
2426        Ok(Self { gpadl })
2427    }
2428}
2429
2430impl<T: RingMem> NetChannel<T> {
2431    /// Process a single non-packet RNDIS message.
2432    fn handle_rndis_message(
2433        &mut self,
2434        state: &mut ActiveState,
2435        message_type: u32,
2436        mut reader: PacketReader<'_>,
2437    ) -> Result<(), WorkerError> {
2438        assert_ne!(
2439            message_type,
2440            rndisprot::MESSAGE_TYPE_PACKET_MSG,
2441            "handled elsewhere"
2442        );
2443        let control = state
2444            .primary
2445            .as_mut()
2446            .ok_or(WorkerError::NotSupportedOnSubChannel(message_type))?;
2447
2448        if message_type == rndisprot::MESSAGE_TYPE_HALT_MSG {
2449            // Currently ignored and does not require a response.
2450            return Ok(());
2451        }
2452
2453        // This is a control message that needs a response. Responding
2454        // will require a suballocation to be available, which it may
2455        // not be right now. Enqueue the suballocation to a queue and
2456        // process the queue as suballocations become available.
2457        const CONTROL_MESSAGE_MAX_QUEUED_BYTES: usize = 100 * 1024;
2458        if reader.len() == 0 {
2459            return Err(WorkerError::RndisMessageTooSmall);
2460        }
2461        // Do not let the queue get too large--the guest should not be
2462        // sending very many control messages at a time.
2463        if CONTROL_MESSAGE_MAX_QUEUED_BYTES - control.control_messages_len < reader.len() {
2464            return Err(WorkerError::TooManyControlMessages);
2465        }
2466
2467        control.control_messages_len += reader.len();
2468        control.control_messages.push_back(ControlMessage {
2469            message_type,
2470            data: reader.read_all()?.into(),
2471        });
2472
2473        // The control message queue will be processed in the main dispatch
2474        // loop.
2475        Ok(())
2476    }
2477
2478    /// Process RNDIS packet messages, which may contain multiple RNDIS packets
2479    /// in a single vmbus message.
2480    ///
2481    /// On entry, the reader has already read the RNDIS message header of the
2482    /// first RNDIS packet in the message.
2483    fn handle_rndis_packet_messages(
2484        &mut self,
2485        buffers: &ChannelBuffers,
2486        state: &mut ActiveState,
2487        id: TxId,
2488        mut message_len: usize,
2489        mut reader: PacketReader<'_>,
2490        segments: &mut Vec<TxSegment>,
2491    ) -> Result<usize, WorkerError> {
2492        // There may be multiple RNDIS packets in a single message, concatenated
2493        // with each other. Consume them until there is no more data in the
2494        // RNDIS message.
2495        let mut num_packets = 0;
2496        loop {
2497            let next_message_offset = message_len
2498                .checked_sub(size_of::<rndisprot::MessageHeader>())
2499                .ok_or(WorkerError::RndisMessageTooSmall)?;
2500
2501            self.handle_rndis_packet_message(
2502                id,
2503                reader.clone(),
2504                &buffers.mem,
2505                segments,
2506                &mut state.stats,
2507            )?;
2508            num_packets += 1;
2509
2510            reader.skip(next_message_offset)?;
2511            if reader.len() == 0 {
2512                break;
2513            }
2514            let header: rndisprot::MessageHeader = reader.read_plain()?;
2515            if header.message_type != rndisprot::MESSAGE_TYPE_PACKET_MSG {
2516                return Err(WorkerError::NonRndisPacketAfterPacket(header.message_type));
2517            }
2518            message_len = header.message_length as usize;
2519        }
2520        Ok(num_packets)
2521    }
2522
2523    /// Process an RNDIS package message (used to send an Ethernet frame).
2524    fn handle_rndis_packet_message(
2525        &mut self,
2526        id: TxId,
2527        reader: PacketReader<'_>,
2528        mem: &GuestMemory,
2529        segments: &mut Vec<TxSegment>,
2530        stats: &mut QueueStats,
2531    ) -> Result<(), WorkerError> {
2532        // Headers are guaranteed to be in a single PagedRange.
2533        let headers = reader
2534            .clone()
2535            .into_inner()
2536            .paged_ranges()
2537            .next()
2538            .ok_or(WorkerError::RndisMessageTooSmall)?;
2539        let mut data = reader.into_inner();
2540        let request: rndisprot::Packet = headers.reader(mem).read_plain()?;
2541        if request.num_oob_data_elements != 0
2542            || request.oob_data_length != 0
2543            || request.oob_data_offset != 0
2544            || request.vc_handle != 0
2545        {
2546            return Err(WorkerError::UnsupportedRndisBehavior);
2547        }
2548
2549        if data.len() < request.data_offset as usize
2550            || (data.len() - request.data_offset as usize) < request.data_length as usize
2551            || request.data_length == 0
2552        {
2553            return Err(WorkerError::RndisMessageTooSmall);
2554        }
2555
2556        data.skip(request.data_offset as usize);
2557        data.truncate(request.data_length as usize);
2558
2559        let mut metadata = net_backend::TxMetadata {
2560            id,
2561            len: request.data_length,
2562            ..Default::default()
2563        };
2564
2565        if request.per_packet_info_length != 0 {
2566            let mut ppi = headers
2567                .try_subrange(
2568                    request.per_packet_info_offset as usize,
2569                    request.per_packet_info_length as usize,
2570                )
2571                .ok_or(WorkerError::RndisMessageTooSmall)?;
2572            while !ppi.is_empty() {
2573                let h: rndisprot::PerPacketInfo = ppi.reader(mem).read_plain()?;
2574                if h.size == 0 {
2575                    return Err(WorkerError::RndisMessageTooSmall);
2576                }
2577                let (this, rest) = ppi
2578                    .try_split(h.size as usize)
2579                    .ok_or(WorkerError::RndisMessageTooSmall)?;
2580                let (_, d) = this
2581                    .try_split(h.per_packet_information_offset as usize)
2582                    .ok_or(WorkerError::RndisMessageTooSmall)?;
2583                match h.typ {
2584                    rndisprot::PPI_TCP_IP_CHECKSUM => {
2585                        let n: rndisprot::TxTcpIpChecksumInfo = d.reader(mem).read_plain()?;
2586
2587                        metadata.flags.set_offload_tcp_checksum(
2588                            (n.is_ipv4() || n.is_ipv6()) && n.tcp_checksum(),
2589                        );
2590                        metadata.flags.set_offload_udp_checksum(
2591                            (n.is_ipv4() || n.is_ipv6()) && !n.tcp_checksum() && n.udp_checksum(),
2592                        );
2593                        metadata
2594                            .flags
2595                            .set_offload_ip_header_checksum(n.is_ipv4() && n.ip_header_checksum());
2596                        metadata.flags.set_is_ipv4(n.is_ipv4());
2597                        metadata.flags.set_is_ipv6(n.is_ipv6() && !n.is_ipv4());
2598                        metadata.transport_header_offset = n.tcp_header_offset();
2599                    }
2600                    rndisprot::PPI_LSO => {
2601                        let n: rndisprot::TcpLsoInfo = d.reader(mem).read_plain()?;
2602
2603                        metadata.flags.set_offload_tcp_segmentation(true);
2604                        metadata.flags.set_offload_tcp_checksum(true);
2605                        metadata.flags.set_offload_ip_header_checksum(n.is_ipv4());
2606                        metadata.flags.set_is_ipv4(n.is_ipv4());
2607                        metadata.flags.set_is_ipv6(n.is_ipv6() && !n.is_ipv4());
2608                        metadata.max_segment_size = n.mss() as u16;
2609                        metadata.transport_header_offset = n.tcp_header_offset();
2610                    }
2611                    rndisprot::PPI_VLAN => {
2612                        let n: rndisprot::EthVlanInfo = d.reader(mem).read_plain()?;
2613
2614                        metadata.vlan = Some(n.into());
2615                    }
2616                    _ => {}
2617                }
2618                ppi = rest;
2619            }
2620
2621            // The frame data always has a 14-byte Ethernet header; when
2622            // VLAN is present it arrives out-of-band in the PPI (not inline
2623            // in the frame), so l2_len is unconditionally 14. If the guest
2624            // does present a different ethernet header length, then the checksum
2625            // will fail and the send won't work, but that's really on the guest.
2626            metadata.l2_len = net_backend::ETHERNET_HEADER_LEN as u8;
2627
2628            if metadata.flags.offload_tcp_checksum() || metadata.flags.offload_udp_checksum() {
2629                // We can determine header length from other means, and presume there's
2630                // no additional data. If there is, the packet will fail checksums but that's
2631                // on the guest for not providing a specific length. This matches extant behavior.
2632                metadata.l3_len = if metadata.transport_header_offset == 0 {
2633                    if metadata.flags.is_ipv4() {
2634                        net_backend::IPV4_MIN_HEADER_LEN
2635                    } else if metadata.flags.is_ipv6() {
2636                        net_backend::IPV6_MIN_HEADER_LEN
2637                    } else {
2638                        unreachable!("this packet is neither v4 nor v6?");
2639                    }
2640                } else if (metadata.transport_header_offset < metadata.l2_len as u16)
2641                    || (metadata.flags.is_ipv4()
2642                        && metadata.transport_header_offset
2643                            < (metadata.l2_len as u16 + net_backend::IPV4_MIN_HEADER_LEN))
2644                    || (metadata.flags.is_ipv6()
2645                        && metadata.transport_header_offset
2646                            < (metadata.l2_len as u16 + net_backend::IPV6_MIN_HEADER_LEN))
2647                    || (metadata.transport_header_offset as u32 >= request.data_length)
2648                {
2649                    return Err(WorkerError::InvalidTcpHeaderOffset(
2650                        metadata.transport_header_offset,
2651                    ));
2652                } else {
2653                    metadata.transport_header_offset - metadata.l2_len as u16
2654                }
2655            }
2656
2657            if metadata.flags.offload_tcp_segmentation() {
2658                const TCP_DOFF_BYTE_OFFSET: u32 = 12;
2659                let tcp_hdr_doff_offset =
2660                    u32::from(metadata.transport_header_offset) + TCP_DOFF_BYTE_OFFSET;
2661                // Validate TCP header Data Offset 4 bit nibble within the packet data bounds.
2662                if tcp_hdr_doff_offset >= request.data_length {
2663                    return Err(WorkerError::InvalidTcpHeaderOffset(
2664                        metadata.transport_header_offset,
2665                    ));
2666                }
2667                metadata.l4_len = {
2668                    let mut reader = data.clone().reader(mem);
2669                    reader.skip(tcp_hdr_doff_offset as usize)?;
2670                    let mut b = 0;
2671                    reader.read(std::slice::from_mut(&mut b))?;
2672                    (b >> 4) * 4
2673                };
2674
2675                if request.data_length >= rndisprot::LSO_MAX_OFFLOAD_SIZE {
2676                    // Not strictly enforced.
2677                    stats.tx_invalid_lso_packets.increment();
2678                }
2679            }
2680
2681            // Issue #3453: USO support is not present. (https://github.com/microsoft/openvmm/issues/3453)
2682        }
2683
2684        let start = segments.len();
2685        for range in data.paged_ranges().flat_map(|r| r.ranges()) {
2686            let range = range.map_err(WorkerError::InvalidGpadl)?;
2687            segments.push(TxSegment {
2688                ty: net_backend::TxSegmentType::Tail,
2689                gpa: range.start,
2690                len: range.len() as u32,
2691            });
2692        }
2693
2694        metadata.segment_count = (segments.len() - start) as u8;
2695
2696        stats.tx_packets.increment();
2697        if metadata.flags.offload_tcp_checksum() || metadata.flags.offload_udp_checksum() {
2698            stats.tx_checksum_packets.increment();
2699        }
2700        if metadata.flags.offload_tcp_segmentation() {
2701            stats.tx_lso_packets.increment();
2702        }
2703        if metadata.vlan.is_some() {
2704            stats.tx_vlan_packets.increment();
2705        }
2706
2707        segments[start].ty = net_backend::TxSegmentType::Head(metadata);
2708
2709        Ok(())
2710    }
2711
2712    /// Notify the adapter that the guest VF state has changed and it may
2713    /// need to send a message to the guest.
2714    /// Pass `vfid: Some(id)` to advertise VF availability; if an association
2715    /// message is queued successfully, its serial number is stored in
2716    /// `primary.advertised_vf_serial_number`.
2717    /// Pass `vfid: None` to send a disassociation; the stored serial number
2718    /// from the most recent association is reused.
2719    fn guest_vf_is_available(
2720        &mut self,
2721        primary: &mut PrimaryChannelState,
2722        vfid: Option<u32>,
2723        version: Version,
2724        config: NdisConfig,
2725    ) -> Result<bool, WorkerError> {
2726        let (serial_number, available) = if let Some(vfid) = vfid {
2727            (self.adapter.get_guest_vf_serial_number(vfid), true)
2728        } else {
2729            (primary.advertised_vf_serial_number.unwrap_or(0), false)
2730        };
2731        if version >= Version::V4 && config.capabilities.sriov() {
2732            tracing::info!(available, serial_number, "sending VF association message");
2733            // N.B. MIN_CONTROL_RING_SIZE reserves room to send this packet.
2734            let message = {
2735                self.message(
2736                    protocol::MESSAGE4_TYPE_SEND_VF_ASSOCIATION,
2737                    protocol::Message4SendVfAssociation {
2738                        vf_allocated: if available { 1 } else { 0 },
2739                        serial_number,
2740                    },
2741                )
2742            };
2743            self.queue
2744                .split()
2745                .1
2746                .batched()
2747                .try_write_aligned(
2748                    VF_ASSOCIATION_TRANSACTION_ID,
2749                    OutgoingPacketType::InBandWithCompletion,
2750                    message.aligned_payload(),
2751                )
2752                .map_err(|err| match err {
2753                    queue::TryWriteError::Full(len) => {
2754                        tracing::error!(len, "failed to write vf association message");
2755                        WorkerError::OutOfSpace
2756                    }
2757                    queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2758                })?;
2759
2760            // Update the advertised VF serial number once the message has been successfully queued.
2761            if available {
2762                primary.advertised_vf_serial_number = Some(serial_number);
2763            } else {
2764                primary.advertised_vf_serial_number = None;
2765            }
2766            Ok(true)
2767        } else {
2768            tracing::info!(
2769                available,
2770                serial_number,
2771                major = version.major(),
2772                minor = version.minor(),
2773                sriov_capable = config.capabilities.sriov(),
2774                "Skipping NvspMessage4TypeSendVFAssociation message"
2775            );
2776            Ok(false)
2777        }
2778    }
2779
2780    /// Send the `NvspMessage5TypeSendIndirectionTable` message.
2781    fn guest_send_indirection_table(&mut self, version: Version, num_channels_opened: u32) {
2782        // N.B. MIN_STATE_CHANGE_RING_SIZE needs to be large enough to support sending the indirection table.
2783        if version < Version::V5 {
2784            return;
2785        }
2786
2787        #[repr(C)]
2788        #[derive(IntoBytes, Immutable, KnownLayout)]
2789        struct SendIndirectionMsg {
2790            pub message: protocol::Message5SendIndirectionTable,
2791            pub send_indirection_table:
2792                [u32; VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES as usize],
2793        }
2794
2795        // The offset to the send indirection table from the beginning of the NVSP message.
2796        let send_indirection_table_offset = offset_of!(SendIndirectionMsg, send_indirection_table)
2797            + size_of::<protocol::MessageHeader>();
2798        let mut data = SendIndirectionMsg {
2799            message: protocol::Message5SendIndirectionTable {
2800                table_entry_count: VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES,
2801                table_offset: send_indirection_table_offset as u32,
2802            },
2803            send_indirection_table: Default::default(),
2804        };
2805
2806        for i in 0..data.send_indirection_table.len() {
2807            data.send_indirection_table[i] = i as u32 % num_channels_opened;
2808        }
2809
2810        let header = protocol::MessageHeader {
2811            message_type: protocol::MESSAGE5_TYPE_SEND_INDIRECTION_TABLE,
2812        };
2813        let result = self
2814            .queue
2815            .split()
2816            .1
2817            .try_write(&queue::OutgoingPacket {
2818                transaction_id: 0,
2819                packet_type: OutgoingPacketType::InBandNoCompletion,
2820                payload: &[header.as_bytes(), data.as_bytes()],
2821            })
2822            .map_err(|err| match err {
2823                queue::TryWriteError::Full(len) => {
2824                    tracing::error!(len, "failed to write send indirection table message");
2825                    WorkerError::OutOfSpace
2826                }
2827                queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2828            });
2829        if let Err(err) = result {
2830            tracing::error!(
2831                error = &err as &dyn std::error::Error,
2832                "Failed to notify guest about the send indirection table"
2833            );
2834        }
2835    }
2836
2837    /// Notify the guest that the data path has been switched back to synthetic
2838    /// due to some external state change.
2839    fn guest_vf_data_path_switched_to_synthetic(&mut self) {
2840        let header = protocol::MessageHeader {
2841            message_type: protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH,
2842        };
2843        let data = protocol::Message4SwitchDataPath {
2844            active_data_path: protocol::DataPath::SYNTHETIC.0,
2845        };
2846        let result = self
2847            .queue
2848            .split()
2849            .1
2850            .try_write(&queue::OutgoingPacket {
2851                transaction_id: SWITCH_DATA_PATH_TRANSACTION_ID,
2852                packet_type: OutgoingPacketType::InBandWithCompletion,
2853                payload: &[header.as_bytes(), data.as_bytes()],
2854            })
2855            .map_err(|err| match err {
2856                queue::TryWriteError::Full(len) => {
2857                    tracing::error!(len, "failed to write switch data path message");
2858                    WorkerError::OutOfSpace
2859                }
2860                queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2861            });
2862        if let Err(err) = result {
2863            tracing::error!(
2864                error = &err as &dyn std::error::Error,
2865                "Failed to notify guest that data path is now synthetic"
2866            );
2867        } else {
2868            tracing::info!("Switched data path to synthetic")
2869        }
2870    }
2871
2872    /// Process an internal state change
2873    async fn handle_state_change(
2874        &mut self,
2875        primary: &mut PrimaryChannelState,
2876        buffers: &ChannelBuffers,
2877    ) -> Result<Option<CoordinatorMessage>, WorkerError> {
2878        // N.B. MIN_STATE_CHANGE_RING_SIZE needs to be large enough to support sending state change messages.
2879        // The worst case is UnavailableFromDataPathSwitchPending, which will send three messages:
2880        //      1. completion of switch data path request
2881        //      2. Switch data path notification (back to synthetic)
2882        //      3. Disassociate VF adapter.
2883        if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
2884            // Notify guest that a VF capability has recently arrived.
2885            if primary.rndis_state == RndisState::Operational {
2886                if self.guest_vf_is_available(
2887                    primary,
2888                    Some(vfid),
2889                    buffers.version,
2890                    buffers.ndis_config,
2891                )? {
2892                    primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
2893                    return Ok(Some(CoordinatorMessage::Update(
2894                        CoordinatorMessageUpdateType {
2895                            guest_vf_state: true,
2896                            ..Default::default()
2897                        },
2898                    )));
2899                } else if let Some(true) = primary.is_data_path_switched {
2900                    tracing::error!(
2901                        "Data path switched, but current guest negotiation does not support VTL0 VF"
2902                    );
2903                }
2904            }
2905            return Ok(None);
2906        }
2907        loop {
2908            primary.guest_vf_state = match primary.guest_vf_state {
2909                PrimaryChannelGuestVfState::UnavailableFromAvailable => {
2910                    // Notify guest that the VF is unavailable. It has already been surprise removed.
2911                    if primary.rndis_state == RndisState::Operational {
2912                        self.guest_vf_is_available(
2913                            primary,
2914                            None,
2915                            buffers.version,
2916                            buffers.ndis_config,
2917                        )?;
2918                    }
2919                    PrimaryChannelGuestVfState::Unavailable
2920                }
2921                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
2922                    to_guest,
2923                    id,
2924                } => {
2925                    // Complete the data path switch request.
2926                    self.send_completion(id, None)?;
2927                    if to_guest {
2928                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
2929                    } else {
2930                        PrimaryChannelGuestVfState::UnavailableFromAvailable
2931                    }
2932                }
2933                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
2934                    // Notify guest that the data path is now synthetic.
2935                    self.guest_vf_data_path_switched_to_synthetic();
2936                    PrimaryChannelGuestVfState::UnavailableFromAvailable
2937                }
2938                PrimaryChannelGuestVfState::DataPathSynthetic => {
2939                    // Notify guest that the data path is now synthetic.
2940                    self.guest_vf_data_path_switched_to_synthetic();
2941                    PrimaryChannelGuestVfState::Ready
2942                }
2943                PrimaryChannelGuestVfState::DataPathSwitchPending {
2944                    to_guest,
2945                    id,
2946                    result,
2947                } => {
2948                    let result = result.expect("DataPathSwitchPending should have been processed");
2949                    // Complete the data path switch request.
2950                    self.send_completion(id, None)?;
2951
2952                    match (to_guest, result) {
2953                        // Switching to guest VF successful.
2954                        (true, true) => PrimaryChannelGuestVfState::DataPathSwitched,
2955                        // Switching to guest VF failed, stay synthetic.
2956                        (true, false) => {
2957                            tracing::error!(
2958                                "Failure switching to guest VF, remaining on synthetic"
2959                            );
2960                            PrimaryChannelGuestVfState::DataPathSynthetic
2961                        }
2962                        // Switching to synthetic successful.
2963                        (false, true) => PrimaryChannelGuestVfState::Ready,
2964                        // Switching to synthetic failed, assume VF remains active.
2965                        (false, false) => {
2966                            tracing::error!(
2967                                "Failure when guest requested switch back to synthetic"
2968                            );
2969                            PrimaryChannelGuestVfState::DataPathSwitched
2970                        }
2971                    }
2972                }
2973                PrimaryChannelGuestVfState::Initializing
2974                | PrimaryChannelGuestVfState::Restoring(_) => {
2975                    panic!("Invalid guest VF state: {}", primary.guest_vf_state)
2976                }
2977                _ => break,
2978            };
2979        }
2980        Ok(None)
2981    }
2982
2983    /// Process a control message, writing the response to the provided receive
2984    /// buffer suballocation.
2985    fn handle_rndis_control_message(
2986        &mut self,
2987        primary: &mut PrimaryChannelState,
2988        buffers: &ChannelBuffers,
2989        message_type: u32,
2990        mut reader: impl MemoryRead + Clone,
2991        id: u32,
2992    ) -> Result<(), WorkerError> {
2993        let mem = &buffers.mem;
2994        let buffer_range = &buffers.recv_buffer.range(id);
2995        match message_type {
2996            rndisprot::MESSAGE_TYPE_INITIALIZE_MSG => {
2997                if primary.rndis_state != RndisState::Initializing {
2998                    return Err(WorkerError::InvalidRndisState);
2999                }
3000
3001                let request: rndisprot::InitializeRequest = reader.read_plain()?;
3002
3003                tracing::trace!(
3004                    ?request,
3005                    "handling control message MESSAGE_TYPE_INITIALIZE_MSG"
3006                );
3007
3008                primary.rndis_state = RndisState::Operational;
3009
3010                let mut writer = buffer_range.writer(mem);
3011                let message_length = write_rndis_message(
3012                    &mut writer,
3013                    rndisprot::MESSAGE_TYPE_INITIALIZE_CMPLT,
3014                    0,
3015                    &rndisprot::InitializeComplete {
3016                        request_id: request.request_id,
3017                        status: rndisprot::STATUS_SUCCESS,
3018                        major_version: rndisprot::MAJOR_VERSION,
3019                        minor_version: rndisprot::MINOR_VERSION,
3020                        device_flags: rndisprot::DF_CONNECTIONLESS,
3021                        medium: rndisprot::MEDIUM_802_3,
3022                        max_packets_per_message: 8,
3023                        max_transfer_size: 0xEFFFFFFF,
3024                        packet_alignment_factor: 3,
3025                        af_list_offset: 0,
3026                        af_list_size: 0,
3027                    },
3028                )?;
3029                self.send_rndis_control_message(buffers, id, message_length)?;
3030                if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
3031                    if self.guest_vf_is_available(
3032                        primary,
3033                        Some(vfid),
3034                        buffers.version,
3035                        buffers.ndis_config,
3036                    )? {
3037                        // Ideally the VF would not be presented to the guest
3038                        // until the completion packet has arrived, so that the
3039                        // guest is prepared. This is most interesting for the
3040                        // case of a VF associated with multiple guest
3041                        // adapters, using a concept like vports. In this
3042                        // scenario it would be better if all of the adapters
3043                        // were aware a VF was available before the device
3044                        // arrived. This is not currently possible because the
3045                        // Linux netvsc driver ignores the completion requested
3046                        // flag on inband packets and won't send a completion
3047                        // packet.
3048                        primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
3049                        self.send_coordinator_update_vf();
3050                    } else if let Some(true) = primary.is_data_path_switched {
3051                        tracing::error!(
3052                            "Data path switched, but current guest negotiation does not support VTL0 VF"
3053                        );
3054                    }
3055                }
3056            }
3057            rndisprot::MESSAGE_TYPE_QUERY_MSG => {
3058                let request: rndisprot::QueryRequest = reader.read_plain()?;
3059
3060                tracing::trace!(?request, "handling control message MESSAGE_TYPE_QUERY_MSG");
3061
3062                let (header, body) = buffer_range
3063                    .try_split(
3064                        size_of::<rndisprot::MessageHeader>()
3065                            + size_of::<rndisprot::QueryComplete>(),
3066                    )
3067                    .ok_or(WorkerError::RndisMessageTooSmall)?;
3068                let (status, tx) = match self.adapter.handle_oid_query(
3069                    buffers,
3070                    primary,
3071                    request.oid,
3072                    body.writer(mem),
3073                ) {
3074                    Ok(tx) => (rndisprot::STATUS_SUCCESS, tx),
3075                    Err(err) => (err.as_status(), 0),
3076                };
3077
3078                let message_length = write_rndis_message(
3079                    &mut header.writer(mem),
3080                    rndisprot::MESSAGE_TYPE_QUERY_CMPLT,
3081                    tx,
3082                    &rndisprot::QueryComplete {
3083                        request_id: request.request_id,
3084                        status,
3085                        information_buffer_offset: size_of::<rndisprot::QueryComplete>() as u32,
3086                        information_buffer_length: tx as u32,
3087                    },
3088                )?;
3089                self.send_rndis_control_message(buffers, id, message_length)?;
3090            }
3091            rndisprot::MESSAGE_TYPE_SET_MSG => {
3092                let request: rndisprot::SetRequest = reader.read_plain()?;
3093
3094                tracing::trace!(?request, "handling control message MESSAGE_TYPE_SET_MSG");
3095
3096                let status = match self.adapter.handle_oid_set(primary, request.oid, reader) {
3097                    Ok((restart_endpoint, packet_filter)) => {
3098                        // Restart the endpoint if the OID changed some critical
3099                        // endpoint property.
3100                        if restart_endpoint {
3101                            self.restart = Some(CoordinatorMessage::Restart { channel_idx: 0 });
3102                        }
3103                        if let Some(filter) = packet_filter {
3104                            if self.packet_filter != filter {
3105                                self.packet_filter = filter;
3106                                self.send_coordinator_update_filter();
3107                            }
3108                        }
3109                        rndisprot::STATUS_SUCCESS
3110                    }
3111                    Err(err) => {
3112                        tracelimit::warn_ratelimited!(
3113                            error = &err as &dyn std::error::Error,
3114                            oid = ?request.oid,
3115                            "oid set failure"
3116                        );
3117                        err.as_status()
3118                    }
3119                };
3120
3121                let message_length = write_rndis_message(
3122                    &mut buffer_range.writer(mem),
3123                    rndisprot::MESSAGE_TYPE_SET_CMPLT,
3124                    0,
3125                    &rndisprot::SetComplete {
3126                        request_id: request.request_id,
3127                        status,
3128                    },
3129                )?;
3130                self.send_rndis_control_message(buffers, id, message_length)?;
3131            }
3132            rndisprot::MESSAGE_TYPE_RESET_MSG => {
3133                return Err(WorkerError::RndisMessageTypeNotImplemented);
3134            }
3135            rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG => {
3136                return Err(WorkerError::RndisMessageTypeNotImplemented);
3137            }
3138            rndisprot::MESSAGE_TYPE_KEEPALIVE_MSG => {
3139                let request: rndisprot::KeepaliveRequest = reader.read_plain()?;
3140
3141                tracing::trace!(
3142                    ?request,
3143                    "handling control message MESSAGE_TYPE_KEEPALIVE_MSG"
3144                );
3145
3146                let message_length = write_rndis_message(
3147                    &mut buffer_range.writer(mem),
3148                    rndisprot::MESSAGE_TYPE_KEEPALIVE_CMPLT,
3149                    0,
3150                    &rndisprot::KeepaliveComplete {
3151                        request_id: request.request_id,
3152                        status: rndisprot::STATUS_SUCCESS,
3153                    },
3154                )?;
3155                self.send_rndis_control_message(buffers, id, message_length)?;
3156            }
3157            rndisprot::MESSAGE_TYPE_SET_EX_MSG => {
3158                return Err(WorkerError::RndisMessageTypeNotImplemented);
3159            }
3160            _ => return Err(WorkerError::UnknownRndisMessageType(message_type)),
3161        };
3162        Ok(())
3163    }
3164
3165    fn try_send_rndis_message(
3166        &mut self,
3167        transaction_id: u64,
3168        channel_type: u32,
3169        recv_buffer_id: u16,
3170        transfer_pages: &[ring::TransferPageRange],
3171    ) -> Result<Option<usize>, WorkerError> {
3172        let message = self.message(
3173            protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET,
3174            protocol::Message1SendRndisPacket {
3175                channel_type,
3176                send_buffer_section_index: 0xffffffff,
3177                send_buffer_section_size: 0,
3178            },
3179        );
3180        let pending_send_size = match self.queue.split().1.batched().try_write_aligned(
3181            transaction_id,
3182            OutgoingPacketType::TransferPages(recv_buffer_id, transfer_pages),
3183            message.aligned_payload(),
3184        ) {
3185            Ok(()) => None,
3186            Err(queue::TryWriteError::Full(n)) => Some(n),
3187            Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
3188        };
3189        Ok(pending_send_size)
3190    }
3191
3192    fn send_rndis_control_message(
3193        &mut self,
3194        buffers: &ChannelBuffers,
3195        id: u32,
3196        message_length: usize,
3197    ) -> Result<(), WorkerError> {
3198        let result = self.try_send_rndis_message(
3199            id as u64,
3200            protocol::CONTROL_CHANNEL_TYPE,
3201            buffers.recv_buffer.id,
3202            std::slice::from_ref(&buffers.recv_buffer.transfer_page_range(id, message_length)),
3203        )?;
3204
3205        // Ring size is checked before control messages are processed, so failure to write is unexpected.
3206        match result {
3207            None => Ok(()),
3208            Some(len) => {
3209                tracelimit::error_ratelimited!(len, "failed to write control message completion");
3210                Err(WorkerError::OutOfSpace)
3211            }
3212        }
3213    }
3214
3215    fn indicate_status(
3216        &mut self,
3217        buffers: &ChannelBuffers,
3218        id: u32,
3219        status: u32,
3220        payload: &[u8],
3221    ) -> Result<(), WorkerError> {
3222        let buffer = &buffers.recv_buffer.range(id);
3223        let mut writer = buffer.writer(&buffers.mem);
3224        let message_length = write_rndis_message(
3225            &mut writer,
3226            rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG,
3227            payload.len(),
3228            &rndisprot::IndicateStatus {
3229                status,
3230                status_buffer_length: payload.len() as u32,
3231                status_buffer_offset: if payload.is_empty() {
3232                    0
3233                } else {
3234                    size_of::<rndisprot::IndicateStatus>() as u32
3235                },
3236            },
3237        )?;
3238        writer.write(payload)?;
3239        self.send_rndis_control_message(buffers, id, message_length)?;
3240        Ok(())
3241    }
3242
3243    /// Processes pending control messages until all are processed or there are
3244    /// no available suballocations.
3245    fn process_control_messages(
3246        &mut self,
3247        buffers: &ChannelBuffers,
3248        state: &mut ActiveState,
3249    ) -> Result<(), WorkerError> {
3250        let Some(primary) = &mut state.primary else {
3251            return Ok(());
3252        };
3253
3254        while !primary.control_messages.is_empty()
3255            || (primary.pending_offload_change && primary.rndis_state == RndisState::Operational)
3256        {
3257            // Ensure the ring buffer has enough room to successfully complete control message handling.
3258            if !self.queue.split().1.can_write(MIN_CONTROL_RING_SIZE)? {
3259                self.pending_send_size = MIN_CONTROL_RING_SIZE;
3260                break;
3261            }
3262            let Some(id) = primary.free_control_buffers.pop() else {
3263                break;
3264            };
3265
3266            // Mark the receive buffer in use to allow the guest to release it.
3267            assert!(state.rx_bufs.is_free(id.0));
3268            state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
3269
3270            if let Some(message) = primary.control_messages.pop_front() {
3271                primary.control_messages_len -= message.data.len();
3272                self.handle_rndis_control_message(
3273                    primary,
3274                    buffers,
3275                    message.message_type,
3276                    message.data.as_ref(),
3277                    id.0,
3278                )?;
3279            } else if primary.pending_offload_change
3280                && primary.rndis_state == RndisState::Operational
3281            {
3282                let ndis_offload = primary.offload_config.ndis_offload();
3283                self.indicate_status(
3284                    buffers,
3285                    id.0,
3286                    rndisprot::STATUS_TASK_OFFLOAD_CURRENT_CONFIG,
3287                    &ndis_offload.as_bytes()[..ndis_offload.header.size.into()],
3288                )?;
3289                primary.pending_offload_change = false;
3290            } else {
3291                unreachable!();
3292            }
3293        }
3294        Ok(())
3295    }
3296
3297    fn send_coordinator_update_message(&mut self, guest_vf: bool, packet_filter: bool) {
3298        if self.restart.is_none() {
3299            self.restart = Some(CoordinatorMessage::Update(CoordinatorMessageUpdateType {
3300                guest_vf_state: guest_vf,
3301                filter_state: packet_filter,
3302            }));
3303        } else if let Some(CoordinatorMessage::Restart { .. }) = self.restart {
3304            // If a restart message is pending, do nothing.
3305            // A restart will try to switch the data path based on primary.guest_vf_state.
3306            // A restart will apply packet filter changes.
3307        } else if let Some(CoordinatorMessage::Update(ref mut update)) = self.restart {
3308            // Add the new update to the existing message.
3309            update.guest_vf_state |= guest_vf;
3310            update.filter_state |= packet_filter;
3311        }
3312    }
3313
3314    fn send_coordinator_update_vf(&mut self) {
3315        self.send_coordinator_update_message(true, false);
3316    }
3317
3318    fn send_coordinator_update_filter(&mut self) {
3319        self.send_coordinator_update_message(false, true);
3320    }
3321}
3322
3323/// Writes an RNDIS message to `writer`.
3324fn write_rndis_message<T: IntoBytes + Immutable + KnownLayout>(
3325    writer: &mut impl MemoryWrite,
3326    message_type: u32,
3327    extra: usize,
3328    payload: &T,
3329) -> Result<usize, AccessError> {
3330    let message_length = size_of::<rndisprot::MessageHeader>() + size_of_val(payload) + extra;
3331    writer.write(
3332        rndisprot::MessageHeader {
3333            message_type,
3334            message_length: message_length as u32,
3335        }
3336        .as_bytes(),
3337    )?;
3338    writer.write(payload.as_bytes())?;
3339    Ok(message_length)
3340}
3341
3342#[derive(Debug, Error)]
3343enum OidError {
3344    #[error(transparent)]
3345    Access(#[from] AccessError),
3346    #[error("unknown oid")]
3347    UnknownOid,
3348    #[error("invalid oid input, bad field {0}")]
3349    InvalidInput(&'static str),
3350    #[error("bad ndis version")]
3351    BadVersion,
3352    #[error("feature {0} not supported")]
3353    NotSupported(&'static str),
3354}
3355
3356impl OidError {
3357    fn as_status(&self) -> u32 {
3358        match self {
3359            OidError::UnknownOid | OidError::NotSupported(_) => rndisprot::STATUS_NOT_SUPPORTED,
3360            OidError::BadVersion => rndisprot::STATUS_BAD_VERSION,
3361            OidError::InvalidInput(_) => rndisprot::STATUS_INVALID_DATA,
3362            OidError::Access(_) => rndisprot::STATUS_FAILURE,
3363        }
3364    }
3365}
3366
3367const DEFAULT_MTU: u32 = 1514;
3368const MIN_MTU: u32 = DEFAULT_MTU;
3369const MAX_MTU: u32 = 9216;
3370
3371impl Adapter {
3372    fn get_guest_vf_serial_number(&self, vfid: u32) -> u32 {
3373        if let Some(guest_os_id) = self.get_guest_os_id.as_ref().map(|f| f()) {
3374            // For enlightened guests (which is only Windows at the moment), send the
3375            // adapter index, which was previously set as the vport serial number.
3376            if guest_os_id
3377                .microsoft()
3378                .unwrap_or(HvGuestOsMicrosoft::from(0))
3379                .os_id()
3380                == HvGuestOsMicrosoftIds::WINDOWS_NT.0
3381            {
3382                self.adapter_index
3383            } else {
3384                vfid
3385            }
3386        } else {
3387            vfid
3388        }
3389    }
3390
3391    fn handle_oid_query(
3392        &self,
3393        buffers: &ChannelBuffers,
3394        primary: &PrimaryChannelState,
3395        oid: rndisprot::Oid,
3396        mut writer: impl MemoryWrite,
3397    ) -> Result<usize, OidError> {
3398        tracing::debug!(?oid, "oid query");
3399        let available_len = writer.len();
3400        match oid {
3401            rndisprot::Oid::OID_GEN_SUPPORTED_LIST => {
3402                let supported_oids_common = &[
3403                    rndisprot::Oid::OID_GEN_SUPPORTED_LIST,
3404                    rndisprot::Oid::OID_GEN_HARDWARE_STATUS,
3405                    rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED,
3406                    rndisprot::Oid::OID_GEN_MEDIA_IN_USE,
3407                    rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD,
3408                    rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD,
3409                    rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE,
3410                    rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE,
3411                    rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE,
3412                    rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE,
3413                    rndisprot::Oid::OID_GEN_LINK_SPEED,
3414                    rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE,
3415                    rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE,
3416                    rndisprot::Oid::OID_GEN_VENDOR_ID,
3417                    rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION,
3418                    rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION,
3419                    rndisprot::Oid::OID_GEN_DRIVER_VERSION,
3420                    rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER,
3421                    rndisprot::Oid::OID_GEN_PROTOCOL_OPTIONS,
3422                    rndisprot::Oid::OID_GEN_MAC_OPTIONS,
3423                    rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS,
3424                    rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS,
3425                    rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES,
3426                    rndisprot::Oid::OID_GEN_FRIENDLY_NAME,
3427                    // Ethernet objects operation characteristics
3428                    rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS,
3429                    rndisprot::Oid::OID_802_3_CURRENT_ADDRESS,
3430                    rndisprot::Oid::OID_802_3_MULTICAST_LIST,
3431                    rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE,
3432                    // Ethernet objects statistics
3433                    rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT,
3434                    rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION,
3435                    rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS,
3436                    // PNP operations characteristics */
3437                    // rndisprot::Oid::OID_PNP_SET_POWER,
3438                    // rndisprot::Oid::OID_PNP_QUERY_POWER,
3439                    // RNDIS OIDS
3440                    rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER,
3441                ];
3442
3443                // NDIS6 OIDs
3444                let supported_oids_6 = &[
3445                    // Link State OID
3446                    rndisprot::Oid::OID_GEN_LINK_PARAMETERS,
3447                    rndisprot::Oid::OID_GEN_LINK_STATE,
3448                    rndisprot::Oid::OID_GEN_MAX_LINK_SPEED,
3449                    // NDIS 6 statistics OID
3450                    rndisprot::Oid::OID_GEN_BYTES_RCV,
3451                    rndisprot::Oid::OID_GEN_BYTES_XMIT,
3452                    // Offload related OID
3453                    rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS,
3454                    rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION,
3455                    rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES,
3456                    rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG,
3457                    // rndisprot::Oid::OID_802_3_ADD_MULTICAST_ADDRESS,
3458                    // rndisprot::Oid::OID_802_3_DELETE_MULTICAST_ADDRESS,
3459                ];
3460
3461                let supported_oids_63 = &[
3462                    rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES,
3463                    rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS,
3464                ];
3465
3466                match buffers.ndis_version.major {
3467                    5 => {
3468                        writer.write(supported_oids_common.as_bytes())?;
3469                    }
3470                    6 => {
3471                        writer.write(supported_oids_common.as_bytes())?;
3472                        writer.write(supported_oids_6.as_bytes())?;
3473                        if buffers.ndis_version.minor >= 30 {
3474                            writer.write(supported_oids_63.as_bytes())?;
3475                        }
3476                    }
3477                    _ => return Err(OidError::BadVersion),
3478                }
3479            }
3480            rndisprot::Oid::OID_GEN_HARDWARE_STATUS => {
3481                let status: u32 = 0; // NdisHardwareStatusReady
3482                writer.write(status.as_bytes())?;
3483            }
3484            rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED | rndisprot::Oid::OID_GEN_MEDIA_IN_USE => {
3485                writer.write(rndisprot::MEDIUM_802_3.as_bytes())?;
3486            }
3487            rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD
3488            | rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD
3489            | rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE => {
3490                let len: u32 = buffers.ndis_config.mtu - net_backend::ETHERNET_HEADER_LEN;
3491                writer.write(len.as_bytes())?;
3492            }
3493            rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE
3494            | rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE
3495            | rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE => {
3496                let len: u32 = buffers.ndis_config.mtu;
3497                writer.write(len.as_bytes())?;
3498            }
3499            rndisprot::Oid::OID_GEN_LINK_SPEED => {
3500                let speed: u32 = (self.link_speed / 100) as u32; // In 100bps units
3501                writer.write(speed.as_bytes())?;
3502            }
3503            rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE
3504            | rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE => {
3505                // This value is meaningless for virtual NICs. Return what vmswitch returns.
3506                writer.write((256u32 * 1024).as_bytes())?
3507            }
3508            rndisprot::Oid::OID_GEN_VENDOR_ID => {
3509                // Like vmswitch, use the first N bytes of Microsoft's MAC address
3510                // prefix as the vendor ID.
3511                writer.write(0x0000155du32.as_bytes())?;
3512            }
3513            rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION => writer.write(b"Microsoft Corporation")?,
3514            rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION
3515            | rndisprot::Oid::OID_GEN_DRIVER_VERSION => {
3516                writer.write(0x0100u16.as_bytes())? // 1.0. Vmswitch reports 19.0 for Mn.
3517            }
3518            rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => writer.write(0u32.as_bytes())?,
3519            rndisprot::Oid::OID_GEN_MAC_OPTIONS => {
3520                let options: u32 = rndisprot::MAC_OPTION_COPY_LOOKAHEAD_DATA
3521                    | rndisprot::MAC_OPTION_TRANSFERS_NOT_PEND
3522                    | rndisprot::MAC_OPTION_NO_LOOPBACK
3523                    | rndisprot::MAC_OPTION_8021P_PRIORITY
3524                    | rndisprot::MAC_OPTION_8021Q_VLAN;
3525                writer.write(options.as_bytes())?;
3526            }
3527            rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS => {
3528                writer.write(rndisprot::MEDIA_STATE_CONNECTED.as_bytes())?;
3529            }
3530            rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS => writer.write(u32::MAX.as_bytes())?,
3531            rndisprot::Oid::OID_GEN_FRIENDLY_NAME => {
3532                let name16: Vec<u16> = "Network Device".encode_utf16().collect();
3533                let mut name = rndisprot::FriendlyName::new_zeroed();
3534                name.name[..name16.len()].copy_from_slice(&name16);
3535                writer.write(name.as_bytes())?
3536            }
3537            rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS
3538            | rndisprot::Oid::OID_802_3_CURRENT_ADDRESS => {
3539                writer.write(&self.mac_address.to_bytes())?
3540            }
3541            rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE => {
3542                writer.write(0u32.as_bytes())?;
3543            }
3544            rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT
3545            | rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION
3546            | rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS => writer.write(0u32.as_bytes())?,
3547
3548            // NDIS6 OIDs:
3549            rndisprot::Oid::OID_GEN_LINK_STATE => {
3550                let link_state = rndisprot::LinkState {
3551                    header: rndisprot::NdisObjectHeader {
3552                        object_type: rndisprot::NdisObjectType::DEFAULT,
3553                        revision: 1,
3554                        size: size_of::<rndisprot::LinkState>() as u16,
3555                    },
3556                    media_connect_state: 1, /* MediaConnectStateConnected */
3557                    media_duplex_state: 0,  /* MediaDuplexStateUnknown */
3558                    padding: 0,
3559                    xmit_link_speed: self.link_speed,
3560                    rcv_link_speed: self.link_speed,
3561                    pause_functions: 0, /* NdisPauseFunctionsUnsupported */
3562                    auto_negotiation_flags: 0,
3563                };
3564                writer.write(link_state.as_bytes())?;
3565            }
3566            rndisprot::Oid::OID_GEN_MAX_LINK_SPEED => {
3567                let link_speed = rndisprot::LinkSpeed {
3568                    xmit: self.link_speed,
3569                    rcv: self.link_speed,
3570                };
3571                writer.write(link_speed.as_bytes())?;
3572            }
3573            rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES => {
3574                let ndis_offload = self.offload_support.ndis_offload();
3575                writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3576            }
3577            rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG => {
3578                let ndis_offload = primary.offload_config.ndis_offload();
3579                writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3580            }
3581            rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3582                writer.write(
3583                    &rndisprot::NdisOffloadEncapsulation {
3584                        header: rndisprot::NdisObjectHeader {
3585                            object_type: rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3586                            revision: 1,
3587                            size: rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1 as u16,
3588                        },
3589                        ipv4_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3590                        ipv4_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3591                        ipv4_header_size: net_backend::ETHERNET_HEADER_LEN,
3592                        ipv6_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3593                        ipv6_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3594                        ipv6_header_size: net_backend::ETHERNET_HEADER_LEN,
3595                    }
3596                    .as_bytes()[..rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1],
3597                )?;
3598            }
3599            rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES => {
3600                writer.write(
3601                    &rndisprot::NdisReceiveScaleCapabilities {
3602                        header: rndisprot::NdisObjectHeader {
3603                            object_type: rndisprot::NdisObjectType::RSS_CAPABILITIES,
3604                            revision: 2,
3605                            size: rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2
3606                                as u16,
3607                        },
3608                        capabilities_flags: rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV4
3609                            | rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV6
3610                            | rndisprot::NDIS_HASH_FUNCTION_TOEPLITZ,
3611                        number_of_interrupt_messages: 1,
3612                        number_of_receive_queues: self.max_queues.into(),
3613                        number_of_indirection_table_entries: if self.indirection_table_size != 0 {
3614                            self.indirection_table_size
3615                        } else {
3616                            // DPDK gets confused if the table size is zero,
3617                            // even if there is only one queue.
3618                            128
3619                        },
3620                        padding: 0,
3621                    }
3622                    .as_bytes()[..rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2],
3623                )?;
3624            }
3625            _ => {
3626                tracelimit::warn_ratelimited!(?oid, "query for unknown OID");
3627                return Err(OidError::UnknownOid);
3628            }
3629        };
3630        Ok(available_len - writer.len())
3631    }
3632
3633    fn handle_oid_set(
3634        &self,
3635        primary: &mut PrimaryChannelState,
3636        oid: rndisprot::Oid,
3637        reader: impl MemoryRead + Clone,
3638    ) -> Result<(bool, Option<u32>), OidError> {
3639        tracing::debug!(?oid, "oid set");
3640
3641        let mut restart_endpoint = false;
3642        let mut packet_filter = None;
3643        match oid {
3644            rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => {
3645                packet_filter = self.oid_set_packet_filter(reader)?;
3646            }
3647            rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS => {
3648                self.oid_set_offload_parameters(reader, primary)?;
3649            }
3650            rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3651                self.oid_set_offload_encapsulation(reader)?;
3652            }
3653            rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER => {
3654                self.oid_set_rndis_config_parameter(reader, primary)?;
3655            }
3656            rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES => {
3657                // TODO
3658            }
3659            rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS => {
3660                let rss_was_enabled = self.oid_set_rss_parameters(reader, primary)?;
3661
3662                // Endpoints cannot currently change RSS parameters without
3663                // being restarted. This was a limitation driven by some DPDK
3664                // PMDs, and should be fixed.
3665                //
3666                // Skip the restart if RSS was already disabled and the guest
3667                // is disabling it again — nothing has changed.
3668                if rss_was_enabled || primary.rss_state.is_some() {
3669                    restart_endpoint = true;
3670                }
3671            }
3672            _ => {
3673                tracelimit::warn_ratelimited!(?oid, "set of unknown OID");
3674                return Err(OidError::UnknownOid);
3675            }
3676        }
3677        Ok((restart_endpoint, packet_filter))
3678    }
3679
3680    fn oid_set_rss_parameters(
3681        &self,
3682        mut reader: impl MemoryRead + Clone,
3683        primary: &mut PrimaryChannelState,
3684    ) -> Result<bool, OidError> {
3685        // Vmswitch doesn't validate the NDIS header on this object, so read it manually.
3686        let mut params = rndisprot::NdisReceiveScaleParameters::new_zeroed();
3687        let len = reader.len().min(size_of_val(&params));
3688        reader.clone().read(&mut params.as_mut_bytes()[..len])?;
3689
3690        let rss_was_enabled = primary.rss_state.is_some();
3691
3692        if ((params.flags & NDIS_RSS_PARAM_FLAG_DISABLE_RSS) != 0)
3693            || ((params.hash_information & NDIS_HASH_FUNCTION_MASK) == 0)
3694        {
3695            primary.rss_state = None;
3696            return Ok(rss_was_enabled);
3697        }
3698
3699        if params.hash_secret_key_size != 40 {
3700            return Err(OidError::InvalidInput("hash_secret_key_size"));
3701        }
3702        if params.indirection_table_size % 4 != 0 || params.indirection_table_size == 0 {
3703            return Err(OidError::InvalidInput("indirection_table_size"));
3704        }
3705        let indirection_table_size =
3706            (params.indirection_table_size / 4).min(self.indirection_table_size) as usize;
3707        let mut key = [0; 40];
3708        let mut indirection_table = vec![0u32; self.indirection_table_size as usize];
3709        reader
3710            .clone()
3711            .skip(params.hash_secret_key_offset as usize)?
3712            .read(&mut key)?;
3713        reader
3714            .skip(params.indirection_table_offset as usize)?
3715            .read(indirection_table[..indirection_table_size].as_mut_bytes())?;
3716        tracelimit::info_ratelimited!(?indirection_table, "OID_GEN_RECEIVE_SCALE_PARAMETERS");
3717        if indirection_table
3718            .iter()
3719            .any(|&x| x >= self.max_queues as u32)
3720        {
3721            return Err(OidError::InvalidInput("indirection_table"));
3722        }
3723        let (indir_init, indir_uninit) = indirection_table.split_at_mut(indirection_table_size);
3724        for (src, dest) in std::iter::repeat_with(|| indir_init.iter().copied())
3725            .flatten()
3726            .zip(indir_uninit)
3727        {
3728            *dest = src;
3729        }
3730        primary.rss_state = Some(RssState {
3731            key,
3732            indirection_table: indirection_table.iter().map(|&x| x as u16).collect(),
3733        });
3734        Ok(rss_was_enabled)
3735    }
3736
3737    fn oid_set_packet_filter(
3738        &self,
3739        reader: impl MemoryRead + Clone,
3740    ) -> Result<Option<u32>, OidError> {
3741        let filter: rndisprot::RndisPacketFilterOidValue = reader.clone().read_plain()?;
3742        tracing::debug!(filter, "set packet filter");
3743        Ok(Some(filter))
3744    }
3745
3746    fn oid_set_offload_parameters(
3747        &self,
3748        reader: impl MemoryRead + Clone,
3749        primary: &mut PrimaryChannelState,
3750    ) -> Result<(), OidError> {
3751        let offload: rndisprot::NdisOffloadParameters = read_ndis_object(
3752            reader,
3753            rndisprot::NdisObjectType::DEFAULT,
3754            1,
3755            rndisprot::NDIS_SIZEOF_OFFLOAD_PARAMETERS_REVISION_1,
3756        )?;
3757
3758        tracing::debug!(?offload, "offload parameters");
3759        let rndisprot::NdisOffloadParameters {
3760            header: _,
3761            ipv4_checksum,
3762            tcp4_checksum,
3763            udp4_checksum,
3764            tcp6_checksum,
3765            udp6_checksum,
3766            lsov1,
3767            ipsec_v1: _,
3768            lsov2_ipv4,
3769            lsov2_ipv6,
3770            tcp_connection_ipv4: _,
3771            tcp_connection_ipv6: _,
3772            reserved: _,
3773            flags: _,
3774        } = offload;
3775
3776        if lsov1 == rndisprot::OffloadParametersSimple::ENABLED {
3777            return Err(OidError::NotSupported("lsov1"));
3778        }
3779        if let Some((tx, rx)) = ipv4_checksum.tx_rx() {
3780            primary.offload_config.checksum_tx.ipv4_header = tx;
3781            primary.offload_config.checksum_rx.ipv4_header = rx;
3782        }
3783        if let Some((tx, rx)) = tcp4_checksum.tx_rx() {
3784            primary.offload_config.checksum_tx.tcp4 = tx;
3785            primary.offload_config.checksum_rx.tcp4 = rx;
3786        }
3787        if let Some((tx, rx)) = tcp6_checksum.tx_rx() {
3788            primary.offload_config.checksum_tx.tcp6 = tx;
3789            primary.offload_config.checksum_rx.tcp6 = rx;
3790        }
3791        if let Some((tx, rx)) = udp4_checksum.tx_rx() {
3792            primary.offload_config.checksum_tx.udp4 = tx;
3793            primary.offload_config.checksum_rx.udp4 = rx;
3794        }
3795        if let Some((tx, rx)) = udp6_checksum.tx_rx() {
3796            primary.offload_config.checksum_tx.udp6 = tx;
3797            primary.offload_config.checksum_rx.udp6 = rx;
3798        }
3799        if let Some(enable) = lsov2_ipv4.enable() {
3800            primary.offload_config.lso4 = enable;
3801        }
3802        if let Some(enable) = lsov2_ipv6.enable() {
3803            primary.offload_config.lso6 = enable;
3804        }
3805        primary
3806            .offload_config
3807            .mask_to_supported(&self.offload_support);
3808        primary.pending_offload_change = true;
3809        Ok(())
3810    }
3811
3812    fn oid_set_offload_encapsulation(
3813        &self,
3814        reader: impl MemoryRead + Clone,
3815    ) -> Result<(), OidError> {
3816        let encap: rndisprot::NdisOffloadEncapsulation = read_ndis_object(
3817            reader,
3818            rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3819            1,
3820            rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1,
3821        )?;
3822        if encap.ipv4_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3823            && (encap.ipv4_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3824                || encap.ipv4_header_size != net_backend::ETHERNET_HEADER_LEN)
3825        {
3826            return Err(OidError::NotSupported("ipv4 encap"));
3827        }
3828        if encap.ipv6_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3829            && (encap.ipv6_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3830                || encap.ipv6_header_size != net_backend::ETHERNET_HEADER_LEN)
3831        {
3832            return Err(OidError::NotSupported("ipv6 encap"));
3833        }
3834        Ok(())
3835    }
3836
3837    fn oid_set_rndis_config_parameter(
3838        &self,
3839        reader: impl MemoryRead + Clone,
3840        primary: &mut PrimaryChannelState,
3841    ) -> Result<(), OidError> {
3842        let info: rndisprot::RndisConfigParameterInfo = reader.clone().read_plain()?;
3843        if info.name_length > 255 {
3844            return Err(OidError::InvalidInput("name_length"));
3845        }
3846        if info.value_length > 255 {
3847            return Err(OidError::InvalidInput("value_length"));
3848        }
3849        let name = reader
3850            .clone()
3851            .skip(info.name_offset as usize)?
3852            .read_n::<u16>(info.name_length as usize / 2)?;
3853        let name = String::from_utf16(&name).map_err(|_| OidError::InvalidInput("name"))?;
3854        let mut value = reader;
3855        value.skip(info.value_offset as usize)?;
3856        let mut value = value.limit(info.value_length as usize);
3857        match info.parameter_type {
3858            rndisprot::NdisParameterType::STRING => {
3859                let value = value.read_n::<u16>(info.value_length as usize / 2)?;
3860                let value =
3861                    String::from_utf16(&value).map_err(|_| OidError::InvalidInput("value"))?;
3862                let as_num = value
3863                    .as_bytes()
3864                    .first()
3865                    .map(|c| c.wrapping_sub(b'0'))
3866                    .filter(|&c| c <= 9)
3867                    .ok_or(OidError::InvalidInput("value as num"))?;
3868                let tx = as_num & 1 != 0;
3869                let rx = as_num & 2 != 0;
3870
3871                tracing::debug!(name, value, "rndis config");
3872                match name.as_str() {
3873                    "*IPChecksumOffloadIPv4" => {
3874                        primary.offload_config.checksum_tx.ipv4_header = tx;
3875                        primary.offload_config.checksum_rx.ipv4_header = rx;
3876                    }
3877                    "*LsoV2IPv4" => {
3878                        primary.offload_config.lso4 = as_num != 0;
3879                    }
3880                    "*LsoV2IPv6" => {
3881                        primary.offload_config.lso6 = as_num != 0;
3882                    }
3883                    "*TCPChecksumOffloadIPv4" => {
3884                        primary.offload_config.checksum_tx.tcp4 = tx;
3885                        primary.offload_config.checksum_rx.tcp4 = rx;
3886                    }
3887                    "*TCPChecksumOffloadIPv6" => {
3888                        primary.offload_config.checksum_tx.tcp6 = tx;
3889                        primary.offload_config.checksum_rx.tcp6 = rx;
3890                    }
3891                    "*UDPChecksumOffloadIPv4" => {
3892                        primary.offload_config.checksum_tx.udp4 = tx;
3893                        primary.offload_config.checksum_rx.udp4 = rx;
3894                    }
3895                    "*UDPChecksumOffloadIPv6" => {
3896                        primary.offload_config.checksum_tx.udp6 = tx;
3897                        primary.offload_config.checksum_rx.udp6 = rx;
3898                    }
3899                    _ => {}
3900                }
3901                primary
3902                    .offload_config
3903                    .mask_to_supported(&self.offload_support);
3904            }
3905            rndisprot::NdisParameterType::INTEGER => {
3906                let value: u32 = value.read_plain()?;
3907                tracing::debug!(name, value, "rndis config");
3908            }
3909            parameter_type => tracelimit::warn_ratelimited!(
3910                name,
3911                ?parameter_type,
3912                "unhandled rndis config parameter type"
3913            ),
3914        }
3915        Ok(())
3916    }
3917}
3918
3919fn read_ndis_object<T: IntoBytes + FromBytes + Debug + Immutable + KnownLayout>(
3920    mut reader: impl MemoryRead,
3921    object_type: rndisprot::NdisObjectType,
3922    min_revision: u8,
3923    min_size: usize,
3924) -> Result<T, OidError> {
3925    let mut buffer = T::new_zeroed();
3926    let sent_size = reader.len();
3927    let len = sent_size.min(size_of_val(&buffer));
3928    reader.read(&mut buffer.as_mut_bytes()[..len])?;
3929    validate_ndis_object_header(
3930        &rndisprot::NdisObjectHeader::read_from_prefix(buffer.as_bytes())
3931            .unwrap()
3932            .0, // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
3933        sent_size,
3934        object_type,
3935        min_revision,
3936        min_size,
3937    )?;
3938    Ok(buffer)
3939}
3940
3941fn validate_ndis_object_header(
3942    header: &rndisprot::NdisObjectHeader,
3943    sent_size: usize,
3944    object_type: rndisprot::NdisObjectType,
3945    min_revision: u8,
3946    min_size: usize,
3947) -> Result<(), OidError> {
3948    if header.object_type != object_type {
3949        return Err(OidError::InvalidInput("header.object_type"));
3950    }
3951    if sent_size < header.size as usize {
3952        return Err(OidError::InvalidInput("header.size"));
3953    }
3954    if header.revision < min_revision {
3955        return Err(OidError::InvalidInput("header.revision"));
3956    }
3957    if (header.size as usize) < min_size {
3958        return Err(OidError::InvalidInput("header.size"));
3959    }
3960    Ok(())
3961}
3962
3963struct Coordinator {
3964    recv: mpsc::Receiver<CoordinatorMessage>,
3965    channel_control: ChannelControl,
3966    restart: bool,
3967    workers: Vec<TaskControl<NetQueue, Worker<GpadlRingMem>>>,
3968    buffers: Option<Arc<ChannelBuffers>>,
3969    num_queues: u16,
3970    active_packet_filter: u32,
3971    sleep_deadline: Option<Instant>,
3972}
3973
3974/// Removing the VF may result in the guest sending messages to switch the data
3975/// path, so these operations need to happen asynchronously with message
3976/// processing.
3977enum CoordinatorStatePendingVfState {
3978    /// No pending updates.
3979    Ready,
3980    /// Delay before adding VF.
3981    Delay {
3982        timer: PolledTimer,
3983        delay_until: Instant,
3984    },
3985    /// A VF update is pending.
3986    Pending,
3987}
3988
3989struct CoordinatorState {
3990    endpoint: Box<dyn Endpoint>,
3991    adapter: Arc<Adapter>,
3992    virtual_function: Option<Box<dyn VirtualFunction>>,
3993    pending_vf_state: CoordinatorStatePendingVfState,
3994}
3995
3996impl InspectTaskMut<Coordinator> for CoordinatorState {
3997    fn inspect_mut(
3998        &mut self,
3999        req: inspect::Request<'_>,
4000        mut coordinator: Option<&mut Coordinator>,
4001    ) {
4002        let mut resp = req.respond();
4003
4004        let adapter = self.adapter.as_ref();
4005        resp.field("mac_address", adapter.mac_address)
4006            .field("max_queues", adapter.max_queues)
4007            .sensitivity_field(
4008                "offload_support",
4009                SensitivityLevel::Safe,
4010                &adapter.offload_support,
4011            )
4012            .field_mut_with("ring_size_limit", |v| -> anyhow::Result<_> {
4013                if let Some(v) = v {
4014                    let v: usize = v.parse()?;
4015                    adapter.ring_size_limit.store(v, Ordering::Relaxed);
4016                    // Bounce each task so that it sees the new value.
4017                    if let Some(this) = &mut coordinator {
4018                        for worker in &mut this.workers {
4019                            worker.update_with(|_, _| ());
4020                        }
4021                    }
4022                }
4023                Ok(adapter.ring_size_limit.load(Ordering::Relaxed))
4024            });
4025
4026        resp.field("endpoint_type", self.endpoint.endpoint_type())
4027            .field(
4028                "endpoint_max_queues",
4029                self.endpoint.multiqueue_support().max_queues,
4030            )
4031            .sensitivity_field_mut("endpoint", SensitivityLevel::Safe, self.endpoint.as_mut());
4032
4033        if let Some(coordinator) = coordinator {
4034            resp.sensitivity_child("queues", SensitivityLevel::Safe, |req| {
4035                let mut resp = req.respond();
4036                for (i, q) in coordinator.workers[..coordinator.num_queues as usize]
4037                    .iter_mut()
4038                    .enumerate()
4039                {
4040                    resp.field_mut(&i.to_string(), q);
4041                }
4042            });
4043
4044            // Get the shared channel state from the primary channel.
4045            resp.merge(inspect::adhoc_mut(|req| {
4046                let deferred = req.defer();
4047                coordinator.workers[0].update_with(|_, worker| {
4048                    let Some(worker) = worker.as_deref() else {
4049                        return;
4050                    };
4051                    if let Some(state) = worker.state.ready() {
4052                        deferred.respond(|resp| {
4053                            resp.merge(&state.buffers);
4054                            resp.sensitivity_field(
4055                                "primary_channel_state",
4056                                SensitivityLevel::Safe,
4057                                &state.state.primary,
4058                            )
4059                            .sensitivity_field(
4060                                "packet_filter",
4061                                SensitivityLevel::Safe,
4062                                inspect::AsHex(worker.channel.packet_filter),
4063                            );
4064                        });
4065                    }
4066                })
4067            }));
4068        }
4069    }
4070}
4071
4072impl AsyncRun<Coordinator> for CoordinatorState {
4073    async fn run(
4074        &mut self,
4075        stop: &mut StopTask<'_>,
4076        coordinator: &mut Coordinator,
4077    ) -> Result<(), task_control::Cancelled> {
4078        coordinator.process(stop, self).await
4079    }
4080}
4081
4082impl Coordinator {
4083    async fn process(
4084        &mut self,
4085        stop: &mut StopTask<'_>,
4086        state: &mut CoordinatorState,
4087    ) -> Result<(), task_control::Cancelled> {
4088        loop {
4089            // `self.restart` is set in a prior iteration when:
4090            // `CoordinatorMessage::Restart` from Primary or sub-channel worker.
4091            // `EndpointAction::RestartRequired`.
4092            // Or in `insert_coordinator` when Restoring from saved state.
4093            if self.restart {
4094                self.restart_worker_queues(stop, state).await?;
4095            }
4096
4097            // Ensure that all workers except the primary are started. The
4098            // primary is started below if there are no outstanding messages.
4099            for worker in &mut self.workers[1..] {
4100                worker.start();
4101            }
4102            if !self.workers[0].is_running()
4103                && self.workers[0].state().is_none_or(|worker| {
4104                    !matches!(worker.state, WorkerState::WaitingForCoordinator(_))
4105                })
4106            {
4107                self.workers[0].start();
4108            }
4109
4110            enum Message {
4111                Internal(CoordinatorMessage),
4112                ChannelDisconnected,
4113                UpdateFromEndpoint(EndpointAction),
4114                UpdateFromVf(Rpc<(), ()>),
4115                OfferVfDevice,
4116                PendingVfStateComplete,
4117                TimerExpired,
4118            }
4119            let message = if matches!(
4120                state.pending_vf_state,
4121                CoordinatorStatePendingVfState::Pending
4122            ) {
4123                // guest_ready_for_device is not restartable, so do not poll on
4124                // stop.
4125                state
4126                    .virtual_function
4127                    .as_mut()
4128                    .expect("Pending requires a VF")
4129                    .guest_ready_for_device()
4130                    .await;
4131                Message::PendingVfStateComplete
4132            } else {
4133                let timer_sleep = async {
4134                    if let Some(deadline) = self.sleep_deadline {
4135                        let mut timer = PolledTimer::new(&state.adapter.driver);
4136                        timer.sleep_until(deadline).await;
4137                    } else {
4138                        pending::<()>().await;
4139                    }
4140                    Message::TimerExpired
4141                };
4142                let wait_for_message = async {
4143                    let internal_msg = self
4144                        .recv
4145                        .next()
4146                        .map(|x| x.map_or(Message::ChannelDisconnected, Message::Internal));
4147                    let endpoint_restart = state
4148                        .endpoint
4149                        .wait_for_endpoint_action()
4150                        .map(Message::UpdateFromEndpoint);
4151                    if let Some(vf) = state.virtual_function.as_mut() {
4152                        match state.pending_vf_state {
4153                            CoordinatorStatePendingVfState::Ready
4154                            | CoordinatorStatePendingVfState::Delay { .. } => {
4155                                let offer_device = async {
4156                                    if let CoordinatorStatePendingVfState::Delay {
4157                                        timer,
4158                                        delay_until,
4159                                    } = &mut state.pending_vf_state
4160                                    {
4161                                        timer.sleep_until(*delay_until).await;
4162                                    } else {
4163                                        pending::<()>().await;
4164                                    }
4165                                    Message::OfferVfDevice
4166                                };
4167                                (
4168                                    internal_msg,
4169                                    offer_device,
4170                                    endpoint_restart,
4171                                    vf.wait_for_state_change().map(Message::UpdateFromVf),
4172                                    timer_sleep,
4173                                )
4174                                    .race()
4175                                    .await
4176                            }
4177                            CoordinatorStatePendingVfState::Pending => unreachable!(),
4178                        }
4179                    } else {
4180                        (internal_msg, endpoint_restart, timer_sleep).race().await
4181                    }
4182                };
4183
4184                stop.until_stopped(wait_for_message).await?
4185            };
4186            match message {
4187                Message::Internal(msg) => {
4188                    self.handle_coordinator_message(msg, state).await;
4189                    // If a restart message has been queued, handle it now
4190                    // to ensure worker queues are restarted prior to
4191                    // `worker.start()` in the next loop.
4192                    self.handle_queued_coordinator_messages(state).await;
4193                }
4194                Message::UpdateFromVf(rpc) => {
4195                    rpc.handle(async |_| {
4196                        self.update_guest_vf_state(state).await;
4197                    })
4198                    .await;
4199                }
4200                Message::OfferVfDevice => {
4201                    self.stop_primary_worker().await;
4202                    if let Some(primary) = self.primary_mut() {
4203                        if matches!(
4204                            primary.guest_vf_state,
4205                            PrimaryChannelGuestVfState::AvailableAdvertised
4206                        ) {
4207                            primary.guest_vf_state = PrimaryChannelGuestVfState::Ready;
4208                        }
4209                    }
4210
4211                    state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4212                }
4213                Message::PendingVfStateComplete => {
4214                    // Worker state unchanged, no worker needs to be stopped.
4215                    state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
4216                }
4217                Message::TimerExpired => {
4218                    // Kick the worker as requested.
4219                    self.stop_primary_worker().await;
4220                    if let Some(primary) = self.primary_mut() {
4221                        if let PendingLinkAction::Delay(up) = primary.pending_link_action {
4222                            primary.pending_link_action = PendingLinkAction::Active(up);
4223                        }
4224                    }
4225                    self.sleep_deadline = None;
4226                }
4227                Message::UpdateFromEndpoint(endpoint_action) => {
4228                    self.handle_endpoint_action(endpoint_action).await;
4229                }
4230                Message::ChannelDisconnected => {
4231                    break;
4232                }
4233            };
4234        }
4235        Ok(())
4236    }
4237
4238    async fn handle_endpoint_action(&mut self, action: EndpointAction) {
4239        match action {
4240            EndpointAction::RestartRequired => self.restart = true,
4241            EndpointAction::LinkStatusNotify(connect) => {
4242                self.stop_primary_worker().await;
4243                // These are the only link state transitions that are tracked.
4244                // 1. up -> down or down -> up
4245                // 2. up -> down -> up or down -> up -> down.
4246                // All other state transitions are coalesced into one of the above cases.
4247                // For example, up -> down -> up -> down is treated as up -> down.
4248                // N.B - Always queue up the incoming state to minimize the effects of loss
4249                //       of any notifications (for example, during vtl2 servicing).
4250                if let Some(primary) = self.primary_mut() {
4251                    primary.pending_link_action = PendingLinkAction::Active(connect);
4252                }
4253
4254                // If there is any existing sleep timer running, cancel it out.
4255                self.sleep_deadline = None;
4256            }
4257        }
4258    }
4259
4260    /// Called from the [`Self::process`] loop when either `CoordinatorMessage::Restart`
4261    /// or `EndpointAction::RestartRequired` is observed.
4262    async fn restart_worker_queues(
4263        &mut self,
4264        stop: &mut StopTask<'_>,
4265        state: &mut CoordinatorState,
4266    ) -> Result<(), task_control::Cancelled> {
4267        stop.until_stopped(self.stop_workers()).await?;
4268
4269        // All workers are stopped and cannot push new messages.
4270        // Drain any messages that arrived prior to or during the stop.
4271        // Coalesce restart messages. Handle non-restart Primary messages.
4272        self.handle_queued_coordinator_messages(state).await;
4273
4274        // Best-effort attempt to coalesce any `RestartRequired` endpoint
4275        // action into this restart operation.
4276        while let Some(action) = state.endpoint.wait_for_endpoint_action().now_or_never() {
4277            self.handle_endpoint_action(action).await;
4278        }
4279
4280        // The queue restart operation is not restartable; do not poll on stop here.
4281        if let Err(err) = self
4282            .restart_queues(state)
4283            .instrument(tracing::info_span!("netvsp_restart_queues"))
4284            .await
4285        {
4286            tracing::error!(
4287                error = &err as &dyn std::error::Error,
4288                "failed to restart queues"
4289            );
4290        }
4291        if let Some(primary) = self.primary_mut() {
4292            primary.is_data_path_switched = state.endpoint.get_data_path_to_guest_vf().await.ok();
4293            tracing::info!(
4294                is_data_path_switched = primary.is_data_path_switched,
4295                "Query data path state"
4296            );
4297        }
4298        self.restore_guest_vf_state(state).await;
4299        self.restart = false;
4300        Ok(())
4301    }
4302
4303    async fn handle_queued_coordinator_messages(&mut self, state: &mut CoordinatorState) {
4304        while let Ok(Some(msg)) = self.recv.try_next() {
4305            self.handle_coordinator_message(msg, state).await;
4306        }
4307    }
4308
4309    async fn handle_coordinator_message(
4310        &mut self,
4311        msg: CoordinatorMessage,
4312        state: &mut CoordinatorState,
4313    ) {
4314        match msg {
4315            CoordinatorMessage::Restart { channel_idx } if channel_idx != 0 => {
4316                tracelimit::event_ratelimited!(
4317                    tracing::Level::DEBUG,
4318                    channel_idx,
4319                    "sub-channel triggered restart"
4320                );
4321                self.restart = true;
4322            }
4323            _ => self.handle_primary_message(msg, state).await,
4324        }
4325    }
4326
4327    async fn handle_primary_message(
4328        &mut self,
4329        msg: CoordinatorMessage,
4330        state: &mut CoordinatorState,
4331    ) {
4332        self.stop_primary_worker().await;
4333        if let Some(worker) = self.workers[0].state_mut() {
4334            if matches!(worker.state, WorkerState::WaitingForCoordinator(_)) {
4335                let WorkerState::WaitingForCoordinator(Some(ready)) =
4336                    std::mem::replace(&mut worker.state, WorkerState::WaitingForCoordinator(None))
4337                else {
4338                    unreachable!("valid ready state")
4339                };
4340                let _ = std::mem::replace(&mut worker.state, WorkerState::Ready(ready));
4341            }
4342        }
4343        match msg {
4344            CoordinatorMessage::Update(update_type) => {
4345                if update_type.filter_state {
4346                    self.stop_workers().await;
4347                    self.active_packet_filter =
4348                        self.workers[0].state().unwrap().channel.packet_filter;
4349                    self.workers.iter_mut().skip(1).for_each(|worker| {
4350                        if let Some(state) = worker.state_mut() {
4351                            state.channel.packet_filter = self.active_packet_filter;
4352                            tracing::debug!(
4353                                packet_filter = ?self.active_packet_filter,
4354                                channel_idx = state.channel_idx,
4355                                "update packet filter"
4356                            );
4357                        }
4358                    });
4359                }
4360
4361                if update_type.guest_vf_state {
4362                    self.update_guest_vf_state(state).await;
4363                }
4364            }
4365            CoordinatorMessage::StartTimer(deadline) => {
4366                self.sleep_deadline = Some(deadline);
4367            }
4368            CoordinatorMessage::Restart { channel_idx } => {
4369                assert_eq!(channel_idx, 0);
4370                self.restart = true;
4371            }
4372        }
4373    }
4374
4375    async fn stop_workers(&mut self) {
4376        for worker in &mut self.workers {
4377            worker.stop().await;
4378        }
4379    }
4380
4381    async fn stop_primary_worker(&mut self) {
4382        self.workers[0].stop().await;
4383    }
4384
4385    async fn restore_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
4386        let primary = match self.primary_mut() {
4387            Some(primary) => primary,
4388            None => return,
4389        };
4390
4391        // Update guest VF state based on current endpoint properties.
4392        let virtual_function = c_state.virtual_function.as_mut();
4393        let guest_vf_id = match &virtual_function {
4394            Some(vf) => vf.id().await,
4395            None => None,
4396        };
4397        if let Some(guest_vf_id) = guest_vf_id {
4398            // Ensure guest VF is in proper state.
4399            match primary.guest_vf_state {
4400                PrimaryChannelGuestVfState::AvailableAdvertised
4401                | PrimaryChannelGuestVfState::Restoring(
4402                    saved_state::GuestVfState::AvailableAdvertised,
4403                ) => {
4404                    if !primary.is_data_path_switched.unwrap_or(false) {
4405                        let timer = PolledTimer::new(&c_state.adapter.driver);
4406                        c_state.pending_vf_state = CoordinatorStatePendingVfState::Delay {
4407                            timer,
4408                            delay_until: Instant::now() + VF_DEVICE_DELAY,
4409                        };
4410                    }
4411                }
4412                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4413                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4414                | PrimaryChannelGuestVfState::Ready
4415                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready)
4416                | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4417                | PrimaryChannelGuestVfState::Restoring(
4418                    saved_state::GuestVfState::DataPathSwitchPending { .. },
4419                )
4420                | PrimaryChannelGuestVfState::DataPathSwitched
4421                | PrimaryChannelGuestVfState::Restoring(
4422                    saved_state::GuestVfState::DataPathSwitched,
4423                )
4424                | PrimaryChannelGuestVfState::DataPathSynthetic => {
4425                    c_state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4426                }
4427                _ => (),
4428            };
4429            // ensure data path is switched as expected
4430            if let PrimaryChannelGuestVfState::Restoring(
4431                saved_state::GuestVfState::DataPathSwitchPending {
4432                    to_guest,
4433                    id,
4434                    result,
4435                },
4436            ) = primary.guest_vf_state
4437            {
4438                // If the save was after the data path switch already occurred, don't do it again.
4439                if result.is_some() {
4440                    primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
4441                        to_guest,
4442                        id,
4443                        result,
4444                    };
4445                    return;
4446                }
4447            }
4448            primary.guest_vf_state = match primary.guest_vf_state {
4449                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4450                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4451                | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4452                | PrimaryChannelGuestVfState::Restoring(
4453                    saved_state::GuestVfState::DataPathSwitchPending { .. },
4454                )
4455                | PrimaryChannelGuestVfState::DataPathSynthetic => {
4456                    let (to_guest, id) = match primary.guest_vf_state {
4457                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4458                            to_guest,
4459                            id,
4460                        }
4461                        | PrimaryChannelGuestVfState::DataPathSwitchPending {
4462                            to_guest, id, ..
4463                        }
4464                        | PrimaryChannelGuestVfState::Restoring(
4465                            saved_state::GuestVfState::DataPathSwitchPending {
4466                                to_guest, id, ..
4467                            },
4468                        ) => (to_guest, id),
4469                        _ => (true, None),
4470                    };
4471                    // Cancel any outstanding delay timers for VF offers if the data path is
4472                    // getting switched, since the guest is already issuing
4473                    // commands assuming a VF.
4474                    if matches!(
4475                        c_state.pending_vf_state,
4476                        CoordinatorStatePendingVfState::Delay { .. }
4477                    ) {
4478                        c_state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4479                    }
4480                    let result = c_state.endpoint.set_data_path_to_guest_vf(to_guest).await;
4481                    let result = if let Err(err) = result {
4482                        tracing::error!(
4483                            err = err.as_ref() as &dyn std::error::Error,
4484                            to_guest,
4485                            "Failed to switch guest VF data path"
4486                        );
4487                        false
4488                    } else {
4489                        primary.is_data_path_switched = Some(to_guest);
4490                        true
4491                    };
4492                    match primary.guest_vf_state {
4493                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4494                            ..
4495                        }
4496                        | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4497                        | PrimaryChannelGuestVfState::Restoring(
4498                            saved_state::GuestVfState::DataPathSwitchPending { .. },
4499                        ) => PrimaryChannelGuestVfState::DataPathSwitchPending {
4500                            to_guest,
4501                            id,
4502                            result: Some(result),
4503                        },
4504                        _ if result => PrimaryChannelGuestVfState::DataPathSwitched,
4505                        _ => PrimaryChannelGuestVfState::DataPathSynthetic,
4506                    }
4507                }
4508                PrimaryChannelGuestVfState::Initializing
4509                | PrimaryChannelGuestVfState::Unavailable
4510                | PrimaryChannelGuestVfState::UnavailableFromAvailable
4511                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
4512                    PrimaryChannelGuestVfState::Available { vfid: guest_vf_id }
4513                }
4514                PrimaryChannelGuestVfState::AvailableAdvertised
4515                | PrimaryChannelGuestVfState::Restoring(
4516                    saved_state::GuestVfState::AvailableAdvertised,
4517                ) => {
4518                    if !primary.is_data_path_switched.unwrap_or(false) {
4519                        PrimaryChannelGuestVfState::AvailableAdvertised
4520                    } else {
4521                        // A previous instantiation already switched the data
4522                        // path.
4523                        PrimaryChannelGuestVfState::DataPathSwitched
4524                    }
4525                }
4526                PrimaryChannelGuestVfState::DataPathSwitched
4527                | PrimaryChannelGuestVfState::Restoring(
4528                    saved_state::GuestVfState::DataPathSwitched,
4529                ) => PrimaryChannelGuestVfState::DataPathSwitched,
4530                PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4531                    PrimaryChannelGuestVfState::Ready
4532                }
4533                _ => primary.guest_vf_state,
4534            };
4535        } else {
4536            // If the device was just removed, make sure the data path is synthetic.
4537            match primary.guest_vf_state {
4538                PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, .. }
4539                | PrimaryChannelGuestVfState::Restoring(
4540                    saved_state::GuestVfState::DataPathSwitchPending { to_guest, .. },
4541                ) => {
4542                    if !to_guest {
4543                        if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4544                            tracing::warn!(
4545                                err = err.as_ref() as &dyn std::error::Error,
4546                                "Failed setting data path back to synthetic after guest VF was removed."
4547                            );
4548                        }
4549                        primary.is_data_path_switched = Some(false);
4550                    }
4551                }
4552                PrimaryChannelGuestVfState::DataPathSwitched
4553                | PrimaryChannelGuestVfState::Restoring(
4554                    saved_state::GuestVfState::DataPathSwitched,
4555                ) => {
4556                    if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4557                        tracing::warn!(
4558                            err = err.as_ref() as &dyn std::error::Error,
4559                            "Failed setting data path back to synthetic after guest VF was removed."
4560                        );
4561                    }
4562                    primary.is_data_path_switched = Some(false);
4563                }
4564                _ => (),
4565            }
4566            if let PrimaryChannelGuestVfState::AvailableAdvertised = primary.guest_vf_state {
4567                c_state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
4568            }
4569            // Notify guest if VF is no longer available
4570            primary.guest_vf_state = match primary.guest_vf_state {
4571                PrimaryChannelGuestVfState::Initializing
4572                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState)
4573                | PrimaryChannelGuestVfState::Available { .. } => {
4574                    PrimaryChannelGuestVfState::Unavailable
4575                }
4576                PrimaryChannelGuestVfState::AvailableAdvertised
4577                | PrimaryChannelGuestVfState::Restoring(
4578                    saved_state::GuestVfState::AvailableAdvertised,
4579                )
4580                | PrimaryChannelGuestVfState::Ready
4581                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4582                    PrimaryChannelGuestVfState::UnavailableFromAvailable
4583                }
4584                PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, id, .. }
4585                | PrimaryChannelGuestVfState::Restoring(
4586                    saved_state::GuestVfState::DataPathSwitchPending { to_guest, id, .. },
4587                ) => PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4588                    to_guest,
4589                    id,
4590                },
4591                PrimaryChannelGuestVfState::DataPathSwitched
4592                | PrimaryChannelGuestVfState::Restoring(
4593                    saved_state::GuestVfState::DataPathSwitched,
4594                )
4595                | PrimaryChannelGuestVfState::DataPathSynthetic => {
4596                    PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4597                }
4598                _ => primary.guest_vf_state,
4599            }
4600        }
4601    }
4602
4603    async fn restart_queues(&mut self, c_state: &mut CoordinatorState) -> Result<(), WorkerError> {
4604        // Pre-compute the active queue count and validate the rx buffer configuration
4605        // before continuing with the queue restart work in this function.
4606        // Invalid configurations are returned to the caller as errors.
4607        let (num_queues, active_queues, active_queue_count) = if let Some(state) = self.workers[0]
4608            .state()
4609            .and_then(|worker| worker.state.ready())
4610        {
4611            let num_queues = state.state.primary.as_ref().unwrap().requested_num_queues;
4612            let mut active_queues = Vec::new();
4613            let active_queue_count = if let Some(rss_state) =
4614                state.state.primary.as_ref().unwrap().rss_state.as_ref()
4615            {
4616                active_queues.clone_from(&rss_state.indirection_table);
4617                active_queues.sort();
4618                active_queues.dedup();
4619                active_queues = active_queues
4620                    .into_iter()
4621                    .filter(|&index| index < num_queues)
4622                    .collect::<Vec<_>>();
4623                if !active_queues.is_empty() {
4624                    active_queues.len() as u16
4625                } else {
4626                    tracelimit::warn_ratelimited!(
4627                        num_queues,
4628                        indirection_table_len = rss_state.indirection_table.len(),
4629                        "RSS indirection table has no entries within the valid queue range, falling back to num_queues",
4630                    );
4631                    num_queues
4632                }
4633            } else {
4634                num_queues
4635            };
4636
4637            RxBufferRanges::validate_params(
4638                state.buffers.recv_buffer.count,
4639                active_queue_count.into(),
4640            )?;
4641
4642            GuestBuffers::validate_config(
4643                &state.buffers.recv_buffer.gpadl,
4644                state.buffers.recv_buffer.sub_allocation_size,
4645                state.buffers.ndis_config.mtu,
4646            )
4647            .map_err(WorkerError::GuestBuffers)?;
4648
4649            (num_queues, active_queues, active_queue_count)
4650        } else {
4651            // No ready state; restart_queues will return Ok(()).
4652            (0, Vec::new(), 0)
4653        };
4654
4655        // Drop all the queues and stop the endpoint. Collect the worker drivers to pass to the queues.
4656        let drivers = self
4657            .workers
4658            .iter_mut()
4659            .map(|worker| {
4660                let task = worker.task_mut();
4661                task.queue_state = None;
4662                task.driver.clone()
4663            })
4664            .collect::<Vec<_>>();
4665
4666        c_state.endpoint.stop().await;
4667
4668        let (primary_worker, subworkers) = if let [primary, sub @ ..] = self.workers.as_mut_slice()
4669        {
4670            (primary, sub)
4671        } else {
4672            unreachable!()
4673        };
4674
4675        let state = primary_worker
4676            .state_mut()
4677            .and_then(|worker| worker.state.ready_mut());
4678
4679        let state = if let Some(state) = state {
4680            state
4681        } else {
4682            return Ok(());
4683        };
4684
4685        // Save the channel buffers for use in the subchannel workers.
4686        self.buffers = Some(state.buffers.clone());
4687
4688        // Distribute the rx buffers to only the active queues.
4689        let (ranges, mut remote_buffer_id_recvs) =
4690            RxBufferRanges::new(state.buffers.recv_buffer.count, active_queue_count.into())?;
4691        let ranges = Arc::new(ranges);
4692
4693        let mut queues = Vec::new();
4694        let mut rx_buffers = Vec::new();
4695        let mut per_queue_rx: Vec<Vec<RxId>> = Vec::new();
4696        let guest_buffers;
4697        {
4698            let buffers = &state.buffers;
4699            guest_buffers = Arc::new(
4700                GuestBuffers::new(
4701                    buffers.mem.clone(),
4702                    buffers.recv_buffer.gpadl.clone(),
4703                    buffers.recv_buffer.sub_allocation_size,
4704                    buffers.ndis_config.mtu,
4705                )
4706                .map_err(WorkerError::GuestBuffers)?,
4707            );
4708
4709            // Get the list of free rx buffers from each task, then partition
4710            // the list per-active-queue, and produce the queue configuration.
4711            let mut queue_config = Vec::new();
4712            let initial_rx;
4713            {
4714                let states = std::iter::once(Some(&*state)).chain(
4715                    subworkers
4716                        .iter()
4717                        .map(|worker| worker.state().and_then(|worker| worker.state.ready())),
4718                );
4719
4720                initial_rx = (RX_RESERVED_CONTROL_BUFFERS..state.buffers.recv_buffer.count)
4721                    .filter(|&n| states.clone().flatten().all(|s| s.state.rx_bufs.is_free(n)))
4722                    .map(RxId)
4723                    .collect::<Vec<_>>();
4724
4725                let mut initial_rx = initial_rx.as_slice();
4726                let mut range_start = 0;
4727                let primary_queue_excluded = !active_queues.is_empty() && active_queues[0] != 0;
4728                let first_queue = if !primary_queue_excluded {
4729                    0
4730                } else {
4731                    // If the primary queue is excluded from the guest supplied
4732                    // indirection table, it is assigned just the reserved
4733                    // buffers.
4734                    queue_config.push(QueueConfig {
4735                        driver: Box::new(drivers[0].clone()),
4736                    });
4737                    per_queue_rx.push(Vec::new());
4738                    rx_buffers.push(RxBufferRange::new(
4739                        ranges.clone(),
4740                        0..RX_RESERVED_CONTROL_BUFFERS,
4741                        None,
4742                    ));
4743                    range_start = RX_RESERVED_CONTROL_BUFFERS;
4744                    1
4745                };
4746                for queue_index in first_queue..num_queues {
4747                    let queue_active = active_queues.is_empty()
4748                        || active_queues.binary_search(&queue_index).is_ok();
4749                    let (range_end, end, buffer_id_recv) = if queue_active {
4750                        let range_end = if rx_buffers.len() as u16 == active_queue_count - 1 {
4751                            // The last queue gets all the remaining buffers.
4752                            state.buffers.recv_buffer.count
4753                        } else if queue_index == 0 {
4754                            // Queue zero always includes the reserved buffers.
4755                            RX_RESERVED_CONTROL_BUFFERS + ranges.buffers_per_queue
4756                        } else {
4757                            range_start + ranges.buffers_per_queue
4758                        };
4759                        (
4760                            range_end,
4761                            initial_rx.partition_point(|id| id.0 < range_end),
4762                            Some(remote_buffer_id_recvs.remove(0)),
4763                        )
4764                    } else {
4765                        (range_start, 0, None)
4766                    };
4767
4768                    let (this, rest) = initial_rx.split_at(end);
4769                    queue_config.push(QueueConfig {
4770                        driver: Box::new(drivers[queue_index as usize].clone()),
4771                    });
4772                    per_queue_rx.push(this.to_vec());
4773                    initial_rx = rest;
4774                    rx_buffers.push(RxBufferRange::new(
4775                        ranges.clone(),
4776                        range_start..range_end,
4777                        buffer_id_recv,
4778                    ));
4779
4780                    range_start = range_end;
4781                }
4782            }
4783
4784            let primary = state.state.primary.as_mut().unwrap();
4785            tracing::debug!(num_queues, "enabling endpoint");
4786
4787            let rss = primary
4788                .rss_state
4789                .as_ref()
4790                .map(|rss| net_backend::RssConfig {
4791                    key: &rss.key,
4792                    indirection_table: &rss.indirection_table,
4793                    flags: 0,
4794                });
4795
4796            c_state
4797                .endpoint
4798                .get_queues(queue_config, rss.as_ref(), &mut queues)
4799                .instrument(tracing::info_span!("netvsp_get_queues"))
4800                .await
4801                .map_err(WorkerError::Endpoint)?;
4802
4803            assert_eq!(queues.len(), num_queues as usize);
4804
4805            // Set the subchannel count.
4806            self.channel_control
4807                .enable_subchannels(num_queues - 1)
4808                .expect("already validated");
4809
4810            self.num_queues = num_queues;
4811        }
4812
4813        self.active_packet_filter = self.workers[0].state().unwrap().channel.packet_filter;
4814        // Provide the queue and receive buffer ranges for each worker.
4815        for (((worker, mut queue), rx_buffer), initial) in self
4816            .workers
4817            .iter_mut()
4818            .zip(queues)
4819            .zip(rx_buffers)
4820            .zip(per_queue_rx)
4821        {
4822            let mut pool = BufferPool::new(guest_buffers.clone());
4823            if !initial.is_empty() {
4824                queue.rx_avail(&mut pool, &initial);
4825            }
4826            worker.task_mut().queue_state = Some(QueueState {
4827                queue,
4828                pool,
4829                target_vp_set: false,
4830                rx_buffer_range: rx_buffer,
4831            });
4832            // Update the receive packet filter for the subchannel worker.
4833            if let Some(worker) = worker.state_mut() {
4834                worker.channel.packet_filter = self.active_packet_filter;
4835                // Clear any pending RxIds as buffers were redistributed
4836                // and reset TX tracking after the endpoint stop.
4837                if let Some(ready_state) = worker.state.ready_mut() {
4838                    ready_state.state.pending_rx_packets.clear();
4839                    ready_state.reset_tx_after_endpoint_stop();
4840                }
4841            }
4842        }
4843
4844        Ok(())
4845    }
4846
4847    fn primary_mut(&mut self) -> Option<&mut PrimaryChannelState> {
4848        self.workers[0]
4849            .state_mut()
4850            .unwrap()
4851            .state
4852            .ready_mut()?
4853            .state
4854            .primary
4855            .as_mut()
4856    }
4857
4858    async fn update_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
4859        self.stop_primary_worker().await;
4860        self.restore_guest_vf_state(c_state).await;
4861    }
4862}
4863
4864impl<T: RingMem + 'static + Sync> AsyncRun<Worker<T>> for NetQueue {
4865    async fn run(
4866        &mut self,
4867        stop: &mut StopTask<'_>,
4868        worker: &mut Worker<T>,
4869    ) -> Result<(), task_control::Cancelled> {
4870        match worker.process(stop, self).await {
4871            Ok(()) | Err(WorkerError::BufferRevoked) => {}
4872            Err(WorkerError::Cancelled(cancelled)) => return Err(cancelled),
4873            Err(err) => {
4874                tracing::error!(
4875                    error = &err as &dyn std::error::Error,
4876                    channel_idx = worker.channel_idx,
4877                    "netvsp error"
4878                );
4879            }
4880        }
4881        Ok(())
4882    }
4883}
4884
4885impl<T: RingMem + 'static> Worker<T> {
4886    async fn process(
4887        &mut self,
4888        stop: &mut StopTask<'_>,
4889        queue: &mut NetQueue,
4890    ) -> Result<(), WorkerError> {
4891        // Be careful not to wait on actions with unbounded blocking time (e.g.
4892        // guest actions, or waiting for network packets to arrive) without
4893        // wrapping the wait on `stop.until_stopped`.
4894        loop {
4895            match &mut self.state {
4896                WorkerState::Init(initializing) => {
4897                    assert_eq!(self.channel_idx, 0);
4898
4899                    tracelimit::info_ratelimited!("network accepted");
4900
4901                    let (buffers, state) = stop
4902                        .until_stopped(self.channel.initialize(initializing, self.mem.clone()))
4903                        .await??;
4904
4905                    let state = ReadyState {
4906                        buffers: Arc::new(buffers),
4907                        state,
4908                        data: ProcessingData::new(),
4909                    };
4910
4911                    // Wake up the coordinator task to start the queues.
4912                    if let Err(err) = self
4913                        .coordinator_send
4914                        .try_send(CoordinatorMessage::Restart { channel_idx: 0 })
4915                    {
4916                        tracelimit::error_ratelimited!(
4917                            error = &err as &dyn std::error::Error,
4918                            channel_idx = self.channel_idx,
4919                            "failed to send restart message to coordinator"
4920                        );
4921                    }
4922
4923                    tracelimit::info_ratelimited!("network initialized");
4924                    self.state = WorkerState::WaitingForCoordinator(Some(state));
4925                }
4926                WorkerState::WaitingForCoordinator(_) => {
4927                    assert_eq!(self.channel_idx, 0);
4928                    // Waiting for the coordinator to process a message and
4929                    // restart the primary worker.
4930                    stop.until_stopped(pending()).await?
4931                }
4932                WorkerState::Ready(state) => {
4933                    let queue_state = if let Some(queue_state) = &mut queue.queue_state {
4934                        if !queue_state.target_vp_set {
4935                            if let Some(target_vp) = self.target_vp {
4936                                tracing::debug!(
4937                                    channel_idx = self.channel_idx,
4938                                    target_vp,
4939                                    "updating target VP"
4940                                );
4941                                queue_state.queue.update_target_vp(target_vp).await;
4942                                queue_state.target_vp_set = true;
4943                            }
4944                        }
4945
4946                        queue_state
4947                    } else {
4948                        // This task will be restarted when the queues are ready.
4949                        stop.until_stopped(pending()).await?
4950                    };
4951
4952                    let result = self.channel.main_loop(stop, state, queue_state).await;
4953                    let msg = match result {
4954                        Ok(restart) => {
4955                            assert_eq!(self.channel_idx, 0);
4956                            restart
4957                        }
4958                        Err(WorkerError::EndpointRequiresQueueRestart(err)) => {
4959                            tracelimit::warn_ratelimited!(
4960                                err = err.as_ref() as &dyn std::error::Error,
4961                                channel_idx = self.channel_idx,
4962                                "Endpoint requires queues to restart",
4963                            );
4964                            CoordinatorMessage::Restart {
4965                                channel_idx: self.channel_idx,
4966                            }
4967                        }
4968                        Err(err) => return Err(err),
4969                    };
4970
4971                    // Only the Primary channel transitions to `WaitingForCoordinator`.
4972                    // Sub-channels stay in `Ready(_)`.
4973                    if self.channel_idx == 0 {
4974                        let WorkerState::Ready(ready) = std::mem::replace(
4975                            &mut self.state,
4976                            WorkerState::WaitingForCoordinator(None),
4977                        ) else {
4978                            unreachable!("must be running in ready state")
4979                        };
4980                        let _ = std::mem::replace(
4981                            &mut self.state,
4982                            WorkerState::WaitingForCoordinator(Some(ready)),
4983                        );
4984                    }
4985                    self.coordinator_send
4986                        .try_send(msg)
4987                        .map_err(WorkerError::CoordinatorMessageSendFailed)?;
4988                    stop.until_stopped(pending()).await?
4989                }
4990            }
4991        }
4992    }
4993}
4994
4995impl<T: 'static + RingMem> NetChannel<T> {
4996    fn try_next_packet<'a>(
4997        &mut self,
4998        send_buffer: Option<&SendBuffer>,
4999        version: Option<Version>,
5000        external_data: &'a mut MultiPagedRangeBuf,
5001    ) -> Result<Option<Packet<'a>>, WorkerError> {
5002        let (mut read, _) = self.queue.split();
5003        let packet = match read.try_read() {
5004            Ok(packet) => parse_packet(&packet, send_buffer, version, external_data)
5005                .map_err(WorkerError::Packet)?,
5006            Err(queue::TryReadError::Empty) => return Ok(None),
5007            Err(queue::TryReadError::Queue(err)) => return Err(err.into()),
5008        };
5009
5010        tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
5011        Ok(Some(packet))
5012    }
5013
5014    async fn next_packet<'a>(
5015        &mut self,
5016        send_buffer: Option<&'a SendBuffer>,
5017        version: Option<Version>,
5018        external_data: &'a mut MultiPagedRangeBuf,
5019    ) -> Result<Packet<'a>, WorkerError> {
5020        let (mut read, _) = self.queue.split();
5021        let mut packet_ref = read.read().await?;
5022        let packet = parse_packet(&packet_ref, send_buffer, version, external_data)
5023            .map_err(WorkerError::Packet)?;
5024        if matches!(packet.data, PacketData::RndisPacket(_)) {
5025            // In WorkerState::Init if an rndis packet is received, assume it is MESSAGE_TYPE_INITIALIZE_MSG
5026            tracing::trace!(target: "netvsp/vmbus", "detected rndis initialization message");
5027            packet_ref.revert();
5028        }
5029        tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
5030        Ok(packet)
5031    }
5032
5033    fn is_ready_to_initialize(initializing: &InitState, allow_missing_send_buffer: bool) -> bool {
5034        (initializing.ndis_config.is_some() || initializing.version < Version::V2)
5035            && initializing.ndis_version.is_some()
5036            && (initializing.send_buffer.is_some() || allow_missing_send_buffer)
5037            && initializing.recv_buffer.is_some()
5038    }
5039
5040    async fn initialize(
5041        &mut self,
5042        initializing: &mut Option<InitState>,
5043        mem: GuestMemory,
5044    ) -> Result<(ChannelBuffers, ActiveState), WorkerError> {
5045        let mut has_init_packet_arrived = false;
5046        loop {
5047            if let Some(initializing) = &mut *initializing {
5048                if Self::is_ready_to_initialize(initializing, false) || has_init_packet_arrived {
5049                    let recv_buffer = initializing.recv_buffer.take().unwrap();
5050                    let send_buffer = initializing.send_buffer.take();
5051                    let state = ActiveState::new(
5052                        Some(PrimaryChannelState::new(
5053                            self.adapter.offload_support.clone(),
5054                        )),
5055                        recv_buffer.count,
5056                    );
5057                    let buffers = ChannelBuffers {
5058                        version: initializing.version,
5059                        mem,
5060                        recv_buffer,
5061                        send_buffer,
5062                        ndis_version: initializing.ndis_version.take().unwrap(),
5063                        ndis_config: initializing.ndis_config.take().unwrap_or(NdisConfig {
5064                            mtu: DEFAULT_MTU,
5065                            capabilities: protocol::NdisConfigCapabilities::new(),
5066                        }),
5067                    };
5068
5069                    break Ok((buffers, state));
5070                }
5071            }
5072
5073            // Wait for enough room in the ring to avoid needing to track
5074            // completion packets.
5075            self.queue
5076                .split()
5077                .1
5078                .wait_ready(ring::PacketSize::completion(protocol::PACKET_SIZE_V61))
5079                .await?;
5080
5081            let mut external_data = MultiPagedRangeBuf::new();
5082            let packet = self
5083                .next_packet(
5084                    None,
5085                    initializing.as_ref().map(|x| x.version),
5086                    &mut external_data,
5087                )
5088                .await?;
5089
5090            if let Some(initializing) = &mut *initializing {
5091                match packet.data {
5092                    PacketData::SendNdisConfig(config) => {
5093                        if initializing.ndis_config.is_some() {
5094                            return Err(WorkerError::UnexpectedPacketOrder(
5095                                PacketOrderError::SendNdisConfigExists,
5096                            ));
5097                        }
5098
5099                        // As in the vmswitch, if the MTU is invalid then use the default.
5100                        let mtu = if config.mtu >= MIN_MTU && config.mtu <= MAX_MTU {
5101                            config.mtu
5102                        } else {
5103                            DEFAULT_MTU
5104                        };
5105
5106                        // The UEFI client expects a completion packet, which can be empty.
5107                        self.send_completion(packet.transaction_id, None)?;
5108                        initializing.ndis_config = Some(NdisConfig {
5109                            mtu,
5110                            capabilities: config.capabilities,
5111                        });
5112                    }
5113                    PacketData::SendNdisVersion(version) => {
5114                        if initializing.ndis_version.is_some() {
5115                            return Err(WorkerError::UnexpectedPacketOrder(
5116                                PacketOrderError::SendNdisVersionExists,
5117                            ));
5118                        }
5119
5120                        // The UEFI client expects a completion packet, which can be empty.
5121                        self.send_completion(packet.transaction_id, None)?;
5122                        initializing.ndis_version = Some(NdisVersion {
5123                            major: version.ndis_major_version,
5124                            minor: version.ndis_minor_version,
5125                        });
5126                    }
5127                    PacketData::SendReceiveBuffer(message) => {
5128                        if initializing.recv_buffer.is_some() {
5129                            return Err(WorkerError::UnexpectedPacketOrder(
5130                                PacketOrderError::SendReceiveBufferExists,
5131                            ));
5132                        }
5133
5134                        let mtu = if let Some(cfg) = &initializing.ndis_config {
5135                            cfg.mtu
5136                        } else if initializing.version < Version::V2 {
5137                            DEFAULT_MTU
5138                        } else {
5139                            return Err(WorkerError::UnexpectedPacketOrder(
5140                                PacketOrderError::SendReceiveBufferMissingMTU,
5141                            ));
5142                        };
5143
5144                        let sub_allocation_size = sub_allocation_size_for_mtu(mtu);
5145
5146                        let recv_buffer = ReceiveBuffer::new(
5147                            &self.gpadl_map,
5148                            message.gpadl_handle,
5149                            message.id,
5150                            sub_allocation_size,
5151                        )?;
5152
5153                        self.send_completion(
5154                            packet.transaction_id,
5155                            Some(&self.message(
5156                                protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER_COMPLETE,
5157                                protocol::Message1SendReceiveBufferComplete {
5158                                    status: protocol::Status::SUCCESS,
5159                                    num_sections: 1,
5160                                    sections: [protocol::ReceiveBufferSection {
5161                                        offset: 0,
5162                                        sub_allocation_size: recv_buffer.sub_allocation_size,
5163                                        num_sub_allocations: recv_buffer.count,
5164                                        end_offset: recv_buffer.sub_allocation_size
5165                                            * recv_buffer.count,
5166                                    }],
5167                                },
5168                            )),
5169                        )?;
5170                        initializing.recv_buffer = Some(recv_buffer);
5171                    }
5172                    PacketData::SendSendBuffer(message) => {
5173                        if initializing.send_buffer.is_some() {
5174                            return Err(WorkerError::UnexpectedPacketOrder(
5175                                PacketOrderError::SendSendBufferExists,
5176                            ));
5177                        }
5178
5179                        let send_buffer = SendBuffer::new(&self.gpadl_map, message.gpadl_handle)?;
5180                        self.send_completion(
5181                            packet.transaction_id,
5182                            Some(&self.message(
5183                                protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER_COMPLETE,
5184                                protocol::Message1SendSendBufferComplete {
5185                                    status: protocol::Status::SUCCESS,
5186                                    section_size: 6144,
5187                                },
5188                            )),
5189                        )?;
5190
5191                        initializing.send_buffer = Some(send_buffer);
5192                    }
5193                    PacketData::RndisPacket(rndis_packet) => {
5194                        if !Self::is_ready_to_initialize(initializing, true) {
5195                            return Err(WorkerError::UnexpectedPacketOrder(
5196                                PacketOrderError::UnexpectedRndisPacket,
5197                            ));
5198                        }
5199                        tracing::debug!(
5200                            channel_type = rndis_packet.channel_type,
5201                            "RndisPacket received during initialization, assuming MESSAGE_TYPE_INITIALIZE_MSG"
5202                        );
5203                        has_init_packet_arrived = true;
5204                    }
5205                    _ => {
5206                        return Err(WorkerError::UnexpectedPacketOrder(
5207                            PacketOrderError::InvalidPacketData,
5208                        ));
5209                    }
5210                }
5211            } else {
5212                match packet.data {
5213                    PacketData::Init(init) => {
5214                        let requested_version = init.protocol_version;
5215                        let version = check_version(requested_version);
5216                        let mut data = protocol::MessageInitComplete {
5217                            deprecated: protocol::INVALID_PROTOCOL_VERSION,
5218                            maximum_mdl_chain_length: 34,
5219                            status: protocol::Status::NONE,
5220                        };
5221                        if let Some(version) = version {
5222                            if version == Version::V1 {
5223                                data.deprecated = Version::V1 as u32;
5224                            }
5225                            data.status = protocol::Status::SUCCESS;
5226                        } else {
5227                            tracing::debug!(requested_version, "unrecognized version");
5228                        }
5229                        let message = self.message(protocol::MESSAGE_TYPE_INIT_COMPLETE, data);
5230                        self.send_completion(packet.transaction_id, Some(&message))?;
5231
5232                        if let Some(version) = version {
5233                            tracelimit::info_ratelimited!(?version, "network negotiated");
5234
5235                            if version >= Version::V61 {
5236                                // Update the packet size so that the appropriate padding is
5237                                // appended for picky Windows guests.
5238                                self.packet_size = PacketSize::V61;
5239                            }
5240                            *initializing = Some(InitState {
5241                                version,
5242                                ndis_config: None,
5243                                ndis_version: None,
5244                                recv_buffer: None,
5245                                send_buffer: None,
5246                            });
5247                        }
5248                    }
5249                    _ => unreachable!(),
5250                }
5251            }
5252        }
5253    }
5254
5255    async fn main_loop(
5256        &mut self,
5257        stop: &mut StopTask<'_>,
5258        ready_state: &mut ReadyState,
5259        queue_state: &mut QueueState,
5260    ) -> Result<CoordinatorMessage, WorkerError> {
5261        let buffers = &ready_state.buffers;
5262        let state = &mut ready_state.state;
5263        let data = &mut ready_state.data;
5264
5265        let ring_spare_capacity = {
5266            let (_, send) = self.queue.split();
5267            let mut limit = if self.can_use_ring_size_opt {
5268                self.adapter.ring_size_limit.load(Ordering::Relaxed)
5269            } else {
5270                0
5271            };
5272            if limit == 0 {
5273                limit = send.capacity() - 2048;
5274            }
5275            send.capacity() - limit
5276        };
5277
5278        // If the packet filter has changed to allow rx packets, add any pended RxIds.
5279        if !state.pending_rx_packets.is_empty()
5280            && self.packet_filter != rndisprot::NDIS_PACKET_TYPE_NONE
5281        {
5282            let (front, back) = state.pending_rx_packets.as_slices();
5283            queue_state.queue.rx_avail(&mut queue_state.pool, front);
5284            queue_state.queue.rx_avail(&mut queue_state.pool, back);
5285            state.pending_rx_packets.clear();
5286        }
5287
5288        // Handle any guest state changes since last run.
5289        if let Some(primary) = state.primary.as_mut() {
5290            if primary.requested_num_queues > 1 && !primary.tx_spread_sent {
5291                let num_channels_opened =
5292                    self.adapter.num_sub_channels_opened.load(Ordering::Relaxed);
5293                if num_channels_opened == primary.requested_num_queues as usize {
5294                    let (_, mut send) = self.queue.split();
5295                    stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
5296                        .await??;
5297                    self.guest_send_indirection_table(buffers.version, num_channels_opened as u32);
5298                    primary.tx_spread_sent = true;
5299                }
5300            }
5301            if let PendingLinkAction::Active(up) = primary.pending_link_action {
5302                let (_, mut send) = self.queue.split();
5303                stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
5304                    .await??;
5305                if let Some(id) = primary.free_control_buffers.pop() {
5306                    let connect = if primary.guest_link_up != up {
5307                        primary.pending_link_action = PendingLinkAction::Default;
5308                        up
5309                    } else {
5310                        // For the up -> down -> up OR down -> up -> down case, the first transition
5311                        // is sent immediately and the second transition is queued with a delay.
5312                        primary.pending_link_action =
5313                            PendingLinkAction::Delay(primary.guest_link_up);
5314                        !primary.guest_link_up
5315                    };
5316                    // Mark the receive buffer in use to allow the guest to release it.
5317                    assert!(state.rx_bufs.is_free(id.0));
5318                    state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
5319                    let state_to_send = if connect {
5320                        rndisprot::STATUS_MEDIA_CONNECT
5321                    } else {
5322                        rndisprot::STATUS_MEDIA_DISCONNECT
5323                    };
5324                    tracing::info!(
5325                        connect,
5326                        mac_address = %self.adapter.mac_address,
5327                        "sending link status"
5328                    );
5329
5330                    self.indicate_status(buffers, id.0, state_to_send, &[])?;
5331                    primary.guest_link_up = connect;
5332                } else {
5333                    primary.pending_link_action = PendingLinkAction::Delay(up);
5334                }
5335
5336                match primary.pending_link_action {
5337                    PendingLinkAction::Delay(_) => {
5338                        return Ok(CoordinatorMessage::StartTimer(
5339                            Instant::now() + LINK_DELAY_DURATION,
5340                        ));
5341                    }
5342                    PendingLinkAction::Active(_) => panic!("State should not be Active"),
5343                    _ => {}
5344                }
5345            }
5346            match primary.guest_vf_state {
5347                PrimaryChannelGuestVfState::Available { .. }
5348                | PrimaryChannelGuestVfState::UnavailableFromAvailable
5349                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
5350                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
5351                | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
5352                | PrimaryChannelGuestVfState::DataPathSynthetic => {
5353                    let (_, mut send) = self.queue.split();
5354                    stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
5355                        .await??;
5356                    if let Some(message) = self.handle_state_change(primary, buffers).await? {
5357                        return Ok(message);
5358                    }
5359                }
5360                _ => (),
5361            }
5362        }
5363
5364        loop {
5365            // If the ring is almost full, then do not poll the endpoint,
5366            // since this will cause us to spend time dropping packets and
5367            // reprogramming receive buffers. It's more efficient to let the
5368            // backend drop packets.
5369            let ring_full = {
5370                let (_, mut send) = self.queue.split();
5371                !send.can_write(ring_spare_capacity)?
5372            };
5373
5374            let did_some_work = (!ring_full
5375                && self.process_endpoint_rx(
5376                    buffers,
5377                    state,
5378                    data,
5379                    queue_state.queue.as_mut(),
5380                    &mut queue_state.pool,
5381                )?)
5382                | self.process_ring_buffer(buffers, state, data, queue_state)?
5383                | (!ring_full
5384                    && self.process_endpoint_tx(
5385                        state,
5386                        data,
5387                        queue_state.queue.as_mut(),
5388                        &mut queue_state.pool,
5389                    )?)
5390                | self.transmit_pending_segments(state, data, queue_state)?
5391                | self.send_pending_packets(state)?;
5392
5393            if !did_some_work {
5394                state.stats.spurious_wakes.increment();
5395            }
5396
5397            // Process any outstanding control messages before sleeping in case
5398            // there are now suballocations available for use.
5399            self.process_control_messages(buffers, state)?;
5400
5401            // This should be the only await point waiting on network traffic or
5402            // guest actions. Wrap it in `stop.until_stopped` to allow
5403            // cancellation.
5404            let restart = stop
5405                .until_stopped(std::future::poll_fn(
5406                    |cx| -> Poll<Option<CoordinatorMessage>> {
5407                        // If the ring is almost full, then don't wait for endpoint
5408                        // interrupts. This allows the interrupt rate to fall when the
5409                        // guest cannot keep up with the load.
5410                        if !ring_full {
5411                            // Check the network endpoint for tx completion or rx.
5412                            if queue_state
5413                                .queue
5414                                .poll_ready(cx, &mut queue_state.pool)
5415                                .is_ready()
5416                            {
5417                                tracing::trace!("endpoint ready");
5418                                return Poll::Ready(None);
5419                            }
5420                        }
5421
5422                        // Check the incoming ring for tx, but only if there are enough
5423                        // free tx packets and no pending tx segments.
5424                        let (mut recv, mut send) = self.queue.split();
5425                        if state.free_tx_packets.len() >= self.adapter.free_tx_packet_threshold
5426                            && data.tx_segments.is_empty()
5427                            && recv.poll_ready(cx).is_ready()
5428                        {
5429                            tracing::trace!("incoming ring ready");
5430                            return Poll::Ready(None);
5431                        }
5432
5433                        // Check the outgoing ring for space to send rx completions or
5434                        // control message sends, if any are pending.
5435                        //
5436                        // Also, if endpoint processing has been suspended due to the
5437                        // ring being nearly full, then ask the guest to wake us up when
5438                        // there is space again.
5439                        let mut pending_send_size = self.pending_send_size;
5440                        if ring_full {
5441                            pending_send_size = ring_spare_capacity;
5442                        }
5443                        if pending_send_size != 0
5444                            && send.poll_ready(cx, pending_send_size).is_ready()
5445                        {
5446                            tracing::trace!("outgoing ring ready");
5447                            return Poll::Ready(None);
5448                        }
5449
5450                        // Collect any of this queue's receive buffers that were
5451                        // completed by a remote channel. This only happens when the
5452                        // subchannel count changes, so that receive buffer ownership
5453                        // moves between queues while some receive buffers are still in
5454                        // use.
5455                        if let Some(remote_buffer_id_recv) =
5456                            &mut queue_state.rx_buffer_range.remote_buffer_id_recv
5457                        {
5458                            while let Poll::Ready(Some(id)) =
5459                                remote_buffer_id_recv.poll_next_unpin(cx)
5460                            {
5461                                if id >= RX_RESERVED_CONTROL_BUFFERS {
5462                                    queue_state
5463                                        .queue
5464                                        .rx_avail(&mut queue_state.pool, &[RxId(id)]);
5465                                } else {
5466                                    state
5467                                        .primary
5468                                        .as_mut()
5469                                        .unwrap()
5470                                        .free_control_buffers
5471                                        .push(ControlMessageId(id));
5472                                }
5473                            }
5474                        }
5475
5476                        if let Some(restart) = self.restart.take() {
5477                            return Poll::Ready(Some(restart));
5478                        }
5479
5480                        tracing::trace!("network waiting");
5481                        Poll::Pending
5482                    },
5483                ))
5484                .await?;
5485
5486            if let Some(restart) = restart {
5487                break Ok(restart);
5488            }
5489        }
5490    }
5491
5492    fn process_endpoint_rx(
5493        &mut self,
5494        buffers: &ChannelBuffers,
5495        state: &mut ActiveState,
5496        data: &mut ProcessingData,
5497        epqueue: &mut dyn net_backend::Queue,
5498        pool: &mut BufferPool,
5499    ) -> Result<bool, WorkerError> {
5500        let n = epqueue
5501            .rx_poll(pool, &mut data.rx_ready)
5502            .map_err(WorkerError::Endpoint)?;
5503
5504        if n == 0 {
5505            return Ok(false);
5506        }
5507
5508        state.stats.rx_packets_per_wake.add_sample(n as u64);
5509        state.stats.rx_vlan_packets.add(pool.take_rx_vlan_count());
5510
5511        if self.packet_filter == rndisprot::NDIS_PACKET_TYPE_NONE {
5512            tracing::trace!(
5513                packet_filter = self.packet_filter,
5514                "rx packets dropped due to packet filter"
5515            );
5516            // Pend the newly available RxIds until the packet filter is updated.
5517            // Under high load this will eventually lead to no available RxIds,
5518            // which will cause the backend to drop the packets instead of
5519            // processing them here.
5520            state.pending_rx_packets.extend(&data.rx_ready[..n]);
5521            state.stats.rx_dropped_filtered.add(n as u64);
5522            return Ok(false);
5523        }
5524
5525        let transaction_id = data.rx_ready[0].0.into();
5526        let ready_ids = data.rx_ready[..n].iter().map(|&RxId(id)| id);
5527
5528        state.rx_bufs.allocate(ready_ids.clone()).unwrap();
5529
5530        // Always use the full suballocation size to avoid tracking the
5531        // message length. See RxBuf::header() for details.
5532        let len = buffers.recv_buffer.sub_allocation_size as usize;
5533        data.transfer_pages.clear();
5534        data.transfer_pages
5535            .extend(ready_ids.map(|id| buffers.recv_buffer.transfer_page_range(id, len)));
5536
5537        match self.try_send_rndis_message(
5538            transaction_id,
5539            protocol::DATA_CHANNEL_TYPE,
5540            buffers.recv_buffer.id,
5541            &data.transfer_pages,
5542        )? {
5543            None => {
5544                // packet was sent
5545                state.stats.rx_packets.add(n as u64);
5546            }
5547            Some(_) => {
5548                // Ring buffer is full. Drop the packets and free the rx
5549                // buffers. When the ring has limited space, the main loop will
5550                // stop polling for receive packets.
5551                state.stats.rx_dropped_ring_full.add(n as u64);
5552
5553                state.rx_bufs.free(data.rx_ready[0].0);
5554                epqueue.rx_avail(pool, &data.rx_ready[..n]);
5555            }
5556        }
5557
5558        Ok(true)
5559    }
5560
5561    fn process_endpoint_tx(
5562        &mut self,
5563        state: &mut ActiveState,
5564        data: &mut ProcessingData,
5565        epqueue: &mut dyn net_backend::Queue,
5566        pool: &mut BufferPool,
5567    ) -> Result<bool, WorkerError> {
5568        // Drain completed transmits.
5569        let result = epqueue.tx_poll(pool, &mut data.tx_done);
5570
5571        match result {
5572            Ok(n) => {
5573                if n == 0 {
5574                    return Ok(false);
5575                }
5576
5577                for &id in &data.tx_done[..n] {
5578                    let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5579                    assert!(tx_packet.pending_packet_count > 0);
5580                    tx_packet.pending_packet_count -= 1;
5581                    if tx_packet.pending_packet_count == 0 {
5582                        self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5583                    }
5584                }
5585
5586                Ok(true)
5587            }
5588            Err(TxError::TryRestart(err)) => {
5589                // In-flight TX packets will be cleaned up by
5590                // `reset_tx_after_endpoint_stop` during the queue restart.
5591                Err(WorkerError::EndpointRequiresQueueRestart(err))
5592            }
5593            Err(TxError::Fatal(err)) => Err(WorkerError::Endpoint(err)),
5594        }
5595    }
5596
5597    fn switch_data_path(
5598        &mut self,
5599        state: &mut ActiveState,
5600        use_guest_vf: bool,
5601        transaction_id: Option<u64>,
5602    ) -> Result<(), WorkerError> {
5603        let primary = state.primary.as_mut().unwrap();
5604        let mut queue_switch_operation = false;
5605        match primary.guest_vf_state {
5606            PrimaryChannelGuestVfState::AvailableAdvertised | PrimaryChannelGuestVfState::Ready => {
5607                // Allow the guest to switch to VTL0, or if the current state
5608                // of the data path is unknown, allow a switch away from VTL0.
5609                // The latter case handles the scenario where the data path has
5610                // been switched but the synthetic device has been restarted.
5611                // The device is queried for the current state of the data path
5612                // but if it is unknown (upgraded from older version that
5613                // doesn't track this data, or failure during query) then the
5614                // safest option is to pass the request through.
5615                if use_guest_vf || primary.is_data_path_switched.is_none() {
5616                    primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5617                        to_guest: use_guest_vf,
5618                        id: transaction_id,
5619                        result: None,
5620                    };
5621                    queue_switch_operation = true;
5622                }
5623            }
5624            PrimaryChannelGuestVfState::DataPathSwitched => {
5625                if !use_guest_vf {
5626                    primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5627                        to_guest: false,
5628                        id: transaction_id,
5629                        result: None,
5630                    };
5631                    queue_switch_operation = true;
5632                }
5633            }
5634            _ if use_guest_vf => {
5635                tracing::warn!(
5636                    state = %primary.guest_vf_state,
5637                    use_guest_vf,
5638                    "Data path switch requested while device is in wrong state"
5639                );
5640            }
5641            _ => (),
5642        };
5643        if queue_switch_operation {
5644            self.send_coordinator_update_vf();
5645        } else {
5646            self.send_completion(transaction_id, None)?;
5647        }
5648        Ok(())
5649    }
5650
5651    fn process_ring_buffer(
5652        &mut self,
5653        buffers: &ChannelBuffers,
5654        state: &mut ActiveState,
5655        data: &mut ProcessingData,
5656        queue_state: &mut QueueState,
5657    ) -> Result<bool, WorkerError> {
5658        if !data.tx_segments.is_empty() {
5659            // There are still segments pending transmission. Skip polling the
5660            // ring until they are sent to increase backpressure and minimize
5661            // unnecessary wakeups.
5662            return Ok(false);
5663        }
5664        let mut total_packets = 0;
5665        let mut did_some_work = false;
5666        loop {
5667            if state.free_tx_packets.is_empty() {
5668                break;
5669            }
5670            let packet = if let Some(packet) = self.try_next_packet(
5671                buffers.send_buffer.as_ref(),
5672                Some(buffers.version),
5673                &mut data.external_data,
5674            )? {
5675                packet
5676            } else {
5677                break;
5678            };
5679
5680            did_some_work = true;
5681            match packet.data {
5682                PacketData::RndisPacket(_) => {
5683                    let id = state.free_tx_packets.pop().unwrap();
5684                    let result: Result<usize, WorkerError> =
5685                        self.handle_rndis(buffers, id, state, &packet, &mut data.tx_segments);
5686                    match result {
5687                        Ok(num_packets) => {
5688                            total_packets += num_packets as u64;
5689                            if num_packets == 0 {
5690                                self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5691                            }
5692                        }
5693                        Err(err) => {
5694                            tracelimit::error_ratelimited!(
5695                                error = &err as &dyn std::error::Error,
5696                                "failed to handle RNDIS packet"
5697                            );
5698                            self.complete_tx_packet(state, id, protocol::Status::FAILURE)?;
5699                        }
5700                    };
5701                }
5702                PacketData::RndisPacketComplete(_completion) => {
5703                    data.rx_done.clear();
5704                    state
5705                        .release_recv_buffers(
5706                            packet
5707                                .transaction_id
5708                                .expect("completion packets have transaction id by construction"),
5709                            &queue_state.rx_buffer_range,
5710                            &mut data.rx_done,
5711                        )
5712                        .ok_or(WorkerError::InvalidRndisPacketCompletion)?;
5713                    queue_state
5714                        .queue
5715                        .rx_avail(&mut queue_state.pool, &data.rx_done);
5716                }
5717                PacketData::SubChannelRequest(request) if state.primary.is_some() => {
5718                    let mut subchannel_count = 0;
5719                    // The number of requested subchannels has to stay below the maximum queue limit
5720                    // because one queue is always reserved for the primary channel. In other words,
5721                    // the subchannels plus the primary channel must fit within the max_queues value,
5722                    // which means subchannels + 1 ≤ max_queues, so the subchannel count must be
5723                    // strictly less than max_queues.
5724                    let num_queues = request.num_sub_channels + 1;
5725                    let status = if request.operation == protocol::SubchannelOperation::ALLOCATE
5726                        && request.num_sub_channels < self.adapter.max_queues.into()
5727                        && RxBufferRanges::validate_params(buffers.recv_buffer.count, num_queues)
5728                            .is_ok()
5729                    {
5730                        subchannel_count = request.num_sub_channels;
5731                        protocol::Status::SUCCESS
5732                    } else {
5733                        tracelimit::warn_ratelimited!(
5734                            operation = ?request.operation,
5735                            request_sub_channels = request.num_sub_channels,
5736                            max_supported_sub_channels = self.adapter.max_queues - 1,
5737                            recv_buffer_count = buffers.recv_buffer.count,
5738                            "Subchannel request failed: either operation is not supported or requested more subchannels than supported"
5739                        );
5740                        protocol::Status::FAILURE
5741                    };
5742
5743                    tracing::debug!(?status, subchannel_count, "subchannel request");
5744                    self.send_completion(
5745                        packet.transaction_id,
5746                        Some(&self.message(
5747                            protocol::MESSAGE5_TYPE_SUB_CHANNEL,
5748                            protocol::Message5SubchannelComplete {
5749                                status,
5750                                num_sub_channels: subchannel_count,
5751                            },
5752                        )),
5753                    )?;
5754
5755                    if subchannel_count > 0 {
5756                        let primary = state.primary.as_mut().unwrap();
5757                        primary.requested_num_queues = subchannel_count as u16 + 1;
5758                        primary.tx_spread_sent = false;
5759                        self.restart = Some(CoordinatorMessage::Restart { channel_idx: 0 });
5760                    }
5761                }
5762                PacketData::RevokeReceiveBuffer(protocol::Message1RevokeReceiveBuffer { id })
5763                | PacketData::RevokeSendBuffer(protocol::Message1RevokeSendBuffer { id })
5764                    if state.primary.is_some() =>
5765                {
5766                    tracing::debug!(
5767                        id,
5768                        "receive/send buffer revoked, terminating channel processing"
5769                    );
5770                    return Err(WorkerError::BufferRevoked);
5771                }
5772                // No operation for VF association completion packets as not all clients send them
5773                PacketData::SendVfAssociationCompletion if state.primary.is_some() => (),
5774                PacketData::SwitchDataPath(switch_data_path) if state.primary.is_some() => {
5775                    self.switch_data_path(
5776                        state,
5777                        switch_data_path.active_data_path == protocol::DataPath::VF.0,
5778                        packet.transaction_id,
5779                    )?;
5780                }
5781                PacketData::SwitchDataPathCompletion if state.primary.is_some() => (),
5782                PacketData::OidQueryEx(oid_query) => {
5783                    tracing::warn!(oid = ?oid_query.oid, "unimplemented OID");
5784                    self.send_completion(
5785                        packet.transaction_id,
5786                        Some(&self.message(
5787                            protocol::MESSAGE5_TYPE_OID_QUERY_EX_COMPLETE,
5788                            protocol::Message5OidQueryExComplete {
5789                                status: rndisprot::STATUS_NOT_SUPPORTED,
5790                                bytes: 0,
5791                            },
5792                        )),
5793                    )?;
5794                }
5795                p => {
5796                    tracing::warn!(request = ?p, "unexpected packet");
5797                    return Err(WorkerError::UnexpectedPacketOrder(
5798                        PacketOrderError::SwitchDataPathCompletionPrimaryChannelState,
5799                    ));
5800                }
5801            }
5802        }
5803        if total_packets > 0 && !self.transmit_segments(state, data, queue_state)? {
5804            state.stats.tx_stalled.increment();
5805        }
5806        state.stats.tx_packets_per_wake.add_sample(total_packets);
5807        Ok(did_some_work)
5808    }
5809
5810    // Transmit any pending segments. Returns Ok(true) if work was done--if any
5811    // segments were transmitted.
5812    fn transmit_pending_segments(
5813        &mut self,
5814        state: &mut ActiveState,
5815        data: &mut ProcessingData,
5816        queue_state: &mut QueueState,
5817    ) -> Result<bool, WorkerError> {
5818        if data.tx_segments.is_empty() {
5819            return Ok(false);
5820        }
5821        let sent = data.tx_segments_sent;
5822        let did_work =
5823            self.transmit_segments(state, data, queue_state)? || data.tx_segments_sent > sent;
5824        Ok(did_work)
5825    }
5826
5827    /// Returns true if all pending segments were transmitted.
5828    fn transmit_segments(
5829        &mut self,
5830        state: &mut ActiveState,
5831        data: &mut ProcessingData,
5832        queue_state: &mut QueueState,
5833    ) -> Result<bool, WorkerError> {
5834        let segments = &data.tx_segments[data.tx_segments_sent..];
5835        let (sync, segments_sent) = queue_state
5836            .queue
5837            .tx_avail(&mut queue_state.pool, segments)
5838            .map_err(WorkerError::Endpoint)?;
5839
5840        let mut segments = &segments[..segments_sent];
5841        data.tx_segments_sent += segments_sent;
5842
5843        if sync {
5844            // Complete the packets now.
5845            while let Some(head) = segments.first() {
5846                let net_backend::TxSegmentType::Head(metadata) = &head.ty else {
5847                    unreachable!()
5848                };
5849                let id = metadata.id;
5850                let pending_tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5851                pending_tx_packet.pending_packet_count -= 1;
5852                if pending_tx_packet.pending_packet_count == 0 {
5853                    self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5854                }
5855                segments = &segments[metadata.segment_count as usize..];
5856            }
5857        }
5858
5859        let all_sent = data.tx_segments_sent == data.tx_segments.len();
5860        if all_sent {
5861            data.tx_segments.clear();
5862            data.tx_segments_sent = 0;
5863        }
5864        Ok(all_sent)
5865    }
5866
5867    fn handle_rndis(
5868        &mut self,
5869        buffers: &ChannelBuffers,
5870        id: TxId,
5871        state: &mut ActiveState,
5872        packet: &Packet<'_>,
5873        segments: &mut Vec<TxSegment>,
5874    ) -> Result<usize, WorkerError> {
5875        let mut total_packets = 0;
5876        let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5877        assert_eq!(tx_packet.pending_packet_count, 0);
5878        tx_packet.transaction_id = packet
5879            .transaction_id
5880            .ok_or(WorkerError::MissingTransactionId)?;
5881
5882        // Probe the data to catch accesses that are out of bounds. This
5883        // simplifies error handling for backends that use
5884        // [`GuestMemory::iova`].
5885        packet
5886            .external_data
5887            .iter()
5888            .try_for_each(|range| buffers.mem.probe_gpns(range.gpns()))
5889            .map_err(WorkerError::GpaDirectError)?;
5890
5891        let mut reader = packet.rndis_reader(&buffers.mem);
5892        let header: rndisprot::MessageHeader = reader.read_plain()?;
5893        if header.message_type == rndisprot::MESSAGE_TYPE_PACKET_MSG {
5894            let start = segments.len();
5895            match self.handle_rndis_packet_messages(
5896                buffers,
5897                state,
5898                id,
5899                header.message_length as usize,
5900                reader,
5901                segments,
5902            ) {
5903                Ok(n) => {
5904                    state.pending_tx_packets[id.0 as usize].pending_packet_count += n;
5905                    total_packets += n;
5906                }
5907                Err(err) => {
5908                    // Roll back any segments added for this message.
5909                    segments.truncate(start);
5910                    return Err(err);
5911                }
5912            }
5913        } else {
5914            self.handle_rndis_message(state, header.message_type, reader)?;
5915        }
5916
5917        Ok(total_packets)
5918    }
5919
5920    fn try_send_tx_packet(
5921        &mut self,
5922        transaction_id: u64,
5923        status: protocol::Status,
5924    ) -> Result<bool, WorkerError> {
5925        let message = self.message(
5926            protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE,
5927            protocol::Message1SendRndisPacketComplete { status },
5928        );
5929        let result = self.queue.split().1.batched().try_write_aligned(
5930            transaction_id,
5931            OutgoingPacketType::Completion,
5932            message.aligned_payload(),
5933        );
5934        let sent = match result {
5935            Ok(()) => true,
5936            Err(queue::TryWriteError::Full(n)) => {
5937                self.pending_send_size = n;
5938                false
5939            }
5940            Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
5941        };
5942        Ok(sent)
5943    }
5944
5945    fn send_pending_packets(&mut self, state: &mut ActiveState) -> Result<bool, WorkerError> {
5946        let mut did_some_work = false;
5947        while let Some(pending) = state.pending_tx_completions.front() {
5948            if !self.try_send_tx_packet(pending.transaction_id, pending.status)? {
5949                return Ok(did_some_work);
5950            }
5951            did_some_work = true;
5952            if let Some(id) = pending.tx_id {
5953                state.free_tx_packets.push(id);
5954            }
5955            tracing::trace!(?pending, "sent tx completion");
5956            state.pending_tx_completions.pop_front();
5957        }
5958
5959        self.pending_send_size = 0;
5960        Ok(did_some_work)
5961    }
5962
5963    fn complete_tx_packet(
5964        &mut self,
5965        state: &mut ActiveState,
5966        id: TxId,
5967        status: protocol::Status,
5968    ) -> Result<(), WorkerError> {
5969        let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5970        assert_eq!(tx_packet.pending_packet_count, 0);
5971        if self.pending_send_size == 0
5972            && self.try_send_tx_packet(tx_packet.transaction_id, status)?
5973        {
5974            tracing::trace!(id = id.0, "sent tx completion");
5975            state.free_tx_packets.push(id);
5976        } else {
5977            tracing::trace!(id = id.0, "pended tx completion");
5978            state.pending_tx_completions.push_back(PendingTxCompletion {
5979                transaction_id: tx_packet.transaction_id,
5980                tx_id: Some(id),
5981                status,
5982            });
5983        }
5984        Ok(())
5985    }
5986}
5987
5988impl ActiveState {
5989    fn release_recv_buffers(
5990        &mut self,
5991        transaction_id: u64,
5992        rx_buffer_range: &RxBufferRange,
5993        done: &mut Vec<RxId>,
5994    ) -> Option<()> {
5995        // The transaction ID specifies the first rx buffer ID.
5996        let first_id: u32 = transaction_id.try_into().ok()?;
5997        let ids = self.rx_bufs.free(first_id)?;
5998        for id in ids {
5999            if !rx_buffer_range.send_if_remote(id) {
6000                if id >= RX_RESERVED_CONTROL_BUFFERS {
6001                    done.push(RxId(id));
6002                } else {
6003                    self.primary
6004                        .as_mut()?
6005                        .free_control_buffers
6006                        .push(ControlMessageId(id));
6007                }
6008            }
6009        }
6010        Some(())
6011    }
6012}