netvsp/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! The user-mode netvsp VMBus device implementation.
5
6#![expect(missing_docs)]
7#![forbid(unsafe_code)]
8
9mod buffers;
10mod protocol;
11pub mod resolver;
12mod rndisprot;
13mod rx_bufs;
14mod saved_state;
15mod test;
16
17use crate::buffers::GuestBuffers;
18use crate::protocol::Message1RevokeReceiveBuffer;
19use crate::protocol::Message1RevokeSendBuffer;
20use crate::protocol::VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES;
21use crate::protocol::Version;
22use crate::rndisprot::NDIS_HASH_FUNCTION_MASK;
23use crate::rndisprot::NDIS_RSS_PARAM_FLAG_DISABLE_RSS;
24use async_trait::async_trait;
25pub use buffers::BufferPool;
26use buffers::sub_allocation_size_for_mtu;
27use futures::FutureExt;
28use futures::StreamExt;
29use futures::channel::mpsc;
30use futures::channel::mpsc::TrySendError;
31use futures_concurrency::future::Race;
32use guestmem::AccessError;
33use guestmem::GuestMemory;
34use guestmem::GuestMemoryError;
35use guestmem::MemoryRead;
36use guestmem::MemoryWrite;
37use guestmem::ranges::PagedRange;
38use guestmem::ranges::PagedRanges;
39use guestmem::ranges::PagedRangesReader;
40use guid::Guid;
41use hvdef::hypercall::HvGuestOsId;
42use hvdef::hypercall::HvGuestOsMicrosoft;
43use hvdef::hypercall::HvGuestOsMicrosoftIds;
44use hvdef::hypercall::HvGuestOsOpenSourceType;
45use inspect::Inspect;
46use inspect::InspectMut;
47use inspect::SensitivityLevel;
48use inspect_counters::Counter;
49use inspect_counters::Histogram;
50use mesh::rpc::Rpc;
51use net_backend::Endpoint;
52use net_backend::EndpointAction;
53use net_backend::QueueConfig;
54use net_backend::RxId;
55use net_backend::TxError;
56use net_backend::TxId;
57use net_backend::TxSegment;
58use net_backend_resources::mac_address::MacAddress;
59use pal_async::timer::Instant;
60use pal_async::timer::PolledTimer;
61use ring::gparange::MultiPagedRangeIter;
62use rx_bufs::RxBuffers;
63use rx_bufs::SubAllocationInUse;
64use std::collections::VecDeque;
65use std::fmt::Debug;
66use std::future::pending;
67use std::mem::offset_of;
68use std::ops::Range;
69use std::sync::Arc;
70use std::sync::atomic::AtomicUsize;
71use std::sync::atomic::Ordering;
72use std::task::Poll;
73use std::time::Duration;
74use task_control::AsyncRun;
75use task_control::InspectTaskMut;
76use task_control::StopTask;
77use task_control::TaskControl;
78use thiserror::Error;
79use tracing::Instrument;
80use vmbus_async::queue;
81use vmbus_async::queue::ExternalDataError;
82use vmbus_async::queue::IncomingPacket;
83use vmbus_async::queue::Queue;
84use vmbus_channel::bus::OfferParams;
85use vmbus_channel::bus::OpenRequest;
86use vmbus_channel::channel::ChannelControl;
87use vmbus_channel::channel::ChannelOpenError;
88use vmbus_channel::channel::ChannelRestoreError;
89use vmbus_channel::channel::DeviceResources;
90use vmbus_channel::channel::RestoreControl;
91use vmbus_channel::channel::SaveRestoreVmbusDevice;
92use vmbus_channel::channel::VmbusDevice;
93use vmbus_channel::gpadl::GpadlId;
94use vmbus_channel::gpadl::GpadlMapView;
95use vmbus_channel::gpadl::GpadlView;
96use vmbus_channel::gpadl::UnknownGpadlId;
97use vmbus_channel::gpadl_ring::GpadlRingMem;
98use vmbus_channel::gpadl_ring::gpadl_channel;
99use vmbus_ring as ring;
100use vmbus_ring::OutgoingPacketType;
101use vmbus_ring::RingMem;
102use vmbus_ring::gparange::MultiPagedRangeBuf;
103use vmcore::save_restore::RestoreError;
104use vmcore::save_restore::SaveError;
105use vmcore::save_restore::SavedStateBlob;
106use vmcore::vm_task::VmTaskDriver;
107use vmcore::vm_task::VmTaskDriverSource;
108use zerocopy::FromBytes;
109use zerocopy::FromZeros;
110use zerocopy::Immutable;
111use zerocopy::IntoBytes;
112use zerocopy::KnownLayout;
113
114// The minimum ring space required to handle a control message. Most control messages only need to send a completion
115// packet, but also need room for an additional SEND_VF_ASSOCIATION message.
116const MIN_CONTROL_RING_SIZE: usize = 144;
117
118// The minimum ring space required to handle external state changes. Worst case requires a completion message plus two
119// additional inband messages (SWITCH_DATA_PATH and SEND_VF_ASSOCIATION)
120const MIN_STATE_CHANGE_RING_SIZE: usize = 196;
121
122// Assign the VF_ASSOCIATION message a specific transaction ID so that the completion packet can be identified easily.
123const VF_ASSOCIATION_TRANSACTION_ID: u64 = 0x8000000000000000;
124// Assign the SWITCH_DATA_PATH message a specific transaction ID so that the completion packet can be identified easily.
125const SWITCH_DATA_PATH_TRANSACTION_ID: u64 = 0x8000000000000001;
126
127const NETVSP_MAX_SUBCHANNELS_PER_VNIC: u16 = 64;
128
129// Arbitrary delay before adding the device to the guest. Older Linux
130// clients can race when initializing the synthetic nic: the network
131// negotiation is done first and then the device is asynchronously queued to
132// receive a name (e.g. eth0). If the AN device is offered too quickly, it
133// could get the "eth0" name. In provisioning scenarios, the scripts will make
134// assumptions about which interface should be used, with eth0 being the
135// default.
136#[cfg(not(test))]
137const VF_DEVICE_DELAY: Duration = Duration::from_secs(1);
138#[cfg(test)]
139const VF_DEVICE_DELAY: Duration = Duration::from_millis(100);
140
141// Linux guests are known to not act on link state change notifications if
142// they happen in quick succession.
143#[cfg(not(test))]
144const LINK_DELAY_DURATION: Duration = Duration::from_secs(5);
145#[cfg(test)]
146const LINK_DELAY_DURATION: Duration = Duration::from_millis(333);
147
148#[derive(Default, PartialEq)]
149struct CoordinatorMessageUpdateType {
150    /// Update guest VF state based on current availability and the guest VF state tracked by the primary channel.
151    /// This includes adding the guest VF device and switching the data path.
152    guest_vf_state: bool,
153    /// Update the receive filter for all channels.
154    filter_state: bool,
155}
156
157#[derive(PartialEq)]
158enum CoordinatorMessage {
159    /// Update network state.
160    Update(CoordinatorMessageUpdateType),
161    /// Restart endpoints and resume processing. This will also attempt to set VF and data path state to match current
162    /// expectations.
163    Restart,
164    /// Start a timer.
165    StartTimer(Instant),
166}
167
168struct Worker<T: RingMem> {
169    channel_idx: u16,
170    target_vp: u32,
171    mem: GuestMemory,
172    channel: NetChannel<T>,
173    state: WorkerState,
174    coordinator_send: mpsc::Sender<CoordinatorMessage>,
175}
176
177struct NetQueue {
178    driver: VmTaskDriver,
179    queue_state: Option<QueueState>,
180}
181
182impl<T: RingMem + 'static + Sync> InspectTaskMut<Worker<T>> for NetQueue {
183    fn inspect_mut(&mut self, req: inspect::Request<'_>, worker: Option<&mut Worker<T>>) {
184        if worker.is_none() && self.queue_state.is_none() {
185            req.ignore();
186            return;
187        }
188
189        let mut resp = req.respond();
190        resp.field("driver", &self.driver);
191        if let Some(worker) = worker {
192            resp.field(
193                "protocol_state",
194                match &worker.state {
195                    WorkerState::Init(None) => "version",
196                    WorkerState::Init(Some(_)) => "init",
197                    WorkerState::Ready(_) => "ready",
198                    WorkerState::WaitingForCoordinator(_) => "waiting for coordinator",
199                },
200            )
201            .field("ring", &worker.channel.queue)
202            .field(
203                "can_use_ring_size_optimization",
204                worker.channel.can_use_ring_size_opt,
205            );
206
207            if let WorkerState::Ready(state) = &worker.state {
208                resp.field(
209                    "outstanding_tx_packets",
210                    state.state.pending_tx_packets.len() - state.state.free_tx_packets.len(),
211                )
212                .field("pending_rx_packets", state.state.pending_rx_packets.len())
213                .field(
214                    "pending_tx_completions",
215                    state.state.pending_tx_completions.len(),
216                )
217                .field("free_tx_packets", state.state.free_tx_packets.len())
218                .merge(&state.state.stats);
219            }
220
221            resp.field("packet_filter", worker.channel.packet_filter);
222        }
223
224        if let Some(queue_state) = &mut self.queue_state {
225            resp.field_mut("queue", &mut queue_state.queue)
226                .field("rx_buffers", queue_state.rx_buffer_range.id_range.len())
227                .field(
228                    "rx_buffers_start",
229                    queue_state.rx_buffer_range.id_range.start,
230                );
231        }
232    }
233}
234
235enum WorkerState {
236    Init(Option<InitState>),
237    Ready(ReadyState),
238    WaitingForCoordinator(Option<ReadyState>),
239}
240
241impl WorkerState {
242    fn ready(&self) -> Option<&ReadyState> {
243        if let Self::Ready(state) = self {
244            Some(state)
245        } else {
246            None
247        }
248    }
249
250    fn ready_mut(&mut self) -> Option<&mut ReadyState> {
251        if let Self::Ready(state) = self {
252            Some(state)
253        } else {
254            None
255        }
256    }
257}
258
259struct InitState {
260    version: Version,
261    ndis_config: Option<NdisConfig>,
262    ndis_version: Option<NdisVersion>,
263    recv_buffer: Option<ReceiveBuffer>,
264    send_buffer: Option<SendBuffer>,
265}
266
267#[derive(Copy, Clone, Debug, Inspect)]
268struct NdisVersion {
269    #[inspect(hex)]
270    major: u32,
271    #[inspect(hex)]
272    minor: u32,
273}
274
275#[derive(Copy, Clone, Debug, Inspect)]
276struct NdisConfig {
277    #[inspect(safe)]
278    mtu: u32,
279    #[inspect(safe)]
280    capabilities: protocol::NdisConfigCapabilities,
281}
282
283struct ReadyState {
284    buffers: Arc<ChannelBuffers>,
285    state: ActiveState,
286    data: ProcessingData,
287}
288
289/// Represents a virtual function (VF) device used to expose accelerated
290/// networking to the guest.
291#[async_trait]
292pub trait VirtualFunction: Sync + Send {
293    /// Unique ID of the device. Used by the client to associate a device with
294    /// its synthetic counterpart. A value of None signifies that the VF is not
295    /// currently available for use.
296    async fn id(&self) -> Option<u32>;
297    /// Dynamically expose the device in the guest.
298    async fn guest_ready_for_device(&mut self);
299    /// Returns when there is a change in VF availability. The Rpc result will
300    ///  indicate if the change was successfully handled.
301    async fn wait_for_state_change(&mut self) -> Rpc<(), ()>;
302}
303
304struct Adapter {
305    driver: VmTaskDriver,
306    mac_address: MacAddress,
307    max_queues: u16,
308    indirection_table_size: u16,
309    offload_support: OffloadConfig,
310    ring_size_limit: AtomicUsize,
311    free_tx_packet_threshold: usize,
312    tx_fast_completions: bool,
313    adapter_index: u32,
314    get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
315    num_sub_channels_opened: AtomicUsize,
316    link_speed: u64,
317}
318
319struct QueueState {
320    queue: Box<dyn net_backend::Queue>,
321    rx_buffer_range: RxBufferRange,
322    target_vp_set: bool,
323}
324
325struct RxBufferRange {
326    id_range: Range<u32>,
327    remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
328    remote_ranges: Arc<RxBufferRanges>,
329}
330
331impl RxBufferRange {
332    fn new(
333        ranges: Arc<RxBufferRanges>,
334        id_range: Range<u32>,
335        remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
336    ) -> Self {
337        Self {
338            id_range,
339            remote_buffer_id_recv,
340            remote_ranges: ranges,
341        }
342    }
343
344    fn send_if_remote(&self, id: u32) -> bool {
345        // Only queue 0 should get reserved buffer IDs. Otherwise check if the
346        // ID is owned by the current range.
347        if id < RX_RESERVED_CONTROL_BUFFERS || self.id_range.contains(&id) {
348            false
349        } else {
350            let i = (id - RX_RESERVED_CONTROL_BUFFERS) / self.remote_ranges.buffers_per_queue;
351            // The total number of receive buffers may not evenly divide among
352            // the active queues. Any extra buffers are given to the last
353            // queue, so redirect any larger values there.
354            let i = (i as usize).min(self.remote_ranges.buffer_id_send.len() - 1);
355            let _ = self.remote_ranges.buffer_id_send[i].unbounded_send(id);
356            true
357        }
358    }
359}
360
361struct RxBufferRanges {
362    buffers_per_queue: u32,
363    buffer_id_send: Vec<mpsc::UnboundedSender<u32>>,
364}
365
366impl RxBufferRanges {
367    fn new(buffer_count: u32, queue_count: u32) -> (Self, Vec<mpsc::UnboundedReceiver<u32>>) {
368        let buffers_per_queue = (buffer_count - RX_RESERVED_CONTROL_BUFFERS) / queue_count;
369        #[expect(clippy::disallowed_methods)] // TODO
370        let (send, recv): (Vec<_>, Vec<_>) = (0..queue_count).map(|_| mpsc::unbounded()).unzip();
371        (
372            Self {
373                buffers_per_queue,
374                buffer_id_send: send,
375            },
376            recv,
377        )
378    }
379}
380
381struct RssState {
382    key: [u8; 40],
383    indirection_table: Vec<u16>,
384}
385
386/// The internal channel state.
387struct NetChannel<T: RingMem> {
388    adapter: Arc<Adapter>,
389    queue: Queue<T>,
390    gpadl_map: GpadlMapView,
391    packet_size: PacketSize,
392    pending_send_size: usize,
393    restart: Option<CoordinatorMessage>,
394    can_use_ring_size_opt: bool,
395    packet_filter: u32,
396}
397
398// Use an enum to give the compiler more visibility into the packet size.
399#[derive(Debug, Copy, Clone)]
400enum PacketSize {
401    /// [`protocol::PACKET_SIZE_V1`]
402    V1,
403    /// [`protocol::PACKET_SIZE_V61`]
404    V61,
405}
406
407/// Buffers used during packet processing.
408struct ProcessingData {
409    tx_segments: Vec<TxSegment>,
410    tx_segments_sent: usize,
411    tx_done: Box<[TxId]>,
412    rx_ready: Box<[RxId]>,
413    rx_done: Vec<RxId>,
414    transfer_pages: Vec<ring::TransferPageRange>,
415    external_data: MultiPagedRangeBuf,
416}
417
418impl ProcessingData {
419    fn new() -> Self {
420        Self {
421            tx_segments: Vec::new(),
422            tx_segments_sent: 0,
423            tx_done: vec![TxId(0); 8192].into(),
424            rx_ready: vec![RxId(0); RX_BATCH_SIZE].into(),
425            rx_done: Vec::with_capacity(RX_BATCH_SIZE),
426            transfer_pages: Vec::with_capacity(RX_BATCH_SIZE),
427            external_data: MultiPagedRangeBuf::new(),
428        }
429    }
430}
431
432/// Buffers used during channel processing. Separated out from the mutable state
433/// to allow multiple concurrent references.
434#[derive(Debug, Inspect)]
435struct ChannelBuffers {
436    version: Version,
437    #[inspect(skip)]
438    mem: GuestMemory,
439    #[inspect(skip)]
440    recv_buffer: ReceiveBuffer,
441    #[inspect(skip)]
442    send_buffer: Option<SendBuffer>,
443    ndis_version: NdisVersion,
444    #[inspect(safe)]
445    ndis_config: NdisConfig,
446}
447
448/// An ID assigned to a control message. This is also its receive buffer index.
449#[derive(Copy, Clone, Debug)]
450struct ControlMessageId(u32);
451
452/// Mutable state for a channel that has finished negotiation.
453struct ActiveState {
454    primary: Option<PrimaryChannelState>,
455
456    pending_tx_packets: Vec<PendingTxPacket>,
457    free_tx_packets: Vec<TxId>,
458    pending_tx_completions: VecDeque<PendingTxCompletion>,
459    pending_rx_packets: VecDeque<RxId>,
460
461    rx_bufs: RxBuffers,
462
463    stats: QueueStats,
464}
465
466#[derive(Inspect, Default)]
467struct QueueStats {
468    tx_stalled: Counter,
469    rx_dropped_ring_full: Counter,
470    rx_dropped_filtered: Counter,
471    spurious_wakes: Counter,
472    rx_packets: Counter,
473    tx_packets: Counter,
474    tx_lso_packets: Counter,
475    tx_checksum_packets: Counter,
476    tx_invalid_lso_packets: Counter,
477    tx_packets_per_wake: Histogram<10>,
478    rx_packets_per_wake: Histogram<10>,
479}
480
481#[derive(Debug)]
482struct PendingTxCompletion {
483    transaction_id: u64,
484    tx_id: Option<TxId>,
485    status: protocol::Status,
486}
487
488#[derive(Clone, Copy)]
489enum PrimaryChannelGuestVfState {
490    /// No state has been assigned yet
491    Initializing,
492    /// State is being restored from a save
493    Restoring(saved_state::GuestVfState),
494    /// No VF available for the guest
495    Unavailable,
496    /// A VF was previously available to the guest, but is no longer available.
497    UnavailableFromAvailable,
498    /// A VF was previously available, but it is no longer available.
499    UnavailableFromDataPathSwitchPending { to_guest: bool, id: Option<u64> },
500    /// A VF was previously available, but it is no longer available.
501    UnavailableFromDataPathSwitched,
502    /// A VF is available for the guest.
503    Available { vfid: u32 },
504    /// A VF is available for the guest and has been advertised to the guest.
505    AvailableAdvertised,
506    /// A VF is ready for guest use.
507    Ready,
508    /// A VF is ready in the guest and guest has requested a data path switch.
509    DataPathSwitchPending {
510        to_guest: bool,
511        id: Option<u64>,
512        result: Option<bool>,
513    },
514    /// A VF is ready in the guest and is currently acting as the data path.
515    DataPathSwitched,
516    /// A VF is ready in the guest and was acting as the data path, but an external
517    /// state change has moved it back to synthetic.
518    DataPathSynthetic,
519}
520
521impl std::fmt::Display for PrimaryChannelGuestVfState {
522    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
523        match self {
524            PrimaryChannelGuestVfState::Initializing => write!(f, "initializing"),
525            PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
526                write!(f, "restoring")
527            }
528            PrimaryChannelGuestVfState::Restoring(
529                saved_state::GuestVfState::AvailableAdvertised,
530            ) => write!(f, "restoring from guest notified of vfid"),
531            PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
532                write!(f, "restoring from vf present")
533            }
534            PrimaryChannelGuestVfState::Restoring(
535                saved_state::GuestVfState::DataPathSwitchPending {
536                    to_guest, result, ..
537                },
538            ) => {
539                write!(
540                    f,
541                    "restoring from client requested data path switch: to {} {}",
542                    if *to_guest { "guest" } else { "synthetic" },
543                    if let Some(result) = result {
544                        if *result { "succeeded\"" } else { "failed\"" }
545                    } else {
546                        "in progress\""
547                    }
548                )
549            }
550            PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::DataPathSwitched) => {
551                write!(f, "restoring from data path in guest")
552            }
553            PrimaryChannelGuestVfState::Unavailable => write!(f, "unavailable"),
554            PrimaryChannelGuestVfState::UnavailableFromAvailable => {
555                write!(f, "\"unavailable (previously available)\"")
556            }
557            PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. } => {
558                write!(f, "unavailable (previously switching data path)")
559            }
560            PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
561                write!(f, "\"unavailable (previously using guest VF)\"")
562            }
563            PrimaryChannelGuestVfState::Available { vfid } => write!(f, "available vfid: {}", vfid),
564            PrimaryChannelGuestVfState::AvailableAdvertised => {
565                write!(f, "\"available, guest notified\"")
566            }
567            PrimaryChannelGuestVfState::Ready => write!(f, "\"available and present in guest\""),
568            PrimaryChannelGuestVfState::DataPathSwitchPending {
569                to_guest, result, ..
570            } => {
571                write!(
572                    f,
573                    "\"switching to {} {}",
574                    if *to_guest { "guest" } else { "synthetic" },
575                    if let Some(result) = result {
576                        if *result { "succeeded\"" } else { "failed\"" }
577                    } else {
578                        "in progress\""
579                    }
580                )
581            }
582            PrimaryChannelGuestVfState::DataPathSwitched => {
583                write!(f, "\"available and data path switched\"")
584            }
585            PrimaryChannelGuestVfState::DataPathSynthetic => write!(
586                f,
587                "\"available but data path switched back to synthetic due to external state change\""
588            ),
589        }
590    }
591}
592
593impl Inspect for PrimaryChannelGuestVfState {
594    fn inspect(&self, req: inspect::Request<'_>) {
595        req.value(self.to_string());
596    }
597}
598
599#[derive(Debug)]
600enum PendingLinkAction {
601    Default,
602    Active(bool),
603    Delay(bool),
604}
605
606struct PrimaryChannelState {
607    guest_vf_state: PrimaryChannelGuestVfState,
608    is_data_path_switched: Option<bool>,
609    control_messages: VecDeque<ControlMessage>,
610    control_messages_len: usize,
611    free_control_buffers: Vec<ControlMessageId>,
612    rss_state: Option<RssState>,
613    requested_num_queues: u16,
614    rndis_state: RndisState,
615    offload_config: OffloadConfig,
616    pending_offload_change: bool,
617    tx_spread_sent: bool,
618    guest_link_up: bool,
619    pending_link_action: PendingLinkAction,
620}
621
622impl Inspect for PrimaryChannelState {
623    fn inspect(&self, req: inspect::Request<'_>) {
624        req.respond()
625            .sensitivity_field(
626                "guest_vf_state",
627                SensitivityLevel::Safe,
628                self.guest_vf_state,
629            )
630            .sensitivity_field(
631                "data_path_switched",
632                SensitivityLevel::Safe,
633                self.is_data_path_switched,
634            )
635            .sensitivity_field(
636                "pending_control_messages",
637                SensitivityLevel::Safe,
638                self.control_messages.len(),
639            )
640            .sensitivity_field(
641                "free_control_message_buffers",
642                SensitivityLevel::Safe,
643                self.free_control_buffers.len(),
644            )
645            .sensitivity_field(
646                "pending_offload_change",
647                SensitivityLevel::Safe,
648                self.pending_offload_change,
649            )
650            .sensitivity_field("rndis_state", SensitivityLevel::Safe, self.rndis_state)
651            .sensitivity_field(
652                "offload_config",
653                SensitivityLevel::Safe,
654                &self.offload_config,
655            )
656            .sensitivity_field(
657                "tx_spread_sent",
658                SensitivityLevel::Safe,
659                self.tx_spread_sent,
660            )
661            .sensitivity_field("guest_link_up", SensitivityLevel::Safe, self.guest_link_up)
662            .sensitivity_field(
663                "pending_link_action",
664                SensitivityLevel::Safe,
665                match &self.pending_link_action {
666                    PendingLinkAction::Active(up) => format!("Active({:x?})", up),
667                    PendingLinkAction::Delay(up) => format!("Delay({:x?})", up),
668                    PendingLinkAction::Default => "None".to_string(),
669                },
670            );
671    }
672}
673
674#[derive(Debug, Inspect, Clone)]
675struct OffloadConfig {
676    #[inspect(safe)]
677    checksum_tx: ChecksumOffloadConfig,
678    #[inspect(safe)]
679    checksum_rx: ChecksumOffloadConfig,
680    #[inspect(safe)]
681    lso4: bool,
682    #[inspect(safe)]
683    lso6: bool,
684}
685
686impl OffloadConfig {
687    fn mask_to_supported(&mut self, supported: &OffloadConfig) {
688        self.checksum_tx.mask_to_supported(&supported.checksum_tx);
689        self.checksum_rx.mask_to_supported(&supported.checksum_rx);
690        self.lso4 &= supported.lso4;
691        self.lso6 &= supported.lso6;
692    }
693}
694
695#[derive(Debug, Inspect, Clone)]
696struct ChecksumOffloadConfig {
697    #[inspect(safe)]
698    ipv4_header: bool,
699    #[inspect(safe)]
700    tcp4: bool,
701    #[inspect(safe)]
702    udp4: bool,
703    #[inspect(safe)]
704    tcp6: bool,
705    #[inspect(safe)]
706    udp6: bool,
707}
708
709impl ChecksumOffloadConfig {
710    fn mask_to_supported(&mut self, supported: &ChecksumOffloadConfig) {
711        self.ipv4_header &= supported.ipv4_header;
712        self.tcp4 &= supported.tcp4;
713        self.udp4 &= supported.udp4;
714        self.tcp6 &= supported.tcp6;
715        self.udp6 &= supported.udp6;
716    }
717
718    fn flags(
719        &self,
720    ) -> (
721        rndisprot::Ipv4ChecksumOffload,
722        rndisprot::Ipv6ChecksumOffload,
723    ) {
724        let on = rndisprot::NDIS_OFFLOAD_SUPPORTED;
725        let mut v4 = rndisprot::Ipv4ChecksumOffload::new();
726        let mut v6 = rndisprot::Ipv6ChecksumOffload::new();
727        if self.ipv4_header {
728            v4.set_ip_options_supported(on);
729            v4.set_ip_checksum(on);
730        }
731        if self.tcp4 {
732            v4.set_ip_options_supported(on);
733            v4.set_tcp_options_supported(on);
734            v4.set_tcp_checksum(on);
735        }
736        if self.tcp6 {
737            v6.set_ip_extension_headers_supported(on);
738            v6.set_tcp_options_supported(on);
739            v6.set_tcp_checksum(on);
740        }
741        if self.udp4 {
742            v4.set_ip_options_supported(on);
743            v4.set_udp_checksum(on);
744        }
745        if self.udp6 {
746            v6.set_ip_extension_headers_supported(on);
747            v6.set_udp_checksum(on);
748        }
749        (v4, v6)
750    }
751}
752
753impl OffloadConfig {
754    fn ndis_offload(&self) -> rndisprot::NdisOffload {
755        let checksum = {
756            let (ipv4_tx_flags, ipv6_tx_flags) = self.checksum_tx.flags();
757            let (ipv4_rx_flags, ipv6_rx_flags) = self.checksum_rx.flags();
758            rndisprot::TcpIpChecksumOffload {
759                ipv4_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
760                ipv4_tx_flags,
761                ipv4_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
762                ipv4_rx_flags,
763                ipv6_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
764                ipv6_tx_flags,
765                ipv6_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
766                ipv6_rx_flags,
767            }
768        };
769
770        let lso_v2 = {
771            let mut lso = rndisprot::TcpLargeSendOffloadV2::new_zeroed();
772            if self.lso4 {
773                lso.ipv4_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
774                lso.ipv4_max_offload_size = rndisprot::LSO_MAX_OFFLOAD_SIZE;
775                lso.ipv4_min_segment_count = rndisprot::LSO_MIN_SEGMENT_COUNT;
776            }
777            if self.lso6 {
778                lso.ipv6_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
779                lso.ipv6_max_offload_size = rndisprot::LSO_MAX_OFFLOAD_SIZE;
780                lso.ipv6_min_segment_count = rndisprot::LSO_MIN_SEGMENT_COUNT;
781                lso.ipv6_flags = rndisprot::Ipv6LsoFlags::new()
782                    .with_ip_extension_headers_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED)
783                    .with_tcp_options_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED);
784            }
785            lso
786        };
787
788        rndisprot::NdisOffload {
789            header: rndisprot::NdisObjectHeader {
790                object_type: rndisprot::NdisObjectType::OFFLOAD,
791                revision: 3,
792                size: rndisprot::NDIS_SIZEOF_NDIS_OFFLOAD_REVISION_3 as u16,
793            },
794            checksum,
795            lso_v2,
796            ..FromZeros::new_zeroed()
797        }
798    }
799}
800
801#[derive(Debug, Inspect, PartialEq, Eq, Copy, Clone)]
802pub enum RndisState {
803    Initializing,
804    Operational,
805    Halted,
806}
807
808impl PrimaryChannelState {
809    fn new(offload_config: OffloadConfig) -> Self {
810        Self {
811            guest_vf_state: PrimaryChannelGuestVfState::Initializing,
812            is_data_path_switched: None,
813            control_messages: VecDeque::new(),
814            control_messages_len: 0,
815            free_control_buffers: (0..RX_RESERVED_CONTROL_BUFFERS)
816                .map(ControlMessageId)
817                .collect(),
818            rss_state: None,
819            requested_num_queues: 1,
820            rndis_state: RndisState::Initializing,
821            pending_offload_change: false,
822            offload_config,
823            tx_spread_sent: false,
824            guest_link_up: true,
825            pending_link_action: PendingLinkAction::Default,
826        }
827    }
828
829    fn restore(
830        guest_vf_state: &saved_state::GuestVfState,
831        rndis_state: &saved_state::RndisState,
832        offload_config: &saved_state::OffloadConfig,
833        pending_offload_change: bool,
834        num_queues: u16,
835        indirection_table_size: u16,
836        rx_bufs: &RxBuffers,
837        control_messages: Vec<saved_state::IncomingControlMessage>,
838        rss_state: Option<saved_state::RssState>,
839        tx_spread_sent: bool,
840        guest_link_down: bool,
841        pending_link_action: Option<bool>,
842    ) -> Result<Self, NetRestoreError> {
843        // Restore control messages.
844        let control_messages_len = control_messages.iter().map(|msg| msg.data.len()).sum();
845
846        let control_messages = control_messages
847            .into_iter()
848            .map(|msg| ControlMessage {
849                message_type: msg.message_type,
850                data: msg.data.into(),
851            })
852            .collect();
853
854        // Compute the free control buffers.
855        let free_control_buffers = (0..RX_RESERVED_CONTROL_BUFFERS)
856            .filter_map(|id| rx_bufs.is_free(id).then_some(ControlMessageId(id)))
857            .collect();
858
859        let rss_state = rss_state
860            .map(|mut rss| {
861                if rss.indirection_table.len() > indirection_table_size as usize {
862                    // Dynamic reduction of indirection table can cause unexpected and hard to investigate issues
863                    // with performance and processor overloading.
864                    return Err(NetRestoreError::ReducedIndirectionTableSize);
865                }
866                if rss.indirection_table.len() < indirection_table_size as usize {
867                    tracing::warn!(
868                        saved_indirection_table_size = rss.indirection_table.len(),
869                        adapter_indirection_table_size = indirection_table_size,
870                        "increasing indirection table size",
871                    );
872                    // Dynamic increase of indirection table is done by duplicating the existing entries until
873                    // the desired size is reached.
874                    let table_clone = rss.indirection_table.clone();
875                    let num_to_add = indirection_table_size as usize - rss.indirection_table.len();
876                    rss.indirection_table
877                        .extend(table_clone.iter().cycle().take(num_to_add));
878                }
879                Ok(RssState {
880                    key: rss
881                        .key
882                        .try_into()
883                        .map_err(|_| NetRestoreError::InvalidRssKeySize)?,
884                    indirection_table: rss.indirection_table,
885                })
886            })
887            .transpose()?;
888
889        let rndis_state = match rndis_state {
890            saved_state::RndisState::Initializing => RndisState::Initializing,
891            saved_state::RndisState::Operational => RndisState::Operational,
892            saved_state::RndisState::Halted => RndisState::Halted,
893        };
894
895        let guest_vf_state = PrimaryChannelGuestVfState::Restoring(*guest_vf_state);
896        let offload_config = OffloadConfig {
897            checksum_tx: ChecksumOffloadConfig {
898                ipv4_header: offload_config.checksum_tx.ipv4_header,
899                tcp4: offload_config.checksum_tx.tcp4,
900                udp4: offload_config.checksum_tx.udp4,
901                tcp6: offload_config.checksum_tx.tcp6,
902                udp6: offload_config.checksum_tx.udp6,
903            },
904            checksum_rx: ChecksumOffloadConfig {
905                ipv4_header: offload_config.checksum_rx.ipv4_header,
906                tcp4: offload_config.checksum_rx.tcp4,
907                udp4: offload_config.checksum_rx.udp4,
908                tcp6: offload_config.checksum_rx.tcp6,
909                udp6: offload_config.checksum_rx.udp6,
910            },
911            lso4: offload_config.lso4,
912            lso6: offload_config.lso6,
913        };
914
915        let pending_link_action = if let Some(pending) = pending_link_action {
916            PendingLinkAction::Active(pending)
917        } else {
918            PendingLinkAction::Default
919        };
920
921        Ok(Self {
922            guest_vf_state,
923            is_data_path_switched: None,
924            control_messages,
925            control_messages_len,
926            free_control_buffers,
927            rss_state,
928            requested_num_queues: num_queues,
929            rndis_state,
930            pending_offload_change,
931            offload_config,
932            tx_spread_sent,
933            guest_link_up: !guest_link_down,
934            pending_link_action,
935        })
936    }
937}
938
939struct ControlMessage {
940    message_type: u32,
941    data: Box<[u8]>,
942}
943
944const TX_PACKET_QUOTA: usize = 1024;
945
946impl ActiveState {
947    fn new(primary: Option<PrimaryChannelState>, recv_buffer_count: u32) -> Self {
948        Self {
949            primary,
950            pending_tx_packets: vec![Default::default(); TX_PACKET_QUOTA],
951            free_tx_packets: (0..TX_PACKET_QUOTA as u32).rev().map(TxId).collect(),
952            pending_tx_completions: VecDeque::new(),
953            pending_rx_packets: VecDeque::new(),
954            rx_bufs: RxBuffers::new(recv_buffer_count),
955            stats: Default::default(),
956        }
957    }
958
959    fn restore(
960        channel: &saved_state::Channel,
961        recv_buffer_count: u32,
962    ) -> Result<Self, NetRestoreError> {
963        let mut active = Self::new(None, recv_buffer_count);
964        let saved_state::Channel {
965            pending_tx_completions,
966            in_use_rx,
967        } = channel;
968        for rx in in_use_rx {
969            active
970                .rx_bufs
971                .allocate(rx.buffers.as_slice().iter().copied())?;
972        }
973        for &transaction_id in pending_tx_completions {
974            // Consume tx quota if any is available. If not, still
975            // allow the restore since tx quota might change from
976            // release to release.
977            let tx_id = active.free_tx_packets.pop();
978            if let Some(id) = tx_id {
979                // This shouldn't be referenced, but set it in case it is in the future.
980                active.pending_tx_packets[id.0 as usize].transaction_id = transaction_id;
981            }
982            // Save/Restore does not preserve the status of pending tx completions,
983            // completing any pending completions with 'success' to avoid making changes to saved_state.
984            active
985                .pending_tx_completions
986                .push_back(PendingTxCompletion {
987                    transaction_id,
988                    tx_id,
989                    status: protocol::Status::SUCCESS,
990                });
991        }
992        Ok(active)
993    }
994}
995
996/// The state for an rndis tx packet that's currently pending in the backend
997/// endpoint.
998#[derive(Default, Clone)]
999struct PendingTxPacket {
1000    pending_packet_count: usize,
1001    transaction_id: u64,
1002}
1003
1004/// The maximum batch size.
1005///
1006/// TODO: An even larger value is supported when RSC is enabled, so look into
1007/// this.
1008const RX_BATCH_SIZE: usize = 375;
1009
1010/// The number of receive buffers to reserve for control message responses.
1011const RX_RESERVED_CONTROL_BUFFERS: u32 = 16;
1012
1013/// A network adapter.
1014pub struct Nic {
1015    instance_id: Guid,
1016    resources: DeviceResources,
1017    coordinator: TaskControl<CoordinatorState, Coordinator>,
1018    coordinator_send: Option<mpsc::Sender<CoordinatorMessage>>,
1019    adapter: Arc<Adapter>,
1020    driver_source: VmTaskDriverSource,
1021}
1022
1023pub struct NicBuilder {
1024    virtual_function: Option<Box<dyn VirtualFunction>>,
1025    limit_ring_buffer: bool,
1026    max_queues: u16,
1027    get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
1028}
1029
1030impl NicBuilder {
1031    pub fn limit_ring_buffer(mut self, limit: bool) -> Self {
1032        self.limit_ring_buffer = limit;
1033        self
1034    }
1035
1036    pub fn max_queues(mut self, max_queues: u16) -> Self {
1037        self.max_queues = max_queues;
1038        self
1039    }
1040
1041    pub fn virtual_function(mut self, virtual_function: Box<dyn VirtualFunction>) -> Self {
1042        self.virtual_function = Some(virtual_function);
1043        self
1044    }
1045
1046    pub fn get_guest_os_id(mut self, os_type: Box<dyn Fn() -> HvGuestOsId + Send + Sync>) -> Self {
1047        self.get_guest_os_id = Some(os_type);
1048        self
1049    }
1050
1051    /// Creates a new NIC.
1052    pub fn build(
1053        self,
1054        driver_source: &VmTaskDriverSource,
1055        instance_id: Guid,
1056        endpoint: Box<dyn Endpoint>,
1057        mac_address: MacAddress,
1058        adapter_index: u32,
1059    ) -> Nic {
1060        let multiqueue = endpoint.multiqueue_support();
1061
1062        let max_queues = self.max_queues.clamp(
1063            1,
1064            multiqueue.max_queues.min(NETVSP_MAX_SUBCHANNELS_PER_VNIC),
1065        );
1066
1067        // If requested, limit the effective size of the outgoing ring buffer.
1068        // In a configuration where the NIC is processed synchronously, this
1069        // will ensure that we don't process incoming rx packets and tx packet
1070        // completions until the guest has processed the data it already has.
1071        let ring_size_limit = if self.limit_ring_buffer { 1024 } else { 0 };
1072
1073        // If the endpoint completes tx packets quickly, then avoid polling the
1074        // incoming ring (and thus avoid arming the signal from the guest) as
1075        // long as there are any tx packets in flight. This can significantly
1076        // reduce the signal rate from the guest, improving batching.
1077        let free_tx_packet_threshold = if endpoint.tx_fast_completions() {
1078            TX_PACKET_QUOTA
1079        } else {
1080            // Avoid getting into a situation where there is always barely
1081            // enough quota.
1082            TX_PACKET_QUOTA / 4
1083        };
1084
1085        let tx_offloads = endpoint.tx_offload_support();
1086
1087        // Always claim support for rx offloads since we can mark any given
1088        // packet as having unknown checksum state.
1089        let offload_support = OffloadConfig {
1090            checksum_rx: ChecksumOffloadConfig {
1091                ipv4_header: true,
1092                tcp4: true,
1093                udp4: true,
1094                tcp6: true,
1095                udp6: true,
1096            },
1097            checksum_tx: ChecksumOffloadConfig {
1098                ipv4_header: tx_offloads.ipv4_header,
1099                tcp4: tx_offloads.tcp,
1100                tcp6: tx_offloads.tcp,
1101                udp4: tx_offloads.udp,
1102                udp6: tx_offloads.udp,
1103            },
1104            lso4: tx_offloads.tso,
1105            lso6: tx_offloads.tso,
1106        };
1107
1108        let driver = driver_source.simple();
1109        let adapter = Arc::new(Adapter {
1110            driver,
1111            mac_address,
1112            max_queues,
1113            indirection_table_size: multiqueue.indirection_table_size,
1114            offload_support,
1115            free_tx_packet_threshold,
1116            ring_size_limit: ring_size_limit.into(),
1117            tx_fast_completions: endpoint.tx_fast_completions(),
1118            adapter_index,
1119            get_guest_os_id: self.get_guest_os_id,
1120            num_sub_channels_opened: AtomicUsize::new(0),
1121            link_speed: endpoint.link_speed(),
1122        });
1123
1124        let coordinator = TaskControl::new(CoordinatorState {
1125            endpoint,
1126            adapter: adapter.clone(),
1127            virtual_function: self.virtual_function,
1128            pending_vf_state: CoordinatorStatePendingVfState::Ready,
1129        });
1130
1131        Nic {
1132            instance_id,
1133            resources: Default::default(),
1134            coordinator,
1135            coordinator_send: None,
1136            adapter,
1137            driver_source: driver_source.clone(),
1138        }
1139    }
1140}
1141
1142fn can_use_ring_opt<T: RingMem>(queue: &mut Queue<T>, guest_os_id: Option<HvGuestOsId>) -> bool {
1143    let Some(guest_os_id) = guest_os_id else {
1144        // guest os id not available.
1145        return false;
1146    };
1147
1148    if !queue.split().0.supports_pending_send_size() {
1149        // guest does not support pending send size.
1150        return false;
1151    }
1152
1153    let Some(open_source_os) = guest_os_id.open_source() else {
1154        // guest os is proprietary (ex: Windows)
1155        return true;
1156    };
1157
1158    match HvGuestOsOpenSourceType(open_source_os.os_type()) {
1159        // Although FreeBSD indicates support for `pending send size`, it doesn't
1160        // implement it correctly. This was fixed in FreeBSD version `1400097`.
1161        HvGuestOsOpenSourceType::FREEBSD => open_source_os.version() >= 1400097,
1162        // Linux kernels prior to 3.11 have issues with pending send size optimization
1163        // which can affect certain Asynchronous I/O (AIO) network operations.
1164        // Disable ring size optimization for these older kernels to avoid flow control issues.
1165        HvGuestOsOpenSourceType::LINUX => {
1166            // Linux version is encoded as: ((major << 16) | (minor << 8) | patch)
1167            // Linux 3.11.0 = (3 << 16) | (11 << 8) | 0 = 199424
1168            open_source_os.version() >= 199424
1169        }
1170        _ => true,
1171    }
1172}
1173
1174impl Nic {
1175    pub fn builder() -> NicBuilder {
1176        NicBuilder {
1177            virtual_function: None,
1178            limit_ring_buffer: false,
1179            max_queues: !0,
1180            get_guest_os_id: None,
1181        }
1182    }
1183
1184    pub fn shutdown(self) -> Box<dyn Endpoint> {
1185        let (state, _) = self.coordinator.into_inner();
1186        state.endpoint
1187    }
1188}
1189
1190impl InspectMut for Nic {
1191    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
1192        self.coordinator.inspect_mut(req);
1193    }
1194}
1195
1196#[async_trait]
1197impl VmbusDevice for Nic {
1198    fn offer(&self) -> OfferParams {
1199        OfferParams {
1200            interface_name: "net".to_owned(),
1201            instance_id: self.instance_id,
1202            interface_id: Guid {
1203                data1: 0xf8615163,
1204                data2: 0xdf3e,
1205                data3: 0x46c5,
1206                data4: [0x91, 0x3f, 0xf2, 0xd2, 0xf9, 0x65, 0xed, 0xe],
1207            },
1208            subchannel_index: 0,
1209            mnf_interrupt_latency: Some(Duration::from_micros(100)),
1210            ..Default::default()
1211        }
1212    }
1213
1214    fn max_subchannels(&self) -> u16 {
1215        self.adapter.max_queues
1216    }
1217
1218    fn install(&mut self, resources: DeviceResources) {
1219        self.resources = resources;
1220    }
1221
1222    async fn open(
1223        &mut self,
1224        channel_idx: u16,
1225        open_request: &OpenRequest,
1226    ) -> Result<(), ChannelOpenError> {
1227        // Start the coordinator task if this is the primary channel.
1228        let state = if channel_idx == 0 {
1229            self.insert_coordinator(1, None);
1230            WorkerState::Init(None)
1231        } else {
1232            self.coordinator.stop().await;
1233            // Get the buffers created when the primary channel was opened.
1234            let buffers = self.coordinator.state().unwrap().buffers.clone().unwrap();
1235            WorkerState::Ready(ReadyState {
1236                state: ActiveState::new(None, buffers.recv_buffer.count),
1237                buffers,
1238                data: ProcessingData::new(),
1239            })
1240        };
1241
1242        let num_opened = self
1243            .adapter
1244            .num_sub_channels_opened
1245            .fetch_add(1, Ordering::SeqCst);
1246        let r = self.insert_worker(channel_idx, open_request, state, true);
1247        if channel_idx != 0
1248            && num_opened + 1 == self.coordinator.state_mut().unwrap().num_queues as usize
1249        {
1250            let coordinator = &mut self.coordinator.state_mut().unwrap();
1251            coordinator.workers[0].stop().await;
1252        }
1253
1254        if r.is_err() && channel_idx == 0 {
1255            self.coordinator.remove();
1256        } else {
1257            // The coordinator will restart any stopped workers.
1258            self.coordinator.start();
1259        }
1260        r?;
1261        Ok(())
1262    }
1263
1264    async fn close(&mut self, channel_idx: u16) {
1265        if !self.coordinator.has_state() {
1266            tracing::error!("Close called while vmbus channel is already closed");
1267            return;
1268        }
1269
1270        // Stop the coordinator to get access to the workers.
1271        let restart = self.coordinator.stop().await;
1272
1273        // Stop and remove the channel worker.
1274        {
1275            let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1276            worker.stop().await;
1277            if worker.has_state() {
1278                worker.remove();
1279            }
1280        }
1281
1282        self.adapter
1283            .num_sub_channels_opened
1284            .fetch_sub(1, Ordering::SeqCst);
1285        // Disable the endpoint.
1286        if channel_idx == 0 {
1287            for worker in &mut self.coordinator.state_mut().unwrap().workers {
1288                worker.task_mut().queue_state = None;
1289            }
1290
1291            // Note that this await is not restartable.
1292            self.coordinator.task_mut().endpoint.stop().await;
1293
1294            // Keep any VF's added to the guest. This is required to keep guest compat as
1295            // some apps (such as DPDK) relies on the VF sticking around even after vmbus
1296            // channel is closed.
1297            // The coordinator's job is done.
1298            self.coordinator.remove();
1299        } else {
1300            // Restart the coordinator.
1301            if restart {
1302                self.coordinator.start();
1303            }
1304        }
1305    }
1306
1307    async fn retarget_vp(&mut self, channel_idx: u16, target_vp: u32) {
1308        if !self.coordinator.has_state() {
1309            return;
1310        }
1311
1312        // Stop the coordinator and worker associated with this channel.
1313        let coordinator_running = self.coordinator.stop().await;
1314        let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1315        worker.stop().await;
1316        let (net_queue, worker_state) = worker.get_mut();
1317
1318        // Update the target VP on the driver.
1319        net_queue.driver.retarget_vp(target_vp);
1320
1321        if let Some(worker_state) = worker_state {
1322            // Update the target VP in the worker state.
1323            worker_state.target_vp = target_vp;
1324            if let Some(queue_state) = &mut net_queue.queue_state {
1325                // Tell the worker to re-set the target VP on next run.
1326                queue_state.target_vp_set = false;
1327            }
1328        }
1329
1330        // The coordinator will restart any stopped workers.
1331        if coordinator_running {
1332            self.coordinator.start();
1333        }
1334    }
1335
1336    fn start(&mut self) {
1337        if !self.coordinator.is_running() {
1338            self.coordinator.start();
1339        }
1340    }
1341
1342    async fn stop(&mut self) {
1343        self.coordinator.stop().await;
1344        if let Some(coordinator) = self.coordinator.state_mut() {
1345            coordinator.stop_workers().await;
1346        }
1347    }
1348
1349    fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
1350        Some(self)
1351    }
1352}
1353
1354#[async_trait]
1355impl SaveRestoreVmbusDevice for Nic {
1356    async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
1357        let state = self.saved_state();
1358        Ok(SavedStateBlob::new(state))
1359    }
1360
1361    async fn restore(
1362        &mut self,
1363        control: RestoreControl<'_>,
1364        state: SavedStateBlob,
1365    ) -> Result<(), RestoreError> {
1366        let state: saved_state::SavedState = state.parse()?;
1367        if let Err(err) = self.restore_state(control, state).await {
1368            tracing::error!(err = &err as &dyn std::error::Error, instance_id = %self.instance_id, "Failed restoring network vmbus state");
1369            Err(err.into())
1370        } else {
1371            Ok(())
1372        }
1373    }
1374}
1375
1376impl Nic {
1377    /// Allocates and inserts a worker.
1378    ///
1379    /// The coordinator must be stopped.
1380    fn insert_worker(
1381        &mut self,
1382        channel_idx: u16,
1383        open_request: &OpenRequest,
1384        state: WorkerState,
1385        start: bool,
1386    ) -> Result<(), OpenError> {
1387        let coordinator = self.coordinator.state_mut().unwrap();
1388
1389        // Retarget the driver now that the channel is open.
1390        let driver = coordinator.workers[channel_idx as usize]
1391            .task()
1392            .driver
1393            .clone();
1394        driver.retarget_vp(open_request.open_data.target_vp);
1395
1396        let raw = gpadl_channel(&driver, &self.resources, open_request, channel_idx)
1397            .map_err(OpenError::Ring)?;
1398        let mut queue = Queue::new(raw).map_err(OpenError::Queue)?;
1399        let guest_os_id = self.adapter.get_guest_os_id.as_ref().map(|f| f());
1400        let can_use_ring_size_opt = can_use_ring_opt(&mut queue, guest_os_id);
1401        let worker = Worker {
1402            channel_idx,
1403            target_vp: open_request.open_data.target_vp,
1404            mem: self
1405                .resources
1406                .offer_resources
1407                .guest_memory(open_request)
1408                .clone(),
1409            channel: NetChannel {
1410                adapter: self.adapter.clone(),
1411                queue,
1412                gpadl_map: self.resources.gpadl_map.clone(),
1413                packet_size: PacketSize::V1,
1414                pending_send_size: 0,
1415                restart: None,
1416                can_use_ring_size_opt,
1417                packet_filter: coordinator.active_packet_filter,
1418            },
1419            state,
1420            coordinator_send: self.coordinator_send.clone().unwrap(),
1421        };
1422        let instance_id = self.instance_id;
1423        let worker_task = &mut coordinator.workers[channel_idx as usize];
1424        worker_task.insert(
1425            driver,
1426            format!("netvsp-{}-{}", instance_id, channel_idx),
1427            worker,
1428        );
1429        if start {
1430            worker_task.start();
1431        }
1432        Ok(())
1433    }
1434}
1435
1436struct RestoreCoordinatorState {
1437    active_packet_filter: u32,
1438}
1439
1440impl Nic {
1441    /// If `restoring`, then restart the queues as soon as the coordinator starts.
1442    fn insert_coordinator(&mut self, num_queues: u16, restoring: Option<RestoreCoordinatorState>) {
1443        let mut driver_builder = self.driver_source.builder();
1444        // Target each driver to VP 0 initially. This will be updated when the
1445        // channel is opened.
1446        driver_builder.target_vp(0);
1447        // If tx completions arrive quickly, then just do tx processing
1448        // on whatever processor the guest happens to signal from.
1449        // Subsequent transmits will be pulled from the completion
1450        // processor.
1451        driver_builder.run_on_target(!self.adapter.tx_fast_completions);
1452
1453        #[expect(clippy::disallowed_methods)] // TODO
1454        let (send, recv) = mpsc::channel(1);
1455        self.coordinator_send = Some(send);
1456        self.coordinator.insert(
1457            &self.adapter.driver,
1458            format!("netvsp-{}-coordinator", self.instance_id),
1459            Coordinator {
1460                recv,
1461                channel_control: self.resources.channel_control.clone(),
1462                restart: restoring.is_some(),
1463                workers: (0..self.adapter.max_queues)
1464                    .map(|i| {
1465                        TaskControl::new(NetQueue {
1466                            queue_state: None,
1467                            driver: driver_builder
1468                                .build(format!("netvsp-{}-{}", self.instance_id, i)),
1469                        })
1470                    })
1471                    .collect(),
1472                buffers: None,
1473                num_queues,
1474                active_packet_filter: restoring
1475                    .map(|r| r.active_packet_filter)
1476                    .unwrap_or(rndisprot::NDIS_PACKET_TYPE_NONE),
1477                sleep_deadline: None,
1478            },
1479        );
1480    }
1481}
1482
1483#[derive(Debug, Error)]
1484enum NetRestoreError {
1485    #[error("unsupported protocol version {0:#x}")]
1486    UnsupportedVersion(u32),
1487    #[error("send/receive buffer invalid gpadl ID")]
1488    UnknownGpadlId(#[from] UnknownGpadlId),
1489    #[error("failed to restore channels")]
1490    Channel(#[source] ChannelRestoreError),
1491    #[error(transparent)]
1492    ReceiveBuffer(#[from] BufferError),
1493    #[error(transparent)]
1494    SuballocationMisconfigured(#[from] SubAllocationInUse),
1495    #[error(transparent)]
1496    Open(#[from] OpenError),
1497    #[error("invalid rss key size")]
1498    InvalidRssKeySize,
1499    #[error("reduced indirection table size")]
1500    ReducedIndirectionTableSize,
1501}
1502
1503impl From<NetRestoreError> for RestoreError {
1504    fn from(err: NetRestoreError) -> Self {
1505        RestoreError::InvalidSavedState(anyhow::Error::new(err))
1506    }
1507}
1508
1509impl Nic {
1510    async fn restore_state(
1511        &mut self,
1512        mut control: RestoreControl<'_>,
1513        state: saved_state::SavedState,
1514    ) -> Result<(), NetRestoreError> {
1515        let mut saved_packet_filter = 0u32;
1516        if let Some(state) = state.open {
1517            let open = match &state.primary {
1518                saved_state::Primary::Version => vec![true],
1519                saved_state::Primary::Init(_) => vec![true],
1520                saved_state::Primary::Ready(ready) => {
1521                    ready.channels.iter().map(|x| x.is_some()).collect()
1522                }
1523            };
1524
1525            let mut states: Vec<_> = open.iter().map(|_| None).collect();
1526
1527            // N.B. This will restore the vmbus view of open channels, so any
1528            //      failures after this point could result in inconsistent
1529            //      state (vmbus believes the channel is open/active). There
1530            //      are a number of failure paths after this point because this
1531            //      call also restores vmbus device state, like the GPADL map.
1532            let requests = control
1533                .restore(&open)
1534                .await
1535                .map_err(NetRestoreError::Channel)?;
1536
1537            match state.primary {
1538                saved_state::Primary::Version => {
1539                    states[0] = Some(WorkerState::Init(None));
1540                }
1541                saved_state::Primary::Init(init) => {
1542                    let version = check_version(init.version)
1543                        .ok_or(NetRestoreError::UnsupportedVersion(init.version))?;
1544
1545                    let recv_buffer = init
1546                        .receive_buffer
1547                        .map(|recv_buffer| {
1548                            ReceiveBuffer::new(
1549                                &self.resources.gpadl_map,
1550                                recv_buffer.gpadl_id,
1551                                recv_buffer.id,
1552                                recv_buffer.sub_allocation_size,
1553                            )
1554                        })
1555                        .transpose()?;
1556
1557                    let send_buffer = init
1558                        .send_buffer
1559                        .map(|send_buffer| {
1560                            SendBuffer::new(&self.resources.gpadl_map, send_buffer.gpadl_id)
1561                        })
1562                        .transpose()?;
1563
1564                    let state = InitState {
1565                        version,
1566                        ndis_config: init.ndis_config.map(
1567                            |saved_state::NdisConfig { mtu, capabilities }| NdisConfig {
1568                                mtu,
1569                                capabilities: capabilities.into(),
1570                            },
1571                        ),
1572                        ndis_version: init.ndis_version.map(
1573                            |saved_state::NdisVersion { major, minor }| NdisVersion {
1574                                major,
1575                                minor,
1576                            },
1577                        ),
1578                        recv_buffer,
1579                        send_buffer,
1580                    };
1581                    states[0] = Some(WorkerState::Init(Some(state)));
1582                }
1583                saved_state::Primary::Ready(ready) => {
1584                    let saved_state::ReadyPrimary {
1585                        version,
1586                        receive_buffer,
1587                        send_buffer,
1588                        mut control_messages,
1589                        mut rss_state,
1590                        channels,
1591                        ndis_version,
1592                        ndis_config,
1593                        rndis_state,
1594                        guest_vf_state,
1595                        offload_config,
1596                        pending_offload_change,
1597                        tx_spread_sent,
1598                        guest_link_down,
1599                        pending_link_action,
1600                        packet_filter,
1601                    } = ready;
1602
1603                    // If saved state does not have a packet filter set, default to directed, multicast, and broadcast.
1604                    saved_packet_filter = packet_filter.unwrap_or(rndisprot::NPROTO_PACKET_FILTER);
1605
1606                    let version = check_version(version)
1607                        .ok_or(NetRestoreError::UnsupportedVersion(version))?;
1608
1609                    let request = requests[0].as_ref().unwrap();
1610                    let buffers = Arc::new(ChannelBuffers {
1611                        version,
1612                        mem: self.resources.offer_resources.guest_memory(request).clone(),
1613                        recv_buffer: ReceiveBuffer::new(
1614                            &self.resources.gpadl_map,
1615                            receive_buffer.gpadl_id,
1616                            receive_buffer.id,
1617                            receive_buffer.sub_allocation_size,
1618                        )?,
1619                        send_buffer: {
1620                            if let Some(send_buffer) = send_buffer {
1621                                Some(SendBuffer::new(
1622                                    &self.resources.gpadl_map,
1623                                    send_buffer.gpadl_id,
1624                                )?)
1625                            } else {
1626                                None
1627                            }
1628                        },
1629                        ndis_version: {
1630                            let saved_state::NdisVersion { major, minor } = ndis_version;
1631                            NdisVersion { major, minor }
1632                        },
1633                        ndis_config: {
1634                            let saved_state::NdisConfig { mtu, capabilities } = ndis_config;
1635                            NdisConfig {
1636                                mtu,
1637                                capabilities: capabilities.into(),
1638                            }
1639                        },
1640                    });
1641
1642                    for (channel_idx, channel) in channels.iter().enumerate() {
1643                        let channel = if let Some(channel) = channel {
1644                            channel
1645                        } else {
1646                            continue;
1647                        };
1648
1649                        let mut active = ActiveState::restore(channel, buffers.recv_buffer.count)?;
1650
1651                        // Restore primary channel state.
1652                        if channel_idx == 0 {
1653                            let primary = PrimaryChannelState::restore(
1654                                &guest_vf_state,
1655                                &rndis_state,
1656                                &offload_config,
1657                                pending_offload_change,
1658                                channels.len() as u16,
1659                                self.adapter.indirection_table_size,
1660                                &active.rx_bufs,
1661                                std::mem::take(&mut control_messages),
1662                                rss_state.take(),
1663                                tx_spread_sent,
1664                                guest_link_down,
1665                                pending_link_action,
1666                            )?;
1667                            active.primary = Some(primary);
1668                        }
1669
1670                        states[channel_idx] = Some(WorkerState::Ready(ReadyState {
1671                            buffers: buffers.clone(),
1672                            state: active,
1673                            data: ProcessingData::new(),
1674                        }));
1675                    }
1676                }
1677            }
1678
1679            // Insert the coordinator and mark that it should try to start the
1680            // network endpoint when it starts running.
1681            self.insert_coordinator(
1682                states.len() as u16,
1683                Some(RestoreCoordinatorState {
1684                    active_packet_filter: saved_packet_filter,
1685                }),
1686            );
1687
1688            for (channel_idx, (state, request)) in states.into_iter().zip(requests).enumerate() {
1689                if let Some(state) = state {
1690                    self.insert_worker(channel_idx as u16, &request.unwrap(), state, false)?;
1691                }
1692            }
1693        } else {
1694            control
1695                .restore(&[false])
1696                .await
1697                .map_err(NetRestoreError::Channel)?;
1698        }
1699        Ok(())
1700    }
1701
1702    fn saved_state(&self) -> saved_state::SavedState {
1703        let open = if let Some(coordinator) = self.coordinator.state() {
1704            let primary = coordinator.workers[0].state().unwrap();
1705            let primary = match &primary.state {
1706                WorkerState::Init(None) => saved_state::Primary::Version,
1707                WorkerState::Init(Some(init)) => {
1708                    saved_state::Primary::Init(saved_state::InitPrimary {
1709                        version: init.version as u32,
1710                        ndis_config: init.ndis_config.map(|NdisConfig { mtu, capabilities }| {
1711                            saved_state::NdisConfig {
1712                                mtu,
1713                                capabilities: capabilities.into(),
1714                            }
1715                        }),
1716                        ndis_version: init.ndis_version.map(|NdisVersion { major, minor }| {
1717                            saved_state::NdisVersion { major, minor }
1718                        }),
1719                        receive_buffer: init.recv_buffer.as_ref().map(|x| x.saved_state()),
1720                        send_buffer: init.send_buffer.as_ref().map(|x| saved_state::SendBuffer {
1721                            gpadl_id: x.gpadl.id(),
1722                        }),
1723                    })
1724                }
1725                WorkerState::WaitingForCoordinator(Some(ready)) | WorkerState::Ready(ready) => {
1726                    let primary = ready.state.primary.as_ref().unwrap();
1727
1728                    let rndis_state = match primary.rndis_state {
1729                        RndisState::Initializing => saved_state::RndisState::Initializing,
1730                        RndisState::Operational => saved_state::RndisState::Operational,
1731                        RndisState::Halted => saved_state::RndisState::Halted,
1732                    };
1733
1734                    let offload_config = saved_state::OffloadConfig {
1735                        checksum_tx: saved_state::ChecksumOffloadConfig {
1736                            ipv4_header: primary.offload_config.checksum_tx.ipv4_header,
1737                            tcp4: primary.offload_config.checksum_tx.tcp4,
1738                            udp4: primary.offload_config.checksum_tx.udp4,
1739                            tcp6: primary.offload_config.checksum_tx.tcp6,
1740                            udp6: primary.offload_config.checksum_tx.udp6,
1741                        },
1742                        checksum_rx: saved_state::ChecksumOffloadConfig {
1743                            ipv4_header: primary.offload_config.checksum_rx.ipv4_header,
1744                            tcp4: primary.offload_config.checksum_rx.tcp4,
1745                            udp4: primary.offload_config.checksum_rx.udp4,
1746                            tcp6: primary.offload_config.checksum_rx.tcp6,
1747                            udp6: primary.offload_config.checksum_rx.udp6,
1748                        },
1749                        lso4: primary.offload_config.lso4,
1750                        lso6: primary.offload_config.lso6,
1751                    };
1752
1753                    let control_messages = primary
1754                        .control_messages
1755                        .iter()
1756                        .map(|message| saved_state::IncomingControlMessage {
1757                            message_type: message.message_type,
1758                            data: message.data.to_vec(),
1759                        })
1760                        .collect();
1761
1762                    let rss_state = primary.rss_state.as_ref().map(|rss| saved_state::RssState {
1763                        key: rss.key.into(),
1764                        indirection_table: rss.indirection_table.clone(),
1765                    });
1766
1767                    let pending_link_action = match primary.pending_link_action {
1768                        PendingLinkAction::Default => None,
1769                        PendingLinkAction::Active(action) | PendingLinkAction::Delay(action) => {
1770                            Some(action)
1771                        }
1772                    };
1773
1774                    let channels = coordinator.workers[..coordinator.num_queues as usize]
1775                        .iter()
1776                        .map(|worker| {
1777                            worker.state().map(|worker| {
1778                                if let Some(ready) = worker.state.ready() {
1779                                    // In flight tx will be considered as dropped packets through save/restore, but need
1780                                    // to complete the requests back to the guest.
1781                                    let pending_tx_completions = ready
1782                                        .state
1783                                        .pending_tx_completions
1784                                        .iter()
1785                                        .map(|pending| pending.transaction_id)
1786                                        .chain(ready.state.pending_tx_packets.iter().filter_map(
1787                                            |inflight| {
1788                                                (inflight.pending_packet_count > 0)
1789                                                    .then_some(inflight.transaction_id)
1790                                            },
1791                                        ))
1792                                        .collect();
1793
1794                                    saved_state::Channel {
1795                                        pending_tx_completions,
1796                                        in_use_rx: {
1797                                            ready
1798                                                .state
1799                                                .rx_bufs
1800                                                .allocated()
1801                                                .map(|id| saved_state::Rx {
1802                                                    buffers: id.collect(),
1803                                                })
1804                                                .collect()
1805                                        },
1806                                    }
1807                                } else {
1808                                    saved_state::Channel {
1809                                        pending_tx_completions: Vec::new(),
1810                                        in_use_rx: Vec::new(),
1811                                    }
1812                                }
1813                            })
1814                        })
1815                        .collect();
1816
1817                    let guest_vf_state = match primary.guest_vf_state {
1818                        PrimaryChannelGuestVfState::Initializing
1819                        | PrimaryChannelGuestVfState::Unavailable
1820                        | PrimaryChannelGuestVfState::Available { .. } => {
1821                            saved_state::GuestVfState::NoState
1822                        }
1823                        PrimaryChannelGuestVfState::UnavailableFromAvailable
1824                        | PrimaryChannelGuestVfState::AvailableAdvertised => {
1825                            saved_state::GuestVfState::AvailableAdvertised
1826                        }
1827                        PrimaryChannelGuestVfState::Ready => saved_state::GuestVfState::Ready,
1828                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
1829                            to_guest,
1830                            id,
1831                        } => saved_state::GuestVfState::DataPathSwitchPending {
1832                            to_guest,
1833                            id,
1834                            result: None,
1835                        },
1836                        PrimaryChannelGuestVfState::DataPathSwitchPending {
1837                            to_guest,
1838                            id,
1839                            result,
1840                        } => saved_state::GuestVfState::DataPathSwitchPending {
1841                            to_guest,
1842                            id,
1843                            result,
1844                        },
1845                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
1846                        | PrimaryChannelGuestVfState::DataPathSwitched
1847                        | PrimaryChannelGuestVfState::DataPathSynthetic => {
1848                            saved_state::GuestVfState::DataPathSwitched
1849                        }
1850                        PrimaryChannelGuestVfState::Restoring(saved_state) => saved_state,
1851                    };
1852
1853                    let worker_0_packet_filter = coordinator.workers[0]
1854                        .state()
1855                        .unwrap()
1856                        .channel
1857                        .packet_filter;
1858                    saved_state::Primary::Ready(saved_state::ReadyPrimary {
1859                        version: ready.buffers.version as u32,
1860                        receive_buffer: ready.buffers.recv_buffer.saved_state(),
1861                        send_buffer: ready.buffers.send_buffer.as_ref().map(|sb| {
1862                            saved_state::SendBuffer {
1863                                gpadl_id: sb.gpadl.id(),
1864                            }
1865                        }),
1866                        rndis_state,
1867                        guest_vf_state,
1868                        offload_config,
1869                        pending_offload_change: primary.pending_offload_change,
1870                        control_messages,
1871                        rss_state,
1872                        channels,
1873                        ndis_config: {
1874                            let NdisConfig { mtu, capabilities } = ready.buffers.ndis_config;
1875                            saved_state::NdisConfig {
1876                                mtu,
1877                                capabilities: capabilities.into(),
1878                            }
1879                        },
1880                        ndis_version: {
1881                            let NdisVersion { major, minor } = ready.buffers.ndis_version;
1882                            saved_state::NdisVersion { major, minor }
1883                        },
1884                        tx_spread_sent: primary.tx_spread_sent,
1885                        guest_link_down: !primary.guest_link_up,
1886                        pending_link_action,
1887                        packet_filter: Some(worker_0_packet_filter),
1888                    })
1889                }
1890                WorkerState::WaitingForCoordinator(None) => {
1891                    unreachable!("valid ready state")
1892                }
1893            };
1894
1895            let state = saved_state::OpenState { primary };
1896            Some(state)
1897        } else {
1898            None
1899        };
1900
1901        saved_state::SavedState { open }
1902    }
1903}
1904
1905#[derive(Debug, Error)]
1906enum WorkerError {
1907    #[error("packet error")]
1908    Packet(#[source] PacketError),
1909    #[error("unexpected packet order: {0}")]
1910    UnexpectedPacketOrder(#[source] PacketOrderError),
1911    #[error("unknown rndis message type: {0}")]
1912    UnknownRndisMessageType(u32),
1913    #[error("junk after rndis packet message: {0:#x}")]
1914    NonRndisPacketAfterPacket(u32),
1915    #[error("memory access error")]
1916    Access(#[from] AccessError),
1917    #[error("rndis message too small")]
1918    RndisMessageTooSmall,
1919    #[error("unsupported rndis behavior")]
1920    UnsupportedRndisBehavior,
1921    #[error("vmbus queue error")]
1922    Queue(#[from] queue::Error),
1923    #[error("too many control messages")]
1924    TooManyControlMessages,
1925    #[error("invalid rndis packet completion")]
1926    InvalidRndisPacketCompletion,
1927    #[error("missing transaction id")]
1928    MissingTransactionId,
1929    #[error("invalid gpadl")]
1930    InvalidGpadl(#[source] guestmem::InvalidGpn),
1931    #[error("gpadl error")]
1932    GpadlError(#[source] GuestMemoryError),
1933    #[error("gpa direct error")]
1934    GpaDirectError(#[source] GuestMemoryError),
1935    #[error("endpoint")]
1936    Endpoint(#[source] anyhow::Error),
1937    #[error("message not supported on sub channel: {0}")]
1938    NotSupportedOnSubChannel(u32),
1939    #[error("the ring buffer ran out of space, which should not be possible")]
1940    OutOfSpace,
1941    #[error("send/receive buffer error")]
1942    Buffer(#[from] BufferError),
1943    #[error("invalid rndis state")]
1944    InvalidRndisState,
1945    #[error("rndis message type not implemented")]
1946    RndisMessageTypeNotImplemented,
1947    #[error("invalid TCP header offset")]
1948    InvalidTcpHeaderOffset,
1949    #[error("cancelled")]
1950    Cancelled(task_control::Cancelled),
1951    #[error("tearing down because send/receive buffer is revoked")]
1952    BufferRevoked,
1953    #[error("endpoint requires queue restart: {0}")]
1954    EndpointRequiresQueueRestart(#[source] anyhow::Error),
1955    #[error("Failed to send message to coordinator")]
1956    CoordinatorMessageSendFailed(#[source] TrySendError<CoordinatorMessage>),
1957}
1958
1959impl From<task_control::Cancelled> for WorkerError {
1960    fn from(value: task_control::Cancelled) -> Self {
1961        Self::Cancelled(value)
1962    }
1963}
1964
1965#[derive(Debug, Error)]
1966enum OpenError {
1967    #[error("error establishing ring buffer")]
1968    Ring(#[source] vmbus_channel::gpadl_ring::Error),
1969    #[error("error establishing vmbus queue")]
1970    Queue(#[source] queue::Error),
1971}
1972
1973#[derive(Debug, Error)]
1974enum PacketError {
1975    #[error("UnknownType {0}")]
1976    UnknownType(u32),
1977    #[error("Access")]
1978    Access(#[source] AccessError),
1979    #[error("ExternalData")]
1980    ExternalData(#[source] ExternalDataError),
1981    #[error("InvalidSendBufferIndex")]
1982    InvalidSendBufferIndex,
1983}
1984
1985#[derive(Debug, Error)]
1986enum PacketOrderError {
1987    #[error("Invalid PacketData")]
1988    InvalidPacketData,
1989    #[error("Unexpected RndisPacket")]
1990    UnexpectedRndisPacket,
1991    #[error("SendNdisVersion already exists")]
1992    SendNdisVersionExists,
1993    #[error("SendNdisConfig already exists")]
1994    SendNdisConfigExists,
1995    #[error("SendReceiveBuffer already exists")]
1996    SendReceiveBufferExists,
1997    #[error("SendReceiveBuffer missing MTU")]
1998    SendReceiveBufferMissingMTU,
1999    #[error("SendSendBuffer already exists")]
2000    SendSendBufferExists,
2001    #[error("SwitchDataPathCompletion during PrimaryChannelState")]
2002    SwitchDataPathCompletionPrimaryChannelState,
2003}
2004
2005#[derive(Debug)]
2006enum PacketData {
2007    Init(protocol::MessageInit),
2008    SendNdisVersion(protocol::Message1SendNdisVersion),
2009    SendReceiveBuffer(protocol::Message1SendReceiveBuffer),
2010    SendSendBuffer(protocol::Message1SendSendBuffer),
2011    RevokeReceiveBuffer(Message1RevokeReceiveBuffer),
2012    RevokeSendBuffer(Message1RevokeSendBuffer),
2013    RndisPacket(protocol::Message1SendRndisPacket),
2014    RndisPacketComplete(protocol::Message1SendRndisPacketComplete),
2015    SendNdisConfig(protocol::Message2SendNdisConfig),
2016    SwitchDataPath(protocol::Message4SwitchDataPath),
2017    OidQueryEx(protocol::Message5OidQueryEx),
2018    SubChannelRequest(protocol::Message5SubchannelRequest),
2019    SendVfAssociationCompletion,
2020    SwitchDataPathCompletion,
2021}
2022
2023#[derive(Debug)]
2024struct Packet<'a> {
2025    data: PacketData,
2026    transaction_id: Option<u64>,
2027    external_data: &'a MultiPagedRangeBuf,
2028}
2029
2030type PacketReader<'a> = PagedRangesReader<'a, MultiPagedRangeIter<'a>>;
2031
2032impl Packet<'_> {
2033    fn rndis_reader<'a>(&'a self, mem: &'a GuestMemory) -> PacketReader<'a> {
2034        PagedRanges::new(self.external_data.iter()).reader(mem)
2035    }
2036}
2037
2038fn read_packet_data<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
2039    reader: &mut impl MemoryRead,
2040) -> Result<T, PacketError> {
2041    reader.read_plain().map_err(PacketError::Access)
2042}
2043
2044fn parse_packet<'a, T: RingMem>(
2045    packet_ref: &queue::PacketRef<'_, T>,
2046    send_buffer: Option<&SendBuffer>,
2047    version: Option<Version>,
2048    external_data: &'a mut MultiPagedRangeBuf,
2049) -> Result<Packet<'a>, PacketError> {
2050    external_data.clear();
2051    let packet = match packet_ref.as_ref() {
2052        IncomingPacket::Data(data) => data,
2053        IncomingPacket::Completion(completion) => {
2054            let data = if completion.transaction_id() == VF_ASSOCIATION_TRANSACTION_ID {
2055                PacketData::SendVfAssociationCompletion
2056            } else if completion.transaction_id() == SWITCH_DATA_PATH_TRANSACTION_ID {
2057                PacketData::SwitchDataPathCompletion
2058            } else {
2059                let mut reader = completion.reader();
2060                let header: protocol::MessageHeader =
2061                    reader.read_plain().map_err(PacketError::Access)?;
2062                match header.message_type {
2063                    protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE => {
2064                        PacketData::RndisPacketComplete(read_packet_data(&mut reader)?)
2065                    }
2066                    typ => return Err(PacketError::UnknownType(typ)),
2067                }
2068            };
2069            return Ok(Packet {
2070                data,
2071                transaction_id: Some(completion.transaction_id()),
2072                external_data,
2073            });
2074        }
2075    };
2076
2077    let mut reader = packet.reader();
2078    let header: protocol::MessageHeader = reader.read_plain().map_err(PacketError::Access)?;
2079    let data = match header.message_type {
2080        protocol::MESSAGE_TYPE_INIT => PacketData::Init(read_packet_data(&mut reader)?),
2081        protocol::MESSAGE1_TYPE_SEND_NDIS_VERSION if version >= Some(Version::V1) => {
2082            PacketData::SendNdisVersion(read_packet_data(&mut reader)?)
2083        }
2084        protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2085            PacketData::SendReceiveBuffer(read_packet_data(&mut reader)?)
2086        }
2087        protocol::MESSAGE1_TYPE_REVOKE_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2088            PacketData::RevokeReceiveBuffer(read_packet_data(&mut reader)?)
2089        }
2090        protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER if version >= Some(Version::V1) => {
2091            PacketData::SendSendBuffer(read_packet_data(&mut reader)?)
2092        }
2093        protocol::MESSAGE1_TYPE_REVOKE_SEND_BUFFER if version >= Some(Version::V1) => {
2094            PacketData::RevokeSendBuffer(read_packet_data(&mut reader)?)
2095        }
2096        protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET if version >= Some(Version::V1) => {
2097            let message: protocol::Message1SendRndisPacket = read_packet_data(&mut reader)?;
2098            if message.send_buffer_section_index != 0xffffffff {
2099                let send_buffer_suballocation = send_buffer
2100                    .ok_or(PacketError::InvalidSendBufferIndex)?
2101                    .gpadl
2102                    .first()
2103                    .unwrap()
2104                    .try_subrange(
2105                        message.send_buffer_section_index as usize * 6144,
2106                        message.send_buffer_section_size as usize,
2107                    )
2108                    .ok_or(PacketError::InvalidSendBufferIndex)?;
2109
2110                external_data.push_range(send_buffer_suballocation);
2111            }
2112            PacketData::RndisPacket(message)
2113        }
2114        protocol::MESSAGE2_TYPE_SEND_NDIS_CONFIG if version >= Some(Version::V2) => {
2115            PacketData::SendNdisConfig(read_packet_data(&mut reader)?)
2116        }
2117        protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH if version >= Some(Version::V4) => {
2118            PacketData::SwitchDataPath(read_packet_data(&mut reader)?)
2119        }
2120        protocol::MESSAGE5_TYPE_OID_QUERY_EX if version >= Some(Version::V5) => {
2121            PacketData::OidQueryEx(read_packet_data(&mut reader)?)
2122        }
2123        protocol::MESSAGE5_TYPE_SUB_CHANNEL if version >= Some(Version::V5) => {
2124            PacketData::SubChannelRequest(read_packet_data(&mut reader)?)
2125        }
2126        typ => return Err(PacketError::UnknownType(typ)),
2127    };
2128    packet
2129        .read_external_ranges(external_data)
2130        .map_err(PacketError::ExternalData)?;
2131    Ok(Packet {
2132        data,
2133        transaction_id: packet.transaction_id(),
2134        external_data,
2135    })
2136}
2137
2138#[derive(Debug, Copy, Clone)]
2139struct NvspMessage {
2140    buf: [u64; protocol::PACKET_SIZE_V61 / 8],
2141    size: PacketSize,
2142}
2143
2144impl NvspMessage {
2145    fn new<P: IntoBytes + Immutable + KnownLayout>(
2146        size: PacketSize,
2147        message_type: u32,
2148        data: P,
2149    ) -> Self {
2150        // Assert at compile time that the packet will fit in the message
2151        // buffer. Note that we are checking against the v1 message size here.
2152        // It's possible this is a v6.1+ message, in which case we could compare
2153        // against the larger size. So far this has not been necessary. If
2154        // needed, make a `new_v61` method that does the more relaxed check,
2155        // rater than weakening this one.
2156        const {
2157            assert!(
2158                size_of::<P>() <= protocol::PACKET_SIZE_V1 - size_of::<protocol::MessageHeader>(),
2159                "packet might not fit in message"
2160            )
2161        };
2162        let mut message = NvspMessage {
2163            buf: [0; protocol::PACKET_SIZE_V61 / 8],
2164            size,
2165        };
2166        let header = protocol::MessageHeader { message_type };
2167        header.write_to_prefix(message.buf.as_mut_bytes()).unwrap();
2168        data.write_to_prefix(
2169            &mut message.buf.as_mut_bytes()[size_of::<protocol::MessageHeader>()..],
2170        )
2171        .unwrap();
2172        message
2173    }
2174
2175    fn aligned_payload(&self) -> &[u64] {
2176        // Note that vmbus packets are always 8-byte multiples, so round the
2177        // protocol package size up.
2178        let len = match self.size {
2179            PacketSize::V1 => const { protocol::PACKET_SIZE_V1.next_multiple_of(8) / 8 },
2180            PacketSize::V61 => const { protocol::PACKET_SIZE_V61.next_multiple_of(8) / 8 },
2181        };
2182        &self.buf[..len]
2183    }
2184}
2185
2186impl<T: RingMem> NetChannel<T> {
2187    fn message<P: IntoBytes + Immutable + KnownLayout>(
2188        &self,
2189        message_type: u32,
2190        data: P,
2191    ) -> NvspMessage {
2192        NvspMessage::new(self.packet_size, message_type, data)
2193    }
2194
2195    fn send_completion(
2196        &mut self,
2197        transaction_id: Option<u64>,
2198        message: Option<&NvspMessage>,
2199    ) -> Result<(), WorkerError> {
2200        match transaction_id {
2201            None => Ok(()),
2202            Some(transaction_id) => Ok(self
2203                .queue
2204                .split()
2205                .1
2206                .batched()
2207                .try_write_aligned(
2208                    transaction_id,
2209                    OutgoingPacketType::Completion,
2210                    message.map_or(&[], |m| m.aligned_payload()),
2211                )
2212                .map_err(|err| match err {
2213                    queue::TryWriteError::Full(_) => WorkerError::OutOfSpace,
2214                    queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2215                })?),
2216        }
2217    }
2218}
2219
2220static SUPPORTED_VERSIONS: &[Version] = &[
2221    Version::V1,
2222    Version::V2,
2223    Version::V4,
2224    Version::V5,
2225    Version::V6,
2226    Version::V61,
2227];
2228
2229fn check_version(requested_version: u32) -> Option<Version> {
2230    SUPPORTED_VERSIONS
2231        .iter()
2232        .find(|version| **version as u32 == requested_version)
2233        .copied()
2234}
2235
2236#[derive(Debug)]
2237struct ReceiveBuffer {
2238    gpadl: GpadlView,
2239    id: u16,
2240    count: u32,
2241    sub_allocation_size: u32,
2242}
2243
2244#[derive(Debug, Error)]
2245enum BufferError {
2246    #[error("unsupported suballocation size {0}")]
2247    UnsupportedSuballocationSize(u32),
2248    #[error("unaligned gpadl")]
2249    UnalignedGpadl,
2250    #[error("unknown gpadl ID")]
2251    UnknownGpadlId(#[from] UnknownGpadlId),
2252}
2253
2254impl ReceiveBuffer {
2255    fn new(
2256        gpadl_map: &GpadlMapView,
2257        gpadl_id: GpadlId,
2258        id: u16,
2259        sub_allocation_size: u32,
2260    ) -> Result<Self, BufferError> {
2261        if sub_allocation_size < sub_allocation_size_for_mtu(DEFAULT_MTU) {
2262            return Err(BufferError::UnsupportedSuballocationSize(
2263                sub_allocation_size,
2264            ));
2265        }
2266        let gpadl = gpadl_map.map(gpadl_id)?;
2267        let range = gpadl
2268            .contiguous_aligned()
2269            .ok_or(BufferError::UnalignedGpadl)?;
2270        let num_sub_allocations = range.len() as u32 / sub_allocation_size;
2271        if num_sub_allocations == 0 {
2272            return Err(BufferError::UnsupportedSuballocationSize(
2273                sub_allocation_size,
2274            ));
2275        }
2276        let recv_buffer = Self {
2277            gpadl,
2278            id,
2279            count: num_sub_allocations,
2280            sub_allocation_size,
2281        };
2282        Ok(recv_buffer)
2283    }
2284
2285    fn range(&self, index: u32) -> PagedRange<'_> {
2286        self.gpadl.first().unwrap().subrange(
2287            (index * self.sub_allocation_size) as usize,
2288            self.sub_allocation_size as usize,
2289        )
2290    }
2291
2292    fn transfer_page_range(&self, index: u32, len: usize) -> ring::TransferPageRange {
2293        assert!(len <= self.sub_allocation_size as usize);
2294        ring::TransferPageRange {
2295            byte_offset: index * self.sub_allocation_size,
2296            byte_count: len as u32,
2297        }
2298    }
2299
2300    fn saved_state(&self) -> saved_state::ReceiveBuffer {
2301        saved_state::ReceiveBuffer {
2302            gpadl_id: self.gpadl.id(),
2303            id: self.id,
2304            sub_allocation_size: self.sub_allocation_size,
2305        }
2306    }
2307}
2308
2309#[derive(Debug)]
2310struct SendBuffer {
2311    gpadl: GpadlView,
2312}
2313
2314impl SendBuffer {
2315    fn new(gpadl_map: &GpadlMapView, gpadl_id: GpadlId) -> Result<Self, BufferError> {
2316        let gpadl = gpadl_map.map(gpadl_id)?;
2317        gpadl
2318            .contiguous_aligned()
2319            .ok_or(BufferError::UnalignedGpadl)?;
2320        Ok(Self { gpadl })
2321    }
2322}
2323
2324impl<T: RingMem> NetChannel<T> {
2325    /// Process a single non-packet RNDIS message.
2326    fn handle_rndis_message(
2327        &mut self,
2328        state: &mut ActiveState,
2329        message_type: u32,
2330        mut reader: PacketReader<'_>,
2331    ) -> Result<(), WorkerError> {
2332        assert_ne!(
2333            message_type,
2334            rndisprot::MESSAGE_TYPE_PACKET_MSG,
2335            "handled elsewhere"
2336        );
2337        let control = state
2338            .primary
2339            .as_mut()
2340            .ok_or(WorkerError::NotSupportedOnSubChannel(message_type))?;
2341
2342        if message_type == rndisprot::MESSAGE_TYPE_HALT_MSG {
2343            // Currently ignored and does not require a response.
2344            return Ok(());
2345        }
2346
2347        // This is a control message that needs a response. Responding
2348        // will require a suballocation to be available, which it may
2349        // not be right now. Enqueue the suballocation to a queue and
2350        // process the queue as suballocations become available.
2351        const CONTROL_MESSAGE_MAX_QUEUED_BYTES: usize = 100 * 1024;
2352        if reader.len() == 0 {
2353            return Err(WorkerError::RndisMessageTooSmall);
2354        }
2355        // Do not let the queue get too large--the guest should not be
2356        // sending very many control messages at a time.
2357        if CONTROL_MESSAGE_MAX_QUEUED_BYTES - control.control_messages_len < reader.len() {
2358            return Err(WorkerError::TooManyControlMessages);
2359        }
2360
2361        control.control_messages_len += reader.len();
2362        control.control_messages.push_back(ControlMessage {
2363            message_type,
2364            data: reader.read_all()?.into(),
2365        });
2366
2367        // The control message queue will be processed in the main dispatch
2368        // loop.
2369        Ok(())
2370    }
2371
2372    /// Process RNDIS packet messages, which may contain multiple RNDIS packets
2373    /// in a single vmbus message.
2374    ///
2375    /// On entry, the reader has already read the RNDIS message header of the
2376    /// first RNDIS packet in the message.
2377    fn handle_rndis_packet_messages(
2378        &mut self,
2379        buffers: &ChannelBuffers,
2380        state: &mut ActiveState,
2381        id: TxId,
2382        mut message_len: usize,
2383        mut reader: PacketReader<'_>,
2384        segments: &mut Vec<TxSegment>,
2385    ) -> Result<usize, WorkerError> {
2386        // There may be multiple RNDIS packets in a single message, concatenated
2387        // with each other. Consume them until there is no more data in the
2388        // RNDIS message.
2389        let mut num_packets = 0;
2390        loop {
2391            let next_message_offset = message_len
2392                .checked_sub(size_of::<rndisprot::MessageHeader>())
2393                .ok_or(WorkerError::RndisMessageTooSmall)?;
2394
2395            self.handle_rndis_packet_message(
2396                id,
2397                reader.clone(),
2398                &buffers.mem,
2399                segments,
2400                &mut state.stats,
2401            )?;
2402            num_packets += 1;
2403
2404            reader.skip(next_message_offset)?;
2405            if reader.len() == 0 {
2406                break;
2407            }
2408            let header: rndisprot::MessageHeader = reader.read_plain()?;
2409            if header.message_type != rndisprot::MESSAGE_TYPE_PACKET_MSG {
2410                return Err(WorkerError::NonRndisPacketAfterPacket(header.message_type));
2411            }
2412            message_len = header.message_length as usize;
2413        }
2414        Ok(num_packets)
2415    }
2416
2417    /// Process an RNDIS package message (used to send an Ethernet frame).
2418    fn handle_rndis_packet_message(
2419        &mut self,
2420        id: TxId,
2421        reader: PacketReader<'_>,
2422        mem: &GuestMemory,
2423        segments: &mut Vec<TxSegment>,
2424        stats: &mut QueueStats,
2425    ) -> Result<(), WorkerError> {
2426        // Headers are guaranteed to be in a single PagedRange.
2427        let headers = reader
2428            .clone()
2429            .into_inner()
2430            .paged_ranges()
2431            .next()
2432            .ok_or(WorkerError::RndisMessageTooSmall)?;
2433        let mut data = reader.into_inner();
2434        let request: rndisprot::Packet = headers.reader(mem).read_plain()?;
2435        if request.num_oob_data_elements != 0
2436            || request.oob_data_length != 0
2437            || request.oob_data_offset != 0
2438            || request.vc_handle != 0
2439        {
2440            return Err(WorkerError::UnsupportedRndisBehavior);
2441        }
2442
2443        if data.len() < request.data_offset as usize
2444            || (data.len() - request.data_offset as usize) < request.data_length as usize
2445            || request.data_length == 0
2446        {
2447            return Err(WorkerError::RndisMessageTooSmall);
2448        }
2449
2450        data.skip(request.data_offset as usize);
2451        data.truncate(request.data_length as usize);
2452
2453        let mut metadata = net_backend::TxMetadata {
2454            id,
2455            len: request.data_length,
2456            ..Default::default()
2457        };
2458
2459        if request.per_packet_info_length != 0 {
2460            let mut ppi = headers
2461                .try_subrange(
2462                    request.per_packet_info_offset as usize,
2463                    request.per_packet_info_length as usize,
2464                )
2465                .ok_or(WorkerError::RndisMessageTooSmall)?;
2466            while !ppi.is_empty() {
2467                let h: rndisprot::PerPacketInfo = ppi.reader(mem).read_plain()?;
2468                if h.size == 0 {
2469                    return Err(WorkerError::RndisMessageTooSmall);
2470                }
2471                let (this, rest) = ppi
2472                    .try_split(h.size as usize)
2473                    .ok_or(WorkerError::RndisMessageTooSmall)?;
2474                let (_, d) = this
2475                    .try_split(h.per_packet_information_offset as usize)
2476                    .ok_or(WorkerError::RndisMessageTooSmall)?;
2477                match h.typ {
2478                    rndisprot::PPI_TCP_IP_CHECKSUM => {
2479                        let n: rndisprot::TxTcpIpChecksumInfo = d.reader(mem).read_plain()?;
2480
2481                        metadata.flags.set_offload_tcp_checksum(
2482                            (n.is_ipv4() || n.is_ipv6()) && n.tcp_checksum(),
2483                        );
2484                        metadata.flags.set_offload_udp_checksum(
2485                            (n.is_ipv4() || n.is_ipv6()) && !n.tcp_checksum() && n.udp_checksum(),
2486                        );
2487                        metadata
2488                            .flags
2489                            .set_offload_ip_header_checksum(n.is_ipv4() && n.ip_header_checksum());
2490                        metadata.flags.set_is_ipv4(n.is_ipv4());
2491                        metadata.flags.set_is_ipv6(n.is_ipv6() && !n.is_ipv4());
2492                        metadata.l2_len = ETHERNET_HEADER_LEN as u8;
2493                        if metadata.flags.offload_tcp_checksum()
2494                            || metadata.flags.offload_udp_checksum()
2495                        {
2496                            metadata.l3_len = if n.tcp_header_offset() >= metadata.l2_len as u16 {
2497                                n.tcp_header_offset() - metadata.l2_len as u16
2498                            } else if n.is_ipv4() {
2499                                let mut reader = data.clone().reader(mem);
2500                                reader.skip(metadata.l2_len as usize)?;
2501                                let mut b = 0;
2502                                reader.read(std::slice::from_mut(&mut b))?;
2503                                (b as u16 >> 4) * 4
2504                            } else {
2505                                // Hope there are no extensions.
2506                                40
2507                            };
2508                        }
2509                    }
2510                    rndisprot::PPI_LSO => {
2511                        let n: rndisprot::TcpLsoInfo = d.reader(mem).read_plain()?;
2512
2513                        metadata.flags.set_offload_tcp_segmentation(true);
2514                        metadata.flags.set_offload_tcp_checksum(true);
2515                        metadata.flags.set_offload_ip_header_checksum(n.is_ipv4());
2516                        metadata.flags.set_is_ipv4(n.is_ipv4());
2517                        metadata.flags.set_is_ipv6(n.is_ipv6() && !n.is_ipv4());
2518                        metadata.l2_len = ETHERNET_HEADER_LEN as u8;
2519                        if n.tcp_header_offset() < metadata.l2_len as u16 {
2520                            return Err(WorkerError::InvalidTcpHeaderOffset);
2521                        }
2522                        metadata.l3_len = n.tcp_header_offset() - metadata.l2_len as u16;
2523                        metadata.l4_len = {
2524                            let mut reader = data.clone().reader(mem);
2525                            reader
2526                                .skip(metadata.l2_len as usize + metadata.l3_len as usize + 12)?;
2527                            let mut b = 0;
2528                            reader.read(std::slice::from_mut(&mut b))?;
2529                            (b >> 4) * 4
2530                        };
2531                        metadata.max_tcp_segment_size = n.mss() as u16;
2532
2533                        if request.data_length >= rndisprot::LSO_MAX_OFFLOAD_SIZE {
2534                            // Not strictly enforced.
2535                            stats.tx_invalid_lso_packets.increment();
2536                        }
2537                    }
2538                    _ => {}
2539                }
2540                ppi = rest;
2541            }
2542        }
2543
2544        let start = segments.len();
2545        for range in data.paged_ranges().flat_map(|r| r.ranges()) {
2546            let range = range.map_err(WorkerError::InvalidGpadl)?;
2547            segments.push(TxSegment {
2548                ty: net_backend::TxSegmentType::Tail,
2549                gpa: range.start,
2550                len: range.len() as u32,
2551            });
2552        }
2553
2554        metadata.segment_count = (segments.len() - start) as u8;
2555
2556        stats.tx_packets.increment();
2557        if metadata.flags.offload_tcp_checksum() || metadata.flags.offload_udp_checksum() {
2558            stats.tx_checksum_packets.increment();
2559        }
2560        if metadata.flags.offload_tcp_segmentation() {
2561            stats.tx_lso_packets.increment();
2562        }
2563
2564        segments[start].ty = net_backend::TxSegmentType::Head(metadata);
2565
2566        Ok(())
2567    }
2568
2569    /// Notify the adapter that the guest VF state has changed and it may
2570    /// need to send a message to the guest.
2571    fn guest_vf_is_available(
2572        &mut self,
2573        guest_vf_id: Option<u32>,
2574        version: Version,
2575        config: NdisConfig,
2576    ) -> Result<bool, WorkerError> {
2577        let serial_number = guest_vf_id.map(|vfid| self.adapter.get_guest_vf_serial_number(vfid));
2578        if version >= Version::V4 && config.capabilities.sriov() {
2579            tracing::info!(
2580                available = serial_number.is_some(),
2581                serial_number,
2582                "sending VF association message."
2583            );
2584            // N.B. MIN_CONTROL_RING_SIZE reserves room to send this packet.
2585            let message = {
2586                self.message(
2587                    protocol::MESSAGE4_TYPE_SEND_VF_ASSOCIATION,
2588                    protocol::Message4SendVfAssociation {
2589                        vf_allocated: if serial_number.is_some() { 1 } else { 0 },
2590                        serial_number: serial_number.unwrap_or(0),
2591                    },
2592                )
2593            };
2594            self.queue
2595                .split()
2596                .1
2597                .batched()
2598                .try_write_aligned(
2599                    VF_ASSOCIATION_TRANSACTION_ID,
2600                    OutgoingPacketType::InBandWithCompletion,
2601                    message.aligned_payload(),
2602                )
2603                .map_err(|err| match err {
2604                    queue::TryWriteError::Full(len) => {
2605                        tracing::error!(len, "failed to write vf association message");
2606                        WorkerError::OutOfSpace
2607                    }
2608                    queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2609                })?;
2610            Ok(true)
2611        } else {
2612            tracing::info!(
2613                available = serial_number.is_some(),
2614                serial_number,
2615                major = version.major(),
2616                minor = version.minor(),
2617                sriov_capable = config.capabilities.sriov(),
2618                "Skipping NvspMessage4TypeSendVFAssociation message"
2619            );
2620            Ok(false)
2621        }
2622    }
2623
2624    /// Send the `NvspMessage5TypeSendIndirectionTable` message.
2625    fn guest_send_indirection_table(&mut self, version: Version, num_channels_opened: u32) {
2626        // N.B. MIN_STATE_CHANGE_RING_SIZE needs to be large enough to support sending the indirection table.
2627        if version < Version::V5 {
2628            return;
2629        }
2630
2631        #[repr(C)]
2632        #[derive(IntoBytes, Immutable, KnownLayout)]
2633        struct SendIndirectionMsg {
2634            pub message: protocol::Message5SendIndirectionTable,
2635            pub send_indirection_table:
2636                [u32; VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES as usize],
2637        }
2638
2639        // The offset to the send indirection table from the beginning of the NVSP message.
2640        let send_indirection_table_offset = offset_of!(SendIndirectionMsg, send_indirection_table)
2641            + size_of::<protocol::MessageHeader>();
2642        let mut data = SendIndirectionMsg {
2643            message: protocol::Message5SendIndirectionTable {
2644                table_entry_count: VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES,
2645                table_offset: send_indirection_table_offset as u32,
2646            },
2647            send_indirection_table: Default::default(),
2648        };
2649
2650        for i in 0..data.send_indirection_table.len() {
2651            data.send_indirection_table[i] = i as u32 % num_channels_opened;
2652        }
2653
2654        let header = protocol::MessageHeader {
2655            message_type: protocol::MESSAGE5_TYPE_SEND_INDIRECTION_TABLE,
2656        };
2657        let result = self
2658            .queue
2659            .split()
2660            .1
2661            .try_write(&queue::OutgoingPacket {
2662                transaction_id: 0,
2663                packet_type: OutgoingPacketType::InBandNoCompletion,
2664                payload: &[header.as_bytes(), data.as_bytes()],
2665            })
2666            .map_err(|err| match err {
2667                queue::TryWriteError::Full(len) => {
2668                    tracing::error!(len, "failed to write send indirection table message");
2669                    WorkerError::OutOfSpace
2670                }
2671                queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2672            });
2673        if let Err(err) = result {
2674            tracing::error!(err = %err, "Failed to notify guest about the send indirection table");
2675        }
2676    }
2677
2678    /// Notify the guest that the data path has been switched back to synthetic
2679    /// due to some external state change.
2680    fn guest_vf_data_path_switched_to_synthetic(&mut self) {
2681        let header = protocol::MessageHeader {
2682            message_type: protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH,
2683        };
2684        let data = protocol::Message4SwitchDataPath {
2685            active_data_path: protocol::DataPath::SYNTHETIC.0,
2686        };
2687        let result = self
2688            .queue
2689            .split()
2690            .1
2691            .try_write(&queue::OutgoingPacket {
2692                transaction_id: SWITCH_DATA_PATH_TRANSACTION_ID,
2693                packet_type: OutgoingPacketType::InBandWithCompletion,
2694                payload: &[header.as_bytes(), data.as_bytes()],
2695            })
2696            .map_err(|err| match err {
2697                queue::TryWriteError::Full(len) => {
2698                    tracing::error!(
2699                        len,
2700                        "failed to write MESSAGE4_TYPE_SWITCH_DATA_PATH message"
2701                    );
2702                    WorkerError::OutOfSpace
2703                }
2704                queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2705            });
2706        if let Err(err) = result {
2707            tracing::error!(err = %err, "Failed to notify guest that data path is now synthetic");
2708        }
2709    }
2710
2711    /// Process an internal state change
2712    async fn handle_state_change(
2713        &mut self,
2714        primary: &mut PrimaryChannelState,
2715        buffers: &ChannelBuffers,
2716    ) -> Result<Option<CoordinatorMessage>, WorkerError> {
2717        // N.B. MIN_STATE_CHANGE_RING_SIZE needs to be large enough to support sending state change messages.
2718        // The worst case is UnavailableFromDataPathSwitchPending, which will send three messages:
2719        //      1. completion of switch data path request
2720        //      2. Switch data path notification (back to synthetic)
2721        //      3. Disassociate VF adapter.
2722        if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
2723            // Notify guest that a VF capability has recently arrived.
2724            if primary.rndis_state == RndisState::Operational {
2725                if self.guest_vf_is_available(Some(vfid), buffers.version, buffers.ndis_config)? {
2726                    primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
2727                    return Ok(Some(CoordinatorMessage::Update(
2728                        CoordinatorMessageUpdateType {
2729                            guest_vf_state: true,
2730                            ..Default::default()
2731                        },
2732                    )));
2733                } else if let Some(true) = primary.is_data_path_switched {
2734                    tracing::error!(
2735                        "Data path switched, but current guest negotiation does not support VTL0 VF"
2736                    );
2737                }
2738            }
2739            return Ok(None);
2740        }
2741        loop {
2742            primary.guest_vf_state = match primary.guest_vf_state {
2743                PrimaryChannelGuestVfState::UnavailableFromAvailable => {
2744                    // Notify guest that the VF is unavailable. It has already been surprise removed.
2745                    if primary.rndis_state == RndisState::Operational {
2746                        self.guest_vf_is_available(None, buffers.version, buffers.ndis_config)?;
2747                    }
2748                    PrimaryChannelGuestVfState::Unavailable
2749                }
2750                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
2751                    to_guest,
2752                    id,
2753                } => {
2754                    // Complete the data path switch request.
2755                    self.send_completion(id, None)?;
2756                    if to_guest {
2757                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
2758                    } else {
2759                        PrimaryChannelGuestVfState::UnavailableFromAvailable
2760                    }
2761                }
2762                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
2763                    // Notify guest that the data path is now synthetic.
2764                    self.guest_vf_data_path_switched_to_synthetic();
2765                    PrimaryChannelGuestVfState::UnavailableFromAvailable
2766                }
2767                PrimaryChannelGuestVfState::DataPathSynthetic => {
2768                    // Notify guest that the data path is now synthetic.
2769                    self.guest_vf_data_path_switched_to_synthetic();
2770                    PrimaryChannelGuestVfState::Ready
2771                }
2772                PrimaryChannelGuestVfState::DataPathSwitchPending {
2773                    to_guest,
2774                    id,
2775                    result,
2776                } => {
2777                    let result = result.expect("DataPathSwitchPending should have been processed");
2778                    // Complete the data path switch request.
2779                    self.send_completion(id, None)?;
2780                    if result {
2781                        if to_guest {
2782                            PrimaryChannelGuestVfState::DataPathSwitched
2783                        } else {
2784                            PrimaryChannelGuestVfState::Ready
2785                        }
2786                    } else {
2787                        if to_guest {
2788                            PrimaryChannelGuestVfState::DataPathSynthetic
2789                        } else {
2790                            tracing::error!(
2791                                "Failure when guest requested switch back to synthetic"
2792                            );
2793                            PrimaryChannelGuestVfState::DataPathSwitched
2794                        }
2795                    }
2796                }
2797                PrimaryChannelGuestVfState::Initializing
2798                | PrimaryChannelGuestVfState::Restoring(_) => {
2799                    panic!("Invalid guest VF state: {}", primary.guest_vf_state)
2800                }
2801                _ => break,
2802            };
2803        }
2804        Ok(None)
2805    }
2806
2807    /// Process a control message, writing the response to the provided receive
2808    /// buffer suballocation.
2809    fn handle_rndis_control_message(
2810        &mut self,
2811        primary: &mut PrimaryChannelState,
2812        buffers: &ChannelBuffers,
2813        message_type: u32,
2814        mut reader: impl MemoryRead + Clone,
2815        id: u32,
2816    ) -> Result<(), WorkerError> {
2817        let mem = &buffers.mem;
2818        let buffer_range = &buffers.recv_buffer.range(id);
2819        match message_type {
2820            rndisprot::MESSAGE_TYPE_INITIALIZE_MSG => {
2821                if primary.rndis_state != RndisState::Initializing {
2822                    return Err(WorkerError::InvalidRndisState);
2823                }
2824
2825                let request: rndisprot::InitializeRequest = reader.read_plain()?;
2826
2827                tracing::trace!(
2828                    ?request,
2829                    "handling control message MESSAGE_TYPE_INITIALIZE_MSG"
2830                );
2831
2832                primary.rndis_state = RndisState::Operational;
2833
2834                let mut writer = buffer_range.writer(mem);
2835                let message_length = write_rndis_message(
2836                    &mut writer,
2837                    rndisprot::MESSAGE_TYPE_INITIALIZE_CMPLT,
2838                    0,
2839                    &rndisprot::InitializeComplete {
2840                        request_id: request.request_id,
2841                        status: rndisprot::STATUS_SUCCESS,
2842                        major_version: rndisprot::MAJOR_VERSION,
2843                        minor_version: rndisprot::MINOR_VERSION,
2844                        device_flags: rndisprot::DF_CONNECTIONLESS,
2845                        medium: rndisprot::MEDIUM_802_3,
2846                        max_packets_per_message: 8,
2847                        max_transfer_size: 0xEFFFFFFF,
2848                        packet_alignment_factor: 3,
2849                        af_list_offset: 0,
2850                        af_list_size: 0,
2851                    },
2852                )?;
2853                self.send_rndis_control_message(buffers, id, message_length)?;
2854                if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
2855                    if self.guest_vf_is_available(
2856                        Some(vfid),
2857                        buffers.version,
2858                        buffers.ndis_config,
2859                    )? {
2860                        // Ideally the VF would not be presented to the guest
2861                        // until the completion packet has arrived, so that the
2862                        // guest is prepared. This is most interesting for the
2863                        // case of a VF associated with multiple guest
2864                        // adapters, using a concept like vports. In this
2865                        // scenario it would be better if all of the adapters
2866                        // were aware a VF was available before the device
2867                        // arrived. This is not currently possible because the
2868                        // Linux netvsc driver ignores the completion requested
2869                        // flag on inband packets and won't send a completion
2870                        // packet.
2871                        primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
2872                        self.send_coordinator_update_vf();
2873                    } else if let Some(true) = primary.is_data_path_switched {
2874                        tracing::error!(
2875                            "Data path switched, but current guest negotiation does not support VTL0 VF"
2876                        );
2877                    }
2878                }
2879            }
2880            rndisprot::MESSAGE_TYPE_QUERY_MSG => {
2881                let request: rndisprot::QueryRequest = reader.read_plain()?;
2882
2883                tracing::trace!(?request, "handling control message MESSAGE_TYPE_QUERY_MSG");
2884
2885                let (header, body) = buffer_range
2886                    .try_split(
2887                        size_of::<rndisprot::MessageHeader>()
2888                            + size_of::<rndisprot::QueryComplete>(),
2889                    )
2890                    .ok_or(WorkerError::RndisMessageTooSmall)?;
2891                let (status, tx) = match self.adapter.handle_oid_query(
2892                    buffers,
2893                    primary,
2894                    request.oid,
2895                    body.writer(mem),
2896                ) {
2897                    Ok(tx) => (rndisprot::STATUS_SUCCESS, tx),
2898                    Err(err) => (err.as_status(), 0),
2899                };
2900
2901                let message_length = write_rndis_message(
2902                    &mut header.writer(mem),
2903                    rndisprot::MESSAGE_TYPE_QUERY_CMPLT,
2904                    tx,
2905                    &rndisprot::QueryComplete {
2906                        request_id: request.request_id,
2907                        status,
2908                        information_buffer_offset: size_of::<rndisprot::QueryComplete>() as u32,
2909                        information_buffer_length: tx as u32,
2910                    },
2911                )?;
2912                self.send_rndis_control_message(buffers, id, message_length)?;
2913            }
2914            rndisprot::MESSAGE_TYPE_SET_MSG => {
2915                let request: rndisprot::SetRequest = reader.read_plain()?;
2916
2917                tracing::trace!(?request, "handling control message MESSAGE_TYPE_SET_MSG");
2918
2919                let status = match self.adapter.handle_oid_set(primary, request.oid, reader) {
2920                    Ok((restart_endpoint, packet_filter)) => {
2921                        // Restart the endpoint if the OID changed some critical
2922                        // endpoint property.
2923                        if restart_endpoint {
2924                            self.restart = Some(CoordinatorMessage::Restart);
2925                        }
2926                        if let Some(filter) = packet_filter {
2927                            if self.packet_filter != filter {
2928                                self.packet_filter = filter;
2929                                self.send_coordinator_update_filter();
2930                            }
2931                        }
2932                        rndisprot::STATUS_SUCCESS
2933                    }
2934                    Err(err) => {
2935                        tracelimit::warn_ratelimited!(oid = ?request.oid, error = &err as &dyn std::error::Error, "oid failure");
2936                        err.as_status()
2937                    }
2938                };
2939
2940                let message_length = write_rndis_message(
2941                    &mut buffer_range.writer(mem),
2942                    rndisprot::MESSAGE_TYPE_SET_CMPLT,
2943                    0,
2944                    &rndisprot::SetComplete {
2945                        request_id: request.request_id,
2946                        status,
2947                    },
2948                )?;
2949                self.send_rndis_control_message(buffers, id, message_length)?;
2950            }
2951            rndisprot::MESSAGE_TYPE_RESET_MSG => {
2952                return Err(WorkerError::RndisMessageTypeNotImplemented);
2953            }
2954            rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG => {
2955                return Err(WorkerError::RndisMessageTypeNotImplemented);
2956            }
2957            rndisprot::MESSAGE_TYPE_KEEPALIVE_MSG => {
2958                let request: rndisprot::KeepaliveRequest = reader.read_plain()?;
2959
2960                tracing::trace!(
2961                    ?request,
2962                    "handling control message MESSAGE_TYPE_KEEPALIVE_MSG"
2963                );
2964
2965                let message_length = write_rndis_message(
2966                    &mut buffer_range.writer(mem),
2967                    rndisprot::MESSAGE_TYPE_KEEPALIVE_CMPLT,
2968                    0,
2969                    &rndisprot::KeepaliveComplete {
2970                        request_id: request.request_id,
2971                        status: rndisprot::STATUS_SUCCESS,
2972                    },
2973                )?;
2974                self.send_rndis_control_message(buffers, id, message_length)?;
2975            }
2976            rndisprot::MESSAGE_TYPE_SET_EX_MSG => {
2977                return Err(WorkerError::RndisMessageTypeNotImplemented);
2978            }
2979            _ => return Err(WorkerError::UnknownRndisMessageType(message_type)),
2980        };
2981        Ok(())
2982    }
2983
2984    fn try_send_rndis_message(
2985        &mut self,
2986        transaction_id: u64,
2987        channel_type: u32,
2988        recv_buffer_id: u16,
2989        transfer_pages: &[ring::TransferPageRange],
2990    ) -> Result<Option<usize>, WorkerError> {
2991        let message = self.message(
2992            protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET,
2993            protocol::Message1SendRndisPacket {
2994                channel_type,
2995                send_buffer_section_index: 0xffffffff,
2996                send_buffer_section_size: 0,
2997            },
2998        );
2999        let pending_send_size = match self.queue.split().1.batched().try_write_aligned(
3000            transaction_id,
3001            OutgoingPacketType::TransferPages(recv_buffer_id, transfer_pages),
3002            message.aligned_payload(),
3003        ) {
3004            Ok(()) => None,
3005            Err(queue::TryWriteError::Full(n)) => Some(n),
3006            Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
3007        };
3008        Ok(pending_send_size)
3009    }
3010
3011    fn send_rndis_control_message(
3012        &mut self,
3013        buffers: &ChannelBuffers,
3014        id: u32,
3015        message_length: usize,
3016    ) -> Result<(), WorkerError> {
3017        let result = self.try_send_rndis_message(
3018            id as u64,
3019            protocol::CONTROL_CHANNEL_TYPE,
3020            buffers.recv_buffer.id,
3021            std::slice::from_ref(&buffers.recv_buffer.transfer_page_range(id, message_length)),
3022        )?;
3023
3024        // Ring size is checked before control messages are processed, so failure to write is unexpected.
3025        match result {
3026            None => Ok(()),
3027            Some(len) => {
3028                tracelimit::error_ratelimited!(len, "failed to write control message completion");
3029                Err(WorkerError::OutOfSpace)
3030            }
3031        }
3032    }
3033
3034    fn indicate_status(
3035        &mut self,
3036        buffers: &ChannelBuffers,
3037        id: u32,
3038        status: u32,
3039        payload: &[u8],
3040    ) -> Result<(), WorkerError> {
3041        let buffer = &buffers.recv_buffer.range(id);
3042        let mut writer = buffer.writer(&buffers.mem);
3043        let message_length = write_rndis_message(
3044            &mut writer,
3045            rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG,
3046            payload.len(),
3047            &rndisprot::IndicateStatus {
3048                status,
3049                status_buffer_length: payload.len() as u32,
3050                status_buffer_offset: if payload.is_empty() {
3051                    0
3052                } else {
3053                    size_of::<rndisprot::IndicateStatus>() as u32
3054                },
3055            },
3056        )?;
3057        writer.write(payload)?;
3058        self.send_rndis_control_message(buffers, id, message_length)?;
3059        Ok(())
3060    }
3061
3062    /// Processes pending control messages until all are processed or there are
3063    /// no available suballocations.
3064    fn process_control_messages(
3065        &mut self,
3066        buffers: &ChannelBuffers,
3067        state: &mut ActiveState,
3068    ) -> Result<(), WorkerError> {
3069        let Some(primary) = &mut state.primary else {
3070            return Ok(());
3071        };
3072
3073        while !primary.control_messages.is_empty()
3074            || (primary.pending_offload_change && primary.rndis_state == RndisState::Operational)
3075        {
3076            // Ensure the ring buffer has enough room to successfully complete control message handling.
3077            if !self.queue.split().1.can_write(MIN_CONTROL_RING_SIZE)? {
3078                self.pending_send_size = MIN_CONTROL_RING_SIZE;
3079                break;
3080            }
3081            let Some(id) = primary.free_control_buffers.pop() else {
3082                break;
3083            };
3084
3085            // Mark the receive buffer in use to allow the guest to release it.
3086            assert!(state.rx_bufs.is_free(id.0));
3087            state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
3088
3089            if let Some(message) = primary.control_messages.pop_front() {
3090                primary.control_messages_len -= message.data.len();
3091                self.handle_rndis_control_message(
3092                    primary,
3093                    buffers,
3094                    message.message_type,
3095                    message.data.as_ref(),
3096                    id.0,
3097                )?;
3098            } else if primary.pending_offload_change
3099                && primary.rndis_state == RndisState::Operational
3100            {
3101                let ndis_offload = primary.offload_config.ndis_offload();
3102                self.indicate_status(
3103                    buffers,
3104                    id.0,
3105                    rndisprot::STATUS_TASK_OFFLOAD_CURRENT_CONFIG,
3106                    &ndis_offload.as_bytes()[..ndis_offload.header.size.into()],
3107                )?;
3108                primary.pending_offload_change = false;
3109            } else {
3110                unreachable!();
3111            }
3112        }
3113        Ok(())
3114    }
3115
3116    fn send_coordinator_update_message(&mut self, guest_vf: bool, packet_filter: bool) {
3117        if self.restart.is_none() {
3118            self.restart = Some(CoordinatorMessage::Update(CoordinatorMessageUpdateType {
3119                guest_vf_state: guest_vf,
3120                filter_state: packet_filter,
3121            }));
3122        } else if let Some(CoordinatorMessage::Restart) = self.restart {
3123            // If a restart message is pending, do nothing.
3124            // A restart will try to switch the data path based on primary.guest_vf_state.
3125            // A restart will apply packet filter changes.
3126        } else if let Some(CoordinatorMessage::Update(ref mut update)) = self.restart {
3127            // Add the new update to the existing message.
3128            update.guest_vf_state |= guest_vf;
3129            update.filter_state |= packet_filter;
3130        }
3131    }
3132
3133    fn send_coordinator_update_vf(&mut self) {
3134        self.send_coordinator_update_message(true, false);
3135    }
3136
3137    fn send_coordinator_update_filter(&mut self) {
3138        self.send_coordinator_update_message(false, true);
3139    }
3140}
3141
3142/// Writes an RNDIS message to `writer`.
3143fn write_rndis_message<T: IntoBytes + Immutable + KnownLayout>(
3144    writer: &mut impl MemoryWrite,
3145    message_type: u32,
3146    extra: usize,
3147    payload: &T,
3148) -> Result<usize, AccessError> {
3149    let message_length = size_of::<rndisprot::MessageHeader>() + size_of_val(payload) + extra;
3150    writer.write(
3151        rndisprot::MessageHeader {
3152            message_type,
3153            message_length: message_length as u32,
3154        }
3155        .as_bytes(),
3156    )?;
3157    writer.write(payload.as_bytes())?;
3158    Ok(message_length)
3159}
3160
3161#[derive(Debug, Error)]
3162enum OidError {
3163    #[error(transparent)]
3164    Access(#[from] AccessError),
3165    #[error("unknown oid")]
3166    UnknownOid,
3167    #[error("invalid oid input, bad field {0}")]
3168    InvalidInput(&'static str),
3169    #[error("bad ndis version")]
3170    BadVersion,
3171    #[error("feature {0} not supported")]
3172    NotSupported(&'static str),
3173}
3174
3175impl OidError {
3176    fn as_status(&self) -> u32 {
3177        match self {
3178            OidError::UnknownOid | OidError::NotSupported(_) => rndisprot::STATUS_NOT_SUPPORTED,
3179            OidError::BadVersion => rndisprot::STATUS_BAD_VERSION,
3180            OidError::InvalidInput(_) => rndisprot::STATUS_INVALID_DATA,
3181            OidError::Access(_) => rndisprot::STATUS_FAILURE,
3182        }
3183    }
3184}
3185
3186const DEFAULT_MTU: u32 = 1514;
3187const MIN_MTU: u32 = DEFAULT_MTU;
3188const MAX_MTU: u32 = 9216;
3189
3190const ETHERNET_HEADER_LEN: u32 = 14;
3191
3192impl Adapter {
3193    fn get_guest_vf_serial_number(&self, vfid: u32) -> u32 {
3194        if let Some(guest_os_id) = self.get_guest_os_id.as_ref().map(|f| f()) {
3195            // For enlightened guests (which is only Windows at the moment), send the
3196            // adapter index, which was previously set as the vport serial number.
3197            if guest_os_id
3198                .microsoft()
3199                .unwrap_or(HvGuestOsMicrosoft::from(0))
3200                .os_id()
3201                == HvGuestOsMicrosoftIds::WINDOWS_NT.0
3202            {
3203                self.adapter_index
3204            } else {
3205                vfid
3206            }
3207        } else {
3208            vfid
3209        }
3210    }
3211
3212    fn handle_oid_query(
3213        &self,
3214        buffers: &ChannelBuffers,
3215        primary: &PrimaryChannelState,
3216        oid: rndisprot::Oid,
3217        mut writer: impl MemoryWrite,
3218    ) -> Result<usize, OidError> {
3219        tracing::debug!(?oid, "oid query");
3220        let available_len = writer.len();
3221        match oid {
3222            rndisprot::Oid::OID_GEN_SUPPORTED_LIST => {
3223                let supported_oids_common = &[
3224                    rndisprot::Oid::OID_GEN_SUPPORTED_LIST,
3225                    rndisprot::Oid::OID_GEN_HARDWARE_STATUS,
3226                    rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED,
3227                    rndisprot::Oid::OID_GEN_MEDIA_IN_USE,
3228                    rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD,
3229                    rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD,
3230                    rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE,
3231                    rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE,
3232                    rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE,
3233                    rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE,
3234                    rndisprot::Oid::OID_GEN_LINK_SPEED,
3235                    rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE,
3236                    rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE,
3237                    rndisprot::Oid::OID_GEN_VENDOR_ID,
3238                    rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION,
3239                    rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION,
3240                    rndisprot::Oid::OID_GEN_DRIVER_VERSION,
3241                    rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER,
3242                    rndisprot::Oid::OID_GEN_PROTOCOL_OPTIONS,
3243                    rndisprot::Oid::OID_GEN_MAC_OPTIONS,
3244                    rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS,
3245                    rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS,
3246                    rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES,
3247                    rndisprot::Oid::OID_GEN_FRIENDLY_NAME,
3248                    // Ethernet objects operation characteristics
3249                    rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS,
3250                    rndisprot::Oid::OID_802_3_CURRENT_ADDRESS,
3251                    rndisprot::Oid::OID_802_3_MULTICAST_LIST,
3252                    rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE,
3253                    // Ethernet objects statistics
3254                    rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT,
3255                    rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION,
3256                    rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS,
3257                    // PNP operations characteristics */
3258                    // rndisprot::Oid::OID_PNP_SET_POWER,
3259                    // rndisprot::Oid::OID_PNP_QUERY_POWER,
3260                    // RNDIS OIDS
3261                    rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER,
3262                ];
3263
3264                // NDIS6 OIDs
3265                let supported_oids_6 = &[
3266                    // Link State OID
3267                    rndisprot::Oid::OID_GEN_LINK_PARAMETERS,
3268                    rndisprot::Oid::OID_GEN_LINK_STATE,
3269                    rndisprot::Oid::OID_GEN_MAX_LINK_SPEED,
3270                    // NDIS 6 statistics OID
3271                    rndisprot::Oid::OID_GEN_BYTES_RCV,
3272                    rndisprot::Oid::OID_GEN_BYTES_XMIT,
3273                    // Offload related OID
3274                    rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS,
3275                    rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION,
3276                    rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES,
3277                    rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG,
3278                    // rndisprot::Oid::OID_802_3_ADD_MULTICAST_ADDRESS,
3279                    // rndisprot::Oid::OID_802_3_DELETE_MULTICAST_ADDRESS,
3280                ];
3281
3282                let supported_oids_63 = &[
3283                    rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES,
3284                    rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS,
3285                ];
3286
3287                match buffers.ndis_version.major {
3288                    5 => {
3289                        writer.write(supported_oids_common.as_bytes())?;
3290                    }
3291                    6 => {
3292                        writer.write(supported_oids_common.as_bytes())?;
3293                        writer.write(supported_oids_6.as_bytes())?;
3294                        if buffers.ndis_version.minor >= 30 {
3295                            writer.write(supported_oids_63.as_bytes())?;
3296                        }
3297                    }
3298                    _ => return Err(OidError::BadVersion),
3299                }
3300            }
3301            rndisprot::Oid::OID_GEN_HARDWARE_STATUS => {
3302                let status: u32 = 0; // NdisHardwareStatusReady
3303                writer.write(status.as_bytes())?;
3304            }
3305            rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED | rndisprot::Oid::OID_GEN_MEDIA_IN_USE => {
3306                writer.write(rndisprot::MEDIUM_802_3.as_bytes())?;
3307            }
3308            rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD
3309            | rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD
3310            | rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE => {
3311                let len: u32 = buffers.ndis_config.mtu - ETHERNET_HEADER_LEN;
3312                writer.write(len.as_bytes())?;
3313            }
3314            rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE
3315            | rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE
3316            | rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE => {
3317                let len: u32 = buffers.ndis_config.mtu;
3318                writer.write(len.as_bytes())?;
3319            }
3320            rndisprot::Oid::OID_GEN_LINK_SPEED => {
3321                let speed: u32 = (self.link_speed / 100) as u32; // In 100bps units
3322                writer.write(speed.as_bytes())?;
3323            }
3324            rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE
3325            | rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE => {
3326                // This value is meaningless for virtual NICs. Return what vmswitch returns.
3327                writer.write((256u32 * 1024).as_bytes())?
3328            }
3329            rndisprot::Oid::OID_GEN_VENDOR_ID => {
3330                // Like vmswitch, use the first N bytes of Microsoft's MAC address
3331                // prefix as the vendor ID.
3332                writer.write(0x0000155du32.as_bytes())?;
3333            }
3334            rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION => writer.write(b"Microsoft Corporation")?,
3335            rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION
3336            | rndisprot::Oid::OID_GEN_DRIVER_VERSION => {
3337                writer.write(0x0100u16.as_bytes())? // 1.0. Vmswitch reports 19.0 for Mn.
3338            }
3339            rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => writer.write(0u32.as_bytes())?,
3340            rndisprot::Oid::OID_GEN_MAC_OPTIONS => {
3341                let options: u32 = rndisprot::MAC_OPTION_COPY_LOOKAHEAD_DATA
3342                    | rndisprot::MAC_OPTION_TRANSFERS_NOT_PEND
3343                    | rndisprot::MAC_OPTION_NO_LOOPBACK;
3344                writer.write(options.as_bytes())?;
3345            }
3346            rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS => {
3347                writer.write(rndisprot::MEDIA_STATE_CONNECTED.as_bytes())?;
3348            }
3349            rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS => writer.write(u32::MAX.as_bytes())?,
3350            rndisprot::Oid::OID_GEN_FRIENDLY_NAME => {
3351                let name16: Vec<u16> = "Network Device".encode_utf16().collect();
3352                let mut name = rndisprot::FriendlyName::new_zeroed();
3353                name.name[..name16.len()].copy_from_slice(&name16);
3354                writer.write(name.as_bytes())?
3355            }
3356            rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS
3357            | rndisprot::Oid::OID_802_3_CURRENT_ADDRESS => {
3358                writer.write(&self.mac_address.to_bytes())?
3359            }
3360            rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE => {
3361                writer.write(0u32.as_bytes())?;
3362            }
3363            rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT
3364            | rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION
3365            | rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS => writer.write(0u32.as_bytes())?,
3366
3367            // NDIS6 OIDs:
3368            rndisprot::Oid::OID_GEN_LINK_STATE => {
3369                let link_state = rndisprot::LinkState {
3370                    header: rndisprot::NdisObjectHeader {
3371                        object_type: rndisprot::NdisObjectType::DEFAULT,
3372                        revision: 1,
3373                        size: size_of::<rndisprot::LinkState>() as u16,
3374                    },
3375                    media_connect_state: 1, /* MediaConnectStateConnected */
3376                    media_duplex_state: 0,  /* MediaDuplexStateUnknown */
3377                    padding: 0,
3378                    xmit_link_speed: self.link_speed,
3379                    rcv_link_speed: self.link_speed,
3380                    pause_functions: 0, /* NdisPauseFunctionsUnsupported */
3381                    auto_negotiation_flags: 0,
3382                };
3383                writer.write(link_state.as_bytes())?;
3384            }
3385            rndisprot::Oid::OID_GEN_MAX_LINK_SPEED => {
3386                let link_speed = rndisprot::LinkSpeed {
3387                    xmit: self.link_speed,
3388                    rcv: self.link_speed,
3389                };
3390                writer.write(link_speed.as_bytes())?;
3391            }
3392            rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES => {
3393                let ndis_offload = self.offload_support.ndis_offload();
3394                writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3395            }
3396            rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG => {
3397                let ndis_offload = primary.offload_config.ndis_offload();
3398                writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3399            }
3400            rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3401                writer.write(
3402                    &rndisprot::NdisOffloadEncapsulation {
3403                        header: rndisprot::NdisObjectHeader {
3404                            object_type: rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3405                            revision: 1,
3406                            size: rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1 as u16,
3407                        },
3408                        ipv4_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3409                        ipv4_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3410                        ipv4_header_size: ETHERNET_HEADER_LEN,
3411                        ipv6_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3412                        ipv6_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3413                        ipv6_header_size: ETHERNET_HEADER_LEN,
3414                    }
3415                    .as_bytes()[..rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1],
3416                )?;
3417            }
3418            rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES => {
3419                writer.write(
3420                    &rndisprot::NdisReceiveScaleCapabilities {
3421                        header: rndisprot::NdisObjectHeader {
3422                            object_type: rndisprot::NdisObjectType::RSS_CAPABILITIES,
3423                            revision: 2,
3424                            size: rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2
3425                                as u16,
3426                        },
3427                        capabilities_flags: rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV4
3428                            | rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV6
3429                            | rndisprot::NDIS_HASH_FUNCTION_TOEPLITZ,
3430                        number_of_interrupt_messages: 1,
3431                        number_of_receive_queues: self.max_queues.into(),
3432                        number_of_indirection_table_entries: if self.indirection_table_size != 0 {
3433                            self.indirection_table_size
3434                        } else {
3435                            // DPDK gets confused if the table size is zero,
3436                            // even if there is only one queue.
3437                            128
3438                        },
3439                        padding: 0,
3440                    }
3441                    .as_bytes()[..rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2],
3442                )?;
3443            }
3444            _ => {
3445                tracelimit::warn_ratelimited!(?oid, "query for unknown OID");
3446                return Err(OidError::UnknownOid);
3447            }
3448        };
3449        Ok(available_len - writer.len())
3450    }
3451
3452    fn handle_oid_set(
3453        &self,
3454        primary: &mut PrimaryChannelState,
3455        oid: rndisprot::Oid,
3456        reader: impl MemoryRead + Clone,
3457    ) -> Result<(bool, Option<u32>), OidError> {
3458        tracing::debug!(?oid, "oid set");
3459
3460        let mut restart_endpoint = false;
3461        let mut packet_filter = None;
3462        match oid {
3463            rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => {
3464                packet_filter = self.oid_set_packet_filter(reader)?;
3465            }
3466            rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS => {
3467                self.oid_set_offload_parameters(reader, primary)?;
3468            }
3469            rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3470                self.oid_set_offload_encapsulation(reader)?;
3471            }
3472            rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER => {
3473                self.oid_set_rndis_config_parameter(reader, primary)?;
3474            }
3475            rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES => {
3476                // TODO
3477            }
3478            rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS => {
3479                self.oid_set_rss_parameters(reader, primary)?;
3480
3481                // Endpoints cannot currently change RSS parameters without
3482                // being restarted. This was a limitation driven by some DPDK
3483                // PMDs, and should be fixed.
3484                restart_endpoint = true;
3485            }
3486            _ => {
3487                tracelimit::warn_ratelimited!(?oid, "set of unknown OID");
3488                return Err(OidError::UnknownOid);
3489            }
3490        }
3491        Ok((restart_endpoint, packet_filter))
3492    }
3493
3494    fn oid_set_rss_parameters(
3495        &self,
3496        mut reader: impl MemoryRead + Clone,
3497        primary: &mut PrimaryChannelState,
3498    ) -> Result<(), OidError> {
3499        // Vmswitch doesn't validate the NDIS header on this object, so read it manually.
3500        let mut params = rndisprot::NdisReceiveScaleParameters::new_zeroed();
3501        let len = reader.len().min(size_of_val(&params));
3502        reader.clone().read(&mut params.as_mut_bytes()[..len])?;
3503
3504        if ((params.flags & NDIS_RSS_PARAM_FLAG_DISABLE_RSS) != 0)
3505            || ((params.hash_information & NDIS_HASH_FUNCTION_MASK) == 0)
3506        {
3507            primary.rss_state = None;
3508            return Ok(());
3509        }
3510
3511        if params.hash_secret_key_size != 40 {
3512            return Err(OidError::InvalidInput("hash_secret_key_size"));
3513        }
3514        if params.indirection_table_size % 4 != 0 {
3515            return Err(OidError::InvalidInput("indirection_table_size"));
3516        }
3517        let indirection_table_size =
3518            (params.indirection_table_size / 4).min(self.indirection_table_size) as usize;
3519        let mut key = [0; 40];
3520        let mut indirection_table = vec![0u32; self.indirection_table_size as usize];
3521        reader
3522            .clone()
3523            .skip(params.hash_secret_key_offset as usize)?
3524            .read(&mut key)?;
3525        reader
3526            .skip(params.indirection_table_offset as usize)?
3527            .read(indirection_table[..indirection_table_size].as_mut_bytes())?;
3528        tracelimit::info_ratelimited!(?indirection_table, "OID_GEN_RECEIVE_SCALE_PARAMETERS");
3529        if indirection_table
3530            .iter()
3531            .any(|&x| x >= self.max_queues as u32)
3532        {
3533            return Err(OidError::InvalidInput("indirection_table"));
3534        }
3535        let (indir_init, indir_uninit) = indirection_table.split_at_mut(indirection_table_size);
3536        for (src, dest) in std::iter::repeat_with(|| indir_init.iter().copied())
3537            .flatten()
3538            .zip(indir_uninit)
3539        {
3540            *dest = src;
3541        }
3542        primary.rss_state = Some(RssState {
3543            key,
3544            indirection_table: indirection_table.iter().map(|&x| x as u16).collect(),
3545        });
3546        Ok(())
3547    }
3548
3549    fn oid_set_packet_filter(
3550        &self,
3551        reader: impl MemoryRead + Clone,
3552    ) -> Result<Option<u32>, OidError> {
3553        let filter: rndisprot::RndisPacketFilterOidValue = reader.clone().read_plain()?;
3554        tracing::debug!(filter, "set packet filter");
3555        Ok(Some(filter))
3556    }
3557
3558    fn oid_set_offload_parameters(
3559        &self,
3560        reader: impl MemoryRead + Clone,
3561        primary: &mut PrimaryChannelState,
3562    ) -> Result<(), OidError> {
3563        let offload: rndisprot::NdisOffloadParameters = read_ndis_object(
3564            reader,
3565            rndisprot::NdisObjectType::DEFAULT,
3566            1,
3567            rndisprot::NDIS_SIZEOF_OFFLOAD_PARAMETERS_REVISION_1,
3568        )?;
3569
3570        tracing::debug!(?offload, "offload parameters");
3571        let rndisprot::NdisOffloadParameters {
3572            header: _,
3573            ipv4_checksum,
3574            tcp4_checksum,
3575            udp4_checksum,
3576            tcp6_checksum,
3577            udp6_checksum,
3578            lsov1,
3579            ipsec_v1: _,
3580            lsov2_ipv4,
3581            lsov2_ipv6,
3582            tcp_connection_ipv4: _,
3583            tcp_connection_ipv6: _,
3584            reserved: _,
3585            flags: _,
3586        } = offload;
3587
3588        if lsov1 == rndisprot::OffloadParametersSimple::ENABLED {
3589            return Err(OidError::NotSupported("lsov1"));
3590        }
3591        if let Some((tx, rx)) = ipv4_checksum.tx_rx() {
3592            primary.offload_config.checksum_tx.ipv4_header = tx;
3593            primary.offload_config.checksum_rx.ipv4_header = rx;
3594        }
3595        if let Some((tx, rx)) = tcp4_checksum.tx_rx() {
3596            primary.offload_config.checksum_tx.tcp4 = tx;
3597            primary.offload_config.checksum_rx.tcp4 = rx;
3598        }
3599        if let Some((tx, rx)) = tcp6_checksum.tx_rx() {
3600            primary.offload_config.checksum_tx.tcp6 = tx;
3601            primary.offload_config.checksum_rx.tcp6 = rx;
3602        }
3603        if let Some((tx, rx)) = udp4_checksum.tx_rx() {
3604            primary.offload_config.checksum_tx.udp4 = tx;
3605            primary.offload_config.checksum_rx.udp4 = rx;
3606        }
3607        if let Some((tx, rx)) = udp6_checksum.tx_rx() {
3608            primary.offload_config.checksum_tx.udp6 = tx;
3609            primary.offload_config.checksum_rx.udp6 = rx;
3610        }
3611        if let Some(enable) = lsov2_ipv4.enable() {
3612            primary.offload_config.lso4 = enable;
3613        }
3614        if let Some(enable) = lsov2_ipv6.enable() {
3615            primary.offload_config.lso6 = enable;
3616        }
3617        primary
3618            .offload_config
3619            .mask_to_supported(&self.offload_support);
3620        primary.pending_offload_change = true;
3621        Ok(())
3622    }
3623
3624    fn oid_set_offload_encapsulation(
3625        &self,
3626        reader: impl MemoryRead + Clone,
3627    ) -> Result<(), OidError> {
3628        let encap: rndisprot::NdisOffloadEncapsulation = read_ndis_object(
3629            reader,
3630            rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3631            1,
3632            rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1,
3633        )?;
3634        if encap.ipv4_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3635            && (encap.ipv4_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3636                || encap.ipv4_header_size != ETHERNET_HEADER_LEN)
3637        {
3638            return Err(OidError::NotSupported("ipv4 encap"));
3639        }
3640        if encap.ipv6_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3641            && (encap.ipv6_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3642                || encap.ipv6_header_size != ETHERNET_HEADER_LEN)
3643        {
3644            return Err(OidError::NotSupported("ipv6 encap"));
3645        }
3646        Ok(())
3647    }
3648
3649    fn oid_set_rndis_config_parameter(
3650        &self,
3651        reader: impl MemoryRead + Clone,
3652        primary: &mut PrimaryChannelState,
3653    ) -> Result<(), OidError> {
3654        let info: rndisprot::RndisConfigParameterInfo = reader.clone().read_plain()?;
3655        if info.name_length > 255 {
3656            return Err(OidError::InvalidInput("name_length"));
3657        }
3658        if info.value_length > 255 {
3659            return Err(OidError::InvalidInput("value_length"));
3660        }
3661        let name = reader
3662            .clone()
3663            .skip(info.name_offset as usize)?
3664            .read_n::<u16>(info.name_length as usize / 2)?;
3665        let name = String::from_utf16(&name).map_err(|_| OidError::InvalidInput("name"))?;
3666        let mut value = reader;
3667        value.skip(info.value_offset as usize)?;
3668        let mut value = value.limit(info.value_length as usize);
3669        match info.parameter_type {
3670            rndisprot::NdisParameterType::STRING => {
3671                let value = value.read_n::<u16>(info.value_length as usize / 2)?;
3672                let value =
3673                    String::from_utf16(&value).map_err(|_| OidError::InvalidInput("value"))?;
3674                let as_num = value.as_bytes().first().map_or(0, |c| c - b'0');
3675                let tx = as_num & 1 != 0;
3676                let rx = as_num & 2 != 0;
3677
3678                tracing::debug!(name, value, "rndis config");
3679                match name.as_str() {
3680                    "*IPChecksumOffloadIPv4" => {
3681                        primary.offload_config.checksum_tx.ipv4_header = tx;
3682                        primary.offload_config.checksum_rx.ipv4_header = rx;
3683                    }
3684                    "*LsoV2IPv4" => {
3685                        primary.offload_config.lso4 = as_num != 0;
3686                    }
3687                    "*LsoV2IPv6" => {
3688                        primary.offload_config.lso6 = as_num != 0;
3689                    }
3690                    "*TCPChecksumOffloadIPv4" => {
3691                        primary.offload_config.checksum_tx.tcp4 = tx;
3692                        primary.offload_config.checksum_rx.tcp4 = rx;
3693                    }
3694                    "*TCPChecksumOffloadIPv6" => {
3695                        primary.offload_config.checksum_tx.tcp6 = tx;
3696                        primary.offload_config.checksum_rx.tcp6 = rx;
3697                    }
3698                    "*UDPChecksumOffloadIPv4" => {
3699                        primary.offload_config.checksum_tx.udp4 = tx;
3700                        primary.offload_config.checksum_rx.udp4 = rx;
3701                    }
3702                    "*UDPChecksumOffloadIPv6" => {
3703                        primary.offload_config.checksum_tx.udp6 = tx;
3704                        primary.offload_config.checksum_rx.udp6 = rx;
3705                    }
3706                    _ => {}
3707                }
3708                primary
3709                    .offload_config
3710                    .mask_to_supported(&self.offload_support);
3711            }
3712            rndisprot::NdisParameterType::INTEGER => {
3713                let value: u32 = value.read_plain()?;
3714                tracing::debug!(name, value, "rndis config");
3715            }
3716            parameter_type => tracelimit::warn_ratelimited!(
3717                name,
3718                ?parameter_type,
3719                "unhandled rndis config parameter type"
3720            ),
3721        }
3722        Ok(())
3723    }
3724}
3725
3726fn read_ndis_object<T: IntoBytes + FromBytes + Debug + Immutable + KnownLayout>(
3727    mut reader: impl MemoryRead,
3728    object_type: rndisprot::NdisObjectType,
3729    min_revision: u8,
3730    min_size: usize,
3731) -> Result<T, OidError> {
3732    let mut buffer = T::new_zeroed();
3733    let sent_size = reader.len();
3734    let len = sent_size.min(size_of_val(&buffer));
3735    reader.read(&mut buffer.as_mut_bytes()[..len])?;
3736    validate_ndis_object_header(
3737        &rndisprot::NdisObjectHeader::read_from_prefix(buffer.as_bytes())
3738            .unwrap()
3739            .0, // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
3740        sent_size,
3741        object_type,
3742        min_revision,
3743        min_size,
3744    )?;
3745    Ok(buffer)
3746}
3747
3748fn validate_ndis_object_header(
3749    header: &rndisprot::NdisObjectHeader,
3750    sent_size: usize,
3751    object_type: rndisprot::NdisObjectType,
3752    min_revision: u8,
3753    min_size: usize,
3754) -> Result<(), OidError> {
3755    if header.object_type != object_type {
3756        return Err(OidError::InvalidInput("header.object_type"));
3757    }
3758    if sent_size < header.size as usize {
3759        return Err(OidError::InvalidInput("header.size"));
3760    }
3761    if header.revision < min_revision {
3762        return Err(OidError::InvalidInput("header.revision"));
3763    }
3764    if (header.size as usize) < min_size {
3765        return Err(OidError::InvalidInput("header.size"));
3766    }
3767    Ok(())
3768}
3769
3770struct Coordinator {
3771    recv: mpsc::Receiver<CoordinatorMessage>,
3772    channel_control: ChannelControl,
3773    restart: bool,
3774    workers: Vec<TaskControl<NetQueue, Worker<GpadlRingMem>>>,
3775    buffers: Option<Arc<ChannelBuffers>>,
3776    num_queues: u16,
3777    active_packet_filter: u32,
3778    sleep_deadline: Option<Instant>,
3779}
3780
3781/// Removing the VF may result in the guest sending messages to switch the data
3782/// path, so these operations need to happen asynchronously with message
3783/// processing.
3784enum CoordinatorStatePendingVfState {
3785    /// No pending updates.
3786    Ready,
3787    /// Delay before adding VF.
3788    Delay {
3789        timer: PolledTimer,
3790        delay_until: Instant,
3791    },
3792    /// A VF update is pending.
3793    Pending,
3794}
3795
3796struct CoordinatorState {
3797    endpoint: Box<dyn Endpoint>,
3798    adapter: Arc<Adapter>,
3799    virtual_function: Option<Box<dyn VirtualFunction>>,
3800    pending_vf_state: CoordinatorStatePendingVfState,
3801}
3802
3803impl InspectTaskMut<Coordinator> for CoordinatorState {
3804    fn inspect_mut(
3805        &mut self,
3806        req: inspect::Request<'_>,
3807        mut coordinator: Option<&mut Coordinator>,
3808    ) {
3809        let mut resp = req.respond();
3810
3811        let adapter = self.adapter.as_ref();
3812        resp.field("mac_address", adapter.mac_address)
3813            .field("max_queues", adapter.max_queues)
3814            .sensitivity_field(
3815                "offload_support",
3816                SensitivityLevel::Safe,
3817                &adapter.offload_support,
3818            )
3819            .field_mut_with("ring_size_limit", |v| -> anyhow::Result<_> {
3820                if let Some(v) = v {
3821                    let v: usize = v.parse()?;
3822                    adapter.ring_size_limit.store(v, Ordering::Relaxed);
3823                    // Bounce each task so that it sees the new value.
3824                    if let Some(this) = &mut coordinator {
3825                        for worker in &mut this.workers {
3826                            worker.update_with(|_, _| ());
3827                        }
3828                    }
3829                }
3830                Ok(adapter.ring_size_limit.load(Ordering::Relaxed))
3831            });
3832
3833        resp.field("endpoint_type", self.endpoint.endpoint_type())
3834            .field(
3835                "endpoint_max_queues",
3836                self.endpoint.multiqueue_support().max_queues,
3837            )
3838            .sensitivity_field_mut("endpoint", SensitivityLevel::Safe, self.endpoint.as_mut());
3839
3840        if let Some(coordinator) = coordinator {
3841            resp.sensitivity_child("queues", SensitivityLevel::Safe, |req| {
3842                let mut resp = req.respond();
3843                for (i, q) in coordinator.workers[..coordinator.num_queues as usize]
3844                    .iter_mut()
3845                    .enumerate()
3846                {
3847                    resp.field_mut(&i.to_string(), q);
3848                }
3849            });
3850
3851            // Get the shared channel state from the primary channel.
3852            resp.merge(inspect::adhoc_mut(|req| {
3853                let deferred = req.defer();
3854                coordinator.workers[0].update_with(|_, worker| {
3855                    let Some(worker) = worker.as_deref() else {
3856                        return;
3857                    };
3858                    if let Some(state) = worker.state.ready() {
3859                        deferred.respond(|resp| {
3860                            resp.merge(&state.buffers);
3861                            resp.sensitivity_field(
3862                                "primary_channel_state",
3863                                SensitivityLevel::Safe,
3864                                &state.state.primary,
3865                            )
3866                            .sensitivity_field(
3867                                "packet_filter",
3868                                SensitivityLevel::Safe,
3869                                inspect::AsHex(worker.channel.packet_filter),
3870                            );
3871                        });
3872                    }
3873                })
3874            }));
3875        }
3876    }
3877}
3878
3879impl AsyncRun<Coordinator> for CoordinatorState {
3880    async fn run(
3881        &mut self,
3882        stop: &mut StopTask<'_>,
3883        coordinator: &mut Coordinator,
3884    ) -> Result<(), task_control::Cancelled> {
3885        coordinator.process(stop, self).await
3886    }
3887}
3888
3889impl Coordinator {
3890    async fn process(
3891        &mut self,
3892        stop: &mut StopTask<'_>,
3893        state: &mut CoordinatorState,
3894    ) -> Result<(), task_control::Cancelled> {
3895        loop {
3896            if self.restart {
3897                stop.until_stopped(self.stop_workers()).await?;
3898                // The queue restart operation is not restartable, so do not
3899                // poll on `stop` here.
3900                if let Err(err) = self
3901                    .restart_queues(state)
3902                    .instrument(tracing::info_span!("netvsp_restart_queues"))
3903                    .await
3904                {
3905                    tracing::error!(
3906                        error = &err as &dyn std::error::Error,
3907                        "failed to restart queues"
3908                    );
3909                }
3910                if let Some(primary) = self.primary_mut() {
3911                    primary.is_data_path_switched =
3912                        state.endpoint.get_data_path_to_guest_vf().await.ok();
3913                    tracing::info!(
3914                        is_data_path_switched = primary.is_data_path_switched,
3915                        "Query data path state"
3916                    );
3917                }
3918                self.restore_guest_vf_state(state).await;
3919                self.restart = false;
3920            }
3921
3922            // Ensure that all workers except the primary are started. The
3923            // primary is started below if there are no outstanding messages.
3924            for worker in &mut self.workers[1..] {
3925                worker.start();
3926            }
3927            if !self.workers[0].is_running()
3928                && self.workers[0].state().is_none_or(|worker| {
3929                    !matches!(worker.state, WorkerState::WaitingForCoordinator(_))
3930                })
3931            {
3932                self.workers[0].start();
3933            }
3934
3935            enum Message {
3936                Internal(CoordinatorMessage),
3937                ChannelDisconnected,
3938                UpdateFromEndpoint(EndpointAction),
3939                UpdateFromVf(Rpc<(), ()>),
3940                OfferVfDevice,
3941                PendingVfStateComplete,
3942                TimerExpired,
3943            }
3944            let message = if matches!(
3945                state.pending_vf_state,
3946                CoordinatorStatePendingVfState::Pending
3947            ) {
3948                // guest_ready_for_device is not restartable, so do not poll on
3949                // stop.
3950                state
3951                    .virtual_function
3952                    .as_mut()
3953                    .expect("Pending requires a VF")
3954                    .guest_ready_for_device()
3955                    .await;
3956                Message::PendingVfStateComplete
3957            } else {
3958                let timer_sleep = async {
3959                    if let Some(deadline) = self.sleep_deadline {
3960                        let mut timer = PolledTimer::new(&state.adapter.driver);
3961                        timer.sleep_until(deadline).await;
3962                    } else {
3963                        pending::<()>().await;
3964                    }
3965                    Message::TimerExpired
3966                };
3967                let wait_for_message = async {
3968                    let internal_msg = self
3969                        .recv
3970                        .next()
3971                        .map(|x| x.map_or(Message::ChannelDisconnected, Message::Internal));
3972                    let endpoint_restart = state
3973                        .endpoint
3974                        .wait_for_endpoint_action()
3975                        .map(Message::UpdateFromEndpoint);
3976                    if let Some(vf) = state.virtual_function.as_mut() {
3977                        match state.pending_vf_state {
3978                            CoordinatorStatePendingVfState::Ready
3979                            | CoordinatorStatePendingVfState::Delay { .. } => {
3980                                let offer_device = async {
3981                                    if let CoordinatorStatePendingVfState::Delay {
3982                                        timer,
3983                                        delay_until,
3984                                    } = &mut state.pending_vf_state
3985                                    {
3986                                        timer.sleep_until(*delay_until).await;
3987                                    } else {
3988                                        pending::<()>().await;
3989                                    }
3990                                    Message::OfferVfDevice
3991                                };
3992                                (
3993                                    internal_msg,
3994                                    offer_device,
3995                                    endpoint_restart,
3996                                    vf.wait_for_state_change().map(Message::UpdateFromVf),
3997                                    timer_sleep,
3998                                )
3999                                    .race()
4000                                    .await
4001                            }
4002                            CoordinatorStatePendingVfState::Pending => unreachable!(),
4003                        }
4004                    } else {
4005                        (internal_msg, endpoint_restart, timer_sleep).race().await
4006                    }
4007                };
4008
4009                stop.until_stopped(wait_for_message).await?
4010            };
4011            match message {
4012                Message::Internal(msg) => {
4013                    self.handle_coordinator_message(msg, state).await;
4014                }
4015                Message::UpdateFromVf(rpc) => {
4016                    rpc.handle(async |_| {
4017                        self.update_guest_vf_state(state).await;
4018                    })
4019                    .await;
4020                }
4021                Message::OfferVfDevice => {
4022                    self.workers[0].stop().await;
4023                    if let Some(primary) = self.primary_mut() {
4024                        if matches!(
4025                            primary.guest_vf_state,
4026                            PrimaryChannelGuestVfState::AvailableAdvertised
4027                        ) {
4028                            primary.guest_vf_state = PrimaryChannelGuestVfState::Ready;
4029                        }
4030                    }
4031
4032                    state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4033                }
4034                Message::PendingVfStateComplete => {
4035                    state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
4036                }
4037                Message::TimerExpired => {
4038                    // Kick the worker as requested.
4039                    self.workers[0].stop().await;
4040                    if let Some(primary) = self.primary_mut() {
4041                        if let PendingLinkAction::Delay(up) = primary.pending_link_action {
4042                            primary.pending_link_action = PendingLinkAction::Active(up);
4043                        }
4044                    }
4045                    self.sleep_deadline = None;
4046                }
4047                Message::UpdateFromEndpoint(EndpointAction::RestartRequired) => self.restart = true,
4048                Message::UpdateFromEndpoint(EndpointAction::LinkStatusNotify(connect)) => {
4049                    self.workers[0].stop().await;
4050
4051                    // These are the only link state transitions that are tracked.
4052                    // 1. up -> down or down -> up
4053                    // 2. up -> down -> up or down -> up -> down.
4054                    // All other state transitions are coalesced into one of the above cases.
4055                    // For example, up -> down -> up -> down is treated as up -> down.
4056                    // N.B - Always queue up the incoming state to minimize the effects of loss
4057                    //       of any notifications (for example, during vtl2 servicing).
4058                    if let Some(primary) = self.primary_mut() {
4059                        primary.pending_link_action = PendingLinkAction::Active(connect);
4060                    }
4061
4062                    // If there is any existing sleep timer running, cancel it out.
4063                    self.sleep_deadline = None;
4064                }
4065                Message::ChannelDisconnected => {
4066                    break;
4067                }
4068            };
4069        }
4070        Ok(())
4071    }
4072
4073    async fn handle_coordinator_message(
4074        &mut self,
4075        msg: CoordinatorMessage,
4076        state: &mut CoordinatorState,
4077    ) {
4078        self.workers[0].stop().await;
4079        if let Some(worker) = self.workers[0].state_mut() {
4080            if matches!(worker.state, WorkerState::WaitingForCoordinator(_)) {
4081                let WorkerState::WaitingForCoordinator(Some(ready)) =
4082                    std::mem::replace(&mut worker.state, WorkerState::WaitingForCoordinator(None))
4083                else {
4084                    unreachable!("valid ready state")
4085                };
4086                let _ = std::mem::replace(&mut worker.state, WorkerState::Ready(ready));
4087            }
4088        }
4089        match msg {
4090            CoordinatorMessage::Update(update_type) => {
4091                if update_type.filter_state {
4092                    self.stop_workers().await;
4093                    self.active_packet_filter =
4094                        self.workers[0].state().unwrap().channel.packet_filter;
4095                    self.workers.iter_mut().skip(1).for_each(|worker| {
4096                        if let Some(state) = worker.state_mut() {
4097                            state.channel.packet_filter = self.active_packet_filter;
4098                            tracing::debug!(
4099                                packet_filter = ?self.active_packet_filter,
4100                                channel_idx = state.channel_idx,
4101                                "update packet filter"
4102                            );
4103                        }
4104                    });
4105                }
4106
4107                if update_type.guest_vf_state {
4108                    self.update_guest_vf_state(state).await;
4109                }
4110            }
4111            CoordinatorMessage::StartTimer(deadline) => {
4112                self.sleep_deadline = Some(deadline);
4113            }
4114            CoordinatorMessage::Restart => self.restart = true,
4115        }
4116    }
4117
4118    async fn stop_workers(&mut self) {
4119        for worker in &mut self.workers {
4120            worker.stop().await;
4121        }
4122    }
4123
4124    async fn restore_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
4125        let primary = match self.primary_mut() {
4126            Some(primary) => primary,
4127            None => return,
4128        };
4129
4130        // Update guest VF state based on current endpoint properties.
4131        let virtual_function = c_state.virtual_function.as_mut();
4132        let guest_vf_id = match &virtual_function {
4133            Some(vf) => vf.id().await,
4134            None => None,
4135        };
4136        if let Some(guest_vf_id) = guest_vf_id {
4137            // Ensure guest VF is in proper state.
4138            match primary.guest_vf_state {
4139                PrimaryChannelGuestVfState::AvailableAdvertised
4140                | PrimaryChannelGuestVfState::Restoring(
4141                    saved_state::GuestVfState::AvailableAdvertised,
4142                ) => {
4143                    if !primary.is_data_path_switched.unwrap_or(false) {
4144                        let timer = PolledTimer::new(&c_state.adapter.driver);
4145                        c_state.pending_vf_state = CoordinatorStatePendingVfState::Delay {
4146                            timer,
4147                            delay_until: Instant::now() + VF_DEVICE_DELAY,
4148                        };
4149                    }
4150                }
4151                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4152                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4153                | PrimaryChannelGuestVfState::Ready
4154                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready)
4155                | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4156                | PrimaryChannelGuestVfState::Restoring(
4157                    saved_state::GuestVfState::DataPathSwitchPending { .. },
4158                )
4159                | PrimaryChannelGuestVfState::DataPathSwitched
4160                | PrimaryChannelGuestVfState::Restoring(
4161                    saved_state::GuestVfState::DataPathSwitched,
4162                )
4163                | PrimaryChannelGuestVfState::DataPathSynthetic => {
4164                    c_state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4165                }
4166                _ => (),
4167            };
4168            // ensure data path is switched as expected
4169            if let PrimaryChannelGuestVfState::Restoring(
4170                saved_state::GuestVfState::DataPathSwitchPending {
4171                    to_guest,
4172                    id,
4173                    result,
4174                },
4175            ) = primary.guest_vf_state
4176            {
4177                // If the save was after the data path switch already occurred, don't do it again.
4178                if result.is_some() {
4179                    primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
4180                        to_guest,
4181                        id,
4182                        result,
4183                    };
4184                    return;
4185                }
4186            }
4187            primary.guest_vf_state = match primary.guest_vf_state {
4188                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4189                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4190                | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4191                | PrimaryChannelGuestVfState::Restoring(
4192                    saved_state::GuestVfState::DataPathSwitchPending { .. },
4193                )
4194                | PrimaryChannelGuestVfState::DataPathSynthetic => {
4195                    let (to_guest, id) = match primary.guest_vf_state {
4196                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4197                            to_guest,
4198                            id,
4199                        }
4200                        | PrimaryChannelGuestVfState::DataPathSwitchPending {
4201                            to_guest, id, ..
4202                        }
4203                        | PrimaryChannelGuestVfState::Restoring(
4204                            saved_state::GuestVfState::DataPathSwitchPending {
4205                                to_guest, id, ..
4206                            },
4207                        ) => (to_guest, id),
4208                        _ => (true, None),
4209                    };
4210                    // Cancel any outstanding delay timers for VF offers if the data path is
4211                    // getting switched, since the guest is already issuing
4212                    // commands assuming a VF.
4213                    if matches!(
4214                        c_state.pending_vf_state,
4215                        CoordinatorStatePendingVfState::Delay { .. }
4216                    ) {
4217                        c_state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4218                    }
4219                    let result = c_state.endpoint.set_data_path_to_guest_vf(to_guest).await;
4220                    let result = if let Err(err) = result {
4221                        tracing::error!(
4222                            err = err.as_ref() as &dyn std::error::Error,
4223                            to_guest,
4224                            "Failed to switch guest VF data path"
4225                        );
4226                        false
4227                    } else {
4228                        primary.is_data_path_switched = Some(to_guest);
4229                        true
4230                    };
4231                    match primary.guest_vf_state {
4232                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4233                            ..
4234                        }
4235                        | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4236                        | PrimaryChannelGuestVfState::Restoring(
4237                            saved_state::GuestVfState::DataPathSwitchPending { .. },
4238                        ) => PrimaryChannelGuestVfState::DataPathSwitchPending {
4239                            to_guest,
4240                            id,
4241                            result: Some(result),
4242                        },
4243                        _ if result => PrimaryChannelGuestVfState::DataPathSwitched,
4244                        _ => PrimaryChannelGuestVfState::DataPathSynthetic,
4245                    }
4246                }
4247                PrimaryChannelGuestVfState::Initializing
4248                | PrimaryChannelGuestVfState::Unavailable
4249                | PrimaryChannelGuestVfState::UnavailableFromAvailable
4250                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
4251                    PrimaryChannelGuestVfState::Available { vfid: guest_vf_id }
4252                }
4253                PrimaryChannelGuestVfState::AvailableAdvertised
4254                | PrimaryChannelGuestVfState::Restoring(
4255                    saved_state::GuestVfState::AvailableAdvertised,
4256                ) => {
4257                    if !primary.is_data_path_switched.unwrap_or(false) {
4258                        PrimaryChannelGuestVfState::AvailableAdvertised
4259                    } else {
4260                        // A previous instantiation already switched the data
4261                        // path.
4262                        PrimaryChannelGuestVfState::DataPathSwitched
4263                    }
4264                }
4265                PrimaryChannelGuestVfState::DataPathSwitched
4266                | PrimaryChannelGuestVfState::Restoring(
4267                    saved_state::GuestVfState::DataPathSwitched,
4268                ) => PrimaryChannelGuestVfState::DataPathSwitched,
4269                PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4270                    PrimaryChannelGuestVfState::Ready
4271                }
4272                _ => primary.guest_vf_state,
4273            };
4274        } else {
4275            // If the device was just removed, make sure the the data path is synthetic.
4276            match primary.guest_vf_state {
4277                PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, .. }
4278                | PrimaryChannelGuestVfState::Restoring(
4279                    saved_state::GuestVfState::DataPathSwitchPending { to_guest, .. },
4280                ) => {
4281                    if !to_guest {
4282                        if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4283                            tracing::warn!(
4284                                err = err.as_ref() as &dyn std::error::Error,
4285                                "Failed setting data path back to synthetic after guest VF was removed."
4286                            );
4287                        }
4288                        primary.is_data_path_switched = Some(false);
4289                    }
4290                }
4291                PrimaryChannelGuestVfState::DataPathSwitched
4292                | PrimaryChannelGuestVfState::Restoring(
4293                    saved_state::GuestVfState::DataPathSwitched,
4294                ) => {
4295                    if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4296                        tracing::warn!(
4297                            err = err.as_ref() as &dyn std::error::Error,
4298                            "Failed setting data path back to synthetic after guest VF was removed."
4299                        );
4300                    }
4301                    primary.is_data_path_switched = Some(false);
4302                }
4303                _ => (),
4304            }
4305            if let PrimaryChannelGuestVfState::AvailableAdvertised = primary.guest_vf_state {
4306                c_state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
4307            }
4308            // Notify guest if VF is no longer available
4309            primary.guest_vf_state = match primary.guest_vf_state {
4310                PrimaryChannelGuestVfState::Initializing
4311                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState)
4312                | PrimaryChannelGuestVfState::Available { .. } => {
4313                    PrimaryChannelGuestVfState::Unavailable
4314                }
4315                PrimaryChannelGuestVfState::AvailableAdvertised
4316                | PrimaryChannelGuestVfState::Restoring(
4317                    saved_state::GuestVfState::AvailableAdvertised,
4318                )
4319                | PrimaryChannelGuestVfState::Ready
4320                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4321                    PrimaryChannelGuestVfState::UnavailableFromAvailable
4322                }
4323                PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, id, .. }
4324                | PrimaryChannelGuestVfState::Restoring(
4325                    saved_state::GuestVfState::DataPathSwitchPending { to_guest, id, .. },
4326                ) => PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4327                    to_guest,
4328                    id,
4329                },
4330                PrimaryChannelGuestVfState::DataPathSwitched
4331                | PrimaryChannelGuestVfState::Restoring(
4332                    saved_state::GuestVfState::DataPathSwitched,
4333                )
4334                | PrimaryChannelGuestVfState::DataPathSynthetic => {
4335                    PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4336                }
4337                _ => primary.guest_vf_state,
4338            }
4339        }
4340    }
4341
4342    async fn restart_queues(&mut self, c_state: &mut CoordinatorState) -> Result<(), WorkerError> {
4343        // Drop all the queues and stop the endpoint. Collect the worker drivers to pass to the queues.
4344        let drivers = self
4345            .workers
4346            .iter_mut()
4347            .map(|worker| {
4348                let task = worker.task_mut();
4349                task.queue_state = None;
4350                task.driver.clone()
4351            })
4352            .collect::<Vec<_>>();
4353
4354        c_state.endpoint.stop().await;
4355
4356        let (primary_worker, subworkers) = if let [primary, sub @ ..] = self.workers.as_mut_slice()
4357        {
4358            (primary, sub)
4359        } else {
4360            unreachable!()
4361        };
4362
4363        let state = primary_worker
4364            .state_mut()
4365            .and_then(|worker| worker.state.ready_mut());
4366
4367        let state = if let Some(state) = state {
4368            state
4369        } else {
4370            return Ok(());
4371        };
4372
4373        // Save the channel buffers for use in the subchannel workers.
4374        self.buffers = Some(state.buffers.clone());
4375
4376        let num_queues = state.state.primary.as_ref().unwrap().requested_num_queues;
4377        let mut active_queues = Vec::new();
4378        let active_queue_count =
4379            if let Some(rss_state) = state.state.primary.as_ref().unwrap().rss_state.as_ref() {
4380                // Active queue count is computed as the number of unique entries in the indirection table
4381                active_queues.clone_from(&rss_state.indirection_table);
4382                active_queues.sort();
4383                active_queues.dedup();
4384                active_queues = active_queues
4385                    .into_iter()
4386                    .filter(|&index| index < num_queues)
4387                    .collect::<Vec<_>>();
4388                if !active_queues.is_empty() {
4389                    active_queues.len() as u16
4390                } else {
4391                    tracelimit::warn_ratelimited!("Invalid RSS indirection table");
4392                    num_queues
4393                }
4394            } else {
4395                num_queues
4396            };
4397
4398        // Distribute the rx buffers to only the active queues.
4399        let (ranges, mut remote_buffer_id_recvs) =
4400            RxBufferRanges::new(state.buffers.recv_buffer.count, active_queue_count.into());
4401        let ranges = Arc::new(ranges);
4402
4403        let mut queues = Vec::new();
4404        let mut rx_buffers = Vec::new();
4405        {
4406            let buffers = &state.buffers;
4407            let guest_buffers = Arc::new(
4408                GuestBuffers::new(
4409                    buffers.mem.clone(),
4410                    buffers.recv_buffer.gpadl.clone(),
4411                    buffers.recv_buffer.sub_allocation_size,
4412                    buffers.ndis_config.mtu,
4413                )
4414                .map_err(WorkerError::GpadlError)?,
4415            );
4416
4417            // Get the list of free rx buffers from each task, then partition
4418            // the list per-active-queue, and produce the queue configuration.
4419            let mut queue_config = Vec::new();
4420            let initial_rx;
4421            {
4422                let states = std::iter::once(Some(&*state)).chain(
4423                    subworkers
4424                        .iter()
4425                        .map(|worker| worker.state().and_then(|worker| worker.state.ready())),
4426                );
4427
4428                initial_rx = (RX_RESERVED_CONTROL_BUFFERS..state.buffers.recv_buffer.count)
4429                    .filter(|&n| states.clone().flatten().all(|s| s.state.rx_bufs.is_free(n)))
4430                    .map(RxId)
4431                    .collect::<Vec<_>>();
4432
4433                let mut initial_rx = initial_rx.as_slice();
4434                let mut range_start = 0;
4435                let primary_queue_excluded = !active_queues.is_empty() && active_queues[0] != 0;
4436                let first_queue = if !primary_queue_excluded {
4437                    0
4438                } else {
4439                    // If the primary queue is excluded from the guest supplied
4440                    // indirection table, it is assigned just the reserved
4441                    // buffers.
4442                    queue_config.push(QueueConfig {
4443                        pool: Box::new(BufferPool::new(guest_buffers.clone())),
4444                        initial_rx: &[],
4445                        driver: Box::new(drivers[0].clone()),
4446                    });
4447                    rx_buffers.push(RxBufferRange::new(
4448                        ranges.clone(),
4449                        0..RX_RESERVED_CONTROL_BUFFERS,
4450                        None,
4451                    ));
4452                    range_start = RX_RESERVED_CONTROL_BUFFERS;
4453                    1
4454                };
4455                for queue_index in first_queue..num_queues {
4456                    let queue_active = active_queues.is_empty()
4457                        || active_queues.binary_search(&queue_index).is_ok();
4458                    let (range_end, end, buffer_id_recv) = if queue_active {
4459                        let range_end = if rx_buffers.len() as u16 == active_queue_count - 1 {
4460                            // The last queue gets all the remaining buffers.
4461                            state.buffers.recv_buffer.count
4462                        } else if queue_index == 0 {
4463                            // Queue zero always includes the reserved buffers.
4464                            RX_RESERVED_CONTROL_BUFFERS + ranges.buffers_per_queue
4465                        } else {
4466                            range_start + ranges.buffers_per_queue
4467                        };
4468                        (
4469                            range_end,
4470                            initial_rx.partition_point(|id| id.0 < range_end),
4471                            Some(remote_buffer_id_recvs.remove(0)),
4472                        )
4473                    } else {
4474                        (range_start, 0, None)
4475                    };
4476
4477                    let (this, rest) = initial_rx.split_at(end);
4478                    queue_config.push(QueueConfig {
4479                        pool: Box::new(BufferPool::new(guest_buffers.clone())),
4480                        initial_rx: this,
4481                        driver: Box::new(drivers[queue_index as usize].clone()),
4482                    });
4483                    initial_rx = rest;
4484                    rx_buffers.push(RxBufferRange::new(
4485                        ranges.clone(),
4486                        range_start..range_end,
4487                        buffer_id_recv,
4488                    ));
4489
4490                    range_start = range_end;
4491                }
4492            }
4493
4494            let primary = state.state.primary.as_mut().unwrap();
4495            tracing::debug!(num_queues, "enabling endpoint");
4496
4497            let rss = primary
4498                .rss_state
4499                .as_ref()
4500                .map(|rss| net_backend::RssConfig {
4501                    key: &rss.key,
4502                    indirection_table: &rss.indirection_table,
4503                    flags: 0,
4504                });
4505
4506            c_state
4507                .endpoint
4508                .get_queues(queue_config, rss.as_ref(), &mut queues)
4509                .instrument(tracing::info_span!("netvsp_get_queues"))
4510                .await
4511                .map_err(WorkerError::Endpoint)?;
4512
4513            assert_eq!(queues.len(), num_queues as usize);
4514
4515            // Set the subchannel count.
4516            self.channel_control
4517                .enable_subchannels(num_queues - 1)
4518                .expect("already validated");
4519
4520            self.num_queues = num_queues;
4521        }
4522
4523        self.active_packet_filter = self.workers[0].state().unwrap().channel.packet_filter;
4524        // Provide the queue and receive buffer ranges for each worker.
4525        for ((worker, queue), rx_buffer) in self.workers.iter_mut().zip(queues).zip(rx_buffers) {
4526            worker.task_mut().queue_state = Some(QueueState {
4527                queue,
4528                target_vp_set: false,
4529                rx_buffer_range: rx_buffer,
4530            });
4531            // Update the receive packet filter for the subchannel worker.
4532            if let Some(worker) = worker.state_mut() {
4533                worker.channel.packet_filter = self.active_packet_filter;
4534                // Clear any pending RxIds as buffers were redistributed.
4535                if let Some(ready_state) = worker.state.ready_mut() {
4536                    ready_state.state.pending_rx_packets.clear();
4537                }
4538            }
4539        }
4540
4541        Ok(())
4542    }
4543
4544    fn primary_mut(&mut self) -> Option<&mut PrimaryChannelState> {
4545        self.workers[0]
4546            .state_mut()
4547            .unwrap()
4548            .state
4549            .ready_mut()?
4550            .state
4551            .primary
4552            .as_mut()
4553    }
4554
4555    async fn update_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
4556        self.workers[0].stop().await;
4557        self.restore_guest_vf_state(c_state).await;
4558    }
4559}
4560
4561impl<T: RingMem + 'static + Sync> AsyncRun<Worker<T>> for NetQueue {
4562    async fn run(
4563        &mut self,
4564        stop: &mut StopTask<'_>,
4565        worker: &mut Worker<T>,
4566    ) -> Result<(), task_control::Cancelled> {
4567        match worker.process(stop, self).await {
4568            Ok(()) | Err(WorkerError::BufferRevoked) => {}
4569            Err(WorkerError::Cancelled(cancelled)) => return Err(cancelled),
4570            Err(err) => {
4571                tracing::error!(
4572                    channel_idx = worker.channel_idx,
4573                    error = &err as &dyn std::error::Error,
4574                    "netvsp error"
4575                );
4576            }
4577        }
4578        Ok(())
4579    }
4580}
4581
4582impl<T: RingMem + 'static> Worker<T> {
4583    async fn process(
4584        &mut self,
4585        stop: &mut StopTask<'_>,
4586        queue: &mut NetQueue,
4587    ) -> Result<(), WorkerError> {
4588        // Be careful not to wait on actions with unbounded blocking time (e.g.
4589        // guest actions, or waiting for network packets to arrive) without
4590        // wrapping the wait on `stop.until_stopped`.
4591        loop {
4592            match &mut self.state {
4593                WorkerState::Init(initializing) => {
4594                    assert_eq!(self.channel_idx, 0);
4595
4596                    tracelimit::info_ratelimited!("network accepted");
4597
4598                    let (buffers, state) = stop
4599                        .until_stopped(self.channel.initialize(initializing, self.mem.clone()))
4600                        .await??;
4601
4602                    let state = ReadyState {
4603                        buffers: Arc::new(buffers),
4604                        state,
4605                        data: ProcessingData::new(),
4606                    };
4607
4608                    // Wake up the coordinator task to start the queues.
4609                    let _ = self.coordinator_send.try_send(CoordinatorMessage::Restart);
4610
4611                    tracelimit::info_ratelimited!("network initialized");
4612                    self.state = WorkerState::WaitingForCoordinator(Some(state));
4613                }
4614                WorkerState::WaitingForCoordinator(_) => {
4615                    assert_eq!(self.channel_idx, 0);
4616                    // Waiting for the coordinator to process a message and
4617                    // restart the primary worker.
4618                    stop.until_stopped(pending()).await?
4619                }
4620                WorkerState::Ready(state) => {
4621                    let queue_state = if let Some(queue_state) = &mut queue.queue_state {
4622                        if !queue_state.target_vp_set
4623                            && self.target_vp != vmbus_core::protocol::VP_INDEX_DISABLE_INTERRUPT
4624                        {
4625                            tracing::debug!(
4626                                channel_idx = self.channel_idx,
4627                                target_vp = self.target_vp,
4628                                "updating target VP"
4629                            );
4630                            queue_state.queue.update_target_vp(self.target_vp).await;
4631                            queue_state.target_vp_set = true;
4632                        }
4633
4634                        queue_state
4635                    } else {
4636                        // This task will be restarted when the queues are ready.
4637                        stop.until_stopped(pending()).await?
4638                    };
4639
4640                    let result = self.channel.main_loop(stop, state, queue_state).await;
4641                    let msg = match result {
4642                        Ok(restart) => {
4643                            assert_eq!(self.channel_idx, 0);
4644                            restart
4645                        }
4646                        Err(WorkerError::EndpointRequiresQueueRestart(err)) => {
4647                            tracelimit::warn_ratelimited!(
4648                                err = err.as_ref() as &dyn std::error::Error,
4649                                "Endpoint requires queues to restart",
4650                            );
4651                            CoordinatorMessage::Restart
4652                        }
4653                        Err(err) => return Err(err),
4654                    };
4655
4656                    let WorkerState::Ready(ready) = std::mem::replace(
4657                        &mut self.state,
4658                        WorkerState::WaitingForCoordinator(None),
4659                    ) else {
4660                        unreachable!("must be running in ready state")
4661                    };
4662                    let _ = std::mem::replace(
4663                        &mut self.state,
4664                        WorkerState::WaitingForCoordinator(Some(ready)),
4665                    );
4666                    self.coordinator_send
4667                        .try_send(msg)
4668                        .map_err(WorkerError::CoordinatorMessageSendFailed)?;
4669                    stop.until_stopped(pending()).await?
4670                }
4671            }
4672        }
4673    }
4674}
4675
4676impl<T: 'static + RingMem> NetChannel<T> {
4677    fn try_next_packet<'a>(
4678        &mut self,
4679        send_buffer: Option<&SendBuffer>,
4680        version: Option<Version>,
4681        external_data: &'a mut MultiPagedRangeBuf,
4682    ) -> Result<Option<Packet<'a>>, WorkerError> {
4683        let (mut read, _) = self.queue.split();
4684        let packet = match read.try_read() {
4685            Ok(packet) => parse_packet(&packet, send_buffer, version, external_data)
4686                .map_err(WorkerError::Packet)?,
4687            Err(queue::TryReadError::Empty) => return Ok(None),
4688            Err(queue::TryReadError::Queue(err)) => return Err(err.into()),
4689        };
4690
4691        tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
4692        Ok(Some(packet))
4693    }
4694
4695    async fn next_packet<'a>(
4696        &mut self,
4697        send_buffer: Option<&'a SendBuffer>,
4698        version: Option<Version>,
4699        external_data: &'a mut MultiPagedRangeBuf,
4700    ) -> Result<Packet<'a>, WorkerError> {
4701        let (mut read, _) = self.queue.split();
4702        let mut packet_ref = read.read().await?;
4703        let packet = parse_packet(&packet_ref, send_buffer, version, external_data)
4704            .map_err(WorkerError::Packet)?;
4705        if matches!(packet.data, PacketData::RndisPacket(_)) {
4706            // In WorkerState::Init if an rndis packet is received, assume it is MESSAGE_TYPE_INITIALIZE_MSG
4707            tracing::trace!(target: "netvsp/vmbus", "detected rndis initialization message");
4708            packet_ref.revert();
4709        }
4710        tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
4711        Ok(packet)
4712    }
4713
4714    fn is_ready_to_initialize(initializing: &InitState, allow_missing_send_buffer: bool) -> bool {
4715        (initializing.ndis_config.is_some() || initializing.version < Version::V2)
4716            && initializing.ndis_version.is_some()
4717            && (initializing.send_buffer.is_some() || allow_missing_send_buffer)
4718            && initializing.recv_buffer.is_some()
4719    }
4720
4721    async fn initialize(
4722        &mut self,
4723        initializing: &mut Option<InitState>,
4724        mem: GuestMemory,
4725    ) -> Result<(ChannelBuffers, ActiveState), WorkerError> {
4726        let mut has_init_packet_arrived = false;
4727        loop {
4728            if let Some(initializing) = &mut *initializing {
4729                if Self::is_ready_to_initialize(initializing, false) || has_init_packet_arrived {
4730                    let recv_buffer = initializing.recv_buffer.take().unwrap();
4731                    let send_buffer = initializing.send_buffer.take();
4732                    let state = ActiveState::new(
4733                        Some(PrimaryChannelState::new(
4734                            self.adapter.offload_support.clone(),
4735                        )),
4736                        recv_buffer.count,
4737                    );
4738                    let buffers = ChannelBuffers {
4739                        version: initializing.version,
4740                        mem,
4741                        recv_buffer,
4742                        send_buffer,
4743                        ndis_version: initializing.ndis_version.take().unwrap(),
4744                        ndis_config: initializing.ndis_config.take().unwrap_or(NdisConfig {
4745                            mtu: DEFAULT_MTU,
4746                            capabilities: protocol::NdisConfigCapabilities::new(),
4747                        }),
4748                    };
4749
4750                    break Ok((buffers, state));
4751                }
4752            }
4753
4754            // Wait for enough room in the ring to avoid needing to track
4755            // completion packets.
4756            self.queue
4757                .split()
4758                .1
4759                .wait_ready(ring::PacketSize::completion(protocol::PACKET_SIZE_V61))
4760                .await?;
4761
4762            let mut external_data = MultiPagedRangeBuf::new();
4763            let packet = self
4764                .next_packet(
4765                    None,
4766                    initializing.as_ref().map(|x| x.version),
4767                    &mut external_data,
4768                )
4769                .await?;
4770
4771            if let Some(initializing) = &mut *initializing {
4772                match packet.data {
4773                    PacketData::SendNdisConfig(config) => {
4774                        if initializing.ndis_config.is_some() {
4775                            return Err(WorkerError::UnexpectedPacketOrder(
4776                                PacketOrderError::SendNdisConfigExists,
4777                            ));
4778                        }
4779
4780                        // As in the vmswitch, if the MTU is invalid then use the default.
4781                        let mtu = if config.mtu >= MIN_MTU && config.mtu <= MAX_MTU {
4782                            config.mtu
4783                        } else {
4784                            DEFAULT_MTU
4785                        };
4786
4787                        // The UEFI client expects a completion packet, which can be empty.
4788                        self.send_completion(packet.transaction_id, None)?;
4789                        initializing.ndis_config = Some(NdisConfig {
4790                            mtu,
4791                            capabilities: config.capabilities,
4792                        });
4793                    }
4794                    PacketData::SendNdisVersion(version) => {
4795                        if initializing.ndis_version.is_some() {
4796                            return Err(WorkerError::UnexpectedPacketOrder(
4797                                PacketOrderError::SendNdisVersionExists,
4798                            ));
4799                        }
4800
4801                        // The UEFI client expects a completion packet, which can be empty.
4802                        self.send_completion(packet.transaction_id, None)?;
4803                        initializing.ndis_version = Some(NdisVersion {
4804                            major: version.ndis_major_version,
4805                            minor: version.ndis_minor_version,
4806                        });
4807                    }
4808                    PacketData::SendReceiveBuffer(message) => {
4809                        if initializing.recv_buffer.is_some() {
4810                            return Err(WorkerError::UnexpectedPacketOrder(
4811                                PacketOrderError::SendReceiveBufferExists,
4812                            ));
4813                        }
4814
4815                        let mtu = if let Some(cfg) = &initializing.ndis_config {
4816                            cfg.mtu
4817                        } else if initializing.version < Version::V2 {
4818                            DEFAULT_MTU
4819                        } else {
4820                            return Err(WorkerError::UnexpectedPacketOrder(
4821                                PacketOrderError::SendReceiveBufferMissingMTU,
4822                            ));
4823                        };
4824
4825                        let sub_allocation_size = sub_allocation_size_for_mtu(mtu);
4826
4827                        let recv_buffer = ReceiveBuffer::new(
4828                            &self.gpadl_map,
4829                            message.gpadl_handle,
4830                            message.id,
4831                            sub_allocation_size,
4832                        )?;
4833
4834                        self.send_completion(
4835                            packet.transaction_id,
4836                            Some(&self.message(
4837                                protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER_COMPLETE,
4838                                protocol::Message1SendReceiveBufferComplete {
4839                                    status: protocol::Status::SUCCESS,
4840                                    num_sections: 1,
4841                                    sections: [protocol::ReceiveBufferSection {
4842                                        offset: 0,
4843                                        sub_allocation_size: recv_buffer.sub_allocation_size,
4844                                        num_sub_allocations: recv_buffer.count,
4845                                        end_offset: recv_buffer.sub_allocation_size
4846                                            * recv_buffer.count,
4847                                    }],
4848                                },
4849                            )),
4850                        )?;
4851                        initializing.recv_buffer = Some(recv_buffer);
4852                    }
4853                    PacketData::SendSendBuffer(message) => {
4854                        if initializing.send_buffer.is_some() {
4855                            return Err(WorkerError::UnexpectedPacketOrder(
4856                                PacketOrderError::SendSendBufferExists,
4857                            ));
4858                        }
4859
4860                        let send_buffer = SendBuffer::new(&self.gpadl_map, message.gpadl_handle)?;
4861                        self.send_completion(
4862                            packet.transaction_id,
4863                            Some(&self.message(
4864                                protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER_COMPLETE,
4865                                protocol::Message1SendSendBufferComplete {
4866                                    status: protocol::Status::SUCCESS,
4867                                    section_size: 6144,
4868                                },
4869                            )),
4870                        )?;
4871
4872                        initializing.send_buffer = Some(send_buffer);
4873                    }
4874                    PacketData::RndisPacket(rndis_packet) => {
4875                        if !Self::is_ready_to_initialize(initializing, true) {
4876                            return Err(WorkerError::UnexpectedPacketOrder(
4877                                PacketOrderError::UnexpectedRndisPacket,
4878                            ));
4879                        }
4880                        tracing::debug!(
4881                            channel_type = rndis_packet.channel_type,
4882                            "RndisPacket received during initialization, assuming MESSAGE_TYPE_INITIALIZE_MSG"
4883                        );
4884                        has_init_packet_arrived = true;
4885                    }
4886                    _ => {
4887                        return Err(WorkerError::UnexpectedPacketOrder(
4888                            PacketOrderError::InvalidPacketData,
4889                        ));
4890                    }
4891                }
4892            } else {
4893                match packet.data {
4894                    PacketData::Init(init) => {
4895                        let requested_version = init.protocol_version;
4896                        let version = check_version(requested_version);
4897                        let mut data = protocol::MessageInitComplete {
4898                            deprecated: protocol::INVALID_PROTOCOL_VERSION,
4899                            maximum_mdl_chain_length: 34,
4900                            status: protocol::Status::NONE,
4901                        };
4902                        if let Some(version) = version {
4903                            if version == Version::V1 {
4904                                data.deprecated = Version::V1 as u32;
4905                            }
4906                            data.status = protocol::Status::SUCCESS;
4907                        } else {
4908                            tracing::debug!(requested_version, "unrecognized version");
4909                        }
4910                        let message = self.message(protocol::MESSAGE_TYPE_INIT_COMPLETE, data);
4911                        self.send_completion(packet.transaction_id, Some(&message))?;
4912
4913                        if let Some(version) = version {
4914                            tracelimit::info_ratelimited!(?version, "network negotiated");
4915
4916                            if version >= Version::V61 {
4917                                // Update the packet size so that the appropriate padding is
4918                                // appended for picky Windows guests.
4919                                self.packet_size = PacketSize::V61;
4920                            }
4921                            *initializing = Some(InitState {
4922                                version,
4923                                ndis_config: None,
4924                                ndis_version: None,
4925                                recv_buffer: None,
4926                                send_buffer: None,
4927                            });
4928                        }
4929                    }
4930                    _ => unreachable!(),
4931                }
4932            }
4933        }
4934    }
4935
4936    async fn main_loop(
4937        &mut self,
4938        stop: &mut StopTask<'_>,
4939        ready_state: &mut ReadyState,
4940        queue_state: &mut QueueState,
4941    ) -> Result<CoordinatorMessage, WorkerError> {
4942        let buffers = &ready_state.buffers;
4943        let state = &mut ready_state.state;
4944        let data = &mut ready_state.data;
4945
4946        let ring_spare_capacity = {
4947            let (_, send) = self.queue.split();
4948            let mut limit = if self.can_use_ring_size_opt {
4949                self.adapter.ring_size_limit.load(Ordering::Relaxed)
4950            } else {
4951                0
4952            };
4953            if limit == 0 {
4954                limit = send.capacity() - 2048;
4955            }
4956            send.capacity() - limit
4957        };
4958
4959        // If the packet filter has changed to allow rx packets, add any pended RxIds.
4960        if !state.pending_rx_packets.is_empty()
4961            && self.packet_filter != rndisprot::NDIS_PACKET_TYPE_NONE
4962        {
4963            let epqueue = queue_state.queue.as_mut();
4964            let (front, back) = state.pending_rx_packets.as_slices();
4965            epqueue.rx_avail(front);
4966            epqueue.rx_avail(back);
4967            state.pending_rx_packets.clear();
4968        }
4969
4970        // Handle any guest state changes since last run.
4971        if let Some(primary) = state.primary.as_mut() {
4972            if primary.requested_num_queues > 1 && !primary.tx_spread_sent {
4973                let num_channels_opened =
4974                    self.adapter.num_sub_channels_opened.load(Ordering::Relaxed);
4975                if num_channels_opened == primary.requested_num_queues as usize {
4976                    let (_, mut send) = self.queue.split();
4977                    stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
4978                        .await??;
4979                    self.guest_send_indirection_table(buffers.version, num_channels_opened as u32);
4980                    primary.tx_spread_sent = true;
4981                }
4982            }
4983            if let PendingLinkAction::Active(up) = primary.pending_link_action {
4984                let (_, mut send) = self.queue.split();
4985                stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
4986                    .await??;
4987                if let Some(id) = primary.free_control_buffers.pop() {
4988                    let connect = if primary.guest_link_up != up {
4989                        primary.pending_link_action = PendingLinkAction::Default;
4990                        up
4991                    } else {
4992                        // For the up -> down -> up OR down -> up -> down case, the first transition
4993                        // is sent immediately and the second transition is queued with a delay.
4994                        primary.pending_link_action =
4995                            PendingLinkAction::Delay(primary.guest_link_up);
4996                        !primary.guest_link_up
4997                    };
4998                    // Mark the receive buffer in use to allow the guest to release it.
4999                    assert!(state.rx_bufs.is_free(id.0));
5000                    state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
5001                    let state_to_send = if connect {
5002                        rndisprot::STATUS_MEDIA_CONNECT
5003                    } else {
5004                        rndisprot::STATUS_MEDIA_DISCONNECT
5005                    };
5006                    tracing::info!(
5007                        connect,
5008                        mac_address = %self.adapter.mac_address,
5009                        "sending link status"
5010                    );
5011
5012                    self.indicate_status(buffers, id.0, state_to_send, &[])?;
5013                    primary.guest_link_up = connect;
5014                } else {
5015                    primary.pending_link_action = PendingLinkAction::Delay(up);
5016                }
5017
5018                match primary.pending_link_action {
5019                    PendingLinkAction::Delay(_) => {
5020                        return Ok(CoordinatorMessage::StartTimer(
5021                            Instant::now() + LINK_DELAY_DURATION,
5022                        ));
5023                    }
5024                    PendingLinkAction::Active(_) => panic!("State should not be Active"),
5025                    _ => {}
5026                }
5027            }
5028            match primary.guest_vf_state {
5029                PrimaryChannelGuestVfState::Available { .. }
5030                | PrimaryChannelGuestVfState::UnavailableFromAvailable
5031                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
5032                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
5033                | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
5034                | PrimaryChannelGuestVfState::DataPathSynthetic => {
5035                    let (_, mut send) = self.queue.split();
5036                    stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
5037                        .await??;
5038                    if let Some(message) = self.handle_state_change(primary, buffers).await? {
5039                        return Ok(message);
5040                    }
5041                }
5042                _ => (),
5043            }
5044        }
5045
5046        loop {
5047            // If the ring is almost full, then do not poll the endpoint,
5048            // since this will cause us to spend time dropping packets and
5049            // reprogramming receive buffers. It's more efficient to let the
5050            // backend drop packets.
5051            let ring_full = {
5052                let (_, mut send) = self.queue.split();
5053                !send.can_write(ring_spare_capacity)?
5054            };
5055
5056            let did_some_work = (!ring_full
5057                && self.process_endpoint_rx(buffers, state, data, queue_state.queue.as_mut())?)
5058                | self.process_ring_buffer(buffers, state, data, queue_state)?
5059                | (!ring_full
5060                    && self.process_endpoint_tx(state, data, queue_state.queue.as_mut())?)
5061                | self.transmit_pending_segments(state, data, queue_state)?
5062                | self.send_pending_packets(state)?;
5063
5064            if !did_some_work {
5065                state.stats.spurious_wakes.increment();
5066            }
5067
5068            // Process any outstanding control messages before sleeping in case
5069            // there are now suballocations available for use.
5070            self.process_control_messages(buffers, state)?;
5071
5072            // This should be the only await point waiting on network traffic or
5073            // guest actions. Wrap it in `stop.until_stopped` to allow
5074            // cancellation.
5075            let restart = stop
5076                .until_stopped(std::future::poll_fn(
5077                    |cx| -> Poll<Option<CoordinatorMessage>> {
5078                        // If the ring is almost full, then don't wait for endpoint
5079                        // interrupts. This allows the interrupt rate to fall when the
5080                        // guest cannot keep up with the load.
5081                        if !ring_full {
5082                            // Check the network endpoint for tx completion or rx.
5083                            if queue_state.queue.poll_ready(cx).is_ready() {
5084                                tracing::trace!("endpoint ready");
5085                                return Poll::Ready(None);
5086                            }
5087                        }
5088
5089                        // Check the incoming ring for tx, but only if there are enough
5090                        // free tx packets and no pending tx segments.
5091                        let (mut recv, mut send) = self.queue.split();
5092                        if state.free_tx_packets.len() >= self.adapter.free_tx_packet_threshold
5093                            && data.tx_segments.is_empty()
5094                            && recv.poll_ready(cx).is_ready()
5095                        {
5096                            tracing::trace!("incoming ring ready");
5097                            return Poll::Ready(None);
5098                        }
5099
5100                        // Check the outgoing ring for space to send rx completions or
5101                        // control message sends, if any are pending.
5102                        //
5103                        // Also, if endpoint processing has been suspended due to the
5104                        // ring being nearly full, then ask the guest to wake us up when
5105                        // there is space again.
5106                        let mut pending_send_size = self.pending_send_size;
5107                        if ring_full {
5108                            pending_send_size = ring_spare_capacity;
5109                        }
5110                        if pending_send_size != 0
5111                            && send.poll_ready(cx, pending_send_size).is_ready()
5112                        {
5113                            tracing::trace!("outgoing ring ready");
5114                            return Poll::Ready(None);
5115                        }
5116
5117                        // Collect any of this queue's receive buffers that were
5118                        // completed by a remote channel. This only happens when the
5119                        // subchannel count changes, so that receive buffer ownership
5120                        // moves between queues while some receive buffers are still in
5121                        // use.
5122                        if let Some(remote_buffer_id_recv) =
5123                            &mut queue_state.rx_buffer_range.remote_buffer_id_recv
5124                        {
5125                            while let Poll::Ready(Some(id)) =
5126                                remote_buffer_id_recv.poll_next_unpin(cx)
5127                            {
5128                                if id >= RX_RESERVED_CONTROL_BUFFERS {
5129                                    queue_state.queue.rx_avail(&[RxId(id)]);
5130                                } else {
5131                                    state
5132                                        .primary
5133                                        .as_mut()
5134                                        .unwrap()
5135                                        .free_control_buffers
5136                                        .push(ControlMessageId(id));
5137                                }
5138                            }
5139                        }
5140
5141                        if let Some(restart) = self.restart.take() {
5142                            return Poll::Ready(Some(restart));
5143                        }
5144
5145                        tracing::trace!("network waiting");
5146                        Poll::Pending
5147                    },
5148                ))
5149                .await?;
5150
5151            if let Some(restart) = restart {
5152                break Ok(restart);
5153            }
5154        }
5155    }
5156
5157    fn process_endpoint_rx(
5158        &mut self,
5159        buffers: &ChannelBuffers,
5160        state: &mut ActiveState,
5161        data: &mut ProcessingData,
5162        epqueue: &mut dyn net_backend::Queue,
5163    ) -> Result<bool, WorkerError> {
5164        let n = epqueue
5165            .rx_poll(&mut data.rx_ready)
5166            .map_err(WorkerError::Endpoint)?;
5167        if n == 0 {
5168            return Ok(false);
5169        }
5170
5171        state.stats.rx_packets_per_wake.add_sample(n as u64);
5172
5173        if self.packet_filter == rndisprot::NDIS_PACKET_TYPE_NONE {
5174            tracing::trace!(
5175                packet_filter = self.packet_filter,
5176                "rx packets dropped due to packet filter"
5177            );
5178            // Pend the newly available RxIds until the packet filter is updated.
5179            // Under high load this will eventually lead to no available RxIds,
5180            // which will cause the backend to drop the packets instead of
5181            // processing them here.
5182            state.pending_rx_packets.extend(&data.rx_ready[..n]);
5183            state.stats.rx_dropped_filtered.add(n as u64);
5184            return Ok(false);
5185        }
5186
5187        let transaction_id = data.rx_ready[0].0.into();
5188        let ready_ids = data.rx_ready[..n].iter().map(|&RxId(id)| id);
5189
5190        state.rx_bufs.allocate(ready_ids.clone()).unwrap();
5191
5192        // Always use the full suballocation size to avoid tracking the
5193        // message length. See RxBuf::header() for details.
5194        let len = buffers.recv_buffer.sub_allocation_size as usize;
5195        data.transfer_pages.clear();
5196        data.transfer_pages
5197            .extend(ready_ids.map(|id| buffers.recv_buffer.transfer_page_range(id, len)));
5198
5199        match self.try_send_rndis_message(
5200            transaction_id,
5201            protocol::DATA_CHANNEL_TYPE,
5202            buffers.recv_buffer.id,
5203            &data.transfer_pages,
5204        )? {
5205            None => {
5206                // packet was sent
5207                state.stats.rx_packets.add(n as u64);
5208            }
5209            Some(_) => {
5210                // Ring buffer is full. Drop the packets and free the rx
5211                // buffers. When the ring has limited space, the main loop will
5212                // stop polling for receive packets.
5213                state.stats.rx_dropped_ring_full.add(n as u64);
5214
5215                state.rx_bufs.free(data.rx_ready[0].0);
5216                epqueue.rx_avail(&data.rx_ready[..n]);
5217            }
5218        }
5219
5220        Ok(true)
5221    }
5222
5223    fn process_endpoint_tx(
5224        &mut self,
5225        state: &mut ActiveState,
5226        data: &mut ProcessingData,
5227        epqueue: &mut dyn net_backend::Queue,
5228    ) -> Result<bool, WorkerError> {
5229        // Drain completed transmits.
5230        let result = epqueue.tx_poll(&mut data.tx_done);
5231
5232        match result {
5233            Ok(n) => {
5234                if n == 0 {
5235                    return Ok(false);
5236                }
5237
5238                for &id in &data.tx_done[..n] {
5239                    let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5240                    assert!(tx_packet.pending_packet_count > 0);
5241                    tx_packet.pending_packet_count -= 1;
5242                    if tx_packet.pending_packet_count == 0 {
5243                        self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5244                    }
5245                }
5246
5247                Ok(true)
5248            }
5249            Err(TxError::TryRestart(err)) => {
5250                // Complete any pending tx prior to restarting queues.
5251                let pending_tx = state
5252                    .pending_tx_packets
5253                    .iter_mut()
5254                    .enumerate()
5255                    .filter_map(|(id, inflight)| {
5256                        if inflight.pending_packet_count > 0 {
5257                            inflight.pending_packet_count = 0;
5258                            Some(PendingTxCompletion {
5259                                transaction_id: inflight.transaction_id,
5260                                tx_id: Some(TxId(id as u32)),
5261                                status: protocol::Status::SUCCESS,
5262                            })
5263                        } else {
5264                            None
5265                        }
5266                    })
5267                    .collect::<Vec<_>>();
5268                state.pending_tx_completions.extend(pending_tx);
5269                Err(WorkerError::EndpointRequiresQueueRestart(err))
5270            }
5271            Err(TxError::Fatal(err)) => Err(WorkerError::Endpoint(err)),
5272        }
5273    }
5274
5275    fn switch_data_path(
5276        &mut self,
5277        state: &mut ActiveState,
5278        use_guest_vf: bool,
5279        transaction_id: Option<u64>,
5280    ) -> Result<(), WorkerError> {
5281        let primary = state.primary.as_mut().unwrap();
5282        let mut queue_switch_operation = false;
5283        match primary.guest_vf_state {
5284            PrimaryChannelGuestVfState::AvailableAdvertised | PrimaryChannelGuestVfState::Ready => {
5285                // Allow the guest to switch to VTL0, or if the current state
5286                // of the data path is unknown, allow a switch away from VTL0.
5287                // The latter case handles the scenario where the data path has
5288                // been switched but the synthetic device has been restarted.
5289                // The device is queried for the current state of the data path
5290                // but if it is unknown (upgraded from older version that
5291                // doesn't track this data, or failure during query) then the
5292                // safest option is to pass the request through.
5293                if use_guest_vf || primary.is_data_path_switched.is_none() {
5294                    primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5295                        to_guest: use_guest_vf,
5296                        id: transaction_id,
5297                        result: None,
5298                    };
5299                    queue_switch_operation = true;
5300                }
5301            }
5302            PrimaryChannelGuestVfState::DataPathSwitched => {
5303                if !use_guest_vf {
5304                    primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5305                        to_guest: false,
5306                        id: transaction_id,
5307                        result: None,
5308                    };
5309                    queue_switch_operation = true;
5310                }
5311            }
5312            _ if use_guest_vf => {
5313                tracing::warn!(
5314                    state = %primary.guest_vf_state,
5315                    use_guest_vf,
5316                    "Data path switch requested while device is in wrong state"
5317                );
5318            }
5319            _ => (),
5320        };
5321        if queue_switch_operation {
5322            self.send_coordinator_update_vf();
5323        } else {
5324            self.send_completion(transaction_id, None)?;
5325        }
5326        Ok(())
5327    }
5328
5329    fn process_ring_buffer(
5330        &mut self,
5331        buffers: &ChannelBuffers,
5332        state: &mut ActiveState,
5333        data: &mut ProcessingData,
5334        queue_state: &mut QueueState,
5335    ) -> Result<bool, WorkerError> {
5336        if !data.tx_segments.is_empty() {
5337            // There are still segments pending transmission. Skip polling the
5338            // ring until they are sent to increase backpressure and minimize
5339            // unnecessary wakeups.
5340            return Ok(false);
5341        }
5342        let mut total_packets = 0;
5343        let mut did_some_work = false;
5344        loop {
5345            if state.free_tx_packets.is_empty() {
5346                break;
5347            }
5348            let packet = if let Some(packet) = self.try_next_packet(
5349                buffers.send_buffer.as_ref(),
5350                Some(buffers.version),
5351                &mut data.external_data,
5352            )? {
5353                packet
5354            } else {
5355                break;
5356            };
5357
5358            did_some_work = true;
5359            match packet.data {
5360                PacketData::RndisPacket(_) => {
5361                    let id = state.free_tx_packets.pop().unwrap();
5362                    let result: Result<usize, WorkerError> =
5363                        self.handle_rndis(buffers, id, state, &packet, &mut data.tx_segments);
5364                    match result {
5365                        Ok(num_packets) => {
5366                            total_packets += num_packets as u64;
5367                            if num_packets == 0 {
5368                                self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5369                            }
5370                        }
5371                        Err(err) => {
5372                            tracelimit::error_ratelimited!(
5373                                err = &err as &dyn std::error::Error,
5374                                "failed to handle RNDIS packet"
5375                            );
5376                            self.complete_tx_packet(state, id, protocol::Status::FAILURE)?;
5377                        }
5378                    };
5379                }
5380                PacketData::RndisPacketComplete(_completion) => {
5381                    data.rx_done.clear();
5382                    state
5383                        .release_recv_buffers(
5384                            packet
5385                                .transaction_id
5386                                .expect("completion packets have transaction id by construction"),
5387                            &queue_state.rx_buffer_range,
5388                            &mut data.rx_done,
5389                        )
5390                        .ok_or(WorkerError::InvalidRndisPacketCompletion)?;
5391                    queue_state.queue.rx_avail(&data.rx_done);
5392                }
5393                PacketData::SubChannelRequest(request) if state.primary.is_some() => {
5394                    let mut subchannel_count = 0;
5395                    // The number of requested subchannels has to stay below the maximum queue limit
5396                    // because one queue is always reserved for the primary channel. In other words,
5397                    // the subchannels plus the primary channel must fit within the max_queues value,
5398                    // which means subchannels + 1 ≤ max_queues, so the subchannel count must be
5399                    // strictly less than max_queues.
5400                    let status = if request.operation == protocol::SubchannelOperation::ALLOCATE
5401                        && request.num_sub_channels < self.adapter.max_queues.into()
5402                    {
5403                        subchannel_count = request.num_sub_channels;
5404                        protocol::Status::SUCCESS
5405                    } else {
5406                        tracelimit::warn_ratelimited!(
5407                            operation = ?request.operation,
5408                            request_sub_channels = request.num_sub_channels,
5409                            max_supported_sub_channels = self.adapter.max_queues - 1,
5410                            "Subchannel request failed: either operation is not supported or requested more subchannels than supported"
5411                        );
5412                        protocol::Status::FAILURE
5413                    };
5414
5415                    tracing::debug!(?status, subchannel_count, "subchannel request");
5416                    self.send_completion(
5417                        packet.transaction_id,
5418                        Some(&self.message(
5419                            protocol::MESSAGE5_TYPE_SUB_CHANNEL,
5420                            protocol::Message5SubchannelComplete {
5421                                status,
5422                                num_sub_channels: subchannel_count,
5423                            },
5424                        )),
5425                    )?;
5426
5427                    if subchannel_count > 0 {
5428                        let primary = state.primary.as_mut().unwrap();
5429                        primary.requested_num_queues = subchannel_count as u16 + 1;
5430                        primary.tx_spread_sent = false;
5431                        self.restart = Some(CoordinatorMessage::Restart);
5432                    }
5433                }
5434                PacketData::RevokeReceiveBuffer(Message1RevokeReceiveBuffer { id })
5435                | PacketData::RevokeSendBuffer(Message1RevokeSendBuffer { id })
5436                    if state.primary.is_some() =>
5437                {
5438                    tracing::debug!(
5439                        id,
5440                        "receive/send buffer revoked, terminating channel processing"
5441                    );
5442                    return Err(WorkerError::BufferRevoked);
5443                }
5444                // No operation for VF association completion packets as not all clients send them
5445                PacketData::SendVfAssociationCompletion if state.primary.is_some() => (),
5446                PacketData::SwitchDataPath(switch_data_path) if state.primary.is_some() => {
5447                    self.switch_data_path(
5448                        state,
5449                        switch_data_path.active_data_path == protocol::DataPath::VF.0,
5450                        packet.transaction_id,
5451                    )?;
5452                }
5453                PacketData::SwitchDataPathCompletion if state.primary.is_some() => (),
5454                PacketData::OidQueryEx(oid_query) => {
5455                    tracing::warn!(oid = ?oid_query.oid, "unimplemented OID");
5456                    self.send_completion(
5457                        packet.transaction_id,
5458                        Some(&self.message(
5459                            protocol::MESSAGE5_TYPE_OID_QUERY_EX_COMPLETE,
5460                            protocol::Message5OidQueryExComplete {
5461                                status: rndisprot::STATUS_NOT_SUPPORTED,
5462                                bytes: 0,
5463                            },
5464                        )),
5465                    )?;
5466                }
5467                p => {
5468                    tracing::warn!(request = ?p, "unexpected packet");
5469                    return Err(WorkerError::UnexpectedPacketOrder(
5470                        PacketOrderError::SwitchDataPathCompletionPrimaryChannelState,
5471                    ));
5472                }
5473            }
5474        }
5475        if total_packets > 0 && !self.transmit_segments(state, data, queue_state)? {
5476            state.stats.tx_stalled.increment();
5477        }
5478        state.stats.tx_packets_per_wake.add_sample(total_packets);
5479        Ok(did_some_work)
5480    }
5481
5482    // Transmit any pending segments. Returns Ok(true) if work was done--if any
5483    // segments were transmitted.
5484    fn transmit_pending_segments(
5485        &mut self,
5486        state: &mut ActiveState,
5487        data: &mut ProcessingData,
5488        queue_state: &mut QueueState,
5489    ) -> Result<bool, WorkerError> {
5490        if data.tx_segments.is_empty() {
5491            return Ok(false);
5492        }
5493        let sent = data.tx_segments_sent;
5494        let did_work =
5495            self.transmit_segments(state, data, queue_state)? || data.tx_segments_sent > sent;
5496        Ok(did_work)
5497    }
5498
5499    /// Returns true if all pending segments were transmitted.
5500    fn transmit_segments(
5501        &mut self,
5502        state: &mut ActiveState,
5503        data: &mut ProcessingData,
5504        queue_state: &mut QueueState,
5505    ) -> Result<bool, WorkerError> {
5506        let segments = &data.tx_segments[data.tx_segments_sent..];
5507        let (sync, segments_sent) = queue_state
5508            .queue
5509            .tx_avail(segments)
5510            .map_err(WorkerError::Endpoint)?;
5511
5512        let mut segments = &segments[..segments_sent];
5513        data.tx_segments_sent += segments_sent;
5514
5515        if sync {
5516            // Complete the packets now.
5517            while let Some(head) = segments.first() {
5518                let net_backend::TxSegmentType::Head(metadata) = &head.ty else {
5519                    unreachable!()
5520                };
5521                let id = metadata.id;
5522                let pending_tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5523                pending_tx_packet.pending_packet_count -= 1;
5524                if pending_tx_packet.pending_packet_count == 0 {
5525                    self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5526                }
5527                segments = &segments[metadata.segment_count as usize..];
5528            }
5529        }
5530
5531        let all_sent = data.tx_segments_sent == data.tx_segments.len();
5532        if all_sent {
5533            data.tx_segments.clear();
5534            data.tx_segments_sent = 0;
5535        }
5536        Ok(all_sent)
5537    }
5538
5539    fn handle_rndis(
5540        &mut self,
5541        buffers: &ChannelBuffers,
5542        id: TxId,
5543        state: &mut ActiveState,
5544        packet: &Packet<'_>,
5545        segments: &mut Vec<TxSegment>,
5546    ) -> Result<usize, WorkerError> {
5547        let mut total_packets = 0;
5548        let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5549        assert_eq!(tx_packet.pending_packet_count, 0);
5550        tx_packet.transaction_id = packet
5551            .transaction_id
5552            .ok_or(WorkerError::MissingTransactionId)?;
5553
5554        // Probe the data to catch accesses that are out of bounds. This
5555        // simplifies error handling for backends that use
5556        // [`GuestMemory::iova`].
5557        packet
5558            .external_data
5559            .iter()
5560            .try_for_each(|range| buffers.mem.probe_gpns(range.gpns()))
5561            .map_err(WorkerError::GpaDirectError)?;
5562
5563        let mut reader = packet.rndis_reader(&buffers.mem);
5564        let header: rndisprot::MessageHeader = reader.read_plain()?;
5565        if header.message_type == rndisprot::MESSAGE_TYPE_PACKET_MSG {
5566            let start = segments.len();
5567            match self.handle_rndis_packet_messages(
5568                buffers,
5569                state,
5570                id,
5571                header.message_length as usize,
5572                reader,
5573                segments,
5574            ) {
5575                Ok(n) => {
5576                    state.pending_tx_packets[id.0 as usize].pending_packet_count += n;
5577                    total_packets += n;
5578                }
5579                Err(err) => {
5580                    // Roll back any segments added for this message.
5581                    segments.truncate(start);
5582                    return Err(err);
5583                }
5584            }
5585        } else {
5586            self.handle_rndis_message(state, header.message_type, reader)?;
5587        }
5588
5589        Ok(total_packets)
5590    }
5591
5592    fn try_send_tx_packet(
5593        &mut self,
5594        transaction_id: u64,
5595        status: protocol::Status,
5596    ) -> Result<bool, WorkerError> {
5597        let message = self.message(
5598            protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE,
5599            protocol::Message1SendRndisPacketComplete { status },
5600        );
5601        let result = self.queue.split().1.batched().try_write_aligned(
5602            transaction_id,
5603            OutgoingPacketType::Completion,
5604            message.aligned_payload(),
5605        );
5606        let sent = match result {
5607            Ok(()) => true,
5608            Err(queue::TryWriteError::Full(n)) => {
5609                self.pending_send_size = n;
5610                false
5611            }
5612            Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
5613        };
5614        Ok(sent)
5615    }
5616
5617    fn send_pending_packets(&mut self, state: &mut ActiveState) -> Result<bool, WorkerError> {
5618        let mut did_some_work = false;
5619        while let Some(pending) = state.pending_tx_completions.front() {
5620            if !self.try_send_tx_packet(pending.transaction_id, pending.status)? {
5621                return Ok(did_some_work);
5622            }
5623            did_some_work = true;
5624            if let Some(id) = pending.tx_id {
5625                state.free_tx_packets.push(id);
5626            }
5627            tracing::trace!(?pending, "sent tx completion");
5628            state.pending_tx_completions.pop_front();
5629        }
5630
5631        self.pending_send_size = 0;
5632        Ok(did_some_work)
5633    }
5634
5635    fn complete_tx_packet(
5636        &mut self,
5637        state: &mut ActiveState,
5638        id: TxId,
5639        status: protocol::Status,
5640    ) -> Result<(), WorkerError> {
5641        let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5642        assert_eq!(tx_packet.pending_packet_count, 0);
5643        if self.pending_send_size == 0
5644            && self.try_send_tx_packet(tx_packet.transaction_id, status)?
5645        {
5646            tracing::trace!(id = id.0, "sent tx completion");
5647            state.free_tx_packets.push(id);
5648        } else {
5649            tracing::trace!(id = id.0, "pended tx completion");
5650            state.pending_tx_completions.push_back(PendingTxCompletion {
5651                transaction_id: tx_packet.transaction_id,
5652                tx_id: Some(id),
5653                status,
5654            });
5655        }
5656        Ok(())
5657    }
5658}
5659
5660impl ActiveState {
5661    fn release_recv_buffers(
5662        &mut self,
5663        transaction_id: u64,
5664        rx_buffer_range: &RxBufferRange,
5665        done: &mut Vec<RxId>,
5666    ) -> Option<()> {
5667        // The transaction ID specifies the first rx buffer ID.
5668        let first_id: u32 = transaction_id.try_into().ok()?;
5669        let ids = self.rx_bufs.free(first_id)?;
5670        for id in ids {
5671            if !rx_buffer_range.send_if_remote(id) {
5672                if id >= RX_RESERVED_CONTROL_BUFFERS {
5673                    done.push(RxId(id));
5674                } else {
5675                    self.primary
5676                        .as_mut()?
5677                        .free_control_buffers
5678                        .push(ControlMessageId(id));
5679                }
5680            }
5681        }
5682        Some(())
5683    }
5684}