netvsp/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! The user-mode netvsp VMBus device implementation.
5
6#![expect(missing_docs)]
7#![forbid(unsafe_code)]
8
9mod buffers;
10mod protocol;
11pub mod resolver;
12mod rndisprot;
13mod rx_bufs;
14mod saved_state;
15mod test;
16
17use crate::buffers::GuestBuffers;
18use crate::protocol::Message1RevokeReceiveBuffer;
19use crate::protocol::Message1RevokeSendBuffer;
20use crate::protocol::VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES;
21use crate::protocol::Version;
22use crate::rndisprot::NDIS_HASH_FUNCTION_MASK;
23use crate::rndisprot::NDIS_RSS_PARAM_FLAG_DISABLE_RSS;
24use async_trait::async_trait;
25pub use buffers::BufferPool;
26use buffers::sub_allocation_size_for_mtu;
27use futures::FutureExt;
28use futures::StreamExt;
29use futures::channel::mpsc;
30use futures_concurrency::future::Race;
31use guestmem::AccessError;
32use guestmem::GuestMemory;
33use guestmem::GuestMemoryError;
34use guestmem::MemoryRead;
35use guestmem::MemoryWrite;
36use guestmem::ranges::PagedRange;
37use guestmem::ranges::PagedRanges;
38use guestmem::ranges::PagedRangesReader;
39use guid::Guid;
40use hvdef::hypercall::HvGuestOsId;
41use hvdef::hypercall::HvGuestOsMicrosoft;
42use hvdef::hypercall::HvGuestOsMicrosoftIds;
43use hvdef::hypercall::HvGuestOsOpenSourceType;
44use inspect::Inspect;
45use inspect::InspectMut;
46use inspect::SensitivityLevel;
47use inspect_counters::Counter;
48use inspect_counters::Histogram;
49use mesh::rpc::Rpc;
50use net_backend::Endpoint;
51use net_backend::EndpointAction;
52use net_backend::L3Protocol;
53use net_backend::QueueConfig;
54use net_backend::RxId;
55use net_backend::TxError;
56use net_backend::TxId;
57use net_backend::TxSegment;
58use net_backend_resources::mac_address::MacAddress;
59use pal_async::timer::Instant;
60use pal_async::timer::PolledTimer;
61use ring::gparange::MultiPagedRangeIter;
62use rx_bufs::RxBuffers;
63use rx_bufs::SubAllocationInUse;
64use std::cmp;
65use std::collections::VecDeque;
66use std::fmt::Debug;
67use std::future::pending;
68use std::mem::offset_of;
69use std::ops::Range;
70use std::sync::Arc;
71use std::sync::atomic::AtomicUsize;
72use std::sync::atomic::Ordering;
73use std::task::Poll;
74use std::time::Duration;
75use task_control::AsyncRun;
76use task_control::InspectTaskMut;
77use task_control::StopTask;
78use task_control::TaskControl;
79use thiserror::Error;
80use tracing::Instrument;
81use vmbus_async::queue;
82use vmbus_async::queue::ExternalDataError;
83use vmbus_async::queue::IncomingPacket;
84use vmbus_async::queue::Queue;
85use vmbus_channel::bus::OfferParams;
86use vmbus_channel::bus::OpenRequest;
87use vmbus_channel::channel::ChannelControl;
88use vmbus_channel::channel::ChannelOpenError;
89use vmbus_channel::channel::ChannelRestoreError;
90use vmbus_channel::channel::DeviceResources;
91use vmbus_channel::channel::RestoreControl;
92use vmbus_channel::channel::SaveRestoreVmbusDevice;
93use vmbus_channel::channel::VmbusDevice;
94use vmbus_channel::gpadl::GpadlId;
95use vmbus_channel::gpadl::GpadlMapView;
96use vmbus_channel::gpadl::GpadlView;
97use vmbus_channel::gpadl::UnknownGpadlId;
98use vmbus_channel::gpadl_ring::GpadlRingMem;
99use vmbus_channel::gpadl_ring::gpadl_channel;
100use vmbus_ring as ring;
101use vmbus_ring::OutgoingPacketType;
102use vmbus_ring::RingMem;
103use vmbus_ring::gparange::GpnList;
104use vmbus_ring::gparange::MultiPagedRangeBuf;
105use vmcore::save_restore::RestoreError;
106use vmcore::save_restore::SaveError;
107use vmcore::save_restore::SavedStateBlob;
108use vmcore::vm_task::VmTaskDriver;
109use vmcore::vm_task::VmTaskDriverSource;
110use zerocopy::FromBytes;
111use zerocopy::FromZeros;
112use zerocopy::Immutable;
113use zerocopy::IntoBytes;
114use zerocopy::KnownLayout;
115
116// The minimum ring space required to handle a control message. Most control messages only need to send a completion
117// packet, but also need room for an additional SEND_VF_ASSOCIATION message.
118const MIN_CONTROL_RING_SIZE: usize = 144;
119
120// The minimum ring space required to handle external state changes. Worst case requires a completion message plus two
121// additional inband messages (SWITCH_DATA_PATH and SEND_VF_ASSOCIATION)
122const MIN_STATE_CHANGE_RING_SIZE: usize = 196;
123
124// Assign the VF_ASSOCIATION message a specific transaction ID so that the completion packet can be identified easily.
125const VF_ASSOCIATION_TRANSACTION_ID: u64 = 0x8000000000000000;
126// Assign the SWITCH_DATA_PATH message a specific transaction ID so that the completion packet can be identified easily.
127const SWITCH_DATA_PATH_TRANSACTION_ID: u64 = 0x8000000000000001;
128
129const NETVSP_MAX_SUBCHANNELS_PER_VNIC: u16 = 64;
130
131// Arbitrary delay before adding the device to the guest. Older Linux
132// clients can race when initializing the synthetic nic: the network
133// negotiation is done first and then the device is asynchronously queued to
134// receive a name (e.g. eth0). If the AN device is offered too quickly, it
135// could get the "eth0" name. In provisioning scenarios, the scripts will make
136// assumptions about which interface should be used, with eth0 being the
137// default.
138#[cfg(not(test))]
139const VF_DEVICE_DELAY: Duration = Duration::from_secs(1);
140#[cfg(test)]
141const VF_DEVICE_DELAY: Duration = Duration::from_millis(100);
142
143// Linux guests are known to not act on link state change notifications if
144// they happen in quick succession.
145#[cfg(not(test))]
146const LINK_DELAY_DURATION: Duration = Duration::from_secs(5);
147#[cfg(test)]
148const LINK_DELAY_DURATION: Duration = Duration::from_millis(333);
149
150#[derive(Default, PartialEq)]
151struct CoordinatorMessageUpdateType {
152    /// Update guest VF state based on current availability and the guest VF state tracked by the primary channel.
153    /// This includes adding the guest VF device and switching the data path.
154    guest_vf_state: bool,
155    /// Update the receive filter for all channels.
156    filter_state: bool,
157}
158
159#[derive(PartialEq)]
160enum CoordinatorMessage {
161    /// Update network state.
162    Update(CoordinatorMessageUpdateType),
163    /// Restart endpoints and resume processing. This will also attempt to set VF and data path state to match current
164    /// expectations.
165    Restart,
166    /// Start a timer.
167    StartTimer(Instant),
168}
169
170struct Worker<T: RingMem> {
171    channel_idx: u16,
172    target_vp: u32,
173    mem: GuestMemory,
174    channel: NetChannel<T>,
175    state: WorkerState,
176    coordinator_send: mpsc::Sender<CoordinatorMessage>,
177}
178
179struct NetQueue {
180    driver: VmTaskDriver,
181    queue_state: Option<QueueState>,
182}
183
184impl<T: RingMem + 'static + Sync> InspectTaskMut<Worker<T>> for NetQueue {
185    fn inspect_mut(&mut self, req: inspect::Request<'_>, worker: Option<&mut Worker<T>>) {
186        if worker.is_none() && self.queue_state.is_none() {
187            req.ignore();
188            return;
189        }
190
191        let mut resp = req.respond();
192        resp.field("driver", &self.driver);
193        if let Some(worker) = worker {
194            resp.field(
195                "protocol_state",
196                match &worker.state {
197                    WorkerState::Init(None) => "version",
198                    WorkerState::Init(Some(_)) => "init",
199                    WorkerState::Ready(_) => "ready",
200                },
201            )
202            .field("ring", &worker.channel.queue)
203            .field(
204                "can_use_ring_size_optimization",
205                worker.channel.can_use_ring_size_opt,
206            );
207
208            if let WorkerState::Ready(state) = &worker.state {
209                resp.field(
210                    "outstanding_tx_packets",
211                    state.state.pending_tx_packets.len() - state.state.free_tx_packets.len(),
212                )
213                .field(
214                    "pending_tx_completions",
215                    state.state.pending_tx_completions.len(),
216                )
217                .field("free_tx_packets", state.state.free_tx_packets.len())
218                .merge(&state.state.stats);
219            }
220        }
221
222        if let Some(queue_state) = &mut self.queue_state {
223            resp.field_mut("queue", &mut queue_state.queue)
224                .field("rx_buffers", queue_state.rx_buffer_range.id_range.len())
225                .field(
226                    "rx_buffers_start",
227                    queue_state.rx_buffer_range.id_range.start,
228                );
229        }
230    }
231}
232
233#[expect(clippy::large_enum_variant)]
234enum WorkerState {
235    Init(Option<InitState>),
236    Ready(ReadyState),
237}
238
239impl WorkerState {
240    fn ready(&self) -> Option<&ReadyState> {
241        if let Self::Ready(state) = self {
242            Some(state)
243        } else {
244            None
245        }
246    }
247
248    fn ready_mut(&mut self) -> Option<&mut ReadyState> {
249        if let Self::Ready(state) = self {
250            Some(state)
251        } else {
252            None
253        }
254    }
255}
256
257struct InitState {
258    version: Version,
259    ndis_config: Option<NdisConfig>,
260    ndis_version: Option<NdisVersion>,
261    recv_buffer: Option<ReceiveBuffer>,
262    send_buffer: Option<SendBuffer>,
263}
264
265#[derive(Copy, Clone, Debug, Inspect)]
266struct NdisVersion {
267    #[inspect(hex)]
268    major: u32,
269    #[inspect(hex)]
270    minor: u32,
271}
272
273#[derive(Copy, Clone, Debug, Inspect)]
274struct NdisConfig {
275    #[inspect(safe)]
276    mtu: u32,
277    #[inspect(safe)]
278    capabilities: protocol::NdisConfigCapabilities,
279}
280
281struct ReadyState {
282    buffers: Arc<ChannelBuffers>,
283    state: ActiveState,
284    data: ProcessingData,
285}
286
287/// Represents a virtual function (VF) device used to expose accelerated
288/// networking to the guest.
289#[async_trait]
290pub trait VirtualFunction: Sync + Send {
291    /// Unique ID of the device. Used by the client to associate a device with
292    /// its synthetic counterpart. A value of None signifies that the VF is not
293    /// currently available for use.
294    async fn id(&self) -> Option<u32>;
295    /// Dynamically expose the device in the guest.
296    async fn guest_ready_for_device(&mut self);
297    /// Returns when there is a change in VF availability. The Rpc result will
298    ///  indicate if the change was successfully handled.
299    async fn wait_for_state_change(&mut self) -> Rpc<(), ()>;
300}
301
302struct Adapter {
303    driver: VmTaskDriver,
304    mac_address: MacAddress,
305    max_queues: u16,
306    indirection_table_size: u16,
307    offload_support: OffloadConfig,
308    ring_size_limit: AtomicUsize,
309    free_tx_packet_threshold: usize,
310    tx_fast_completions: bool,
311    adapter_index: u32,
312    get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
313    num_sub_channels_opened: AtomicUsize,
314    link_speed: u64,
315}
316
317struct QueueState {
318    queue: Box<dyn net_backend::Queue>,
319    rx_buffer_range: RxBufferRange,
320    target_vp_set: bool,
321}
322
323struct RxBufferRange {
324    id_range: Range<u32>,
325    remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
326    remote_ranges: Arc<RxBufferRanges>,
327}
328
329impl RxBufferRange {
330    fn new(
331        ranges: Arc<RxBufferRanges>,
332        id_range: Range<u32>,
333        remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
334    ) -> Self {
335        Self {
336            id_range,
337            remote_buffer_id_recv,
338            remote_ranges: ranges,
339        }
340    }
341
342    fn send_if_remote(&self, id: u32) -> bool {
343        // Only queue 0 should get reserved buffer IDs. Otherwise check if the
344        // ID is owned by the current range.
345        if id < RX_RESERVED_CONTROL_BUFFERS || self.id_range.contains(&id) {
346            false
347        } else {
348            let i = (id - RX_RESERVED_CONTROL_BUFFERS) / self.remote_ranges.buffers_per_queue;
349            // The total number of receive buffers may not evenly divide among
350            // the active queues. Any extra buffers are given to the last
351            // queue, so redirect any larger values there.
352            let i = (i as usize).min(self.remote_ranges.buffer_id_send.len() - 1);
353            let _ = self.remote_ranges.buffer_id_send[i].unbounded_send(id);
354            true
355        }
356    }
357}
358
359struct RxBufferRanges {
360    buffers_per_queue: u32,
361    buffer_id_send: Vec<mpsc::UnboundedSender<u32>>,
362}
363
364impl RxBufferRanges {
365    fn new(buffer_count: u32, queue_count: u32) -> (Self, Vec<mpsc::UnboundedReceiver<u32>>) {
366        let buffers_per_queue = (buffer_count - RX_RESERVED_CONTROL_BUFFERS) / queue_count;
367        #[expect(clippy::disallowed_methods)] // TODO
368        let (send, recv): (Vec<_>, Vec<_>) = (0..queue_count).map(|_| mpsc::unbounded()).unzip();
369        (
370            Self {
371                buffers_per_queue,
372                buffer_id_send: send,
373            },
374            recv,
375        )
376    }
377}
378
379struct RssState {
380    key: [u8; 40],
381    indirection_table: Vec<u16>,
382}
383
384/// The internal channel state.
385struct NetChannel<T: RingMem> {
386    adapter: Arc<Adapter>,
387    queue: Queue<T>,
388    gpadl_map: GpadlMapView,
389    packet_size: usize,
390    pending_send_size: usize,
391    restart: Option<CoordinatorMessage>,
392    can_use_ring_size_opt: bool,
393    packet_filter: u32,
394}
395
396/// Buffers used during packet processing.
397struct ProcessingData {
398    tx_segments: Vec<TxSegment>,
399    tx_done: Box<[TxId]>,
400    rx_ready: Box<[RxId]>,
401    rx_done: Vec<RxId>,
402    transfer_pages: Vec<ring::TransferPageRange>,
403}
404
405impl ProcessingData {
406    fn new() -> Self {
407        Self {
408            tx_segments: Vec::new(),
409            tx_done: vec![TxId(0); 8192].into(),
410            rx_ready: vec![RxId(0); RX_BATCH_SIZE].into(),
411            rx_done: Vec::with_capacity(RX_BATCH_SIZE),
412            transfer_pages: Vec::with_capacity(RX_BATCH_SIZE),
413        }
414    }
415}
416
417/// Buffers used during channel processing. Separated out from the mutable state
418/// to allow multiple concurrent references.
419#[derive(Debug, Inspect)]
420struct ChannelBuffers {
421    version: Version,
422    #[inspect(skip)]
423    mem: GuestMemory,
424    #[inspect(skip)]
425    recv_buffer: ReceiveBuffer,
426    #[inspect(skip)]
427    send_buffer: Option<SendBuffer>,
428    ndis_version: NdisVersion,
429    #[inspect(safe)]
430    ndis_config: NdisConfig,
431}
432
433/// An ID assigned to a control message. This is also its receive buffer index.
434#[derive(Copy, Clone, Debug)]
435struct ControlMessageId(u32);
436
437/// Mutable state for a channel that has finished negotiation.
438struct ActiveState {
439    primary: Option<PrimaryChannelState>,
440
441    pending_tx_packets: Vec<PendingTxPacket>,
442    free_tx_packets: Vec<TxId>,
443    pending_tx_completions: VecDeque<PendingTxCompletion>,
444    pending_rx_packets: VecDeque<RxId>,
445
446    rx_bufs: RxBuffers,
447
448    stats: QueueStats,
449}
450
451#[derive(Inspect, Default)]
452struct QueueStats {
453    tx_stalled: Counter,
454    rx_dropped_ring_full: Counter,
455    rx_dropped_filtered: Counter,
456    spurious_wakes: Counter,
457    rx_packets: Counter,
458    tx_packets: Counter,
459    tx_lso_packets: Counter,
460    tx_checksum_packets: Counter,
461    tx_packets_per_wake: Histogram<10>,
462    rx_packets_per_wake: Histogram<10>,
463}
464
465#[derive(Debug)]
466struct PendingTxCompletion {
467    transaction_id: u64,
468    tx_id: Option<TxId>,
469    status: protocol::Status,
470}
471
472#[derive(Clone, Copy)]
473enum PrimaryChannelGuestVfState {
474    /// No state has been assigned yet
475    Initializing,
476    /// State is being restored from a save
477    Restoring(saved_state::GuestVfState),
478    /// No VF available for the guest
479    Unavailable,
480    /// A VF was previously available to the guest, but is no longer available.
481    UnavailableFromAvailable,
482    /// A VF was previously available, but it is no longer available.
483    UnavailableFromDataPathSwitchPending { to_guest: bool, id: Option<u64> },
484    /// A VF was previously available, but it is no longer available.
485    UnavailableFromDataPathSwitched,
486    /// A VF is available for the guest.
487    Available { vfid: u32 },
488    /// A VF is available for the guest and has been advertised to the guest.
489    AvailableAdvertised,
490    /// A VF is ready for guest use.
491    Ready,
492    /// A VF is ready in the guest and guest has requested a data path switch.
493    DataPathSwitchPending {
494        to_guest: bool,
495        id: Option<u64>,
496        result: Option<bool>,
497    },
498    /// A VF is ready in the guest and is currently acting as the data path.
499    DataPathSwitched,
500    /// A VF is ready in the guest and was acting as the data path, but an external
501    /// state change has moved it back to synthetic.
502    DataPathSynthetic,
503}
504
505impl std::fmt::Display for PrimaryChannelGuestVfState {
506    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
507        match self {
508            PrimaryChannelGuestVfState::Initializing => write!(f, "initializing"),
509            PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
510                write!(f, "restoring")
511            }
512            PrimaryChannelGuestVfState::Restoring(
513                saved_state::GuestVfState::AvailableAdvertised,
514            ) => write!(f, "restoring from guest notified of vfid"),
515            PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
516                write!(f, "restoring from vf present")
517            }
518            PrimaryChannelGuestVfState::Restoring(
519                saved_state::GuestVfState::DataPathSwitchPending {
520                    to_guest, result, ..
521                },
522            ) => {
523                write!(
524                    f,
525                    "restoring from client requested data path switch: to {} {}",
526                    if *to_guest { "guest" } else { "synthetic" },
527                    if let Some(result) = result {
528                        if *result { "succeeded\"" } else { "failed\"" }
529                    } else {
530                        "in progress\""
531                    }
532                )
533            }
534            PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::DataPathSwitched) => {
535                write!(f, "restoring from data path in guest")
536            }
537            PrimaryChannelGuestVfState::Unavailable => write!(f, "unavailable"),
538            PrimaryChannelGuestVfState::UnavailableFromAvailable => {
539                write!(f, "\"unavailable (previously available)\"")
540            }
541            PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. } => {
542                write!(f, "unavailable (previously switching data path)")
543            }
544            PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
545                write!(f, "\"unavailable (previously using guest VF)\"")
546            }
547            PrimaryChannelGuestVfState::Available { vfid } => write!(f, "available vfid: {}", vfid),
548            PrimaryChannelGuestVfState::AvailableAdvertised => {
549                write!(f, "\"available, guest notified\"")
550            }
551            PrimaryChannelGuestVfState::Ready => write!(f, "\"available and present in guest\""),
552            PrimaryChannelGuestVfState::DataPathSwitchPending {
553                to_guest, result, ..
554            } => {
555                write!(
556                    f,
557                    "\"switching to {} {}",
558                    if *to_guest { "guest" } else { "synthetic" },
559                    if let Some(result) = result {
560                        if *result { "succeeded\"" } else { "failed\"" }
561                    } else {
562                        "in progress\""
563                    }
564                )
565            }
566            PrimaryChannelGuestVfState::DataPathSwitched => {
567                write!(f, "\"available and data path switched\"")
568            }
569            PrimaryChannelGuestVfState::DataPathSynthetic => write!(
570                f,
571                "\"available but data path switched back to synthetic due to external state change\""
572            ),
573        }
574    }
575}
576
577impl Inspect for PrimaryChannelGuestVfState {
578    fn inspect(&self, req: inspect::Request<'_>) {
579        req.value(self.to_string());
580    }
581}
582
583#[derive(Debug)]
584enum PendingLinkAction {
585    Default,
586    Active(bool),
587    Delay(bool),
588}
589
590struct PrimaryChannelState {
591    guest_vf_state: PrimaryChannelGuestVfState,
592    is_data_path_switched: Option<bool>,
593    control_messages: VecDeque<ControlMessage>,
594    control_messages_len: usize,
595    free_control_buffers: Vec<ControlMessageId>,
596    rss_state: Option<RssState>,
597    requested_num_queues: u16,
598    rndis_state: RndisState,
599    offload_config: OffloadConfig,
600    pending_offload_change: bool,
601    tx_spread_sent: bool,
602    guest_link_up: bool,
603    pending_link_action: PendingLinkAction,
604}
605
606impl Inspect for PrimaryChannelState {
607    fn inspect(&self, req: inspect::Request<'_>) {
608        req.respond()
609            .sensitivity_field(
610                "guest_vf_state",
611                SensitivityLevel::Safe,
612                self.guest_vf_state,
613            )
614            .sensitivity_field(
615                "data_path_switched",
616                SensitivityLevel::Safe,
617                self.is_data_path_switched,
618            )
619            .sensitivity_field(
620                "pending_control_messages",
621                SensitivityLevel::Safe,
622                self.control_messages.len(),
623            )
624            .sensitivity_field(
625                "free_control_message_buffers",
626                SensitivityLevel::Safe,
627                self.free_control_buffers.len(),
628            )
629            .sensitivity_field(
630                "pending_offload_change",
631                SensitivityLevel::Safe,
632                self.pending_offload_change,
633            )
634            .sensitivity_field("rndis_state", SensitivityLevel::Safe, self.rndis_state)
635            .sensitivity_field(
636                "offload_config",
637                SensitivityLevel::Safe,
638                &self.offload_config,
639            )
640            .sensitivity_field(
641                "tx_spread_sent",
642                SensitivityLevel::Safe,
643                self.tx_spread_sent,
644            )
645            .sensitivity_field("guest_link_up", SensitivityLevel::Safe, self.guest_link_up)
646            .sensitivity_field(
647                "pending_link_action",
648                SensitivityLevel::Safe,
649                match &self.pending_link_action {
650                    PendingLinkAction::Active(up) => format!("Active({:x?})", up),
651                    PendingLinkAction::Delay(up) => format!("Delay({:x?})", up),
652                    PendingLinkAction::Default => "None".to_string(),
653                },
654            );
655    }
656}
657
658#[derive(Debug, Inspect, Clone)]
659struct OffloadConfig {
660    #[inspect(safe)]
661    checksum_tx: ChecksumOffloadConfig,
662    #[inspect(safe)]
663    checksum_rx: ChecksumOffloadConfig,
664    #[inspect(safe)]
665    lso4: bool,
666    #[inspect(safe)]
667    lso6: bool,
668}
669
670#[derive(Debug, Inspect, Clone)]
671struct ChecksumOffloadConfig {
672    #[inspect(safe)]
673    ipv4_header: bool,
674    #[inspect(safe)]
675    tcp4: bool,
676    #[inspect(safe)]
677    udp4: bool,
678    #[inspect(safe)]
679    tcp6: bool,
680    #[inspect(safe)]
681    udp6: bool,
682}
683
684impl ChecksumOffloadConfig {
685    fn flags(
686        &self,
687    ) -> (
688        rndisprot::Ipv4ChecksumOffload,
689        rndisprot::Ipv6ChecksumOffload,
690    ) {
691        let on = rndisprot::NDIS_OFFLOAD_SUPPORTED;
692        let mut v4 = rndisprot::Ipv4ChecksumOffload::new();
693        let mut v6 = rndisprot::Ipv6ChecksumOffload::new();
694        if self.ipv4_header {
695            v4.set_ip_options_supported(on);
696            v4.set_ip_checksum(on);
697        }
698        if self.tcp4 {
699            v4.set_ip_options_supported(on);
700            v4.set_tcp_options_supported(on);
701            v4.set_tcp_checksum(on);
702        }
703        if self.tcp6 {
704            v6.set_ip_extension_headers_supported(on);
705            v6.set_tcp_options_supported(on);
706            v6.set_tcp_checksum(on);
707        }
708        if self.udp4 {
709            v4.set_ip_options_supported(on);
710            v4.set_udp_checksum(on);
711        }
712        if self.udp6 {
713            v6.set_ip_extension_headers_supported(on);
714            v6.set_udp_checksum(on);
715        }
716        (v4, v6)
717    }
718}
719
720impl OffloadConfig {
721    fn ndis_offload(&self) -> rndisprot::NdisOffload {
722        let checksum = {
723            let (ipv4_tx_flags, ipv6_tx_flags) = self.checksum_tx.flags();
724            let (ipv4_rx_flags, ipv6_rx_flags) = self.checksum_rx.flags();
725            rndisprot::TcpIpChecksumOffload {
726                ipv4_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
727                ipv4_tx_flags,
728                ipv4_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
729                ipv4_rx_flags,
730                ipv6_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
731                ipv6_tx_flags,
732                ipv6_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
733                ipv6_rx_flags,
734            }
735        };
736
737        let lso_v2 = {
738            let mut lso = rndisprot::TcpLargeSendOffloadV2::new_zeroed();
739            // Use the same maximum as vmswitch.
740            const MAX_OFFLOAD_SIZE: u32 = 0xF53C;
741            if self.lso4 {
742                lso.ipv4_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
743                lso.ipv4_max_offload_size = MAX_OFFLOAD_SIZE;
744                lso.ipv4_min_segment_count = 2;
745            }
746            if self.lso6 {
747                lso.ipv6_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
748                lso.ipv6_max_offload_size = MAX_OFFLOAD_SIZE;
749                lso.ipv6_min_segment_count = 2;
750                lso.ipv6_flags = rndisprot::Ipv6LsoFlags::new()
751                    .with_ip_extension_headers_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED)
752                    .with_tcp_options_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED);
753            }
754            lso
755        };
756
757        rndisprot::NdisOffload {
758            header: rndisprot::NdisObjectHeader {
759                object_type: rndisprot::NdisObjectType::OFFLOAD,
760                revision: 3,
761                size: rndisprot::NDIS_SIZEOF_NDIS_OFFLOAD_REVISION_3 as u16,
762            },
763            checksum,
764            lso_v2,
765            ..FromZeros::new_zeroed()
766        }
767    }
768}
769
770#[derive(Debug, Inspect, PartialEq, Eq, Copy, Clone)]
771pub enum RndisState {
772    Initializing,
773    Operational,
774    Halted,
775}
776
777impl PrimaryChannelState {
778    fn new(offload_config: OffloadConfig) -> Self {
779        Self {
780            guest_vf_state: PrimaryChannelGuestVfState::Initializing,
781            is_data_path_switched: None,
782            control_messages: VecDeque::new(),
783            control_messages_len: 0,
784            free_control_buffers: (0..RX_RESERVED_CONTROL_BUFFERS)
785                .map(ControlMessageId)
786                .collect(),
787            rss_state: None,
788            requested_num_queues: 1,
789            rndis_state: RndisState::Initializing,
790            pending_offload_change: false,
791            offload_config,
792            tx_spread_sent: false,
793            guest_link_up: true,
794            pending_link_action: PendingLinkAction::Default,
795        }
796    }
797
798    fn restore(
799        guest_vf_state: &saved_state::GuestVfState,
800        rndis_state: &saved_state::RndisState,
801        offload_config: &saved_state::OffloadConfig,
802        pending_offload_change: bool,
803        num_queues: u16,
804        indirection_table_size: u16,
805        rx_bufs: &RxBuffers,
806        control_messages: Vec<saved_state::IncomingControlMessage>,
807        rss_state: Option<saved_state::RssState>,
808        tx_spread_sent: bool,
809        guest_link_down: bool,
810        pending_link_action: Option<bool>,
811    ) -> Result<Self, NetRestoreError> {
812        // Restore control messages.
813        let control_messages_len = control_messages.iter().map(|msg| msg.data.len()).sum();
814
815        let control_messages = control_messages
816            .into_iter()
817            .map(|msg| ControlMessage {
818                message_type: msg.message_type,
819                data: msg.data.into(),
820            })
821            .collect();
822
823        // Compute the free control buffers.
824        let free_control_buffers = (0..RX_RESERVED_CONTROL_BUFFERS)
825            .filter_map(|id| rx_bufs.is_free(id).then_some(ControlMessageId(id)))
826            .collect();
827
828        let rss_state = rss_state
829            .map(|mut rss| {
830                if rss.indirection_table.len() > indirection_table_size as usize {
831                    // Dynamic reduction of indirection table can cause unexpected and hard to investigate issues
832                    // with performance and processor overloading.
833                    return Err(NetRestoreError::ReducedIndirectionTableSize);
834                }
835                if rss.indirection_table.len() < indirection_table_size as usize {
836                    tracing::warn!(
837                        saved_indirection_table_size = rss.indirection_table.len(),
838                        adapter_indirection_table_size = indirection_table_size,
839                        "increasing indirection table size",
840                    );
841                    // Dynamic increase of indirection table is done by duplicating the existing entries until
842                    // the desired size is reached.
843                    let table_clone = rss.indirection_table.clone();
844                    let num_to_add = indirection_table_size as usize - rss.indirection_table.len();
845                    rss.indirection_table
846                        .extend(table_clone.iter().cycle().take(num_to_add));
847                }
848                Ok(RssState {
849                    key: rss
850                        .key
851                        .try_into()
852                        .map_err(|_| NetRestoreError::InvalidRssKeySize)?,
853                    indirection_table: rss.indirection_table,
854                })
855            })
856            .transpose()?;
857
858        let rndis_state = match rndis_state {
859            saved_state::RndisState::Initializing => RndisState::Initializing,
860            saved_state::RndisState::Operational => RndisState::Operational,
861            saved_state::RndisState::Halted => RndisState::Halted,
862        };
863
864        let guest_vf_state = PrimaryChannelGuestVfState::Restoring(*guest_vf_state);
865        let offload_config = OffloadConfig {
866            checksum_tx: ChecksumOffloadConfig {
867                ipv4_header: offload_config.checksum_tx.ipv4_header,
868                tcp4: offload_config.checksum_tx.tcp4,
869                udp4: offload_config.checksum_tx.udp4,
870                tcp6: offload_config.checksum_tx.tcp6,
871                udp6: offload_config.checksum_tx.udp6,
872            },
873            checksum_rx: ChecksumOffloadConfig {
874                ipv4_header: offload_config.checksum_rx.ipv4_header,
875                tcp4: offload_config.checksum_rx.tcp4,
876                udp4: offload_config.checksum_rx.udp4,
877                tcp6: offload_config.checksum_rx.tcp6,
878                udp6: offload_config.checksum_rx.udp6,
879            },
880            lso4: offload_config.lso4,
881            lso6: offload_config.lso6,
882        };
883
884        let pending_link_action = if let Some(pending) = pending_link_action {
885            PendingLinkAction::Active(pending)
886        } else {
887            PendingLinkAction::Default
888        };
889
890        Ok(Self {
891            guest_vf_state,
892            is_data_path_switched: None,
893            control_messages,
894            control_messages_len,
895            free_control_buffers,
896            rss_state,
897            requested_num_queues: num_queues,
898            rndis_state,
899            pending_offload_change,
900            offload_config,
901            tx_spread_sent,
902            guest_link_up: !guest_link_down,
903            pending_link_action,
904        })
905    }
906}
907
908struct ControlMessage {
909    message_type: u32,
910    data: Box<[u8]>,
911}
912
913const TX_PACKET_QUOTA: usize = 1024;
914
915impl ActiveState {
916    fn new(primary: Option<PrimaryChannelState>, recv_buffer_count: u32) -> Self {
917        Self {
918            primary,
919            pending_tx_packets: vec![Default::default(); TX_PACKET_QUOTA],
920            free_tx_packets: (0..TX_PACKET_QUOTA as u32).rev().map(TxId).collect(),
921            pending_tx_completions: VecDeque::new(),
922            pending_rx_packets: VecDeque::new(),
923            rx_bufs: RxBuffers::new(recv_buffer_count),
924            stats: Default::default(),
925        }
926    }
927
928    fn restore(
929        channel: &saved_state::Channel,
930        recv_buffer_count: u32,
931    ) -> Result<Self, NetRestoreError> {
932        let mut active = Self::new(None, recv_buffer_count);
933        let saved_state::Channel {
934            pending_tx_completions,
935            in_use_rx,
936        } = channel;
937        for rx in in_use_rx {
938            active
939                .rx_bufs
940                .allocate(rx.buffers.as_slice().iter().copied())?;
941        }
942        for &transaction_id in pending_tx_completions {
943            // Consume tx quota if any is available. If not, still
944            // allow the restore since tx quota might change from
945            // release to release.
946            let tx_id = active.free_tx_packets.pop();
947            if let Some(id) = tx_id {
948                // This shouldn't be referenced, but set it in case it is in the future.
949                active.pending_tx_packets[id.0 as usize].transaction_id = transaction_id;
950            }
951            // Save/Restore does not preserve the status of pending tx completions,
952            // completing any pending completions with 'success' to avoid making changes to saved_state.
953            active
954                .pending_tx_completions
955                .push_back(PendingTxCompletion {
956                    transaction_id,
957                    tx_id,
958                    status: protocol::Status::SUCCESS,
959                });
960        }
961        Ok(active)
962    }
963}
964
965/// The state for an rndis tx packet that's currently pending in the backend
966/// endpoint.
967#[derive(Default, Clone)]
968struct PendingTxPacket {
969    pending_packet_count: usize,
970    transaction_id: u64,
971}
972
973/// The maximum batch size.
974///
975/// TODO: An even larger value is supported when RSC is enabled, so look into
976/// this.
977const RX_BATCH_SIZE: usize = 375;
978
979/// The number of receive buffers to reserve for control message responses.
980const RX_RESERVED_CONTROL_BUFFERS: u32 = 16;
981
982/// A network adapter.
983pub struct Nic {
984    instance_id: Guid,
985    resources: DeviceResources,
986    coordinator: TaskControl<CoordinatorState, Coordinator>,
987    coordinator_send: Option<mpsc::Sender<CoordinatorMessage>>,
988    adapter: Arc<Adapter>,
989    driver_source: VmTaskDriverSource,
990}
991
992pub struct NicBuilder {
993    virtual_function: Option<Box<dyn VirtualFunction>>,
994    limit_ring_buffer: bool,
995    max_queues: u16,
996    get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
997}
998
999impl NicBuilder {
1000    pub fn limit_ring_buffer(mut self, limit: bool) -> Self {
1001        self.limit_ring_buffer = limit;
1002        self
1003    }
1004
1005    pub fn max_queues(mut self, max_queues: u16) -> Self {
1006        self.max_queues = max_queues;
1007        self
1008    }
1009
1010    pub fn virtual_function(mut self, virtual_function: Box<dyn VirtualFunction>) -> Self {
1011        self.virtual_function = Some(virtual_function);
1012        self
1013    }
1014
1015    pub fn get_guest_os_id(mut self, os_type: Box<dyn Fn() -> HvGuestOsId + Send + Sync>) -> Self {
1016        self.get_guest_os_id = Some(os_type);
1017        self
1018    }
1019
1020    /// Creates a new NIC.
1021    pub fn build(
1022        self,
1023        driver_source: &VmTaskDriverSource,
1024        instance_id: Guid,
1025        endpoint: Box<dyn Endpoint>,
1026        mac_address: MacAddress,
1027        adapter_index: u32,
1028    ) -> Nic {
1029        let multiqueue = endpoint.multiqueue_support();
1030
1031        let max_queues = self.max_queues.clamp(
1032            1,
1033            multiqueue.max_queues.min(NETVSP_MAX_SUBCHANNELS_PER_VNIC),
1034        );
1035
1036        // If requested, limit the effective size of the outgoing ring buffer.
1037        // In a configuration where the NIC is processed synchronously, this
1038        // will ensure that we don't process incoming rx packets and tx packet
1039        // completions until the guest has processed the data it already has.
1040        let ring_size_limit = if self.limit_ring_buffer { 1024 } else { 0 };
1041
1042        // If the endpoint completes tx packets quickly, then avoid polling the
1043        // incoming ring (and thus avoid arming the signal from the guest) as
1044        // long as there are any tx packets in flight. This can significantly
1045        // reduce the signal rate from the guest, improving batching.
1046        let free_tx_packet_threshold = if endpoint.tx_fast_completions() {
1047            TX_PACKET_QUOTA
1048        } else {
1049            // Avoid getting into a situation where there is always barely
1050            // enough quota.
1051            TX_PACKET_QUOTA / 4
1052        };
1053
1054        let tx_offloads = endpoint.tx_offload_support();
1055
1056        // Always claim support for rx offloads since we can mark any given
1057        // packet as having unknown checksum state.
1058        let offload_support = OffloadConfig {
1059            checksum_rx: ChecksumOffloadConfig {
1060                ipv4_header: true,
1061                tcp4: true,
1062                udp4: true,
1063                tcp6: true,
1064                udp6: true,
1065            },
1066            checksum_tx: ChecksumOffloadConfig {
1067                ipv4_header: tx_offloads.ipv4_header,
1068                tcp4: tx_offloads.tcp,
1069                tcp6: tx_offloads.tcp,
1070                udp4: tx_offloads.udp,
1071                udp6: tx_offloads.udp,
1072            },
1073            lso4: tx_offloads.tso,
1074            lso6: tx_offloads.tso,
1075        };
1076
1077        let driver = driver_source.simple();
1078        let adapter = Arc::new(Adapter {
1079            driver,
1080            mac_address,
1081            max_queues,
1082            indirection_table_size: multiqueue.indirection_table_size,
1083            offload_support,
1084            free_tx_packet_threshold,
1085            ring_size_limit: ring_size_limit.into(),
1086            tx_fast_completions: endpoint.tx_fast_completions(),
1087            adapter_index,
1088            get_guest_os_id: self.get_guest_os_id,
1089            num_sub_channels_opened: AtomicUsize::new(0),
1090            link_speed: endpoint.link_speed(),
1091        });
1092
1093        let coordinator = TaskControl::new(CoordinatorState {
1094            endpoint,
1095            adapter: adapter.clone(),
1096            virtual_function: self.virtual_function,
1097            pending_vf_state: CoordinatorStatePendingVfState::Ready,
1098        });
1099
1100        Nic {
1101            instance_id,
1102            resources: Default::default(),
1103            coordinator,
1104            coordinator_send: None,
1105            adapter,
1106            driver_source: driver_source.clone(),
1107        }
1108    }
1109}
1110
1111fn can_use_ring_opt<T: RingMem>(queue: &mut Queue<T>, guest_os_id: Option<HvGuestOsId>) -> bool {
1112    let Some(guest_os_id) = guest_os_id else {
1113        // guest os id not available.
1114        return false;
1115    };
1116
1117    if !queue.split().0.supports_pending_send_size() {
1118        // guest does not support pending send size.
1119        return false;
1120    }
1121
1122    let Some(open_source_os) = guest_os_id.open_source() else {
1123        // guest os is proprietary (ex: Windows)
1124        return true;
1125    };
1126
1127    match HvGuestOsOpenSourceType(open_source_os.os_type()) {
1128        // Although FreeBSD indicates support for `pending send size`, it doesn't
1129        // implement it correctly. This was fixed in FreeBSD version `1400097`.
1130        HvGuestOsOpenSourceType::FREEBSD => open_source_os.version() >= 1400097,
1131        // Linux kernels prior to 3.11 have issues with pending send size optimization
1132        // which can affect certain Asynchronous I/O (AIO) network operations.
1133        // Disable ring size optimization for these older kernels to avoid flow control issues.
1134        HvGuestOsOpenSourceType::LINUX => {
1135            // Linux version is encoded as: ((major << 16) | (minor << 8) | patch)
1136            // Linux 3.11.0 = (3 << 16) | (11 << 8) | 0 = 199424
1137            open_source_os.version() >= 199424
1138        }
1139        _ => true,
1140    }
1141}
1142
1143impl Nic {
1144    pub fn builder() -> NicBuilder {
1145        NicBuilder {
1146            virtual_function: None,
1147            limit_ring_buffer: false,
1148            max_queues: !0,
1149            get_guest_os_id: None,
1150        }
1151    }
1152
1153    pub fn shutdown(self) -> Box<dyn Endpoint> {
1154        let (state, _) = self.coordinator.into_inner();
1155        state.endpoint
1156    }
1157}
1158
1159impl InspectMut for Nic {
1160    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
1161        self.coordinator.inspect_mut(req);
1162    }
1163}
1164
1165#[async_trait]
1166impl VmbusDevice for Nic {
1167    fn offer(&self) -> OfferParams {
1168        OfferParams {
1169            interface_name: "net".to_owned(),
1170            instance_id: self.instance_id,
1171            interface_id: Guid {
1172                data1: 0xf8615163,
1173                data2: 0xdf3e,
1174                data3: 0x46c5,
1175                data4: [0x91, 0x3f, 0xf2, 0xd2, 0xf9, 0x65, 0xed, 0xe],
1176            },
1177            subchannel_index: 0,
1178            mnf_interrupt_latency: Some(Duration::from_micros(100)),
1179            ..Default::default()
1180        }
1181    }
1182
1183    fn max_subchannels(&self) -> u16 {
1184        self.adapter.max_queues
1185    }
1186
1187    fn install(&mut self, resources: DeviceResources) {
1188        self.resources = resources;
1189    }
1190
1191    async fn open(
1192        &mut self,
1193        channel_idx: u16,
1194        open_request: &OpenRequest,
1195    ) -> Result<(), ChannelOpenError> {
1196        // Start the coordinator task if this is the primary channel.
1197        let state = if channel_idx == 0 {
1198            self.insert_coordinator(1, false);
1199            WorkerState::Init(None)
1200        } else {
1201            self.coordinator.stop().await;
1202            // Get the buffers created when the primary channel was opened.
1203            let buffers = self.coordinator.state().unwrap().buffers.clone().unwrap();
1204            WorkerState::Ready(ReadyState {
1205                state: ActiveState::new(None, buffers.recv_buffer.count),
1206                buffers,
1207                data: ProcessingData::new(),
1208            })
1209        };
1210
1211        let num_opened = self
1212            .adapter
1213            .num_sub_channels_opened
1214            .fetch_add(1, Ordering::SeqCst);
1215        let r = self.insert_worker(channel_idx, open_request, state, true);
1216        if channel_idx != 0
1217            && num_opened + 1 == self.coordinator.state_mut().unwrap().num_queues as usize
1218        {
1219            let coordinator = &mut self.coordinator.state_mut().unwrap();
1220            coordinator.workers[0].stop().await;
1221        }
1222
1223        if r.is_err() && channel_idx == 0 {
1224            self.coordinator.remove();
1225        } else {
1226            // The coordinator will restart any stopped workers.
1227            self.coordinator.start();
1228        }
1229        r?;
1230        Ok(())
1231    }
1232
1233    async fn close(&mut self, channel_idx: u16) {
1234        if !self.coordinator.has_state() {
1235            tracing::error!("Close called while vmbus channel is already closed");
1236            return;
1237        }
1238
1239        // Stop the coordinator to get access to the workers.
1240        let restart = self.coordinator.stop().await;
1241
1242        // Stop and remove the channel worker.
1243        {
1244            let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1245            worker.stop().await;
1246            if worker.has_state() {
1247                worker.remove();
1248            }
1249        }
1250
1251        self.adapter
1252            .num_sub_channels_opened
1253            .fetch_sub(1, Ordering::SeqCst);
1254        // Disable the endpoint.
1255        if channel_idx == 0 {
1256            for worker in &mut self.coordinator.state_mut().unwrap().workers {
1257                worker.task_mut().queue_state = None;
1258            }
1259
1260            // Note that this await is not restartable.
1261            self.coordinator.task_mut().endpoint.stop().await;
1262
1263            // Keep any VF's added to the guest. This is required to keep guest compat as
1264            // some apps (such as DPDK) relies on the VF sticking around even after vmbus
1265            // channel is closed.
1266            // The coordinator's job is done.
1267            self.coordinator.remove();
1268        } else {
1269            // Restart the coordinator.
1270            if restart {
1271                self.coordinator.start();
1272            }
1273        }
1274    }
1275
1276    async fn retarget_vp(&mut self, channel_idx: u16, target_vp: u32) {
1277        if !self.coordinator.has_state() {
1278            return;
1279        }
1280
1281        // Stop the coordinator and worker associated with this channel.
1282        let coordinator_running = self.coordinator.stop().await;
1283        let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1284        worker.stop().await;
1285        let (net_queue, worker_state) = worker.get_mut();
1286
1287        // Update the target VP on the driver.
1288        net_queue.driver.retarget_vp(target_vp);
1289
1290        if let Some(worker_state) = worker_state {
1291            // Update the target VP in the worker state.
1292            worker_state.target_vp = target_vp;
1293            if let Some(queue_state) = &mut net_queue.queue_state {
1294                // Tell the worker to re-set the target VP on next run.
1295                queue_state.target_vp_set = false;
1296            }
1297        }
1298
1299        // The coordinator will restart any stopped workers.
1300        if coordinator_running {
1301            self.coordinator.start();
1302        }
1303    }
1304
1305    fn start(&mut self) {
1306        if !self.coordinator.is_running() {
1307            self.coordinator.start();
1308        }
1309    }
1310
1311    async fn stop(&mut self) {
1312        self.coordinator.stop().await;
1313        if let Some(coordinator) = self.coordinator.state_mut() {
1314            coordinator.stop_workers().await;
1315        }
1316    }
1317
1318    fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
1319        Some(self)
1320    }
1321}
1322
1323#[async_trait]
1324impl SaveRestoreVmbusDevice for Nic {
1325    async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
1326        let state = self.saved_state();
1327        Ok(SavedStateBlob::new(state))
1328    }
1329
1330    async fn restore(
1331        &mut self,
1332        control: RestoreControl<'_>,
1333        state: SavedStateBlob,
1334    ) -> Result<(), RestoreError> {
1335        let state: saved_state::SavedState = state.parse()?;
1336        if let Err(err) = self.restore_state(control, state).await {
1337            tracing::error!(err = &err as &dyn std::error::Error, instance_id = %self.instance_id, "Failed restoring network vmbus state");
1338            Err(err.into())
1339        } else {
1340            Ok(())
1341        }
1342    }
1343}
1344
1345impl Nic {
1346    /// Allocates and inserts a worker.
1347    ///
1348    /// The coordinator must be stopped.
1349    fn insert_worker(
1350        &mut self,
1351        channel_idx: u16,
1352        open_request: &OpenRequest,
1353        state: WorkerState,
1354        start: bool,
1355    ) -> Result<(), OpenError> {
1356        let coordinator = self.coordinator.state_mut().unwrap();
1357
1358        // Retarget the driver now that the channel is open.
1359        let driver = coordinator.workers[channel_idx as usize]
1360            .task()
1361            .driver
1362            .clone();
1363        driver.retarget_vp(open_request.open_data.target_vp);
1364
1365        let raw = gpadl_channel(&driver, &self.resources, open_request, channel_idx)
1366            .map_err(OpenError::Ring)?;
1367        let mut queue = Queue::new(raw).map_err(OpenError::Queue)?;
1368        let guest_os_id = self.adapter.get_guest_os_id.as_ref().map(|f| f());
1369        let can_use_ring_size_opt = can_use_ring_opt(&mut queue, guest_os_id);
1370        let worker = Worker {
1371            channel_idx,
1372            target_vp: open_request.open_data.target_vp,
1373            mem: self
1374                .resources
1375                .offer_resources
1376                .guest_memory(open_request)
1377                .clone(),
1378            channel: NetChannel {
1379                adapter: self.adapter.clone(),
1380                queue,
1381                gpadl_map: self.resources.gpadl_map.clone(),
1382                packet_size: protocol::PACKET_SIZE_V1,
1383                pending_send_size: 0,
1384                restart: None,
1385                can_use_ring_size_opt,
1386                packet_filter: rndisprot::NDIS_PACKET_TYPE_NONE,
1387            },
1388            state,
1389            coordinator_send: self.coordinator_send.clone().unwrap(),
1390        };
1391        let instance_id = self.instance_id;
1392        let worker_task = &mut coordinator.workers[channel_idx as usize];
1393        worker_task.insert(
1394            driver,
1395            format!("netvsp-{}-{}", instance_id, channel_idx),
1396            worker,
1397        );
1398        if start {
1399            worker_task.start();
1400        }
1401        Ok(())
1402    }
1403}
1404
1405impl Nic {
1406    /// If `restoring`, then restart the queues as soon as the coordinator starts.
1407    fn insert_coordinator(&mut self, num_queues: u16, restoring: bool) {
1408        let mut driver_builder = self.driver_source.builder();
1409        // Target each driver to VP 0 initially. This will be updated when the
1410        // channel is opened.
1411        driver_builder.target_vp(0);
1412        // If tx completions arrive quickly, then just do tx processing
1413        // on whatever processor the guest happens to signal from.
1414        // Subsequent transmits will be pulled from the completion
1415        // processor.
1416        driver_builder.run_on_target(!self.adapter.tx_fast_completions);
1417
1418        #[expect(clippy::disallowed_methods)] // TODO
1419        let (send, recv) = mpsc::channel(1);
1420        self.coordinator_send = Some(send);
1421        self.coordinator.insert(
1422            &self.adapter.driver,
1423            format!("netvsp-{}-coordinator", self.instance_id),
1424            Coordinator {
1425                recv,
1426                channel_control: self.resources.channel_control.clone(),
1427                restart: restoring,
1428                workers: (0..self.adapter.max_queues)
1429                    .map(|i| {
1430                        TaskControl::new(NetQueue {
1431                            queue_state: None,
1432                            driver: driver_builder
1433                                .build(format!("netvsp-{}-{}", self.instance_id, i)),
1434                        })
1435                    })
1436                    .collect(),
1437                buffers: None,
1438                num_queues,
1439            },
1440        );
1441    }
1442}
1443
1444#[derive(Debug, Error)]
1445enum NetRestoreError {
1446    #[error("unsupported protocol version {0:#x}")]
1447    UnsupportedVersion(u32),
1448    #[error("send/receive buffer invalid gpadl ID")]
1449    UnknownGpadlId(#[from] UnknownGpadlId),
1450    #[error("failed to restore channels")]
1451    Channel(#[source] ChannelRestoreError),
1452    #[error(transparent)]
1453    ReceiveBuffer(#[from] BufferError),
1454    #[error(transparent)]
1455    SuballocationMisconfigured(#[from] SubAllocationInUse),
1456    #[error(transparent)]
1457    Open(#[from] OpenError),
1458    #[error("invalid rss key size")]
1459    InvalidRssKeySize,
1460    #[error("reduced indirection table size")]
1461    ReducedIndirectionTableSize,
1462}
1463
1464impl From<NetRestoreError> for RestoreError {
1465    fn from(err: NetRestoreError) -> Self {
1466        RestoreError::InvalidSavedState(anyhow::Error::new(err))
1467    }
1468}
1469
1470impl Nic {
1471    async fn restore_state(
1472        &mut self,
1473        mut control: RestoreControl<'_>,
1474        state: saved_state::SavedState,
1475    ) -> Result<(), NetRestoreError> {
1476        let mut saved_packet_filter = 0u32;
1477        if let Some(state) = state.open {
1478            let open = match &state.primary {
1479                saved_state::Primary::Version => vec![true],
1480                saved_state::Primary::Init(_) => vec![true],
1481                saved_state::Primary::Ready(ready) => {
1482                    ready.channels.iter().map(|x| x.is_some()).collect()
1483                }
1484            };
1485
1486            let mut states: Vec<_> = open.iter().map(|_| None).collect();
1487
1488            // N.B. This will restore the vmbus view of open channels, so any
1489            //      failures after this point could result in inconsistent
1490            //      state (vmbus believes the channel is open/active). There
1491            //      are a number of failure paths after this point because this
1492            //      call also restores vmbus device state, like the GPADL map.
1493            let requests = control
1494                .restore(&open)
1495                .await
1496                .map_err(NetRestoreError::Channel)?;
1497
1498            match state.primary {
1499                saved_state::Primary::Version => {
1500                    states[0] = Some(WorkerState::Init(None));
1501                }
1502                saved_state::Primary::Init(init) => {
1503                    let version = check_version(init.version)
1504                        .ok_or(NetRestoreError::UnsupportedVersion(init.version))?;
1505
1506                    let recv_buffer = init
1507                        .receive_buffer
1508                        .map(|recv_buffer| {
1509                            ReceiveBuffer::new(
1510                                &self.resources.gpadl_map,
1511                                recv_buffer.gpadl_id,
1512                                recv_buffer.id,
1513                                recv_buffer.sub_allocation_size,
1514                            )
1515                        })
1516                        .transpose()?;
1517
1518                    let send_buffer = init
1519                        .send_buffer
1520                        .map(|send_buffer| {
1521                            SendBuffer::new(&self.resources.gpadl_map, send_buffer.gpadl_id)
1522                        })
1523                        .transpose()?;
1524
1525                    let state = InitState {
1526                        version,
1527                        ndis_config: init.ndis_config.map(
1528                            |saved_state::NdisConfig { mtu, capabilities }| NdisConfig {
1529                                mtu,
1530                                capabilities: capabilities.into(),
1531                            },
1532                        ),
1533                        ndis_version: init.ndis_version.map(
1534                            |saved_state::NdisVersion { major, minor }| NdisVersion {
1535                                major,
1536                                minor,
1537                            },
1538                        ),
1539                        recv_buffer,
1540                        send_buffer,
1541                    };
1542                    states[0] = Some(WorkerState::Init(Some(state)));
1543                }
1544                saved_state::Primary::Ready(ready) => {
1545                    let saved_state::ReadyPrimary {
1546                        version,
1547                        receive_buffer,
1548                        send_buffer,
1549                        mut control_messages,
1550                        mut rss_state,
1551                        channels,
1552                        ndis_version,
1553                        ndis_config,
1554                        rndis_state,
1555                        guest_vf_state,
1556                        offload_config,
1557                        pending_offload_change,
1558                        tx_spread_sent,
1559                        guest_link_down,
1560                        pending_link_action,
1561                        packet_filter,
1562                    } = ready;
1563
1564                    // If saved state does not have a packet filter set, default to directed, multicast, and broadcast.
1565                    saved_packet_filter = packet_filter.unwrap_or(rndisprot::NPROTO_PACKET_FILTER);
1566
1567                    let version = check_version(version)
1568                        .ok_or(NetRestoreError::UnsupportedVersion(version))?;
1569
1570                    let request = requests[0].as_ref().unwrap();
1571                    let buffers = Arc::new(ChannelBuffers {
1572                        version,
1573                        mem: self.resources.offer_resources.guest_memory(request).clone(),
1574                        recv_buffer: ReceiveBuffer::new(
1575                            &self.resources.gpadl_map,
1576                            receive_buffer.gpadl_id,
1577                            receive_buffer.id,
1578                            receive_buffer.sub_allocation_size,
1579                        )?,
1580                        send_buffer: {
1581                            if let Some(send_buffer) = send_buffer {
1582                                Some(SendBuffer::new(
1583                                    &self.resources.gpadl_map,
1584                                    send_buffer.gpadl_id,
1585                                )?)
1586                            } else {
1587                                None
1588                            }
1589                        },
1590                        ndis_version: {
1591                            let saved_state::NdisVersion { major, minor } = ndis_version;
1592                            NdisVersion { major, minor }
1593                        },
1594                        ndis_config: {
1595                            let saved_state::NdisConfig { mtu, capabilities } = ndis_config;
1596                            NdisConfig {
1597                                mtu,
1598                                capabilities: capabilities.into(),
1599                            }
1600                        },
1601                    });
1602
1603                    for (channel_idx, channel) in channels.iter().enumerate() {
1604                        let channel = if let Some(channel) = channel {
1605                            channel
1606                        } else {
1607                            continue;
1608                        };
1609
1610                        let mut active = ActiveState::restore(channel, buffers.recv_buffer.count)?;
1611
1612                        // Restore primary channel state.
1613                        if channel_idx == 0 {
1614                            let primary = PrimaryChannelState::restore(
1615                                &guest_vf_state,
1616                                &rndis_state,
1617                                &offload_config,
1618                                pending_offload_change,
1619                                channels.len() as u16,
1620                                self.adapter.indirection_table_size,
1621                                &active.rx_bufs,
1622                                std::mem::take(&mut control_messages),
1623                                rss_state.take(),
1624                                tx_spread_sent,
1625                                guest_link_down,
1626                                pending_link_action,
1627                            )?;
1628                            active.primary = Some(primary);
1629                        }
1630
1631                        states[channel_idx] = Some(WorkerState::Ready(ReadyState {
1632                            buffers: buffers.clone(),
1633                            state: active,
1634                            data: ProcessingData::new(),
1635                        }));
1636                    }
1637                }
1638            }
1639
1640            // Insert the coordinator and mark that it should try to start the
1641            // network endpoint when it starts running.
1642            self.insert_coordinator(states.len() as u16, true);
1643
1644            for (channel_idx, (state, request)) in states.into_iter().zip(requests).enumerate() {
1645                if let Some(state) = state {
1646                    self.insert_worker(channel_idx as u16, &request.unwrap(), state, false)?;
1647                }
1648            }
1649            for worker in self.coordinator.state_mut().unwrap().workers.iter_mut() {
1650                if let Some(worker_state) = worker.state_mut() {
1651                    worker_state.channel.packet_filter = saved_packet_filter;
1652                }
1653            }
1654        } else {
1655            control
1656                .restore(&[false])
1657                .await
1658                .map_err(NetRestoreError::Channel)?;
1659        }
1660        Ok(())
1661    }
1662
1663    fn saved_state(&self) -> saved_state::SavedState {
1664        let open = if let Some(coordinator) = self.coordinator.state() {
1665            let primary = coordinator.workers[0].state().unwrap();
1666            let primary = match &primary.state {
1667                WorkerState::Init(None) => saved_state::Primary::Version,
1668                WorkerState::Init(Some(init)) => {
1669                    saved_state::Primary::Init(saved_state::InitPrimary {
1670                        version: init.version as u32,
1671                        ndis_config: init.ndis_config.map(|NdisConfig { mtu, capabilities }| {
1672                            saved_state::NdisConfig {
1673                                mtu,
1674                                capabilities: capabilities.into(),
1675                            }
1676                        }),
1677                        ndis_version: init.ndis_version.map(|NdisVersion { major, minor }| {
1678                            saved_state::NdisVersion { major, minor }
1679                        }),
1680                        receive_buffer: init.recv_buffer.as_ref().map(|x| x.saved_state()),
1681                        send_buffer: init.send_buffer.as_ref().map(|x| saved_state::SendBuffer {
1682                            gpadl_id: x.gpadl.id(),
1683                        }),
1684                    })
1685                }
1686                WorkerState::Ready(ready) => {
1687                    let primary = ready.state.primary.as_ref().unwrap();
1688
1689                    let rndis_state = match primary.rndis_state {
1690                        RndisState::Initializing => saved_state::RndisState::Initializing,
1691                        RndisState::Operational => saved_state::RndisState::Operational,
1692                        RndisState::Halted => saved_state::RndisState::Halted,
1693                    };
1694
1695                    let offload_config = saved_state::OffloadConfig {
1696                        checksum_tx: saved_state::ChecksumOffloadConfig {
1697                            ipv4_header: primary.offload_config.checksum_tx.ipv4_header,
1698                            tcp4: primary.offload_config.checksum_tx.tcp4,
1699                            udp4: primary.offload_config.checksum_tx.udp4,
1700                            tcp6: primary.offload_config.checksum_tx.tcp6,
1701                            udp6: primary.offload_config.checksum_tx.udp6,
1702                        },
1703                        checksum_rx: saved_state::ChecksumOffloadConfig {
1704                            ipv4_header: primary.offload_config.checksum_rx.ipv4_header,
1705                            tcp4: primary.offload_config.checksum_rx.tcp4,
1706                            udp4: primary.offload_config.checksum_rx.udp4,
1707                            tcp6: primary.offload_config.checksum_rx.tcp6,
1708                            udp6: primary.offload_config.checksum_rx.udp6,
1709                        },
1710                        lso4: primary.offload_config.lso4,
1711                        lso6: primary.offload_config.lso6,
1712                    };
1713
1714                    let control_messages = primary
1715                        .control_messages
1716                        .iter()
1717                        .map(|message| saved_state::IncomingControlMessage {
1718                            message_type: message.message_type,
1719                            data: message.data.to_vec(),
1720                        })
1721                        .collect();
1722
1723                    let rss_state = primary.rss_state.as_ref().map(|rss| saved_state::RssState {
1724                        key: rss.key.into(),
1725                        indirection_table: rss.indirection_table.clone(),
1726                    });
1727
1728                    let pending_link_action = match primary.pending_link_action {
1729                        PendingLinkAction::Default => None,
1730                        PendingLinkAction::Active(action) | PendingLinkAction::Delay(action) => {
1731                            Some(action)
1732                        }
1733                    };
1734
1735                    let channels = coordinator.workers[..coordinator.num_queues as usize]
1736                        .iter()
1737                        .map(|worker| {
1738                            worker.state().map(|worker| {
1739                                if let Some(ready) = worker.state.ready() {
1740                                    // In flight tx will be considered as dropped packets through save/restore, but need
1741                                    // to complete the requests back to the guest.
1742                                    let pending_tx_completions = ready
1743                                        .state
1744                                        .pending_tx_completions
1745                                        .iter()
1746                                        .map(|pending| pending.transaction_id)
1747                                        .chain(ready.state.pending_tx_packets.iter().filter_map(
1748                                            |inflight| {
1749                                                (inflight.pending_packet_count > 0)
1750                                                    .then_some(inflight.transaction_id)
1751                                            },
1752                                        ))
1753                                        .collect();
1754
1755                                    saved_state::Channel {
1756                                        pending_tx_completions,
1757                                        in_use_rx: {
1758                                            ready
1759                                                .state
1760                                                .rx_bufs
1761                                                .allocated()
1762                                                .map(|id| saved_state::Rx {
1763                                                    buffers: id.collect(),
1764                                                })
1765                                                .collect()
1766                                        },
1767                                    }
1768                                } else {
1769                                    saved_state::Channel {
1770                                        pending_tx_completions: Vec::new(),
1771                                        in_use_rx: Vec::new(),
1772                                    }
1773                                }
1774                            })
1775                        })
1776                        .collect();
1777
1778                    let guest_vf_state = match primary.guest_vf_state {
1779                        PrimaryChannelGuestVfState::Initializing
1780                        | PrimaryChannelGuestVfState::Unavailable
1781                        | PrimaryChannelGuestVfState::Available { .. } => {
1782                            saved_state::GuestVfState::NoState
1783                        }
1784                        PrimaryChannelGuestVfState::UnavailableFromAvailable
1785                        | PrimaryChannelGuestVfState::AvailableAdvertised => {
1786                            saved_state::GuestVfState::AvailableAdvertised
1787                        }
1788                        PrimaryChannelGuestVfState::Ready => saved_state::GuestVfState::Ready,
1789                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
1790                            to_guest,
1791                            id,
1792                        } => saved_state::GuestVfState::DataPathSwitchPending {
1793                            to_guest,
1794                            id,
1795                            result: None,
1796                        },
1797                        PrimaryChannelGuestVfState::DataPathSwitchPending {
1798                            to_guest,
1799                            id,
1800                            result,
1801                        } => saved_state::GuestVfState::DataPathSwitchPending {
1802                            to_guest,
1803                            id,
1804                            result,
1805                        },
1806                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
1807                        | PrimaryChannelGuestVfState::DataPathSwitched
1808                        | PrimaryChannelGuestVfState::DataPathSynthetic => {
1809                            saved_state::GuestVfState::DataPathSwitched
1810                        }
1811                        PrimaryChannelGuestVfState::Restoring(saved_state) => saved_state,
1812                    };
1813
1814                    let worker_0_packet_filter = coordinator.workers[0]
1815                        .state()
1816                        .unwrap()
1817                        .channel
1818                        .packet_filter;
1819                    saved_state::Primary::Ready(saved_state::ReadyPrimary {
1820                        version: ready.buffers.version as u32,
1821                        receive_buffer: ready.buffers.recv_buffer.saved_state(),
1822                        send_buffer: ready.buffers.send_buffer.as_ref().map(|sb| {
1823                            saved_state::SendBuffer {
1824                                gpadl_id: sb.gpadl.id(),
1825                            }
1826                        }),
1827                        rndis_state,
1828                        guest_vf_state,
1829                        offload_config,
1830                        pending_offload_change: primary.pending_offload_change,
1831                        control_messages,
1832                        rss_state,
1833                        channels,
1834                        ndis_config: {
1835                            let NdisConfig { mtu, capabilities } = ready.buffers.ndis_config;
1836                            saved_state::NdisConfig {
1837                                mtu,
1838                                capabilities: capabilities.into(),
1839                            }
1840                        },
1841                        ndis_version: {
1842                            let NdisVersion { major, minor } = ready.buffers.ndis_version;
1843                            saved_state::NdisVersion { major, minor }
1844                        },
1845                        tx_spread_sent: primary.tx_spread_sent,
1846                        guest_link_down: !primary.guest_link_up,
1847                        pending_link_action,
1848                        packet_filter: Some(worker_0_packet_filter),
1849                    })
1850                }
1851            };
1852
1853            let state = saved_state::OpenState { primary };
1854            Some(state)
1855        } else {
1856            None
1857        };
1858
1859        saved_state::SavedState { open }
1860    }
1861}
1862
1863#[derive(Debug, Error)]
1864enum WorkerError {
1865    #[error("packet error")]
1866    Packet(#[source] PacketError),
1867    #[error("unexpected packet order: {0}")]
1868    UnexpectedPacketOrder(#[source] PacketOrderError),
1869    #[error("unknown rndis message type: {0}")]
1870    UnknownRndisMessageType(u32),
1871    #[error("memory access error")]
1872    Access(#[from] AccessError),
1873    #[error("rndis message too small")]
1874    RndisMessageTooSmall,
1875    #[error("unsupported rndis behavior")]
1876    UnsupportedRndisBehavior,
1877    #[error("vmbus queue error")]
1878    Queue(#[from] queue::Error),
1879    #[error("too many control messages")]
1880    TooManyControlMessages,
1881    #[error("invalid rndis packet completion")]
1882    InvalidRndisPacketCompletion,
1883    #[error("missing transaction id")]
1884    MissingTransactionId,
1885    #[error("invalid gpadl")]
1886    InvalidGpadl(#[source] guestmem::InvalidGpn),
1887    #[error("gpadl error")]
1888    GpadlError(#[source] GuestMemoryError),
1889    #[error("gpa direct error")]
1890    GpaDirectError(#[source] GuestMemoryError),
1891    #[error("endpoint")]
1892    Endpoint(#[source] anyhow::Error),
1893    #[error("message not supported on sub channel: {0}")]
1894    NotSupportedOnSubChannel(u32),
1895    #[error("the ring buffer ran out of space, which should not be possible")]
1896    OutOfSpace,
1897    #[error("send/receive buffer error")]
1898    Buffer(#[from] BufferError),
1899    #[error("invalid rndis state")]
1900    InvalidRndisState,
1901    #[error("rndis message type not implemented")]
1902    RndisMessageTypeNotImplemented,
1903    #[error("invalid TCP header offset")]
1904    InvalidTcpHeaderOffset,
1905    #[error("cancelled")]
1906    Cancelled(task_control::Cancelled),
1907    #[error("tearing down because send/receive buffer is revoked")]
1908    BufferRevoked,
1909    #[error("endpoint requires queue restart: {0}")]
1910    EndpointRequiresQueueRestart(#[source] anyhow::Error),
1911}
1912
1913impl From<task_control::Cancelled> for WorkerError {
1914    fn from(value: task_control::Cancelled) -> Self {
1915        Self::Cancelled(value)
1916    }
1917}
1918
1919#[derive(Debug, Error)]
1920enum OpenError {
1921    #[error("error establishing ring buffer")]
1922    Ring(#[source] vmbus_channel::gpadl_ring::Error),
1923    #[error("error establishing vmbus queue")]
1924    Queue(#[source] queue::Error),
1925}
1926
1927#[derive(Debug, Error)]
1928enum PacketError {
1929    #[error("UnknownType {0}")]
1930    UnknownType(u32),
1931    #[error("Access")]
1932    Access(#[source] AccessError),
1933    #[error("ExternalData")]
1934    ExternalData(#[source] ExternalDataError),
1935    #[error("InvalidSendBufferIndex")]
1936    InvalidSendBufferIndex,
1937}
1938
1939#[derive(Debug, Error)]
1940enum PacketOrderError {
1941    #[error("Invalid PacketData")]
1942    InvalidPacketData,
1943    #[error("Unexpected RndisPacket")]
1944    UnexpectedRndisPacket,
1945    #[error("SendNdisVersion already exists")]
1946    SendNdisVersionExists,
1947    #[error("SendNdisConfig already exists")]
1948    SendNdisConfigExists,
1949    #[error("SendReceiveBuffer already exists")]
1950    SendReceiveBufferExists,
1951    #[error("SendReceiveBuffer missing MTU")]
1952    SendReceiveBufferMissingMTU,
1953    #[error("SendSendBuffer already exists")]
1954    SendSendBufferExists,
1955    #[error("SwitchDataPathCompletion during PrimaryChannelState")]
1956    SwitchDataPathCompletionPrimaryChannelState,
1957}
1958
1959#[derive(Debug)]
1960enum PacketData {
1961    Init(protocol::MessageInit),
1962    SendNdisVersion(protocol::Message1SendNdisVersion),
1963    SendReceiveBuffer(protocol::Message1SendReceiveBuffer),
1964    SendSendBuffer(protocol::Message1SendSendBuffer),
1965    RevokeReceiveBuffer(Message1RevokeReceiveBuffer),
1966    RevokeSendBuffer(Message1RevokeSendBuffer),
1967    RndisPacket(protocol::Message1SendRndisPacket),
1968    RndisPacketComplete(protocol::Message1SendRndisPacketComplete),
1969    SendNdisConfig(protocol::Message2SendNdisConfig),
1970    SwitchDataPath(protocol::Message4SwitchDataPath),
1971    OidQueryEx(protocol::Message5OidQueryEx),
1972    SubChannelRequest(protocol::Message5SubchannelRequest),
1973    SendVfAssociationCompletion,
1974    SwitchDataPathCompletion,
1975}
1976
1977#[derive(Debug)]
1978struct Packet<'a> {
1979    data: PacketData,
1980    transaction_id: Option<u64>,
1981    external_data: MultiPagedRangeBuf<GpnList>,
1982    send_buffer_suballocation: PagedRange<'a>,
1983}
1984
1985type PacketReader<'a> = PagedRangesReader<
1986    'a,
1987    std::iter::Chain<std::iter::Once<PagedRange<'a>>, MultiPagedRangeIter<'a>>,
1988>;
1989
1990impl Packet<'_> {
1991    fn rndis_reader<'a>(&'a self, mem: &'a GuestMemory) -> PacketReader<'a> {
1992        PagedRanges::new(
1993            std::iter::once(self.send_buffer_suballocation).chain(self.external_data.iter()),
1994        )
1995        .reader(mem)
1996    }
1997}
1998
1999fn read_packet_data<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
2000    reader: &mut impl MemoryRead,
2001) -> Result<T, PacketError> {
2002    reader.read_plain().map_err(PacketError::Access)
2003}
2004
2005fn parse_packet<'a, T: RingMem>(
2006    packet_ref: &queue::PacketRef<'_, T>,
2007    send_buffer: Option<&'a SendBuffer>,
2008    version: Option<Version>,
2009) -> Result<Packet<'a>, PacketError> {
2010    let packet = match packet_ref.as_ref() {
2011        IncomingPacket::Data(data) => data,
2012        IncomingPacket::Completion(completion) => {
2013            let data = if completion.transaction_id() == VF_ASSOCIATION_TRANSACTION_ID {
2014                PacketData::SendVfAssociationCompletion
2015            } else if completion.transaction_id() == SWITCH_DATA_PATH_TRANSACTION_ID {
2016                PacketData::SwitchDataPathCompletion
2017            } else {
2018                let mut reader = completion.reader();
2019                let header: protocol::MessageHeader =
2020                    reader.read_plain().map_err(PacketError::Access)?;
2021                match header.message_type {
2022                    protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE => {
2023                        PacketData::RndisPacketComplete(read_packet_data(&mut reader)?)
2024                    }
2025                    typ => return Err(PacketError::UnknownType(typ)),
2026                }
2027            };
2028            return Ok(Packet {
2029                data,
2030                transaction_id: Some(completion.transaction_id()),
2031                external_data: MultiPagedRangeBuf::empty(),
2032                send_buffer_suballocation: PagedRange::empty(),
2033            });
2034        }
2035    };
2036
2037    let mut reader = packet.reader();
2038    let header: protocol::MessageHeader = reader.read_plain().map_err(PacketError::Access)?;
2039    let mut send_buffer_suballocation = PagedRange::empty();
2040    let data = match header.message_type {
2041        protocol::MESSAGE_TYPE_INIT => PacketData::Init(read_packet_data(&mut reader)?),
2042        protocol::MESSAGE1_TYPE_SEND_NDIS_VERSION if version >= Some(Version::V1) => {
2043            PacketData::SendNdisVersion(read_packet_data(&mut reader)?)
2044        }
2045        protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2046            PacketData::SendReceiveBuffer(read_packet_data(&mut reader)?)
2047        }
2048        protocol::MESSAGE1_TYPE_REVOKE_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2049            PacketData::RevokeReceiveBuffer(read_packet_data(&mut reader)?)
2050        }
2051        protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER if version >= Some(Version::V1) => {
2052            PacketData::SendSendBuffer(read_packet_data(&mut reader)?)
2053        }
2054        protocol::MESSAGE1_TYPE_REVOKE_SEND_BUFFER if version >= Some(Version::V1) => {
2055            PacketData::RevokeSendBuffer(read_packet_data(&mut reader)?)
2056        }
2057        protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET if version >= Some(Version::V1) => {
2058            let message: protocol::Message1SendRndisPacket = read_packet_data(&mut reader)?;
2059            if message.send_buffer_section_index != 0xffffffff {
2060                send_buffer_suballocation = send_buffer
2061                    .ok_or(PacketError::InvalidSendBufferIndex)?
2062                    .gpadl
2063                    .first()
2064                    .unwrap()
2065                    .try_subrange(
2066                        message.send_buffer_section_index as usize * 6144,
2067                        message.send_buffer_section_size as usize,
2068                    )
2069                    .ok_or(PacketError::InvalidSendBufferIndex)?;
2070            }
2071            PacketData::RndisPacket(message)
2072        }
2073        protocol::MESSAGE2_TYPE_SEND_NDIS_CONFIG if version >= Some(Version::V2) => {
2074            PacketData::SendNdisConfig(read_packet_data(&mut reader)?)
2075        }
2076        protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH if version >= Some(Version::V4) => {
2077            PacketData::SwitchDataPath(read_packet_data(&mut reader)?)
2078        }
2079        protocol::MESSAGE5_TYPE_OID_QUERY_EX if version >= Some(Version::V5) => {
2080            PacketData::OidQueryEx(read_packet_data(&mut reader)?)
2081        }
2082        protocol::MESSAGE5_TYPE_SUB_CHANNEL if version >= Some(Version::V5) => {
2083            PacketData::SubChannelRequest(read_packet_data(&mut reader)?)
2084        }
2085        typ => return Err(PacketError::UnknownType(typ)),
2086    };
2087    Ok(Packet {
2088        data,
2089        transaction_id: packet.transaction_id(),
2090        external_data: packet
2091            .read_external_ranges()
2092            .map_err(PacketError::ExternalData)?,
2093        send_buffer_suballocation,
2094    })
2095}
2096
2097#[derive(Debug, Copy, Clone)]
2098struct NvspMessage<T> {
2099    header: protocol::MessageHeader,
2100    data: T,
2101    padding: &'static [u8],
2102}
2103
2104impl<T: IntoBytes + Immutable + KnownLayout> NvspMessage<T> {
2105    fn payload(&self) -> [&[u8]; 3] {
2106        [self.header.as_bytes(), self.data.as_bytes(), self.padding]
2107    }
2108}
2109
2110impl<T: RingMem> NetChannel<T> {
2111    fn message<P: IntoBytes + Immutable + KnownLayout>(
2112        &self,
2113        message_type: u32,
2114        data: P,
2115    ) -> NvspMessage<P> {
2116        let padding = self.padding(&data);
2117        NvspMessage {
2118            header: protocol::MessageHeader { message_type },
2119            data,
2120            padding,
2121        }
2122    }
2123
2124    /// Returns zero padding bytes to round the payload up to the packet size.
2125    /// Only needed for Windows guests, which are picky about packet sizes.
2126    fn padding<P: IntoBytes + Immutable + KnownLayout>(&self, data: &P) -> &'static [u8] {
2127        static PADDING: &[u8] = &[0; protocol::PACKET_SIZE_V61];
2128        let padding_len = self.packet_size
2129            - cmp::min(
2130                self.packet_size,
2131                size_of::<protocol::MessageHeader>() + data.as_bytes().len(),
2132            );
2133        &PADDING[..padding_len]
2134    }
2135
2136    fn send_completion(
2137        &mut self,
2138        transaction_id: Option<u64>,
2139        payload: &[&[u8]],
2140    ) -> Result<(), WorkerError> {
2141        match transaction_id {
2142            None => Ok(()),
2143            Some(transaction_id) => Ok(self
2144                .queue
2145                .split()
2146                .1
2147                .try_write(&queue::OutgoingPacket {
2148                    transaction_id,
2149                    packet_type: OutgoingPacketType::Completion,
2150                    payload,
2151                })
2152                .map_err(|err| match err {
2153                    queue::TryWriteError::Full(_) => WorkerError::OutOfSpace,
2154                    queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2155                })?),
2156        }
2157    }
2158}
2159
2160static SUPPORTED_VERSIONS: &[Version] = &[
2161    Version::V1,
2162    Version::V2,
2163    Version::V4,
2164    Version::V5,
2165    Version::V6,
2166    Version::V61,
2167];
2168
2169fn check_version(requested_version: u32) -> Option<Version> {
2170    SUPPORTED_VERSIONS
2171        .iter()
2172        .find(|version| **version as u32 == requested_version)
2173        .copied()
2174}
2175
2176#[derive(Debug)]
2177struct ReceiveBuffer {
2178    gpadl: GpadlView,
2179    id: u16,
2180    count: u32,
2181    sub_allocation_size: u32,
2182}
2183
2184#[derive(Debug, Error)]
2185enum BufferError {
2186    #[error("unsupported suballocation size {0}")]
2187    UnsupportedSuballocationSize(u32),
2188    #[error("unaligned gpadl")]
2189    UnalignedGpadl,
2190    #[error("unknown gpadl ID")]
2191    UnknownGpadlId(#[from] UnknownGpadlId),
2192}
2193
2194impl ReceiveBuffer {
2195    fn new(
2196        gpadl_map: &GpadlMapView,
2197        gpadl_id: GpadlId,
2198        id: u16,
2199        sub_allocation_size: u32,
2200    ) -> Result<Self, BufferError> {
2201        if sub_allocation_size < sub_allocation_size_for_mtu(DEFAULT_MTU) {
2202            return Err(BufferError::UnsupportedSuballocationSize(
2203                sub_allocation_size,
2204            ));
2205        }
2206        let gpadl = gpadl_map.map(gpadl_id)?;
2207        let range = gpadl
2208            .contiguous_aligned()
2209            .ok_or(BufferError::UnalignedGpadl)?;
2210        let num_sub_allocations = range.len() as u32 / sub_allocation_size;
2211        if num_sub_allocations == 0 {
2212            return Err(BufferError::UnsupportedSuballocationSize(
2213                sub_allocation_size,
2214            ));
2215        }
2216        let recv_buffer = Self {
2217            gpadl,
2218            id,
2219            count: num_sub_allocations,
2220            sub_allocation_size,
2221        };
2222        Ok(recv_buffer)
2223    }
2224
2225    fn range(&self, index: u32) -> PagedRange<'_> {
2226        self.gpadl.first().unwrap().subrange(
2227            (index * self.sub_allocation_size) as usize,
2228            self.sub_allocation_size as usize,
2229        )
2230    }
2231
2232    fn transfer_page_range(&self, index: u32, len: usize) -> ring::TransferPageRange {
2233        assert!(len <= self.sub_allocation_size as usize);
2234        ring::TransferPageRange {
2235            byte_offset: index * self.sub_allocation_size,
2236            byte_count: len as u32,
2237        }
2238    }
2239
2240    fn saved_state(&self) -> saved_state::ReceiveBuffer {
2241        saved_state::ReceiveBuffer {
2242            gpadl_id: self.gpadl.id(),
2243            id: self.id,
2244            sub_allocation_size: self.sub_allocation_size,
2245        }
2246    }
2247}
2248
2249#[derive(Debug)]
2250struct SendBuffer {
2251    gpadl: GpadlView,
2252}
2253
2254impl SendBuffer {
2255    fn new(gpadl_map: &GpadlMapView, gpadl_id: GpadlId) -> Result<Self, BufferError> {
2256        let gpadl = gpadl_map.map(gpadl_id)?;
2257        gpadl
2258            .contiguous_aligned()
2259            .ok_or(BufferError::UnalignedGpadl)?;
2260        Ok(Self { gpadl })
2261    }
2262}
2263
2264impl<T: RingMem> NetChannel<T> {
2265    /// Process a single RNDIS message.
2266    fn handle_rndis_message(
2267        &mut self,
2268        buffers: &ChannelBuffers,
2269        state: &mut ActiveState,
2270        id: TxId,
2271        message_type: u32,
2272        mut reader: PacketReader<'_>,
2273        segments: &mut Vec<TxSegment>,
2274    ) -> Result<bool, WorkerError> {
2275        let is_packet = match message_type {
2276            rndisprot::MESSAGE_TYPE_PACKET_MSG => {
2277                self.handle_rndis_packet_message(
2278                    id,
2279                    reader,
2280                    &buffers.mem,
2281                    segments,
2282                    &mut state.stats,
2283                )?;
2284                true
2285            }
2286            rndisprot::MESSAGE_TYPE_HALT_MSG => false,
2287            n => {
2288                let control = state
2289                    .primary
2290                    .as_mut()
2291                    .ok_or(WorkerError::NotSupportedOnSubChannel(n))?;
2292
2293                // This is a control message that needs a response. Responding
2294                // will require a suballocation to be available, which it may
2295                // not be right now. Enqueue the suballocation to a queue and
2296                // process the queue as suballocations become available.
2297                const CONTROL_MESSAGE_MAX_QUEUED_BYTES: usize = 100 * 1024;
2298                if reader.len() == 0 {
2299                    return Err(WorkerError::RndisMessageTooSmall);
2300                }
2301                // Do not let the queue get too large--the guest should not be
2302                // sending very many control messages at a time.
2303                if CONTROL_MESSAGE_MAX_QUEUED_BYTES - control.control_messages_len < reader.len() {
2304                    return Err(WorkerError::TooManyControlMessages);
2305                }
2306                control.control_messages_len += reader.len();
2307                control.control_messages.push_back(ControlMessage {
2308                    message_type,
2309                    data: reader.read_all()?.into(),
2310                });
2311
2312                false
2313                // The queue will be processed in the main dispatch loop.
2314            }
2315        };
2316        Ok(is_packet)
2317    }
2318
2319    /// Process an RNDIS package message (used to send an Ethernet frame).
2320    fn handle_rndis_packet_message(
2321        &mut self,
2322        id: TxId,
2323        reader: PacketReader<'_>,
2324        mem: &GuestMemory,
2325        segments: &mut Vec<TxSegment>,
2326        stats: &mut QueueStats,
2327    ) -> Result<(), WorkerError> {
2328        // Headers are guaranteed to be in a single PagedRange.
2329        let headers = reader
2330            .clone()
2331            .into_inner()
2332            .paged_ranges()
2333            .find(|r| !r.is_empty())
2334            .ok_or(WorkerError::RndisMessageTooSmall)?;
2335        let mut data = reader.into_inner();
2336        let request: rndisprot::Packet = headers.reader(mem).read_plain()?;
2337        if request.num_oob_data_elements != 0
2338            || request.oob_data_length != 0
2339            || request.oob_data_offset != 0
2340            || request.vc_handle != 0
2341        {
2342            return Err(WorkerError::UnsupportedRndisBehavior);
2343        }
2344
2345        if data.len() < request.data_offset as usize
2346            || (data.len() - request.data_offset as usize) < request.data_length as usize
2347            || request.data_length == 0
2348        {
2349            return Err(WorkerError::RndisMessageTooSmall);
2350        }
2351
2352        data.skip(request.data_offset as usize);
2353        data.truncate(request.data_length as usize);
2354
2355        let mut metadata = net_backend::TxMetadata {
2356            id,
2357            len: request.data_length as usize,
2358            ..Default::default()
2359        };
2360
2361        if request.per_packet_info_length != 0 {
2362            let mut ppi = headers
2363                .try_subrange(
2364                    request.per_packet_info_offset as usize,
2365                    request.per_packet_info_length as usize,
2366                )
2367                .ok_or(WorkerError::RndisMessageTooSmall)?;
2368            while !ppi.is_empty() {
2369                let h: rndisprot::PerPacketInfo = ppi.reader(mem).read_plain()?;
2370                if h.size == 0 {
2371                    return Err(WorkerError::RndisMessageTooSmall);
2372                }
2373                let (this, rest) = ppi
2374                    .try_split(h.size as usize)
2375                    .ok_or(WorkerError::RndisMessageTooSmall)?;
2376                let (_, d) = this
2377                    .try_split(h.per_packet_information_offset as usize)
2378                    .ok_or(WorkerError::RndisMessageTooSmall)?;
2379                match h.typ {
2380                    rndisprot::PPI_TCP_IP_CHECKSUM => {
2381                        let n: rndisprot::TxTcpIpChecksumInfo = d.reader(mem).read_plain()?;
2382
2383                        metadata.offload_tcp_checksum =
2384                            (n.is_ipv4() || n.is_ipv6()) && n.tcp_checksum();
2385                        metadata.offload_udp_checksum =
2386                            (n.is_ipv4() || n.is_ipv6()) && !n.tcp_checksum() && n.udp_checksum();
2387                        metadata.offload_ip_header_checksum = n.is_ipv4() && n.ip_header_checksum();
2388                        metadata.l3_protocol = if n.is_ipv4() {
2389                            L3Protocol::Ipv4
2390                        } else if n.is_ipv6() {
2391                            L3Protocol::Ipv6
2392                        } else {
2393                            L3Protocol::Unknown
2394                        };
2395                        metadata.l2_len = ETHERNET_HEADER_LEN as u8;
2396                        if metadata.offload_tcp_checksum || metadata.offload_udp_checksum {
2397                            metadata.l3_len = if n.tcp_header_offset() >= metadata.l2_len as u16 {
2398                                n.tcp_header_offset() - metadata.l2_len as u16
2399                            } else if n.is_ipv4() {
2400                                let mut reader = data.clone().reader(mem);
2401                                reader.skip(metadata.l2_len as usize)?;
2402                                let mut b = 0;
2403                                reader.read(std::slice::from_mut(&mut b))?;
2404                                (b as u16 >> 4) * 4
2405                            } else {
2406                                // Hope there are no extensions.
2407                                40
2408                            };
2409                        }
2410                    }
2411                    rndisprot::PPI_LSO => {
2412                        let n: rndisprot::TcpLsoInfo = d.reader(mem).read_plain()?;
2413
2414                        metadata.offload_tcp_segmentation = true;
2415                        metadata.offload_tcp_checksum = true;
2416                        metadata.offload_ip_header_checksum = n.is_ipv4();
2417                        metadata.l3_protocol = if n.is_ipv4() {
2418                            L3Protocol::Ipv4
2419                        } else {
2420                            L3Protocol::Ipv6
2421                        };
2422                        metadata.l2_len = ETHERNET_HEADER_LEN as u8;
2423                        if n.tcp_header_offset() < metadata.l2_len as u16 {
2424                            return Err(WorkerError::InvalidTcpHeaderOffset);
2425                        }
2426                        metadata.l3_len = n.tcp_header_offset() - metadata.l2_len as u16;
2427                        metadata.l4_len = {
2428                            let mut reader = data.clone().reader(mem);
2429                            reader
2430                                .skip(metadata.l2_len as usize + metadata.l3_len as usize + 12)?;
2431                            let mut b = 0;
2432                            reader.read(std::slice::from_mut(&mut b))?;
2433                            (b >> 4) * 4
2434                        };
2435                        metadata.max_tcp_segment_size = n.mss() as u16;
2436                    }
2437                    _ => {}
2438                }
2439                ppi = rest;
2440            }
2441        }
2442
2443        let start = segments.len();
2444        for range in data.paged_ranges().flat_map(|r| r.ranges()) {
2445            let range = range.map_err(WorkerError::InvalidGpadl)?;
2446            segments.push(TxSegment {
2447                ty: net_backend::TxSegmentType::Tail,
2448                gpa: range.start,
2449                len: range.len() as u32,
2450            });
2451        }
2452
2453        metadata.segment_count = segments.len() - start;
2454
2455        stats.tx_packets.increment();
2456        if metadata.offload_tcp_checksum || metadata.offload_udp_checksum {
2457            stats.tx_checksum_packets.increment();
2458        }
2459        if metadata.offload_tcp_segmentation {
2460            stats.tx_lso_packets.increment();
2461        }
2462
2463        segments[start].ty = net_backend::TxSegmentType::Head(metadata);
2464
2465        Ok(())
2466    }
2467
2468    /// Notify the adapter that the guest VF state has changed and it may
2469    /// need to send a message to the guest.
2470    fn guest_vf_is_available(
2471        &mut self,
2472        guest_vf_id: Option<u32>,
2473        version: Version,
2474        config: NdisConfig,
2475    ) -> Result<bool, WorkerError> {
2476        let serial_number = guest_vf_id.map(|vfid| self.adapter.get_guest_vf_serial_number(vfid));
2477        if version >= Version::V4 && config.capabilities.sriov() {
2478            tracing::info!(
2479                available = serial_number.is_some(),
2480                serial_number,
2481                "sending VF association message."
2482            );
2483            // N.B. MIN_CONTROL_RING_SIZE reserves room to send this packet.
2484            let message = {
2485                self.message(
2486                    protocol::MESSAGE4_TYPE_SEND_VF_ASSOCIATION,
2487                    protocol::Message4SendVfAssociation {
2488                        vf_allocated: if serial_number.is_some() { 1 } else { 0 },
2489                        serial_number: serial_number.unwrap_or(0),
2490                    },
2491                )
2492            };
2493            self.queue
2494                .split()
2495                .1
2496                .try_write(&queue::OutgoingPacket {
2497                    transaction_id: VF_ASSOCIATION_TRANSACTION_ID,
2498                    packet_type: OutgoingPacketType::InBandWithCompletion,
2499                    payload: &message.payload(),
2500                })
2501                .map_err(|err| match err {
2502                    queue::TryWriteError::Full(len) => {
2503                        tracing::error!(len, "failed to write vf association message");
2504                        WorkerError::OutOfSpace
2505                    }
2506                    queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2507                })?;
2508            Ok(true)
2509        } else {
2510            tracing::info!(
2511                available = serial_number.is_some(),
2512                serial_number,
2513                major = version.major(),
2514                minor = version.minor(),
2515                sriov_capable = config.capabilities.sriov(),
2516                "Skipping NvspMessage4TypeSendVFAssociation message"
2517            );
2518            Ok(false)
2519        }
2520    }
2521
2522    /// Send the `NvspMessage5TypeSendIndirectionTable` message.
2523    fn guest_send_indirection_table(&mut self, version: Version, num_channels_opened: u32) {
2524        // N.B. MIN_STATE_CHANGE_RING_SIZE needs to be large enough to support sending the indirection table.
2525        if version < Version::V5 {
2526            return;
2527        }
2528
2529        #[repr(C)]
2530        #[derive(IntoBytes, Immutable, KnownLayout)]
2531        struct SendIndirectionMsg {
2532            pub message: protocol::Message5SendIndirectionTable,
2533            pub send_indirection_table:
2534                [u32; VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES as usize],
2535        }
2536
2537        // The offset to the send indirection table from the beginning of the NVSP message.
2538        let send_indirection_table_offset = offset_of!(SendIndirectionMsg, send_indirection_table)
2539            + size_of::<protocol::MessageHeader>();
2540        let mut data = SendIndirectionMsg {
2541            message: protocol::Message5SendIndirectionTable {
2542                table_entry_count: VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES,
2543                table_offset: send_indirection_table_offset as u32,
2544            },
2545            send_indirection_table: Default::default(),
2546        };
2547
2548        for i in 0..data.send_indirection_table.len() {
2549            data.send_indirection_table[i] = i as u32 % num_channels_opened;
2550        }
2551
2552        let message = NvspMessage {
2553            header: protocol::MessageHeader {
2554                message_type: protocol::MESSAGE5_TYPE_SEND_INDIRECTION_TABLE,
2555            },
2556            data,
2557            padding: &[],
2558        };
2559        let result = self
2560            .queue
2561            .split()
2562            .1
2563            .try_write(&queue::OutgoingPacket {
2564                transaction_id: 0,
2565                packet_type: OutgoingPacketType::InBandNoCompletion,
2566                payload: &message.payload(),
2567            })
2568            .map_err(|err| match err {
2569                queue::TryWriteError::Full(len) => {
2570                    tracing::error!(len, "failed to write send indirection table message");
2571                    WorkerError::OutOfSpace
2572                }
2573                queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2574            });
2575        if let Err(err) = result {
2576            tracing::error!(err = %err, "Failed to notify guest about the send indirection table");
2577        }
2578    }
2579
2580    /// Notify the guest that the data path has been switched back to synthetic
2581    /// due to some external state change.
2582    fn guest_vf_data_path_switched_to_synthetic(&mut self) {
2583        let message = NvspMessage {
2584            header: protocol::MessageHeader {
2585                message_type: protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH,
2586            },
2587            data: protocol::Message4SwitchDataPath {
2588                active_data_path: protocol::DataPath::SYNTHETIC.0,
2589            },
2590            padding: &[],
2591        };
2592        let result = self
2593            .queue
2594            .split()
2595            .1
2596            .try_write(&queue::OutgoingPacket {
2597                transaction_id: SWITCH_DATA_PATH_TRANSACTION_ID,
2598                packet_type: OutgoingPacketType::InBandWithCompletion,
2599                payload: &message.payload(),
2600            })
2601            .map_err(|err| match err {
2602                queue::TryWriteError::Full(len) => {
2603                    tracing::error!(
2604                        len,
2605                        "failed to write MESSAGE4_TYPE_SWITCH_DATA_PATH message"
2606                    );
2607                    WorkerError::OutOfSpace
2608                }
2609                queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2610            });
2611        if let Err(err) = result {
2612            tracing::error!(err = %err, "Failed to notify guest that data path is now synthetic");
2613        }
2614    }
2615
2616    /// Process an internal state change
2617    async fn handle_state_change(
2618        &mut self,
2619        primary: &mut PrimaryChannelState,
2620        buffers: &ChannelBuffers,
2621    ) -> Result<Option<CoordinatorMessage>, WorkerError> {
2622        // N.B. MIN_STATE_CHANGE_RING_SIZE needs to be large enough to support sending state change messages.
2623        // The worst case is UnavailableFromDataPathSwitchPending, which will send three messages:
2624        //      1. completion of switch data path request
2625        //      2. Switch data path notification (back to synthetic)
2626        //      3. Disassociate VF adapter.
2627        if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
2628            // Notify guest that a VF capability has recently arrived.
2629            if primary.rndis_state == RndisState::Operational {
2630                if self.guest_vf_is_available(Some(vfid), buffers.version, buffers.ndis_config)? {
2631                    primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
2632                    return Ok(Some(CoordinatorMessage::Update(
2633                        CoordinatorMessageUpdateType {
2634                            guest_vf_state: true,
2635                            ..Default::default()
2636                        },
2637                    )));
2638                } else if let Some(true) = primary.is_data_path_switched {
2639                    tracing::error!(
2640                        "Data path switched, but current guest negotiation does not support VTL0 VF"
2641                    );
2642                }
2643            }
2644            return Ok(None);
2645        }
2646        loop {
2647            primary.guest_vf_state = match primary.guest_vf_state {
2648                PrimaryChannelGuestVfState::UnavailableFromAvailable => {
2649                    // Notify guest that the VF is unavailable. It has already been surprise removed.
2650                    if primary.rndis_state == RndisState::Operational {
2651                        self.guest_vf_is_available(None, buffers.version, buffers.ndis_config)?;
2652                    }
2653                    PrimaryChannelGuestVfState::Unavailable
2654                }
2655                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
2656                    to_guest,
2657                    id,
2658                } => {
2659                    // Complete the data path switch request.
2660                    self.send_completion(id, &[])?;
2661                    if to_guest {
2662                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
2663                    } else {
2664                        PrimaryChannelGuestVfState::UnavailableFromAvailable
2665                    }
2666                }
2667                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
2668                    // Notify guest that the data path is now synthetic.
2669                    self.guest_vf_data_path_switched_to_synthetic();
2670                    PrimaryChannelGuestVfState::UnavailableFromAvailable
2671                }
2672                PrimaryChannelGuestVfState::DataPathSynthetic => {
2673                    // Notify guest that the data path is now synthetic.
2674                    self.guest_vf_data_path_switched_to_synthetic();
2675                    PrimaryChannelGuestVfState::Ready
2676                }
2677                PrimaryChannelGuestVfState::DataPathSwitchPending {
2678                    to_guest,
2679                    id,
2680                    result,
2681                } => {
2682                    let result = result.expect("DataPathSwitchPending should have been processed");
2683                    // Complete the data path switch request.
2684                    self.send_completion(id, &[])?;
2685                    if result {
2686                        if to_guest {
2687                            PrimaryChannelGuestVfState::DataPathSwitched
2688                        } else {
2689                            PrimaryChannelGuestVfState::Ready
2690                        }
2691                    } else {
2692                        if to_guest {
2693                            PrimaryChannelGuestVfState::DataPathSynthetic
2694                        } else {
2695                            tracing::error!(
2696                                "Failure when guest requested switch back to synthetic"
2697                            );
2698                            PrimaryChannelGuestVfState::DataPathSwitched
2699                        }
2700                    }
2701                }
2702                PrimaryChannelGuestVfState::Initializing
2703                | PrimaryChannelGuestVfState::Restoring(_) => {
2704                    panic!("Invalid guest VF state: {}", primary.guest_vf_state)
2705                }
2706                _ => break,
2707            };
2708        }
2709        Ok(None)
2710    }
2711
2712    /// Process a control message, writing the response to the provided receive
2713    /// buffer suballocation.
2714    fn handle_rndis_control_message(
2715        &mut self,
2716        primary: &mut PrimaryChannelState,
2717        buffers: &ChannelBuffers,
2718        message_type: u32,
2719        mut reader: impl MemoryRead + Clone,
2720        id: u32,
2721    ) -> Result<(), WorkerError> {
2722        let mem = &buffers.mem;
2723        let buffer_range = &buffers.recv_buffer.range(id);
2724        match message_type {
2725            rndisprot::MESSAGE_TYPE_INITIALIZE_MSG => {
2726                if primary.rndis_state != RndisState::Initializing {
2727                    return Err(WorkerError::InvalidRndisState);
2728                }
2729
2730                let request: rndisprot::InitializeRequest = reader.read_plain()?;
2731
2732                tracing::trace!(
2733                    ?request,
2734                    "handling control message MESSAGE_TYPE_INITIALIZE_MSG"
2735                );
2736
2737                primary.rndis_state = RndisState::Operational;
2738
2739                let mut writer = buffer_range.writer(mem);
2740                let message_length = write_rndis_message(
2741                    &mut writer,
2742                    rndisprot::MESSAGE_TYPE_INITIALIZE_CMPLT,
2743                    0,
2744                    &rndisprot::InitializeComplete {
2745                        request_id: request.request_id,
2746                        status: rndisprot::STATUS_SUCCESS,
2747                        major_version: rndisprot::MAJOR_VERSION,
2748                        minor_version: rndisprot::MINOR_VERSION,
2749                        device_flags: rndisprot::DF_CONNECTIONLESS,
2750                        medium: rndisprot::MEDIUM_802_3,
2751                        max_packets_per_message: 8,
2752                        max_transfer_size: 0xEFFFFFFF,
2753                        packet_alignment_factor: 3,
2754                        af_list_offset: 0,
2755                        af_list_size: 0,
2756                    },
2757                )?;
2758                self.send_rndis_control_message(buffers, id, message_length)?;
2759                if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
2760                    if self.guest_vf_is_available(
2761                        Some(vfid),
2762                        buffers.version,
2763                        buffers.ndis_config,
2764                    )? {
2765                        // Ideally the VF would not be presented to the guest
2766                        // until the completion packet has arrived, so that the
2767                        // guest is prepared. This is most interesting for the
2768                        // case of a VF associated with multiple guest
2769                        // adapters, using a concept like vports. In this
2770                        // scenario it would be better if all of the adapters
2771                        // were aware a VF was available before the device
2772                        // arrived. This is not currently possible because the
2773                        // Linux netvsc driver ignores the completion requested
2774                        // flag on inband packets and won't send a completion
2775                        // packet.
2776                        primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
2777                        self.send_coordinator_update_vf();
2778                    } else if let Some(true) = primary.is_data_path_switched {
2779                        tracing::error!(
2780                            "Data path switched, but current guest negotiation does not support VTL0 VF"
2781                        );
2782                    }
2783                }
2784            }
2785            rndisprot::MESSAGE_TYPE_QUERY_MSG => {
2786                let request: rndisprot::QueryRequest = reader.read_plain()?;
2787
2788                tracing::trace!(?request, "handling control message MESSAGE_TYPE_QUERY_MSG");
2789
2790                let (header, body) = buffer_range
2791                    .try_split(
2792                        size_of::<rndisprot::MessageHeader>()
2793                            + size_of::<rndisprot::QueryComplete>(),
2794                    )
2795                    .ok_or(WorkerError::RndisMessageTooSmall)?;
2796                let (status, tx) = match self.adapter.handle_oid_query(
2797                    buffers,
2798                    primary,
2799                    request.oid,
2800                    body.writer(mem),
2801                ) {
2802                    Ok(tx) => (rndisprot::STATUS_SUCCESS, tx),
2803                    Err(err) => (err.as_status(), 0),
2804                };
2805
2806                let message_length = write_rndis_message(
2807                    &mut header.writer(mem),
2808                    rndisprot::MESSAGE_TYPE_QUERY_CMPLT,
2809                    tx,
2810                    &rndisprot::QueryComplete {
2811                        request_id: request.request_id,
2812                        status,
2813                        information_buffer_offset: size_of::<rndisprot::QueryComplete>() as u32,
2814                        information_buffer_length: tx as u32,
2815                    },
2816                )?;
2817                self.send_rndis_control_message(buffers, id, message_length)?;
2818            }
2819            rndisprot::MESSAGE_TYPE_SET_MSG => {
2820                let request: rndisprot::SetRequest = reader.read_plain()?;
2821
2822                tracing::trace!(?request, "handling control message MESSAGE_TYPE_SET_MSG");
2823
2824                let status = match self.adapter.handle_oid_set(primary, request.oid, reader) {
2825                    Ok((restart_endpoint, packet_filter)) => {
2826                        // Restart the endpoint if the OID changed some critical
2827                        // endpoint property.
2828                        if restart_endpoint {
2829                            self.restart = Some(CoordinatorMessage::Restart);
2830                        }
2831                        if let Some(filter) = packet_filter {
2832                            if self.packet_filter != filter {
2833                                self.packet_filter = filter;
2834                                self.send_coordinator_update_filter();
2835                            }
2836                        }
2837                        rndisprot::STATUS_SUCCESS
2838                    }
2839                    Err(err) => {
2840                        tracelimit::warn_ratelimited!(oid = ?request.oid, error = &err as &dyn std::error::Error, "oid failure");
2841                        err.as_status()
2842                    }
2843                };
2844
2845                let message_length = write_rndis_message(
2846                    &mut buffer_range.writer(mem),
2847                    rndisprot::MESSAGE_TYPE_SET_CMPLT,
2848                    0,
2849                    &rndisprot::SetComplete {
2850                        request_id: request.request_id,
2851                        status,
2852                    },
2853                )?;
2854                self.send_rndis_control_message(buffers, id, message_length)?;
2855            }
2856            rndisprot::MESSAGE_TYPE_RESET_MSG => {
2857                return Err(WorkerError::RndisMessageTypeNotImplemented);
2858            }
2859            rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG => {
2860                return Err(WorkerError::RndisMessageTypeNotImplemented);
2861            }
2862            rndisprot::MESSAGE_TYPE_KEEPALIVE_MSG => {
2863                let request: rndisprot::KeepaliveRequest = reader.read_plain()?;
2864
2865                tracing::trace!(
2866                    ?request,
2867                    "handling control message MESSAGE_TYPE_KEEPALIVE_MSG"
2868                );
2869
2870                let message_length = write_rndis_message(
2871                    &mut buffer_range.writer(mem),
2872                    rndisprot::MESSAGE_TYPE_KEEPALIVE_CMPLT,
2873                    0,
2874                    &rndisprot::KeepaliveComplete {
2875                        request_id: request.request_id,
2876                        status: rndisprot::STATUS_SUCCESS,
2877                    },
2878                )?;
2879                self.send_rndis_control_message(buffers, id, message_length)?;
2880            }
2881            rndisprot::MESSAGE_TYPE_SET_EX_MSG => {
2882                return Err(WorkerError::RndisMessageTypeNotImplemented);
2883            }
2884            _ => return Err(WorkerError::UnknownRndisMessageType(message_type)),
2885        };
2886        Ok(())
2887    }
2888
2889    fn try_send_rndis_message(
2890        &mut self,
2891        transaction_id: u64,
2892        channel_type: u32,
2893        recv_buffer_id: u16,
2894        transfer_pages: &[ring::TransferPageRange],
2895    ) -> Result<Option<usize>, WorkerError> {
2896        let message = self.message(
2897            protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET,
2898            protocol::Message1SendRndisPacket {
2899                channel_type,
2900                send_buffer_section_index: 0xffffffff,
2901                send_buffer_section_size: 0,
2902            },
2903        );
2904        let pending_send_size = match self.queue.split().1.try_write(&queue::OutgoingPacket {
2905            transaction_id,
2906            packet_type: OutgoingPacketType::TransferPages(recv_buffer_id, transfer_pages),
2907            payload: &message.payload(),
2908        }) {
2909            Ok(()) => None,
2910            Err(queue::TryWriteError::Full(n)) => Some(n),
2911            Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
2912        };
2913        Ok(pending_send_size)
2914    }
2915
2916    fn send_rndis_control_message(
2917        &mut self,
2918        buffers: &ChannelBuffers,
2919        id: u32,
2920        message_length: usize,
2921    ) -> Result<(), WorkerError> {
2922        let result = self.try_send_rndis_message(
2923            id as u64,
2924            protocol::CONTROL_CHANNEL_TYPE,
2925            buffers.recv_buffer.id,
2926            std::slice::from_ref(&buffers.recv_buffer.transfer_page_range(id, message_length)),
2927        )?;
2928
2929        // Ring size is checked before control messages are processed, so failure to write is unexpected.
2930        match result {
2931            None => Ok(()),
2932            Some(len) => {
2933                tracelimit::error_ratelimited!(len, "failed to write control message completion");
2934                Err(WorkerError::OutOfSpace)
2935            }
2936        }
2937    }
2938
2939    fn indicate_status(
2940        &mut self,
2941        buffers: &ChannelBuffers,
2942        id: u32,
2943        status: u32,
2944        payload: &[u8],
2945    ) -> Result<(), WorkerError> {
2946        let buffer = &buffers.recv_buffer.range(id);
2947        let mut writer = buffer.writer(&buffers.mem);
2948        let message_length = write_rndis_message(
2949            &mut writer,
2950            rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG,
2951            payload.len(),
2952            &rndisprot::IndicateStatus {
2953                status,
2954                status_buffer_length: payload.len() as u32,
2955                status_buffer_offset: if payload.is_empty() {
2956                    0
2957                } else {
2958                    size_of::<rndisprot::IndicateStatus>() as u32
2959                },
2960            },
2961        )?;
2962        writer.write(payload)?;
2963        self.send_rndis_control_message(buffers, id, message_length)?;
2964        Ok(())
2965    }
2966
2967    /// Processes pending control messages until all are processed or there are
2968    /// no available suballocations.
2969    fn process_control_messages(
2970        &mut self,
2971        buffers: &ChannelBuffers,
2972        state: &mut ActiveState,
2973    ) -> Result<(), WorkerError> {
2974        let Some(primary) = &mut state.primary else {
2975            return Ok(());
2976        };
2977
2978        while !primary.control_messages.is_empty()
2979            || (primary.pending_offload_change && primary.rndis_state == RndisState::Operational)
2980        {
2981            // Ensure the ring buffer has enough room to successfully complete control message handling.
2982            if !self.queue.split().1.can_write(MIN_CONTROL_RING_SIZE)? {
2983                self.pending_send_size = MIN_CONTROL_RING_SIZE;
2984                break;
2985            }
2986            let Some(id) = primary.free_control_buffers.pop() else {
2987                break;
2988            };
2989
2990            // Mark the receive buffer in use to allow the guest to release it.
2991            assert!(state.rx_bufs.is_free(id.0));
2992            state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
2993
2994            if let Some(message) = primary.control_messages.pop_front() {
2995                primary.control_messages_len -= message.data.len();
2996                self.handle_rndis_control_message(
2997                    primary,
2998                    buffers,
2999                    message.message_type,
3000                    message.data.as_ref(),
3001                    id.0,
3002                )?;
3003            } else if primary.pending_offload_change
3004                && primary.rndis_state == RndisState::Operational
3005            {
3006                let ndis_offload = primary.offload_config.ndis_offload();
3007                self.indicate_status(
3008                    buffers,
3009                    id.0,
3010                    rndisprot::STATUS_TASK_OFFLOAD_CURRENT_CONFIG,
3011                    &ndis_offload.as_bytes()[..ndis_offload.header.size.into()],
3012                )?;
3013                primary.pending_offload_change = false;
3014            } else {
3015                unreachable!();
3016            }
3017        }
3018        Ok(())
3019    }
3020
3021    fn send_coordinator_update_message(&mut self, guest_vf: bool, packet_filter: bool) {
3022        if self.restart.is_none() {
3023            self.restart = Some(CoordinatorMessage::Update(CoordinatorMessageUpdateType {
3024                guest_vf_state: guest_vf,
3025                filter_state: packet_filter,
3026            }));
3027        } else if let Some(CoordinatorMessage::Restart) = self.restart {
3028            // If a restart message is pending, do nothing.
3029            // A restart will try to switch the data path based on primary.guest_vf_state.
3030            // A restart will apply packet filter changes.
3031        } else if let Some(CoordinatorMessage::Update(ref mut update)) = self.restart {
3032            // Add the new update to the existing message.
3033            update.guest_vf_state |= guest_vf;
3034            update.filter_state |= packet_filter;
3035        }
3036    }
3037
3038    fn send_coordinator_update_vf(&mut self) {
3039        self.send_coordinator_update_message(true, false);
3040    }
3041
3042    fn send_coordinator_update_filter(&mut self) {
3043        self.send_coordinator_update_message(false, true);
3044    }
3045}
3046
3047/// Writes an RNDIS message to `writer`.
3048fn write_rndis_message<T: IntoBytes + Immutable + KnownLayout>(
3049    writer: &mut impl MemoryWrite,
3050    message_type: u32,
3051    extra: usize,
3052    payload: &T,
3053) -> Result<usize, AccessError> {
3054    let message_length = size_of::<rndisprot::MessageHeader>() + size_of_val(payload) + extra;
3055    writer.write(
3056        rndisprot::MessageHeader {
3057            message_type,
3058            message_length: message_length as u32,
3059        }
3060        .as_bytes(),
3061    )?;
3062    writer.write(payload.as_bytes())?;
3063    Ok(message_length)
3064}
3065
3066#[derive(Debug, Error)]
3067enum OidError {
3068    #[error(transparent)]
3069    Access(#[from] AccessError),
3070    #[error("unknown oid")]
3071    UnknownOid,
3072    #[error("invalid oid input, bad field {0}")]
3073    InvalidInput(&'static str),
3074    #[error("bad ndis version")]
3075    BadVersion,
3076    #[error("feature {0} not supported")]
3077    NotSupported(&'static str),
3078}
3079
3080impl OidError {
3081    fn as_status(&self) -> u32 {
3082        match self {
3083            OidError::UnknownOid | OidError::NotSupported(_) => rndisprot::STATUS_NOT_SUPPORTED,
3084            OidError::BadVersion => rndisprot::STATUS_BAD_VERSION,
3085            OidError::InvalidInput(_) => rndisprot::STATUS_INVALID_DATA,
3086            OidError::Access(_) => rndisprot::STATUS_FAILURE,
3087        }
3088    }
3089}
3090
3091const DEFAULT_MTU: u32 = 1514;
3092const MIN_MTU: u32 = DEFAULT_MTU;
3093const MAX_MTU: u32 = 9216;
3094
3095const ETHERNET_HEADER_LEN: u32 = 14;
3096
3097impl Adapter {
3098    fn get_guest_vf_serial_number(&self, vfid: u32) -> u32 {
3099        if let Some(guest_os_id) = self.get_guest_os_id.as_ref().map(|f| f()) {
3100            // For enlightened guests (which is only Windows at the moment), send the
3101            // adapter index, which was previously set as the vport serial number.
3102            if guest_os_id
3103                .microsoft()
3104                .unwrap_or(HvGuestOsMicrosoft::from(0))
3105                .os_id()
3106                == HvGuestOsMicrosoftIds::WINDOWS_NT.0
3107            {
3108                self.adapter_index
3109            } else {
3110                vfid
3111            }
3112        } else {
3113            vfid
3114        }
3115    }
3116
3117    fn handle_oid_query(
3118        &self,
3119        buffers: &ChannelBuffers,
3120        primary: &PrimaryChannelState,
3121        oid: rndisprot::Oid,
3122        mut writer: impl MemoryWrite,
3123    ) -> Result<usize, OidError> {
3124        tracing::debug!(?oid, "oid query");
3125        let available_len = writer.len();
3126        match oid {
3127            rndisprot::Oid::OID_GEN_SUPPORTED_LIST => {
3128                let supported_oids_common = &[
3129                    rndisprot::Oid::OID_GEN_SUPPORTED_LIST,
3130                    rndisprot::Oid::OID_GEN_HARDWARE_STATUS,
3131                    rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED,
3132                    rndisprot::Oid::OID_GEN_MEDIA_IN_USE,
3133                    rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD,
3134                    rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD,
3135                    rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE,
3136                    rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE,
3137                    rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE,
3138                    rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE,
3139                    rndisprot::Oid::OID_GEN_LINK_SPEED,
3140                    rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE,
3141                    rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE,
3142                    rndisprot::Oid::OID_GEN_VENDOR_ID,
3143                    rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION,
3144                    rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION,
3145                    rndisprot::Oid::OID_GEN_DRIVER_VERSION,
3146                    rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER,
3147                    rndisprot::Oid::OID_GEN_PROTOCOL_OPTIONS,
3148                    rndisprot::Oid::OID_GEN_MAC_OPTIONS,
3149                    rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS,
3150                    rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS,
3151                    rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES,
3152                    rndisprot::Oid::OID_GEN_FRIENDLY_NAME,
3153                    // Ethernet objects operation characteristics
3154                    rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS,
3155                    rndisprot::Oid::OID_802_3_CURRENT_ADDRESS,
3156                    rndisprot::Oid::OID_802_3_MULTICAST_LIST,
3157                    rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE,
3158                    // Ethernet objects statistics
3159                    rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT,
3160                    rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION,
3161                    rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS,
3162                    // PNP operations characteristics */
3163                    // rndisprot::Oid::OID_PNP_SET_POWER,
3164                    // rndisprot::Oid::OID_PNP_QUERY_POWER,
3165                    // RNDIS OIDS
3166                    rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER,
3167                ];
3168
3169                // NDIS6 OIDs
3170                let supported_oids_6 = &[
3171                    // Link State OID
3172                    rndisprot::Oid::OID_GEN_LINK_PARAMETERS,
3173                    rndisprot::Oid::OID_GEN_LINK_STATE,
3174                    rndisprot::Oid::OID_GEN_MAX_LINK_SPEED,
3175                    // NDIS 6 statistics OID
3176                    rndisprot::Oid::OID_GEN_BYTES_RCV,
3177                    rndisprot::Oid::OID_GEN_BYTES_XMIT,
3178                    // Offload related OID
3179                    rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS,
3180                    rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION,
3181                    rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES,
3182                    rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG,
3183                    // rndisprot::Oid::OID_802_3_ADD_MULTICAST_ADDRESS,
3184                    // rndisprot::Oid::OID_802_3_DELETE_MULTICAST_ADDRESS,
3185                ];
3186
3187                let supported_oids_63 = &[
3188                    rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES,
3189                    rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS,
3190                ];
3191
3192                match buffers.ndis_version.major {
3193                    5 => {
3194                        writer.write(supported_oids_common.as_bytes())?;
3195                    }
3196                    6 => {
3197                        writer.write(supported_oids_common.as_bytes())?;
3198                        writer.write(supported_oids_6.as_bytes())?;
3199                        if buffers.ndis_version.minor >= 30 {
3200                            writer.write(supported_oids_63.as_bytes())?;
3201                        }
3202                    }
3203                    _ => return Err(OidError::BadVersion),
3204                }
3205            }
3206            rndisprot::Oid::OID_GEN_HARDWARE_STATUS => {
3207                let status: u32 = 0; // NdisHardwareStatusReady
3208                writer.write(status.as_bytes())?;
3209            }
3210            rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED | rndisprot::Oid::OID_GEN_MEDIA_IN_USE => {
3211                writer.write(rndisprot::MEDIUM_802_3.as_bytes())?;
3212            }
3213            rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD
3214            | rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD
3215            | rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE => {
3216                let len: u32 = buffers.ndis_config.mtu - ETHERNET_HEADER_LEN;
3217                writer.write(len.as_bytes())?;
3218            }
3219            rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE
3220            | rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE
3221            | rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE => {
3222                let len: u32 = buffers.ndis_config.mtu;
3223                writer.write(len.as_bytes())?;
3224            }
3225            rndisprot::Oid::OID_GEN_LINK_SPEED => {
3226                let speed: u32 = (self.link_speed / 100) as u32; // In 100bps units
3227                writer.write(speed.as_bytes())?;
3228            }
3229            rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE
3230            | rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE => {
3231                // This value is meaningless for virtual NICs. Return what vmswitch returns.
3232                writer.write((256u32 * 1024).as_bytes())?
3233            }
3234            rndisprot::Oid::OID_GEN_VENDOR_ID => {
3235                // Like vmswitch, use the first N bytes of Microsoft's MAC address
3236                // prefix as the vendor ID.
3237                writer.write(0x0000155du32.as_bytes())?;
3238            }
3239            rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION => writer.write(b"Microsoft Corporation")?,
3240            rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION
3241            | rndisprot::Oid::OID_GEN_DRIVER_VERSION => {
3242                writer.write(0x0100u16.as_bytes())? // 1.0. Vmswitch reports 19.0 for Mn.
3243            }
3244            rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => writer.write(0u32.as_bytes())?,
3245            rndisprot::Oid::OID_GEN_MAC_OPTIONS => {
3246                let options: u32 = rndisprot::MAC_OPTION_COPY_LOOKAHEAD_DATA
3247                    | rndisprot::MAC_OPTION_TRANSFERS_NOT_PEND
3248                    | rndisprot::MAC_OPTION_NO_LOOPBACK;
3249                writer.write(options.as_bytes())?;
3250            }
3251            rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS => {
3252                writer.write(rndisprot::MEDIA_STATE_CONNECTED.as_bytes())?;
3253            }
3254            rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS => writer.write(u32::MAX.as_bytes())?,
3255            rndisprot::Oid::OID_GEN_FRIENDLY_NAME => {
3256                let name16: Vec<u16> = "Network Device".encode_utf16().collect();
3257                let mut name = rndisprot::FriendlyName::new_zeroed();
3258                name.name[..name16.len()].copy_from_slice(&name16);
3259                writer.write(name.as_bytes())?
3260            }
3261            rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS
3262            | rndisprot::Oid::OID_802_3_CURRENT_ADDRESS => {
3263                writer.write(&self.mac_address.to_bytes())?
3264            }
3265            rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE => {
3266                writer.write(0u32.as_bytes())?;
3267            }
3268            rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT
3269            | rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION
3270            | rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS => writer.write(0u32.as_bytes())?,
3271
3272            // NDIS6 OIDs:
3273            rndisprot::Oid::OID_GEN_LINK_STATE => {
3274                let link_state = rndisprot::LinkState {
3275                    header: rndisprot::NdisObjectHeader {
3276                        object_type: rndisprot::NdisObjectType::DEFAULT,
3277                        revision: 1,
3278                        size: size_of::<rndisprot::LinkState>() as u16,
3279                    },
3280                    media_connect_state: 1, /* MediaConnectStateConnected */
3281                    media_duplex_state: 0,  /* MediaDuplexStateUnknown */
3282                    padding: 0,
3283                    xmit_link_speed: self.link_speed,
3284                    rcv_link_speed: self.link_speed,
3285                    pause_functions: 0, /* NdisPauseFunctionsUnsupported */
3286                    auto_negotiation_flags: 0,
3287                };
3288                writer.write(link_state.as_bytes())?;
3289            }
3290            rndisprot::Oid::OID_GEN_MAX_LINK_SPEED => {
3291                let link_speed = rndisprot::LinkSpeed {
3292                    xmit: self.link_speed,
3293                    rcv: self.link_speed,
3294                };
3295                writer.write(link_speed.as_bytes())?;
3296            }
3297            rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES => {
3298                let ndis_offload = self.offload_support.ndis_offload();
3299                writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3300            }
3301            rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG => {
3302                let ndis_offload = primary.offload_config.ndis_offload();
3303                writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3304            }
3305            rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3306                writer.write(
3307                    &rndisprot::NdisOffloadEncapsulation {
3308                        header: rndisprot::NdisObjectHeader {
3309                            object_type: rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3310                            revision: 1,
3311                            size: rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1 as u16,
3312                        },
3313                        ipv4_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3314                        ipv4_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3315                        ipv4_header_size: ETHERNET_HEADER_LEN,
3316                        ipv6_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3317                        ipv6_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3318                        ipv6_header_size: ETHERNET_HEADER_LEN,
3319                    }
3320                    .as_bytes()[..rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1],
3321                )?;
3322            }
3323            rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES => {
3324                writer.write(
3325                    &rndisprot::NdisReceiveScaleCapabilities {
3326                        header: rndisprot::NdisObjectHeader {
3327                            object_type: rndisprot::NdisObjectType::RSS_CAPABILITIES,
3328                            revision: 2,
3329                            size: rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2
3330                                as u16,
3331                        },
3332                        capabilities_flags: rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV4
3333                            | rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV6
3334                            | rndisprot::NDIS_HASH_FUNCTION_TOEPLITZ,
3335                        number_of_interrupt_messages: 1,
3336                        number_of_receive_queues: self.max_queues.into(),
3337                        number_of_indirection_table_entries: if self.indirection_table_size != 0 {
3338                            self.indirection_table_size
3339                        } else {
3340                            // DPDK gets confused if the table size is zero,
3341                            // even if there is only one queue.
3342                            128
3343                        },
3344                        padding: 0,
3345                    }
3346                    .as_bytes()[..rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2],
3347                )?;
3348            }
3349            _ => {
3350                tracelimit::warn_ratelimited!(?oid, "query for unknown OID");
3351                return Err(OidError::UnknownOid);
3352            }
3353        };
3354        Ok(available_len - writer.len())
3355    }
3356
3357    fn handle_oid_set(
3358        &self,
3359        primary: &mut PrimaryChannelState,
3360        oid: rndisprot::Oid,
3361        reader: impl MemoryRead + Clone,
3362    ) -> Result<(bool, Option<u32>), OidError> {
3363        tracing::debug!(?oid, "oid set");
3364
3365        let mut restart_endpoint = false;
3366        let mut packet_filter = None;
3367        match oid {
3368            rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => {
3369                packet_filter = self.oid_set_packet_filter(reader)?;
3370            }
3371            rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS => {
3372                self.oid_set_offload_parameters(reader, primary)?;
3373            }
3374            rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3375                self.oid_set_offload_encapsulation(reader)?;
3376            }
3377            rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER => {
3378                self.oid_set_rndis_config_parameter(reader, primary)?;
3379            }
3380            rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES => {
3381                // TODO
3382            }
3383            rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS => {
3384                self.oid_set_rss_parameters(reader, primary)?;
3385
3386                // Endpoints cannot currently change RSS parameters without
3387                // being restarted. This was a limitation driven by some DPDK
3388                // PMDs, and should be fixed.
3389                restart_endpoint = true;
3390            }
3391            _ => {
3392                tracelimit::warn_ratelimited!(?oid, "set of unknown OID");
3393                return Err(OidError::UnknownOid);
3394            }
3395        }
3396        Ok((restart_endpoint, packet_filter))
3397    }
3398
3399    fn oid_set_rss_parameters(
3400        &self,
3401        mut reader: impl MemoryRead + Clone,
3402        primary: &mut PrimaryChannelState,
3403    ) -> Result<(), OidError> {
3404        // Vmswitch doesn't validate the NDIS header on this object, so read it manually.
3405        let mut params = rndisprot::NdisReceiveScaleParameters::new_zeroed();
3406        let len = reader.len().min(size_of_val(&params));
3407        reader.clone().read(&mut params.as_mut_bytes()[..len])?;
3408
3409        if ((params.flags & NDIS_RSS_PARAM_FLAG_DISABLE_RSS) != 0)
3410            || ((params.hash_information & NDIS_HASH_FUNCTION_MASK) == 0)
3411        {
3412            primary.rss_state = None;
3413            return Ok(());
3414        }
3415
3416        if params.hash_secret_key_size != 40 {
3417            return Err(OidError::InvalidInput("hash_secret_key_size"));
3418        }
3419        if params.indirection_table_size % 4 != 0 {
3420            return Err(OidError::InvalidInput("indirection_table_size"));
3421        }
3422        let indirection_table_size =
3423            (params.indirection_table_size / 4).min(self.indirection_table_size) as usize;
3424        let mut key = [0; 40];
3425        let mut indirection_table = vec![0u32; self.indirection_table_size as usize];
3426        reader
3427            .clone()
3428            .skip(params.hash_secret_key_offset as usize)?
3429            .read(&mut key)?;
3430        reader
3431            .skip(params.indirection_table_offset as usize)?
3432            .read(indirection_table[..indirection_table_size].as_mut_bytes())?;
3433        tracelimit::info_ratelimited!(?indirection_table, "OID_GEN_RECEIVE_SCALE_PARAMETERS");
3434        if indirection_table
3435            .iter()
3436            .any(|&x| x >= self.max_queues as u32)
3437        {
3438            return Err(OidError::InvalidInput("indirection_table"));
3439        }
3440        let (indir_init, indir_uninit) = indirection_table.split_at_mut(indirection_table_size);
3441        for (src, dest) in std::iter::repeat_with(|| indir_init.iter().copied())
3442            .flatten()
3443            .zip(indir_uninit)
3444        {
3445            *dest = src;
3446        }
3447        primary.rss_state = Some(RssState {
3448            key,
3449            indirection_table: indirection_table.iter().map(|&x| x as u16).collect(),
3450        });
3451        Ok(())
3452    }
3453
3454    fn oid_set_packet_filter(
3455        &self,
3456        reader: impl MemoryRead + Clone,
3457    ) -> Result<Option<u32>, OidError> {
3458        let filter: rndisprot::RndisPacketFilterOidValue = reader.clone().read_plain()?;
3459        tracing::debug!(filter, "set packet filter");
3460        Ok(Some(filter))
3461    }
3462
3463    fn oid_set_offload_parameters(
3464        &self,
3465        reader: impl MemoryRead + Clone,
3466        primary: &mut PrimaryChannelState,
3467    ) -> Result<(), OidError> {
3468        let offload: rndisprot::NdisOffloadParameters = read_ndis_object(
3469            reader,
3470            rndisprot::NdisObjectType::DEFAULT,
3471            1,
3472            rndisprot::NDIS_SIZEOF_OFFLOAD_PARAMETERS_REVISION_1,
3473        )?;
3474
3475        tracing::debug!(?offload, "offload parameters");
3476        let rndisprot::NdisOffloadParameters {
3477            header: _,
3478            ipv4_checksum,
3479            tcp4_checksum,
3480            udp4_checksum,
3481            tcp6_checksum,
3482            udp6_checksum,
3483            lsov1,
3484            ipsec_v1: _,
3485            lsov2_ipv4,
3486            lsov2_ipv6,
3487            tcp_connection_ipv4: _,
3488            tcp_connection_ipv6: _,
3489            reserved: _,
3490            flags: _,
3491        } = offload;
3492
3493        if lsov1 == rndisprot::OffloadParametersSimple::ENABLED {
3494            return Err(OidError::NotSupported("lsov1"));
3495        }
3496        if let Some((tx, rx)) = ipv4_checksum.tx_rx() {
3497            primary.offload_config.checksum_tx.ipv4_header = tx;
3498            primary.offload_config.checksum_rx.ipv4_header = rx;
3499        }
3500        if let Some((tx, rx)) = tcp4_checksum.tx_rx() {
3501            primary.offload_config.checksum_tx.tcp4 = tx;
3502            primary.offload_config.checksum_rx.tcp4 = rx;
3503        }
3504        if let Some((tx, rx)) = tcp6_checksum.tx_rx() {
3505            primary.offload_config.checksum_tx.tcp6 = tx;
3506            primary.offload_config.checksum_rx.tcp6 = rx;
3507        }
3508        if let Some((tx, rx)) = udp4_checksum.tx_rx() {
3509            primary.offload_config.checksum_tx.udp4 = tx;
3510            primary.offload_config.checksum_rx.udp4 = rx;
3511        }
3512        if let Some((tx, rx)) = udp6_checksum.tx_rx() {
3513            primary.offload_config.checksum_tx.udp6 = tx;
3514            primary.offload_config.checksum_rx.udp6 = rx;
3515        }
3516        if let Some(enable) = lsov2_ipv4.enable() {
3517            primary.offload_config.lso4 = enable && self.offload_support.lso4;
3518        }
3519        if let Some(enable) = lsov2_ipv6.enable() {
3520            primary.offload_config.lso6 = enable && self.offload_support.lso6;
3521        }
3522        primary.pending_offload_change = true;
3523        Ok(())
3524    }
3525
3526    fn oid_set_offload_encapsulation(
3527        &self,
3528        reader: impl MemoryRead + Clone,
3529    ) -> Result<(), OidError> {
3530        let encap: rndisprot::NdisOffloadEncapsulation = read_ndis_object(
3531            reader,
3532            rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3533            1,
3534            rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1,
3535        )?;
3536        if encap.ipv4_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3537            && (encap.ipv4_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3538                || encap.ipv4_header_size != ETHERNET_HEADER_LEN)
3539        {
3540            return Err(OidError::NotSupported("ipv4 encap"));
3541        }
3542        if encap.ipv6_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3543            && (encap.ipv6_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3544                || encap.ipv6_header_size != ETHERNET_HEADER_LEN)
3545        {
3546            return Err(OidError::NotSupported("ipv6 encap"));
3547        }
3548        Ok(())
3549    }
3550
3551    fn oid_set_rndis_config_parameter(
3552        &self,
3553        reader: impl MemoryRead + Clone,
3554        primary: &mut PrimaryChannelState,
3555    ) -> Result<(), OidError> {
3556        let info: rndisprot::RndisConfigParameterInfo = reader.clone().read_plain()?;
3557        if info.name_length > 255 {
3558            return Err(OidError::InvalidInput("name_length"));
3559        }
3560        if info.value_length > 255 {
3561            return Err(OidError::InvalidInput("value_length"));
3562        }
3563        let name = reader
3564            .clone()
3565            .skip(info.name_offset as usize)?
3566            .read_n::<u16>(info.name_length as usize / 2)?;
3567        let name = String::from_utf16(&name).map_err(|_| OidError::InvalidInput("name"))?;
3568        let mut value = reader;
3569        value.skip(info.value_offset as usize)?;
3570        let mut value = value.limit(info.value_length as usize);
3571        match info.parameter_type {
3572            rndisprot::NdisParameterType::STRING => {
3573                let value = value.read_n::<u16>(info.value_length as usize / 2)?;
3574                let value =
3575                    String::from_utf16(&value).map_err(|_| OidError::InvalidInput("value"))?;
3576                let as_num = value.as_bytes().first().map_or(0, |c| c - b'0');
3577                let tx = as_num & 1 != 0;
3578                let rx = as_num & 2 != 0;
3579
3580                tracing::debug!(name, value, "rndis config");
3581                match name.as_str() {
3582                    "*IPChecksumOffloadIPv4" => {
3583                        primary.offload_config.checksum_tx.ipv4_header = tx;
3584                        primary.offload_config.checksum_rx.ipv4_header = rx;
3585                    }
3586                    "*LsoV2IPv4" => {
3587                        primary.offload_config.lso4 = as_num != 0 && self.offload_support.lso4;
3588                    }
3589                    "*LsoV2IPv6" => {
3590                        primary.offload_config.lso6 = as_num != 0 && self.offload_support.lso6;
3591                    }
3592                    "*TCPChecksumOffloadIPv4" => {
3593                        primary.offload_config.checksum_tx.tcp4 = tx;
3594                        primary.offload_config.checksum_rx.tcp4 = rx;
3595                    }
3596                    "*TCPChecksumOffloadIPv6" => {
3597                        primary.offload_config.checksum_tx.tcp6 = tx;
3598                        primary.offload_config.checksum_rx.tcp6 = rx;
3599                    }
3600                    "*UDPChecksumOffloadIPv4" => {
3601                        primary.offload_config.checksum_tx.udp4 = tx;
3602                        primary.offload_config.checksum_rx.udp4 = rx;
3603                    }
3604                    "*UDPChecksumOffloadIPv6" => {
3605                        primary.offload_config.checksum_tx.udp6 = tx;
3606                        primary.offload_config.checksum_rx.udp6 = rx;
3607                    }
3608                    _ => {}
3609                }
3610            }
3611            rndisprot::NdisParameterType::INTEGER => {
3612                let value: u32 = value.read_plain()?;
3613                tracing::debug!(name, value, "rndis config");
3614            }
3615            parameter_type => tracelimit::warn_ratelimited!(
3616                name,
3617                ?parameter_type,
3618                "unhandled rndis config parameter type"
3619            ),
3620        }
3621        Ok(())
3622    }
3623}
3624
3625fn read_ndis_object<T: IntoBytes + FromBytes + Debug + Immutable + KnownLayout>(
3626    mut reader: impl MemoryRead,
3627    object_type: rndisprot::NdisObjectType,
3628    min_revision: u8,
3629    min_size: usize,
3630) -> Result<T, OidError> {
3631    let mut buffer = T::new_zeroed();
3632    let sent_size = reader.len();
3633    let len = sent_size.min(size_of_val(&buffer));
3634    reader.read(&mut buffer.as_mut_bytes()[..len])?;
3635    validate_ndis_object_header(
3636        &rndisprot::NdisObjectHeader::read_from_prefix(buffer.as_bytes())
3637            .unwrap()
3638            .0, // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
3639        sent_size,
3640        object_type,
3641        min_revision,
3642        min_size,
3643    )?;
3644    Ok(buffer)
3645}
3646
3647fn validate_ndis_object_header(
3648    header: &rndisprot::NdisObjectHeader,
3649    sent_size: usize,
3650    object_type: rndisprot::NdisObjectType,
3651    min_revision: u8,
3652    min_size: usize,
3653) -> Result<(), OidError> {
3654    if header.object_type != object_type {
3655        return Err(OidError::InvalidInput("header.object_type"));
3656    }
3657    if sent_size < header.size as usize {
3658        return Err(OidError::InvalidInput("header.size"));
3659    }
3660    if header.revision < min_revision {
3661        return Err(OidError::InvalidInput("header.revision"));
3662    }
3663    if (header.size as usize) < min_size {
3664        return Err(OidError::InvalidInput("header.size"));
3665    }
3666    Ok(())
3667}
3668
3669struct Coordinator {
3670    recv: mpsc::Receiver<CoordinatorMessage>,
3671    channel_control: ChannelControl,
3672    restart: bool,
3673    workers: Vec<TaskControl<NetQueue, Worker<GpadlRingMem>>>,
3674    buffers: Option<Arc<ChannelBuffers>>,
3675    num_queues: u16,
3676}
3677
3678/// Removing the VF may result in the guest sending messages to switch the data
3679/// path, so these operations need to happen asynchronously with message
3680/// processing.
3681enum CoordinatorStatePendingVfState {
3682    /// No pending updates.
3683    Ready,
3684    /// Delay before adding VF.
3685    Delay {
3686        timer: PolledTimer,
3687        delay_until: Instant,
3688    },
3689    /// A VF update is pending.
3690    Pending,
3691}
3692
3693struct CoordinatorState {
3694    endpoint: Box<dyn Endpoint>,
3695    adapter: Arc<Adapter>,
3696    virtual_function: Option<Box<dyn VirtualFunction>>,
3697    pending_vf_state: CoordinatorStatePendingVfState,
3698}
3699
3700impl InspectTaskMut<Coordinator> for CoordinatorState {
3701    fn inspect_mut(
3702        &mut self,
3703        req: inspect::Request<'_>,
3704        mut coordinator: Option<&mut Coordinator>,
3705    ) {
3706        let mut resp = req.respond();
3707
3708        let adapter = self.adapter.as_ref();
3709        resp.field("mac_address", adapter.mac_address)
3710            .field("max_queues", adapter.max_queues)
3711            .sensitivity_field(
3712                "offload_support",
3713                SensitivityLevel::Safe,
3714                &adapter.offload_support,
3715            )
3716            .field_mut_with("ring_size_limit", |v| -> anyhow::Result<_> {
3717                if let Some(v) = v {
3718                    let v: usize = v.parse()?;
3719                    adapter.ring_size_limit.store(v, Ordering::Relaxed);
3720                    // Bounce each task so that it sees the new value.
3721                    if let Some(this) = &mut coordinator {
3722                        for worker in &mut this.workers {
3723                            worker.update_with(|_, _| ());
3724                        }
3725                    }
3726                }
3727                Ok(adapter.ring_size_limit.load(Ordering::Relaxed))
3728            });
3729
3730        resp.field("endpoint_type", self.endpoint.endpoint_type())
3731            .field(
3732                "endpoint_max_queues",
3733                self.endpoint.multiqueue_support().max_queues,
3734            )
3735            .sensitivity_field_mut("endpoint", SensitivityLevel::Safe, self.endpoint.as_mut());
3736
3737        if let Some(coordinator) = coordinator {
3738            resp.sensitivity_child("queues", SensitivityLevel::Safe, |req| {
3739                let mut resp = req.respond();
3740                for (i, q) in coordinator.workers[..coordinator.num_queues as usize]
3741                    .iter_mut()
3742                    .enumerate()
3743                {
3744                    resp.field_mut(&i.to_string(), q);
3745                }
3746            });
3747
3748            // Get the shared channel state from the primary channel.
3749            resp.merge(inspect::adhoc_mut(|req| {
3750                let deferred = req.defer();
3751                coordinator.workers[0].update_with(|_, worker| {
3752                    let Some(worker) = worker.as_deref() else {
3753                        return;
3754                    };
3755                    if let Some(state) = worker.state.ready() {
3756                        deferred.respond(|resp| {
3757                            resp.merge(&state.buffers);
3758                            resp.sensitivity_field(
3759                                "primary_channel_state",
3760                                SensitivityLevel::Safe,
3761                                &state.state.primary,
3762                            )
3763                            .sensitivity_field(
3764                                "packet_filter",
3765                                SensitivityLevel::Safe,
3766                                inspect::AsHex(worker.channel.packet_filter),
3767                            );
3768                        });
3769                    }
3770                })
3771            }));
3772        }
3773    }
3774}
3775
3776impl AsyncRun<Coordinator> for CoordinatorState {
3777    async fn run(
3778        &mut self,
3779        stop: &mut StopTask<'_>,
3780        coordinator: &mut Coordinator,
3781    ) -> Result<(), task_control::Cancelled> {
3782        coordinator.process(stop, self).await
3783    }
3784}
3785
3786impl Coordinator {
3787    async fn process(
3788        &mut self,
3789        stop: &mut StopTask<'_>,
3790        state: &mut CoordinatorState,
3791    ) -> Result<(), task_control::Cancelled> {
3792        let mut sleep_duration: Option<Instant> = None;
3793        loop {
3794            if self.restart {
3795                stop.until_stopped(self.stop_workers()).await?;
3796                // The queue restart operation is not restartable, so do not
3797                // poll on `stop` here.
3798                if let Err(err) = self
3799                    .restart_queues(state)
3800                    .instrument(tracing::info_span!("netvsp_restart_queues"))
3801                    .await
3802                {
3803                    tracing::error!(
3804                        error = &err as &dyn std::error::Error,
3805                        "failed to restart queues"
3806                    );
3807                }
3808                if let Some(primary) = self.primary_mut() {
3809                    primary.is_data_path_switched =
3810                        state.endpoint.get_data_path_to_guest_vf().await.ok();
3811                    tracing::info!(
3812                        is_data_path_switched = primary.is_data_path_switched,
3813                        "Query data path state"
3814                    );
3815                }
3816                self.restore_guest_vf_state(state).await;
3817                self.restart = false;
3818            }
3819
3820            // Ensure that all workers except the primary are started. The
3821            // primary is started below if there are no outstanding messages.
3822            for worker in &mut self.workers[1..] {
3823                worker.start();
3824            }
3825
3826            enum Message {
3827                Internal(CoordinatorMessage),
3828                ChannelDisconnected,
3829                UpdateFromEndpoint(EndpointAction),
3830                UpdateFromVf(Rpc<(), ()>),
3831                OfferVfDevice,
3832                PendingVfStateComplete,
3833                TimerExpired,
3834            }
3835            let message = if matches!(
3836                state.pending_vf_state,
3837                CoordinatorStatePendingVfState::Pending
3838            ) {
3839                // The primary worker is allowed to run, but as no
3840                // notifications are being processed, if it is waiting for an
3841                // action then it should remain stopped until this completes
3842                // and the regular message processing logic resumes. Currently
3843                // the only message that requires processing is
3844                // DataPathSwitchPending, so check for that here.
3845                if !self.workers[0].is_running()
3846                    && self.primary_mut().is_none_or(|primary| {
3847                        !matches!(
3848                            primary.guest_vf_state,
3849                            PrimaryChannelGuestVfState::DataPathSwitchPending { result: None, .. }
3850                        )
3851                    })
3852                {
3853                    self.workers[0].start();
3854                }
3855
3856                // guest_ready_for_device is not restartable, so do not poll on
3857                // stop.
3858                state
3859                    .virtual_function
3860                    .as_mut()
3861                    .expect("Pending requires a VF")
3862                    .guest_ready_for_device()
3863                    .await;
3864                Message::PendingVfStateComplete
3865            } else {
3866                let timer_sleep = async {
3867                    if let Some(sleep_duration) = sleep_duration {
3868                        let mut timer = PolledTimer::new(&state.adapter.driver);
3869                        timer.sleep_until(sleep_duration).await;
3870                    } else {
3871                        pending::<()>().await;
3872                    }
3873                    Message::TimerExpired
3874                };
3875                let wait_for_message = async {
3876                    let internal_msg = self
3877                        .recv
3878                        .next()
3879                        .map(|x| x.map_or(Message::ChannelDisconnected, Message::Internal));
3880                    let endpoint_restart = state
3881                        .endpoint
3882                        .wait_for_endpoint_action()
3883                        .map(Message::UpdateFromEndpoint);
3884                    if let Some(vf) = state.virtual_function.as_mut() {
3885                        match state.pending_vf_state {
3886                            CoordinatorStatePendingVfState::Ready
3887                            | CoordinatorStatePendingVfState::Delay { .. } => {
3888                                let offer_device = async {
3889                                    if let CoordinatorStatePendingVfState::Delay {
3890                                        timer,
3891                                        delay_until,
3892                                    } = &mut state.pending_vf_state
3893                                    {
3894                                        timer.sleep_until(*delay_until).await;
3895                                    } else {
3896                                        pending::<()>().await;
3897                                    }
3898                                    Message::OfferVfDevice
3899                                };
3900                                (
3901                                    internal_msg,
3902                                    offer_device,
3903                                    endpoint_restart,
3904                                    vf.wait_for_state_change().map(Message::UpdateFromVf),
3905                                    timer_sleep,
3906                                )
3907                                    .race()
3908                                    .await
3909                            }
3910                            CoordinatorStatePendingVfState::Pending => unreachable!(),
3911                        }
3912                    } else {
3913                        (internal_msg, endpoint_restart, timer_sleep).race().await
3914                    }
3915                };
3916
3917                let mut wait_for_message = std::pin::pin!(wait_for_message);
3918                match (&mut wait_for_message).now_or_never() {
3919                    Some(message) => message,
3920                    None => {
3921                        self.workers[0].start();
3922                        stop.until_stopped(wait_for_message).await?
3923                    }
3924                }
3925            };
3926            match message {
3927                Message::UpdateFromVf(rpc) => {
3928                    rpc.handle(async |_| {
3929                        self.update_guest_vf_state(state).await;
3930                    })
3931                    .await;
3932                }
3933                Message::OfferVfDevice => {
3934                    self.workers[0].stop().await;
3935                    if let Some(primary) = self.primary_mut() {
3936                        if matches!(
3937                            primary.guest_vf_state,
3938                            PrimaryChannelGuestVfState::AvailableAdvertised
3939                        ) {
3940                            primary.guest_vf_state = PrimaryChannelGuestVfState::Ready;
3941                        }
3942                    }
3943
3944                    state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
3945                }
3946                Message::PendingVfStateComplete => {
3947                    state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
3948                }
3949                Message::TimerExpired => {
3950                    // Kick the worker as requested.
3951                    if self.workers[0].is_running() {
3952                        self.workers[0].stop().await;
3953                        if let Some(primary) = self.primary_mut() {
3954                            if let PendingLinkAction::Delay(up) = primary.pending_link_action {
3955                                primary.pending_link_action = PendingLinkAction::Active(up);
3956                            }
3957                        }
3958                    }
3959                    sleep_duration = None;
3960                }
3961                Message::Internal(CoordinatorMessage::Update(update_type)) => {
3962                    if update_type.filter_state {
3963                        self.stop_workers().await;
3964                        let worker_0_packet_filter =
3965                            self.workers[0].state().unwrap().channel.packet_filter;
3966                        self.workers.iter_mut().skip(1).for_each(|worker| {
3967                            if let Some(state) = worker.state_mut() {
3968                                state.channel.packet_filter = worker_0_packet_filter;
3969                                tracing::debug!(
3970                                    packet_filter = ?worker_0_packet_filter,
3971                                    channel_idx = state.channel_idx,
3972                                    "update packet filter"
3973                                );
3974                            }
3975                        });
3976                    }
3977
3978                    if update_type.guest_vf_state {
3979                        self.update_guest_vf_state(state).await;
3980                    }
3981                }
3982                Message::UpdateFromEndpoint(EndpointAction::RestartRequired) => self.restart = true,
3983                Message::UpdateFromEndpoint(EndpointAction::LinkStatusNotify(connect)) => {
3984                    self.workers[0].stop().await;
3985
3986                    // These are the only link state transitions that are tracked.
3987                    // 1. up -> down or down -> up
3988                    // 2. up -> down -> up or down -> up -> down.
3989                    // All other state transitions are coalesced into one of the above cases.
3990                    // For example, up -> down -> up -> down is treated as up -> down.
3991                    // N.B - Always queue up the incoming state to minimize the effects of loss
3992                    //       of any notifications (for example, during vtl2 servicing).
3993                    if let Some(primary) = self.primary_mut() {
3994                        primary.pending_link_action = PendingLinkAction::Active(connect);
3995                    }
3996
3997                    // If there is any existing sleep timer running, cancel it out.
3998                    sleep_duration = None;
3999                }
4000                Message::Internal(CoordinatorMessage::Restart) => self.restart = true,
4001                Message::Internal(CoordinatorMessage::StartTimer(duration)) => {
4002                    sleep_duration = Some(duration);
4003                    // Restart primary task.
4004                    self.workers[0].stop().await;
4005                }
4006                Message::ChannelDisconnected => {
4007                    break;
4008                }
4009            };
4010        }
4011        Ok(())
4012    }
4013
4014    async fn stop_workers(&mut self) {
4015        for worker in &mut self.workers {
4016            worker.stop().await;
4017        }
4018    }
4019
4020    async fn restore_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
4021        let primary = match self.primary_mut() {
4022            Some(primary) => primary,
4023            None => return,
4024        };
4025
4026        // Update guest VF state based on current endpoint properties.
4027        let virtual_function = c_state.virtual_function.as_mut();
4028        let guest_vf_id = match &virtual_function {
4029            Some(vf) => vf.id().await,
4030            None => None,
4031        };
4032        if let Some(guest_vf_id) = guest_vf_id {
4033            // Ensure guest VF is in proper state.
4034            match primary.guest_vf_state {
4035                PrimaryChannelGuestVfState::AvailableAdvertised
4036                | PrimaryChannelGuestVfState::Restoring(
4037                    saved_state::GuestVfState::AvailableAdvertised,
4038                ) => {
4039                    if !primary.is_data_path_switched.unwrap_or(false) {
4040                        let timer = PolledTimer::new(&c_state.adapter.driver);
4041                        c_state.pending_vf_state = CoordinatorStatePendingVfState::Delay {
4042                            timer,
4043                            delay_until: Instant::now() + VF_DEVICE_DELAY,
4044                        };
4045                    }
4046                }
4047                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4048                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4049                | PrimaryChannelGuestVfState::Ready
4050                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready)
4051                | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4052                | PrimaryChannelGuestVfState::Restoring(
4053                    saved_state::GuestVfState::DataPathSwitchPending { .. },
4054                )
4055                | PrimaryChannelGuestVfState::DataPathSwitched
4056                | PrimaryChannelGuestVfState::Restoring(
4057                    saved_state::GuestVfState::DataPathSwitched,
4058                )
4059                | PrimaryChannelGuestVfState::DataPathSynthetic => {
4060                    c_state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4061                }
4062                _ => (),
4063            };
4064            // ensure data path is switched as expected
4065            if let PrimaryChannelGuestVfState::Restoring(
4066                saved_state::GuestVfState::DataPathSwitchPending {
4067                    to_guest,
4068                    id,
4069                    result,
4070                },
4071            ) = primary.guest_vf_state
4072            {
4073                // If the save was after the data path switch already occurred, don't do it again.
4074                if result.is_some() {
4075                    primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
4076                        to_guest,
4077                        id,
4078                        result,
4079                    };
4080                    return;
4081                }
4082            }
4083            primary.guest_vf_state = match primary.guest_vf_state {
4084                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4085                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4086                | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4087                | PrimaryChannelGuestVfState::Restoring(
4088                    saved_state::GuestVfState::DataPathSwitchPending { .. },
4089                )
4090                | PrimaryChannelGuestVfState::DataPathSynthetic => {
4091                    let (to_guest, id) = match primary.guest_vf_state {
4092                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4093                            to_guest,
4094                            id,
4095                        }
4096                        | PrimaryChannelGuestVfState::DataPathSwitchPending {
4097                            to_guest, id, ..
4098                        }
4099                        | PrimaryChannelGuestVfState::Restoring(
4100                            saved_state::GuestVfState::DataPathSwitchPending {
4101                                to_guest, id, ..
4102                            },
4103                        ) => (to_guest, id),
4104                        _ => (true, None),
4105                    };
4106                    // Cancel any outstanding delay timers for VF offers if the data path is
4107                    // getting switched, since the guest is already issuing
4108                    // commands assuming a VF.
4109                    if matches!(
4110                        c_state.pending_vf_state,
4111                        CoordinatorStatePendingVfState::Delay { .. }
4112                    ) {
4113                        c_state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4114                    }
4115                    let result = c_state.endpoint.set_data_path_to_guest_vf(to_guest).await;
4116                    let result = if let Err(err) = result {
4117                        tracing::error!(
4118                            err = err.as_ref() as &dyn std::error::Error,
4119                            to_guest,
4120                            "Failed to switch guest VF data path"
4121                        );
4122                        false
4123                    } else {
4124                        primary.is_data_path_switched = Some(to_guest);
4125                        true
4126                    };
4127                    match primary.guest_vf_state {
4128                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4129                            ..
4130                        }
4131                        | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4132                        | PrimaryChannelGuestVfState::Restoring(
4133                            saved_state::GuestVfState::DataPathSwitchPending { .. },
4134                        ) => PrimaryChannelGuestVfState::DataPathSwitchPending {
4135                            to_guest,
4136                            id,
4137                            result: Some(result),
4138                        },
4139                        _ if result => PrimaryChannelGuestVfState::DataPathSwitched,
4140                        _ => PrimaryChannelGuestVfState::DataPathSynthetic,
4141                    }
4142                }
4143                PrimaryChannelGuestVfState::Initializing
4144                | PrimaryChannelGuestVfState::Unavailable
4145                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
4146                    PrimaryChannelGuestVfState::Available { vfid: guest_vf_id }
4147                }
4148                PrimaryChannelGuestVfState::AvailableAdvertised
4149                | PrimaryChannelGuestVfState::Restoring(
4150                    saved_state::GuestVfState::AvailableAdvertised,
4151                ) => {
4152                    if !primary.is_data_path_switched.unwrap_or(false) {
4153                        PrimaryChannelGuestVfState::AvailableAdvertised
4154                    } else {
4155                        // A previous instantiation already switched the data
4156                        // path.
4157                        PrimaryChannelGuestVfState::DataPathSwitched
4158                    }
4159                }
4160                PrimaryChannelGuestVfState::DataPathSwitched
4161                | PrimaryChannelGuestVfState::Restoring(
4162                    saved_state::GuestVfState::DataPathSwitched,
4163                ) => PrimaryChannelGuestVfState::DataPathSwitched,
4164                PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4165                    PrimaryChannelGuestVfState::Ready
4166                }
4167                _ => primary.guest_vf_state,
4168            };
4169        } else {
4170            // If the device was just removed, make sure the the data path is synthetic.
4171            match primary.guest_vf_state {
4172                PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, .. }
4173                | PrimaryChannelGuestVfState::Restoring(
4174                    saved_state::GuestVfState::DataPathSwitchPending { to_guest, .. },
4175                ) => {
4176                    if !to_guest {
4177                        if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4178                            tracing::warn!(
4179                                err = err.as_ref() as &dyn std::error::Error,
4180                                "Failed setting data path back to synthetic after guest VF was removed."
4181                            );
4182                        }
4183                        primary.is_data_path_switched = Some(false);
4184                    }
4185                }
4186                PrimaryChannelGuestVfState::DataPathSwitched
4187                | PrimaryChannelGuestVfState::Restoring(
4188                    saved_state::GuestVfState::DataPathSwitched,
4189                ) => {
4190                    if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4191                        tracing::warn!(
4192                            err = err.as_ref() as &dyn std::error::Error,
4193                            "Failed setting data path back to synthetic after guest VF was removed."
4194                        );
4195                    }
4196                    primary.is_data_path_switched = Some(false);
4197                }
4198                _ => (),
4199            }
4200            if let PrimaryChannelGuestVfState::AvailableAdvertised = primary.guest_vf_state {
4201                c_state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
4202            }
4203            // Notify guest if VF is no longer available
4204            primary.guest_vf_state = match primary.guest_vf_state {
4205                PrimaryChannelGuestVfState::Initializing
4206                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState)
4207                | PrimaryChannelGuestVfState::Available { .. } => {
4208                    PrimaryChannelGuestVfState::Unavailable
4209                }
4210                PrimaryChannelGuestVfState::AvailableAdvertised
4211                | PrimaryChannelGuestVfState::Restoring(
4212                    saved_state::GuestVfState::AvailableAdvertised,
4213                )
4214                | PrimaryChannelGuestVfState::Ready
4215                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4216                    PrimaryChannelGuestVfState::UnavailableFromAvailable
4217                }
4218                PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, id, .. }
4219                | PrimaryChannelGuestVfState::Restoring(
4220                    saved_state::GuestVfState::DataPathSwitchPending { to_guest, id, .. },
4221                ) => PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4222                    to_guest,
4223                    id,
4224                },
4225                PrimaryChannelGuestVfState::DataPathSwitched
4226                | PrimaryChannelGuestVfState::Restoring(
4227                    saved_state::GuestVfState::DataPathSwitched,
4228                )
4229                | PrimaryChannelGuestVfState::DataPathSynthetic => {
4230                    PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4231                }
4232                _ => primary.guest_vf_state,
4233            }
4234        }
4235    }
4236
4237    async fn restart_queues(&mut self, c_state: &mut CoordinatorState) -> Result<(), WorkerError> {
4238        // Drop all the queues and stop the endpoint. Collect the worker drivers to pass to the queues.
4239        let drivers = self
4240            .workers
4241            .iter_mut()
4242            .map(|worker| {
4243                let task = worker.task_mut();
4244                task.queue_state = None;
4245                task.driver.clone()
4246            })
4247            .collect::<Vec<_>>();
4248
4249        c_state.endpoint.stop().await;
4250
4251        let (primary_worker, subworkers) = if let [primary, sub @ ..] = self.workers.as_mut_slice()
4252        {
4253            (primary, sub)
4254        } else {
4255            unreachable!()
4256        };
4257
4258        let state = primary_worker
4259            .state_mut()
4260            .and_then(|worker| worker.state.ready_mut());
4261
4262        let state = if let Some(state) = state {
4263            state
4264        } else {
4265            return Ok(());
4266        };
4267
4268        // Save the channel buffers for use in the subchannel workers.
4269        self.buffers = Some(state.buffers.clone());
4270
4271        let num_queues = state.state.primary.as_ref().unwrap().requested_num_queues;
4272        let mut active_queues = Vec::new();
4273        let active_queue_count =
4274            if let Some(rss_state) = state.state.primary.as_ref().unwrap().rss_state.as_ref() {
4275                // Active queue count is computed as the number of unique entries in the indirection table
4276                active_queues.clone_from(&rss_state.indirection_table);
4277                active_queues.sort();
4278                active_queues.dedup();
4279                active_queues = active_queues
4280                    .into_iter()
4281                    .filter(|&index| index < num_queues)
4282                    .collect::<Vec<_>>();
4283                if !active_queues.is_empty() {
4284                    active_queues.len() as u16
4285                } else {
4286                    tracelimit::warn_ratelimited!("Invalid RSS indirection table");
4287                    num_queues
4288                }
4289            } else {
4290                num_queues
4291            };
4292
4293        // Distribute the rx buffers to only the active queues.
4294        let (ranges, mut remote_buffer_id_recvs) =
4295            RxBufferRanges::new(state.buffers.recv_buffer.count, active_queue_count.into());
4296        let ranges = Arc::new(ranges);
4297
4298        let mut queues = Vec::new();
4299        let mut rx_buffers = Vec::new();
4300        {
4301            let buffers = &state.buffers;
4302            let guest_buffers = Arc::new(
4303                GuestBuffers::new(
4304                    buffers.mem.clone(),
4305                    buffers.recv_buffer.gpadl.clone(),
4306                    buffers.recv_buffer.sub_allocation_size,
4307                    buffers.ndis_config.mtu,
4308                )
4309                .map_err(WorkerError::GpadlError)?,
4310            );
4311
4312            // Get the list of free rx buffers from each task, then partition
4313            // the list per-active-queue, and produce the queue configuration.
4314            let mut queue_config = Vec::new();
4315            let initial_rx;
4316            {
4317                let states = std::iter::once(Some(&*state)).chain(
4318                    subworkers
4319                        .iter()
4320                        .map(|worker| worker.state().and_then(|worker| worker.state.ready())),
4321                );
4322
4323                initial_rx = (RX_RESERVED_CONTROL_BUFFERS..state.buffers.recv_buffer.count)
4324                    .filter(|&n| {
4325                        states
4326                            .clone()
4327                            .flatten()
4328                            .all(|s| (s.state.rx_bufs.is_free(n)))
4329                    })
4330                    .map(RxId)
4331                    .collect::<Vec<_>>();
4332
4333                let mut initial_rx = initial_rx.as_slice();
4334                let mut range_start = 0;
4335                let primary_queue_excluded = !active_queues.is_empty() && active_queues[0] != 0;
4336                let first_queue = if !primary_queue_excluded {
4337                    0
4338                } else {
4339                    // If the primary queue is excluded from the guest supplied
4340                    // indirection table, it is assigned just the reserved
4341                    // buffers.
4342                    queue_config.push(QueueConfig {
4343                        pool: Box::new(BufferPool::new(guest_buffers.clone())),
4344                        initial_rx: &[],
4345                        driver: Box::new(drivers[0].clone()),
4346                    });
4347                    rx_buffers.push(RxBufferRange::new(
4348                        ranges.clone(),
4349                        0..RX_RESERVED_CONTROL_BUFFERS,
4350                        None,
4351                    ));
4352                    range_start = RX_RESERVED_CONTROL_BUFFERS;
4353                    1
4354                };
4355                for queue_index in first_queue..num_queues {
4356                    let queue_active = active_queues.is_empty()
4357                        || active_queues.binary_search(&queue_index).is_ok();
4358                    let (range_end, end, buffer_id_recv) = if queue_active {
4359                        let range_end = if rx_buffers.len() as u16 == active_queue_count - 1 {
4360                            // The last queue gets all the remaining buffers.
4361                            state.buffers.recv_buffer.count
4362                        } else if queue_index == 0 {
4363                            // Queue zero always includes the reserved buffers.
4364                            RX_RESERVED_CONTROL_BUFFERS + ranges.buffers_per_queue
4365                        } else {
4366                            range_start + ranges.buffers_per_queue
4367                        };
4368                        (
4369                            range_end,
4370                            initial_rx.partition_point(|id| id.0 < range_end),
4371                            Some(remote_buffer_id_recvs.remove(0)),
4372                        )
4373                    } else {
4374                        (range_start, 0, None)
4375                    };
4376
4377                    let (this, rest) = initial_rx.split_at(end);
4378                    queue_config.push(QueueConfig {
4379                        pool: Box::new(BufferPool::new(guest_buffers.clone())),
4380                        initial_rx: this,
4381                        driver: Box::new(drivers[queue_index as usize].clone()),
4382                    });
4383                    initial_rx = rest;
4384                    rx_buffers.push(RxBufferRange::new(
4385                        ranges.clone(),
4386                        range_start..range_end,
4387                        buffer_id_recv,
4388                    ));
4389
4390                    range_start = range_end;
4391                }
4392            }
4393
4394            let primary = state.state.primary.as_mut().unwrap();
4395            tracing::debug!(num_queues, "enabling endpoint");
4396
4397            let rss = primary
4398                .rss_state
4399                .as_ref()
4400                .map(|rss| net_backend::RssConfig {
4401                    key: &rss.key,
4402                    indirection_table: &rss.indirection_table,
4403                    flags: 0,
4404                });
4405
4406            c_state
4407                .endpoint
4408                .get_queues(queue_config, rss.as_ref(), &mut queues)
4409                .instrument(tracing::info_span!("netvsp_get_queues"))
4410                .await
4411                .map_err(WorkerError::Endpoint)?;
4412
4413            assert_eq!(queues.len(), num_queues as usize);
4414
4415            // Set the subchannel count.
4416            self.channel_control
4417                .enable_subchannels(num_queues - 1)
4418                .expect("already validated");
4419
4420            self.num_queues = num_queues;
4421        }
4422
4423        let worker_0_packet_filter = self.workers[0].state().unwrap().channel.packet_filter;
4424        // Provide the queue and receive buffer ranges for each worker.
4425        for ((worker, queue), rx_buffer) in self.workers.iter_mut().zip(queues).zip(rx_buffers) {
4426            worker.task_mut().queue_state = Some(QueueState {
4427                queue,
4428                target_vp_set: false,
4429                rx_buffer_range: rx_buffer,
4430            });
4431            // Update the receive packet filter for the subchannel worker.
4432            if let Some(worker) = worker.state_mut() {
4433                worker.channel.packet_filter = worker_0_packet_filter;
4434                // Clear any pending RxIds as buffers were redistributed.
4435                if let Some(ready_state) = worker.state.ready_mut() {
4436                    ready_state.state.pending_rx_packets.clear();
4437                }
4438            }
4439        }
4440
4441        Ok(())
4442    }
4443
4444    fn primary_mut(&mut self) -> Option<&mut PrimaryChannelState> {
4445        self.workers[0]
4446            .state_mut()
4447            .unwrap()
4448            .state
4449            .ready_mut()?
4450            .state
4451            .primary
4452            .as_mut()
4453    }
4454
4455    async fn update_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
4456        self.workers[0].stop().await;
4457        self.restore_guest_vf_state(c_state).await;
4458    }
4459}
4460
4461impl<T: RingMem + 'static + Sync> AsyncRun<Worker<T>> for NetQueue {
4462    async fn run(
4463        &mut self,
4464        stop: &mut StopTask<'_>,
4465        worker: &mut Worker<T>,
4466    ) -> Result<(), task_control::Cancelled> {
4467        match worker.process(stop, self).await {
4468            Ok(()) | Err(WorkerError::BufferRevoked) => {}
4469            Err(WorkerError::Cancelled(cancelled)) => return Err(cancelled),
4470            Err(err) => {
4471                tracing::error!(
4472                    channel_idx = worker.channel_idx,
4473                    error = &err as &dyn std::error::Error,
4474                    "netvsp error"
4475                );
4476            }
4477        }
4478        Ok(())
4479    }
4480}
4481
4482impl<T: RingMem + 'static> Worker<T> {
4483    async fn process(
4484        &mut self,
4485        stop: &mut StopTask<'_>,
4486        queue: &mut NetQueue,
4487    ) -> Result<(), WorkerError> {
4488        // Be careful not to wait on actions with unbounded blocking time (e.g.
4489        // guest actions, or waiting for network packets to arrive) without
4490        // wrapping the wait on `stop.until_stopped`.
4491        loop {
4492            match &mut self.state {
4493                WorkerState::Init(initializing) => {
4494                    if self.channel_idx != 0 {
4495                        // Still waiting for the coordinator to provide receive
4496                        // buffers. The task will be restarted when they are available.
4497                        stop.until_stopped(pending()).await?
4498                    }
4499
4500                    tracelimit::info_ratelimited!("network accepted");
4501
4502                    let (buffers, state) = stop
4503                        .until_stopped(self.channel.initialize(initializing, self.mem.clone()))
4504                        .await??;
4505
4506                    let state = ReadyState {
4507                        buffers: Arc::new(buffers),
4508                        state,
4509                        data: ProcessingData::new(),
4510                    };
4511
4512                    // Wake up the coordinator task to start the queues.
4513                    let _ = self.coordinator_send.try_send(CoordinatorMessage::Restart);
4514
4515                    tracelimit::info_ratelimited!("network initialized");
4516                    self.state = WorkerState::Ready(state);
4517                }
4518                WorkerState::Ready(state) => {
4519                    let queue_state = if let Some(queue_state) = &mut queue.queue_state {
4520                        if !queue_state.target_vp_set
4521                            && self.target_vp != vmbus_core::protocol::VP_INDEX_DISABLE_INTERRUPT
4522                        {
4523                            tracing::debug!(
4524                                channel_idx = self.channel_idx,
4525                                target_vp = self.target_vp,
4526                                "updating target VP"
4527                            );
4528                            queue_state.queue.update_target_vp(self.target_vp).await;
4529                            queue_state.target_vp_set = true;
4530                        }
4531
4532                        queue_state
4533                    } else {
4534                        // This task will be restarted when the queues are ready.
4535                        stop.until_stopped(pending()).await?
4536                    };
4537
4538                    let result = self.channel.main_loop(stop, state, queue_state).await;
4539                    match result {
4540                        Ok(restart) => {
4541                            assert_eq!(self.channel_idx, 0);
4542                            let _ = self.coordinator_send.try_send(restart);
4543                        }
4544                        Err(WorkerError::EndpointRequiresQueueRestart(err)) => {
4545                            tracelimit::warn_ratelimited!(
4546                                err = err.as_ref() as &dyn std::error::Error,
4547                                "Endpoint requires queues to restart",
4548                            );
4549                            if let Err(try_send_err) =
4550                                self.coordinator_send.try_send(CoordinatorMessage::Restart)
4551                            {
4552                                tracing::error!(
4553                                    try_send_err = &try_send_err as &dyn std::error::Error,
4554                                    "failed to restart queues"
4555                                );
4556                                return Err(WorkerError::Endpoint(err));
4557                            }
4558                        }
4559                        Err(err) => return Err(err),
4560                    }
4561
4562                    stop.until_stopped(pending()).await?
4563                }
4564            }
4565        }
4566    }
4567}
4568
4569impl<T: 'static + RingMem> NetChannel<T> {
4570    fn try_next_packet<'a>(
4571        &mut self,
4572        send_buffer: Option<&'a SendBuffer>,
4573        version: Option<Version>,
4574    ) -> Result<Option<Packet<'a>>, WorkerError> {
4575        let (mut read, _) = self.queue.split();
4576        let packet = match read.try_read() {
4577            Ok(packet) => {
4578                parse_packet(&packet, send_buffer, version).map_err(WorkerError::Packet)?
4579            }
4580            Err(queue::TryReadError::Empty) => return Ok(None),
4581            Err(queue::TryReadError::Queue(err)) => return Err(err.into()),
4582        };
4583
4584        tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
4585        Ok(Some(packet))
4586    }
4587
4588    async fn next_packet<'a>(
4589        &mut self,
4590        send_buffer: Option<&'a SendBuffer>,
4591        version: Option<Version>,
4592    ) -> Result<Packet<'a>, WorkerError> {
4593        let (mut read, _) = self.queue.split();
4594        let mut packet_ref = read.read().await?;
4595        let packet =
4596            parse_packet(&packet_ref, send_buffer, version).map_err(WorkerError::Packet)?;
4597        if matches!(packet.data, PacketData::RndisPacket(_)) {
4598            // In WorkerState::Init if an rndis packet is received, assume it is MESSAGE_TYPE_INITIALIZE_MSG
4599            tracing::trace!(target: "netvsp/vmbus", "detected rndis initialization message");
4600            packet_ref.revert();
4601        }
4602        tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
4603        Ok(packet)
4604    }
4605
4606    fn is_ready_to_initialize(initializing: &InitState, allow_missing_send_buffer: bool) -> bool {
4607        (initializing.ndis_config.is_some() || initializing.version < Version::V2)
4608            && initializing.ndis_version.is_some()
4609            && (initializing.send_buffer.is_some() || allow_missing_send_buffer)
4610            && initializing.recv_buffer.is_some()
4611    }
4612
4613    async fn initialize(
4614        &mut self,
4615        initializing: &mut Option<InitState>,
4616        mem: GuestMemory,
4617    ) -> Result<(ChannelBuffers, ActiveState), WorkerError> {
4618        let mut has_init_packet_arrived = false;
4619        loop {
4620            if let Some(initializing) = &mut *initializing {
4621                if Self::is_ready_to_initialize(initializing, false) || has_init_packet_arrived {
4622                    let recv_buffer = initializing.recv_buffer.take().unwrap();
4623                    let send_buffer = initializing.send_buffer.take();
4624                    let state = ActiveState::new(
4625                        Some(PrimaryChannelState::new(
4626                            self.adapter.offload_support.clone(),
4627                        )),
4628                        recv_buffer.count,
4629                    );
4630                    let buffers = ChannelBuffers {
4631                        version: initializing.version,
4632                        mem,
4633                        recv_buffer,
4634                        send_buffer,
4635                        ndis_version: initializing.ndis_version.take().unwrap(),
4636                        ndis_config: initializing.ndis_config.take().unwrap_or(NdisConfig {
4637                            mtu: DEFAULT_MTU,
4638                            capabilities: protocol::NdisConfigCapabilities::new(),
4639                        }),
4640                    };
4641
4642                    break Ok((buffers, state));
4643                }
4644            }
4645
4646            // Wait for enough room in the ring to avoid needing to track
4647            // completion packets.
4648            self.queue
4649                .split()
4650                .1
4651                .wait_ready(ring::PacketSize::completion(protocol::PACKET_SIZE_V61))
4652                .await?;
4653
4654            let packet = self
4655                .next_packet(None, initializing.as_ref().map(|x| x.version))
4656                .await?;
4657
4658            if let Some(initializing) = &mut *initializing {
4659                match packet.data {
4660                    PacketData::SendNdisConfig(config) => {
4661                        if initializing.ndis_config.is_some() {
4662                            return Err(WorkerError::UnexpectedPacketOrder(
4663                                PacketOrderError::SendNdisConfigExists,
4664                            ));
4665                        }
4666
4667                        // As in the vmswitch, if the MTU is invalid then use the default.
4668                        let mtu = if config.mtu >= MIN_MTU && config.mtu <= MAX_MTU {
4669                            config.mtu
4670                        } else {
4671                            DEFAULT_MTU
4672                        };
4673
4674                        // The UEFI client expects a completion packet, which can be empty.
4675                        self.send_completion(packet.transaction_id, &[])?;
4676                        initializing.ndis_config = Some(NdisConfig {
4677                            mtu,
4678                            capabilities: config.capabilities,
4679                        });
4680                    }
4681                    PacketData::SendNdisVersion(version) => {
4682                        if initializing.ndis_version.is_some() {
4683                            return Err(WorkerError::UnexpectedPacketOrder(
4684                                PacketOrderError::SendNdisVersionExists,
4685                            ));
4686                        }
4687
4688                        // The UEFI client expects a completion packet, which can be empty.
4689                        self.send_completion(packet.transaction_id, &[])?;
4690                        initializing.ndis_version = Some(NdisVersion {
4691                            major: version.ndis_major_version,
4692                            minor: version.ndis_minor_version,
4693                        });
4694                    }
4695                    PacketData::SendReceiveBuffer(message) => {
4696                        if initializing.recv_buffer.is_some() {
4697                            return Err(WorkerError::UnexpectedPacketOrder(
4698                                PacketOrderError::SendReceiveBufferExists,
4699                            ));
4700                        }
4701
4702                        let mtu = if let Some(cfg) = &initializing.ndis_config {
4703                            cfg.mtu
4704                        } else if initializing.version < Version::V2 {
4705                            DEFAULT_MTU
4706                        } else {
4707                            return Err(WorkerError::UnexpectedPacketOrder(
4708                                PacketOrderError::SendReceiveBufferMissingMTU,
4709                            ));
4710                        };
4711
4712                        let sub_allocation_size = sub_allocation_size_for_mtu(mtu);
4713
4714                        let recv_buffer = ReceiveBuffer::new(
4715                            &self.gpadl_map,
4716                            message.gpadl_handle,
4717                            message.id,
4718                            sub_allocation_size,
4719                        )?;
4720
4721                        self.send_completion(
4722                            packet.transaction_id,
4723                            &self
4724                                .message(
4725                                    protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER_COMPLETE,
4726                                    protocol::Message1SendReceiveBufferComplete {
4727                                        status: protocol::Status::SUCCESS,
4728                                        num_sections: 1,
4729                                        sections: [protocol::ReceiveBufferSection {
4730                                            offset: 0,
4731                                            sub_allocation_size: recv_buffer.sub_allocation_size,
4732                                            num_sub_allocations: recv_buffer.count,
4733                                            end_offset: recv_buffer.sub_allocation_size
4734                                                * recv_buffer.count,
4735                                        }],
4736                                    },
4737                                )
4738                                .payload(),
4739                        )?;
4740                        initializing.recv_buffer = Some(recv_buffer);
4741                    }
4742                    PacketData::SendSendBuffer(message) => {
4743                        if initializing.send_buffer.is_some() {
4744                            return Err(WorkerError::UnexpectedPacketOrder(
4745                                PacketOrderError::SendSendBufferExists,
4746                            ));
4747                        }
4748
4749                        let send_buffer = SendBuffer::new(&self.gpadl_map, message.gpadl_handle)?;
4750                        self.send_completion(
4751                            packet.transaction_id,
4752                            &self
4753                                .message(
4754                                    protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER_COMPLETE,
4755                                    protocol::Message1SendSendBufferComplete {
4756                                        status: protocol::Status::SUCCESS,
4757                                        section_size: 6144,
4758                                    },
4759                                )
4760                                .payload(),
4761                        )?;
4762
4763                        initializing.send_buffer = Some(send_buffer);
4764                    }
4765                    PacketData::RndisPacket(rndis_packet) => {
4766                        if !Self::is_ready_to_initialize(initializing, true) {
4767                            return Err(WorkerError::UnexpectedPacketOrder(
4768                                PacketOrderError::UnexpectedRndisPacket,
4769                            ));
4770                        }
4771                        tracing::debug!(
4772                            channel_type = rndis_packet.channel_type,
4773                            "RndisPacket received during initialization, assuming MESSAGE_TYPE_INITIALIZE_MSG"
4774                        );
4775                        has_init_packet_arrived = true;
4776                    }
4777                    _ => {
4778                        return Err(WorkerError::UnexpectedPacketOrder(
4779                            PacketOrderError::InvalidPacketData,
4780                        ));
4781                    }
4782                }
4783            } else {
4784                match packet.data {
4785                    PacketData::Init(init) => {
4786                        let requested_version = init.protocol_version;
4787                        let version = check_version(requested_version);
4788                        let mut message = self.message(
4789                            protocol::MESSAGE_TYPE_INIT_COMPLETE,
4790                            protocol::MessageInitComplete {
4791                                deprecated: protocol::INVALID_PROTOCOL_VERSION,
4792                                maximum_mdl_chain_length: 34,
4793                                status: protocol::Status::NONE,
4794                            },
4795                        );
4796                        if let Some(version) = version {
4797                            if version == Version::V1 {
4798                                message.data.deprecated = Version::V1 as u32;
4799                            }
4800                            message.data.status = protocol::Status::SUCCESS;
4801                        } else {
4802                            tracing::debug!(requested_version, "unrecognized version");
4803                        }
4804                        self.send_completion(packet.transaction_id, &message.payload())?;
4805
4806                        if let Some(version) = version {
4807                            tracelimit::info_ratelimited!(?version, "network negotiated");
4808
4809                            if version >= Version::V61 {
4810                                // Update the packet size so that the appropriate padding is
4811                                // appended for picky Windows guests.
4812                                self.packet_size = protocol::PACKET_SIZE_V61;
4813                            }
4814                            *initializing = Some(InitState {
4815                                version,
4816                                ndis_config: None,
4817                                ndis_version: None,
4818                                recv_buffer: None,
4819                                send_buffer: None,
4820                            });
4821                        }
4822                    }
4823                    _ => unreachable!(),
4824                }
4825            }
4826        }
4827    }
4828
4829    async fn main_loop(
4830        &mut self,
4831        stop: &mut StopTask<'_>,
4832        ready_state: &mut ReadyState,
4833        queue_state: &mut QueueState,
4834    ) -> Result<CoordinatorMessage, WorkerError> {
4835        let buffers = &ready_state.buffers;
4836        let state = &mut ready_state.state;
4837        let data = &mut ready_state.data;
4838
4839        let ring_spare_capacity = {
4840            let (_, send) = self.queue.split();
4841            let mut limit = if self.can_use_ring_size_opt {
4842                self.adapter.ring_size_limit.load(Ordering::Relaxed)
4843            } else {
4844                0
4845            };
4846            if limit == 0 {
4847                limit = send.capacity() - 2048;
4848            }
4849            send.capacity() - limit
4850        };
4851
4852        // If the packet filter has changed to allow rx packets, add any pended RxIds.
4853        if !state.pending_rx_packets.is_empty()
4854            && self.packet_filter != rndisprot::NDIS_PACKET_TYPE_NONE
4855        {
4856            let epqueue = queue_state.queue.as_mut();
4857            let (front, back) = state.pending_rx_packets.as_slices();
4858            epqueue.rx_avail(front);
4859            epqueue.rx_avail(back);
4860            state.pending_rx_packets.clear();
4861        }
4862
4863        // Handle any guest state changes since last run.
4864        if let Some(primary) = state.primary.as_mut() {
4865            if primary.requested_num_queues > 1 && !primary.tx_spread_sent {
4866                let num_channels_opened =
4867                    self.adapter.num_sub_channels_opened.load(Ordering::Relaxed);
4868                if num_channels_opened == primary.requested_num_queues as usize {
4869                    let (_, mut send) = self.queue.split();
4870                    stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
4871                        .await??;
4872                    self.guest_send_indirection_table(buffers.version, num_channels_opened as u32);
4873                    primary.tx_spread_sent = true;
4874                }
4875            }
4876            if let PendingLinkAction::Active(up) = primary.pending_link_action {
4877                let (_, mut send) = self.queue.split();
4878                stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
4879                    .await??;
4880                if let Some(id) = primary.free_control_buffers.pop() {
4881                    let connect = if primary.guest_link_up != up {
4882                        primary.pending_link_action = PendingLinkAction::Default;
4883                        up
4884                    } else {
4885                        // For the up -> down -> up OR down -> up -> down case, the first transition
4886                        // is sent immediately and the second transition is queued with a delay.
4887                        primary.pending_link_action =
4888                            PendingLinkAction::Delay(primary.guest_link_up);
4889                        !primary.guest_link_up
4890                    };
4891                    // Mark the receive buffer in use to allow the guest to release it.
4892                    assert!(state.rx_bufs.is_free(id.0));
4893                    state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
4894                    let state_to_send = if connect {
4895                        rndisprot::STATUS_MEDIA_CONNECT
4896                    } else {
4897                        rndisprot::STATUS_MEDIA_DISCONNECT
4898                    };
4899                    tracing::info!(
4900                        connect,
4901                        mac_address = %self.adapter.mac_address,
4902                        "sending link status"
4903                    );
4904
4905                    self.indicate_status(buffers, id.0, state_to_send, &[])?;
4906                    primary.guest_link_up = connect;
4907                } else {
4908                    primary.pending_link_action = PendingLinkAction::Delay(up);
4909                }
4910
4911                match primary.pending_link_action {
4912                    PendingLinkAction::Delay(_) => {
4913                        return Ok(CoordinatorMessage::StartTimer(
4914                            Instant::now() + LINK_DELAY_DURATION,
4915                        ));
4916                    }
4917                    PendingLinkAction::Active(_) => panic!("State should not be Active"),
4918                    _ => {}
4919                }
4920            }
4921            match primary.guest_vf_state {
4922                PrimaryChannelGuestVfState::Available { .. }
4923                | PrimaryChannelGuestVfState::UnavailableFromAvailable
4924                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4925                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4926                | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4927                | PrimaryChannelGuestVfState::DataPathSynthetic => {
4928                    let (_, mut send) = self.queue.split();
4929                    stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
4930                        .await??;
4931                    if let Some(message) = self.handle_state_change(primary, buffers).await? {
4932                        return Ok(message);
4933                    }
4934                }
4935                _ => (),
4936            }
4937        }
4938
4939        loop {
4940            // If the ring is almost full, then do not poll the endpoint,
4941            // since this will cause us to spend time dropping packets and
4942            // reprogramming receive buffers. It's more efficient to let the
4943            // backend drop packets.
4944            let ring_full = {
4945                let (_, mut send) = self.queue.split();
4946                !send.can_write(ring_spare_capacity)?
4947            };
4948
4949            let did_some_work = (!ring_full
4950                && self.process_endpoint_rx(buffers, state, data, queue_state.queue.as_mut())?)
4951                | self.process_ring_buffer(buffers, state, data, queue_state)?
4952                | (!ring_full
4953                    && self.process_endpoint_tx(state, data, queue_state.queue.as_mut())?)
4954                | self.transmit_pending_segments(state, data, queue_state)?
4955                | self.send_pending_packets(state)?;
4956
4957            if !did_some_work {
4958                state.stats.spurious_wakes.increment();
4959            }
4960
4961            // Process any outstanding control messages before sleeping in case
4962            // there are now suballocations available for use.
4963            self.process_control_messages(buffers, state)?;
4964
4965            // This should be the only await point waiting on network traffic or
4966            // guest actions. Wrap it in `stop.until_stopped` to allow
4967            // cancellation.
4968            let restart = stop
4969                .until_stopped(std::future::poll_fn(
4970                    |cx| -> Poll<Option<CoordinatorMessage>> {
4971                        // If the ring is almost full, then don't wait for endpoint
4972                        // interrupts. This allows the interrupt rate to fall when the
4973                        // guest cannot keep up with the load.
4974                        if !ring_full {
4975                            // Check the network endpoint for tx completion or rx.
4976                            if queue_state.queue.poll_ready(cx).is_ready() {
4977                                tracing::trace!("endpoint ready");
4978                                return Poll::Ready(None);
4979                            }
4980                        }
4981
4982                        // Check the incoming ring for tx, but only if there are enough
4983                        // free tx packets.
4984                        let (mut recv, mut send) = self.queue.split();
4985                        if state.free_tx_packets.len() >= self.adapter.free_tx_packet_threshold
4986                            && data.tx_segments.is_empty()
4987                            && recv.poll_ready(cx).is_ready()
4988                        {
4989                            tracing::trace!("incoming ring ready");
4990                            return Poll::Ready(None);
4991                        }
4992
4993                        // Check the outgoing ring for space to send rx completions or
4994                        // control message sends, if any are pending.
4995                        //
4996                        // Also, if endpoint processing has been suspended due to the
4997                        // ring being nearly full, then ask the guest to wake us up when
4998                        // there is space again.
4999                        let mut pending_send_size = self.pending_send_size;
5000                        if ring_full {
5001                            pending_send_size = ring_spare_capacity;
5002                        }
5003                        if pending_send_size != 0
5004                            && send.poll_ready(cx, pending_send_size).is_ready()
5005                        {
5006                            tracing::trace!("outgoing ring ready");
5007                            return Poll::Ready(None);
5008                        }
5009
5010                        // Collect any of this queue's receive buffers that were
5011                        // completed by a remote channel. This only happens when the
5012                        // subchannel count changes, so that receive buffer ownership
5013                        // moves between queues while some receive buffers are still in
5014                        // use.
5015                        if let Some(remote_buffer_id_recv) =
5016                            &mut queue_state.rx_buffer_range.remote_buffer_id_recv
5017                        {
5018                            while let Poll::Ready(Some(id)) =
5019                                remote_buffer_id_recv.poll_next_unpin(cx)
5020                            {
5021                                if id >= RX_RESERVED_CONTROL_BUFFERS {
5022                                    queue_state.queue.rx_avail(&[RxId(id)]);
5023                                } else {
5024                                    state
5025                                        .primary
5026                                        .as_mut()
5027                                        .unwrap()
5028                                        .free_control_buffers
5029                                        .push(ControlMessageId(id));
5030                                }
5031                            }
5032                        }
5033
5034                        if let Some(restart) = self.restart.take() {
5035                            return Poll::Ready(Some(restart));
5036                        }
5037
5038                        tracing::trace!("network waiting");
5039                        Poll::Pending
5040                    },
5041                ))
5042                .await?;
5043
5044            if let Some(restart) = restart {
5045                break Ok(restart);
5046            }
5047        }
5048    }
5049
5050    fn process_endpoint_rx(
5051        &mut self,
5052        buffers: &ChannelBuffers,
5053        state: &mut ActiveState,
5054        data: &mut ProcessingData,
5055        epqueue: &mut dyn net_backend::Queue,
5056    ) -> Result<bool, WorkerError> {
5057        let n = epqueue
5058            .rx_poll(&mut data.rx_ready)
5059            .map_err(WorkerError::Endpoint)?;
5060        if n == 0 {
5061            return Ok(false);
5062        }
5063
5064        state.stats.rx_packets_per_wake.add_sample(n as u64);
5065
5066        if self.packet_filter == rndisprot::NDIS_PACKET_TYPE_NONE {
5067            tracing::trace!(
5068                packet_filter = self.packet_filter,
5069                "rx packets dropped due to packet filter"
5070            );
5071            // Pend the newly available RxIds until the packet filter is updated.
5072            // Under high load this will eventually lead to no available RxIds,
5073            // which will cause the backend to drop the packets instead of
5074            // processing them here.
5075            state.pending_rx_packets.extend(&data.rx_ready[..n]);
5076            state.stats.rx_dropped_filtered.add(n as u64);
5077            return Ok(false);
5078        }
5079
5080        let transaction_id = data.rx_ready[0].0.into();
5081        let ready_ids = data.rx_ready[..n].iter().map(|&RxId(id)| id);
5082
5083        state.rx_bufs.allocate(ready_ids.clone()).unwrap();
5084
5085        // Always use the full suballocation size to avoid tracking the
5086        // message length. See RxBuf::header() for details.
5087        let len = buffers.recv_buffer.sub_allocation_size as usize;
5088        data.transfer_pages.clear();
5089        data.transfer_pages
5090            .extend(ready_ids.map(|id| buffers.recv_buffer.transfer_page_range(id, len)));
5091
5092        match self.try_send_rndis_message(
5093            transaction_id,
5094            protocol::DATA_CHANNEL_TYPE,
5095            buffers.recv_buffer.id,
5096            &data.transfer_pages,
5097        )? {
5098            None => {
5099                // packet was sent
5100                state.stats.rx_packets.add(n as u64);
5101            }
5102            Some(_) => {
5103                // Ring buffer is full. Drop the packets and free the rx
5104                // buffers. When the ring has limited space, the main loop will
5105                // stop polling for receive packets.
5106                state.stats.rx_dropped_ring_full.add(n as u64);
5107
5108                state.rx_bufs.free(data.rx_ready[0].0);
5109                epqueue.rx_avail(&data.rx_ready[..n]);
5110            }
5111        }
5112
5113        Ok(true)
5114    }
5115
5116    fn process_endpoint_tx(
5117        &mut self,
5118        state: &mut ActiveState,
5119        data: &mut ProcessingData,
5120        epqueue: &mut dyn net_backend::Queue,
5121    ) -> Result<bool, WorkerError> {
5122        // Drain completed transmits.
5123        let result = epqueue.tx_poll(&mut data.tx_done);
5124
5125        match result {
5126            Ok(n) => {
5127                if n == 0 {
5128                    return Ok(false);
5129                }
5130
5131                for &id in &data.tx_done[..n] {
5132                    let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5133                    assert!(tx_packet.pending_packet_count > 0);
5134                    tx_packet.pending_packet_count -= 1;
5135                    if tx_packet.pending_packet_count == 0 {
5136                        self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5137                    }
5138                }
5139
5140                Ok(true)
5141            }
5142            Err(TxError::TryRestart(err)) => {
5143                // Complete any pending tx prior to restarting queues.
5144                let pending_tx = state
5145                    .pending_tx_packets
5146                    .iter_mut()
5147                    .enumerate()
5148                    .filter_map(|(id, inflight)| {
5149                        if inflight.pending_packet_count > 0 {
5150                            inflight.pending_packet_count = 0;
5151                            Some(PendingTxCompletion {
5152                                transaction_id: inflight.transaction_id,
5153                                tx_id: Some(TxId(id as u32)),
5154                                status: protocol::Status::SUCCESS,
5155                            })
5156                        } else {
5157                            None
5158                        }
5159                    })
5160                    .collect::<Vec<_>>();
5161                state.pending_tx_completions.extend(pending_tx);
5162                Err(WorkerError::EndpointRequiresQueueRestart(err))
5163            }
5164            Err(TxError::Fatal(err)) => Err(WorkerError::Endpoint(err)),
5165        }
5166    }
5167
5168    fn switch_data_path(
5169        &mut self,
5170        state: &mut ActiveState,
5171        use_guest_vf: bool,
5172        transaction_id: Option<u64>,
5173    ) -> Result<(), WorkerError> {
5174        let primary = state.primary.as_mut().unwrap();
5175        let mut queue_switch_operation = false;
5176        match primary.guest_vf_state {
5177            PrimaryChannelGuestVfState::AvailableAdvertised | PrimaryChannelGuestVfState::Ready => {
5178                // Allow the guest to switch to VTL0, or if the current state
5179                // of the data path is unknown, allow a switch away from VTL0.
5180                // The latter case handles the scenario where the data path has
5181                // been switched but the synthetic device has been restarted.
5182                // The device is queried for the current state of the data path
5183                // but if it is unknown (upgraded from older version that
5184                // doesn't track this data, or failure during query) then the
5185                // safest option is to pass the request through.
5186                if use_guest_vf || primary.is_data_path_switched.is_none() {
5187                    primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5188                        to_guest: use_guest_vf,
5189                        id: transaction_id,
5190                        result: None,
5191                    };
5192                    queue_switch_operation = true;
5193                }
5194            }
5195            PrimaryChannelGuestVfState::DataPathSwitched => {
5196                if !use_guest_vf {
5197                    primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5198                        to_guest: false,
5199                        id: transaction_id,
5200                        result: None,
5201                    };
5202                    queue_switch_operation = true;
5203                }
5204            }
5205            _ if use_guest_vf => {
5206                tracing::warn!(
5207                    state = %primary.guest_vf_state,
5208                    use_guest_vf,
5209                    "Data path switch requested while device is in wrong state"
5210                );
5211            }
5212            _ => (),
5213        };
5214        if queue_switch_operation {
5215            self.send_coordinator_update_vf();
5216        } else {
5217            self.send_completion(transaction_id, &[])?;
5218        }
5219        Ok(())
5220    }
5221
5222    fn process_ring_buffer(
5223        &mut self,
5224        buffers: &ChannelBuffers,
5225        state: &mut ActiveState,
5226        data: &mut ProcessingData,
5227        queue_state: &mut QueueState,
5228    ) -> Result<bool, WorkerError> {
5229        let mut total_packets = 0;
5230        let mut did_some_work = false;
5231        loop {
5232            if state.free_tx_packets.is_empty() || !data.tx_segments.is_empty() {
5233                break;
5234            }
5235            let packet = if let Some(packet) =
5236                self.try_next_packet(buffers.send_buffer.as_ref(), Some(buffers.version))?
5237            {
5238                packet
5239            } else {
5240                break;
5241            };
5242
5243            did_some_work = true;
5244            match packet.data {
5245                PacketData::RndisPacket(_) => {
5246                    assert!(data.tx_segments.is_empty());
5247                    let id = state.free_tx_packets.pop().unwrap();
5248                    let result: Result<usize, WorkerError> =
5249                        self.handle_rndis(buffers, id, state, &packet, &mut data.tx_segments);
5250                    let num_packets = match result {
5251                        Ok(num_packets) => num_packets,
5252                        Err(err) => {
5253                            tracelimit::error_ratelimited!(
5254                                err = &err as &dyn std::error::Error,
5255                                "failed to handle RNDIS packet"
5256                            );
5257                            self.complete_tx_packet(state, id, protocol::Status::FAILURE)?;
5258                            continue;
5259                        }
5260                    };
5261                    total_packets += num_packets as u64;
5262                    state.pending_tx_packets[id.0 as usize].pending_packet_count += num_packets;
5263
5264                    if num_packets != 0 {
5265                        if self.transmit_segments(state, data, queue_state, id, num_packets)?
5266                            < num_packets
5267                        {
5268                            state.stats.tx_stalled.increment();
5269                        }
5270                    } else {
5271                        self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5272                    }
5273                }
5274                PacketData::RndisPacketComplete(_completion) => {
5275                    data.rx_done.clear();
5276                    state
5277                        .release_recv_buffers(
5278                            packet
5279                                .transaction_id
5280                                .expect("completion packets have transaction id by construction"),
5281                            &queue_state.rx_buffer_range,
5282                            &mut data.rx_done,
5283                        )
5284                        .ok_or(WorkerError::InvalidRndisPacketCompletion)?;
5285                    queue_state.queue.rx_avail(&data.rx_done);
5286                }
5287                PacketData::SubChannelRequest(request) if state.primary.is_some() => {
5288                    let mut subchannel_count = 0;
5289                    let status = if request.operation == protocol::SubchannelOperation::ALLOCATE
5290                        && request.num_sub_channels <= self.adapter.max_queues.into()
5291                    {
5292                        subchannel_count = request.num_sub_channels;
5293                        protocol::Status::SUCCESS
5294                    } else {
5295                        protocol::Status::FAILURE
5296                    };
5297
5298                    tracing::debug!(?status, subchannel_count, "subchannel request");
5299                    self.send_completion(
5300                        packet.transaction_id,
5301                        &self
5302                            .message(
5303                                protocol::MESSAGE5_TYPE_SUB_CHANNEL,
5304                                protocol::Message5SubchannelComplete {
5305                                    status,
5306                                    num_sub_channels: subchannel_count,
5307                                },
5308                            )
5309                            .payload(),
5310                    )?;
5311
5312                    if subchannel_count > 0 {
5313                        let primary = state.primary.as_mut().unwrap();
5314                        primary.requested_num_queues = subchannel_count as u16 + 1;
5315                        primary.tx_spread_sent = false;
5316                        self.restart = Some(CoordinatorMessage::Restart);
5317                    }
5318                }
5319                PacketData::RevokeReceiveBuffer(Message1RevokeReceiveBuffer { id })
5320                | PacketData::RevokeSendBuffer(Message1RevokeSendBuffer { id })
5321                    if state.primary.is_some() =>
5322                {
5323                    tracing::debug!(
5324                        id,
5325                        "receive/send buffer revoked, terminating channel processing"
5326                    );
5327                    return Err(WorkerError::BufferRevoked);
5328                }
5329                PacketData::SendVfAssociationCompletion if state.primary.is_some() => (),
5330                PacketData::SwitchDataPath(switch_data_path) if state.primary.is_some() => {
5331                    self.switch_data_path(
5332                        state,
5333                        switch_data_path.active_data_path == protocol::DataPath::VF.0,
5334                        packet.transaction_id,
5335                    )?;
5336                }
5337                PacketData::SwitchDataPathCompletion if state.primary.is_some() => (),
5338                PacketData::OidQueryEx(oid_query) => {
5339                    tracing::warn!(oid = ?oid_query.oid, "unimplemented OID");
5340                    self.send_completion(
5341                        packet.transaction_id,
5342                        &self
5343                            .message(
5344                                protocol::MESSAGE5_TYPE_OID_QUERY_EX_COMPLETE,
5345                                protocol::Message5OidQueryExComplete {
5346                                    status: rndisprot::STATUS_NOT_SUPPORTED,
5347                                    bytes: 0,
5348                                },
5349                            )
5350                            .payload(),
5351                    )?;
5352                }
5353                p => {
5354                    tracing::warn!(request = ?p, "unexpected packet");
5355                    return Err(WorkerError::UnexpectedPacketOrder(
5356                        PacketOrderError::SwitchDataPathCompletionPrimaryChannelState,
5357                    ));
5358                }
5359            }
5360        }
5361        state.stats.tx_packets_per_wake.add_sample(total_packets);
5362        Ok(did_some_work)
5363    }
5364
5365    fn transmit_pending_segments(
5366        &mut self,
5367        state: &mut ActiveState,
5368        data: &mut ProcessingData,
5369        queue_state: &mut QueueState,
5370    ) -> Result<bool, WorkerError> {
5371        if data.tx_segments.is_empty() {
5372            return Ok(false);
5373        }
5374        let net_backend::TxSegmentType::Head(metadata) = &data.tx_segments[0].ty else {
5375            unreachable!()
5376        };
5377        let id = metadata.id;
5378        let num_packets = state.pending_tx_packets[id.0 as usize].pending_packet_count;
5379        let packets_sent = self.transmit_segments(state, data, queue_state, id, num_packets)?;
5380        Ok(num_packets == packets_sent)
5381    }
5382
5383    fn transmit_segments(
5384        &mut self,
5385        state: &mut ActiveState,
5386        data: &mut ProcessingData,
5387        queue_state: &mut QueueState,
5388        id: TxId,
5389        num_packets: usize,
5390    ) -> Result<usize, WorkerError> {
5391        let (sync, segments_sent) = queue_state
5392            .queue
5393            .tx_avail(&data.tx_segments)
5394            .map_err(WorkerError::Endpoint)?;
5395
5396        assert!(segments_sent <= data.tx_segments.len());
5397
5398        let packets_sent = if segments_sent == data.tx_segments.len() {
5399            num_packets
5400        } else {
5401            net_backend::packet_count(&data.tx_segments[..segments_sent])
5402        };
5403
5404        data.tx_segments.drain(..segments_sent);
5405
5406        if sync {
5407            state.pending_tx_packets[id.0 as usize].pending_packet_count -= packets_sent;
5408        }
5409
5410        if state.pending_tx_packets[id.0 as usize].pending_packet_count == 0 {
5411            self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5412        }
5413
5414        Ok(packets_sent)
5415    }
5416
5417    fn handle_rndis(
5418        &mut self,
5419        buffers: &ChannelBuffers,
5420        id: TxId,
5421        state: &mut ActiveState,
5422        packet: &Packet<'_>,
5423        segments: &mut Vec<TxSegment>,
5424    ) -> Result<usize, WorkerError> {
5425        let mut num_packets = 0;
5426        let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5427        assert!(tx_packet.pending_packet_count == 0);
5428        tx_packet.transaction_id = packet
5429            .transaction_id
5430            .ok_or(WorkerError::MissingTransactionId)?;
5431
5432        // Probe the data to catch accesses that are out of bounds. This
5433        // simplifies error handling for backends that use
5434        // [`GuestMemory::iova`].
5435        packet
5436            .external_data
5437            .iter()
5438            .try_for_each(|range| buffers.mem.probe_gpns(range.gpns()))
5439            .map_err(WorkerError::GpaDirectError)?;
5440
5441        let mut reader = packet.rndis_reader(&buffers.mem);
5442        // There may be multiple RNDIS packets in a single
5443        // message, concatenated with each other. Consume
5444        // them until there is no more data in the RNDIS
5445        // message.
5446        while reader.len() > 0 {
5447            let mut this_reader = reader.clone();
5448            let header: rndisprot::MessageHeader = this_reader.read_plain()?;
5449            if self.handle_rndis_message(
5450                buffers,
5451                state,
5452                id,
5453                header.message_type,
5454                this_reader,
5455                segments,
5456            )? {
5457                num_packets += 1;
5458            }
5459            reader.skip(header.message_length as usize)?;
5460        }
5461
5462        Ok(num_packets)
5463    }
5464
5465    fn try_send_tx_packet(
5466        &mut self,
5467        transaction_id: u64,
5468        status: protocol::Status,
5469    ) -> Result<bool, WorkerError> {
5470        let message = self.message(
5471            protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE,
5472            protocol::Message1SendRndisPacketComplete { status },
5473        );
5474        let result = self.queue.split().1.try_write(&queue::OutgoingPacket {
5475            transaction_id,
5476            packet_type: OutgoingPacketType::Completion,
5477            payload: &message.payload(),
5478        });
5479        let sent = match result {
5480            Ok(()) => true,
5481            Err(queue::TryWriteError::Full(n)) => {
5482                self.pending_send_size = n;
5483                false
5484            }
5485            Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
5486        };
5487        Ok(sent)
5488    }
5489
5490    fn send_pending_packets(&mut self, state: &mut ActiveState) -> Result<bool, WorkerError> {
5491        let mut did_some_work = false;
5492        while let Some(pending) = state.pending_tx_completions.front() {
5493            if !self.try_send_tx_packet(pending.transaction_id, pending.status)? {
5494                return Ok(did_some_work);
5495            }
5496            did_some_work = true;
5497            if let Some(id) = pending.tx_id {
5498                state.free_tx_packets.push(id);
5499            }
5500            tracing::trace!(?pending, "sent tx completion");
5501            state.pending_tx_completions.pop_front();
5502        }
5503
5504        self.pending_send_size = 0;
5505        Ok(did_some_work)
5506    }
5507
5508    fn complete_tx_packet(
5509        &mut self,
5510        state: &mut ActiveState,
5511        id: TxId,
5512        status: protocol::Status,
5513    ) -> Result<(), WorkerError> {
5514        let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5515        assert_eq!(tx_packet.pending_packet_count, 0);
5516        if self.pending_send_size == 0
5517            && self.try_send_tx_packet(tx_packet.transaction_id, status)?
5518        {
5519            tracing::trace!(id = id.0, "sent tx completion");
5520            state.free_tx_packets.push(id);
5521        } else {
5522            tracing::trace!(id = id.0, "pended tx completion");
5523            state.pending_tx_completions.push_back(PendingTxCompletion {
5524                transaction_id: tx_packet.transaction_id,
5525                tx_id: Some(id),
5526                status,
5527            });
5528        }
5529        Ok(())
5530    }
5531}
5532
5533impl ActiveState {
5534    fn release_recv_buffers(
5535        &mut self,
5536        transaction_id: u64,
5537        rx_buffer_range: &RxBufferRange,
5538        done: &mut Vec<RxId>,
5539    ) -> Option<()> {
5540        // The transaction ID specifies the first rx buffer ID.
5541        let first_id: u32 = transaction_id.try_into().ok()?;
5542        let ids = self.rx_bufs.free(first_id)?;
5543        for id in ids {
5544            if !rx_buffer_range.send_if_remote(id) {
5545                if id >= RX_RESERVED_CONTROL_BUFFERS {
5546                    done.push(RxId(id));
5547                } else {
5548                    self.primary
5549                        .as_mut()?
5550                        .free_control_buffers
5551                        .push(ControlMessageId(id));
5552                }
5553            }
5554        }
5555        Some(())
5556    }
5557}