netvsp/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! The user-mode netvsp VMBus device implementation.
5
6#![expect(missing_docs)]
7#![forbid(unsafe_code)]
8
9mod buffers;
10mod protocol;
11pub mod resolver;
12mod rndisprot;
13mod rx_bufs;
14mod saved_state;
15mod test;
16
17use crate::buffers::GuestBuffers;
18use crate::protocol::VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES;
19use crate::protocol::Version;
20use crate::rndisprot::NDIS_HASH_FUNCTION_MASK;
21use crate::rndisprot::NDIS_RSS_PARAM_FLAG_DISABLE_RSS;
22use async_trait::async_trait;
23pub use buffers::BufferPool;
24use buffers::sub_allocation_size_for_mtu;
25use futures::FutureExt;
26use futures::StreamExt;
27use futures::channel::mpsc;
28use futures::channel::mpsc::TrySendError;
29use futures_concurrency::future::Race;
30use guestmem::AccessError;
31use guestmem::GuestMemory;
32use guestmem::GuestMemoryError;
33use guestmem::MemoryRead;
34use guestmem::MemoryWrite;
35use guestmem::ranges::PagedRange;
36use guestmem::ranges::PagedRanges;
37use guestmem::ranges::PagedRangesReader;
38use guid::Guid;
39use hvdef::hypercall::HvGuestOsId;
40use hvdef::hypercall::HvGuestOsMicrosoft;
41use hvdef::hypercall::HvGuestOsMicrosoftIds;
42use hvdef::hypercall::HvGuestOsOpenSourceType;
43use inspect::Inspect;
44use inspect::InspectMut;
45use inspect::SensitivityLevel;
46use inspect_counters::Counter;
47use inspect_counters::Histogram;
48use mesh::rpc::Rpc;
49use net_backend::Endpoint;
50use net_backend::EndpointAction;
51use net_backend::QueueConfig;
52use net_backend::RxId;
53use net_backend::TxError;
54use net_backend::TxId;
55use net_backend::TxSegment;
56use net_backend_resources::mac_address::MacAddress;
57use pal_async::timer::Instant;
58use pal_async::timer::PolledTimer;
59use ring::gparange::MultiPagedRangeIter;
60use rx_bufs::RxBuffers;
61use rx_bufs::SubAllocationInUse;
62use std::collections::VecDeque;
63use std::fmt::Debug;
64use std::future::pending;
65use std::mem::offset_of;
66use std::ops::Range;
67use std::sync::Arc;
68use std::sync::atomic::AtomicUsize;
69use std::sync::atomic::Ordering;
70use std::task::Poll;
71use std::time::Duration;
72use task_control::AsyncRun;
73use task_control::InspectTaskMut;
74use task_control::StopTask;
75use task_control::TaskControl;
76use thiserror::Error;
77use tracing::Instrument;
78use vmbus_async::queue;
79use vmbus_async::queue::ExternalDataError;
80use vmbus_async::queue::IncomingPacket;
81use vmbus_async::queue::Queue;
82use vmbus_channel::bus::OfferParams;
83use vmbus_channel::bus::OpenRequest;
84use vmbus_channel::channel::ChannelControl;
85use vmbus_channel::channel::ChannelOpenError;
86use vmbus_channel::channel::ChannelRestoreError;
87use vmbus_channel::channel::DeviceResources;
88use vmbus_channel::channel::RestoreControl;
89use vmbus_channel::channel::SaveRestoreVmbusDevice;
90use vmbus_channel::channel::VmbusDevice;
91use vmbus_channel::gpadl::GpadlId;
92use vmbus_channel::gpadl::GpadlMapView;
93use vmbus_channel::gpadl::GpadlView;
94use vmbus_channel::gpadl::UnknownGpadlId;
95use vmbus_channel::gpadl_ring::GpadlRingMem;
96use vmbus_channel::gpadl_ring::gpadl_channel;
97use vmbus_ring as ring;
98use vmbus_ring::OutgoingPacketType;
99use vmbus_ring::RingMem;
100use vmbus_ring::gparange::MultiPagedRangeBuf;
101use vmcore::save_restore::RestoreError;
102use vmcore::save_restore::SaveError;
103use vmcore::save_restore::SavedStateBlob;
104use vmcore::vm_task::VmTaskDriver;
105use vmcore::vm_task::VmTaskDriverSource;
106use zerocopy::FromBytes;
107use zerocopy::FromZeros;
108use zerocopy::Immutable;
109use zerocopy::IntoBytes;
110use zerocopy::KnownLayout;
111
112// The minimum ring space required to handle a control message. Most control messages only need to send a completion
113// packet, but also need room for an additional SEND_VF_ASSOCIATION message.
114const MIN_CONTROL_RING_SIZE: usize = 144;
115
116// The minimum ring space required to handle external state changes. Worst case requires a completion message plus two
117// additional inband messages (SWITCH_DATA_PATH and SEND_VF_ASSOCIATION)
118const MIN_STATE_CHANGE_RING_SIZE: usize = 196;
119
120// Assign the VF_ASSOCIATION message a specific transaction ID so that the completion packet can be identified easily.
121const VF_ASSOCIATION_TRANSACTION_ID: u64 = 0x8000000000000000;
122// Assign the SWITCH_DATA_PATH message a specific transaction ID so that the completion packet can be identified easily.
123const SWITCH_DATA_PATH_TRANSACTION_ID: u64 = 0x8000000000000001;
124
125const NETVSP_MAX_SUBCHANNELS_PER_VNIC: u16 = 64;
126
127// Arbitrary delay before adding the device to the guest. Older Linux
128// clients can race when initializing the synthetic nic: the network
129// negotiation is done first and then the device is asynchronously queued to
130// receive a name (e.g. eth0). If the AN device is offered too quickly, it
131// could get the "eth0" name. In provisioning scenarios, the scripts will make
132// assumptions about which interface should be used, with eth0 being the
133// default.
134#[cfg(not(test))]
135const VF_DEVICE_DELAY: Duration = Duration::from_secs(1);
136#[cfg(test)]
137const VF_DEVICE_DELAY: Duration = Duration::from_millis(100);
138
139// Linux guests are known to not act on link state change notifications if
140// they happen in quick succession.
141#[cfg(not(test))]
142const LINK_DELAY_DURATION: Duration = Duration::from_secs(5);
143#[cfg(test)]
144const LINK_DELAY_DURATION: Duration = Duration::from_millis(333);
145
146#[derive(Default, PartialEq)]
147struct CoordinatorMessageUpdateType {
148    /// Update guest VF state based on current availability and the guest VF state tracked by the primary channel.
149    /// This includes adding the guest VF device and switching the data path.
150    guest_vf_state: bool,
151    /// Update the receive filter for all channels.
152    filter_state: bool,
153}
154
155#[derive(PartialEq)]
156enum CoordinatorMessage {
157    /// Update network state.
158    Update(CoordinatorMessageUpdateType),
159    /// Restart endpoints and resume processing. This will also attempt to set VF and data path state to match current
160    /// expectations.
161    Restart,
162    /// Start a timer.
163    StartTimer(Instant),
164}
165
166struct Worker<T: RingMem> {
167    channel_idx: u16,
168    target_vp: Option<u32>,
169    mem: GuestMemory,
170    channel: NetChannel<T>,
171    state: WorkerState,
172    coordinator_send: mpsc::Sender<CoordinatorMessage>,
173}
174
175struct NetQueue {
176    driver: VmTaskDriver,
177    queue_state: Option<QueueState>,
178}
179
180impl<T: RingMem + 'static + Sync> InspectTaskMut<Worker<T>> for NetQueue {
181    fn inspect_mut(&mut self, req: inspect::Request<'_>, worker: Option<&mut Worker<T>>) {
182        if worker.is_none() && self.queue_state.is_none() {
183            req.ignore();
184            return;
185        }
186
187        let mut resp = req.respond();
188        resp.field("driver", &self.driver);
189        if let Some(worker) = worker {
190            resp.field(
191                "protocol_state",
192                match &worker.state {
193                    WorkerState::Init(None) => "version",
194                    WorkerState::Init(Some(_)) => "init",
195                    WorkerState::Ready(_) => "ready",
196                    WorkerState::WaitingForCoordinator(_) => "waiting for coordinator",
197                },
198            )
199            .field("ring", &worker.channel.queue)
200            .field(
201                "can_use_ring_size_optimization",
202                worker.channel.can_use_ring_size_opt,
203            );
204
205            if let WorkerState::Ready(state) = &worker.state {
206                resp.field(
207                    "outstanding_tx_packets",
208                    state.state.pending_tx_packets.len() - state.state.free_tx_packets.len(),
209                )
210                .field("pending_rx_packets", state.state.pending_rx_packets.len())
211                .field(
212                    "pending_tx_completions",
213                    state.state.pending_tx_completions.len(),
214                )
215                .field("free_tx_packets", state.state.free_tx_packets.len())
216                .merge(&state.state.stats);
217            }
218
219            resp.field("packet_filter", worker.channel.packet_filter);
220        }
221
222        if let Some(queue_state) = &mut self.queue_state {
223            resp.field_mut("queue", &mut queue_state.queue)
224                .field("rx_buffers", queue_state.rx_buffer_range.id_range.len())
225                .field(
226                    "rx_buffers_start",
227                    queue_state.rx_buffer_range.id_range.start,
228                );
229        }
230    }
231}
232
233enum WorkerState {
234    Init(Option<InitState>),
235    Ready(ReadyState),
236    WaitingForCoordinator(Option<ReadyState>),
237}
238
239impl WorkerState {
240    fn ready(&self) -> Option<&ReadyState> {
241        if let Self::Ready(state) = self {
242            Some(state)
243        } else {
244            None
245        }
246    }
247
248    fn ready_mut(&mut self) -> Option<&mut ReadyState> {
249        if let Self::Ready(state) = self {
250            Some(state)
251        } else {
252            None
253        }
254    }
255}
256
257struct InitState {
258    version: Version,
259    ndis_config: Option<NdisConfig>,
260    ndis_version: Option<NdisVersion>,
261    recv_buffer: Option<ReceiveBuffer>,
262    send_buffer: Option<SendBuffer>,
263}
264
265#[derive(Copy, Clone, Debug, Inspect)]
266struct NdisVersion {
267    #[inspect(hex)]
268    major: u32,
269    #[inspect(hex)]
270    minor: u32,
271}
272
273#[derive(Copy, Clone, Debug, Inspect)]
274struct NdisConfig {
275    #[inspect(safe)]
276    mtu: u32,
277    #[inspect(safe)]
278    capabilities: protocol::NdisConfigCapabilities,
279}
280
281struct ReadyState {
282    buffers: Arc<ChannelBuffers>,
283    state: ActiveState,
284    data: ProcessingData,
285}
286
287impl ReadyState {
288    /// Any in-flight TX packets submitted to the old endpoint queues will
289    /// never be completed, so this method:
290    /// 1. Queues completions for all outstanding sends so the guest gets
291    ///    responses and the associated transmit slots are reclaimed, ensuring
292    ///    the worker is ready to poll post restart of the queues.
293    /// 2. Clears leftover transmit segments that referenced the old endpoint
294    ///    queue so they are not submitted to the new one.
295    ///
296    fn reset_tx_after_endpoint_stop(&mut self) {
297        let state = &mut self.state;
298
299        // Queue completions for in-flight TX packets that were lost when the
300        // endpoint stopped. They will get picked up when the worker restarts.
301        let pending_tx = state
302            .pending_tx_packets
303            .iter_mut()
304            .enumerate()
305            .filter_map(|(id, inflight)| {
306                if inflight.pending_packet_count > 0 {
307                    inflight.pending_packet_count = 0;
308                    Some(PendingTxCompletion {
309                        transaction_id: inflight.transaction_id,
310                        tx_id: Some(TxId(id as u32)),
311                        status: protocol::Status::SUCCESS,
312                    })
313                } else {
314                    None
315                }
316            })
317            .collect::<Vec<_>>();
318        state.pending_tx_completions.extend(pending_tx);
319
320        // Clear leftover TX segments from the previous endpoint queue;
321        // they cannot be submitted to the new queue.
322        self.data.tx_segments.clear();
323        self.data.tx_segments_sent = 0;
324    }
325}
326
327/// Represents a virtual function (VF) device used to expose accelerated
328/// networking to the guest.
329#[async_trait]
330pub trait VirtualFunction: Sync + Send {
331    /// Unique ID of the device. Used by the client to associate a device with
332    /// its synthetic counterpart. A value of None signifies that the VF is not
333    /// currently available for use.
334    async fn id(&self) -> Option<u32>;
335    /// Dynamically expose the device in the guest.
336    async fn guest_ready_for_device(&mut self);
337    /// Returns when there is a change in VF availability. The Rpc result will
338    ///  indicate if the change was successfully handled.
339    async fn wait_for_state_change(&mut self) -> Rpc<(), ()>;
340}
341
342struct Adapter {
343    driver: VmTaskDriver,
344    mac_address: MacAddress,
345    max_queues: u16,
346    indirection_table_size: u16,
347    offload_support: OffloadConfig,
348    ring_size_limit: AtomicUsize,
349    free_tx_packet_threshold: usize,
350    tx_fast_completions: bool,
351    adapter_index: u32,
352    get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
353    num_sub_channels_opened: AtomicUsize,
354    link_speed: u64,
355}
356
357struct QueueState {
358    queue: Box<dyn net_backend::Queue>,
359    rx_buffer_range: RxBufferRange,
360    target_vp_set: bool,
361}
362
363struct RxBufferRange {
364    id_range: Range<u32>,
365    remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
366    remote_ranges: Arc<RxBufferRanges>,
367}
368
369impl RxBufferRange {
370    fn new(
371        ranges: Arc<RxBufferRanges>,
372        id_range: Range<u32>,
373        remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
374    ) -> Self {
375        Self {
376            id_range,
377            remote_buffer_id_recv,
378            remote_ranges: ranges,
379        }
380    }
381
382    fn send_if_remote(&self, id: u32) -> bool {
383        // Only queue 0 should get reserved buffer IDs. Otherwise check if the
384        // ID is owned by the current range.
385        if id < RX_RESERVED_CONTROL_BUFFERS || self.id_range.contains(&id) {
386            false
387        } else {
388            let i = (id - RX_RESERVED_CONTROL_BUFFERS) / self.remote_ranges.buffers_per_queue;
389            // The total number of receive buffers may not evenly divide among
390            // the active queues. Any extra buffers are given to the last
391            // queue, so redirect any larger values there.
392            let i = (i as usize).min(self.remote_ranges.buffer_id_send.len() - 1);
393            let _ = self.remote_ranges.buffer_id_send[i].unbounded_send(id);
394            true
395        }
396    }
397}
398
399struct RxBufferRanges {
400    buffers_per_queue: u32,
401    buffer_id_send: Vec<mpsc::UnboundedSender<u32>>,
402}
403
404impl RxBufferRanges {
405    fn new(buffer_count: u32, queue_count: u32) -> (Self, Vec<mpsc::UnboundedReceiver<u32>>) {
406        let buffers_per_queue = (buffer_count - RX_RESERVED_CONTROL_BUFFERS) / queue_count;
407        #[expect(clippy::disallowed_methods)] // TODO
408        let (send, recv): (Vec<_>, Vec<_>) = (0..queue_count).map(|_| mpsc::unbounded()).unzip();
409        (
410            Self {
411                buffers_per_queue,
412                buffer_id_send: send,
413            },
414            recv,
415        )
416    }
417}
418
419struct RssState {
420    key: [u8; 40],
421    indirection_table: Vec<u16>,
422}
423
424/// The internal channel state.
425struct NetChannel<T: RingMem> {
426    adapter: Arc<Adapter>,
427    queue: Queue<T>,
428    gpadl_map: GpadlMapView,
429    packet_size: PacketSize,
430    pending_send_size: usize,
431    restart: Option<CoordinatorMessage>,
432    can_use_ring_size_opt: bool,
433    packet_filter: u32,
434}
435
436// Use an enum to give the compiler more visibility into the packet size.
437#[derive(Debug, Copy, Clone)]
438enum PacketSize {
439    /// [`protocol::PACKET_SIZE_V1`]
440    V1,
441    /// [`protocol::PACKET_SIZE_V61`]
442    V61,
443}
444
445/// Buffers used during packet processing.
446struct ProcessingData {
447    tx_segments: Vec<TxSegment>,
448    tx_segments_sent: usize,
449    tx_done: Box<[TxId]>,
450    rx_ready: Box<[RxId]>,
451    rx_done: Vec<RxId>,
452    transfer_pages: Vec<ring::TransferPageRange>,
453    external_data: MultiPagedRangeBuf,
454}
455
456impl ProcessingData {
457    fn new() -> Self {
458        Self {
459            tx_segments: Vec::new(),
460            tx_segments_sent: 0,
461            tx_done: vec![TxId(0); 8192].into(),
462            rx_ready: vec![RxId(0); RX_BATCH_SIZE].into(),
463            rx_done: Vec::with_capacity(RX_BATCH_SIZE),
464            transfer_pages: Vec::with_capacity(RX_BATCH_SIZE),
465            external_data: MultiPagedRangeBuf::new(),
466        }
467    }
468}
469
470/// Buffers used during channel processing. Separated out from the mutable state
471/// to allow multiple concurrent references.
472#[derive(Debug, Inspect)]
473struct ChannelBuffers {
474    version: Version,
475    #[inspect(skip)]
476    mem: GuestMemory,
477    #[inspect(skip)]
478    recv_buffer: ReceiveBuffer,
479    #[inspect(skip)]
480    send_buffer: Option<SendBuffer>,
481    ndis_version: NdisVersion,
482    #[inspect(safe)]
483    ndis_config: NdisConfig,
484}
485
486/// An ID assigned to a control message. This is also its receive buffer index.
487#[derive(Copy, Clone, Debug)]
488struct ControlMessageId(u32);
489
490/// Mutable state for a channel that has finished negotiation.
491struct ActiveState {
492    primary: Option<PrimaryChannelState>,
493
494    pending_tx_packets: Vec<PendingTxPacket>,
495    free_tx_packets: Vec<TxId>,
496    pending_tx_completions: VecDeque<PendingTxCompletion>,
497    pending_rx_packets: VecDeque<RxId>,
498
499    rx_bufs: RxBuffers,
500
501    stats: QueueStats,
502}
503
504#[derive(Inspect, Default)]
505struct QueueStats {
506    tx_stalled: Counter,
507    rx_dropped_ring_full: Counter,
508    rx_dropped_filtered: Counter,
509    spurious_wakes: Counter,
510    rx_packets: Counter,
511    tx_packets: Counter,
512    tx_lso_packets: Counter,
513    tx_checksum_packets: Counter,
514    tx_invalid_lso_packets: Counter,
515    tx_packets_per_wake: Histogram<10>,
516    rx_packets_per_wake: Histogram<10>,
517}
518
519#[derive(Debug)]
520struct PendingTxCompletion {
521    transaction_id: u64,
522    tx_id: Option<TxId>,
523    status: protocol::Status,
524}
525
526#[derive(Clone, Copy)]
527enum PrimaryChannelGuestVfState {
528    /// No state has been assigned yet
529    Initializing,
530    /// State is being restored from a save
531    Restoring(saved_state::GuestVfState),
532    /// No VF available for the guest
533    Unavailable,
534    /// A VF was previously available to the guest, but is no longer available.
535    UnavailableFromAvailable,
536    /// A VF was previously available, but it is no longer available.
537    UnavailableFromDataPathSwitchPending { to_guest: bool, id: Option<u64> },
538    /// A VF was previously available, but it is no longer available.
539    UnavailableFromDataPathSwitched,
540    /// A VF is available for the guest.
541    Available { vfid: u32 },
542    /// A VF is available for the guest and has been advertised to the guest.
543    AvailableAdvertised,
544    /// A VF is ready for guest use.
545    Ready,
546    /// A VF is ready in the guest and guest has requested a data path switch.
547    DataPathSwitchPending {
548        to_guest: bool,
549        id: Option<u64>,
550        result: Option<bool>,
551    },
552    /// A VF is ready in the guest and is currently acting as the data path.
553    DataPathSwitched,
554    /// A VF is ready in the guest and was acting as the data path, but an external
555    /// state change has moved it back to synthetic.
556    DataPathSynthetic,
557}
558
559impl std::fmt::Display for PrimaryChannelGuestVfState {
560    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
561        match self {
562            PrimaryChannelGuestVfState::Initializing => write!(f, "initializing"),
563            PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
564                write!(f, "restoring")
565            }
566            PrimaryChannelGuestVfState::Restoring(
567                saved_state::GuestVfState::AvailableAdvertised,
568            ) => write!(f, "restoring from guest notified of vfid"),
569            PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
570                write!(f, "restoring from vf present")
571            }
572            PrimaryChannelGuestVfState::Restoring(
573                saved_state::GuestVfState::DataPathSwitchPending {
574                    to_guest, result, ..
575                },
576            ) => {
577                write!(
578                    f,
579                    "restoring from client requested data path switch: to {} {}",
580                    if *to_guest { "guest" } else { "synthetic" },
581                    if let Some(result) = result {
582                        if *result { "succeeded\"" } else { "failed\"" }
583                    } else {
584                        "in progress\""
585                    }
586                )
587            }
588            PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::DataPathSwitched) => {
589                write!(f, "restoring from data path in guest")
590            }
591            PrimaryChannelGuestVfState::Unavailable => write!(f, "unavailable"),
592            PrimaryChannelGuestVfState::UnavailableFromAvailable => {
593                write!(f, "\"unavailable (previously available)\"")
594            }
595            PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. } => {
596                write!(f, "unavailable (previously switching data path)")
597            }
598            PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
599                write!(f, "\"unavailable (previously using guest VF)\"")
600            }
601            PrimaryChannelGuestVfState::Available { vfid } => write!(f, "available vfid: {}", vfid),
602            PrimaryChannelGuestVfState::AvailableAdvertised => {
603                write!(f, "\"available, guest notified\"")
604            }
605            PrimaryChannelGuestVfState::Ready => write!(f, "\"available and present in guest\""),
606            PrimaryChannelGuestVfState::DataPathSwitchPending {
607                to_guest, result, ..
608            } => {
609                write!(
610                    f,
611                    "\"switching to {} {}",
612                    if *to_guest { "guest" } else { "synthetic" },
613                    if let Some(result) = result {
614                        if *result { "succeeded\"" } else { "failed\"" }
615                    } else {
616                        "in progress\""
617                    }
618                )
619            }
620            PrimaryChannelGuestVfState::DataPathSwitched => {
621                write!(f, "\"available and data path switched\"")
622            }
623            PrimaryChannelGuestVfState::DataPathSynthetic => write!(
624                f,
625                "\"available but data path switched back to synthetic due to external state change\""
626            ),
627        }
628    }
629}
630
631impl Inspect for PrimaryChannelGuestVfState {
632    fn inspect(&self, req: inspect::Request<'_>) {
633        req.value(self.to_string());
634    }
635}
636
637#[derive(Debug)]
638enum PendingLinkAction {
639    Default,
640    Active(bool),
641    Delay(bool),
642}
643
644struct PrimaryChannelState {
645    guest_vf_state: PrimaryChannelGuestVfState,
646    is_data_path_switched: Option<bool>,
647    control_messages: VecDeque<ControlMessage>,
648    control_messages_len: usize,
649    free_control_buffers: Vec<ControlMessageId>,
650    rss_state: Option<RssState>,
651    requested_num_queues: u16,
652    rndis_state: RndisState,
653    offload_config: OffloadConfig,
654    pending_offload_change: bool,
655    tx_spread_sent: bool,
656    guest_link_up: bool,
657    pending_link_action: PendingLinkAction,
658}
659
660impl Inspect for PrimaryChannelState {
661    fn inspect(&self, req: inspect::Request<'_>) {
662        req.respond()
663            .sensitivity_field(
664                "guest_vf_state",
665                SensitivityLevel::Safe,
666                self.guest_vf_state,
667            )
668            .sensitivity_field(
669                "data_path_switched",
670                SensitivityLevel::Safe,
671                self.is_data_path_switched,
672            )
673            .sensitivity_field(
674                "pending_control_messages",
675                SensitivityLevel::Safe,
676                self.control_messages.len(),
677            )
678            .sensitivity_field(
679                "free_control_message_buffers",
680                SensitivityLevel::Safe,
681                self.free_control_buffers.len(),
682            )
683            .sensitivity_field(
684                "pending_offload_change",
685                SensitivityLevel::Safe,
686                self.pending_offload_change,
687            )
688            .sensitivity_field("rndis_state", SensitivityLevel::Safe, self.rndis_state)
689            .sensitivity_field(
690                "offload_config",
691                SensitivityLevel::Safe,
692                &self.offload_config,
693            )
694            .sensitivity_field(
695                "tx_spread_sent",
696                SensitivityLevel::Safe,
697                self.tx_spread_sent,
698            )
699            .sensitivity_field("guest_link_up", SensitivityLevel::Safe, self.guest_link_up)
700            .sensitivity_field(
701                "pending_link_action",
702                SensitivityLevel::Safe,
703                match &self.pending_link_action {
704                    PendingLinkAction::Active(up) => format!("Active({:x?})", up),
705                    PendingLinkAction::Delay(up) => format!("Delay({:x?})", up),
706                    PendingLinkAction::Default => "None".to_string(),
707                },
708            );
709    }
710}
711
712#[derive(Debug, Inspect, Clone)]
713struct OffloadConfig {
714    #[inspect(safe)]
715    checksum_tx: ChecksumOffloadConfig,
716    #[inspect(safe)]
717    checksum_rx: ChecksumOffloadConfig,
718    #[inspect(safe)]
719    lso4: bool,
720    #[inspect(safe)]
721    lso6: bool,
722}
723
724impl OffloadConfig {
725    fn mask_to_supported(&mut self, supported: &OffloadConfig) {
726        self.checksum_tx.mask_to_supported(&supported.checksum_tx);
727        self.checksum_rx.mask_to_supported(&supported.checksum_rx);
728        self.lso4 &= supported.lso4;
729        self.lso6 &= supported.lso6;
730    }
731}
732
733#[derive(Debug, Inspect, Clone)]
734struct ChecksumOffloadConfig {
735    #[inspect(safe)]
736    ipv4_header: bool,
737    #[inspect(safe)]
738    tcp4: bool,
739    #[inspect(safe)]
740    udp4: bool,
741    #[inspect(safe)]
742    tcp6: bool,
743    #[inspect(safe)]
744    udp6: bool,
745}
746
747impl ChecksumOffloadConfig {
748    fn mask_to_supported(&mut self, supported: &ChecksumOffloadConfig) {
749        self.ipv4_header &= supported.ipv4_header;
750        self.tcp4 &= supported.tcp4;
751        self.udp4 &= supported.udp4;
752        self.tcp6 &= supported.tcp6;
753        self.udp6 &= supported.udp6;
754    }
755
756    fn flags(
757        &self,
758    ) -> (
759        rndisprot::Ipv4ChecksumOffload,
760        rndisprot::Ipv6ChecksumOffload,
761    ) {
762        let on = rndisprot::NDIS_OFFLOAD_SUPPORTED;
763        let mut v4 = rndisprot::Ipv4ChecksumOffload::new();
764        let mut v6 = rndisprot::Ipv6ChecksumOffload::new();
765        if self.ipv4_header {
766            v4.set_ip_options_supported(on);
767            v4.set_ip_checksum(on);
768        }
769        if self.tcp4 {
770            v4.set_ip_options_supported(on);
771            v4.set_tcp_options_supported(on);
772            v4.set_tcp_checksum(on);
773        }
774        if self.tcp6 {
775            v6.set_ip_extension_headers_supported(on);
776            v6.set_tcp_options_supported(on);
777            v6.set_tcp_checksum(on);
778        }
779        if self.udp4 {
780            v4.set_ip_options_supported(on);
781            v4.set_udp_checksum(on);
782        }
783        if self.udp6 {
784            v6.set_ip_extension_headers_supported(on);
785            v6.set_udp_checksum(on);
786        }
787        (v4, v6)
788    }
789}
790
791impl OffloadConfig {
792    fn ndis_offload(&self) -> rndisprot::NdisOffload {
793        let checksum = {
794            let (ipv4_tx_flags, ipv6_tx_flags) = self.checksum_tx.flags();
795            let (ipv4_rx_flags, ipv6_rx_flags) = self.checksum_rx.flags();
796            rndisprot::TcpIpChecksumOffload {
797                ipv4_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
798                ipv4_tx_flags,
799                ipv4_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
800                ipv4_rx_flags,
801                ipv6_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
802                ipv6_tx_flags,
803                ipv6_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
804                ipv6_rx_flags,
805            }
806        };
807
808        let lso_v2 = {
809            let mut lso = rndisprot::TcpLargeSendOffloadV2::new_zeroed();
810            if self.lso4 {
811                lso.ipv4_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
812                lso.ipv4_max_offload_size = rndisprot::LSO_MAX_OFFLOAD_SIZE;
813                lso.ipv4_min_segment_count = rndisprot::LSO_MIN_SEGMENT_COUNT;
814            }
815            if self.lso6 {
816                lso.ipv6_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
817                lso.ipv6_max_offload_size = rndisprot::LSO_MAX_OFFLOAD_SIZE;
818                lso.ipv6_min_segment_count = rndisprot::LSO_MIN_SEGMENT_COUNT;
819                lso.ipv6_flags = rndisprot::Ipv6LsoFlags::new()
820                    .with_ip_extension_headers_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED)
821                    .with_tcp_options_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED);
822            }
823            lso
824        };
825
826        rndisprot::NdisOffload {
827            header: rndisprot::NdisObjectHeader {
828                object_type: rndisprot::NdisObjectType::OFFLOAD,
829                revision: 3,
830                size: rndisprot::NDIS_SIZEOF_NDIS_OFFLOAD_REVISION_3 as u16,
831            },
832            checksum,
833            lso_v2,
834            ..FromZeros::new_zeroed()
835        }
836    }
837}
838
839#[derive(Debug, Inspect, PartialEq, Eq, Copy, Clone)]
840pub enum RndisState {
841    Initializing,
842    Operational,
843    Halted,
844}
845
846impl PrimaryChannelState {
847    fn new(offload_config: OffloadConfig) -> Self {
848        Self {
849            guest_vf_state: PrimaryChannelGuestVfState::Initializing,
850            is_data_path_switched: None,
851            control_messages: VecDeque::new(),
852            control_messages_len: 0,
853            free_control_buffers: (0..RX_RESERVED_CONTROL_BUFFERS)
854                .map(ControlMessageId)
855                .collect(),
856            rss_state: None,
857            requested_num_queues: 1,
858            rndis_state: RndisState::Initializing,
859            pending_offload_change: false,
860            offload_config,
861            tx_spread_sent: false,
862            guest_link_up: true,
863            pending_link_action: PendingLinkAction::Default,
864        }
865    }
866
867    fn restore(
868        guest_vf_state: &saved_state::GuestVfState,
869        rndis_state: &saved_state::RndisState,
870        offload_config: &saved_state::OffloadConfig,
871        pending_offload_change: bool,
872        num_queues: u16,
873        indirection_table_size: u16,
874        rx_bufs: &RxBuffers,
875        control_messages: Vec<saved_state::IncomingControlMessage>,
876        rss_state: Option<saved_state::RssState>,
877        tx_spread_sent: bool,
878        guest_link_down: bool,
879        pending_link_action: Option<bool>,
880    ) -> Result<Self, NetRestoreError> {
881        // Restore control messages.
882        let control_messages_len = control_messages.iter().map(|msg| msg.data.len()).sum();
883
884        let control_messages = control_messages
885            .into_iter()
886            .map(|msg| ControlMessage {
887                message_type: msg.message_type,
888                data: msg.data.into(),
889            })
890            .collect();
891
892        // Compute the free control buffers.
893        let free_control_buffers = (0..RX_RESERVED_CONTROL_BUFFERS)
894            .filter_map(|id| rx_bufs.is_free(id).then_some(ControlMessageId(id)))
895            .collect();
896
897        let rss_state = rss_state
898            .map(|mut rss| {
899                if rss.indirection_table.len() > indirection_table_size as usize {
900                    // Dynamic reduction of indirection table can cause unexpected and hard to investigate issues
901                    // with performance and processor overloading.
902                    return Err(NetRestoreError::ReducedIndirectionTableSize);
903                }
904                if rss.indirection_table.len() < indirection_table_size as usize {
905                    tracing::warn!(
906                        saved_indirection_table_size = rss.indirection_table.len(),
907                        adapter_indirection_table_size = indirection_table_size,
908                        "increasing indirection table size",
909                    );
910                    // Dynamic increase of indirection table is done by duplicating the existing entries until
911                    // the desired size is reached.
912                    let table_clone = rss.indirection_table.clone();
913                    let num_to_add = indirection_table_size as usize - rss.indirection_table.len();
914                    rss.indirection_table
915                        .extend(table_clone.iter().cycle().take(num_to_add));
916                }
917                Ok(RssState {
918                    key: rss
919                        .key
920                        .try_into()
921                        .map_err(|_| NetRestoreError::InvalidRssKeySize)?,
922                    indirection_table: rss.indirection_table,
923                })
924            })
925            .transpose()?;
926
927        let rndis_state = match rndis_state {
928            saved_state::RndisState::Initializing => RndisState::Initializing,
929            saved_state::RndisState::Operational => RndisState::Operational,
930            saved_state::RndisState::Halted => RndisState::Halted,
931        };
932
933        let guest_vf_state = PrimaryChannelGuestVfState::Restoring(*guest_vf_state);
934        let offload_config = OffloadConfig {
935            checksum_tx: ChecksumOffloadConfig {
936                ipv4_header: offload_config.checksum_tx.ipv4_header,
937                tcp4: offload_config.checksum_tx.tcp4,
938                udp4: offload_config.checksum_tx.udp4,
939                tcp6: offload_config.checksum_tx.tcp6,
940                udp6: offload_config.checksum_tx.udp6,
941            },
942            checksum_rx: ChecksumOffloadConfig {
943                ipv4_header: offload_config.checksum_rx.ipv4_header,
944                tcp4: offload_config.checksum_rx.tcp4,
945                udp4: offload_config.checksum_rx.udp4,
946                tcp6: offload_config.checksum_rx.tcp6,
947                udp6: offload_config.checksum_rx.udp6,
948            },
949            lso4: offload_config.lso4,
950            lso6: offload_config.lso6,
951        };
952
953        let pending_link_action = if let Some(pending) = pending_link_action {
954            PendingLinkAction::Active(pending)
955        } else {
956            PendingLinkAction::Default
957        };
958
959        Ok(Self {
960            guest_vf_state,
961            is_data_path_switched: None,
962            control_messages,
963            control_messages_len,
964            free_control_buffers,
965            rss_state,
966            requested_num_queues: num_queues,
967            rndis_state,
968            pending_offload_change,
969            offload_config,
970            tx_spread_sent,
971            guest_link_up: !guest_link_down,
972            pending_link_action,
973        })
974    }
975}
976
977struct ControlMessage {
978    message_type: u32,
979    data: Box<[u8]>,
980}
981
982const TX_PACKET_QUOTA: usize = 1024;
983
984impl ActiveState {
985    fn new(primary: Option<PrimaryChannelState>, recv_buffer_count: u32) -> Self {
986        Self {
987            primary,
988            pending_tx_packets: vec![Default::default(); TX_PACKET_QUOTA],
989            free_tx_packets: (0..TX_PACKET_QUOTA as u32).rev().map(TxId).collect(),
990            pending_tx_completions: VecDeque::new(),
991            pending_rx_packets: VecDeque::new(),
992            rx_bufs: RxBuffers::new(recv_buffer_count),
993            stats: Default::default(),
994        }
995    }
996
997    fn restore(
998        channel: &saved_state::Channel,
999        recv_buffer_count: u32,
1000    ) -> Result<Self, NetRestoreError> {
1001        let mut active = Self::new(None, recv_buffer_count);
1002        let saved_state::Channel {
1003            pending_tx_completions,
1004            in_use_rx,
1005        } = channel;
1006        for rx in in_use_rx {
1007            active
1008                .rx_bufs
1009                .allocate(rx.buffers.as_slice().iter().copied())?;
1010        }
1011        for &transaction_id in pending_tx_completions {
1012            // Consume tx quota if any is available. If not, still
1013            // allow the restore since tx quota might change from
1014            // release to release.
1015            let tx_id = active.free_tx_packets.pop();
1016            if let Some(id) = tx_id {
1017                // This shouldn't be referenced, but set it in case it is in the future.
1018                active.pending_tx_packets[id.0 as usize].transaction_id = transaction_id;
1019            }
1020            // Save/Restore does not preserve the status of pending tx completions,
1021            // completing any pending completions with 'success' to avoid making changes to saved_state.
1022            active
1023                .pending_tx_completions
1024                .push_back(PendingTxCompletion {
1025                    transaction_id,
1026                    tx_id,
1027                    status: protocol::Status::SUCCESS,
1028                });
1029        }
1030        Ok(active)
1031    }
1032}
1033
1034/// The state for an rndis tx packet that's currently pending in the backend
1035/// endpoint.
1036#[derive(Default, Clone)]
1037struct PendingTxPacket {
1038    pending_packet_count: usize,
1039    transaction_id: u64,
1040}
1041
1042/// The maximum batch size.
1043///
1044/// TODO: An even larger value is supported when RSC is enabled, so look into
1045/// this.
1046const RX_BATCH_SIZE: usize = 375;
1047
1048/// The number of receive buffers to reserve for control message responses.
1049const RX_RESERVED_CONTROL_BUFFERS: u32 = 16;
1050
1051/// A network adapter.
1052pub struct Nic {
1053    instance_id: Guid,
1054    resources: DeviceResources,
1055    coordinator: TaskControl<CoordinatorState, Coordinator>,
1056    coordinator_send: Option<mpsc::Sender<CoordinatorMessage>>,
1057    adapter: Arc<Adapter>,
1058    driver_source: VmTaskDriverSource,
1059}
1060
1061pub struct NicBuilder {
1062    virtual_function: Option<Box<dyn VirtualFunction>>,
1063    limit_ring_buffer: bool,
1064    max_queues: u16,
1065    get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
1066}
1067
1068impl NicBuilder {
1069    pub fn limit_ring_buffer(mut self, limit: bool) -> Self {
1070        self.limit_ring_buffer = limit;
1071        self
1072    }
1073
1074    pub fn max_queues(mut self, max_queues: u16) -> Self {
1075        self.max_queues = max_queues;
1076        self
1077    }
1078
1079    pub fn virtual_function(mut self, virtual_function: Box<dyn VirtualFunction>) -> Self {
1080        self.virtual_function = Some(virtual_function);
1081        self
1082    }
1083
1084    pub fn get_guest_os_id(mut self, os_type: Box<dyn Fn() -> HvGuestOsId + Send + Sync>) -> Self {
1085        self.get_guest_os_id = Some(os_type);
1086        self
1087    }
1088
1089    /// Creates a new NIC.
1090    pub fn build(
1091        self,
1092        driver_source: &VmTaskDriverSource,
1093        instance_id: Guid,
1094        endpoint: Box<dyn Endpoint>,
1095        mac_address: MacAddress,
1096        adapter_index: u32,
1097    ) -> Nic {
1098        let multiqueue = endpoint.multiqueue_support();
1099
1100        let max_queues = self.max_queues.clamp(
1101            1,
1102            multiqueue.max_queues.min(NETVSP_MAX_SUBCHANNELS_PER_VNIC),
1103        );
1104
1105        // If requested, limit the effective size of the outgoing ring buffer.
1106        // In a configuration where the NIC is processed synchronously, this
1107        // will ensure that we don't process incoming rx packets and tx packet
1108        // completions until the guest has processed the data it already has.
1109        let ring_size_limit = if self.limit_ring_buffer { 1024 } else { 0 };
1110
1111        // If the endpoint completes tx packets quickly, then avoid polling the
1112        // incoming ring (and thus avoid arming the signal from the guest) as
1113        // long as there are any tx packets in flight. This can significantly
1114        // reduce the signal rate from the guest, improving batching.
1115        let free_tx_packet_threshold = if endpoint.tx_fast_completions() {
1116            TX_PACKET_QUOTA
1117        } else {
1118            // Avoid getting into a situation where there is always barely
1119            // enough quota.
1120            TX_PACKET_QUOTA / 4
1121        };
1122
1123        let tx_offloads = endpoint.tx_offload_support();
1124
1125        // Always claim support for rx offloads since we can mark any given
1126        // packet as having unknown checksum state.
1127        let offload_support = OffloadConfig {
1128            checksum_rx: ChecksumOffloadConfig {
1129                ipv4_header: true,
1130                tcp4: true,
1131                udp4: true,
1132                tcp6: true,
1133                udp6: true,
1134            },
1135            checksum_tx: ChecksumOffloadConfig {
1136                ipv4_header: tx_offloads.ipv4_header,
1137                tcp4: tx_offloads.tcp,
1138                tcp6: tx_offloads.tcp,
1139                udp4: tx_offloads.udp,
1140                udp6: tx_offloads.udp,
1141            },
1142            lso4: tx_offloads.tso,
1143            lso6: tx_offloads.tso,
1144        };
1145
1146        let driver = driver_source.simple();
1147        let adapter = Arc::new(Adapter {
1148            driver,
1149            mac_address,
1150            max_queues,
1151            indirection_table_size: multiqueue.indirection_table_size,
1152            offload_support,
1153            free_tx_packet_threshold,
1154            ring_size_limit: ring_size_limit.into(),
1155            tx_fast_completions: endpoint.tx_fast_completions(),
1156            adapter_index,
1157            get_guest_os_id: self.get_guest_os_id,
1158            num_sub_channels_opened: AtomicUsize::new(0),
1159            link_speed: endpoint.link_speed(),
1160        });
1161
1162        let coordinator = TaskControl::new(CoordinatorState {
1163            endpoint,
1164            adapter: adapter.clone(),
1165            virtual_function: self.virtual_function,
1166            pending_vf_state: CoordinatorStatePendingVfState::Ready,
1167        });
1168
1169        Nic {
1170            instance_id,
1171            resources: Default::default(),
1172            coordinator,
1173            coordinator_send: None,
1174            adapter,
1175            driver_source: driver_source.clone(),
1176        }
1177    }
1178}
1179
1180fn can_use_ring_opt<T: RingMem>(queue: &mut Queue<T>, guest_os_id: Option<HvGuestOsId>) -> bool {
1181    let Some(guest_os_id) = guest_os_id else {
1182        // guest os id not available.
1183        return false;
1184    };
1185
1186    if !queue.split().0.supports_pending_send_size() {
1187        // guest does not support pending send size.
1188        return false;
1189    }
1190
1191    let Some(open_source_os) = guest_os_id.open_source() else {
1192        // guest os is proprietary (ex: Windows)
1193        return true;
1194    };
1195
1196    match HvGuestOsOpenSourceType(open_source_os.os_type()) {
1197        // Although FreeBSD indicates support for `pending send size`, it doesn't
1198        // implement it correctly. This was fixed in FreeBSD version `1400097`.
1199        HvGuestOsOpenSourceType::FREEBSD => open_source_os.version() >= 1400097,
1200        // Linux kernels prior to 3.11 have issues with pending send size optimization
1201        // which can affect certain Asynchronous I/O (AIO) network operations.
1202        // Disable ring size optimization for these older kernels to avoid flow control issues.
1203        HvGuestOsOpenSourceType::LINUX => {
1204            // Linux version is encoded as: ((major << 16) | (minor << 8) | patch)
1205            // Linux 3.11.0 = (3 << 16) | (11 << 8) | 0 = 199424
1206            open_source_os.version() >= 199424
1207        }
1208        _ => true,
1209    }
1210}
1211
1212impl Nic {
1213    pub fn builder() -> NicBuilder {
1214        NicBuilder {
1215            virtual_function: None,
1216            limit_ring_buffer: false,
1217            max_queues: !0,
1218            get_guest_os_id: None,
1219        }
1220    }
1221
1222    pub fn shutdown(self) -> Box<dyn Endpoint> {
1223        let (state, _) = self.coordinator.into_inner();
1224        state.endpoint
1225    }
1226}
1227
1228impl InspectMut for Nic {
1229    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
1230        self.coordinator.inspect_mut(req);
1231    }
1232}
1233
1234#[async_trait]
1235impl VmbusDevice for Nic {
1236    fn offer(&self) -> OfferParams {
1237        OfferParams {
1238            interface_name: "net".to_owned(),
1239            instance_id: self.instance_id,
1240            interface_id: Guid {
1241                data1: 0xf8615163,
1242                data2: 0xdf3e,
1243                data3: 0x46c5,
1244                data4: [0x91, 0x3f, 0xf2, 0xd2, 0xf9, 0x65, 0xed, 0xe],
1245            },
1246            subchannel_index: 0,
1247            mnf_interrupt_latency: Some(Duration::from_micros(100)),
1248            ..Default::default()
1249        }
1250    }
1251
1252    fn max_subchannels(&self) -> u16 {
1253        self.adapter.max_queues
1254    }
1255
1256    fn install(&mut self, resources: DeviceResources) {
1257        self.resources = resources;
1258    }
1259
1260    async fn open(
1261        &mut self,
1262        channel_idx: u16,
1263        open_request: &OpenRequest,
1264    ) -> Result<(), ChannelOpenError> {
1265        // Start the coordinator task if this is the primary channel.
1266        let state = if channel_idx == 0 {
1267            self.insert_coordinator(1, None);
1268            WorkerState::Init(None)
1269        } else {
1270            self.coordinator.stop().await;
1271            // Get the buffers created when the primary channel was opened.
1272            let buffers = self.coordinator.state().unwrap().buffers.clone().unwrap();
1273            WorkerState::Ready(ReadyState {
1274                state: ActiveState::new(None, buffers.recv_buffer.count),
1275                buffers,
1276                data: ProcessingData::new(),
1277            })
1278        };
1279
1280        let num_opened = self
1281            .adapter
1282            .num_sub_channels_opened
1283            .fetch_add(1, Ordering::SeqCst);
1284        let r = self.insert_worker(channel_idx, open_request, state, true);
1285        if channel_idx != 0
1286            && num_opened + 1 == self.coordinator.state_mut().unwrap().num_queues as usize
1287        {
1288            let coordinator = &mut self.coordinator.state_mut().unwrap();
1289            coordinator.workers[0].stop().await;
1290        }
1291
1292        if r.is_err() && channel_idx == 0 {
1293            self.coordinator.remove();
1294        } else {
1295            // The coordinator will restart any stopped workers.
1296            self.coordinator.start();
1297        }
1298        r?;
1299        Ok(())
1300    }
1301
1302    async fn close(&mut self, channel_idx: u16) {
1303        if !self.coordinator.has_state() {
1304            tracing::error!(
1305                channel_idx,
1306                instance_id = %self.instance_id,
1307                "Close called while vmbus channel is already closed"
1308            );
1309            return;
1310        }
1311
1312        // Stop the coordinator to get access to the workers.
1313        let restart = self.coordinator.stop().await;
1314
1315        // Stop and remove the channel worker.
1316        {
1317            let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1318            worker.stop().await;
1319            if worker.has_state() {
1320                worker.remove();
1321            }
1322        }
1323
1324        self.adapter
1325            .num_sub_channels_opened
1326            .fetch_sub(1, Ordering::SeqCst);
1327        // Disable the endpoint.
1328        if channel_idx == 0 {
1329            for worker in &mut self.coordinator.state_mut().unwrap().workers {
1330                worker.task_mut().queue_state = None;
1331            }
1332
1333            // Note that this await is not restartable.
1334            self.coordinator.task_mut().endpoint.stop().await;
1335
1336            // Keep any VF's added to the guest. This is required to keep guest compat as
1337            // some apps (such as DPDK) relies on the VF sticking around even after vmbus
1338            // channel is closed.
1339            // The coordinator's job is done.
1340            self.coordinator.remove();
1341        } else {
1342            // Restart the coordinator.
1343            if restart {
1344                self.coordinator.start();
1345            }
1346        }
1347    }
1348
1349    async fn retarget_vp(&mut self, channel_idx: u16, target_vp: u32) {
1350        if !self.coordinator.has_state() {
1351            return;
1352        }
1353
1354        // Stop the coordinator and worker associated with this channel.
1355        let coordinator_running = self.coordinator.stop().await;
1356        let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1357        worker.stop().await;
1358        let (net_queue, worker_state) = worker.get_mut();
1359
1360        // Update the target VP on the driver.
1361        net_queue.driver.retarget_vp(target_vp);
1362
1363        if let Some(worker_state) = worker_state {
1364            // Update the target VP in the worker state.
1365            worker_state.target_vp = Some(target_vp);
1366            if let Some(queue_state) = &mut net_queue.queue_state {
1367                // Tell the worker to re-set the target VP on next run.
1368                queue_state.target_vp_set = false;
1369            }
1370        }
1371
1372        // The coordinator will restart any stopped workers.
1373        if coordinator_running {
1374            self.coordinator.start();
1375        }
1376    }
1377
1378    fn start(&mut self) {
1379        if !self.coordinator.is_running() {
1380            self.coordinator.start();
1381        }
1382    }
1383
1384    async fn stop(&mut self) {
1385        self.coordinator.stop().await;
1386        if let Some(coordinator) = self.coordinator.state_mut() {
1387            coordinator.stop_workers().await;
1388        }
1389    }
1390
1391    fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
1392        Some(self)
1393    }
1394}
1395
1396#[async_trait]
1397impl SaveRestoreVmbusDevice for Nic {
1398    async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
1399        let state = self.saved_state();
1400        Ok(SavedStateBlob::new(state))
1401    }
1402
1403    async fn restore(
1404        &mut self,
1405        control: RestoreControl<'_>,
1406        state: SavedStateBlob,
1407    ) -> Result<(), RestoreError> {
1408        let state: saved_state::SavedState = state.parse()?;
1409        if let Err(err) = self.restore_state(control, state).await {
1410            tracing::error!(
1411                error = &err as &dyn std::error::Error,
1412                instance_id = %self.instance_id,
1413                "Failed restoring network vmbus state"
1414            );
1415            Err(err.into())
1416        } else {
1417            Ok(())
1418        }
1419    }
1420}
1421
1422impl Nic {
1423    /// Allocates and inserts a worker.
1424    ///
1425    /// The coordinator must be stopped.
1426    fn insert_worker(
1427        &mut self,
1428        channel_idx: u16,
1429        open_request: &OpenRequest,
1430        state: WorkerState,
1431        start: bool,
1432    ) -> Result<(), OpenError> {
1433        let coordinator = self.coordinator.state_mut().unwrap();
1434
1435        // Retarget the driver now that the channel is open.
1436        // N.B. VMBus doesn't provide a target VP if the channel is not using interrupts. Run on VP
1437        //      0 in that case.
1438        let driver = coordinator.workers[channel_idx as usize]
1439            .task()
1440            .driver
1441            .clone();
1442        driver.retarget_vp(open_request.open_data.target_vp.unwrap_or_default());
1443
1444        let raw = gpadl_channel(&driver, &self.resources, open_request, channel_idx)
1445            .map_err(OpenError::Ring)?;
1446        let mut queue = Queue::new(raw).map_err(OpenError::Queue)?;
1447        let guest_os_id = self.adapter.get_guest_os_id.as_ref().map(|f| f());
1448        let can_use_ring_size_opt = can_use_ring_opt(&mut queue, guest_os_id);
1449        let worker = Worker {
1450            channel_idx,
1451            target_vp: open_request.open_data.target_vp,
1452            mem: self
1453                .resources
1454                .offer_resources
1455                .guest_memory(open_request)
1456                .clone(),
1457            channel: NetChannel {
1458                adapter: self.adapter.clone(),
1459                queue,
1460                gpadl_map: self.resources.gpadl_map.clone(),
1461                packet_size: PacketSize::V1,
1462                pending_send_size: 0,
1463                restart: None,
1464                can_use_ring_size_opt,
1465                packet_filter: coordinator.active_packet_filter,
1466            },
1467            state,
1468            coordinator_send: self.coordinator_send.clone().unwrap(),
1469        };
1470        let instance_id = self.instance_id;
1471        let worker_task = &mut coordinator.workers[channel_idx as usize];
1472        worker_task.insert(
1473            driver,
1474            format!("netvsp-{}-{}", instance_id, channel_idx),
1475            worker,
1476        );
1477        if start {
1478            worker_task.start();
1479        }
1480        Ok(())
1481    }
1482}
1483
1484struct RestoreCoordinatorState {
1485    active_packet_filter: u32,
1486}
1487
1488impl Nic {
1489    /// If `restoring`, then restart the queues as soon as the coordinator starts.
1490    fn insert_coordinator(&mut self, num_queues: u16, restoring: Option<RestoreCoordinatorState>) {
1491        let mut driver_builder = self.driver_source.builder();
1492        // Target each driver to VP 0 initially. This will be updated when the
1493        // channel is opened.
1494        driver_builder.target_vp(0);
1495        // If tx completions arrive quickly, then just do tx processing
1496        // on whatever processor the guest happens to signal from.
1497        // Subsequent transmits will be pulled from the completion
1498        // processor.
1499        driver_builder.run_on_target(!self.adapter.tx_fast_completions);
1500
1501        #[expect(clippy::disallowed_methods)] // TODO
1502        let (send, recv) = mpsc::channel(1);
1503        self.coordinator_send = Some(send);
1504        self.coordinator.insert(
1505            &self.adapter.driver,
1506            format!("netvsp-{}-coordinator", self.instance_id),
1507            Coordinator {
1508                recv,
1509                channel_control: self.resources.channel_control.clone(),
1510                restart: restoring.is_some(),
1511                workers: (0..self.adapter.max_queues)
1512                    .map(|i| {
1513                        TaskControl::new(NetQueue {
1514                            queue_state: None,
1515                            driver: driver_builder
1516                                .build(format!("netvsp-{}-{}", self.instance_id, i)),
1517                        })
1518                    })
1519                    .collect(),
1520                buffers: None,
1521                num_queues,
1522                active_packet_filter: restoring
1523                    .map(|r| r.active_packet_filter)
1524                    .unwrap_or(rndisprot::NDIS_PACKET_TYPE_NONE),
1525                sleep_deadline: None,
1526            },
1527        );
1528    }
1529}
1530
1531#[derive(Debug, Error)]
1532enum NetRestoreError {
1533    #[error("unsupported protocol version {0:#x}")]
1534    UnsupportedVersion(u32),
1535    #[error("send/receive buffer invalid gpadl ID")]
1536    UnknownGpadlId(#[from] UnknownGpadlId),
1537    #[error("failed to restore channels")]
1538    Channel(#[source] ChannelRestoreError),
1539    #[error(transparent)]
1540    ReceiveBuffer(#[from] BufferError),
1541    #[error(transparent)]
1542    SuballocationMisconfigured(#[from] SubAllocationInUse),
1543    #[error(transparent)]
1544    Open(#[from] OpenError),
1545    #[error("invalid rss key size")]
1546    InvalidRssKeySize,
1547    #[error("reduced indirection table size")]
1548    ReducedIndirectionTableSize,
1549}
1550
1551impl From<NetRestoreError> for RestoreError {
1552    fn from(err: NetRestoreError) -> Self {
1553        RestoreError::InvalidSavedState(anyhow::Error::new(err))
1554    }
1555}
1556
1557impl Nic {
1558    async fn restore_state(
1559        &mut self,
1560        mut control: RestoreControl<'_>,
1561        state: saved_state::SavedState,
1562    ) -> Result<(), NetRestoreError> {
1563        let mut saved_packet_filter = 0u32;
1564        if let Some(state) = state.open {
1565            let open = match &state.primary {
1566                saved_state::Primary::Version => vec![true],
1567                saved_state::Primary::Init(_) => vec![true],
1568                saved_state::Primary::Ready(ready) => {
1569                    ready.channels.iter().map(|x| x.is_some()).collect()
1570                }
1571            };
1572
1573            let mut states: Vec<_> = open.iter().map(|_| None).collect();
1574
1575            // N.B. This will restore the vmbus view of open channels, so any
1576            //      failures after this point could result in inconsistent
1577            //      state (vmbus believes the channel is open/active). There
1578            //      are a number of failure paths after this point because this
1579            //      call also restores vmbus device state, like the GPADL map.
1580            let requests = control
1581                .restore(&open)
1582                .await
1583                .map_err(NetRestoreError::Channel)?;
1584
1585            match state.primary {
1586                saved_state::Primary::Version => {
1587                    states[0] = Some(WorkerState::Init(None));
1588                }
1589                saved_state::Primary::Init(init) => {
1590                    let version = check_version(init.version)
1591                        .ok_or(NetRestoreError::UnsupportedVersion(init.version))?;
1592
1593                    let recv_buffer = init
1594                        .receive_buffer
1595                        .map(|recv_buffer| {
1596                            ReceiveBuffer::new(
1597                                &self.resources.gpadl_map,
1598                                recv_buffer.gpadl_id,
1599                                recv_buffer.id,
1600                                recv_buffer.sub_allocation_size,
1601                            )
1602                        })
1603                        .transpose()?;
1604
1605                    let send_buffer = init
1606                        .send_buffer
1607                        .map(|send_buffer| {
1608                            SendBuffer::new(&self.resources.gpadl_map, send_buffer.gpadl_id)
1609                        })
1610                        .transpose()?;
1611
1612                    let state = InitState {
1613                        version,
1614                        ndis_config: init.ndis_config.map(
1615                            |saved_state::NdisConfig { mtu, capabilities }| NdisConfig {
1616                                mtu,
1617                                capabilities: capabilities.into(),
1618                            },
1619                        ),
1620                        ndis_version: init.ndis_version.map(
1621                            |saved_state::NdisVersion { major, minor }| NdisVersion {
1622                                major,
1623                                minor,
1624                            },
1625                        ),
1626                        recv_buffer,
1627                        send_buffer,
1628                    };
1629                    states[0] = Some(WorkerState::Init(Some(state)));
1630                }
1631                saved_state::Primary::Ready(ready) => {
1632                    let saved_state::ReadyPrimary {
1633                        version,
1634                        receive_buffer,
1635                        send_buffer,
1636                        mut control_messages,
1637                        mut rss_state,
1638                        channels,
1639                        ndis_version,
1640                        ndis_config,
1641                        rndis_state,
1642                        guest_vf_state,
1643                        offload_config,
1644                        pending_offload_change,
1645                        tx_spread_sent,
1646                        guest_link_down,
1647                        pending_link_action,
1648                        packet_filter,
1649                    } = ready;
1650
1651                    // If saved state does not have a packet filter set, default to directed, multicast, and broadcast.
1652                    saved_packet_filter = packet_filter.unwrap_or(rndisprot::NPROTO_PACKET_FILTER);
1653
1654                    let version = check_version(version)
1655                        .ok_or(NetRestoreError::UnsupportedVersion(version))?;
1656
1657                    let request = requests[0].as_ref().unwrap();
1658                    let buffers = Arc::new(ChannelBuffers {
1659                        version,
1660                        mem: self.resources.offer_resources.guest_memory(request).clone(),
1661                        recv_buffer: ReceiveBuffer::new(
1662                            &self.resources.gpadl_map,
1663                            receive_buffer.gpadl_id,
1664                            receive_buffer.id,
1665                            receive_buffer.sub_allocation_size,
1666                        )?,
1667                        send_buffer: {
1668                            if let Some(send_buffer) = send_buffer {
1669                                Some(SendBuffer::new(
1670                                    &self.resources.gpadl_map,
1671                                    send_buffer.gpadl_id,
1672                                )?)
1673                            } else {
1674                                None
1675                            }
1676                        },
1677                        ndis_version: {
1678                            let saved_state::NdisVersion { major, minor } = ndis_version;
1679                            NdisVersion { major, minor }
1680                        },
1681                        ndis_config: {
1682                            let saved_state::NdisConfig { mtu, capabilities } = ndis_config;
1683                            NdisConfig {
1684                                mtu,
1685                                capabilities: capabilities.into(),
1686                            }
1687                        },
1688                    });
1689
1690                    for (channel_idx, channel) in channels.iter().enumerate() {
1691                        let channel = if let Some(channel) = channel {
1692                            channel
1693                        } else {
1694                            continue;
1695                        };
1696
1697                        let mut active = ActiveState::restore(channel, buffers.recv_buffer.count)?;
1698
1699                        // Restore primary channel state.
1700                        if channel_idx == 0 {
1701                            let primary = PrimaryChannelState::restore(
1702                                &guest_vf_state,
1703                                &rndis_state,
1704                                &offload_config,
1705                                pending_offload_change,
1706                                channels.len() as u16,
1707                                self.adapter.indirection_table_size,
1708                                &active.rx_bufs,
1709                                std::mem::take(&mut control_messages),
1710                                rss_state.take(),
1711                                tx_spread_sent,
1712                                guest_link_down,
1713                                pending_link_action,
1714                            )?;
1715                            active.primary = Some(primary);
1716                        }
1717
1718                        states[channel_idx] = Some(WorkerState::Ready(ReadyState {
1719                            buffers: buffers.clone(),
1720                            state: active,
1721                            data: ProcessingData::new(),
1722                        }));
1723                    }
1724                }
1725            }
1726
1727            // Insert the coordinator and mark that it should try to start the
1728            // network endpoint when it starts running.
1729            self.insert_coordinator(
1730                states.len() as u16,
1731                Some(RestoreCoordinatorState {
1732                    active_packet_filter: saved_packet_filter,
1733                }),
1734            );
1735
1736            for (channel_idx, (state, request)) in states.into_iter().zip(requests).enumerate() {
1737                if let Some(state) = state {
1738                    self.insert_worker(channel_idx as u16, &request.unwrap(), state, false)?;
1739                }
1740            }
1741        } else {
1742            control
1743                .restore(&[false])
1744                .await
1745                .map_err(NetRestoreError::Channel)?;
1746        }
1747        Ok(())
1748    }
1749
1750    fn saved_state(&self) -> saved_state::SavedState {
1751        let open = if let Some(coordinator) = self.coordinator.state() {
1752            let primary = coordinator.workers[0].state().unwrap();
1753            let primary = match &primary.state {
1754                WorkerState::Init(None) => saved_state::Primary::Version,
1755                WorkerState::Init(Some(init)) => {
1756                    saved_state::Primary::Init(saved_state::InitPrimary {
1757                        version: init.version as u32,
1758                        ndis_config: init.ndis_config.map(|NdisConfig { mtu, capabilities }| {
1759                            saved_state::NdisConfig {
1760                                mtu,
1761                                capabilities: capabilities.into(),
1762                            }
1763                        }),
1764                        ndis_version: init.ndis_version.map(|NdisVersion { major, minor }| {
1765                            saved_state::NdisVersion { major, minor }
1766                        }),
1767                        receive_buffer: init.recv_buffer.as_ref().map(|x| x.saved_state()),
1768                        send_buffer: init.send_buffer.as_ref().map(|x| saved_state::SendBuffer {
1769                            gpadl_id: x.gpadl.id(),
1770                        }),
1771                    })
1772                }
1773                WorkerState::WaitingForCoordinator(Some(ready)) | WorkerState::Ready(ready) => {
1774                    let primary = ready.state.primary.as_ref().unwrap();
1775
1776                    let rndis_state = match primary.rndis_state {
1777                        RndisState::Initializing => saved_state::RndisState::Initializing,
1778                        RndisState::Operational => saved_state::RndisState::Operational,
1779                        RndisState::Halted => saved_state::RndisState::Halted,
1780                    };
1781
1782                    let offload_config = saved_state::OffloadConfig {
1783                        checksum_tx: saved_state::ChecksumOffloadConfig {
1784                            ipv4_header: primary.offload_config.checksum_tx.ipv4_header,
1785                            tcp4: primary.offload_config.checksum_tx.tcp4,
1786                            udp4: primary.offload_config.checksum_tx.udp4,
1787                            tcp6: primary.offload_config.checksum_tx.tcp6,
1788                            udp6: primary.offload_config.checksum_tx.udp6,
1789                        },
1790                        checksum_rx: saved_state::ChecksumOffloadConfig {
1791                            ipv4_header: primary.offload_config.checksum_rx.ipv4_header,
1792                            tcp4: primary.offload_config.checksum_rx.tcp4,
1793                            udp4: primary.offload_config.checksum_rx.udp4,
1794                            tcp6: primary.offload_config.checksum_rx.tcp6,
1795                            udp6: primary.offload_config.checksum_rx.udp6,
1796                        },
1797                        lso4: primary.offload_config.lso4,
1798                        lso6: primary.offload_config.lso6,
1799                    };
1800
1801                    let control_messages = primary
1802                        .control_messages
1803                        .iter()
1804                        .map(|message| saved_state::IncomingControlMessage {
1805                            message_type: message.message_type,
1806                            data: message.data.to_vec(),
1807                        })
1808                        .collect();
1809
1810                    let rss_state = primary.rss_state.as_ref().map(|rss| saved_state::RssState {
1811                        key: rss.key.into(),
1812                        indirection_table: rss.indirection_table.clone(),
1813                    });
1814
1815                    let pending_link_action = match primary.pending_link_action {
1816                        PendingLinkAction::Default => None,
1817                        PendingLinkAction::Active(action) | PendingLinkAction::Delay(action) => {
1818                            Some(action)
1819                        }
1820                    };
1821
1822                    let channels = coordinator.workers[..coordinator.num_queues as usize]
1823                        .iter()
1824                        .map(|worker| {
1825                            worker.state().map(|worker| {
1826                                if let Some(ready) = worker.state.ready() {
1827                                    // In flight tx will be considered as dropped packets through save/restore, but need
1828                                    // to complete the requests back to the guest.
1829                                    let pending_tx_completions = ready
1830                                        .state
1831                                        .pending_tx_completions
1832                                        .iter()
1833                                        .map(|pending| pending.transaction_id)
1834                                        .chain(ready.state.pending_tx_packets.iter().filter_map(
1835                                            |inflight| {
1836                                                (inflight.pending_packet_count > 0)
1837                                                    .then_some(inflight.transaction_id)
1838                                            },
1839                                        ))
1840                                        .collect();
1841
1842                                    saved_state::Channel {
1843                                        pending_tx_completions,
1844                                        in_use_rx: {
1845                                            ready
1846                                                .state
1847                                                .rx_bufs
1848                                                .allocated()
1849                                                .map(|id| saved_state::Rx {
1850                                                    buffers: id.collect(),
1851                                                })
1852                                                .collect()
1853                                        },
1854                                    }
1855                                } else {
1856                                    saved_state::Channel {
1857                                        pending_tx_completions: Vec::new(),
1858                                        in_use_rx: Vec::new(),
1859                                    }
1860                                }
1861                            })
1862                        })
1863                        .collect();
1864
1865                    let guest_vf_state = match primary.guest_vf_state {
1866                        PrimaryChannelGuestVfState::Initializing
1867                        | PrimaryChannelGuestVfState::Unavailable
1868                        | PrimaryChannelGuestVfState::Available { .. } => {
1869                            saved_state::GuestVfState::NoState
1870                        }
1871                        PrimaryChannelGuestVfState::UnavailableFromAvailable
1872                        | PrimaryChannelGuestVfState::AvailableAdvertised => {
1873                            saved_state::GuestVfState::AvailableAdvertised
1874                        }
1875                        PrimaryChannelGuestVfState::Ready => saved_state::GuestVfState::Ready,
1876                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
1877                            to_guest,
1878                            id,
1879                        } => saved_state::GuestVfState::DataPathSwitchPending {
1880                            to_guest,
1881                            id,
1882                            result: None,
1883                        },
1884                        PrimaryChannelGuestVfState::DataPathSwitchPending {
1885                            to_guest,
1886                            id,
1887                            result,
1888                        } => saved_state::GuestVfState::DataPathSwitchPending {
1889                            to_guest,
1890                            id,
1891                            result,
1892                        },
1893                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
1894                        | PrimaryChannelGuestVfState::DataPathSwitched
1895                        | PrimaryChannelGuestVfState::DataPathSynthetic => {
1896                            saved_state::GuestVfState::DataPathSwitched
1897                        }
1898                        PrimaryChannelGuestVfState::Restoring(saved_state) => saved_state,
1899                    };
1900
1901                    let worker_0_packet_filter = coordinator.workers[0]
1902                        .state()
1903                        .unwrap()
1904                        .channel
1905                        .packet_filter;
1906                    saved_state::Primary::Ready(saved_state::ReadyPrimary {
1907                        version: ready.buffers.version as u32,
1908                        receive_buffer: ready.buffers.recv_buffer.saved_state(),
1909                        send_buffer: ready.buffers.send_buffer.as_ref().map(|sb| {
1910                            saved_state::SendBuffer {
1911                                gpadl_id: sb.gpadl.id(),
1912                            }
1913                        }),
1914                        rndis_state,
1915                        guest_vf_state,
1916                        offload_config,
1917                        pending_offload_change: primary.pending_offload_change,
1918                        control_messages,
1919                        rss_state,
1920                        channels,
1921                        ndis_config: {
1922                            let NdisConfig { mtu, capabilities } = ready.buffers.ndis_config;
1923                            saved_state::NdisConfig {
1924                                mtu,
1925                                capabilities: capabilities.into(),
1926                            }
1927                        },
1928                        ndis_version: {
1929                            let NdisVersion { major, minor } = ready.buffers.ndis_version;
1930                            saved_state::NdisVersion { major, minor }
1931                        },
1932                        tx_spread_sent: primary.tx_spread_sent,
1933                        guest_link_down: !primary.guest_link_up,
1934                        pending_link_action,
1935                        packet_filter: Some(worker_0_packet_filter),
1936                    })
1937                }
1938                WorkerState::WaitingForCoordinator(None) => {
1939                    unreachable!("valid ready state")
1940                }
1941            };
1942
1943            let state = saved_state::OpenState { primary };
1944            Some(state)
1945        } else {
1946            None
1947        };
1948
1949        saved_state::SavedState { open }
1950    }
1951}
1952
1953#[derive(Debug, Error)]
1954enum WorkerError {
1955    #[error("packet error")]
1956    Packet(#[source] PacketError),
1957    #[error("unexpected packet order: {0}")]
1958    UnexpectedPacketOrder(#[source] PacketOrderError),
1959    #[error("unknown rndis message type: {0}")]
1960    UnknownRndisMessageType(u32),
1961    #[error("junk after rndis packet message: {0:#x}")]
1962    NonRndisPacketAfterPacket(u32),
1963    #[error("memory access error")]
1964    Access(#[from] AccessError),
1965    #[error("rndis message too small")]
1966    RndisMessageTooSmall,
1967    #[error("unsupported rndis behavior")]
1968    UnsupportedRndisBehavior,
1969    #[error("vmbus queue error")]
1970    Queue(#[from] queue::Error),
1971    #[error("too many control messages")]
1972    TooManyControlMessages,
1973    #[error("invalid rndis packet completion")]
1974    InvalidRndisPacketCompletion,
1975    #[error("missing transaction id")]
1976    MissingTransactionId,
1977    #[error("invalid gpadl")]
1978    InvalidGpadl(#[source] guestmem::InvalidGpn),
1979    #[error("gpadl error")]
1980    GpadlError(#[source] GuestMemoryError),
1981    #[error("gpa direct error")]
1982    GpaDirectError(#[source] GuestMemoryError),
1983    #[error("endpoint")]
1984    Endpoint(#[source] anyhow::Error),
1985    #[error("message not supported on sub channel: {0}")]
1986    NotSupportedOnSubChannel(u32),
1987    #[error("the ring buffer ran out of space, which should not be possible")]
1988    OutOfSpace,
1989    #[error("send/receive buffer error")]
1990    Buffer(#[from] BufferError),
1991    #[error("invalid rndis state")]
1992    InvalidRndisState,
1993    #[error("rndis message type not implemented")]
1994    RndisMessageTypeNotImplemented,
1995    #[error("invalid TCP header offset")]
1996    InvalidTcpHeaderOffset,
1997    #[error("cancelled")]
1998    Cancelled(task_control::Cancelled),
1999    #[error("tearing down because send/receive buffer is revoked")]
2000    BufferRevoked,
2001    #[error("endpoint requires queue restart: {0}")]
2002    EndpointRequiresQueueRestart(#[source] anyhow::Error),
2003    #[error("Failed to send message to coordinator")]
2004    CoordinatorMessageSendFailed(#[source] TrySendError<CoordinatorMessage>),
2005}
2006
2007impl From<task_control::Cancelled> for WorkerError {
2008    fn from(value: task_control::Cancelled) -> Self {
2009        Self::Cancelled(value)
2010    }
2011}
2012
2013#[derive(Debug, Error)]
2014enum OpenError {
2015    #[error("error establishing ring buffer")]
2016    Ring(#[source] vmbus_channel::gpadl_ring::Error),
2017    #[error("error establishing vmbus queue")]
2018    Queue(#[source] queue::Error),
2019}
2020
2021#[derive(Debug, Error)]
2022enum PacketError {
2023    #[error("UnknownType {0}")]
2024    UnknownType(u32),
2025    #[error("Access")]
2026    Access(#[source] AccessError),
2027    #[error("ExternalData")]
2028    ExternalData(#[source] ExternalDataError),
2029    #[error("InvalidSendBufferIndex")]
2030    InvalidSendBufferIndex,
2031}
2032
2033#[derive(Debug, Error)]
2034enum PacketOrderError {
2035    #[error("Invalid PacketData")]
2036    InvalidPacketData,
2037    #[error("Unexpected RndisPacket")]
2038    UnexpectedRndisPacket,
2039    #[error("SendNdisVersion already exists")]
2040    SendNdisVersionExists,
2041    #[error("SendNdisConfig already exists")]
2042    SendNdisConfigExists,
2043    #[error("SendReceiveBuffer already exists")]
2044    SendReceiveBufferExists,
2045    #[error("SendReceiveBuffer missing MTU")]
2046    SendReceiveBufferMissingMTU,
2047    #[error("SendSendBuffer already exists")]
2048    SendSendBufferExists,
2049    #[error("SwitchDataPathCompletion during PrimaryChannelState")]
2050    SwitchDataPathCompletionPrimaryChannelState,
2051}
2052
2053#[derive(Debug)]
2054enum PacketData {
2055    Init(protocol::MessageInit),
2056    SendNdisVersion(protocol::Message1SendNdisVersion),
2057    SendReceiveBuffer(protocol::Message1SendReceiveBuffer),
2058    SendSendBuffer(protocol::Message1SendSendBuffer),
2059    RevokeReceiveBuffer(protocol::Message1RevokeReceiveBuffer),
2060    RevokeSendBuffer(protocol::Message1RevokeSendBuffer),
2061    RndisPacket(protocol::Message1SendRndisPacket),
2062    RndisPacketComplete(protocol::Message1SendRndisPacketComplete),
2063    SendNdisConfig(protocol::Message2SendNdisConfig),
2064    SwitchDataPath(protocol::Message4SwitchDataPath),
2065    OidQueryEx(protocol::Message5OidQueryEx),
2066    SubChannelRequest(protocol::Message5SubchannelRequest),
2067    SendVfAssociationCompletion,
2068    SwitchDataPathCompletion,
2069}
2070
2071#[derive(Debug)]
2072struct Packet<'a> {
2073    data: PacketData,
2074    transaction_id: Option<u64>,
2075    external_data: &'a MultiPagedRangeBuf,
2076}
2077
2078type PacketReader<'a> = PagedRangesReader<'a, MultiPagedRangeIter<'a>>;
2079
2080impl Packet<'_> {
2081    fn rndis_reader<'a>(&'a self, mem: &'a GuestMemory) -> PacketReader<'a> {
2082        PagedRanges::new(self.external_data.iter()).reader(mem)
2083    }
2084}
2085
2086fn read_packet_data<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
2087    reader: &mut impl MemoryRead,
2088) -> Result<T, PacketError> {
2089    reader.read_plain().map_err(PacketError::Access)
2090}
2091
2092fn parse_packet<'a, T: RingMem>(
2093    packet_ref: &queue::PacketRef<'_, T>,
2094    send_buffer: Option<&SendBuffer>,
2095    version: Option<Version>,
2096    external_data: &'a mut MultiPagedRangeBuf,
2097) -> Result<Packet<'a>, PacketError> {
2098    external_data.clear();
2099    let packet = match packet_ref.as_ref() {
2100        IncomingPacket::Data(data) => data,
2101        IncomingPacket::Completion(completion) => {
2102            let data = if completion.transaction_id() == VF_ASSOCIATION_TRANSACTION_ID {
2103                PacketData::SendVfAssociationCompletion
2104            } else if completion.transaction_id() == SWITCH_DATA_PATH_TRANSACTION_ID {
2105                PacketData::SwitchDataPathCompletion
2106            } else {
2107                let mut reader = completion.reader();
2108                let header: protocol::MessageHeader =
2109                    reader.read_plain().map_err(PacketError::Access)?;
2110                match header.message_type {
2111                    protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE => {
2112                        PacketData::RndisPacketComplete(read_packet_data(&mut reader)?)
2113                    }
2114                    typ => return Err(PacketError::UnknownType(typ)),
2115                }
2116            };
2117            return Ok(Packet {
2118                data,
2119                transaction_id: Some(completion.transaction_id()),
2120                external_data,
2121            });
2122        }
2123    };
2124
2125    let mut reader = packet.reader();
2126    let header: protocol::MessageHeader = reader.read_plain().map_err(PacketError::Access)?;
2127    let data = match header.message_type {
2128        protocol::MESSAGE_TYPE_INIT => PacketData::Init(read_packet_data(&mut reader)?),
2129        protocol::MESSAGE1_TYPE_SEND_NDIS_VERSION if version >= Some(Version::V1) => {
2130            PacketData::SendNdisVersion(read_packet_data(&mut reader)?)
2131        }
2132        protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2133            PacketData::SendReceiveBuffer(read_packet_data(&mut reader)?)
2134        }
2135        protocol::MESSAGE1_TYPE_REVOKE_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2136            PacketData::RevokeReceiveBuffer(read_packet_data(&mut reader)?)
2137        }
2138        protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER if version >= Some(Version::V1) => {
2139            PacketData::SendSendBuffer(read_packet_data(&mut reader)?)
2140        }
2141        protocol::MESSAGE1_TYPE_REVOKE_SEND_BUFFER if version >= Some(Version::V1) => {
2142            PacketData::RevokeSendBuffer(read_packet_data(&mut reader)?)
2143        }
2144        protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET if version >= Some(Version::V1) => {
2145            let message: protocol::Message1SendRndisPacket = read_packet_data(&mut reader)?;
2146            if message.send_buffer_section_index != 0xffffffff {
2147                let send_buffer_suballocation = send_buffer
2148                    .ok_or(PacketError::InvalidSendBufferIndex)?
2149                    .gpadl
2150                    .first()
2151                    .unwrap()
2152                    .try_subrange(
2153                        message.send_buffer_section_index as usize * 6144,
2154                        message.send_buffer_section_size as usize,
2155                    )
2156                    .ok_or(PacketError::InvalidSendBufferIndex)?;
2157
2158                external_data.push_range(send_buffer_suballocation);
2159            }
2160            PacketData::RndisPacket(message)
2161        }
2162        protocol::MESSAGE2_TYPE_SEND_NDIS_CONFIG if version >= Some(Version::V2) => {
2163            PacketData::SendNdisConfig(read_packet_data(&mut reader)?)
2164        }
2165        protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH if version >= Some(Version::V4) => {
2166            PacketData::SwitchDataPath(read_packet_data(&mut reader)?)
2167        }
2168        protocol::MESSAGE5_TYPE_OID_QUERY_EX if version >= Some(Version::V5) => {
2169            PacketData::OidQueryEx(read_packet_data(&mut reader)?)
2170        }
2171        protocol::MESSAGE5_TYPE_SUB_CHANNEL if version >= Some(Version::V5) => {
2172            PacketData::SubChannelRequest(read_packet_data(&mut reader)?)
2173        }
2174        typ => return Err(PacketError::UnknownType(typ)),
2175    };
2176    packet
2177        .read_external_ranges(external_data)
2178        .map_err(PacketError::ExternalData)?;
2179    Ok(Packet {
2180        data,
2181        transaction_id: packet.transaction_id(),
2182        external_data,
2183    })
2184}
2185
2186#[derive(Debug, Copy, Clone)]
2187struct NvspMessage {
2188    buf: [u64; protocol::PACKET_SIZE_V61 / 8],
2189    size: PacketSize,
2190}
2191
2192impl NvspMessage {
2193    fn new<P: IntoBytes + Immutable + KnownLayout>(
2194        size: PacketSize,
2195        message_type: u32,
2196        data: P,
2197    ) -> Self {
2198        // Assert at compile time that the packet will fit in the message
2199        // buffer. Note that we are checking against the v1 message size here.
2200        // It's possible this is a v6.1+ message, in which case we could compare
2201        // against the larger size. So far this has not been necessary. If
2202        // needed, make a `new_v61` method that does the more relaxed check,
2203        // rater than weakening this one.
2204        const {
2205            assert!(
2206                size_of::<P>() <= protocol::PACKET_SIZE_V1 - size_of::<protocol::MessageHeader>(),
2207                "packet might not fit in message"
2208            )
2209        };
2210        let mut message = NvspMessage {
2211            buf: [0; protocol::PACKET_SIZE_V61 / 8],
2212            size,
2213        };
2214        let header = protocol::MessageHeader { message_type };
2215        header.write_to_prefix(message.buf.as_mut_bytes()).unwrap();
2216        data.write_to_prefix(
2217            &mut message.buf.as_mut_bytes()[size_of::<protocol::MessageHeader>()..],
2218        )
2219        .unwrap();
2220        message
2221    }
2222
2223    fn aligned_payload(&self) -> &[u64] {
2224        // Note that vmbus packets are always 8-byte multiples, so round the
2225        // protocol package size up.
2226        let len = match self.size {
2227            PacketSize::V1 => const { protocol::PACKET_SIZE_V1.div_ceil(8) },
2228            PacketSize::V61 => const { protocol::PACKET_SIZE_V61.div_ceil(8) },
2229        };
2230        &self.buf[..len]
2231    }
2232}
2233
2234impl<T: RingMem> NetChannel<T> {
2235    fn message<P: IntoBytes + Immutable + KnownLayout>(
2236        &self,
2237        message_type: u32,
2238        data: P,
2239    ) -> NvspMessage {
2240        NvspMessage::new(self.packet_size, message_type, data)
2241    }
2242
2243    fn send_completion(
2244        &mut self,
2245        transaction_id: Option<u64>,
2246        message: Option<&NvspMessage>,
2247    ) -> Result<(), WorkerError> {
2248        match transaction_id {
2249            None => Ok(()),
2250            Some(transaction_id) => Ok(self
2251                .queue
2252                .split()
2253                .1
2254                .batched()
2255                .try_write_aligned(
2256                    transaction_id,
2257                    OutgoingPacketType::Completion,
2258                    message.map_or(&[], |m| m.aligned_payload()),
2259                )
2260                .map_err(|err| match err {
2261                    queue::TryWriteError::Full(_) => WorkerError::OutOfSpace,
2262                    queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2263                })?),
2264        }
2265    }
2266}
2267
2268static SUPPORTED_VERSIONS: &[Version] = &[
2269    Version::V1,
2270    Version::V2,
2271    Version::V4,
2272    Version::V5,
2273    Version::V6,
2274    Version::V61,
2275];
2276
2277fn check_version(requested_version: u32) -> Option<Version> {
2278    SUPPORTED_VERSIONS
2279        .iter()
2280        .find(|version| **version as u32 == requested_version)
2281        .copied()
2282}
2283
2284#[derive(Debug)]
2285struct ReceiveBuffer {
2286    gpadl: GpadlView,
2287    id: u16,
2288    count: u32,
2289    sub_allocation_size: u32,
2290}
2291
2292#[derive(Debug, Error)]
2293enum BufferError {
2294    #[error("unsupported suballocation size {0}")]
2295    UnsupportedSuballocationSize(u32),
2296    #[error("unaligned gpadl")]
2297    UnalignedGpadl,
2298    #[error("unknown gpadl ID")]
2299    UnknownGpadlId(#[from] UnknownGpadlId),
2300}
2301
2302impl ReceiveBuffer {
2303    fn new(
2304        gpadl_map: &GpadlMapView,
2305        gpadl_id: GpadlId,
2306        id: u16,
2307        sub_allocation_size: u32,
2308    ) -> Result<Self, BufferError> {
2309        if sub_allocation_size < sub_allocation_size_for_mtu(DEFAULT_MTU) {
2310            return Err(BufferError::UnsupportedSuballocationSize(
2311                sub_allocation_size,
2312            ));
2313        }
2314        let gpadl = gpadl_map.map(gpadl_id)?;
2315        let range = gpadl
2316            .contiguous_aligned()
2317            .ok_or(BufferError::UnalignedGpadl)?;
2318        let num_sub_allocations = range.len() as u32 / sub_allocation_size;
2319        if num_sub_allocations == 0 {
2320            return Err(BufferError::UnsupportedSuballocationSize(
2321                sub_allocation_size,
2322            ));
2323        }
2324        let recv_buffer = Self {
2325            gpadl,
2326            id,
2327            count: num_sub_allocations,
2328            sub_allocation_size,
2329        };
2330        Ok(recv_buffer)
2331    }
2332
2333    fn range(&self, index: u32) -> PagedRange<'_> {
2334        self.gpadl.first().unwrap().subrange(
2335            (index * self.sub_allocation_size) as usize,
2336            self.sub_allocation_size as usize,
2337        )
2338    }
2339
2340    fn transfer_page_range(&self, index: u32, len: usize) -> ring::TransferPageRange {
2341        assert!(len <= self.sub_allocation_size as usize);
2342        ring::TransferPageRange {
2343            byte_offset: index * self.sub_allocation_size,
2344            byte_count: len as u32,
2345        }
2346    }
2347
2348    fn saved_state(&self) -> saved_state::ReceiveBuffer {
2349        saved_state::ReceiveBuffer {
2350            gpadl_id: self.gpadl.id(),
2351            id: self.id,
2352            sub_allocation_size: self.sub_allocation_size,
2353        }
2354    }
2355}
2356
2357#[derive(Debug)]
2358struct SendBuffer {
2359    gpadl: GpadlView,
2360}
2361
2362impl SendBuffer {
2363    fn new(gpadl_map: &GpadlMapView, gpadl_id: GpadlId) -> Result<Self, BufferError> {
2364        let gpadl = gpadl_map.map(gpadl_id)?;
2365        gpadl
2366            .contiguous_aligned()
2367            .ok_or(BufferError::UnalignedGpadl)?;
2368        Ok(Self { gpadl })
2369    }
2370}
2371
2372impl<T: RingMem> NetChannel<T> {
2373    /// Process a single non-packet RNDIS message.
2374    fn handle_rndis_message(
2375        &mut self,
2376        state: &mut ActiveState,
2377        message_type: u32,
2378        mut reader: PacketReader<'_>,
2379    ) -> Result<(), WorkerError> {
2380        assert_ne!(
2381            message_type,
2382            rndisprot::MESSAGE_TYPE_PACKET_MSG,
2383            "handled elsewhere"
2384        );
2385        let control = state
2386            .primary
2387            .as_mut()
2388            .ok_or(WorkerError::NotSupportedOnSubChannel(message_type))?;
2389
2390        if message_type == rndisprot::MESSAGE_TYPE_HALT_MSG {
2391            // Currently ignored and does not require a response.
2392            return Ok(());
2393        }
2394
2395        // This is a control message that needs a response. Responding
2396        // will require a suballocation to be available, which it may
2397        // not be right now. Enqueue the suballocation to a queue and
2398        // process the queue as suballocations become available.
2399        const CONTROL_MESSAGE_MAX_QUEUED_BYTES: usize = 100 * 1024;
2400        if reader.len() == 0 {
2401            return Err(WorkerError::RndisMessageTooSmall);
2402        }
2403        // Do not let the queue get too large--the guest should not be
2404        // sending very many control messages at a time.
2405        if CONTROL_MESSAGE_MAX_QUEUED_BYTES - control.control_messages_len < reader.len() {
2406            return Err(WorkerError::TooManyControlMessages);
2407        }
2408
2409        control.control_messages_len += reader.len();
2410        control.control_messages.push_back(ControlMessage {
2411            message_type,
2412            data: reader.read_all()?.into(),
2413        });
2414
2415        // The control message queue will be processed in the main dispatch
2416        // loop.
2417        Ok(())
2418    }
2419
2420    /// Process RNDIS packet messages, which may contain multiple RNDIS packets
2421    /// in a single vmbus message.
2422    ///
2423    /// On entry, the reader has already read the RNDIS message header of the
2424    /// first RNDIS packet in the message.
2425    fn handle_rndis_packet_messages(
2426        &mut self,
2427        buffers: &ChannelBuffers,
2428        state: &mut ActiveState,
2429        id: TxId,
2430        mut message_len: usize,
2431        mut reader: PacketReader<'_>,
2432        segments: &mut Vec<TxSegment>,
2433    ) -> Result<usize, WorkerError> {
2434        // There may be multiple RNDIS packets in a single message, concatenated
2435        // with each other. Consume them until there is no more data in the
2436        // RNDIS message.
2437        let mut num_packets = 0;
2438        loop {
2439            let next_message_offset = message_len
2440                .checked_sub(size_of::<rndisprot::MessageHeader>())
2441                .ok_or(WorkerError::RndisMessageTooSmall)?;
2442
2443            self.handle_rndis_packet_message(
2444                id,
2445                reader.clone(),
2446                &buffers.mem,
2447                segments,
2448                &mut state.stats,
2449            )?;
2450            num_packets += 1;
2451
2452            reader.skip(next_message_offset)?;
2453            if reader.len() == 0 {
2454                break;
2455            }
2456            let header: rndisprot::MessageHeader = reader.read_plain()?;
2457            if header.message_type != rndisprot::MESSAGE_TYPE_PACKET_MSG {
2458                return Err(WorkerError::NonRndisPacketAfterPacket(header.message_type));
2459            }
2460            message_len = header.message_length as usize;
2461        }
2462        Ok(num_packets)
2463    }
2464
2465    /// Process an RNDIS package message (used to send an Ethernet frame).
2466    fn handle_rndis_packet_message(
2467        &mut self,
2468        id: TxId,
2469        reader: PacketReader<'_>,
2470        mem: &GuestMemory,
2471        segments: &mut Vec<TxSegment>,
2472        stats: &mut QueueStats,
2473    ) -> Result<(), WorkerError> {
2474        // Headers are guaranteed to be in a single PagedRange.
2475        let headers = reader
2476            .clone()
2477            .into_inner()
2478            .paged_ranges()
2479            .next()
2480            .ok_or(WorkerError::RndisMessageTooSmall)?;
2481        let mut data = reader.into_inner();
2482        let request: rndisprot::Packet = headers.reader(mem).read_plain()?;
2483        if request.num_oob_data_elements != 0
2484            || request.oob_data_length != 0
2485            || request.oob_data_offset != 0
2486            || request.vc_handle != 0
2487        {
2488            return Err(WorkerError::UnsupportedRndisBehavior);
2489        }
2490
2491        if data.len() < request.data_offset as usize
2492            || (data.len() - request.data_offset as usize) < request.data_length as usize
2493            || request.data_length == 0
2494        {
2495            return Err(WorkerError::RndisMessageTooSmall);
2496        }
2497
2498        data.skip(request.data_offset as usize);
2499        data.truncate(request.data_length as usize);
2500
2501        let mut metadata = net_backend::TxMetadata {
2502            id,
2503            len: request.data_length,
2504            ..Default::default()
2505        };
2506
2507        if request.per_packet_info_length != 0 {
2508            let mut ppi = headers
2509                .try_subrange(
2510                    request.per_packet_info_offset as usize,
2511                    request.per_packet_info_length as usize,
2512                )
2513                .ok_or(WorkerError::RndisMessageTooSmall)?;
2514            while !ppi.is_empty() {
2515                let h: rndisprot::PerPacketInfo = ppi.reader(mem).read_plain()?;
2516                if h.size == 0 {
2517                    return Err(WorkerError::RndisMessageTooSmall);
2518                }
2519                let (this, rest) = ppi
2520                    .try_split(h.size as usize)
2521                    .ok_or(WorkerError::RndisMessageTooSmall)?;
2522                let (_, d) = this
2523                    .try_split(h.per_packet_information_offset as usize)
2524                    .ok_or(WorkerError::RndisMessageTooSmall)?;
2525                match h.typ {
2526                    rndisprot::PPI_TCP_IP_CHECKSUM => {
2527                        let n: rndisprot::TxTcpIpChecksumInfo = d.reader(mem).read_plain()?;
2528
2529                        metadata.flags.set_offload_tcp_checksum(
2530                            (n.is_ipv4() || n.is_ipv6()) && n.tcp_checksum(),
2531                        );
2532                        metadata.flags.set_offload_udp_checksum(
2533                            (n.is_ipv4() || n.is_ipv6()) && !n.tcp_checksum() && n.udp_checksum(),
2534                        );
2535                        metadata
2536                            .flags
2537                            .set_offload_ip_header_checksum(n.is_ipv4() && n.ip_header_checksum());
2538                        metadata.flags.set_is_ipv4(n.is_ipv4());
2539                        metadata.flags.set_is_ipv6(n.is_ipv6() && !n.is_ipv4());
2540                        metadata.l2_len = ETHERNET_HEADER_LEN as u8;
2541                        if metadata.flags.offload_tcp_checksum()
2542                            || metadata.flags.offload_udp_checksum()
2543                        {
2544                            metadata.l3_len = if n.tcp_header_offset() >= metadata.l2_len as u16 {
2545                                n.tcp_header_offset() - metadata.l2_len as u16
2546                            } else if n.is_ipv4() {
2547                                let mut reader = data.clone().reader(mem);
2548                                reader.skip(metadata.l2_len as usize)?;
2549                                let mut b = 0;
2550                                reader.read(std::slice::from_mut(&mut b))?;
2551                                (b as u16 >> 4) * 4
2552                            } else {
2553                                // Hope there are no extensions.
2554                                40
2555                            };
2556                        }
2557                    }
2558                    rndisprot::PPI_LSO => {
2559                        let n: rndisprot::TcpLsoInfo = d.reader(mem).read_plain()?;
2560
2561                        metadata.flags.set_offload_tcp_segmentation(true);
2562                        metadata.flags.set_offload_tcp_checksum(true);
2563                        metadata.flags.set_offload_ip_header_checksum(n.is_ipv4());
2564                        metadata.flags.set_is_ipv4(n.is_ipv4());
2565                        metadata.flags.set_is_ipv6(n.is_ipv6() && !n.is_ipv4());
2566                        metadata.l2_len = ETHERNET_HEADER_LEN as u8;
2567                        if n.tcp_header_offset() < metadata.l2_len as u16 {
2568                            return Err(WorkerError::InvalidTcpHeaderOffset);
2569                        }
2570                        metadata.l3_len = n.tcp_header_offset() - metadata.l2_len as u16;
2571                        metadata.l4_len = {
2572                            let mut reader = data.clone().reader(mem);
2573                            reader
2574                                .skip(metadata.l2_len as usize + metadata.l3_len as usize + 12)?;
2575                            let mut b = 0;
2576                            reader.read(std::slice::from_mut(&mut b))?;
2577                            (b >> 4) * 4
2578                        };
2579                        metadata.max_tcp_segment_size = n.mss() as u16;
2580
2581                        if request.data_length >= rndisprot::LSO_MAX_OFFLOAD_SIZE {
2582                            // Not strictly enforced.
2583                            stats.tx_invalid_lso_packets.increment();
2584                        }
2585                    }
2586                    _ => {}
2587                }
2588                ppi = rest;
2589            }
2590        }
2591
2592        let start = segments.len();
2593        for range in data.paged_ranges().flat_map(|r| r.ranges()) {
2594            let range = range.map_err(WorkerError::InvalidGpadl)?;
2595            segments.push(TxSegment {
2596                ty: net_backend::TxSegmentType::Tail,
2597                gpa: range.start,
2598                len: range.len() as u32,
2599            });
2600        }
2601
2602        metadata.segment_count = (segments.len() - start) as u8;
2603
2604        stats.tx_packets.increment();
2605        if metadata.flags.offload_tcp_checksum() || metadata.flags.offload_udp_checksum() {
2606            stats.tx_checksum_packets.increment();
2607        }
2608        if metadata.flags.offload_tcp_segmentation() {
2609            stats.tx_lso_packets.increment();
2610        }
2611
2612        segments[start].ty = net_backend::TxSegmentType::Head(metadata);
2613
2614        Ok(())
2615    }
2616
2617    /// Notify the adapter that the guest VF state has changed and it may
2618    /// need to send a message to the guest.
2619    fn guest_vf_is_available(
2620        &mut self,
2621        guest_vf_id: Option<u32>,
2622        version: Version,
2623        config: NdisConfig,
2624    ) -> Result<bool, WorkerError> {
2625        let serial_number = guest_vf_id.map(|vfid| self.adapter.get_guest_vf_serial_number(vfid));
2626        if version >= Version::V4 && config.capabilities.sriov() {
2627            tracing::info!(
2628                available = serial_number.is_some(),
2629                serial_number,
2630                "sending VF association message"
2631            );
2632            // N.B. MIN_CONTROL_RING_SIZE reserves room to send this packet.
2633            let message = {
2634                self.message(
2635                    protocol::MESSAGE4_TYPE_SEND_VF_ASSOCIATION,
2636                    protocol::Message4SendVfAssociation {
2637                        vf_allocated: if serial_number.is_some() { 1 } else { 0 },
2638                        serial_number: serial_number.unwrap_or(0),
2639                    },
2640                )
2641            };
2642            self.queue
2643                .split()
2644                .1
2645                .batched()
2646                .try_write_aligned(
2647                    VF_ASSOCIATION_TRANSACTION_ID,
2648                    OutgoingPacketType::InBandWithCompletion,
2649                    message.aligned_payload(),
2650                )
2651                .map_err(|err| match err {
2652                    queue::TryWriteError::Full(len) => {
2653                        tracing::error!(len, "failed to write vf association message");
2654                        WorkerError::OutOfSpace
2655                    }
2656                    queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2657                })?;
2658            Ok(true)
2659        } else {
2660            tracing::info!(
2661                available = serial_number.is_some(),
2662                serial_number,
2663                major = version.major(),
2664                minor = version.minor(),
2665                sriov_capable = config.capabilities.sriov(),
2666                "Skipping NvspMessage4TypeSendVFAssociation message"
2667            );
2668            Ok(false)
2669        }
2670    }
2671
2672    /// Send the `NvspMessage5TypeSendIndirectionTable` message.
2673    fn guest_send_indirection_table(&mut self, version: Version, num_channels_opened: u32) {
2674        // N.B. MIN_STATE_CHANGE_RING_SIZE needs to be large enough to support sending the indirection table.
2675        if version < Version::V5 {
2676            return;
2677        }
2678
2679        #[repr(C)]
2680        #[derive(IntoBytes, Immutable, KnownLayout)]
2681        struct SendIndirectionMsg {
2682            pub message: protocol::Message5SendIndirectionTable,
2683            pub send_indirection_table:
2684                [u32; VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES as usize],
2685        }
2686
2687        // The offset to the send indirection table from the beginning of the NVSP message.
2688        let send_indirection_table_offset = offset_of!(SendIndirectionMsg, send_indirection_table)
2689            + size_of::<protocol::MessageHeader>();
2690        let mut data = SendIndirectionMsg {
2691            message: protocol::Message5SendIndirectionTable {
2692                table_entry_count: VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES,
2693                table_offset: send_indirection_table_offset as u32,
2694            },
2695            send_indirection_table: Default::default(),
2696        };
2697
2698        for i in 0..data.send_indirection_table.len() {
2699            data.send_indirection_table[i] = i as u32 % num_channels_opened;
2700        }
2701
2702        let header = protocol::MessageHeader {
2703            message_type: protocol::MESSAGE5_TYPE_SEND_INDIRECTION_TABLE,
2704        };
2705        let result = self
2706            .queue
2707            .split()
2708            .1
2709            .try_write(&queue::OutgoingPacket {
2710                transaction_id: 0,
2711                packet_type: OutgoingPacketType::InBandNoCompletion,
2712                payload: &[header.as_bytes(), data.as_bytes()],
2713            })
2714            .map_err(|err| match err {
2715                queue::TryWriteError::Full(len) => {
2716                    tracing::error!(len, "failed to write send indirection table message");
2717                    WorkerError::OutOfSpace
2718                }
2719                queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2720            });
2721        if let Err(err) = result {
2722            tracing::error!(
2723                error = &err as &dyn std::error::Error,
2724                "Failed to notify guest about the send indirection table"
2725            );
2726        }
2727    }
2728
2729    /// Notify the guest that the data path has been switched back to synthetic
2730    /// due to some external state change.
2731    fn guest_vf_data_path_switched_to_synthetic(&mut self) {
2732        let header = protocol::MessageHeader {
2733            message_type: protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH,
2734        };
2735        let data = protocol::Message4SwitchDataPath {
2736            active_data_path: protocol::DataPath::SYNTHETIC.0,
2737        };
2738        let result = self
2739            .queue
2740            .split()
2741            .1
2742            .try_write(&queue::OutgoingPacket {
2743                transaction_id: SWITCH_DATA_PATH_TRANSACTION_ID,
2744                packet_type: OutgoingPacketType::InBandWithCompletion,
2745                payload: &[header.as_bytes(), data.as_bytes()],
2746            })
2747            .map_err(|err| match err {
2748                queue::TryWriteError::Full(len) => {
2749                    tracing::error!(len, "failed to write switch data path message");
2750                    WorkerError::OutOfSpace
2751                }
2752                queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2753            });
2754        if let Err(err) = result {
2755            tracing::error!(
2756                error = &err as &dyn std::error::Error,
2757                "Failed to notify guest that data path is now synthetic"
2758            );
2759        }
2760    }
2761
2762    /// Process an internal state change
2763    async fn handle_state_change(
2764        &mut self,
2765        primary: &mut PrimaryChannelState,
2766        buffers: &ChannelBuffers,
2767    ) -> Result<Option<CoordinatorMessage>, WorkerError> {
2768        // N.B. MIN_STATE_CHANGE_RING_SIZE needs to be large enough to support sending state change messages.
2769        // The worst case is UnavailableFromDataPathSwitchPending, which will send three messages:
2770        //      1. completion of switch data path request
2771        //      2. Switch data path notification (back to synthetic)
2772        //      3. Disassociate VF adapter.
2773        if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
2774            // Notify guest that a VF capability has recently arrived.
2775            if primary.rndis_state == RndisState::Operational {
2776                if self.guest_vf_is_available(Some(vfid), buffers.version, buffers.ndis_config)? {
2777                    primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
2778                    return Ok(Some(CoordinatorMessage::Update(
2779                        CoordinatorMessageUpdateType {
2780                            guest_vf_state: true,
2781                            ..Default::default()
2782                        },
2783                    )));
2784                } else if let Some(true) = primary.is_data_path_switched {
2785                    tracing::error!(
2786                        "Data path switched, but current guest negotiation does not support VTL0 VF"
2787                    );
2788                }
2789            }
2790            return Ok(None);
2791        }
2792        loop {
2793            primary.guest_vf_state = match primary.guest_vf_state {
2794                PrimaryChannelGuestVfState::UnavailableFromAvailable => {
2795                    // Notify guest that the VF is unavailable. It has already been surprise removed.
2796                    if primary.rndis_state == RndisState::Operational {
2797                        self.guest_vf_is_available(None, buffers.version, buffers.ndis_config)?;
2798                    }
2799                    PrimaryChannelGuestVfState::Unavailable
2800                }
2801                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
2802                    to_guest,
2803                    id,
2804                } => {
2805                    // Complete the data path switch request.
2806                    self.send_completion(id, None)?;
2807                    if to_guest {
2808                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
2809                    } else {
2810                        PrimaryChannelGuestVfState::UnavailableFromAvailable
2811                    }
2812                }
2813                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
2814                    // Notify guest that the data path is now synthetic.
2815                    self.guest_vf_data_path_switched_to_synthetic();
2816                    PrimaryChannelGuestVfState::UnavailableFromAvailable
2817                }
2818                PrimaryChannelGuestVfState::DataPathSynthetic => {
2819                    // Notify guest that the data path is now synthetic.
2820                    self.guest_vf_data_path_switched_to_synthetic();
2821                    PrimaryChannelGuestVfState::Ready
2822                }
2823                PrimaryChannelGuestVfState::DataPathSwitchPending {
2824                    to_guest,
2825                    id,
2826                    result,
2827                } => {
2828                    let result = result.expect("DataPathSwitchPending should have been processed");
2829                    // Complete the data path switch request.
2830                    self.send_completion(id, None)?;
2831
2832                    match (to_guest, result) {
2833                        // Switching to guest VF successful.
2834                        (true, true) => PrimaryChannelGuestVfState::DataPathSwitched,
2835                        // Switching to guest VF failed, stay synthetic.
2836                        (true, false) => {
2837                            tracing::error!(
2838                                "Failure switching to guest VF, remaining on synthetic"
2839                            );
2840                            PrimaryChannelGuestVfState::DataPathSynthetic
2841                        }
2842                        // Switching to synthetic successful.
2843                        (false, true) => PrimaryChannelGuestVfState::Ready,
2844                        // Switching to synthetic failed, assume VF remains active.
2845                        (false, false) => {
2846                            tracing::error!(
2847                                "Failure when guest requested switch back to synthetic"
2848                            );
2849                            PrimaryChannelGuestVfState::DataPathSwitched
2850                        }
2851                    }
2852                }
2853                PrimaryChannelGuestVfState::Initializing
2854                | PrimaryChannelGuestVfState::Restoring(_) => {
2855                    panic!("Invalid guest VF state: {}", primary.guest_vf_state)
2856                }
2857                _ => break,
2858            };
2859        }
2860        Ok(None)
2861    }
2862
2863    /// Process a control message, writing the response to the provided receive
2864    /// buffer suballocation.
2865    fn handle_rndis_control_message(
2866        &mut self,
2867        primary: &mut PrimaryChannelState,
2868        buffers: &ChannelBuffers,
2869        message_type: u32,
2870        mut reader: impl MemoryRead + Clone,
2871        id: u32,
2872    ) -> Result<(), WorkerError> {
2873        let mem = &buffers.mem;
2874        let buffer_range = &buffers.recv_buffer.range(id);
2875        match message_type {
2876            rndisprot::MESSAGE_TYPE_INITIALIZE_MSG => {
2877                if primary.rndis_state != RndisState::Initializing {
2878                    return Err(WorkerError::InvalidRndisState);
2879                }
2880
2881                let request: rndisprot::InitializeRequest = reader.read_plain()?;
2882
2883                tracing::trace!(
2884                    ?request,
2885                    "handling control message MESSAGE_TYPE_INITIALIZE_MSG"
2886                );
2887
2888                primary.rndis_state = RndisState::Operational;
2889
2890                let mut writer = buffer_range.writer(mem);
2891                let message_length = write_rndis_message(
2892                    &mut writer,
2893                    rndisprot::MESSAGE_TYPE_INITIALIZE_CMPLT,
2894                    0,
2895                    &rndisprot::InitializeComplete {
2896                        request_id: request.request_id,
2897                        status: rndisprot::STATUS_SUCCESS,
2898                        major_version: rndisprot::MAJOR_VERSION,
2899                        minor_version: rndisprot::MINOR_VERSION,
2900                        device_flags: rndisprot::DF_CONNECTIONLESS,
2901                        medium: rndisprot::MEDIUM_802_3,
2902                        max_packets_per_message: 8,
2903                        max_transfer_size: 0xEFFFFFFF,
2904                        packet_alignment_factor: 3,
2905                        af_list_offset: 0,
2906                        af_list_size: 0,
2907                    },
2908                )?;
2909                self.send_rndis_control_message(buffers, id, message_length)?;
2910                if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
2911                    if self.guest_vf_is_available(
2912                        Some(vfid),
2913                        buffers.version,
2914                        buffers.ndis_config,
2915                    )? {
2916                        // Ideally the VF would not be presented to the guest
2917                        // until the completion packet has arrived, so that the
2918                        // guest is prepared. This is most interesting for the
2919                        // case of a VF associated with multiple guest
2920                        // adapters, using a concept like vports. In this
2921                        // scenario it would be better if all of the adapters
2922                        // were aware a VF was available before the device
2923                        // arrived. This is not currently possible because the
2924                        // Linux netvsc driver ignores the completion requested
2925                        // flag on inband packets and won't send a completion
2926                        // packet.
2927                        primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
2928                        self.send_coordinator_update_vf();
2929                    } else if let Some(true) = primary.is_data_path_switched {
2930                        tracing::error!(
2931                            "Data path switched, but current guest negotiation does not support VTL0 VF"
2932                        );
2933                    }
2934                }
2935            }
2936            rndisprot::MESSAGE_TYPE_QUERY_MSG => {
2937                let request: rndisprot::QueryRequest = reader.read_plain()?;
2938
2939                tracing::trace!(?request, "handling control message MESSAGE_TYPE_QUERY_MSG");
2940
2941                let (header, body) = buffer_range
2942                    .try_split(
2943                        size_of::<rndisprot::MessageHeader>()
2944                            + size_of::<rndisprot::QueryComplete>(),
2945                    )
2946                    .ok_or(WorkerError::RndisMessageTooSmall)?;
2947                let (status, tx) = match self.adapter.handle_oid_query(
2948                    buffers,
2949                    primary,
2950                    request.oid,
2951                    body.writer(mem),
2952                ) {
2953                    Ok(tx) => (rndisprot::STATUS_SUCCESS, tx),
2954                    Err(err) => (err.as_status(), 0),
2955                };
2956
2957                let message_length = write_rndis_message(
2958                    &mut header.writer(mem),
2959                    rndisprot::MESSAGE_TYPE_QUERY_CMPLT,
2960                    tx,
2961                    &rndisprot::QueryComplete {
2962                        request_id: request.request_id,
2963                        status,
2964                        information_buffer_offset: size_of::<rndisprot::QueryComplete>() as u32,
2965                        information_buffer_length: tx as u32,
2966                    },
2967                )?;
2968                self.send_rndis_control_message(buffers, id, message_length)?;
2969            }
2970            rndisprot::MESSAGE_TYPE_SET_MSG => {
2971                let request: rndisprot::SetRequest = reader.read_plain()?;
2972
2973                tracing::trace!(?request, "handling control message MESSAGE_TYPE_SET_MSG");
2974
2975                let status = match self.adapter.handle_oid_set(primary, request.oid, reader) {
2976                    Ok((restart_endpoint, packet_filter)) => {
2977                        // Restart the endpoint if the OID changed some critical
2978                        // endpoint property.
2979                        if restart_endpoint {
2980                            self.restart = Some(CoordinatorMessage::Restart);
2981                        }
2982                        if let Some(filter) = packet_filter {
2983                            if self.packet_filter != filter {
2984                                self.packet_filter = filter;
2985                                self.send_coordinator_update_filter();
2986                            }
2987                        }
2988                        rndisprot::STATUS_SUCCESS
2989                    }
2990                    Err(err) => {
2991                        tracelimit::warn_ratelimited!(
2992                            error = &err as &dyn std::error::Error,
2993                            oid = ?request.oid,
2994                            "oid set failure"
2995                        );
2996                        err.as_status()
2997                    }
2998                };
2999
3000                let message_length = write_rndis_message(
3001                    &mut buffer_range.writer(mem),
3002                    rndisprot::MESSAGE_TYPE_SET_CMPLT,
3003                    0,
3004                    &rndisprot::SetComplete {
3005                        request_id: request.request_id,
3006                        status,
3007                    },
3008                )?;
3009                self.send_rndis_control_message(buffers, id, message_length)?;
3010            }
3011            rndisprot::MESSAGE_TYPE_RESET_MSG => {
3012                return Err(WorkerError::RndisMessageTypeNotImplemented);
3013            }
3014            rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG => {
3015                return Err(WorkerError::RndisMessageTypeNotImplemented);
3016            }
3017            rndisprot::MESSAGE_TYPE_KEEPALIVE_MSG => {
3018                let request: rndisprot::KeepaliveRequest = reader.read_plain()?;
3019
3020                tracing::trace!(
3021                    ?request,
3022                    "handling control message MESSAGE_TYPE_KEEPALIVE_MSG"
3023                );
3024
3025                let message_length = write_rndis_message(
3026                    &mut buffer_range.writer(mem),
3027                    rndisprot::MESSAGE_TYPE_KEEPALIVE_CMPLT,
3028                    0,
3029                    &rndisprot::KeepaliveComplete {
3030                        request_id: request.request_id,
3031                        status: rndisprot::STATUS_SUCCESS,
3032                    },
3033                )?;
3034                self.send_rndis_control_message(buffers, id, message_length)?;
3035            }
3036            rndisprot::MESSAGE_TYPE_SET_EX_MSG => {
3037                return Err(WorkerError::RndisMessageTypeNotImplemented);
3038            }
3039            _ => return Err(WorkerError::UnknownRndisMessageType(message_type)),
3040        };
3041        Ok(())
3042    }
3043
3044    fn try_send_rndis_message(
3045        &mut self,
3046        transaction_id: u64,
3047        channel_type: u32,
3048        recv_buffer_id: u16,
3049        transfer_pages: &[ring::TransferPageRange],
3050    ) -> Result<Option<usize>, WorkerError> {
3051        let message = self.message(
3052            protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET,
3053            protocol::Message1SendRndisPacket {
3054                channel_type,
3055                send_buffer_section_index: 0xffffffff,
3056                send_buffer_section_size: 0,
3057            },
3058        );
3059        let pending_send_size = match self.queue.split().1.batched().try_write_aligned(
3060            transaction_id,
3061            OutgoingPacketType::TransferPages(recv_buffer_id, transfer_pages),
3062            message.aligned_payload(),
3063        ) {
3064            Ok(()) => None,
3065            Err(queue::TryWriteError::Full(n)) => Some(n),
3066            Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
3067        };
3068        Ok(pending_send_size)
3069    }
3070
3071    fn send_rndis_control_message(
3072        &mut self,
3073        buffers: &ChannelBuffers,
3074        id: u32,
3075        message_length: usize,
3076    ) -> Result<(), WorkerError> {
3077        let result = self.try_send_rndis_message(
3078            id as u64,
3079            protocol::CONTROL_CHANNEL_TYPE,
3080            buffers.recv_buffer.id,
3081            std::slice::from_ref(&buffers.recv_buffer.transfer_page_range(id, message_length)),
3082        )?;
3083
3084        // Ring size is checked before control messages are processed, so failure to write is unexpected.
3085        match result {
3086            None => Ok(()),
3087            Some(len) => {
3088                tracelimit::error_ratelimited!(len, "failed to write control message completion");
3089                Err(WorkerError::OutOfSpace)
3090            }
3091        }
3092    }
3093
3094    fn indicate_status(
3095        &mut self,
3096        buffers: &ChannelBuffers,
3097        id: u32,
3098        status: u32,
3099        payload: &[u8],
3100    ) -> Result<(), WorkerError> {
3101        let buffer = &buffers.recv_buffer.range(id);
3102        let mut writer = buffer.writer(&buffers.mem);
3103        let message_length = write_rndis_message(
3104            &mut writer,
3105            rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG,
3106            payload.len(),
3107            &rndisprot::IndicateStatus {
3108                status,
3109                status_buffer_length: payload.len() as u32,
3110                status_buffer_offset: if payload.is_empty() {
3111                    0
3112                } else {
3113                    size_of::<rndisprot::IndicateStatus>() as u32
3114                },
3115            },
3116        )?;
3117        writer.write(payload)?;
3118        self.send_rndis_control_message(buffers, id, message_length)?;
3119        Ok(())
3120    }
3121
3122    /// Processes pending control messages until all are processed or there are
3123    /// no available suballocations.
3124    fn process_control_messages(
3125        &mut self,
3126        buffers: &ChannelBuffers,
3127        state: &mut ActiveState,
3128    ) -> Result<(), WorkerError> {
3129        let Some(primary) = &mut state.primary else {
3130            return Ok(());
3131        };
3132
3133        while !primary.control_messages.is_empty()
3134            || (primary.pending_offload_change && primary.rndis_state == RndisState::Operational)
3135        {
3136            // Ensure the ring buffer has enough room to successfully complete control message handling.
3137            if !self.queue.split().1.can_write(MIN_CONTROL_RING_SIZE)? {
3138                self.pending_send_size = MIN_CONTROL_RING_SIZE;
3139                break;
3140            }
3141            let Some(id) = primary.free_control_buffers.pop() else {
3142                break;
3143            };
3144
3145            // Mark the receive buffer in use to allow the guest to release it.
3146            assert!(state.rx_bufs.is_free(id.0));
3147            state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
3148
3149            if let Some(message) = primary.control_messages.pop_front() {
3150                primary.control_messages_len -= message.data.len();
3151                self.handle_rndis_control_message(
3152                    primary,
3153                    buffers,
3154                    message.message_type,
3155                    message.data.as_ref(),
3156                    id.0,
3157                )?;
3158            } else if primary.pending_offload_change
3159                && primary.rndis_state == RndisState::Operational
3160            {
3161                let ndis_offload = primary.offload_config.ndis_offload();
3162                self.indicate_status(
3163                    buffers,
3164                    id.0,
3165                    rndisprot::STATUS_TASK_OFFLOAD_CURRENT_CONFIG,
3166                    &ndis_offload.as_bytes()[..ndis_offload.header.size.into()],
3167                )?;
3168                primary.pending_offload_change = false;
3169            } else {
3170                unreachable!();
3171            }
3172        }
3173        Ok(())
3174    }
3175
3176    fn send_coordinator_update_message(&mut self, guest_vf: bool, packet_filter: bool) {
3177        if self.restart.is_none() {
3178            self.restart = Some(CoordinatorMessage::Update(CoordinatorMessageUpdateType {
3179                guest_vf_state: guest_vf,
3180                filter_state: packet_filter,
3181            }));
3182        } else if let Some(CoordinatorMessage::Restart) = self.restart {
3183            // If a restart message is pending, do nothing.
3184            // A restart will try to switch the data path based on primary.guest_vf_state.
3185            // A restart will apply packet filter changes.
3186        } else if let Some(CoordinatorMessage::Update(ref mut update)) = self.restart {
3187            // Add the new update to the existing message.
3188            update.guest_vf_state |= guest_vf;
3189            update.filter_state |= packet_filter;
3190        }
3191    }
3192
3193    fn send_coordinator_update_vf(&mut self) {
3194        self.send_coordinator_update_message(true, false);
3195    }
3196
3197    fn send_coordinator_update_filter(&mut self) {
3198        self.send_coordinator_update_message(false, true);
3199    }
3200}
3201
3202/// Writes an RNDIS message to `writer`.
3203fn write_rndis_message<T: IntoBytes + Immutable + KnownLayout>(
3204    writer: &mut impl MemoryWrite,
3205    message_type: u32,
3206    extra: usize,
3207    payload: &T,
3208) -> Result<usize, AccessError> {
3209    let message_length = size_of::<rndisprot::MessageHeader>() + size_of_val(payload) + extra;
3210    writer.write(
3211        rndisprot::MessageHeader {
3212            message_type,
3213            message_length: message_length as u32,
3214        }
3215        .as_bytes(),
3216    )?;
3217    writer.write(payload.as_bytes())?;
3218    Ok(message_length)
3219}
3220
3221#[derive(Debug, Error)]
3222enum OidError {
3223    #[error(transparent)]
3224    Access(#[from] AccessError),
3225    #[error("unknown oid")]
3226    UnknownOid,
3227    #[error("invalid oid input, bad field {0}")]
3228    InvalidInput(&'static str),
3229    #[error("bad ndis version")]
3230    BadVersion,
3231    #[error("feature {0} not supported")]
3232    NotSupported(&'static str),
3233}
3234
3235impl OidError {
3236    fn as_status(&self) -> u32 {
3237        match self {
3238            OidError::UnknownOid | OidError::NotSupported(_) => rndisprot::STATUS_NOT_SUPPORTED,
3239            OidError::BadVersion => rndisprot::STATUS_BAD_VERSION,
3240            OidError::InvalidInput(_) => rndisprot::STATUS_INVALID_DATA,
3241            OidError::Access(_) => rndisprot::STATUS_FAILURE,
3242        }
3243    }
3244}
3245
3246const DEFAULT_MTU: u32 = 1514;
3247const MIN_MTU: u32 = DEFAULT_MTU;
3248const MAX_MTU: u32 = 9216;
3249
3250const ETHERNET_HEADER_LEN: u32 = 14;
3251
3252impl Adapter {
3253    fn get_guest_vf_serial_number(&self, vfid: u32) -> u32 {
3254        if let Some(guest_os_id) = self.get_guest_os_id.as_ref().map(|f| f()) {
3255            // For enlightened guests (which is only Windows at the moment), send the
3256            // adapter index, which was previously set as the vport serial number.
3257            if guest_os_id
3258                .microsoft()
3259                .unwrap_or(HvGuestOsMicrosoft::from(0))
3260                .os_id()
3261                == HvGuestOsMicrosoftIds::WINDOWS_NT.0
3262            {
3263                self.adapter_index
3264            } else {
3265                vfid
3266            }
3267        } else {
3268            vfid
3269        }
3270    }
3271
3272    fn handle_oid_query(
3273        &self,
3274        buffers: &ChannelBuffers,
3275        primary: &PrimaryChannelState,
3276        oid: rndisprot::Oid,
3277        mut writer: impl MemoryWrite,
3278    ) -> Result<usize, OidError> {
3279        tracing::debug!(?oid, "oid query");
3280        let available_len = writer.len();
3281        match oid {
3282            rndisprot::Oid::OID_GEN_SUPPORTED_LIST => {
3283                let supported_oids_common = &[
3284                    rndisprot::Oid::OID_GEN_SUPPORTED_LIST,
3285                    rndisprot::Oid::OID_GEN_HARDWARE_STATUS,
3286                    rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED,
3287                    rndisprot::Oid::OID_GEN_MEDIA_IN_USE,
3288                    rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD,
3289                    rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD,
3290                    rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE,
3291                    rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE,
3292                    rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE,
3293                    rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE,
3294                    rndisprot::Oid::OID_GEN_LINK_SPEED,
3295                    rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE,
3296                    rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE,
3297                    rndisprot::Oid::OID_GEN_VENDOR_ID,
3298                    rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION,
3299                    rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION,
3300                    rndisprot::Oid::OID_GEN_DRIVER_VERSION,
3301                    rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER,
3302                    rndisprot::Oid::OID_GEN_PROTOCOL_OPTIONS,
3303                    rndisprot::Oid::OID_GEN_MAC_OPTIONS,
3304                    rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS,
3305                    rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS,
3306                    rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES,
3307                    rndisprot::Oid::OID_GEN_FRIENDLY_NAME,
3308                    // Ethernet objects operation characteristics
3309                    rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS,
3310                    rndisprot::Oid::OID_802_3_CURRENT_ADDRESS,
3311                    rndisprot::Oid::OID_802_3_MULTICAST_LIST,
3312                    rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE,
3313                    // Ethernet objects statistics
3314                    rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT,
3315                    rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION,
3316                    rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS,
3317                    // PNP operations characteristics */
3318                    // rndisprot::Oid::OID_PNP_SET_POWER,
3319                    // rndisprot::Oid::OID_PNP_QUERY_POWER,
3320                    // RNDIS OIDS
3321                    rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER,
3322                ];
3323
3324                // NDIS6 OIDs
3325                let supported_oids_6 = &[
3326                    // Link State OID
3327                    rndisprot::Oid::OID_GEN_LINK_PARAMETERS,
3328                    rndisprot::Oid::OID_GEN_LINK_STATE,
3329                    rndisprot::Oid::OID_GEN_MAX_LINK_SPEED,
3330                    // NDIS 6 statistics OID
3331                    rndisprot::Oid::OID_GEN_BYTES_RCV,
3332                    rndisprot::Oid::OID_GEN_BYTES_XMIT,
3333                    // Offload related OID
3334                    rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS,
3335                    rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION,
3336                    rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES,
3337                    rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG,
3338                    // rndisprot::Oid::OID_802_3_ADD_MULTICAST_ADDRESS,
3339                    // rndisprot::Oid::OID_802_3_DELETE_MULTICAST_ADDRESS,
3340                ];
3341
3342                let supported_oids_63 = &[
3343                    rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES,
3344                    rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS,
3345                ];
3346
3347                match buffers.ndis_version.major {
3348                    5 => {
3349                        writer.write(supported_oids_common.as_bytes())?;
3350                    }
3351                    6 => {
3352                        writer.write(supported_oids_common.as_bytes())?;
3353                        writer.write(supported_oids_6.as_bytes())?;
3354                        if buffers.ndis_version.minor >= 30 {
3355                            writer.write(supported_oids_63.as_bytes())?;
3356                        }
3357                    }
3358                    _ => return Err(OidError::BadVersion),
3359                }
3360            }
3361            rndisprot::Oid::OID_GEN_HARDWARE_STATUS => {
3362                let status: u32 = 0; // NdisHardwareStatusReady
3363                writer.write(status.as_bytes())?;
3364            }
3365            rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED | rndisprot::Oid::OID_GEN_MEDIA_IN_USE => {
3366                writer.write(rndisprot::MEDIUM_802_3.as_bytes())?;
3367            }
3368            rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD
3369            | rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD
3370            | rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE => {
3371                let len: u32 = buffers.ndis_config.mtu - ETHERNET_HEADER_LEN;
3372                writer.write(len.as_bytes())?;
3373            }
3374            rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE
3375            | rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE
3376            | rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE => {
3377                let len: u32 = buffers.ndis_config.mtu;
3378                writer.write(len.as_bytes())?;
3379            }
3380            rndisprot::Oid::OID_GEN_LINK_SPEED => {
3381                let speed: u32 = (self.link_speed / 100) as u32; // In 100bps units
3382                writer.write(speed.as_bytes())?;
3383            }
3384            rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE
3385            | rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE => {
3386                // This value is meaningless for virtual NICs. Return what vmswitch returns.
3387                writer.write((256u32 * 1024).as_bytes())?
3388            }
3389            rndisprot::Oid::OID_GEN_VENDOR_ID => {
3390                // Like vmswitch, use the first N bytes of Microsoft's MAC address
3391                // prefix as the vendor ID.
3392                writer.write(0x0000155du32.as_bytes())?;
3393            }
3394            rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION => writer.write(b"Microsoft Corporation")?,
3395            rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION
3396            | rndisprot::Oid::OID_GEN_DRIVER_VERSION => {
3397                writer.write(0x0100u16.as_bytes())? // 1.0. Vmswitch reports 19.0 for Mn.
3398            }
3399            rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => writer.write(0u32.as_bytes())?,
3400            rndisprot::Oid::OID_GEN_MAC_OPTIONS => {
3401                let options: u32 = rndisprot::MAC_OPTION_COPY_LOOKAHEAD_DATA
3402                    | rndisprot::MAC_OPTION_TRANSFERS_NOT_PEND
3403                    | rndisprot::MAC_OPTION_NO_LOOPBACK;
3404                writer.write(options.as_bytes())?;
3405            }
3406            rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS => {
3407                writer.write(rndisprot::MEDIA_STATE_CONNECTED.as_bytes())?;
3408            }
3409            rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS => writer.write(u32::MAX.as_bytes())?,
3410            rndisprot::Oid::OID_GEN_FRIENDLY_NAME => {
3411                let name16: Vec<u16> = "Network Device".encode_utf16().collect();
3412                let mut name = rndisprot::FriendlyName::new_zeroed();
3413                name.name[..name16.len()].copy_from_slice(&name16);
3414                writer.write(name.as_bytes())?
3415            }
3416            rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS
3417            | rndisprot::Oid::OID_802_3_CURRENT_ADDRESS => {
3418                writer.write(&self.mac_address.to_bytes())?
3419            }
3420            rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE => {
3421                writer.write(0u32.as_bytes())?;
3422            }
3423            rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT
3424            | rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION
3425            | rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS => writer.write(0u32.as_bytes())?,
3426
3427            // NDIS6 OIDs:
3428            rndisprot::Oid::OID_GEN_LINK_STATE => {
3429                let link_state = rndisprot::LinkState {
3430                    header: rndisprot::NdisObjectHeader {
3431                        object_type: rndisprot::NdisObjectType::DEFAULT,
3432                        revision: 1,
3433                        size: size_of::<rndisprot::LinkState>() as u16,
3434                    },
3435                    media_connect_state: 1, /* MediaConnectStateConnected */
3436                    media_duplex_state: 0,  /* MediaDuplexStateUnknown */
3437                    padding: 0,
3438                    xmit_link_speed: self.link_speed,
3439                    rcv_link_speed: self.link_speed,
3440                    pause_functions: 0, /* NdisPauseFunctionsUnsupported */
3441                    auto_negotiation_flags: 0,
3442                };
3443                writer.write(link_state.as_bytes())?;
3444            }
3445            rndisprot::Oid::OID_GEN_MAX_LINK_SPEED => {
3446                let link_speed = rndisprot::LinkSpeed {
3447                    xmit: self.link_speed,
3448                    rcv: self.link_speed,
3449                };
3450                writer.write(link_speed.as_bytes())?;
3451            }
3452            rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES => {
3453                let ndis_offload = self.offload_support.ndis_offload();
3454                writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3455            }
3456            rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG => {
3457                let ndis_offload = primary.offload_config.ndis_offload();
3458                writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3459            }
3460            rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3461                writer.write(
3462                    &rndisprot::NdisOffloadEncapsulation {
3463                        header: rndisprot::NdisObjectHeader {
3464                            object_type: rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3465                            revision: 1,
3466                            size: rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1 as u16,
3467                        },
3468                        ipv4_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3469                        ipv4_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3470                        ipv4_header_size: ETHERNET_HEADER_LEN,
3471                        ipv6_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3472                        ipv6_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3473                        ipv6_header_size: ETHERNET_HEADER_LEN,
3474                    }
3475                    .as_bytes()[..rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1],
3476                )?;
3477            }
3478            rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES => {
3479                writer.write(
3480                    &rndisprot::NdisReceiveScaleCapabilities {
3481                        header: rndisprot::NdisObjectHeader {
3482                            object_type: rndisprot::NdisObjectType::RSS_CAPABILITIES,
3483                            revision: 2,
3484                            size: rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2
3485                                as u16,
3486                        },
3487                        capabilities_flags: rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV4
3488                            | rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV6
3489                            | rndisprot::NDIS_HASH_FUNCTION_TOEPLITZ,
3490                        number_of_interrupt_messages: 1,
3491                        number_of_receive_queues: self.max_queues.into(),
3492                        number_of_indirection_table_entries: if self.indirection_table_size != 0 {
3493                            self.indirection_table_size
3494                        } else {
3495                            // DPDK gets confused if the table size is zero,
3496                            // even if there is only one queue.
3497                            128
3498                        },
3499                        padding: 0,
3500                    }
3501                    .as_bytes()[..rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2],
3502                )?;
3503            }
3504            _ => {
3505                tracelimit::warn_ratelimited!(?oid, "query for unknown OID");
3506                return Err(OidError::UnknownOid);
3507            }
3508        };
3509        Ok(available_len - writer.len())
3510    }
3511
3512    fn handle_oid_set(
3513        &self,
3514        primary: &mut PrimaryChannelState,
3515        oid: rndisprot::Oid,
3516        reader: impl MemoryRead + Clone,
3517    ) -> Result<(bool, Option<u32>), OidError> {
3518        tracing::debug!(?oid, "oid set");
3519
3520        let mut restart_endpoint = false;
3521        let mut packet_filter = None;
3522        match oid {
3523            rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => {
3524                packet_filter = self.oid_set_packet_filter(reader)?;
3525            }
3526            rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS => {
3527                self.oid_set_offload_parameters(reader, primary)?;
3528            }
3529            rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3530                self.oid_set_offload_encapsulation(reader)?;
3531            }
3532            rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER => {
3533                self.oid_set_rndis_config_parameter(reader, primary)?;
3534            }
3535            rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES => {
3536                // TODO
3537            }
3538            rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS => {
3539                let rss_was_enabled = self.oid_set_rss_parameters(reader, primary)?;
3540
3541                // Endpoints cannot currently change RSS parameters without
3542                // being restarted. This was a limitation driven by some DPDK
3543                // PMDs, and should be fixed.
3544                //
3545                // Skip the restart if RSS was already disabled and the guest
3546                // is disabling it again — nothing has changed.
3547                if rss_was_enabled || primary.rss_state.is_some() {
3548                    restart_endpoint = true;
3549                }
3550            }
3551            _ => {
3552                tracelimit::warn_ratelimited!(?oid, "set of unknown OID");
3553                return Err(OidError::UnknownOid);
3554            }
3555        }
3556        Ok((restart_endpoint, packet_filter))
3557    }
3558
3559    fn oid_set_rss_parameters(
3560        &self,
3561        mut reader: impl MemoryRead + Clone,
3562        primary: &mut PrimaryChannelState,
3563    ) -> Result<bool, OidError> {
3564        // Vmswitch doesn't validate the NDIS header on this object, so read it manually.
3565        let mut params = rndisprot::NdisReceiveScaleParameters::new_zeroed();
3566        let len = reader.len().min(size_of_val(&params));
3567        reader.clone().read(&mut params.as_mut_bytes()[..len])?;
3568
3569        let rss_was_enabled = primary.rss_state.is_some();
3570
3571        if ((params.flags & NDIS_RSS_PARAM_FLAG_DISABLE_RSS) != 0)
3572            || ((params.hash_information & NDIS_HASH_FUNCTION_MASK) == 0)
3573        {
3574            primary.rss_state = None;
3575            return Ok(rss_was_enabled);
3576        }
3577
3578        if params.hash_secret_key_size != 40 {
3579            return Err(OidError::InvalidInput("hash_secret_key_size"));
3580        }
3581        if params.indirection_table_size % 4 != 0 {
3582            return Err(OidError::InvalidInput("indirection_table_size"));
3583        }
3584        let indirection_table_size =
3585            (params.indirection_table_size / 4).min(self.indirection_table_size) as usize;
3586        let mut key = [0; 40];
3587        let mut indirection_table = vec![0u32; self.indirection_table_size as usize];
3588        reader
3589            .clone()
3590            .skip(params.hash_secret_key_offset as usize)?
3591            .read(&mut key)?;
3592        reader
3593            .skip(params.indirection_table_offset as usize)?
3594            .read(indirection_table[..indirection_table_size].as_mut_bytes())?;
3595        tracelimit::info_ratelimited!(?indirection_table, "OID_GEN_RECEIVE_SCALE_PARAMETERS");
3596        if indirection_table
3597            .iter()
3598            .any(|&x| x >= self.max_queues as u32)
3599        {
3600            return Err(OidError::InvalidInput("indirection_table"));
3601        }
3602        let (indir_init, indir_uninit) = indirection_table.split_at_mut(indirection_table_size);
3603        for (src, dest) in std::iter::repeat_with(|| indir_init.iter().copied())
3604            .flatten()
3605            .zip(indir_uninit)
3606        {
3607            *dest = src;
3608        }
3609        primary.rss_state = Some(RssState {
3610            key,
3611            indirection_table: indirection_table.iter().map(|&x| x as u16).collect(),
3612        });
3613        Ok(rss_was_enabled)
3614    }
3615
3616    fn oid_set_packet_filter(
3617        &self,
3618        reader: impl MemoryRead + Clone,
3619    ) -> Result<Option<u32>, OidError> {
3620        let filter: rndisprot::RndisPacketFilterOidValue = reader.clone().read_plain()?;
3621        tracing::debug!(filter, "set packet filter");
3622        Ok(Some(filter))
3623    }
3624
3625    fn oid_set_offload_parameters(
3626        &self,
3627        reader: impl MemoryRead + Clone,
3628        primary: &mut PrimaryChannelState,
3629    ) -> Result<(), OidError> {
3630        let offload: rndisprot::NdisOffloadParameters = read_ndis_object(
3631            reader,
3632            rndisprot::NdisObjectType::DEFAULT,
3633            1,
3634            rndisprot::NDIS_SIZEOF_OFFLOAD_PARAMETERS_REVISION_1,
3635        )?;
3636
3637        tracing::debug!(?offload, "offload parameters");
3638        let rndisprot::NdisOffloadParameters {
3639            header: _,
3640            ipv4_checksum,
3641            tcp4_checksum,
3642            udp4_checksum,
3643            tcp6_checksum,
3644            udp6_checksum,
3645            lsov1,
3646            ipsec_v1: _,
3647            lsov2_ipv4,
3648            lsov2_ipv6,
3649            tcp_connection_ipv4: _,
3650            tcp_connection_ipv6: _,
3651            reserved: _,
3652            flags: _,
3653        } = offload;
3654
3655        if lsov1 == rndisprot::OffloadParametersSimple::ENABLED {
3656            return Err(OidError::NotSupported("lsov1"));
3657        }
3658        if let Some((tx, rx)) = ipv4_checksum.tx_rx() {
3659            primary.offload_config.checksum_tx.ipv4_header = tx;
3660            primary.offload_config.checksum_rx.ipv4_header = rx;
3661        }
3662        if let Some((tx, rx)) = tcp4_checksum.tx_rx() {
3663            primary.offload_config.checksum_tx.tcp4 = tx;
3664            primary.offload_config.checksum_rx.tcp4 = rx;
3665        }
3666        if let Some((tx, rx)) = tcp6_checksum.tx_rx() {
3667            primary.offload_config.checksum_tx.tcp6 = tx;
3668            primary.offload_config.checksum_rx.tcp6 = rx;
3669        }
3670        if let Some((tx, rx)) = udp4_checksum.tx_rx() {
3671            primary.offload_config.checksum_tx.udp4 = tx;
3672            primary.offload_config.checksum_rx.udp4 = rx;
3673        }
3674        if let Some((tx, rx)) = udp6_checksum.tx_rx() {
3675            primary.offload_config.checksum_tx.udp6 = tx;
3676            primary.offload_config.checksum_rx.udp6 = rx;
3677        }
3678        if let Some(enable) = lsov2_ipv4.enable() {
3679            primary.offload_config.lso4 = enable;
3680        }
3681        if let Some(enable) = lsov2_ipv6.enable() {
3682            primary.offload_config.lso6 = enable;
3683        }
3684        primary
3685            .offload_config
3686            .mask_to_supported(&self.offload_support);
3687        primary.pending_offload_change = true;
3688        Ok(())
3689    }
3690
3691    fn oid_set_offload_encapsulation(
3692        &self,
3693        reader: impl MemoryRead + Clone,
3694    ) -> Result<(), OidError> {
3695        let encap: rndisprot::NdisOffloadEncapsulation = read_ndis_object(
3696            reader,
3697            rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3698            1,
3699            rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1,
3700        )?;
3701        if encap.ipv4_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3702            && (encap.ipv4_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3703                || encap.ipv4_header_size != ETHERNET_HEADER_LEN)
3704        {
3705            return Err(OidError::NotSupported("ipv4 encap"));
3706        }
3707        if encap.ipv6_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3708            && (encap.ipv6_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3709                || encap.ipv6_header_size != ETHERNET_HEADER_LEN)
3710        {
3711            return Err(OidError::NotSupported("ipv6 encap"));
3712        }
3713        Ok(())
3714    }
3715
3716    fn oid_set_rndis_config_parameter(
3717        &self,
3718        reader: impl MemoryRead + Clone,
3719        primary: &mut PrimaryChannelState,
3720    ) -> Result<(), OidError> {
3721        let info: rndisprot::RndisConfigParameterInfo = reader.clone().read_plain()?;
3722        if info.name_length > 255 {
3723            return Err(OidError::InvalidInput("name_length"));
3724        }
3725        if info.value_length > 255 {
3726            return Err(OidError::InvalidInput("value_length"));
3727        }
3728        let name = reader
3729            .clone()
3730            .skip(info.name_offset as usize)?
3731            .read_n::<u16>(info.name_length as usize / 2)?;
3732        let name = String::from_utf16(&name).map_err(|_| OidError::InvalidInput("name"))?;
3733        let mut value = reader;
3734        value.skip(info.value_offset as usize)?;
3735        let mut value = value.limit(info.value_length as usize);
3736        match info.parameter_type {
3737            rndisprot::NdisParameterType::STRING => {
3738                let value = value.read_n::<u16>(info.value_length as usize / 2)?;
3739                let value =
3740                    String::from_utf16(&value).map_err(|_| OidError::InvalidInput("value"))?;
3741                let as_num = value.as_bytes().first().map_or(0, |c| c - b'0');
3742                let tx = as_num & 1 != 0;
3743                let rx = as_num & 2 != 0;
3744
3745                tracing::debug!(name, value, "rndis config");
3746                match name.as_str() {
3747                    "*IPChecksumOffloadIPv4" => {
3748                        primary.offload_config.checksum_tx.ipv4_header = tx;
3749                        primary.offload_config.checksum_rx.ipv4_header = rx;
3750                    }
3751                    "*LsoV2IPv4" => {
3752                        primary.offload_config.lso4 = as_num != 0;
3753                    }
3754                    "*LsoV2IPv6" => {
3755                        primary.offload_config.lso6 = as_num != 0;
3756                    }
3757                    "*TCPChecksumOffloadIPv4" => {
3758                        primary.offload_config.checksum_tx.tcp4 = tx;
3759                        primary.offload_config.checksum_rx.tcp4 = rx;
3760                    }
3761                    "*TCPChecksumOffloadIPv6" => {
3762                        primary.offload_config.checksum_tx.tcp6 = tx;
3763                        primary.offload_config.checksum_rx.tcp6 = rx;
3764                    }
3765                    "*UDPChecksumOffloadIPv4" => {
3766                        primary.offload_config.checksum_tx.udp4 = tx;
3767                        primary.offload_config.checksum_rx.udp4 = rx;
3768                    }
3769                    "*UDPChecksumOffloadIPv6" => {
3770                        primary.offload_config.checksum_tx.udp6 = tx;
3771                        primary.offload_config.checksum_rx.udp6 = rx;
3772                    }
3773                    _ => {}
3774                }
3775                primary
3776                    .offload_config
3777                    .mask_to_supported(&self.offload_support);
3778            }
3779            rndisprot::NdisParameterType::INTEGER => {
3780                let value: u32 = value.read_plain()?;
3781                tracing::debug!(name, value, "rndis config");
3782            }
3783            parameter_type => tracelimit::warn_ratelimited!(
3784                name,
3785                ?parameter_type,
3786                "unhandled rndis config parameter type"
3787            ),
3788        }
3789        Ok(())
3790    }
3791}
3792
3793fn read_ndis_object<T: IntoBytes + FromBytes + Debug + Immutable + KnownLayout>(
3794    mut reader: impl MemoryRead,
3795    object_type: rndisprot::NdisObjectType,
3796    min_revision: u8,
3797    min_size: usize,
3798) -> Result<T, OidError> {
3799    let mut buffer = T::new_zeroed();
3800    let sent_size = reader.len();
3801    let len = sent_size.min(size_of_val(&buffer));
3802    reader.read(&mut buffer.as_mut_bytes()[..len])?;
3803    validate_ndis_object_header(
3804        &rndisprot::NdisObjectHeader::read_from_prefix(buffer.as_bytes())
3805            .unwrap()
3806            .0, // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
3807        sent_size,
3808        object_type,
3809        min_revision,
3810        min_size,
3811    )?;
3812    Ok(buffer)
3813}
3814
3815fn validate_ndis_object_header(
3816    header: &rndisprot::NdisObjectHeader,
3817    sent_size: usize,
3818    object_type: rndisprot::NdisObjectType,
3819    min_revision: u8,
3820    min_size: usize,
3821) -> Result<(), OidError> {
3822    if header.object_type != object_type {
3823        return Err(OidError::InvalidInput("header.object_type"));
3824    }
3825    if sent_size < header.size as usize {
3826        return Err(OidError::InvalidInput("header.size"));
3827    }
3828    if header.revision < min_revision {
3829        return Err(OidError::InvalidInput("header.revision"));
3830    }
3831    if (header.size as usize) < min_size {
3832        return Err(OidError::InvalidInput("header.size"));
3833    }
3834    Ok(())
3835}
3836
3837struct Coordinator {
3838    recv: mpsc::Receiver<CoordinatorMessage>,
3839    channel_control: ChannelControl,
3840    restart: bool,
3841    workers: Vec<TaskControl<NetQueue, Worker<GpadlRingMem>>>,
3842    buffers: Option<Arc<ChannelBuffers>>,
3843    num_queues: u16,
3844    active_packet_filter: u32,
3845    sleep_deadline: Option<Instant>,
3846}
3847
3848/// Removing the VF may result in the guest sending messages to switch the data
3849/// path, so these operations need to happen asynchronously with message
3850/// processing.
3851enum CoordinatorStatePendingVfState {
3852    /// No pending updates.
3853    Ready,
3854    /// Delay before adding VF.
3855    Delay {
3856        timer: PolledTimer,
3857        delay_until: Instant,
3858    },
3859    /// A VF update is pending.
3860    Pending,
3861}
3862
3863struct CoordinatorState {
3864    endpoint: Box<dyn Endpoint>,
3865    adapter: Arc<Adapter>,
3866    virtual_function: Option<Box<dyn VirtualFunction>>,
3867    pending_vf_state: CoordinatorStatePendingVfState,
3868}
3869
3870impl InspectTaskMut<Coordinator> for CoordinatorState {
3871    fn inspect_mut(
3872        &mut self,
3873        req: inspect::Request<'_>,
3874        mut coordinator: Option<&mut Coordinator>,
3875    ) {
3876        let mut resp = req.respond();
3877
3878        let adapter = self.adapter.as_ref();
3879        resp.field("mac_address", adapter.mac_address)
3880            .field("max_queues", adapter.max_queues)
3881            .sensitivity_field(
3882                "offload_support",
3883                SensitivityLevel::Safe,
3884                &adapter.offload_support,
3885            )
3886            .field_mut_with("ring_size_limit", |v| -> anyhow::Result<_> {
3887                if let Some(v) = v {
3888                    let v: usize = v.parse()?;
3889                    adapter.ring_size_limit.store(v, Ordering::Relaxed);
3890                    // Bounce each task so that it sees the new value.
3891                    if let Some(this) = &mut coordinator {
3892                        for worker in &mut this.workers {
3893                            worker.update_with(|_, _| ());
3894                        }
3895                    }
3896                }
3897                Ok(adapter.ring_size_limit.load(Ordering::Relaxed))
3898            });
3899
3900        resp.field("endpoint_type", self.endpoint.endpoint_type())
3901            .field(
3902                "endpoint_max_queues",
3903                self.endpoint.multiqueue_support().max_queues,
3904            )
3905            .sensitivity_field_mut("endpoint", SensitivityLevel::Safe, self.endpoint.as_mut());
3906
3907        if let Some(coordinator) = coordinator {
3908            resp.sensitivity_child("queues", SensitivityLevel::Safe, |req| {
3909                let mut resp = req.respond();
3910                for (i, q) in coordinator.workers[..coordinator.num_queues as usize]
3911                    .iter_mut()
3912                    .enumerate()
3913                {
3914                    resp.field_mut(&i.to_string(), q);
3915                }
3916            });
3917
3918            // Get the shared channel state from the primary channel.
3919            resp.merge(inspect::adhoc_mut(|req| {
3920                let deferred = req.defer();
3921                coordinator.workers[0].update_with(|_, worker| {
3922                    let Some(worker) = worker.as_deref() else {
3923                        return;
3924                    };
3925                    if let Some(state) = worker.state.ready() {
3926                        deferred.respond(|resp| {
3927                            resp.merge(&state.buffers);
3928                            resp.sensitivity_field(
3929                                "primary_channel_state",
3930                                SensitivityLevel::Safe,
3931                                &state.state.primary,
3932                            )
3933                            .sensitivity_field(
3934                                "packet_filter",
3935                                SensitivityLevel::Safe,
3936                                inspect::AsHex(worker.channel.packet_filter),
3937                            );
3938                        });
3939                    }
3940                })
3941            }));
3942        }
3943    }
3944}
3945
3946impl AsyncRun<Coordinator> for CoordinatorState {
3947    async fn run(
3948        &mut self,
3949        stop: &mut StopTask<'_>,
3950        coordinator: &mut Coordinator,
3951    ) -> Result<(), task_control::Cancelled> {
3952        coordinator.process(stop, self).await
3953    }
3954}
3955
3956impl Coordinator {
3957    async fn process(
3958        &mut self,
3959        stop: &mut StopTask<'_>,
3960        state: &mut CoordinatorState,
3961    ) -> Result<(), task_control::Cancelled> {
3962        loop {
3963            if self.restart {
3964                stop.until_stopped(self.stop_workers()).await?;
3965                // The queue restart operation is not restartable, so do not
3966                // poll on `stop` here.
3967                if let Err(err) = self
3968                    .restart_queues(state)
3969                    .instrument(tracing::info_span!("netvsp_restart_queues"))
3970                    .await
3971                {
3972                    tracing::error!(
3973                        error = &err as &dyn std::error::Error,
3974                        "failed to restart queues"
3975                    );
3976                }
3977                if let Some(primary) = self.primary_mut() {
3978                    primary.is_data_path_switched =
3979                        state.endpoint.get_data_path_to_guest_vf().await.ok();
3980                    tracing::info!(
3981                        is_data_path_switched = primary.is_data_path_switched,
3982                        "Query data path state"
3983                    );
3984                }
3985                self.restore_guest_vf_state(state).await;
3986                self.restart = false;
3987            }
3988
3989            // Ensure that all workers except the primary are started. The
3990            // primary is started below if there are no outstanding messages.
3991            for worker in &mut self.workers[1..] {
3992                worker.start();
3993            }
3994            if !self.workers[0].is_running()
3995                && self.workers[0].state().is_none_or(|worker| {
3996                    !matches!(worker.state, WorkerState::WaitingForCoordinator(_))
3997                })
3998            {
3999                self.workers[0].start();
4000            }
4001
4002            enum Message {
4003                Internal(CoordinatorMessage),
4004                ChannelDisconnected,
4005                UpdateFromEndpoint(EndpointAction),
4006                UpdateFromVf(Rpc<(), ()>),
4007                OfferVfDevice,
4008                PendingVfStateComplete,
4009                TimerExpired,
4010            }
4011            let message = if matches!(
4012                state.pending_vf_state,
4013                CoordinatorStatePendingVfState::Pending
4014            ) {
4015                // guest_ready_for_device is not restartable, so do not poll on
4016                // stop.
4017                state
4018                    .virtual_function
4019                    .as_mut()
4020                    .expect("Pending requires a VF")
4021                    .guest_ready_for_device()
4022                    .await;
4023                Message::PendingVfStateComplete
4024            } else {
4025                let timer_sleep = async {
4026                    if let Some(deadline) = self.sleep_deadline {
4027                        let mut timer = PolledTimer::new(&state.adapter.driver);
4028                        timer.sleep_until(deadline).await;
4029                    } else {
4030                        pending::<()>().await;
4031                    }
4032                    Message::TimerExpired
4033                };
4034                let wait_for_message = async {
4035                    let internal_msg = self
4036                        .recv
4037                        .next()
4038                        .map(|x| x.map_or(Message::ChannelDisconnected, Message::Internal));
4039                    let endpoint_restart = state
4040                        .endpoint
4041                        .wait_for_endpoint_action()
4042                        .map(Message::UpdateFromEndpoint);
4043                    if let Some(vf) = state.virtual_function.as_mut() {
4044                        match state.pending_vf_state {
4045                            CoordinatorStatePendingVfState::Ready
4046                            | CoordinatorStatePendingVfState::Delay { .. } => {
4047                                let offer_device = async {
4048                                    if let CoordinatorStatePendingVfState::Delay {
4049                                        timer,
4050                                        delay_until,
4051                                    } = &mut state.pending_vf_state
4052                                    {
4053                                        timer.sleep_until(*delay_until).await;
4054                                    } else {
4055                                        pending::<()>().await;
4056                                    }
4057                                    Message::OfferVfDevice
4058                                };
4059                                (
4060                                    internal_msg,
4061                                    offer_device,
4062                                    endpoint_restart,
4063                                    vf.wait_for_state_change().map(Message::UpdateFromVf),
4064                                    timer_sleep,
4065                                )
4066                                    .race()
4067                                    .await
4068                            }
4069                            CoordinatorStatePendingVfState::Pending => unreachable!(),
4070                        }
4071                    } else {
4072                        (internal_msg, endpoint_restart, timer_sleep).race().await
4073                    }
4074                };
4075
4076                stop.until_stopped(wait_for_message).await?
4077            };
4078            match message {
4079                Message::Internal(msg) => {
4080                    self.handle_coordinator_message(msg, state).await;
4081                }
4082                Message::UpdateFromVf(rpc) => {
4083                    rpc.handle(async |_| {
4084                        self.update_guest_vf_state(state).await;
4085                    })
4086                    .await;
4087                }
4088                Message::OfferVfDevice => {
4089                    self.workers[0].stop().await;
4090                    if let Some(primary) = self.primary_mut() {
4091                        if matches!(
4092                            primary.guest_vf_state,
4093                            PrimaryChannelGuestVfState::AvailableAdvertised
4094                        ) {
4095                            primary.guest_vf_state = PrimaryChannelGuestVfState::Ready;
4096                        }
4097                    }
4098
4099                    state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4100                }
4101                Message::PendingVfStateComplete => {
4102                    state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
4103                }
4104                Message::TimerExpired => {
4105                    // Kick the worker as requested.
4106                    self.workers[0].stop().await;
4107                    if let Some(primary) = self.primary_mut() {
4108                        if let PendingLinkAction::Delay(up) = primary.pending_link_action {
4109                            primary.pending_link_action = PendingLinkAction::Active(up);
4110                        }
4111                    }
4112                    self.sleep_deadline = None;
4113                }
4114                Message::UpdateFromEndpoint(EndpointAction::RestartRequired) => self.restart = true,
4115                Message::UpdateFromEndpoint(EndpointAction::LinkStatusNotify(connect)) => {
4116                    self.workers[0].stop().await;
4117
4118                    // These are the only link state transitions that are tracked.
4119                    // 1. up -> down or down -> up
4120                    // 2. up -> down -> up or down -> up -> down.
4121                    // All other state transitions are coalesced into one of the above cases.
4122                    // For example, up -> down -> up -> down is treated as up -> down.
4123                    // N.B - Always queue up the incoming state to minimize the effects of loss
4124                    //       of any notifications (for example, during vtl2 servicing).
4125                    if let Some(primary) = self.primary_mut() {
4126                        primary.pending_link_action = PendingLinkAction::Active(connect);
4127                    }
4128
4129                    // If there is any existing sleep timer running, cancel it out.
4130                    self.sleep_deadline = None;
4131                }
4132                Message::ChannelDisconnected => {
4133                    break;
4134                }
4135            };
4136        }
4137        Ok(())
4138    }
4139
4140    async fn handle_coordinator_message(
4141        &mut self,
4142        msg: CoordinatorMessage,
4143        state: &mut CoordinatorState,
4144    ) {
4145        self.workers[0].stop().await;
4146        if let Some(worker) = self.workers[0].state_mut() {
4147            if matches!(worker.state, WorkerState::WaitingForCoordinator(_)) {
4148                let WorkerState::WaitingForCoordinator(Some(ready)) =
4149                    std::mem::replace(&mut worker.state, WorkerState::WaitingForCoordinator(None))
4150                else {
4151                    unreachable!("valid ready state")
4152                };
4153                let _ = std::mem::replace(&mut worker.state, WorkerState::Ready(ready));
4154            }
4155        }
4156        match msg {
4157            CoordinatorMessage::Update(update_type) => {
4158                if update_type.filter_state {
4159                    self.stop_workers().await;
4160                    self.active_packet_filter =
4161                        self.workers[0].state().unwrap().channel.packet_filter;
4162                    self.workers.iter_mut().skip(1).for_each(|worker| {
4163                        if let Some(state) = worker.state_mut() {
4164                            state.channel.packet_filter = self.active_packet_filter;
4165                            tracing::debug!(
4166                                packet_filter = ?self.active_packet_filter,
4167                                channel_idx = state.channel_idx,
4168                                "update packet filter"
4169                            );
4170                        }
4171                    });
4172                }
4173
4174                if update_type.guest_vf_state {
4175                    self.update_guest_vf_state(state).await;
4176                }
4177            }
4178            CoordinatorMessage::StartTimer(deadline) => {
4179                self.sleep_deadline = Some(deadline);
4180            }
4181            CoordinatorMessage::Restart => self.restart = true,
4182        }
4183    }
4184
4185    async fn stop_workers(&mut self) {
4186        for worker in &mut self.workers {
4187            worker.stop().await;
4188        }
4189    }
4190
4191    async fn restore_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
4192        let primary = match self.primary_mut() {
4193            Some(primary) => primary,
4194            None => return,
4195        };
4196
4197        // Update guest VF state based on current endpoint properties.
4198        let virtual_function = c_state.virtual_function.as_mut();
4199        let guest_vf_id = match &virtual_function {
4200            Some(vf) => vf.id().await,
4201            None => None,
4202        };
4203        if let Some(guest_vf_id) = guest_vf_id {
4204            // Ensure guest VF is in proper state.
4205            match primary.guest_vf_state {
4206                PrimaryChannelGuestVfState::AvailableAdvertised
4207                | PrimaryChannelGuestVfState::Restoring(
4208                    saved_state::GuestVfState::AvailableAdvertised,
4209                ) => {
4210                    if !primary.is_data_path_switched.unwrap_or(false) {
4211                        let timer = PolledTimer::new(&c_state.adapter.driver);
4212                        c_state.pending_vf_state = CoordinatorStatePendingVfState::Delay {
4213                            timer,
4214                            delay_until: Instant::now() + VF_DEVICE_DELAY,
4215                        };
4216                    }
4217                }
4218                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4219                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4220                | PrimaryChannelGuestVfState::Ready
4221                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready)
4222                | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4223                | PrimaryChannelGuestVfState::Restoring(
4224                    saved_state::GuestVfState::DataPathSwitchPending { .. },
4225                )
4226                | PrimaryChannelGuestVfState::DataPathSwitched
4227                | PrimaryChannelGuestVfState::Restoring(
4228                    saved_state::GuestVfState::DataPathSwitched,
4229                )
4230                | PrimaryChannelGuestVfState::DataPathSynthetic => {
4231                    c_state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4232                }
4233                _ => (),
4234            };
4235            // ensure data path is switched as expected
4236            if let PrimaryChannelGuestVfState::Restoring(
4237                saved_state::GuestVfState::DataPathSwitchPending {
4238                    to_guest,
4239                    id,
4240                    result,
4241                },
4242            ) = primary.guest_vf_state
4243            {
4244                // If the save was after the data path switch already occurred, don't do it again.
4245                if result.is_some() {
4246                    primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
4247                        to_guest,
4248                        id,
4249                        result,
4250                    };
4251                    return;
4252                }
4253            }
4254            primary.guest_vf_state = match primary.guest_vf_state {
4255                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4256                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4257                | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4258                | PrimaryChannelGuestVfState::Restoring(
4259                    saved_state::GuestVfState::DataPathSwitchPending { .. },
4260                )
4261                | PrimaryChannelGuestVfState::DataPathSynthetic => {
4262                    let (to_guest, id) = match primary.guest_vf_state {
4263                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4264                            to_guest,
4265                            id,
4266                        }
4267                        | PrimaryChannelGuestVfState::DataPathSwitchPending {
4268                            to_guest, id, ..
4269                        }
4270                        | PrimaryChannelGuestVfState::Restoring(
4271                            saved_state::GuestVfState::DataPathSwitchPending {
4272                                to_guest, id, ..
4273                            },
4274                        ) => (to_guest, id),
4275                        _ => (true, None),
4276                    };
4277                    // Cancel any outstanding delay timers for VF offers if the data path is
4278                    // getting switched, since the guest is already issuing
4279                    // commands assuming a VF.
4280                    if matches!(
4281                        c_state.pending_vf_state,
4282                        CoordinatorStatePendingVfState::Delay { .. }
4283                    ) {
4284                        c_state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4285                    }
4286                    let result = c_state.endpoint.set_data_path_to_guest_vf(to_guest).await;
4287                    let result = if let Err(err) = result {
4288                        tracing::error!(
4289                            err = err.as_ref() as &dyn std::error::Error,
4290                            to_guest,
4291                            "Failed to switch guest VF data path"
4292                        );
4293                        false
4294                    } else {
4295                        primary.is_data_path_switched = Some(to_guest);
4296                        true
4297                    };
4298                    match primary.guest_vf_state {
4299                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4300                            ..
4301                        }
4302                        | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4303                        | PrimaryChannelGuestVfState::Restoring(
4304                            saved_state::GuestVfState::DataPathSwitchPending { .. },
4305                        ) => PrimaryChannelGuestVfState::DataPathSwitchPending {
4306                            to_guest,
4307                            id,
4308                            result: Some(result),
4309                        },
4310                        _ if result => PrimaryChannelGuestVfState::DataPathSwitched,
4311                        _ => PrimaryChannelGuestVfState::DataPathSynthetic,
4312                    }
4313                }
4314                PrimaryChannelGuestVfState::Initializing
4315                | PrimaryChannelGuestVfState::Unavailable
4316                | PrimaryChannelGuestVfState::UnavailableFromAvailable
4317                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
4318                    PrimaryChannelGuestVfState::Available { vfid: guest_vf_id }
4319                }
4320                PrimaryChannelGuestVfState::AvailableAdvertised
4321                | PrimaryChannelGuestVfState::Restoring(
4322                    saved_state::GuestVfState::AvailableAdvertised,
4323                ) => {
4324                    if !primary.is_data_path_switched.unwrap_or(false) {
4325                        PrimaryChannelGuestVfState::AvailableAdvertised
4326                    } else {
4327                        // A previous instantiation already switched the data
4328                        // path.
4329                        PrimaryChannelGuestVfState::DataPathSwitched
4330                    }
4331                }
4332                PrimaryChannelGuestVfState::DataPathSwitched
4333                | PrimaryChannelGuestVfState::Restoring(
4334                    saved_state::GuestVfState::DataPathSwitched,
4335                ) => PrimaryChannelGuestVfState::DataPathSwitched,
4336                PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4337                    PrimaryChannelGuestVfState::Ready
4338                }
4339                _ => primary.guest_vf_state,
4340            };
4341        } else {
4342            // If the device was just removed, make sure the data path is synthetic.
4343            match primary.guest_vf_state {
4344                PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, .. }
4345                | PrimaryChannelGuestVfState::Restoring(
4346                    saved_state::GuestVfState::DataPathSwitchPending { to_guest, .. },
4347                ) => {
4348                    if !to_guest {
4349                        if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4350                            tracing::warn!(
4351                                err = err.as_ref() as &dyn std::error::Error,
4352                                "Failed setting data path back to synthetic after guest VF was removed."
4353                            );
4354                        }
4355                        primary.is_data_path_switched = Some(false);
4356                    }
4357                }
4358                PrimaryChannelGuestVfState::DataPathSwitched
4359                | PrimaryChannelGuestVfState::Restoring(
4360                    saved_state::GuestVfState::DataPathSwitched,
4361                ) => {
4362                    if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4363                        tracing::warn!(
4364                            err = err.as_ref() as &dyn std::error::Error,
4365                            "Failed setting data path back to synthetic after guest VF was removed."
4366                        );
4367                    }
4368                    primary.is_data_path_switched = Some(false);
4369                }
4370                _ => (),
4371            }
4372            if let PrimaryChannelGuestVfState::AvailableAdvertised = primary.guest_vf_state {
4373                c_state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
4374            }
4375            // Notify guest if VF is no longer available
4376            primary.guest_vf_state = match primary.guest_vf_state {
4377                PrimaryChannelGuestVfState::Initializing
4378                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState)
4379                | PrimaryChannelGuestVfState::Available { .. } => {
4380                    PrimaryChannelGuestVfState::Unavailable
4381                }
4382                PrimaryChannelGuestVfState::AvailableAdvertised
4383                | PrimaryChannelGuestVfState::Restoring(
4384                    saved_state::GuestVfState::AvailableAdvertised,
4385                )
4386                | PrimaryChannelGuestVfState::Ready
4387                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4388                    PrimaryChannelGuestVfState::UnavailableFromAvailable
4389                }
4390                PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, id, .. }
4391                | PrimaryChannelGuestVfState::Restoring(
4392                    saved_state::GuestVfState::DataPathSwitchPending { to_guest, id, .. },
4393                ) => PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4394                    to_guest,
4395                    id,
4396                },
4397                PrimaryChannelGuestVfState::DataPathSwitched
4398                | PrimaryChannelGuestVfState::Restoring(
4399                    saved_state::GuestVfState::DataPathSwitched,
4400                )
4401                | PrimaryChannelGuestVfState::DataPathSynthetic => {
4402                    PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4403                }
4404                _ => primary.guest_vf_state,
4405            }
4406        }
4407    }
4408
4409    async fn restart_queues(&mut self, c_state: &mut CoordinatorState) -> Result<(), WorkerError> {
4410        // Drop all the queues and stop the endpoint. Collect the worker drivers to pass to the queues.
4411        let drivers = self
4412            .workers
4413            .iter_mut()
4414            .map(|worker| {
4415                let task = worker.task_mut();
4416                task.queue_state = None;
4417                task.driver.clone()
4418            })
4419            .collect::<Vec<_>>();
4420
4421        c_state.endpoint.stop().await;
4422
4423        let (primary_worker, subworkers) = if let [primary, sub @ ..] = self.workers.as_mut_slice()
4424        {
4425            (primary, sub)
4426        } else {
4427            unreachable!()
4428        };
4429
4430        let state = primary_worker
4431            .state_mut()
4432            .and_then(|worker| worker.state.ready_mut());
4433
4434        let state = if let Some(state) = state {
4435            state
4436        } else {
4437            return Ok(());
4438        };
4439
4440        // Save the channel buffers for use in the subchannel workers.
4441        self.buffers = Some(state.buffers.clone());
4442
4443        let num_queues = state.state.primary.as_ref().unwrap().requested_num_queues;
4444        let mut active_queues = Vec::new();
4445        let active_queue_count =
4446            if let Some(rss_state) = state.state.primary.as_ref().unwrap().rss_state.as_ref() {
4447                // Active queue count is computed as the number of unique entries in the indirection table
4448                active_queues.clone_from(&rss_state.indirection_table);
4449                active_queues.sort();
4450                active_queues.dedup();
4451                active_queues = active_queues
4452                    .into_iter()
4453                    .filter(|&index| index < num_queues)
4454                    .collect::<Vec<_>>();
4455                if !active_queues.is_empty() {
4456                    active_queues.len() as u16
4457                } else {
4458                    tracelimit::warn_ratelimited!("Invalid RSS indirection table");
4459                    num_queues
4460                }
4461            } else {
4462                num_queues
4463            };
4464
4465        // Distribute the rx buffers to only the active queues.
4466        let (ranges, mut remote_buffer_id_recvs) =
4467            RxBufferRanges::new(state.buffers.recv_buffer.count, active_queue_count.into());
4468        let ranges = Arc::new(ranges);
4469
4470        let mut queues = Vec::new();
4471        let mut rx_buffers = Vec::new();
4472        {
4473            let buffers = &state.buffers;
4474            let guest_buffers = Arc::new(
4475                GuestBuffers::new(
4476                    buffers.mem.clone(),
4477                    buffers.recv_buffer.gpadl.clone(),
4478                    buffers.recv_buffer.sub_allocation_size,
4479                    buffers.ndis_config.mtu,
4480                )
4481                .map_err(WorkerError::GpadlError)?,
4482            );
4483
4484            // Get the list of free rx buffers from each task, then partition
4485            // the list per-active-queue, and produce the queue configuration.
4486            let mut queue_config = Vec::new();
4487            let initial_rx;
4488            {
4489                let states = std::iter::once(Some(&*state)).chain(
4490                    subworkers
4491                        .iter()
4492                        .map(|worker| worker.state().and_then(|worker| worker.state.ready())),
4493                );
4494
4495                initial_rx = (RX_RESERVED_CONTROL_BUFFERS..state.buffers.recv_buffer.count)
4496                    .filter(|&n| states.clone().flatten().all(|s| s.state.rx_bufs.is_free(n)))
4497                    .map(RxId)
4498                    .collect::<Vec<_>>();
4499
4500                let mut initial_rx = initial_rx.as_slice();
4501                let mut range_start = 0;
4502                let primary_queue_excluded = !active_queues.is_empty() && active_queues[0] != 0;
4503                let first_queue = if !primary_queue_excluded {
4504                    0
4505                } else {
4506                    // If the primary queue is excluded from the guest supplied
4507                    // indirection table, it is assigned just the reserved
4508                    // buffers.
4509                    queue_config.push(QueueConfig {
4510                        pool: Box::new(BufferPool::new(guest_buffers.clone())),
4511                        initial_rx: &[],
4512                        driver: Box::new(drivers[0].clone()),
4513                    });
4514                    rx_buffers.push(RxBufferRange::new(
4515                        ranges.clone(),
4516                        0..RX_RESERVED_CONTROL_BUFFERS,
4517                        None,
4518                    ));
4519                    range_start = RX_RESERVED_CONTROL_BUFFERS;
4520                    1
4521                };
4522                for queue_index in first_queue..num_queues {
4523                    let queue_active = active_queues.is_empty()
4524                        || active_queues.binary_search(&queue_index).is_ok();
4525                    let (range_end, end, buffer_id_recv) = if queue_active {
4526                        let range_end = if rx_buffers.len() as u16 == active_queue_count - 1 {
4527                            // The last queue gets all the remaining buffers.
4528                            state.buffers.recv_buffer.count
4529                        } else if queue_index == 0 {
4530                            // Queue zero always includes the reserved buffers.
4531                            RX_RESERVED_CONTROL_BUFFERS + ranges.buffers_per_queue
4532                        } else {
4533                            range_start + ranges.buffers_per_queue
4534                        };
4535                        (
4536                            range_end,
4537                            initial_rx.partition_point(|id| id.0 < range_end),
4538                            Some(remote_buffer_id_recvs.remove(0)),
4539                        )
4540                    } else {
4541                        (range_start, 0, None)
4542                    };
4543
4544                    let (this, rest) = initial_rx.split_at(end);
4545                    queue_config.push(QueueConfig {
4546                        pool: Box::new(BufferPool::new(guest_buffers.clone())),
4547                        initial_rx: this,
4548                        driver: Box::new(drivers[queue_index as usize].clone()),
4549                    });
4550                    initial_rx = rest;
4551                    rx_buffers.push(RxBufferRange::new(
4552                        ranges.clone(),
4553                        range_start..range_end,
4554                        buffer_id_recv,
4555                    ));
4556
4557                    range_start = range_end;
4558                }
4559            }
4560
4561            let primary = state.state.primary.as_mut().unwrap();
4562            tracing::debug!(num_queues, "enabling endpoint");
4563
4564            let rss = primary
4565                .rss_state
4566                .as_ref()
4567                .map(|rss| net_backend::RssConfig {
4568                    key: &rss.key,
4569                    indirection_table: &rss.indirection_table,
4570                    flags: 0,
4571                });
4572
4573            c_state
4574                .endpoint
4575                .get_queues(queue_config, rss.as_ref(), &mut queues)
4576                .instrument(tracing::info_span!("netvsp_get_queues"))
4577                .await
4578                .map_err(WorkerError::Endpoint)?;
4579
4580            assert_eq!(queues.len(), num_queues as usize);
4581
4582            // Set the subchannel count.
4583            self.channel_control
4584                .enable_subchannels(num_queues - 1)
4585                .expect("already validated");
4586
4587            self.num_queues = num_queues;
4588        }
4589
4590        self.active_packet_filter = self.workers[0].state().unwrap().channel.packet_filter;
4591        // Provide the queue and receive buffer ranges for each worker.
4592        for ((worker, queue), rx_buffer) in self.workers.iter_mut().zip(queues).zip(rx_buffers) {
4593            worker.task_mut().queue_state = Some(QueueState {
4594                queue,
4595                target_vp_set: false,
4596                rx_buffer_range: rx_buffer,
4597            });
4598            // Update the receive packet filter for the subchannel worker.
4599            if let Some(worker) = worker.state_mut() {
4600                worker.channel.packet_filter = self.active_packet_filter;
4601                // Clear any pending RxIds as buffers were redistributed
4602                // and reset TX tracking after the endpoint stop.
4603                if let Some(ready_state) = worker.state.ready_mut() {
4604                    ready_state.state.pending_rx_packets.clear();
4605                    ready_state.reset_tx_after_endpoint_stop();
4606                }
4607            }
4608        }
4609
4610        Ok(())
4611    }
4612
4613    fn primary_mut(&mut self) -> Option<&mut PrimaryChannelState> {
4614        self.workers[0]
4615            .state_mut()
4616            .unwrap()
4617            .state
4618            .ready_mut()?
4619            .state
4620            .primary
4621            .as_mut()
4622    }
4623
4624    async fn update_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
4625        self.workers[0].stop().await;
4626        self.restore_guest_vf_state(c_state).await;
4627    }
4628}
4629
4630impl<T: RingMem + 'static + Sync> AsyncRun<Worker<T>> for NetQueue {
4631    async fn run(
4632        &mut self,
4633        stop: &mut StopTask<'_>,
4634        worker: &mut Worker<T>,
4635    ) -> Result<(), task_control::Cancelled> {
4636        match worker.process(stop, self).await {
4637            Ok(()) | Err(WorkerError::BufferRevoked) => {}
4638            Err(WorkerError::Cancelled(cancelled)) => return Err(cancelled),
4639            Err(err) => {
4640                tracing::error!(
4641                    error = &err as &dyn std::error::Error,
4642                    channel_idx = worker.channel_idx,
4643                    "netvsp error"
4644                );
4645            }
4646        }
4647        Ok(())
4648    }
4649}
4650
4651impl<T: RingMem + 'static> Worker<T> {
4652    async fn process(
4653        &mut self,
4654        stop: &mut StopTask<'_>,
4655        queue: &mut NetQueue,
4656    ) -> Result<(), WorkerError> {
4657        // Be careful not to wait on actions with unbounded blocking time (e.g.
4658        // guest actions, or waiting for network packets to arrive) without
4659        // wrapping the wait on `stop.until_stopped`.
4660        loop {
4661            match &mut self.state {
4662                WorkerState::Init(initializing) => {
4663                    assert_eq!(self.channel_idx, 0);
4664
4665                    tracelimit::info_ratelimited!("network accepted");
4666
4667                    let (buffers, state) = stop
4668                        .until_stopped(self.channel.initialize(initializing, self.mem.clone()))
4669                        .await??;
4670
4671                    let state = ReadyState {
4672                        buffers: Arc::new(buffers),
4673                        state,
4674                        data: ProcessingData::new(),
4675                    };
4676
4677                    // Wake up the coordinator task to start the queues.
4678                    let _ = self.coordinator_send.try_send(CoordinatorMessage::Restart);
4679
4680                    tracelimit::info_ratelimited!("network initialized");
4681                    self.state = WorkerState::WaitingForCoordinator(Some(state));
4682                }
4683                WorkerState::WaitingForCoordinator(_) => {
4684                    assert_eq!(self.channel_idx, 0);
4685                    // Waiting for the coordinator to process a message and
4686                    // restart the primary worker.
4687                    stop.until_stopped(pending()).await?
4688                }
4689                WorkerState::Ready(state) => {
4690                    let queue_state = if let Some(queue_state) = &mut queue.queue_state {
4691                        if !queue_state.target_vp_set {
4692                            if let Some(target_vp) = self.target_vp {
4693                                tracing::debug!(
4694                                    channel_idx = self.channel_idx,
4695                                    target_vp,
4696                                    "updating target VP"
4697                                );
4698                                queue_state.queue.update_target_vp(target_vp).await;
4699                                queue_state.target_vp_set = true;
4700                            }
4701                        }
4702
4703                        queue_state
4704                    } else {
4705                        // This task will be restarted when the queues are ready.
4706                        stop.until_stopped(pending()).await?
4707                    };
4708
4709                    let result = self.channel.main_loop(stop, state, queue_state).await;
4710                    let msg = match result {
4711                        Ok(restart) => {
4712                            assert_eq!(self.channel_idx, 0);
4713                            restart
4714                        }
4715                        Err(WorkerError::EndpointRequiresQueueRestart(err)) => {
4716                            tracelimit::warn_ratelimited!(
4717                                err = err.as_ref() as &dyn std::error::Error,
4718                                "Endpoint requires queues to restart",
4719                            );
4720                            CoordinatorMessage::Restart
4721                        }
4722                        Err(err) => return Err(err),
4723                    };
4724
4725                    let WorkerState::Ready(ready) = std::mem::replace(
4726                        &mut self.state,
4727                        WorkerState::WaitingForCoordinator(None),
4728                    ) else {
4729                        unreachable!("must be running in ready state")
4730                    };
4731                    let _ = std::mem::replace(
4732                        &mut self.state,
4733                        WorkerState::WaitingForCoordinator(Some(ready)),
4734                    );
4735                    self.coordinator_send
4736                        .try_send(msg)
4737                        .map_err(WorkerError::CoordinatorMessageSendFailed)?;
4738                    stop.until_stopped(pending()).await?
4739                }
4740            }
4741        }
4742    }
4743}
4744
4745impl<T: 'static + RingMem> NetChannel<T> {
4746    fn try_next_packet<'a>(
4747        &mut self,
4748        send_buffer: Option<&SendBuffer>,
4749        version: Option<Version>,
4750        external_data: &'a mut MultiPagedRangeBuf,
4751    ) -> Result<Option<Packet<'a>>, WorkerError> {
4752        let (mut read, _) = self.queue.split();
4753        let packet = match read.try_read() {
4754            Ok(packet) => parse_packet(&packet, send_buffer, version, external_data)
4755                .map_err(WorkerError::Packet)?,
4756            Err(queue::TryReadError::Empty) => return Ok(None),
4757            Err(queue::TryReadError::Queue(err)) => return Err(err.into()),
4758        };
4759
4760        tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
4761        Ok(Some(packet))
4762    }
4763
4764    async fn next_packet<'a>(
4765        &mut self,
4766        send_buffer: Option<&'a SendBuffer>,
4767        version: Option<Version>,
4768        external_data: &'a mut MultiPagedRangeBuf,
4769    ) -> Result<Packet<'a>, WorkerError> {
4770        let (mut read, _) = self.queue.split();
4771        let mut packet_ref = read.read().await?;
4772        let packet = parse_packet(&packet_ref, send_buffer, version, external_data)
4773            .map_err(WorkerError::Packet)?;
4774        if matches!(packet.data, PacketData::RndisPacket(_)) {
4775            // In WorkerState::Init if an rndis packet is received, assume it is MESSAGE_TYPE_INITIALIZE_MSG
4776            tracing::trace!(target: "netvsp/vmbus", "detected rndis initialization message");
4777            packet_ref.revert();
4778        }
4779        tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
4780        Ok(packet)
4781    }
4782
4783    fn is_ready_to_initialize(initializing: &InitState, allow_missing_send_buffer: bool) -> bool {
4784        (initializing.ndis_config.is_some() || initializing.version < Version::V2)
4785            && initializing.ndis_version.is_some()
4786            && (initializing.send_buffer.is_some() || allow_missing_send_buffer)
4787            && initializing.recv_buffer.is_some()
4788    }
4789
4790    async fn initialize(
4791        &mut self,
4792        initializing: &mut Option<InitState>,
4793        mem: GuestMemory,
4794    ) -> Result<(ChannelBuffers, ActiveState), WorkerError> {
4795        let mut has_init_packet_arrived = false;
4796        loop {
4797            if let Some(initializing) = &mut *initializing {
4798                if Self::is_ready_to_initialize(initializing, false) || has_init_packet_arrived {
4799                    let recv_buffer = initializing.recv_buffer.take().unwrap();
4800                    let send_buffer = initializing.send_buffer.take();
4801                    let state = ActiveState::new(
4802                        Some(PrimaryChannelState::new(
4803                            self.adapter.offload_support.clone(),
4804                        )),
4805                        recv_buffer.count,
4806                    );
4807                    let buffers = ChannelBuffers {
4808                        version: initializing.version,
4809                        mem,
4810                        recv_buffer,
4811                        send_buffer,
4812                        ndis_version: initializing.ndis_version.take().unwrap(),
4813                        ndis_config: initializing.ndis_config.take().unwrap_or(NdisConfig {
4814                            mtu: DEFAULT_MTU,
4815                            capabilities: protocol::NdisConfigCapabilities::new(),
4816                        }),
4817                    };
4818
4819                    break Ok((buffers, state));
4820                }
4821            }
4822
4823            // Wait for enough room in the ring to avoid needing to track
4824            // completion packets.
4825            self.queue
4826                .split()
4827                .1
4828                .wait_ready(ring::PacketSize::completion(protocol::PACKET_SIZE_V61))
4829                .await?;
4830
4831            let mut external_data = MultiPagedRangeBuf::new();
4832            let packet = self
4833                .next_packet(
4834                    None,
4835                    initializing.as_ref().map(|x| x.version),
4836                    &mut external_data,
4837                )
4838                .await?;
4839
4840            if let Some(initializing) = &mut *initializing {
4841                match packet.data {
4842                    PacketData::SendNdisConfig(config) => {
4843                        if initializing.ndis_config.is_some() {
4844                            return Err(WorkerError::UnexpectedPacketOrder(
4845                                PacketOrderError::SendNdisConfigExists,
4846                            ));
4847                        }
4848
4849                        // As in the vmswitch, if the MTU is invalid then use the default.
4850                        let mtu = if config.mtu >= MIN_MTU && config.mtu <= MAX_MTU {
4851                            config.mtu
4852                        } else {
4853                            DEFAULT_MTU
4854                        };
4855
4856                        // The UEFI client expects a completion packet, which can be empty.
4857                        self.send_completion(packet.transaction_id, None)?;
4858                        initializing.ndis_config = Some(NdisConfig {
4859                            mtu,
4860                            capabilities: config.capabilities,
4861                        });
4862                    }
4863                    PacketData::SendNdisVersion(version) => {
4864                        if initializing.ndis_version.is_some() {
4865                            return Err(WorkerError::UnexpectedPacketOrder(
4866                                PacketOrderError::SendNdisVersionExists,
4867                            ));
4868                        }
4869
4870                        // The UEFI client expects a completion packet, which can be empty.
4871                        self.send_completion(packet.transaction_id, None)?;
4872                        initializing.ndis_version = Some(NdisVersion {
4873                            major: version.ndis_major_version,
4874                            minor: version.ndis_minor_version,
4875                        });
4876                    }
4877                    PacketData::SendReceiveBuffer(message) => {
4878                        if initializing.recv_buffer.is_some() {
4879                            return Err(WorkerError::UnexpectedPacketOrder(
4880                                PacketOrderError::SendReceiveBufferExists,
4881                            ));
4882                        }
4883
4884                        let mtu = if let Some(cfg) = &initializing.ndis_config {
4885                            cfg.mtu
4886                        } else if initializing.version < Version::V2 {
4887                            DEFAULT_MTU
4888                        } else {
4889                            return Err(WorkerError::UnexpectedPacketOrder(
4890                                PacketOrderError::SendReceiveBufferMissingMTU,
4891                            ));
4892                        };
4893
4894                        let sub_allocation_size = sub_allocation_size_for_mtu(mtu);
4895
4896                        let recv_buffer = ReceiveBuffer::new(
4897                            &self.gpadl_map,
4898                            message.gpadl_handle,
4899                            message.id,
4900                            sub_allocation_size,
4901                        )?;
4902
4903                        self.send_completion(
4904                            packet.transaction_id,
4905                            Some(&self.message(
4906                                protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER_COMPLETE,
4907                                protocol::Message1SendReceiveBufferComplete {
4908                                    status: protocol::Status::SUCCESS,
4909                                    num_sections: 1,
4910                                    sections: [protocol::ReceiveBufferSection {
4911                                        offset: 0,
4912                                        sub_allocation_size: recv_buffer.sub_allocation_size,
4913                                        num_sub_allocations: recv_buffer.count,
4914                                        end_offset: recv_buffer.sub_allocation_size
4915                                            * recv_buffer.count,
4916                                    }],
4917                                },
4918                            )),
4919                        )?;
4920                        initializing.recv_buffer = Some(recv_buffer);
4921                    }
4922                    PacketData::SendSendBuffer(message) => {
4923                        if initializing.send_buffer.is_some() {
4924                            return Err(WorkerError::UnexpectedPacketOrder(
4925                                PacketOrderError::SendSendBufferExists,
4926                            ));
4927                        }
4928
4929                        let send_buffer = SendBuffer::new(&self.gpadl_map, message.gpadl_handle)?;
4930                        self.send_completion(
4931                            packet.transaction_id,
4932                            Some(&self.message(
4933                                protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER_COMPLETE,
4934                                protocol::Message1SendSendBufferComplete {
4935                                    status: protocol::Status::SUCCESS,
4936                                    section_size: 6144,
4937                                },
4938                            )),
4939                        )?;
4940
4941                        initializing.send_buffer = Some(send_buffer);
4942                    }
4943                    PacketData::RndisPacket(rndis_packet) => {
4944                        if !Self::is_ready_to_initialize(initializing, true) {
4945                            return Err(WorkerError::UnexpectedPacketOrder(
4946                                PacketOrderError::UnexpectedRndisPacket,
4947                            ));
4948                        }
4949                        tracing::debug!(
4950                            channel_type = rndis_packet.channel_type,
4951                            "RndisPacket received during initialization, assuming MESSAGE_TYPE_INITIALIZE_MSG"
4952                        );
4953                        has_init_packet_arrived = true;
4954                    }
4955                    _ => {
4956                        return Err(WorkerError::UnexpectedPacketOrder(
4957                            PacketOrderError::InvalidPacketData,
4958                        ));
4959                    }
4960                }
4961            } else {
4962                match packet.data {
4963                    PacketData::Init(init) => {
4964                        let requested_version = init.protocol_version;
4965                        let version = check_version(requested_version);
4966                        let mut data = protocol::MessageInitComplete {
4967                            deprecated: protocol::INVALID_PROTOCOL_VERSION,
4968                            maximum_mdl_chain_length: 34,
4969                            status: protocol::Status::NONE,
4970                        };
4971                        if let Some(version) = version {
4972                            if version == Version::V1 {
4973                                data.deprecated = Version::V1 as u32;
4974                            }
4975                            data.status = protocol::Status::SUCCESS;
4976                        } else {
4977                            tracing::debug!(requested_version, "unrecognized version");
4978                        }
4979                        let message = self.message(protocol::MESSAGE_TYPE_INIT_COMPLETE, data);
4980                        self.send_completion(packet.transaction_id, Some(&message))?;
4981
4982                        if let Some(version) = version {
4983                            tracelimit::info_ratelimited!(?version, "network negotiated");
4984
4985                            if version >= Version::V61 {
4986                                // Update the packet size so that the appropriate padding is
4987                                // appended for picky Windows guests.
4988                                self.packet_size = PacketSize::V61;
4989                            }
4990                            *initializing = Some(InitState {
4991                                version,
4992                                ndis_config: None,
4993                                ndis_version: None,
4994                                recv_buffer: None,
4995                                send_buffer: None,
4996                            });
4997                        }
4998                    }
4999                    _ => unreachable!(),
5000                }
5001            }
5002        }
5003    }
5004
5005    async fn main_loop(
5006        &mut self,
5007        stop: &mut StopTask<'_>,
5008        ready_state: &mut ReadyState,
5009        queue_state: &mut QueueState,
5010    ) -> Result<CoordinatorMessage, WorkerError> {
5011        let buffers = &ready_state.buffers;
5012        let state = &mut ready_state.state;
5013        let data = &mut ready_state.data;
5014
5015        let ring_spare_capacity = {
5016            let (_, send) = self.queue.split();
5017            let mut limit = if self.can_use_ring_size_opt {
5018                self.adapter.ring_size_limit.load(Ordering::Relaxed)
5019            } else {
5020                0
5021            };
5022            if limit == 0 {
5023                limit = send.capacity() - 2048;
5024            }
5025            send.capacity() - limit
5026        };
5027
5028        // If the packet filter has changed to allow rx packets, add any pended RxIds.
5029        if !state.pending_rx_packets.is_empty()
5030            && self.packet_filter != rndisprot::NDIS_PACKET_TYPE_NONE
5031        {
5032            let epqueue = queue_state.queue.as_mut();
5033            let (front, back) = state.pending_rx_packets.as_slices();
5034            epqueue.rx_avail(front);
5035            epqueue.rx_avail(back);
5036            state.pending_rx_packets.clear();
5037        }
5038
5039        // Handle any guest state changes since last run.
5040        if let Some(primary) = state.primary.as_mut() {
5041            if primary.requested_num_queues > 1 && !primary.tx_spread_sent {
5042                let num_channels_opened =
5043                    self.adapter.num_sub_channels_opened.load(Ordering::Relaxed);
5044                if num_channels_opened == primary.requested_num_queues as usize {
5045                    let (_, mut send) = self.queue.split();
5046                    stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
5047                        .await??;
5048                    self.guest_send_indirection_table(buffers.version, num_channels_opened as u32);
5049                    primary.tx_spread_sent = true;
5050                }
5051            }
5052            if let PendingLinkAction::Active(up) = primary.pending_link_action {
5053                let (_, mut send) = self.queue.split();
5054                stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
5055                    .await??;
5056                if let Some(id) = primary.free_control_buffers.pop() {
5057                    let connect = if primary.guest_link_up != up {
5058                        primary.pending_link_action = PendingLinkAction::Default;
5059                        up
5060                    } else {
5061                        // For the up -> down -> up OR down -> up -> down case, the first transition
5062                        // is sent immediately and the second transition is queued with a delay.
5063                        primary.pending_link_action =
5064                            PendingLinkAction::Delay(primary.guest_link_up);
5065                        !primary.guest_link_up
5066                    };
5067                    // Mark the receive buffer in use to allow the guest to release it.
5068                    assert!(state.rx_bufs.is_free(id.0));
5069                    state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
5070                    let state_to_send = if connect {
5071                        rndisprot::STATUS_MEDIA_CONNECT
5072                    } else {
5073                        rndisprot::STATUS_MEDIA_DISCONNECT
5074                    };
5075                    tracing::info!(
5076                        connect,
5077                        mac_address = %self.adapter.mac_address,
5078                        "sending link status"
5079                    );
5080
5081                    self.indicate_status(buffers, id.0, state_to_send, &[])?;
5082                    primary.guest_link_up = connect;
5083                } else {
5084                    primary.pending_link_action = PendingLinkAction::Delay(up);
5085                }
5086
5087                match primary.pending_link_action {
5088                    PendingLinkAction::Delay(_) => {
5089                        return Ok(CoordinatorMessage::StartTimer(
5090                            Instant::now() + LINK_DELAY_DURATION,
5091                        ));
5092                    }
5093                    PendingLinkAction::Active(_) => panic!("State should not be Active"),
5094                    _ => {}
5095                }
5096            }
5097            match primary.guest_vf_state {
5098                PrimaryChannelGuestVfState::Available { .. }
5099                | PrimaryChannelGuestVfState::UnavailableFromAvailable
5100                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
5101                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
5102                | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
5103                | PrimaryChannelGuestVfState::DataPathSynthetic => {
5104                    let (_, mut send) = self.queue.split();
5105                    stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
5106                        .await??;
5107                    if let Some(message) = self.handle_state_change(primary, buffers).await? {
5108                        return Ok(message);
5109                    }
5110                }
5111                _ => (),
5112            }
5113        }
5114
5115        loop {
5116            // If the ring is almost full, then do not poll the endpoint,
5117            // since this will cause us to spend time dropping packets and
5118            // reprogramming receive buffers. It's more efficient to let the
5119            // backend drop packets.
5120            let ring_full = {
5121                let (_, mut send) = self.queue.split();
5122                !send.can_write(ring_spare_capacity)?
5123            };
5124
5125            let did_some_work = (!ring_full
5126                && self.process_endpoint_rx(buffers, state, data, queue_state.queue.as_mut())?)
5127                | self.process_ring_buffer(buffers, state, data, queue_state)?
5128                | (!ring_full
5129                    && self.process_endpoint_tx(state, data, queue_state.queue.as_mut())?)
5130                | self.transmit_pending_segments(state, data, queue_state)?
5131                | self.send_pending_packets(state)?;
5132
5133            if !did_some_work {
5134                state.stats.spurious_wakes.increment();
5135            }
5136
5137            // Process any outstanding control messages before sleeping in case
5138            // there are now suballocations available for use.
5139            self.process_control_messages(buffers, state)?;
5140
5141            // This should be the only await point waiting on network traffic or
5142            // guest actions. Wrap it in `stop.until_stopped` to allow
5143            // cancellation.
5144            let restart = stop
5145                .until_stopped(std::future::poll_fn(
5146                    |cx| -> Poll<Option<CoordinatorMessage>> {
5147                        // If the ring is almost full, then don't wait for endpoint
5148                        // interrupts. This allows the interrupt rate to fall when the
5149                        // guest cannot keep up with the load.
5150                        if !ring_full {
5151                            // Check the network endpoint for tx completion or rx.
5152                            if queue_state.queue.poll_ready(cx).is_ready() {
5153                                tracing::trace!("endpoint ready");
5154                                return Poll::Ready(None);
5155                            }
5156                        }
5157
5158                        // Check the incoming ring for tx, but only if there are enough
5159                        // free tx packets and no pending tx segments.
5160                        let (mut recv, mut send) = self.queue.split();
5161                        if state.free_tx_packets.len() >= self.adapter.free_tx_packet_threshold
5162                            && data.tx_segments.is_empty()
5163                            && recv.poll_ready(cx).is_ready()
5164                        {
5165                            tracing::trace!("incoming ring ready");
5166                            return Poll::Ready(None);
5167                        }
5168
5169                        // Check the outgoing ring for space to send rx completions or
5170                        // control message sends, if any are pending.
5171                        //
5172                        // Also, if endpoint processing has been suspended due to the
5173                        // ring being nearly full, then ask the guest to wake us up when
5174                        // there is space again.
5175                        let mut pending_send_size = self.pending_send_size;
5176                        if ring_full {
5177                            pending_send_size = ring_spare_capacity;
5178                        }
5179                        if pending_send_size != 0
5180                            && send.poll_ready(cx, pending_send_size).is_ready()
5181                        {
5182                            tracing::trace!("outgoing ring ready");
5183                            return Poll::Ready(None);
5184                        }
5185
5186                        // Collect any of this queue's receive buffers that were
5187                        // completed by a remote channel. This only happens when the
5188                        // subchannel count changes, so that receive buffer ownership
5189                        // moves between queues while some receive buffers are still in
5190                        // use.
5191                        if let Some(remote_buffer_id_recv) =
5192                            &mut queue_state.rx_buffer_range.remote_buffer_id_recv
5193                        {
5194                            while let Poll::Ready(Some(id)) =
5195                                remote_buffer_id_recv.poll_next_unpin(cx)
5196                            {
5197                                if id >= RX_RESERVED_CONTROL_BUFFERS {
5198                                    queue_state.queue.rx_avail(&[RxId(id)]);
5199                                } else {
5200                                    state
5201                                        .primary
5202                                        .as_mut()
5203                                        .unwrap()
5204                                        .free_control_buffers
5205                                        .push(ControlMessageId(id));
5206                                }
5207                            }
5208                        }
5209
5210                        if let Some(restart) = self.restart.take() {
5211                            return Poll::Ready(Some(restart));
5212                        }
5213
5214                        tracing::trace!("network waiting");
5215                        Poll::Pending
5216                    },
5217                ))
5218                .await?;
5219
5220            if let Some(restart) = restart {
5221                break Ok(restart);
5222            }
5223        }
5224    }
5225
5226    fn process_endpoint_rx(
5227        &mut self,
5228        buffers: &ChannelBuffers,
5229        state: &mut ActiveState,
5230        data: &mut ProcessingData,
5231        epqueue: &mut dyn net_backend::Queue,
5232    ) -> Result<bool, WorkerError> {
5233        let n = epqueue
5234            .rx_poll(&mut data.rx_ready)
5235            .map_err(WorkerError::Endpoint)?;
5236        if n == 0 {
5237            return Ok(false);
5238        }
5239
5240        state.stats.rx_packets_per_wake.add_sample(n as u64);
5241
5242        if self.packet_filter == rndisprot::NDIS_PACKET_TYPE_NONE {
5243            tracing::trace!(
5244                packet_filter = self.packet_filter,
5245                "rx packets dropped due to packet filter"
5246            );
5247            // Pend the newly available RxIds until the packet filter is updated.
5248            // Under high load this will eventually lead to no available RxIds,
5249            // which will cause the backend to drop the packets instead of
5250            // processing them here.
5251            state.pending_rx_packets.extend(&data.rx_ready[..n]);
5252            state.stats.rx_dropped_filtered.add(n as u64);
5253            return Ok(false);
5254        }
5255
5256        let transaction_id = data.rx_ready[0].0.into();
5257        let ready_ids = data.rx_ready[..n].iter().map(|&RxId(id)| id);
5258
5259        state.rx_bufs.allocate(ready_ids.clone()).unwrap();
5260
5261        // Always use the full suballocation size to avoid tracking the
5262        // message length. See RxBuf::header() for details.
5263        let len = buffers.recv_buffer.sub_allocation_size as usize;
5264        data.transfer_pages.clear();
5265        data.transfer_pages
5266            .extend(ready_ids.map(|id| buffers.recv_buffer.transfer_page_range(id, len)));
5267
5268        match self.try_send_rndis_message(
5269            transaction_id,
5270            protocol::DATA_CHANNEL_TYPE,
5271            buffers.recv_buffer.id,
5272            &data.transfer_pages,
5273        )? {
5274            None => {
5275                // packet was sent
5276                state.stats.rx_packets.add(n as u64);
5277            }
5278            Some(_) => {
5279                // Ring buffer is full. Drop the packets and free the rx
5280                // buffers. When the ring has limited space, the main loop will
5281                // stop polling for receive packets.
5282                state.stats.rx_dropped_ring_full.add(n as u64);
5283
5284                state.rx_bufs.free(data.rx_ready[0].0);
5285                epqueue.rx_avail(&data.rx_ready[..n]);
5286            }
5287        }
5288
5289        Ok(true)
5290    }
5291
5292    fn process_endpoint_tx(
5293        &mut self,
5294        state: &mut ActiveState,
5295        data: &mut ProcessingData,
5296        epqueue: &mut dyn net_backend::Queue,
5297    ) -> Result<bool, WorkerError> {
5298        // Drain completed transmits.
5299        let result = epqueue.tx_poll(&mut data.tx_done);
5300
5301        match result {
5302            Ok(n) => {
5303                if n == 0 {
5304                    return Ok(false);
5305                }
5306
5307                for &id in &data.tx_done[..n] {
5308                    let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5309                    assert!(tx_packet.pending_packet_count > 0);
5310                    tx_packet.pending_packet_count -= 1;
5311                    if tx_packet.pending_packet_count == 0 {
5312                        self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5313                    }
5314                }
5315
5316                Ok(true)
5317            }
5318            Err(TxError::TryRestart(err)) => {
5319                // In-flight TX packets will be cleaned up by
5320                // `reset_tx_after_endpoint_stop` during the queue restart.
5321                Err(WorkerError::EndpointRequiresQueueRestart(err))
5322            }
5323            Err(TxError::Fatal(err)) => Err(WorkerError::Endpoint(err)),
5324        }
5325    }
5326
5327    fn switch_data_path(
5328        &mut self,
5329        state: &mut ActiveState,
5330        use_guest_vf: bool,
5331        transaction_id: Option<u64>,
5332    ) -> Result<(), WorkerError> {
5333        let primary = state.primary.as_mut().unwrap();
5334        let mut queue_switch_operation = false;
5335        match primary.guest_vf_state {
5336            PrimaryChannelGuestVfState::AvailableAdvertised | PrimaryChannelGuestVfState::Ready => {
5337                // Allow the guest to switch to VTL0, or if the current state
5338                // of the data path is unknown, allow a switch away from VTL0.
5339                // The latter case handles the scenario where the data path has
5340                // been switched but the synthetic device has been restarted.
5341                // The device is queried for the current state of the data path
5342                // but if it is unknown (upgraded from older version that
5343                // doesn't track this data, or failure during query) then the
5344                // safest option is to pass the request through.
5345                if use_guest_vf || primary.is_data_path_switched.is_none() {
5346                    primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5347                        to_guest: use_guest_vf,
5348                        id: transaction_id,
5349                        result: None,
5350                    };
5351                    queue_switch_operation = true;
5352                }
5353            }
5354            PrimaryChannelGuestVfState::DataPathSwitched => {
5355                if !use_guest_vf {
5356                    primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5357                        to_guest: false,
5358                        id: transaction_id,
5359                        result: None,
5360                    };
5361                    queue_switch_operation = true;
5362                }
5363            }
5364            _ if use_guest_vf => {
5365                tracing::warn!(
5366                    state = %primary.guest_vf_state,
5367                    use_guest_vf,
5368                    "Data path switch requested while device is in wrong state"
5369                );
5370            }
5371            _ => (),
5372        };
5373        if queue_switch_operation {
5374            self.send_coordinator_update_vf();
5375        } else {
5376            self.send_completion(transaction_id, None)?;
5377        }
5378        Ok(())
5379    }
5380
5381    fn process_ring_buffer(
5382        &mut self,
5383        buffers: &ChannelBuffers,
5384        state: &mut ActiveState,
5385        data: &mut ProcessingData,
5386        queue_state: &mut QueueState,
5387    ) -> Result<bool, WorkerError> {
5388        if !data.tx_segments.is_empty() {
5389            // There are still segments pending transmission. Skip polling the
5390            // ring until they are sent to increase backpressure and minimize
5391            // unnecessary wakeups.
5392            return Ok(false);
5393        }
5394        let mut total_packets = 0;
5395        let mut did_some_work = false;
5396        loop {
5397            if state.free_tx_packets.is_empty() {
5398                break;
5399            }
5400            let packet = if let Some(packet) = self.try_next_packet(
5401                buffers.send_buffer.as_ref(),
5402                Some(buffers.version),
5403                &mut data.external_data,
5404            )? {
5405                packet
5406            } else {
5407                break;
5408            };
5409
5410            did_some_work = true;
5411            match packet.data {
5412                PacketData::RndisPacket(_) => {
5413                    let id = state.free_tx_packets.pop().unwrap();
5414                    let result: Result<usize, WorkerError> =
5415                        self.handle_rndis(buffers, id, state, &packet, &mut data.tx_segments);
5416                    match result {
5417                        Ok(num_packets) => {
5418                            total_packets += num_packets as u64;
5419                            if num_packets == 0 {
5420                                self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5421                            }
5422                        }
5423                        Err(err) => {
5424                            tracelimit::error_ratelimited!(
5425                                error = &err as &dyn std::error::Error,
5426                                "failed to handle RNDIS packet"
5427                            );
5428                            self.complete_tx_packet(state, id, protocol::Status::FAILURE)?;
5429                        }
5430                    };
5431                }
5432                PacketData::RndisPacketComplete(_completion) => {
5433                    data.rx_done.clear();
5434                    state
5435                        .release_recv_buffers(
5436                            packet
5437                                .transaction_id
5438                                .expect("completion packets have transaction id by construction"),
5439                            &queue_state.rx_buffer_range,
5440                            &mut data.rx_done,
5441                        )
5442                        .ok_or(WorkerError::InvalidRndisPacketCompletion)?;
5443                    queue_state.queue.rx_avail(&data.rx_done);
5444                }
5445                PacketData::SubChannelRequest(request) if state.primary.is_some() => {
5446                    let mut subchannel_count = 0;
5447                    // The number of requested subchannels has to stay below the maximum queue limit
5448                    // because one queue is always reserved for the primary channel. In other words,
5449                    // the subchannels plus the primary channel must fit within the max_queues value,
5450                    // which means subchannels + 1 ≤ max_queues, so the subchannel count must be
5451                    // strictly less than max_queues.
5452                    let status = if request.operation == protocol::SubchannelOperation::ALLOCATE
5453                        && request.num_sub_channels < self.adapter.max_queues.into()
5454                    {
5455                        subchannel_count = request.num_sub_channels;
5456                        protocol::Status::SUCCESS
5457                    } else {
5458                        tracelimit::warn_ratelimited!(
5459                            operation = ?request.operation,
5460                            request_sub_channels = request.num_sub_channels,
5461                            max_supported_sub_channels = self.adapter.max_queues - 1,
5462                            "Subchannel request failed: either operation is not supported or requested more subchannels than supported"
5463                        );
5464                        protocol::Status::FAILURE
5465                    };
5466
5467                    tracing::debug!(?status, subchannel_count, "subchannel request");
5468                    self.send_completion(
5469                        packet.transaction_id,
5470                        Some(&self.message(
5471                            protocol::MESSAGE5_TYPE_SUB_CHANNEL,
5472                            protocol::Message5SubchannelComplete {
5473                                status,
5474                                num_sub_channels: subchannel_count,
5475                            },
5476                        )),
5477                    )?;
5478
5479                    if subchannel_count > 0 {
5480                        let primary = state.primary.as_mut().unwrap();
5481                        primary.requested_num_queues = subchannel_count as u16 + 1;
5482                        primary.tx_spread_sent = false;
5483                        self.restart = Some(CoordinatorMessage::Restart);
5484                    }
5485                }
5486                PacketData::RevokeReceiveBuffer(protocol::Message1RevokeReceiveBuffer { id })
5487                | PacketData::RevokeSendBuffer(protocol::Message1RevokeSendBuffer { id })
5488                    if state.primary.is_some() =>
5489                {
5490                    tracing::debug!(
5491                        id,
5492                        "receive/send buffer revoked, terminating channel processing"
5493                    );
5494                    return Err(WorkerError::BufferRevoked);
5495                }
5496                // No operation for VF association completion packets as not all clients send them
5497                PacketData::SendVfAssociationCompletion if state.primary.is_some() => (),
5498                PacketData::SwitchDataPath(switch_data_path) if state.primary.is_some() => {
5499                    self.switch_data_path(
5500                        state,
5501                        switch_data_path.active_data_path == protocol::DataPath::VF.0,
5502                        packet.transaction_id,
5503                    )?;
5504                }
5505                PacketData::SwitchDataPathCompletion if state.primary.is_some() => (),
5506                PacketData::OidQueryEx(oid_query) => {
5507                    tracing::warn!(oid = ?oid_query.oid, "unimplemented OID");
5508                    self.send_completion(
5509                        packet.transaction_id,
5510                        Some(&self.message(
5511                            protocol::MESSAGE5_TYPE_OID_QUERY_EX_COMPLETE,
5512                            protocol::Message5OidQueryExComplete {
5513                                status: rndisprot::STATUS_NOT_SUPPORTED,
5514                                bytes: 0,
5515                            },
5516                        )),
5517                    )?;
5518                }
5519                p => {
5520                    tracing::warn!(request = ?p, "unexpected packet");
5521                    return Err(WorkerError::UnexpectedPacketOrder(
5522                        PacketOrderError::SwitchDataPathCompletionPrimaryChannelState,
5523                    ));
5524                }
5525            }
5526        }
5527        if total_packets > 0 && !self.transmit_segments(state, data, queue_state)? {
5528            state.stats.tx_stalled.increment();
5529        }
5530        state.stats.tx_packets_per_wake.add_sample(total_packets);
5531        Ok(did_some_work)
5532    }
5533
5534    // Transmit any pending segments. Returns Ok(true) if work was done--if any
5535    // segments were transmitted.
5536    fn transmit_pending_segments(
5537        &mut self,
5538        state: &mut ActiveState,
5539        data: &mut ProcessingData,
5540        queue_state: &mut QueueState,
5541    ) -> Result<bool, WorkerError> {
5542        if data.tx_segments.is_empty() {
5543            return Ok(false);
5544        }
5545        let sent = data.tx_segments_sent;
5546        let did_work =
5547            self.transmit_segments(state, data, queue_state)? || data.tx_segments_sent > sent;
5548        Ok(did_work)
5549    }
5550
5551    /// Returns true if all pending segments were transmitted.
5552    fn transmit_segments(
5553        &mut self,
5554        state: &mut ActiveState,
5555        data: &mut ProcessingData,
5556        queue_state: &mut QueueState,
5557    ) -> Result<bool, WorkerError> {
5558        let segments = &data.tx_segments[data.tx_segments_sent..];
5559        let (sync, segments_sent) = queue_state
5560            .queue
5561            .tx_avail(segments)
5562            .map_err(WorkerError::Endpoint)?;
5563
5564        let mut segments = &segments[..segments_sent];
5565        data.tx_segments_sent += segments_sent;
5566
5567        if sync {
5568            // Complete the packets now.
5569            while let Some(head) = segments.first() {
5570                let net_backend::TxSegmentType::Head(metadata) = &head.ty else {
5571                    unreachable!()
5572                };
5573                let id = metadata.id;
5574                let pending_tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5575                pending_tx_packet.pending_packet_count -= 1;
5576                if pending_tx_packet.pending_packet_count == 0 {
5577                    self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5578                }
5579                segments = &segments[metadata.segment_count as usize..];
5580            }
5581        }
5582
5583        let all_sent = data.tx_segments_sent == data.tx_segments.len();
5584        if all_sent {
5585            data.tx_segments.clear();
5586            data.tx_segments_sent = 0;
5587        }
5588        Ok(all_sent)
5589    }
5590
5591    fn handle_rndis(
5592        &mut self,
5593        buffers: &ChannelBuffers,
5594        id: TxId,
5595        state: &mut ActiveState,
5596        packet: &Packet<'_>,
5597        segments: &mut Vec<TxSegment>,
5598    ) -> Result<usize, WorkerError> {
5599        let mut total_packets = 0;
5600        let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5601        assert_eq!(tx_packet.pending_packet_count, 0);
5602        tx_packet.transaction_id = packet
5603            .transaction_id
5604            .ok_or(WorkerError::MissingTransactionId)?;
5605
5606        // Probe the data to catch accesses that are out of bounds. This
5607        // simplifies error handling for backends that use
5608        // [`GuestMemory::iova`].
5609        packet
5610            .external_data
5611            .iter()
5612            .try_for_each(|range| buffers.mem.probe_gpns(range.gpns()))
5613            .map_err(WorkerError::GpaDirectError)?;
5614
5615        let mut reader = packet.rndis_reader(&buffers.mem);
5616        let header: rndisprot::MessageHeader = reader.read_plain()?;
5617        if header.message_type == rndisprot::MESSAGE_TYPE_PACKET_MSG {
5618            let start = segments.len();
5619            match self.handle_rndis_packet_messages(
5620                buffers,
5621                state,
5622                id,
5623                header.message_length as usize,
5624                reader,
5625                segments,
5626            ) {
5627                Ok(n) => {
5628                    state.pending_tx_packets[id.0 as usize].pending_packet_count += n;
5629                    total_packets += n;
5630                }
5631                Err(err) => {
5632                    // Roll back any segments added for this message.
5633                    segments.truncate(start);
5634                    return Err(err);
5635                }
5636            }
5637        } else {
5638            self.handle_rndis_message(state, header.message_type, reader)?;
5639        }
5640
5641        Ok(total_packets)
5642    }
5643
5644    fn try_send_tx_packet(
5645        &mut self,
5646        transaction_id: u64,
5647        status: protocol::Status,
5648    ) -> Result<bool, WorkerError> {
5649        let message = self.message(
5650            protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE,
5651            protocol::Message1SendRndisPacketComplete { status },
5652        );
5653        let result = self.queue.split().1.batched().try_write_aligned(
5654            transaction_id,
5655            OutgoingPacketType::Completion,
5656            message.aligned_payload(),
5657        );
5658        let sent = match result {
5659            Ok(()) => true,
5660            Err(queue::TryWriteError::Full(n)) => {
5661                self.pending_send_size = n;
5662                false
5663            }
5664            Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
5665        };
5666        Ok(sent)
5667    }
5668
5669    fn send_pending_packets(&mut self, state: &mut ActiveState) -> Result<bool, WorkerError> {
5670        let mut did_some_work = false;
5671        while let Some(pending) = state.pending_tx_completions.front() {
5672            if !self.try_send_tx_packet(pending.transaction_id, pending.status)? {
5673                return Ok(did_some_work);
5674            }
5675            did_some_work = true;
5676            if let Some(id) = pending.tx_id {
5677                state.free_tx_packets.push(id);
5678            }
5679            tracing::trace!(?pending, "sent tx completion");
5680            state.pending_tx_completions.pop_front();
5681        }
5682
5683        self.pending_send_size = 0;
5684        Ok(did_some_work)
5685    }
5686
5687    fn complete_tx_packet(
5688        &mut self,
5689        state: &mut ActiveState,
5690        id: TxId,
5691        status: protocol::Status,
5692    ) -> Result<(), WorkerError> {
5693        let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5694        assert_eq!(tx_packet.pending_packet_count, 0);
5695        if self.pending_send_size == 0
5696            && self.try_send_tx_packet(tx_packet.transaction_id, status)?
5697        {
5698            tracing::trace!(id = id.0, "sent tx completion");
5699            state.free_tx_packets.push(id);
5700        } else {
5701            tracing::trace!(id = id.0, "pended tx completion");
5702            state.pending_tx_completions.push_back(PendingTxCompletion {
5703                transaction_id: tx_packet.transaction_id,
5704                tx_id: Some(id),
5705                status,
5706            });
5707        }
5708        Ok(())
5709    }
5710}
5711
5712impl ActiveState {
5713    fn release_recv_buffers(
5714        &mut self,
5715        transaction_id: u64,
5716        rx_buffer_range: &RxBufferRange,
5717        done: &mut Vec<RxId>,
5718    ) -> Option<()> {
5719        // The transaction ID specifies the first rx buffer ID.
5720        let first_id: u32 = transaction_id.try_into().ok()?;
5721        let ids = self.rx_bufs.free(first_id)?;
5722        for id in ids {
5723            if !rx_buffer_range.send_if_remote(id) {
5724                if id >= RX_RESERVED_CONTROL_BUFFERS {
5725                    done.push(RxId(id));
5726                } else {
5727                    self.primary
5728                        .as_mut()?
5729                        .free_control_buffers
5730                        .push(ControlMessageId(id));
5731                }
5732            }
5733        }
5734        Some(())
5735    }
5736}