netvsp/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! The user-mode netvsp VMBus device implementation.
5
6#![expect(missing_docs)]
7#![forbid(unsafe_code)]
8
9mod buffers;
10mod protocol;
11pub mod resolver;
12mod rndisprot;
13mod rx_bufs;
14mod saved_state;
15mod test;
16
17use crate::buffers::GuestBuffers;
18use crate::protocol::Message1RevokeReceiveBuffer;
19use crate::protocol::Message1RevokeSendBuffer;
20use crate::protocol::VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES;
21use crate::protocol::Version;
22use crate::rndisprot::NDIS_HASH_FUNCTION_MASK;
23use crate::rndisprot::NDIS_RSS_PARAM_FLAG_DISABLE_RSS;
24use async_trait::async_trait;
25pub use buffers::BufferPool;
26use buffers::sub_allocation_size_for_mtu;
27use futures::FutureExt;
28use futures::StreamExt;
29use futures::channel::mpsc;
30use futures_concurrency::future::Race;
31use guestmem::AccessError;
32use guestmem::GuestMemory;
33use guestmem::GuestMemoryError;
34use guestmem::MemoryRead;
35use guestmem::MemoryWrite;
36use guestmem::ranges::PagedRange;
37use guestmem::ranges::PagedRanges;
38use guestmem::ranges::PagedRangesReader;
39use guid::Guid;
40use hvdef::hypercall::HvGuestOsId;
41use hvdef::hypercall::HvGuestOsMicrosoft;
42use hvdef::hypercall::HvGuestOsMicrosoftIds;
43use hvdef::hypercall::HvGuestOsOpenSourceType;
44use inspect::Inspect;
45use inspect::InspectMut;
46use inspect::SensitivityLevel;
47use inspect_counters::Counter;
48use inspect_counters::Histogram;
49use mesh::rpc::Rpc;
50use net_backend::Endpoint;
51use net_backend::EndpointAction;
52use net_backend::QueueConfig;
53use net_backend::RxId;
54use net_backend::TxError;
55use net_backend::TxId;
56use net_backend::TxSegment;
57use net_backend_resources::mac_address::MacAddress;
58use pal_async::timer::Instant;
59use pal_async::timer::PolledTimer;
60use ring::gparange::MultiPagedRangeIter;
61use rx_bufs::RxBuffers;
62use rx_bufs::SubAllocationInUse;
63use std::collections::VecDeque;
64use std::fmt::Debug;
65use std::future::pending;
66use std::mem::offset_of;
67use std::ops::Range;
68use std::sync::Arc;
69use std::sync::atomic::AtomicUsize;
70use std::sync::atomic::Ordering;
71use std::task::Poll;
72use std::time::Duration;
73use task_control::AsyncRun;
74use task_control::InspectTaskMut;
75use task_control::StopTask;
76use task_control::TaskControl;
77use thiserror::Error;
78use tracing::Instrument;
79use vmbus_async::queue;
80use vmbus_async::queue::ExternalDataError;
81use vmbus_async::queue::IncomingPacket;
82use vmbus_async::queue::Queue;
83use vmbus_channel::bus::OfferParams;
84use vmbus_channel::bus::OpenRequest;
85use vmbus_channel::channel::ChannelControl;
86use vmbus_channel::channel::ChannelOpenError;
87use vmbus_channel::channel::ChannelRestoreError;
88use vmbus_channel::channel::DeviceResources;
89use vmbus_channel::channel::RestoreControl;
90use vmbus_channel::channel::SaveRestoreVmbusDevice;
91use vmbus_channel::channel::VmbusDevice;
92use vmbus_channel::gpadl::GpadlId;
93use vmbus_channel::gpadl::GpadlMapView;
94use vmbus_channel::gpadl::GpadlView;
95use vmbus_channel::gpadl::UnknownGpadlId;
96use vmbus_channel::gpadl_ring::GpadlRingMem;
97use vmbus_channel::gpadl_ring::gpadl_channel;
98use vmbus_ring as ring;
99use vmbus_ring::OutgoingPacketType;
100use vmbus_ring::RingMem;
101use vmbus_ring::gparange::MultiPagedRangeBuf;
102use vmcore::save_restore::RestoreError;
103use vmcore::save_restore::SaveError;
104use vmcore::save_restore::SavedStateBlob;
105use vmcore::vm_task::VmTaskDriver;
106use vmcore::vm_task::VmTaskDriverSource;
107use zerocopy::FromBytes;
108use zerocopy::FromZeros;
109use zerocopy::Immutable;
110use zerocopy::IntoBytes;
111use zerocopy::KnownLayout;
112
113// The minimum ring space required to handle a control message. Most control messages only need to send a completion
114// packet, but also need room for an additional SEND_VF_ASSOCIATION message.
115const MIN_CONTROL_RING_SIZE: usize = 144;
116
117// The minimum ring space required to handle external state changes. Worst case requires a completion message plus two
118// additional inband messages (SWITCH_DATA_PATH and SEND_VF_ASSOCIATION)
119const MIN_STATE_CHANGE_RING_SIZE: usize = 196;
120
121// Assign the VF_ASSOCIATION message a specific transaction ID so that the completion packet can be identified easily.
122const VF_ASSOCIATION_TRANSACTION_ID: u64 = 0x8000000000000000;
123// Assign the SWITCH_DATA_PATH message a specific transaction ID so that the completion packet can be identified easily.
124const SWITCH_DATA_PATH_TRANSACTION_ID: u64 = 0x8000000000000001;
125
126const NETVSP_MAX_SUBCHANNELS_PER_VNIC: u16 = 64;
127
128// Arbitrary delay before adding the device to the guest. Older Linux
129// clients can race when initializing the synthetic nic: the network
130// negotiation is done first and then the device is asynchronously queued to
131// receive a name (e.g. eth0). If the AN device is offered too quickly, it
132// could get the "eth0" name. In provisioning scenarios, the scripts will make
133// assumptions about which interface should be used, with eth0 being the
134// default.
135#[cfg(not(test))]
136const VF_DEVICE_DELAY: Duration = Duration::from_secs(1);
137#[cfg(test)]
138const VF_DEVICE_DELAY: Duration = Duration::from_millis(100);
139
140// Linux guests are known to not act on link state change notifications if
141// they happen in quick succession.
142#[cfg(not(test))]
143const LINK_DELAY_DURATION: Duration = Duration::from_secs(5);
144#[cfg(test)]
145const LINK_DELAY_DURATION: Duration = Duration::from_millis(333);
146
147#[derive(Default, PartialEq)]
148struct CoordinatorMessageUpdateType {
149    /// Update guest VF state based on current availability and the guest VF state tracked by the primary channel.
150    /// This includes adding the guest VF device and switching the data path.
151    guest_vf_state: bool,
152    /// Update the receive filter for all channels.
153    filter_state: bool,
154}
155
156#[derive(PartialEq)]
157enum CoordinatorMessage {
158    /// Update network state.
159    Update(CoordinatorMessageUpdateType),
160    /// Restart endpoints and resume processing. This will also attempt to set VF and data path state to match current
161    /// expectations.
162    Restart,
163    /// Start a timer.
164    StartTimer(Instant),
165}
166
167struct Worker<T: RingMem> {
168    channel_idx: u16,
169    target_vp: u32,
170    mem: GuestMemory,
171    channel: NetChannel<T>,
172    state: WorkerState,
173    coordinator_send: mpsc::Sender<CoordinatorMessage>,
174}
175
176struct NetQueue {
177    driver: VmTaskDriver,
178    queue_state: Option<QueueState>,
179}
180
181impl<T: RingMem + 'static + Sync> InspectTaskMut<Worker<T>> for NetQueue {
182    fn inspect_mut(&mut self, req: inspect::Request<'_>, worker: Option<&mut Worker<T>>) {
183        if worker.is_none() && self.queue_state.is_none() {
184            req.ignore();
185            return;
186        }
187
188        let mut resp = req.respond();
189        resp.field("driver", &self.driver);
190        if let Some(worker) = worker {
191            resp.field(
192                "protocol_state",
193                match &worker.state {
194                    WorkerState::Init(None) => "version",
195                    WorkerState::Init(Some(_)) => "init",
196                    WorkerState::Ready(_) => "ready",
197                },
198            )
199            .field("ring", &worker.channel.queue)
200            .field(
201                "can_use_ring_size_optimization",
202                worker.channel.can_use_ring_size_opt,
203            );
204
205            if let WorkerState::Ready(state) = &worker.state {
206                resp.field(
207                    "outstanding_tx_packets",
208                    state.state.pending_tx_packets.len() - state.state.free_tx_packets.len(),
209                )
210                .field("pending_rx_packets", state.state.pending_rx_packets.len())
211                .field(
212                    "pending_tx_completions",
213                    state.state.pending_tx_completions.len(),
214                )
215                .field("free_tx_packets", state.state.free_tx_packets.len())
216                .merge(&state.state.stats);
217            }
218
219            resp.field("packet_filter", worker.channel.packet_filter);
220        }
221
222        if let Some(queue_state) = &mut self.queue_state {
223            resp.field_mut("queue", &mut queue_state.queue)
224                .field("rx_buffers", queue_state.rx_buffer_range.id_range.len())
225                .field(
226                    "rx_buffers_start",
227                    queue_state.rx_buffer_range.id_range.start,
228                );
229        }
230    }
231}
232
233#[expect(clippy::large_enum_variant)]
234enum WorkerState {
235    Init(Option<InitState>),
236    Ready(ReadyState),
237}
238
239impl WorkerState {
240    fn ready(&self) -> Option<&ReadyState> {
241        if let Self::Ready(state) = self {
242            Some(state)
243        } else {
244            None
245        }
246    }
247
248    fn ready_mut(&mut self) -> Option<&mut ReadyState> {
249        if let Self::Ready(state) = self {
250            Some(state)
251        } else {
252            None
253        }
254    }
255}
256
257struct InitState {
258    version: Version,
259    ndis_config: Option<NdisConfig>,
260    ndis_version: Option<NdisVersion>,
261    recv_buffer: Option<ReceiveBuffer>,
262    send_buffer: Option<SendBuffer>,
263}
264
265#[derive(Copy, Clone, Debug, Inspect)]
266struct NdisVersion {
267    #[inspect(hex)]
268    major: u32,
269    #[inspect(hex)]
270    minor: u32,
271}
272
273#[derive(Copy, Clone, Debug, Inspect)]
274struct NdisConfig {
275    #[inspect(safe)]
276    mtu: u32,
277    #[inspect(safe)]
278    capabilities: protocol::NdisConfigCapabilities,
279}
280
281struct ReadyState {
282    buffers: Arc<ChannelBuffers>,
283    state: ActiveState,
284    data: ProcessingData,
285}
286
287/// Represents a virtual function (VF) device used to expose accelerated
288/// networking to the guest.
289#[async_trait]
290pub trait VirtualFunction: Sync + Send {
291    /// Unique ID of the device. Used by the client to associate a device with
292    /// its synthetic counterpart. A value of None signifies that the VF is not
293    /// currently available for use.
294    async fn id(&self) -> Option<u32>;
295    /// Dynamically expose the device in the guest.
296    async fn guest_ready_for_device(&mut self);
297    /// Returns when there is a change in VF availability. The Rpc result will
298    ///  indicate if the change was successfully handled.
299    async fn wait_for_state_change(&mut self) -> Rpc<(), ()>;
300}
301
302struct Adapter {
303    driver: VmTaskDriver,
304    mac_address: MacAddress,
305    max_queues: u16,
306    indirection_table_size: u16,
307    offload_support: OffloadConfig,
308    ring_size_limit: AtomicUsize,
309    free_tx_packet_threshold: usize,
310    tx_fast_completions: bool,
311    adapter_index: u32,
312    get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
313    num_sub_channels_opened: AtomicUsize,
314    link_speed: u64,
315}
316
317struct QueueState {
318    queue: Box<dyn net_backend::Queue>,
319    rx_buffer_range: RxBufferRange,
320    target_vp_set: bool,
321}
322
323struct RxBufferRange {
324    id_range: Range<u32>,
325    remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
326    remote_ranges: Arc<RxBufferRanges>,
327}
328
329impl RxBufferRange {
330    fn new(
331        ranges: Arc<RxBufferRanges>,
332        id_range: Range<u32>,
333        remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
334    ) -> Self {
335        Self {
336            id_range,
337            remote_buffer_id_recv,
338            remote_ranges: ranges,
339        }
340    }
341
342    fn send_if_remote(&self, id: u32) -> bool {
343        // Only queue 0 should get reserved buffer IDs. Otherwise check if the
344        // ID is owned by the current range.
345        if id < RX_RESERVED_CONTROL_BUFFERS || self.id_range.contains(&id) {
346            false
347        } else {
348            let i = (id - RX_RESERVED_CONTROL_BUFFERS) / self.remote_ranges.buffers_per_queue;
349            // The total number of receive buffers may not evenly divide among
350            // the active queues. Any extra buffers are given to the last
351            // queue, so redirect any larger values there.
352            let i = (i as usize).min(self.remote_ranges.buffer_id_send.len() - 1);
353            let _ = self.remote_ranges.buffer_id_send[i].unbounded_send(id);
354            true
355        }
356    }
357}
358
359struct RxBufferRanges {
360    buffers_per_queue: u32,
361    buffer_id_send: Vec<mpsc::UnboundedSender<u32>>,
362}
363
364impl RxBufferRanges {
365    fn new(buffer_count: u32, queue_count: u32) -> (Self, Vec<mpsc::UnboundedReceiver<u32>>) {
366        let buffers_per_queue = (buffer_count - RX_RESERVED_CONTROL_BUFFERS) / queue_count;
367        #[expect(clippy::disallowed_methods)] // TODO
368        let (send, recv): (Vec<_>, Vec<_>) = (0..queue_count).map(|_| mpsc::unbounded()).unzip();
369        (
370            Self {
371                buffers_per_queue,
372                buffer_id_send: send,
373            },
374            recv,
375        )
376    }
377}
378
379struct RssState {
380    key: [u8; 40],
381    indirection_table: Vec<u16>,
382}
383
384/// The internal channel state.
385struct NetChannel<T: RingMem> {
386    adapter: Arc<Adapter>,
387    queue: Queue<T>,
388    gpadl_map: GpadlMapView,
389    packet_size: PacketSize,
390    pending_send_size: usize,
391    restart: Option<CoordinatorMessage>,
392    can_use_ring_size_opt: bool,
393    packet_filter: u32,
394}
395
396// Use an enum to give the compiler more visibility into the packet size.
397#[derive(Debug, Copy, Clone)]
398enum PacketSize {
399    /// [`protocol::PACKET_SIZE_V1`]
400    V1,
401    /// [`protocol::PACKET_SIZE_V61`]
402    V61,
403}
404
405/// Buffers used during packet processing.
406struct ProcessingData {
407    tx_segments: Vec<TxSegment>,
408    tx_segments_sent: usize,
409    tx_done: Box<[TxId]>,
410    rx_ready: Box<[RxId]>,
411    rx_done: Vec<RxId>,
412    transfer_pages: Vec<ring::TransferPageRange>,
413    external_data: MultiPagedRangeBuf,
414}
415
416impl ProcessingData {
417    fn new() -> Self {
418        Self {
419            tx_segments: Vec::new(),
420            tx_segments_sent: 0,
421            tx_done: vec![TxId(0); 8192].into(),
422            rx_ready: vec![RxId(0); RX_BATCH_SIZE].into(),
423            rx_done: Vec::with_capacity(RX_BATCH_SIZE),
424            transfer_pages: Vec::with_capacity(RX_BATCH_SIZE),
425            external_data: MultiPagedRangeBuf::new(),
426        }
427    }
428}
429
430/// Buffers used during channel processing. Separated out from the mutable state
431/// to allow multiple concurrent references.
432#[derive(Debug, Inspect)]
433struct ChannelBuffers {
434    version: Version,
435    #[inspect(skip)]
436    mem: GuestMemory,
437    #[inspect(skip)]
438    recv_buffer: ReceiveBuffer,
439    #[inspect(skip)]
440    send_buffer: Option<SendBuffer>,
441    ndis_version: NdisVersion,
442    #[inspect(safe)]
443    ndis_config: NdisConfig,
444}
445
446/// An ID assigned to a control message. This is also its receive buffer index.
447#[derive(Copy, Clone, Debug)]
448struct ControlMessageId(u32);
449
450/// Mutable state for a channel that has finished negotiation.
451struct ActiveState {
452    primary: Option<PrimaryChannelState>,
453
454    pending_tx_packets: Vec<PendingTxPacket>,
455    free_tx_packets: Vec<TxId>,
456    pending_tx_completions: VecDeque<PendingTxCompletion>,
457    pending_rx_packets: VecDeque<RxId>,
458
459    rx_bufs: RxBuffers,
460
461    stats: QueueStats,
462}
463
464#[derive(Inspect, Default)]
465struct QueueStats {
466    tx_stalled: Counter,
467    rx_dropped_ring_full: Counter,
468    rx_dropped_filtered: Counter,
469    spurious_wakes: Counter,
470    rx_packets: Counter,
471    tx_packets: Counter,
472    tx_lso_packets: Counter,
473    tx_checksum_packets: Counter,
474    tx_invalid_lso_packets: Counter,
475    tx_packets_per_wake: Histogram<10>,
476    rx_packets_per_wake: Histogram<10>,
477}
478
479#[derive(Debug)]
480struct PendingTxCompletion {
481    transaction_id: u64,
482    tx_id: Option<TxId>,
483    status: protocol::Status,
484}
485
486#[derive(Clone, Copy)]
487enum PrimaryChannelGuestVfState {
488    /// No state has been assigned yet
489    Initializing,
490    /// State is being restored from a save
491    Restoring(saved_state::GuestVfState),
492    /// No VF available for the guest
493    Unavailable,
494    /// A VF was previously available to the guest, but is no longer available.
495    UnavailableFromAvailable,
496    /// A VF was previously available, but it is no longer available.
497    UnavailableFromDataPathSwitchPending { to_guest: bool, id: Option<u64> },
498    /// A VF was previously available, but it is no longer available.
499    UnavailableFromDataPathSwitched,
500    /// A VF is available for the guest.
501    Available { vfid: u32 },
502    /// A VF is available for the guest and has been advertised to the guest.
503    AvailableAdvertised,
504    /// A VF is ready for guest use.
505    Ready,
506    /// A VF is ready in the guest and guest has requested a data path switch.
507    DataPathSwitchPending {
508        to_guest: bool,
509        id: Option<u64>,
510        result: Option<bool>,
511    },
512    /// A VF is ready in the guest and is currently acting as the data path.
513    DataPathSwitched,
514    /// A VF is ready in the guest and was acting as the data path, but an external
515    /// state change has moved it back to synthetic.
516    DataPathSynthetic,
517}
518
519impl std::fmt::Display for PrimaryChannelGuestVfState {
520    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
521        match self {
522            PrimaryChannelGuestVfState::Initializing => write!(f, "initializing"),
523            PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
524                write!(f, "restoring")
525            }
526            PrimaryChannelGuestVfState::Restoring(
527                saved_state::GuestVfState::AvailableAdvertised,
528            ) => write!(f, "restoring from guest notified of vfid"),
529            PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
530                write!(f, "restoring from vf present")
531            }
532            PrimaryChannelGuestVfState::Restoring(
533                saved_state::GuestVfState::DataPathSwitchPending {
534                    to_guest, result, ..
535                },
536            ) => {
537                write!(
538                    f,
539                    "restoring from client requested data path switch: to {} {}",
540                    if *to_guest { "guest" } else { "synthetic" },
541                    if let Some(result) = result {
542                        if *result { "succeeded\"" } else { "failed\"" }
543                    } else {
544                        "in progress\""
545                    }
546                )
547            }
548            PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::DataPathSwitched) => {
549                write!(f, "restoring from data path in guest")
550            }
551            PrimaryChannelGuestVfState::Unavailable => write!(f, "unavailable"),
552            PrimaryChannelGuestVfState::UnavailableFromAvailable => {
553                write!(f, "\"unavailable (previously available)\"")
554            }
555            PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. } => {
556                write!(f, "unavailable (previously switching data path)")
557            }
558            PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
559                write!(f, "\"unavailable (previously using guest VF)\"")
560            }
561            PrimaryChannelGuestVfState::Available { vfid } => write!(f, "available vfid: {}", vfid),
562            PrimaryChannelGuestVfState::AvailableAdvertised => {
563                write!(f, "\"available, guest notified\"")
564            }
565            PrimaryChannelGuestVfState::Ready => write!(f, "\"available and present in guest\""),
566            PrimaryChannelGuestVfState::DataPathSwitchPending {
567                to_guest, result, ..
568            } => {
569                write!(
570                    f,
571                    "\"switching to {} {}",
572                    if *to_guest { "guest" } else { "synthetic" },
573                    if let Some(result) = result {
574                        if *result { "succeeded\"" } else { "failed\"" }
575                    } else {
576                        "in progress\""
577                    }
578                )
579            }
580            PrimaryChannelGuestVfState::DataPathSwitched => {
581                write!(f, "\"available and data path switched\"")
582            }
583            PrimaryChannelGuestVfState::DataPathSynthetic => write!(
584                f,
585                "\"available but data path switched back to synthetic due to external state change\""
586            ),
587        }
588    }
589}
590
591impl Inspect for PrimaryChannelGuestVfState {
592    fn inspect(&self, req: inspect::Request<'_>) {
593        req.value(self.to_string());
594    }
595}
596
597#[derive(Debug)]
598enum PendingLinkAction {
599    Default,
600    Active(bool),
601    Delay(bool),
602}
603
604struct PrimaryChannelState {
605    guest_vf_state: PrimaryChannelGuestVfState,
606    is_data_path_switched: Option<bool>,
607    control_messages: VecDeque<ControlMessage>,
608    control_messages_len: usize,
609    free_control_buffers: Vec<ControlMessageId>,
610    rss_state: Option<RssState>,
611    requested_num_queues: u16,
612    rndis_state: RndisState,
613    offload_config: OffloadConfig,
614    pending_offload_change: bool,
615    tx_spread_sent: bool,
616    guest_link_up: bool,
617    pending_link_action: PendingLinkAction,
618}
619
620impl Inspect for PrimaryChannelState {
621    fn inspect(&self, req: inspect::Request<'_>) {
622        req.respond()
623            .sensitivity_field(
624                "guest_vf_state",
625                SensitivityLevel::Safe,
626                self.guest_vf_state,
627            )
628            .sensitivity_field(
629                "data_path_switched",
630                SensitivityLevel::Safe,
631                self.is_data_path_switched,
632            )
633            .sensitivity_field(
634                "pending_control_messages",
635                SensitivityLevel::Safe,
636                self.control_messages.len(),
637            )
638            .sensitivity_field(
639                "free_control_message_buffers",
640                SensitivityLevel::Safe,
641                self.free_control_buffers.len(),
642            )
643            .sensitivity_field(
644                "pending_offload_change",
645                SensitivityLevel::Safe,
646                self.pending_offload_change,
647            )
648            .sensitivity_field("rndis_state", SensitivityLevel::Safe, self.rndis_state)
649            .sensitivity_field(
650                "offload_config",
651                SensitivityLevel::Safe,
652                &self.offload_config,
653            )
654            .sensitivity_field(
655                "tx_spread_sent",
656                SensitivityLevel::Safe,
657                self.tx_spread_sent,
658            )
659            .sensitivity_field("guest_link_up", SensitivityLevel::Safe, self.guest_link_up)
660            .sensitivity_field(
661                "pending_link_action",
662                SensitivityLevel::Safe,
663                match &self.pending_link_action {
664                    PendingLinkAction::Active(up) => format!("Active({:x?})", up),
665                    PendingLinkAction::Delay(up) => format!("Delay({:x?})", up),
666                    PendingLinkAction::Default => "None".to_string(),
667                },
668            );
669    }
670}
671
672#[derive(Debug, Inspect, Clone)]
673struct OffloadConfig {
674    #[inspect(safe)]
675    checksum_tx: ChecksumOffloadConfig,
676    #[inspect(safe)]
677    checksum_rx: ChecksumOffloadConfig,
678    #[inspect(safe)]
679    lso4: bool,
680    #[inspect(safe)]
681    lso6: bool,
682}
683
684impl OffloadConfig {
685    fn mask_to_supported(&mut self, supported: &OffloadConfig) {
686        self.checksum_tx.mask_to_supported(&supported.checksum_tx);
687        self.checksum_rx.mask_to_supported(&supported.checksum_rx);
688        self.lso4 &= supported.lso4;
689        self.lso6 &= supported.lso6;
690    }
691}
692
693#[derive(Debug, Inspect, Clone)]
694struct ChecksumOffloadConfig {
695    #[inspect(safe)]
696    ipv4_header: bool,
697    #[inspect(safe)]
698    tcp4: bool,
699    #[inspect(safe)]
700    udp4: bool,
701    #[inspect(safe)]
702    tcp6: bool,
703    #[inspect(safe)]
704    udp6: bool,
705}
706
707impl ChecksumOffloadConfig {
708    fn mask_to_supported(&mut self, supported: &ChecksumOffloadConfig) {
709        self.ipv4_header &= supported.ipv4_header;
710        self.tcp4 &= supported.tcp4;
711        self.udp4 &= supported.udp4;
712        self.tcp6 &= supported.tcp6;
713        self.udp6 &= supported.udp6;
714    }
715
716    fn flags(
717        &self,
718    ) -> (
719        rndisprot::Ipv4ChecksumOffload,
720        rndisprot::Ipv6ChecksumOffload,
721    ) {
722        let on = rndisprot::NDIS_OFFLOAD_SUPPORTED;
723        let mut v4 = rndisprot::Ipv4ChecksumOffload::new();
724        let mut v6 = rndisprot::Ipv6ChecksumOffload::new();
725        if self.ipv4_header {
726            v4.set_ip_options_supported(on);
727            v4.set_ip_checksum(on);
728        }
729        if self.tcp4 {
730            v4.set_ip_options_supported(on);
731            v4.set_tcp_options_supported(on);
732            v4.set_tcp_checksum(on);
733        }
734        if self.tcp6 {
735            v6.set_ip_extension_headers_supported(on);
736            v6.set_tcp_options_supported(on);
737            v6.set_tcp_checksum(on);
738        }
739        if self.udp4 {
740            v4.set_ip_options_supported(on);
741            v4.set_udp_checksum(on);
742        }
743        if self.udp6 {
744            v6.set_ip_extension_headers_supported(on);
745            v6.set_udp_checksum(on);
746        }
747        (v4, v6)
748    }
749}
750
751impl OffloadConfig {
752    fn ndis_offload(&self) -> rndisprot::NdisOffload {
753        let checksum = {
754            let (ipv4_tx_flags, ipv6_tx_flags) = self.checksum_tx.flags();
755            let (ipv4_rx_flags, ipv6_rx_flags) = self.checksum_rx.flags();
756            rndisprot::TcpIpChecksumOffload {
757                ipv4_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
758                ipv4_tx_flags,
759                ipv4_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
760                ipv4_rx_flags,
761                ipv6_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
762                ipv6_tx_flags,
763                ipv6_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
764                ipv6_rx_flags,
765            }
766        };
767
768        let lso_v2 = {
769            let mut lso = rndisprot::TcpLargeSendOffloadV2::new_zeroed();
770            if self.lso4 {
771                lso.ipv4_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
772                lso.ipv4_max_offload_size = rndisprot::LSO_MAX_OFFLOAD_SIZE;
773                lso.ipv4_min_segment_count = rndisprot::LSO_MIN_SEGMENT_COUNT;
774            }
775            if self.lso6 {
776                lso.ipv6_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
777                lso.ipv6_max_offload_size = rndisprot::LSO_MAX_OFFLOAD_SIZE;
778                lso.ipv6_min_segment_count = rndisprot::LSO_MIN_SEGMENT_COUNT;
779                lso.ipv6_flags = rndisprot::Ipv6LsoFlags::new()
780                    .with_ip_extension_headers_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED)
781                    .with_tcp_options_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED);
782            }
783            lso
784        };
785
786        rndisprot::NdisOffload {
787            header: rndisprot::NdisObjectHeader {
788                object_type: rndisprot::NdisObjectType::OFFLOAD,
789                revision: 3,
790                size: rndisprot::NDIS_SIZEOF_NDIS_OFFLOAD_REVISION_3 as u16,
791            },
792            checksum,
793            lso_v2,
794            ..FromZeros::new_zeroed()
795        }
796    }
797}
798
799#[derive(Debug, Inspect, PartialEq, Eq, Copy, Clone)]
800pub enum RndisState {
801    Initializing,
802    Operational,
803    Halted,
804}
805
806impl PrimaryChannelState {
807    fn new(offload_config: OffloadConfig) -> Self {
808        Self {
809            guest_vf_state: PrimaryChannelGuestVfState::Initializing,
810            is_data_path_switched: None,
811            control_messages: VecDeque::new(),
812            control_messages_len: 0,
813            free_control_buffers: (0..RX_RESERVED_CONTROL_BUFFERS)
814                .map(ControlMessageId)
815                .collect(),
816            rss_state: None,
817            requested_num_queues: 1,
818            rndis_state: RndisState::Initializing,
819            pending_offload_change: false,
820            offload_config,
821            tx_spread_sent: false,
822            guest_link_up: true,
823            pending_link_action: PendingLinkAction::Default,
824        }
825    }
826
827    fn restore(
828        guest_vf_state: &saved_state::GuestVfState,
829        rndis_state: &saved_state::RndisState,
830        offload_config: &saved_state::OffloadConfig,
831        pending_offload_change: bool,
832        num_queues: u16,
833        indirection_table_size: u16,
834        rx_bufs: &RxBuffers,
835        control_messages: Vec<saved_state::IncomingControlMessage>,
836        rss_state: Option<saved_state::RssState>,
837        tx_spread_sent: bool,
838        guest_link_down: bool,
839        pending_link_action: Option<bool>,
840    ) -> Result<Self, NetRestoreError> {
841        // Restore control messages.
842        let control_messages_len = control_messages.iter().map(|msg| msg.data.len()).sum();
843
844        let control_messages = control_messages
845            .into_iter()
846            .map(|msg| ControlMessage {
847                message_type: msg.message_type,
848                data: msg.data.into(),
849            })
850            .collect();
851
852        // Compute the free control buffers.
853        let free_control_buffers = (0..RX_RESERVED_CONTROL_BUFFERS)
854            .filter_map(|id| rx_bufs.is_free(id).then_some(ControlMessageId(id)))
855            .collect();
856
857        let rss_state = rss_state
858            .map(|mut rss| {
859                if rss.indirection_table.len() > indirection_table_size as usize {
860                    // Dynamic reduction of indirection table can cause unexpected and hard to investigate issues
861                    // with performance and processor overloading.
862                    return Err(NetRestoreError::ReducedIndirectionTableSize);
863                }
864                if rss.indirection_table.len() < indirection_table_size as usize {
865                    tracing::warn!(
866                        saved_indirection_table_size = rss.indirection_table.len(),
867                        adapter_indirection_table_size = indirection_table_size,
868                        "increasing indirection table size",
869                    );
870                    // Dynamic increase of indirection table is done by duplicating the existing entries until
871                    // the desired size is reached.
872                    let table_clone = rss.indirection_table.clone();
873                    let num_to_add = indirection_table_size as usize - rss.indirection_table.len();
874                    rss.indirection_table
875                        .extend(table_clone.iter().cycle().take(num_to_add));
876                }
877                Ok(RssState {
878                    key: rss
879                        .key
880                        .try_into()
881                        .map_err(|_| NetRestoreError::InvalidRssKeySize)?,
882                    indirection_table: rss.indirection_table,
883                })
884            })
885            .transpose()?;
886
887        let rndis_state = match rndis_state {
888            saved_state::RndisState::Initializing => RndisState::Initializing,
889            saved_state::RndisState::Operational => RndisState::Operational,
890            saved_state::RndisState::Halted => RndisState::Halted,
891        };
892
893        let guest_vf_state = PrimaryChannelGuestVfState::Restoring(*guest_vf_state);
894        let offload_config = OffloadConfig {
895            checksum_tx: ChecksumOffloadConfig {
896                ipv4_header: offload_config.checksum_tx.ipv4_header,
897                tcp4: offload_config.checksum_tx.tcp4,
898                udp4: offload_config.checksum_tx.udp4,
899                tcp6: offload_config.checksum_tx.tcp6,
900                udp6: offload_config.checksum_tx.udp6,
901            },
902            checksum_rx: ChecksumOffloadConfig {
903                ipv4_header: offload_config.checksum_rx.ipv4_header,
904                tcp4: offload_config.checksum_rx.tcp4,
905                udp4: offload_config.checksum_rx.udp4,
906                tcp6: offload_config.checksum_rx.tcp6,
907                udp6: offload_config.checksum_rx.udp6,
908            },
909            lso4: offload_config.lso4,
910            lso6: offload_config.lso6,
911        };
912
913        let pending_link_action = if let Some(pending) = pending_link_action {
914            PendingLinkAction::Active(pending)
915        } else {
916            PendingLinkAction::Default
917        };
918
919        Ok(Self {
920            guest_vf_state,
921            is_data_path_switched: None,
922            control_messages,
923            control_messages_len,
924            free_control_buffers,
925            rss_state,
926            requested_num_queues: num_queues,
927            rndis_state,
928            pending_offload_change,
929            offload_config,
930            tx_spread_sent,
931            guest_link_up: !guest_link_down,
932            pending_link_action,
933        })
934    }
935}
936
937struct ControlMessage {
938    message_type: u32,
939    data: Box<[u8]>,
940}
941
942const TX_PACKET_QUOTA: usize = 1024;
943
944impl ActiveState {
945    fn new(primary: Option<PrimaryChannelState>, recv_buffer_count: u32) -> Self {
946        Self {
947            primary,
948            pending_tx_packets: vec![Default::default(); TX_PACKET_QUOTA],
949            free_tx_packets: (0..TX_PACKET_QUOTA as u32).rev().map(TxId).collect(),
950            pending_tx_completions: VecDeque::new(),
951            pending_rx_packets: VecDeque::new(),
952            rx_bufs: RxBuffers::new(recv_buffer_count),
953            stats: Default::default(),
954        }
955    }
956
957    fn restore(
958        channel: &saved_state::Channel,
959        recv_buffer_count: u32,
960    ) -> Result<Self, NetRestoreError> {
961        let mut active = Self::new(None, recv_buffer_count);
962        let saved_state::Channel {
963            pending_tx_completions,
964            in_use_rx,
965        } = channel;
966        for rx in in_use_rx {
967            active
968                .rx_bufs
969                .allocate(rx.buffers.as_slice().iter().copied())?;
970        }
971        for &transaction_id in pending_tx_completions {
972            // Consume tx quota if any is available. If not, still
973            // allow the restore since tx quota might change from
974            // release to release.
975            let tx_id = active.free_tx_packets.pop();
976            if let Some(id) = tx_id {
977                // This shouldn't be referenced, but set it in case it is in the future.
978                active.pending_tx_packets[id.0 as usize].transaction_id = transaction_id;
979            }
980            // Save/Restore does not preserve the status of pending tx completions,
981            // completing any pending completions with 'success' to avoid making changes to saved_state.
982            active
983                .pending_tx_completions
984                .push_back(PendingTxCompletion {
985                    transaction_id,
986                    tx_id,
987                    status: protocol::Status::SUCCESS,
988                });
989        }
990        Ok(active)
991    }
992}
993
994/// The state for an rndis tx packet that's currently pending in the backend
995/// endpoint.
996#[derive(Default, Clone)]
997struct PendingTxPacket {
998    pending_packet_count: usize,
999    transaction_id: u64,
1000}
1001
1002/// The maximum batch size.
1003///
1004/// TODO: An even larger value is supported when RSC is enabled, so look into
1005/// this.
1006const RX_BATCH_SIZE: usize = 375;
1007
1008/// The number of receive buffers to reserve for control message responses.
1009const RX_RESERVED_CONTROL_BUFFERS: u32 = 16;
1010
1011/// A network adapter.
1012pub struct Nic {
1013    instance_id: Guid,
1014    resources: DeviceResources,
1015    coordinator: TaskControl<CoordinatorState, Coordinator>,
1016    coordinator_send: Option<mpsc::Sender<CoordinatorMessage>>,
1017    adapter: Arc<Adapter>,
1018    driver_source: VmTaskDriverSource,
1019}
1020
1021pub struct NicBuilder {
1022    virtual_function: Option<Box<dyn VirtualFunction>>,
1023    limit_ring_buffer: bool,
1024    max_queues: u16,
1025    get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
1026}
1027
1028impl NicBuilder {
1029    pub fn limit_ring_buffer(mut self, limit: bool) -> Self {
1030        self.limit_ring_buffer = limit;
1031        self
1032    }
1033
1034    pub fn max_queues(mut self, max_queues: u16) -> Self {
1035        self.max_queues = max_queues;
1036        self
1037    }
1038
1039    pub fn virtual_function(mut self, virtual_function: Box<dyn VirtualFunction>) -> Self {
1040        self.virtual_function = Some(virtual_function);
1041        self
1042    }
1043
1044    pub fn get_guest_os_id(mut self, os_type: Box<dyn Fn() -> HvGuestOsId + Send + Sync>) -> Self {
1045        self.get_guest_os_id = Some(os_type);
1046        self
1047    }
1048
1049    /// Creates a new NIC.
1050    pub fn build(
1051        self,
1052        driver_source: &VmTaskDriverSource,
1053        instance_id: Guid,
1054        endpoint: Box<dyn Endpoint>,
1055        mac_address: MacAddress,
1056        adapter_index: u32,
1057    ) -> Nic {
1058        let multiqueue = endpoint.multiqueue_support();
1059
1060        let max_queues = self.max_queues.clamp(
1061            1,
1062            multiqueue.max_queues.min(NETVSP_MAX_SUBCHANNELS_PER_VNIC),
1063        );
1064
1065        // If requested, limit the effective size of the outgoing ring buffer.
1066        // In a configuration where the NIC is processed synchronously, this
1067        // will ensure that we don't process incoming rx packets and tx packet
1068        // completions until the guest has processed the data it already has.
1069        let ring_size_limit = if self.limit_ring_buffer { 1024 } else { 0 };
1070
1071        // If the endpoint completes tx packets quickly, then avoid polling the
1072        // incoming ring (and thus avoid arming the signal from the guest) as
1073        // long as there are any tx packets in flight. This can significantly
1074        // reduce the signal rate from the guest, improving batching.
1075        let free_tx_packet_threshold = if endpoint.tx_fast_completions() {
1076            TX_PACKET_QUOTA
1077        } else {
1078            // Avoid getting into a situation where there is always barely
1079            // enough quota.
1080            TX_PACKET_QUOTA / 4
1081        };
1082
1083        let tx_offloads = endpoint.tx_offload_support();
1084
1085        // Always claim support for rx offloads since we can mark any given
1086        // packet as having unknown checksum state.
1087        let offload_support = OffloadConfig {
1088            checksum_rx: ChecksumOffloadConfig {
1089                ipv4_header: true,
1090                tcp4: true,
1091                udp4: true,
1092                tcp6: true,
1093                udp6: true,
1094            },
1095            checksum_tx: ChecksumOffloadConfig {
1096                ipv4_header: tx_offloads.ipv4_header,
1097                tcp4: tx_offloads.tcp,
1098                tcp6: tx_offloads.tcp,
1099                udp4: tx_offloads.udp,
1100                udp6: tx_offloads.udp,
1101            },
1102            lso4: tx_offloads.tso,
1103            lso6: tx_offloads.tso,
1104        };
1105
1106        let driver = driver_source.simple();
1107        let adapter = Arc::new(Adapter {
1108            driver,
1109            mac_address,
1110            max_queues,
1111            indirection_table_size: multiqueue.indirection_table_size,
1112            offload_support,
1113            free_tx_packet_threshold,
1114            ring_size_limit: ring_size_limit.into(),
1115            tx_fast_completions: endpoint.tx_fast_completions(),
1116            adapter_index,
1117            get_guest_os_id: self.get_guest_os_id,
1118            num_sub_channels_opened: AtomicUsize::new(0),
1119            link_speed: endpoint.link_speed(),
1120        });
1121
1122        let coordinator = TaskControl::new(CoordinatorState {
1123            endpoint,
1124            adapter: adapter.clone(),
1125            virtual_function: self.virtual_function,
1126            pending_vf_state: CoordinatorStatePendingVfState::Ready,
1127        });
1128
1129        Nic {
1130            instance_id,
1131            resources: Default::default(),
1132            coordinator,
1133            coordinator_send: None,
1134            adapter,
1135            driver_source: driver_source.clone(),
1136        }
1137    }
1138}
1139
1140fn can_use_ring_opt<T: RingMem>(queue: &mut Queue<T>, guest_os_id: Option<HvGuestOsId>) -> bool {
1141    let Some(guest_os_id) = guest_os_id else {
1142        // guest os id not available.
1143        return false;
1144    };
1145
1146    if !queue.split().0.supports_pending_send_size() {
1147        // guest does not support pending send size.
1148        return false;
1149    }
1150
1151    let Some(open_source_os) = guest_os_id.open_source() else {
1152        // guest os is proprietary (ex: Windows)
1153        return true;
1154    };
1155
1156    match HvGuestOsOpenSourceType(open_source_os.os_type()) {
1157        // Although FreeBSD indicates support for `pending send size`, it doesn't
1158        // implement it correctly. This was fixed in FreeBSD version `1400097`.
1159        HvGuestOsOpenSourceType::FREEBSD => open_source_os.version() >= 1400097,
1160        // Linux kernels prior to 3.11 have issues with pending send size optimization
1161        // which can affect certain Asynchronous I/O (AIO) network operations.
1162        // Disable ring size optimization for these older kernels to avoid flow control issues.
1163        HvGuestOsOpenSourceType::LINUX => {
1164            // Linux version is encoded as: ((major << 16) | (minor << 8) | patch)
1165            // Linux 3.11.0 = (3 << 16) | (11 << 8) | 0 = 199424
1166            open_source_os.version() >= 199424
1167        }
1168        _ => true,
1169    }
1170}
1171
1172impl Nic {
1173    pub fn builder() -> NicBuilder {
1174        NicBuilder {
1175            virtual_function: None,
1176            limit_ring_buffer: false,
1177            max_queues: !0,
1178            get_guest_os_id: None,
1179        }
1180    }
1181
1182    pub fn shutdown(self) -> Box<dyn Endpoint> {
1183        let (state, _) = self.coordinator.into_inner();
1184        state.endpoint
1185    }
1186}
1187
1188impl InspectMut for Nic {
1189    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
1190        self.coordinator.inspect_mut(req);
1191    }
1192}
1193
1194#[async_trait]
1195impl VmbusDevice for Nic {
1196    fn offer(&self) -> OfferParams {
1197        OfferParams {
1198            interface_name: "net".to_owned(),
1199            instance_id: self.instance_id,
1200            interface_id: Guid {
1201                data1: 0xf8615163,
1202                data2: 0xdf3e,
1203                data3: 0x46c5,
1204                data4: [0x91, 0x3f, 0xf2, 0xd2, 0xf9, 0x65, 0xed, 0xe],
1205            },
1206            subchannel_index: 0,
1207            mnf_interrupt_latency: Some(Duration::from_micros(100)),
1208            ..Default::default()
1209        }
1210    }
1211
1212    fn max_subchannels(&self) -> u16 {
1213        self.adapter.max_queues
1214    }
1215
1216    fn install(&mut self, resources: DeviceResources) {
1217        self.resources = resources;
1218    }
1219
1220    async fn open(
1221        &mut self,
1222        channel_idx: u16,
1223        open_request: &OpenRequest,
1224    ) -> Result<(), ChannelOpenError> {
1225        // Start the coordinator task if this is the primary channel.
1226        let state = if channel_idx == 0 {
1227            self.insert_coordinator(1, None);
1228            WorkerState::Init(None)
1229        } else {
1230            self.coordinator.stop().await;
1231            // Get the buffers created when the primary channel was opened.
1232            let buffers = self.coordinator.state().unwrap().buffers.clone().unwrap();
1233            WorkerState::Ready(ReadyState {
1234                state: ActiveState::new(None, buffers.recv_buffer.count),
1235                buffers,
1236                data: ProcessingData::new(),
1237            })
1238        };
1239
1240        let num_opened = self
1241            .adapter
1242            .num_sub_channels_opened
1243            .fetch_add(1, Ordering::SeqCst);
1244        let r = self.insert_worker(channel_idx, open_request, state, true);
1245        if channel_idx != 0
1246            && num_opened + 1 == self.coordinator.state_mut().unwrap().num_queues as usize
1247        {
1248            let coordinator = &mut self.coordinator.state_mut().unwrap();
1249            coordinator.workers[0].stop().await;
1250        }
1251
1252        if r.is_err() && channel_idx == 0 {
1253            self.coordinator.remove();
1254        } else {
1255            // The coordinator will restart any stopped workers.
1256            self.coordinator.start();
1257        }
1258        r?;
1259        Ok(())
1260    }
1261
1262    async fn close(&mut self, channel_idx: u16) {
1263        if !self.coordinator.has_state() {
1264            tracing::error!("Close called while vmbus channel is already closed");
1265            return;
1266        }
1267
1268        // Stop the coordinator to get access to the workers.
1269        let restart = self.coordinator.stop().await;
1270
1271        // Stop and remove the channel worker.
1272        {
1273            let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1274            worker.stop().await;
1275            if worker.has_state() {
1276                worker.remove();
1277            }
1278        }
1279
1280        self.adapter
1281            .num_sub_channels_opened
1282            .fetch_sub(1, Ordering::SeqCst);
1283        // Disable the endpoint.
1284        if channel_idx == 0 {
1285            for worker in &mut self.coordinator.state_mut().unwrap().workers {
1286                worker.task_mut().queue_state = None;
1287            }
1288
1289            // Note that this await is not restartable.
1290            self.coordinator.task_mut().endpoint.stop().await;
1291
1292            // Keep any VF's added to the guest. This is required to keep guest compat as
1293            // some apps (such as DPDK) relies on the VF sticking around even after vmbus
1294            // channel is closed.
1295            // The coordinator's job is done.
1296            self.coordinator.remove();
1297        } else {
1298            // Restart the coordinator.
1299            if restart {
1300                self.coordinator.start();
1301            }
1302        }
1303    }
1304
1305    async fn retarget_vp(&mut self, channel_idx: u16, target_vp: u32) {
1306        if !self.coordinator.has_state() {
1307            return;
1308        }
1309
1310        // Stop the coordinator and worker associated with this channel.
1311        let coordinator_running = self.coordinator.stop().await;
1312        let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1313        worker.stop().await;
1314        let (net_queue, worker_state) = worker.get_mut();
1315
1316        // Update the target VP on the driver.
1317        net_queue.driver.retarget_vp(target_vp);
1318
1319        if let Some(worker_state) = worker_state {
1320            // Update the target VP in the worker state.
1321            worker_state.target_vp = target_vp;
1322            if let Some(queue_state) = &mut net_queue.queue_state {
1323                // Tell the worker to re-set the target VP on next run.
1324                queue_state.target_vp_set = false;
1325            }
1326        }
1327
1328        // The coordinator will restart any stopped workers.
1329        if coordinator_running {
1330            self.coordinator.start();
1331        }
1332    }
1333
1334    fn start(&mut self) {
1335        if !self.coordinator.is_running() {
1336            self.coordinator.start();
1337        }
1338    }
1339
1340    async fn stop(&mut self) {
1341        self.coordinator.stop().await;
1342        if let Some(coordinator) = self.coordinator.state_mut() {
1343            coordinator.stop_workers().await;
1344        }
1345    }
1346
1347    fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
1348        Some(self)
1349    }
1350}
1351
1352#[async_trait]
1353impl SaveRestoreVmbusDevice for Nic {
1354    async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
1355        let state = self.saved_state();
1356        Ok(SavedStateBlob::new(state))
1357    }
1358
1359    async fn restore(
1360        &mut self,
1361        control: RestoreControl<'_>,
1362        state: SavedStateBlob,
1363    ) -> Result<(), RestoreError> {
1364        let state: saved_state::SavedState = state.parse()?;
1365        if let Err(err) = self.restore_state(control, state).await {
1366            tracing::error!(err = &err as &dyn std::error::Error, instance_id = %self.instance_id, "Failed restoring network vmbus state");
1367            Err(err.into())
1368        } else {
1369            Ok(())
1370        }
1371    }
1372}
1373
1374impl Nic {
1375    /// Allocates and inserts a worker.
1376    ///
1377    /// The coordinator must be stopped.
1378    fn insert_worker(
1379        &mut self,
1380        channel_idx: u16,
1381        open_request: &OpenRequest,
1382        state: WorkerState,
1383        start: bool,
1384    ) -> Result<(), OpenError> {
1385        let coordinator = self.coordinator.state_mut().unwrap();
1386
1387        // Retarget the driver now that the channel is open.
1388        let driver = coordinator.workers[channel_idx as usize]
1389            .task()
1390            .driver
1391            .clone();
1392        driver.retarget_vp(open_request.open_data.target_vp);
1393
1394        let raw = gpadl_channel(&driver, &self.resources, open_request, channel_idx)
1395            .map_err(OpenError::Ring)?;
1396        let mut queue = Queue::new(raw).map_err(OpenError::Queue)?;
1397        let guest_os_id = self.adapter.get_guest_os_id.as_ref().map(|f| f());
1398        let can_use_ring_size_opt = can_use_ring_opt(&mut queue, guest_os_id);
1399        let worker = Worker {
1400            channel_idx,
1401            target_vp: open_request.open_data.target_vp,
1402            mem: self
1403                .resources
1404                .offer_resources
1405                .guest_memory(open_request)
1406                .clone(),
1407            channel: NetChannel {
1408                adapter: self.adapter.clone(),
1409                queue,
1410                gpadl_map: self.resources.gpadl_map.clone(),
1411                packet_size: PacketSize::V1,
1412                pending_send_size: 0,
1413                restart: None,
1414                can_use_ring_size_opt,
1415                packet_filter: coordinator.active_packet_filter,
1416            },
1417            state,
1418            coordinator_send: self.coordinator_send.clone().unwrap(),
1419        };
1420        let instance_id = self.instance_id;
1421        let worker_task = &mut coordinator.workers[channel_idx as usize];
1422        worker_task.insert(
1423            driver,
1424            format!("netvsp-{}-{}", instance_id, channel_idx),
1425            worker,
1426        );
1427        if start {
1428            worker_task.start();
1429        }
1430        Ok(())
1431    }
1432}
1433
1434struct RestoreCoordinatorState {
1435    active_packet_filter: u32,
1436}
1437
1438impl Nic {
1439    /// If `restoring`, then restart the queues as soon as the coordinator starts.
1440    fn insert_coordinator(&mut self, num_queues: u16, restoring: Option<RestoreCoordinatorState>) {
1441        let mut driver_builder = self.driver_source.builder();
1442        // Target each driver to VP 0 initially. This will be updated when the
1443        // channel is opened.
1444        driver_builder.target_vp(0);
1445        // If tx completions arrive quickly, then just do tx processing
1446        // on whatever processor the guest happens to signal from.
1447        // Subsequent transmits will be pulled from the completion
1448        // processor.
1449        driver_builder.run_on_target(!self.adapter.tx_fast_completions);
1450
1451        #[expect(clippy::disallowed_methods)] // TODO
1452        let (send, recv) = mpsc::channel(1);
1453        self.coordinator_send = Some(send);
1454        self.coordinator.insert(
1455            &self.adapter.driver,
1456            format!("netvsp-{}-coordinator", self.instance_id),
1457            Coordinator {
1458                recv,
1459                channel_control: self.resources.channel_control.clone(),
1460                restart: restoring.is_some(),
1461                workers: (0..self.adapter.max_queues)
1462                    .map(|i| {
1463                        TaskControl::new(NetQueue {
1464                            queue_state: None,
1465                            driver: driver_builder
1466                                .build(format!("netvsp-{}-{}", self.instance_id, i)),
1467                        })
1468                    })
1469                    .collect(),
1470                buffers: None,
1471                num_queues,
1472                active_packet_filter: restoring
1473                    .map(|r| r.active_packet_filter)
1474                    .unwrap_or(rndisprot::NDIS_PACKET_TYPE_NONE),
1475            },
1476        );
1477    }
1478}
1479
1480#[derive(Debug, Error)]
1481enum NetRestoreError {
1482    #[error("unsupported protocol version {0:#x}")]
1483    UnsupportedVersion(u32),
1484    #[error("send/receive buffer invalid gpadl ID")]
1485    UnknownGpadlId(#[from] UnknownGpadlId),
1486    #[error("failed to restore channels")]
1487    Channel(#[source] ChannelRestoreError),
1488    #[error(transparent)]
1489    ReceiveBuffer(#[from] BufferError),
1490    #[error(transparent)]
1491    SuballocationMisconfigured(#[from] SubAllocationInUse),
1492    #[error(transparent)]
1493    Open(#[from] OpenError),
1494    #[error("invalid rss key size")]
1495    InvalidRssKeySize,
1496    #[error("reduced indirection table size")]
1497    ReducedIndirectionTableSize,
1498}
1499
1500impl From<NetRestoreError> for RestoreError {
1501    fn from(err: NetRestoreError) -> Self {
1502        RestoreError::InvalidSavedState(anyhow::Error::new(err))
1503    }
1504}
1505
1506impl Nic {
1507    async fn restore_state(
1508        &mut self,
1509        mut control: RestoreControl<'_>,
1510        state: saved_state::SavedState,
1511    ) -> Result<(), NetRestoreError> {
1512        let mut saved_packet_filter = 0u32;
1513        if let Some(state) = state.open {
1514            let open = match &state.primary {
1515                saved_state::Primary::Version => vec![true],
1516                saved_state::Primary::Init(_) => vec![true],
1517                saved_state::Primary::Ready(ready) => {
1518                    ready.channels.iter().map(|x| x.is_some()).collect()
1519                }
1520            };
1521
1522            let mut states: Vec<_> = open.iter().map(|_| None).collect();
1523
1524            // N.B. This will restore the vmbus view of open channels, so any
1525            //      failures after this point could result in inconsistent
1526            //      state (vmbus believes the channel is open/active). There
1527            //      are a number of failure paths after this point because this
1528            //      call also restores vmbus device state, like the GPADL map.
1529            let requests = control
1530                .restore(&open)
1531                .await
1532                .map_err(NetRestoreError::Channel)?;
1533
1534            match state.primary {
1535                saved_state::Primary::Version => {
1536                    states[0] = Some(WorkerState::Init(None));
1537                }
1538                saved_state::Primary::Init(init) => {
1539                    let version = check_version(init.version)
1540                        .ok_or(NetRestoreError::UnsupportedVersion(init.version))?;
1541
1542                    let recv_buffer = init
1543                        .receive_buffer
1544                        .map(|recv_buffer| {
1545                            ReceiveBuffer::new(
1546                                &self.resources.gpadl_map,
1547                                recv_buffer.gpadl_id,
1548                                recv_buffer.id,
1549                                recv_buffer.sub_allocation_size,
1550                            )
1551                        })
1552                        .transpose()?;
1553
1554                    let send_buffer = init
1555                        .send_buffer
1556                        .map(|send_buffer| {
1557                            SendBuffer::new(&self.resources.gpadl_map, send_buffer.gpadl_id)
1558                        })
1559                        .transpose()?;
1560
1561                    let state = InitState {
1562                        version,
1563                        ndis_config: init.ndis_config.map(
1564                            |saved_state::NdisConfig { mtu, capabilities }| NdisConfig {
1565                                mtu,
1566                                capabilities: capabilities.into(),
1567                            },
1568                        ),
1569                        ndis_version: init.ndis_version.map(
1570                            |saved_state::NdisVersion { major, minor }| NdisVersion {
1571                                major,
1572                                minor,
1573                            },
1574                        ),
1575                        recv_buffer,
1576                        send_buffer,
1577                    };
1578                    states[0] = Some(WorkerState::Init(Some(state)));
1579                }
1580                saved_state::Primary::Ready(ready) => {
1581                    let saved_state::ReadyPrimary {
1582                        version,
1583                        receive_buffer,
1584                        send_buffer,
1585                        mut control_messages,
1586                        mut rss_state,
1587                        channels,
1588                        ndis_version,
1589                        ndis_config,
1590                        rndis_state,
1591                        guest_vf_state,
1592                        offload_config,
1593                        pending_offload_change,
1594                        tx_spread_sent,
1595                        guest_link_down,
1596                        pending_link_action,
1597                        packet_filter,
1598                    } = ready;
1599
1600                    // If saved state does not have a packet filter set, default to directed, multicast, and broadcast.
1601                    saved_packet_filter = packet_filter.unwrap_or(rndisprot::NPROTO_PACKET_FILTER);
1602
1603                    let version = check_version(version)
1604                        .ok_or(NetRestoreError::UnsupportedVersion(version))?;
1605
1606                    let request = requests[0].as_ref().unwrap();
1607                    let buffers = Arc::new(ChannelBuffers {
1608                        version,
1609                        mem: self.resources.offer_resources.guest_memory(request).clone(),
1610                        recv_buffer: ReceiveBuffer::new(
1611                            &self.resources.gpadl_map,
1612                            receive_buffer.gpadl_id,
1613                            receive_buffer.id,
1614                            receive_buffer.sub_allocation_size,
1615                        )?,
1616                        send_buffer: {
1617                            if let Some(send_buffer) = send_buffer {
1618                                Some(SendBuffer::new(
1619                                    &self.resources.gpadl_map,
1620                                    send_buffer.gpadl_id,
1621                                )?)
1622                            } else {
1623                                None
1624                            }
1625                        },
1626                        ndis_version: {
1627                            let saved_state::NdisVersion { major, minor } = ndis_version;
1628                            NdisVersion { major, minor }
1629                        },
1630                        ndis_config: {
1631                            let saved_state::NdisConfig { mtu, capabilities } = ndis_config;
1632                            NdisConfig {
1633                                mtu,
1634                                capabilities: capabilities.into(),
1635                            }
1636                        },
1637                    });
1638
1639                    for (channel_idx, channel) in channels.iter().enumerate() {
1640                        let channel = if let Some(channel) = channel {
1641                            channel
1642                        } else {
1643                            continue;
1644                        };
1645
1646                        let mut active = ActiveState::restore(channel, buffers.recv_buffer.count)?;
1647
1648                        // Restore primary channel state.
1649                        if channel_idx == 0 {
1650                            let primary = PrimaryChannelState::restore(
1651                                &guest_vf_state,
1652                                &rndis_state,
1653                                &offload_config,
1654                                pending_offload_change,
1655                                channels.len() as u16,
1656                                self.adapter.indirection_table_size,
1657                                &active.rx_bufs,
1658                                std::mem::take(&mut control_messages),
1659                                rss_state.take(),
1660                                tx_spread_sent,
1661                                guest_link_down,
1662                                pending_link_action,
1663                            )?;
1664                            active.primary = Some(primary);
1665                        }
1666
1667                        states[channel_idx] = Some(WorkerState::Ready(ReadyState {
1668                            buffers: buffers.clone(),
1669                            state: active,
1670                            data: ProcessingData::new(),
1671                        }));
1672                    }
1673                }
1674            }
1675
1676            // Insert the coordinator and mark that it should try to start the
1677            // network endpoint when it starts running.
1678            self.insert_coordinator(
1679                states.len() as u16,
1680                Some(RestoreCoordinatorState {
1681                    active_packet_filter: saved_packet_filter,
1682                }),
1683            );
1684
1685            for (channel_idx, (state, request)) in states.into_iter().zip(requests).enumerate() {
1686                if let Some(state) = state {
1687                    self.insert_worker(channel_idx as u16, &request.unwrap(), state, false)?;
1688                }
1689            }
1690        } else {
1691            control
1692                .restore(&[false])
1693                .await
1694                .map_err(NetRestoreError::Channel)?;
1695        }
1696        Ok(())
1697    }
1698
1699    fn saved_state(&self) -> saved_state::SavedState {
1700        let open = if let Some(coordinator) = self.coordinator.state() {
1701            let primary = coordinator.workers[0].state().unwrap();
1702            let primary = match &primary.state {
1703                WorkerState::Init(None) => saved_state::Primary::Version,
1704                WorkerState::Init(Some(init)) => {
1705                    saved_state::Primary::Init(saved_state::InitPrimary {
1706                        version: init.version as u32,
1707                        ndis_config: init.ndis_config.map(|NdisConfig { mtu, capabilities }| {
1708                            saved_state::NdisConfig {
1709                                mtu,
1710                                capabilities: capabilities.into(),
1711                            }
1712                        }),
1713                        ndis_version: init.ndis_version.map(|NdisVersion { major, minor }| {
1714                            saved_state::NdisVersion { major, minor }
1715                        }),
1716                        receive_buffer: init.recv_buffer.as_ref().map(|x| x.saved_state()),
1717                        send_buffer: init.send_buffer.as_ref().map(|x| saved_state::SendBuffer {
1718                            gpadl_id: x.gpadl.id(),
1719                        }),
1720                    })
1721                }
1722                WorkerState::Ready(ready) => {
1723                    let primary = ready.state.primary.as_ref().unwrap();
1724
1725                    let rndis_state = match primary.rndis_state {
1726                        RndisState::Initializing => saved_state::RndisState::Initializing,
1727                        RndisState::Operational => saved_state::RndisState::Operational,
1728                        RndisState::Halted => saved_state::RndisState::Halted,
1729                    };
1730
1731                    let offload_config = saved_state::OffloadConfig {
1732                        checksum_tx: saved_state::ChecksumOffloadConfig {
1733                            ipv4_header: primary.offload_config.checksum_tx.ipv4_header,
1734                            tcp4: primary.offload_config.checksum_tx.tcp4,
1735                            udp4: primary.offload_config.checksum_tx.udp4,
1736                            tcp6: primary.offload_config.checksum_tx.tcp6,
1737                            udp6: primary.offload_config.checksum_tx.udp6,
1738                        },
1739                        checksum_rx: saved_state::ChecksumOffloadConfig {
1740                            ipv4_header: primary.offload_config.checksum_rx.ipv4_header,
1741                            tcp4: primary.offload_config.checksum_rx.tcp4,
1742                            udp4: primary.offload_config.checksum_rx.udp4,
1743                            tcp6: primary.offload_config.checksum_rx.tcp6,
1744                            udp6: primary.offload_config.checksum_rx.udp6,
1745                        },
1746                        lso4: primary.offload_config.lso4,
1747                        lso6: primary.offload_config.lso6,
1748                    };
1749
1750                    let control_messages = primary
1751                        .control_messages
1752                        .iter()
1753                        .map(|message| saved_state::IncomingControlMessage {
1754                            message_type: message.message_type,
1755                            data: message.data.to_vec(),
1756                        })
1757                        .collect();
1758
1759                    let rss_state = primary.rss_state.as_ref().map(|rss| saved_state::RssState {
1760                        key: rss.key.into(),
1761                        indirection_table: rss.indirection_table.clone(),
1762                    });
1763
1764                    let pending_link_action = match primary.pending_link_action {
1765                        PendingLinkAction::Default => None,
1766                        PendingLinkAction::Active(action) | PendingLinkAction::Delay(action) => {
1767                            Some(action)
1768                        }
1769                    };
1770
1771                    let channels = coordinator.workers[..coordinator.num_queues as usize]
1772                        .iter()
1773                        .map(|worker| {
1774                            worker.state().map(|worker| {
1775                                if let Some(ready) = worker.state.ready() {
1776                                    // In flight tx will be considered as dropped packets through save/restore, but need
1777                                    // to complete the requests back to the guest.
1778                                    let pending_tx_completions = ready
1779                                        .state
1780                                        .pending_tx_completions
1781                                        .iter()
1782                                        .map(|pending| pending.transaction_id)
1783                                        .chain(ready.state.pending_tx_packets.iter().filter_map(
1784                                            |inflight| {
1785                                                (inflight.pending_packet_count > 0)
1786                                                    .then_some(inflight.transaction_id)
1787                                            },
1788                                        ))
1789                                        .collect();
1790
1791                                    saved_state::Channel {
1792                                        pending_tx_completions,
1793                                        in_use_rx: {
1794                                            ready
1795                                                .state
1796                                                .rx_bufs
1797                                                .allocated()
1798                                                .map(|id| saved_state::Rx {
1799                                                    buffers: id.collect(),
1800                                                })
1801                                                .collect()
1802                                        },
1803                                    }
1804                                } else {
1805                                    saved_state::Channel {
1806                                        pending_tx_completions: Vec::new(),
1807                                        in_use_rx: Vec::new(),
1808                                    }
1809                                }
1810                            })
1811                        })
1812                        .collect();
1813
1814                    let guest_vf_state = match primary.guest_vf_state {
1815                        PrimaryChannelGuestVfState::Initializing
1816                        | PrimaryChannelGuestVfState::Unavailable
1817                        | PrimaryChannelGuestVfState::Available { .. } => {
1818                            saved_state::GuestVfState::NoState
1819                        }
1820                        PrimaryChannelGuestVfState::UnavailableFromAvailable
1821                        | PrimaryChannelGuestVfState::AvailableAdvertised => {
1822                            saved_state::GuestVfState::AvailableAdvertised
1823                        }
1824                        PrimaryChannelGuestVfState::Ready => saved_state::GuestVfState::Ready,
1825                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
1826                            to_guest,
1827                            id,
1828                        } => saved_state::GuestVfState::DataPathSwitchPending {
1829                            to_guest,
1830                            id,
1831                            result: None,
1832                        },
1833                        PrimaryChannelGuestVfState::DataPathSwitchPending {
1834                            to_guest,
1835                            id,
1836                            result,
1837                        } => saved_state::GuestVfState::DataPathSwitchPending {
1838                            to_guest,
1839                            id,
1840                            result,
1841                        },
1842                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
1843                        | PrimaryChannelGuestVfState::DataPathSwitched
1844                        | PrimaryChannelGuestVfState::DataPathSynthetic => {
1845                            saved_state::GuestVfState::DataPathSwitched
1846                        }
1847                        PrimaryChannelGuestVfState::Restoring(saved_state) => saved_state,
1848                    };
1849
1850                    let worker_0_packet_filter = coordinator.workers[0]
1851                        .state()
1852                        .unwrap()
1853                        .channel
1854                        .packet_filter;
1855                    saved_state::Primary::Ready(saved_state::ReadyPrimary {
1856                        version: ready.buffers.version as u32,
1857                        receive_buffer: ready.buffers.recv_buffer.saved_state(),
1858                        send_buffer: ready.buffers.send_buffer.as_ref().map(|sb| {
1859                            saved_state::SendBuffer {
1860                                gpadl_id: sb.gpadl.id(),
1861                            }
1862                        }),
1863                        rndis_state,
1864                        guest_vf_state,
1865                        offload_config,
1866                        pending_offload_change: primary.pending_offload_change,
1867                        control_messages,
1868                        rss_state,
1869                        channels,
1870                        ndis_config: {
1871                            let NdisConfig { mtu, capabilities } = ready.buffers.ndis_config;
1872                            saved_state::NdisConfig {
1873                                mtu,
1874                                capabilities: capabilities.into(),
1875                            }
1876                        },
1877                        ndis_version: {
1878                            let NdisVersion { major, minor } = ready.buffers.ndis_version;
1879                            saved_state::NdisVersion { major, minor }
1880                        },
1881                        tx_spread_sent: primary.tx_spread_sent,
1882                        guest_link_down: !primary.guest_link_up,
1883                        pending_link_action,
1884                        packet_filter: Some(worker_0_packet_filter),
1885                    })
1886                }
1887            };
1888
1889            let state = saved_state::OpenState { primary };
1890            Some(state)
1891        } else {
1892            None
1893        };
1894
1895        saved_state::SavedState { open }
1896    }
1897}
1898
1899#[derive(Debug, Error)]
1900enum WorkerError {
1901    #[error("packet error")]
1902    Packet(#[source] PacketError),
1903    #[error("unexpected packet order: {0}")]
1904    UnexpectedPacketOrder(#[source] PacketOrderError),
1905    #[error("unknown rndis message type: {0}")]
1906    UnknownRndisMessageType(u32),
1907    #[error("junk after rndis packet message: {0:#x}")]
1908    NonRndisPacketAfterPacket(u32),
1909    #[error("memory access error")]
1910    Access(#[from] AccessError),
1911    #[error("rndis message too small")]
1912    RndisMessageTooSmall,
1913    #[error("unsupported rndis behavior")]
1914    UnsupportedRndisBehavior,
1915    #[error("vmbus queue error")]
1916    Queue(#[from] queue::Error),
1917    #[error("too many control messages")]
1918    TooManyControlMessages,
1919    #[error("invalid rndis packet completion")]
1920    InvalidRndisPacketCompletion,
1921    #[error("missing transaction id")]
1922    MissingTransactionId,
1923    #[error("invalid gpadl")]
1924    InvalidGpadl(#[source] guestmem::InvalidGpn),
1925    #[error("gpadl error")]
1926    GpadlError(#[source] GuestMemoryError),
1927    #[error("gpa direct error")]
1928    GpaDirectError(#[source] GuestMemoryError),
1929    #[error("endpoint")]
1930    Endpoint(#[source] anyhow::Error),
1931    #[error("message not supported on sub channel: {0}")]
1932    NotSupportedOnSubChannel(u32),
1933    #[error("the ring buffer ran out of space, which should not be possible")]
1934    OutOfSpace,
1935    #[error("send/receive buffer error")]
1936    Buffer(#[from] BufferError),
1937    #[error("invalid rndis state")]
1938    InvalidRndisState,
1939    #[error("rndis message type not implemented")]
1940    RndisMessageTypeNotImplemented,
1941    #[error("invalid TCP header offset")]
1942    InvalidTcpHeaderOffset,
1943    #[error("cancelled")]
1944    Cancelled(task_control::Cancelled),
1945    #[error("tearing down because send/receive buffer is revoked")]
1946    BufferRevoked,
1947    #[error("endpoint requires queue restart: {0}")]
1948    EndpointRequiresQueueRestart(#[source] anyhow::Error),
1949}
1950
1951impl From<task_control::Cancelled> for WorkerError {
1952    fn from(value: task_control::Cancelled) -> Self {
1953        Self::Cancelled(value)
1954    }
1955}
1956
1957#[derive(Debug, Error)]
1958enum OpenError {
1959    #[error("error establishing ring buffer")]
1960    Ring(#[source] vmbus_channel::gpadl_ring::Error),
1961    #[error("error establishing vmbus queue")]
1962    Queue(#[source] queue::Error),
1963}
1964
1965#[derive(Debug, Error)]
1966enum PacketError {
1967    #[error("UnknownType {0}")]
1968    UnknownType(u32),
1969    #[error("Access")]
1970    Access(#[source] AccessError),
1971    #[error("ExternalData")]
1972    ExternalData(#[source] ExternalDataError),
1973    #[error("InvalidSendBufferIndex")]
1974    InvalidSendBufferIndex,
1975}
1976
1977#[derive(Debug, Error)]
1978enum PacketOrderError {
1979    #[error("Invalid PacketData")]
1980    InvalidPacketData,
1981    #[error("Unexpected RndisPacket")]
1982    UnexpectedRndisPacket,
1983    #[error("SendNdisVersion already exists")]
1984    SendNdisVersionExists,
1985    #[error("SendNdisConfig already exists")]
1986    SendNdisConfigExists,
1987    #[error("SendReceiveBuffer already exists")]
1988    SendReceiveBufferExists,
1989    #[error("SendReceiveBuffer missing MTU")]
1990    SendReceiveBufferMissingMTU,
1991    #[error("SendSendBuffer already exists")]
1992    SendSendBufferExists,
1993    #[error("SwitchDataPathCompletion during PrimaryChannelState")]
1994    SwitchDataPathCompletionPrimaryChannelState,
1995}
1996
1997#[derive(Debug)]
1998enum PacketData {
1999    Init(protocol::MessageInit),
2000    SendNdisVersion(protocol::Message1SendNdisVersion),
2001    SendReceiveBuffer(protocol::Message1SendReceiveBuffer),
2002    SendSendBuffer(protocol::Message1SendSendBuffer),
2003    RevokeReceiveBuffer(Message1RevokeReceiveBuffer),
2004    RevokeSendBuffer(Message1RevokeSendBuffer),
2005    RndisPacket(protocol::Message1SendRndisPacket),
2006    RndisPacketComplete(protocol::Message1SendRndisPacketComplete),
2007    SendNdisConfig(protocol::Message2SendNdisConfig),
2008    SwitchDataPath(protocol::Message4SwitchDataPath),
2009    OidQueryEx(protocol::Message5OidQueryEx),
2010    SubChannelRequest(protocol::Message5SubchannelRequest),
2011    SendVfAssociationCompletion,
2012    SwitchDataPathCompletion,
2013}
2014
2015#[derive(Debug)]
2016struct Packet<'a> {
2017    data: PacketData,
2018    transaction_id: Option<u64>,
2019    external_data: &'a MultiPagedRangeBuf,
2020}
2021
2022type PacketReader<'a> = PagedRangesReader<'a, MultiPagedRangeIter<'a>>;
2023
2024impl Packet<'_> {
2025    fn rndis_reader<'a>(&'a self, mem: &'a GuestMemory) -> PacketReader<'a> {
2026        PagedRanges::new(self.external_data.iter()).reader(mem)
2027    }
2028}
2029
2030fn read_packet_data<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
2031    reader: &mut impl MemoryRead,
2032) -> Result<T, PacketError> {
2033    reader.read_plain().map_err(PacketError::Access)
2034}
2035
2036fn parse_packet<'a, T: RingMem>(
2037    packet_ref: &queue::PacketRef<'_, T>,
2038    send_buffer: Option<&SendBuffer>,
2039    version: Option<Version>,
2040    external_data: &'a mut MultiPagedRangeBuf,
2041) -> Result<Packet<'a>, PacketError> {
2042    external_data.clear();
2043    let packet = match packet_ref.as_ref() {
2044        IncomingPacket::Data(data) => data,
2045        IncomingPacket::Completion(completion) => {
2046            let data = if completion.transaction_id() == VF_ASSOCIATION_TRANSACTION_ID {
2047                PacketData::SendVfAssociationCompletion
2048            } else if completion.transaction_id() == SWITCH_DATA_PATH_TRANSACTION_ID {
2049                PacketData::SwitchDataPathCompletion
2050            } else {
2051                let mut reader = completion.reader();
2052                let header: protocol::MessageHeader =
2053                    reader.read_plain().map_err(PacketError::Access)?;
2054                match header.message_type {
2055                    protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE => {
2056                        PacketData::RndisPacketComplete(read_packet_data(&mut reader)?)
2057                    }
2058                    typ => return Err(PacketError::UnknownType(typ)),
2059                }
2060            };
2061            return Ok(Packet {
2062                data,
2063                transaction_id: Some(completion.transaction_id()),
2064                external_data,
2065            });
2066        }
2067    };
2068
2069    let mut reader = packet.reader();
2070    let header: protocol::MessageHeader = reader.read_plain().map_err(PacketError::Access)?;
2071    let data = match header.message_type {
2072        protocol::MESSAGE_TYPE_INIT => PacketData::Init(read_packet_data(&mut reader)?),
2073        protocol::MESSAGE1_TYPE_SEND_NDIS_VERSION if version >= Some(Version::V1) => {
2074            PacketData::SendNdisVersion(read_packet_data(&mut reader)?)
2075        }
2076        protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2077            PacketData::SendReceiveBuffer(read_packet_data(&mut reader)?)
2078        }
2079        protocol::MESSAGE1_TYPE_REVOKE_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2080            PacketData::RevokeReceiveBuffer(read_packet_data(&mut reader)?)
2081        }
2082        protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER if version >= Some(Version::V1) => {
2083            PacketData::SendSendBuffer(read_packet_data(&mut reader)?)
2084        }
2085        protocol::MESSAGE1_TYPE_REVOKE_SEND_BUFFER if version >= Some(Version::V1) => {
2086            PacketData::RevokeSendBuffer(read_packet_data(&mut reader)?)
2087        }
2088        protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET if version >= Some(Version::V1) => {
2089            let message: protocol::Message1SendRndisPacket = read_packet_data(&mut reader)?;
2090            if message.send_buffer_section_index != 0xffffffff {
2091                let send_buffer_suballocation = send_buffer
2092                    .ok_or(PacketError::InvalidSendBufferIndex)?
2093                    .gpadl
2094                    .first()
2095                    .unwrap()
2096                    .try_subrange(
2097                        message.send_buffer_section_index as usize * 6144,
2098                        message.send_buffer_section_size as usize,
2099                    )
2100                    .ok_or(PacketError::InvalidSendBufferIndex)?;
2101
2102                external_data.push_range(send_buffer_suballocation);
2103            }
2104            PacketData::RndisPacket(message)
2105        }
2106        protocol::MESSAGE2_TYPE_SEND_NDIS_CONFIG if version >= Some(Version::V2) => {
2107            PacketData::SendNdisConfig(read_packet_data(&mut reader)?)
2108        }
2109        protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH if version >= Some(Version::V4) => {
2110            PacketData::SwitchDataPath(read_packet_data(&mut reader)?)
2111        }
2112        protocol::MESSAGE5_TYPE_OID_QUERY_EX if version >= Some(Version::V5) => {
2113            PacketData::OidQueryEx(read_packet_data(&mut reader)?)
2114        }
2115        protocol::MESSAGE5_TYPE_SUB_CHANNEL if version >= Some(Version::V5) => {
2116            PacketData::SubChannelRequest(read_packet_data(&mut reader)?)
2117        }
2118        typ => return Err(PacketError::UnknownType(typ)),
2119    };
2120    packet
2121        .read_external_ranges(external_data)
2122        .map_err(PacketError::ExternalData)?;
2123    Ok(Packet {
2124        data,
2125        transaction_id: packet.transaction_id(),
2126        external_data,
2127    })
2128}
2129
2130#[derive(Debug, Copy, Clone)]
2131struct NvspMessage {
2132    buf: [u64; protocol::PACKET_SIZE_V61 / 8],
2133    size: PacketSize,
2134}
2135
2136impl NvspMessage {
2137    fn new<P: IntoBytes + Immutable + KnownLayout>(
2138        size: PacketSize,
2139        message_type: u32,
2140        data: P,
2141    ) -> Self {
2142        // Assert at compile time that the packet will fit in the message
2143        // buffer. Note that we are checking against the v1 message size here.
2144        // It's possible this is a v6.1+ message, in which case we could compare
2145        // against the larger size. So far this has not been necessary. If
2146        // needed, make a `new_v61` method that does the more relaxed check,
2147        // rater than weakening this one.
2148        const {
2149            assert!(
2150                size_of::<P>() <= protocol::PACKET_SIZE_V1 - size_of::<protocol::MessageHeader>(),
2151                "packet might not fit in message"
2152            )
2153        };
2154        let mut message = NvspMessage {
2155            buf: [0; protocol::PACKET_SIZE_V61 / 8],
2156            size,
2157        };
2158        let header = protocol::MessageHeader { message_type };
2159        header.write_to_prefix(message.buf.as_mut_bytes()).unwrap();
2160        data.write_to_prefix(
2161            &mut message.buf.as_mut_bytes()[size_of::<protocol::MessageHeader>()..],
2162        )
2163        .unwrap();
2164        message
2165    }
2166
2167    fn aligned_payload(&self) -> &[u64] {
2168        // Note that vmbus packets are always 8-byte multiples, so round the
2169        // protocol package size up.
2170        let len = match self.size {
2171            PacketSize::V1 => const { protocol::PACKET_SIZE_V1.next_multiple_of(8) / 8 },
2172            PacketSize::V61 => const { protocol::PACKET_SIZE_V61.next_multiple_of(8) / 8 },
2173        };
2174        &self.buf[..len]
2175    }
2176}
2177
2178impl<T: RingMem> NetChannel<T> {
2179    fn message<P: IntoBytes + Immutable + KnownLayout>(
2180        &self,
2181        message_type: u32,
2182        data: P,
2183    ) -> NvspMessage {
2184        NvspMessage::new(self.packet_size, message_type, data)
2185    }
2186
2187    fn send_completion(
2188        &mut self,
2189        transaction_id: Option<u64>,
2190        message: Option<&NvspMessage>,
2191    ) -> Result<(), WorkerError> {
2192        match transaction_id {
2193            None => Ok(()),
2194            Some(transaction_id) => Ok(self
2195                .queue
2196                .split()
2197                .1
2198                .batched()
2199                .try_write_aligned(
2200                    transaction_id,
2201                    OutgoingPacketType::Completion,
2202                    message.map_or(&[], |m| m.aligned_payload()),
2203                )
2204                .map_err(|err| match err {
2205                    queue::TryWriteError::Full(_) => WorkerError::OutOfSpace,
2206                    queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2207                })?),
2208        }
2209    }
2210}
2211
2212static SUPPORTED_VERSIONS: &[Version] = &[
2213    Version::V1,
2214    Version::V2,
2215    Version::V4,
2216    Version::V5,
2217    Version::V6,
2218    Version::V61,
2219];
2220
2221fn check_version(requested_version: u32) -> Option<Version> {
2222    SUPPORTED_VERSIONS
2223        .iter()
2224        .find(|version| **version as u32 == requested_version)
2225        .copied()
2226}
2227
2228#[derive(Debug)]
2229struct ReceiveBuffer {
2230    gpadl: GpadlView,
2231    id: u16,
2232    count: u32,
2233    sub_allocation_size: u32,
2234}
2235
2236#[derive(Debug, Error)]
2237enum BufferError {
2238    #[error("unsupported suballocation size {0}")]
2239    UnsupportedSuballocationSize(u32),
2240    #[error("unaligned gpadl")]
2241    UnalignedGpadl,
2242    #[error("unknown gpadl ID")]
2243    UnknownGpadlId(#[from] UnknownGpadlId),
2244}
2245
2246impl ReceiveBuffer {
2247    fn new(
2248        gpadl_map: &GpadlMapView,
2249        gpadl_id: GpadlId,
2250        id: u16,
2251        sub_allocation_size: u32,
2252    ) -> Result<Self, BufferError> {
2253        if sub_allocation_size < sub_allocation_size_for_mtu(DEFAULT_MTU) {
2254            return Err(BufferError::UnsupportedSuballocationSize(
2255                sub_allocation_size,
2256            ));
2257        }
2258        let gpadl = gpadl_map.map(gpadl_id)?;
2259        let range = gpadl
2260            .contiguous_aligned()
2261            .ok_or(BufferError::UnalignedGpadl)?;
2262        let num_sub_allocations = range.len() as u32 / sub_allocation_size;
2263        if num_sub_allocations == 0 {
2264            return Err(BufferError::UnsupportedSuballocationSize(
2265                sub_allocation_size,
2266            ));
2267        }
2268        let recv_buffer = Self {
2269            gpadl,
2270            id,
2271            count: num_sub_allocations,
2272            sub_allocation_size,
2273        };
2274        Ok(recv_buffer)
2275    }
2276
2277    fn range(&self, index: u32) -> PagedRange<'_> {
2278        self.gpadl.first().unwrap().subrange(
2279            (index * self.sub_allocation_size) as usize,
2280            self.sub_allocation_size as usize,
2281        )
2282    }
2283
2284    fn transfer_page_range(&self, index: u32, len: usize) -> ring::TransferPageRange {
2285        assert!(len <= self.sub_allocation_size as usize);
2286        ring::TransferPageRange {
2287            byte_offset: index * self.sub_allocation_size,
2288            byte_count: len as u32,
2289        }
2290    }
2291
2292    fn saved_state(&self) -> saved_state::ReceiveBuffer {
2293        saved_state::ReceiveBuffer {
2294            gpadl_id: self.gpadl.id(),
2295            id: self.id,
2296            sub_allocation_size: self.sub_allocation_size,
2297        }
2298    }
2299}
2300
2301#[derive(Debug)]
2302struct SendBuffer {
2303    gpadl: GpadlView,
2304}
2305
2306impl SendBuffer {
2307    fn new(gpadl_map: &GpadlMapView, gpadl_id: GpadlId) -> Result<Self, BufferError> {
2308        let gpadl = gpadl_map.map(gpadl_id)?;
2309        gpadl
2310            .contiguous_aligned()
2311            .ok_or(BufferError::UnalignedGpadl)?;
2312        Ok(Self { gpadl })
2313    }
2314}
2315
2316impl<T: RingMem> NetChannel<T> {
2317    /// Process a single non-packet RNDIS message.
2318    fn handle_rndis_message(
2319        &mut self,
2320        state: &mut ActiveState,
2321        message_type: u32,
2322        mut reader: PacketReader<'_>,
2323    ) -> Result<(), WorkerError> {
2324        assert_ne!(
2325            message_type,
2326            rndisprot::MESSAGE_TYPE_PACKET_MSG,
2327            "handled elsewhere"
2328        );
2329        let control = state
2330            .primary
2331            .as_mut()
2332            .ok_or(WorkerError::NotSupportedOnSubChannel(message_type))?;
2333
2334        if message_type == rndisprot::MESSAGE_TYPE_HALT_MSG {
2335            // Currently ignored and does not require a response.
2336            return Ok(());
2337        }
2338
2339        // This is a control message that needs a response. Responding
2340        // will require a suballocation to be available, which it may
2341        // not be right now. Enqueue the suballocation to a queue and
2342        // process the queue as suballocations become available.
2343        const CONTROL_MESSAGE_MAX_QUEUED_BYTES: usize = 100 * 1024;
2344        if reader.len() == 0 {
2345            return Err(WorkerError::RndisMessageTooSmall);
2346        }
2347        // Do not let the queue get too large--the guest should not be
2348        // sending very many control messages at a time.
2349        if CONTROL_MESSAGE_MAX_QUEUED_BYTES - control.control_messages_len < reader.len() {
2350            return Err(WorkerError::TooManyControlMessages);
2351        }
2352
2353        control.control_messages_len += reader.len();
2354        control.control_messages.push_back(ControlMessage {
2355            message_type,
2356            data: reader.read_all()?.into(),
2357        });
2358
2359        // The control message queue will be processed in the main dispatch
2360        // loop.
2361        Ok(())
2362    }
2363
2364    /// Process RNDIS packet messages, which may contain multiple RNDIS packets
2365    /// in a single vmbus message.
2366    ///
2367    /// On entry, the reader has already read the RNDIS message header of the
2368    /// first RNDIS packet in the message.
2369    fn handle_rndis_packet_messages(
2370        &mut self,
2371        buffers: &ChannelBuffers,
2372        state: &mut ActiveState,
2373        id: TxId,
2374        mut message_len: usize,
2375        mut reader: PacketReader<'_>,
2376        segments: &mut Vec<TxSegment>,
2377    ) -> Result<usize, WorkerError> {
2378        // There may be multiple RNDIS packets in a single message, concatenated
2379        // with each other. Consume them until there is no more data in the
2380        // RNDIS message.
2381        let mut num_packets = 0;
2382        loop {
2383            let next_message_offset = message_len
2384                .checked_sub(size_of::<rndisprot::MessageHeader>())
2385                .ok_or(WorkerError::RndisMessageTooSmall)?;
2386
2387            self.handle_rndis_packet_message(
2388                id,
2389                reader.clone(),
2390                &buffers.mem,
2391                segments,
2392                &mut state.stats,
2393            )?;
2394            num_packets += 1;
2395
2396            reader.skip(next_message_offset)?;
2397            if reader.len() == 0 {
2398                break;
2399            }
2400            let header: rndisprot::MessageHeader = reader.read_plain()?;
2401            if header.message_type != rndisprot::MESSAGE_TYPE_PACKET_MSG {
2402                return Err(WorkerError::NonRndisPacketAfterPacket(header.message_type));
2403            }
2404            message_len = header.message_length as usize;
2405        }
2406        Ok(num_packets)
2407    }
2408
2409    /// Process an RNDIS package message (used to send an Ethernet frame).
2410    fn handle_rndis_packet_message(
2411        &mut self,
2412        id: TxId,
2413        reader: PacketReader<'_>,
2414        mem: &GuestMemory,
2415        segments: &mut Vec<TxSegment>,
2416        stats: &mut QueueStats,
2417    ) -> Result<(), WorkerError> {
2418        // Headers are guaranteed to be in a single PagedRange.
2419        let headers = reader
2420            .clone()
2421            .into_inner()
2422            .paged_ranges()
2423            .next()
2424            .ok_or(WorkerError::RndisMessageTooSmall)?;
2425        let mut data = reader.into_inner();
2426        let request: rndisprot::Packet = headers.reader(mem).read_plain()?;
2427        if request.num_oob_data_elements != 0
2428            || request.oob_data_length != 0
2429            || request.oob_data_offset != 0
2430            || request.vc_handle != 0
2431        {
2432            return Err(WorkerError::UnsupportedRndisBehavior);
2433        }
2434
2435        if data.len() < request.data_offset as usize
2436            || (data.len() - request.data_offset as usize) < request.data_length as usize
2437            || request.data_length == 0
2438        {
2439            return Err(WorkerError::RndisMessageTooSmall);
2440        }
2441
2442        data.skip(request.data_offset as usize);
2443        data.truncate(request.data_length as usize);
2444
2445        let mut metadata = net_backend::TxMetadata {
2446            id,
2447            len: request.data_length,
2448            ..Default::default()
2449        };
2450
2451        if request.per_packet_info_length != 0 {
2452            let mut ppi = headers
2453                .try_subrange(
2454                    request.per_packet_info_offset as usize,
2455                    request.per_packet_info_length as usize,
2456                )
2457                .ok_or(WorkerError::RndisMessageTooSmall)?;
2458            while !ppi.is_empty() {
2459                let h: rndisprot::PerPacketInfo = ppi.reader(mem).read_plain()?;
2460                if h.size == 0 {
2461                    return Err(WorkerError::RndisMessageTooSmall);
2462                }
2463                let (this, rest) = ppi
2464                    .try_split(h.size as usize)
2465                    .ok_or(WorkerError::RndisMessageTooSmall)?;
2466                let (_, d) = this
2467                    .try_split(h.per_packet_information_offset as usize)
2468                    .ok_or(WorkerError::RndisMessageTooSmall)?;
2469                match h.typ {
2470                    rndisprot::PPI_TCP_IP_CHECKSUM => {
2471                        let n: rndisprot::TxTcpIpChecksumInfo = d.reader(mem).read_plain()?;
2472
2473                        metadata.flags.set_offload_tcp_checksum(
2474                            (n.is_ipv4() || n.is_ipv6()) && n.tcp_checksum(),
2475                        );
2476                        metadata.flags.set_offload_udp_checksum(
2477                            (n.is_ipv4() || n.is_ipv6()) && !n.tcp_checksum() && n.udp_checksum(),
2478                        );
2479                        metadata
2480                            .flags
2481                            .set_offload_ip_header_checksum(n.is_ipv4() && n.ip_header_checksum());
2482                        metadata.flags.set_is_ipv4(n.is_ipv4());
2483                        metadata.flags.set_is_ipv6(n.is_ipv6() && !n.is_ipv4());
2484                        metadata.l2_len = ETHERNET_HEADER_LEN as u8;
2485                        if metadata.flags.offload_tcp_checksum()
2486                            || metadata.flags.offload_udp_checksum()
2487                        {
2488                            metadata.l3_len = if n.tcp_header_offset() >= metadata.l2_len as u16 {
2489                                n.tcp_header_offset() - metadata.l2_len as u16
2490                            } else if n.is_ipv4() {
2491                                let mut reader = data.clone().reader(mem);
2492                                reader.skip(metadata.l2_len as usize)?;
2493                                let mut b = 0;
2494                                reader.read(std::slice::from_mut(&mut b))?;
2495                                (b as u16 >> 4) * 4
2496                            } else {
2497                                // Hope there are no extensions.
2498                                40
2499                            };
2500                        }
2501                    }
2502                    rndisprot::PPI_LSO => {
2503                        let n: rndisprot::TcpLsoInfo = d.reader(mem).read_plain()?;
2504
2505                        metadata.flags.set_offload_tcp_segmentation(true);
2506                        metadata.flags.set_offload_tcp_checksum(true);
2507                        metadata.flags.set_offload_ip_header_checksum(n.is_ipv4());
2508                        metadata.flags.set_is_ipv4(n.is_ipv4());
2509                        metadata.flags.set_is_ipv6(n.is_ipv6() && !n.is_ipv4());
2510                        metadata.l2_len = ETHERNET_HEADER_LEN as u8;
2511                        if n.tcp_header_offset() < metadata.l2_len as u16 {
2512                            return Err(WorkerError::InvalidTcpHeaderOffset);
2513                        }
2514                        metadata.l3_len = n.tcp_header_offset() - metadata.l2_len as u16;
2515                        metadata.l4_len = {
2516                            let mut reader = data.clone().reader(mem);
2517                            reader
2518                                .skip(metadata.l2_len as usize + metadata.l3_len as usize + 12)?;
2519                            let mut b = 0;
2520                            reader.read(std::slice::from_mut(&mut b))?;
2521                            (b >> 4) * 4
2522                        };
2523                        metadata.max_tcp_segment_size = n.mss() as u16;
2524
2525                        if request.data_length >= rndisprot::LSO_MAX_OFFLOAD_SIZE {
2526                            // Not strictly enforced.
2527                            stats.tx_invalid_lso_packets.increment();
2528                        }
2529                    }
2530                    _ => {}
2531                }
2532                ppi = rest;
2533            }
2534        }
2535
2536        let start = segments.len();
2537        for range in data.paged_ranges().flat_map(|r| r.ranges()) {
2538            let range = range.map_err(WorkerError::InvalidGpadl)?;
2539            segments.push(TxSegment {
2540                ty: net_backend::TxSegmentType::Tail,
2541                gpa: range.start,
2542                len: range.len() as u32,
2543            });
2544        }
2545
2546        metadata.segment_count = (segments.len() - start) as u8;
2547
2548        stats.tx_packets.increment();
2549        if metadata.flags.offload_tcp_checksum() || metadata.flags.offload_udp_checksum() {
2550            stats.tx_checksum_packets.increment();
2551        }
2552        if metadata.flags.offload_tcp_segmentation() {
2553            stats.tx_lso_packets.increment();
2554        }
2555
2556        segments[start].ty = net_backend::TxSegmentType::Head(metadata);
2557
2558        Ok(())
2559    }
2560
2561    /// Notify the adapter that the guest VF state has changed and it may
2562    /// need to send a message to the guest.
2563    fn guest_vf_is_available(
2564        &mut self,
2565        guest_vf_id: Option<u32>,
2566        version: Version,
2567        config: NdisConfig,
2568    ) -> Result<bool, WorkerError> {
2569        let serial_number = guest_vf_id.map(|vfid| self.adapter.get_guest_vf_serial_number(vfid));
2570        if version >= Version::V4 && config.capabilities.sriov() {
2571            tracing::info!(
2572                available = serial_number.is_some(),
2573                serial_number,
2574                "sending VF association message."
2575            );
2576            // N.B. MIN_CONTROL_RING_SIZE reserves room to send this packet.
2577            let message = {
2578                self.message(
2579                    protocol::MESSAGE4_TYPE_SEND_VF_ASSOCIATION,
2580                    protocol::Message4SendVfAssociation {
2581                        vf_allocated: if serial_number.is_some() { 1 } else { 0 },
2582                        serial_number: serial_number.unwrap_or(0),
2583                    },
2584                )
2585            };
2586            self.queue
2587                .split()
2588                .1
2589                .batched()
2590                .try_write_aligned(
2591                    VF_ASSOCIATION_TRANSACTION_ID,
2592                    OutgoingPacketType::InBandWithCompletion,
2593                    message.aligned_payload(),
2594                )
2595                .map_err(|err| match err {
2596                    queue::TryWriteError::Full(len) => {
2597                        tracing::error!(len, "failed to write vf association message");
2598                        WorkerError::OutOfSpace
2599                    }
2600                    queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2601                })?;
2602            Ok(true)
2603        } else {
2604            tracing::info!(
2605                available = serial_number.is_some(),
2606                serial_number,
2607                major = version.major(),
2608                minor = version.minor(),
2609                sriov_capable = config.capabilities.sriov(),
2610                "Skipping NvspMessage4TypeSendVFAssociation message"
2611            );
2612            Ok(false)
2613        }
2614    }
2615
2616    /// Send the `NvspMessage5TypeSendIndirectionTable` message.
2617    fn guest_send_indirection_table(&mut self, version: Version, num_channels_opened: u32) {
2618        // N.B. MIN_STATE_CHANGE_RING_SIZE needs to be large enough to support sending the indirection table.
2619        if version < Version::V5 {
2620            return;
2621        }
2622
2623        #[repr(C)]
2624        #[derive(IntoBytes, Immutable, KnownLayout)]
2625        struct SendIndirectionMsg {
2626            pub message: protocol::Message5SendIndirectionTable,
2627            pub send_indirection_table:
2628                [u32; VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES as usize],
2629        }
2630
2631        // The offset to the send indirection table from the beginning of the NVSP message.
2632        let send_indirection_table_offset = offset_of!(SendIndirectionMsg, send_indirection_table)
2633            + size_of::<protocol::MessageHeader>();
2634        let mut data = SendIndirectionMsg {
2635            message: protocol::Message5SendIndirectionTable {
2636                table_entry_count: VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES,
2637                table_offset: send_indirection_table_offset as u32,
2638            },
2639            send_indirection_table: Default::default(),
2640        };
2641
2642        for i in 0..data.send_indirection_table.len() {
2643            data.send_indirection_table[i] = i as u32 % num_channels_opened;
2644        }
2645
2646        let header = protocol::MessageHeader {
2647            message_type: protocol::MESSAGE5_TYPE_SEND_INDIRECTION_TABLE,
2648        };
2649        let result = self
2650            .queue
2651            .split()
2652            .1
2653            .try_write(&queue::OutgoingPacket {
2654                transaction_id: 0,
2655                packet_type: OutgoingPacketType::InBandNoCompletion,
2656                payload: &[header.as_bytes(), data.as_bytes()],
2657            })
2658            .map_err(|err| match err {
2659                queue::TryWriteError::Full(len) => {
2660                    tracing::error!(len, "failed to write send indirection table message");
2661                    WorkerError::OutOfSpace
2662                }
2663                queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2664            });
2665        if let Err(err) = result {
2666            tracing::error!(err = %err, "Failed to notify guest about the send indirection table");
2667        }
2668    }
2669
2670    /// Notify the guest that the data path has been switched back to synthetic
2671    /// due to some external state change.
2672    fn guest_vf_data_path_switched_to_synthetic(&mut self) {
2673        let header = protocol::MessageHeader {
2674            message_type: protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH,
2675        };
2676        let data = protocol::Message4SwitchDataPath {
2677            active_data_path: protocol::DataPath::SYNTHETIC.0,
2678        };
2679        let result = self
2680            .queue
2681            .split()
2682            .1
2683            .try_write(&queue::OutgoingPacket {
2684                transaction_id: SWITCH_DATA_PATH_TRANSACTION_ID,
2685                packet_type: OutgoingPacketType::InBandWithCompletion,
2686                payload: &[header.as_bytes(), data.as_bytes()],
2687            })
2688            .map_err(|err| match err {
2689                queue::TryWriteError::Full(len) => {
2690                    tracing::error!(
2691                        len,
2692                        "failed to write MESSAGE4_TYPE_SWITCH_DATA_PATH message"
2693                    );
2694                    WorkerError::OutOfSpace
2695                }
2696                queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2697            });
2698        if let Err(err) = result {
2699            tracing::error!(err = %err, "Failed to notify guest that data path is now synthetic");
2700        }
2701    }
2702
2703    /// Process an internal state change
2704    async fn handle_state_change(
2705        &mut self,
2706        primary: &mut PrimaryChannelState,
2707        buffers: &ChannelBuffers,
2708    ) -> Result<Option<CoordinatorMessage>, WorkerError> {
2709        // N.B. MIN_STATE_CHANGE_RING_SIZE needs to be large enough to support sending state change messages.
2710        // The worst case is UnavailableFromDataPathSwitchPending, which will send three messages:
2711        //      1. completion of switch data path request
2712        //      2. Switch data path notification (back to synthetic)
2713        //      3. Disassociate VF adapter.
2714        if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
2715            // Notify guest that a VF capability has recently arrived.
2716            if primary.rndis_state == RndisState::Operational {
2717                if self.guest_vf_is_available(Some(vfid), buffers.version, buffers.ndis_config)? {
2718                    primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
2719                    return Ok(Some(CoordinatorMessage::Update(
2720                        CoordinatorMessageUpdateType {
2721                            guest_vf_state: true,
2722                            ..Default::default()
2723                        },
2724                    )));
2725                } else if let Some(true) = primary.is_data_path_switched {
2726                    tracing::error!(
2727                        "Data path switched, but current guest negotiation does not support VTL0 VF"
2728                    );
2729                }
2730            }
2731            return Ok(None);
2732        }
2733        loop {
2734            primary.guest_vf_state = match primary.guest_vf_state {
2735                PrimaryChannelGuestVfState::UnavailableFromAvailable => {
2736                    // Notify guest that the VF is unavailable. It has already been surprise removed.
2737                    if primary.rndis_state == RndisState::Operational {
2738                        self.guest_vf_is_available(None, buffers.version, buffers.ndis_config)?;
2739                    }
2740                    PrimaryChannelGuestVfState::Unavailable
2741                }
2742                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
2743                    to_guest,
2744                    id,
2745                } => {
2746                    // Complete the data path switch request.
2747                    self.send_completion(id, None)?;
2748                    if to_guest {
2749                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
2750                    } else {
2751                        PrimaryChannelGuestVfState::UnavailableFromAvailable
2752                    }
2753                }
2754                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
2755                    // Notify guest that the data path is now synthetic.
2756                    self.guest_vf_data_path_switched_to_synthetic();
2757                    PrimaryChannelGuestVfState::UnavailableFromAvailable
2758                }
2759                PrimaryChannelGuestVfState::DataPathSynthetic => {
2760                    // Notify guest that the data path is now synthetic.
2761                    self.guest_vf_data_path_switched_to_synthetic();
2762                    PrimaryChannelGuestVfState::Ready
2763                }
2764                PrimaryChannelGuestVfState::DataPathSwitchPending {
2765                    to_guest,
2766                    id,
2767                    result,
2768                } => {
2769                    let result = result.expect("DataPathSwitchPending should have been processed");
2770                    // Complete the data path switch request.
2771                    self.send_completion(id, None)?;
2772                    if result {
2773                        if to_guest {
2774                            PrimaryChannelGuestVfState::DataPathSwitched
2775                        } else {
2776                            PrimaryChannelGuestVfState::Ready
2777                        }
2778                    } else {
2779                        if to_guest {
2780                            PrimaryChannelGuestVfState::DataPathSynthetic
2781                        } else {
2782                            tracing::error!(
2783                                "Failure when guest requested switch back to synthetic"
2784                            );
2785                            PrimaryChannelGuestVfState::DataPathSwitched
2786                        }
2787                    }
2788                }
2789                PrimaryChannelGuestVfState::Initializing
2790                | PrimaryChannelGuestVfState::Restoring(_) => {
2791                    panic!("Invalid guest VF state: {}", primary.guest_vf_state)
2792                }
2793                _ => break,
2794            };
2795        }
2796        Ok(None)
2797    }
2798
2799    /// Process a control message, writing the response to the provided receive
2800    /// buffer suballocation.
2801    fn handle_rndis_control_message(
2802        &mut self,
2803        primary: &mut PrimaryChannelState,
2804        buffers: &ChannelBuffers,
2805        message_type: u32,
2806        mut reader: impl MemoryRead + Clone,
2807        id: u32,
2808    ) -> Result<(), WorkerError> {
2809        let mem = &buffers.mem;
2810        let buffer_range = &buffers.recv_buffer.range(id);
2811        match message_type {
2812            rndisprot::MESSAGE_TYPE_INITIALIZE_MSG => {
2813                if primary.rndis_state != RndisState::Initializing {
2814                    return Err(WorkerError::InvalidRndisState);
2815                }
2816
2817                let request: rndisprot::InitializeRequest = reader.read_plain()?;
2818
2819                tracing::trace!(
2820                    ?request,
2821                    "handling control message MESSAGE_TYPE_INITIALIZE_MSG"
2822                );
2823
2824                primary.rndis_state = RndisState::Operational;
2825
2826                let mut writer = buffer_range.writer(mem);
2827                let message_length = write_rndis_message(
2828                    &mut writer,
2829                    rndisprot::MESSAGE_TYPE_INITIALIZE_CMPLT,
2830                    0,
2831                    &rndisprot::InitializeComplete {
2832                        request_id: request.request_id,
2833                        status: rndisprot::STATUS_SUCCESS,
2834                        major_version: rndisprot::MAJOR_VERSION,
2835                        minor_version: rndisprot::MINOR_VERSION,
2836                        device_flags: rndisprot::DF_CONNECTIONLESS,
2837                        medium: rndisprot::MEDIUM_802_3,
2838                        max_packets_per_message: 8,
2839                        max_transfer_size: 0xEFFFFFFF,
2840                        packet_alignment_factor: 3,
2841                        af_list_offset: 0,
2842                        af_list_size: 0,
2843                    },
2844                )?;
2845                self.send_rndis_control_message(buffers, id, message_length)?;
2846                if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
2847                    if self.guest_vf_is_available(
2848                        Some(vfid),
2849                        buffers.version,
2850                        buffers.ndis_config,
2851                    )? {
2852                        // Ideally the VF would not be presented to the guest
2853                        // until the completion packet has arrived, so that the
2854                        // guest is prepared. This is most interesting for the
2855                        // case of a VF associated with multiple guest
2856                        // adapters, using a concept like vports. In this
2857                        // scenario it would be better if all of the adapters
2858                        // were aware a VF was available before the device
2859                        // arrived. This is not currently possible because the
2860                        // Linux netvsc driver ignores the completion requested
2861                        // flag on inband packets and won't send a completion
2862                        // packet.
2863                        primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
2864                        self.send_coordinator_update_vf();
2865                    } else if let Some(true) = primary.is_data_path_switched {
2866                        tracing::error!(
2867                            "Data path switched, but current guest negotiation does not support VTL0 VF"
2868                        );
2869                    }
2870                }
2871            }
2872            rndisprot::MESSAGE_TYPE_QUERY_MSG => {
2873                let request: rndisprot::QueryRequest = reader.read_plain()?;
2874
2875                tracing::trace!(?request, "handling control message MESSAGE_TYPE_QUERY_MSG");
2876
2877                let (header, body) = buffer_range
2878                    .try_split(
2879                        size_of::<rndisprot::MessageHeader>()
2880                            + size_of::<rndisprot::QueryComplete>(),
2881                    )
2882                    .ok_or(WorkerError::RndisMessageTooSmall)?;
2883                let (status, tx) = match self.adapter.handle_oid_query(
2884                    buffers,
2885                    primary,
2886                    request.oid,
2887                    body.writer(mem),
2888                ) {
2889                    Ok(tx) => (rndisprot::STATUS_SUCCESS, tx),
2890                    Err(err) => (err.as_status(), 0),
2891                };
2892
2893                let message_length = write_rndis_message(
2894                    &mut header.writer(mem),
2895                    rndisprot::MESSAGE_TYPE_QUERY_CMPLT,
2896                    tx,
2897                    &rndisprot::QueryComplete {
2898                        request_id: request.request_id,
2899                        status,
2900                        information_buffer_offset: size_of::<rndisprot::QueryComplete>() as u32,
2901                        information_buffer_length: tx as u32,
2902                    },
2903                )?;
2904                self.send_rndis_control_message(buffers, id, message_length)?;
2905            }
2906            rndisprot::MESSAGE_TYPE_SET_MSG => {
2907                let request: rndisprot::SetRequest = reader.read_plain()?;
2908
2909                tracing::trace!(?request, "handling control message MESSAGE_TYPE_SET_MSG");
2910
2911                let status = match self.adapter.handle_oid_set(primary, request.oid, reader) {
2912                    Ok((restart_endpoint, packet_filter)) => {
2913                        // Restart the endpoint if the OID changed some critical
2914                        // endpoint property.
2915                        if restart_endpoint {
2916                            self.restart = Some(CoordinatorMessage::Restart);
2917                        }
2918                        if let Some(filter) = packet_filter {
2919                            if self.packet_filter != filter {
2920                                self.packet_filter = filter;
2921                                self.send_coordinator_update_filter();
2922                            }
2923                        }
2924                        rndisprot::STATUS_SUCCESS
2925                    }
2926                    Err(err) => {
2927                        tracelimit::warn_ratelimited!(oid = ?request.oid, error = &err as &dyn std::error::Error, "oid failure");
2928                        err.as_status()
2929                    }
2930                };
2931
2932                let message_length = write_rndis_message(
2933                    &mut buffer_range.writer(mem),
2934                    rndisprot::MESSAGE_TYPE_SET_CMPLT,
2935                    0,
2936                    &rndisprot::SetComplete {
2937                        request_id: request.request_id,
2938                        status,
2939                    },
2940                )?;
2941                self.send_rndis_control_message(buffers, id, message_length)?;
2942            }
2943            rndisprot::MESSAGE_TYPE_RESET_MSG => {
2944                return Err(WorkerError::RndisMessageTypeNotImplemented);
2945            }
2946            rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG => {
2947                return Err(WorkerError::RndisMessageTypeNotImplemented);
2948            }
2949            rndisprot::MESSAGE_TYPE_KEEPALIVE_MSG => {
2950                let request: rndisprot::KeepaliveRequest = reader.read_plain()?;
2951
2952                tracing::trace!(
2953                    ?request,
2954                    "handling control message MESSAGE_TYPE_KEEPALIVE_MSG"
2955                );
2956
2957                let message_length = write_rndis_message(
2958                    &mut buffer_range.writer(mem),
2959                    rndisprot::MESSAGE_TYPE_KEEPALIVE_CMPLT,
2960                    0,
2961                    &rndisprot::KeepaliveComplete {
2962                        request_id: request.request_id,
2963                        status: rndisprot::STATUS_SUCCESS,
2964                    },
2965                )?;
2966                self.send_rndis_control_message(buffers, id, message_length)?;
2967            }
2968            rndisprot::MESSAGE_TYPE_SET_EX_MSG => {
2969                return Err(WorkerError::RndisMessageTypeNotImplemented);
2970            }
2971            _ => return Err(WorkerError::UnknownRndisMessageType(message_type)),
2972        };
2973        Ok(())
2974    }
2975
2976    fn try_send_rndis_message(
2977        &mut self,
2978        transaction_id: u64,
2979        channel_type: u32,
2980        recv_buffer_id: u16,
2981        transfer_pages: &[ring::TransferPageRange],
2982    ) -> Result<Option<usize>, WorkerError> {
2983        let message = self.message(
2984            protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET,
2985            protocol::Message1SendRndisPacket {
2986                channel_type,
2987                send_buffer_section_index: 0xffffffff,
2988                send_buffer_section_size: 0,
2989            },
2990        );
2991        let pending_send_size = match self.queue.split().1.batched().try_write_aligned(
2992            transaction_id,
2993            OutgoingPacketType::TransferPages(recv_buffer_id, transfer_pages),
2994            message.aligned_payload(),
2995        ) {
2996            Ok(()) => None,
2997            Err(queue::TryWriteError::Full(n)) => Some(n),
2998            Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
2999        };
3000        Ok(pending_send_size)
3001    }
3002
3003    fn send_rndis_control_message(
3004        &mut self,
3005        buffers: &ChannelBuffers,
3006        id: u32,
3007        message_length: usize,
3008    ) -> Result<(), WorkerError> {
3009        let result = self.try_send_rndis_message(
3010            id as u64,
3011            protocol::CONTROL_CHANNEL_TYPE,
3012            buffers.recv_buffer.id,
3013            std::slice::from_ref(&buffers.recv_buffer.transfer_page_range(id, message_length)),
3014        )?;
3015
3016        // Ring size is checked before control messages are processed, so failure to write is unexpected.
3017        match result {
3018            None => Ok(()),
3019            Some(len) => {
3020                tracelimit::error_ratelimited!(len, "failed to write control message completion");
3021                Err(WorkerError::OutOfSpace)
3022            }
3023        }
3024    }
3025
3026    fn indicate_status(
3027        &mut self,
3028        buffers: &ChannelBuffers,
3029        id: u32,
3030        status: u32,
3031        payload: &[u8],
3032    ) -> Result<(), WorkerError> {
3033        let buffer = &buffers.recv_buffer.range(id);
3034        let mut writer = buffer.writer(&buffers.mem);
3035        let message_length = write_rndis_message(
3036            &mut writer,
3037            rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG,
3038            payload.len(),
3039            &rndisprot::IndicateStatus {
3040                status,
3041                status_buffer_length: payload.len() as u32,
3042                status_buffer_offset: if payload.is_empty() {
3043                    0
3044                } else {
3045                    size_of::<rndisprot::IndicateStatus>() as u32
3046                },
3047            },
3048        )?;
3049        writer.write(payload)?;
3050        self.send_rndis_control_message(buffers, id, message_length)?;
3051        Ok(())
3052    }
3053
3054    /// Processes pending control messages until all are processed or there are
3055    /// no available suballocations.
3056    fn process_control_messages(
3057        &mut self,
3058        buffers: &ChannelBuffers,
3059        state: &mut ActiveState,
3060    ) -> Result<(), WorkerError> {
3061        let Some(primary) = &mut state.primary else {
3062            return Ok(());
3063        };
3064
3065        while !primary.control_messages.is_empty()
3066            || (primary.pending_offload_change && primary.rndis_state == RndisState::Operational)
3067        {
3068            // Ensure the ring buffer has enough room to successfully complete control message handling.
3069            if !self.queue.split().1.can_write(MIN_CONTROL_RING_SIZE)? {
3070                self.pending_send_size = MIN_CONTROL_RING_SIZE;
3071                break;
3072            }
3073            let Some(id) = primary.free_control_buffers.pop() else {
3074                break;
3075            };
3076
3077            // Mark the receive buffer in use to allow the guest to release it.
3078            assert!(state.rx_bufs.is_free(id.0));
3079            state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
3080
3081            if let Some(message) = primary.control_messages.pop_front() {
3082                primary.control_messages_len -= message.data.len();
3083                self.handle_rndis_control_message(
3084                    primary,
3085                    buffers,
3086                    message.message_type,
3087                    message.data.as_ref(),
3088                    id.0,
3089                )?;
3090            } else if primary.pending_offload_change
3091                && primary.rndis_state == RndisState::Operational
3092            {
3093                let ndis_offload = primary.offload_config.ndis_offload();
3094                self.indicate_status(
3095                    buffers,
3096                    id.0,
3097                    rndisprot::STATUS_TASK_OFFLOAD_CURRENT_CONFIG,
3098                    &ndis_offload.as_bytes()[..ndis_offload.header.size.into()],
3099                )?;
3100                primary.pending_offload_change = false;
3101            } else {
3102                unreachable!();
3103            }
3104        }
3105        Ok(())
3106    }
3107
3108    fn send_coordinator_update_message(&mut self, guest_vf: bool, packet_filter: bool) {
3109        if self.restart.is_none() {
3110            self.restart = Some(CoordinatorMessage::Update(CoordinatorMessageUpdateType {
3111                guest_vf_state: guest_vf,
3112                filter_state: packet_filter,
3113            }));
3114        } else if let Some(CoordinatorMessage::Restart) = self.restart {
3115            // If a restart message is pending, do nothing.
3116            // A restart will try to switch the data path based on primary.guest_vf_state.
3117            // A restart will apply packet filter changes.
3118        } else if let Some(CoordinatorMessage::Update(ref mut update)) = self.restart {
3119            // Add the new update to the existing message.
3120            update.guest_vf_state |= guest_vf;
3121            update.filter_state |= packet_filter;
3122        }
3123    }
3124
3125    fn send_coordinator_update_vf(&mut self) {
3126        self.send_coordinator_update_message(true, false);
3127    }
3128
3129    fn send_coordinator_update_filter(&mut self) {
3130        self.send_coordinator_update_message(false, true);
3131    }
3132}
3133
3134/// Writes an RNDIS message to `writer`.
3135fn write_rndis_message<T: IntoBytes + Immutable + KnownLayout>(
3136    writer: &mut impl MemoryWrite,
3137    message_type: u32,
3138    extra: usize,
3139    payload: &T,
3140) -> Result<usize, AccessError> {
3141    let message_length = size_of::<rndisprot::MessageHeader>() + size_of_val(payload) + extra;
3142    writer.write(
3143        rndisprot::MessageHeader {
3144            message_type,
3145            message_length: message_length as u32,
3146        }
3147        .as_bytes(),
3148    )?;
3149    writer.write(payload.as_bytes())?;
3150    Ok(message_length)
3151}
3152
3153#[derive(Debug, Error)]
3154enum OidError {
3155    #[error(transparent)]
3156    Access(#[from] AccessError),
3157    #[error("unknown oid")]
3158    UnknownOid,
3159    #[error("invalid oid input, bad field {0}")]
3160    InvalidInput(&'static str),
3161    #[error("bad ndis version")]
3162    BadVersion,
3163    #[error("feature {0} not supported")]
3164    NotSupported(&'static str),
3165}
3166
3167impl OidError {
3168    fn as_status(&self) -> u32 {
3169        match self {
3170            OidError::UnknownOid | OidError::NotSupported(_) => rndisprot::STATUS_NOT_SUPPORTED,
3171            OidError::BadVersion => rndisprot::STATUS_BAD_VERSION,
3172            OidError::InvalidInput(_) => rndisprot::STATUS_INVALID_DATA,
3173            OidError::Access(_) => rndisprot::STATUS_FAILURE,
3174        }
3175    }
3176}
3177
3178const DEFAULT_MTU: u32 = 1514;
3179const MIN_MTU: u32 = DEFAULT_MTU;
3180const MAX_MTU: u32 = 9216;
3181
3182const ETHERNET_HEADER_LEN: u32 = 14;
3183
3184impl Adapter {
3185    fn get_guest_vf_serial_number(&self, vfid: u32) -> u32 {
3186        if let Some(guest_os_id) = self.get_guest_os_id.as_ref().map(|f| f()) {
3187            // For enlightened guests (which is only Windows at the moment), send the
3188            // adapter index, which was previously set as the vport serial number.
3189            if guest_os_id
3190                .microsoft()
3191                .unwrap_or(HvGuestOsMicrosoft::from(0))
3192                .os_id()
3193                == HvGuestOsMicrosoftIds::WINDOWS_NT.0
3194            {
3195                self.adapter_index
3196            } else {
3197                vfid
3198            }
3199        } else {
3200            vfid
3201        }
3202    }
3203
3204    fn handle_oid_query(
3205        &self,
3206        buffers: &ChannelBuffers,
3207        primary: &PrimaryChannelState,
3208        oid: rndisprot::Oid,
3209        mut writer: impl MemoryWrite,
3210    ) -> Result<usize, OidError> {
3211        tracing::debug!(?oid, "oid query");
3212        let available_len = writer.len();
3213        match oid {
3214            rndisprot::Oid::OID_GEN_SUPPORTED_LIST => {
3215                let supported_oids_common = &[
3216                    rndisprot::Oid::OID_GEN_SUPPORTED_LIST,
3217                    rndisprot::Oid::OID_GEN_HARDWARE_STATUS,
3218                    rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED,
3219                    rndisprot::Oid::OID_GEN_MEDIA_IN_USE,
3220                    rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD,
3221                    rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD,
3222                    rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE,
3223                    rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE,
3224                    rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE,
3225                    rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE,
3226                    rndisprot::Oid::OID_GEN_LINK_SPEED,
3227                    rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE,
3228                    rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE,
3229                    rndisprot::Oid::OID_GEN_VENDOR_ID,
3230                    rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION,
3231                    rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION,
3232                    rndisprot::Oid::OID_GEN_DRIVER_VERSION,
3233                    rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER,
3234                    rndisprot::Oid::OID_GEN_PROTOCOL_OPTIONS,
3235                    rndisprot::Oid::OID_GEN_MAC_OPTIONS,
3236                    rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS,
3237                    rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS,
3238                    rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES,
3239                    rndisprot::Oid::OID_GEN_FRIENDLY_NAME,
3240                    // Ethernet objects operation characteristics
3241                    rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS,
3242                    rndisprot::Oid::OID_802_3_CURRENT_ADDRESS,
3243                    rndisprot::Oid::OID_802_3_MULTICAST_LIST,
3244                    rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE,
3245                    // Ethernet objects statistics
3246                    rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT,
3247                    rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION,
3248                    rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS,
3249                    // PNP operations characteristics */
3250                    // rndisprot::Oid::OID_PNP_SET_POWER,
3251                    // rndisprot::Oid::OID_PNP_QUERY_POWER,
3252                    // RNDIS OIDS
3253                    rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER,
3254                ];
3255
3256                // NDIS6 OIDs
3257                let supported_oids_6 = &[
3258                    // Link State OID
3259                    rndisprot::Oid::OID_GEN_LINK_PARAMETERS,
3260                    rndisprot::Oid::OID_GEN_LINK_STATE,
3261                    rndisprot::Oid::OID_GEN_MAX_LINK_SPEED,
3262                    // NDIS 6 statistics OID
3263                    rndisprot::Oid::OID_GEN_BYTES_RCV,
3264                    rndisprot::Oid::OID_GEN_BYTES_XMIT,
3265                    // Offload related OID
3266                    rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS,
3267                    rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION,
3268                    rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES,
3269                    rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG,
3270                    // rndisprot::Oid::OID_802_3_ADD_MULTICAST_ADDRESS,
3271                    // rndisprot::Oid::OID_802_3_DELETE_MULTICAST_ADDRESS,
3272                ];
3273
3274                let supported_oids_63 = &[
3275                    rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES,
3276                    rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS,
3277                ];
3278
3279                match buffers.ndis_version.major {
3280                    5 => {
3281                        writer.write(supported_oids_common.as_bytes())?;
3282                    }
3283                    6 => {
3284                        writer.write(supported_oids_common.as_bytes())?;
3285                        writer.write(supported_oids_6.as_bytes())?;
3286                        if buffers.ndis_version.minor >= 30 {
3287                            writer.write(supported_oids_63.as_bytes())?;
3288                        }
3289                    }
3290                    _ => return Err(OidError::BadVersion),
3291                }
3292            }
3293            rndisprot::Oid::OID_GEN_HARDWARE_STATUS => {
3294                let status: u32 = 0; // NdisHardwareStatusReady
3295                writer.write(status.as_bytes())?;
3296            }
3297            rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED | rndisprot::Oid::OID_GEN_MEDIA_IN_USE => {
3298                writer.write(rndisprot::MEDIUM_802_3.as_bytes())?;
3299            }
3300            rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD
3301            | rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD
3302            | rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE => {
3303                let len: u32 = buffers.ndis_config.mtu - ETHERNET_HEADER_LEN;
3304                writer.write(len.as_bytes())?;
3305            }
3306            rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE
3307            | rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE
3308            | rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE => {
3309                let len: u32 = buffers.ndis_config.mtu;
3310                writer.write(len.as_bytes())?;
3311            }
3312            rndisprot::Oid::OID_GEN_LINK_SPEED => {
3313                let speed: u32 = (self.link_speed / 100) as u32; // In 100bps units
3314                writer.write(speed.as_bytes())?;
3315            }
3316            rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE
3317            | rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE => {
3318                // This value is meaningless for virtual NICs. Return what vmswitch returns.
3319                writer.write((256u32 * 1024).as_bytes())?
3320            }
3321            rndisprot::Oid::OID_GEN_VENDOR_ID => {
3322                // Like vmswitch, use the first N bytes of Microsoft's MAC address
3323                // prefix as the vendor ID.
3324                writer.write(0x0000155du32.as_bytes())?;
3325            }
3326            rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION => writer.write(b"Microsoft Corporation")?,
3327            rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION
3328            | rndisprot::Oid::OID_GEN_DRIVER_VERSION => {
3329                writer.write(0x0100u16.as_bytes())? // 1.0. Vmswitch reports 19.0 for Mn.
3330            }
3331            rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => writer.write(0u32.as_bytes())?,
3332            rndisprot::Oid::OID_GEN_MAC_OPTIONS => {
3333                let options: u32 = rndisprot::MAC_OPTION_COPY_LOOKAHEAD_DATA
3334                    | rndisprot::MAC_OPTION_TRANSFERS_NOT_PEND
3335                    | rndisprot::MAC_OPTION_NO_LOOPBACK;
3336                writer.write(options.as_bytes())?;
3337            }
3338            rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS => {
3339                writer.write(rndisprot::MEDIA_STATE_CONNECTED.as_bytes())?;
3340            }
3341            rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS => writer.write(u32::MAX.as_bytes())?,
3342            rndisprot::Oid::OID_GEN_FRIENDLY_NAME => {
3343                let name16: Vec<u16> = "Network Device".encode_utf16().collect();
3344                let mut name = rndisprot::FriendlyName::new_zeroed();
3345                name.name[..name16.len()].copy_from_slice(&name16);
3346                writer.write(name.as_bytes())?
3347            }
3348            rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS
3349            | rndisprot::Oid::OID_802_3_CURRENT_ADDRESS => {
3350                writer.write(&self.mac_address.to_bytes())?
3351            }
3352            rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE => {
3353                writer.write(0u32.as_bytes())?;
3354            }
3355            rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT
3356            | rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION
3357            | rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS => writer.write(0u32.as_bytes())?,
3358
3359            // NDIS6 OIDs:
3360            rndisprot::Oid::OID_GEN_LINK_STATE => {
3361                let link_state = rndisprot::LinkState {
3362                    header: rndisprot::NdisObjectHeader {
3363                        object_type: rndisprot::NdisObjectType::DEFAULT,
3364                        revision: 1,
3365                        size: size_of::<rndisprot::LinkState>() as u16,
3366                    },
3367                    media_connect_state: 1, /* MediaConnectStateConnected */
3368                    media_duplex_state: 0,  /* MediaDuplexStateUnknown */
3369                    padding: 0,
3370                    xmit_link_speed: self.link_speed,
3371                    rcv_link_speed: self.link_speed,
3372                    pause_functions: 0, /* NdisPauseFunctionsUnsupported */
3373                    auto_negotiation_flags: 0,
3374                };
3375                writer.write(link_state.as_bytes())?;
3376            }
3377            rndisprot::Oid::OID_GEN_MAX_LINK_SPEED => {
3378                let link_speed = rndisprot::LinkSpeed {
3379                    xmit: self.link_speed,
3380                    rcv: self.link_speed,
3381                };
3382                writer.write(link_speed.as_bytes())?;
3383            }
3384            rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES => {
3385                let ndis_offload = self.offload_support.ndis_offload();
3386                writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3387            }
3388            rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG => {
3389                let ndis_offload = primary.offload_config.ndis_offload();
3390                writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3391            }
3392            rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3393                writer.write(
3394                    &rndisprot::NdisOffloadEncapsulation {
3395                        header: rndisprot::NdisObjectHeader {
3396                            object_type: rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3397                            revision: 1,
3398                            size: rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1 as u16,
3399                        },
3400                        ipv4_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3401                        ipv4_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3402                        ipv4_header_size: ETHERNET_HEADER_LEN,
3403                        ipv6_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3404                        ipv6_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3405                        ipv6_header_size: ETHERNET_HEADER_LEN,
3406                    }
3407                    .as_bytes()[..rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1],
3408                )?;
3409            }
3410            rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES => {
3411                writer.write(
3412                    &rndisprot::NdisReceiveScaleCapabilities {
3413                        header: rndisprot::NdisObjectHeader {
3414                            object_type: rndisprot::NdisObjectType::RSS_CAPABILITIES,
3415                            revision: 2,
3416                            size: rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2
3417                                as u16,
3418                        },
3419                        capabilities_flags: rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV4
3420                            | rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV6
3421                            | rndisprot::NDIS_HASH_FUNCTION_TOEPLITZ,
3422                        number_of_interrupt_messages: 1,
3423                        number_of_receive_queues: self.max_queues.into(),
3424                        number_of_indirection_table_entries: if self.indirection_table_size != 0 {
3425                            self.indirection_table_size
3426                        } else {
3427                            // DPDK gets confused if the table size is zero,
3428                            // even if there is only one queue.
3429                            128
3430                        },
3431                        padding: 0,
3432                    }
3433                    .as_bytes()[..rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2],
3434                )?;
3435            }
3436            _ => {
3437                tracelimit::warn_ratelimited!(?oid, "query for unknown OID");
3438                return Err(OidError::UnknownOid);
3439            }
3440        };
3441        Ok(available_len - writer.len())
3442    }
3443
3444    fn handle_oid_set(
3445        &self,
3446        primary: &mut PrimaryChannelState,
3447        oid: rndisprot::Oid,
3448        reader: impl MemoryRead + Clone,
3449    ) -> Result<(bool, Option<u32>), OidError> {
3450        tracing::debug!(?oid, "oid set");
3451
3452        let mut restart_endpoint = false;
3453        let mut packet_filter = None;
3454        match oid {
3455            rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => {
3456                packet_filter = self.oid_set_packet_filter(reader)?;
3457            }
3458            rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS => {
3459                self.oid_set_offload_parameters(reader, primary)?;
3460            }
3461            rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3462                self.oid_set_offload_encapsulation(reader)?;
3463            }
3464            rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER => {
3465                self.oid_set_rndis_config_parameter(reader, primary)?;
3466            }
3467            rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES => {
3468                // TODO
3469            }
3470            rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS => {
3471                self.oid_set_rss_parameters(reader, primary)?;
3472
3473                // Endpoints cannot currently change RSS parameters without
3474                // being restarted. This was a limitation driven by some DPDK
3475                // PMDs, and should be fixed.
3476                restart_endpoint = true;
3477            }
3478            _ => {
3479                tracelimit::warn_ratelimited!(?oid, "set of unknown OID");
3480                return Err(OidError::UnknownOid);
3481            }
3482        }
3483        Ok((restart_endpoint, packet_filter))
3484    }
3485
3486    fn oid_set_rss_parameters(
3487        &self,
3488        mut reader: impl MemoryRead + Clone,
3489        primary: &mut PrimaryChannelState,
3490    ) -> Result<(), OidError> {
3491        // Vmswitch doesn't validate the NDIS header on this object, so read it manually.
3492        let mut params = rndisprot::NdisReceiveScaleParameters::new_zeroed();
3493        let len = reader.len().min(size_of_val(&params));
3494        reader.clone().read(&mut params.as_mut_bytes()[..len])?;
3495
3496        if ((params.flags & NDIS_RSS_PARAM_FLAG_DISABLE_RSS) != 0)
3497            || ((params.hash_information & NDIS_HASH_FUNCTION_MASK) == 0)
3498        {
3499            primary.rss_state = None;
3500            return Ok(());
3501        }
3502
3503        if params.hash_secret_key_size != 40 {
3504            return Err(OidError::InvalidInput("hash_secret_key_size"));
3505        }
3506        if params.indirection_table_size % 4 != 0 {
3507            return Err(OidError::InvalidInput("indirection_table_size"));
3508        }
3509        let indirection_table_size =
3510            (params.indirection_table_size / 4).min(self.indirection_table_size) as usize;
3511        let mut key = [0; 40];
3512        let mut indirection_table = vec![0u32; self.indirection_table_size as usize];
3513        reader
3514            .clone()
3515            .skip(params.hash_secret_key_offset as usize)?
3516            .read(&mut key)?;
3517        reader
3518            .skip(params.indirection_table_offset as usize)?
3519            .read(indirection_table[..indirection_table_size].as_mut_bytes())?;
3520        tracelimit::info_ratelimited!(?indirection_table, "OID_GEN_RECEIVE_SCALE_PARAMETERS");
3521        if indirection_table
3522            .iter()
3523            .any(|&x| x >= self.max_queues as u32)
3524        {
3525            return Err(OidError::InvalidInput("indirection_table"));
3526        }
3527        let (indir_init, indir_uninit) = indirection_table.split_at_mut(indirection_table_size);
3528        for (src, dest) in std::iter::repeat_with(|| indir_init.iter().copied())
3529            .flatten()
3530            .zip(indir_uninit)
3531        {
3532            *dest = src;
3533        }
3534        primary.rss_state = Some(RssState {
3535            key,
3536            indirection_table: indirection_table.iter().map(|&x| x as u16).collect(),
3537        });
3538        Ok(())
3539    }
3540
3541    fn oid_set_packet_filter(
3542        &self,
3543        reader: impl MemoryRead + Clone,
3544    ) -> Result<Option<u32>, OidError> {
3545        let filter: rndisprot::RndisPacketFilterOidValue = reader.clone().read_plain()?;
3546        tracing::debug!(filter, "set packet filter");
3547        Ok(Some(filter))
3548    }
3549
3550    fn oid_set_offload_parameters(
3551        &self,
3552        reader: impl MemoryRead + Clone,
3553        primary: &mut PrimaryChannelState,
3554    ) -> Result<(), OidError> {
3555        let offload: rndisprot::NdisOffloadParameters = read_ndis_object(
3556            reader,
3557            rndisprot::NdisObjectType::DEFAULT,
3558            1,
3559            rndisprot::NDIS_SIZEOF_OFFLOAD_PARAMETERS_REVISION_1,
3560        )?;
3561
3562        tracing::debug!(?offload, "offload parameters");
3563        let rndisprot::NdisOffloadParameters {
3564            header: _,
3565            ipv4_checksum,
3566            tcp4_checksum,
3567            udp4_checksum,
3568            tcp6_checksum,
3569            udp6_checksum,
3570            lsov1,
3571            ipsec_v1: _,
3572            lsov2_ipv4,
3573            lsov2_ipv6,
3574            tcp_connection_ipv4: _,
3575            tcp_connection_ipv6: _,
3576            reserved: _,
3577            flags: _,
3578        } = offload;
3579
3580        if lsov1 == rndisprot::OffloadParametersSimple::ENABLED {
3581            return Err(OidError::NotSupported("lsov1"));
3582        }
3583        if let Some((tx, rx)) = ipv4_checksum.tx_rx() {
3584            primary.offload_config.checksum_tx.ipv4_header = tx;
3585            primary.offload_config.checksum_rx.ipv4_header = rx;
3586        }
3587        if let Some((tx, rx)) = tcp4_checksum.tx_rx() {
3588            primary.offload_config.checksum_tx.tcp4 = tx;
3589            primary.offload_config.checksum_rx.tcp4 = rx;
3590        }
3591        if let Some((tx, rx)) = tcp6_checksum.tx_rx() {
3592            primary.offload_config.checksum_tx.tcp6 = tx;
3593            primary.offload_config.checksum_rx.tcp6 = rx;
3594        }
3595        if let Some((tx, rx)) = udp4_checksum.tx_rx() {
3596            primary.offload_config.checksum_tx.udp4 = tx;
3597            primary.offload_config.checksum_rx.udp4 = rx;
3598        }
3599        if let Some((tx, rx)) = udp6_checksum.tx_rx() {
3600            primary.offload_config.checksum_tx.udp6 = tx;
3601            primary.offload_config.checksum_rx.udp6 = rx;
3602        }
3603        if let Some(enable) = lsov2_ipv4.enable() {
3604            primary.offload_config.lso4 = enable;
3605        }
3606        if let Some(enable) = lsov2_ipv6.enable() {
3607            primary.offload_config.lso6 = enable;
3608        }
3609        primary
3610            .offload_config
3611            .mask_to_supported(&self.offload_support);
3612        primary.pending_offload_change = true;
3613        Ok(())
3614    }
3615
3616    fn oid_set_offload_encapsulation(
3617        &self,
3618        reader: impl MemoryRead + Clone,
3619    ) -> Result<(), OidError> {
3620        let encap: rndisprot::NdisOffloadEncapsulation = read_ndis_object(
3621            reader,
3622            rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3623            1,
3624            rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1,
3625        )?;
3626        if encap.ipv4_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3627            && (encap.ipv4_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3628                || encap.ipv4_header_size != ETHERNET_HEADER_LEN)
3629        {
3630            return Err(OidError::NotSupported("ipv4 encap"));
3631        }
3632        if encap.ipv6_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3633            && (encap.ipv6_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3634                || encap.ipv6_header_size != ETHERNET_HEADER_LEN)
3635        {
3636            return Err(OidError::NotSupported("ipv6 encap"));
3637        }
3638        Ok(())
3639    }
3640
3641    fn oid_set_rndis_config_parameter(
3642        &self,
3643        reader: impl MemoryRead + Clone,
3644        primary: &mut PrimaryChannelState,
3645    ) -> Result<(), OidError> {
3646        let info: rndisprot::RndisConfigParameterInfo = reader.clone().read_plain()?;
3647        if info.name_length > 255 {
3648            return Err(OidError::InvalidInput("name_length"));
3649        }
3650        if info.value_length > 255 {
3651            return Err(OidError::InvalidInput("value_length"));
3652        }
3653        let name = reader
3654            .clone()
3655            .skip(info.name_offset as usize)?
3656            .read_n::<u16>(info.name_length as usize / 2)?;
3657        let name = String::from_utf16(&name).map_err(|_| OidError::InvalidInput("name"))?;
3658        let mut value = reader;
3659        value.skip(info.value_offset as usize)?;
3660        let mut value = value.limit(info.value_length as usize);
3661        match info.parameter_type {
3662            rndisprot::NdisParameterType::STRING => {
3663                let value = value.read_n::<u16>(info.value_length as usize / 2)?;
3664                let value =
3665                    String::from_utf16(&value).map_err(|_| OidError::InvalidInput("value"))?;
3666                let as_num = value.as_bytes().first().map_or(0, |c| c - b'0');
3667                let tx = as_num & 1 != 0;
3668                let rx = as_num & 2 != 0;
3669
3670                tracing::debug!(name, value, "rndis config");
3671                match name.as_str() {
3672                    "*IPChecksumOffloadIPv4" => {
3673                        primary.offload_config.checksum_tx.ipv4_header = tx;
3674                        primary.offload_config.checksum_rx.ipv4_header = rx;
3675                    }
3676                    "*LsoV2IPv4" => {
3677                        primary.offload_config.lso4 = as_num != 0;
3678                    }
3679                    "*LsoV2IPv6" => {
3680                        primary.offload_config.lso6 = as_num != 0;
3681                    }
3682                    "*TCPChecksumOffloadIPv4" => {
3683                        primary.offload_config.checksum_tx.tcp4 = tx;
3684                        primary.offload_config.checksum_rx.tcp4 = rx;
3685                    }
3686                    "*TCPChecksumOffloadIPv6" => {
3687                        primary.offload_config.checksum_tx.tcp6 = tx;
3688                        primary.offload_config.checksum_rx.tcp6 = rx;
3689                    }
3690                    "*UDPChecksumOffloadIPv4" => {
3691                        primary.offload_config.checksum_tx.udp4 = tx;
3692                        primary.offload_config.checksum_rx.udp4 = rx;
3693                    }
3694                    "*UDPChecksumOffloadIPv6" => {
3695                        primary.offload_config.checksum_tx.udp6 = tx;
3696                        primary.offload_config.checksum_rx.udp6 = rx;
3697                    }
3698                    _ => {}
3699                }
3700                primary
3701                    .offload_config
3702                    .mask_to_supported(&self.offload_support);
3703            }
3704            rndisprot::NdisParameterType::INTEGER => {
3705                let value: u32 = value.read_plain()?;
3706                tracing::debug!(name, value, "rndis config");
3707            }
3708            parameter_type => tracelimit::warn_ratelimited!(
3709                name,
3710                ?parameter_type,
3711                "unhandled rndis config parameter type"
3712            ),
3713        }
3714        Ok(())
3715    }
3716}
3717
3718fn read_ndis_object<T: IntoBytes + FromBytes + Debug + Immutable + KnownLayout>(
3719    mut reader: impl MemoryRead,
3720    object_type: rndisprot::NdisObjectType,
3721    min_revision: u8,
3722    min_size: usize,
3723) -> Result<T, OidError> {
3724    let mut buffer = T::new_zeroed();
3725    let sent_size = reader.len();
3726    let len = sent_size.min(size_of_val(&buffer));
3727    reader.read(&mut buffer.as_mut_bytes()[..len])?;
3728    validate_ndis_object_header(
3729        &rndisprot::NdisObjectHeader::read_from_prefix(buffer.as_bytes())
3730            .unwrap()
3731            .0, // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
3732        sent_size,
3733        object_type,
3734        min_revision,
3735        min_size,
3736    )?;
3737    Ok(buffer)
3738}
3739
3740fn validate_ndis_object_header(
3741    header: &rndisprot::NdisObjectHeader,
3742    sent_size: usize,
3743    object_type: rndisprot::NdisObjectType,
3744    min_revision: u8,
3745    min_size: usize,
3746) -> Result<(), OidError> {
3747    if header.object_type != object_type {
3748        return Err(OidError::InvalidInput("header.object_type"));
3749    }
3750    if sent_size < header.size as usize {
3751        return Err(OidError::InvalidInput("header.size"));
3752    }
3753    if header.revision < min_revision {
3754        return Err(OidError::InvalidInput("header.revision"));
3755    }
3756    if (header.size as usize) < min_size {
3757        return Err(OidError::InvalidInput("header.size"));
3758    }
3759    Ok(())
3760}
3761
3762struct Coordinator {
3763    recv: mpsc::Receiver<CoordinatorMessage>,
3764    channel_control: ChannelControl,
3765    restart: bool,
3766    workers: Vec<TaskControl<NetQueue, Worker<GpadlRingMem>>>,
3767    buffers: Option<Arc<ChannelBuffers>>,
3768    num_queues: u16,
3769    active_packet_filter: u32,
3770}
3771
3772/// Removing the VF may result in the guest sending messages to switch the data
3773/// path, so these operations need to happen asynchronously with message
3774/// processing.
3775enum CoordinatorStatePendingVfState {
3776    /// No pending updates.
3777    Ready,
3778    /// Delay before adding VF.
3779    Delay {
3780        timer: PolledTimer,
3781        delay_until: Instant,
3782    },
3783    /// A VF update is pending.
3784    Pending,
3785}
3786
3787struct CoordinatorState {
3788    endpoint: Box<dyn Endpoint>,
3789    adapter: Arc<Adapter>,
3790    virtual_function: Option<Box<dyn VirtualFunction>>,
3791    pending_vf_state: CoordinatorStatePendingVfState,
3792}
3793
3794impl InspectTaskMut<Coordinator> for CoordinatorState {
3795    fn inspect_mut(
3796        &mut self,
3797        req: inspect::Request<'_>,
3798        mut coordinator: Option<&mut Coordinator>,
3799    ) {
3800        let mut resp = req.respond();
3801
3802        let adapter = self.adapter.as_ref();
3803        resp.field("mac_address", adapter.mac_address)
3804            .field("max_queues", adapter.max_queues)
3805            .sensitivity_field(
3806                "offload_support",
3807                SensitivityLevel::Safe,
3808                &adapter.offload_support,
3809            )
3810            .field_mut_with("ring_size_limit", |v| -> anyhow::Result<_> {
3811                if let Some(v) = v {
3812                    let v: usize = v.parse()?;
3813                    adapter.ring_size_limit.store(v, Ordering::Relaxed);
3814                    // Bounce each task so that it sees the new value.
3815                    if let Some(this) = &mut coordinator {
3816                        for worker in &mut this.workers {
3817                            worker.update_with(|_, _| ());
3818                        }
3819                    }
3820                }
3821                Ok(adapter.ring_size_limit.load(Ordering::Relaxed))
3822            });
3823
3824        resp.field("endpoint_type", self.endpoint.endpoint_type())
3825            .field(
3826                "endpoint_max_queues",
3827                self.endpoint.multiqueue_support().max_queues,
3828            )
3829            .sensitivity_field_mut("endpoint", SensitivityLevel::Safe, self.endpoint.as_mut());
3830
3831        if let Some(coordinator) = coordinator {
3832            resp.sensitivity_child("queues", SensitivityLevel::Safe, |req| {
3833                let mut resp = req.respond();
3834                for (i, q) in coordinator.workers[..coordinator.num_queues as usize]
3835                    .iter_mut()
3836                    .enumerate()
3837                {
3838                    resp.field_mut(&i.to_string(), q);
3839                }
3840            });
3841
3842            // Get the shared channel state from the primary channel.
3843            resp.merge(inspect::adhoc_mut(|req| {
3844                let deferred = req.defer();
3845                coordinator.workers[0].update_with(|_, worker| {
3846                    let Some(worker) = worker.as_deref() else {
3847                        return;
3848                    };
3849                    if let Some(state) = worker.state.ready() {
3850                        deferred.respond(|resp| {
3851                            resp.merge(&state.buffers);
3852                            resp.sensitivity_field(
3853                                "primary_channel_state",
3854                                SensitivityLevel::Safe,
3855                                &state.state.primary,
3856                            )
3857                            .sensitivity_field(
3858                                "packet_filter",
3859                                SensitivityLevel::Safe,
3860                                inspect::AsHex(worker.channel.packet_filter),
3861                            );
3862                        });
3863                    }
3864                })
3865            }));
3866        }
3867    }
3868}
3869
3870impl AsyncRun<Coordinator> for CoordinatorState {
3871    async fn run(
3872        &mut self,
3873        stop: &mut StopTask<'_>,
3874        coordinator: &mut Coordinator,
3875    ) -> Result<(), task_control::Cancelled> {
3876        coordinator.process(stop, self).await
3877    }
3878}
3879
3880impl Coordinator {
3881    async fn process(
3882        &mut self,
3883        stop: &mut StopTask<'_>,
3884        state: &mut CoordinatorState,
3885    ) -> Result<(), task_control::Cancelled> {
3886        let mut sleep_duration: Option<Instant> = None;
3887        loop {
3888            if self.restart {
3889                stop.until_stopped(self.stop_workers()).await?;
3890                // The queue restart operation is not restartable, so do not
3891                // poll on `stop` here.
3892                if let Err(err) = self
3893                    .restart_queues(state)
3894                    .instrument(tracing::info_span!("netvsp_restart_queues"))
3895                    .await
3896                {
3897                    tracing::error!(
3898                        error = &err as &dyn std::error::Error,
3899                        "failed to restart queues"
3900                    );
3901                }
3902                if let Some(primary) = self.primary_mut() {
3903                    primary.is_data_path_switched =
3904                        state.endpoint.get_data_path_to_guest_vf().await.ok();
3905                    tracing::info!(
3906                        is_data_path_switched = primary.is_data_path_switched,
3907                        "Query data path state"
3908                    );
3909                }
3910                self.restore_guest_vf_state(state).await;
3911                self.restart = false;
3912            }
3913
3914            // Ensure that all workers except the primary are started. The
3915            // primary is started below if there are no outstanding messages.
3916            for worker in &mut self.workers[1..] {
3917                worker.start();
3918            }
3919
3920            enum Message {
3921                Internal(CoordinatorMessage),
3922                ChannelDisconnected,
3923                UpdateFromEndpoint(EndpointAction),
3924                UpdateFromVf(Rpc<(), ()>),
3925                OfferVfDevice,
3926                PendingVfStateComplete,
3927                TimerExpired,
3928            }
3929            let message = if matches!(
3930                state.pending_vf_state,
3931                CoordinatorStatePendingVfState::Pending
3932            ) {
3933                // The primary worker is allowed to run, but as no
3934                // notifications are being processed, if it is waiting for an
3935                // action then it should remain stopped until this completes
3936                // and the regular message processing logic resumes. Currently
3937                // the only message that requires processing is
3938                // DataPathSwitchPending, so check for that here.
3939                if !self.workers[0].is_running()
3940                    && self.primary_mut().is_none_or(|primary| {
3941                        !matches!(
3942                            primary.guest_vf_state,
3943                            PrimaryChannelGuestVfState::DataPathSwitchPending { result: None, .. }
3944                        )
3945                    })
3946                {
3947                    self.workers[0].start();
3948                }
3949
3950                // guest_ready_for_device is not restartable, so do not poll on
3951                // stop.
3952                state
3953                    .virtual_function
3954                    .as_mut()
3955                    .expect("Pending requires a VF")
3956                    .guest_ready_for_device()
3957                    .await;
3958                Message::PendingVfStateComplete
3959            } else {
3960                let timer_sleep = async {
3961                    if let Some(sleep_duration) = sleep_duration {
3962                        let mut timer = PolledTimer::new(&state.adapter.driver);
3963                        timer.sleep_until(sleep_duration).await;
3964                    } else {
3965                        pending::<()>().await;
3966                    }
3967                    Message::TimerExpired
3968                };
3969                let wait_for_message = async {
3970                    let internal_msg = self
3971                        .recv
3972                        .next()
3973                        .map(|x| x.map_or(Message::ChannelDisconnected, Message::Internal));
3974                    let endpoint_restart = state
3975                        .endpoint
3976                        .wait_for_endpoint_action()
3977                        .map(Message::UpdateFromEndpoint);
3978                    if let Some(vf) = state.virtual_function.as_mut() {
3979                        match state.pending_vf_state {
3980                            CoordinatorStatePendingVfState::Ready
3981                            | CoordinatorStatePendingVfState::Delay { .. } => {
3982                                let offer_device = async {
3983                                    if let CoordinatorStatePendingVfState::Delay {
3984                                        timer,
3985                                        delay_until,
3986                                    } = &mut state.pending_vf_state
3987                                    {
3988                                        timer.sleep_until(*delay_until).await;
3989                                    } else {
3990                                        pending::<()>().await;
3991                                    }
3992                                    Message::OfferVfDevice
3993                                };
3994                                (
3995                                    internal_msg,
3996                                    offer_device,
3997                                    endpoint_restart,
3998                                    vf.wait_for_state_change().map(Message::UpdateFromVf),
3999                                    timer_sleep,
4000                                )
4001                                    .race()
4002                                    .await
4003                            }
4004                            CoordinatorStatePendingVfState::Pending => unreachable!(),
4005                        }
4006                    } else {
4007                        (internal_msg, endpoint_restart, timer_sleep).race().await
4008                    }
4009                };
4010
4011                let mut wait_for_message = std::pin::pin!(wait_for_message);
4012                match (&mut wait_for_message).now_or_never() {
4013                    Some(message) => message,
4014                    None => {
4015                        self.workers[0].start();
4016                        stop.until_stopped(wait_for_message).await?
4017                    }
4018                }
4019            };
4020            match message {
4021                Message::UpdateFromVf(rpc) => {
4022                    rpc.handle(async |_| {
4023                        self.update_guest_vf_state(state).await;
4024                    })
4025                    .await;
4026                }
4027                Message::OfferVfDevice => {
4028                    self.workers[0].stop().await;
4029                    if let Some(primary) = self.primary_mut() {
4030                        if matches!(
4031                            primary.guest_vf_state,
4032                            PrimaryChannelGuestVfState::AvailableAdvertised
4033                        ) {
4034                            primary.guest_vf_state = PrimaryChannelGuestVfState::Ready;
4035                        }
4036                    }
4037
4038                    state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4039                }
4040                Message::PendingVfStateComplete => {
4041                    state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
4042                }
4043                Message::TimerExpired => {
4044                    // Kick the worker as requested.
4045                    if self.workers[0].is_running() {
4046                        self.workers[0].stop().await;
4047                        if let Some(primary) = self.primary_mut() {
4048                            if let PendingLinkAction::Delay(up) = primary.pending_link_action {
4049                                primary.pending_link_action = PendingLinkAction::Active(up);
4050                            }
4051                        }
4052                    }
4053                    sleep_duration = None;
4054                }
4055                Message::Internal(CoordinatorMessage::Update(update_type)) => {
4056                    if update_type.filter_state {
4057                        self.stop_workers().await;
4058                        self.active_packet_filter =
4059                            self.workers[0].state().unwrap().channel.packet_filter;
4060                        self.workers.iter_mut().skip(1).for_each(|worker| {
4061                            if let Some(state) = worker.state_mut() {
4062                                state.channel.packet_filter = self.active_packet_filter;
4063                                tracing::debug!(
4064                                    packet_filter = ?self.active_packet_filter,
4065                                    channel_idx = state.channel_idx,
4066                                    "update packet filter"
4067                                );
4068                            }
4069                        });
4070                    }
4071
4072                    if update_type.guest_vf_state {
4073                        self.update_guest_vf_state(state).await;
4074                    }
4075                }
4076                Message::UpdateFromEndpoint(EndpointAction::RestartRequired) => self.restart = true,
4077                Message::UpdateFromEndpoint(EndpointAction::LinkStatusNotify(connect)) => {
4078                    self.workers[0].stop().await;
4079
4080                    // These are the only link state transitions that are tracked.
4081                    // 1. up -> down or down -> up
4082                    // 2. up -> down -> up or down -> up -> down.
4083                    // All other state transitions are coalesced into one of the above cases.
4084                    // For example, up -> down -> up -> down is treated as up -> down.
4085                    // N.B - Always queue up the incoming state to minimize the effects of loss
4086                    //       of any notifications (for example, during vtl2 servicing).
4087                    if let Some(primary) = self.primary_mut() {
4088                        primary.pending_link_action = PendingLinkAction::Active(connect);
4089                    }
4090
4091                    // If there is any existing sleep timer running, cancel it out.
4092                    sleep_duration = None;
4093                }
4094                Message::Internal(CoordinatorMessage::Restart) => self.restart = true,
4095                Message::Internal(CoordinatorMessage::StartTimer(duration)) => {
4096                    sleep_duration = Some(duration);
4097                    // Restart primary task.
4098                    self.workers[0].stop().await;
4099                }
4100                Message::ChannelDisconnected => {
4101                    break;
4102                }
4103            };
4104        }
4105        Ok(())
4106    }
4107
4108    async fn stop_workers(&mut self) {
4109        for worker in &mut self.workers {
4110            worker.stop().await;
4111        }
4112    }
4113
4114    async fn restore_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
4115        let primary = match self.primary_mut() {
4116            Some(primary) => primary,
4117            None => return,
4118        };
4119
4120        // Update guest VF state based on current endpoint properties.
4121        let virtual_function = c_state.virtual_function.as_mut();
4122        let guest_vf_id = match &virtual_function {
4123            Some(vf) => vf.id().await,
4124            None => None,
4125        };
4126        if let Some(guest_vf_id) = guest_vf_id {
4127            // Ensure guest VF is in proper state.
4128            match primary.guest_vf_state {
4129                PrimaryChannelGuestVfState::AvailableAdvertised
4130                | PrimaryChannelGuestVfState::Restoring(
4131                    saved_state::GuestVfState::AvailableAdvertised,
4132                ) => {
4133                    if !primary.is_data_path_switched.unwrap_or(false) {
4134                        let timer = PolledTimer::new(&c_state.adapter.driver);
4135                        c_state.pending_vf_state = CoordinatorStatePendingVfState::Delay {
4136                            timer,
4137                            delay_until: Instant::now() + VF_DEVICE_DELAY,
4138                        };
4139                    }
4140                }
4141                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4142                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4143                | PrimaryChannelGuestVfState::Ready
4144                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready)
4145                | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4146                | PrimaryChannelGuestVfState::Restoring(
4147                    saved_state::GuestVfState::DataPathSwitchPending { .. },
4148                )
4149                | PrimaryChannelGuestVfState::DataPathSwitched
4150                | PrimaryChannelGuestVfState::Restoring(
4151                    saved_state::GuestVfState::DataPathSwitched,
4152                )
4153                | PrimaryChannelGuestVfState::DataPathSynthetic => {
4154                    c_state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4155                }
4156                _ => (),
4157            };
4158            // ensure data path is switched as expected
4159            if let PrimaryChannelGuestVfState::Restoring(
4160                saved_state::GuestVfState::DataPathSwitchPending {
4161                    to_guest,
4162                    id,
4163                    result,
4164                },
4165            ) = primary.guest_vf_state
4166            {
4167                // If the save was after the data path switch already occurred, don't do it again.
4168                if result.is_some() {
4169                    primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
4170                        to_guest,
4171                        id,
4172                        result,
4173                    };
4174                    return;
4175                }
4176            }
4177            primary.guest_vf_state = match primary.guest_vf_state {
4178                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4179                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4180                | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4181                | PrimaryChannelGuestVfState::Restoring(
4182                    saved_state::GuestVfState::DataPathSwitchPending { .. },
4183                )
4184                | PrimaryChannelGuestVfState::DataPathSynthetic => {
4185                    let (to_guest, id) = match primary.guest_vf_state {
4186                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4187                            to_guest,
4188                            id,
4189                        }
4190                        | PrimaryChannelGuestVfState::DataPathSwitchPending {
4191                            to_guest, id, ..
4192                        }
4193                        | PrimaryChannelGuestVfState::Restoring(
4194                            saved_state::GuestVfState::DataPathSwitchPending {
4195                                to_guest, id, ..
4196                            },
4197                        ) => (to_guest, id),
4198                        _ => (true, None),
4199                    };
4200                    // Cancel any outstanding delay timers for VF offers if the data path is
4201                    // getting switched, since the guest is already issuing
4202                    // commands assuming a VF.
4203                    if matches!(
4204                        c_state.pending_vf_state,
4205                        CoordinatorStatePendingVfState::Delay { .. }
4206                    ) {
4207                        c_state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4208                    }
4209                    let result = c_state.endpoint.set_data_path_to_guest_vf(to_guest).await;
4210                    let result = if let Err(err) = result {
4211                        tracing::error!(
4212                            err = err.as_ref() as &dyn std::error::Error,
4213                            to_guest,
4214                            "Failed to switch guest VF data path"
4215                        );
4216                        false
4217                    } else {
4218                        primary.is_data_path_switched = Some(to_guest);
4219                        true
4220                    };
4221                    match primary.guest_vf_state {
4222                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4223                            ..
4224                        }
4225                        | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4226                        | PrimaryChannelGuestVfState::Restoring(
4227                            saved_state::GuestVfState::DataPathSwitchPending { .. },
4228                        ) => PrimaryChannelGuestVfState::DataPathSwitchPending {
4229                            to_guest,
4230                            id,
4231                            result: Some(result),
4232                        },
4233                        _ if result => PrimaryChannelGuestVfState::DataPathSwitched,
4234                        _ => PrimaryChannelGuestVfState::DataPathSynthetic,
4235                    }
4236                }
4237                PrimaryChannelGuestVfState::Initializing
4238                | PrimaryChannelGuestVfState::Unavailable
4239                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
4240                    PrimaryChannelGuestVfState::Available { vfid: guest_vf_id }
4241                }
4242                PrimaryChannelGuestVfState::AvailableAdvertised
4243                | PrimaryChannelGuestVfState::Restoring(
4244                    saved_state::GuestVfState::AvailableAdvertised,
4245                ) => {
4246                    if !primary.is_data_path_switched.unwrap_or(false) {
4247                        PrimaryChannelGuestVfState::AvailableAdvertised
4248                    } else {
4249                        // A previous instantiation already switched the data
4250                        // path.
4251                        PrimaryChannelGuestVfState::DataPathSwitched
4252                    }
4253                }
4254                PrimaryChannelGuestVfState::DataPathSwitched
4255                | PrimaryChannelGuestVfState::Restoring(
4256                    saved_state::GuestVfState::DataPathSwitched,
4257                ) => PrimaryChannelGuestVfState::DataPathSwitched,
4258                PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4259                    PrimaryChannelGuestVfState::Ready
4260                }
4261                _ => primary.guest_vf_state,
4262            };
4263        } else {
4264            // If the device was just removed, make sure the the data path is synthetic.
4265            match primary.guest_vf_state {
4266                PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, .. }
4267                | PrimaryChannelGuestVfState::Restoring(
4268                    saved_state::GuestVfState::DataPathSwitchPending { to_guest, .. },
4269                ) => {
4270                    if !to_guest {
4271                        if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4272                            tracing::warn!(
4273                                err = err.as_ref() as &dyn std::error::Error,
4274                                "Failed setting data path back to synthetic after guest VF was removed."
4275                            );
4276                        }
4277                        primary.is_data_path_switched = Some(false);
4278                    }
4279                }
4280                PrimaryChannelGuestVfState::DataPathSwitched
4281                | PrimaryChannelGuestVfState::Restoring(
4282                    saved_state::GuestVfState::DataPathSwitched,
4283                ) => {
4284                    if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4285                        tracing::warn!(
4286                            err = err.as_ref() as &dyn std::error::Error,
4287                            "Failed setting data path back to synthetic after guest VF was removed."
4288                        );
4289                    }
4290                    primary.is_data_path_switched = Some(false);
4291                }
4292                _ => (),
4293            }
4294            if let PrimaryChannelGuestVfState::AvailableAdvertised = primary.guest_vf_state {
4295                c_state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
4296            }
4297            // Notify guest if VF is no longer available
4298            primary.guest_vf_state = match primary.guest_vf_state {
4299                PrimaryChannelGuestVfState::Initializing
4300                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState)
4301                | PrimaryChannelGuestVfState::Available { .. } => {
4302                    PrimaryChannelGuestVfState::Unavailable
4303                }
4304                PrimaryChannelGuestVfState::AvailableAdvertised
4305                | PrimaryChannelGuestVfState::Restoring(
4306                    saved_state::GuestVfState::AvailableAdvertised,
4307                )
4308                | PrimaryChannelGuestVfState::Ready
4309                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4310                    PrimaryChannelGuestVfState::UnavailableFromAvailable
4311                }
4312                PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, id, .. }
4313                | PrimaryChannelGuestVfState::Restoring(
4314                    saved_state::GuestVfState::DataPathSwitchPending { to_guest, id, .. },
4315                ) => PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4316                    to_guest,
4317                    id,
4318                },
4319                PrimaryChannelGuestVfState::DataPathSwitched
4320                | PrimaryChannelGuestVfState::Restoring(
4321                    saved_state::GuestVfState::DataPathSwitched,
4322                )
4323                | PrimaryChannelGuestVfState::DataPathSynthetic => {
4324                    PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4325                }
4326                _ => primary.guest_vf_state,
4327            }
4328        }
4329    }
4330
4331    async fn restart_queues(&mut self, c_state: &mut CoordinatorState) -> Result<(), WorkerError> {
4332        // Drop all the queues and stop the endpoint. Collect the worker drivers to pass to the queues.
4333        let drivers = self
4334            .workers
4335            .iter_mut()
4336            .map(|worker| {
4337                let task = worker.task_mut();
4338                task.queue_state = None;
4339                task.driver.clone()
4340            })
4341            .collect::<Vec<_>>();
4342
4343        c_state.endpoint.stop().await;
4344
4345        let (primary_worker, subworkers) = if let [primary, sub @ ..] = self.workers.as_mut_slice()
4346        {
4347            (primary, sub)
4348        } else {
4349            unreachable!()
4350        };
4351
4352        let state = primary_worker
4353            .state_mut()
4354            .and_then(|worker| worker.state.ready_mut());
4355
4356        let state = if let Some(state) = state {
4357            state
4358        } else {
4359            return Ok(());
4360        };
4361
4362        // Save the channel buffers for use in the subchannel workers.
4363        self.buffers = Some(state.buffers.clone());
4364
4365        let num_queues = state.state.primary.as_ref().unwrap().requested_num_queues;
4366        let mut active_queues = Vec::new();
4367        let active_queue_count =
4368            if let Some(rss_state) = state.state.primary.as_ref().unwrap().rss_state.as_ref() {
4369                // Active queue count is computed as the number of unique entries in the indirection table
4370                active_queues.clone_from(&rss_state.indirection_table);
4371                active_queues.sort();
4372                active_queues.dedup();
4373                active_queues = active_queues
4374                    .into_iter()
4375                    .filter(|&index| index < num_queues)
4376                    .collect::<Vec<_>>();
4377                if !active_queues.is_empty() {
4378                    active_queues.len() as u16
4379                } else {
4380                    tracelimit::warn_ratelimited!("Invalid RSS indirection table");
4381                    num_queues
4382                }
4383            } else {
4384                num_queues
4385            };
4386
4387        // Distribute the rx buffers to only the active queues.
4388        let (ranges, mut remote_buffer_id_recvs) =
4389            RxBufferRanges::new(state.buffers.recv_buffer.count, active_queue_count.into());
4390        let ranges = Arc::new(ranges);
4391
4392        let mut queues = Vec::new();
4393        let mut rx_buffers = Vec::new();
4394        {
4395            let buffers = &state.buffers;
4396            let guest_buffers = Arc::new(
4397                GuestBuffers::new(
4398                    buffers.mem.clone(),
4399                    buffers.recv_buffer.gpadl.clone(),
4400                    buffers.recv_buffer.sub_allocation_size,
4401                    buffers.ndis_config.mtu,
4402                )
4403                .map_err(WorkerError::GpadlError)?,
4404            );
4405
4406            // Get the list of free rx buffers from each task, then partition
4407            // the list per-active-queue, and produce the queue configuration.
4408            let mut queue_config = Vec::new();
4409            let initial_rx;
4410            {
4411                let states = std::iter::once(Some(&*state)).chain(
4412                    subworkers
4413                        .iter()
4414                        .map(|worker| worker.state().and_then(|worker| worker.state.ready())),
4415                );
4416
4417                initial_rx = (RX_RESERVED_CONTROL_BUFFERS..state.buffers.recv_buffer.count)
4418                    .filter(|&n| states.clone().flatten().all(|s| s.state.rx_bufs.is_free(n)))
4419                    .map(RxId)
4420                    .collect::<Vec<_>>();
4421
4422                let mut initial_rx = initial_rx.as_slice();
4423                let mut range_start = 0;
4424                let primary_queue_excluded = !active_queues.is_empty() && active_queues[0] != 0;
4425                let first_queue = if !primary_queue_excluded {
4426                    0
4427                } else {
4428                    // If the primary queue is excluded from the guest supplied
4429                    // indirection table, it is assigned just the reserved
4430                    // buffers.
4431                    queue_config.push(QueueConfig {
4432                        pool: Box::new(BufferPool::new(guest_buffers.clone())),
4433                        initial_rx: &[],
4434                        driver: Box::new(drivers[0].clone()),
4435                    });
4436                    rx_buffers.push(RxBufferRange::new(
4437                        ranges.clone(),
4438                        0..RX_RESERVED_CONTROL_BUFFERS,
4439                        None,
4440                    ));
4441                    range_start = RX_RESERVED_CONTROL_BUFFERS;
4442                    1
4443                };
4444                for queue_index in first_queue..num_queues {
4445                    let queue_active = active_queues.is_empty()
4446                        || active_queues.binary_search(&queue_index).is_ok();
4447                    let (range_end, end, buffer_id_recv) = if queue_active {
4448                        let range_end = if rx_buffers.len() as u16 == active_queue_count - 1 {
4449                            // The last queue gets all the remaining buffers.
4450                            state.buffers.recv_buffer.count
4451                        } else if queue_index == 0 {
4452                            // Queue zero always includes the reserved buffers.
4453                            RX_RESERVED_CONTROL_BUFFERS + ranges.buffers_per_queue
4454                        } else {
4455                            range_start + ranges.buffers_per_queue
4456                        };
4457                        (
4458                            range_end,
4459                            initial_rx.partition_point(|id| id.0 < range_end),
4460                            Some(remote_buffer_id_recvs.remove(0)),
4461                        )
4462                    } else {
4463                        (range_start, 0, None)
4464                    };
4465
4466                    let (this, rest) = initial_rx.split_at(end);
4467                    queue_config.push(QueueConfig {
4468                        pool: Box::new(BufferPool::new(guest_buffers.clone())),
4469                        initial_rx: this,
4470                        driver: Box::new(drivers[queue_index as usize].clone()),
4471                    });
4472                    initial_rx = rest;
4473                    rx_buffers.push(RxBufferRange::new(
4474                        ranges.clone(),
4475                        range_start..range_end,
4476                        buffer_id_recv,
4477                    ));
4478
4479                    range_start = range_end;
4480                }
4481            }
4482
4483            let primary = state.state.primary.as_mut().unwrap();
4484            tracing::debug!(num_queues, "enabling endpoint");
4485
4486            let rss = primary
4487                .rss_state
4488                .as_ref()
4489                .map(|rss| net_backend::RssConfig {
4490                    key: &rss.key,
4491                    indirection_table: &rss.indirection_table,
4492                    flags: 0,
4493                });
4494
4495            c_state
4496                .endpoint
4497                .get_queues(queue_config, rss.as_ref(), &mut queues)
4498                .instrument(tracing::info_span!("netvsp_get_queues"))
4499                .await
4500                .map_err(WorkerError::Endpoint)?;
4501
4502            assert_eq!(queues.len(), num_queues as usize);
4503
4504            // Set the subchannel count.
4505            self.channel_control
4506                .enable_subchannels(num_queues - 1)
4507                .expect("already validated");
4508
4509            self.num_queues = num_queues;
4510        }
4511
4512        self.active_packet_filter = self.workers[0].state().unwrap().channel.packet_filter;
4513        // Provide the queue and receive buffer ranges for each worker.
4514        for ((worker, queue), rx_buffer) in self.workers.iter_mut().zip(queues).zip(rx_buffers) {
4515            worker.task_mut().queue_state = Some(QueueState {
4516                queue,
4517                target_vp_set: false,
4518                rx_buffer_range: rx_buffer,
4519            });
4520            // Update the receive packet filter for the subchannel worker.
4521            if let Some(worker) = worker.state_mut() {
4522                worker.channel.packet_filter = self.active_packet_filter;
4523                // Clear any pending RxIds as buffers were redistributed.
4524                if let Some(ready_state) = worker.state.ready_mut() {
4525                    ready_state.state.pending_rx_packets.clear();
4526                }
4527            }
4528        }
4529
4530        Ok(())
4531    }
4532
4533    fn primary_mut(&mut self) -> Option<&mut PrimaryChannelState> {
4534        self.workers[0]
4535            .state_mut()
4536            .unwrap()
4537            .state
4538            .ready_mut()?
4539            .state
4540            .primary
4541            .as_mut()
4542    }
4543
4544    async fn update_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
4545        self.workers[0].stop().await;
4546        self.restore_guest_vf_state(c_state).await;
4547    }
4548}
4549
4550impl<T: RingMem + 'static + Sync> AsyncRun<Worker<T>> for NetQueue {
4551    async fn run(
4552        &mut self,
4553        stop: &mut StopTask<'_>,
4554        worker: &mut Worker<T>,
4555    ) -> Result<(), task_control::Cancelled> {
4556        match worker.process(stop, self).await {
4557            Ok(()) | Err(WorkerError::BufferRevoked) => {}
4558            Err(WorkerError::Cancelled(cancelled)) => return Err(cancelled),
4559            Err(err) => {
4560                tracing::error!(
4561                    channel_idx = worker.channel_idx,
4562                    error = &err as &dyn std::error::Error,
4563                    "netvsp error"
4564                );
4565            }
4566        }
4567        Ok(())
4568    }
4569}
4570
4571impl<T: RingMem + 'static> Worker<T> {
4572    async fn process(
4573        &mut self,
4574        stop: &mut StopTask<'_>,
4575        queue: &mut NetQueue,
4576    ) -> Result<(), WorkerError> {
4577        // Be careful not to wait on actions with unbounded blocking time (e.g.
4578        // guest actions, or waiting for network packets to arrive) without
4579        // wrapping the wait on `stop.until_stopped`.
4580        loop {
4581            match &mut self.state {
4582                WorkerState::Init(initializing) => {
4583                    if self.channel_idx != 0 {
4584                        // Still waiting for the coordinator to provide receive
4585                        // buffers. The task will be restarted when they are available.
4586                        stop.until_stopped(pending()).await?
4587                    }
4588
4589                    tracelimit::info_ratelimited!("network accepted");
4590
4591                    let (buffers, state) = stop
4592                        .until_stopped(self.channel.initialize(initializing, self.mem.clone()))
4593                        .await??;
4594
4595                    let state = ReadyState {
4596                        buffers: Arc::new(buffers),
4597                        state,
4598                        data: ProcessingData::new(),
4599                    };
4600
4601                    // Wake up the coordinator task to start the queues.
4602                    let _ = self.coordinator_send.try_send(CoordinatorMessage::Restart);
4603
4604                    tracelimit::info_ratelimited!("network initialized");
4605                    self.state = WorkerState::Ready(state);
4606                }
4607                WorkerState::Ready(state) => {
4608                    let queue_state = if let Some(queue_state) = &mut queue.queue_state {
4609                        if !queue_state.target_vp_set
4610                            && self.target_vp != vmbus_core::protocol::VP_INDEX_DISABLE_INTERRUPT
4611                        {
4612                            tracing::debug!(
4613                                channel_idx = self.channel_idx,
4614                                target_vp = self.target_vp,
4615                                "updating target VP"
4616                            );
4617                            queue_state.queue.update_target_vp(self.target_vp).await;
4618                            queue_state.target_vp_set = true;
4619                        }
4620
4621                        queue_state
4622                    } else {
4623                        // This task will be restarted when the queues are ready.
4624                        stop.until_stopped(pending()).await?
4625                    };
4626
4627                    let result = self.channel.main_loop(stop, state, queue_state).await;
4628                    match result {
4629                        Ok(restart) => {
4630                            assert_eq!(self.channel_idx, 0);
4631                            let _ = self.coordinator_send.try_send(restart);
4632                        }
4633                        Err(WorkerError::EndpointRequiresQueueRestart(err)) => {
4634                            tracelimit::warn_ratelimited!(
4635                                err = err.as_ref() as &dyn std::error::Error,
4636                                "Endpoint requires queues to restart",
4637                            );
4638                            if let Err(try_send_err) =
4639                                self.coordinator_send.try_send(CoordinatorMessage::Restart)
4640                            {
4641                                tracing::error!(
4642                                    try_send_err = &try_send_err as &dyn std::error::Error,
4643                                    "failed to restart queues"
4644                                );
4645                                return Err(WorkerError::Endpoint(err));
4646                            }
4647                        }
4648                        Err(err) => return Err(err),
4649                    }
4650
4651                    stop.until_stopped(pending()).await?
4652                }
4653            }
4654        }
4655    }
4656}
4657
4658impl<T: 'static + RingMem> NetChannel<T> {
4659    fn try_next_packet<'a>(
4660        &mut self,
4661        send_buffer: Option<&SendBuffer>,
4662        version: Option<Version>,
4663        external_data: &'a mut MultiPagedRangeBuf,
4664    ) -> Result<Option<Packet<'a>>, WorkerError> {
4665        let (mut read, _) = self.queue.split();
4666        let packet = match read.try_read() {
4667            Ok(packet) => parse_packet(&packet, send_buffer, version, external_data)
4668                .map_err(WorkerError::Packet)?,
4669            Err(queue::TryReadError::Empty) => return Ok(None),
4670            Err(queue::TryReadError::Queue(err)) => return Err(err.into()),
4671        };
4672
4673        tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
4674        Ok(Some(packet))
4675    }
4676
4677    async fn next_packet<'a>(
4678        &mut self,
4679        send_buffer: Option<&'a SendBuffer>,
4680        version: Option<Version>,
4681        external_data: &'a mut MultiPagedRangeBuf,
4682    ) -> Result<Packet<'a>, WorkerError> {
4683        let (mut read, _) = self.queue.split();
4684        let mut packet_ref = read.read().await?;
4685        let packet = parse_packet(&packet_ref, send_buffer, version, external_data)
4686            .map_err(WorkerError::Packet)?;
4687        if matches!(packet.data, PacketData::RndisPacket(_)) {
4688            // In WorkerState::Init if an rndis packet is received, assume it is MESSAGE_TYPE_INITIALIZE_MSG
4689            tracing::trace!(target: "netvsp/vmbus", "detected rndis initialization message");
4690            packet_ref.revert();
4691        }
4692        tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
4693        Ok(packet)
4694    }
4695
4696    fn is_ready_to_initialize(initializing: &InitState, allow_missing_send_buffer: bool) -> bool {
4697        (initializing.ndis_config.is_some() || initializing.version < Version::V2)
4698            && initializing.ndis_version.is_some()
4699            && (initializing.send_buffer.is_some() || allow_missing_send_buffer)
4700            && initializing.recv_buffer.is_some()
4701    }
4702
4703    async fn initialize(
4704        &mut self,
4705        initializing: &mut Option<InitState>,
4706        mem: GuestMemory,
4707    ) -> Result<(ChannelBuffers, ActiveState), WorkerError> {
4708        let mut has_init_packet_arrived = false;
4709        loop {
4710            if let Some(initializing) = &mut *initializing {
4711                if Self::is_ready_to_initialize(initializing, false) || has_init_packet_arrived {
4712                    let recv_buffer = initializing.recv_buffer.take().unwrap();
4713                    let send_buffer = initializing.send_buffer.take();
4714                    let state = ActiveState::new(
4715                        Some(PrimaryChannelState::new(
4716                            self.adapter.offload_support.clone(),
4717                        )),
4718                        recv_buffer.count,
4719                    );
4720                    let buffers = ChannelBuffers {
4721                        version: initializing.version,
4722                        mem,
4723                        recv_buffer,
4724                        send_buffer,
4725                        ndis_version: initializing.ndis_version.take().unwrap(),
4726                        ndis_config: initializing.ndis_config.take().unwrap_or(NdisConfig {
4727                            mtu: DEFAULT_MTU,
4728                            capabilities: protocol::NdisConfigCapabilities::new(),
4729                        }),
4730                    };
4731
4732                    break Ok((buffers, state));
4733                }
4734            }
4735
4736            // Wait for enough room in the ring to avoid needing to track
4737            // completion packets.
4738            self.queue
4739                .split()
4740                .1
4741                .wait_ready(ring::PacketSize::completion(protocol::PACKET_SIZE_V61))
4742                .await?;
4743
4744            let mut external_data = MultiPagedRangeBuf::new();
4745            let packet = self
4746                .next_packet(
4747                    None,
4748                    initializing.as_ref().map(|x| x.version),
4749                    &mut external_data,
4750                )
4751                .await?;
4752
4753            if let Some(initializing) = &mut *initializing {
4754                match packet.data {
4755                    PacketData::SendNdisConfig(config) => {
4756                        if initializing.ndis_config.is_some() {
4757                            return Err(WorkerError::UnexpectedPacketOrder(
4758                                PacketOrderError::SendNdisConfigExists,
4759                            ));
4760                        }
4761
4762                        // As in the vmswitch, if the MTU is invalid then use the default.
4763                        let mtu = if config.mtu >= MIN_MTU && config.mtu <= MAX_MTU {
4764                            config.mtu
4765                        } else {
4766                            DEFAULT_MTU
4767                        };
4768
4769                        // The UEFI client expects a completion packet, which can be empty.
4770                        self.send_completion(packet.transaction_id, None)?;
4771                        initializing.ndis_config = Some(NdisConfig {
4772                            mtu,
4773                            capabilities: config.capabilities,
4774                        });
4775                    }
4776                    PacketData::SendNdisVersion(version) => {
4777                        if initializing.ndis_version.is_some() {
4778                            return Err(WorkerError::UnexpectedPacketOrder(
4779                                PacketOrderError::SendNdisVersionExists,
4780                            ));
4781                        }
4782
4783                        // The UEFI client expects a completion packet, which can be empty.
4784                        self.send_completion(packet.transaction_id, None)?;
4785                        initializing.ndis_version = Some(NdisVersion {
4786                            major: version.ndis_major_version,
4787                            minor: version.ndis_minor_version,
4788                        });
4789                    }
4790                    PacketData::SendReceiveBuffer(message) => {
4791                        if initializing.recv_buffer.is_some() {
4792                            return Err(WorkerError::UnexpectedPacketOrder(
4793                                PacketOrderError::SendReceiveBufferExists,
4794                            ));
4795                        }
4796
4797                        let mtu = if let Some(cfg) = &initializing.ndis_config {
4798                            cfg.mtu
4799                        } else if initializing.version < Version::V2 {
4800                            DEFAULT_MTU
4801                        } else {
4802                            return Err(WorkerError::UnexpectedPacketOrder(
4803                                PacketOrderError::SendReceiveBufferMissingMTU,
4804                            ));
4805                        };
4806
4807                        let sub_allocation_size = sub_allocation_size_for_mtu(mtu);
4808
4809                        let recv_buffer = ReceiveBuffer::new(
4810                            &self.gpadl_map,
4811                            message.gpadl_handle,
4812                            message.id,
4813                            sub_allocation_size,
4814                        )?;
4815
4816                        self.send_completion(
4817                            packet.transaction_id,
4818                            Some(&self.message(
4819                                protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER_COMPLETE,
4820                                protocol::Message1SendReceiveBufferComplete {
4821                                    status: protocol::Status::SUCCESS,
4822                                    num_sections: 1,
4823                                    sections: [protocol::ReceiveBufferSection {
4824                                        offset: 0,
4825                                        sub_allocation_size: recv_buffer.sub_allocation_size,
4826                                        num_sub_allocations: recv_buffer.count,
4827                                        end_offset: recv_buffer.sub_allocation_size
4828                                            * recv_buffer.count,
4829                                    }],
4830                                },
4831                            )),
4832                        )?;
4833                        initializing.recv_buffer = Some(recv_buffer);
4834                    }
4835                    PacketData::SendSendBuffer(message) => {
4836                        if initializing.send_buffer.is_some() {
4837                            return Err(WorkerError::UnexpectedPacketOrder(
4838                                PacketOrderError::SendSendBufferExists,
4839                            ));
4840                        }
4841
4842                        let send_buffer = SendBuffer::new(&self.gpadl_map, message.gpadl_handle)?;
4843                        self.send_completion(
4844                            packet.transaction_id,
4845                            Some(&self.message(
4846                                protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER_COMPLETE,
4847                                protocol::Message1SendSendBufferComplete {
4848                                    status: protocol::Status::SUCCESS,
4849                                    section_size: 6144,
4850                                },
4851                            )),
4852                        )?;
4853
4854                        initializing.send_buffer = Some(send_buffer);
4855                    }
4856                    PacketData::RndisPacket(rndis_packet) => {
4857                        if !Self::is_ready_to_initialize(initializing, true) {
4858                            return Err(WorkerError::UnexpectedPacketOrder(
4859                                PacketOrderError::UnexpectedRndisPacket,
4860                            ));
4861                        }
4862                        tracing::debug!(
4863                            channel_type = rndis_packet.channel_type,
4864                            "RndisPacket received during initialization, assuming MESSAGE_TYPE_INITIALIZE_MSG"
4865                        );
4866                        has_init_packet_arrived = true;
4867                    }
4868                    _ => {
4869                        return Err(WorkerError::UnexpectedPacketOrder(
4870                            PacketOrderError::InvalidPacketData,
4871                        ));
4872                    }
4873                }
4874            } else {
4875                match packet.data {
4876                    PacketData::Init(init) => {
4877                        let requested_version = init.protocol_version;
4878                        let version = check_version(requested_version);
4879                        let mut data = protocol::MessageInitComplete {
4880                            deprecated: protocol::INVALID_PROTOCOL_VERSION,
4881                            maximum_mdl_chain_length: 34,
4882                            status: protocol::Status::NONE,
4883                        };
4884                        if let Some(version) = version {
4885                            if version == Version::V1 {
4886                                data.deprecated = Version::V1 as u32;
4887                            }
4888                            data.status = protocol::Status::SUCCESS;
4889                        } else {
4890                            tracing::debug!(requested_version, "unrecognized version");
4891                        }
4892                        let message = self.message(protocol::MESSAGE_TYPE_INIT_COMPLETE, data);
4893                        self.send_completion(packet.transaction_id, Some(&message))?;
4894
4895                        if let Some(version) = version {
4896                            tracelimit::info_ratelimited!(?version, "network negotiated");
4897
4898                            if version >= Version::V61 {
4899                                // Update the packet size so that the appropriate padding is
4900                                // appended for picky Windows guests.
4901                                self.packet_size = PacketSize::V61;
4902                            }
4903                            *initializing = Some(InitState {
4904                                version,
4905                                ndis_config: None,
4906                                ndis_version: None,
4907                                recv_buffer: None,
4908                                send_buffer: None,
4909                            });
4910                        }
4911                    }
4912                    _ => unreachable!(),
4913                }
4914            }
4915        }
4916    }
4917
4918    async fn main_loop(
4919        &mut self,
4920        stop: &mut StopTask<'_>,
4921        ready_state: &mut ReadyState,
4922        queue_state: &mut QueueState,
4923    ) -> Result<CoordinatorMessage, WorkerError> {
4924        let buffers = &ready_state.buffers;
4925        let state = &mut ready_state.state;
4926        let data = &mut ready_state.data;
4927
4928        let ring_spare_capacity = {
4929            let (_, send) = self.queue.split();
4930            let mut limit = if self.can_use_ring_size_opt {
4931                self.adapter.ring_size_limit.load(Ordering::Relaxed)
4932            } else {
4933                0
4934            };
4935            if limit == 0 {
4936                limit = send.capacity() - 2048;
4937            }
4938            send.capacity() - limit
4939        };
4940
4941        // If the packet filter has changed to allow rx packets, add any pended RxIds.
4942        if !state.pending_rx_packets.is_empty()
4943            && self.packet_filter != rndisprot::NDIS_PACKET_TYPE_NONE
4944        {
4945            let epqueue = queue_state.queue.as_mut();
4946            let (front, back) = state.pending_rx_packets.as_slices();
4947            epqueue.rx_avail(front);
4948            epqueue.rx_avail(back);
4949            state.pending_rx_packets.clear();
4950        }
4951
4952        // Handle any guest state changes since last run.
4953        if let Some(primary) = state.primary.as_mut() {
4954            if primary.requested_num_queues > 1 && !primary.tx_spread_sent {
4955                let num_channels_opened =
4956                    self.adapter.num_sub_channels_opened.load(Ordering::Relaxed);
4957                if num_channels_opened == primary.requested_num_queues as usize {
4958                    let (_, mut send) = self.queue.split();
4959                    stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
4960                        .await??;
4961                    self.guest_send_indirection_table(buffers.version, num_channels_opened as u32);
4962                    primary.tx_spread_sent = true;
4963                }
4964            }
4965            if let PendingLinkAction::Active(up) = primary.pending_link_action {
4966                let (_, mut send) = self.queue.split();
4967                stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
4968                    .await??;
4969                if let Some(id) = primary.free_control_buffers.pop() {
4970                    let connect = if primary.guest_link_up != up {
4971                        primary.pending_link_action = PendingLinkAction::Default;
4972                        up
4973                    } else {
4974                        // For the up -> down -> up OR down -> up -> down case, the first transition
4975                        // is sent immediately and the second transition is queued with a delay.
4976                        primary.pending_link_action =
4977                            PendingLinkAction::Delay(primary.guest_link_up);
4978                        !primary.guest_link_up
4979                    };
4980                    // Mark the receive buffer in use to allow the guest to release it.
4981                    assert!(state.rx_bufs.is_free(id.0));
4982                    state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
4983                    let state_to_send = if connect {
4984                        rndisprot::STATUS_MEDIA_CONNECT
4985                    } else {
4986                        rndisprot::STATUS_MEDIA_DISCONNECT
4987                    };
4988                    tracing::info!(
4989                        connect,
4990                        mac_address = %self.adapter.mac_address,
4991                        "sending link status"
4992                    );
4993
4994                    self.indicate_status(buffers, id.0, state_to_send, &[])?;
4995                    primary.guest_link_up = connect;
4996                } else {
4997                    primary.pending_link_action = PendingLinkAction::Delay(up);
4998                }
4999
5000                match primary.pending_link_action {
5001                    PendingLinkAction::Delay(_) => {
5002                        return Ok(CoordinatorMessage::StartTimer(
5003                            Instant::now() + LINK_DELAY_DURATION,
5004                        ));
5005                    }
5006                    PendingLinkAction::Active(_) => panic!("State should not be Active"),
5007                    _ => {}
5008                }
5009            }
5010            match primary.guest_vf_state {
5011                PrimaryChannelGuestVfState::Available { .. }
5012                | PrimaryChannelGuestVfState::UnavailableFromAvailable
5013                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
5014                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
5015                | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
5016                | PrimaryChannelGuestVfState::DataPathSynthetic => {
5017                    let (_, mut send) = self.queue.split();
5018                    stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
5019                        .await??;
5020                    if let Some(message) = self.handle_state_change(primary, buffers).await? {
5021                        return Ok(message);
5022                    }
5023                }
5024                _ => (),
5025            }
5026        }
5027
5028        loop {
5029            // If the ring is almost full, then do not poll the endpoint,
5030            // since this will cause us to spend time dropping packets and
5031            // reprogramming receive buffers. It's more efficient to let the
5032            // backend drop packets.
5033            let ring_full = {
5034                let (_, mut send) = self.queue.split();
5035                !send.can_write(ring_spare_capacity)?
5036            };
5037
5038            let did_some_work = (!ring_full
5039                && self.process_endpoint_rx(buffers, state, data, queue_state.queue.as_mut())?)
5040                | self.process_ring_buffer(buffers, state, data, queue_state)?
5041                | (!ring_full
5042                    && self.process_endpoint_tx(state, data, queue_state.queue.as_mut())?)
5043                | self.transmit_pending_segments(state, data, queue_state)?
5044                | self.send_pending_packets(state)?;
5045
5046            if !did_some_work {
5047                state.stats.spurious_wakes.increment();
5048            }
5049
5050            // Process any outstanding control messages before sleeping in case
5051            // there are now suballocations available for use.
5052            self.process_control_messages(buffers, state)?;
5053
5054            // This should be the only await point waiting on network traffic or
5055            // guest actions. Wrap it in `stop.until_stopped` to allow
5056            // cancellation.
5057            let restart = stop
5058                .until_stopped(std::future::poll_fn(
5059                    |cx| -> Poll<Option<CoordinatorMessage>> {
5060                        // If the ring is almost full, then don't wait for endpoint
5061                        // interrupts. This allows the interrupt rate to fall when the
5062                        // guest cannot keep up with the load.
5063                        if !ring_full {
5064                            // Check the network endpoint for tx completion or rx.
5065                            if queue_state.queue.poll_ready(cx).is_ready() {
5066                                tracing::trace!("endpoint ready");
5067                                return Poll::Ready(None);
5068                            }
5069                        }
5070
5071                        // Check the incoming ring for tx, but only if there are enough
5072                        // free tx packets and no pending tx segments.
5073                        let (mut recv, mut send) = self.queue.split();
5074                        if state.free_tx_packets.len() >= self.adapter.free_tx_packet_threshold
5075                            && data.tx_segments.is_empty()
5076                            && recv.poll_ready(cx).is_ready()
5077                        {
5078                            tracing::trace!("incoming ring ready");
5079                            return Poll::Ready(None);
5080                        }
5081
5082                        // Check the outgoing ring for space to send rx completions or
5083                        // control message sends, if any are pending.
5084                        //
5085                        // Also, if endpoint processing has been suspended due to the
5086                        // ring being nearly full, then ask the guest to wake us up when
5087                        // there is space again.
5088                        let mut pending_send_size = self.pending_send_size;
5089                        if ring_full {
5090                            pending_send_size = ring_spare_capacity;
5091                        }
5092                        if pending_send_size != 0
5093                            && send.poll_ready(cx, pending_send_size).is_ready()
5094                        {
5095                            tracing::trace!("outgoing ring ready");
5096                            return Poll::Ready(None);
5097                        }
5098
5099                        // Collect any of this queue's receive buffers that were
5100                        // completed by a remote channel. This only happens when the
5101                        // subchannel count changes, so that receive buffer ownership
5102                        // moves between queues while some receive buffers are still in
5103                        // use.
5104                        if let Some(remote_buffer_id_recv) =
5105                            &mut queue_state.rx_buffer_range.remote_buffer_id_recv
5106                        {
5107                            while let Poll::Ready(Some(id)) =
5108                                remote_buffer_id_recv.poll_next_unpin(cx)
5109                            {
5110                                if id >= RX_RESERVED_CONTROL_BUFFERS {
5111                                    queue_state.queue.rx_avail(&[RxId(id)]);
5112                                } else {
5113                                    state
5114                                        .primary
5115                                        .as_mut()
5116                                        .unwrap()
5117                                        .free_control_buffers
5118                                        .push(ControlMessageId(id));
5119                                }
5120                            }
5121                        }
5122
5123                        if let Some(restart) = self.restart.take() {
5124                            return Poll::Ready(Some(restart));
5125                        }
5126
5127                        tracing::trace!("network waiting");
5128                        Poll::Pending
5129                    },
5130                ))
5131                .await?;
5132
5133            if let Some(restart) = restart {
5134                break Ok(restart);
5135            }
5136        }
5137    }
5138
5139    fn process_endpoint_rx(
5140        &mut self,
5141        buffers: &ChannelBuffers,
5142        state: &mut ActiveState,
5143        data: &mut ProcessingData,
5144        epqueue: &mut dyn net_backend::Queue,
5145    ) -> Result<bool, WorkerError> {
5146        let n = epqueue
5147            .rx_poll(&mut data.rx_ready)
5148            .map_err(WorkerError::Endpoint)?;
5149        if n == 0 {
5150            return Ok(false);
5151        }
5152
5153        state.stats.rx_packets_per_wake.add_sample(n as u64);
5154
5155        if self.packet_filter == rndisprot::NDIS_PACKET_TYPE_NONE {
5156            tracing::trace!(
5157                packet_filter = self.packet_filter,
5158                "rx packets dropped due to packet filter"
5159            );
5160            // Pend the newly available RxIds until the packet filter is updated.
5161            // Under high load this will eventually lead to no available RxIds,
5162            // which will cause the backend to drop the packets instead of
5163            // processing them here.
5164            state.pending_rx_packets.extend(&data.rx_ready[..n]);
5165            state.stats.rx_dropped_filtered.add(n as u64);
5166            return Ok(false);
5167        }
5168
5169        let transaction_id = data.rx_ready[0].0.into();
5170        let ready_ids = data.rx_ready[..n].iter().map(|&RxId(id)| id);
5171
5172        state.rx_bufs.allocate(ready_ids.clone()).unwrap();
5173
5174        // Always use the full suballocation size to avoid tracking the
5175        // message length. See RxBuf::header() for details.
5176        let len = buffers.recv_buffer.sub_allocation_size as usize;
5177        data.transfer_pages.clear();
5178        data.transfer_pages
5179            .extend(ready_ids.map(|id| buffers.recv_buffer.transfer_page_range(id, len)));
5180
5181        match self.try_send_rndis_message(
5182            transaction_id,
5183            protocol::DATA_CHANNEL_TYPE,
5184            buffers.recv_buffer.id,
5185            &data.transfer_pages,
5186        )? {
5187            None => {
5188                // packet was sent
5189                state.stats.rx_packets.add(n as u64);
5190            }
5191            Some(_) => {
5192                // Ring buffer is full. Drop the packets and free the rx
5193                // buffers. When the ring has limited space, the main loop will
5194                // stop polling for receive packets.
5195                state.stats.rx_dropped_ring_full.add(n as u64);
5196
5197                state.rx_bufs.free(data.rx_ready[0].0);
5198                epqueue.rx_avail(&data.rx_ready[..n]);
5199            }
5200        }
5201
5202        Ok(true)
5203    }
5204
5205    fn process_endpoint_tx(
5206        &mut self,
5207        state: &mut ActiveState,
5208        data: &mut ProcessingData,
5209        epqueue: &mut dyn net_backend::Queue,
5210    ) -> Result<bool, WorkerError> {
5211        // Drain completed transmits.
5212        let result = epqueue.tx_poll(&mut data.tx_done);
5213
5214        match result {
5215            Ok(n) => {
5216                if n == 0 {
5217                    return Ok(false);
5218                }
5219
5220                for &id in &data.tx_done[..n] {
5221                    let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5222                    assert!(tx_packet.pending_packet_count > 0);
5223                    tx_packet.pending_packet_count -= 1;
5224                    if tx_packet.pending_packet_count == 0 {
5225                        self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5226                    }
5227                }
5228
5229                Ok(true)
5230            }
5231            Err(TxError::TryRestart(err)) => {
5232                // Complete any pending tx prior to restarting queues.
5233                let pending_tx = state
5234                    .pending_tx_packets
5235                    .iter_mut()
5236                    .enumerate()
5237                    .filter_map(|(id, inflight)| {
5238                        if inflight.pending_packet_count > 0 {
5239                            inflight.pending_packet_count = 0;
5240                            Some(PendingTxCompletion {
5241                                transaction_id: inflight.transaction_id,
5242                                tx_id: Some(TxId(id as u32)),
5243                                status: protocol::Status::SUCCESS,
5244                            })
5245                        } else {
5246                            None
5247                        }
5248                    })
5249                    .collect::<Vec<_>>();
5250                state.pending_tx_completions.extend(pending_tx);
5251                Err(WorkerError::EndpointRequiresQueueRestart(err))
5252            }
5253            Err(TxError::Fatal(err)) => Err(WorkerError::Endpoint(err)),
5254        }
5255    }
5256
5257    fn switch_data_path(
5258        &mut self,
5259        state: &mut ActiveState,
5260        use_guest_vf: bool,
5261        transaction_id: Option<u64>,
5262    ) -> Result<(), WorkerError> {
5263        let primary = state.primary.as_mut().unwrap();
5264        let mut queue_switch_operation = false;
5265        match primary.guest_vf_state {
5266            PrimaryChannelGuestVfState::AvailableAdvertised | PrimaryChannelGuestVfState::Ready => {
5267                // Allow the guest to switch to VTL0, or if the current state
5268                // of the data path is unknown, allow a switch away from VTL0.
5269                // The latter case handles the scenario where the data path has
5270                // been switched but the synthetic device has been restarted.
5271                // The device is queried for the current state of the data path
5272                // but if it is unknown (upgraded from older version that
5273                // doesn't track this data, or failure during query) then the
5274                // safest option is to pass the request through.
5275                if use_guest_vf || primary.is_data_path_switched.is_none() {
5276                    primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5277                        to_guest: use_guest_vf,
5278                        id: transaction_id,
5279                        result: None,
5280                    };
5281                    queue_switch_operation = true;
5282                }
5283            }
5284            PrimaryChannelGuestVfState::DataPathSwitched => {
5285                if !use_guest_vf {
5286                    primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5287                        to_guest: false,
5288                        id: transaction_id,
5289                        result: None,
5290                    };
5291                    queue_switch_operation = true;
5292                }
5293            }
5294            _ if use_guest_vf => {
5295                tracing::warn!(
5296                    state = %primary.guest_vf_state,
5297                    use_guest_vf,
5298                    "Data path switch requested while device is in wrong state"
5299                );
5300            }
5301            _ => (),
5302        };
5303        if queue_switch_operation {
5304            self.send_coordinator_update_vf();
5305        } else {
5306            self.send_completion(transaction_id, None)?;
5307        }
5308        Ok(())
5309    }
5310
5311    fn process_ring_buffer(
5312        &mut self,
5313        buffers: &ChannelBuffers,
5314        state: &mut ActiveState,
5315        data: &mut ProcessingData,
5316        queue_state: &mut QueueState,
5317    ) -> Result<bool, WorkerError> {
5318        if !data.tx_segments.is_empty() {
5319            // There are still segments pending transmission. Skip polling the
5320            // ring until they are sent to increase backpressure and minimize
5321            // unnecessary wakeups.
5322            return Ok(false);
5323        }
5324        let mut total_packets = 0;
5325        let mut did_some_work = false;
5326        loop {
5327            if state.free_tx_packets.is_empty() {
5328                break;
5329            }
5330            let packet = if let Some(packet) = self.try_next_packet(
5331                buffers.send_buffer.as_ref(),
5332                Some(buffers.version),
5333                &mut data.external_data,
5334            )? {
5335                packet
5336            } else {
5337                break;
5338            };
5339
5340            did_some_work = true;
5341            match packet.data {
5342                PacketData::RndisPacket(_) => {
5343                    let id = state.free_tx_packets.pop().unwrap();
5344                    let result: Result<usize, WorkerError> =
5345                        self.handle_rndis(buffers, id, state, &packet, &mut data.tx_segments);
5346                    match result {
5347                        Ok(num_packets) => {
5348                            total_packets += num_packets as u64;
5349                            if num_packets == 0 {
5350                                self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5351                            }
5352                        }
5353                        Err(err) => {
5354                            tracelimit::error_ratelimited!(
5355                                err = &err as &dyn std::error::Error,
5356                                "failed to handle RNDIS packet"
5357                            );
5358                            self.complete_tx_packet(state, id, protocol::Status::FAILURE)?;
5359                        }
5360                    };
5361                }
5362                PacketData::RndisPacketComplete(_completion) => {
5363                    data.rx_done.clear();
5364                    state
5365                        .release_recv_buffers(
5366                            packet
5367                                .transaction_id
5368                                .expect("completion packets have transaction id by construction"),
5369                            &queue_state.rx_buffer_range,
5370                            &mut data.rx_done,
5371                        )
5372                        .ok_or(WorkerError::InvalidRndisPacketCompletion)?;
5373                    queue_state.queue.rx_avail(&data.rx_done);
5374                }
5375                PacketData::SubChannelRequest(request) if state.primary.is_some() => {
5376                    let mut subchannel_count = 0;
5377                    let status = if request.operation == protocol::SubchannelOperation::ALLOCATE
5378                        && request.num_sub_channels <= self.adapter.max_queues.into()
5379                    {
5380                        subchannel_count = request.num_sub_channels;
5381                        protocol::Status::SUCCESS
5382                    } else {
5383                        protocol::Status::FAILURE
5384                    };
5385
5386                    tracing::debug!(?status, subchannel_count, "subchannel request");
5387                    self.send_completion(
5388                        packet.transaction_id,
5389                        Some(&self.message(
5390                            protocol::MESSAGE5_TYPE_SUB_CHANNEL,
5391                            protocol::Message5SubchannelComplete {
5392                                status,
5393                                num_sub_channels: subchannel_count,
5394                            },
5395                        )),
5396                    )?;
5397
5398                    if subchannel_count > 0 {
5399                        let primary = state.primary.as_mut().unwrap();
5400                        primary.requested_num_queues = subchannel_count as u16 + 1;
5401                        primary.tx_spread_sent = false;
5402                        self.restart = Some(CoordinatorMessage::Restart);
5403                    }
5404                }
5405                PacketData::RevokeReceiveBuffer(Message1RevokeReceiveBuffer { id })
5406                | PacketData::RevokeSendBuffer(Message1RevokeSendBuffer { id })
5407                    if state.primary.is_some() =>
5408                {
5409                    tracing::debug!(
5410                        id,
5411                        "receive/send buffer revoked, terminating channel processing"
5412                    );
5413                    return Err(WorkerError::BufferRevoked);
5414                }
5415                // No operation for VF association completion packets as not all clients send them
5416                PacketData::SendVfAssociationCompletion if state.primary.is_some() => (),
5417                PacketData::SwitchDataPath(switch_data_path) if state.primary.is_some() => {
5418                    self.switch_data_path(
5419                        state,
5420                        switch_data_path.active_data_path == protocol::DataPath::VF.0,
5421                        packet.transaction_id,
5422                    )?;
5423                }
5424                PacketData::SwitchDataPathCompletion if state.primary.is_some() => (),
5425                PacketData::OidQueryEx(oid_query) => {
5426                    tracing::warn!(oid = ?oid_query.oid, "unimplemented OID");
5427                    self.send_completion(
5428                        packet.transaction_id,
5429                        Some(&self.message(
5430                            protocol::MESSAGE5_TYPE_OID_QUERY_EX_COMPLETE,
5431                            protocol::Message5OidQueryExComplete {
5432                                status: rndisprot::STATUS_NOT_SUPPORTED,
5433                                bytes: 0,
5434                            },
5435                        )),
5436                    )?;
5437                }
5438                p => {
5439                    tracing::warn!(request = ?p, "unexpected packet");
5440                    return Err(WorkerError::UnexpectedPacketOrder(
5441                        PacketOrderError::SwitchDataPathCompletionPrimaryChannelState,
5442                    ));
5443                }
5444            }
5445        }
5446        if total_packets > 0 && !self.transmit_segments(state, data, queue_state)? {
5447            state.stats.tx_stalled.increment();
5448        }
5449        state.stats.tx_packets_per_wake.add_sample(total_packets);
5450        Ok(did_some_work)
5451    }
5452
5453    // Transmit any pending segments. Returns Ok(true) if work was done--if any
5454    // segments were transmitted.
5455    fn transmit_pending_segments(
5456        &mut self,
5457        state: &mut ActiveState,
5458        data: &mut ProcessingData,
5459        queue_state: &mut QueueState,
5460    ) -> Result<bool, WorkerError> {
5461        if data.tx_segments.is_empty() {
5462            return Ok(false);
5463        }
5464        let sent = data.tx_segments_sent;
5465        let did_work =
5466            self.transmit_segments(state, data, queue_state)? || data.tx_segments_sent > sent;
5467        Ok(did_work)
5468    }
5469
5470    /// Returns true if all pending segments were transmitted.
5471    fn transmit_segments(
5472        &mut self,
5473        state: &mut ActiveState,
5474        data: &mut ProcessingData,
5475        queue_state: &mut QueueState,
5476    ) -> Result<bool, WorkerError> {
5477        let segments = &data.tx_segments[data.tx_segments_sent..];
5478        let (sync, segments_sent) = queue_state
5479            .queue
5480            .tx_avail(segments)
5481            .map_err(WorkerError::Endpoint)?;
5482
5483        let mut segments = &segments[..segments_sent];
5484        data.tx_segments_sent += segments_sent;
5485
5486        if sync {
5487            // Complete the packets now.
5488            while let Some(head) = segments.first() {
5489                let net_backend::TxSegmentType::Head(metadata) = &head.ty else {
5490                    unreachable!()
5491                };
5492                let id = metadata.id;
5493                let pending_tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5494                pending_tx_packet.pending_packet_count -= 1;
5495                if pending_tx_packet.pending_packet_count == 0 {
5496                    self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5497                }
5498                segments = &segments[metadata.segment_count as usize..];
5499            }
5500        }
5501
5502        let all_sent = data.tx_segments_sent == data.tx_segments.len();
5503        if all_sent {
5504            data.tx_segments.clear();
5505            data.tx_segments_sent = 0;
5506        }
5507        Ok(all_sent)
5508    }
5509
5510    fn handle_rndis(
5511        &mut self,
5512        buffers: &ChannelBuffers,
5513        id: TxId,
5514        state: &mut ActiveState,
5515        packet: &Packet<'_>,
5516        segments: &mut Vec<TxSegment>,
5517    ) -> Result<usize, WorkerError> {
5518        let mut total_packets = 0;
5519        let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5520        assert_eq!(tx_packet.pending_packet_count, 0);
5521        tx_packet.transaction_id = packet
5522            .transaction_id
5523            .ok_or(WorkerError::MissingTransactionId)?;
5524
5525        // Probe the data to catch accesses that are out of bounds. This
5526        // simplifies error handling for backends that use
5527        // [`GuestMemory::iova`].
5528        packet
5529            .external_data
5530            .iter()
5531            .try_for_each(|range| buffers.mem.probe_gpns(range.gpns()))
5532            .map_err(WorkerError::GpaDirectError)?;
5533
5534        let mut reader = packet.rndis_reader(&buffers.mem);
5535        let header: rndisprot::MessageHeader = reader.read_plain()?;
5536        if header.message_type == rndisprot::MESSAGE_TYPE_PACKET_MSG {
5537            let start = segments.len();
5538            match self.handle_rndis_packet_messages(
5539                buffers,
5540                state,
5541                id,
5542                header.message_length as usize,
5543                reader,
5544                segments,
5545            ) {
5546                Ok(n) => {
5547                    state.pending_tx_packets[id.0 as usize].pending_packet_count += n;
5548                    total_packets += n;
5549                }
5550                Err(err) => {
5551                    // Roll back any segments added for this message.
5552                    segments.truncate(start);
5553                    return Err(err);
5554                }
5555            }
5556        } else {
5557            self.handle_rndis_message(state, header.message_type, reader)?;
5558        }
5559
5560        Ok(total_packets)
5561    }
5562
5563    fn try_send_tx_packet(
5564        &mut self,
5565        transaction_id: u64,
5566        status: protocol::Status,
5567    ) -> Result<bool, WorkerError> {
5568        let message = self.message(
5569            protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE,
5570            protocol::Message1SendRndisPacketComplete { status },
5571        );
5572        let result = self.queue.split().1.batched().try_write_aligned(
5573            transaction_id,
5574            OutgoingPacketType::Completion,
5575            message.aligned_payload(),
5576        );
5577        let sent = match result {
5578            Ok(()) => true,
5579            Err(queue::TryWriteError::Full(n)) => {
5580                self.pending_send_size = n;
5581                false
5582            }
5583            Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
5584        };
5585        Ok(sent)
5586    }
5587
5588    fn send_pending_packets(&mut self, state: &mut ActiveState) -> Result<bool, WorkerError> {
5589        let mut did_some_work = false;
5590        while let Some(pending) = state.pending_tx_completions.front() {
5591            if !self.try_send_tx_packet(pending.transaction_id, pending.status)? {
5592                return Ok(did_some_work);
5593            }
5594            did_some_work = true;
5595            if let Some(id) = pending.tx_id {
5596                state.free_tx_packets.push(id);
5597            }
5598            tracing::trace!(?pending, "sent tx completion");
5599            state.pending_tx_completions.pop_front();
5600        }
5601
5602        self.pending_send_size = 0;
5603        Ok(did_some_work)
5604    }
5605
5606    fn complete_tx_packet(
5607        &mut self,
5608        state: &mut ActiveState,
5609        id: TxId,
5610        status: protocol::Status,
5611    ) -> Result<(), WorkerError> {
5612        let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5613        assert_eq!(tx_packet.pending_packet_count, 0);
5614        if self.pending_send_size == 0
5615            && self.try_send_tx_packet(tx_packet.transaction_id, status)?
5616        {
5617            tracing::trace!(id = id.0, "sent tx completion");
5618            state.free_tx_packets.push(id);
5619        } else {
5620            tracing::trace!(id = id.0, "pended tx completion");
5621            state.pending_tx_completions.push_back(PendingTxCompletion {
5622                transaction_id: tx_packet.transaction_id,
5623                tx_id: Some(id),
5624                status,
5625            });
5626        }
5627        Ok(())
5628    }
5629}
5630
5631impl ActiveState {
5632    fn release_recv_buffers(
5633        &mut self,
5634        transaction_id: u64,
5635        rx_buffer_range: &RxBufferRange,
5636        done: &mut Vec<RxId>,
5637    ) -> Option<()> {
5638        // The transaction ID specifies the first rx buffer ID.
5639        let first_id: u32 = transaction_id.try_into().ok()?;
5640        let ids = self.rx_bufs.free(first_id)?;
5641        for id in ids {
5642            if !rx_buffer_range.send_if_remote(id) {
5643                if id >= RX_RESERVED_CONTROL_BUFFERS {
5644                    done.push(RxId(id));
5645                } else {
5646                    self.primary
5647                        .as_mut()?
5648                        .free_control_buffers
5649                        .push(ControlMessageId(id));
5650                }
5651            }
5652        }
5653        Some(())
5654    }
5655}