netvsp/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! The user-mode netvsp VMBus device implementation.
5
6#![expect(missing_docs)]
7#![forbid(unsafe_code)]
8
9mod buffers;
10mod protocol;
11pub mod resolver;
12mod rndisprot;
13mod rx_bufs;
14mod saved_state;
15mod test;
16
17use crate::buffers::GuestBuffers;
18use crate::protocol::Message1RevokeReceiveBuffer;
19use crate::protocol::Message1RevokeSendBuffer;
20use crate::protocol::VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES;
21use crate::protocol::Version;
22use crate::rndisprot::NDIS_HASH_FUNCTION_MASK;
23use crate::rndisprot::NDIS_RSS_PARAM_FLAG_DISABLE_RSS;
24use async_trait::async_trait;
25pub use buffers::BufferPool;
26use buffers::sub_allocation_size_for_mtu;
27use futures::FutureExt;
28use futures::StreamExt;
29use futures::channel::mpsc;
30use futures_concurrency::future::Race;
31use guestmem::AccessError;
32use guestmem::GuestMemory;
33use guestmem::GuestMemoryError;
34use guestmem::MemoryRead;
35use guestmem::MemoryWrite;
36use guestmem::ranges::PagedRange;
37use guestmem::ranges::PagedRanges;
38use guestmem::ranges::PagedRangesReader;
39use guid::Guid;
40use hvdef::hypercall::HvGuestOsId;
41use hvdef::hypercall::HvGuestOsMicrosoft;
42use hvdef::hypercall::HvGuestOsMicrosoftIds;
43use hvdef::hypercall::HvGuestOsOpenSourceType;
44use inspect::Inspect;
45use inspect::InspectMut;
46use inspect::SensitivityLevel;
47use inspect_counters::Counter;
48use inspect_counters::Histogram;
49use mesh::rpc::Rpc;
50use net_backend::Endpoint;
51use net_backend::EndpointAction;
52use net_backend::L3Protocol;
53use net_backend::QueueConfig;
54use net_backend::RxId;
55use net_backend::TxError;
56use net_backend::TxId;
57use net_backend::TxSegment;
58use net_backend_resources::mac_address::MacAddress;
59use pal_async::timer::Instant;
60use pal_async::timer::PolledTimer;
61use ring::gparange::MultiPagedRangeIter;
62use rx_bufs::RxBuffers;
63use rx_bufs::SubAllocationInUse;
64use std::cmp;
65use std::collections::VecDeque;
66use std::fmt::Debug;
67use std::future::pending;
68use std::mem::offset_of;
69use std::ops::Range;
70use std::sync::Arc;
71use std::sync::atomic::AtomicUsize;
72use std::sync::atomic::Ordering;
73use std::task::Poll;
74use std::time::Duration;
75use task_control::AsyncRun;
76use task_control::InspectTaskMut;
77use task_control::StopTask;
78use task_control::TaskControl;
79use thiserror::Error;
80use tracing::Instrument;
81use vmbus_async::queue;
82use vmbus_async::queue::ExternalDataError;
83use vmbus_async::queue::IncomingPacket;
84use vmbus_async::queue::Queue;
85use vmbus_channel::bus::OfferParams;
86use vmbus_channel::bus::OpenRequest;
87use vmbus_channel::channel::ChannelControl;
88use vmbus_channel::channel::ChannelOpenError;
89use vmbus_channel::channel::ChannelRestoreError;
90use vmbus_channel::channel::DeviceResources;
91use vmbus_channel::channel::RestoreControl;
92use vmbus_channel::channel::SaveRestoreVmbusDevice;
93use vmbus_channel::channel::VmbusDevice;
94use vmbus_channel::gpadl::GpadlId;
95use vmbus_channel::gpadl::GpadlMapView;
96use vmbus_channel::gpadl::GpadlView;
97use vmbus_channel::gpadl::UnknownGpadlId;
98use vmbus_channel::gpadl_ring::GpadlRingMem;
99use vmbus_channel::gpadl_ring::gpadl_channel;
100use vmbus_ring as ring;
101use vmbus_ring::OutgoingPacketType;
102use vmbus_ring::RingMem;
103use vmbus_ring::gparange::GpnList;
104use vmbus_ring::gparange::MultiPagedRangeBuf;
105use vmcore::save_restore::RestoreError;
106use vmcore::save_restore::SaveError;
107use vmcore::save_restore::SavedStateBlob;
108use vmcore::vm_task::VmTaskDriver;
109use vmcore::vm_task::VmTaskDriverSource;
110use zerocopy::FromBytes;
111use zerocopy::FromZeros;
112use zerocopy::Immutable;
113use zerocopy::IntoBytes;
114use zerocopy::KnownLayout;
115
116// The minimum ring space required to handle a control message. Most control messages only need to send a completion
117// packet, but also need room for an additional SEND_VF_ASSOCIATION message.
118const MIN_CONTROL_RING_SIZE: usize = 144;
119
120// The minimum ring space required to handle external state changes. Worst case requires a completion message plus two
121// additional inband messages (SWITCH_DATA_PATH and SEND_VF_ASSOCIATION)
122const MIN_STATE_CHANGE_RING_SIZE: usize = 196;
123
124// Assign the VF_ASSOCIATION message a specific transaction ID so that the completion packet can be identified easily.
125const VF_ASSOCIATION_TRANSACTION_ID: u64 = 0x8000000000000000;
126// Assign the SWITCH_DATA_PATH message a specific transaction ID so that the completion packet can be identified easily.
127const SWITCH_DATA_PATH_TRANSACTION_ID: u64 = 0x8000000000000001;
128
129const NETVSP_MAX_SUBCHANNELS_PER_VNIC: u16 = 64;
130
131// Arbitrary delay before adding the device to the guest. Older Linux
132// clients can race when initializing the synthetic nic: the network
133// negotiation is done first and then the device is asynchronously queued to
134// receive a name (e.g. eth0). If the AN device is offered too quickly, it
135// could get the "eth0" name. In provisioning scenarios, the scripts will make
136// assumptions about which interface should be used, with eth0 being the
137// default.
138#[cfg(not(test))]
139const VF_DEVICE_DELAY: Duration = Duration::from_secs(1);
140#[cfg(test)]
141const VF_DEVICE_DELAY: Duration = Duration::from_millis(100);
142
143// Linux guests are known to not act on link state change notifications if
144// they happen in quick succession.
145#[cfg(not(test))]
146const LINK_DELAY_DURATION: Duration = Duration::from_secs(5);
147#[cfg(test)]
148const LINK_DELAY_DURATION: Duration = Duration::from_millis(333);
149
150#[derive(Default, PartialEq)]
151struct CoordinatorMessageUpdateType {
152    /// Update guest VF state based on current availability and the guest VF state tracked by the primary channel.
153    /// This includes adding the guest VF device and switching the data path.
154    guest_vf_state: bool,
155    /// Update the receive filter for all channels.
156    filter_state: bool,
157}
158
159#[derive(PartialEq)]
160enum CoordinatorMessage {
161    /// Update network state.
162    Update(CoordinatorMessageUpdateType),
163    /// Restart endpoints and resume processing. This will also attempt to set VF and data path state to match current
164    /// expectations.
165    Restart,
166    /// Start a timer.
167    StartTimer(Instant),
168}
169
170struct Worker<T: RingMem> {
171    channel_idx: u16,
172    target_vp: u32,
173    mem: GuestMemory,
174    channel: NetChannel<T>,
175    state: WorkerState,
176    coordinator_send: mpsc::Sender<CoordinatorMessage>,
177}
178
179struct NetQueue {
180    driver: VmTaskDriver,
181    queue_state: Option<QueueState>,
182}
183
184impl<T: RingMem + 'static + Sync> InspectTaskMut<Worker<T>> for NetQueue {
185    fn inspect_mut(&mut self, req: inspect::Request<'_>, worker: Option<&mut Worker<T>>) {
186        if worker.is_none() && self.queue_state.is_none() {
187            req.ignore();
188            return;
189        }
190
191        let mut resp = req.respond();
192        resp.field("driver", &self.driver);
193        if let Some(worker) = worker {
194            resp.field(
195                "protocol_state",
196                match &worker.state {
197                    WorkerState::Init(None) => "version",
198                    WorkerState::Init(Some(_)) => "init",
199                    WorkerState::Ready(_) => "ready",
200                },
201            )
202            .field("ring", &worker.channel.queue)
203            .field(
204                "can_use_ring_size_optimization",
205                worker.channel.can_use_ring_size_opt,
206            );
207
208            if let WorkerState::Ready(state) = &worker.state {
209                resp.field(
210                    "outstanding_tx_packets",
211                    state.state.pending_tx_packets.len() - state.state.free_tx_packets.len(),
212                )
213                .field("pending_rx_packets", state.state.pending_rx_packets.len())
214                .field(
215                    "pending_tx_completions",
216                    state.state.pending_tx_completions.len(),
217                )
218                .field("free_tx_packets", state.state.free_tx_packets.len())
219                .merge(&state.state.stats);
220            }
221
222            resp.field("packet_filter", worker.channel.packet_filter);
223        }
224
225        if let Some(queue_state) = &mut self.queue_state {
226            resp.field_mut("queue", &mut queue_state.queue)
227                .field("rx_buffers", queue_state.rx_buffer_range.id_range.len())
228                .field(
229                    "rx_buffers_start",
230                    queue_state.rx_buffer_range.id_range.start,
231                );
232        }
233    }
234}
235
236#[expect(clippy::large_enum_variant)]
237enum WorkerState {
238    Init(Option<InitState>),
239    Ready(ReadyState),
240}
241
242impl WorkerState {
243    fn ready(&self) -> Option<&ReadyState> {
244        if let Self::Ready(state) = self {
245            Some(state)
246        } else {
247            None
248        }
249    }
250
251    fn ready_mut(&mut self) -> Option<&mut ReadyState> {
252        if let Self::Ready(state) = self {
253            Some(state)
254        } else {
255            None
256        }
257    }
258}
259
260struct InitState {
261    version: Version,
262    ndis_config: Option<NdisConfig>,
263    ndis_version: Option<NdisVersion>,
264    recv_buffer: Option<ReceiveBuffer>,
265    send_buffer: Option<SendBuffer>,
266}
267
268#[derive(Copy, Clone, Debug, Inspect)]
269struct NdisVersion {
270    #[inspect(hex)]
271    major: u32,
272    #[inspect(hex)]
273    minor: u32,
274}
275
276#[derive(Copy, Clone, Debug, Inspect)]
277struct NdisConfig {
278    #[inspect(safe)]
279    mtu: u32,
280    #[inspect(safe)]
281    capabilities: protocol::NdisConfigCapabilities,
282}
283
284struct ReadyState {
285    buffers: Arc<ChannelBuffers>,
286    state: ActiveState,
287    data: ProcessingData,
288}
289
290/// Represents a virtual function (VF) device used to expose accelerated
291/// networking to the guest.
292#[async_trait]
293pub trait VirtualFunction: Sync + Send {
294    /// Unique ID of the device. Used by the client to associate a device with
295    /// its synthetic counterpart. A value of None signifies that the VF is not
296    /// currently available for use.
297    async fn id(&self) -> Option<u32>;
298    /// Dynamically expose the device in the guest.
299    async fn guest_ready_for_device(&mut self);
300    /// Returns when there is a change in VF availability. The Rpc result will
301    ///  indicate if the change was successfully handled.
302    async fn wait_for_state_change(&mut self) -> Rpc<(), ()>;
303}
304
305struct Adapter {
306    driver: VmTaskDriver,
307    mac_address: MacAddress,
308    max_queues: u16,
309    indirection_table_size: u16,
310    offload_support: OffloadConfig,
311    ring_size_limit: AtomicUsize,
312    free_tx_packet_threshold: usize,
313    tx_fast_completions: bool,
314    adapter_index: u32,
315    get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
316    num_sub_channels_opened: AtomicUsize,
317    link_speed: u64,
318}
319
320struct QueueState {
321    queue: Box<dyn net_backend::Queue>,
322    rx_buffer_range: RxBufferRange,
323    target_vp_set: bool,
324}
325
326struct RxBufferRange {
327    id_range: Range<u32>,
328    remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
329    remote_ranges: Arc<RxBufferRanges>,
330}
331
332impl RxBufferRange {
333    fn new(
334        ranges: Arc<RxBufferRanges>,
335        id_range: Range<u32>,
336        remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
337    ) -> Self {
338        Self {
339            id_range,
340            remote_buffer_id_recv,
341            remote_ranges: ranges,
342        }
343    }
344
345    fn send_if_remote(&self, id: u32) -> bool {
346        // Only queue 0 should get reserved buffer IDs. Otherwise check if the
347        // ID is owned by the current range.
348        if id < RX_RESERVED_CONTROL_BUFFERS || self.id_range.contains(&id) {
349            false
350        } else {
351            let i = (id - RX_RESERVED_CONTROL_BUFFERS) / self.remote_ranges.buffers_per_queue;
352            // The total number of receive buffers may not evenly divide among
353            // the active queues. Any extra buffers are given to the last
354            // queue, so redirect any larger values there.
355            let i = (i as usize).min(self.remote_ranges.buffer_id_send.len() - 1);
356            let _ = self.remote_ranges.buffer_id_send[i].unbounded_send(id);
357            true
358        }
359    }
360}
361
362struct RxBufferRanges {
363    buffers_per_queue: u32,
364    buffer_id_send: Vec<mpsc::UnboundedSender<u32>>,
365}
366
367impl RxBufferRanges {
368    fn new(buffer_count: u32, queue_count: u32) -> (Self, Vec<mpsc::UnboundedReceiver<u32>>) {
369        let buffers_per_queue = (buffer_count - RX_RESERVED_CONTROL_BUFFERS) / queue_count;
370        #[expect(clippy::disallowed_methods)] // TODO
371        let (send, recv): (Vec<_>, Vec<_>) = (0..queue_count).map(|_| mpsc::unbounded()).unzip();
372        (
373            Self {
374                buffers_per_queue,
375                buffer_id_send: send,
376            },
377            recv,
378        )
379    }
380}
381
382struct RssState {
383    key: [u8; 40],
384    indirection_table: Vec<u16>,
385}
386
387/// The internal channel state.
388struct NetChannel<T: RingMem> {
389    adapter: Arc<Adapter>,
390    queue: Queue<T>,
391    gpadl_map: GpadlMapView,
392    packet_size: usize,
393    pending_send_size: usize,
394    restart: Option<CoordinatorMessage>,
395    can_use_ring_size_opt: bool,
396    packet_filter: u32,
397}
398
399/// Buffers used during packet processing.
400struct ProcessingData {
401    tx_segments: Vec<TxSegment>,
402    tx_done: Box<[TxId]>,
403    rx_ready: Box<[RxId]>,
404    rx_done: Vec<RxId>,
405    transfer_pages: Vec<ring::TransferPageRange>,
406}
407
408impl ProcessingData {
409    fn new() -> Self {
410        Self {
411            tx_segments: Vec::new(),
412            tx_done: vec![TxId(0); 8192].into(),
413            rx_ready: vec![RxId(0); RX_BATCH_SIZE].into(),
414            rx_done: Vec::with_capacity(RX_BATCH_SIZE),
415            transfer_pages: Vec::with_capacity(RX_BATCH_SIZE),
416        }
417    }
418}
419
420/// Buffers used during channel processing. Separated out from the mutable state
421/// to allow multiple concurrent references.
422#[derive(Debug, Inspect)]
423struct ChannelBuffers {
424    version: Version,
425    #[inspect(skip)]
426    mem: GuestMemory,
427    #[inspect(skip)]
428    recv_buffer: ReceiveBuffer,
429    #[inspect(skip)]
430    send_buffer: Option<SendBuffer>,
431    ndis_version: NdisVersion,
432    #[inspect(safe)]
433    ndis_config: NdisConfig,
434}
435
436/// An ID assigned to a control message. This is also its receive buffer index.
437#[derive(Copy, Clone, Debug)]
438struct ControlMessageId(u32);
439
440/// Mutable state for a channel that has finished negotiation.
441struct ActiveState {
442    primary: Option<PrimaryChannelState>,
443
444    pending_tx_packets: Vec<PendingTxPacket>,
445    free_tx_packets: Vec<TxId>,
446    pending_tx_completions: VecDeque<PendingTxCompletion>,
447    pending_rx_packets: VecDeque<RxId>,
448
449    rx_bufs: RxBuffers,
450
451    stats: QueueStats,
452}
453
454#[derive(Inspect, Default)]
455struct QueueStats {
456    tx_stalled: Counter,
457    rx_dropped_ring_full: Counter,
458    rx_dropped_filtered: Counter,
459    spurious_wakes: Counter,
460    rx_packets: Counter,
461    tx_packets: Counter,
462    tx_lso_packets: Counter,
463    tx_checksum_packets: Counter,
464    tx_packets_per_wake: Histogram<10>,
465    rx_packets_per_wake: Histogram<10>,
466}
467
468#[derive(Debug)]
469struct PendingTxCompletion {
470    transaction_id: u64,
471    tx_id: Option<TxId>,
472    status: protocol::Status,
473}
474
475#[derive(Clone, Copy)]
476enum PrimaryChannelGuestVfState {
477    /// No state has been assigned yet
478    Initializing,
479    /// State is being restored from a save
480    Restoring(saved_state::GuestVfState),
481    /// No VF available for the guest
482    Unavailable,
483    /// A VF was previously available to the guest, but is no longer available.
484    UnavailableFromAvailable,
485    /// A VF was previously available, but it is no longer available.
486    UnavailableFromDataPathSwitchPending { to_guest: bool, id: Option<u64> },
487    /// A VF was previously available, but it is no longer available.
488    UnavailableFromDataPathSwitched,
489    /// A VF is available for the guest.
490    Available { vfid: u32 },
491    /// A VF is available for the guest and has been advertised to the guest.
492    AvailableAdvertised,
493    /// A VF is ready for guest use.
494    Ready,
495    /// A VF is ready in the guest and guest has requested a data path switch.
496    DataPathSwitchPending {
497        to_guest: bool,
498        id: Option<u64>,
499        result: Option<bool>,
500    },
501    /// A VF is ready in the guest and is currently acting as the data path.
502    DataPathSwitched,
503    /// A VF is ready in the guest and was acting as the data path, but an external
504    /// state change has moved it back to synthetic.
505    DataPathSynthetic,
506}
507
508impl std::fmt::Display for PrimaryChannelGuestVfState {
509    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
510        match self {
511            PrimaryChannelGuestVfState::Initializing => write!(f, "initializing"),
512            PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
513                write!(f, "restoring")
514            }
515            PrimaryChannelGuestVfState::Restoring(
516                saved_state::GuestVfState::AvailableAdvertised,
517            ) => write!(f, "restoring from guest notified of vfid"),
518            PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
519                write!(f, "restoring from vf present")
520            }
521            PrimaryChannelGuestVfState::Restoring(
522                saved_state::GuestVfState::DataPathSwitchPending {
523                    to_guest, result, ..
524                },
525            ) => {
526                write!(
527                    f,
528                    "restoring from client requested data path switch: to {} {}",
529                    if *to_guest { "guest" } else { "synthetic" },
530                    if let Some(result) = result {
531                        if *result { "succeeded\"" } else { "failed\"" }
532                    } else {
533                        "in progress\""
534                    }
535                )
536            }
537            PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::DataPathSwitched) => {
538                write!(f, "restoring from data path in guest")
539            }
540            PrimaryChannelGuestVfState::Unavailable => write!(f, "unavailable"),
541            PrimaryChannelGuestVfState::UnavailableFromAvailable => {
542                write!(f, "\"unavailable (previously available)\"")
543            }
544            PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. } => {
545                write!(f, "unavailable (previously switching data path)")
546            }
547            PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
548                write!(f, "\"unavailable (previously using guest VF)\"")
549            }
550            PrimaryChannelGuestVfState::Available { vfid } => write!(f, "available vfid: {}", vfid),
551            PrimaryChannelGuestVfState::AvailableAdvertised => {
552                write!(f, "\"available, guest notified\"")
553            }
554            PrimaryChannelGuestVfState::Ready => write!(f, "\"available and present in guest\""),
555            PrimaryChannelGuestVfState::DataPathSwitchPending {
556                to_guest, result, ..
557            } => {
558                write!(
559                    f,
560                    "\"switching to {} {}",
561                    if *to_guest { "guest" } else { "synthetic" },
562                    if let Some(result) = result {
563                        if *result { "succeeded\"" } else { "failed\"" }
564                    } else {
565                        "in progress\""
566                    }
567                )
568            }
569            PrimaryChannelGuestVfState::DataPathSwitched => {
570                write!(f, "\"available and data path switched\"")
571            }
572            PrimaryChannelGuestVfState::DataPathSynthetic => write!(
573                f,
574                "\"available but data path switched back to synthetic due to external state change\""
575            ),
576        }
577    }
578}
579
580impl Inspect for PrimaryChannelGuestVfState {
581    fn inspect(&self, req: inspect::Request<'_>) {
582        req.value(self.to_string());
583    }
584}
585
586#[derive(Debug)]
587enum PendingLinkAction {
588    Default,
589    Active(bool),
590    Delay(bool),
591}
592
593struct PrimaryChannelState {
594    guest_vf_state: PrimaryChannelGuestVfState,
595    is_data_path_switched: Option<bool>,
596    control_messages: VecDeque<ControlMessage>,
597    control_messages_len: usize,
598    free_control_buffers: Vec<ControlMessageId>,
599    rss_state: Option<RssState>,
600    requested_num_queues: u16,
601    rndis_state: RndisState,
602    offload_config: OffloadConfig,
603    pending_offload_change: bool,
604    tx_spread_sent: bool,
605    guest_link_up: bool,
606    pending_link_action: PendingLinkAction,
607}
608
609impl Inspect for PrimaryChannelState {
610    fn inspect(&self, req: inspect::Request<'_>) {
611        req.respond()
612            .sensitivity_field(
613                "guest_vf_state",
614                SensitivityLevel::Safe,
615                self.guest_vf_state,
616            )
617            .sensitivity_field(
618                "data_path_switched",
619                SensitivityLevel::Safe,
620                self.is_data_path_switched,
621            )
622            .sensitivity_field(
623                "pending_control_messages",
624                SensitivityLevel::Safe,
625                self.control_messages.len(),
626            )
627            .sensitivity_field(
628                "free_control_message_buffers",
629                SensitivityLevel::Safe,
630                self.free_control_buffers.len(),
631            )
632            .sensitivity_field(
633                "pending_offload_change",
634                SensitivityLevel::Safe,
635                self.pending_offload_change,
636            )
637            .sensitivity_field("rndis_state", SensitivityLevel::Safe, self.rndis_state)
638            .sensitivity_field(
639                "offload_config",
640                SensitivityLevel::Safe,
641                &self.offload_config,
642            )
643            .sensitivity_field(
644                "tx_spread_sent",
645                SensitivityLevel::Safe,
646                self.tx_spread_sent,
647            )
648            .sensitivity_field("guest_link_up", SensitivityLevel::Safe, self.guest_link_up)
649            .sensitivity_field(
650                "pending_link_action",
651                SensitivityLevel::Safe,
652                match &self.pending_link_action {
653                    PendingLinkAction::Active(up) => format!("Active({:x?})", up),
654                    PendingLinkAction::Delay(up) => format!("Delay({:x?})", up),
655                    PendingLinkAction::Default => "None".to_string(),
656                },
657            );
658    }
659}
660
661#[derive(Debug, Inspect, Clone)]
662struct OffloadConfig {
663    #[inspect(safe)]
664    checksum_tx: ChecksumOffloadConfig,
665    #[inspect(safe)]
666    checksum_rx: ChecksumOffloadConfig,
667    #[inspect(safe)]
668    lso4: bool,
669    #[inspect(safe)]
670    lso6: bool,
671}
672
673impl OffloadConfig {
674    fn mask_to_supported(&mut self, supported: &OffloadConfig) {
675        self.checksum_tx.mask_to_supported(&supported.checksum_tx);
676        self.checksum_rx.mask_to_supported(&supported.checksum_rx);
677        self.lso4 &= supported.lso4;
678        self.lso6 &= supported.lso6;
679    }
680}
681
682#[derive(Debug, Inspect, Clone)]
683struct ChecksumOffloadConfig {
684    #[inspect(safe)]
685    ipv4_header: bool,
686    #[inspect(safe)]
687    tcp4: bool,
688    #[inspect(safe)]
689    udp4: bool,
690    #[inspect(safe)]
691    tcp6: bool,
692    #[inspect(safe)]
693    udp6: bool,
694}
695
696impl ChecksumOffloadConfig {
697    fn mask_to_supported(&mut self, supported: &ChecksumOffloadConfig) {
698        self.ipv4_header &= supported.ipv4_header;
699        self.tcp4 &= supported.tcp4;
700        self.udp4 &= supported.udp4;
701        self.tcp6 &= supported.tcp6;
702        self.udp6 &= supported.udp6;
703    }
704
705    fn flags(
706        &self,
707    ) -> (
708        rndisprot::Ipv4ChecksumOffload,
709        rndisprot::Ipv6ChecksumOffload,
710    ) {
711        let on = rndisprot::NDIS_OFFLOAD_SUPPORTED;
712        let mut v4 = rndisprot::Ipv4ChecksumOffload::new();
713        let mut v6 = rndisprot::Ipv6ChecksumOffload::new();
714        if self.ipv4_header {
715            v4.set_ip_options_supported(on);
716            v4.set_ip_checksum(on);
717        }
718        if self.tcp4 {
719            v4.set_ip_options_supported(on);
720            v4.set_tcp_options_supported(on);
721            v4.set_tcp_checksum(on);
722        }
723        if self.tcp6 {
724            v6.set_ip_extension_headers_supported(on);
725            v6.set_tcp_options_supported(on);
726            v6.set_tcp_checksum(on);
727        }
728        if self.udp4 {
729            v4.set_ip_options_supported(on);
730            v4.set_udp_checksum(on);
731        }
732        if self.udp6 {
733            v6.set_ip_extension_headers_supported(on);
734            v6.set_udp_checksum(on);
735        }
736        (v4, v6)
737    }
738}
739
740impl OffloadConfig {
741    fn ndis_offload(&self) -> rndisprot::NdisOffload {
742        let checksum = {
743            let (ipv4_tx_flags, ipv6_tx_flags) = self.checksum_tx.flags();
744            let (ipv4_rx_flags, ipv6_rx_flags) = self.checksum_rx.flags();
745            rndisprot::TcpIpChecksumOffload {
746                ipv4_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
747                ipv4_tx_flags,
748                ipv4_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
749                ipv4_rx_flags,
750                ipv6_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
751                ipv6_tx_flags,
752                ipv6_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
753                ipv6_rx_flags,
754            }
755        };
756
757        let lso_v2 = {
758            let mut lso = rndisprot::TcpLargeSendOffloadV2::new_zeroed();
759            // Use the same maximum as vmswitch.
760            const MAX_OFFLOAD_SIZE: u32 = 0xF53C;
761            if self.lso4 {
762                lso.ipv4_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
763                lso.ipv4_max_offload_size = MAX_OFFLOAD_SIZE;
764                lso.ipv4_min_segment_count = 2;
765            }
766            if self.lso6 {
767                lso.ipv6_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
768                lso.ipv6_max_offload_size = MAX_OFFLOAD_SIZE;
769                lso.ipv6_min_segment_count = 2;
770                lso.ipv6_flags = rndisprot::Ipv6LsoFlags::new()
771                    .with_ip_extension_headers_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED)
772                    .with_tcp_options_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED);
773            }
774            lso
775        };
776
777        rndisprot::NdisOffload {
778            header: rndisprot::NdisObjectHeader {
779                object_type: rndisprot::NdisObjectType::OFFLOAD,
780                revision: 3,
781                size: rndisprot::NDIS_SIZEOF_NDIS_OFFLOAD_REVISION_3 as u16,
782            },
783            checksum,
784            lso_v2,
785            ..FromZeros::new_zeroed()
786        }
787    }
788}
789
790#[derive(Debug, Inspect, PartialEq, Eq, Copy, Clone)]
791pub enum RndisState {
792    Initializing,
793    Operational,
794    Halted,
795}
796
797impl PrimaryChannelState {
798    fn new(offload_config: OffloadConfig) -> Self {
799        Self {
800            guest_vf_state: PrimaryChannelGuestVfState::Initializing,
801            is_data_path_switched: None,
802            control_messages: VecDeque::new(),
803            control_messages_len: 0,
804            free_control_buffers: (0..RX_RESERVED_CONTROL_BUFFERS)
805                .map(ControlMessageId)
806                .collect(),
807            rss_state: None,
808            requested_num_queues: 1,
809            rndis_state: RndisState::Initializing,
810            pending_offload_change: false,
811            offload_config,
812            tx_spread_sent: false,
813            guest_link_up: true,
814            pending_link_action: PendingLinkAction::Default,
815        }
816    }
817
818    fn restore(
819        guest_vf_state: &saved_state::GuestVfState,
820        rndis_state: &saved_state::RndisState,
821        offload_config: &saved_state::OffloadConfig,
822        pending_offload_change: bool,
823        num_queues: u16,
824        indirection_table_size: u16,
825        rx_bufs: &RxBuffers,
826        control_messages: Vec<saved_state::IncomingControlMessage>,
827        rss_state: Option<saved_state::RssState>,
828        tx_spread_sent: bool,
829        guest_link_down: bool,
830        pending_link_action: Option<bool>,
831    ) -> Result<Self, NetRestoreError> {
832        // Restore control messages.
833        let control_messages_len = control_messages.iter().map(|msg| msg.data.len()).sum();
834
835        let control_messages = control_messages
836            .into_iter()
837            .map(|msg| ControlMessage {
838                message_type: msg.message_type,
839                data: msg.data.into(),
840            })
841            .collect();
842
843        // Compute the free control buffers.
844        let free_control_buffers = (0..RX_RESERVED_CONTROL_BUFFERS)
845            .filter_map(|id| rx_bufs.is_free(id).then_some(ControlMessageId(id)))
846            .collect();
847
848        let rss_state = rss_state
849            .map(|mut rss| {
850                if rss.indirection_table.len() > indirection_table_size as usize {
851                    // Dynamic reduction of indirection table can cause unexpected and hard to investigate issues
852                    // with performance and processor overloading.
853                    return Err(NetRestoreError::ReducedIndirectionTableSize);
854                }
855                if rss.indirection_table.len() < indirection_table_size as usize {
856                    tracing::warn!(
857                        saved_indirection_table_size = rss.indirection_table.len(),
858                        adapter_indirection_table_size = indirection_table_size,
859                        "increasing indirection table size",
860                    );
861                    // Dynamic increase of indirection table is done by duplicating the existing entries until
862                    // the desired size is reached.
863                    let table_clone = rss.indirection_table.clone();
864                    let num_to_add = indirection_table_size as usize - rss.indirection_table.len();
865                    rss.indirection_table
866                        .extend(table_clone.iter().cycle().take(num_to_add));
867                }
868                Ok(RssState {
869                    key: rss
870                        .key
871                        .try_into()
872                        .map_err(|_| NetRestoreError::InvalidRssKeySize)?,
873                    indirection_table: rss.indirection_table,
874                })
875            })
876            .transpose()?;
877
878        let rndis_state = match rndis_state {
879            saved_state::RndisState::Initializing => RndisState::Initializing,
880            saved_state::RndisState::Operational => RndisState::Operational,
881            saved_state::RndisState::Halted => RndisState::Halted,
882        };
883
884        let guest_vf_state = PrimaryChannelGuestVfState::Restoring(*guest_vf_state);
885        let offload_config = OffloadConfig {
886            checksum_tx: ChecksumOffloadConfig {
887                ipv4_header: offload_config.checksum_tx.ipv4_header,
888                tcp4: offload_config.checksum_tx.tcp4,
889                udp4: offload_config.checksum_tx.udp4,
890                tcp6: offload_config.checksum_tx.tcp6,
891                udp6: offload_config.checksum_tx.udp6,
892            },
893            checksum_rx: ChecksumOffloadConfig {
894                ipv4_header: offload_config.checksum_rx.ipv4_header,
895                tcp4: offload_config.checksum_rx.tcp4,
896                udp4: offload_config.checksum_rx.udp4,
897                tcp6: offload_config.checksum_rx.tcp6,
898                udp6: offload_config.checksum_rx.udp6,
899            },
900            lso4: offload_config.lso4,
901            lso6: offload_config.lso6,
902        };
903
904        let pending_link_action = if let Some(pending) = pending_link_action {
905            PendingLinkAction::Active(pending)
906        } else {
907            PendingLinkAction::Default
908        };
909
910        Ok(Self {
911            guest_vf_state,
912            is_data_path_switched: None,
913            control_messages,
914            control_messages_len,
915            free_control_buffers,
916            rss_state,
917            requested_num_queues: num_queues,
918            rndis_state,
919            pending_offload_change,
920            offload_config,
921            tx_spread_sent,
922            guest_link_up: !guest_link_down,
923            pending_link_action,
924        })
925    }
926}
927
928struct ControlMessage {
929    message_type: u32,
930    data: Box<[u8]>,
931}
932
933const TX_PACKET_QUOTA: usize = 1024;
934
935impl ActiveState {
936    fn new(primary: Option<PrimaryChannelState>, recv_buffer_count: u32) -> Self {
937        Self {
938            primary,
939            pending_tx_packets: vec![Default::default(); TX_PACKET_QUOTA],
940            free_tx_packets: (0..TX_PACKET_QUOTA as u32).rev().map(TxId).collect(),
941            pending_tx_completions: VecDeque::new(),
942            pending_rx_packets: VecDeque::new(),
943            rx_bufs: RxBuffers::new(recv_buffer_count),
944            stats: Default::default(),
945        }
946    }
947
948    fn restore(
949        channel: &saved_state::Channel,
950        recv_buffer_count: u32,
951    ) -> Result<Self, NetRestoreError> {
952        let mut active = Self::new(None, recv_buffer_count);
953        let saved_state::Channel {
954            pending_tx_completions,
955            in_use_rx,
956        } = channel;
957        for rx in in_use_rx {
958            active
959                .rx_bufs
960                .allocate(rx.buffers.as_slice().iter().copied())?;
961        }
962        for &transaction_id in pending_tx_completions {
963            // Consume tx quota if any is available. If not, still
964            // allow the restore since tx quota might change from
965            // release to release.
966            let tx_id = active.free_tx_packets.pop();
967            if let Some(id) = tx_id {
968                // This shouldn't be referenced, but set it in case it is in the future.
969                active.pending_tx_packets[id.0 as usize].transaction_id = transaction_id;
970            }
971            // Save/Restore does not preserve the status of pending tx completions,
972            // completing any pending completions with 'success' to avoid making changes to saved_state.
973            active
974                .pending_tx_completions
975                .push_back(PendingTxCompletion {
976                    transaction_id,
977                    tx_id,
978                    status: protocol::Status::SUCCESS,
979                });
980        }
981        Ok(active)
982    }
983}
984
985/// The state for an rndis tx packet that's currently pending in the backend
986/// endpoint.
987#[derive(Default, Clone)]
988struct PendingTxPacket {
989    pending_packet_count: usize,
990    transaction_id: u64,
991}
992
993/// The maximum batch size.
994///
995/// TODO: An even larger value is supported when RSC is enabled, so look into
996/// this.
997const RX_BATCH_SIZE: usize = 375;
998
999/// The number of receive buffers to reserve for control message responses.
1000const RX_RESERVED_CONTROL_BUFFERS: u32 = 16;
1001
1002/// A network adapter.
1003pub struct Nic {
1004    instance_id: Guid,
1005    resources: DeviceResources,
1006    coordinator: TaskControl<CoordinatorState, Coordinator>,
1007    coordinator_send: Option<mpsc::Sender<CoordinatorMessage>>,
1008    adapter: Arc<Adapter>,
1009    driver_source: VmTaskDriverSource,
1010}
1011
1012pub struct NicBuilder {
1013    virtual_function: Option<Box<dyn VirtualFunction>>,
1014    limit_ring_buffer: bool,
1015    max_queues: u16,
1016    get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
1017}
1018
1019impl NicBuilder {
1020    pub fn limit_ring_buffer(mut self, limit: bool) -> Self {
1021        self.limit_ring_buffer = limit;
1022        self
1023    }
1024
1025    pub fn max_queues(mut self, max_queues: u16) -> Self {
1026        self.max_queues = max_queues;
1027        self
1028    }
1029
1030    pub fn virtual_function(mut self, virtual_function: Box<dyn VirtualFunction>) -> Self {
1031        self.virtual_function = Some(virtual_function);
1032        self
1033    }
1034
1035    pub fn get_guest_os_id(mut self, os_type: Box<dyn Fn() -> HvGuestOsId + Send + Sync>) -> Self {
1036        self.get_guest_os_id = Some(os_type);
1037        self
1038    }
1039
1040    /// Creates a new NIC.
1041    pub fn build(
1042        self,
1043        driver_source: &VmTaskDriverSource,
1044        instance_id: Guid,
1045        endpoint: Box<dyn Endpoint>,
1046        mac_address: MacAddress,
1047        adapter_index: u32,
1048    ) -> Nic {
1049        let multiqueue = endpoint.multiqueue_support();
1050
1051        let max_queues = self.max_queues.clamp(
1052            1,
1053            multiqueue.max_queues.min(NETVSP_MAX_SUBCHANNELS_PER_VNIC),
1054        );
1055
1056        // If requested, limit the effective size of the outgoing ring buffer.
1057        // In a configuration where the NIC is processed synchronously, this
1058        // will ensure that we don't process incoming rx packets and tx packet
1059        // completions until the guest has processed the data it already has.
1060        let ring_size_limit = if self.limit_ring_buffer { 1024 } else { 0 };
1061
1062        // If the endpoint completes tx packets quickly, then avoid polling the
1063        // incoming ring (and thus avoid arming the signal from the guest) as
1064        // long as there are any tx packets in flight. This can significantly
1065        // reduce the signal rate from the guest, improving batching.
1066        let free_tx_packet_threshold = if endpoint.tx_fast_completions() {
1067            TX_PACKET_QUOTA
1068        } else {
1069            // Avoid getting into a situation where there is always barely
1070            // enough quota.
1071            TX_PACKET_QUOTA / 4
1072        };
1073
1074        let tx_offloads = endpoint.tx_offload_support();
1075
1076        // Always claim support for rx offloads since we can mark any given
1077        // packet as having unknown checksum state.
1078        let offload_support = OffloadConfig {
1079            checksum_rx: ChecksumOffloadConfig {
1080                ipv4_header: true,
1081                tcp4: true,
1082                udp4: true,
1083                tcp6: true,
1084                udp6: true,
1085            },
1086            checksum_tx: ChecksumOffloadConfig {
1087                ipv4_header: tx_offloads.ipv4_header,
1088                tcp4: tx_offloads.tcp,
1089                tcp6: tx_offloads.tcp,
1090                udp4: tx_offloads.udp,
1091                udp6: tx_offloads.udp,
1092            },
1093            lso4: tx_offloads.tso,
1094            lso6: tx_offloads.tso,
1095        };
1096
1097        let driver = driver_source.simple();
1098        let adapter = Arc::new(Adapter {
1099            driver,
1100            mac_address,
1101            max_queues,
1102            indirection_table_size: multiqueue.indirection_table_size,
1103            offload_support,
1104            free_tx_packet_threshold,
1105            ring_size_limit: ring_size_limit.into(),
1106            tx_fast_completions: endpoint.tx_fast_completions(),
1107            adapter_index,
1108            get_guest_os_id: self.get_guest_os_id,
1109            num_sub_channels_opened: AtomicUsize::new(0),
1110            link_speed: endpoint.link_speed(),
1111        });
1112
1113        let coordinator = TaskControl::new(CoordinatorState {
1114            endpoint,
1115            adapter: adapter.clone(),
1116            virtual_function: self.virtual_function,
1117            pending_vf_state: CoordinatorStatePendingVfState::Ready,
1118        });
1119
1120        Nic {
1121            instance_id,
1122            resources: Default::default(),
1123            coordinator,
1124            coordinator_send: None,
1125            adapter,
1126            driver_source: driver_source.clone(),
1127        }
1128    }
1129}
1130
1131fn can_use_ring_opt<T: RingMem>(queue: &mut Queue<T>, guest_os_id: Option<HvGuestOsId>) -> bool {
1132    let Some(guest_os_id) = guest_os_id else {
1133        // guest os id not available.
1134        return false;
1135    };
1136
1137    if !queue.split().0.supports_pending_send_size() {
1138        // guest does not support pending send size.
1139        return false;
1140    }
1141
1142    let Some(open_source_os) = guest_os_id.open_source() else {
1143        // guest os is proprietary (ex: Windows)
1144        return true;
1145    };
1146
1147    match HvGuestOsOpenSourceType(open_source_os.os_type()) {
1148        // Although FreeBSD indicates support for `pending send size`, it doesn't
1149        // implement it correctly. This was fixed in FreeBSD version `1400097`.
1150        HvGuestOsOpenSourceType::FREEBSD => open_source_os.version() >= 1400097,
1151        // Linux kernels prior to 3.11 have issues with pending send size optimization
1152        // which can affect certain Asynchronous I/O (AIO) network operations.
1153        // Disable ring size optimization for these older kernels to avoid flow control issues.
1154        HvGuestOsOpenSourceType::LINUX => {
1155            // Linux version is encoded as: ((major << 16) | (minor << 8) | patch)
1156            // Linux 3.11.0 = (3 << 16) | (11 << 8) | 0 = 199424
1157            open_source_os.version() >= 199424
1158        }
1159        _ => true,
1160    }
1161}
1162
1163impl Nic {
1164    pub fn builder() -> NicBuilder {
1165        NicBuilder {
1166            virtual_function: None,
1167            limit_ring_buffer: false,
1168            max_queues: !0,
1169            get_guest_os_id: None,
1170        }
1171    }
1172
1173    pub fn shutdown(self) -> Box<dyn Endpoint> {
1174        let (state, _) = self.coordinator.into_inner();
1175        state.endpoint
1176    }
1177}
1178
1179impl InspectMut for Nic {
1180    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
1181        self.coordinator.inspect_mut(req);
1182    }
1183}
1184
1185#[async_trait]
1186impl VmbusDevice for Nic {
1187    fn offer(&self) -> OfferParams {
1188        OfferParams {
1189            interface_name: "net".to_owned(),
1190            instance_id: self.instance_id,
1191            interface_id: Guid {
1192                data1: 0xf8615163,
1193                data2: 0xdf3e,
1194                data3: 0x46c5,
1195                data4: [0x91, 0x3f, 0xf2, 0xd2, 0xf9, 0x65, 0xed, 0xe],
1196            },
1197            subchannel_index: 0,
1198            mnf_interrupt_latency: Some(Duration::from_micros(100)),
1199            ..Default::default()
1200        }
1201    }
1202
1203    fn max_subchannels(&self) -> u16 {
1204        self.adapter.max_queues
1205    }
1206
1207    fn install(&mut self, resources: DeviceResources) {
1208        self.resources = resources;
1209    }
1210
1211    async fn open(
1212        &mut self,
1213        channel_idx: u16,
1214        open_request: &OpenRequest,
1215    ) -> Result<(), ChannelOpenError> {
1216        // Start the coordinator task if this is the primary channel.
1217        let state = if channel_idx == 0 {
1218            self.insert_coordinator(1, false);
1219            WorkerState::Init(None)
1220        } else {
1221            self.coordinator.stop().await;
1222            // Get the buffers created when the primary channel was opened.
1223            let buffers = self.coordinator.state().unwrap().buffers.clone().unwrap();
1224            WorkerState::Ready(ReadyState {
1225                state: ActiveState::new(None, buffers.recv_buffer.count),
1226                buffers,
1227                data: ProcessingData::new(),
1228            })
1229        };
1230
1231        let num_opened = self
1232            .adapter
1233            .num_sub_channels_opened
1234            .fetch_add(1, Ordering::SeqCst);
1235        let r = self.insert_worker(channel_idx, open_request, state, true);
1236        if channel_idx != 0
1237            && num_opened + 1 == self.coordinator.state_mut().unwrap().num_queues as usize
1238        {
1239            let coordinator = &mut self.coordinator.state_mut().unwrap();
1240            coordinator.workers[0].stop().await;
1241        }
1242
1243        if r.is_err() && channel_idx == 0 {
1244            self.coordinator.remove();
1245        } else {
1246            // The coordinator will restart any stopped workers.
1247            self.coordinator.start();
1248        }
1249        r?;
1250        Ok(())
1251    }
1252
1253    async fn close(&mut self, channel_idx: u16) {
1254        if !self.coordinator.has_state() {
1255            tracing::error!("Close called while vmbus channel is already closed");
1256            return;
1257        }
1258
1259        // Stop the coordinator to get access to the workers.
1260        let restart = self.coordinator.stop().await;
1261
1262        // Stop and remove the channel worker.
1263        {
1264            let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1265            worker.stop().await;
1266            if worker.has_state() {
1267                worker.remove();
1268            }
1269        }
1270
1271        self.adapter
1272            .num_sub_channels_opened
1273            .fetch_sub(1, Ordering::SeqCst);
1274        // Disable the endpoint.
1275        if channel_idx == 0 {
1276            for worker in &mut self.coordinator.state_mut().unwrap().workers {
1277                worker.task_mut().queue_state = None;
1278            }
1279
1280            // Note that this await is not restartable.
1281            self.coordinator.task_mut().endpoint.stop().await;
1282
1283            // Keep any VF's added to the guest. This is required to keep guest compat as
1284            // some apps (such as DPDK) relies on the VF sticking around even after vmbus
1285            // channel is closed.
1286            // The coordinator's job is done.
1287            self.coordinator.remove();
1288        } else {
1289            // Restart the coordinator.
1290            if restart {
1291                self.coordinator.start();
1292            }
1293        }
1294    }
1295
1296    async fn retarget_vp(&mut self, channel_idx: u16, target_vp: u32) {
1297        if !self.coordinator.has_state() {
1298            return;
1299        }
1300
1301        // Stop the coordinator and worker associated with this channel.
1302        let coordinator_running = self.coordinator.stop().await;
1303        let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1304        worker.stop().await;
1305        let (net_queue, worker_state) = worker.get_mut();
1306
1307        // Update the target VP on the driver.
1308        net_queue.driver.retarget_vp(target_vp);
1309
1310        if let Some(worker_state) = worker_state {
1311            // Update the target VP in the worker state.
1312            worker_state.target_vp = target_vp;
1313            if let Some(queue_state) = &mut net_queue.queue_state {
1314                // Tell the worker to re-set the target VP on next run.
1315                queue_state.target_vp_set = false;
1316            }
1317        }
1318
1319        // The coordinator will restart any stopped workers.
1320        if coordinator_running {
1321            self.coordinator.start();
1322        }
1323    }
1324
1325    fn start(&mut self) {
1326        if !self.coordinator.is_running() {
1327            self.coordinator.start();
1328        }
1329    }
1330
1331    async fn stop(&mut self) {
1332        self.coordinator.stop().await;
1333        if let Some(coordinator) = self.coordinator.state_mut() {
1334            coordinator.stop_workers().await;
1335        }
1336    }
1337
1338    fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
1339        Some(self)
1340    }
1341}
1342
1343#[async_trait]
1344impl SaveRestoreVmbusDevice for Nic {
1345    async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
1346        let state = self.saved_state();
1347        Ok(SavedStateBlob::new(state))
1348    }
1349
1350    async fn restore(
1351        &mut self,
1352        control: RestoreControl<'_>,
1353        state: SavedStateBlob,
1354    ) -> Result<(), RestoreError> {
1355        let state: saved_state::SavedState = state.parse()?;
1356        if let Err(err) = self.restore_state(control, state).await {
1357            tracing::error!(err = &err as &dyn std::error::Error, instance_id = %self.instance_id, "Failed restoring network vmbus state");
1358            Err(err.into())
1359        } else {
1360            Ok(())
1361        }
1362    }
1363}
1364
1365impl Nic {
1366    /// Allocates and inserts a worker.
1367    ///
1368    /// The coordinator must be stopped.
1369    fn insert_worker(
1370        &mut self,
1371        channel_idx: u16,
1372        open_request: &OpenRequest,
1373        state: WorkerState,
1374        start: bool,
1375    ) -> Result<(), OpenError> {
1376        let coordinator = self.coordinator.state_mut().unwrap();
1377
1378        // Retarget the driver now that the channel is open.
1379        let driver = coordinator.workers[channel_idx as usize]
1380            .task()
1381            .driver
1382            .clone();
1383        driver.retarget_vp(open_request.open_data.target_vp);
1384
1385        let raw = gpadl_channel(&driver, &self.resources, open_request, channel_idx)
1386            .map_err(OpenError::Ring)?;
1387        let mut queue = Queue::new(raw).map_err(OpenError::Queue)?;
1388        let guest_os_id = self.adapter.get_guest_os_id.as_ref().map(|f| f());
1389        let can_use_ring_size_opt = can_use_ring_opt(&mut queue, guest_os_id);
1390        let worker = Worker {
1391            channel_idx,
1392            target_vp: open_request.open_data.target_vp,
1393            mem: self
1394                .resources
1395                .offer_resources
1396                .guest_memory(open_request)
1397                .clone(),
1398            channel: NetChannel {
1399                adapter: self.adapter.clone(),
1400                queue,
1401                gpadl_map: self.resources.gpadl_map.clone(),
1402                packet_size: protocol::PACKET_SIZE_V1,
1403                pending_send_size: 0,
1404                restart: None,
1405                can_use_ring_size_opt,
1406                packet_filter: rndisprot::NDIS_PACKET_TYPE_NONE,
1407            },
1408            state,
1409            coordinator_send: self.coordinator_send.clone().unwrap(),
1410        };
1411        let instance_id = self.instance_id;
1412        let worker_task = &mut coordinator.workers[channel_idx as usize];
1413        worker_task.insert(
1414            driver,
1415            format!("netvsp-{}-{}", instance_id, channel_idx),
1416            worker,
1417        );
1418        if start {
1419            worker_task.start();
1420        }
1421        Ok(())
1422    }
1423}
1424
1425impl Nic {
1426    /// If `restoring`, then restart the queues as soon as the coordinator starts.
1427    fn insert_coordinator(&mut self, num_queues: u16, restoring: bool) {
1428        let mut driver_builder = self.driver_source.builder();
1429        // Target each driver to VP 0 initially. This will be updated when the
1430        // channel is opened.
1431        driver_builder.target_vp(0);
1432        // If tx completions arrive quickly, then just do tx processing
1433        // on whatever processor the guest happens to signal from.
1434        // Subsequent transmits will be pulled from the completion
1435        // processor.
1436        driver_builder.run_on_target(!self.adapter.tx_fast_completions);
1437
1438        #[expect(clippy::disallowed_methods)] // TODO
1439        let (send, recv) = mpsc::channel(1);
1440        self.coordinator_send = Some(send);
1441        self.coordinator.insert(
1442            &self.adapter.driver,
1443            format!("netvsp-{}-coordinator", self.instance_id),
1444            Coordinator {
1445                recv,
1446                channel_control: self.resources.channel_control.clone(),
1447                restart: restoring,
1448                workers: (0..self.adapter.max_queues)
1449                    .map(|i| {
1450                        TaskControl::new(NetQueue {
1451                            queue_state: None,
1452                            driver: driver_builder
1453                                .build(format!("netvsp-{}-{}", self.instance_id, i)),
1454                        })
1455                    })
1456                    .collect(),
1457                buffers: None,
1458                num_queues,
1459            },
1460        );
1461    }
1462}
1463
1464#[derive(Debug, Error)]
1465enum NetRestoreError {
1466    #[error("unsupported protocol version {0:#x}")]
1467    UnsupportedVersion(u32),
1468    #[error("send/receive buffer invalid gpadl ID")]
1469    UnknownGpadlId(#[from] UnknownGpadlId),
1470    #[error("failed to restore channels")]
1471    Channel(#[source] ChannelRestoreError),
1472    #[error(transparent)]
1473    ReceiveBuffer(#[from] BufferError),
1474    #[error(transparent)]
1475    SuballocationMisconfigured(#[from] SubAllocationInUse),
1476    #[error(transparent)]
1477    Open(#[from] OpenError),
1478    #[error("invalid rss key size")]
1479    InvalidRssKeySize,
1480    #[error("reduced indirection table size")]
1481    ReducedIndirectionTableSize,
1482}
1483
1484impl From<NetRestoreError> for RestoreError {
1485    fn from(err: NetRestoreError) -> Self {
1486        RestoreError::InvalidSavedState(anyhow::Error::new(err))
1487    }
1488}
1489
1490impl Nic {
1491    async fn restore_state(
1492        &mut self,
1493        mut control: RestoreControl<'_>,
1494        state: saved_state::SavedState,
1495    ) -> Result<(), NetRestoreError> {
1496        let mut saved_packet_filter = 0u32;
1497        if let Some(state) = state.open {
1498            let open = match &state.primary {
1499                saved_state::Primary::Version => vec![true],
1500                saved_state::Primary::Init(_) => vec![true],
1501                saved_state::Primary::Ready(ready) => {
1502                    ready.channels.iter().map(|x| x.is_some()).collect()
1503                }
1504            };
1505
1506            let mut states: Vec<_> = open.iter().map(|_| None).collect();
1507
1508            // N.B. This will restore the vmbus view of open channels, so any
1509            //      failures after this point could result in inconsistent
1510            //      state (vmbus believes the channel is open/active). There
1511            //      are a number of failure paths after this point because this
1512            //      call also restores vmbus device state, like the GPADL map.
1513            let requests = control
1514                .restore(&open)
1515                .await
1516                .map_err(NetRestoreError::Channel)?;
1517
1518            match state.primary {
1519                saved_state::Primary::Version => {
1520                    states[0] = Some(WorkerState::Init(None));
1521                }
1522                saved_state::Primary::Init(init) => {
1523                    let version = check_version(init.version)
1524                        .ok_or(NetRestoreError::UnsupportedVersion(init.version))?;
1525
1526                    let recv_buffer = init
1527                        .receive_buffer
1528                        .map(|recv_buffer| {
1529                            ReceiveBuffer::new(
1530                                &self.resources.gpadl_map,
1531                                recv_buffer.gpadl_id,
1532                                recv_buffer.id,
1533                                recv_buffer.sub_allocation_size,
1534                            )
1535                        })
1536                        .transpose()?;
1537
1538                    let send_buffer = init
1539                        .send_buffer
1540                        .map(|send_buffer| {
1541                            SendBuffer::new(&self.resources.gpadl_map, send_buffer.gpadl_id)
1542                        })
1543                        .transpose()?;
1544
1545                    let state = InitState {
1546                        version,
1547                        ndis_config: init.ndis_config.map(
1548                            |saved_state::NdisConfig { mtu, capabilities }| NdisConfig {
1549                                mtu,
1550                                capabilities: capabilities.into(),
1551                            },
1552                        ),
1553                        ndis_version: init.ndis_version.map(
1554                            |saved_state::NdisVersion { major, minor }| NdisVersion {
1555                                major,
1556                                minor,
1557                            },
1558                        ),
1559                        recv_buffer,
1560                        send_buffer,
1561                    };
1562                    states[0] = Some(WorkerState::Init(Some(state)));
1563                }
1564                saved_state::Primary::Ready(ready) => {
1565                    let saved_state::ReadyPrimary {
1566                        version,
1567                        receive_buffer,
1568                        send_buffer,
1569                        mut control_messages,
1570                        mut rss_state,
1571                        channels,
1572                        ndis_version,
1573                        ndis_config,
1574                        rndis_state,
1575                        guest_vf_state,
1576                        offload_config,
1577                        pending_offload_change,
1578                        tx_spread_sent,
1579                        guest_link_down,
1580                        pending_link_action,
1581                        packet_filter,
1582                    } = ready;
1583
1584                    // If saved state does not have a packet filter set, default to directed, multicast, and broadcast.
1585                    saved_packet_filter = packet_filter.unwrap_or(rndisprot::NPROTO_PACKET_FILTER);
1586
1587                    let version = check_version(version)
1588                        .ok_or(NetRestoreError::UnsupportedVersion(version))?;
1589
1590                    let request = requests[0].as_ref().unwrap();
1591                    let buffers = Arc::new(ChannelBuffers {
1592                        version,
1593                        mem: self.resources.offer_resources.guest_memory(request).clone(),
1594                        recv_buffer: ReceiveBuffer::new(
1595                            &self.resources.gpadl_map,
1596                            receive_buffer.gpadl_id,
1597                            receive_buffer.id,
1598                            receive_buffer.sub_allocation_size,
1599                        )?,
1600                        send_buffer: {
1601                            if let Some(send_buffer) = send_buffer {
1602                                Some(SendBuffer::new(
1603                                    &self.resources.gpadl_map,
1604                                    send_buffer.gpadl_id,
1605                                )?)
1606                            } else {
1607                                None
1608                            }
1609                        },
1610                        ndis_version: {
1611                            let saved_state::NdisVersion { major, minor } = ndis_version;
1612                            NdisVersion { major, minor }
1613                        },
1614                        ndis_config: {
1615                            let saved_state::NdisConfig { mtu, capabilities } = ndis_config;
1616                            NdisConfig {
1617                                mtu,
1618                                capabilities: capabilities.into(),
1619                            }
1620                        },
1621                    });
1622
1623                    for (channel_idx, channel) in channels.iter().enumerate() {
1624                        let channel = if let Some(channel) = channel {
1625                            channel
1626                        } else {
1627                            continue;
1628                        };
1629
1630                        let mut active = ActiveState::restore(channel, buffers.recv_buffer.count)?;
1631
1632                        // Restore primary channel state.
1633                        if channel_idx == 0 {
1634                            let primary = PrimaryChannelState::restore(
1635                                &guest_vf_state,
1636                                &rndis_state,
1637                                &offload_config,
1638                                pending_offload_change,
1639                                channels.len() as u16,
1640                                self.adapter.indirection_table_size,
1641                                &active.rx_bufs,
1642                                std::mem::take(&mut control_messages),
1643                                rss_state.take(),
1644                                tx_spread_sent,
1645                                guest_link_down,
1646                                pending_link_action,
1647                            )?;
1648                            active.primary = Some(primary);
1649                        }
1650
1651                        states[channel_idx] = Some(WorkerState::Ready(ReadyState {
1652                            buffers: buffers.clone(),
1653                            state: active,
1654                            data: ProcessingData::new(),
1655                        }));
1656                    }
1657                }
1658            }
1659
1660            // Insert the coordinator and mark that it should try to start the
1661            // network endpoint when it starts running.
1662            self.insert_coordinator(states.len() as u16, true);
1663
1664            for (channel_idx, (state, request)) in states.into_iter().zip(requests).enumerate() {
1665                if let Some(state) = state {
1666                    self.insert_worker(channel_idx as u16, &request.unwrap(), state, false)?;
1667                }
1668            }
1669            for worker in self.coordinator.state_mut().unwrap().workers.iter_mut() {
1670                if let Some(worker_state) = worker.state_mut() {
1671                    worker_state.channel.packet_filter = saved_packet_filter;
1672                }
1673            }
1674        } else {
1675            control
1676                .restore(&[false])
1677                .await
1678                .map_err(NetRestoreError::Channel)?;
1679        }
1680        Ok(())
1681    }
1682
1683    fn saved_state(&self) -> saved_state::SavedState {
1684        let open = if let Some(coordinator) = self.coordinator.state() {
1685            let primary = coordinator.workers[0].state().unwrap();
1686            let primary = match &primary.state {
1687                WorkerState::Init(None) => saved_state::Primary::Version,
1688                WorkerState::Init(Some(init)) => {
1689                    saved_state::Primary::Init(saved_state::InitPrimary {
1690                        version: init.version as u32,
1691                        ndis_config: init.ndis_config.map(|NdisConfig { mtu, capabilities }| {
1692                            saved_state::NdisConfig {
1693                                mtu,
1694                                capabilities: capabilities.into(),
1695                            }
1696                        }),
1697                        ndis_version: init.ndis_version.map(|NdisVersion { major, minor }| {
1698                            saved_state::NdisVersion { major, minor }
1699                        }),
1700                        receive_buffer: init.recv_buffer.as_ref().map(|x| x.saved_state()),
1701                        send_buffer: init.send_buffer.as_ref().map(|x| saved_state::SendBuffer {
1702                            gpadl_id: x.gpadl.id(),
1703                        }),
1704                    })
1705                }
1706                WorkerState::Ready(ready) => {
1707                    let primary = ready.state.primary.as_ref().unwrap();
1708
1709                    let rndis_state = match primary.rndis_state {
1710                        RndisState::Initializing => saved_state::RndisState::Initializing,
1711                        RndisState::Operational => saved_state::RndisState::Operational,
1712                        RndisState::Halted => saved_state::RndisState::Halted,
1713                    };
1714
1715                    let offload_config = saved_state::OffloadConfig {
1716                        checksum_tx: saved_state::ChecksumOffloadConfig {
1717                            ipv4_header: primary.offload_config.checksum_tx.ipv4_header,
1718                            tcp4: primary.offload_config.checksum_tx.tcp4,
1719                            udp4: primary.offload_config.checksum_tx.udp4,
1720                            tcp6: primary.offload_config.checksum_tx.tcp6,
1721                            udp6: primary.offload_config.checksum_tx.udp6,
1722                        },
1723                        checksum_rx: saved_state::ChecksumOffloadConfig {
1724                            ipv4_header: primary.offload_config.checksum_rx.ipv4_header,
1725                            tcp4: primary.offload_config.checksum_rx.tcp4,
1726                            udp4: primary.offload_config.checksum_rx.udp4,
1727                            tcp6: primary.offload_config.checksum_rx.tcp6,
1728                            udp6: primary.offload_config.checksum_rx.udp6,
1729                        },
1730                        lso4: primary.offload_config.lso4,
1731                        lso6: primary.offload_config.lso6,
1732                    };
1733
1734                    let control_messages = primary
1735                        .control_messages
1736                        .iter()
1737                        .map(|message| saved_state::IncomingControlMessage {
1738                            message_type: message.message_type,
1739                            data: message.data.to_vec(),
1740                        })
1741                        .collect();
1742
1743                    let rss_state = primary.rss_state.as_ref().map(|rss| saved_state::RssState {
1744                        key: rss.key.into(),
1745                        indirection_table: rss.indirection_table.clone(),
1746                    });
1747
1748                    let pending_link_action = match primary.pending_link_action {
1749                        PendingLinkAction::Default => None,
1750                        PendingLinkAction::Active(action) | PendingLinkAction::Delay(action) => {
1751                            Some(action)
1752                        }
1753                    };
1754
1755                    let channels = coordinator.workers[..coordinator.num_queues as usize]
1756                        .iter()
1757                        .map(|worker| {
1758                            worker.state().map(|worker| {
1759                                if let Some(ready) = worker.state.ready() {
1760                                    // In flight tx will be considered as dropped packets through save/restore, but need
1761                                    // to complete the requests back to the guest.
1762                                    let pending_tx_completions = ready
1763                                        .state
1764                                        .pending_tx_completions
1765                                        .iter()
1766                                        .map(|pending| pending.transaction_id)
1767                                        .chain(ready.state.pending_tx_packets.iter().filter_map(
1768                                            |inflight| {
1769                                                (inflight.pending_packet_count > 0)
1770                                                    .then_some(inflight.transaction_id)
1771                                            },
1772                                        ))
1773                                        .collect();
1774
1775                                    saved_state::Channel {
1776                                        pending_tx_completions,
1777                                        in_use_rx: {
1778                                            ready
1779                                                .state
1780                                                .rx_bufs
1781                                                .allocated()
1782                                                .map(|id| saved_state::Rx {
1783                                                    buffers: id.collect(),
1784                                                })
1785                                                .collect()
1786                                        },
1787                                    }
1788                                } else {
1789                                    saved_state::Channel {
1790                                        pending_tx_completions: Vec::new(),
1791                                        in_use_rx: Vec::new(),
1792                                    }
1793                                }
1794                            })
1795                        })
1796                        .collect();
1797
1798                    let guest_vf_state = match primary.guest_vf_state {
1799                        PrimaryChannelGuestVfState::Initializing
1800                        | PrimaryChannelGuestVfState::Unavailable
1801                        | PrimaryChannelGuestVfState::Available { .. } => {
1802                            saved_state::GuestVfState::NoState
1803                        }
1804                        PrimaryChannelGuestVfState::UnavailableFromAvailable
1805                        | PrimaryChannelGuestVfState::AvailableAdvertised => {
1806                            saved_state::GuestVfState::AvailableAdvertised
1807                        }
1808                        PrimaryChannelGuestVfState::Ready => saved_state::GuestVfState::Ready,
1809                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
1810                            to_guest,
1811                            id,
1812                        } => saved_state::GuestVfState::DataPathSwitchPending {
1813                            to_guest,
1814                            id,
1815                            result: None,
1816                        },
1817                        PrimaryChannelGuestVfState::DataPathSwitchPending {
1818                            to_guest,
1819                            id,
1820                            result,
1821                        } => saved_state::GuestVfState::DataPathSwitchPending {
1822                            to_guest,
1823                            id,
1824                            result,
1825                        },
1826                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
1827                        | PrimaryChannelGuestVfState::DataPathSwitched
1828                        | PrimaryChannelGuestVfState::DataPathSynthetic => {
1829                            saved_state::GuestVfState::DataPathSwitched
1830                        }
1831                        PrimaryChannelGuestVfState::Restoring(saved_state) => saved_state,
1832                    };
1833
1834                    let worker_0_packet_filter = coordinator.workers[0]
1835                        .state()
1836                        .unwrap()
1837                        .channel
1838                        .packet_filter;
1839                    saved_state::Primary::Ready(saved_state::ReadyPrimary {
1840                        version: ready.buffers.version as u32,
1841                        receive_buffer: ready.buffers.recv_buffer.saved_state(),
1842                        send_buffer: ready.buffers.send_buffer.as_ref().map(|sb| {
1843                            saved_state::SendBuffer {
1844                                gpadl_id: sb.gpadl.id(),
1845                            }
1846                        }),
1847                        rndis_state,
1848                        guest_vf_state,
1849                        offload_config,
1850                        pending_offload_change: primary.pending_offload_change,
1851                        control_messages,
1852                        rss_state,
1853                        channels,
1854                        ndis_config: {
1855                            let NdisConfig { mtu, capabilities } = ready.buffers.ndis_config;
1856                            saved_state::NdisConfig {
1857                                mtu,
1858                                capabilities: capabilities.into(),
1859                            }
1860                        },
1861                        ndis_version: {
1862                            let NdisVersion { major, minor } = ready.buffers.ndis_version;
1863                            saved_state::NdisVersion { major, minor }
1864                        },
1865                        tx_spread_sent: primary.tx_spread_sent,
1866                        guest_link_down: !primary.guest_link_up,
1867                        pending_link_action,
1868                        packet_filter: Some(worker_0_packet_filter),
1869                    })
1870                }
1871            };
1872
1873            let state = saved_state::OpenState { primary };
1874            Some(state)
1875        } else {
1876            None
1877        };
1878
1879        saved_state::SavedState { open }
1880    }
1881}
1882
1883#[derive(Debug, Error)]
1884enum WorkerError {
1885    #[error("packet error")]
1886    Packet(#[source] PacketError),
1887    #[error("unexpected packet order: {0}")]
1888    UnexpectedPacketOrder(#[source] PacketOrderError),
1889    #[error("unknown rndis message type: {0}")]
1890    UnknownRndisMessageType(u32),
1891    #[error("memory access error")]
1892    Access(#[from] AccessError),
1893    #[error("rndis message too small")]
1894    RndisMessageTooSmall,
1895    #[error("unsupported rndis behavior")]
1896    UnsupportedRndisBehavior,
1897    #[error("vmbus queue error")]
1898    Queue(#[from] queue::Error),
1899    #[error("too many control messages")]
1900    TooManyControlMessages,
1901    #[error("invalid rndis packet completion")]
1902    InvalidRndisPacketCompletion,
1903    #[error("missing transaction id")]
1904    MissingTransactionId,
1905    #[error("invalid gpadl")]
1906    InvalidGpadl(#[source] guestmem::InvalidGpn),
1907    #[error("gpadl error")]
1908    GpadlError(#[source] GuestMemoryError),
1909    #[error("gpa direct error")]
1910    GpaDirectError(#[source] GuestMemoryError),
1911    #[error("endpoint")]
1912    Endpoint(#[source] anyhow::Error),
1913    #[error("message not supported on sub channel: {0}")]
1914    NotSupportedOnSubChannel(u32),
1915    #[error("the ring buffer ran out of space, which should not be possible")]
1916    OutOfSpace,
1917    #[error("send/receive buffer error")]
1918    Buffer(#[from] BufferError),
1919    #[error("invalid rndis state")]
1920    InvalidRndisState,
1921    #[error("rndis message type not implemented")]
1922    RndisMessageTypeNotImplemented,
1923    #[error("invalid TCP header offset")]
1924    InvalidTcpHeaderOffset,
1925    #[error("cancelled")]
1926    Cancelled(task_control::Cancelled),
1927    #[error("tearing down because send/receive buffer is revoked")]
1928    BufferRevoked,
1929    #[error("endpoint requires queue restart: {0}")]
1930    EndpointRequiresQueueRestart(#[source] anyhow::Error),
1931}
1932
1933impl From<task_control::Cancelled> for WorkerError {
1934    fn from(value: task_control::Cancelled) -> Self {
1935        Self::Cancelled(value)
1936    }
1937}
1938
1939#[derive(Debug, Error)]
1940enum OpenError {
1941    #[error("error establishing ring buffer")]
1942    Ring(#[source] vmbus_channel::gpadl_ring::Error),
1943    #[error("error establishing vmbus queue")]
1944    Queue(#[source] queue::Error),
1945}
1946
1947#[derive(Debug, Error)]
1948enum PacketError {
1949    #[error("UnknownType {0}")]
1950    UnknownType(u32),
1951    #[error("Access")]
1952    Access(#[source] AccessError),
1953    #[error("ExternalData")]
1954    ExternalData(#[source] ExternalDataError),
1955    #[error("InvalidSendBufferIndex")]
1956    InvalidSendBufferIndex,
1957}
1958
1959#[derive(Debug, Error)]
1960enum PacketOrderError {
1961    #[error("Invalid PacketData")]
1962    InvalidPacketData,
1963    #[error("Unexpected RndisPacket")]
1964    UnexpectedRndisPacket,
1965    #[error("SendNdisVersion already exists")]
1966    SendNdisVersionExists,
1967    #[error("SendNdisConfig already exists")]
1968    SendNdisConfigExists,
1969    #[error("SendReceiveBuffer already exists")]
1970    SendReceiveBufferExists,
1971    #[error("SendReceiveBuffer missing MTU")]
1972    SendReceiveBufferMissingMTU,
1973    #[error("SendSendBuffer already exists")]
1974    SendSendBufferExists,
1975    #[error("SwitchDataPathCompletion during PrimaryChannelState")]
1976    SwitchDataPathCompletionPrimaryChannelState,
1977}
1978
1979#[derive(Debug)]
1980enum PacketData {
1981    Init(protocol::MessageInit),
1982    SendNdisVersion(protocol::Message1SendNdisVersion),
1983    SendReceiveBuffer(protocol::Message1SendReceiveBuffer),
1984    SendSendBuffer(protocol::Message1SendSendBuffer),
1985    RevokeReceiveBuffer(Message1RevokeReceiveBuffer),
1986    RevokeSendBuffer(Message1RevokeSendBuffer),
1987    RndisPacket(protocol::Message1SendRndisPacket),
1988    RndisPacketComplete(protocol::Message1SendRndisPacketComplete),
1989    SendNdisConfig(protocol::Message2SendNdisConfig),
1990    SwitchDataPath(protocol::Message4SwitchDataPath),
1991    OidQueryEx(protocol::Message5OidQueryEx),
1992    SubChannelRequest(protocol::Message5SubchannelRequest),
1993    SendVfAssociationCompletion,
1994    SwitchDataPathCompletion,
1995}
1996
1997#[derive(Debug)]
1998struct Packet<'a> {
1999    data: PacketData,
2000    transaction_id: Option<u64>,
2001    external_data: MultiPagedRangeBuf<GpnList>,
2002    send_buffer_suballocation: PagedRange<'a>,
2003}
2004
2005type PacketReader<'a> = PagedRangesReader<
2006    'a,
2007    std::iter::Chain<std::iter::Once<PagedRange<'a>>, MultiPagedRangeIter<'a>>,
2008>;
2009
2010impl Packet<'_> {
2011    fn rndis_reader<'a>(&'a self, mem: &'a GuestMemory) -> PacketReader<'a> {
2012        PagedRanges::new(
2013            std::iter::once(self.send_buffer_suballocation).chain(self.external_data.iter()),
2014        )
2015        .reader(mem)
2016    }
2017}
2018
2019fn read_packet_data<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
2020    reader: &mut impl MemoryRead,
2021) -> Result<T, PacketError> {
2022    reader.read_plain().map_err(PacketError::Access)
2023}
2024
2025fn parse_packet<'a, T: RingMem>(
2026    packet_ref: &queue::PacketRef<'_, T>,
2027    send_buffer: Option<&'a SendBuffer>,
2028    version: Option<Version>,
2029) -> Result<Packet<'a>, PacketError> {
2030    let packet = match packet_ref.as_ref() {
2031        IncomingPacket::Data(data) => data,
2032        IncomingPacket::Completion(completion) => {
2033            let data = if completion.transaction_id() == VF_ASSOCIATION_TRANSACTION_ID {
2034                PacketData::SendVfAssociationCompletion
2035            } else if completion.transaction_id() == SWITCH_DATA_PATH_TRANSACTION_ID {
2036                PacketData::SwitchDataPathCompletion
2037            } else {
2038                let mut reader = completion.reader();
2039                let header: protocol::MessageHeader =
2040                    reader.read_plain().map_err(PacketError::Access)?;
2041                match header.message_type {
2042                    protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE => {
2043                        PacketData::RndisPacketComplete(read_packet_data(&mut reader)?)
2044                    }
2045                    typ => return Err(PacketError::UnknownType(typ)),
2046                }
2047            };
2048            return Ok(Packet {
2049                data,
2050                transaction_id: Some(completion.transaction_id()),
2051                external_data: MultiPagedRangeBuf::empty(),
2052                send_buffer_suballocation: PagedRange::empty(),
2053            });
2054        }
2055    };
2056
2057    let mut reader = packet.reader();
2058    let header: protocol::MessageHeader = reader.read_plain().map_err(PacketError::Access)?;
2059    let mut send_buffer_suballocation = PagedRange::empty();
2060    let data = match header.message_type {
2061        protocol::MESSAGE_TYPE_INIT => PacketData::Init(read_packet_data(&mut reader)?),
2062        protocol::MESSAGE1_TYPE_SEND_NDIS_VERSION if version >= Some(Version::V1) => {
2063            PacketData::SendNdisVersion(read_packet_data(&mut reader)?)
2064        }
2065        protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2066            PacketData::SendReceiveBuffer(read_packet_data(&mut reader)?)
2067        }
2068        protocol::MESSAGE1_TYPE_REVOKE_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2069            PacketData::RevokeReceiveBuffer(read_packet_data(&mut reader)?)
2070        }
2071        protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER if version >= Some(Version::V1) => {
2072            PacketData::SendSendBuffer(read_packet_data(&mut reader)?)
2073        }
2074        protocol::MESSAGE1_TYPE_REVOKE_SEND_BUFFER if version >= Some(Version::V1) => {
2075            PacketData::RevokeSendBuffer(read_packet_data(&mut reader)?)
2076        }
2077        protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET if version >= Some(Version::V1) => {
2078            let message: protocol::Message1SendRndisPacket = read_packet_data(&mut reader)?;
2079            if message.send_buffer_section_index != 0xffffffff {
2080                send_buffer_suballocation = send_buffer
2081                    .ok_or(PacketError::InvalidSendBufferIndex)?
2082                    .gpadl
2083                    .first()
2084                    .unwrap()
2085                    .try_subrange(
2086                        message.send_buffer_section_index as usize * 6144,
2087                        message.send_buffer_section_size as usize,
2088                    )
2089                    .ok_or(PacketError::InvalidSendBufferIndex)?;
2090            }
2091            PacketData::RndisPacket(message)
2092        }
2093        protocol::MESSAGE2_TYPE_SEND_NDIS_CONFIG if version >= Some(Version::V2) => {
2094            PacketData::SendNdisConfig(read_packet_data(&mut reader)?)
2095        }
2096        protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH if version >= Some(Version::V4) => {
2097            PacketData::SwitchDataPath(read_packet_data(&mut reader)?)
2098        }
2099        protocol::MESSAGE5_TYPE_OID_QUERY_EX if version >= Some(Version::V5) => {
2100            PacketData::OidQueryEx(read_packet_data(&mut reader)?)
2101        }
2102        protocol::MESSAGE5_TYPE_SUB_CHANNEL if version >= Some(Version::V5) => {
2103            PacketData::SubChannelRequest(read_packet_data(&mut reader)?)
2104        }
2105        typ => return Err(PacketError::UnknownType(typ)),
2106    };
2107    Ok(Packet {
2108        data,
2109        transaction_id: packet.transaction_id(),
2110        external_data: packet
2111            .read_external_ranges()
2112            .map_err(PacketError::ExternalData)?,
2113        send_buffer_suballocation,
2114    })
2115}
2116
2117#[derive(Debug, Copy, Clone)]
2118struct NvspMessage<T> {
2119    header: protocol::MessageHeader,
2120    data: T,
2121    padding: &'static [u8],
2122}
2123
2124impl<T: IntoBytes + Immutable + KnownLayout> NvspMessage<T> {
2125    fn payload(&self) -> [&[u8]; 3] {
2126        [self.header.as_bytes(), self.data.as_bytes(), self.padding]
2127    }
2128}
2129
2130impl<T: RingMem> NetChannel<T> {
2131    fn message<P: IntoBytes + Immutable + KnownLayout>(
2132        &self,
2133        message_type: u32,
2134        data: P,
2135    ) -> NvspMessage<P> {
2136        let padding = self.padding(&data);
2137        NvspMessage {
2138            header: protocol::MessageHeader { message_type },
2139            data,
2140            padding,
2141        }
2142    }
2143
2144    /// Returns zero padding bytes to round the payload up to the packet size.
2145    /// Only needed for Windows guests, which are picky about packet sizes.
2146    fn padding<P: IntoBytes + Immutable + KnownLayout>(&self, data: &P) -> &'static [u8] {
2147        static PADDING: &[u8] = &[0; protocol::PACKET_SIZE_V61];
2148        let padding_len = self.packet_size
2149            - cmp::min(
2150                self.packet_size,
2151                size_of::<protocol::MessageHeader>() + data.as_bytes().len(),
2152            );
2153        &PADDING[..padding_len]
2154    }
2155
2156    fn send_completion(
2157        &mut self,
2158        transaction_id: Option<u64>,
2159        payload: &[&[u8]],
2160    ) -> Result<(), WorkerError> {
2161        match transaction_id {
2162            None => Ok(()),
2163            Some(transaction_id) => Ok(self
2164                .queue
2165                .split()
2166                .1
2167                .try_write(&queue::OutgoingPacket {
2168                    transaction_id,
2169                    packet_type: OutgoingPacketType::Completion,
2170                    payload,
2171                })
2172                .map_err(|err| match err {
2173                    queue::TryWriteError::Full(_) => WorkerError::OutOfSpace,
2174                    queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2175                })?),
2176        }
2177    }
2178}
2179
2180static SUPPORTED_VERSIONS: &[Version] = &[
2181    Version::V1,
2182    Version::V2,
2183    Version::V4,
2184    Version::V5,
2185    Version::V6,
2186    Version::V61,
2187];
2188
2189fn check_version(requested_version: u32) -> Option<Version> {
2190    SUPPORTED_VERSIONS
2191        .iter()
2192        .find(|version| **version as u32 == requested_version)
2193        .copied()
2194}
2195
2196#[derive(Debug)]
2197struct ReceiveBuffer {
2198    gpadl: GpadlView,
2199    id: u16,
2200    count: u32,
2201    sub_allocation_size: u32,
2202}
2203
2204#[derive(Debug, Error)]
2205enum BufferError {
2206    #[error("unsupported suballocation size {0}")]
2207    UnsupportedSuballocationSize(u32),
2208    #[error("unaligned gpadl")]
2209    UnalignedGpadl,
2210    #[error("unknown gpadl ID")]
2211    UnknownGpadlId(#[from] UnknownGpadlId),
2212}
2213
2214impl ReceiveBuffer {
2215    fn new(
2216        gpadl_map: &GpadlMapView,
2217        gpadl_id: GpadlId,
2218        id: u16,
2219        sub_allocation_size: u32,
2220    ) -> Result<Self, BufferError> {
2221        if sub_allocation_size < sub_allocation_size_for_mtu(DEFAULT_MTU) {
2222            return Err(BufferError::UnsupportedSuballocationSize(
2223                sub_allocation_size,
2224            ));
2225        }
2226        let gpadl = gpadl_map.map(gpadl_id)?;
2227        let range = gpadl
2228            .contiguous_aligned()
2229            .ok_or(BufferError::UnalignedGpadl)?;
2230        let num_sub_allocations = range.len() as u32 / sub_allocation_size;
2231        if num_sub_allocations == 0 {
2232            return Err(BufferError::UnsupportedSuballocationSize(
2233                sub_allocation_size,
2234            ));
2235        }
2236        let recv_buffer = Self {
2237            gpadl,
2238            id,
2239            count: num_sub_allocations,
2240            sub_allocation_size,
2241        };
2242        Ok(recv_buffer)
2243    }
2244
2245    fn range(&self, index: u32) -> PagedRange<'_> {
2246        self.gpadl.first().unwrap().subrange(
2247            (index * self.sub_allocation_size) as usize,
2248            self.sub_allocation_size as usize,
2249        )
2250    }
2251
2252    fn transfer_page_range(&self, index: u32, len: usize) -> ring::TransferPageRange {
2253        assert!(len <= self.sub_allocation_size as usize);
2254        ring::TransferPageRange {
2255            byte_offset: index * self.sub_allocation_size,
2256            byte_count: len as u32,
2257        }
2258    }
2259
2260    fn saved_state(&self) -> saved_state::ReceiveBuffer {
2261        saved_state::ReceiveBuffer {
2262            gpadl_id: self.gpadl.id(),
2263            id: self.id,
2264            sub_allocation_size: self.sub_allocation_size,
2265        }
2266    }
2267}
2268
2269#[derive(Debug)]
2270struct SendBuffer {
2271    gpadl: GpadlView,
2272}
2273
2274impl SendBuffer {
2275    fn new(gpadl_map: &GpadlMapView, gpadl_id: GpadlId) -> Result<Self, BufferError> {
2276        let gpadl = gpadl_map.map(gpadl_id)?;
2277        gpadl
2278            .contiguous_aligned()
2279            .ok_or(BufferError::UnalignedGpadl)?;
2280        Ok(Self { gpadl })
2281    }
2282}
2283
2284impl<T: RingMem> NetChannel<T> {
2285    /// Process a single RNDIS message.
2286    fn handle_rndis_message(
2287        &mut self,
2288        buffers: &ChannelBuffers,
2289        state: &mut ActiveState,
2290        id: TxId,
2291        message_type: u32,
2292        mut reader: PacketReader<'_>,
2293        segments: &mut Vec<TxSegment>,
2294    ) -> Result<bool, WorkerError> {
2295        let is_packet = match message_type {
2296            rndisprot::MESSAGE_TYPE_PACKET_MSG => {
2297                self.handle_rndis_packet_message(
2298                    id,
2299                    reader,
2300                    &buffers.mem,
2301                    segments,
2302                    &mut state.stats,
2303                )?;
2304                true
2305            }
2306            rndisprot::MESSAGE_TYPE_HALT_MSG => false,
2307            n => {
2308                let control = state
2309                    .primary
2310                    .as_mut()
2311                    .ok_or(WorkerError::NotSupportedOnSubChannel(n))?;
2312
2313                // This is a control message that needs a response. Responding
2314                // will require a suballocation to be available, which it may
2315                // not be right now. Enqueue the suballocation to a queue and
2316                // process the queue as suballocations become available.
2317                const CONTROL_MESSAGE_MAX_QUEUED_BYTES: usize = 100 * 1024;
2318                if reader.len() == 0 {
2319                    return Err(WorkerError::RndisMessageTooSmall);
2320                }
2321                // Do not let the queue get too large--the guest should not be
2322                // sending very many control messages at a time.
2323                if CONTROL_MESSAGE_MAX_QUEUED_BYTES - control.control_messages_len < reader.len() {
2324                    return Err(WorkerError::TooManyControlMessages);
2325                }
2326                control.control_messages_len += reader.len();
2327                control.control_messages.push_back(ControlMessage {
2328                    message_type,
2329                    data: reader.read_all()?.into(),
2330                });
2331
2332                false
2333                // The queue will be processed in the main dispatch loop.
2334            }
2335        };
2336        Ok(is_packet)
2337    }
2338
2339    /// Process an RNDIS package message (used to send an Ethernet frame).
2340    fn handle_rndis_packet_message(
2341        &mut self,
2342        id: TxId,
2343        reader: PacketReader<'_>,
2344        mem: &GuestMemory,
2345        segments: &mut Vec<TxSegment>,
2346        stats: &mut QueueStats,
2347    ) -> Result<(), WorkerError> {
2348        // Headers are guaranteed to be in a single PagedRange.
2349        let headers = reader
2350            .clone()
2351            .into_inner()
2352            .paged_ranges()
2353            .find(|r| !r.is_empty())
2354            .ok_or(WorkerError::RndisMessageTooSmall)?;
2355        let mut data = reader.into_inner();
2356        let request: rndisprot::Packet = headers.reader(mem).read_plain()?;
2357        if request.num_oob_data_elements != 0
2358            || request.oob_data_length != 0
2359            || request.oob_data_offset != 0
2360            || request.vc_handle != 0
2361        {
2362            return Err(WorkerError::UnsupportedRndisBehavior);
2363        }
2364
2365        if data.len() < request.data_offset as usize
2366            || (data.len() - request.data_offset as usize) < request.data_length as usize
2367            || request.data_length == 0
2368        {
2369            return Err(WorkerError::RndisMessageTooSmall);
2370        }
2371
2372        data.skip(request.data_offset as usize);
2373        data.truncate(request.data_length as usize);
2374
2375        let mut metadata = net_backend::TxMetadata {
2376            id,
2377            len: request.data_length as usize,
2378            ..Default::default()
2379        };
2380
2381        if request.per_packet_info_length != 0 {
2382            let mut ppi = headers
2383                .try_subrange(
2384                    request.per_packet_info_offset as usize,
2385                    request.per_packet_info_length as usize,
2386                )
2387                .ok_or(WorkerError::RndisMessageTooSmall)?;
2388            while !ppi.is_empty() {
2389                let h: rndisprot::PerPacketInfo = ppi.reader(mem).read_plain()?;
2390                if h.size == 0 {
2391                    return Err(WorkerError::RndisMessageTooSmall);
2392                }
2393                let (this, rest) = ppi
2394                    .try_split(h.size as usize)
2395                    .ok_or(WorkerError::RndisMessageTooSmall)?;
2396                let (_, d) = this
2397                    .try_split(h.per_packet_information_offset as usize)
2398                    .ok_or(WorkerError::RndisMessageTooSmall)?;
2399                match h.typ {
2400                    rndisprot::PPI_TCP_IP_CHECKSUM => {
2401                        let n: rndisprot::TxTcpIpChecksumInfo = d.reader(mem).read_plain()?;
2402
2403                        metadata.offload_tcp_checksum =
2404                            (n.is_ipv4() || n.is_ipv6()) && n.tcp_checksum();
2405                        metadata.offload_udp_checksum =
2406                            (n.is_ipv4() || n.is_ipv6()) && !n.tcp_checksum() && n.udp_checksum();
2407                        metadata.offload_ip_header_checksum = n.is_ipv4() && n.ip_header_checksum();
2408                        metadata.l3_protocol = if n.is_ipv4() {
2409                            L3Protocol::Ipv4
2410                        } else if n.is_ipv6() {
2411                            L3Protocol::Ipv6
2412                        } else {
2413                            L3Protocol::Unknown
2414                        };
2415                        metadata.l2_len = ETHERNET_HEADER_LEN as u8;
2416                        if metadata.offload_tcp_checksum || metadata.offload_udp_checksum {
2417                            metadata.l3_len = if n.tcp_header_offset() >= metadata.l2_len as u16 {
2418                                n.tcp_header_offset() - metadata.l2_len as u16
2419                            } else if n.is_ipv4() {
2420                                let mut reader = data.clone().reader(mem);
2421                                reader.skip(metadata.l2_len as usize)?;
2422                                let mut b = 0;
2423                                reader.read(std::slice::from_mut(&mut b))?;
2424                                (b as u16 >> 4) * 4
2425                            } else {
2426                                // Hope there are no extensions.
2427                                40
2428                            };
2429                        }
2430                    }
2431                    rndisprot::PPI_LSO => {
2432                        let n: rndisprot::TcpLsoInfo = d.reader(mem).read_plain()?;
2433
2434                        metadata.offload_tcp_segmentation = true;
2435                        metadata.offload_tcp_checksum = true;
2436                        metadata.offload_ip_header_checksum = n.is_ipv4();
2437                        metadata.l3_protocol = if n.is_ipv4() {
2438                            L3Protocol::Ipv4
2439                        } else {
2440                            L3Protocol::Ipv6
2441                        };
2442                        metadata.l2_len = ETHERNET_HEADER_LEN as u8;
2443                        if n.tcp_header_offset() < metadata.l2_len as u16 {
2444                            return Err(WorkerError::InvalidTcpHeaderOffset);
2445                        }
2446                        metadata.l3_len = n.tcp_header_offset() - metadata.l2_len as u16;
2447                        metadata.l4_len = {
2448                            let mut reader = data.clone().reader(mem);
2449                            reader
2450                                .skip(metadata.l2_len as usize + metadata.l3_len as usize + 12)?;
2451                            let mut b = 0;
2452                            reader.read(std::slice::from_mut(&mut b))?;
2453                            (b >> 4) * 4
2454                        };
2455                        metadata.max_tcp_segment_size = n.mss() as u16;
2456                    }
2457                    _ => {}
2458                }
2459                ppi = rest;
2460            }
2461        }
2462
2463        let start = segments.len();
2464        for range in data.paged_ranges().flat_map(|r| r.ranges()) {
2465            let range = range.map_err(WorkerError::InvalidGpadl)?;
2466            segments.push(TxSegment {
2467                ty: net_backend::TxSegmentType::Tail,
2468                gpa: range.start,
2469                len: range.len() as u32,
2470            });
2471        }
2472
2473        metadata.segment_count = segments.len() - start;
2474
2475        stats.tx_packets.increment();
2476        if metadata.offload_tcp_checksum || metadata.offload_udp_checksum {
2477            stats.tx_checksum_packets.increment();
2478        }
2479        if metadata.offload_tcp_segmentation {
2480            stats.tx_lso_packets.increment();
2481        }
2482
2483        segments[start].ty = net_backend::TxSegmentType::Head(metadata);
2484
2485        Ok(())
2486    }
2487
2488    /// Notify the adapter that the guest VF state has changed and it may
2489    /// need to send a message to the guest.
2490    fn guest_vf_is_available(
2491        &mut self,
2492        guest_vf_id: Option<u32>,
2493        version: Version,
2494        config: NdisConfig,
2495    ) -> Result<bool, WorkerError> {
2496        let serial_number = guest_vf_id.map(|vfid| self.adapter.get_guest_vf_serial_number(vfid));
2497        if version >= Version::V4 && config.capabilities.sriov() {
2498            tracing::info!(
2499                available = serial_number.is_some(),
2500                serial_number,
2501                "sending VF association message."
2502            );
2503            // N.B. MIN_CONTROL_RING_SIZE reserves room to send this packet.
2504            let message = {
2505                self.message(
2506                    protocol::MESSAGE4_TYPE_SEND_VF_ASSOCIATION,
2507                    protocol::Message4SendVfAssociation {
2508                        vf_allocated: if serial_number.is_some() { 1 } else { 0 },
2509                        serial_number: serial_number.unwrap_or(0),
2510                    },
2511                )
2512            };
2513            self.queue
2514                .split()
2515                .1
2516                .try_write(&queue::OutgoingPacket {
2517                    transaction_id: VF_ASSOCIATION_TRANSACTION_ID,
2518                    packet_type: OutgoingPacketType::InBandWithCompletion,
2519                    payload: &message.payload(),
2520                })
2521                .map_err(|err| match err {
2522                    queue::TryWriteError::Full(len) => {
2523                        tracing::error!(len, "failed to write vf association message");
2524                        WorkerError::OutOfSpace
2525                    }
2526                    queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2527                })?;
2528            Ok(true)
2529        } else {
2530            tracing::info!(
2531                available = serial_number.is_some(),
2532                serial_number,
2533                major = version.major(),
2534                minor = version.minor(),
2535                sriov_capable = config.capabilities.sriov(),
2536                "Skipping NvspMessage4TypeSendVFAssociation message"
2537            );
2538            Ok(false)
2539        }
2540    }
2541
2542    /// Send the `NvspMessage5TypeSendIndirectionTable` message.
2543    fn guest_send_indirection_table(&mut self, version: Version, num_channels_opened: u32) {
2544        // N.B. MIN_STATE_CHANGE_RING_SIZE needs to be large enough to support sending the indirection table.
2545        if version < Version::V5 {
2546            return;
2547        }
2548
2549        #[repr(C)]
2550        #[derive(IntoBytes, Immutable, KnownLayout)]
2551        struct SendIndirectionMsg {
2552            pub message: protocol::Message5SendIndirectionTable,
2553            pub send_indirection_table:
2554                [u32; VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES as usize],
2555        }
2556
2557        // The offset to the send indirection table from the beginning of the NVSP message.
2558        let send_indirection_table_offset = offset_of!(SendIndirectionMsg, send_indirection_table)
2559            + size_of::<protocol::MessageHeader>();
2560        let mut data = SendIndirectionMsg {
2561            message: protocol::Message5SendIndirectionTable {
2562                table_entry_count: VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES,
2563                table_offset: send_indirection_table_offset as u32,
2564            },
2565            send_indirection_table: Default::default(),
2566        };
2567
2568        for i in 0..data.send_indirection_table.len() {
2569            data.send_indirection_table[i] = i as u32 % num_channels_opened;
2570        }
2571
2572        let message = NvspMessage {
2573            header: protocol::MessageHeader {
2574                message_type: protocol::MESSAGE5_TYPE_SEND_INDIRECTION_TABLE,
2575            },
2576            data,
2577            padding: &[],
2578        };
2579        let result = self
2580            .queue
2581            .split()
2582            .1
2583            .try_write(&queue::OutgoingPacket {
2584                transaction_id: 0,
2585                packet_type: OutgoingPacketType::InBandNoCompletion,
2586                payload: &message.payload(),
2587            })
2588            .map_err(|err| match err {
2589                queue::TryWriteError::Full(len) => {
2590                    tracing::error!(len, "failed to write send indirection table message");
2591                    WorkerError::OutOfSpace
2592                }
2593                queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2594            });
2595        if let Err(err) = result {
2596            tracing::error!(err = %err, "Failed to notify guest about the send indirection table");
2597        }
2598    }
2599
2600    /// Notify the guest that the data path has been switched back to synthetic
2601    /// due to some external state change.
2602    fn guest_vf_data_path_switched_to_synthetic(&mut self) {
2603        let message = NvspMessage {
2604            header: protocol::MessageHeader {
2605                message_type: protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH,
2606            },
2607            data: protocol::Message4SwitchDataPath {
2608                active_data_path: protocol::DataPath::SYNTHETIC.0,
2609            },
2610            padding: &[],
2611        };
2612        let result = self
2613            .queue
2614            .split()
2615            .1
2616            .try_write(&queue::OutgoingPacket {
2617                transaction_id: SWITCH_DATA_PATH_TRANSACTION_ID,
2618                packet_type: OutgoingPacketType::InBandWithCompletion,
2619                payload: &message.payload(),
2620            })
2621            .map_err(|err| match err {
2622                queue::TryWriteError::Full(len) => {
2623                    tracing::error!(
2624                        len,
2625                        "failed to write MESSAGE4_TYPE_SWITCH_DATA_PATH message"
2626                    );
2627                    WorkerError::OutOfSpace
2628                }
2629                queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2630            });
2631        if let Err(err) = result {
2632            tracing::error!(err = %err, "Failed to notify guest that data path is now synthetic");
2633        }
2634    }
2635
2636    /// Process an internal state change
2637    async fn handle_state_change(
2638        &mut self,
2639        primary: &mut PrimaryChannelState,
2640        buffers: &ChannelBuffers,
2641    ) -> Result<Option<CoordinatorMessage>, WorkerError> {
2642        // N.B. MIN_STATE_CHANGE_RING_SIZE needs to be large enough to support sending state change messages.
2643        // The worst case is UnavailableFromDataPathSwitchPending, which will send three messages:
2644        //      1. completion of switch data path request
2645        //      2. Switch data path notification (back to synthetic)
2646        //      3. Disassociate VF adapter.
2647        if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
2648            // Notify guest that a VF capability has recently arrived.
2649            if primary.rndis_state == RndisState::Operational {
2650                if self.guest_vf_is_available(Some(vfid), buffers.version, buffers.ndis_config)? {
2651                    primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
2652                    return Ok(Some(CoordinatorMessage::Update(
2653                        CoordinatorMessageUpdateType {
2654                            guest_vf_state: true,
2655                            ..Default::default()
2656                        },
2657                    )));
2658                } else if let Some(true) = primary.is_data_path_switched {
2659                    tracing::error!(
2660                        "Data path switched, but current guest negotiation does not support VTL0 VF"
2661                    );
2662                }
2663            }
2664            return Ok(None);
2665        }
2666        loop {
2667            primary.guest_vf_state = match primary.guest_vf_state {
2668                PrimaryChannelGuestVfState::UnavailableFromAvailable => {
2669                    // Notify guest that the VF is unavailable. It has already been surprise removed.
2670                    if primary.rndis_state == RndisState::Operational {
2671                        self.guest_vf_is_available(None, buffers.version, buffers.ndis_config)?;
2672                    }
2673                    PrimaryChannelGuestVfState::Unavailable
2674                }
2675                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
2676                    to_guest,
2677                    id,
2678                } => {
2679                    // Complete the data path switch request.
2680                    self.send_completion(id, &[])?;
2681                    if to_guest {
2682                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
2683                    } else {
2684                        PrimaryChannelGuestVfState::UnavailableFromAvailable
2685                    }
2686                }
2687                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
2688                    // Notify guest that the data path is now synthetic.
2689                    self.guest_vf_data_path_switched_to_synthetic();
2690                    PrimaryChannelGuestVfState::UnavailableFromAvailable
2691                }
2692                PrimaryChannelGuestVfState::DataPathSynthetic => {
2693                    // Notify guest that the data path is now synthetic.
2694                    self.guest_vf_data_path_switched_to_synthetic();
2695                    PrimaryChannelGuestVfState::Ready
2696                }
2697                PrimaryChannelGuestVfState::DataPathSwitchPending {
2698                    to_guest,
2699                    id,
2700                    result,
2701                } => {
2702                    let result = result.expect("DataPathSwitchPending should have been processed");
2703                    // Complete the data path switch request.
2704                    self.send_completion(id, &[])?;
2705                    if result {
2706                        if to_guest {
2707                            PrimaryChannelGuestVfState::DataPathSwitched
2708                        } else {
2709                            PrimaryChannelGuestVfState::Ready
2710                        }
2711                    } else {
2712                        if to_guest {
2713                            PrimaryChannelGuestVfState::DataPathSynthetic
2714                        } else {
2715                            tracing::error!(
2716                                "Failure when guest requested switch back to synthetic"
2717                            );
2718                            PrimaryChannelGuestVfState::DataPathSwitched
2719                        }
2720                    }
2721                }
2722                PrimaryChannelGuestVfState::Initializing
2723                | PrimaryChannelGuestVfState::Restoring(_) => {
2724                    panic!("Invalid guest VF state: {}", primary.guest_vf_state)
2725                }
2726                _ => break,
2727            };
2728        }
2729        Ok(None)
2730    }
2731
2732    /// Process a control message, writing the response to the provided receive
2733    /// buffer suballocation.
2734    fn handle_rndis_control_message(
2735        &mut self,
2736        primary: &mut PrimaryChannelState,
2737        buffers: &ChannelBuffers,
2738        message_type: u32,
2739        mut reader: impl MemoryRead + Clone,
2740        id: u32,
2741    ) -> Result<(), WorkerError> {
2742        let mem = &buffers.mem;
2743        let buffer_range = &buffers.recv_buffer.range(id);
2744        match message_type {
2745            rndisprot::MESSAGE_TYPE_INITIALIZE_MSG => {
2746                if primary.rndis_state != RndisState::Initializing {
2747                    return Err(WorkerError::InvalidRndisState);
2748                }
2749
2750                let request: rndisprot::InitializeRequest = reader.read_plain()?;
2751
2752                tracing::trace!(
2753                    ?request,
2754                    "handling control message MESSAGE_TYPE_INITIALIZE_MSG"
2755                );
2756
2757                primary.rndis_state = RndisState::Operational;
2758
2759                let mut writer = buffer_range.writer(mem);
2760                let message_length = write_rndis_message(
2761                    &mut writer,
2762                    rndisprot::MESSAGE_TYPE_INITIALIZE_CMPLT,
2763                    0,
2764                    &rndisprot::InitializeComplete {
2765                        request_id: request.request_id,
2766                        status: rndisprot::STATUS_SUCCESS,
2767                        major_version: rndisprot::MAJOR_VERSION,
2768                        minor_version: rndisprot::MINOR_VERSION,
2769                        device_flags: rndisprot::DF_CONNECTIONLESS,
2770                        medium: rndisprot::MEDIUM_802_3,
2771                        max_packets_per_message: 8,
2772                        max_transfer_size: 0xEFFFFFFF,
2773                        packet_alignment_factor: 3,
2774                        af_list_offset: 0,
2775                        af_list_size: 0,
2776                    },
2777                )?;
2778                self.send_rndis_control_message(buffers, id, message_length)?;
2779                if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
2780                    if self.guest_vf_is_available(
2781                        Some(vfid),
2782                        buffers.version,
2783                        buffers.ndis_config,
2784                    )? {
2785                        // Ideally the VF would not be presented to the guest
2786                        // until the completion packet has arrived, so that the
2787                        // guest is prepared. This is most interesting for the
2788                        // case of a VF associated with multiple guest
2789                        // adapters, using a concept like vports. In this
2790                        // scenario it would be better if all of the adapters
2791                        // were aware a VF was available before the device
2792                        // arrived. This is not currently possible because the
2793                        // Linux netvsc driver ignores the completion requested
2794                        // flag on inband packets and won't send a completion
2795                        // packet.
2796                        primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
2797                        self.send_coordinator_update_vf();
2798                    } else if let Some(true) = primary.is_data_path_switched {
2799                        tracing::error!(
2800                            "Data path switched, but current guest negotiation does not support VTL0 VF"
2801                        );
2802                    }
2803                }
2804            }
2805            rndisprot::MESSAGE_TYPE_QUERY_MSG => {
2806                let request: rndisprot::QueryRequest = reader.read_plain()?;
2807
2808                tracing::trace!(?request, "handling control message MESSAGE_TYPE_QUERY_MSG");
2809
2810                let (header, body) = buffer_range
2811                    .try_split(
2812                        size_of::<rndisprot::MessageHeader>()
2813                            + size_of::<rndisprot::QueryComplete>(),
2814                    )
2815                    .ok_or(WorkerError::RndisMessageTooSmall)?;
2816                let (status, tx) = match self.adapter.handle_oid_query(
2817                    buffers,
2818                    primary,
2819                    request.oid,
2820                    body.writer(mem),
2821                ) {
2822                    Ok(tx) => (rndisprot::STATUS_SUCCESS, tx),
2823                    Err(err) => (err.as_status(), 0),
2824                };
2825
2826                let message_length = write_rndis_message(
2827                    &mut header.writer(mem),
2828                    rndisprot::MESSAGE_TYPE_QUERY_CMPLT,
2829                    tx,
2830                    &rndisprot::QueryComplete {
2831                        request_id: request.request_id,
2832                        status,
2833                        information_buffer_offset: size_of::<rndisprot::QueryComplete>() as u32,
2834                        information_buffer_length: tx as u32,
2835                    },
2836                )?;
2837                self.send_rndis_control_message(buffers, id, message_length)?;
2838            }
2839            rndisprot::MESSAGE_TYPE_SET_MSG => {
2840                let request: rndisprot::SetRequest = reader.read_plain()?;
2841
2842                tracing::trace!(?request, "handling control message MESSAGE_TYPE_SET_MSG");
2843
2844                let status = match self.adapter.handle_oid_set(primary, request.oid, reader) {
2845                    Ok((restart_endpoint, packet_filter)) => {
2846                        // Restart the endpoint if the OID changed some critical
2847                        // endpoint property.
2848                        if restart_endpoint {
2849                            self.restart = Some(CoordinatorMessage::Restart);
2850                        }
2851                        if let Some(filter) = packet_filter {
2852                            if self.packet_filter != filter {
2853                                self.packet_filter = filter;
2854                                self.send_coordinator_update_filter();
2855                            }
2856                        }
2857                        rndisprot::STATUS_SUCCESS
2858                    }
2859                    Err(err) => {
2860                        tracelimit::warn_ratelimited!(oid = ?request.oid, error = &err as &dyn std::error::Error, "oid failure");
2861                        err.as_status()
2862                    }
2863                };
2864
2865                let message_length = write_rndis_message(
2866                    &mut buffer_range.writer(mem),
2867                    rndisprot::MESSAGE_TYPE_SET_CMPLT,
2868                    0,
2869                    &rndisprot::SetComplete {
2870                        request_id: request.request_id,
2871                        status,
2872                    },
2873                )?;
2874                self.send_rndis_control_message(buffers, id, message_length)?;
2875            }
2876            rndisprot::MESSAGE_TYPE_RESET_MSG => {
2877                return Err(WorkerError::RndisMessageTypeNotImplemented);
2878            }
2879            rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG => {
2880                return Err(WorkerError::RndisMessageTypeNotImplemented);
2881            }
2882            rndisprot::MESSAGE_TYPE_KEEPALIVE_MSG => {
2883                let request: rndisprot::KeepaliveRequest = reader.read_plain()?;
2884
2885                tracing::trace!(
2886                    ?request,
2887                    "handling control message MESSAGE_TYPE_KEEPALIVE_MSG"
2888                );
2889
2890                let message_length = write_rndis_message(
2891                    &mut buffer_range.writer(mem),
2892                    rndisprot::MESSAGE_TYPE_KEEPALIVE_CMPLT,
2893                    0,
2894                    &rndisprot::KeepaliveComplete {
2895                        request_id: request.request_id,
2896                        status: rndisprot::STATUS_SUCCESS,
2897                    },
2898                )?;
2899                self.send_rndis_control_message(buffers, id, message_length)?;
2900            }
2901            rndisprot::MESSAGE_TYPE_SET_EX_MSG => {
2902                return Err(WorkerError::RndisMessageTypeNotImplemented);
2903            }
2904            _ => return Err(WorkerError::UnknownRndisMessageType(message_type)),
2905        };
2906        Ok(())
2907    }
2908
2909    fn try_send_rndis_message(
2910        &mut self,
2911        transaction_id: u64,
2912        channel_type: u32,
2913        recv_buffer_id: u16,
2914        transfer_pages: &[ring::TransferPageRange],
2915    ) -> Result<Option<usize>, WorkerError> {
2916        let message = self.message(
2917            protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET,
2918            protocol::Message1SendRndisPacket {
2919                channel_type,
2920                send_buffer_section_index: 0xffffffff,
2921                send_buffer_section_size: 0,
2922            },
2923        );
2924        let pending_send_size = match self.queue.split().1.try_write(&queue::OutgoingPacket {
2925            transaction_id,
2926            packet_type: OutgoingPacketType::TransferPages(recv_buffer_id, transfer_pages),
2927            payload: &message.payload(),
2928        }) {
2929            Ok(()) => None,
2930            Err(queue::TryWriteError::Full(n)) => Some(n),
2931            Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
2932        };
2933        Ok(pending_send_size)
2934    }
2935
2936    fn send_rndis_control_message(
2937        &mut self,
2938        buffers: &ChannelBuffers,
2939        id: u32,
2940        message_length: usize,
2941    ) -> Result<(), WorkerError> {
2942        let result = self.try_send_rndis_message(
2943            id as u64,
2944            protocol::CONTROL_CHANNEL_TYPE,
2945            buffers.recv_buffer.id,
2946            std::slice::from_ref(&buffers.recv_buffer.transfer_page_range(id, message_length)),
2947        )?;
2948
2949        // Ring size is checked before control messages are processed, so failure to write is unexpected.
2950        match result {
2951            None => Ok(()),
2952            Some(len) => {
2953                tracelimit::error_ratelimited!(len, "failed to write control message completion");
2954                Err(WorkerError::OutOfSpace)
2955            }
2956        }
2957    }
2958
2959    fn indicate_status(
2960        &mut self,
2961        buffers: &ChannelBuffers,
2962        id: u32,
2963        status: u32,
2964        payload: &[u8],
2965    ) -> Result<(), WorkerError> {
2966        let buffer = &buffers.recv_buffer.range(id);
2967        let mut writer = buffer.writer(&buffers.mem);
2968        let message_length = write_rndis_message(
2969            &mut writer,
2970            rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG,
2971            payload.len(),
2972            &rndisprot::IndicateStatus {
2973                status,
2974                status_buffer_length: payload.len() as u32,
2975                status_buffer_offset: if payload.is_empty() {
2976                    0
2977                } else {
2978                    size_of::<rndisprot::IndicateStatus>() as u32
2979                },
2980            },
2981        )?;
2982        writer.write(payload)?;
2983        self.send_rndis_control_message(buffers, id, message_length)?;
2984        Ok(())
2985    }
2986
2987    /// Processes pending control messages until all are processed or there are
2988    /// no available suballocations.
2989    fn process_control_messages(
2990        &mut self,
2991        buffers: &ChannelBuffers,
2992        state: &mut ActiveState,
2993    ) -> Result<(), WorkerError> {
2994        let Some(primary) = &mut state.primary else {
2995            return Ok(());
2996        };
2997
2998        while !primary.control_messages.is_empty()
2999            || (primary.pending_offload_change && primary.rndis_state == RndisState::Operational)
3000        {
3001            // Ensure the ring buffer has enough room to successfully complete control message handling.
3002            if !self.queue.split().1.can_write(MIN_CONTROL_RING_SIZE)? {
3003                self.pending_send_size = MIN_CONTROL_RING_SIZE;
3004                break;
3005            }
3006            let Some(id) = primary.free_control_buffers.pop() else {
3007                break;
3008            };
3009
3010            // Mark the receive buffer in use to allow the guest to release it.
3011            assert!(state.rx_bufs.is_free(id.0));
3012            state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
3013
3014            if let Some(message) = primary.control_messages.pop_front() {
3015                primary.control_messages_len -= message.data.len();
3016                self.handle_rndis_control_message(
3017                    primary,
3018                    buffers,
3019                    message.message_type,
3020                    message.data.as_ref(),
3021                    id.0,
3022                )?;
3023            } else if primary.pending_offload_change
3024                && primary.rndis_state == RndisState::Operational
3025            {
3026                let ndis_offload = primary.offload_config.ndis_offload();
3027                self.indicate_status(
3028                    buffers,
3029                    id.0,
3030                    rndisprot::STATUS_TASK_OFFLOAD_CURRENT_CONFIG,
3031                    &ndis_offload.as_bytes()[..ndis_offload.header.size.into()],
3032                )?;
3033                primary.pending_offload_change = false;
3034            } else {
3035                unreachable!();
3036            }
3037        }
3038        Ok(())
3039    }
3040
3041    fn send_coordinator_update_message(&mut self, guest_vf: bool, packet_filter: bool) {
3042        if self.restart.is_none() {
3043            self.restart = Some(CoordinatorMessage::Update(CoordinatorMessageUpdateType {
3044                guest_vf_state: guest_vf,
3045                filter_state: packet_filter,
3046            }));
3047        } else if let Some(CoordinatorMessage::Restart) = self.restart {
3048            // If a restart message is pending, do nothing.
3049            // A restart will try to switch the data path based on primary.guest_vf_state.
3050            // A restart will apply packet filter changes.
3051        } else if let Some(CoordinatorMessage::Update(ref mut update)) = self.restart {
3052            // Add the new update to the existing message.
3053            update.guest_vf_state |= guest_vf;
3054            update.filter_state |= packet_filter;
3055        }
3056    }
3057
3058    fn send_coordinator_update_vf(&mut self) {
3059        self.send_coordinator_update_message(true, false);
3060    }
3061
3062    fn send_coordinator_update_filter(&mut self) {
3063        self.send_coordinator_update_message(false, true);
3064    }
3065}
3066
3067/// Writes an RNDIS message to `writer`.
3068fn write_rndis_message<T: IntoBytes + Immutable + KnownLayout>(
3069    writer: &mut impl MemoryWrite,
3070    message_type: u32,
3071    extra: usize,
3072    payload: &T,
3073) -> Result<usize, AccessError> {
3074    let message_length = size_of::<rndisprot::MessageHeader>() + size_of_val(payload) + extra;
3075    writer.write(
3076        rndisprot::MessageHeader {
3077            message_type,
3078            message_length: message_length as u32,
3079        }
3080        .as_bytes(),
3081    )?;
3082    writer.write(payload.as_bytes())?;
3083    Ok(message_length)
3084}
3085
3086#[derive(Debug, Error)]
3087enum OidError {
3088    #[error(transparent)]
3089    Access(#[from] AccessError),
3090    #[error("unknown oid")]
3091    UnknownOid,
3092    #[error("invalid oid input, bad field {0}")]
3093    InvalidInput(&'static str),
3094    #[error("bad ndis version")]
3095    BadVersion,
3096    #[error("feature {0} not supported")]
3097    NotSupported(&'static str),
3098}
3099
3100impl OidError {
3101    fn as_status(&self) -> u32 {
3102        match self {
3103            OidError::UnknownOid | OidError::NotSupported(_) => rndisprot::STATUS_NOT_SUPPORTED,
3104            OidError::BadVersion => rndisprot::STATUS_BAD_VERSION,
3105            OidError::InvalidInput(_) => rndisprot::STATUS_INVALID_DATA,
3106            OidError::Access(_) => rndisprot::STATUS_FAILURE,
3107        }
3108    }
3109}
3110
3111const DEFAULT_MTU: u32 = 1514;
3112const MIN_MTU: u32 = DEFAULT_MTU;
3113const MAX_MTU: u32 = 9216;
3114
3115const ETHERNET_HEADER_LEN: u32 = 14;
3116
3117impl Adapter {
3118    fn get_guest_vf_serial_number(&self, vfid: u32) -> u32 {
3119        if let Some(guest_os_id) = self.get_guest_os_id.as_ref().map(|f| f()) {
3120            // For enlightened guests (which is only Windows at the moment), send the
3121            // adapter index, which was previously set as the vport serial number.
3122            if guest_os_id
3123                .microsoft()
3124                .unwrap_or(HvGuestOsMicrosoft::from(0))
3125                .os_id()
3126                == HvGuestOsMicrosoftIds::WINDOWS_NT.0
3127            {
3128                self.adapter_index
3129            } else {
3130                vfid
3131            }
3132        } else {
3133            vfid
3134        }
3135    }
3136
3137    fn handle_oid_query(
3138        &self,
3139        buffers: &ChannelBuffers,
3140        primary: &PrimaryChannelState,
3141        oid: rndisprot::Oid,
3142        mut writer: impl MemoryWrite,
3143    ) -> Result<usize, OidError> {
3144        tracing::debug!(?oid, "oid query");
3145        let available_len = writer.len();
3146        match oid {
3147            rndisprot::Oid::OID_GEN_SUPPORTED_LIST => {
3148                let supported_oids_common = &[
3149                    rndisprot::Oid::OID_GEN_SUPPORTED_LIST,
3150                    rndisprot::Oid::OID_GEN_HARDWARE_STATUS,
3151                    rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED,
3152                    rndisprot::Oid::OID_GEN_MEDIA_IN_USE,
3153                    rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD,
3154                    rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD,
3155                    rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE,
3156                    rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE,
3157                    rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE,
3158                    rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE,
3159                    rndisprot::Oid::OID_GEN_LINK_SPEED,
3160                    rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE,
3161                    rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE,
3162                    rndisprot::Oid::OID_GEN_VENDOR_ID,
3163                    rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION,
3164                    rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION,
3165                    rndisprot::Oid::OID_GEN_DRIVER_VERSION,
3166                    rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER,
3167                    rndisprot::Oid::OID_GEN_PROTOCOL_OPTIONS,
3168                    rndisprot::Oid::OID_GEN_MAC_OPTIONS,
3169                    rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS,
3170                    rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS,
3171                    rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES,
3172                    rndisprot::Oid::OID_GEN_FRIENDLY_NAME,
3173                    // Ethernet objects operation characteristics
3174                    rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS,
3175                    rndisprot::Oid::OID_802_3_CURRENT_ADDRESS,
3176                    rndisprot::Oid::OID_802_3_MULTICAST_LIST,
3177                    rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE,
3178                    // Ethernet objects statistics
3179                    rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT,
3180                    rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION,
3181                    rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS,
3182                    // PNP operations characteristics */
3183                    // rndisprot::Oid::OID_PNP_SET_POWER,
3184                    // rndisprot::Oid::OID_PNP_QUERY_POWER,
3185                    // RNDIS OIDS
3186                    rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER,
3187                ];
3188
3189                // NDIS6 OIDs
3190                let supported_oids_6 = &[
3191                    // Link State OID
3192                    rndisprot::Oid::OID_GEN_LINK_PARAMETERS,
3193                    rndisprot::Oid::OID_GEN_LINK_STATE,
3194                    rndisprot::Oid::OID_GEN_MAX_LINK_SPEED,
3195                    // NDIS 6 statistics OID
3196                    rndisprot::Oid::OID_GEN_BYTES_RCV,
3197                    rndisprot::Oid::OID_GEN_BYTES_XMIT,
3198                    // Offload related OID
3199                    rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS,
3200                    rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION,
3201                    rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES,
3202                    rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG,
3203                    // rndisprot::Oid::OID_802_3_ADD_MULTICAST_ADDRESS,
3204                    // rndisprot::Oid::OID_802_3_DELETE_MULTICAST_ADDRESS,
3205                ];
3206
3207                let supported_oids_63 = &[
3208                    rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES,
3209                    rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS,
3210                ];
3211
3212                match buffers.ndis_version.major {
3213                    5 => {
3214                        writer.write(supported_oids_common.as_bytes())?;
3215                    }
3216                    6 => {
3217                        writer.write(supported_oids_common.as_bytes())?;
3218                        writer.write(supported_oids_6.as_bytes())?;
3219                        if buffers.ndis_version.minor >= 30 {
3220                            writer.write(supported_oids_63.as_bytes())?;
3221                        }
3222                    }
3223                    _ => return Err(OidError::BadVersion),
3224                }
3225            }
3226            rndisprot::Oid::OID_GEN_HARDWARE_STATUS => {
3227                let status: u32 = 0; // NdisHardwareStatusReady
3228                writer.write(status.as_bytes())?;
3229            }
3230            rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED | rndisprot::Oid::OID_GEN_MEDIA_IN_USE => {
3231                writer.write(rndisprot::MEDIUM_802_3.as_bytes())?;
3232            }
3233            rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD
3234            | rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD
3235            | rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE => {
3236                let len: u32 = buffers.ndis_config.mtu - ETHERNET_HEADER_LEN;
3237                writer.write(len.as_bytes())?;
3238            }
3239            rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE
3240            | rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE
3241            | rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE => {
3242                let len: u32 = buffers.ndis_config.mtu;
3243                writer.write(len.as_bytes())?;
3244            }
3245            rndisprot::Oid::OID_GEN_LINK_SPEED => {
3246                let speed: u32 = (self.link_speed / 100) as u32; // In 100bps units
3247                writer.write(speed.as_bytes())?;
3248            }
3249            rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE
3250            | rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE => {
3251                // This value is meaningless for virtual NICs. Return what vmswitch returns.
3252                writer.write((256u32 * 1024).as_bytes())?
3253            }
3254            rndisprot::Oid::OID_GEN_VENDOR_ID => {
3255                // Like vmswitch, use the first N bytes of Microsoft's MAC address
3256                // prefix as the vendor ID.
3257                writer.write(0x0000155du32.as_bytes())?;
3258            }
3259            rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION => writer.write(b"Microsoft Corporation")?,
3260            rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION
3261            | rndisprot::Oid::OID_GEN_DRIVER_VERSION => {
3262                writer.write(0x0100u16.as_bytes())? // 1.0. Vmswitch reports 19.0 for Mn.
3263            }
3264            rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => writer.write(0u32.as_bytes())?,
3265            rndisprot::Oid::OID_GEN_MAC_OPTIONS => {
3266                let options: u32 = rndisprot::MAC_OPTION_COPY_LOOKAHEAD_DATA
3267                    | rndisprot::MAC_OPTION_TRANSFERS_NOT_PEND
3268                    | rndisprot::MAC_OPTION_NO_LOOPBACK;
3269                writer.write(options.as_bytes())?;
3270            }
3271            rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS => {
3272                writer.write(rndisprot::MEDIA_STATE_CONNECTED.as_bytes())?;
3273            }
3274            rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS => writer.write(u32::MAX.as_bytes())?,
3275            rndisprot::Oid::OID_GEN_FRIENDLY_NAME => {
3276                let name16: Vec<u16> = "Network Device".encode_utf16().collect();
3277                let mut name = rndisprot::FriendlyName::new_zeroed();
3278                name.name[..name16.len()].copy_from_slice(&name16);
3279                writer.write(name.as_bytes())?
3280            }
3281            rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS
3282            | rndisprot::Oid::OID_802_3_CURRENT_ADDRESS => {
3283                writer.write(&self.mac_address.to_bytes())?
3284            }
3285            rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE => {
3286                writer.write(0u32.as_bytes())?;
3287            }
3288            rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT
3289            | rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION
3290            | rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS => writer.write(0u32.as_bytes())?,
3291
3292            // NDIS6 OIDs:
3293            rndisprot::Oid::OID_GEN_LINK_STATE => {
3294                let link_state = rndisprot::LinkState {
3295                    header: rndisprot::NdisObjectHeader {
3296                        object_type: rndisprot::NdisObjectType::DEFAULT,
3297                        revision: 1,
3298                        size: size_of::<rndisprot::LinkState>() as u16,
3299                    },
3300                    media_connect_state: 1, /* MediaConnectStateConnected */
3301                    media_duplex_state: 0,  /* MediaDuplexStateUnknown */
3302                    padding: 0,
3303                    xmit_link_speed: self.link_speed,
3304                    rcv_link_speed: self.link_speed,
3305                    pause_functions: 0, /* NdisPauseFunctionsUnsupported */
3306                    auto_negotiation_flags: 0,
3307                };
3308                writer.write(link_state.as_bytes())?;
3309            }
3310            rndisprot::Oid::OID_GEN_MAX_LINK_SPEED => {
3311                let link_speed = rndisprot::LinkSpeed {
3312                    xmit: self.link_speed,
3313                    rcv: self.link_speed,
3314                };
3315                writer.write(link_speed.as_bytes())?;
3316            }
3317            rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES => {
3318                let ndis_offload = self.offload_support.ndis_offload();
3319                writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3320            }
3321            rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG => {
3322                let ndis_offload = primary.offload_config.ndis_offload();
3323                writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3324            }
3325            rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3326                writer.write(
3327                    &rndisprot::NdisOffloadEncapsulation {
3328                        header: rndisprot::NdisObjectHeader {
3329                            object_type: rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3330                            revision: 1,
3331                            size: rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1 as u16,
3332                        },
3333                        ipv4_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3334                        ipv4_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3335                        ipv4_header_size: ETHERNET_HEADER_LEN,
3336                        ipv6_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3337                        ipv6_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3338                        ipv6_header_size: ETHERNET_HEADER_LEN,
3339                    }
3340                    .as_bytes()[..rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1],
3341                )?;
3342            }
3343            rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES => {
3344                writer.write(
3345                    &rndisprot::NdisReceiveScaleCapabilities {
3346                        header: rndisprot::NdisObjectHeader {
3347                            object_type: rndisprot::NdisObjectType::RSS_CAPABILITIES,
3348                            revision: 2,
3349                            size: rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2
3350                                as u16,
3351                        },
3352                        capabilities_flags: rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV4
3353                            | rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV6
3354                            | rndisprot::NDIS_HASH_FUNCTION_TOEPLITZ,
3355                        number_of_interrupt_messages: 1,
3356                        number_of_receive_queues: self.max_queues.into(),
3357                        number_of_indirection_table_entries: if self.indirection_table_size != 0 {
3358                            self.indirection_table_size
3359                        } else {
3360                            // DPDK gets confused if the table size is zero,
3361                            // even if there is only one queue.
3362                            128
3363                        },
3364                        padding: 0,
3365                    }
3366                    .as_bytes()[..rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2],
3367                )?;
3368            }
3369            _ => {
3370                tracelimit::warn_ratelimited!(?oid, "query for unknown OID");
3371                return Err(OidError::UnknownOid);
3372            }
3373        };
3374        Ok(available_len - writer.len())
3375    }
3376
3377    fn handle_oid_set(
3378        &self,
3379        primary: &mut PrimaryChannelState,
3380        oid: rndisprot::Oid,
3381        reader: impl MemoryRead + Clone,
3382    ) -> Result<(bool, Option<u32>), OidError> {
3383        tracing::debug!(?oid, "oid set");
3384
3385        let mut restart_endpoint = false;
3386        let mut packet_filter = None;
3387        match oid {
3388            rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => {
3389                packet_filter = self.oid_set_packet_filter(reader)?;
3390            }
3391            rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS => {
3392                self.oid_set_offload_parameters(reader, primary)?;
3393            }
3394            rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3395                self.oid_set_offload_encapsulation(reader)?;
3396            }
3397            rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER => {
3398                self.oid_set_rndis_config_parameter(reader, primary)?;
3399            }
3400            rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES => {
3401                // TODO
3402            }
3403            rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS => {
3404                self.oid_set_rss_parameters(reader, primary)?;
3405
3406                // Endpoints cannot currently change RSS parameters without
3407                // being restarted. This was a limitation driven by some DPDK
3408                // PMDs, and should be fixed.
3409                restart_endpoint = true;
3410            }
3411            _ => {
3412                tracelimit::warn_ratelimited!(?oid, "set of unknown OID");
3413                return Err(OidError::UnknownOid);
3414            }
3415        }
3416        Ok((restart_endpoint, packet_filter))
3417    }
3418
3419    fn oid_set_rss_parameters(
3420        &self,
3421        mut reader: impl MemoryRead + Clone,
3422        primary: &mut PrimaryChannelState,
3423    ) -> Result<(), OidError> {
3424        // Vmswitch doesn't validate the NDIS header on this object, so read it manually.
3425        let mut params = rndisprot::NdisReceiveScaleParameters::new_zeroed();
3426        let len = reader.len().min(size_of_val(&params));
3427        reader.clone().read(&mut params.as_mut_bytes()[..len])?;
3428
3429        if ((params.flags & NDIS_RSS_PARAM_FLAG_DISABLE_RSS) != 0)
3430            || ((params.hash_information & NDIS_HASH_FUNCTION_MASK) == 0)
3431        {
3432            primary.rss_state = None;
3433            return Ok(());
3434        }
3435
3436        if params.hash_secret_key_size != 40 {
3437            return Err(OidError::InvalidInput("hash_secret_key_size"));
3438        }
3439        if params.indirection_table_size % 4 != 0 {
3440            return Err(OidError::InvalidInput("indirection_table_size"));
3441        }
3442        let indirection_table_size =
3443            (params.indirection_table_size / 4).min(self.indirection_table_size) as usize;
3444        let mut key = [0; 40];
3445        let mut indirection_table = vec![0u32; self.indirection_table_size as usize];
3446        reader
3447            .clone()
3448            .skip(params.hash_secret_key_offset as usize)?
3449            .read(&mut key)?;
3450        reader
3451            .skip(params.indirection_table_offset as usize)?
3452            .read(indirection_table[..indirection_table_size].as_mut_bytes())?;
3453        tracelimit::info_ratelimited!(?indirection_table, "OID_GEN_RECEIVE_SCALE_PARAMETERS");
3454        if indirection_table
3455            .iter()
3456            .any(|&x| x >= self.max_queues as u32)
3457        {
3458            return Err(OidError::InvalidInput("indirection_table"));
3459        }
3460        let (indir_init, indir_uninit) = indirection_table.split_at_mut(indirection_table_size);
3461        for (src, dest) in std::iter::repeat_with(|| indir_init.iter().copied())
3462            .flatten()
3463            .zip(indir_uninit)
3464        {
3465            *dest = src;
3466        }
3467        primary.rss_state = Some(RssState {
3468            key,
3469            indirection_table: indirection_table.iter().map(|&x| x as u16).collect(),
3470        });
3471        Ok(())
3472    }
3473
3474    fn oid_set_packet_filter(
3475        &self,
3476        reader: impl MemoryRead + Clone,
3477    ) -> Result<Option<u32>, OidError> {
3478        let filter: rndisprot::RndisPacketFilterOidValue = reader.clone().read_plain()?;
3479        tracing::debug!(filter, "set packet filter");
3480        Ok(Some(filter))
3481    }
3482
3483    fn oid_set_offload_parameters(
3484        &self,
3485        reader: impl MemoryRead + Clone,
3486        primary: &mut PrimaryChannelState,
3487    ) -> Result<(), OidError> {
3488        let offload: rndisprot::NdisOffloadParameters = read_ndis_object(
3489            reader,
3490            rndisprot::NdisObjectType::DEFAULT,
3491            1,
3492            rndisprot::NDIS_SIZEOF_OFFLOAD_PARAMETERS_REVISION_1,
3493        )?;
3494
3495        tracing::debug!(?offload, "offload parameters");
3496        let rndisprot::NdisOffloadParameters {
3497            header: _,
3498            ipv4_checksum,
3499            tcp4_checksum,
3500            udp4_checksum,
3501            tcp6_checksum,
3502            udp6_checksum,
3503            lsov1,
3504            ipsec_v1: _,
3505            lsov2_ipv4,
3506            lsov2_ipv6,
3507            tcp_connection_ipv4: _,
3508            tcp_connection_ipv6: _,
3509            reserved: _,
3510            flags: _,
3511        } = offload;
3512
3513        if lsov1 == rndisprot::OffloadParametersSimple::ENABLED {
3514            return Err(OidError::NotSupported("lsov1"));
3515        }
3516        if let Some((tx, rx)) = ipv4_checksum.tx_rx() {
3517            primary.offload_config.checksum_tx.ipv4_header = tx;
3518            primary.offload_config.checksum_rx.ipv4_header = rx;
3519        }
3520        if let Some((tx, rx)) = tcp4_checksum.tx_rx() {
3521            primary.offload_config.checksum_tx.tcp4 = tx;
3522            primary.offload_config.checksum_rx.tcp4 = rx;
3523        }
3524        if let Some((tx, rx)) = tcp6_checksum.tx_rx() {
3525            primary.offload_config.checksum_tx.tcp6 = tx;
3526            primary.offload_config.checksum_rx.tcp6 = rx;
3527        }
3528        if let Some((tx, rx)) = udp4_checksum.tx_rx() {
3529            primary.offload_config.checksum_tx.udp4 = tx;
3530            primary.offload_config.checksum_rx.udp4 = rx;
3531        }
3532        if let Some((tx, rx)) = udp6_checksum.tx_rx() {
3533            primary.offload_config.checksum_tx.udp6 = tx;
3534            primary.offload_config.checksum_rx.udp6 = rx;
3535        }
3536        if let Some(enable) = lsov2_ipv4.enable() {
3537            primary.offload_config.lso4 = enable;
3538        }
3539        if let Some(enable) = lsov2_ipv6.enable() {
3540            primary.offload_config.lso6 = enable;
3541        }
3542        primary
3543            .offload_config
3544            .mask_to_supported(&self.offload_support);
3545        primary.pending_offload_change = true;
3546        Ok(())
3547    }
3548
3549    fn oid_set_offload_encapsulation(
3550        &self,
3551        reader: impl MemoryRead + Clone,
3552    ) -> Result<(), OidError> {
3553        let encap: rndisprot::NdisOffloadEncapsulation = read_ndis_object(
3554            reader,
3555            rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3556            1,
3557            rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1,
3558        )?;
3559        if encap.ipv4_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3560            && (encap.ipv4_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3561                || encap.ipv4_header_size != ETHERNET_HEADER_LEN)
3562        {
3563            return Err(OidError::NotSupported("ipv4 encap"));
3564        }
3565        if encap.ipv6_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3566            && (encap.ipv6_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3567                || encap.ipv6_header_size != ETHERNET_HEADER_LEN)
3568        {
3569            return Err(OidError::NotSupported("ipv6 encap"));
3570        }
3571        Ok(())
3572    }
3573
3574    fn oid_set_rndis_config_parameter(
3575        &self,
3576        reader: impl MemoryRead + Clone,
3577        primary: &mut PrimaryChannelState,
3578    ) -> Result<(), OidError> {
3579        let info: rndisprot::RndisConfigParameterInfo = reader.clone().read_plain()?;
3580        if info.name_length > 255 {
3581            return Err(OidError::InvalidInput("name_length"));
3582        }
3583        if info.value_length > 255 {
3584            return Err(OidError::InvalidInput("value_length"));
3585        }
3586        let name = reader
3587            .clone()
3588            .skip(info.name_offset as usize)?
3589            .read_n::<u16>(info.name_length as usize / 2)?;
3590        let name = String::from_utf16(&name).map_err(|_| OidError::InvalidInput("name"))?;
3591        let mut value = reader;
3592        value.skip(info.value_offset as usize)?;
3593        let mut value = value.limit(info.value_length as usize);
3594        match info.parameter_type {
3595            rndisprot::NdisParameterType::STRING => {
3596                let value = value.read_n::<u16>(info.value_length as usize / 2)?;
3597                let value =
3598                    String::from_utf16(&value).map_err(|_| OidError::InvalidInput("value"))?;
3599                let as_num = value.as_bytes().first().map_or(0, |c| c - b'0');
3600                let tx = as_num & 1 != 0;
3601                let rx = as_num & 2 != 0;
3602
3603                tracing::debug!(name, value, "rndis config");
3604                match name.as_str() {
3605                    "*IPChecksumOffloadIPv4" => {
3606                        primary.offload_config.checksum_tx.ipv4_header = tx;
3607                        primary.offload_config.checksum_rx.ipv4_header = rx;
3608                    }
3609                    "*LsoV2IPv4" => {
3610                        primary.offload_config.lso4 = as_num != 0;
3611                    }
3612                    "*LsoV2IPv6" => {
3613                        primary.offload_config.lso6 = as_num != 0;
3614                    }
3615                    "*TCPChecksumOffloadIPv4" => {
3616                        primary.offload_config.checksum_tx.tcp4 = tx;
3617                        primary.offload_config.checksum_rx.tcp4 = rx;
3618                    }
3619                    "*TCPChecksumOffloadIPv6" => {
3620                        primary.offload_config.checksum_tx.tcp6 = tx;
3621                        primary.offload_config.checksum_rx.tcp6 = rx;
3622                    }
3623                    "*UDPChecksumOffloadIPv4" => {
3624                        primary.offload_config.checksum_tx.udp4 = tx;
3625                        primary.offload_config.checksum_rx.udp4 = rx;
3626                    }
3627                    "*UDPChecksumOffloadIPv6" => {
3628                        primary.offload_config.checksum_tx.udp6 = tx;
3629                        primary.offload_config.checksum_rx.udp6 = rx;
3630                    }
3631                    _ => {}
3632                }
3633                primary
3634                    .offload_config
3635                    .mask_to_supported(&self.offload_support);
3636            }
3637            rndisprot::NdisParameterType::INTEGER => {
3638                let value: u32 = value.read_plain()?;
3639                tracing::debug!(name, value, "rndis config");
3640            }
3641            parameter_type => tracelimit::warn_ratelimited!(
3642                name,
3643                ?parameter_type,
3644                "unhandled rndis config parameter type"
3645            ),
3646        }
3647        Ok(())
3648    }
3649}
3650
3651fn read_ndis_object<T: IntoBytes + FromBytes + Debug + Immutable + KnownLayout>(
3652    mut reader: impl MemoryRead,
3653    object_type: rndisprot::NdisObjectType,
3654    min_revision: u8,
3655    min_size: usize,
3656) -> Result<T, OidError> {
3657    let mut buffer = T::new_zeroed();
3658    let sent_size = reader.len();
3659    let len = sent_size.min(size_of_val(&buffer));
3660    reader.read(&mut buffer.as_mut_bytes()[..len])?;
3661    validate_ndis_object_header(
3662        &rndisprot::NdisObjectHeader::read_from_prefix(buffer.as_bytes())
3663            .unwrap()
3664            .0, // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
3665        sent_size,
3666        object_type,
3667        min_revision,
3668        min_size,
3669    )?;
3670    Ok(buffer)
3671}
3672
3673fn validate_ndis_object_header(
3674    header: &rndisprot::NdisObjectHeader,
3675    sent_size: usize,
3676    object_type: rndisprot::NdisObjectType,
3677    min_revision: u8,
3678    min_size: usize,
3679) -> Result<(), OidError> {
3680    if header.object_type != object_type {
3681        return Err(OidError::InvalidInput("header.object_type"));
3682    }
3683    if sent_size < header.size as usize {
3684        return Err(OidError::InvalidInput("header.size"));
3685    }
3686    if header.revision < min_revision {
3687        return Err(OidError::InvalidInput("header.revision"));
3688    }
3689    if (header.size as usize) < min_size {
3690        return Err(OidError::InvalidInput("header.size"));
3691    }
3692    Ok(())
3693}
3694
3695struct Coordinator {
3696    recv: mpsc::Receiver<CoordinatorMessage>,
3697    channel_control: ChannelControl,
3698    restart: bool,
3699    workers: Vec<TaskControl<NetQueue, Worker<GpadlRingMem>>>,
3700    buffers: Option<Arc<ChannelBuffers>>,
3701    num_queues: u16,
3702}
3703
3704/// Removing the VF may result in the guest sending messages to switch the data
3705/// path, so these operations need to happen asynchronously with message
3706/// processing.
3707enum CoordinatorStatePendingVfState {
3708    /// No pending updates.
3709    Ready,
3710    /// Delay before adding VF.
3711    Delay {
3712        timer: PolledTimer,
3713        delay_until: Instant,
3714    },
3715    /// A VF update is pending.
3716    Pending,
3717}
3718
3719struct CoordinatorState {
3720    endpoint: Box<dyn Endpoint>,
3721    adapter: Arc<Adapter>,
3722    virtual_function: Option<Box<dyn VirtualFunction>>,
3723    pending_vf_state: CoordinatorStatePendingVfState,
3724}
3725
3726impl InspectTaskMut<Coordinator> for CoordinatorState {
3727    fn inspect_mut(
3728        &mut self,
3729        req: inspect::Request<'_>,
3730        mut coordinator: Option<&mut Coordinator>,
3731    ) {
3732        let mut resp = req.respond();
3733
3734        let adapter = self.adapter.as_ref();
3735        resp.field("mac_address", adapter.mac_address)
3736            .field("max_queues", adapter.max_queues)
3737            .sensitivity_field(
3738                "offload_support",
3739                SensitivityLevel::Safe,
3740                &adapter.offload_support,
3741            )
3742            .field_mut_with("ring_size_limit", |v| -> anyhow::Result<_> {
3743                if let Some(v) = v {
3744                    let v: usize = v.parse()?;
3745                    adapter.ring_size_limit.store(v, Ordering::Relaxed);
3746                    // Bounce each task so that it sees the new value.
3747                    if let Some(this) = &mut coordinator {
3748                        for worker in &mut this.workers {
3749                            worker.update_with(|_, _| ());
3750                        }
3751                    }
3752                }
3753                Ok(adapter.ring_size_limit.load(Ordering::Relaxed))
3754            });
3755
3756        resp.field("endpoint_type", self.endpoint.endpoint_type())
3757            .field(
3758                "endpoint_max_queues",
3759                self.endpoint.multiqueue_support().max_queues,
3760            )
3761            .sensitivity_field_mut("endpoint", SensitivityLevel::Safe, self.endpoint.as_mut());
3762
3763        if let Some(coordinator) = coordinator {
3764            resp.sensitivity_child("queues", SensitivityLevel::Safe, |req| {
3765                let mut resp = req.respond();
3766                for (i, q) in coordinator.workers[..coordinator.num_queues as usize]
3767                    .iter_mut()
3768                    .enumerate()
3769                {
3770                    resp.field_mut(&i.to_string(), q);
3771                }
3772            });
3773
3774            // Get the shared channel state from the primary channel.
3775            resp.merge(inspect::adhoc_mut(|req| {
3776                let deferred = req.defer();
3777                coordinator.workers[0].update_with(|_, worker| {
3778                    let Some(worker) = worker.as_deref() else {
3779                        return;
3780                    };
3781                    if let Some(state) = worker.state.ready() {
3782                        deferred.respond(|resp| {
3783                            resp.merge(&state.buffers);
3784                            resp.sensitivity_field(
3785                                "primary_channel_state",
3786                                SensitivityLevel::Safe,
3787                                &state.state.primary,
3788                            )
3789                            .sensitivity_field(
3790                                "packet_filter",
3791                                SensitivityLevel::Safe,
3792                                inspect::AsHex(worker.channel.packet_filter),
3793                            );
3794                        });
3795                    }
3796                })
3797            }));
3798        }
3799    }
3800}
3801
3802impl AsyncRun<Coordinator> for CoordinatorState {
3803    async fn run(
3804        &mut self,
3805        stop: &mut StopTask<'_>,
3806        coordinator: &mut Coordinator,
3807    ) -> Result<(), task_control::Cancelled> {
3808        coordinator.process(stop, self).await
3809    }
3810}
3811
3812impl Coordinator {
3813    async fn process(
3814        &mut self,
3815        stop: &mut StopTask<'_>,
3816        state: &mut CoordinatorState,
3817    ) -> Result<(), task_control::Cancelled> {
3818        let mut sleep_duration: Option<Instant> = None;
3819        loop {
3820            if self.restart {
3821                stop.until_stopped(self.stop_workers()).await?;
3822                // The queue restart operation is not restartable, so do not
3823                // poll on `stop` here.
3824                if let Err(err) = self
3825                    .restart_queues(state)
3826                    .instrument(tracing::info_span!("netvsp_restart_queues"))
3827                    .await
3828                {
3829                    tracing::error!(
3830                        error = &err as &dyn std::error::Error,
3831                        "failed to restart queues"
3832                    );
3833                }
3834                if let Some(primary) = self.primary_mut() {
3835                    primary.is_data_path_switched =
3836                        state.endpoint.get_data_path_to_guest_vf().await.ok();
3837                    tracing::info!(
3838                        is_data_path_switched = primary.is_data_path_switched,
3839                        "Query data path state"
3840                    );
3841                }
3842                self.restore_guest_vf_state(state).await;
3843                self.restart = false;
3844            }
3845
3846            // Ensure that all workers except the primary are started. The
3847            // primary is started below if there are no outstanding messages.
3848            for worker in &mut self.workers[1..] {
3849                worker.start();
3850            }
3851
3852            enum Message {
3853                Internal(CoordinatorMessage),
3854                ChannelDisconnected,
3855                UpdateFromEndpoint(EndpointAction),
3856                UpdateFromVf(Rpc<(), ()>),
3857                OfferVfDevice,
3858                PendingVfStateComplete,
3859                TimerExpired,
3860            }
3861            let message = if matches!(
3862                state.pending_vf_state,
3863                CoordinatorStatePendingVfState::Pending
3864            ) {
3865                // The primary worker is allowed to run, but as no
3866                // notifications are being processed, if it is waiting for an
3867                // action then it should remain stopped until this completes
3868                // and the regular message processing logic resumes. Currently
3869                // the only message that requires processing is
3870                // DataPathSwitchPending, so check for that here.
3871                if !self.workers[0].is_running()
3872                    && self.primary_mut().is_none_or(|primary| {
3873                        !matches!(
3874                            primary.guest_vf_state,
3875                            PrimaryChannelGuestVfState::DataPathSwitchPending { result: None, .. }
3876                        )
3877                    })
3878                {
3879                    self.workers[0].start();
3880                }
3881
3882                // guest_ready_for_device is not restartable, so do not poll on
3883                // stop.
3884                state
3885                    .virtual_function
3886                    .as_mut()
3887                    .expect("Pending requires a VF")
3888                    .guest_ready_for_device()
3889                    .await;
3890                Message::PendingVfStateComplete
3891            } else {
3892                let timer_sleep = async {
3893                    if let Some(sleep_duration) = sleep_duration {
3894                        let mut timer = PolledTimer::new(&state.adapter.driver);
3895                        timer.sleep_until(sleep_duration).await;
3896                    } else {
3897                        pending::<()>().await;
3898                    }
3899                    Message::TimerExpired
3900                };
3901                let wait_for_message = async {
3902                    let internal_msg = self
3903                        .recv
3904                        .next()
3905                        .map(|x| x.map_or(Message::ChannelDisconnected, Message::Internal));
3906                    let endpoint_restart = state
3907                        .endpoint
3908                        .wait_for_endpoint_action()
3909                        .map(Message::UpdateFromEndpoint);
3910                    if let Some(vf) = state.virtual_function.as_mut() {
3911                        match state.pending_vf_state {
3912                            CoordinatorStatePendingVfState::Ready
3913                            | CoordinatorStatePendingVfState::Delay { .. } => {
3914                                let offer_device = async {
3915                                    if let CoordinatorStatePendingVfState::Delay {
3916                                        timer,
3917                                        delay_until,
3918                                    } = &mut state.pending_vf_state
3919                                    {
3920                                        timer.sleep_until(*delay_until).await;
3921                                    } else {
3922                                        pending::<()>().await;
3923                                    }
3924                                    Message::OfferVfDevice
3925                                };
3926                                (
3927                                    internal_msg,
3928                                    offer_device,
3929                                    endpoint_restart,
3930                                    vf.wait_for_state_change().map(Message::UpdateFromVf),
3931                                    timer_sleep,
3932                                )
3933                                    .race()
3934                                    .await
3935                            }
3936                            CoordinatorStatePendingVfState::Pending => unreachable!(),
3937                        }
3938                    } else {
3939                        (internal_msg, endpoint_restart, timer_sleep).race().await
3940                    }
3941                };
3942
3943                let mut wait_for_message = std::pin::pin!(wait_for_message);
3944                match (&mut wait_for_message).now_or_never() {
3945                    Some(message) => message,
3946                    None => {
3947                        self.workers[0].start();
3948                        stop.until_stopped(wait_for_message).await?
3949                    }
3950                }
3951            };
3952            match message {
3953                Message::UpdateFromVf(rpc) => {
3954                    rpc.handle(async |_| {
3955                        self.update_guest_vf_state(state).await;
3956                    })
3957                    .await;
3958                }
3959                Message::OfferVfDevice => {
3960                    self.workers[0].stop().await;
3961                    if let Some(primary) = self.primary_mut() {
3962                        if matches!(
3963                            primary.guest_vf_state,
3964                            PrimaryChannelGuestVfState::AvailableAdvertised
3965                        ) {
3966                            primary.guest_vf_state = PrimaryChannelGuestVfState::Ready;
3967                        }
3968                    }
3969
3970                    state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
3971                }
3972                Message::PendingVfStateComplete => {
3973                    state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
3974                }
3975                Message::TimerExpired => {
3976                    // Kick the worker as requested.
3977                    if self.workers[0].is_running() {
3978                        self.workers[0].stop().await;
3979                        if let Some(primary) = self.primary_mut() {
3980                            if let PendingLinkAction::Delay(up) = primary.pending_link_action {
3981                                primary.pending_link_action = PendingLinkAction::Active(up);
3982                            }
3983                        }
3984                    }
3985                    sleep_duration = None;
3986                }
3987                Message::Internal(CoordinatorMessage::Update(update_type)) => {
3988                    if update_type.filter_state {
3989                        self.stop_workers().await;
3990                        let worker_0_packet_filter =
3991                            self.workers[0].state().unwrap().channel.packet_filter;
3992                        self.workers.iter_mut().skip(1).for_each(|worker| {
3993                            if let Some(state) = worker.state_mut() {
3994                                state.channel.packet_filter = worker_0_packet_filter;
3995                                tracing::debug!(
3996                                    packet_filter = ?worker_0_packet_filter,
3997                                    channel_idx = state.channel_idx,
3998                                    "update packet filter"
3999                                );
4000                            }
4001                        });
4002                    }
4003
4004                    if update_type.guest_vf_state {
4005                        self.update_guest_vf_state(state).await;
4006                    }
4007                }
4008                Message::UpdateFromEndpoint(EndpointAction::RestartRequired) => self.restart = true,
4009                Message::UpdateFromEndpoint(EndpointAction::LinkStatusNotify(connect)) => {
4010                    self.workers[0].stop().await;
4011
4012                    // These are the only link state transitions that are tracked.
4013                    // 1. up -> down or down -> up
4014                    // 2. up -> down -> up or down -> up -> down.
4015                    // All other state transitions are coalesced into one of the above cases.
4016                    // For example, up -> down -> up -> down is treated as up -> down.
4017                    // N.B - Always queue up the incoming state to minimize the effects of loss
4018                    //       of any notifications (for example, during vtl2 servicing).
4019                    if let Some(primary) = self.primary_mut() {
4020                        primary.pending_link_action = PendingLinkAction::Active(connect);
4021                    }
4022
4023                    // If there is any existing sleep timer running, cancel it out.
4024                    sleep_duration = None;
4025                }
4026                Message::Internal(CoordinatorMessage::Restart) => self.restart = true,
4027                Message::Internal(CoordinatorMessage::StartTimer(duration)) => {
4028                    sleep_duration = Some(duration);
4029                    // Restart primary task.
4030                    self.workers[0].stop().await;
4031                }
4032                Message::ChannelDisconnected => {
4033                    break;
4034                }
4035            };
4036        }
4037        Ok(())
4038    }
4039
4040    async fn stop_workers(&mut self) {
4041        for worker in &mut self.workers {
4042            worker.stop().await;
4043        }
4044    }
4045
4046    async fn restore_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
4047        let primary = match self.primary_mut() {
4048            Some(primary) => primary,
4049            None => return,
4050        };
4051
4052        // Update guest VF state based on current endpoint properties.
4053        let virtual_function = c_state.virtual_function.as_mut();
4054        let guest_vf_id = match &virtual_function {
4055            Some(vf) => vf.id().await,
4056            None => None,
4057        };
4058        if let Some(guest_vf_id) = guest_vf_id {
4059            // Ensure guest VF is in proper state.
4060            match primary.guest_vf_state {
4061                PrimaryChannelGuestVfState::AvailableAdvertised
4062                | PrimaryChannelGuestVfState::Restoring(
4063                    saved_state::GuestVfState::AvailableAdvertised,
4064                ) => {
4065                    if !primary.is_data_path_switched.unwrap_or(false) {
4066                        let timer = PolledTimer::new(&c_state.adapter.driver);
4067                        c_state.pending_vf_state = CoordinatorStatePendingVfState::Delay {
4068                            timer,
4069                            delay_until: Instant::now() + VF_DEVICE_DELAY,
4070                        };
4071                    }
4072                }
4073                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4074                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4075                | PrimaryChannelGuestVfState::Ready
4076                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready)
4077                | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4078                | PrimaryChannelGuestVfState::Restoring(
4079                    saved_state::GuestVfState::DataPathSwitchPending { .. },
4080                )
4081                | PrimaryChannelGuestVfState::DataPathSwitched
4082                | PrimaryChannelGuestVfState::Restoring(
4083                    saved_state::GuestVfState::DataPathSwitched,
4084                )
4085                | PrimaryChannelGuestVfState::DataPathSynthetic => {
4086                    c_state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4087                }
4088                _ => (),
4089            };
4090            // ensure data path is switched as expected
4091            if let PrimaryChannelGuestVfState::Restoring(
4092                saved_state::GuestVfState::DataPathSwitchPending {
4093                    to_guest,
4094                    id,
4095                    result,
4096                },
4097            ) = primary.guest_vf_state
4098            {
4099                // If the save was after the data path switch already occurred, don't do it again.
4100                if result.is_some() {
4101                    primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
4102                        to_guest,
4103                        id,
4104                        result,
4105                    };
4106                    return;
4107                }
4108            }
4109            primary.guest_vf_state = match primary.guest_vf_state {
4110                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4111                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4112                | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4113                | PrimaryChannelGuestVfState::Restoring(
4114                    saved_state::GuestVfState::DataPathSwitchPending { .. },
4115                )
4116                | PrimaryChannelGuestVfState::DataPathSynthetic => {
4117                    let (to_guest, id) = match primary.guest_vf_state {
4118                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4119                            to_guest,
4120                            id,
4121                        }
4122                        | PrimaryChannelGuestVfState::DataPathSwitchPending {
4123                            to_guest, id, ..
4124                        }
4125                        | PrimaryChannelGuestVfState::Restoring(
4126                            saved_state::GuestVfState::DataPathSwitchPending {
4127                                to_guest, id, ..
4128                            },
4129                        ) => (to_guest, id),
4130                        _ => (true, None),
4131                    };
4132                    // Cancel any outstanding delay timers for VF offers if the data path is
4133                    // getting switched, since the guest is already issuing
4134                    // commands assuming a VF.
4135                    if matches!(
4136                        c_state.pending_vf_state,
4137                        CoordinatorStatePendingVfState::Delay { .. }
4138                    ) {
4139                        c_state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4140                    }
4141                    let result = c_state.endpoint.set_data_path_to_guest_vf(to_guest).await;
4142                    let result = if let Err(err) = result {
4143                        tracing::error!(
4144                            err = err.as_ref() as &dyn std::error::Error,
4145                            to_guest,
4146                            "Failed to switch guest VF data path"
4147                        );
4148                        false
4149                    } else {
4150                        primary.is_data_path_switched = Some(to_guest);
4151                        true
4152                    };
4153                    match primary.guest_vf_state {
4154                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4155                            ..
4156                        }
4157                        | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4158                        | PrimaryChannelGuestVfState::Restoring(
4159                            saved_state::GuestVfState::DataPathSwitchPending { .. },
4160                        ) => PrimaryChannelGuestVfState::DataPathSwitchPending {
4161                            to_guest,
4162                            id,
4163                            result: Some(result),
4164                        },
4165                        _ if result => PrimaryChannelGuestVfState::DataPathSwitched,
4166                        _ => PrimaryChannelGuestVfState::DataPathSynthetic,
4167                    }
4168                }
4169                PrimaryChannelGuestVfState::Initializing
4170                | PrimaryChannelGuestVfState::Unavailable
4171                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
4172                    PrimaryChannelGuestVfState::Available { vfid: guest_vf_id }
4173                }
4174                PrimaryChannelGuestVfState::AvailableAdvertised
4175                | PrimaryChannelGuestVfState::Restoring(
4176                    saved_state::GuestVfState::AvailableAdvertised,
4177                ) => {
4178                    if !primary.is_data_path_switched.unwrap_or(false) {
4179                        PrimaryChannelGuestVfState::AvailableAdvertised
4180                    } else {
4181                        // A previous instantiation already switched the data
4182                        // path.
4183                        PrimaryChannelGuestVfState::DataPathSwitched
4184                    }
4185                }
4186                PrimaryChannelGuestVfState::DataPathSwitched
4187                | PrimaryChannelGuestVfState::Restoring(
4188                    saved_state::GuestVfState::DataPathSwitched,
4189                ) => PrimaryChannelGuestVfState::DataPathSwitched,
4190                PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4191                    PrimaryChannelGuestVfState::Ready
4192                }
4193                _ => primary.guest_vf_state,
4194            };
4195        } else {
4196            // If the device was just removed, make sure the the data path is synthetic.
4197            match primary.guest_vf_state {
4198                PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, .. }
4199                | PrimaryChannelGuestVfState::Restoring(
4200                    saved_state::GuestVfState::DataPathSwitchPending { to_guest, .. },
4201                ) => {
4202                    if !to_guest {
4203                        if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4204                            tracing::warn!(
4205                                err = err.as_ref() as &dyn std::error::Error,
4206                                "Failed setting data path back to synthetic after guest VF was removed."
4207                            );
4208                        }
4209                        primary.is_data_path_switched = Some(false);
4210                    }
4211                }
4212                PrimaryChannelGuestVfState::DataPathSwitched
4213                | PrimaryChannelGuestVfState::Restoring(
4214                    saved_state::GuestVfState::DataPathSwitched,
4215                ) => {
4216                    if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4217                        tracing::warn!(
4218                            err = err.as_ref() as &dyn std::error::Error,
4219                            "Failed setting data path back to synthetic after guest VF was removed."
4220                        );
4221                    }
4222                    primary.is_data_path_switched = Some(false);
4223                }
4224                _ => (),
4225            }
4226            if let PrimaryChannelGuestVfState::AvailableAdvertised = primary.guest_vf_state {
4227                c_state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
4228            }
4229            // Notify guest if VF is no longer available
4230            primary.guest_vf_state = match primary.guest_vf_state {
4231                PrimaryChannelGuestVfState::Initializing
4232                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState)
4233                | PrimaryChannelGuestVfState::Available { .. } => {
4234                    PrimaryChannelGuestVfState::Unavailable
4235                }
4236                PrimaryChannelGuestVfState::AvailableAdvertised
4237                | PrimaryChannelGuestVfState::Restoring(
4238                    saved_state::GuestVfState::AvailableAdvertised,
4239                )
4240                | PrimaryChannelGuestVfState::Ready
4241                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4242                    PrimaryChannelGuestVfState::UnavailableFromAvailable
4243                }
4244                PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, id, .. }
4245                | PrimaryChannelGuestVfState::Restoring(
4246                    saved_state::GuestVfState::DataPathSwitchPending { to_guest, id, .. },
4247                ) => PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4248                    to_guest,
4249                    id,
4250                },
4251                PrimaryChannelGuestVfState::DataPathSwitched
4252                | PrimaryChannelGuestVfState::Restoring(
4253                    saved_state::GuestVfState::DataPathSwitched,
4254                )
4255                | PrimaryChannelGuestVfState::DataPathSynthetic => {
4256                    PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4257                }
4258                _ => primary.guest_vf_state,
4259            }
4260        }
4261    }
4262
4263    async fn restart_queues(&mut self, c_state: &mut CoordinatorState) -> Result<(), WorkerError> {
4264        // Drop all the queues and stop the endpoint. Collect the worker drivers to pass to the queues.
4265        let drivers = self
4266            .workers
4267            .iter_mut()
4268            .map(|worker| {
4269                let task = worker.task_mut();
4270                task.queue_state = None;
4271                task.driver.clone()
4272            })
4273            .collect::<Vec<_>>();
4274
4275        c_state.endpoint.stop().await;
4276
4277        let (primary_worker, subworkers) = if let [primary, sub @ ..] = self.workers.as_mut_slice()
4278        {
4279            (primary, sub)
4280        } else {
4281            unreachable!()
4282        };
4283
4284        let state = primary_worker
4285            .state_mut()
4286            .and_then(|worker| worker.state.ready_mut());
4287
4288        let state = if let Some(state) = state {
4289            state
4290        } else {
4291            return Ok(());
4292        };
4293
4294        // Save the channel buffers for use in the subchannel workers.
4295        self.buffers = Some(state.buffers.clone());
4296
4297        let num_queues = state.state.primary.as_ref().unwrap().requested_num_queues;
4298        let mut active_queues = Vec::new();
4299        let active_queue_count =
4300            if let Some(rss_state) = state.state.primary.as_ref().unwrap().rss_state.as_ref() {
4301                // Active queue count is computed as the number of unique entries in the indirection table
4302                active_queues.clone_from(&rss_state.indirection_table);
4303                active_queues.sort();
4304                active_queues.dedup();
4305                active_queues = active_queues
4306                    .into_iter()
4307                    .filter(|&index| index < num_queues)
4308                    .collect::<Vec<_>>();
4309                if !active_queues.is_empty() {
4310                    active_queues.len() as u16
4311                } else {
4312                    tracelimit::warn_ratelimited!("Invalid RSS indirection table");
4313                    num_queues
4314                }
4315            } else {
4316                num_queues
4317            };
4318
4319        // Distribute the rx buffers to only the active queues.
4320        let (ranges, mut remote_buffer_id_recvs) =
4321            RxBufferRanges::new(state.buffers.recv_buffer.count, active_queue_count.into());
4322        let ranges = Arc::new(ranges);
4323
4324        let mut queues = Vec::new();
4325        let mut rx_buffers = Vec::new();
4326        {
4327            let buffers = &state.buffers;
4328            let guest_buffers = Arc::new(
4329                GuestBuffers::new(
4330                    buffers.mem.clone(),
4331                    buffers.recv_buffer.gpadl.clone(),
4332                    buffers.recv_buffer.sub_allocation_size,
4333                    buffers.ndis_config.mtu,
4334                )
4335                .map_err(WorkerError::GpadlError)?,
4336            );
4337
4338            // Get the list of free rx buffers from each task, then partition
4339            // the list per-active-queue, and produce the queue configuration.
4340            let mut queue_config = Vec::new();
4341            let initial_rx;
4342            {
4343                let states = std::iter::once(Some(&*state)).chain(
4344                    subworkers
4345                        .iter()
4346                        .map(|worker| worker.state().and_then(|worker| worker.state.ready())),
4347                );
4348
4349                initial_rx = (RX_RESERVED_CONTROL_BUFFERS..state.buffers.recv_buffer.count)
4350                    .filter(|&n| states.clone().flatten().all(|s| s.state.rx_bufs.is_free(n)))
4351                    .map(RxId)
4352                    .collect::<Vec<_>>();
4353
4354                let mut initial_rx = initial_rx.as_slice();
4355                let mut range_start = 0;
4356                let primary_queue_excluded = !active_queues.is_empty() && active_queues[0] != 0;
4357                let first_queue = if !primary_queue_excluded {
4358                    0
4359                } else {
4360                    // If the primary queue is excluded from the guest supplied
4361                    // indirection table, it is assigned just the reserved
4362                    // buffers.
4363                    queue_config.push(QueueConfig {
4364                        pool: Box::new(BufferPool::new(guest_buffers.clone())),
4365                        initial_rx: &[],
4366                        driver: Box::new(drivers[0].clone()),
4367                    });
4368                    rx_buffers.push(RxBufferRange::new(
4369                        ranges.clone(),
4370                        0..RX_RESERVED_CONTROL_BUFFERS,
4371                        None,
4372                    ));
4373                    range_start = RX_RESERVED_CONTROL_BUFFERS;
4374                    1
4375                };
4376                for queue_index in first_queue..num_queues {
4377                    let queue_active = active_queues.is_empty()
4378                        || active_queues.binary_search(&queue_index).is_ok();
4379                    let (range_end, end, buffer_id_recv) = if queue_active {
4380                        let range_end = if rx_buffers.len() as u16 == active_queue_count - 1 {
4381                            // The last queue gets all the remaining buffers.
4382                            state.buffers.recv_buffer.count
4383                        } else if queue_index == 0 {
4384                            // Queue zero always includes the reserved buffers.
4385                            RX_RESERVED_CONTROL_BUFFERS + ranges.buffers_per_queue
4386                        } else {
4387                            range_start + ranges.buffers_per_queue
4388                        };
4389                        (
4390                            range_end,
4391                            initial_rx.partition_point(|id| id.0 < range_end),
4392                            Some(remote_buffer_id_recvs.remove(0)),
4393                        )
4394                    } else {
4395                        (range_start, 0, None)
4396                    };
4397
4398                    let (this, rest) = initial_rx.split_at(end);
4399                    queue_config.push(QueueConfig {
4400                        pool: Box::new(BufferPool::new(guest_buffers.clone())),
4401                        initial_rx: this,
4402                        driver: Box::new(drivers[queue_index as usize].clone()),
4403                    });
4404                    initial_rx = rest;
4405                    rx_buffers.push(RxBufferRange::new(
4406                        ranges.clone(),
4407                        range_start..range_end,
4408                        buffer_id_recv,
4409                    ));
4410
4411                    range_start = range_end;
4412                }
4413            }
4414
4415            let primary = state.state.primary.as_mut().unwrap();
4416            tracing::debug!(num_queues, "enabling endpoint");
4417
4418            let rss = primary
4419                .rss_state
4420                .as_ref()
4421                .map(|rss| net_backend::RssConfig {
4422                    key: &rss.key,
4423                    indirection_table: &rss.indirection_table,
4424                    flags: 0,
4425                });
4426
4427            c_state
4428                .endpoint
4429                .get_queues(queue_config, rss.as_ref(), &mut queues)
4430                .instrument(tracing::info_span!("netvsp_get_queues"))
4431                .await
4432                .map_err(WorkerError::Endpoint)?;
4433
4434            assert_eq!(queues.len(), num_queues as usize);
4435
4436            // Set the subchannel count.
4437            self.channel_control
4438                .enable_subchannels(num_queues - 1)
4439                .expect("already validated");
4440
4441            self.num_queues = num_queues;
4442        }
4443
4444        let worker_0_packet_filter = self.workers[0].state().unwrap().channel.packet_filter;
4445        // Provide the queue and receive buffer ranges for each worker.
4446        for ((worker, queue), rx_buffer) in self.workers.iter_mut().zip(queues).zip(rx_buffers) {
4447            worker.task_mut().queue_state = Some(QueueState {
4448                queue,
4449                target_vp_set: false,
4450                rx_buffer_range: rx_buffer,
4451            });
4452            // Update the receive packet filter for the subchannel worker.
4453            if let Some(worker) = worker.state_mut() {
4454                worker.channel.packet_filter = worker_0_packet_filter;
4455                // Clear any pending RxIds as buffers were redistributed.
4456                if let Some(ready_state) = worker.state.ready_mut() {
4457                    ready_state.state.pending_rx_packets.clear();
4458                }
4459            }
4460        }
4461
4462        Ok(())
4463    }
4464
4465    fn primary_mut(&mut self) -> Option<&mut PrimaryChannelState> {
4466        self.workers[0]
4467            .state_mut()
4468            .unwrap()
4469            .state
4470            .ready_mut()?
4471            .state
4472            .primary
4473            .as_mut()
4474    }
4475
4476    async fn update_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
4477        self.workers[0].stop().await;
4478        self.restore_guest_vf_state(c_state).await;
4479    }
4480}
4481
4482impl<T: RingMem + 'static + Sync> AsyncRun<Worker<T>> for NetQueue {
4483    async fn run(
4484        &mut self,
4485        stop: &mut StopTask<'_>,
4486        worker: &mut Worker<T>,
4487    ) -> Result<(), task_control::Cancelled> {
4488        match worker.process(stop, self).await {
4489            Ok(()) | Err(WorkerError::BufferRevoked) => {}
4490            Err(WorkerError::Cancelled(cancelled)) => return Err(cancelled),
4491            Err(err) => {
4492                tracing::error!(
4493                    channel_idx = worker.channel_idx,
4494                    error = &err as &dyn std::error::Error,
4495                    "netvsp error"
4496                );
4497            }
4498        }
4499        Ok(())
4500    }
4501}
4502
4503impl<T: RingMem + 'static> Worker<T> {
4504    async fn process(
4505        &mut self,
4506        stop: &mut StopTask<'_>,
4507        queue: &mut NetQueue,
4508    ) -> Result<(), WorkerError> {
4509        // Be careful not to wait on actions with unbounded blocking time (e.g.
4510        // guest actions, or waiting for network packets to arrive) without
4511        // wrapping the wait on `stop.until_stopped`.
4512        loop {
4513            match &mut self.state {
4514                WorkerState::Init(initializing) => {
4515                    if self.channel_idx != 0 {
4516                        // Still waiting for the coordinator to provide receive
4517                        // buffers. The task will be restarted when they are available.
4518                        stop.until_stopped(pending()).await?
4519                    }
4520
4521                    tracelimit::info_ratelimited!("network accepted");
4522
4523                    let (buffers, state) = stop
4524                        .until_stopped(self.channel.initialize(initializing, self.mem.clone()))
4525                        .await??;
4526
4527                    let state = ReadyState {
4528                        buffers: Arc::new(buffers),
4529                        state,
4530                        data: ProcessingData::new(),
4531                    };
4532
4533                    // Wake up the coordinator task to start the queues.
4534                    let _ = self.coordinator_send.try_send(CoordinatorMessage::Restart);
4535
4536                    tracelimit::info_ratelimited!("network initialized");
4537                    self.state = WorkerState::Ready(state);
4538                }
4539                WorkerState::Ready(state) => {
4540                    let queue_state = if let Some(queue_state) = &mut queue.queue_state {
4541                        if !queue_state.target_vp_set
4542                            && self.target_vp != vmbus_core::protocol::VP_INDEX_DISABLE_INTERRUPT
4543                        {
4544                            tracing::debug!(
4545                                channel_idx = self.channel_idx,
4546                                target_vp = self.target_vp,
4547                                "updating target VP"
4548                            );
4549                            queue_state.queue.update_target_vp(self.target_vp).await;
4550                            queue_state.target_vp_set = true;
4551                        }
4552
4553                        queue_state
4554                    } else {
4555                        // This task will be restarted when the queues are ready.
4556                        stop.until_stopped(pending()).await?
4557                    };
4558
4559                    let result = self.channel.main_loop(stop, state, queue_state).await;
4560                    match result {
4561                        Ok(restart) => {
4562                            assert_eq!(self.channel_idx, 0);
4563                            let _ = self.coordinator_send.try_send(restart);
4564                        }
4565                        Err(WorkerError::EndpointRequiresQueueRestart(err)) => {
4566                            tracelimit::warn_ratelimited!(
4567                                err = err.as_ref() as &dyn std::error::Error,
4568                                "Endpoint requires queues to restart",
4569                            );
4570                            if let Err(try_send_err) =
4571                                self.coordinator_send.try_send(CoordinatorMessage::Restart)
4572                            {
4573                                tracing::error!(
4574                                    try_send_err = &try_send_err as &dyn std::error::Error,
4575                                    "failed to restart queues"
4576                                );
4577                                return Err(WorkerError::Endpoint(err));
4578                            }
4579                        }
4580                        Err(err) => return Err(err),
4581                    }
4582
4583                    stop.until_stopped(pending()).await?
4584                }
4585            }
4586        }
4587    }
4588}
4589
4590impl<T: 'static + RingMem> NetChannel<T> {
4591    fn try_next_packet<'a>(
4592        &mut self,
4593        send_buffer: Option<&'a SendBuffer>,
4594        version: Option<Version>,
4595    ) -> Result<Option<Packet<'a>>, WorkerError> {
4596        let (mut read, _) = self.queue.split();
4597        let packet = match read.try_read() {
4598            Ok(packet) => {
4599                parse_packet(&packet, send_buffer, version).map_err(WorkerError::Packet)?
4600            }
4601            Err(queue::TryReadError::Empty) => return Ok(None),
4602            Err(queue::TryReadError::Queue(err)) => return Err(err.into()),
4603        };
4604
4605        tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
4606        Ok(Some(packet))
4607    }
4608
4609    async fn next_packet<'a>(
4610        &mut self,
4611        send_buffer: Option<&'a SendBuffer>,
4612        version: Option<Version>,
4613    ) -> Result<Packet<'a>, WorkerError> {
4614        let (mut read, _) = self.queue.split();
4615        let mut packet_ref = read.read().await?;
4616        let packet =
4617            parse_packet(&packet_ref, send_buffer, version).map_err(WorkerError::Packet)?;
4618        if matches!(packet.data, PacketData::RndisPacket(_)) {
4619            // In WorkerState::Init if an rndis packet is received, assume it is MESSAGE_TYPE_INITIALIZE_MSG
4620            tracing::trace!(target: "netvsp/vmbus", "detected rndis initialization message");
4621            packet_ref.revert();
4622        }
4623        tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
4624        Ok(packet)
4625    }
4626
4627    fn is_ready_to_initialize(initializing: &InitState, allow_missing_send_buffer: bool) -> bool {
4628        (initializing.ndis_config.is_some() || initializing.version < Version::V2)
4629            && initializing.ndis_version.is_some()
4630            && (initializing.send_buffer.is_some() || allow_missing_send_buffer)
4631            && initializing.recv_buffer.is_some()
4632    }
4633
4634    async fn initialize(
4635        &mut self,
4636        initializing: &mut Option<InitState>,
4637        mem: GuestMemory,
4638    ) -> Result<(ChannelBuffers, ActiveState), WorkerError> {
4639        let mut has_init_packet_arrived = false;
4640        loop {
4641            if let Some(initializing) = &mut *initializing {
4642                if Self::is_ready_to_initialize(initializing, false) || has_init_packet_arrived {
4643                    let recv_buffer = initializing.recv_buffer.take().unwrap();
4644                    let send_buffer = initializing.send_buffer.take();
4645                    let state = ActiveState::new(
4646                        Some(PrimaryChannelState::new(
4647                            self.adapter.offload_support.clone(),
4648                        )),
4649                        recv_buffer.count,
4650                    );
4651                    let buffers = ChannelBuffers {
4652                        version: initializing.version,
4653                        mem,
4654                        recv_buffer,
4655                        send_buffer,
4656                        ndis_version: initializing.ndis_version.take().unwrap(),
4657                        ndis_config: initializing.ndis_config.take().unwrap_or(NdisConfig {
4658                            mtu: DEFAULT_MTU,
4659                            capabilities: protocol::NdisConfigCapabilities::new(),
4660                        }),
4661                    };
4662
4663                    break Ok((buffers, state));
4664                }
4665            }
4666
4667            // Wait for enough room in the ring to avoid needing to track
4668            // completion packets.
4669            self.queue
4670                .split()
4671                .1
4672                .wait_ready(ring::PacketSize::completion(protocol::PACKET_SIZE_V61))
4673                .await?;
4674
4675            let packet = self
4676                .next_packet(None, initializing.as_ref().map(|x| x.version))
4677                .await?;
4678
4679            if let Some(initializing) = &mut *initializing {
4680                match packet.data {
4681                    PacketData::SendNdisConfig(config) => {
4682                        if initializing.ndis_config.is_some() {
4683                            return Err(WorkerError::UnexpectedPacketOrder(
4684                                PacketOrderError::SendNdisConfigExists,
4685                            ));
4686                        }
4687
4688                        // As in the vmswitch, if the MTU is invalid then use the default.
4689                        let mtu = if config.mtu >= MIN_MTU && config.mtu <= MAX_MTU {
4690                            config.mtu
4691                        } else {
4692                            DEFAULT_MTU
4693                        };
4694
4695                        // The UEFI client expects a completion packet, which can be empty.
4696                        self.send_completion(packet.transaction_id, &[])?;
4697                        initializing.ndis_config = Some(NdisConfig {
4698                            mtu,
4699                            capabilities: config.capabilities,
4700                        });
4701                    }
4702                    PacketData::SendNdisVersion(version) => {
4703                        if initializing.ndis_version.is_some() {
4704                            return Err(WorkerError::UnexpectedPacketOrder(
4705                                PacketOrderError::SendNdisVersionExists,
4706                            ));
4707                        }
4708
4709                        // The UEFI client expects a completion packet, which can be empty.
4710                        self.send_completion(packet.transaction_id, &[])?;
4711                        initializing.ndis_version = Some(NdisVersion {
4712                            major: version.ndis_major_version,
4713                            minor: version.ndis_minor_version,
4714                        });
4715                    }
4716                    PacketData::SendReceiveBuffer(message) => {
4717                        if initializing.recv_buffer.is_some() {
4718                            return Err(WorkerError::UnexpectedPacketOrder(
4719                                PacketOrderError::SendReceiveBufferExists,
4720                            ));
4721                        }
4722
4723                        let mtu = if let Some(cfg) = &initializing.ndis_config {
4724                            cfg.mtu
4725                        } else if initializing.version < Version::V2 {
4726                            DEFAULT_MTU
4727                        } else {
4728                            return Err(WorkerError::UnexpectedPacketOrder(
4729                                PacketOrderError::SendReceiveBufferMissingMTU,
4730                            ));
4731                        };
4732
4733                        let sub_allocation_size = sub_allocation_size_for_mtu(mtu);
4734
4735                        let recv_buffer = ReceiveBuffer::new(
4736                            &self.gpadl_map,
4737                            message.gpadl_handle,
4738                            message.id,
4739                            sub_allocation_size,
4740                        )?;
4741
4742                        self.send_completion(
4743                            packet.transaction_id,
4744                            &self
4745                                .message(
4746                                    protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER_COMPLETE,
4747                                    protocol::Message1SendReceiveBufferComplete {
4748                                        status: protocol::Status::SUCCESS,
4749                                        num_sections: 1,
4750                                        sections: [protocol::ReceiveBufferSection {
4751                                            offset: 0,
4752                                            sub_allocation_size: recv_buffer.sub_allocation_size,
4753                                            num_sub_allocations: recv_buffer.count,
4754                                            end_offset: recv_buffer.sub_allocation_size
4755                                                * recv_buffer.count,
4756                                        }],
4757                                    },
4758                                )
4759                                .payload(),
4760                        )?;
4761                        initializing.recv_buffer = Some(recv_buffer);
4762                    }
4763                    PacketData::SendSendBuffer(message) => {
4764                        if initializing.send_buffer.is_some() {
4765                            return Err(WorkerError::UnexpectedPacketOrder(
4766                                PacketOrderError::SendSendBufferExists,
4767                            ));
4768                        }
4769
4770                        let send_buffer = SendBuffer::new(&self.gpadl_map, message.gpadl_handle)?;
4771                        self.send_completion(
4772                            packet.transaction_id,
4773                            &self
4774                                .message(
4775                                    protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER_COMPLETE,
4776                                    protocol::Message1SendSendBufferComplete {
4777                                        status: protocol::Status::SUCCESS,
4778                                        section_size: 6144,
4779                                    },
4780                                )
4781                                .payload(),
4782                        )?;
4783
4784                        initializing.send_buffer = Some(send_buffer);
4785                    }
4786                    PacketData::RndisPacket(rndis_packet) => {
4787                        if !Self::is_ready_to_initialize(initializing, true) {
4788                            return Err(WorkerError::UnexpectedPacketOrder(
4789                                PacketOrderError::UnexpectedRndisPacket,
4790                            ));
4791                        }
4792                        tracing::debug!(
4793                            channel_type = rndis_packet.channel_type,
4794                            "RndisPacket received during initialization, assuming MESSAGE_TYPE_INITIALIZE_MSG"
4795                        );
4796                        has_init_packet_arrived = true;
4797                    }
4798                    _ => {
4799                        return Err(WorkerError::UnexpectedPacketOrder(
4800                            PacketOrderError::InvalidPacketData,
4801                        ));
4802                    }
4803                }
4804            } else {
4805                match packet.data {
4806                    PacketData::Init(init) => {
4807                        let requested_version = init.protocol_version;
4808                        let version = check_version(requested_version);
4809                        let mut message = self.message(
4810                            protocol::MESSAGE_TYPE_INIT_COMPLETE,
4811                            protocol::MessageInitComplete {
4812                                deprecated: protocol::INVALID_PROTOCOL_VERSION,
4813                                maximum_mdl_chain_length: 34,
4814                                status: protocol::Status::NONE,
4815                            },
4816                        );
4817                        if let Some(version) = version {
4818                            if version == Version::V1 {
4819                                message.data.deprecated = Version::V1 as u32;
4820                            }
4821                            message.data.status = protocol::Status::SUCCESS;
4822                        } else {
4823                            tracing::debug!(requested_version, "unrecognized version");
4824                        }
4825                        self.send_completion(packet.transaction_id, &message.payload())?;
4826
4827                        if let Some(version) = version {
4828                            tracelimit::info_ratelimited!(?version, "network negotiated");
4829
4830                            if version >= Version::V61 {
4831                                // Update the packet size so that the appropriate padding is
4832                                // appended for picky Windows guests.
4833                                self.packet_size = protocol::PACKET_SIZE_V61;
4834                            }
4835                            *initializing = Some(InitState {
4836                                version,
4837                                ndis_config: None,
4838                                ndis_version: None,
4839                                recv_buffer: None,
4840                                send_buffer: None,
4841                            });
4842                        }
4843                    }
4844                    _ => unreachable!(),
4845                }
4846            }
4847        }
4848    }
4849
4850    async fn main_loop(
4851        &mut self,
4852        stop: &mut StopTask<'_>,
4853        ready_state: &mut ReadyState,
4854        queue_state: &mut QueueState,
4855    ) -> Result<CoordinatorMessage, WorkerError> {
4856        let buffers = &ready_state.buffers;
4857        let state = &mut ready_state.state;
4858        let data = &mut ready_state.data;
4859
4860        let ring_spare_capacity = {
4861            let (_, send) = self.queue.split();
4862            let mut limit = if self.can_use_ring_size_opt {
4863                self.adapter.ring_size_limit.load(Ordering::Relaxed)
4864            } else {
4865                0
4866            };
4867            if limit == 0 {
4868                limit = send.capacity() - 2048;
4869            }
4870            send.capacity() - limit
4871        };
4872
4873        // If the packet filter has changed to allow rx packets, add any pended RxIds.
4874        if !state.pending_rx_packets.is_empty()
4875            && self.packet_filter != rndisprot::NDIS_PACKET_TYPE_NONE
4876        {
4877            let epqueue = queue_state.queue.as_mut();
4878            let (front, back) = state.pending_rx_packets.as_slices();
4879            epqueue.rx_avail(front);
4880            epqueue.rx_avail(back);
4881            state.pending_rx_packets.clear();
4882        }
4883
4884        // Handle any guest state changes since last run.
4885        if let Some(primary) = state.primary.as_mut() {
4886            if primary.requested_num_queues > 1 && !primary.tx_spread_sent {
4887                let num_channels_opened =
4888                    self.adapter.num_sub_channels_opened.load(Ordering::Relaxed);
4889                if num_channels_opened == primary.requested_num_queues as usize {
4890                    let (_, mut send) = self.queue.split();
4891                    stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
4892                        .await??;
4893                    self.guest_send_indirection_table(buffers.version, num_channels_opened as u32);
4894                    primary.tx_spread_sent = true;
4895                }
4896            }
4897            if let PendingLinkAction::Active(up) = primary.pending_link_action {
4898                let (_, mut send) = self.queue.split();
4899                stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
4900                    .await??;
4901                if let Some(id) = primary.free_control_buffers.pop() {
4902                    let connect = if primary.guest_link_up != up {
4903                        primary.pending_link_action = PendingLinkAction::Default;
4904                        up
4905                    } else {
4906                        // For the up -> down -> up OR down -> up -> down case, the first transition
4907                        // is sent immediately and the second transition is queued with a delay.
4908                        primary.pending_link_action =
4909                            PendingLinkAction::Delay(primary.guest_link_up);
4910                        !primary.guest_link_up
4911                    };
4912                    // Mark the receive buffer in use to allow the guest to release it.
4913                    assert!(state.rx_bufs.is_free(id.0));
4914                    state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
4915                    let state_to_send = if connect {
4916                        rndisprot::STATUS_MEDIA_CONNECT
4917                    } else {
4918                        rndisprot::STATUS_MEDIA_DISCONNECT
4919                    };
4920                    tracing::info!(
4921                        connect,
4922                        mac_address = %self.adapter.mac_address,
4923                        "sending link status"
4924                    );
4925
4926                    self.indicate_status(buffers, id.0, state_to_send, &[])?;
4927                    primary.guest_link_up = connect;
4928                } else {
4929                    primary.pending_link_action = PendingLinkAction::Delay(up);
4930                }
4931
4932                match primary.pending_link_action {
4933                    PendingLinkAction::Delay(_) => {
4934                        return Ok(CoordinatorMessage::StartTimer(
4935                            Instant::now() + LINK_DELAY_DURATION,
4936                        ));
4937                    }
4938                    PendingLinkAction::Active(_) => panic!("State should not be Active"),
4939                    _ => {}
4940                }
4941            }
4942            match primary.guest_vf_state {
4943                PrimaryChannelGuestVfState::Available { .. }
4944                | PrimaryChannelGuestVfState::UnavailableFromAvailable
4945                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4946                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4947                | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4948                | PrimaryChannelGuestVfState::DataPathSynthetic => {
4949                    let (_, mut send) = self.queue.split();
4950                    stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
4951                        .await??;
4952                    if let Some(message) = self.handle_state_change(primary, buffers).await? {
4953                        return Ok(message);
4954                    }
4955                }
4956                _ => (),
4957            }
4958        }
4959
4960        loop {
4961            // If the ring is almost full, then do not poll the endpoint,
4962            // since this will cause us to spend time dropping packets and
4963            // reprogramming receive buffers. It's more efficient to let the
4964            // backend drop packets.
4965            let ring_full = {
4966                let (_, mut send) = self.queue.split();
4967                !send.can_write(ring_spare_capacity)?
4968            };
4969
4970            let did_some_work = (!ring_full
4971                && self.process_endpoint_rx(buffers, state, data, queue_state.queue.as_mut())?)
4972                | self.process_ring_buffer(buffers, state, data, queue_state)?
4973                | (!ring_full
4974                    && self.process_endpoint_tx(state, data, queue_state.queue.as_mut())?)
4975                | self.transmit_pending_segments(state, data, queue_state)?
4976                | self.send_pending_packets(state)?;
4977
4978            if !did_some_work {
4979                state.stats.spurious_wakes.increment();
4980            }
4981
4982            // Process any outstanding control messages before sleeping in case
4983            // there are now suballocations available for use.
4984            self.process_control_messages(buffers, state)?;
4985
4986            // This should be the only await point waiting on network traffic or
4987            // guest actions. Wrap it in `stop.until_stopped` to allow
4988            // cancellation.
4989            let restart = stop
4990                .until_stopped(std::future::poll_fn(
4991                    |cx| -> Poll<Option<CoordinatorMessage>> {
4992                        // If the ring is almost full, then don't wait for endpoint
4993                        // interrupts. This allows the interrupt rate to fall when the
4994                        // guest cannot keep up with the load.
4995                        if !ring_full {
4996                            // Check the network endpoint for tx completion or rx.
4997                            if queue_state.queue.poll_ready(cx).is_ready() {
4998                                tracing::trace!("endpoint ready");
4999                                return Poll::Ready(None);
5000                            }
5001                        }
5002
5003                        // Check the incoming ring for tx, but only if there are enough
5004                        // free tx packets.
5005                        let (mut recv, mut send) = self.queue.split();
5006                        if state.free_tx_packets.len() >= self.adapter.free_tx_packet_threshold
5007                            && data.tx_segments.is_empty()
5008                            && recv.poll_ready(cx).is_ready()
5009                        {
5010                            tracing::trace!("incoming ring ready");
5011                            return Poll::Ready(None);
5012                        }
5013
5014                        // Check the outgoing ring for space to send rx completions or
5015                        // control message sends, if any are pending.
5016                        //
5017                        // Also, if endpoint processing has been suspended due to the
5018                        // ring being nearly full, then ask the guest to wake us up when
5019                        // there is space again.
5020                        let mut pending_send_size = self.pending_send_size;
5021                        if ring_full {
5022                            pending_send_size = ring_spare_capacity;
5023                        }
5024                        if pending_send_size != 0
5025                            && send.poll_ready(cx, pending_send_size).is_ready()
5026                        {
5027                            tracing::trace!("outgoing ring ready");
5028                            return Poll::Ready(None);
5029                        }
5030
5031                        // Collect any of this queue's receive buffers that were
5032                        // completed by a remote channel. This only happens when the
5033                        // subchannel count changes, so that receive buffer ownership
5034                        // moves between queues while some receive buffers are still in
5035                        // use.
5036                        if let Some(remote_buffer_id_recv) =
5037                            &mut queue_state.rx_buffer_range.remote_buffer_id_recv
5038                        {
5039                            while let Poll::Ready(Some(id)) =
5040                                remote_buffer_id_recv.poll_next_unpin(cx)
5041                            {
5042                                if id >= RX_RESERVED_CONTROL_BUFFERS {
5043                                    queue_state.queue.rx_avail(&[RxId(id)]);
5044                                } else {
5045                                    state
5046                                        .primary
5047                                        .as_mut()
5048                                        .unwrap()
5049                                        .free_control_buffers
5050                                        .push(ControlMessageId(id));
5051                                }
5052                            }
5053                        }
5054
5055                        if let Some(restart) = self.restart.take() {
5056                            return Poll::Ready(Some(restart));
5057                        }
5058
5059                        tracing::trace!("network waiting");
5060                        Poll::Pending
5061                    },
5062                ))
5063                .await?;
5064
5065            if let Some(restart) = restart {
5066                break Ok(restart);
5067            }
5068        }
5069    }
5070
5071    fn process_endpoint_rx(
5072        &mut self,
5073        buffers: &ChannelBuffers,
5074        state: &mut ActiveState,
5075        data: &mut ProcessingData,
5076        epqueue: &mut dyn net_backend::Queue,
5077    ) -> Result<bool, WorkerError> {
5078        let n = epqueue
5079            .rx_poll(&mut data.rx_ready)
5080            .map_err(WorkerError::Endpoint)?;
5081        if n == 0 {
5082            return Ok(false);
5083        }
5084
5085        state.stats.rx_packets_per_wake.add_sample(n as u64);
5086
5087        if self.packet_filter == rndisprot::NDIS_PACKET_TYPE_NONE {
5088            tracing::trace!(
5089                packet_filter = self.packet_filter,
5090                "rx packets dropped due to packet filter"
5091            );
5092            // Pend the newly available RxIds until the packet filter is updated.
5093            // Under high load this will eventually lead to no available RxIds,
5094            // which will cause the backend to drop the packets instead of
5095            // processing them here.
5096            state.pending_rx_packets.extend(&data.rx_ready[..n]);
5097            state.stats.rx_dropped_filtered.add(n as u64);
5098            return Ok(false);
5099        }
5100
5101        let transaction_id = data.rx_ready[0].0.into();
5102        let ready_ids = data.rx_ready[..n].iter().map(|&RxId(id)| id);
5103
5104        state.rx_bufs.allocate(ready_ids.clone()).unwrap();
5105
5106        // Always use the full suballocation size to avoid tracking the
5107        // message length. See RxBuf::header() for details.
5108        let len = buffers.recv_buffer.sub_allocation_size as usize;
5109        data.transfer_pages.clear();
5110        data.transfer_pages
5111            .extend(ready_ids.map(|id| buffers.recv_buffer.transfer_page_range(id, len)));
5112
5113        match self.try_send_rndis_message(
5114            transaction_id,
5115            protocol::DATA_CHANNEL_TYPE,
5116            buffers.recv_buffer.id,
5117            &data.transfer_pages,
5118        )? {
5119            None => {
5120                // packet was sent
5121                state.stats.rx_packets.add(n as u64);
5122            }
5123            Some(_) => {
5124                // Ring buffer is full. Drop the packets and free the rx
5125                // buffers. When the ring has limited space, the main loop will
5126                // stop polling for receive packets.
5127                state.stats.rx_dropped_ring_full.add(n as u64);
5128
5129                state.rx_bufs.free(data.rx_ready[0].0);
5130                epqueue.rx_avail(&data.rx_ready[..n]);
5131            }
5132        }
5133
5134        Ok(true)
5135    }
5136
5137    fn process_endpoint_tx(
5138        &mut self,
5139        state: &mut ActiveState,
5140        data: &mut ProcessingData,
5141        epqueue: &mut dyn net_backend::Queue,
5142    ) -> Result<bool, WorkerError> {
5143        // Drain completed transmits.
5144        let result = epqueue.tx_poll(&mut data.tx_done);
5145
5146        match result {
5147            Ok(n) => {
5148                if n == 0 {
5149                    return Ok(false);
5150                }
5151
5152                for &id in &data.tx_done[..n] {
5153                    let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5154                    assert!(tx_packet.pending_packet_count > 0);
5155                    tx_packet.pending_packet_count -= 1;
5156                    if tx_packet.pending_packet_count == 0 {
5157                        self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5158                    }
5159                }
5160
5161                Ok(true)
5162            }
5163            Err(TxError::TryRestart(err)) => {
5164                // Complete any pending tx prior to restarting queues.
5165                let pending_tx = state
5166                    .pending_tx_packets
5167                    .iter_mut()
5168                    .enumerate()
5169                    .filter_map(|(id, inflight)| {
5170                        if inflight.pending_packet_count > 0 {
5171                            inflight.pending_packet_count = 0;
5172                            Some(PendingTxCompletion {
5173                                transaction_id: inflight.transaction_id,
5174                                tx_id: Some(TxId(id as u32)),
5175                                status: protocol::Status::SUCCESS,
5176                            })
5177                        } else {
5178                            None
5179                        }
5180                    })
5181                    .collect::<Vec<_>>();
5182                state.pending_tx_completions.extend(pending_tx);
5183                Err(WorkerError::EndpointRequiresQueueRestart(err))
5184            }
5185            Err(TxError::Fatal(err)) => Err(WorkerError::Endpoint(err)),
5186        }
5187    }
5188
5189    fn switch_data_path(
5190        &mut self,
5191        state: &mut ActiveState,
5192        use_guest_vf: bool,
5193        transaction_id: Option<u64>,
5194    ) -> Result<(), WorkerError> {
5195        let primary = state.primary.as_mut().unwrap();
5196        let mut queue_switch_operation = false;
5197        match primary.guest_vf_state {
5198            PrimaryChannelGuestVfState::AvailableAdvertised | PrimaryChannelGuestVfState::Ready => {
5199                // Allow the guest to switch to VTL0, or if the current state
5200                // of the data path is unknown, allow a switch away from VTL0.
5201                // The latter case handles the scenario where the data path has
5202                // been switched but the synthetic device has been restarted.
5203                // The device is queried for the current state of the data path
5204                // but if it is unknown (upgraded from older version that
5205                // doesn't track this data, or failure during query) then the
5206                // safest option is to pass the request through.
5207                if use_guest_vf || primary.is_data_path_switched.is_none() {
5208                    primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5209                        to_guest: use_guest_vf,
5210                        id: transaction_id,
5211                        result: None,
5212                    };
5213                    queue_switch_operation = true;
5214                }
5215            }
5216            PrimaryChannelGuestVfState::DataPathSwitched => {
5217                if !use_guest_vf {
5218                    primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5219                        to_guest: false,
5220                        id: transaction_id,
5221                        result: None,
5222                    };
5223                    queue_switch_operation = true;
5224                }
5225            }
5226            _ if use_guest_vf => {
5227                tracing::warn!(
5228                    state = %primary.guest_vf_state,
5229                    use_guest_vf,
5230                    "Data path switch requested while device is in wrong state"
5231                );
5232            }
5233            _ => (),
5234        };
5235        if queue_switch_operation {
5236            self.send_coordinator_update_vf();
5237        } else {
5238            self.send_completion(transaction_id, &[])?;
5239        }
5240        Ok(())
5241    }
5242
5243    fn process_ring_buffer(
5244        &mut self,
5245        buffers: &ChannelBuffers,
5246        state: &mut ActiveState,
5247        data: &mut ProcessingData,
5248        queue_state: &mut QueueState,
5249    ) -> Result<bool, WorkerError> {
5250        let mut total_packets = 0;
5251        let mut did_some_work = false;
5252        loop {
5253            if state.free_tx_packets.is_empty() || !data.tx_segments.is_empty() {
5254                break;
5255            }
5256            let packet = if let Some(packet) =
5257                self.try_next_packet(buffers.send_buffer.as_ref(), Some(buffers.version))?
5258            {
5259                packet
5260            } else {
5261                break;
5262            };
5263
5264            did_some_work = true;
5265            match packet.data {
5266                PacketData::RndisPacket(_) => {
5267                    assert!(data.tx_segments.is_empty());
5268                    let id = state.free_tx_packets.pop().unwrap();
5269                    let result: Result<usize, WorkerError> =
5270                        self.handle_rndis(buffers, id, state, &packet, &mut data.tx_segments);
5271                    let num_packets = match result {
5272                        Ok(num_packets) => num_packets,
5273                        Err(err) => {
5274                            tracelimit::error_ratelimited!(
5275                                err = &err as &dyn std::error::Error,
5276                                "failed to handle RNDIS packet"
5277                            );
5278                            self.complete_tx_packet(state, id, protocol::Status::FAILURE)?;
5279                            continue;
5280                        }
5281                    };
5282                    total_packets += num_packets as u64;
5283                    state.pending_tx_packets[id.0 as usize].pending_packet_count += num_packets;
5284
5285                    if num_packets != 0 {
5286                        if self.transmit_segments(state, data, queue_state, id, num_packets)?
5287                            < num_packets
5288                        {
5289                            state.stats.tx_stalled.increment();
5290                        }
5291                    } else {
5292                        self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5293                    }
5294                }
5295                PacketData::RndisPacketComplete(_completion) => {
5296                    data.rx_done.clear();
5297                    state
5298                        .release_recv_buffers(
5299                            packet
5300                                .transaction_id
5301                                .expect("completion packets have transaction id by construction"),
5302                            &queue_state.rx_buffer_range,
5303                            &mut data.rx_done,
5304                        )
5305                        .ok_or(WorkerError::InvalidRndisPacketCompletion)?;
5306                    queue_state.queue.rx_avail(&data.rx_done);
5307                }
5308                PacketData::SubChannelRequest(request) if state.primary.is_some() => {
5309                    let mut subchannel_count = 0;
5310                    let status = if request.operation == protocol::SubchannelOperation::ALLOCATE
5311                        && request.num_sub_channels <= self.adapter.max_queues.into()
5312                    {
5313                        subchannel_count = request.num_sub_channels;
5314                        protocol::Status::SUCCESS
5315                    } else {
5316                        protocol::Status::FAILURE
5317                    };
5318
5319                    tracing::debug!(?status, subchannel_count, "subchannel request");
5320                    self.send_completion(
5321                        packet.transaction_id,
5322                        &self
5323                            .message(
5324                                protocol::MESSAGE5_TYPE_SUB_CHANNEL,
5325                                protocol::Message5SubchannelComplete {
5326                                    status,
5327                                    num_sub_channels: subchannel_count,
5328                                },
5329                            )
5330                            .payload(),
5331                    )?;
5332
5333                    if subchannel_count > 0 {
5334                        let primary = state.primary.as_mut().unwrap();
5335                        primary.requested_num_queues = subchannel_count as u16 + 1;
5336                        primary.tx_spread_sent = false;
5337                        self.restart = Some(CoordinatorMessage::Restart);
5338                    }
5339                }
5340                PacketData::RevokeReceiveBuffer(Message1RevokeReceiveBuffer { id })
5341                | PacketData::RevokeSendBuffer(Message1RevokeSendBuffer { id })
5342                    if state.primary.is_some() =>
5343                {
5344                    tracing::debug!(
5345                        id,
5346                        "receive/send buffer revoked, terminating channel processing"
5347                    );
5348                    return Err(WorkerError::BufferRevoked);
5349                }
5350                // No operation for VF association completion packets as not all clients send them
5351                PacketData::SendVfAssociationCompletion if state.primary.is_some() => (),
5352                PacketData::SwitchDataPath(switch_data_path) if state.primary.is_some() => {
5353                    self.switch_data_path(
5354                        state,
5355                        switch_data_path.active_data_path == protocol::DataPath::VF.0,
5356                        packet.transaction_id,
5357                    )?;
5358                }
5359                PacketData::SwitchDataPathCompletion if state.primary.is_some() => (),
5360                PacketData::OidQueryEx(oid_query) => {
5361                    tracing::warn!(oid = ?oid_query.oid, "unimplemented OID");
5362                    self.send_completion(
5363                        packet.transaction_id,
5364                        &self
5365                            .message(
5366                                protocol::MESSAGE5_TYPE_OID_QUERY_EX_COMPLETE,
5367                                protocol::Message5OidQueryExComplete {
5368                                    status: rndisprot::STATUS_NOT_SUPPORTED,
5369                                    bytes: 0,
5370                                },
5371                            )
5372                            .payload(),
5373                    )?;
5374                }
5375                p => {
5376                    tracing::warn!(request = ?p, "unexpected packet");
5377                    return Err(WorkerError::UnexpectedPacketOrder(
5378                        PacketOrderError::SwitchDataPathCompletionPrimaryChannelState,
5379                    ));
5380                }
5381            }
5382        }
5383        state.stats.tx_packets_per_wake.add_sample(total_packets);
5384        Ok(did_some_work)
5385    }
5386
5387    fn transmit_pending_segments(
5388        &mut self,
5389        state: &mut ActiveState,
5390        data: &mut ProcessingData,
5391        queue_state: &mut QueueState,
5392    ) -> Result<bool, WorkerError> {
5393        if data.tx_segments.is_empty() {
5394            return Ok(false);
5395        }
5396        let net_backend::TxSegmentType::Head(metadata) = &data.tx_segments[0].ty else {
5397            unreachable!()
5398        };
5399        let id = metadata.id;
5400        let num_packets = state.pending_tx_packets[id.0 as usize].pending_packet_count;
5401        let packets_sent = self.transmit_segments(state, data, queue_state, id, num_packets)?;
5402        Ok(num_packets == packets_sent)
5403    }
5404
5405    fn transmit_segments(
5406        &mut self,
5407        state: &mut ActiveState,
5408        data: &mut ProcessingData,
5409        queue_state: &mut QueueState,
5410        id: TxId,
5411        num_packets: usize,
5412    ) -> Result<usize, WorkerError> {
5413        let (sync, segments_sent) = queue_state
5414            .queue
5415            .tx_avail(&data.tx_segments)
5416            .map_err(WorkerError::Endpoint)?;
5417
5418        assert!(segments_sent <= data.tx_segments.len());
5419
5420        let packets_sent = if segments_sent == data.tx_segments.len() {
5421            num_packets
5422        } else {
5423            net_backend::packet_count(&data.tx_segments[..segments_sent])
5424        };
5425
5426        data.tx_segments.drain(..segments_sent);
5427
5428        if sync {
5429            state.pending_tx_packets[id.0 as usize].pending_packet_count -= packets_sent;
5430        }
5431
5432        if state.pending_tx_packets[id.0 as usize].pending_packet_count == 0 {
5433            self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5434        }
5435
5436        Ok(packets_sent)
5437    }
5438
5439    fn handle_rndis(
5440        &mut self,
5441        buffers: &ChannelBuffers,
5442        id: TxId,
5443        state: &mut ActiveState,
5444        packet: &Packet<'_>,
5445        segments: &mut Vec<TxSegment>,
5446    ) -> Result<usize, WorkerError> {
5447        let mut num_packets = 0;
5448        let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5449        assert!(tx_packet.pending_packet_count == 0);
5450        tx_packet.transaction_id = packet
5451            .transaction_id
5452            .ok_or(WorkerError::MissingTransactionId)?;
5453
5454        // Probe the data to catch accesses that are out of bounds. This
5455        // simplifies error handling for backends that use
5456        // [`GuestMemory::iova`].
5457        packet
5458            .external_data
5459            .iter()
5460            .try_for_each(|range| buffers.mem.probe_gpns(range.gpns()))
5461            .map_err(WorkerError::GpaDirectError)?;
5462
5463        let mut reader = packet.rndis_reader(&buffers.mem);
5464        // There may be multiple RNDIS packets in a single
5465        // message, concatenated with each other. Consume
5466        // them until there is no more data in the RNDIS
5467        // message.
5468        while reader.len() > 0 {
5469            let mut this_reader = reader.clone();
5470            let header: rndisprot::MessageHeader = this_reader.read_plain()?;
5471            if self.handle_rndis_message(
5472                buffers,
5473                state,
5474                id,
5475                header.message_type,
5476                this_reader,
5477                segments,
5478            )? {
5479                num_packets += 1;
5480            }
5481            reader.skip(header.message_length as usize)?;
5482        }
5483
5484        Ok(num_packets)
5485    }
5486
5487    fn try_send_tx_packet(
5488        &mut self,
5489        transaction_id: u64,
5490        status: protocol::Status,
5491    ) -> Result<bool, WorkerError> {
5492        let message = self.message(
5493            protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE,
5494            protocol::Message1SendRndisPacketComplete { status },
5495        );
5496        let result = self.queue.split().1.try_write(&queue::OutgoingPacket {
5497            transaction_id,
5498            packet_type: OutgoingPacketType::Completion,
5499            payload: &message.payload(),
5500        });
5501        let sent = match result {
5502            Ok(()) => true,
5503            Err(queue::TryWriteError::Full(n)) => {
5504                self.pending_send_size = n;
5505                false
5506            }
5507            Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
5508        };
5509        Ok(sent)
5510    }
5511
5512    fn send_pending_packets(&mut self, state: &mut ActiveState) -> Result<bool, WorkerError> {
5513        let mut did_some_work = false;
5514        while let Some(pending) = state.pending_tx_completions.front() {
5515            if !self.try_send_tx_packet(pending.transaction_id, pending.status)? {
5516                return Ok(did_some_work);
5517            }
5518            did_some_work = true;
5519            if let Some(id) = pending.tx_id {
5520                state.free_tx_packets.push(id);
5521            }
5522            tracing::trace!(?pending, "sent tx completion");
5523            state.pending_tx_completions.pop_front();
5524        }
5525
5526        self.pending_send_size = 0;
5527        Ok(did_some_work)
5528    }
5529
5530    fn complete_tx_packet(
5531        &mut self,
5532        state: &mut ActiveState,
5533        id: TxId,
5534        status: protocol::Status,
5535    ) -> Result<(), WorkerError> {
5536        let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5537        assert_eq!(tx_packet.pending_packet_count, 0);
5538        if self.pending_send_size == 0
5539            && self.try_send_tx_packet(tx_packet.transaction_id, status)?
5540        {
5541            tracing::trace!(id = id.0, "sent tx completion");
5542            state.free_tx_packets.push(id);
5543        } else {
5544            tracing::trace!(id = id.0, "pended tx completion");
5545            state.pending_tx_completions.push_back(PendingTxCompletion {
5546                transaction_id: tx_packet.transaction_id,
5547                tx_id: Some(id),
5548                status,
5549            });
5550        }
5551        Ok(())
5552    }
5553}
5554
5555impl ActiveState {
5556    fn release_recv_buffers(
5557        &mut self,
5558        transaction_id: u64,
5559        rx_buffer_range: &RxBufferRange,
5560        done: &mut Vec<RxId>,
5561    ) -> Option<()> {
5562        // The transaction ID specifies the first rx buffer ID.
5563        let first_id: u32 = transaction_id.try_into().ok()?;
5564        let ids = self.rx_bufs.free(first_id)?;
5565        for id in ids {
5566            if !rx_buffer_range.send_if_remote(id) {
5567                if id >= RX_RESERVED_CONTROL_BUFFERS {
5568                    done.push(RxId(id));
5569                } else {
5570                    self.primary
5571                        .as_mut()?
5572                        .free_control_buffers
5573                        .push(ControlMessageId(id));
5574                }
5575            }
5576        }
5577        Some(())
5578    }
5579}