netvsp/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! The user-mode netvsp VMBus device implementation.
5
6#![expect(missing_docs)]
7#![forbid(unsafe_code)]
8
9mod buffers;
10mod protocol;
11pub mod resolver;
12mod rndisprot;
13mod rx_bufs;
14mod saved_state;
15mod test;
16
17use crate::buffers::GuestBuffers;
18use crate::protocol::VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES;
19use crate::protocol::Version;
20use crate::rndisprot::NDIS_HASH_FUNCTION_MASK;
21use crate::rndisprot::NDIS_RSS_PARAM_FLAG_DISABLE_RSS;
22use async_trait::async_trait;
23pub use buffers::BufferPool;
24use buffers::sub_allocation_size_for_mtu;
25use futures::FutureExt;
26use futures::StreamExt;
27use futures::channel::mpsc;
28use futures::channel::mpsc::TrySendError;
29use futures_concurrency::future::Race;
30use guestmem::AccessError;
31use guestmem::GuestMemory;
32use guestmem::GuestMemoryError;
33use guestmem::MemoryRead;
34use guestmem::MemoryWrite;
35use guestmem::ranges::PagedRange;
36use guestmem::ranges::PagedRanges;
37use guestmem::ranges::PagedRangesReader;
38use guid::Guid;
39use hvdef::hypercall::HvGuestOsId;
40use hvdef::hypercall::HvGuestOsMicrosoft;
41use hvdef::hypercall::HvGuestOsMicrosoftIds;
42use hvdef::hypercall::HvGuestOsOpenSourceType;
43use inspect::Inspect;
44use inspect::InspectMut;
45use inspect::SensitivityLevel;
46use inspect_counters::Counter;
47use inspect_counters::Histogram;
48use mesh::rpc::Rpc;
49use net_backend::Endpoint;
50use net_backend::EndpointAction;
51use net_backend::QueueConfig;
52use net_backend::RxId;
53use net_backend::TxError;
54use net_backend::TxId;
55use net_backend::TxSegment;
56use net_backend_resources::mac_address::MacAddress;
57use pal_async::timer::Instant;
58use pal_async::timer::PolledTimer;
59use ring::gparange::MultiPagedRangeIter;
60use rx_bufs::RxBuffers;
61use rx_bufs::SubAllocationInUse;
62use std::collections::VecDeque;
63use std::fmt::Debug;
64use std::future::pending;
65use std::mem::offset_of;
66use std::ops::Range;
67use std::sync::Arc;
68use std::sync::atomic::AtomicUsize;
69use std::sync::atomic::Ordering;
70use std::task::Poll;
71use std::time::Duration;
72use task_control::AsyncRun;
73use task_control::InspectTaskMut;
74use task_control::StopTask;
75use task_control::TaskControl;
76use thiserror::Error;
77use tracing::Instrument;
78use vmbus_async::queue;
79use vmbus_async::queue::ExternalDataError;
80use vmbus_async::queue::IncomingPacket;
81use vmbus_async::queue::Queue;
82use vmbus_channel::bus::OfferParams;
83use vmbus_channel::bus::OpenRequest;
84use vmbus_channel::channel::ChannelControl;
85use vmbus_channel::channel::ChannelOpenError;
86use vmbus_channel::channel::ChannelRestoreError;
87use vmbus_channel::channel::DeviceResources;
88use vmbus_channel::channel::RestoreControl;
89use vmbus_channel::channel::SaveRestoreVmbusDevice;
90use vmbus_channel::channel::VmbusDevice;
91use vmbus_channel::gpadl::GpadlId;
92use vmbus_channel::gpadl::GpadlMapView;
93use vmbus_channel::gpadl::GpadlView;
94use vmbus_channel::gpadl::UnknownGpadlId;
95use vmbus_channel::gpadl_ring::GpadlRingMem;
96use vmbus_channel::gpadl_ring::gpadl_channel;
97use vmbus_ring as ring;
98use vmbus_ring::OutgoingPacketType;
99use vmbus_ring::RingMem;
100use vmbus_ring::gparange::MultiPagedRangeBuf;
101use vmcore::save_restore::RestoreError;
102use vmcore::save_restore::SaveError;
103use vmcore::save_restore::SavedStateBlob;
104use vmcore::vm_task::VmTaskDriver;
105use vmcore::vm_task::VmTaskDriverSource;
106use zerocopy::FromBytes;
107use zerocopy::FromZeros;
108use zerocopy::Immutable;
109use zerocopy::IntoBytes;
110use zerocopy::KnownLayout;
111
112// The minimum ring space required to handle a control message. Most control messages only need to send a completion
113// packet, but also need room for an additional SEND_VF_ASSOCIATION message.
114const MIN_CONTROL_RING_SIZE: usize = 144;
115
116// The minimum ring space required to handle external state changes. Worst case requires a completion message plus two
117// additional inband messages (SWITCH_DATA_PATH and SEND_VF_ASSOCIATION)
118const MIN_STATE_CHANGE_RING_SIZE: usize = 196;
119
120// Assign the VF_ASSOCIATION message a specific transaction ID so that the completion packet can be identified easily.
121const VF_ASSOCIATION_TRANSACTION_ID: u64 = 0x8000000000000000;
122// Assign the SWITCH_DATA_PATH message a specific transaction ID so that the completion packet can be identified easily.
123const SWITCH_DATA_PATH_TRANSACTION_ID: u64 = 0x8000000000000001;
124
125const NETVSP_MAX_SUBCHANNELS_PER_VNIC: u16 = 64;
126
127// Arbitrary delay before adding the device to the guest. Older Linux
128// clients can race when initializing the synthetic nic: the network
129// negotiation is done first and then the device is asynchronously queued to
130// receive a name (e.g. eth0). If the AN device is offered too quickly, it
131// could get the "eth0" name. In provisioning scenarios, the scripts will make
132// assumptions about which interface should be used, with eth0 being the
133// default.
134#[cfg(not(test))]
135const VF_DEVICE_DELAY: Duration = Duration::from_secs(1);
136#[cfg(test)]
137const VF_DEVICE_DELAY: Duration = Duration::from_millis(100);
138
139// Linux guests are known to not act on link state change notifications if
140// they happen in quick succession.
141#[cfg(not(test))]
142const LINK_DELAY_DURATION: Duration = Duration::from_secs(5);
143#[cfg(test)]
144const LINK_DELAY_DURATION: Duration = Duration::from_millis(333);
145
146#[derive(Default, PartialEq)]
147struct CoordinatorMessageUpdateType {
148    /// Update guest VF state based on current availability and the guest VF state tracked by the primary channel.
149    /// This includes adding the guest VF device and switching the data path.
150    guest_vf_state: bool,
151    /// Update the receive filter for all channels.
152    filter_state: bool,
153}
154
155#[derive(PartialEq)]
156enum CoordinatorMessage {
157    /// Update network state.
158    Update(CoordinatorMessageUpdateType),
159    /// Restart endpoints and resume processing. This will also attempt to set VF and data path state to match current
160    /// expectations.
161    Restart,
162    /// Start a timer.
163    StartTimer(Instant),
164}
165
166struct Worker<T: RingMem> {
167    channel_idx: u16,
168    target_vp: Option<u32>,
169    mem: GuestMemory,
170    channel: NetChannel<T>,
171    state: WorkerState,
172    coordinator_send: mpsc::Sender<CoordinatorMessage>,
173}
174
175struct NetQueue {
176    driver: VmTaskDriver,
177    queue_state: Option<QueueState>,
178}
179
180impl<T: RingMem + 'static + Sync> InspectTaskMut<Worker<T>> for NetQueue {
181    fn inspect_mut(&mut self, req: inspect::Request<'_>, worker: Option<&mut Worker<T>>) {
182        if worker.is_none() && self.queue_state.is_none() {
183            req.ignore();
184            return;
185        }
186
187        let mut resp = req.respond();
188        resp.field("driver", &self.driver);
189        if let Some(worker) = worker {
190            resp.field(
191                "protocol_state",
192                match &worker.state {
193                    WorkerState::Init(None) => "version",
194                    WorkerState::Init(Some(_)) => "init",
195                    WorkerState::Ready(_) => "ready",
196                    WorkerState::WaitingForCoordinator(_) => "waiting for coordinator",
197                },
198            )
199            .field("ring", &worker.channel.queue)
200            .field(
201                "can_use_ring_size_optimization",
202                worker.channel.can_use_ring_size_opt,
203            );
204
205            if let WorkerState::Ready(state) = &worker.state {
206                resp.field(
207                    "outstanding_tx_packets",
208                    state.state.pending_tx_packets.len() - state.state.free_tx_packets.len(),
209                )
210                .field("pending_rx_packets", state.state.pending_rx_packets.len())
211                .field(
212                    "pending_tx_completions",
213                    state.state.pending_tx_completions.len(),
214                )
215                .field("free_tx_packets", state.state.free_tx_packets.len())
216                .merge(&state.state.stats);
217            }
218
219            resp.field("packet_filter", worker.channel.packet_filter);
220        }
221
222        if let Some(queue_state) = &mut self.queue_state {
223            resp.field_mut("queue", &mut queue_state.queue)
224                .field("rx_buffers", queue_state.rx_buffer_range.id_range.len())
225                .field(
226                    "rx_buffers_start",
227                    queue_state.rx_buffer_range.id_range.start,
228                );
229        }
230    }
231}
232
233enum WorkerState {
234    Init(Option<InitState>),
235    Ready(ReadyState),
236    WaitingForCoordinator(Option<ReadyState>),
237}
238
239impl WorkerState {
240    fn ready(&self) -> Option<&ReadyState> {
241        if let Self::Ready(state) = self {
242            Some(state)
243        } else {
244            None
245        }
246    }
247
248    fn ready_mut(&mut self) -> Option<&mut ReadyState> {
249        if let Self::Ready(state) = self {
250            Some(state)
251        } else {
252            None
253        }
254    }
255}
256
257struct InitState {
258    version: Version,
259    ndis_config: Option<NdisConfig>,
260    ndis_version: Option<NdisVersion>,
261    recv_buffer: Option<ReceiveBuffer>,
262    send_buffer: Option<SendBuffer>,
263}
264
265#[derive(Copy, Clone, Debug, Inspect)]
266struct NdisVersion {
267    #[inspect(hex)]
268    major: u32,
269    #[inspect(hex)]
270    minor: u32,
271}
272
273#[derive(Copy, Clone, Debug, Inspect)]
274struct NdisConfig {
275    #[inspect(safe)]
276    mtu: u32,
277    #[inspect(safe)]
278    capabilities: protocol::NdisConfigCapabilities,
279}
280
281struct ReadyState {
282    buffers: Arc<ChannelBuffers>,
283    state: ActiveState,
284    data: ProcessingData,
285}
286
287impl ReadyState {
288    /// Any in-flight TX packets submitted to the old endpoint queues will
289    /// never be completed, so this method:
290    /// 1. Queues completions for all outstanding sends so the guest gets
291    ///    responses and the associated transmit slots are reclaimed, ensuring
292    ///    the worker is ready to poll post restart of the queues.
293    /// 2. Clears leftover transmit segments that referenced the old endpoint
294    ///    queue so they are not submitted to the new one.
295    ///
296    fn reset_tx_after_endpoint_stop(&mut self) {
297        let state = &mut self.state;
298
299        // Queue completions for in-flight TX packets that were lost when the
300        // endpoint stopped. They will get picked up when the worker restarts.
301        let pending_tx = state
302            .pending_tx_packets
303            .iter_mut()
304            .enumerate()
305            .filter_map(|(id, inflight)| {
306                if inflight.pending_packet_count > 0 {
307                    inflight.pending_packet_count = 0;
308                    Some(PendingTxCompletion {
309                        transaction_id: inflight.transaction_id,
310                        tx_id: Some(TxId(id as u32)),
311                        status: protocol::Status::SUCCESS,
312                    })
313                } else {
314                    None
315                }
316            })
317            .collect::<Vec<_>>();
318        state.pending_tx_completions.extend(pending_tx);
319
320        // Clear leftover TX segments from the previous endpoint queue;
321        // they cannot be submitted to the new queue.
322        self.data.tx_segments.clear();
323        self.data.tx_segments_sent = 0;
324    }
325}
326
327/// Represents a virtual function (VF) device used to expose accelerated
328/// networking to the guest.
329#[async_trait]
330pub trait VirtualFunction: Sync + Send {
331    /// Unique ID of the device. Used by the client to associate a device with
332    /// its synthetic counterpart. A value of None signifies that the VF is not
333    /// currently available for use.
334    async fn id(&self) -> Option<u32>;
335    /// Dynamically expose the device in the guest.
336    async fn guest_ready_for_device(&mut self);
337    /// Returns when there is a change in VF availability. The Rpc result will
338    ///  indicate if the change was successfully handled.
339    async fn wait_for_state_change(&mut self) -> Rpc<(), ()>;
340}
341
342struct Adapter {
343    driver: VmTaskDriver,
344    mac_address: MacAddress,
345    max_queues: u16,
346    indirection_table_size: u16,
347    offload_support: OffloadConfig,
348    ring_size_limit: AtomicUsize,
349    free_tx_packet_threshold: usize,
350    tx_fast_completions: bool,
351    adapter_index: u32,
352    get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
353    num_sub_channels_opened: AtomicUsize,
354    link_speed: u64,
355}
356
357struct QueueState {
358    queue: Box<dyn net_backend::Queue>,
359    pool: BufferPool,
360    rx_buffer_range: RxBufferRange,
361    target_vp_set: bool,
362}
363
364struct RxBufferRange {
365    id_range: Range<u32>,
366    remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
367    remote_ranges: Arc<RxBufferRanges>,
368}
369
370impl RxBufferRange {
371    fn new(
372        ranges: Arc<RxBufferRanges>,
373        id_range: Range<u32>,
374        remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
375    ) -> Self {
376        Self {
377            id_range,
378            remote_buffer_id_recv,
379            remote_ranges: ranges,
380        }
381    }
382
383    fn send_if_remote(&self, id: u32) -> bool {
384        // Only queue 0 should get reserved buffer IDs. Otherwise check if the
385        // ID is owned by the current range.
386        if id < RX_RESERVED_CONTROL_BUFFERS || self.id_range.contains(&id) {
387            false
388        } else {
389            let i = (id - RX_RESERVED_CONTROL_BUFFERS) / self.remote_ranges.buffers_per_queue;
390            // The total number of receive buffers may not evenly divide among
391            // the active queues. Any extra buffers are given to the last
392            // queue, so redirect any larger values there.
393            let i = (i as usize).min(self.remote_ranges.buffer_id_send.len() - 1);
394            let _ = self.remote_ranges.buffer_id_send[i].unbounded_send(id);
395            true
396        }
397    }
398}
399
400struct RxBufferRanges {
401    buffers_per_queue: u32,
402    buffer_id_send: Vec<mpsc::UnboundedSender<u32>>,
403}
404
405impl RxBufferRanges {
406    fn new(buffer_count: u32, queue_count: u32) -> (Self, Vec<mpsc::UnboundedReceiver<u32>>) {
407        let buffers_per_queue = (buffer_count - RX_RESERVED_CONTROL_BUFFERS) / queue_count;
408        #[expect(clippy::disallowed_methods)] // TODO
409        let (send, recv): (Vec<_>, Vec<_>) = (0..queue_count).map(|_| mpsc::unbounded()).unzip();
410        (
411            Self {
412                buffers_per_queue,
413                buffer_id_send: send,
414            },
415            recv,
416        )
417    }
418}
419
420struct RssState {
421    key: [u8; 40],
422    indirection_table: Vec<u16>,
423}
424
425/// The internal channel state.
426struct NetChannel<T: RingMem> {
427    adapter: Arc<Adapter>,
428    queue: Queue<T>,
429    gpadl_map: GpadlMapView,
430    packet_size: PacketSize,
431    pending_send_size: usize,
432    restart: Option<CoordinatorMessage>,
433    can_use_ring_size_opt: bool,
434    packet_filter: u32,
435}
436
437// Use an enum to give the compiler more visibility into the packet size.
438#[derive(Debug, Copy, Clone)]
439enum PacketSize {
440    /// [`protocol::PACKET_SIZE_V1`]
441    V1,
442    /// [`protocol::PACKET_SIZE_V61`]
443    V61,
444}
445
446/// Buffers used during packet processing.
447struct ProcessingData {
448    tx_segments: Vec<TxSegment>,
449    tx_segments_sent: usize,
450    tx_done: Box<[TxId]>,
451    rx_ready: Box<[RxId]>,
452    rx_done: Vec<RxId>,
453    transfer_pages: Vec<ring::TransferPageRange>,
454    external_data: MultiPagedRangeBuf,
455}
456
457impl ProcessingData {
458    fn new() -> Self {
459        Self {
460            tx_segments: Vec::new(),
461            tx_segments_sent: 0,
462            tx_done: vec![TxId(0); 8192].into(),
463            rx_ready: vec![RxId(0); RX_BATCH_SIZE].into(),
464            rx_done: Vec::with_capacity(RX_BATCH_SIZE),
465            transfer_pages: Vec::with_capacity(RX_BATCH_SIZE),
466            external_data: MultiPagedRangeBuf::new(),
467        }
468    }
469}
470
471/// Buffers used during channel processing. Separated out from the mutable state
472/// to allow multiple concurrent references.
473#[derive(Debug, Inspect)]
474struct ChannelBuffers {
475    version: Version,
476    #[inspect(skip)]
477    mem: GuestMemory,
478    #[inspect(skip)]
479    recv_buffer: ReceiveBuffer,
480    #[inspect(skip)]
481    send_buffer: Option<SendBuffer>,
482    ndis_version: NdisVersion,
483    #[inspect(safe)]
484    ndis_config: NdisConfig,
485}
486
487/// An ID assigned to a control message. This is also its receive buffer index.
488#[derive(Copy, Clone, Debug)]
489struct ControlMessageId(u32);
490
491/// Mutable state for a channel that has finished negotiation.
492struct ActiveState {
493    primary: Option<PrimaryChannelState>,
494
495    pending_tx_packets: Vec<PendingTxPacket>,
496    free_tx_packets: Vec<TxId>,
497    pending_tx_completions: VecDeque<PendingTxCompletion>,
498    pending_rx_packets: VecDeque<RxId>,
499
500    rx_bufs: RxBuffers,
501
502    stats: QueueStats,
503}
504
505#[derive(Inspect, Default)]
506struct QueueStats {
507    tx_stalled: Counter,
508    rx_dropped_ring_full: Counter,
509    rx_dropped_filtered: Counter,
510    spurious_wakes: Counter,
511    rx_packets: Counter,
512    tx_packets: Counter,
513    tx_lso_packets: Counter,
514    tx_checksum_packets: Counter,
515    tx_invalid_lso_packets: Counter,
516    tx_packets_per_wake: Histogram<10>,
517    rx_packets_per_wake: Histogram<10>,
518}
519
520#[derive(Debug)]
521struct PendingTxCompletion {
522    transaction_id: u64,
523    tx_id: Option<TxId>,
524    status: protocol::Status,
525}
526
527#[derive(Clone, Copy)]
528enum PrimaryChannelGuestVfState {
529    /// No state has been assigned yet
530    Initializing,
531    /// State is being restored from a save
532    Restoring(saved_state::GuestVfState),
533    /// No VF available for the guest
534    Unavailable,
535    /// A VF was previously available to the guest, but is no longer available.
536    UnavailableFromAvailable,
537    /// A VF was previously available, but it is no longer available.
538    UnavailableFromDataPathSwitchPending { to_guest: bool, id: Option<u64> },
539    /// A VF was previously available, but it is no longer available.
540    UnavailableFromDataPathSwitched,
541    /// A VF is available for the guest.
542    Available { vfid: u32 },
543    /// A VF is available for the guest and has been advertised to the guest.
544    AvailableAdvertised,
545    /// A VF is ready for guest use.
546    Ready,
547    /// A VF is ready in the guest and guest has requested a data path switch.
548    DataPathSwitchPending {
549        to_guest: bool,
550        id: Option<u64>,
551        result: Option<bool>,
552    },
553    /// A VF is ready in the guest and is currently acting as the data path.
554    DataPathSwitched,
555    /// A VF is ready in the guest and was acting as the data path, but an external
556    /// state change has moved it back to synthetic.
557    DataPathSynthetic,
558}
559
560impl std::fmt::Display for PrimaryChannelGuestVfState {
561    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
562        match self {
563            PrimaryChannelGuestVfState::Initializing => write!(f, "initializing"),
564            PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
565                write!(f, "restoring")
566            }
567            PrimaryChannelGuestVfState::Restoring(
568                saved_state::GuestVfState::AvailableAdvertised,
569            ) => write!(f, "restoring from guest notified of vfid"),
570            PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
571                write!(f, "restoring from vf present")
572            }
573            PrimaryChannelGuestVfState::Restoring(
574                saved_state::GuestVfState::DataPathSwitchPending {
575                    to_guest, result, ..
576                },
577            ) => {
578                write!(
579                    f,
580                    "restoring from client requested data path switch: to {} {}",
581                    if *to_guest { "guest" } else { "synthetic" },
582                    if let Some(result) = result {
583                        if *result { "succeeded\"" } else { "failed\"" }
584                    } else {
585                        "in progress\""
586                    }
587                )
588            }
589            PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::DataPathSwitched) => {
590                write!(f, "restoring from data path in guest")
591            }
592            PrimaryChannelGuestVfState::Unavailable => write!(f, "unavailable"),
593            PrimaryChannelGuestVfState::UnavailableFromAvailable => {
594                write!(f, "\"unavailable (previously available)\"")
595            }
596            PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. } => {
597                write!(f, "unavailable (previously switching data path)")
598            }
599            PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
600                write!(f, "\"unavailable (previously using guest VF)\"")
601            }
602            PrimaryChannelGuestVfState::Available { vfid } => write!(f, "available vfid: {}", vfid),
603            PrimaryChannelGuestVfState::AvailableAdvertised => {
604                write!(f, "\"available, guest notified\"")
605            }
606            PrimaryChannelGuestVfState::Ready => write!(f, "\"available and present in guest\""),
607            PrimaryChannelGuestVfState::DataPathSwitchPending {
608                to_guest, result, ..
609            } => {
610                write!(
611                    f,
612                    "\"switching to {} {}",
613                    if *to_guest { "guest" } else { "synthetic" },
614                    if let Some(result) = result {
615                        if *result { "succeeded\"" } else { "failed\"" }
616                    } else {
617                        "in progress\""
618                    }
619                )
620            }
621            PrimaryChannelGuestVfState::DataPathSwitched => {
622                write!(f, "\"available and data path switched\"")
623            }
624            PrimaryChannelGuestVfState::DataPathSynthetic => write!(
625                f,
626                "\"available but data path switched back to synthetic due to external state change\""
627            ),
628        }
629    }
630}
631
632impl Inspect for PrimaryChannelGuestVfState {
633    fn inspect(&self, req: inspect::Request<'_>) {
634        req.value(self.to_string());
635    }
636}
637
638#[derive(Debug)]
639enum PendingLinkAction {
640    Default,
641    Active(bool),
642    Delay(bool),
643}
644
645struct PrimaryChannelState {
646    guest_vf_state: PrimaryChannelGuestVfState,
647    is_data_path_switched: Option<bool>,
648    control_messages: VecDeque<ControlMessage>,
649    control_messages_len: usize,
650    free_control_buffers: Vec<ControlMessageId>,
651    rss_state: Option<RssState>,
652    requested_num_queues: u16,
653    rndis_state: RndisState,
654    offload_config: OffloadConfig,
655    pending_offload_change: bool,
656    tx_spread_sent: bool,
657    guest_link_up: bool,
658    pending_link_action: PendingLinkAction,
659}
660
661impl Inspect for PrimaryChannelState {
662    fn inspect(&self, req: inspect::Request<'_>) {
663        req.respond()
664            .sensitivity_field(
665                "guest_vf_state",
666                SensitivityLevel::Safe,
667                self.guest_vf_state,
668            )
669            .sensitivity_field(
670                "data_path_switched",
671                SensitivityLevel::Safe,
672                self.is_data_path_switched,
673            )
674            .sensitivity_field(
675                "pending_control_messages",
676                SensitivityLevel::Safe,
677                self.control_messages.len(),
678            )
679            .sensitivity_field(
680                "free_control_message_buffers",
681                SensitivityLevel::Safe,
682                self.free_control_buffers.len(),
683            )
684            .sensitivity_field(
685                "pending_offload_change",
686                SensitivityLevel::Safe,
687                self.pending_offload_change,
688            )
689            .sensitivity_field("rndis_state", SensitivityLevel::Safe, self.rndis_state)
690            .sensitivity_field(
691                "offload_config",
692                SensitivityLevel::Safe,
693                &self.offload_config,
694            )
695            .sensitivity_field(
696                "tx_spread_sent",
697                SensitivityLevel::Safe,
698                self.tx_spread_sent,
699            )
700            .sensitivity_field("guest_link_up", SensitivityLevel::Safe, self.guest_link_up)
701            .sensitivity_field(
702                "pending_link_action",
703                SensitivityLevel::Safe,
704                match &self.pending_link_action {
705                    PendingLinkAction::Active(up) => format!("Active({:x?})", up),
706                    PendingLinkAction::Delay(up) => format!("Delay({:x?})", up),
707                    PendingLinkAction::Default => "None".to_string(),
708                },
709            );
710    }
711}
712
713#[derive(Debug, Inspect, Clone)]
714struct OffloadConfig {
715    #[inspect(safe)]
716    checksum_tx: ChecksumOffloadConfig,
717    #[inspect(safe)]
718    checksum_rx: ChecksumOffloadConfig,
719    #[inspect(safe)]
720    lso4: bool,
721    #[inspect(safe)]
722    lso6: bool,
723}
724
725impl OffloadConfig {
726    fn mask_to_supported(&mut self, supported: &OffloadConfig) {
727        self.checksum_tx.mask_to_supported(&supported.checksum_tx);
728        self.checksum_rx.mask_to_supported(&supported.checksum_rx);
729        self.lso4 &= supported.lso4;
730        self.lso6 &= supported.lso6;
731    }
732}
733
734#[derive(Debug, Inspect, Clone)]
735struct ChecksumOffloadConfig {
736    #[inspect(safe)]
737    ipv4_header: bool,
738    #[inspect(safe)]
739    tcp4: bool,
740    #[inspect(safe)]
741    udp4: bool,
742    #[inspect(safe)]
743    tcp6: bool,
744    #[inspect(safe)]
745    udp6: bool,
746}
747
748impl ChecksumOffloadConfig {
749    fn mask_to_supported(&mut self, supported: &ChecksumOffloadConfig) {
750        self.ipv4_header &= supported.ipv4_header;
751        self.tcp4 &= supported.tcp4;
752        self.udp4 &= supported.udp4;
753        self.tcp6 &= supported.tcp6;
754        self.udp6 &= supported.udp6;
755    }
756
757    fn flags(
758        &self,
759    ) -> (
760        rndisprot::Ipv4ChecksumOffload,
761        rndisprot::Ipv6ChecksumOffload,
762    ) {
763        let on = rndisprot::NDIS_OFFLOAD_SUPPORTED;
764        let mut v4 = rndisprot::Ipv4ChecksumOffload::new();
765        let mut v6 = rndisprot::Ipv6ChecksumOffload::new();
766        if self.ipv4_header {
767            v4.set_ip_options_supported(on);
768            v4.set_ip_checksum(on);
769        }
770        if self.tcp4 {
771            v4.set_ip_options_supported(on);
772            v4.set_tcp_options_supported(on);
773            v4.set_tcp_checksum(on);
774        }
775        if self.tcp6 {
776            v6.set_ip_extension_headers_supported(on);
777            v6.set_tcp_options_supported(on);
778            v6.set_tcp_checksum(on);
779        }
780        if self.udp4 {
781            v4.set_ip_options_supported(on);
782            v4.set_udp_checksum(on);
783        }
784        if self.udp6 {
785            v6.set_ip_extension_headers_supported(on);
786            v6.set_udp_checksum(on);
787        }
788        (v4, v6)
789    }
790}
791
792impl OffloadConfig {
793    fn ndis_offload(&self) -> rndisprot::NdisOffload {
794        let checksum = {
795            let (ipv4_tx_flags, ipv6_tx_flags) = self.checksum_tx.flags();
796            let (ipv4_rx_flags, ipv6_rx_flags) = self.checksum_rx.flags();
797            rndisprot::TcpIpChecksumOffload {
798                ipv4_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
799                ipv4_tx_flags,
800                ipv4_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
801                ipv4_rx_flags,
802                ipv6_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
803                ipv6_tx_flags,
804                ipv6_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
805                ipv6_rx_flags,
806            }
807        };
808
809        let lso_v2 = {
810            let mut lso = rndisprot::TcpLargeSendOffloadV2::new_zeroed();
811            if self.lso4 {
812                lso.ipv4_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
813                lso.ipv4_max_offload_size = rndisprot::LSO_MAX_OFFLOAD_SIZE;
814                lso.ipv4_min_segment_count = rndisprot::LSO_MIN_SEGMENT_COUNT;
815            }
816            if self.lso6 {
817                lso.ipv6_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
818                lso.ipv6_max_offload_size = rndisprot::LSO_MAX_OFFLOAD_SIZE;
819                lso.ipv6_min_segment_count = rndisprot::LSO_MIN_SEGMENT_COUNT;
820                lso.ipv6_flags = rndisprot::Ipv6LsoFlags::new()
821                    .with_ip_extension_headers_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED)
822                    .with_tcp_options_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED);
823            }
824            lso
825        };
826
827        rndisprot::NdisOffload {
828            header: rndisprot::NdisObjectHeader {
829                object_type: rndisprot::NdisObjectType::OFFLOAD,
830                revision: 3,
831                size: rndisprot::NDIS_SIZEOF_NDIS_OFFLOAD_REVISION_3 as u16,
832            },
833            checksum,
834            lso_v2,
835            ..FromZeros::new_zeroed()
836        }
837    }
838}
839
840#[derive(Debug, Inspect, PartialEq, Eq, Copy, Clone)]
841pub enum RndisState {
842    Initializing,
843    Operational,
844    Halted,
845}
846
847impl PrimaryChannelState {
848    fn new(offload_config: OffloadConfig) -> Self {
849        Self {
850            guest_vf_state: PrimaryChannelGuestVfState::Initializing,
851            is_data_path_switched: None,
852            control_messages: VecDeque::new(),
853            control_messages_len: 0,
854            free_control_buffers: (0..RX_RESERVED_CONTROL_BUFFERS)
855                .map(ControlMessageId)
856                .collect(),
857            rss_state: None,
858            requested_num_queues: 1,
859            rndis_state: RndisState::Initializing,
860            pending_offload_change: false,
861            offload_config,
862            tx_spread_sent: false,
863            guest_link_up: true,
864            pending_link_action: PendingLinkAction::Default,
865        }
866    }
867
868    fn restore(
869        guest_vf_state: &saved_state::GuestVfState,
870        rndis_state: &saved_state::RndisState,
871        offload_config: &saved_state::OffloadConfig,
872        pending_offload_change: bool,
873        num_queues: u16,
874        indirection_table_size: u16,
875        rx_bufs: &RxBuffers,
876        control_messages: Vec<saved_state::IncomingControlMessage>,
877        rss_state: Option<saved_state::RssState>,
878        tx_spread_sent: bool,
879        guest_link_down: bool,
880        pending_link_action: Option<bool>,
881    ) -> Result<Self, NetRestoreError> {
882        // Restore control messages.
883        let control_messages_len = control_messages.iter().map(|msg| msg.data.len()).sum();
884
885        let control_messages = control_messages
886            .into_iter()
887            .map(|msg| ControlMessage {
888                message_type: msg.message_type,
889                data: msg.data.into(),
890            })
891            .collect();
892
893        // Compute the free control buffers.
894        let free_control_buffers = (0..RX_RESERVED_CONTROL_BUFFERS)
895            .filter_map(|id| rx_bufs.is_free(id).then_some(ControlMessageId(id)))
896            .collect();
897
898        let rss_state = rss_state
899            .map(|mut rss| {
900                if rss.indirection_table.len() > indirection_table_size as usize {
901                    // Dynamic reduction of indirection table can cause unexpected and hard to investigate issues
902                    // with performance and processor overloading.
903                    return Err(NetRestoreError::ReducedIndirectionTableSize);
904                }
905                if rss.indirection_table.len() < indirection_table_size as usize {
906                    tracing::warn!(
907                        saved_indirection_table_size = rss.indirection_table.len(),
908                        adapter_indirection_table_size = indirection_table_size,
909                        "increasing indirection table size",
910                    );
911                    // Dynamic increase of indirection table is done by duplicating the existing entries until
912                    // the desired size is reached.
913                    let table_clone = rss.indirection_table.clone();
914                    let num_to_add = indirection_table_size as usize - rss.indirection_table.len();
915                    rss.indirection_table
916                        .extend(table_clone.iter().cycle().take(num_to_add));
917                }
918                Ok(RssState {
919                    key: rss
920                        .key
921                        .try_into()
922                        .map_err(|_| NetRestoreError::InvalidRssKeySize)?,
923                    indirection_table: rss.indirection_table,
924                })
925            })
926            .transpose()?;
927
928        let rndis_state = match rndis_state {
929            saved_state::RndisState::Initializing => RndisState::Initializing,
930            saved_state::RndisState::Operational => RndisState::Operational,
931            saved_state::RndisState::Halted => RndisState::Halted,
932        };
933
934        let guest_vf_state = PrimaryChannelGuestVfState::Restoring(*guest_vf_state);
935        let offload_config = OffloadConfig {
936            checksum_tx: ChecksumOffloadConfig {
937                ipv4_header: offload_config.checksum_tx.ipv4_header,
938                tcp4: offload_config.checksum_tx.tcp4,
939                udp4: offload_config.checksum_tx.udp4,
940                tcp6: offload_config.checksum_tx.tcp6,
941                udp6: offload_config.checksum_tx.udp6,
942            },
943            checksum_rx: ChecksumOffloadConfig {
944                ipv4_header: offload_config.checksum_rx.ipv4_header,
945                tcp4: offload_config.checksum_rx.tcp4,
946                udp4: offload_config.checksum_rx.udp4,
947                tcp6: offload_config.checksum_rx.tcp6,
948                udp6: offload_config.checksum_rx.udp6,
949            },
950            lso4: offload_config.lso4,
951            lso6: offload_config.lso6,
952        };
953
954        let pending_link_action = if let Some(pending) = pending_link_action {
955            PendingLinkAction::Active(pending)
956        } else {
957            PendingLinkAction::Default
958        };
959
960        Ok(Self {
961            guest_vf_state,
962            is_data_path_switched: None,
963            control_messages,
964            control_messages_len,
965            free_control_buffers,
966            rss_state,
967            requested_num_queues: num_queues,
968            rndis_state,
969            pending_offload_change,
970            offload_config,
971            tx_spread_sent,
972            guest_link_up: !guest_link_down,
973            pending_link_action,
974        })
975    }
976}
977
978struct ControlMessage {
979    message_type: u32,
980    data: Box<[u8]>,
981}
982
983const TX_PACKET_QUOTA: usize = 1024;
984
985impl ActiveState {
986    fn new(primary: Option<PrimaryChannelState>, recv_buffer_count: u32) -> Self {
987        Self {
988            primary,
989            pending_tx_packets: vec![Default::default(); TX_PACKET_QUOTA],
990            free_tx_packets: (0..TX_PACKET_QUOTA as u32).rev().map(TxId).collect(),
991            pending_tx_completions: VecDeque::new(),
992            pending_rx_packets: VecDeque::new(),
993            rx_bufs: RxBuffers::new(recv_buffer_count),
994            stats: Default::default(),
995        }
996    }
997
998    fn restore(
999        channel: &saved_state::Channel,
1000        recv_buffer_count: u32,
1001    ) -> Result<Self, NetRestoreError> {
1002        let mut active = Self::new(None, recv_buffer_count);
1003        let saved_state::Channel {
1004            pending_tx_completions,
1005            in_use_rx,
1006        } = channel;
1007        for rx in in_use_rx {
1008            active
1009                .rx_bufs
1010                .allocate(rx.buffers.as_slice().iter().copied())?;
1011        }
1012        for &transaction_id in pending_tx_completions {
1013            // Consume tx quota if any is available. If not, still
1014            // allow the restore since tx quota might change from
1015            // release to release.
1016            let tx_id = active.free_tx_packets.pop();
1017            if let Some(id) = tx_id {
1018                // This shouldn't be referenced, but set it in case it is in the future.
1019                active.pending_tx_packets[id.0 as usize].transaction_id = transaction_id;
1020            }
1021            // Save/Restore does not preserve the status of pending tx completions,
1022            // completing any pending completions with 'success' to avoid making changes to saved_state.
1023            active
1024                .pending_tx_completions
1025                .push_back(PendingTxCompletion {
1026                    transaction_id,
1027                    tx_id,
1028                    status: protocol::Status::SUCCESS,
1029                });
1030        }
1031        Ok(active)
1032    }
1033}
1034
1035/// The state for an rndis tx packet that's currently pending in the backend
1036/// endpoint.
1037#[derive(Default, Clone)]
1038struct PendingTxPacket {
1039    pending_packet_count: usize,
1040    transaction_id: u64,
1041}
1042
1043/// The maximum batch size.
1044///
1045/// TODO: An even larger value is supported when RSC is enabled, so look into
1046/// this.
1047const RX_BATCH_SIZE: usize = 375;
1048
1049/// The number of receive buffers to reserve for control message responses.
1050const RX_RESERVED_CONTROL_BUFFERS: u32 = 16;
1051
1052/// A network adapter.
1053pub struct Nic {
1054    instance_id: Guid,
1055    resources: DeviceResources,
1056    coordinator: TaskControl<CoordinatorState, Coordinator>,
1057    coordinator_send: Option<mpsc::Sender<CoordinatorMessage>>,
1058    adapter: Arc<Adapter>,
1059    driver_source: VmTaskDriverSource,
1060}
1061
1062pub struct NicBuilder {
1063    virtual_function: Option<Box<dyn VirtualFunction>>,
1064    limit_ring_buffer: bool,
1065    max_queues: u16,
1066    get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
1067}
1068
1069impl NicBuilder {
1070    pub fn limit_ring_buffer(mut self, limit: bool) -> Self {
1071        self.limit_ring_buffer = limit;
1072        self
1073    }
1074
1075    pub fn max_queues(mut self, max_queues: u16) -> Self {
1076        self.max_queues = max_queues;
1077        self
1078    }
1079
1080    pub fn virtual_function(mut self, virtual_function: Box<dyn VirtualFunction>) -> Self {
1081        self.virtual_function = Some(virtual_function);
1082        self
1083    }
1084
1085    pub fn get_guest_os_id(mut self, os_type: Box<dyn Fn() -> HvGuestOsId + Send + Sync>) -> Self {
1086        self.get_guest_os_id = Some(os_type);
1087        self
1088    }
1089
1090    /// Creates a new NIC.
1091    pub fn build(
1092        self,
1093        driver_source: &VmTaskDriverSource,
1094        instance_id: Guid,
1095        endpoint: Box<dyn Endpoint>,
1096        mac_address: MacAddress,
1097        adapter_index: u32,
1098    ) -> Nic {
1099        let multiqueue = endpoint.multiqueue_support();
1100
1101        let max_queues = self.max_queues.clamp(
1102            1,
1103            multiqueue.max_queues.min(NETVSP_MAX_SUBCHANNELS_PER_VNIC),
1104        );
1105
1106        // If requested, limit the effective size of the outgoing ring buffer.
1107        // In a configuration where the NIC is processed synchronously, this
1108        // will ensure that we don't process incoming rx packets and tx packet
1109        // completions until the guest has processed the data it already has.
1110        let ring_size_limit = if self.limit_ring_buffer { 1024 } else { 0 };
1111
1112        // If the endpoint completes tx packets quickly, then avoid polling the
1113        // incoming ring (and thus avoid arming the signal from the guest) as
1114        // long as there are any tx packets in flight. This can significantly
1115        // reduce the signal rate from the guest, improving batching.
1116        let free_tx_packet_threshold = if endpoint.tx_fast_completions() {
1117            TX_PACKET_QUOTA
1118        } else {
1119            // Avoid getting into a situation where there is always barely
1120            // enough quota.
1121            TX_PACKET_QUOTA / 4
1122        };
1123
1124        let tx_offloads = endpoint.tx_offload_support();
1125
1126        // Always claim support for rx offloads since we can mark any given
1127        // packet as having unknown checksum state.
1128        let offload_support = OffloadConfig {
1129            checksum_rx: ChecksumOffloadConfig {
1130                ipv4_header: true,
1131                tcp4: true,
1132                udp4: true,
1133                tcp6: true,
1134                udp6: true,
1135            },
1136            checksum_tx: ChecksumOffloadConfig {
1137                ipv4_header: tx_offloads.ipv4_header,
1138                tcp4: tx_offloads.tcp,
1139                tcp6: tx_offloads.tcp,
1140                udp4: tx_offloads.udp,
1141                udp6: tx_offloads.udp,
1142            },
1143            // LSOv4 requires both TSO and IPv4 header checksum support,
1144            // because the TAP/virtio GSO engine needs a valid IPv4 header
1145            // checksum that NDIS LSO packets don't provide.
1146            lso4: tx_offloads.tso && tx_offloads.ipv4_header,
1147            lso6: tx_offloads.tso,
1148        };
1149
1150        let driver = driver_source.simple();
1151        let adapter = Arc::new(Adapter {
1152            driver,
1153            mac_address,
1154            max_queues,
1155            indirection_table_size: multiqueue.indirection_table_size,
1156            offload_support,
1157            free_tx_packet_threshold,
1158            ring_size_limit: ring_size_limit.into(),
1159            tx_fast_completions: endpoint.tx_fast_completions(),
1160            adapter_index,
1161            get_guest_os_id: self.get_guest_os_id,
1162            num_sub_channels_opened: AtomicUsize::new(0),
1163            link_speed: endpoint.link_speed(),
1164        });
1165
1166        let coordinator = TaskControl::new(CoordinatorState {
1167            endpoint,
1168            adapter: adapter.clone(),
1169            virtual_function: self.virtual_function,
1170            pending_vf_state: CoordinatorStatePendingVfState::Ready,
1171        });
1172
1173        Nic {
1174            instance_id,
1175            resources: Default::default(),
1176            coordinator,
1177            coordinator_send: None,
1178            adapter,
1179            driver_source: driver_source.clone(),
1180        }
1181    }
1182}
1183
1184fn can_use_ring_opt<T: RingMem>(queue: &mut Queue<T>, guest_os_id: Option<HvGuestOsId>) -> bool {
1185    let Some(guest_os_id) = guest_os_id else {
1186        // guest os id not available.
1187        return false;
1188    };
1189
1190    if !queue.split().0.supports_pending_send_size() {
1191        // guest does not support pending send size.
1192        return false;
1193    }
1194
1195    let Some(open_source_os) = guest_os_id.open_source() else {
1196        // guest os is proprietary (ex: Windows)
1197        return true;
1198    };
1199
1200    match HvGuestOsOpenSourceType(open_source_os.os_type()) {
1201        // Although FreeBSD indicates support for `pending send size`, it doesn't
1202        // implement it correctly. This was fixed in FreeBSD version `1400097`.
1203        HvGuestOsOpenSourceType::FREEBSD => open_source_os.version() >= 1400097,
1204        // Linux kernels prior to 3.11 have issues with pending send size optimization
1205        // which can affect certain Asynchronous I/O (AIO) network operations.
1206        // Disable ring size optimization for these older kernels to avoid flow control issues.
1207        HvGuestOsOpenSourceType::LINUX => {
1208            // Linux version is encoded as: ((major << 16) | (minor << 8) | patch)
1209            // Linux 3.11.0 = (3 << 16) | (11 << 8) | 0 = 199424
1210            open_source_os.version() >= 199424
1211        }
1212        _ => true,
1213    }
1214}
1215
1216impl Nic {
1217    pub fn builder() -> NicBuilder {
1218        NicBuilder {
1219            virtual_function: None,
1220            limit_ring_buffer: false,
1221            max_queues: !0,
1222            get_guest_os_id: None,
1223        }
1224    }
1225
1226    pub fn shutdown(self) -> (Box<dyn Endpoint>, MacAddress) {
1227        let (state, _) = self.coordinator.into_inner();
1228        (state.endpoint, self.adapter.mac_address)
1229    }
1230}
1231
1232impl InspectMut for Nic {
1233    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
1234        self.coordinator.inspect_mut(req);
1235    }
1236}
1237
1238#[async_trait]
1239impl VmbusDevice for Nic {
1240    fn offer(&self) -> OfferParams {
1241        OfferParams {
1242            interface_name: "net".to_owned(),
1243            instance_id: self.instance_id,
1244            interface_id: Guid {
1245                data1: 0xf8615163,
1246                data2: 0xdf3e,
1247                data3: 0x46c5,
1248                data4: [0x91, 0x3f, 0xf2, 0xd2, 0xf9, 0x65, 0xed, 0xe],
1249            },
1250            subchannel_index: 0,
1251            mnf_interrupt_latency: Some(Duration::from_micros(100)),
1252            ..Default::default()
1253        }
1254    }
1255
1256    fn max_subchannels(&self) -> u16 {
1257        self.adapter.max_queues
1258    }
1259
1260    fn install(&mut self, resources: DeviceResources) {
1261        self.resources = resources;
1262    }
1263
1264    async fn open(
1265        &mut self,
1266        channel_idx: u16,
1267        open_request: &OpenRequest,
1268    ) -> Result<(), ChannelOpenError> {
1269        // Start the coordinator task if this is the primary channel.
1270        let state = if channel_idx == 0 {
1271            self.insert_coordinator(1, None);
1272            WorkerState::Init(None)
1273        } else {
1274            self.coordinator.stop().await;
1275            // Get the buffers created when the primary channel was opened.
1276            let buffers = self.coordinator.state().unwrap().buffers.clone().unwrap();
1277            WorkerState::Ready(ReadyState {
1278                state: ActiveState::new(None, buffers.recv_buffer.count),
1279                buffers,
1280                data: ProcessingData::new(),
1281            })
1282        };
1283
1284        let num_opened = self
1285            .adapter
1286            .num_sub_channels_opened
1287            .fetch_add(1, Ordering::SeqCst);
1288        let r = self.insert_worker(channel_idx, open_request, state, true);
1289        if channel_idx != 0
1290            && num_opened + 1 == self.coordinator.state_mut().unwrap().num_queues as usize
1291        {
1292            let coordinator = &mut self.coordinator.state_mut().unwrap();
1293            coordinator.workers[0].stop().await;
1294        }
1295
1296        if r.is_err() && channel_idx == 0 {
1297            self.coordinator.remove();
1298        } else {
1299            // The coordinator will restart any stopped workers.
1300            self.coordinator.start();
1301        }
1302        r?;
1303        Ok(())
1304    }
1305
1306    async fn close(&mut self, channel_idx: u16) {
1307        if !self.coordinator.has_state() {
1308            tracing::error!(
1309                channel_idx,
1310                instance_id = %self.instance_id,
1311                "Close called while vmbus channel is already closed"
1312            );
1313            return;
1314        }
1315
1316        // Stop the coordinator to get access to the workers.
1317        let restart = self.coordinator.stop().await;
1318
1319        // Stop and remove the channel worker.
1320        {
1321            let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1322            worker.stop().await;
1323            if worker.has_state() {
1324                worker.remove();
1325            }
1326        }
1327
1328        self.adapter
1329            .num_sub_channels_opened
1330            .fetch_sub(1, Ordering::SeqCst);
1331        // Disable the endpoint.
1332        if channel_idx == 0 {
1333            for worker in &mut self.coordinator.state_mut().unwrap().workers {
1334                worker.task_mut().queue_state = None;
1335            }
1336
1337            // Note that this await is not restartable.
1338            self.coordinator.task_mut().endpoint.stop().await;
1339
1340            // Keep any VF's added to the guest. This is required to keep guest compat as
1341            // some apps (such as DPDK) relies on the VF sticking around even after vmbus
1342            // channel is closed.
1343            // The coordinator's job is done.
1344            self.coordinator.remove();
1345        } else {
1346            // Restart the coordinator.
1347            if restart {
1348                self.coordinator.start();
1349            }
1350        }
1351    }
1352
1353    async fn retarget_vp(&mut self, channel_idx: u16, target_vp: u32) {
1354        if !self.coordinator.has_state() {
1355            return;
1356        }
1357
1358        // Stop the coordinator and worker associated with this channel.
1359        let coordinator_running = self.coordinator.stop().await;
1360        let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1361        worker.stop().await;
1362        let (net_queue, worker_state) = worker.get_mut();
1363
1364        // Update the target VP on the driver.
1365        net_queue.driver.retarget_vp(target_vp);
1366
1367        if let Some(worker_state) = worker_state {
1368            // Update the target VP in the worker state.
1369            worker_state.target_vp = Some(target_vp);
1370            if let Some(queue_state) = &mut net_queue.queue_state {
1371                // Tell the worker to re-set the target VP on next run.
1372                queue_state.target_vp_set = false;
1373            }
1374        }
1375
1376        // The coordinator will restart any stopped workers.
1377        if coordinator_running {
1378            self.coordinator.start();
1379        }
1380    }
1381
1382    fn start(&mut self) {
1383        if !self.coordinator.is_running() {
1384            self.coordinator.start();
1385        }
1386    }
1387
1388    async fn stop(&mut self) {
1389        self.coordinator.stop().await;
1390        if let Some(coordinator) = self.coordinator.state_mut() {
1391            coordinator.stop_workers().await;
1392        }
1393    }
1394
1395    fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
1396        Some(self)
1397    }
1398}
1399
1400#[async_trait]
1401impl SaveRestoreVmbusDevice for Nic {
1402    async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
1403        let state = self.saved_state();
1404        Ok(SavedStateBlob::new(state))
1405    }
1406
1407    async fn restore(
1408        &mut self,
1409        control: RestoreControl<'_>,
1410        state: SavedStateBlob,
1411    ) -> Result<(), RestoreError> {
1412        let state: saved_state::SavedState = state.parse()?;
1413        if let Err(err) = self.restore_state(control, state).await {
1414            tracing::error!(
1415                error = &err as &dyn std::error::Error,
1416                instance_id = %self.instance_id,
1417                "Failed restoring network vmbus state"
1418            );
1419            Err(err.into())
1420        } else {
1421            Ok(())
1422        }
1423    }
1424}
1425
1426impl Nic {
1427    /// Allocates and inserts a worker.
1428    ///
1429    /// The coordinator must be stopped.
1430    fn insert_worker(
1431        &mut self,
1432        channel_idx: u16,
1433        open_request: &OpenRequest,
1434        state: WorkerState,
1435        start: bool,
1436    ) -> Result<(), OpenError> {
1437        let coordinator = self.coordinator.state_mut().unwrap();
1438
1439        // Retarget the driver now that the channel is open.
1440        // N.B. VMBus doesn't provide a target VP if the channel is not using interrupts. Run on VP
1441        //      0 in that case.
1442        let driver = coordinator.workers[channel_idx as usize]
1443            .task()
1444            .driver
1445            .clone();
1446        driver.retarget_vp(open_request.open_data.target_vp.unwrap_or_default());
1447
1448        let raw = gpadl_channel(&driver, &self.resources, open_request, channel_idx)
1449            .map_err(OpenError::Ring)?;
1450        let mut queue = Queue::new(raw).map_err(OpenError::Queue)?;
1451        let guest_os_id = self.adapter.get_guest_os_id.as_ref().map(|f| f());
1452        let can_use_ring_size_opt = can_use_ring_opt(&mut queue, guest_os_id);
1453        let worker = Worker {
1454            channel_idx,
1455            target_vp: open_request.open_data.target_vp,
1456            mem: self
1457                .resources
1458                .offer_resources
1459                .guest_memory(open_request)
1460                .clone(),
1461            channel: NetChannel {
1462                adapter: self.adapter.clone(),
1463                queue,
1464                gpadl_map: self.resources.gpadl_map.clone(),
1465                packet_size: PacketSize::V1,
1466                pending_send_size: 0,
1467                restart: None,
1468                can_use_ring_size_opt,
1469                packet_filter: coordinator.active_packet_filter,
1470            },
1471            state,
1472            coordinator_send: self.coordinator_send.clone().unwrap(),
1473        };
1474        let instance_id = self.instance_id;
1475        let worker_task = &mut coordinator.workers[channel_idx as usize];
1476        worker_task.insert(
1477            driver,
1478            format!("netvsp-{}-{}", instance_id, channel_idx),
1479            worker,
1480        );
1481        if start {
1482            worker_task.start();
1483        }
1484        Ok(())
1485    }
1486}
1487
1488struct RestoreCoordinatorState {
1489    active_packet_filter: u32,
1490}
1491
1492impl Nic {
1493    /// If `restoring`, then restart the queues as soon as the coordinator starts.
1494    fn insert_coordinator(&mut self, num_queues: u16, restoring: Option<RestoreCoordinatorState>) {
1495        let mut driver_builder = self.driver_source.builder();
1496        // Target each driver to VP 0 initially. This will be updated when the
1497        // channel is opened.
1498        driver_builder.target_vp(0);
1499        // If tx completions arrive quickly, then just do tx processing
1500        // on whatever processor the guest happens to signal from.
1501        // Subsequent transmits will be pulled from the completion
1502        // processor.
1503        driver_builder.run_on_target(!self.adapter.tx_fast_completions);
1504
1505        #[expect(clippy::disallowed_methods)] // TODO
1506        let (send, recv) = mpsc::channel(1);
1507        self.coordinator_send = Some(send);
1508        self.coordinator.insert(
1509            &self.adapter.driver,
1510            format!("netvsp-{}-coordinator", self.instance_id),
1511            Coordinator {
1512                recv,
1513                channel_control: self.resources.channel_control.clone(),
1514                restart: restoring.is_some(),
1515                workers: (0..self.adapter.max_queues)
1516                    .map(|i| {
1517                        TaskControl::new(NetQueue {
1518                            queue_state: None,
1519                            driver: driver_builder
1520                                .build(format!("netvsp-{}-{}", self.instance_id, i)),
1521                        })
1522                    })
1523                    .collect(),
1524                buffers: None,
1525                num_queues,
1526                active_packet_filter: restoring
1527                    .map(|r| r.active_packet_filter)
1528                    .unwrap_or(rndisprot::NDIS_PACKET_TYPE_NONE),
1529                sleep_deadline: None,
1530            },
1531        );
1532    }
1533}
1534
1535#[derive(Debug, Error)]
1536enum NetRestoreError {
1537    #[error("unsupported protocol version {0:#x}")]
1538    UnsupportedVersion(u32),
1539    #[error("send/receive buffer invalid gpadl ID")]
1540    UnknownGpadlId(#[from] UnknownGpadlId),
1541    #[error("failed to restore channels")]
1542    Channel(#[source] ChannelRestoreError),
1543    #[error(transparent)]
1544    ReceiveBuffer(#[from] BufferError),
1545    #[error(transparent)]
1546    SuballocationMisconfigured(#[from] SubAllocationInUse),
1547    #[error(transparent)]
1548    Open(#[from] OpenError),
1549    #[error("invalid rss key size")]
1550    InvalidRssKeySize,
1551    #[error("reduced indirection table size")]
1552    ReducedIndirectionTableSize,
1553}
1554
1555impl From<NetRestoreError> for RestoreError {
1556    fn from(err: NetRestoreError) -> Self {
1557        RestoreError::InvalidSavedState(anyhow::Error::new(err))
1558    }
1559}
1560
1561impl Nic {
1562    async fn restore_state(
1563        &mut self,
1564        mut control: RestoreControl<'_>,
1565        state: saved_state::SavedState,
1566    ) -> Result<(), NetRestoreError> {
1567        let mut saved_packet_filter = 0u32;
1568        if let Some(state) = state.open {
1569            let open = match &state.primary {
1570                saved_state::Primary::Version => vec![true],
1571                saved_state::Primary::Init(_) => vec![true],
1572                saved_state::Primary::Ready(ready) => {
1573                    ready.channels.iter().map(|x| x.is_some()).collect()
1574                }
1575            };
1576
1577            let mut states: Vec<_> = open.iter().map(|_| None).collect();
1578
1579            // N.B. This will restore the vmbus view of open channels, so any
1580            //      failures after this point could result in inconsistent
1581            //      state (vmbus believes the channel is open/active). There
1582            //      are a number of failure paths after this point because this
1583            //      call also restores vmbus device state, like the GPADL map.
1584            let requests = control
1585                .restore(&open)
1586                .await
1587                .map_err(NetRestoreError::Channel)?;
1588
1589            match state.primary {
1590                saved_state::Primary::Version => {
1591                    states[0] = Some(WorkerState::Init(None));
1592                }
1593                saved_state::Primary::Init(init) => {
1594                    let version = check_version(init.version)
1595                        .ok_or(NetRestoreError::UnsupportedVersion(init.version))?;
1596
1597                    let recv_buffer = init
1598                        .receive_buffer
1599                        .map(|recv_buffer| {
1600                            ReceiveBuffer::new(
1601                                &self.resources.gpadl_map,
1602                                recv_buffer.gpadl_id,
1603                                recv_buffer.id,
1604                                recv_buffer.sub_allocation_size,
1605                            )
1606                        })
1607                        .transpose()?;
1608
1609                    let send_buffer = init
1610                        .send_buffer
1611                        .map(|send_buffer| {
1612                            SendBuffer::new(&self.resources.gpadl_map, send_buffer.gpadl_id)
1613                        })
1614                        .transpose()?;
1615
1616                    let state = InitState {
1617                        version,
1618                        ndis_config: init.ndis_config.map(
1619                            |saved_state::NdisConfig { mtu, capabilities }| NdisConfig {
1620                                mtu,
1621                                capabilities: capabilities.into(),
1622                            },
1623                        ),
1624                        ndis_version: init.ndis_version.map(
1625                            |saved_state::NdisVersion { major, minor }| NdisVersion {
1626                                major,
1627                                minor,
1628                            },
1629                        ),
1630                        recv_buffer,
1631                        send_buffer,
1632                    };
1633                    states[0] = Some(WorkerState::Init(Some(state)));
1634                }
1635                saved_state::Primary::Ready(ready) => {
1636                    let saved_state::ReadyPrimary {
1637                        version,
1638                        receive_buffer,
1639                        send_buffer,
1640                        mut control_messages,
1641                        mut rss_state,
1642                        channels,
1643                        ndis_version,
1644                        ndis_config,
1645                        rndis_state,
1646                        guest_vf_state,
1647                        offload_config,
1648                        pending_offload_change,
1649                        tx_spread_sent,
1650                        guest_link_down,
1651                        pending_link_action,
1652                        packet_filter,
1653                    } = ready;
1654
1655                    // If saved state does not have a packet filter set, default to directed, multicast, and broadcast.
1656                    saved_packet_filter = packet_filter.unwrap_or(rndisprot::NPROTO_PACKET_FILTER);
1657
1658                    let version = check_version(version)
1659                        .ok_or(NetRestoreError::UnsupportedVersion(version))?;
1660
1661                    let request = requests[0].as_ref().unwrap();
1662                    let buffers = Arc::new(ChannelBuffers {
1663                        version,
1664                        mem: self.resources.offer_resources.guest_memory(request).clone(),
1665                        recv_buffer: ReceiveBuffer::new(
1666                            &self.resources.gpadl_map,
1667                            receive_buffer.gpadl_id,
1668                            receive_buffer.id,
1669                            receive_buffer.sub_allocation_size,
1670                        )?,
1671                        send_buffer: {
1672                            if let Some(send_buffer) = send_buffer {
1673                                Some(SendBuffer::new(
1674                                    &self.resources.gpadl_map,
1675                                    send_buffer.gpadl_id,
1676                                )?)
1677                            } else {
1678                                None
1679                            }
1680                        },
1681                        ndis_version: {
1682                            let saved_state::NdisVersion { major, minor } = ndis_version;
1683                            NdisVersion { major, minor }
1684                        },
1685                        ndis_config: {
1686                            let saved_state::NdisConfig { mtu, capabilities } = ndis_config;
1687                            NdisConfig {
1688                                mtu,
1689                                capabilities: capabilities.into(),
1690                            }
1691                        },
1692                    });
1693
1694                    for (channel_idx, channel) in channels.iter().enumerate() {
1695                        let channel = if let Some(channel) = channel {
1696                            channel
1697                        } else {
1698                            continue;
1699                        };
1700
1701                        let mut active = ActiveState::restore(channel, buffers.recv_buffer.count)?;
1702
1703                        // Restore primary channel state.
1704                        if channel_idx == 0 {
1705                            let primary = PrimaryChannelState::restore(
1706                                &guest_vf_state,
1707                                &rndis_state,
1708                                &offload_config,
1709                                pending_offload_change,
1710                                channels.len() as u16,
1711                                self.adapter.indirection_table_size,
1712                                &active.rx_bufs,
1713                                std::mem::take(&mut control_messages),
1714                                rss_state.take(),
1715                                tx_spread_sent,
1716                                guest_link_down,
1717                                pending_link_action,
1718                            )?;
1719                            active.primary = Some(primary);
1720                        }
1721
1722                        states[channel_idx] = Some(WorkerState::Ready(ReadyState {
1723                            buffers: buffers.clone(),
1724                            state: active,
1725                            data: ProcessingData::new(),
1726                        }));
1727                    }
1728                }
1729            }
1730
1731            // Insert the coordinator and mark that it should try to start the
1732            // network endpoint when it starts running.
1733            self.insert_coordinator(
1734                states.len() as u16,
1735                Some(RestoreCoordinatorState {
1736                    active_packet_filter: saved_packet_filter,
1737                }),
1738            );
1739
1740            for (channel_idx, (state, request)) in states.into_iter().zip(requests).enumerate() {
1741                if let Some(state) = state {
1742                    self.insert_worker(channel_idx as u16, &request.unwrap(), state, false)?;
1743                }
1744            }
1745        } else {
1746            control
1747                .restore(&[false])
1748                .await
1749                .map_err(NetRestoreError::Channel)?;
1750        }
1751        Ok(())
1752    }
1753
1754    fn saved_state(&self) -> saved_state::SavedState {
1755        let open = if let Some(coordinator) = self.coordinator.state() {
1756            let primary = coordinator.workers[0].state().unwrap();
1757            let primary = match &primary.state {
1758                WorkerState::Init(None) => saved_state::Primary::Version,
1759                WorkerState::Init(Some(init)) => {
1760                    saved_state::Primary::Init(saved_state::InitPrimary {
1761                        version: init.version as u32,
1762                        ndis_config: init.ndis_config.map(|NdisConfig { mtu, capabilities }| {
1763                            saved_state::NdisConfig {
1764                                mtu,
1765                                capabilities: capabilities.into(),
1766                            }
1767                        }),
1768                        ndis_version: init.ndis_version.map(|NdisVersion { major, minor }| {
1769                            saved_state::NdisVersion { major, minor }
1770                        }),
1771                        receive_buffer: init.recv_buffer.as_ref().map(|x| x.saved_state()),
1772                        send_buffer: init.send_buffer.as_ref().map(|x| saved_state::SendBuffer {
1773                            gpadl_id: x.gpadl.id(),
1774                        }),
1775                    })
1776                }
1777                WorkerState::WaitingForCoordinator(Some(ready)) | WorkerState::Ready(ready) => {
1778                    let primary = ready.state.primary.as_ref().unwrap();
1779
1780                    let rndis_state = match primary.rndis_state {
1781                        RndisState::Initializing => saved_state::RndisState::Initializing,
1782                        RndisState::Operational => saved_state::RndisState::Operational,
1783                        RndisState::Halted => saved_state::RndisState::Halted,
1784                    };
1785
1786                    let offload_config = saved_state::OffloadConfig {
1787                        checksum_tx: saved_state::ChecksumOffloadConfig {
1788                            ipv4_header: primary.offload_config.checksum_tx.ipv4_header,
1789                            tcp4: primary.offload_config.checksum_tx.tcp4,
1790                            udp4: primary.offload_config.checksum_tx.udp4,
1791                            tcp6: primary.offload_config.checksum_tx.tcp6,
1792                            udp6: primary.offload_config.checksum_tx.udp6,
1793                        },
1794                        checksum_rx: saved_state::ChecksumOffloadConfig {
1795                            ipv4_header: primary.offload_config.checksum_rx.ipv4_header,
1796                            tcp4: primary.offload_config.checksum_rx.tcp4,
1797                            udp4: primary.offload_config.checksum_rx.udp4,
1798                            tcp6: primary.offload_config.checksum_rx.tcp6,
1799                            udp6: primary.offload_config.checksum_rx.udp6,
1800                        },
1801                        lso4: primary.offload_config.lso4,
1802                        lso6: primary.offload_config.lso6,
1803                    };
1804
1805                    let control_messages = primary
1806                        .control_messages
1807                        .iter()
1808                        .map(|message| saved_state::IncomingControlMessage {
1809                            message_type: message.message_type,
1810                            data: message.data.to_vec(),
1811                        })
1812                        .collect();
1813
1814                    let rss_state = primary.rss_state.as_ref().map(|rss| saved_state::RssState {
1815                        key: rss.key.into(),
1816                        indirection_table: rss.indirection_table.clone(),
1817                    });
1818
1819                    let pending_link_action = match primary.pending_link_action {
1820                        PendingLinkAction::Default => None,
1821                        PendingLinkAction::Active(action) | PendingLinkAction::Delay(action) => {
1822                            Some(action)
1823                        }
1824                    };
1825
1826                    let channels = coordinator.workers[..coordinator.num_queues as usize]
1827                        .iter()
1828                        .map(|worker| {
1829                            worker.state().map(|worker| {
1830                                if let Some(ready) = worker.state.ready() {
1831                                    // In flight tx will be considered as dropped packets through save/restore, but need
1832                                    // to complete the requests back to the guest.
1833                                    let pending_tx_completions = ready
1834                                        .state
1835                                        .pending_tx_completions
1836                                        .iter()
1837                                        .map(|pending| pending.transaction_id)
1838                                        .chain(ready.state.pending_tx_packets.iter().filter_map(
1839                                            |inflight| {
1840                                                (inflight.pending_packet_count > 0)
1841                                                    .then_some(inflight.transaction_id)
1842                                            },
1843                                        ))
1844                                        .collect();
1845
1846                                    saved_state::Channel {
1847                                        pending_tx_completions,
1848                                        in_use_rx: {
1849                                            ready
1850                                                .state
1851                                                .rx_bufs
1852                                                .allocated()
1853                                                .map(|id| saved_state::Rx {
1854                                                    buffers: id.collect(),
1855                                                })
1856                                                .collect()
1857                                        },
1858                                    }
1859                                } else {
1860                                    saved_state::Channel {
1861                                        pending_tx_completions: Vec::new(),
1862                                        in_use_rx: Vec::new(),
1863                                    }
1864                                }
1865                            })
1866                        })
1867                        .collect();
1868
1869                    let guest_vf_state = match primary.guest_vf_state {
1870                        PrimaryChannelGuestVfState::Initializing
1871                        | PrimaryChannelGuestVfState::Unavailable
1872                        | PrimaryChannelGuestVfState::Available { .. } => {
1873                            saved_state::GuestVfState::NoState
1874                        }
1875                        PrimaryChannelGuestVfState::UnavailableFromAvailable
1876                        | PrimaryChannelGuestVfState::AvailableAdvertised => {
1877                            saved_state::GuestVfState::AvailableAdvertised
1878                        }
1879                        PrimaryChannelGuestVfState::Ready => saved_state::GuestVfState::Ready,
1880                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
1881                            to_guest,
1882                            id,
1883                        } => saved_state::GuestVfState::DataPathSwitchPending {
1884                            to_guest,
1885                            id,
1886                            result: None,
1887                        },
1888                        PrimaryChannelGuestVfState::DataPathSwitchPending {
1889                            to_guest,
1890                            id,
1891                            result,
1892                        } => saved_state::GuestVfState::DataPathSwitchPending {
1893                            to_guest,
1894                            id,
1895                            result,
1896                        },
1897                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
1898                        | PrimaryChannelGuestVfState::DataPathSwitched
1899                        | PrimaryChannelGuestVfState::DataPathSynthetic => {
1900                            saved_state::GuestVfState::DataPathSwitched
1901                        }
1902                        PrimaryChannelGuestVfState::Restoring(saved_state) => saved_state,
1903                    };
1904
1905                    let worker_0_packet_filter = coordinator.workers[0]
1906                        .state()
1907                        .unwrap()
1908                        .channel
1909                        .packet_filter;
1910                    saved_state::Primary::Ready(saved_state::ReadyPrimary {
1911                        version: ready.buffers.version as u32,
1912                        receive_buffer: ready.buffers.recv_buffer.saved_state(),
1913                        send_buffer: ready.buffers.send_buffer.as_ref().map(|sb| {
1914                            saved_state::SendBuffer {
1915                                gpadl_id: sb.gpadl.id(),
1916                            }
1917                        }),
1918                        rndis_state,
1919                        guest_vf_state,
1920                        offload_config,
1921                        pending_offload_change: primary.pending_offload_change,
1922                        control_messages,
1923                        rss_state,
1924                        channels,
1925                        ndis_config: {
1926                            let NdisConfig { mtu, capabilities } = ready.buffers.ndis_config;
1927                            saved_state::NdisConfig {
1928                                mtu,
1929                                capabilities: capabilities.into(),
1930                            }
1931                        },
1932                        ndis_version: {
1933                            let NdisVersion { major, minor } = ready.buffers.ndis_version;
1934                            saved_state::NdisVersion { major, minor }
1935                        },
1936                        tx_spread_sent: primary.tx_spread_sent,
1937                        guest_link_down: !primary.guest_link_up,
1938                        pending_link_action,
1939                        packet_filter: Some(worker_0_packet_filter),
1940                    })
1941                }
1942                WorkerState::WaitingForCoordinator(None) => {
1943                    unreachable!("valid ready state")
1944                }
1945            };
1946
1947            let state = saved_state::OpenState { primary };
1948            Some(state)
1949        } else {
1950            None
1951        };
1952
1953        saved_state::SavedState { open }
1954    }
1955}
1956
1957#[derive(Debug, Error)]
1958enum WorkerError {
1959    #[error("packet error")]
1960    Packet(#[source] PacketError),
1961    #[error("unexpected packet order: {0}")]
1962    UnexpectedPacketOrder(#[source] PacketOrderError),
1963    #[error("unknown rndis message type: {0}")]
1964    UnknownRndisMessageType(u32),
1965    #[error("junk after rndis packet message: {0:#x}")]
1966    NonRndisPacketAfterPacket(u32),
1967    #[error("memory access error")]
1968    Access(#[from] AccessError),
1969    #[error("rndis message too small")]
1970    RndisMessageTooSmall,
1971    #[error("unsupported rndis behavior")]
1972    UnsupportedRndisBehavior,
1973    #[error("vmbus queue error")]
1974    Queue(#[from] queue::Error),
1975    #[error("too many control messages")]
1976    TooManyControlMessages,
1977    #[error("invalid rndis packet completion")]
1978    InvalidRndisPacketCompletion,
1979    #[error("missing transaction id")]
1980    MissingTransactionId,
1981    #[error("invalid gpadl")]
1982    InvalidGpadl(#[source] guestmem::InvalidGpn),
1983    #[error("gpadl error")]
1984    GpadlError(#[source] GuestMemoryError),
1985    #[error("gpa direct error")]
1986    GpaDirectError(#[source] GuestMemoryError),
1987    #[error("endpoint")]
1988    Endpoint(#[source] anyhow::Error),
1989    #[error("message not supported on sub channel: {0}")]
1990    NotSupportedOnSubChannel(u32),
1991    #[error("the ring buffer ran out of space, which should not be possible")]
1992    OutOfSpace,
1993    #[error("send/receive buffer error")]
1994    Buffer(#[from] BufferError),
1995    #[error("invalid rndis state")]
1996    InvalidRndisState,
1997    #[error("rndis message type not implemented")]
1998    RndisMessageTypeNotImplemented,
1999    #[error("invalid TCP header offset")]
2000    InvalidTcpHeaderOffset,
2001    #[error("cancelled")]
2002    Cancelled(task_control::Cancelled),
2003    #[error("tearing down because send/receive buffer is revoked")]
2004    BufferRevoked,
2005    #[error("endpoint requires queue restart: {0}")]
2006    EndpointRequiresQueueRestart(#[source] anyhow::Error),
2007    #[error("Failed to send message to coordinator")]
2008    CoordinatorMessageSendFailed(#[source] TrySendError<CoordinatorMessage>),
2009}
2010
2011impl From<task_control::Cancelled> for WorkerError {
2012    fn from(value: task_control::Cancelled) -> Self {
2013        Self::Cancelled(value)
2014    }
2015}
2016
2017#[derive(Debug, Error)]
2018enum OpenError {
2019    #[error("error establishing ring buffer")]
2020    Ring(#[source] vmbus_channel::gpadl_ring::Error),
2021    #[error("error establishing vmbus queue")]
2022    Queue(#[source] queue::Error),
2023}
2024
2025#[derive(Debug, Error)]
2026enum PacketError {
2027    #[error("UnknownType {0}")]
2028    UnknownType(u32),
2029    #[error("Access")]
2030    Access(#[source] AccessError),
2031    #[error("ExternalData")]
2032    ExternalData(#[source] ExternalDataError),
2033    #[error("InvalidSendBufferIndex")]
2034    InvalidSendBufferIndex,
2035}
2036
2037#[derive(Debug, Error)]
2038enum PacketOrderError {
2039    #[error("Invalid PacketData")]
2040    InvalidPacketData,
2041    #[error("Unexpected RndisPacket")]
2042    UnexpectedRndisPacket,
2043    #[error("SendNdisVersion already exists")]
2044    SendNdisVersionExists,
2045    #[error("SendNdisConfig already exists")]
2046    SendNdisConfigExists,
2047    #[error("SendReceiveBuffer already exists")]
2048    SendReceiveBufferExists,
2049    #[error("SendReceiveBuffer missing MTU")]
2050    SendReceiveBufferMissingMTU,
2051    #[error("SendSendBuffer already exists")]
2052    SendSendBufferExists,
2053    #[error("SwitchDataPathCompletion during PrimaryChannelState")]
2054    SwitchDataPathCompletionPrimaryChannelState,
2055}
2056
2057#[derive(Debug)]
2058enum PacketData {
2059    Init(protocol::MessageInit),
2060    SendNdisVersion(protocol::Message1SendNdisVersion),
2061    SendReceiveBuffer(protocol::Message1SendReceiveBuffer),
2062    SendSendBuffer(protocol::Message1SendSendBuffer),
2063    RevokeReceiveBuffer(protocol::Message1RevokeReceiveBuffer),
2064    RevokeSendBuffer(protocol::Message1RevokeSendBuffer),
2065    RndisPacket(protocol::Message1SendRndisPacket),
2066    RndisPacketComplete(protocol::Message1SendRndisPacketComplete),
2067    SendNdisConfig(protocol::Message2SendNdisConfig),
2068    SwitchDataPath(protocol::Message4SwitchDataPath),
2069    OidQueryEx(protocol::Message5OidQueryEx),
2070    SubChannelRequest(protocol::Message5SubchannelRequest),
2071    SendVfAssociationCompletion,
2072    SwitchDataPathCompletion,
2073}
2074
2075#[derive(Debug)]
2076struct Packet<'a> {
2077    data: PacketData,
2078    transaction_id: Option<u64>,
2079    external_data: &'a MultiPagedRangeBuf,
2080}
2081
2082type PacketReader<'a> = PagedRangesReader<'a, MultiPagedRangeIter<'a>>;
2083
2084impl Packet<'_> {
2085    fn rndis_reader<'a>(&'a self, mem: &'a GuestMemory) -> PacketReader<'a> {
2086        PagedRanges::new(self.external_data.iter()).reader(mem)
2087    }
2088}
2089
2090fn read_packet_data<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
2091    reader: &mut impl MemoryRead,
2092) -> Result<T, PacketError> {
2093    reader.read_plain().map_err(PacketError::Access)
2094}
2095
2096fn parse_packet<'a, T: RingMem>(
2097    packet_ref: &queue::PacketRef<'_, T>,
2098    send_buffer: Option<&SendBuffer>,
2099    version: Option<Version>,
2100    external_data: &'a mut MultiPagedRangeBuf,
2101) -> Result<Packet<'a>, PacketError> {
2102    external_data.clear();
2103    let packet = match packet_ref.as_ref() {
2104        IncomingPacket::Data(data) => data,
2105        IncomingPacket::Completion(completion) => {
2106            let data = if completion.transaction_id() == VF_ASSOCIATION_TRANSACTION_ID {
2107                PacketData::SendVfAssociationCompletion
2108            } else if completion.transaction_id() == SWITCH_DATA_PATH_TRANSACTION_ID {
2109                PacketData::SwitchDataPathCompletion
2110            } else {
2111                let mut reader = completion.reader();
2112                let header: protocol::MessageHeader =
2113                    reader.read_plain().map_err(PacketError::Access)?;
2114                match header.message_type {
2115                    protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE => {
2116                        PacketData::RndisPacketComplete(read_packet_data(&mut reader)?)
2117                    }
2118                    typ => return Err(PacketError::UnknownType(typ)),
2119                }
2120            };
2121            return Ok(Packet {
2122                data,
2123                transaction_id: Some(completion.transaction_id()),
2124                external_data,
2125            });
2126        }
2127    };
2128
2129    let mut reader = packet.reader();
2130    let header: protocol::MessageHeader = reader.read_plain().map_err(PacketError::Access)?;
2131    let data = match header.message_type {
2132        protocol::MESSAGE_TYPE_INIT => PacketData::Init(read_packet_data(&mut reader)?),
2133        protocol::MESSAGE1_TYPE_SEND_NDIS_VERSION if version >= Some(Version::V1) => {
2134            PacketData::SendNdisVersion(read_packet_data(&mut reader)?)
2135        }
2136        protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2137            PacketData::SendReceiveBuffer(read_packet_data(&mut reader)?)
2138        }
2139        protocol::MESSAGE1_TYPE_REVOKE_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2140            PacketData::RevokeReceiveBuffer(read_packet_data(&mut reader)?)
2141        }
2142        protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER if version >= Some(Version::V1) => {
2143            PacketData::SendSendBuffer(read_packet_data(&mut reader)?)
2144        }
2145        protocol::MESSAGE1_TYPE_REVOKE_SEND_BUFFER if version >= Some(Version::V1) => {
2146            PacketData::RevokeSendBuffer(read_packet_data(&mut reader)?)
2147        }
2148        protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET if version >= Some(Version::V1) => {
2149            let message: protocol::Message1SendRndisPacket = read_packet_data(&mut reader)?;
2150            if message.send_buffer_section_index != 0xffffffff {
2151                let send_buffer_suballocation = send_buffer
2152                    .ok_or(PacketError::InvalidSendBufferIndex)?
2153                    .gpadl
2154                    .first()
2155                    .unwrap()
2156                    .try_subrange(
2157                        message.send_buffer_section_index as usize * 6144,
2158                        message.send_buffer_section_size as usize,
2159                    )
2160                    .ok_or(PacketError::InvalidSendBufferIndex)?;
2161
2162                external_data.push_range(send_buffer_suballocation);
2163            }
2164            PacketData::RndisPacket(message)
2165        }
2166        protocol::MESSAGE2_TYPE_SEND_NDIS_CONFIG if version >= Some(Version::V2) => {
2167            PacketData::SendNdisConfig(read_packet_data(&mut reader)?)
2168        }
2169        protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH if version >= Some(Version::V4) => {
2170            PacketData::SwitchDataPath(read_packet_data(&mut reader)?)
2171        }
2172        protocol::MESSAGE5_TYPE_OID_QUERY_EX if version >= Some(Version::V5) => {
2173            PacketData::OidQueryEx(read_packet_data(&mut reader)?)
2174        }
2175        protocol::MESSAGE5_TYPE_SUB_CHANNEL if version >= Some(Version::V5) => {
2176            PacketData::SubChannelRequest(read_packet_data(&mut reader)?)
2177        }
2178        typ => return Err(PacketError::UnknownType(typ)),
2179    };
2180    packet
2181        .read_external_ranges(external_data)
2182        .map_err(PacketError::ExternalData)?;
2183    Ok(Packet {
2184        data,
2185        transaction_id: packet.transaction_id(),
2186        external_data,
2187    })
2188}
2189
2190#[derive(Debug, Copy, Clone)]
2191struct NvspMessage {
2192    buf: [u64; protocol::PACKET_SIZE_V61 / 8],
2193    size: PacketSize,
2194}
2195
2196impl NvspMessage {
2197    fn new<P: IntoBytes + Immutable + KnownLayout>(
2198        size: PacketSize,
2199        message_type: u32,
2200        data: P,
2201    ) -> Self {
2202        // Assert at compile time that the packet will fit in the message
2203        // buffer. Note that we are checking against the v1 message size here.
2204        // It's possible this is a v6.1+ message, in which case we could compare
2205        // against the larger size. So far this has not been necessary. If
2206        // needed, make a `new_v61` method that does the more relaxed check,
2207        // rater than weakening this one.
2208        const {
2209            assert!(
2210                size_of::<P>() <= protocol::PACKET_SIZE_V1 - size_of::<protocol::MessageHeader>(),
2211                "packet might not fit in message"
2212            )
2213        };
2214        let mut message = NvspMessage {
2215            buf: [0; protocol::PACKET_SIZE_V61 / 8],
2216            size,
2217        };
2218        let header = protocol::MessageHeader { message_type };
2219        header.write_to_prefix(message.buf.as_mut_bytes()).unwrap();
2220        data.write_to_prefix(
2221            &mut message.buf.as_mut_bytes()[size_of::<protocol::MessageHeader>()..],
2222        )
2223        .unwrap();
2224        message
2225    }
2226
2227    fn aligned_payload(&self) -> &[u64] {
2228        // Note that vmbus packets are always 8-byte multiples, so round the
2229        // protocol package size up.
2230        let len = match self.size {
2231            PacketSize::V1 => const { protocol::PACKET_SIZE_V1.div_ceil(8) },
2232            PacketSize::V61 => const { protocol::PACKET_SIZE_V61.div_ceil(8) },
2233        };
2234        &self.buf[..len]
2235    }
2236}
2237
2238impl<T: RingMem> NetChannel<T> {
2239    fn message<P: IntoBytes + Immutable + KnownLayout>(
2240        &self,
2241        message_type: u32,
2242        data: P,
2243    ) -> NvspMessage {
2244        NvspMessage::new(self.packet_size, message_type, data)
2245    }
2246
2247    fn send_completion(
2248        &mut self,
2249        transaction_id: Option<u64>,
2250        message: Option<&NvspMessage>,
2251    ) -> Result<(), WorkerError> {
2252        match transaction_id {
2253            None => Ok(()),
2254            Some(transaction_id) => Ok(self
2255                .queue
2256                .split()
2257                .1
2258                .batched()
2259                .try_write_aligned(
2260                    transaction_id,
2261                    OutgoingPacketType::Completion,
2262                    message.map_or(&[], |m| m.aligned_payload()),
2263                )
2264                .map_err(|err| match err {
2265                    queue::TryWriteError::Full(_) => WorkerError::OutOfSpace,
2266                    queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2267                })?),
2268        }
2269    }
2270}
2271
2272static SUPPORTED_VERSIONS: &[Version] = &[
2273    Version::V1,
2274    Version::V2,
2275    Version::V4,
2276    Version::V5,
2277    Version::V6,
2278    Version::V61,
2279];
2280
2281fn check_version(requested_version: u32) -> Option<Version> {
2282    SUPPORTED_VERSIONS
2283        .iter()
2284        .find(|version| **version as u32 == requested_version)
2285        .copied()
2286}
2287
2288#[derive(Debug)]
2289struct ReceiveBuffer {
2290    gpadl: GpadlView,
2291    id: u16,
2292    count: u32,
2293    sub_allocation_size: u32,
2294}
2295
2296#[derive(Debug, Error)]
2297enum BufferError {
2298    #[error("unsupported suballocation size {0}")]
2299    UnsupportedSuballocationSize(u32),
2300    #[error("unaligned gpadl")]
2301    UnalignedGpadl,
2302    #[error("unknown gpadl ID")]
2303    UnknownGpadlId(#[from] UnknownGpadlId),
2304}
2305
2306impl ReceiveBuffer {
2307    fn new(
2308        gpadl_map: &GpadlMapView,
2309        gpadl_id: GpadlId,
2310        id: u16,
2311        sub_allocation_size: u32,
2312    ) -> Result<Self, BufferError> {
2313        if sub_allocation_size < sub_allocation_size_for_mtu(DEFAULT_MTU) {
2314            return Err(BufferError::UnsupportedSuballocationSize(
2315                sub_allocation_size,
2316            ));
2317        }
2318        let gpadl = gpadl_map.map(gpadl_id)?;
2319        let range = gpadl
2320            .contiguous_aligned()
2321            .ok_or(BufferError::UnalignedGpadl)?;
2322        let num_sub_allocations = range.len() as u32 / sub_allocation_size;
2323        if num_sub_allocations == 0 {
2324            return Err(BufferError::UnsupportedSuballocationSize(
2325                sub_allocation_size,
2326            ));
2327        }
2328        let recv_buffer = Self {
2329            gpadl,
2330            id,
2331            count: num_sub_allocations,
2332            sub_allocation_size,
2333        };
2334        Ok(recv_buffer)
2335    }
2336
2337    fn range(&self, index: u32) -> PagedRange<'_> {
2338        self.gpadl.first().unwrap().subrange(
2339            (index * self.sub_allocation_size) as usize,
2340            self.sub_allocation_size as usize,
2341        )
2342    }
2343
2344    fn transfer_page_range(&self, index: u32, len: usize) -> ring::TransferPageRange {
2345        assert!(len <= self.sub_allocation_size as usize);
2346        ring::TransferPageRange {
2347            byte_offset: index * self.sub_allocation_size,
2348            byte_count: len as u32,
2349        }
2350    }
2351
2352    fn saved_state(&self) -> saved_state::ReceiveBuffer {
2353        saved_state::ReceiveBuffer {
2354            gpadl_id: self.gpadl.id(),
2355            id: self.id,
2356            sub_allocation_size: self.sub_allocation_size,
2357        }
2358    }
2359}
2360
2361#[derive(Debug)]
2362struct SendBuffer {
2363    gpadl: GpadlView,
2364}
2365
2366impl SendBuffer {
2367    fn new(gpadl_map: &GpadlMapView, gpadl_id: GpadlId) -> Result<Self, BufferError> {
2368        let gpadl = gpadl_map.map(gpadl_id)?;
2369        gpadl
2370            .contiguous_aligned()
2371            .ok_or(BufferError::UnalignedGpadl)?;
2372        Ok(Self { gpadl })
2373    }
2374}
2375
2376impl<T: RingMem> NetChannel<T> {
2377    /// Process a single non-packet RNDIS message.
2378    fn handle_rndis_message(
2379        &mut self,
2380        state: &mut ActiveState,
2381        message_type: u32,
2382        mut reader: PacketReader<'_>,
2383    ) -> Result<(), WorkerError> {
2384        assert_ne!(
2385            message_type,
2386            rndisprot::MESSAGE_TYPE_PACKET_MSG,
2387            "handled elsewhere"
2388        );
2389        let control = state
2390            .primary
2391            .as_mut()
2392            .ok_or(WorkerError::NotSupportedOnSubChannel(message_type))?;
2393
2394        if message_type == rndisprot::MESSAGE_TYPE_HALT_MSG {
2395            // Currently ignored and does not require a response.
2396            return Ok(());
2397        }
2398
2399        // This is a control message that needs a response. Responding
2400        // will require a suballocation to be available, which it may
2401        // not be right now. Enqueue the suballocation to a queue and
2402        // process the queue as suballocations become available.
2403        const CONTROL_MESSAGE_MAX_QUEUED_BYTES: usize = 100 * 1024;
2404        if reader.len() == 0 {
2405            return Err(WorkerError::RndisMessageTooSmall);
2406        }
2407        // Do not let the queue get too large--the guest should not be
2408        // sending very many control messages at a time.
2409        if CONTROL_MESSAGE_MAX_QUEUED_BYTES - control.control_messages_len < reader.len() {
2410            return Err(WorkerError::TooManyControlMessages);
2411        }
2412
2413        control.control_messages_len += reader.len();
2414        control.control_messages.push_back(ControlMessage {
2415            message_type,
2416            data: reader.read_all()?.into(),
2417        });
2418
2419        // The control message queue will be processed in the main dispatch
2420        // loop.
2421        Ok(())
2422    }
2423
2424    /// Process RNDIS packet messages, which may contain multiple RNDIS packets
2425    /// in a single vmbus message.
2426    ///
2427    /// On entry, the reader has already read the RNDIS message header of the
2428    /// first RNDIS packet in the message.
2429    fn handle_rndis_packet_messages(
2430        &mut self,
2431        buffers: &ChannelBuffers,
2432        state: &mut ActiveState,
2433        id: TxId,
2434        mut message_len: usize,
2435        mut reader: PacketReader<'_>,
2436        segments: &mut Vec<TxSegment>,
2437    ) -> Result<usize, WorkerError> {
2438        // There may be multiple RNDIS packets in a single message, concatenated
2439        // with each other. Consume them until there is no more data in the
2440        // RNDIS message.
2441        let mut num_packets = 0;
2442        loop {
2443            let next_message_offset = message_len
2444                .checked_sub(size_of::<rndisprot::MessageHeader>())
2445                .ok_or(WorkerError::RndisMessageTooSmall)?;
2446
2447            self.handle_rndis_packet_message(
2448                id,
2449                reader.clone(),
2450                &buffers.mem,
2451                segments,
2452                &mut state.stats,
2453            )?;
2454            num_packets += 1;
2455
2456            reader.skip(next_message_offset)?;
2457            if reader.len() == 0 {
2458                break;
2459            }
2460            let header: rndisprot::MessageHeader = reader.read_plain()?;
2461            if header.message_type != rndisprot::MESSAGE_TYPE_PACKET_MSG {
2462                return Err(WorkerError::NonRndisPacketAfterPacket(header.message_type));
2463            }
2464            message_len = header.message_length as usize;
2465        }
2466        Ok(num_packets)
2467    }
2468
2469    /// Process an RNDIS package message (used to send an Ethernet frame).
2470    fn handle_rndis_packet_message(
2471        &mut self,
2472        id: TxId,
2473        reader: PacketReader<'_>,
2474        mem: &GuestMemory,
2475        segments: &mut Vec<TxSegment>,
2476        stats: &mut QueueStats,
2477    ) -> Result<(), WorkerError> {
2478        // Headers are guaranteed to be in a single PagedRange.
2479        let headers = reader
2480            .clone()
2481            .into_inner()
2482            .paged_ranges()
2483            .next()
2484            .ok_or(WorkerError::RndisMessageTooSmall)?;
2485        let mut data = reader.into_inner();
2486        let request: rndisprot::Packet = headers.reader(mem).read_plain()?;
2487        if request.num_oob_data_elements != 0
2488            || request.oob_data_length != 0
2489            || request.oob_data_offset != 0
2490            || request.vc_handle != 0
2491        {
2492            return Err(WorkerError::UnsupportedRndisBehavior);
2493        }
2494
2495        if data.len() < request.data_offset as usize
2496            || (data.len() - request.data_offset as usize) < request.data_length as usize
2497            || request.data_length == 0
2498        {
2499            return Err(WorkerError::RndisMessageTooSmall);
2500        }
2501
2502        data.skip(request.data_offset as usize);
2503        data.truncate(request.data_length as usize);
2504
2505        let mut metadata = net_backend::TxMetadata {
2506            id,
2507            len: request.data_length,
2508            ..Default::default()
2509        };
2510
2511        if request.per_packet_info_length != 0 {
2512            let mut ppi = headers
2513                .try_subrange(
2514                    request.per_packet_info_offset as usize,
2515                    request.per_packet_info_length as usize,
2516                )
2517                .ok_or(WorkerError::RndisMessageTooSmall)?;
2518            while !ppi.is_empty() {
2519                let h: rndisprot::PerPacketInfo = ppi.reader(mem).read_plain()?;
2520                if h.size == 0 {
2521                    return Err(WorkerError::RndisMessageTooSmall);
2522                }
2523                let (this, rest) = ppi
2524                    .try_split(h.size as usize)
2525                    .ok_or(WorkerError::RndisMessageTooSmall)?;
2526                let (_, d) = this
2527                    .try_split(h.per_packet_information_offset as usize)
2528                    .ok_or(WorkerError::RndisMessageTooSmall)?;
2529                match h.typ {
2530                    rndisprot::PPI_TCP_IP_CHECKSUM => {
2531                        let n: rndisprot::TxTcpIpChecksumInfo = d.reader(mem).read_plain()?;
2532
2533                        metadata.flags.set_offload_tcp_checksum(
2534                            (n.is_ipv4() || n.is_ipv6()) && n.tcp_checksum(),
2535                        );
2536                        metadata.flags.set_offload_udp_checksum(
2537                            (n.is_ipv4() || n.is_ipv6()) && !n.tcp_checksum() && n.udp_checksum(),
2538                        );
2539                        metadata
2540                            .flags
2541                            .set_offload_ip_header_checksum(n.is_ipv4() && n.ip_header_checksum());
2542                        metadata.flags.set_is_ipv4(n.is_ipv4());
2543                        metadata.flags.set_is_ipv6(n.is_ipv6() && !n.is_ipv4());
2544                        metadata.l2_len = ETHERNET_HEADER_LEN as u8;
2545                        if metadata.flags.offload_tcp_checksum()
2546                            || metadata.flags.offload_udp_checksum()
2547                        {
2548                            metadata.l3_len = if n.tcp_header_offset() >= metadata.l2_len as u16 {
2549                                n.tcp_header_offset() - metadata.l2_len as u16
2550                            } else if n.is_ipv4() {
2551                                let mut reader = data.clone().reader(mem);
2552                                reader.skip(metadata.l2_len as usize)?;
2553                                let mut b = 0;
2554                                reader.read(std::slice::from_mut(&mut b))?;
2555                                (b as u16 >> 4) * 4
2556                            } else {
2557                                // Hope there are no extensions.
2558                                40
2559                            };
2560                        }
2561                    }
2562                    rndisprot::PPI_LSO => {
2563                        let n: rndisprot::TcpLsoInfo = d.reader(mem).read_plain()?;
2564
2565                        metadata.flags.set_offload_tcp_segmentation(true);
2566                        metadata.flags.set_offload_tcp_checksum(true);
2567                        metadata.flags.set_offload_ip_header_checksum(n.is_ipv4());
2568                        metadata.flags.set_is_ipv4(n.is_ipv4());
2569                        metadata.flags.set_is_ipv6(n.is_ipv6() && !n.is_ipv4());
2570                        metadata.l2_len = ETHERNET_HEADER_LEN as u8;
2571                        if n.tcp_header_offset() < metadata.l2_len as u16 {
2572                            return Err(WorkerError::InvalidTcpHeaderOffset);
2573                        }
2574                        metadata.l3_len = n.tcp_header_offset() - metadata.l2_len as u16;
2575                        metadata.l4_len = {
2576                            let mut reader = data.clone().reader(mem);
2577                            reader
2578                                .skip(metadata.l2_len as usize + metadata.l3_len as usize + 12)?;
2579                            let mut b = 0;
2580                            reader.read(std::slice::from_mut(&mut b))?;
2581                            (b >> 4) * 4
2582                        };
2583                        metadata.max_tcp_segment_size = n.mss() as u16;
2584
2585                        if request.data_length >= rndisprot::LSO_MAX_OFFLOAD_SIZE {
2586                            // Not strictly enforced.
2587                            stats.tx_invalid_lso_packets.increment();
2588                        }
2589                    }
2590                    _ => {}
2591                }
2592                ppi = rest;
2593            }
2594        }
2595
2596        let start = segments.len();
2597        for range in data.paged_ranges().flat_map(|r| r.ranges()) {
2598            let range = range.map_err(WorkerError::InvalidGpadl)?;
2599            segments.push(TxSegment {
2600                ty: net_backend::TxSegmentType::Tail,
2601                gpa: range.start,
2602                len: range.len() as u32,
2603            });
2604        }
2605
2606        metadata.segment_count = (segments.len() - start) as u8;
2607
2608        stats.tx_packets.increment();
2609        if metadata.flags.offload_tcp_checksum() || metadata.flags.offload_udp_checksum() {
2610            stats.tx_checksum_packets.increment();
2611        }
2612        if metadata.flags.offload_tcp_segmentation() {
2613            stats.tx_lso_packets.increment();
2614        }
2615
2616        segments[start].ty = net_backend::TxSegmentType::Head(metadata);
2617
2618        Ok(())
2619    }
2620
2621    /// Notify the adapter that the guest VF state has changed and it may
2622    /// need to send a message to the guest.
2623    fn guest_vf_is_available(
2624        &mut self,
2625        guest_vf_id: Option<u32>,
2626        version: Version,
2627        config: NdisConfig,
2628    ) -> Result<bool, WorkerError> {
2629        let serial_number = guest_vf_id.map(|vfid| self.adapter.get_guest_vf_serial_number(vfid));
2630        if version >= Version::V4 && config.capabilities.sriov() {
2631            tracing::info!(
2632                available = serial_number.is_some(),
2633                serial_number,
2634                "sending VF association message"
2635            );
2636            // N.B. MIN_CONTROL_RING_SIZE reserves room to send this packet.
2637            let message = {
2638                self.message(
2639                    protocol::MESSAGE4_TYPE_SEND_VF_ASSOCIATION,
2640                    protocol::Message4SendVfAssociation {
2641                        vf_allocated: if serial_number.is_some() { 1 } else { 0 },
2642                        serial_number: serial_number.unwrap_or(0),
2643                    },
2644                )
2645            };
2646            self.queue
2647                .split()
2648                .1
2649                .batched()
2650                .try_write_aligned(
2651                    VF_ASSOCIATION_TRANSACTION_ID,
2652                    OutgoingPacketType::InBandWithCompletion,
2653                    message.aligned_payload(),
2654                )
2655                .map_err(|err| match err {
2656                    queue::TryWriteError::Full(len) => {
2657                        tracing::error!(len, "failed to write vf association message");
2658                        WorkerError::OutOfSpace
2659                    }
2660                    queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2661                })?;
2662            Ok(true)
2663        } else {
2664            tracing::info!(
2665                available = serial_number.is_some(),
2666                serial_number,
2667                major = version.major(),
2668                minor = version.minor(),
2669                sriov_capable = config.capabilities.sriov(),
2670                "Skipping NvspMessage4TypeSendVFAssociation message"
2671            );
2672            Ok(false)
2673        }
2674    }
2675
2676    /// Send the `NvspMessage5TypeSendIndirectionTable` message.
2677    fn guest_send_indirection_table(&mut self, version: Version, num_channels_opened: u32) {
2678        // N.B. MIN_STATE_CHANGE_RING_SIZE needs to be large enough to support sending the indirection table.
2679        if version < Version::V5 {
2680            return;
2681        }
2682
2683        #[repr(C)]
2684        #[derive(IntoBytes, Immutable, KnownLayout)]
2685        struct SendIndirectionMsg {
2686            pub message: protocol::Message5SendIndirectionTable,
2687            pub send_indirection_table:
2688                [u32; VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES as usize],
2689        }
2690
2691        // The offset to the send indirection table from the beginning of the NVSP message.
2692        let send_indirection_table_offset = offset_of!(SendIndirectionMsg, send_indirection_table)
2693            + size_of::<protocol::MessageHeader>();
2694        let mut data = SendIndirectionMsg {
2695            message: protocol::Message5SendIndirectionTable {
2696                table_entry_count: VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES,
2697                table_offset: send_indirection_table_offset as u32,
2698            },
2699            send_indirection_table: Default::default(),
2700        };
2701
2702        for i in 0..data.send_indirection_table.len() {
2703            data.send_indirection_table[i] = i as u32 % num_channels_opened;
2704        }
2705
2706        let header = protocol::MessageHeader {
2707            message_type: protocol::MESSAGE5_TYPE_SEND_INDIRECTION_TABLE,
2708        };
2709        let result = self
2710            .queue
2711            .split()
2712            .1
2713            .try_write(&queue::OutgoingPacket {
2714                transaction_id: 0,
2715                packet_type: OutgoingPacketType::InBandNoCompletion,
2716                payload: &[header.as_bytes(), data.as_bytes()],
2717            })
2718            .map_err(|err| match err {
2719                queue::TryWriteError::Full(len) => {
2720                    tracing::error!(len, "failed to write send indirection table message");
2721                    WorkerError::OutOfSpace
2722                }
2723                queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2724            });
2725        if let Err(err) = result {
2726            tracing::error!(
2727                error = &err as &dyn std::error::Error,
2728                "Failed to notify guest about the send indirection table"
2729            );
2730        }
2731    }
2732
2733    /// Notify the guest that the data path has been switched back to synthetic
2734    /// due to some external state change.
2735    fn guest_vf_data_path_switched_to_synthetic(&mut self) {
2736        let header = protocol::MessageHeader {
2737            message_type: protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH,
2738        };
2739        let data = protocol::Message4SwitchDataPath {
2740            active_data_path: protocol::DataPath::SYNTHETIC.0,
2741        };
2742        let result = self
2743            .queue
2744            .split()
2745            .1
2746            .try_write(&queue::OutgoingPacket {
2747                transaction_id: SWITCH_DATA_PATH_TRANSACTION_ID,
2748                packet_type: OutgoingPacketType::InBandWithCompletion,
2749                payload: &[header.as_bytes(), data.as_bytes()],
2750            })
2751            .map_err(|err| match err {
2752                queue::TryWriteError::Full(len) => {
2753                    tracing::error!(len, "failed to write switch data path message");
2754                    WorkerError::OutOfSpace
2755                }
2756                queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2757            });
2758        if let Err(err) = result {
2759            tracing::error!(
2760                error = &err as &dyn std::error::Error,
2761                "Failed to notify guest that data path is now synthetic"
2762            );
2763        }
2764    }
2765
2766    /// Process an internal state change
2767    async fn handle_state_change(
2768        &mut self,
2769        primary: &mut PrimaryChannelState,
2770        buffers: &ChannelBuffers,
2771    ) -> Result<Option<CoordinatorMessage>, WorkerError> {
2772        // N.B. MIN_STATE_CHANGE_RING_SIZE needs to be large enough to support sending state change messages.
2773        // The worst case is UnavailableFromDataPathSwitchPending, which will send three messages:
2774        //      1. completion of switch data path request
2775        //      2. Switch data path notification (back to synthetic)
2776        //      3. Disassociate VF adapter.
2777        if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
2778            // Notify guest that a VF capability has recently arrived.
2779            if primary.rndis_state == RndisState::Operational {
2780                if self.guest_vf_is_available(Some(vfid), buffers.version, buffers.ndis_config)? {
2781                    primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
2782                    return Ok(Some(CoordinatorMessage::Update(
2783                        CoordinatorMessageUpdateType {
2784                            guest_vf_state: true,
2785                            ..Default::default()
2786                        },
2787                    )));
2788                } else if let Some(true) = primary.is_data_path_switched {
2789                    tracing::error!(
2790                        "Data path switched, but current guest negotiation does not support VTL0 VF"
2791                    );
2792                }
2793            }
2794            return Ok(None);
2795        }
2796        loop {
2797            primary.guest_vf_state = match primary.guest_vf_state {
2798                PrimaryChannelGuestVfState::UnavailableFromAvailable => {
2799                    // Notify guest that the VF is unavailable. It has already been surprise removed.
2800                    if primary.rndis_state == RndisState::Operational {
2801                        self.guest_vf_is_available(None, buffers.version, buffers.ndis_config)?;
2802                    }
2803                    PrimaryChannelGuestVfState::Unavailable
2804                }
2805                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
2806                    to_guest,
2807                    id,
2808                } => {
2809                    // Complete the data path switch request.
2810                    self.send_completion(id, None)?;
2811                    if to_guest {
2812                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
2813                    } else {
2814                        PrimaryChannelGuestVfState::UnavailableFromAvailable
2815                    }
2816                }
2817                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
2818                    // Notify guest that the data path is now synthetic.
2819                    self.guest_vf_data_path_switched_to_synthetic();
2820                    PrimaryChannelGuestVfState::UnavailableFromAvailable
2821                }
2822                PrimaryChannelGuestVfState::DataPathSynthetic => {
2823                    // Notify guest that the data path is now synthetic.
2824                    self.guest_vf_data_path_switched_to_synthetic();
2825                    PrimaryChannelGuestVfState::Ready
2826                }
2827                PrimaryChannelGuestVfState::DataPathSwitchPending {
2828                    to_guest,
2829                    id,
2830                    result,
2831                } => {
2832                    let result = result.expect("DataPathSwitchPending should have been processed");
2833                    // Complete the data path switch request.
2834                    self.send_completion(id, None)?;
2835
2836                    match (to_guest, result) {
2837                        // Switching to guest VF successful.
2838                        (true, true) => PrimaryChannelGuestVfState::DataPathSwitched,
2839                        // Switching to guest VF failed, stay synthetic.
2840                        (true, false) => {
2841                            tracing::error!(
2842                                "Failure switching to guest VF, remaining on synthetic"
2843                            );
2844                            PrimaryChannelGuestVfState::DataPathSynthetic
2845                        }
2846                        // Switching to synthetic successful.
2847                        (false, true) => PrimaryChannelGuestVfState::Ready,
2848                        // Switching to synthetic failed, assume VF remains active.
2849                        (false, false) => {
2850                            tracing::error!(
2851                                "Failure when guest requested switch back to synthetic"
2852                            );
2853                            PrimaryChannelGuestVfState::DataPathSwitched
2854                        }
2855                    }
2856                }
2857                PrimaryChannelGuestVfState::Initializing
2858                | PrimaryChannelGuestVfState::Restoring(_) => {
2859                    panic!("Invalid guest VF state: {}", primary.guest_vf_state)
2860                }
2861                _ => break,
2862            };
2863        }
2864        Ok(None)
2865    }
2866
2867    /// Process a control message, writing the response to the provided receive
2868    /// buffer suballocation.
2869    fn handle_rndis_control_message(
2870        &mut self,
2871        primary: &mut PrimaryChannelState,
2872        buffers: &ChannelBuffers,
2873        message_type: u32,
2874        mut reader: impl MemoryRead + Clone,
2875        id: u32,
2876    ) -> Result<(), WorkerError> {
2877        let mem = &buffers.mem;
2878        let buffer_range = &buffers.recv_buffer.range(id);
2879        match message_type {
2880            rndisprot::MESSAGE_TYPE_INITIALIZE_MSG => {
2881                if primary.rndis_state != RndisState::Initializing {
2882                    return Err(WorkerError::InvalidRndisState);
2883                }
2884
2885                let request: rndisprot::InitializeRequest = reader.read_plain()?;
2886
2887                tracing::trace!(
2888                    ?request,
2889                    "handling control message MESSAGE_TYPE_INITIALIZE_MSG"
2890                );
2891
2892                primary.rndis_state = RndisState::Operational;
2893
2894                let mut writer = buffer_range.writer(mem);
2895                let message_length = write_rndis_message(
2896                    &mut writer,
2897                    rndisprot::MESSAGE_TYPE_INITIALIZE_CMPLT,
2898                    0,
2899                    &rndisprot::InitializeComplete {
2900                        request_id: request.request_id,
2901                        status: rndisprot::STATUS_SUCCESS,
2902                        major_version: rndisprot::MAJOR_VERSION,
2903                        minor_version: rndisprot::MINOR_VERSION,
2904                        device_flags: rndisprot::DF_CONNECTIONLESS,
2905                        medium: rndisprot::MEDIUM_802_3,
2906                        max_packets_per_message: 8,
2907                        max_transfer_size: 0xEFFFFFFF,
2908                        packet_alignment_factor: 3,
2909                        af_list_offset: 0,
2910                        af_list_size: 0,
2911                    },
2912                )?;
2913                self.send_rndis_control_message(buffers, id, message_length)?;
2914                if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
2915                    if self.guest_vf_is_available(
2916                        Some(vfid),
2917                        buffers.version,
2918                        buffers.ndis_config,
2919                    )? {
2920                        // Ideally the VF would not be presented to the guest
2921                        // until the completion packet has arrived, so that the
2922                        // guest is prepared. This is most interesting for the
2923                        // case of a VF associated with multiple guest
2924                        // adapters, using a concept like vports. In this
2925                        // scenario it would be better if all of the adapters
2926                        // were aware a VF was available before the device
2927                        // arrived. This is not currently possible because the
2928                        // Linux netvsc driver ignores the completion requested
2929                        // flag on inband packets and won't send a completion
2930                        // packet.
2931                        primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
2932                        self.send_coordinator_update_vf();
2933                    } else if let Some(true) = primary.is_data_path_switched {
2934                        tracing::error!(
2935                            "Data path switched, but current guest negotiation does not support VTL0 VF"
2936                        );
2937                    }
2938                }
2939            }
2940            rndisprot::MESSAGE_TYPE_QUERY_MSG => {
2941                let request: rndisprot::QueryRequest = reader.read_plain()?;
2942
2943                tracing::trace!(?request, "handling control message MESSAGE_TYPE_QUERY_MSG");
2944
2945                let (header, body) = buffer_range
2946                    .try_split(
2947                        size_of::<rndisprot::MessageHeader>()
2948                            + size_of::<rndisprot::QueryComplete>(),
2949                    )
2950                    .ok_or(WorkerError::RndisMessageTooSmall)?;
2951                let (status, tx) = match self.adapter.handle_oid_query(
2952                    buffers,
2953                    primary,
2954                    request.oid,
2955                    body.writer(mem),
2956                ) {
2957                    Ok(tx) => (rndisprot::STATUS_SUCCESS, tx),
2958                    Err(err) => (err.as_status(), 0),
2959                };
2960
2961                let message_length = write_rndis_message(
2962                    &mut header.writer(mem),
2963                    rndisprot::MESSAGE_TYPE_QUERY_CMPLT,
2964                    tx,
2965                    &rndisprot::QueryComplete {
2966                        request_id: request.request_id,
2967                        status,
2968                        information_buffer_offset: size_of::<rndisprot::QueryComplete>() as u32,
2969                        information_buffer_length: tx as u32,
2970                    },
2971                )?;
2972                self.send_rndis_control_message(buffers, id, message_length)?;
2973            }
2974            rndisprot::MESSAGE_TYPE_SET_MSG => {
2975                let request: rndisprot::SetRequest = reader.read_plain()?;
2976
2977                tracing::trace!(?request, "handling control message MESSAGE_TYPE_SET_MSG");
2978
2979                let status = match self.adapter.handle_oid_set(primary, request.oid, reader) {
2980                    Ok((restart_endpoint, packet_filter)) => {
2981                        // Restart the endpoint if the OID changed some critical
2982                        // endpoint property.
2983                        if restart_endpoint {
2984                            self.restart = Some(CoordinatorMessage::Restart);
2985                        }
2986                        if let Some(filter) = packet_filter {
2987                            if self.packet_filter != filter {
2988                                self.packet_filter = filter;
2989                                self.send_coordinator_update_filter();
2990                            }
2991                        }
2992                        rndisprot::STATUS_SUCCESS
2993                    }
2994                    Err(err) => {
2995                        tracelimit::warn_ratelimited!(
2996                            error = &err as &dyn std::error::Error,
2997                            oid = ?request.oid,
2998                            "oid set failure"
2999                        );
3000                        err.as_status()
3001                    }
3002                };
3003
3004                let message_length = write_rndis_message(
3005                    &mut buffer_range.writer(mem),
3006                    rndisprot::MESSAGE_TYPE_SET_CMPLT,
3007                    0,
3008                    &rndisprot::SetComplete {
3009                        request_id: request.request_id,
3010                        status,
3011                    },
3012                )?;
3013                self.send_rndis_control_message(buffers, id, message_length)?;
3014            }
3015            rndisprot::MESSAGE_TYPE_RESET_MSG => {
3016                return Err(WorkerError::RndisMessageTypeNotImplemented);
3017            }
3018            rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG => {
3019                return Err(WorkerError::RndisMessageTypeNotImplemented);
3020            }
3021            rndisprot::MESSAGE_TYPE_KEEPALIVE_MSG => {
3022                let request: rndisprot::KeepaliveRequest = reader.read_plain()?;
3023
3024                tracing::trace!(
3025                    ?request,
3026                    "handling control message MESSAGE_TYPE_KEEPALIVE_MSG"
3027                );
3028
3029                let message_length = write_rndis_message(
3030                    &mut buffer_range.writer(mem),
3031                    rndisprot::MESSAGE_TYPE_KEEPALIVE_CMPLT,
3032                    0,
3033                    &rndisprot::KeepaliveComplete {
3034                        request_id: request.request_id,
3035                        status: rndisprot::STATUS_SUCCESS,
3036                    },
3037                )?;
3038                self.send_rndis_control_message(buffers, id, message_length)?;
3039            }
3040            rndisprot::MESSAGE_TYPE_SET_EX_MSG => {
3041                return Err(WorkerError::RndisMessageTypeNotImplemented);
3042            }
3043            _ => return Err(WorkerError::UnknownRndisMessageType(message_type)),
3044        };
3045        Ok(())
3046    }
3047
3048    fn try_send_rndis_message(
3049        &mut self,
3050        transaction_id: u64,
3051        channel_type: u32,
3052        recv_buffer_id: u16,
3053        transfer_pages: &[ring::TransferPageRange],
3054    ) -> Result<Option<usize>, WorkerError> {
3055        let message = self.message(
3056            protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET,
3057            protocol::Message1SendRndisPacket {
3058                channel_type,
3059                send_buffer_section_index: 0xffffffff,
3060                send_buffer_section_size: 0,
3061            },
3062        );
3063        let pending_send_size = match self.queue.split().1.batched().try_write_aligned(
3064            transaction_id,
3065            OutgoingPacketType::TransferPages(recv_buffer_id, transfer_pages),
3066            message.aligned_payload(),
3067        ) {
3068            Ok(()) => None,
3069            Err(queue::TryWriteError::Full(n)) => Some(n),
3070            Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
3071        };
3072        Ok(pending_send_size)
3073    }
3074
3075    fn send_rndis_control_message(
3076        &mut self,
3077        buffers: &ChannelBuffers,
3078        id: u32,
3079        message_length: usize,
3080    ) -> Result<(), WorkerError> {
3081        let result = self.try_send_rndis_message(
3082            id as u64,
3083            protocol::CONTROL_CHANNEL_TYPE,
3084            buffers.recv_buffer.id,
3085            std::slice::from_ref(&buffers.recv_buffer.transfer_page_range(id, message_length)),
3086        )?;
3087
3088        // Ring size is checked before control messages are processed, so failure to write is unexpected.
3089        match result {
3090            None => Ok(()),
3091            Some(len) => {
3092                tracelimit::error_ratelimited!(len, "failed to write control message completion");
3093                Err(WorkerError::OutOfSpace)
3094            }
3095        }
3096    }
3097
3098    fn indicate_status(
3099        &mut self,
3100        buffers: &ChannelBuffers,
3101        id: u32,
3102        status: u32,
3103        payload: &[u8],
3104    ) -> Result<(), WorkerError> {
3105        let buffer = &buffers.recv_buffer.range(id);
3106        let mut writer = buffer.writer(&buffers.mem);
3107        let message_length = write_rndis_message(
3108            &mut writer,
3109            rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG,
3110            payload.len(),
3111            &rndisprot::IndicateStatus {
3112                status,
3113                status_buffer_length: payload.len() as u32,
3114                status_buffer_offset: if payload.is_empty() {
3115                    0
3116                } else {
3117                    size_of::<rndisprot::IndicateStatus>() as u32
3118                },
3119            },
3120        )?;
3121        writer.write(payload)?;
3122        self.send_rndis_control_message(buffers, id, message_length)?;
3123        Ok(())
3124    }
3125
3126    /// Processes pending control messages until all are processed or there are
3127    /// no available suballocations.
3128    fn process_control_messages(
3129        &mut self,
3130        buffers: &ChannelBuffers,
3131        state: &mut ActiveState,
3132    ) -> Result<(), WorkerError> {
3133        let Some(primary) = &mut state.primary else {
3134            return Ok(());
3135        };
3136
3137        while !primary.control_messages.is_empty()
3138            || (primary.pending_offload_change && primary.rndis_state == RndisState::Operational)
3139        {
3140            // Ensure the ring buffer has enough room to successfully complete control message handling.
3141            if !self.queue.split().1.can_write(MIN_CONTROL_RING_SIZE)? {
3142                self.pending_send_size = MIN_CONTROL_RING_SIZE;
3143                break;
3144            }
3145            let Some(id) = primary.free_control_buffers.pop() else {
3146                break;
3147            };
3148
3149            // Mark the receive buffer in use to allow the guest to release it.
3150            assert!(state.rx_bufs.is_free(id.0));
3151            state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
3152
3153            if let Some(message) = primary.control_messages.pop_front() {
3154                primary.control_messages_len -= message.data.len();
3155                self.handle_rndis_control_message(
3156                    primary,
3157                    buffers,
3158                    message.message_type,
3159                    message.data.as_ref(),
3160                    id.0,
3161                )?;
3162            } else if primary.pending_offload_change
3163                && primary.rndis_state == RndisState::Operational
3164            {
3165                let ndis_offload = primary.offload_config.ndis_offload();
3166                self.indicate_status(
3167                    buffers,
3168                    id.0,
3169                    rndisprot::STATUS_TASK_OFFLOAD_CURRENT_CONFIG,
3170                    &ndis_offload.as_bytes()[..ndis_offload.header.size.into()],
3171                )?;
3172                primary.pending_offload_change = false;
3173            } else {
3174                unreachable!();
3175            }
3176        }
3177        Ok(())
3178    }
3179
3180    fn send_coordinator_update_message(&mut self, guest_vf: bool, packet_filter: bool) {
3181        if self.restart.is_none() {
3182            self.restart = Some(CoordinatorMessage::Update(CoordinatorMessageUpdateType {
3183                guest_vf_state: guest_vf,
3184                filter_state: packet_filter,
3185            }));
3186        } else if let Some(CoordinatorMessage::Restart) = self.restart {
3187            // If a restart message is pending, do nothing.
3188            // A restart will try to switch the data path based on primary.guest_vf_state.
3189            // A restart will apply packet filter changes.
3190        } else if let Some(CoordinatorMessage::Update(ref mut update)) = self.restart {
3191            // Add the new update to the existing message.
3192            update.guest_vf_state |= guest_vf;
3193            update.filter_state |= packet_filter;
3194        }
3195    }
3196
3197    fn send_coordinator_update_vf(&mut self) {
3198        self.send_coordinator_update_message(true, false);
3199    }
3200
3201    fn send_coordinator_update_filter(&mut self) {
3202        self.send_coordinator_update_message(false, true);
3203    }
3204}
3205
3206/// Writes an RNDIS message to `writer`.
3207fn write_rndis_message<T: IntoBytes + Immutable + KnownLayout>(
3208    writer: &mut impl MemoryWrite,
3209    message_type: u32,
3210    extra: usize,
3211    payload: &T,
3212) -> Result<usize, AccessError> {
3213    let message_length = size_of::<rndisprot::MessageHeader>() + size_of_val(payload) + extra;
3214    writer.write(
3215        rndisprot::MessageHeader {
3216            message_type,
3217            message_length: message_length as u32,
3218        }
3219        .as_bytes(),
3220    )?;
3221    writer.write(payload.as_bytes())?;
3222    Ok(message_length)
3223}
3224
3225#[derive(Debug, Error)]
3226enum OidError {
3227    #[error(transparent)]
3228    Access(#[from] AccessError),
3229    #[error("unknown oid")]
3230    UnknownOid,
3231    #[error("invalid oid input, bad field {0}")]
3232    InvalidInput(&'static str),
3233    #[error("bad ndis version")]
3234    BadVersion,
3235    #[error("feature {0} not supported")]
3236    NotSupported(&'static str),
3237}
3238
3239impl OidError {
3240    fn as_status(&self) -> u32 {
3241        match self {
3242            OidError::UnknownOid | OidError::NotSupported(_) => rndisprot::STATUS_NOT_SUPPORTED,
3243            OidError::BadVersion => rndisprot::STATUS_BAD_VERSION,
3244            OidError::InvalidInput(_) => rndisprot::STATUS_INVALID_DATA,
3245            OidError::Access(_) => rndisprot::STATUS_FAILURE,
3246        }
3247    }
3248}
3249
3250const DEFAULT_MTU: u32 = 1514;
3251const MIN_MTU: u32 = DEFAULT_MTU;
3252const MAX_MTU: u32 = 9216;
3253
3254const ETHERNET_HEADER_LEN: u32 = 14;
3255
3256impl Adapter {
3257    fn get_guest_vf_serial_number(&self, vfid: u32) -> u32 {
3258        if let Some(guest_os_id) = self.get_guest_os_id.as_ref().map(|f| f()) {
3259            // For enlightened guests (which is only Windows at the moment), send the
3260            // adapter index, which was previously set as the vport serial number.
3261            if guest_os_id
3262                .microsoft()
3263                .unwrap_or(HvGuestOsMicrosoft::from(0))
3264                .os_id()
3265                == HvGuestOsMicrosoftIds::WINDOWS_NT.0
3266            {
3267                self.adapter_index
3268            } else {
3269                vfid
3270            }
3271        } else {
3272            vfid
3273        }
3274    }
3275
3276    fn handle_oid_query(
3277        &self,
3278        buffers: &ChannelBuffers,
3279        primary: &PrimaryChannelState,
3280        oid: rndisprot::Oid,
3281        mut writer: impl MemoryWrite,
3282    ) -> Result<usize, OidError> {
3283        tracing::debug!(?oid, "oid query");
3284        let available_len = writer.len();
3285        match oid {
3286            rndisprot::Oid::OID_GEN_SUPPORTED_LIST => {
3287                let supported_oids_common = &[
3288                    rndisprot::Oid::OID_GEN_SUPPORTED_LIST,
3289                    rndisprot::Oid::OID_GEN_HARDWARE_STATUS,
3290                    rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED,
3291                    rndisprot::Oid::OID_GEN_MEDIA_IN_USE,
3292                    rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD,
3293                    rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD,
3294                    rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE,
3295                    rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE,
3296                    rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE,
3297                    rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE,
3298                    rndisprot::Oid::OID_GEN_LINK_SPEED,
3299                    rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE,
3300                    rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE,
3301                    rndisprot::Oid::OID_GEN_VENDOR_ID,
3302                    rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION,
3303                    rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION,
3304                    rndisprot::Oid::OID_GEN_DRIVER_VERSION,
3305                    rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER,
3306                    rndisprot::Oid::OID_GEN_PROTOCOL_OPTIONS,
3307                    rndisprot::Oid::OID_GEN_MAC_OPTIONS,
3308                    rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS,
3309                    rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS,
3310                    rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES,
3311                    rndisprot::Oid::OID_GEN_FRIENDLY_NAME,
3312                    // Ethernet objects operation characteristics
3313                    rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS,
3314                    rndisprot::Oid::OID_802_3_CURRENT_ADDRESS,
3315                    rndisprot::Oid::OID_802_3_MULTICAST_LIST,
3316                    rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE,
3317                    // Ethernet objects statistics
3318                    rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT,
3319                    rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION,
3320                    rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS,
3321                    // PNP operations characteristics */
3322                    // rndisprot::Oid::OID_PNP_SET_POWER,
3323                    // rndisprot::Oid::OID_PNP_QUERY_POWER,
3324                    // RNDIS OIDS
3325                    rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER,
3326                ];
3327
3328                // NDIS6 OIDs
3329                let supported_oids_6 = &[
3330                    // Link State OID
3331                    rndisprot::Oid::OID_GEN_LINK_PARAMETERS,
3332                    rndisprot::Oid::OID_GEN_LINK_STATE,
3333                    rndisprot::Oid::OID_GEN_MAX_LINK_SPEED,
3334                    // NDIS 6 statistics OID
3335                    rndisprot::Oid::OID_GEN_BYTES_RCV,
3336                    rndisprot::Oid::OID_GEN_BYTES_XMIT,
3337                    // Offload related OID
3338                    rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS,
3339                    rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION,
3340                    rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES,
3341                    rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG,
3342                    // rndisprot::Oid::OID_802_3_ADD_MULTICAST_ADDRESS,
3343                    // rndisprot::Oid::OID_802_3_DELETE_MULTICAST_ADDRESS,
3344                ];
3345
3346                let supported_oids_63 = &[
3347                    rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES,
3348                    rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS,
3349                ];
3350
3351                match buffers.ndis_version.major {
3352                    5 => {
3353                        writer.write(supported_oids_common.as_bytes())?;
3354                    }
3355                    6 => {
3356                        writer.write(supported_oids_common.as_bytes())?;
3357                        writer.write(supported_oids_6.as_bytes())?;
3358                        if buffers.ndis_version.minor >= 30 {
3359                            writer.write(supported_oids_63.as_bytes())?;
3360                        }
3361                    }
3362                    _ => return Err(OidError::BadVersion),
3363                }
3364            }
3365            rndisprot::Oid::OID_GEN_HARDWARE_STATUS => {
3366                let status: u32 = 0; // NdisHardwareStatusReady
3367                writer.write(status.as_bytes())?;
3368            }
3369            rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED | rndisprot::Oid::OID_GEN_MEDIA_IN_USE => {
3370                writer.write(rndisprot::MEDIUM_802_3.as_bytes())?;
3371            }
3372            rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD
3373            | rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD
3374            | rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE => {
3375                let len: u32 = buffers.ndis_config.mtu - ETHERNET_HEADER_LEN;
3376                writer.write(len.as_bytes())?;
3377            }
3378            rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE
3379            | rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE
3380            | rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE => {
3381                let len: u32 = buffers.ndis_config.mtu;
3382                writer.write(len.as_bytes())?;
3383            }
3384            rndisprot::Oid::OID_GEN_LINK_SPEED => {
3385                let speed: u32 = (self.link_speed / 100) as u32; // In 100bps units
3386                writer.write(speed.as_bytes())?;
3387            }
3388            rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE
3389            | rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE => {
3390                // This value is meaningless for virtual NICs. Return what vmswitch returns.
3391                writer.write((256u32 * 1024).as_bytes())?
3392            }
3393            rndisprot::Oid::OID_GEN_VENDOR_ID => {
3394                // Like vmswitch, use the first N bytes of Microsoft's MAC address
3395                // prefix as the vendor ID.
3396                writer.write(0x0000155du32.as_bytes())?;
3397            }
3398            rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION => writer.write(b"Microsoft Corporation")?,
3399            rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION
3400            | rndisprot::Oid::OID_GEN_DRIVER_VERSION => {
3401                writer.write(0x0100u16.as_bytes())? // 1.0. Vmswitch reports 19.0 for Mn.
3402            }
3403            rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => writer.write(0u32.as_bytes())?,
3404            rndisprot::Oid::OID_GEN_MAC_OPTIONS => {
3405                let options: u32 = rndisprot::MAC_OPTION_COPY_LOOKAHEAD_DATA
3406                    | rndisprot::MAC_OPTION_TRANSFERS_NOT_PEND
3407                    | rndisprot::MAC_OPTION_NO_LOOPBACK;
3408                writer.write(options.as_bytes())?;
3409            }
3410            rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS => {
3411                writer.write(rndisprot::MEDIA_STATE_CONNECTED.as_bytes())?;
3412            }
3413            rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS => writer.write(u32::MAX.as_bytes())?,
3414            rndisprot::Oid::OID_GEN_FRIENDLY_NAME => {
3415                let name16: Vec<u16> = "Network Device".encode_utf16().collect();
3416                let mut name = rndisprot::FriendlyName::new_zeroed();
3417                name.name[..name16.len()].copy_from_slice(&name16);
3418                writer.write(name.as_bytes())?
3419            }
3420            rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS
3421            | rndisprot::Oid::OID_802_3_CURRENT_ADDRESS => {
3422                writer.write(&self.mac_address.to_bytes())?
3423            }
3424            rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE => {
3425                writer.write(0u32.as_bytes())?;
3426            }
3427            rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT
3428            | rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION
3429            | rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS => writer.write(0u32.as_bytes())?,
3430
3431            // NDIS6 OIDs:
3432            rndisprot::Oid::OID_GEN_LINK_STATE => {
3433                let link_state = rndisprot::LinkState {
3434                    header: rndisprot::NdisObjectHeader {
3435                        object_type: rndisprot::NdisObjectType::DEFAULT,
3436                        revision: 1,
3437                        size: size_of::<rndisprot::LinkState>() as u16,
3438                    },
3439                    media_connect_state: 1, /* MediaConnectStateConnected */
3440                    media_duplex_state: 0,  /* MediaDuplexStateUnknown */
3441                    padding: 0,
3442                    xmit_link_speed: self.link_speed,
3443                    rcv_link_speed: self.link_speed,
3444                    pause_functions: 0, /* NdisPauseFunctionsUnsupported */
3445                    auto_negotiation_flags: 0,
3446                };
3447                writer.write(link_state.as_bytes())?;
3448            }
3449            rndisprot::Oid::OID_GEN_MAX_LINK_SPEED => {
3450                let link_speed = rndisprot::LinkSpeed {
3451                    xmit: self.link_speed,
3452                    rcv: self.link_speed,
3453                };
3454                writer.write(link_speed.as_bytes())?;
3455            }
3456            rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES => {
3457                let ndis_offload = self.offload_support.ndis_offload();
3458                writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3459            }
3460            rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG => {
3461                let ndis_offload = primary.offload_config.ndis_offload();
3462                writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3463            }
3464            rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3465                writer.write(
3466                    &rndisprot::NdisOffloadEncapsulation {
3467                        header: rndisprot::NdisObjectHeader {
3468                            object_type: rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3469                            revision: 1,
3470                            size: rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1 as u16,
3471                        },
3472                        ipv4_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3473                        ipv4_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3474                        ipv4_header_size: ETHERNET_HEADER_LEN,
3475                        ipv6_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3476                        ipv6_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3477                        ipv6_header_size: ETHERNET_HEADER_LEN,
3478                    }
3479                    .as_bytes()[..rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1],
3480                )?;
3481            }
3482            rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES => {
3483                writer.write(
3484                    &rndisprot::NdisReceiveScaleCapabilities {
3485                        header: rndisprot::NdisObjectHeader {
3486                            object_type: rndisprot::NdisObjectType::RSS_CAPABILITIES,
3487                            revision: 2,
3488                            size: rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2
3489                                as u16,
3490                        },
3491                        capabilities_flags: rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV4
3492                            | rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV6
3493                            | rndisprot::NDIS_HASH_FUNCTION_TOEPLITZ,
3494                        number_of_interrupt_messages: 1,
3495                        number_of_receive_queues: self.max_queues.into(),
3496                        number_of_indirection_table_entries: if self.indirection_table_size != 0 {
3497                            self.indirection_table_size
3498                        } else {
3499                            // DPDK gets confused if the table size is zero,
3500                            // even if there is only one queue.
3501                            128
3502                        },
3503                        padding: 0,
3504                    }
3505                    .as_bytes()[..rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2],
3506                )?;
3507            }
3508            _ => {
3509                tracelimit::warn_ratelimited!(?oid, "query for unknown OID");
3510                return Err(OidError::UnknownOid);
3511            }
3512        };
3513        Ok(available_len - writer.len())
3514    }
3515
3516    fn handle_oid_set(
3517        &self,
3518        primary: &mut PrimaryChannelState,
3519        oid: rndisprot::Oid,
3520        reader: impl MemoryRead + Clone,
3521    ) -> Result<(bool, Option<u32>), OidError> {
3522        tracing::debug!(?oid, "oid set");
3523
3524        let mut restart_endpoint = false;
3525        let mut packet_filter = None;
3526        match oid {
3527            rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => {
3528                packet_filter = self.oid_set_packet_filter(reader)?;
3529            }
3530            rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS => {
3531                self.oid_set_offload_parameters(reader, primary)?;
3532            }
3533            rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3534                self.oid_set_offload_encapsulation(reader)?;
3535            }
3536            rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER => {
3537                self.oid_set_rndis_config_parameter(reader, primary)?;
3538            }
3539            rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES => {
3540                // TODO
3541            }
3542            rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS => {
3543                let rss_was_enabled = self.oid_set_rss_parameters(reader, primary)?;
3544
3545                // Endpoints cannot currently change RSS parameters without
3546                // being restarted. This was a limitation driven by some DPDK
3547                // PMDs, and should be fixed.
3548                //
3549                // Skip the restart if RSS was already disabled and the guest
3550                // is disabling it again — nothing has changed.
3551                if rss_was_enabled || primary.rss_state.is_some() {
3552                    restart_endpoint = true;
3553                }
3554            }
3555            _ => {
3556                tracelimit::warn_ratelimited!(?oid, "set of unknown OID");
3557                return Err(OidError::UnknownOid);
3558            }
3559        }
3560        Ok((restart_endpoint, packet_filter))
3561    }
3562
3563    fn oid_set_rss_parameters(
3564        &self,
3565        mut reader: impl MemoryRead + Clone,
3566        primary: &mut PrimaryChannelState,
3567    ) -> Result<bool, OidError> {
3568        // Vmswitch doesn't validate the NDIS header on this object, so read it manually.
3569        let mut params = rndisprot::NdisReceiveScaleParameters::new_zeroed();
3570        let len = reader.len().min(size_of_val(&params));
3571        reader.clone().read(&mut params.as_mut_bytes()[..len])?;
3572
3573        let rss_was_enabled = primary.rss_state.is_some();
3574
3575        if ((params.flags & NDIS_RSS_PARAM_FLAG_DISABLE_RSS) != 0)
3576            || ((params.hash_information & NDIS_HASH_FUNCTION_MASK) == 0)
3577        {
3578            primary.rss_state = None;
3579            return Ok(rss_was_enabled);
3580        }
3581
3582        if params.hash_secret_key_size != 40 {
3583            return Err(OidError::InvalidInput("hash_secret_key_size"));
3584        }
3585        if params.indirection_table_size % 4 != 0 {
3586            return Err(OidError::InvalidInput("indirection_table_size"));
3587        }
3588        let indirection_table_size =
3589            (params.indirection_table_size / 4).min(self.indirection_table_size) as usize;
3590        let mut key = [0; 40];
3591        let mut indirection_table = vec![0u32; self.indirection_table_size as usize];
3592        reader
3593            .clone()
3594            .skip(params.hash_secret_key_offset as usize)?
3595            .read(&mut key)?;
3596        reader
3597            .skip(params.indirection_table_offset as usize)?
3598            .read(indirection_table[..indirection_table_size].as_mut_bytes())?;
3599        tracelimit::info_ratelimited!(?indirection_table, "OID_GEN_RECEIVE_SCALE_PARAMETERS");
3600        if indirection_table
3601            .iter()
3602            .any(|&x| x >= self.max_queues as u32)
3603        {
3604            return Err(OidError::InvalidInput("indirection_table"));
3605        }
3606        let (indir_init, indir_uninit) = indirection_table.split_at_mut(indirection_table_size);
3607        for (src, dest) in std::iter::repeat_with(|| indir_init.iter().copied())
3608            .flatten()
3609            .zip(indir_uninit)
3610        {
3611            *dest = src;
3612        }
3613        primary.rss_state = Some(RssState {
3614            key,
3615            indirection_table: indirection_table.iter().map(|&x| x as u16).collect(),
3616        });
3617        Ok(rss_was_enabled)
3618    }
3619
3620    fn oid_set_packet_filter(
3621        &self,
3622        reader: impl MemoryRead + Clone,
3623    ) -> Result<Option<u32>, OidError> {
3624        let filter: rndisprot::RndisPacketFilterOidValue = reader.clone().read_plain()?;
3625        tracing::debug!(filter, "set packet filter");
3626        Ok(Some(filter))
3627    }
3628
3629    fn oid_set_offload_parameters(
3630        &self,
3631        reader: impl MemoryRead + Clone,
3632        primary: &mut PrimaryChannelState,
3633    ) -> Result<(), OidError> {
3634        let offload: rndisprot::NdisOffloadParameters = read_ndis_object(
3635            reader,
3636            rndisprot::NdisObjectType::DEFAULT,
3637            1,
3638            rndisprot::NDIS_SIZEOF_OFFLOAD_PARAMETERS_REVISION_1,
3639        )?;
3640
3641        tracing::debug!(?offload, "offload parameters");
3642        let rndisprot::NdisOffloadParameters {
3643            header: _,
3644            ipv4_checksum,
3645            tcp4_checksum,
3646            udp4_checksum,
3647            tcp6_checksum,
3648            udp6_checksum,
3649            lsov1,
3650            ipsec_v1: _,
3651            lsov2_ipv4,
3652            lsov2_ipv6,
3653            tcp_connection_ipv4: _,
3654            tcp_connection_ipv6: _,
3655            reserved: _,
3656            flags: _,
3657        } = offload;
3658
3659        if lsov1 == rndisprot::OffloadParametersSimple::ENABLED {
3660            return Err(OidError::NotSupported("lsov1"));
3661        }
3662        if let Some((tx, rx)) = ipv4_checksum.tx_rx() {
3663            primary.offload_config.checksum_tx.ipv4_header = tx;
3664            primary.offload_config.checksum_rx.ipv4_header = rx;
3665        }
3666        if let Some((tx, rx)) = tcp4_checksum.tx_rx() {
3667            primary.offload_config.checksum_tx.tcp4 = tx;
3668            primary.offload_config.checksum_rx.tcp4 = rx;
3669        }
3670        if let Some((tx, rx)) = tcp6_checksum.tx_rx() {
3671            primary.offload_config.checksum_tx.tcp6 = tx;
3672            primary.offload_config.checksum_rx.tcp6 = rx;
3673        }
3674        if let Some((tx, rx)) = udp4_checksum.tx_rx() {
3675            primary.offload_config.checksum_tx.udp4 = tx;
3676            primary.offload_config.checksum_rx.udp4 = rx;
3677        }
3678        if let Some((tx, rx)) = udp6_checksum.tx_rx() {
3679            primary.offload_config.checksum_tx.udp6 = tx;
3680            primary.offload_config.checksum_rx.udp6 = rx;
3681        }
3682        if let Some(enable) = lsov2_ipv4.enable() {
3683            primary.offload_config.lso4 = enable;
3684        }
3685        if let Some(enable) = lsov2_ipv6.enable() {
3686            primary.offload_config.lso6 = enable;
3687        }
3688        primary
3689            .offload_config
3690            .mask_to_supported(&self.offload_support);
3691        primary.pending_offload_change = true;
3692        Ok(())
3693    }
3694
3695    fn oid_set_offload_encapsulation(
3696        &self,
3697        reader: impl MemoryRead + Clone,
3698    ) -> Result<(), OidError> {
3699        let encap: rndisprot::NdisOffloadEncapsulation = read_ndis_object(
3700            reader,
3701            rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3702            1,
3703            rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1,
3704        )?;
3705        if encap.ipv4_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3706            && (encap.ipv4_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3707                || encap.ipv4_header_size != ETHERNET_HEADER_LEN)
3708        {
3709            return Err(OidError::NotSupported("ipv4 encap"));
3710        }
3711        if encap.ipv6_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3712            && (encap.ipv6_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3713                || encap.ipv6_header_size != ETHERNET_HEADER_LEN)
3714        {
3715            return Err(OidError::NotSupported("ipv6 encap"));
3716        }
3717        Ok(())
3718    }
3719
3720    fn oid_set_rndis_config_parameter(
3721        &self,
3722        reader: impl MemoryRead + Clone,
3723        primary: &mut PrimaryChannelState,
3724    ) -> Result<(), OidError> {
3725        let info: rndisprot::RndisConfigParameterInfo = reader.clone().read_plain()?;
3726        if info.name_length > 255 {
3727            return Err(OidError::InvalidInput("name_length"));
3728        }
3729        if info.value_length > 255 {
3730            return Err(OidError::InvalidInput("value_length"));
3731        }
3732        let name = reader
3733            .clone()
3734            .skip(info.name_offset as usize)?
3735            .read_n::<u16>(info.name_length as usize / 2)?;
3736        let name = String::from_utf16(&name).map_err(|_| OidError::InvalidInput("name"))?;
3737        let mut value = reader;
3738        value.skip(info.value_offset as usize)?;
3739        let mut value = value.limit(info.value_length as usize);
3740        match info.parameter_type {
3741            rndisprot::NdisParameterType::STRING => {
3742                let value = value.read_n::<u16>(info.value_length as usize / 2)?;
3743                let value =
3744                    String::from_utf16(&value).map_err(|_| OidError::InvalidInput("value"))?;
3745                let as_num = value.as_bytes().first().map_or(0, |c| c - b'0');
3746                let tx = as_num & 1 != 0;
3747                let rx = as_num & 2 != 0;
3748
3749                tracing::debug!(name, value, "rndis config");
3750                match name.as_str() {
3751                    "*IPChecksumOffloadIPv4" => {
3752                        primary.offload_config.checksum_tx.ipv4_header = tx;
3753                        primary.offload_config.checksum_rx.ipv4_header = rx;
3754                    }
3755                    "*LsoV2IPv4" => {
3756                        primary.offload_config.lso4 = as_num != 0;
3757                    }
3758                    "*LsoV2IPv6" => {
3759                        primary.offload_config.lso6 = as_num != 0;
3760                    }
3761                    "*TCPChecksumOffloadIPv4" => {
3762                        primary.offload_config.checksum_tx.tcp4 = tx;
3763                        primary.offload_config.checksum_rx.tcp4 = rx;
3764                    }
3765                    "*TCPChecksumOffloadIPv6" => {
3766                        primary.offload_config.checksum_tx.tcp6 = tx;
3767                        primary.offload_config.checksum_rx.tcp6 = rx;
3768                    }
3769                    "*UDPChecksumOffloadIPv4" => {
3770                        primary.offload_config.checksum_tx.udp4 = tx;
3771                        primary.offload_config.checksum_rx.udp4 = rx;
3772                    }
3773                    "*UDPChecksumOffloadIPv6" => {
3774                        primary.offload_config.checksum_tx.udp6 = tx;
3775                        primary.offload_config.checksum_rx.udp6 = rx;
3776                    }
3777                    _ => {}
3778                }
3779                primary
3780                    .offload_config
3781                    .mask_to_supported(&self.offload_support);
3782            }
3783            rndisprot::NdisParameterType::INTEGER => {
3784                let value: u32 = value.read_plain()?;
3785                tracing::debug!(name, value, "rndis config");
3786            }
3787            parameter_type => tracelimit::warn_ratelimited!(
3788                name,
3789                ?parameter_type,
3790                "unhandled rndis config parameter type"
3791            ),
3792        }
3793        Ok(())
3794    }
3795}
3796
3797fn read_ndis_object<T: IntoBytes + FromBytes + Debug + Immutable + KnownLayout>(
3798    mut reader: impl MemoryRead,
3799    object_type: rndisprot::NdisObjectType,
3800    min_revision: u8,
3801    min_size: usize,
3802) -> Result<T, OidError> {
3803    let mut buffer = T::new_zeroed();
3804    let sent_size = reader.len();
3805    let len = sent_size.min(size_of_val(&buffer));
3806    reader.read(&mut buffer.as_mut_bytes()[..len])?;
3807    validate_ndis_object_header(
3808        &rndisprot::NdisObjectHeader::read_from_prefix(buffer.as_bytes())
3809            .unwrap()
3810            .0, // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
3811        sent_size,
3812        object_type,
3813        min_revision,
3814        min_size,
3815    )?;
3816    Ok(buffer)
3817}
3818
3819fn validate_ndis_object_header(
3820    header: &rndisprot::NdisObjectHeader,
3821    sent_size: usize,
3822    object_type: rndisprot::NdisObjectType,
3823    min_revision: u8,
3824    min_size: usize,
3825) -> Result<(), OidError> {
3826    if header.object_type != object_type {
3827        return Err(OidError::InvalidInput("header.object_type"));
3828    }
3829    if sent_size < header.size as usize {
3830        return Err(OidError::InvalidInput("header.size"));
3831    }
3832    if header.revision < min_revision {
3833        return Err(OidError::InvalidInput("header.revision"));
3834    }
3835    if (header.size as usize) < min_size {
3836        return Err(OidError::InvalidInput("header.size"));
3837    }
3838    Ok(())
3839}
3840
3841struct Coordinator {
3842    recv: mpsc::Receiver<CoordinatorMessage>,
3843    channel_control: ChannelControl,
3844    restart: bool,
3845    workers: Vec<TaskControl<NetQueue, Worker<GpadlRingMem>>>,
3846    buffers: Option<Arc<ChannelBuffers>>,
3847    num_queues: u16,
3848    active_packet_filter: u32,
3849    sleep_deadline: Option<Instant>,
3850}
3851
3852/// Removing the VF may result in the guest sending messages to switch the data
3853/// path, so these operations need to happen asynchronously with message
3854/// processing.
3855enum CoordinatorStatePendingVfState {
3856    /// No pending updates.
3857    Ready,
3858    /// Delay before adding VF.
3859    Delay {
3860        timer: PolledTimer,
3861        delay_until: Instant,
3862    },
3863    /// A VF update is pending.
3864    Pending,
3865}
3866
3867struct CoordinatorState {
3868    endpoint: Box<dyn Endpoint>,
3869    adapter: Arc<Adapter>,
3870    virtual_function: Option<Box<dyn VirtualFunction>>,
3871    pending_vf_state: CoordinatorStatePendingVfState,
3872}
3873
3874impl InspectTaskMut<Coordinator> for CoordinatorState {
3875    fn inspect_mut(
3876        &mut self,
3877        req: inspect::Request<'_>,
3878        mut coordinator: Option<&mut Coordinator>,
3879    ) {
3880        let mut resp = req.respond();
3881
3882        let adapter = self.adapter.as_ref();
3883        resp.field("mac_address", adapter.mac_address)
3884            .field("max_queues", adapter.max_queues)
3885            .sensitivity_field(
3886                "offload_support",
3887                SensitivityLevel::Safe,
3888                &adapter.offload_support,
3889            )
3890            .field_mut_with("ring_size_limit", |v| -> anyhow::Result<_> {
3891                if let Some(v) = v {
3892                    let v: usize = v.parse()?;
3893                    adapter.ring_size_limit.store(v, Ordering::Relaxed);
3894                    // Bounce each task so that it sees the new value.
3895                    if let Some(this) = &mut coordinator {
3896                        for worker in &mut this.workers {
3897                            worker.update_with(|_, _| ());
3898                        }
3899                    }
3900                }
3901                Ok(adapter.ring_size_limit.load(Ordering::Relaxed))
3902            });
3903
3904        resp.field("endpoint_type", self.endpoint.endpoint_type())
3905            .field(
3906                "endpoint_max_queues",
3907                self.endpoint.multiqueue_support().max_queues,
3908            )
3909            .sensitivity_field_mut("endpoint", SensitivityLevel::Safe, self.endpoint.as_mut());
3910
3911        if let Some(coordinator) = coordinator {
3912            resp.sensitivity_child("queues", SensitivityLevel::Safe, |req| {
3913                let mut resp = req.respond();
3914                for (i, q) in coordinator.workers[..coordinator.num_queues as usize]
3915                    .iter_mut()
3916                    .enumerate()
3917                {
3918                    resp.field_mut(&i.to_string(), q);
3919                }
3920            });
3921
3922            // Get the shared channel state from the primary channel.
3923            resp.merge(inspect::adhoc_mut(|req| {
3924                let deferred = req.defer();
3925                coordinator.workers[0].update_with(|_, worker| {
3926                    let Some(worker) = worker.as_deref() else {
3927                        return;
3928                    };
3929                    if let Some(state) = worker.state.ready() {
3930                        deferred.respond(|resp| {
3931                            resp.merge(&state.buffers);
3932                            resp.sensitivity_field(
3933                                "primary_channel_state",
3934                                SensitivityLevel::Safe,
3935                                &state.state.primary,
3936                            )
3937                            .sensitivity_field(
3938                                "packet_filter",
3939                                SensitivityLevel::Safe,
3940                                inspect::AsHex(worker.channel.packet_filter),
3941                            );
3942                        });
3943                    }
3944                })
3945            }));
3946        }
3947    }
3948}
3949
3950impl AsyncRun<Coordinator> for CoordinatorState {
3951    async fn run(
3952        &mut self,
3953        stop: &mut StopTask<'_>,
3954        coordinator: &mut Coordinator,
3955    ) -> Result<(), task_control::Cancelled> {
3956        coordinator.process(stop, self).await
3957    }
3958}
3959
3960impl Coordinator {
3961    async fn process(
3962        &mut self,
3963        stop: &mut StopTask<'_>,
3964        state: &mut CoordinatorState,
3965    ) -> Result<(), task_control::Cancelled> {
3966        loop {
3967            if self.restart {
3968                stop.until_stopped(self.stop_workers()).await?;
3969                // The queue restart operation is not restartable, so do not
3970                // poll on `stop` here.
3971                if let Err(err) = self
3972                    .restart_queues(state)
3973                    .instrument(tracing::info_span!("netvsp_restart_queues"))
3974                    .await
3975                {
3976                    tracing::error!(
3977                        error = &err as &dyn std::error::Error,
3978                        "failed to restart queues"
3979                    );
3980                }
3981                if let Some(primary) = self.primary_mut() {
3982                    primary.is_data_path_switched =
3983                        state.endpoint.get_data_path_to_guest_vf().await.ok();
3984                    tracing::info!(
3985                        is_data_path_switched = primary.is_data_path_switched,
3986                        "Query data path state"
3987                    );
3988                }
3989                self.restore_guest_vf_state(state).await;
3990                self.restart = false;
3991            }
3992
3993            // Ensure that all workers except the primary are started. The
3994            // primary is started below if there are no outstanding messages.
3995            for worker in &mut self.workers[1..] {
3996                worker.start();
3997            }
3998            if !self.workers[0].is_running()
3999                && self.workers[0].state().is_none_or(|worker| {
4000                    !matches!(worker.state, WorkerState::WaitingForCoordinator(_))
4001                })
4002            {
4003                self.workers[0].start();
4004            }
4005
4006            enum Message {
4007                Internal(CoordinatorMessage),
4008                ChannelDisconnected,
4009                UpdateFromEndpoint(EndpointAction),
4010                UpdateFromVf(Rpc<(), ()>),
4011                OfferVfDevice,
4012                PendingVfStateComplete,
4013                TimerExpired,
4014            }
4015            let message = if matches!(
4016                state.pending_vf_state,
4017                CoordinatorStatePendingVfState::Pending
4018            ) {
4019                // guest_ready_for_device is not restartable, so do not poll on
4020                // stop.
4021                state
4022                    .virtual_function
4023                    .as_mut()
4024                    .expect("Pending requires a VF")
4025                    .guest_ready_for_device()
4026                    .await;
4027                Message::PendingVfStateComplete
4028            } else {
4029                let timer_sleep = async {
4030                    if let Some(deadline) = self.sleep_deadline {
4031                        let mut timer = PolledTimer::new(&state.adapter.driver);
4032                        timer.sleep_until(deadline).await;
4033                    } else {
4034                        pending::<()>().await;
4035                    }
4036                    Message::TimerExpired
4037                };
4038                let wait_for_message = async {
4039                    let internal_msg = self
4040                        .recv
4041                        .next()
4042                        .map(|x| x.map_or(Message::ChannelDisconnected, Message::Internal));
4043                    let endpoint_restart = state
4044                        .endpoint
4045                        .wait_for_endpoint_action()
4046                        .map(Message::UpdateFromEndpoint);
4047                    if let Some(vf) = state.virtual_function.as_mut() {
4048                        match state.pending_vf_state {
4049                            CoordinatorStatePendingVfState::Ready
4050                            | CoordinatorStatePendingVfState::Delay { .. } => {
4051                                let offer_device = async {
4052                                    if let CoordinatorStatePendingVfState::Delay {
4053                                        timer,
4054                                        delay_until,
4055                                    } = &mut state.pending_vf_state
4056                                    {
4057                                        timer.sleep_until(*delay_until).await;
4058                                    } else {
4059                                        pending::<()>().await;
4060                                    }
4061                                    Message::OfferVfDevice
4062                                };
4063                                (
4064                                    internal_msg,
4065                                    offer_device,
4066                                    endpoint_restart,
4067                                    vf.wait_for_state_change().map(Message::UpdateFromVf),
4068                                    timer_sleep,
4069                                )
4070                                    .race()
4071                                    .await
4072                            }
4073                            CoordinatorStatePendingVfState::Pending => unreachable!(),
4074                        }
4075                    } else {
4076                        (internal_msg, endpoint_restart, timer_sleep).race().await
4077                    }
4078                };
4079
4080                stop.until_stopped(wait_for_message).await?
4081            };
4082            match message {
4083                Message::Internal(msg) => {
4084                    self.handle_coordinator_message(msg, state).await;
4085                }
4086                Message::UpdateFromVf(rpc) => {
4087                    rpc.handle(async |_| {
4088                        self.update_guest_vf_state(state).await;
4089                    })
4090                    .await;
4091                }
4092                Message::OfferVfDevice => {
4093                    self.workers[0].stop().await;
4094                    if let Some(primary) = self.primary_mut() {
4095                        if matches!(
4096                            primary.guest_vf_state,
4097                            PrimaryChannelGuestVfState::AvailableAdvertised
4098                        ) {
4099                            primary.guest_vf_state = PrimaryChannelGuestVfState::Ready;
4100                        }
4101                    }
4102
4103                    state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4104                }
4105                Message::PendingVfStateComplete => {
4106                    state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
4107                }
4108                Message::TimerExpired => {
4109                    // Kick the worker as requested.
4110                    self.workers[0].stop().await;
4111                    if let Some(primary) = self.primary_mut() {
4112                        if let PendingLinkAction::Delay(up) = primary.pending_link_action {
4113                            primary.pending_link_action = PendingLinkAction::Active(up);
4114                        }
4115                    }
4116                    self.sleep_deadline = None;
4117                }
4118                Message::UpdateFromEndpoint(EndpointAction::RestartRequired) => self.restart = true,
4119                Message::UpdateFromEndpoint(EndpointAction::LinkStatusNotify(connect)) => {
4120                    self.workers[0].stop().await;
4121
4122                    // These are the only link state transitions that are tracked.
4123                    // 1. up -> down or down -> up
4124                    // 2. up -> down -> up or down -> up -> down.
4125                    // All other state transitions are coalesced into one of the above cases.
4126                    // For example, up -> down -> up -> down is treated as up -> down.
4127                    // N.B - Always queue up the incoming state to minimize the effects of loss
4128                    //       of any notifications (for example, during vtl2 servicing).
4129                    if let Some(primary) = self.primary_mut() {
4130                        primary.pending_link_action = PendingLinkAction::Active(connect);
4131                    }
4132
4133                    // If there is any existing sleep timer running, cancel it out.
4134                    self.sleep_deadline = None;
4135                }
4136                Message::ChannelDisconnected => {
4137                    break;
4138                }
4139            };
4140        }
4141        Ok(())
4142    }
4143
4144    async fn handle_coordinator_message(
4145        &mut self,
4146        msg: CoordinatorMessage,
4147        state: &mut CoordinatorState,
4148    ) {
4149        self.workers[0].stop().await;
4150        if let Some(worker) = self.workers[0].state_mut() {
4151            if matches!(worker.state, WorkerState::WaitingForCoordinator(_)) {
4152                let WorkerState::WaitingForCoordinator(Some(ready)) =
4153                    std::mem::replace(&mut worker.state, WorkerState::WaitingForCoordinator(None))
4154                else {
4155                    unreachable!("valid ready state")
4156                };
4157                let _ = std::mem::replace(&mut worker.state, WorkerState::Ready(ready));
4158            }
4159        }
4160        match msg {
4161            CoordinatorMessage::Update(update_type) => {
4162                if update_type.filter_state {
4163                    self.stop_workers().await;
4164                    self.active_packet_filter =
4165                        self.workers[0].state().unwrap().channel.packet_filter;
4166                    self.workers.iter_mut().skip(1).for_each(|worker| {
4167                        if let Some(state) = worker.state_mut() {
4168                            state.channel.packet_filter = self.active_packet_filter;
4169                            tracing::debug!(
4170                                packet_filter = ?self.active_packet_filter,
4171                                channel_idx = state.channel_idx,
4172                                "update packet filter"
4173                            );
4174                        }
4175                    });
4176                }
4177
4178                if update_type.guest_vf_state {
4179                    self.update_guest_vf_state(state).await;
4180                }
4181            }
4182            CoordinatorMessage::StartTimer(deadline) => {
4183                self.sleep_deadline = Some(deadline);
4184            }
4185            CoordinatorMessage::Restart => self.restart = true,
4186        }
4187    }
4188
4189    async fn stop_workers(&mut self) {
4190        for worker in &mut self.workers {
4191            worker.stop().await;
4192        }
4193    }
4194
4195    async fn restore_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
4196        let primary = match self.primary_mut() {
4197            Some(primary) => primary,
4198            None => return,
4199        };
4200
4201        // Update guest VF state based on current endpoint properties.
4202        let virtual_function = c_state.virtual_function.as_mut();
4203        let guest_vf_id = match &virtual_function {
4204            Some(vf) => vf.id().await,
4205            None => None,
4206        };
4207        if let Some(guest_vf_id) = guest_vf_id {
4208            // Ensure guest VF is in proper state.
4209            match primary.guest_vf_state {
4210                PrimaryChannelGuestVfState::AvailableAdvertised
4211                | PrimaryChannelGuestVfState::Restoring(
4212                    saved_state::GuestVfState::AvailableAdvertised,
4213                ) => {
4214                    if !primary.is_data_path_switched.unwrap_or(false) {
4215                        let timer = PolledTimer::new(&c_state.adapter.driver);
4216                        c_state.pending_vf_state = CoordinatorStatePendingVfState::Delay {
4217                            timer,
4218                            delay_until: Instant::now() + VF_DEVICE_DELAY,
4219                        };
4220                    }
4221                }
4222                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4223                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4224                | PrimaryChannelGuestVfState::Ready
4225                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready)
4226                | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4227                | PrimaryChannelGuestVfState::Restoring(
4228                    saved_state::GuestVfState::DataPathSwitchPending { .. },
4229                )
4230                | PrimaryChannelGuestVfState::DataPathSwitched
4231                | PrimaryChannelGuestVfState::Restoring(
4232                    saved_state::GuestVfState::DataPathSwitched,
4233                )
4234                | PrimaryChannelGuestVfState::DataPathSynthetic => {
4235                    c_state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4236                }
4237                _ => (),
4238            };
4239            // ensure data path is switched as expected
4240            if let PrimaryChannelGuestVfState::Restoring(
4241                saved_state::GuestVfState::DataPathSwitchPending {
4242                    to_guest,
4243                    id,
4244                    result,
4245                },
4246            ) = primary.guest_vf_state
4247            {
4248                // If the save was after the data path switch already occurred, don't do it again.
4249                if result.is_some() {
4250                    primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
4251                        to_guest,
4252                        id,
4253                        result,
4254                    };
4255                    return;
4256                }
4257            }
4258            primary.guest_vf_state = match primary.guest_vf_state {
4259                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4260                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4261                | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4262                | PrimaryChannelGuestVfState::Restoring(
4263                    saved_state::GuestVfState::DataPathSwitchPending { .. },
4264                )
4265                | PrimaryChannelGuestVfState::DataPathSynthetic => {
4266                    let (to_guest, id) = match primary.guest_vf_state {
4267                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4268                            to_guest,
4269                            id,
4270                        }
4271                        | PrimaryChannelGuestVfState::DataPathSwitchPending {
4272                            to_guest, id, ..
4273                        }
4274                        | PrimaryChannelGuestVfState::Restoring(
4275                            saved_state::GuestVfState::DataPathSwitchPending {
4276                                to_guest, id, ..
4277                            },
4278                        ) => (to_guest, id),
4279                        _ => (true, None),
4280                    };
4281                    // Cancel any outstanding delay timers for VF offers if the data path is
4282                    // getting switched, since the guest is already issuing
4283                    // commands assuming a VF.
4284                    if matches!(
4285                        c_state.pending_vf_state,
4286                        CoordinatorStatePendingVfState::Delay { .. }
4287                    ) {
4288                        c_state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4289                    }
4290                    let result = c_state.endpoint.set_data_path_to_guest_vf(to_guest).await;
4291                    let result = if let Err(err) = result {
4292                        tracing::error!(
4293                            err = err.as_ref() as &dyn std::error::Error,
4294                            to_guest,
4295                            "Failed to switch guest VF data path"
4296                        );
4297                        false
4298                    } else {
4299                        primary.is_data_path_switched = Some(to_guest);
4300                        true
4301                    };
4302                    match primary.guest_vf_state {
4303                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4304                            ..
4305                        }
4306                        | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4307                        | PrimaryChannelGuestVfState::Restoring(
4308                            saved_state::GuestVfState::DataPathSwitchPending { .. },
4309                        ) => PrimaryChannelGuestVfState::DataPathSwitchPending {
4310                            to_guest,
4311                            id,
4312                            result: Some(result),
4313                        },
4314                        _ if result => PrimaryChannelGuestVfState::DataPathSwitched,
4315                        _ => PrimaryChannelGuestVfState::DataPathSynthetic,
4316                    }
4317                }
4318                PrimaryChannelGuestVfState::Initializing
4319                | PrimaryChannelGuestVfState::Unavailable
4320                | PrimaryChannelGuestVfState::UnavailableFromAvailable
4321                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
4322                    PrimaryChannelGuestVfState::Available { vfid: guest_vf_id }
4323                }
4324                PrimaryChannelGuestVfState::AvailableAdvertised
4325                | PrimaryChannelGuestVfState::Restoring(
4326                    saved_state::GuestVfState::AvailableAdvertised,
4327                ) => {
4328                    if !primary.is_data_path_switched.unwrap_or(false) {
4329                        PrimaryChannelGuestVfState::AvailableAdvertised
4330                    } else {
4331                        // A previous instantiation already switched the data
4332                        // path.
4333                        PrimaryChannelGuestVfState::DataPathSwitched
4334                    }
4335                }
4336                PrimaryChannelGuestVfState::DataPathSwitched
4337                | PrimaryChannelGuestVfState::Restoring(
4338                    saved_state::GuestVfState::DataPathSwitched,
4339                ) => PrimaryChannelGuestVfState::DataPathSwitched,
4340                PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4341                    PrimaryChannelGuestVfState::Ready
4342                }
4343                _ => primary.guest_vf_state,
4344            };
4345        } else {
4346            // If the device was just removed, make sure the data path is synthetic.
4347            match primary.guest_vf_state {
4348                PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, .. }
4349                | PrimaryChannelGuestVfState::Restoring(
4350                    saved_state::GuestVfState::DataPathSwitchPending { to_guest, .. },
4351                ) => {
4352                    if !to_guest {
4353                        if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4354                            tracing::warn!(
4355                                err = err.as_ref() as &dyn std::error::Error,
4356                                "Failed setting data path back to synthetic after guest VF was removed."
4357                            );
4358                        }
4359                        primary.is_data_path_switched = Some(false);
4360                    }
4361                }
4362                PrimaryChannelGuestVfState::DataPathSwitched
4363                | PrimaryChannelGuestVfState::Restoring(
4364                    saved_state::GuestVfState::DataPathSwitched,
4365                ) => {
4366                    if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4367                        tracing::warn!(
4368                            err = err.as_ref() as &dyn std::error::Error,
4369                            "Failed setting data path back to synthetic after guest VF was removed."
4370                        );
4371                    }
4372                    primary.is_data_path_switched = Some(false);
4373                }
4374                _ => (),
4375            }
4376            if let PrimaryChannelGuestVfState::AvailableAdvertised = primary.guest_vf_state {
4377                c_state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
4378            }
4379            // Notify guest if VF is no longer available
4380            primary.guest_vf_state = match primary.guest_vf_state {
4381                PrimaryChannelGuestVfState::Initializing
4382                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState)
4383                | PrimaryChannelGuestVfState::Available { .. } => {
4384                    PrimaryChannelGuestVfState::Unavailable
4385                }
4386                PrimaryChannelGuestVfState::AvailableAdvertised
4387                | PrimaryChannelGuestVfState::Restoring(
4388                    saved_state::GuestVfState::AvailableAdvertised,
4389                )
4390                | PrimaryChannelGuestVfState::Ready
4391                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4392                    PrimaryChannelGuestVfState::UnavailableFromAvailable
4393                }
4394                PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, id, .. }
4395                | PrimaryChannelGuestVfState::Restoring(
4396                    saved_state::GuestVfState::DataPathSwitchPending { to_guest, id, .. },
4397                ) => PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4398                    to_guest,
4399                    id,
4400                },
4401                PrimaryChannelGuestVfState::DataPathSwitched
4402                | PrimaryChannelGuestVfState::Restoring(
4403                    saved_state::GuestVfState::DataPathSwitched,
4404                )
4405                | PrimaryChannelGuestVfState::DataPathSynthetic => {
4406                    PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4407                }
4408                _ => primary.guest_vf_state,
4409            }
4410        }
4411    }
4412
4413    async fn restart_queues(&mut self, c_state: &mut CoordinatorState) -> Result<(), WorkerError> {
4414        // Drop all the queues and stop the endpoint. Collect the worker drivers to pass to the queues.
4415        let drivers = self
4416            .workers
4417            .iter_mut()
4418            .map(|worker| {
4419                let task = worker.task_mut();
4420                task.queue_state = None;
4421                task.driver.clone()
4422            })
4423            .collect::<Vec<_>>();
4424
4425        c_state.endpoint.stop().await;
4426
4427        let (primary_worker, subworkers) = if let [primary, sub @ ..] = self.workers.as_mut_slice()
4428        {
4429            (primary, sub)
4430        } else {
4431            unreachable!()
4432        };
4433
4434        let state = primary_worker
4435            .state_mut()
4436            .and_then(|worker| worker.state.ready_mut());
4437
4438        let state = if let Some(state) = state {
4439            state
4440        } else {
4441            return Ok(());
4442        };
4443
4444        // Save the channel buffers for use in the subchannel workers.
4445        self.buffers = Some(state.buffers.clone());
4446
4447        let num_queues = state.state.primary.as_ref().unwrap().requested_num_queues;
4448        let mut active_queues = Vec::new();
4449        let active_queue_count =
4450            if let Some(rss_state) = state.state.primary.as_ref().unwrap().rss_state.as_ref() {
4451                // Active queue count is computed as the number of unique entries in the indirection table
4452                active_queues.clone_from(&rss_state.indirection_table);
4453                active_queues.sort();
4454                active_queues.dedup();
4455                active_queues = active_queues
4456                    .into_iter()
4457                    .filter(|&index| index < num_queues)
4458                    .collect::<Vec<_>>();
4459                if !active_queues.is_empty() {
4460                    active_queues.len() as u16
4461                } else {
4462                    tracelimit::warn_ratelimited!("Invalid RSS indirection table");
4463                    num_queues
4464                }
4465            } else {
4466                num_queues
4467            };
4468
4469        // Distribute the rx buffers to only the active queues.
4470        let (ranges, mut remote_buffer_id_recvs) =
4471            RxBufferRanges::new(state.buffers.recv_buffer.count, active_queue_count.into());
4472        let ranges = Arc::new(ranges);
4473
4474        let mut queues = Vec::new();
4475        let mut rx_buffers = Vec::new();
4476        let mut per_queue_rx: Vec<Vec<RxId>> = Vec::new();
4477        let guest_buffers;
4478        {
4479            let buffers = &state.buffers;
4480            guest_buffers = Arc::new(
4481                GuestBuffers::new(
4482                    buffers.mem.clone(),
4483                    buffers.recv_buffer.gpadl.clone(),
4484                    buffers.recv_buffer.sub_allocation_size,
4485                    buffers.ndis_config.mtu,
4486                )
4487                .map_err(WorkerError::GpadlError)?,
4488            );
4489
4490            // Get the list of free rx buffers from each task, then partition
4491            // the list per-active-queue, and produce the queue configuration.
4492            let mut queue_config = Vec::new();
4493            let initial_rx;
4494            {
4495                let states = std::iter::once(Some(&*state)).chain(
4496                    subworkers
4497                        .iter()
4498                        .map(|worker| worker.state().and_then(|worker| worker.state.ready())),
4499                );
4500
4501                initial_rx = (RX_RESERVED_CONTROL_BUFFERS..state.buffers.recv_buffer.count)
4502                    .filter(|&n| states.clone().flatten().all(|s| s.state.rx_bufs.is_free(n)))
4503                    .map(RxId)
4504                    .collect::<Vec<_>>();
4505
4506                let mut initial_rx = initial_rx.as_slice();
4507                let mut range_start = 0;
4508                let primary_queue_excluded = !active_queues.is_empty() && active_queues[0] != 0;
4509                let first_queue = if !primary_queue_excluded {
4510                    0
4511                } else {
4512                    // If the primary queue is excluded from the guest supplied
4513                    // indirection table, it is assigned just the reserved
4514                    // buffers.
4515                    queue_config.push(QueueConfig {
4516                        driver: Box::new(drivers[0].clone()),
4517                    });
4518                    per_queue_rx.push(Vec::new());
4519                    rx_buffers.push(RxBufferRange::new(
4520                        ranges.clone(),
4521                        0..RX_RESERVED_CONTROL_BUFFERS,
4522                        None,
4523                    ));
4524                    range_start = RX_RESERVED_CONTROL_BUFFERS;
4525                    1
4526                };
4527                for queue_index in first_queue..num_queues {
4528                    let queue_active = active_queues.is_empty()
4529                        || active_queues.binary_search(&queue_index).is_ok();
4530                    let (range_end, end, buffer_id_recv) = if queue_active {
4531                        let range_end = if rx_buffers.len() as u16 == active_queue_count - 1 {
4532                            // The last queue gets all the remaining buffers.
4533                            state.buffers.recv_buffer.count
4534                        } else if queue_index == 0 {
4535                            // Queue zero always includes the reserved buffers.
4536                            RX_RESERVED_CONTROL_BUFFERS + ranges.buffers_per_queue
4537                        } else {
4538                            range_start + ranges.buffers_per_queue
4539                        };
4540                        (
4541                            range_end,
4542                            initial_rx.partition_point(|id| id.0 < range_end),
4543                            Some(remote_buffer_id_recvs.remove(0)),
4544                        )
4545                    } else {
4546                        (range_start, 0, None)
4547                    };
4548
4549                    let (this, rest) = initial_rx.split_at(end);
4550                    queue_config.push(QueueConfig {
4551                        driver: Box::new(drivers[queue_index as usize].clone()),
4552                    });
4553                    per_queue_rx.push(this.to_vec());
4554                    initial_rx = rest;
4555                    rx_buffers.push(RxBufferRange::new(
4556                        ranges.clone(),
4557                        range_start..range_end,
4558                        buffer_id_recv,
4559                    ));
4560
4561                    range_start = range_end;
4562                }
4563            }
4564
4565            let primary = state.state.primary.as_mut().unwrap();
4566            tracing::debug!(num_queues, "enabling endpoint");
4567
4568            let rss = primary
4569                .rss_state
4570                .as_ref()
4571                .map(|rss| net_backend::RssConfig {
4572                    key: &rss.key,
4573                    indirection_table: &rss.indirection_table,
4574                    flags: 0,
4575                });
4576
4577            c_state
4578                .endpoint
4579                .get_queues(queue_config, rss.as_ref(), &mut queues)
4580                .instrument(tracing::info_span!("netvsp_get_queues"))
4581                .await
4582                .map_err(WorkerError::Endpoint)?;
4583
4584            assert_eq!(queues.len(), num_queues as usize);
4585
4586            // Set the subchannel count.
4587            self.channel_control
4588                .enable_subchannels(num_queues - 1)
4589                .expect("already validated");
4590
4591            self.num_queues = num_queues;
4592        }
4593
4594        self.active_packet_filter = self.workers[0].state().unwrap().channel.packet_filter;
4595        // Provide the queue and receive buffer ranges for each worker.
4596        for (((worker, mut queue), rx_buffer), initial) in self
4597            .workers
4598            .iter_mut()
4599            .zip(queues)
4600            .zip(rx_buffers)
4601            .zip(per_queue_rx)
4602        {
4603            let mut pool = BufferPool::new(guest_buffers.clone());
4604            if !initial.is_empty() {
4605                queue.rx_avail(&mut pool, &initial);
4606            }
4607            worker.task_mut().queue_state = Some(QueueState {
4608                queue,
4609                pool,
4610                target_vp_set: false,
4611                rx_buffer_range: rx_buffer,
4612            });
4613            // Update the receive packet filter for the subchannel worker.
4614            if let Some(worker) = worker.state_mut() {
4615                worker.channel.packet_filter = self.active_packet_filter;
4616                // Clear any pending RxIds as buffers were redistributed
4617                // and reset TX tracking after the endpoint stop.
4618                if let Some(ready_state) = worker.state.ready_mut() {
4619                    ready_state.state.pending_rx_packets.clear();
4620                    ready_state.reset_tx_after_endpoint_stop();
4621                }
4622            }
4623        }
4624
4625        Ok(())
4626    }
4627
4628    fn primary_mut(&mut self) -> Option<&mut PrimaryChannelState> {
4629        self.workers[0]
4630            .state_mut()
4631            .unwrap()
4632            .state
4633            .ready_mut()?
4634            .state
4635            .primary
4636            .as_mut()
4637    }
4638
4639    async fn update_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
4640        self.workers[0].stop().await;
4641        self.restore_guest_vf_state(c_state).await;
4642    }
4643}
4644
4645impl<T: RingMem + 'static + Sync> AsyncRun<Worker<T>> for NetQueue {
4646    async fn run(
4647        &mut self,
4648        stop: &mut StopTask<'_>,
4649        worker: &mut Worker<T>,
4650    ) -> Result<(), task_control::Cancelled> {
4651        match worker.process(stop, self).await {
4652            Ok(()) | Err(WorkerError::BufferRevoked) => {}
4653            Err(WorkerError::Cancelled(cancelled)) => return Err(cancelled),
4654            Err(err) => {
4655                tracing::error!(
4656                    error = &err as &dyn std::error::Error,
4657                    channel_idx = worker.channel_idx,
4658                    "netvsp error"
4659                );
4660            }
4661        }
4662        Ok(())
4663    }
4664}
4665
4666impl<T: RingMem + 'static> Worker<T> {
4667    async fn process(
4668        &mut self,
4669        stop: &mut StopTask<'_>,
4670        queue: &mut NetQueue,
4671    ) -> Result<(), WorkerError> {
4672        // Be careful not to wait on actions with unbounded blocking time (e.g.
4673        // guest actions, or waiting for network packets to arrive) without
4674        // wrapping the wait on `stop.until_stopped`.
4675        loop {
4676            match &mut self.state {
4677                WorkerState::Init(initializing) => {
4678                    assert_eq!(self.channel_idx, 0);
4679
4680                    tracelimit::info_ratelimited!("network accepted");
4681
4682                    let (buffers, state) = stop
4683                        .until_stopped(self.channel.initialize(initializing, self.mem.clone()))
4684                        .await??;
4685
4686                    let state = ReadyState {
4687                        buffers: Arc::new(buffers),
4688                        state,
4689                        data: ProcessingData::new(),
4690                    };
4691
4692                    // Wake up the coordinator task to start the queues.
4693                    let _ = self.coordinator_send.try_send(CoordinatorMessage::Restart);
4694
4695                    tracelimit::info_ratelimited!("network initialized");
4696                    self.state = WorkerState::WaitingForCoordinator(Some(state));
4697                }
4698                WorkerState::WaitingForCoordinator(_) => {
4699                    assert_eq!(self.channel_idx, 0);
4700                    // Waiting for the coordinator to process a message and
4701                    // restart the primary worker.
4702                    stop.until_stopped(pending()).await?
4703                }
4704                WorkerState::Ready(state) => {
4705                    let queue_state = if let Some(queue_state) = &mut queue.queue_state {
4706                        if !queue_state.target_vp_set {
4707                            if let Some(target_vp) = self.target_vp {
4708                                tracing::debug!(
4709                                    channel_idx = self.channel_idx,
4710                                    target_vp,
4711                                    "updating target VP"
4712                                );
4713                                queue_state.queue.update_target_vp(target_vp).await;
4714                                queue_state.target_vp_set = true;
4715                            }
4716                        }
4717
4718                        queue_state
4719                    } else {
4720                        // This task will be restarted when the queues are ready.
4721                        stop.until_stopped(pending()).await?
4722                    };
4723
4724                    let result = self.channel.main_loop(stop, state, queue_state).await;
4725                    let msg = match result {
4726                        Ok(restart) => {
4727                            assert_eq!(self.channel_idx, 0);
4728                            restart
4729                        }
4730                        Err(WorkerError::EndpointRequiresQueueRestart(err)) => {
4731                            tracelimit::warn_ratelimited!(
4732                                err = err.as_ref() as &dyn std::error::Error,
4733                                "Endpoint requires queues to restart",
4734                            );
4735                            CoordinatorMessage::Restart
4736                        }
4737                        Err(err) => return Err(err),
4738                    };
4739
4740                    let WorkerState::Ready(ready) = std::mem::replace(
4741                        &mut self.state,
4742                        WorkerState::WaitingForCoordinator(None),
4743                    ) else {
4744                        unreachable!("must be running in ready state")
4745                    };
4746                    let _ = std::mem::replace(
4747                        &mut self.state,
4748                        WorkerState::WaitingForCoordinator(Some(ready)),
4749                    );
4750                    self.coordinator_send
4751                        .try_send(msg)
4752                        .map_err(WorkerError::CoordinatorMessageSendFailed)?;
4753                    stop.until_stopped(pending()).await?
4754                }
4755            }
4756        }
4757    }
4758}
4759
4760impl<T: 'static + RingMem> NetChannel<T> {
4761    fn try_next_packet<'a>(
4762        &mut self,
4763        send_buffer: Option<&SendBuffer>,
4764        version: Option<Version>,
4765        external_data: &'a mut MultiPagedRangeBuf,
4766    ) -> Result<Option<Packet<'a>>, WorkerError> {
4767        let (mut read, _) = self.queue.split();
4768        let packet = match read.try_read() {
4769            Ok(packet) => parse_packet(&packet, send_buffer, version, external_data)
4770                .map_err(WorkerError::Packet)?,
4771            Err(queue::TryReadError::Empty) => return Ok(None),
4772            Err(queue::TryReadError::Queue(err)) => return Err(err.into()),
4773        };
4774
4775        tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
4776        Ok(Some(packet))
4777    }
4778
4779    async fn next_packet<'a>(
4780        &mut self,
4781        send_buffer: Option<&'a SendBuffer>,
4782        version: Option<Version>,
4783        external_data: &'a mut MultiPagedRangeBuf,
4784    ) -> Result<Packet<'a>, WorkerError> {
4785        let (mut read, _) = self.queue.split();
4786        let mut packet_ref = read.read().await?;
4787        let packet = parse_packet(&packet_ref, send_buffer, version, external_data)
4788            .map_err(WorkerError::Packet)?;
4789        if matches!(packet.data, PacketData::RndisPacket(_)) {
4790            // In WorkerState::Init if an rndis packet is received, assume it is MESSAGE_TYPE_INITIALIZE_MSG
4791            tracing::trace!(target: "netvsp/vmbus", "detected rndis initialization message");
4792            packet_ref.revert();
4793        }
4794        tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
4795        Ok(packet)
4796    }
4797
4798    fn is_ready_to_initialize(initializing: &InitState, allow_missing_send_buffer: bool) -> bool {
4799        (initializing.ndis_config.is_some() || initializing.version < Version::V2)
4800            && initializing.ndis_version.is_some()
4801            && (initializing.send_buffer.is_some() || allow_missing_send_buffer)
4802            && initializing.recv_buffer.is_some()
4803    }
4804
4805    async fn initialize(
4806        &mut self,
4807        initializing: &mut Option<InitState>,
4808        mem: GuestMemory,
4809    ) -> Result<(ChannelBuffers, ActiveState), WorkerError> {
4810        let mut has_init_packet_arrived = false;
4811        loop {
4812            if let Some(initializing) = &mut *initializing {
4813                if Self::is_ready_to_initialize(initializing, false) || has_init_packet_arrived {
4814                    let recv_buffer = initializing.recv_buffer.take().unwrap();
4815                    let send_buffer = initializing.send_buffer.take();
4816                    let state = ActiveState::new(
4817                        Some(PrimaryChannelState::new(
4818                            self.adapter.offload_support.clone(),
4819                        )),
4820                        recv_buffer.count,
4821                    );
4822                    let buffers = ChannelBuffers {
4823                        version: initializing.version,
4824                        mem,
4825                        recv_buffer,
4826                        send_buffer,
4827                        ndis_version: initializing.ndis_version.take().unwrap(),
4828                        ndis_config: initializing.ndis_config.take().unwrap_or(NdisConfig {
4829                            mtu: DEFAULT_MTU,
4830                            capabilities: protocol::NdisConfigCapabilities::new(),
4831                        }),
4832                    };
4833
4834                    break Ok((buffers, state));
4835                }
4836            }
4837
4838            // Wait for enough room in the ring to avoid needing to track
4839            // completion packets.
4840            self.queue
4841                .split()
4842                .1
4843                .wait_ready(ring::PacketSize::completion(protocol::PACKET_SIZE_V61))
4844                .await?;
4845
4846            let mut external_data = MultiPagedRangeBuf::new();
4847            let packet = self
4848                .next_packet(
4849                    None,
4850                    initializing.as_ref().map(|x| x.version),
4851                    &mut external_data,
4852                )
4853                .await?;
4854
4855            if let Some(initializing) = &mut *initializing {
4856                match packet.data {
4857                    PacketData::SendNdisConfig(config) => {
4858                        if initializing.ndis_config.is_some() {
4859                            return Err(WorkerError::UnexpectedPacketOrder(
4860                                PacketOrderError::SendNdisConfigExists,
4861                            ));
4862                        }
4863
4864                        // As in the vmswitch, if the MTU is invalid then use the default.
4865                        let mtu = if config.mtu >= MIN_MTU && config.mtu <= MAX_MTU {
4866                            config.mtu
4867                        } else {
4868                            DEFAULT_MTU
4869                        };
4870
4871                        // The UEFI client expects a completion packet, which can be empty.
4872                        self.send_completion(packet.transaction_id, None)?;
4873                        initializing.ndis_config = Some(NdisConfig {
4874                            mtu,
4875                            capabilities: config.capabilities,
4876                        });
4877                    }
4878                    PacketData::SendNdisVersion(version) => {
4879                        if initializing.ndis_version.is_some() {
4880                            return Err(WorkerError::UnexpectedPacketOrder(
4881                                PacketOrderError::SendNdisVersionExists,
4882                            ));
4883                        }
4884
4885                        // The UEFI client expects a completion packet, which can be empty.
4886                        self.send_completion(packet.transaction_id, None)?;
4887                        initializing.ndis_version = Some(NdisVersion {
4888                            major: version.ndis_major_version,
4889                            minor: version.ndis_minor_version,
4890                        });
4891                    }
4892                    PacketData::SendReceiveBuffer(message) => {
4893                        if initializing.recv_buffer.is_some() {
4894                            return Err(WorkerError::UnexpectedPacketOrder(
4895                                PacketOrderError::SendReceiveBufferExists,
4896                            ));
4897                        }
4898
4899                        let mtu = if let Some(cfg) = &initializing.ndis_config {
4900                            cfg.mtu
4901                        } else if initializing.version < Version::V2 {
4902                            DEFAULT_MTU
4903                        } else {
4904                            return Err(WorkerError::UnexpectedPacketOrder(
4905                                PacketOrderError::SendReceiveBufferMissingMTU,
4906                            ));
4907                        };
4908
4909                        let sub_allocation_size = sub_allocation_size_for_mtu(mtu);
4910
4911                        let recv_buffer = ReceiveBuffer::new(
4912                            &self.gpadl_map,
4913                            message.gpadl_handle,
4914                            message.id,
4915                            sub_allocation_size,
4916                        )?;
4917
4918                        self.send_completion(
4919                            packet.transaction_id,
4920                            Some(&self.message(
4921                                protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER_COMPLETE,
4922                                protocol::Message1SendReceiveBufferComplete {
4923                                    status: protocol::Status::SUCCESS,
4924                                    num_sections: 1,
4925                                    sections: [protocol::ReceiveBufferSection {
4926                                        offset: 0,
4927                                        sub_allocation_size: recv_buffer.sub_allocation_size,
4928                                        num_sub_allocations: recv_buffer.count,
4929                                        end_offset: recv_buffer.sub_allocation_size
4930                                            * recv_buffer.count,
4931                                    }],
4932                                },
4933                            )),
4934                        )?;
4935                        initializing.recv_buffer = Some(recv_buffer);
4936                    }
4937                    PacketData::SendSendBuffer(message) => {
4938                        if initializing.send_buffer.is_some() {
4939                            return Err(WorkerError::UnexpectedPacketOrder(
4940                                PacketOrderError::SendSendBufferExists,
4941                            ));
4942                        }
4943
4944                        let send_buffer = SendBuffer::new(&self.gpadl_map, message.gpadl_handle)?;
4945                        self.send_completion(
4946                            packet.transaction_id,
4947                            Some(&self.message(
4948                                protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER_COMPLETE,
4949                                protocol::Message1SendSendBufferComplete {
4950                                    status: protocol::Status::SUCCESS,
4951                                    section_size: 6144,
4952                                },
4953                            )),
4954                        )?;
4955
4956                        initializing.send_buffer = Some(send_buffer);
4957                    }
4958                    PacketData::RndisPacket(rndis_packet) => {
4959                        if !Self::is_ready_to_initialize(initializing, true) {
4960                            return Err(WorkerError::UnexpectedPacketOrder(
4961                                PacketOrderError::UnexpectedRndisPacket,
4962                            ));
4963                        }
4964                        tracing::debug!(
4965                            channel_type = rndis_packet.channel_type,
4966                            "RndisPacket received during initialization, assuming MESSAGE_TYPE_INITIALIZE_MSG"
4967                        );
4968                        has_init_packet_arrived = true;
4969                    }
4970                    _ => {
4971                        return Err(WorkerError::UnexpectedPacketOrder(
4972                            PacketOrderError::InvalidPacketData,
4973                        ));
4974                    }
4975                }
4976            } else {
4977                match packet.data {
4978                    PacketData::Init(init) => {
4979                        let requested_version = init.protocol_version;
4980                        let version = check_version(requested_version);
4981                        let mut data = protocol::MessageInitComplete {
4982                            deprecated: protocol::INVALID_PROTOCOL_VERSION,
4983                            maximum_mdl_chain_length: 34,
4984                            status: protocol::Status::NONE,
4985                        };
4986                        if let Some(version) = version {
4987                            if version == Version::V1 {
4988                                data.deprecated = Version::V1 as u32;
4989                            }
4990                            data.status = protocol::Status::SUCCESS;
4991                        } else {
4992                            tracing::debug!(requested_version, "unrecognized version");
4993                        }
4994                        let message = self.message(protocol::MESSAGE_TYPE_INIT_COMPLETE, data);
4995                        self.send_completion(packet.transaction_id, Some(&message))?;
4996
4997                        if let Some(version) = version {
4998                            tracelimit::info_ratelimited!(?version, "network negotiated");
4999
5000                            if version >= Version::V61 {
5001                                // Update the packet size so that the appropriate padding is
5002                                // appended for picky Windows guests.
5003                                self.packet_size = PacketSize::V61;
5004                            }
5005                            *initializing = Some(InitState {
5006                                version,
5007                                ndis_config: None,
5008                                ndis_version: None,
5009                                recv_buffer: None,
5010                                send_buffer: None,
5011                            });
5012                        }
5013                    }
5014                    _ => unreachable!(),
5015                }
5016            }
5017        }
5018    }
5019
5020    async fn main_loop(
5021        &mut self,
5022        stop: &mut StopTask<'_>,
5023        ready_state: &mut ReadyState,
5024        queue_state: &mut QueueState,
5025    ) -> Result<CoordinatorMessage, WorkerError> {
5026        let buffers = &ready_state.buffers;
5027        let state = &mut ready_state.state;
5028        let data = &mut ready_state.data;
5029
5030        let ring_spare_capacity = {
5031            let (_, send) = self.queue.split();
5032            let mut limit = if self.can_use_ring_size_opt {
5033                self.adapter.ring_size_limit.load(Ordering::Relaxed)
5034            } else {
5035                0
5036            };
5037            if limit == 0 {
5038                limit = send.capacity() - 2048;
5039            }
5040            send.capacity() - limit
5041        };
5042
5043        // If the packet filter has changed to allow rx packets, add any pended RxIds.
5044        if !state.pending_rx_packets.is_empty()
5045            && self.packet_filter != rndisprot::NDIS_PACKET_TYPE_NONE
5046        {
5047            let (front, back) = state.pending_rx_packets.as_slices();
5048            queue_state.queue.rx_avail(&mut queue_state.pool, front);
5049            queue_state.queue.rx_avail(&mut queue_state.pool, back);
5050            state.pending_rx_packets.clear();
5051        }
5052
5053        // Handle any guest state changes since last run.
5054        if let Some(primary) = state.primary.as_mut() {
5055            if primary.requested_num_queues > 1 && !primary.tx_spread_sent {
5056                let num_channels_opened =
5057                    self.adapter.num_sub_channels_opened.load(Ordering::Relaxed);
5058                if num_channels_opened == primary.requested_num_queues as usize {
5059                    let (_, mut send) = self.queue.split();
5060                    stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
5061                        .await??;
5062                    self.guest_send_indirection_table(buffers.version, num_channels_opened as u32);
5063                    primary.tx_spread_sent = true;
5064                }
5065            }
5066            if let PendingLinkAction::Active(up) = primary.pending_link_action {
5067                let (_, mut send) = self.queue.split();
5068                stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
5069                    .await??;
5070                if let Some(id) = primary.free_control_buffers.pop() {
5071                    let connect = if primary.guest_link_up != up {
5072                        primary.pending_link_action = PendingLinkAction::Default;
5073                        up
5074                    } else {
5075                        // For the up -> down -> up OR down -> up -> down case, the first transition
5076                        // is sent immediately and the second transition is queued with a delay.
5077                        primary.pending_link_action =
5078                            PendingLinkAction::Delay(primary.guest_link_up);
5079                        !primary.guest_link_up
5080                    };
5081                    // Mark the receive buffer in use to allow the guest to release it.
5082                    assert!(state.rx_bufs.is_free(id.0));
5083                    state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
5084                    let state_to_send = if connect {
5085                        rndisprot::STATUS_MEDIA_CONNECT
5086                    } else {
5087                        rndisprot::STATUS_MEDIA_DISCONNECT
5088                    };
5089                    tracing::info!(
5090                        connect,
5091                        mac_address = %self.adapter.mac_address,
5092                        "sending link status"
5093                    );
5094
5095                    self.indicate_status(buffers, id.0, state_to_send, &[])?;
5096                    primary.guest_link_up = connect;
5097                } else {
5098                    primary.pending_link_action = PendingLinkAction::Delay(up);
5099                }
5100
5101                match primary.pending_link_action {
5102                    PendingLinkAction::Delay(_) => {
5103                        return Ok(CoordinatorMessage::StartTimer(
5104                            Instant::now() + LINK_DELAY_DURATION,
5105                        ));
5106                    }
5107                    PendingLinkAction::Active(_) => panic!("State should not be Active"),
5108                    _ => {}
5109                }
5110            }
5111            match primary.guest_vf_state {
5112                PrimaryChannelGuestVfState::Available { .. }
5113                | PrimaryChannelGuestVfState::UnavailableFromAvailable
5114                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
5115                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
5116                | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
5117                | PrimaryChannelGuestVfState::DataPathSynthetic => {
5118                    let (_, mut send) = self.queue.split();
5119                    stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
5120                        .await??;
5121                    if let Some(message) = self.handle_state_change(primary, buffers).await? {
5122                        return Ok(message);
5123                    }
5124                }
5125                _ => (),
5126            }
5127        }
5128
5129        loop {
5130            // If the ring is almost full, then do not poll the endpoint,
5131            // since this will cause us to spend time dropping packets and
5132            // reprogramming receive buffers. It's more efficient to let the
5133            // backend drop packets.
5134            let ring_full = {
5135                let (_, mut send) = self.queue.split();
5136                !send.can_write(ring_spare_capacity)?
5137            };
5138
5139            let did_some_work = (!ring_full
5140                && self.process_endpoint_rx(
5141                    buffers,
5142                    state,
5143                    data,
5144                    queue_state.queue.as_mut(),
5145                    &mut queue_state.pool,
5146                )?)
5147                | self.process_ring_buffer(buffers, state, data, queue_state)?
5148                | (!ring_full
5149                    && self.process_endpoint_tx(
5150                        state,
5151                        data,
5152                        queue_state.queue.as_mut(),
5153                        &mut queue_state.pool,
5154                    )?)
5155                | self.transmit_pending_segments(state, data, queue_state)?
5156                | self.send_pending_packets(state)?;
5157
5158            if !did_some_work {
5159                state.stats.spurious_wakes.increment();
5160            }
5161
5162            // Process any outstanding control messages before sleeping in case
5163            // there are now suballocations available for use.
5164            self.process_control_messages(buffers, state)?;
5165
5166            // This should be the only await point waiting on network traffic or
5167            // guest actions. Wrap it in `stop.until_stopped` to allow
5168            // cancellation.
5169            let restart = stop
5170                .until_stopped(std::future::poll_fn(
5171                    |cx| -> Poll<Option<CoordinatorMessage>> {
5172                        // If the ring is almost full, then don't wait for endpoint
5173                        // interrupts. This allows the interrupt rate to fall when the
5174                        // guest cannot keep up with the load.
5175                        if !ring_full {
5176                            // Check the network endpoint for tx completion or rx.
5177                            if queue_state
5178                                .queue
5179                                .poll_ready(cx, &mut queue_state.pool)
5180                                .is_ready()
5181                            {
5182                                tracing::trace!("endpoint ready");
5183                                return Poll::Ready(None);
5184                            }
5185                        }
5186
5187                        // Check the incoming ring for tx, but only if there are enough
5188                        // free tx packets and no pending tx segments.
5189                        let (mut recv, mut send) = self.queue.split();
5190                        if state.free_tx_packets.len() >= self.adapter.free_tx_packet_threshold
5191                            && data.tx_segments.is_empty()
5192                            && recv.poll_ready(cx).is_ready()
5193                        {
5194                            tracing::trace!("incoming ring ready");
5195                            return Poll::Ready(None);
5196                        }
5197
5198                        // Check the outgoing ring for space to send rx completions or
5199                        // control message sends, if any are pending.
5200                        //
5201                        // Also, if endpoint processing has been suspended due to the
5202                        // ring being nearly full, then ask the guest to wake us up when
5203                        // there is space again.
5204                        let mut pending_send_size = self.pending_send_size;
5205                        if ring_full {
5206                            pending_send_size = ring_spare_capacity;
5207                        }
5208                        if pending_send_size != 0
5209                            && send.poll_ready(cx, pending_send_size).is_ready()
5210                        {
5211                            tracing::trace!("outgoing ring ready");
5212                            return Poll::Ready(None);
5213                        }
5214
5215                        // Collect any of this queue's receive buffers that were
5216                        // completed by a remote channel. This only happens when the
5217                        // subchannel count changes, so that receive buffer ownership
5218                        // moves between queues while some receive buffers are still in
5219                        // use.
5220                        if let Some(remote_buffer_id_recv) =
5221                            &mut queue_state.rx_buffer_range.remote_buffer_id_recv
5222                        {
5223                            while let Poll::Ready(Some(id)) =
5224                                remote_buffer_id_recv.poll_next_unpin(cx)
5225                            {
5226                                if id >= RX_RESERVED_CONTROL_BUFFERS {
5227                                    queue_state
5228                                        .queue
5229                                        .rx_avail(&mut queue_state.pool, &[RxId(id)]);
5230                                } else {
5231                                    state
5232                                        .primary
5233                                        .as_mut()
5234                                        .unwrap()
5235                                        .free_control_buffers
5236                                        .push(ControlMessageId(id));
5237                                }
5238                            }
5239                        }
5240
5241                        if let Some(restart) = self.restart.take() {
5242                            return Poll::Ready(Some(restart));
5243                        }
5244
5245                        tracing::trace!("network waiting");
5246                        Poll::Pending
5247                    },
5248                ))
5249                .await?;
5250
5251            if let Some(restart) = restart {
5252                break Ok(restart);
5253            }
5254        }
5255    }
5256
5257    fn process_endpoint_rx(
5258        &mut self,
5259        buffers: &ChannelBuffers,
5260        state: &mut ActiveState,
5261        data: &mut ProcessingData,
5262        epqueue: &mut dyn net_backend::Queue,
5263        pool: &mut BufferPool,
5264    ) -> Result<bool, WorkerError> {
5265        let n = epqueue
5266            .rx_poll(pool, &mut data.rx_ready)
5267            .map_err(WorkerError::Endpoint)?;
5268        if n == 0 {
5269            return Ok(false);
5270        }
5271
5272        state.stats.rx_packets_per_wake.add_sample(n as u64);
5273
5274        if self.packet_filter == rndisprot::NDIS_PACKET_TYPE_NONE {
5275            tracing::trace!(
5276                packet_filter = self.packet_filter,
5277                "rx packets dropped due to packet filter"
5278            );
5279            // Pend the newly available RxIds until the packet filter is updated.
5280            // Under high load this will eventually lead to no available RxIds,
5281            // which will cause the backend to drop the packets instead of
5282            // processing them here.
5283            state.pending_rx_packets.extend(&data.rx_ready[..n]);
5284            state.stats.rx_dropped_filtered.add(n as u64);
5285            return Ok(false);
5286        }
5287
5288        let transaction_id = data.rx_ready[0].0.into();
5289        let ready_ids = data.rx_ready[..n].iter().map(|&RxId(id)| id);
5290
5291        state.rx_bufs.allocate(ready_ids.clone()).unwrap();
5292
5293        // Always use the full suballocation size to avoid tracking the
5294        // message length. See RxBuf::header() for details.
5295        let len = buffers.recv_buffer.sub_allocation_size as usize;
5296        data.transfer_pages.clear();
5297        data.transfer_pages
5298            .extend(ready_ids.map(|id| buffers.recv_buffer.transfer_page_range(id, len)));
5299
5300        match self.try_send_rndis_message(
5301            transaction_id,
5302            protocol::DATA_CHANNEL_TYPE,
5303            buffers.recv_buffer.id,
5304            &data.transfer_pages,
5305        )? {
5306            None => {
5307                // packet was sent
5308                state.stats.rx_packets.add(n as u64);
5309            }
5310            Some(_) => {
5311                // Ring buffer is full. Drop the packets and free the rx
5312                // buffers. When the ring has limited space, the main loop will
5313                // stop polling for receive packets.
5314                state.stats.rx_dropped_ring_full.add(n as u64);
5315
5316                state.rx_bufs.free(data.rx_ready[0].0);
5317                epqueue.rx_avail(pool, &data.rx_ready[..n]);
5318            }
5319        }
5320
5321        Ok(true)
5322    }
5323
5324    fn process_endpoint_tx(
5325        &mut self,
5326        state: &mut ActiveState,
5327        data: &mut ProcessingData,
5328        epqueue: &mut dyn net_backend::Queue,
5329        pool: &mut BufferPool,
5330    ) -> Result<bool, WorkerError> {
5331        // Drain completed transmits.
5332        let result = epqueue.tx_poll(pool, &mut data.tx_done);
5333
5334        match result {
5335            Ok(n) => {
5336                if n == 0 {
5337                    return Ok(false);
5338                }
5339
5340                for &id in &data.tx_done[..n] {
5341                    let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5342                    assert!(tx_packet.pending_packet_count > 0);
5343                    tx_packet.pending_packet_count -= 1;
5344                    if tx_packet.pending_packet_count == 0 {
5345                        self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5346                    }
5347                }
5348
5349                Ok(true)
5350            }
5351            Err(TxError::TryRestart(err)) => {
5352                // In-flight TX packets will be cleaned up by
5353                // `reset_tx_after_endpoint_stop` during the queue restart.
5354                Err(WorkerError::EndpointRequiresQueueRestart(err))
5355            }
5356            Err(TxError::Fatal(err)) => Err(WorkerError::Endpoint(err)),
5357        }
5358    }
5359
5360    fn switch_data_path(
5361        &mut self,
5362        state: &mut ActiveState,
5363        use_guest_vf: bool,
5364        transaction_id: Option<u64>,
5365    ) -> Result<(), WorkerError> {
5366        let primary = state.primary.as_mut().unwrap();
5367        let mut queue_switch_operation = false;
5368        match primary.guest_vf_state {
5369            PrimaryChannelGuestVfState::AvailableAdvertised | PrimaryChannelGuestVfState::Ready => {
5370                // Allow the guest to switch to VTL0, or if the current state
5371                // of the data path is unknown, allow a switch away from VTL0.
5372                // The latter case handles the scenario where the data path has
5373                // been switched but the synthetic device has been restarted.
5374                // The device is queried for the current state of the data path
5375                // but if it is unknown (upgraded from older version that
5376                // doesn't track this data, or failure during query) then the
5377                // safest option is to pass the request through.
5378                if use_guest_vf || primary.is_data_path_switched.is_none() {
5379                    primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5380                        to_guest: use_guest_vf,
5381                        id: transaction_id,
5382                        result: None,
5383                    };
5384                    queue_switch_operation = true;
5385                }
5386            }
5387            PrimaryChannelGuestVfState::DataPathSwitched => {
5388                if !use_guest_vf {
5389                    primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5390                        to_guest: false,
5391                        id: transaction_id,
5392                        result: None,
5393                    };
5394                    queue_switch_operation = true;
5395                }
5396            }
5397            _ if use_guest_vf => {
5398                tracing::warn!(
5399                    state = %primary.guest_vf_state,
5400                    use_guest_vf,
5401                    "Data path switch requested while device is in wrong state"
5402                );
5403            }
5404            _ => (),
5405        };
5406        if queue_switch_operation {
5407            self.send_coordinator_update_vf();
5408        } else {
5409            self.send_completion(transaction_id, None)?;
5410        }
5411        Ok(())
5412    }
5413
5414    fn process_ring_buffer(
5415        &mut self,
5416        buffers: &ChannelBuffers,
5417        state: &mut ActiveState,
5418        data: &mut ProcessingData,
5419        queue_state: &mut QueueState,
5420    ) -> Result<bool, WorkerError> {
5421        if !data.tx_segments.is_empty() {
5422            // There are still segments pending transmission. Skip polling the
5423            // ring until they are sent to increase backpressure and minimize
5424            // unnecessary wakeups.
5425            return Ok(false);
5426        }
5427        let mut total_packets = 0;
5428        let mut did_some_work = false;
5429        loop {
5430            if state.free_tx_packets.is_empty() {
5431                break;
5432            }
5433            let packet = if let Some(packet) = self.try_next_packet(
5434                buffers.send_buffer.as_ref(),
5435                Some(buffers.version),
5436                &mut data.external_data,
5437            )? {
5438                packet
5439            } else {
5440                break;
5441            };
5442
5443            did_some_work = true;
5444            match packet.data {
5445                PacketData::RndisPacket(_) => {
5446                    let id = state.free_tx_packets.pop().unwrap();
5447                    let result: Result<usize, WorkerError> =
5448                        self.handle_rndis(buffers, id, state, &packet, &mut data.tx_segments);
5449                    match result {
5450                        Ok(num_packets) => {
5451                            total_packets += num_packets as u64;
5452                            if num_packets == 0 {
5453                                self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5454                            }
5455                        }
5456                        Err(err) => {
5457                            tracelimit::error_ratelimited!(
5458                                error = &err as &dyn std::error::Error,
5459                                "failed to handle RNDIS packet"
5460                            );
5461                            self.complete_tx_packet(state, id, protocol::Status::FAILURE)?;
5462                        }
5463                    };
5464                }
5465                PacketData::RndisPacketComplete(_completion) => {
5466                    data.rx_done.clear();
5467                    state
5468                        .release_recv_buffers(
5469                            packet
5470                                .transaction_id
5471                                .expect("completion packets have transaction id by construction"),
5472                            &queue_state.rx_buffer_range,
5473                            &mut data.rx_done,
5474                        )
5475                        .ok_or(WorkerError::InvalidRndisPacketCompletion)?;
5476                    queue_state
5477                        .queue
5478                        .rx_avail(&mut queue_state.pool, &data.rx_done);
5479                }
5480                PacketData::SubChannelRequest(request) if state.primary.is_some() => {
5481                    let mut subchannel_count = 0;
5482                    // The number of requested subchannels has to stay below the maximum queue limit
5483                    // because one queue is always reserved for the primary channel. In other words,
5484                    // the subchannels plus the primary channel must fit within the max_queues value,
5485                    // which means subchannels + 1 ≤ max_queues, so the subchannel count must be
5486                    // strictly less than max_queues.
5487                    let status = if request.operation == protocol::SubchannelOperation::ALLOCATE
5488                        && request.num_sub_channels < self.adapter.max_queues.into()
5489                    {
5490                        subchannel_count = request.num_sub_channels;
5491                        protocol::Status::SUCCESS
5492                    } else {
5493                        tracelimit::warn_ratelimited!(
5494                            operation = ?request.operation,
5495                            request_sub_channels = request.num_sub_channels,
5496                            max_supported_sub_channels = self.adapter.max_queues - 1,
5497                            "Subchannel request failed: either operation is not supported or requested more subchannels than supported"
5498                        );
5499                        protocol::Status::FAILURE
5500                    };
5501
5502                    tracing::debug!(?status, subchannel_count, "subchannel request");
5503                    self.send_completion(
5504                        packet.transaction_id,
5505                        Some(&self.message(
5506                            protocol::MESSAGE5_TYPE_SUB_CHANNEL,
5507                            protocol::Message5SubchannelComplete {
5508                                status,
5509                                num_sub_channels: subchannel_count,
5510                            },
5511                        )),
5512                    )?;
5513
5514                    if subchannel_count > 0 {
5515                        let primary = state.primary.as_mut().unwrap();
5516                        primary.requested_num_queues = subchannel_count as u16 + 1;
5517                        primary.tx_spread_sent = false;
5518                        self.restart = Some(CoordinatorMessage::Restart);
5519                    }
5520                }
5521                PacketData::RevokeReceiveBuffer(protocol::Message1RevokeReceiveBuffer { id })
5522                | PacketData::RevokeSendBuffer(protocol::Message1RevokeSendBuffer { id })
5523                    if state.primary.is_some() =>
5524                {
5525                    tracing::debug!(
5526                        id,
5527                        "receive/send buffer revoked, terminating channel processing"
5528                    );
5529                    return Err(WorkerError::BufferRevoked);
5530                }
5531                // No operation for VF association completion packets as not all clients send them
5532                PacketData::SendVfAssociationCompletion if state.primary.is_some() => (),
5533                PacketData::SwitchDataPath(switch_data_path) if state.primary.is_some() => {
5534                    self.switch_data_path(
5535                        state,
5536                        switch_data_path.active_data_path == protocol::DataPath::VF.0,
5537                        packet.transaction_id,
5538                    )?;
5539                }
5540                PacketData::SwitchDataPathCompletion if state.primary.is_some() => (),
5541                PacketData::OidQueryEx(oid_query) => {
5542                    tracing::warn!(oid = ?oid_query.oid, "unimplemented OID");
5543                    self.send_completion(
5544                        packet.transaction_id,
5545                        Some(&self.message(
5546                            protocol::MESSAGE5_TYPE_OID_QUERY_EX_COMPLETE,
5547                            protocol::Message5OidQueryExComplete {
5548                                status: rndisprot::STATUS_NOT_SUPPORTED,
5549                                bytes: 0,
5550                            },
5551                        )),
5552                    )?;
5553                }
5554                p => {
5555                    tracing::warn!(request = ?p, "unexpected packet");
5556                    return Err(WorkerError::UnexpectedPacketOrder(
5557                        PacketOrderError::SwitchDataPathCompletionPrimaryChannelState,
5558                    ));
5559                }
5560            }
5561        }
5562        if total_packets > 0 && !self.transmit_segments(state, data, queue_state)? {
5563            state.stats.tx_stalled.increment();
5564        }
5565        state.stats.tx_packets_per_wake.add_sample(total_packets);
5566        Ok(did_some_work)
5567    }
5568
5569    // Transmit any pending segments. Returns Ok(true) if work was done--if any
5570    // segments were transmitted.
5571    fn transmit_pending_segments(
5572        &mut self,
5573        state: &mut ActiveState,
5574        data: &mut ProcessingData,
5575        queue_state: &mut QueueState,
5576    ) -> Result<bool, WorkerError> {
5577        if data.tx_segments.is_empty() {
5578            return Ok(false);
5579        }
5580        let sent = data.tx_segments_sent;
5581        let did_work =
5582            self.transmit_segments(state, data, queue_state)? || data.tx_segments_sent > sent;
5583        Ok(did_work)
5584    }
5585
5586    /// Returns true if all pending segments were transmitted.
5587    fn transmit_segments(
5588        &mut self,
5589        state: &mut ActiveState,
5590        data: &mut ProcessingData,
5591        queue_state: &mut QueueState,
5592    ) -> Result<bool, WorkerError> {
5593        let segments = &data.tx_segments[data.tx_segments_sent..];
5594        let (sync, segments_sent) = queue_state
5595            .queue
5596            .tx_avail(&mut queue_state.pool, segments)
5597            .map_err(WorkerError::Endpoint)?;
5598
5599        let mut segments = &segments[..segments_sent];
5600        data.tx_segments_sent += segments_sent;
5601
5602        if sync {
5603            // Complete the packets now.
5604            while let Some(head) = segments.first() {
5605                let net_backend::TxSegmentType::Head(metadata) = &head.ty else {
5606                    unreachable!()
5607                };
5608                let id = metadata.id;
5609                let pending_tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5610                pending_tx_packet.pending_packet_count -= 1;
5611                if pending_tx_packet.pending_packet_count == 0 {
5612                    self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5613                }
5614                segments = &segments[metadata.segment_count as usize..];
5615            }
5616        }
5617
5618        let all_sent = data.tx_segments_sent == data.tx_segments.len();
5619        if all_sent {
5620            data.tx_segments.clear();
5621            data.tx_segments_sent = 0;
5622        }
5623        Ok(all_sent)
5624    }
5625
5626    fn handle_rndis(
5627        &mut self,
5628        buffers: &ChannelBuffers,
5629        id: TxId,
5630        state: &mut ActiveState,
5631        packet: &Packet<'_>,
5632        segments: &mut Vec<TxSegment>,
5633    ) -> Result<usize, WorkerError> {
5634        let mut total_packets = 0;
5635        let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5636        assert_eq!(tx_packet.pending_packet_count, 0);
5637        tx_packet.transaction_id = packet
5638            .transaction_id
5639            .ok_or(WorkerError::MissingTransactionId)?;
5640
5641        // Probe the data to catch accesses that are out of bounds. This
5642        // simplifies error handling for backends that use
5643        // [`GuestMemory::iova`].
5644        packet
5645            .external_data
5646            .iter()
5647            .try_for_each(|range| buffers.mem.probe_gpns(range.gpns()))
5648            .map_err(WorkerError::GpaDirectError)?;
5649
5650        let mut reader = packet.rndis_reader(&buffers.mem);
5651        let header: rndisprot::MessageHeader = reader.read_plain()?;
5652        if header.message_type == rndisprot::MESSAGE_TYPE_PACKET_MSG {
5653            let start = segments.len();
5654            match self.handle_rndis_packet_messages(
5655                buffers,
5656                state,
5657                id,
5658                header.message_length as usize,
5659                reader,
5660                segments,
5661            ) {
5662                Ok(n) => {
5663                    state.pending_tx_packets[id.0 as usize].pending_packet_count += n;
5664                    total_packets += n;
5665                }
5666                Err(err) => {
5667                    // Roll back any segments added for this message.
5668                    segments.truncate(start);
5669                    return Err(err);
5670                }
5671            }
5672        } else {
5673            self.handle_rndis_message(state, header.message_type, reader)?;
5674        }
5675
5676        Ok(total_packets)
5677    }
5678
5679    fn try_send_tx_packet(
5680        &mut self,
5681        transaction_id: u64,
5682        status: protocol::Status,
5683    ) -> Result<bool, WorkerError> {
5684        let message = self.message(
5685            protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE,
5686            protocol::Message1SendRndisPacketComplete { status },
5687        );
5688        let result = self.queue.split().1.batched().try_write_aligned(
5689            transaction_id,
5690            OutgoingPacketType::Completion,
5691            message.aligned_payload(),
5692        );
5693        let sent = match result {
5694            Ok(()) => true,
5695            Err(queue::TryWriteError::Full(n)) => {
5696                self.pending_send_size = n;
5697                false
5698            }
5699            Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
5700        };
5701        Ok(sent)
5702    }
5703
5704    fn send_pending_packets(&mut self, state: &mut ActiveState) -> Result<bool, WorkerError> {
5705        let mut did_some_work = false;
5706        while let Some(pending) = state.pending_tx_completions.front() {
5707            if !self.try_send_tx_packet(pending.transaction_id, pending.status)? {
5708                return Ok(did_some_work);
5709            }
5710            did_some_work = true;
5711            if let Some(id) = pending.tx_id {
5712                state.free_tx_packets.push(id);
5713            }
5714            tracing::trace!(?pending, "sent tx completion");
5715            state.pending_tx_completions.pop_front();
5716        }
5717
5718        self.pending_send_size = 0;
5719        Ok(did_some_work)
5720    }
5721
5722    fn complete_tx_packet(
5723        &mut self,
5724        state: &mut ActiveState,
5725        id: TxId,
5726        status: protocol::Status,
5727    ) -> Result<(), WorkerError> {
5728        let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5729        assert_eq!(tx_packet.pending_packet_count, 0);
5730        if self.pending_send_size == 0
5731            && self.try_send_tx_packet(tx_packet.transaction_id, status)?
5732        {
5733            tracing::trace!(id = id.0, "sent tx completion");
5734            state.free_tx_packets.push(id);
5735        } else {
5736            tracing::trace!(id = id.0, "pended tx completion");
5737            state.pending_tx_completions.push_back(PendingTxCompletion {
5738                transaction_id: tx_packet.transaction_id,
5739                tx_id: Some(id),
5740                status,
5741            });
5742        }
5743        Ok(())
5744    }
5745}
5746
5747impl ActiveState {
5748    fn release_recv_buffers(
5749        &mut self,
5750        transaction_id: u64,
5751        rx_buffer_range: &RxBufferRange,
5752        done: &mut Vec<RxId>,
5753    ) -> Option<()> {
5754        // The transaction ID specifies the first rx buffer ID.
5755        let first_id: u32 = transaction_id.try_into().ok()?;
5756        let ids = self.rx_bufs.free(first_id)?;
5757        for id in ids {
5758            if !rx_buffer_range.send_if_remote(id) {
5759                if id >= RX_RESERVED_CONTROL_BUFFERS {
5760                    done.push(RxId(id));
5761                } else {
5762                    self.primary
5763                        .as_mut()?
5764                        .free_control_buffers
5765                        .push(ControlMessageId(id));
5766                }
5767            }
5768        }
5769        Some(())
5770    }
5771}