netvsp/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! The user-mode netvsp VMBus device implementation.
5
6#![expect(missing_docs)]
7#![forbid(unsafe_code)]
8
9mod buffers;
10mod protocol;
11pub mod resolver;
12mod rndisprot;
13mod rx_bufs;
14mod saved_state;
15mod test;
16
17use crate::buffers::GuestBuffers;
18use crate::protocol::Message1RevokeReceiveBuffer;
19use crate::protocol::Message1RevokeSendBuffer;
20use crate::protocol::VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES;
21use crate::protocol::Version;
22use crate::rndisprot::NDIS_HASH_FUNCTION_MASK;
23use crate::rndisprot::NDIS_RSS_PARAM_FLAG_DISABLE_RSS;
24use async_trait::async_trait;
25pub use buffers::BufferPool;
26use buffers::sub_allocation_size_for_mtu;
27use futures::FutureExt;
28use futures::StreamExt;
29use futures::channel::mpsc;
30use futures_concurrency::future::Race;
31use guestmem::AccessError;
32use guestmem::GuestMemory;
33use guestmem::GuestMemoryError;
34use guestmem::MemoryRead;
35use guestmem::MemoryWrite;
36use guestmem::ranges::PagedRange;
37use guestmem::ranges::PagedRanges;
38use guestmem::ranges::PagedRangesReader;
39use guid::Guid;
40use hvdef::hypercall::HvGuestOsId;
41use hvdef::hypercall::HvGuestOsMicrosoft;
42use hvdef::hypercall::HvGuestOsMicrosoftIds;
43use hvdef::hypercall::HvGuestOsOpenSourceType;
44use inspect::Inspect;
45use inspect::InspectMut;
46use inspect::SensitivityLevel;
47use inspect_counters::Counter;
48use inspect_counters::Histogram;
49use mesh::rpc::Rpc;
50use net_backend::Endpoint;
51use net_backend::EndpointAction;
52use net_backend::L3Protocol;
53use net_backend::QueueConfig;
54use net_backend::RxId;
55use net_backend::TxError;
56use net_backend::TxId;
57use net_backend::TxSegment;
58use net_backend_resources::mac_address::MacAddress;
59use pal_async::timer::Instant;
60use pal_async::timer::PolledTimer;
61use ring::gparange::MultiPagedRangeIter;
62use rx_bufs::RxBuffers;
63use rx_bufs::SubAllocationInUse;
64use std::cmp;
65use std::collections::VecDeque;
66use std::fmt::Debug;
67use std::future::pending;
68use std::mem::offset_of;
69use std::ops::Range;
70use std::sync::Arc;
71use std::sync::atomic::AtomicUsize;
72use std::sync::atomic::Ordering;
73use std::task::Poll;
74use std::time::Duration;
75use task_control::AsyncRun;
76use task_control::InspectTaskMut;
77use task_control::StopTask;
78use task_control::TaskControl;
79use thiserror::Error;
80use tracing::Instrument;
81use vmbus_async::queue;
82use vmbus_async::queue::ExternalDataError;
83use vmbus_async::queue::IncomingPacket;
84use vmbus_async::queue::Queue;
85use vmbus_channel::bus::OfferParams;
86use vmbus_channel::bus::OpenRequest;
87use vmbus_channel::channel::ChannelControl;
88use vmbus_channel::channel::ChannelOpenError;
89use vmbus_channel::channel::ChannelRestoreError;
90use vmbus_channel::channel::DeviceResources;
91use vmbus_channel::channel::RestoreControl;
92use vmbus_channel::channel::SaveRestoreVmbusDevice;
93use vmbus_channel::channel::VmbusDevice;
94use vmbus_channel::gpadl::GpadlId;
95use vmbus_channel::gpadl::GpadlMapView;
96use vmbus_channel::gpadl::GpadlView;
97use vmbus_channel::gpadl::UnknownGpadlId;
98use vmbus_channel::gpadl_ring::GpadlRingMem;
99use vmbus_channel::gpadl_ring::gpadl_channel;
100use vmbus_ring as ring;
101use vmbus_ring::OutgoingPacketType;
102use vmbus_ring::RingMem;
103use vmbus_ring::gparange::GpnList;
104use vmbus_ring::gparange::MultiPagedRangeBuf;
105use vmcore::save_restore::RestoreError;
106use vmcore::save_restore::SaveError;
107use vmcore::save_restore::SavedStateBlob;
108use vmcore::vm_task::VmTaskDriver;
109use vmcore::vm_task::VmTaskDriverSource;
110use zerocopy::FromBytes;
111use zerocopy::FromZeros;
112use zerocopy::Immutable;
113use zerocopy::IntoBytes;
114use zerocopy::KnownLayout;
115
116// The minimum ring space required to handle a control message. Most control messages only need to send a completion
117// packet, but also need room for an additional SEND_VF_ASSOCIATION message.
118const MIN_CONTROL_RING_SIZE: usize = 144;
119
120// The minimum ring space required to handle external state changes. Worst case requires a completion message plus two
121// additional inband messages (SWITCH_DATA_PATH and SEND_VF_ASSOCIATION)
122const MIN_STATE_CHANGE_RING_SIZE: usize = 196;
123
124// Assign the VF_ASSOCIATION message a specific transaction ID so that the completion packet can be identified easily.
125const VF_ASSOCIATION_TRANSACTION_ID: u64 = 0x8000000000000000;
126// Assign the SWITCH_DATA_PATH message a specific transaction ID so that the completion packet can be identified easily.
127const SWITCH_DATA_PATH_TRANSACTION_ID: u64 = 0x8000000000000001;
128
129const NETVSP_MAX_SUBCHANNELS_PER_VNIC: u16 = 64;
130
131// Arbitrary delay before adding the device to the guest. Older Linux
132// clients can race when initializing the synthetic nic: the network
133// negotiation is done first and then the device is asynchronously queued to
134// receive a name (e.g. eth0). If the AN device is offered too quickly, it
135// could get the "eth0" name. In provisioning scenarios, the scripts will make
136// assumptions about which interface should be used, with eth0 being the
137// default.
138#[cfg(not(test))]
139const VF_DEVICE_DELAY: Duration = Duration::from_secs(1);
140#[cfg(test)]
141const VF_DEVICE_DELAY: Duration = Duration::from_millis(100);
142
143// Linux guests are known to not act on link state change notifications if
144// they happen in quick succession.
145#[cfg(not(test))]
146const LINK_DELAY_DURATION: Duration = Duration::from_secs(5);
147#[cfg(test)]
148const LINK_DELAY_DURATION: Duration = Duration::from_millis(333);
149
150#[derive(PartialEq)]
151enum CoordinatorMessage {
152    /// Update guest VF state based on current availability and the guest VF state tracked by the primary channel.
153    /// This includes adding the guest VF device and switching the data path.
154    UpdateGuestVfState,
155    /// Restart endpoints and resume processing. This will also attempt to set VF and data path state to match current
156    /// expectations.
157    Restart,
158    /// Start a timer.
159    StartTimer(Instant),
160}
161
162struct Worker<T: RingMem> {
163    channel_idx: u16,
164    target_vp: u32,
165    mem: GuestMemory,
166    channel: NetChannel<T>,
167    state: WorkerState,
168    coordinator_send: mpsc::Sender<CoordinatorMessage>,
169}
170
171struct NetQueue {
172    driver: VmTaskDriver,
173    queue_state: Option<QueueState>,
174}
175
176impl<T: RingMem + 'static + Sync> InspectTaskMut<Worker<T>> for NetQueue {
177    fn inspect_mut(&mut self, req: inspect::Request<'_>, worker: Option<&mut Worker<T>>) {
178        if worker.is_none() && self.queue_state.is_none() {
179            req.ignore();
180            return;
181        }
182
183        let mut resp = req.respond();
184        resp.field("driver", &self.driver);
185        if let Some(worker) = worker {
186            resp.field(
187                "protocol_state",
188                match &worker.state {
189                    WorkerState::Init(None) => "version",
190                    WorkerState::Init(Some(_)) => "init",
191                    WorkerState::Ready(_) => "ready",
192                },
193            )
194            .field("ring", &worker.channel.queue)
195            .field(
196                "can_use_ring_size_optimization",
197                worker.channel.can_use_ring_size_opt,
198            );
199
200            if let WorkerState::Ready(state) = &worker.state {
201                resp.field(
202                    "outstanding_tx_packets",
203                    state.state.pending_tx_packets.len() - state.state.free_tx_packets.len(),
204                )
205                .field(
206                    "pending_tx_completions",
207                    state.state.pending_tx_completions.len(),
208                )
209                .field("free_tx_packets", state.state.free_tx_packets.len())
210                .merge(&state.state.stats);
211            }
212        }
213
214        if let Some(queue_state) = &mut self.queue_state {
215            resp.field_mut("queue", &mut queue_state.queue)
216                .field("rx_buffers", queue_state.rx_buffer_range.id_range.len())
217                .field(
218                    "rx_buffers_start",
219                    queue_state.rx_buffer_range.id_range.start,
220                );
221        }
222    }
223}
224
225#[expect(clippy::large_enum_variant)]
226enum WorkerState {
227    Init(Option<InitState>),
228    Ready(ReadyState),
229}
230
231impl WorkerState {
232    fn ready(&self) -> Option<&ReadyState> {
233        if let Self::Ready(state) = self {
234            Some(state)
235        } else {
236            None
237        }
238    }
239
240    fn ready_mut(&mut self) -> Option<&mut ReadyState> {
241        if let Self::Ready(state) = self {
242            Some(state)
243        } else {
244            None
245        }
246    }
247}
248
249struct InitState {
250    version: Version,
251    ndis_config: Option<NdisConfig>,
252    ndis_version: Option<NdisVersion>,
253    recv_buffer: Option<ReceiveBuffer>,
254    send_buffer: Option<SendBuffer>,
255}
256
257#[derive(Copy, Clone, Debug, Inspect)]
258struct NdisVersion {
259    #[inspect(hex)]
260    major: u32,
261    #[inspect(hex)]
262    minor: u32,
263}
264
265#[derive(Copy, Clone, Debug, Inspect)]
266struct NdisConfig {
267    #[inspect(safe)]
268    mtu: u32,
269    #[inspect(safe)]
270    capabilities: protocol::NdisConfigCapabilities,
271}
272
273struct ReadyState {
274    buffers: Arc<ChannelBuffers>,
275    state: ActiveState,
276    data: ProcessingData,
277}
278
279/// Represents a virtual function (VF) device used to expose accelerated
280/// networking to the guest.
281#[async_trait]
282pub trait VirtualFunction: Sync + Send {
283    /// Unique ID of the device. Used by the client to associate a device with
284    /// its synthetic counterpart. A value of None signifies that the VF is not
285    /// currently available for use.
286    async fn id(&self) -> Option<u32>;
287    /// Dynamically expose the device in the guest.
288    async fn guest_ready_for_device(&mut self);
289    /// Returns when there is a change in VF availability. The Rpc result will
290    ///  indicate if the change was successfully handled.
291    async fn wait_for_state_change(&mut self) -> Rpc<(), ()>;
292}
293
294struct Adapter {
295    driver: VmTaskDriver,
296    mac_address: MacAddress,
297    max_queues: u16,
298    indirection_table_size: u16,
299    offload_support: OffloadConfig,
300    ring_size_limit: AtomicUsize,
301    free_tx_packet_threshold: usize,
302    tx_fast_completions: bool,
303    adapter_index: u32,
304    get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
305    num_sub_channels_opened: AtomicUsize,
306    link_speed: u64,
307}
308
309struct QueueState {
310    queue: Box<dyn net_backend::Queue>,
311    rx_buffer_range: RxBufferRange,
312    target_vp_set: bool,
313}
314
315struct RxBufferRange {
316    id_range: Range<u32>,
317    remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
318    remote_ranges: Arc<RxBufferRanges>,
319}
320
321impl RxBufferRange {
322    fn new(
323        ranges: Arc<RxBufferRanges>,
324        id_range: Range<u32>,
325        remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
326    ) -> Self {
327        Self {
328            id_range,
329            remote_buffer_id_recv,
330            remote_ranges: ranges,
331        }
332    }
333
334    fn send_if_remote(&self, id: u32) -> bool {
335        // Only queue 0 should get reserved buffer IDs. Otherwise check if the
336        // ID is owned by the current range.
337        if id < RX_RESERVED_CONTROL_BUFFERS || self.id_range.contains(&id) {
338            false
339        } else {
340            let i = (id - RX_RESERVED_CONTROL_BUFFERS) / self.remote_ranges.buffers_per_queue;
341            // The total number of receive buffers may not evenly divide among
342            // the active queues. Any extra buffers are given to the last
343            // queue, so redirect any larger values there.
344            let i = (i as usize).min(self.remote_ranges.buffer_id_send.len() - 1);
345            let _ = self.remote_ranges.buffer_id_send[i].unbounded_send(id);
346            true
347        }
348    }
349}
350
351struct RxBufferRanges {
352    buffers_per_queue: u32,
353    buffer_id_send: Vec<mpsc::UnboundedSender<u32>>,
354}
355
356impl RxBufferRanges {
357    fn new(buffer_count: u32, queue_count: u32) -> (Self, Vec<mpsc::UnboundedReceiver<u32>>) {
358        let buffers_per_queue = (buffer_count - RX_RESERVED_CONTROL_BUFFERS) / queue_count;
359        #[expect(clippy::disallowed_methods)] // TODO
360        let (send, recv): (Vec<_>, Vec<_>) = (0..queue_count).map(|_| mpsc::unbounded()).unzip();
361        (
362            Self {
363                buffers_per_queue,
364                buffer_id_send: send,
365            },
366            recv,
367        )
368    }
369}
370
371struct RssState {
372    key: [u8; 40],
373    indirection_table: Vec<u16>,
374}
375
376/// The internal channel state.
377struct NetChannel<T: RingMem> {
378    adapter: Arc<Adapter>,
379    queue: Queue<T>,
380    gpadl_map: GpadlMapView,
381    packet_size: usize,
382    pending_send_size: usize,
383    restart: Option<CoordinatorMessage>,
384    can_use_ring_size_opt: bool,
385}
386
387/// Buffers used during packet processing.
388struct ProcessingData {
389    tx_segments: Vec<TxSegment>,
390    tx_done: Box<[TxId]>,
391    rx_ready: Box<[RxId]>,
392    rx_done: Vec<RxId>,
393    transfer_pages: Vec<ring::TransferPageRange>,
394}
395
396impl ProcessingData {
397    fn new() -> Self {
398        Self {
399            tx_segments: Vec::new(),
400            tx_done: vec![TxId(0); 8192].into(),
401            rx_ready: vec![RxId(0); RX_BATCH_SIZE].into(),
402            rx_done: Vec::with_capacity(RX_BATCH_SIZE),
403            transfer_pages: Vec::with_capacity(RX_BATCH_SIZE),
404        }
405    }
406}
407
408/// Buffers used during channel processing. Separated out from the mutable state
409/// to allow multiple concurrent references.
410#[derive(Debug, Inspect)]
411struct ChannelBuffers {
412    version: Version,
413    #[inspect(skip)]
414    mem: GuestMemory,
415    #[inspect(skip)]
416    recv_buffer: ReceiveBuffer,
417    #[inspect(skip)]
418    send_buffer: Option<SendBuffer>,
419    ndis_version: NdisVersion,
420    #[inspect(safe)]
421    ndis_config: NdisConfig,
422}
423
424/// An ID assigned to a control message. This is also its receive buffer index.
425#[derive(Copy, Clone, Debug)]
426struct ControlMessageId(u32);
427
428/// Mutable state for a channel that has finished negotiation.
429struct ActiveState {
430    primary: Option<PrimaryChannelState>,
431
432    pending_tx_packets: Vec<PendingTxPacket>,
433    free_tx_packets: Vec<TxId>,
434    pending_tx_completions: VecDeque<PendingTxCompletion>,
435
436    rx_bufs: RxBuffers,
437
438    stats: QueueStats,
439}
440
441#[derive(Inspect, Default)]
442struct QueueStats {
443    tx_stalled: Counter,
444    rx_dropped_ring_full: Counter,
445    spurious_wakes: Counter,
446    rx_packets: Counter,
447    tx_packets: Counter,
448    tx_lso_packets: Counter,
449    tx_checksum_packets: Counter,
450    tx_packets_per_wake: Histogram<10>,
451    rx_packets_per_wake: Histogram<10>,
452}
453
454#[derive(Debug)]
455struct PendingTxCompletion {
456    transaction_id: u64,
457    tx_id: Option<TxId>,
458    status: protocol::Status,
459}
460
461#[derive(Clone, Copy)]
462enum PrimaryChannelGuestVfState {
463    /// No state has been assigned yet
464    Initializing,
465    /// State is being restored from a save
466    Restoring(saved_state::GuestVfState),
467    /// No VF available for the guest
468    Unavailable,
469    /// A VF was previously available to the guest, but is no longer available.
470    UnavailableFromAvailable,
471    /// A VF was previously available, but it is no longer available.
472    UnavailableFromDataPathSwitchPending { to_guest: bool, id: Option<u64> },
473    /// A VF was previously available, but it is no longer available.
474    UnavailableFromDataPathSwitched,
475    /// A VF is available for the guest.
476    Available { vfid: u32 },
477    /// A VF is available for the guest and has been advertised to the guest.
478    AvailableAdvertised,
479    /// A VF is ready for guest use.
480    Ready,
481    /// A VF is ready in the guest and guest has requested a data path switch.
482    DataPathSwitchPending {
483        to_guest: bool,
484        id: Option<u64>,
485        result: Option<bool>,
486    },
487    /// A VF is ready in the guest and is currently acting as the data path.
488    DataPathSwitched,
489    /// A VF is ready in the guest and was acting as the data path, but an external
490    /// state change has moved it back to synthetic.
491    DataPathSynthetic,
492}
493
494impl std::fmt::Display for PrimaryChannelGuestVfState {
495    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
496        match self {
497            PrimaryChannelGuestVfState::Initializing => write!(f, "initializing"),
498            PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
499                write!(f, "restoring")
500            }
501            PrimaryChannelGuestVfState::Restoring(
502                saved_state::GuestVfState::AvailableAdvertised,
503            ) => write!(f, "restoring from guest notified of vfid"),
504            PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
505                write!(f, "restoring from vf present")
506            }
507            PrimaryChannelGuestVfState::Restoring(
508                saved_state::GuestVfState::DataPathSwitchPending {
509                    to_guest, result, ..
510                },
511            ) => {
512                write!(
513                    f,
514                    "restoring from client requested data path switch: to {} {}",
515                    if *to_guest { "guest" } else { "synthetic" },
516                    if let Some(result) = result {
517                        if *result { "succeeded\"" } else { "failed\"" }
518                    } else {
519                        "in progress\""
520                    }
521                )
522            }
523            PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::DataPathSwitched) => {
524                write!(f, "restoring from data path in guest")
525            }
526            PrimaryChannelGuestVfState::Unavailable => write!(f, "unavailable"),
527            PrimaryChannelGuestVfState::UnavailableFromAvailable => {
528                write!(f, "\"unavailable (previously available)\"")
529            }
530            PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. } => {
531                write!(f, "unavailable (previously switching data path)")
532            }
533            PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
534                write!(f, "\"unavailable (previously using guest VF)\"")
535            }
536            PrimaryChannelGuestVfState::Available { vfid } => write!(f, "available vfid: {}", vfid),
537            PrimaryChannelGuestVfState::AvailableAdvertised => {
538                write!(f, "\"available, guest notified\"")
539            }
540            PrimaryChannelGuestVfState::Ready => write!(f, "\"available and present in guest\""),
541            PrimaryChannelGuestVfState::DataPathSwitchPending {
542                to_guest, result, ..
543            } => {
544                write!(
545                    f,
546                    "\"switching to {} {}",
547                    if *to_guest { "guest" } else { "synthetic" },
548                    if let Some(result) = result {
549                        if *result { "succeeded\"" } else { "failed\"" }
550                    } else {
551                        "in progress\""
552                    }
553                )
554            }
555            PrimaryChannelGuestVfState::DataPathSwitched => {
556                write!(f, "\"available and data path switched\"")
557            }
558            PrimaryChannelGuestVfState::DataPathSynthetic => write!(
559                f,
560                "\"available but data path switched back to synthetic due to external state change\""
561            ),
562        }
563    }
564}
565
566impl Inspect for PrimaryChannelGuestVfState {
567    fn inspect(&self, req: inspect::Request<'_>) {
568        req.value(self.to_string());
569    }
570}
571
572#[derive(Debug)]
573enum PendingLinkAction {
574    Default,
575    Active(bool),
576    Delay(bool),
577}
578
579struct PrimaryChannelState {
580    guest_vf_state: PrimaryChannelGuestVfState,
581    is_data_path_switched: Option<bool>,
582    control_messages: VecDeque<ControlMessage>,
583    control_messages_len: usize,
584    free_control_buffers: Vec<ControlMessageId>,
585    rss_state: Option<RssState>,
586    requested_num_queues: u16,
587    rndis_state: RndisState,
588    offload_config: OffloadConfig,
589    pending_offload_change: bool,
590    tx_spread_sent: bool,
591    guest_link_up: bool,
592    pending_link_action: PendingLinkAction,
593}
594
595impl Inspect for PrimaryChannelState {
596    fn inspect(&self, req: inspect::Request<'_>) {
597        req.respond()
598            .sensitivity_field(
599                "guest_vf_state",
600                SensitivityLevel::Safe,
601                self.guest_vf_state,
602            )
603            .sensitivity_field(
604                "data_path_switched",
605                SensitivityLevel::Safe,
606                self.is_data_path_switched,
607            )
608            .sensitivity_field(
609                "pending_control_messages",
610                SensitivityLevel::Safe,
611                self.control_messages.len(),
612            )
613            .sensitivity_field(
614                "free_control_message_buffers",
615                SensitivityLevel::Safe,
616                self.free_control_buffers.len(),
617            )
618            .sensitivity_field(
619                "pending_offload_change",
620                SensitivityLevel::Safe,
621                self.pending_offload_change,
622            )
623            .sensitivity_field("rndis_state", SensitivityLevel::Safe, self.rndis_state)
624            .sensitivity_field(
625                "offload_config",
626                SensitivityLevel::Safe,
627                &self.offload_config,
628            )
629            .sensitivity_field(
630                "tx_spread_sent",
631                SensitivityLevel::Safe,
632                self.tx_spread_sent,
633            )
634            .sensitivity_field("guest_link_up", SensitivityLevel::Safe, self.guest_link_up)
635            .sensitivity_field(
636                "pending_link_action",
637                SensitivityLevel::Safe,
638                match &self.pending_link_action {
639                    PendingLinkAction::Active(up) => format!("Active({:x?})", up),
640                    PendingLinkAction::Delay(up) => format!("Delay({:x?})", up),
641                    PendingLinkAction::Default => "None".to_string(),
642                },
643            );
644    }
645}
646
647#[derive(Debug, Inspect, Clone)]
648struct OffloadConfig {
649    #[inspect(safe)]
650    checksum_tx: ChecksumOffloadConfig,
651    #[inspect(safe)]
652    checksum_rx: ChecksumOffloadConfig,
653    #[inspect(safe)]
654    lso4: bool,
655    #[inspect(safe)]
656    lso6: bool,
657}
658
659#[derive(Debug, Inspect, Clone)]
660struct ChecksumOffloadConfig {
661    #[inspect(safe)]
662    ipv4_header: bool,
663    #[inspect(safe)]
664    tcp4: bool,
665    #[inspect(safe)]
666    udp4: bool,
667    #[inspect(safe)]
668    tcp6: bool,
669    #[inspect(safe)]
670    udp6: bool,
671}
672
673impl ChecksumOffloadConfig {
674    fn flags(
675        &self,
676    ) -> (
677        rndisprot::Ipv4ChecksumOffload,
678        rndisprot::Ipv6ChecksumOffload,
679    ) {
680        let on = rndisprot::NDIS_OFFLOAD_SUPPORTED;
681        let mut v4 = rndisprot::Ipv4ChecksumOffload::new();
682        let mut v6 = rndisprot::Ipv6ChecksumOffload::new();
683        if self.ipv4_header {
684            v4.set_ip_options_supported(on);
685            v4.set_ip_checksum(on);
686        }
687        if self.tcp4 {
688            v4.set_ip_options_supported(on);
689            v4.set_tcp_options_supported(on);
690            v4.set_tcp_checksum(on);
691        }
692        if self.tcp6 {
693            v6.set_ip_extension_headers_supported(on);
694            v6.set_tcp_options_supported(on);
695            v6.set_tcp_checksum(on);
696        }
697        if self.udp4 {
698            v4.set_ip_options_supported(on);
699            v4.set_udp_checksum(on);
700        }
701        if self.udp6 {
702            v6.set_ip_extension_headers_supported(on);
703            v6.set_udp_checksum(on);
704        }
705        (v4, v6)
706    }
707}
708
709impl OffloadConfig {
710    fn ndis_offload(&self) -> rndisprot::NdisOffload {
711        let checksum = {
712            let (ipv4_tx_flags, ipv6_tx_flags) = self.checksum_tx.flags();
713            let (ipv4_rx_flags, ipv6_rx_flags) = self.checksum_rx.flags();
714            rndisprot::TcpIpChecksumOffload {
715                ipv4_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
716                ipv4_tx_flags,
717                ipv4_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
718                ipv4_rx_flags,
719                ipv6_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
720                ipv6_tx_flags,
721                ipv6_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
722                ipv6_rx_flags,
723            }
724        };
725
726        let lso_v2 = {
727            let mut lso = rndisprot::TcpLargeSendOffloadV2::new_zeroed();
728            // Use the same maximum as vmswitch.
729            const MAX_OFFLOAD_SIZE: u32 = 0xF53C;
730            if self.lso4 {
731                lso.ipv4_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
732                lso.ipv4_max_offload_size = MAX_OFFLOAD_SIZE;
733                lso.ipv4_min_segment_count = 2;
734            }
735            if self.lso6 {
736                lso.ipv6_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
737                lso.ipv6_max_offload_size = MAX_OFFLOAD_SIZE;
738                lso.ipv6_min_segment_count = 2;
739                lso.ipv6_flags = rndisprot::Ipv6LsoFlags::new()
740                    .with_ip_extension_headers_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED)
741                    .with_tcp_options_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED);
742            }
743            lso
744        };
745
746        rndisprot::NdisOffload {
747            header: rndisprot::NdisObjectHeader {
748                object_type: rndisprot::NdisObjectType::OFFLOAD,
749                revision: 3,
750                size: rndisprot::NDIS_SIZEOF_NDIS_OFFLOAD_REVISION_3 as u16,
751            },
752            checksum,
753            lso_v2,
754            ..FromZeros::new_zeroed()
755        }
756    }
757}
758
759#[derive(Debug, Inspect, PartialEq, Eq, Copy, Clone)]
760pub enum RndisState {
761    Initializing,
762    Operational,
763    Halted,
764}
765
766impl PrimaryChannelState {
767    fn new(offload_config: OffloadConfig) -> Self {
768        Self {
769            guest_vf_state: PrimaryChannelGuestVfState::Initializing,
770            is_data_path_switched: None,
771            control_messages: VecDeque::new(),
772            control_messages_len: 0,
773            free_control_buffers: (0..RX_RESERVED_CONTROL_BUFFERS)
774                .map(ControlMessageId)
775                .collect(),
776            rss_state: None,
777            requested_num_queues: 1,
778            rndis_state: RndisState::Initializing,
779            pending_offload_change: false,
780            offload_config,
781            tx_spread_sent: false,
782            guest_link_up: true,
783            pending_link_action: PendingLinkAction::Default,
784        }
785    }
786
787    fn restore(
788        guest_vf_state: &saved_state::GuestVfState,
789        rndis_state: &saved_state::RndisState,
790        offload_config: &saved_state::OffloadConfig,
791        pending_offload_change: bool,
792        num_queues: u16,
793        indirection_table_size: u16,
794        rx_bufs: &RxBuffers,
795        control_messages: Vec<saved_state::IncomingControlMessage>,
796        rss_state: Option<saved_state::RssState>,
797        tx_spread_sent: bool,
798        guest_link_down: bool,
799        pending_link_action: Option<bool>,
800    ) -> Result<Self, NetRestoreError> {
801        // Restore control messages.
802        let control_messages_len = control_messages.iter().map(|msg| msg.data.len()).sum();
803
804        let control_messages = control_messages
805            .into_iter()
806            .map(|msg| ControlMessage {
807                message_type: msg.message_type,
808                data: msg.data.into(),
809            })
810            .collect();
811
812        // Compute the free control buffers.
813        let free_control_buffers = (0..RX_RESERVED_CONTROL_BUFFERS)
814            .filter_map(|id| rx_bufs.is_free(id).then_some(ControlMessageId(id)))
815            .collect();
816
817        let rss_state = rss_state
818            .map(|mut rss| {
819                if rss.indirection_table.len() > indirection_table_size as usize {
820                    // Dynamic reduction of indirection table can cause unexpected and hard to investigate issues
821                    // with performance and processor overloading.
822                    return Err(NetRestoreError::ReducedIndirectionTableSize);
823                }
824                if rss.indirection_table.len() < indirection_table_size as usize {
825                    tracing::warn!(
826                        saved_indirection_table_size = rss.indirection_table.len(),
827                        adapter_indirection_table_size = indirection_table_size,
828                        "increasing indirection table size",
829                    );
830                    // Dynamic increase of indirection table is done by duplicating the existing entries until
831                    // the desired size is reached.
832                    let table_clone = rss.indirection_table.clone();
833                    let num_to_add = indirection_table_size as usize - rss.indirection_table.len();
834                    rss.indirection_table
835                        .extend(table_clone.iter().cycle().take(num_to_add));
836                }
837                Ok(RssState {
838                    key: rss
839                        .key
840                        .try_into()
841                        .map_err(|_| NetRestoreError::InvalidRssKeySize)?,
842                    indirection_table: rss.indirection_table,
843                })
844            })
845            .transpose()?;
846
847        let rndis_state = match rndis_state {
848            saved_state::RndisState::Initializing => RndisState::Initializing,
849            saved_state::RndisState::Operational => RndisState::Operational,
850            saved_state::RndisState::Halted => RndisState::Halted,
851        };
852
853        let guest_vf_state = PrimaryChannelGuestVfState::Restoring(*guest_vf_state);
854        let offload_config = OffloadConfig {
855            checksum_tx: ChecksumOffloadConfig {
856                ipv4_header: offload_config.checksum_tx.ipv4_header,
857                tcp4: offload_config.checksum_tx.tcp4,
858                udp4: offload_config.checksum_tx.udp4,
859                tcp6: offload_config.checksum_tx.tcp6,
860                udp6: offload_config.checksum_tx.udp6,
861            },
862            checksum_rx: ChecksumOffloadConfig {
863                ipv4_header: offload_config.checksum_rx.ipv4_header,
864                tcp4: offload_config.checksum_rx.tcp4,
865                udp4: offload_config.checksum_rx.udp4,
866                tcp6: offload_config.checksum_rx.tcp6,
867                udp6: offload_config.checksum_rx.udp6,
868            },
869            lso4: offload_config.lso4,
870            lso6: offload_config.lso6,
871        };
872
873        let pending_link_action = if let Some(pending) = pending_link_action {
874            PendingLinkAction::Active(pending)
875        } else {
876            PendingLinkAction::Default
877        };
878
879        Ok(Self {
880            guest_vf_state,
881            is_data_path_switched: None,
882            control_messages,
883            control_messages_len,
884            free_control_buffers,
885            rss_state,
886            requested_num_queues: num_queues,
887            rndis_state,
888            pending_offload_change,
889            offload_config,
890            tx_spread_sent,
891            guest_link_up: !guest_link_down,
892            pending_link_action,
893        })
894    }
895}
896
897struct ControlMessage {
898    message_type: u32,
899    data: Box<[u8]>,
900}
901
902const TX_PACKET_QUOTA: usize = 1024;
903
904impl ActiveState {
905    fn new(primary: Option<PrimaryChannelState>, recv_buffer_count: u32) -> Self {
906        Self {
907            primary,
908            pending_tx_packets: vec![Default::default(); TX_PACKET_QUOTA],
909            free_tx_packets: (0..TX_PACKET_QUOTA as u32).rev().map(TxId).collect(),
910            pending_tx_completions: VecDeque::new(),
911            rx_bufs: RxBuffers::new(recv_buffer_count),
912            stats: Default::default(),
913        }
914    }
915
916    fn restore(
917        channel: &saved_state::Channel,
918        recv_buffer_count: u32,
919    ) -> Result<Self, NetRestoreError> {
920        let mut active = Self::new(None, recv_buffer_count);
921        let saved_state::Channel {
922            pending_tx_completions,
923            in_use_rx,
924        } = channel;
925        for rx in in_use_rx {
926            active
927                .rx_bufs
928                .allocate(rx.buffers.as_slice().iter().copied())?;
929        }
930        for &transaction_id in pending_tx_completions {
931            // Consume tx quota if any is available. If not, still
932            // allow the restore since tx quota might change from
933            // release to release.
934            let tx_id = active.free_tx_packets.pop();
935            if let Some(id) = tx_id {
936                // This shouldn't be referenced, but set it in case it is in the future.
937                active.pending_tx_packets[id.0 as usize].transaction_id = transaction_id;
938            }
939            // Save/Restore does not preserve the status of pending tx completions,
940            // completing any pending completions with 'success' to avoid making changes to saved_state.
941            active
942                .pending_tx_completions
943                .push_back(PendingTxCompletion {
944                    transaction_id,
945                    tx_id,
946                    status: protocol::Status::SUCCESS,
947                });
948        }
949        Ok(active)
950    }
951}
952
953/// The state for an rndis tx packet that's currently pending in the backend
954/// endpoint.
955#[derive(Default, Clone)]
956struct PendingTxPacket {
957    pending_packet_count: usize,
958    transaction_id: u64,
959}
960
961/// The maximum batch size.
962///
963/// TODO: An even larger value is supported when RSC is enabled, so look into
964/// this.
965const RX_BATCH_SIZE: usize = 375;
966
967/// The number of receive buffers to reserve for control message responses.
968const RX_RESERVED_CONTROL_BUFFERS: u32 = 16;
969
970/// A network adapter.
971pub struct Nic {
972    instance_id: Guid,
973    resources: DeviceResources,
974    coordinator: TaskControl<CoordinatorState, Coordinator>,
975    coordinator_send: Option<mpsc::Sender<CoordinatorMessage>>,
976    adapter: Arc<Adapter>,
977    driver_source: VmTaskDriverSource,
978}
979
980pub struct NicBuilder {
981    virtual_function: Option<Box<dyn VirtualFunction>>,
982    limit_ring_buffer: bool,
983    max_queues: u16,
984    get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
985}
986
987impl NicBuilder {
988    pub fn limit_ring_buffer(mut self, limit: bool) -> Self {
989        self.limit_ring_buffer = limit;
990        self
991    }
992
993    pub fn max_queues(mut self, max_queues: u16) -> Self {
994        self.max_queues = max_queues;
995        self
996    }
997
998    pub fn virtual_function(mut self, virtual_function: Box<dyn VirtualFunction>) -> Self {
999        self.virtual_function = Some(virtual_function);
1000        self
1001    }
1002
1003    pub fn get_guest_os_id(mut self, os_type: Box<dyn Fn() -> HvGuestOsId + Send + Sync>) -> Self {
1004        self.get_guest_os_id = Some(os_type);
1005        self
1006    }
1007
1008    /// Creates a new NIC.
1009    pub fn build(
1010        self,
1011        driver_source: &VmTaskDriverSource,
1012        instance_id: Guid,
1013        endpoint: Box<dyn Endpoint>,
1014        mac_address: MacAddress,
1015        adapter_index: u32,
1016    ) -> Nic {
1017        let multiqueue = endpoint.multiqueue_support();
1018
1019        let max_queues = self.max_queues.clamp(
1020            1,
1021            multiqueue.max_queues.min(NETVSP_MAX_SUBCHANNELS_PER_VNIC),
1022        );
1023
1024        // If requested, limit the effective size of the outgoing ring buffer.
1025        // In a configuration where the NIC is processed synchronously, this
1026        // will ensure that we don't process incoming rx packets and tx packet
1027        // completions until the guest has processed the data it already has.
1028        let ring_size_limit = if self.limit_ring_buffer { 1024 } else { 0 };
1029
1030        // If the endpoint completes tx packets quickly, then avoid polling the
1031        // incoming ring (and thus avoid arming the signal from the guest) as
1032        // long as there are any tx packets in flight. This can significantly
1033        // reduce the signal rate from the guest, improving batching.
1034        let free_tx_packet_threshold = if endpoint.tx_fast_completions() {
1035            TX_PACKET_QUOTA
1036        } else {
1037            // Avoid getting into a situation where there is always barely
1038            // enough quota.
1039            TX_PACKET_QUOTA / 4
1040        };
1041
1042        let tx_offloads = endpoint.tx_offload_support();
1043
1044        // Always claim support for rx offloads since we can mark any given
1045        // packet as having unknown checksum state.
1046        let offload_support = OffloadConfig {
1047            checksum_rx: ChecksumOffloadConfig {
1048                ipv4_header: true,
1049                tcp4: true,
1050                udp4: true,
1051                tcp6: true,
1052                udp6: true,
1053            },
1054            checksum_tx: ChecksumOffloadConfig {
1055                ipv4_header: tx_offloads.ipv4_header,
1056                tcp4: tx_offloads.tcp,
1057                tcp6: tx_offloads.tcp,
1058                udp4: tx_offloads.udp,
1059                udp6: tx_offloads.udp,
1060            },
1061            lso4: tx_offloads.tso,
1062            lso6: tx_offloads.tso,
1063        };
1064
1065        let driver = driver_source.simple();
1066        let adapter = Arc::new(Adapter {
1067            driver,
1068            mac_address,
1069            max_queues,
1070            indirection_table_size: multiqueue.indirection_table_size,
1071            offload_support,
1072            free_tx_packet_threshold,
1073            ring_size_limit: ring_size_limit.into(),
1074            tx_fast_completions: endpoint.tx_fast_completions(),
1075            adapter_index,
1076            get_guest_os_id: self.get_guest_os_id,
1077            num_sub_channels_opened: AtomicUsize::new(0),
1078            link_speed: endpoint.link_speed(),
1079        });
1080
1081        let coordinator = TaskControl::new(CoordinatorState {
1082            endpoint,
1083            adapter: adapter.clone(),
1084            virtual_function: self.virtual_function,
1085            pending_vf_state: CoordinatorStatePendingVfState::Ready,
1086        });
1087
1088        Nic {
1089            instance_id,
1090            resources: Default::default(),
1091            coordinator,
1092            coordinator_send: None,
1093            adapter,
1094            driver_source: driver_source.clone(),
1095        }
1096    }
1097}
1098
1099fn can_use_ring_opt<T: RingMem>(queue: &mut Queue<T>, guest_os_id: Option<HvGuestOsId>) -> bool {
1100    let Some(guest_os_id) = guest_os_id else {
1101        // guest os id not available.
1102        return false;
1103    };
1104
1105    if !queue.split().0.supports_pending_send_size() {
1106        // guest does not support pending send size.
1107        return false;
1108    }
1109
1110    let Some(open_source_os) = guest_os_id.open_source() else {
1111        // guest os is proprietary (ex: Windows)
1112        return true;
1113    };
1114
1115    match HvGuestOsOpenSourceType(open_source_os.os_type()) {
1116        // Although FreeBSD indicates support for `pending send size`, it doesn't
1117        // implement it correctly. This was fixed in FreeBSD version `1400097`.
1118        // No known issues with other open source OS.
1119        HvGuestOsOpenSourceType::FREEBSD => open_source_os.version() >= 1400097,
1120        _ => true,
1121    }
1122}
1123
1124impl Nic {
1125    pub fn builder() -> NicBuilder {
1126        NicBuilder {
1127            virtual_function: None,
1128            limit_ring_buffer: false,
1129            max_queues: !0,
1130            get_guest_os_id: None,
1131        }
1132    }
1133
1134    pub fn shutdown(self) -> Box<dyn Endpoint> {
1135        let (state, _) = self.coordinator.into_inner();
1136        state.endpoint
1137    }
1138}
1139
1140impl InspectMut for Nic {
1141    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
1142        self.coordinator.inspect_mut(req);
1143    }
1144}
1145
1146#[async_trait]
1147impl VmbusDevice for Nic {
1148    fn offer(&self) -> OfferParams {
1149        OfferParams {
1150            interface_name: "net".to_owned(),
1151            instance_id: self.instance_id,
1152            interface_id: Guid {
1153                data1: 0xf8615163,
1154                data2: 0xdf3e,
1155                data3: 0x46c5,
1156                data4: [0x91, 0x3f, 0xf2, 0xd2, 0xf9, 0x65, 0xed, 0xe],
1157            },
1158            subchannel_index: 0,
1159            mnf_interrupt_latency: Some(Duration::from_micros(100)),
1160            ..Default::default()
1161        }
1162    }
1163
1164    fn max_subchannels(&self) -> u16 {
1165        self.adapter.max_queues
1166    }
1167
1168    fn install(&mut self, resources: DeviceResources) {
1169        self.resources = resources;
1170    }
1171
1172    async fn open(
1173        &mut self,
1174        channel_idx: u16,
1175        open_request: &OpenRequest,
1176    ) -> Result<(), ChannelOpenError> {
1177        // Start the coordinator task if this is the primary channel.
1178        let state = if channel_idx == 0 {
1179            self.insert_coordinator(1, false);
1180            WorkerState::Init(None)
1181        } else {
1182            self.coordinator.stop().await;
1183            // Get the buffers created when the primary channel was opened.
1184            let buffers = self.coordinator.state().unwrap().buffers.clone().unwrap();
1185            WorkerState::Ready(ReadyState {
1186                state: ActiveState::new(None, buffers.recv_buffer.count),
1187                buffers,
1188                data: ProcessingData::new(),
1189            })
1190        };
1191
1192        let num_opened = self
1193            .adapter
1194            .num_sub_channels_opened
1195            .fetch_add(1, Ordering::SeqCst);
1196        let r = self.insert_worker(channel_idx, open_request, state, true);
1197        if channel_idx != 0
1198            && num_opened + 1 == self.coordinator.state_mut().unwrap().num_queues as usize
1199        {
1200            let coordinator = &mut self.coordinator.state_mut().unwrap();
1201            coordinator.workers[0].stop().await;
1202        }
1203
1204        if r.is_err() && channel_idx == 0 {
1205            self.coordinator.remove();
1206        } else {
1207            // The coordinator will restart any stopped workers.
1208            self.coordinator.start();
1209        }
1210        r?;
1211        Ok(())
1212    }
1213
1214    async fn close(&mut self, channel_idx: u16) {
1215        if !self.coordinator.has_state() {
1216            tracing::error!("Close called while vmbus channel is already closed");
1217            return;
1218        }
1219
1220        // Stop the coordinator to get access to the workers.
1221        let restart = self.coordinator.stop().await;
1222
1223        // Stop and remove the channel worker.
1224        {
1225            let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1226            worker.stop().await;
1227            if worker.has_state() {
1228                worker.remove();
1229            }
1230        }
1231
1232        self.adapter
1233            .num_sub_channels_opened
1234            .fetch_sub(1, Ordering::SeqCst);
1235        // Disable the endpoint.
1236        if channel_idx == 0 {
1237            for worker in &mut self.coordinator.state_mut().unwrap().workers {
1238                worker.task_mut().queue_state = None;
1239            }
1240
1241            // Note that this await is not restartable.
1242            self.coordinator.task_mut().endpoint.stop().await;
1243
1244            // Keep any VF's added to the guest. This is required to keep guest compat as
1245            // some apps (such as DPDK) relies on the VF sticking around even after vmbus
1246            // channel is closed.
1247            // The coordinator's job is done.
1248            self.coordinator.remove();
1249        } else {
1250            // Restart the coordinator.
1251            if restart {
1252                self.coordinator.start();
1253            }
1254        }
1255    }
1256
1257    async fn retarget_vp(&mut self, channel_idx: u16, target_vp: u32) {
1258        if !self.coordinator.has_state() {
1259            return;
1260        }
1261
1262        // Stop the coordinator and worker associated with this channel.
1263        let coordinator_running = self.coordinator.stop().await;
1264        let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1265        worker.stop().await;
1266        let (net_queue, worker_state) = worker.get_mut();
1267
1268        // Update the target VP on the driver.
1269        net_queue.driver.retarget_vp(target_vp);
1270
1271        if let Some(worker_state) = worker_state {
1272            // Update the target VP in the worker state.
1273            worker_state.target_vp = target_vp;
1274            if let Some(queue_state) = &mut net_queue.queue_state {
1275                // Tell the worker to re-set the target VP on next run.
1276                queue_state.target_vp_set = false;
1277            }
1278        }
1279
1280        // The coordinator will restart any stopped workers.
1281        if coordinator_running {
1282            self.coordinator.start();
1283        }
1284    }
1285
1286    fn start(&mut self) {
1287        if !self.coordinator.is_running() {
1288            self.coordinator.start();
1289        }
1290    }
1291
1292    async fn stop(&mut self) {
1293        self.coordinator.stop().await;
1294        if let Some(coordinator) = self.coordinator.state_mut() {
1295            coordinator.stop_workers().await;
1296        }
1297    }
1298
1299    fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
1300        Some(self)
1301    }
1302}
1303
1304#[async_trait]
1305impl SaveRestoreVmbusDevice for Nic {
1306    async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
1307        let state = self.saved_state();
1308        Ok(SavedStateBlob::new(state))
1309    }
1310
1311    async fn restore(
1312        &mut self,
1313        control: RestoreControl<'_>,
1314        state: SavedStateBlob,
1315    ) -> Result<(), RestoreError> {
1316        let state: saved_state::SavedState = state.parse()?;
1317        if let Err(err) = self.restore_state(control, state).await {
1318            tracing::error!(err = &err as &dyn std::error::Error, instance_id = %self.instance_id, "Failed restoring network vmbus state");
1319            Err(err.into())
1320        } else {
1321            Ok(())
1322        }
1323    }
1324}
1325
1326impl Nic {
1327    /// Allocates and inserts a worker.
1328    ///
1329    /// The coordinator must be stopped.
1330    fn insert_worker(
1331        &mut self,
1332        channel_idx: u16,
1333        open_request: &OpenRequest,
1334        state: WorkerState,
1335        start: bool,
1336    ) -> Result<(), OpenError> {
1337        let coordinator = self.coordinator.state_mut().unwrap();
1338
1339        // Retarget the driver now that the channel is open.
1340        let driver = coordinator.workers[channel_idx as usize]
1341            .task()
1342            .driver
1343            .clone();
1344        driver.retarget_vp(open_request.open_data.target_vp);
1345
1346        let raw = gpadl_channel(&driver, &self.resources, open_request, channel_idx)
1347            .map_err(OpenError::Ring)?;
1348        let mut queue = Queue::new(raw).map_err(OpenError::Queue)?;
1349        let guest_os_id = self.adapter.get_guest_os_id.as_ref().map(|f| f());
1350        let can_use_ring_size_opt = can_use_ring_opt(&mut queue, guest_os_id);
1351        let worker = Worker {
1352            channel_idx,
1353            target_vp: open_request.open_data.target_vp,
1354            mem: self
1355                .resources
1356                .offer_resources
1357                .guest_memory(open_request)
1358                .clone(),
1359            channel: NetChannel {
1360                adapter: self.adapter.clone(),
1361                queue,
1362                gpadl_map: self.resources.gpadl_map.clone(),
1363                packet_size: protocol::PACKET_SIZE_V1,
1364                pending_send_size: 0,
1365                restart: None,
1366                can_use_ring_size_opt,
1367            },
1368            state,
1369            coordinator_send: self.coordinator_send.clone().unwrap(),
1370        };
1371        let instance_id = self.instance_id;
1372        let worker_task = &mut coordinator.workers[channel_idx as usize];
1373        worker_task.insert(
1374            driver,
1375            format!("netvsp-{}-{}", instance_id, channel_idx),
1376            worker,
1377        );
1378        if start {
1379            worker_task.start();
1380        }
1381        Ok(())
1382    }
1383}
1384
1385impl Nic {
1386    /// If `restoring`, then restart the queues as soon as the coordinator starts.
1387    fn insert_coordinator(&mut self, num_queues: u16, restoring: bool) {
1388        let mut driver_builder = self.driver_source.builder();
1389        // Target each driver to VP 0 initially. This will be updated when the
1390        // channel is opened.
1391        driver_builder.target_vp(0);
1392        // If tx completions arrive quickly, then just do tx processing
1393        // on whatever processor the guest happens to signal from.
1394        // Subsequent transmits will be pulled from the completion
1395        // processor.
1396        driver_builder.run_on_target(!self.adapter.tx_fast_completions);
1397
1398        #[expect(clippy::disallowed_methods)] // TODO
1399        let (send, recv) = mpsc::channel(1);
1400        self.coordinator_send = Some(send);
1401        self.coordinator.insert(
1402            &self.adapter.driver,
1403            format!("netvsp-{}-coordinator", self.instance_id),
1404            Coordinator {
1405                recv,
1406                channel_control: self.resources.channel_control.clone(),
1407                restart: restoring,
1408                workers: (0..self.adapter.max_queues)
1409                    .map(|i| {
1410                        TaskControl::new(NetQueue {
1411                            queue_state: None,
1412                            driver: driver_builder
1413                                .build(format!("netvsp-{}-{}", self.instance_id, i)),
1414                        })
1415                    })
1416                    .collect(),
1417                buffers: None,
1418                num_queues,
1419            },
1420        );
1421    }
1422}
1423
1424#[derive(Debug, Error)]
1425enum NetRestoreError {
1426    #[error("unsupported protocol version {0:#x}")]
1427    UnsupportedVersion(u32),
1428    #[error("send/receive buffer invalid gpadl ID")]
1429    UnknownGpadlId(#[from] UnknownGpadlId),
1430    #[error("failed to restore channels")]
1431    Channel(#[source] ChannelRestoreError),
1432    #[error(transparent)]
1433    ReceiveBuffer(#[from] BufferError),
1434    #[error(transparent)]
1435    SuballocationMisconfigured(#[from] SubAllocationInUse),
1436    #[error(transparent)]
1437    Open(#[from] OpenError),
1438    #[error("invalid rss key size")]
1439    InvalidRssKeySize,
1440    #[error("reduced indirection table size")]
1441    ReducedIndirectionTableSize,
1442}
1443
1444impl From<NetRestoreError> for RestoreError {
1445    fn from(err: NetRestoreError) -> Self {
1446        RestoreError::InvalidSavedState(anyhow::Error::new(err))
1447    }
1448}
1449
1450impl Nic {
1451    async fn restore_state(
1452        &mut self,
1453        mut control: RestoreControl<'_>,
1454        state: saved_state::SavedState,
1455    ) -> Result<(), NetRestoreError> {
1456        if let Some(state) = state.open {
1457            let open = match &state.primary {
1458                saved_state::Primary::Version => vec![true],
1459                saved_state::Primary::Init(_) => vec![true],
1460                saved_state::Primary::Ready(ready) => {
1461                    ready.channels.iter().map(|x| x.is_some()).collect()
1462                }
1463            };
1464
1465            let mut states: Vec<_> = open.iter().map(|_| None).collect();
1466
1467            // N.B. This will restore the vmbus view of open channels, so any
1468            //      failures after this point could result in inconsistent
1469            //      state (vmbus believes the channel is open/active). There
1470            //      are a number of failure paths after this point because this
1471            //      call also restores vmbus device state, like the GPADL map.
1472            let requests = control
1473                .restore(&open)
1474                .await
1475                .map_err(NetRestoreError::Channel)?;
1476
1477            match state.primary {
1478                saved_state::Primary::Version => {
1479                    states[0] = Some(WorkerState::Init(None));
1480                }
1481                saved_state::Primary::Init(init) => {
1482                    let version = check_version(init.version)
1483                        .ok_or(NetRestoreError::UnsupportedVersion(init.version))?;
1484
1485                    let recv_buffer = init
1486                        .receive_buffer
1487                        .map(|recv_buffer| {
1488                            ReceiveBuffer::new(
1489                                &self.resources.gpadl_map,
1490                                recv_buffer.gpadl_id,
1491                                recv_buffer.id,
1492                                recv_buffer.sub_allocation_size,
1493                            )
1494                        })
1495                        .transpose()?;
1496
1497                    let send_buffer = init
1498                        .send_buffer
1499                        .map(|send_buffer| {
1500                            SendBuffer::new(&self.resources.gpadl_map, send_buffer.gpadl_id)
1501                        })
1502                        .transpose()?;
1503
1504                    let state = InitState {
1505                        version,
1506                        ndis_config: init.ndis_config.map(
1507                            |saved_state::NdisConfig { mtu, capabilities }| NdisConfig {
1508                                mtu,
1509                                capabilities: capabilities.into(),
1510                            },
1511                        ),
1512                        ndis_version: init.ndis_version.map(
1513                            |saved_state::NdisVersion { major, minor }| NdisVersion {
1514                                major,
1515                                minor,
1516                            },
1517                        ),
1518                        recv_buffer,
1519                        send_buffer,
1520                    };
1521                    states[0] = Some(WorkerState::Init(Some(state)));
1522                }
1523                saved_state::Primary::Ready(ready) => {
1524                    let saved_state::ReadyPrimary {
1525                        version,
1526                        receive_buffer,
1527                        send_buffer,
1528                        mut control_messages,
1529                        mut rss_state,
1530                        channels,
1531                        ndis_version,
1532                        ndis_config,
1533                        rndis_state,
1534                        guest_vf_state,
1535                        offload_config,
1536                        pending_offload_change,
1537                        tx_spread_sent,
1538                        guest_link_down,
1539                        pending_link_action,
1540                    } = ready;
1541
1542                    let version = check_version(version)
1543                        .ok_or(NetRestoreError::UnsupportedVersion(version))?;
1544
1545                    let request = requests[0].as_ref().unwrap();
1546                    let buffers = Arc::new(ChannelBuffers {
1547                        version,
1548                        mem: self.resources.offer_resources.guest_memory(request).clone(),
1549                        recv_buffer: ReceiveBuffer::new(
1550                            &self.resources.gpadl_map,
1551                            receive_buffer.gpadl_id,
1552                            receive_buffer.id,
1553                            receive_buffer.sub_allocation_size,
1554                        )?,
1555                        send_buffer: {
1556                            if let Some(send_buffer) = send_buffer {
1557                                Some(SendBuffer::new(
1558                                    &self.resources.gpadl_map,
1559                                    send_buffer.gpadl_id,
1560                                )?)
1561                            } else {
1562                                None
1563                            }
1564                        },
1565                        ndis_version: {
1566                            let saved_state::NdisVersion { major, minor } = ndis_version;
1567                            NdisVersion { major, minor }
1568                        },
1569                        ndis_config: {
1570                            let saved_state::NdisConfig { mtu, capabilities } = ndis_config;
1571                            NdisConfig {
1572                                mtu,
1573                                capabilities: capabilities.into(),
1574                            }
1575                        },
1576                    });
1577
1578                    for (channel_idx, channel) in channels.iter().enumerate() {
1579                        let channel = if let Some(channel) = channel {
1580                            channel
1581                        } else {
1582                            continue;
1583                        };
1584
1585                        let mut active = ActiveState::restore(channel, buffers.recv_buffer.count)?;
1586
1587                        // Restore primary channel state.
1588                        if channel_idx == 0 {
1589                            let primary = PrimaryChannelState::restore(
1590                                &guest_vf_state,
1591                                &rndis_state,
1592                                &offload_config,
1593                                pending_offload_change,
1594                                channels.len() as u16,
1595                                self.adapter.indirection_table_size,
1596                                &active.rx_bufs,
1597                                std::mem::take(&mut control_messages),
1598                                rss_state.take(),
1599                                tx_spread_sent,
1600                                guest_link_down,
1601                                pending_link_action,
1602                            )?;
1603                            active.primary = Some(primary);
1604                        }
1605
1606                        states[channel_idx] = Some(WorkerState::Ready(ReadyState {
1607                            buffers: buffers.clone(),
1608                            state: active,
1609                            data: ProcessingData::new(),
1610                        }));
1611                    }
1612                }
1613            }
1614
1615            // Insert the coordinator and mark that it should try to start the
1616            // network endpoint when it starts running.
1617            self.insert_coordinator(states.len() as u16, true);
1618
1619            for (channel_idx, (state, request)) in states.into_iter().zip(requests).enumerate() {
1620                if let Some(state) = state {
1621                    self.insert_worker(channel_idx as u16, &request.unwrap(), state, false)?;
1622                }
1623            }
1624        } else {
1625            control
1626                .restore(&[false])
1627                .await
1628                .map_err(NetRestoreError::Channel)?;
1629        }
1630        Ok(())
1631    }
1632
1633    fn saved_state(&self) -> saved_state::SavedState {
1634        let open = if let Some(coordinator) = self.coordinator.state() {
1635            let primary = coordinator.workers[0].state().unwrap();
1636            let primary = match &primary.state {
1637                WorkerState::Init(None) => saved_state::Primary::Version,
1638                WorkerState::Init(Some(init)) => {
1639                    saved_state::Primary::Init(saved_state::InitPrimary {
1640                        version: init.version as u32,
1641                        ndis_config: init.ndis_config.map(|NdisConfig { mtu, capabilities }| {
1642                            saved_state::NdisConfig {
1643                                mtu,
1644                                capabilities: capabilities.into(),
1645                            }
1646                        }),
1647                        ndis_version: init.ndis_version.map(|NdisVersion { major, minor }| {
1648                            saved_state::NdisVersion { major, minor }
1649                        }),
1650                        receive_buffer: init.recv_buffer.as_ref().map(|x| x.saved_state()),
1651                        send_buffer: init.send_buffer.as_ref().map(|x| saved_state::SendBuffer {
1652                            gpadl_id: x.gpadl.id(),
1653                        }),
1654                    })
1655                }
1656                WorkerState::Ready(ready) => {
1657                    let primary = ready.state.primary.as_ref().unwrap();
1658
1659                    let rndis_state = match primary.rndis_state {
1660                        RndisState::Initializing => saved_state::RndisState::Initializing,
1661                        RndisState::Operational => saved_state::RndisState::Operational,
1662                        RndisState::Halted => saved_state::RndisState::Halted,
1663                    };
1664
1665                    let offload_config = saved_state::OffloadConfig {
1666                        checksum_tx: saved_state::ChecksumOffloadConfig {
1667                            ipv4_header: primary.offload_config.checksum_tx.ipv4_header,
1668                            tcp4: primary.offload_config.checksum_tx.tcp4,
1669                            udp4: primary.offload_config.checksum_tx.udp4,
1670                            tcp6: primary.offload_config.checksum_tx.tcp6,
1671                            udp6: primary.offload_config.checksum_tx.udp6,
1672                        },
1673                        checksum_rx: saved_state::ChecksumOffloadConfig {
1674                            ipv4_header: primary.offload_config.checksum_rx.ipv4_header,
1675                            tcp4: primary.offload_config.checksum_rx.tcp4,
1676                            udp4: primary.offload_config.checksum_rx.udp4,
1677                            tcp6: primary.offload_config.checksum_rx.tcp6,
1678                            udp6: primary.offload_config.checksum_rx.udp6,
1679                        },
1680                        lso4: primary.offload_config.lso4,
1681                        lso6: primary.offload_config.lso6,
1682                    };
1683
1684                    let control_messages = primary
1685                        .control_messages
1686                        .iter()
1687                        .map(|message| saved_state::IncomingControlMessage {
1688                            message_type: message.message_type,
1689                            data: message.data.to_vec(),
1690                        })
1691                        .collect();
1692
1693                    let rss_state = primary.rss_state.as_ref().map(|rss| saved_state::RssState {
1694                        key: rss.key.into(),
1695                        indirection_table: rss.indirection_table.clone(),
1696                    });
1697
1698                    let pending_link_action = match primary.pending_link_action {
1699                        PendingLinkAction::Default => None,
1700                        PendingLinkAction::Active(action) | PendingLinkAction::Delay(action) => {
1701                            Some(action)
1702                        }
1703                    };
1704
1705                    let channels = coordinator.workers[..coordinator.num_queues as usize]
1706                        .iter()
1707                        .map(|worker| {
1708                            worker.state().map(|worker| {
1709                                let channel = if let Some(ready) = worker.state.ready() {
1710                                    // In flight tx will be considered as dropped packets through save/restore, but need
1711                                    // to complete the requests back to the guest.
1712                                    let pending_tx_completions = ready
1713                                        .state
1714                                        .pending_tx_completions
1715                                        .iter()
1716                                        .map(|pending| pending.transaction_id)
1717                                        .chain(ready.state.pending_tx_packets.iter().filter_map(
1718                                            |inflight| {
1719                                                (inflight.pending_packet_count > 0)
1720                                                    .then_some(inflight.transaction_id)
1721                                            },
1722                                        ))
1723                                        .collect();
1724
1725                                    saved_state::Channel {
1726                                        pending_tx_completions,
1727                                        in_use_rx: {
1728                                            ready
1729                                                .state
1730                                                .rx_bufs
1731                                                .allocated()
1732                                                .map(|id| saved_state::Rx {
1733                                                    buffers: id.collect(),
1734                                                })
1735                                                .collect()
1736                                        },
1737                                    }
1738                                } else {
1739                                    saved_state::Channel {
1740                                        pending_tx_completions: Vec::new(),
1741                                        in_use_rx: Vec::new(),
1742                                    }
1743                                };
1744                                channel
1745                            })
1746                        })
1747                        .collect();
1748
1749                    let guest_vf_state = match primary.guest_vf_state {
1750                        PrimaryChannelGuestVfState::Initializing
1751                        | PrimaryChannelGuestVfState::Unavailable
1752                        | PrimaryChannelGuestVfState::Available { .. } => {
1753                            saved_state::GuestVfState::NoState
1754                        }
1755                        PrimaryChannelGuestVfState::UnavailableFromAvailable
1756                        | PrimaryChannelGuestVfState::AvailableAdvertised => {
1757                            saved_state::GuestVfState::AvailableAdvertised
1758                        }
1759                        PrimaryChannelGuestVfState::Ready => saved_state::GuestVfState::Ready,
1760                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
1761                            to_guest,
1762                            id,
1763                        } => saved_state::GuestVfState::DataPathSwitchPending {
1764                            to_guest,
1765                            id,
1766                            result: None,
1767                        },
1768                        PrimaryChannelGuestVfState::DataPathSwitchPending {
1769                            to_guest,
1770                            id,
1771                            result,
1772                        } => saved_state::GuestVfState::DataPathSwitchPending {
1773                            to_guest,
1774                            id,
1775                            result,
1776                        },
1777                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
1778                        | PrimaryChannelGuestVfState::DataPathSwitched
1779                        | PrimaryChannelGuestVfState::DataPathSynthetic => {
1780                            saved_state::GuestVfState::DataPathSwitched
1781                        }
1782                        PrimaryChannelGuestVfState::Restoring(saved_state) => saved_state,
1783                    };
1784
1785                    saved_state::Primary::Ready(saved_state::ReadyPrimary {
1786                        version: ready.buffers.version as u32,
1787                        receive_buffer: ready.buffers.recv_buffer.saved_state(),
1788                        send_buffer: ready.buffers.send_buffer.as_ref().map(|sb| {
1789                            saved_state::SendBuffer {
1790                                gpadl_id: sb.gpadl.id(),
1791                            }
1792                        }),
1793                        rndis_state,
1794                        guest_vf_state,
1795                        offload_config,
1796                        pending_offload_change: primary.pending_offload_change,
1797                        control_messages,
1798                        rss_state,
1799                        channels,
1800                        ndis_config: {
1801                            let NdisConfig { mtu, capabilities } = ready.buffers.ndis_config;
1802                            saved_state::NdisConfig {
1803                                mtu,
1804                                capabilities: capabilities.into(),
1805                            }
1806                        },
1807                        ndis_version: {
1808                            let NdisVersion { major, minor } = ready.buffers.ndis_version;
1809                            saved_state::NdisVersion { major, minor }
1810                        },
1811                        tx_spread_sent: primary.tx_spread_sent,
1812                        guest_link_down: !primary.guest_link_up,
1813                        pending_link_action,
1814                    })
1815                }
1816            };
1817
1818            let state = saved_state::OpenState { primary };
1819            Some(state)
1820        } else {
1821            None
1822        };
1823
1824        saved_state::SavedState { open }
1825    }
1826}
1827
1828#[derive(Debug, Error)]
1829enum WorkerError {
1830    #[error("packet error")]
1831    Packet(#[source] PacketError),
1832    #[error("unexpected packet order: {0}")]
1833    UnexpectedPacketOrder(#[source] PacketOrderError),
1834    #[error("unknown rndis message type: {0}")]
1835    UnknownRndisMessageType(u32),
1836    #[error("memory access error")]
1837    Access(#[from] AccessError),
1838    #[error("rndis message too small")]
1839    RndisMessageTooSmall,
1840    #[error("unsupported rndis behavior")]
1841    UnsupportedRndisBehavior,
1842    #[error("vmbus queue error")]
1843    Queue(#[from] queue::Error),
1844    #[error("too many control messages")]
1845    TooManyControlMessages,
1846    #[error("invalid rndis packet completion")]
1847    InvalidRndisPacketCompletion,
1848    #[error("missing transaction id")]
1849    MissingTransactionId,
1850    #[error("invalid gpadl")]
1851    InvalidGpadl(#[source] guestmem::InvalidGpn),
1852    #[error("gpadl error")]
1853    GpadlError(#[source] GuestMemoryError),
1854    #[error("gpa direct error")]
1855    GpaDirectError(#[source] GuestMemoryError),
1856    #[error("endpoint")]
1857    Endpoint(#[source] anyhow::Error),
1858    #[error("message not supported on sub channel: {0}")]
1859    NotSupportedOnSubChannel(u32),
1860    #[error("the ring buffer ran out of space, which should not be possible")]
1861    OutOfSpace,
1862    #[error("send/receive buffer error")]
1863    Buffer(#[from] BufferError),
1864    #[error("invalid rndis state")]
1865    InvalidRndisState,
1866    #[error("rndis message type not implemented")]
1867    RndisMessageTypeNotImplemented,
1868    #[error("invalid TCP header offset")]
1869    InvalidTcpHeaderOffset,
1870    #[error("cancelled")]
1871    Cancelled(task_control::Cancelled),
1872    #[error("tearing down because send/receive buffer is revoked")]
1873    BufferRevoked,
1874    #[error("endpoint requires queue restart: {0}")]
1875    EndpointRequiresQueueRestart(#[source] anyhow::Error),
1876}
1877
1878impl From<task_control::Cancelled> for WorkerError {
1879    fn from(value: task_control::Cancelled) -> Self {
1880        Self::Cancelled(value)
1881    }
1882}
1883
1884#[derive(Debug, Error)]
1885enum OpenError {
1886    #[error("error establishing ring buffer")]
1887    Ring(#[source] vmbus_channel::gpadl_ring::Error),
1888    #[error("error establishing vmbus queue")]
1889    Queue(#[source] queue::Error),
1890}
1891
1892#[derive(Debug, Error)]
1893enum PacketError {
1894    #[error("UnknownType {0}")]
1895    UnknownType(u32),
1896    #[error("Access")]
1897    Access(#[source] AccessError),
1898    #[error("ExternalData")]
1899    ExternalData(#[source] ExternalDataError),
1900    #[error("InvalidSendBufferIndex")]
1901    InvalidSendBufferIndex,
1902}
1903
1904#[derive(Debug, Error)]
1905enum PacketOrderError {
1906    #[error("Invalid PacketData")]
1907    InvalidPacketData,
1908    #[error("Unexpected RndisPacket")]
1909    UnexpectedRndisPacket,
1910    #[error("SendNdisVersion already exists")]
1911    SendNdisVersionExists,
1912    #[error("SendNdisConfig already exists")]
1913    SendNdisConfigExists,
1914    #[error("SendReceiveBuffer already exists")]
1915    SendReceiveBufferExists,
1916    #[error("SendReceiveBuffer missing MTU")]
1917    SendReceiveBufferMissingMTU,
1918    #[error("SendSendBuffer already exists")]
1919    SendSendBufferExists,
1920    #[error("SwitchDataPathCompletion during PrimaryChannelState")]
1921    SwitchDataPathCompletionPrimaryChannelState,
1922}
1923
1924#[derive(Debug)]
1925enum PacketData {
1926    Init(protocol::MessageInit),
1927    SendNdisVersion(protocol::Message1SendNdisVersion),
1928    SendReceiveBuffer(protocol::Message1SendReceiveBuffer),
1929    SendSendBuffer(protocol::Message1SendSendBuffer),
1930    RevokeReceiveBuffer(Message1RevokeReceiveBuffer),
1931    RevokeSendBuffer(Message1RevokeSendBuffer),
1932    RndisPacket(protocol::Message1SendRndisPacket),
1933    RndisPacketComplete(protocol::Message1SendRndisPacketComplete),
1934    SendNdisConfig(protocol::Message2SendNdisConfig),
1935    SwitchDataPath(protocol::Message4SwitchDataPath),
1936    OidQueryEx(protocol::Message5OidQueryEx),
1937    SubChannelRequest(protocol::Message5SubchannelRequest),
1938    SendVfAssociationCompletion,
1939    SwitchDataPathCompletion,
1940}
1941
1942#[derive(Debug)]
1943struct Packet<'a> {
1944    data: PacketData,
1945    transaction_id: Option<u64>,
1946    external_data: MultiPagedRangeBuf<GpnList>,
1947    send_buffer_suballocation: PagedRange<'a>,
1948}
1949
1950type PacketReader<'a> = PagedRangesReader<
1951    'a,
1952    std::iter::Chain<std::iter::Once<PagedRange<'a>>, MultiPagedRangeIter<'a>>,
1953>;
1954
1955impl Packet<'_> {
1956    fn rndis_reader<'a>(&'a self, mem: &'a GuestMemory) -> PacketReader<'a> {
1957        PagedRanges::new(
1958            std::iter::once(self.send_buffer_suballocation).chain(self.external_data.iter()),
1959        )
1960        .reader(mem)
1961    }
1962}
1963
1964fn read_packet_data<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
1965    reader: &mut impl MemoryRead,
1966) -> Result<T, PacketError> {
1967    reader.read_plain().map_err(PacketError::Access)
1968}
1969
1970fn parse_packet<'a, T: RingMem>(
1971    packet_ref: &queue::PacketRef<'_, T>,
1972    send_buffer: Option<&'a SendBuffer>,
1973    version: Option<Version>,
1974) -> Result<Packet<'a>, PacketError> {
1975    let packet = match packet_ref.as_ref() {
1976        IncomingPacket::Data(data) => data,
1977        IncomingPacket::Completion(completion) => {
1978            let data = if completion.transaction_id() == VF_ASSOCIATION_TRANSACTION_ID {
1979                PacketData::SendVfAssociationCompletion
1980            } else if completion.transaction_id() == SWITCH_DATA_PATH_TRANSACTION_ID {
1981                PacketData::SwitchDataPathCompletion
1982            } else {
1983                let mut reader = completion.reader();
1984                let header: protocol::MessageHeader =
1985                    reader.read_plain().map_err(PacketError::Access)?;
1986                match header.message_type {
1987                    protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE => {
1988                        PacketData::RndisPacketComplete(read_packet_data(&mut reader)?)
1989                    }
1990                    typ => return Err(PacketError::UnknownType(typ)),
1991                }
1992            };
1993            return Ok(Packet {
1994                data,
1995                transaction_id: Some(completion.transaction_id()),
1996                external_data: MultiPagedRangeBuf::empty(),
1997                send_buffer_suballocation: PagedRange::empty(),
1998            });
1999        }
2000    };
2001
2002    let mut reader = packet.reader();
2003    let header: protocol::MessageHeader = reader.read_plain().map_err(PacketError::Access)?;
2004    let mut send_buffer_suballocation = PagedRange::empty();
2005    let data = match header.message_type {
2006        protocol::MESSAGE_TYPE_INIT => PacketData::Init(read_packet_data(&mut reader)?),
2007        protocol::MESSAGE1_TYPE_SEND_NDIS_VERSION if version >= Some(Version::V1) => {
2008            PacketData::SendNdisVersion(read_packet_data(&mut reader)?)
2009        }
2010        protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2011            PacketData::SendReceiveBuffer(read_packet_data(&mut reader)?)
2012        }
2013        protocol::MESSAGE1_TYPE_REVOKE_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2014            PacketData::RevokeReceiveBuffer(read_packet_data(&mut reader)?)
2015        }
2016        protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER if version >= Some(Version::V1) => {
2017            PacketData::SendSendBuffer(read_packet_data(&mut reader)?)
2018        }
2019        protocol::MESSAGE1_TYPE_REVOKE_SEND_BUFFER if version >= Some(Version::V1) => {
2020            PacketData::RevokeSendBuffer(read_packet_data(&mut reader)?)
2021        }
2022        protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET if version >= Some(Version::V1) => {
2023            let message: protocol::Message1SendRndisPacket = read_packet_data(&mut reader)?;
2024            if message.send_buffer_section_index != 0xffffffff {
2025                send_buffer_suballocation = send_buffer
2026                    .ok_or(PacketError::InvalidSendBufferIndex)?
2027                    .gpadl
2028                    .first()
2029                    .unwrap()
2030                    .try_subrange(
2031                        message.send_buffer_section_index as usize * 6144,
2032                        message.send_buffer_section_size as usize,
2033                    )
2034                    .ok_or(PacketError::InvalidSendBufferIndex)?;
2035            }
2036            PacketData::RndisPacket(message)
2037        }
2038        protocol::MESSAGE2_TYPE_SEND_NDIS_CONFIG if version >= Some(Version::V2) => {
2039            PacketData::SendNdisConfig(read_packet_data(&mut reader)?)
2040        }
2041        protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH if version >= Some(Version::V4) => {
2042            PacketData::SwitchDataPath(read_packet_data(&mut reader)?)
2043        }
2044        protocol::MESSAGE5_TYPE_OID_QUERY_EX if version >= Some(Version::V5) => {
2045            PacketData::OidQueryEx(read_packet_data(&mut reader)?)
2046        }
2047        protocol::MESSAGE5_TYPE_SUB_CHANNEL if version >= Some(Version::V5) => {
2048            PacketData::SubChannelRequest(read_packet_data(&mut reader)?)
2049        }
2050        typ => return Err(PacketError::UnknownType(typ)),
2051    };
2052    Ok(Packet {
2053        data,
2054        transaction_id: packet.transaction_id(),
2055        external_data: packet
2056            .read_external_ranges()
2057            .map_err(PacketError::ExternalData)?,
2058        send_buffer_suballocation,
2059    })
2060}
2061
2062#[derive(Debug, Copy, Clone)]
2063struct NvspMessage<T> {
2064    header: protocol::MessageHeader,
2065    data: T,
2066    padding: &'static [u8],
2067}
2068
2069impl<T: IntoBytes + Immutable + KnownLayout> NvspMessage<T> {
2070    fn payload(&self) -> [&[u8]; 3] {
2071        [self.header.as_bytes(), self.data.as_bytes(), self.padding]
2072    }
2073}
2074
2075impl<T: RingMem> NetChannel<T> {
2076    fn message<P: IntoBytes + Immutable + KnownLayout>(
2077        &self,
2078        message_type: u32,
2079        data: P,
2080    ) -> NvspMessage<P> {
2081        let padding = self.padding(&data);
2082        NvspMessage {
2083            header: protocol::MessageHeader { message_type },
2084            data,
2085            padding,
2086        }
2087    }
2088
2089    /// Returns zero padding bytes to round the payload up to the packet size.
2090    /// Only needed for Windows guests, which are picky about packet sizes.
2091    fn padding<P: IntoBytes + Immutable + KnownLayout>(&self, data: &P) -> &'static [u8] {
2092        static PADDING: &[u8] = &[0; protocol::PACKET_SIZE_V61];
2093        let padding_len = self.packet_size
2094            - cmp::min(
2095                self.packet_size,
2096                size_of::<protocol::MessageHeader>() + data.as_bytes().len(),
2097            );
2098        &PADDING[..padding_len]
2099    }
2100
2101    fn send_completion(
2102        &mut self,
2103        transaction_id: Option<u64>,
2104        payload: &[&[u8]],
2105    ) -> Result<(), WorkerError> {
2106        match transaction_id {
2107            None => Ok(()),
2108            Some(transaction_id) => Ok(self
2109                .queue
2110                .split()
2111                .1
2112                .try_write(&queue::OutgoingPacket {
2113                    transaction_id,
2114                    packet_type: OutgoingPacketType::Completion,
2115                    payload,
2116                })
2117                .map_err(|err| match err {
2118                    queue::TryWriteError::Full(_) => WorkerError::OutOfSpace,
2119                    queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2120                })?),
2121        }
2122    }
2123}
2124
2125static SUPPORTED_VERSIONS: &[Version] = &[
2126    Version::V1,
2127    Version::V2,
2128    Version::V4,
2129    Version::V5,
2130    Version::V6,
2131    Version::V61,
2132];
2133
2134fn check_version(requested_version: u32) -> Option<Version> {
2135    SUPPORTED_VERSIONS
2136        .iter()
2137        .find(|version| **version as u32 == requested_version)
2138        .copied()
2139}
2140
2141#[derive(Debug)]
2142struct ReceiveBuffer {
2143    gpadl: GpadlView,
2144    id: u16,
2145    count: u32,
2146    sub_allocation_size: u32,
2147}
2148
2149#[derive(Debug, Error)]
2150enum BufferError {
2151    #[error("unsupported suballocation size {0}")]
2152    UnsupportedSuballocationSize(u32),
2153    #[error("unaligned gpadl")]
2154    UnalignedGpadl,
2155    #[error("unknown gpadl ID")]
2156    UnknownGpadlId(#[from] UnknownGpadlId),
2157}
2158
2159impl ReceiveBuffer {
2160    fn new(
2161        gpadl_map: &GpadlMapView,
2162        gpadl_id: GpadlId,
2163        id: u16,
2164        sub_allocation_size: u32,
2165    ) -> Result<Self, BufferError> {
2166        if sub_allocation_size < sub_allocation_size_for_mtu(DEFAULT_MTU) {
2167            return Err(BufferError::UnsupportedSuballocationSize(
2168                sub_allocation_size,
2169            ));
2170        }
2171        let gpadl = gpadl_map.map(gpadl_id)?;
2172        let range = gpadl
2173            .contiguous_aligned()
2174            .ok_or(BufferError::UnalignedGpadl)?;
2175        let num_sub_allocations = range.len() as u32 / sub_allocation_size;
2176        if num_sub_allocations == 0 {
2177            return Err(BufferError::UnsupportedSuballocationSize(
2178                sub_allocation_size,
2179            ));
2180        }
2181        let recv_buffer = Self {
2182            gpadl,
2183            id,
2184            count: num_sub_allocations,
2185            sub_allocation_size,
2186        };
2187        Ok(recv_buffer)
2188    }
2189
2190    fn range(&self, index: u32) -> PagedRange<'_> {
2191        self.gpadl.first().unwrap().subrange(
2192            (index * self.sub_allocation_size) as usize,
2193            self.sub_allocation_size as usize,
2194        )
2195    }
2196
2197    fn transfer_page_range(&self, index: u32, len: usize) -> ring::TransferPageRange {
2198        assert!(len <= self.sub_allocation_size as usize);
2199        ring::TransferPageRange {
2200            byte_offset: index * self.sub_allocation_size,
2201            byte_count: len as u32,
2202        }
2203    }
2204
2205    fn saved_state(&self) -> saved_state::ReceiveBuffer {
2206        saved_state::ReceiveBuffer {
2207            gpadl_id: self.gpadl.id(),
2208            id: self.id,
2209            sub_allocation_size: self.sub_allocation_size,
2210        }
2211    }
2212}
2213
2214#[derive(Debug)]
2215struct SendBuffer {
2216    gpadl: GpadlView,
2217}
2218
2219impl SendBuffer {
2220    fn new(gpadl_map: &GpadlMapView, gpadl_id: GpadlId) -> Result<Self, BufferError> {
2221        let gpadl = gpadl_map.map(gpadl_id)?;
2222        gpadl
2223            .contiguous_aligned()
2224            .ok_or(BufferError::UnalignedGpadl)?;
2225        Ok(Self { gpadl })
2226    }
2227}
2228
2229impl<T: RingMem> NetChannel<T> {
2230    /// Process a single RNDIS message.
2231    fn handle_rndis_message(
2232        &mut self,
2233        buffers: &ChannelBuffers,
2234        state: &mut ActiveState,
2235        id: TxId,
2236        message_type: u32,
2237        mut reader: PacketReader<'_>,
2238        segments: &mut Vec<TxSegment>,
2239    ) -> Result<bool, WorkerError> {
2240        let is_packet = match message_type {
2241            rndisprot::MESSAGE_TYPE_PACKET_MSG => {
2242                self.handle_rndis_packet_message(
2243                    id,
2244                    reader,
2245                    &buffers.mem,
2246                    segments,
2247                    &mut state.stats,
2248                )?;
2249                true
2250            }
2251            rndisprot::MESSAGE_TYPE_HALT_MSG => false,
2252            n => {
2253                let control = state
2254                    .primary
2255                    .as_mut()
2256                    .ok_or(WorkerError::NotSupportedOnSubChannel(n))?;
2257
2258                // This is a control message that needs a response. Responding
2259                // will require a suballocation to be available, which it may
2260                // not be right now. Enqueue the suballocation to a queue and
2261                // process the queue as suballocations become available.
2262                const CONTROL_MESSAGE_MAX_QUEUED_BYTES: usize = 100 * 1024;
2263                if reader.len() == 0 {
2264                    return Err(WorkerError::RndisMessageTooSmall);
2265                }
2266                // Do not let the queue get too large--the guest should not be
2267                // sending very many control messages at a time.
2268                if CONTROL_MESSAGE_MAX_QUEUED_BYTES - control.control_messages_len < reader.len() {
2269                    return Err(WorkerError::TooManyControlMessages);
2270                }
2271                control.control_messages_len += reader.len();
2272                control.control_messages.push_back(ControlMessage {
2273                    message_type,
2274                    data: reader.read_all()?.into(),
2275                });
2276
2277                false
2278                // The queue will be processed in the main dispatch loop.
2279            }
2280        };
2281        Ok(is_packet)
2282    }
2283
2284    /// Process an RNDIS package message (used to send an Ethernet frame).
2285    fn handle_rndis_packet_message(
2286        &mut self,
2287        id: TxId,
2288        reader: PacketReader<'_>,
2289        mem: &GuestMemory,
2290        segments: &mut Vec<TxSegment>,
2291        stats: &mut QueueStats,
2292    ) -> Result<(), WorkerError> {
2293        // Headers are guaranteed to be in a single PagedRange.
2294        let headers = reader
2295            .clone()
2296            .into_inner()
2297            .paged_ranges()
2298            .find(|r| !r.is_empty())
2299            .ok_or(WorkerError::RndisMessageTooSmall)?;
2300        let mut data = reader.into_inner();
2301        let request: rndisprot::Packet = headers.reader(mem).read_plain()?;
2302        if request.num_oob_data_elements != 0
2303            || request.oob_data_length != 0
2304            || request.oob_data_offset != 0
2305            || request.vc_handle != 0
2306        {
2307            return Err(WorkerError::UnsupportedRndisBehavior);
2308        }
2309
2310        if data.len() < request.data_offset as usize
2311            || (data.len() - request.data_offset as usize) < request.data_length as usize
2312            || request.data_length == 0
2313        {
2314            return Err(WorkerError::RndisMessageTooSmall);
2315        }
2316
2317        data.skip(request.data_offset as usize);
2318        data.truncate(request.data_length as usize);
2319
2320        let mut metadata = net_backend::TxMetadata {
2321            id,
2322            len: request.data_length as usize,
2323            ..Default::default()
2324        };
2325
2326        if request.per_packet_info_length != 0 {
2327            let mut ppi = headers
2328                .try_subrange(
2329                    request.per_packet_info_offset as usize,
2330                    request.per_packet_info_length as usize,
2331                )
2332                .ok_or(WorkerError::RndisMessageTooSmall)?;
2333            while !ppi.is_empty() {
2334                let h: rndisprot::PerPacketInfo = ppi.reader(mem).read_plain()?;
2335                if h.size == 0 {
2336                    return Err(WorkerError::RndisMessageTooSmall);
2337                }
2338                let (this, rest) = ppi
2339                    .try_split(h.size as usize)
2340                    .ok_or(WorkerError::RndisMessageTooSmall)?;
2341                let (_, d) = this
2342                    .try_split(h.per_packet_information_offset as usize)
2343                    .ok_or(WorkerError::RndisMessageTooSmall)?;
2344                match h.typ {
2345                    rndisprot::PPI_TCP_IP_CHECKSUM => {
2346                        let n: rndisprot::TxTcpIpChecksumInfo = d.reader(mem).read_plain()?;
2347
2348                        metadata.offload_tcp_checksum =
2349                            (n.is_ipv4() || n.is_ipv6()) && n.tcp_checksum();
2350                        metadata.offload_udp_checksum =
2351                            (n.is_ipv4() || n.is_ipv6()) && !n.tcp_checksum() && n.udp_checksum();
2352                        metadata.offload_ip_header_checksum = n.is_ipv4() && n.ip_header_checksum();
2353                        metadata.l3_protocol = if n.is_ipv4() {
2354                            L3Protocol::Ipv4
2355                        } else if n.is_ipv6() {
2356                            L3Protocol::Ipv6
2357                        } else {
2358                            L3Protocol::Unknown
2359                        };
2360                        metadata.l2_len = ETHERNET_HEADER_LEN as u8;
2361                        if metadata.offload_tcp_checksum || metadata.offload_udp_checksum {
2362                            metadata.l3_len = if n.tcp_header_offset() >= metadata.l2_len as u16 {
2363                                n.tcp_header_offset() - metadata.l2_len as u16
2364                            } else if n.is_ipv4() {
2365                                let mut reader = data.clone().reader(mem);
2366                                reader.skip(metadata.l2_len as usize)?;
2367                                let mut b = 0;
2368                                reader.read(std::slice::from_mut(&mut b))?;
2369                                (b as u16 >> 4) * 4
2370                            } else {
2371                                // Hope there are no extensions.
2372                                40
2373                            };
2374                        }
2375                    }
2376                    rndisprot::PPI_LSO => {
2377                        let n: rndisprot::TcpLsoInfo = d.reader(mem).read_plain()?;
2378
2379                        metadata.offload_tcp_segmentation = true;
2380                        metadata.offload_tcp_checksum = true;
2381                        metadata.offload_ip_header_checksum = n.is_ipv4();
2382                        metadata.l3_protocol = if n.is_ipv4() {
2383                            L3Protocol::Ipv4
2384                        } else {
2385                            L3Protocol::Ipv6
2386                        };
2387                        metadata.l2_len = ETHERNET_HEADER_LEN as u8;
2388                        if n.tcp_header_offset() < metadata.l2_len as u16 {
2389                            return Err(WorkerError::InvalidTcpHeaderOffset);
2390                        }
2391                        metadata.l3_len = n.tcp_header_offset() - metadata.l2_len as u16;
2392                        metadata.l4_len = {
2393                            let mut reader = data.clone().reader(mem);
2394                            reader
2395                                .skip(metadata.l2_len as usize + metadata.l3_len as usize + 12)?;
2396                            let mut b = 0;
2397                            reader.read(std::slice::from_mut(&mut b))?;
2398                            (b >> 4) * 4
2399                        };
2400                        metadata.max_tcp_segment_size = n.mss() as u16;
2401                    }
2402                    _ => {}
2403                }
2404                ppi = rest;
2405            }
2406        }
2407
2408        let start = segments.len();
2409        for range in data.paged_ranges().flat_map(|r| r.ranges()) {
2410            let range = range.map_err(WorkerError::InvalidGpadl)?;
2411            segments.push(TxSegment {
2412                ty: net_backend::TxSegmentType::Tail,
2413                gpa: range.start,
2414                len: range.len() as u32,
2415            });
2416        }
2417
2418        metadata.segment_count = segments.len() - start;
2419
2420        stats.tx_packets.increment();
2421        if metadata.offload_tcp_checksum || metadata.offload_udp_checksum {
2422            stats.tx_checksum_packets.increment();
2423        }
2424        if metadata.offload_tcp_segmentation {
2425            stats.tx_lso_packets.increment();
2426        }
2427
2428        segments[start].ty = net_backend::TxSegmentType::Head(metadata);
2429
2430        Ok(())
2431    }
2432
2433    /// Notify the adapter that the guest VF state has changed and it may
2434    /// need to send a message to the guest.
2435    fn guest_vf_is_available(
2436        &mut self,
2437        guest_vf_id: Option<u32>,
2438        version: Version,
2439        config: NdisConfig,
2440    ) -> Result<bool, WorkerError> {
2441        let serial_number = guest_vf_id.map(|vfid| self.adapter.get_guest_vf_serial_number(vfid));
2442        if version >= Version::V4 && config.capabilities.sriov() {
2443            tracing::info!(
2444                available = serial_number.is_some(),
2445                serial_number,
2446                "sending VF association message."
2447            );
2448            // N.B. MIN_CONTROL_RING_SIZE reserves room to send this packet.
2449            let message = {
2450                self.message(
2451                    protocol::MESSAGE4_TYPE_SEND_VF_ASSOCIATION,
2452                    protocol::Message4SendVfAssociation {
2453                        vf_allocated: if serial_number.is_some() { 1 } else { 0 },
2454                        serial_number: serial_number.unwrap_or(0),
2455                    },
2456                )
2457            };
2458            self.queue
2459                .split()
2460                .1
2461                .try_write(&queue::OutgoingPacket {
2462                    transaction_id: VF_ASSOCIATION_TRANSACTION_ID,
2463                    packet_type: OutgoingPacketType::InBandWithCompletion,
2464                    payload: &message.payload(),
2465                })
2466                .map_err(|err| match err {
2467                    queue::TryWriteError::Full(len) => {
2468                        tracing::error!(len, "failed to write vf association message");
2469                        WorkerError::OutOfSpace
2470                    }
2471                    queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2472                })?;
2473            Ok(true)
2474        } else {
2475            tracing::info!(
2476                available = serial_number.is_some(),
2477                serial_number,
2478                major = version.major(),
2479                minor = version.minor(),
2480                sriov_capable = config.capabilities.sriov(),
2481                "Skipping NvspMessage4TypeSendVFAssociation message"
2482            );
2483            Ok(false)
2484        }
2485    }
2486
2487    /// Send the `NvspMessage5TypeSendIndirectionTable` message.
2488    fn guest_send_indirection_table(&mut self, version: Version, num_channels_opened: u32) {
2489        // N.B. MIN_STATE_CHANGE_RING_SIZE needs to be large enough to support sending the indirection table.
2490        if version < Version::V5 {
2491            return;
2492        }
2493
2494        #[repr(C)]
2495        #[derive(IntoBytes, Immutable, KnownLayout)]
2496        struct SendIndirectionMsg {
2497            pub message: protocol::Message5SendIndirectionTable,
2498            pub send_indirection_table:
2499                [u32; VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES as usize],
2500        }
2501
2502        // The offset to the send indirection table from the beginning of the NVSP message.
2503        let send_indirection_table_offset = offset_of!(SendIndirectionMsg, send_indirection_table)
2504            + size_of::<protocol::MessageHeader>();
2505        let mut data = SendIndirectionMsg {
2506            message: protocol::Message5SendIndirectionTable {
2507                table_entry_count: VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES,
2508                table_offset: send_indirection_table_offset as u32,
2509            },
2510            send_indirection_table: Default::default(),
2511        };
2512
2513        for i in 0..data.send_indirection_table.len() {
2514            data.send_indirection_table[i] = i as u32 % num_channels_opened;
2515        }
2516
2517        let message = NvspMessage {
2518            header: protocol::MessageHeader {
2519                message_type: protocol::MESSAGE5_TYPE_SEND_INDIRECTION_TABLE,
2520            },
2521            data,
2522            padding: &[],
2523        };
2524        let result = self
2525            .queue
2526            .split()
2527            .1
2528            .try_write(&queue::OutgoingPacket {
2529                transaction_id: 0,
2530                packet_type: OutgoingPacketType::InBandNoCompletion,
2531                payload: &message.payload(),
2532            })
2533            .map_err(|err| match err {
2534                queue::TryWriteError::Full(len) => {
2535                    tracing::error!(len, "failed to write send indirection table message");
2536                    WorkerError::OutOfSpace
2537                }
2538                queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2539            });
2540        if let Err(err) = result {
2541            tracing::error!(err = %err, "Failed to notify guest about the send indirection table");
2542        }
2543    }
2544
2545    /// Notify the guest that the data path has been switched back to synthetic
2546    /// due to some external state change.
2547    fn guest_vf_data_path_switched_to_synthetic(&mut self) {
2548        let message = NvspMessage {
2549            header: protocol::MessageHeader {
2550                message_type: protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH,
2551            },
2552            data: protocol::Message4SwitchDataPath {
2553                active_data_path: protocol::DataPath::SYNTHETIC.0,
2554            },
2555            padding: &[],
2556        };
2557        let result = self
2558            .queue
2559            .split()
2560            .1
2561            .try_write(&queue::OutgoingPacket {
2562                transaction_id: SWITCH_DATA_PATH_TRANSACTION_ID,
2563                packet_type: OutgoingPacketType::InBandWithCompletion,
2564                payload: &message.payload(),
2565            })
2566            .map_err(|err| match err {
2567                queue::TryWriteError::Full(len) => {
2568                    tracing::error!(
2569                        len,
2570                        "failed to write MESSAGE4_TYPE_SWITCH_DATA_PATH message"
2571                    );
2572                    WorkerError::OutOfSpace
2573                }
2574                queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2575            });
2576        if let Err(err) = result {
2577            tracing::error!(err = %err, "Failed to notify guest that data path is now synthetic");
2578        }
2579    }
2580
2581    /// Process an internal state change
2582    async fn handle_state_change(
2583        &mut self,
2584        primary: &mut PrimaryChannelState,
2585        buffers: &ChannelBuffers,
2586    ) -> Result<Option<CoordinatorMessage>, WorkerError> {
2587        // N.B. MIN_STATE_CHANGE_RING_SIZE needs to be large enough to support sending state change messages.
2588        // The worst case is UnavailableFromDataPathSwitchPending, which will send three messages:
2589        //      1. completion of switch data path request
2590        //      2. Switch data path notification (back to synthetic)
2591        //      3. Disassociate VF adapter.
2592        if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
2593            // Notify guest that a VF capability has recently arrived.
2594            if primary.rndis_state == RndisState::Operational {
2595                if self.guest_vf_is_available(Some(vfid), buffers.version, buffers.ndis_config)? {
2596                    primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
2597                    return Ok(Some(CoordinatorMessage::UpdateGuestVfState));
2598                } else if let Some(true) = primary.is_data_path_switched {
2599                    tracing::error!(
2600                        "Data path switched, but current guest negotiation does not support VTL0 VF"
2601                    );
2602                }
2603            }
2604            return Ok(None);
2605        }
2606        loop {
2607            primary.guest_vf_state = match primary.guest_vf_state {
2608                PrimaryChannelGuestVfState::UnavailableFromAvailable => {
2609                    // Notify guest that the VF is unavailable. It has already been surprise removed.
2610                    if primary.rndis_state == RndisState::Operational {
2611                        self.guest_vf_is_available(None, buffers.version, buffers.ndis_config)?;
2612                    }
2613                    PrimaryChannelGuestVfState::Unavailable
2614                }
2615                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
2616                    to_guest,
2617                    id,
2618                } => {
2619                    // Complete the data path switch request.
2620                    self.send_completion(id, &[])?;
2621                    if to_guest {
2622                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
2623                    } else {
2624                        PrimaryChannelGuestVfState::UnavailableFromAvailable
2625                    }
2626                }
2627                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
2628                    // Notify guest that the data path is now synthetic.
2629                    self.guest_vf_data_path_switched_to_synthetic();
2630                    PrimaryChannelGuestVfState::UnavailableFromAvailable
2631                }
2632                PrimaryChannelGuestVfState::DataPathSynthetic => {
2633                    // Notify guest that the data path is now synthetic.
2634                    self.guest_vf_data_path_switched_to_synthetic();
2635                    PrimaryChannelGuestVfState::Ready
2636                }
2637                PrimaryChannelGuestVfState::DataPathSwitchPending {
2638                    to_guest,
2639                    id,
2640                    result,
2641                } => {
2642                    let result = result.expect("DataPathSwitchPending should have been processed");
2643                    // Complete the data path switch request.
2644                    self.send_completion(id, &[])?;
2645                    if result {
2646                        if to_guest {
2647                            PrimaryChannelGuestVfState::DataPathSwitched
2648                        } else {
2649                            PrimaryChannelGuestVfState::Ready
2650                        }
2651                    } else {
2652                        if to_guest {
2653                            PrimaryChannelGuestVfState::DataPathSynthetic
2654                        } else {
2655                            tracing::error!(
2656                                "Failure when guest requested switch back to synthetic"
2657                            );
2658                            PrimaryChannelGuestVfState::DataPathSwitched
2659                        }
2660                    }
2661                }
2662                PrimaryChannelGuestVfState::Initializing
2663                | PrimaryChannelGuestVfState::Restoring(_) => {
2664                    panic!("Invalid guest VF state: {}", primary.guest_vf_state)
2665                }
2666                _ => break,
2667            };
2668        }
2669        Ok(None)
2670    }
2671
2672    /// Process a control message, writing the response to the provided receive
2673    /// buffer suballocation.
2674    fn handle_rndis_control_message(
2675        &mut self,
2676        primary: &mut PrimaryChannelState,
2677        buffers: &ChannelBuffers,
2678        message_type: u32,
2679        mut reader: impl MemoryRead + Clone,
2680        id: u32,
2681    ) -> Result<(), WorkerError> {
2682        let mem = &buffers.mem;
2683        let buffer_range = &buffers.recv_buffer.range(id);
2684        match message_type {
2685            rndisprot::MESSAGE_TYPE_INITIALIZE_MSG => {
2686                if primary.rndis_state != RndisState::Initializing {
2687                    return Err(WorkerError::InvalidRndisState);
2688                }
2689
2690                let request: rndisprot::InitializeRequest = reader.read_plain()?;
2691
2692                tracing::trace!(
2693                    ?request,
2694                    "handling control message MESSAGE_TYPE_INITIALIZE_MSG"
2695                );
2696
2697                primary.rndis_state = RndisState::Operational;
2698
2699                let mut writer = buffer_range.writer(mem);
2700                let message_length = write_rndis_message(
2701                    &mut writer,
2702                    rndisprot::MESSAGE_TYPE_INITIALIZE_CMPLT,
2703                    0,
2704                    &rndisprot::InitializeComplete {
2705                        request_id: request.request_id,
2706                        status: rndisprot::STATUS_SUCCESS,
2707                        major_version: rndisprot::MAJOR_VERSION,
2708                        minor_version: rndisprot::MINOR_VERSION,
2709                        device_flags: rndisprot::DF_CONNECTIONLESS,
2710                        medium: rndisprot::MEDIUM_802_3,
2711                        max_packets_per_message: 8,
2712                        max_transfer_size: 0xEFFFFFFF,
2713                        packet_alignment_factor: 3,
2714                        af_list_offset: 0,
2715                        af_list_size: 0,
2716                    },
2717                )?;
2718                self.send_rndis_control_message(buffers, id, message_length)?;
2719                if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
2720                    if self.guest_vf_is_available(
2721                        Some(vfid),
2722                        buffers.version,
2723                        buffers.ndis_config,
2724                    )? {
2725                        // Ideally the VF would not be presented to the guest
2726                        // until the completion packet has arrived, so that the
2727                        // guest is prepared. This is most interesting for the
2728                        // case of a VF associated with multiple guest
2729                        // adapters, using a concept like vports. In this
2730                        // scenario it would be better if all of the adapters
2731                        // were aware a VF was available before the device
2732                        // arrived. This is not currently possible because the
2733                        // Linux netvsc driver ignores the completion requested
2734                        // flag on inband packets and won't send a completion
2735                        // packet.
2736                        primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
2737                        // restart will also add the VF based on the guest_vf_state
2738                        if self.restart.is_none() {
2739                            self.restart = Some(CoordinatorMessage::UpdateGuestVfState);
2740                        }
2741                    } else if let Some(true) = primary.is_data_path_switched {
2742                        tracing::error!(
2743                            "Data path switched, but current guest negotiation does not support VTL0 VF"
2744                        );
2745                    }
2746                }
2747            }
2748            rndisprot::MESSAGE_TYPE_QUERY_MSG => {
2749                let request: rndisprot::QueryRequest = reader.read_plain()?;
2750
2751                tracing::trace!(?request, "handling control message MESSAGE_TYPE_QUERY_MSG");
2752
2753                let (header, body) = buffer_range
2754                    .try_split(
2755                        size_of::<rndisprot::MessageHeader>()
2756                            + size_of::<rndisprot::QueryComplete>(),
2757                    )
2758                    .ok_or(WorkerError::RndisMessageTooSmall)?;
2759                let (status, tx) = match self.adapter.handle_oid_query(
2760                    buffers,
2761                    primary,
2762                    request.oid,
2763                    body.writer(mem),
2764                ) {
2765                    Ok(tx) => (rndisprot::STATUS_SUCCESS, tx),
2766                    Err(err) => (err.as_status(), 0),
2767                };
2768
2769                let message_length = write_rndis_message(
2770                    &mut header.writer(mem),
2771                    rndisprot::MESSAGE_TYPE_QUERY_CMPLT,
2772                    tx,
2773                    &rndisprot::QueryComplete {
2774                        request_id: request.request_id,
2775                        status,
2776                        information_buffer_offset: size_of::<rndisprot::QueryComplete>() as u32,
2777                        information_buffer_length: tx as u32,
2778                    },
2779                )?;
2780                self.send_rndis_control_message(buffers, id, message_length)?;
2781            }
2782            rndisprot::MESSAGE_TYPE_SET_MSG => {
2783                let request: rndisprot::SetRequest = reader.read_plain()?;
2784
2785                tracing::trace!(?request, "handling control message MESSAGE_TYPE_SET_MSG");
2786
2787                let status = match self.adapter.handle_oid_set(primary, request.oid, reader) {
2788                    Ok(restart_endpoint) => {
2789                        // Restart the endpoint if the OID changed some critical
2790                        // endpoint property.
2791                        if restart_endpoint {
2792                            self.restart = Some(CoordinatorMessage::Restart);
2793                        }
2794                        rndisprot::STATUS_SUCCESS
2795                    }
2796                    Err(err) => {
2797                        tracelimit::warn_ratelimited!(oid = ?request.oid, error = &err as &dyn std::error::Error, "oid failure");
2798                        err.as_status()
2799                    }
2800                };
2801
2802                let message_length = write_rndis_message(
2803                    &mut buffer_range.writer(mem),
2804                    rndisprot::MESSAGE_TYPE_SET_CMPLT,
2805                    0,
2806                    &rndisprot::SetComplete {
2807                        request_id: request.request_id,
2808                        status,
2809                    },
2810                )?;
2811                self.send_rndis_control_message(buffers, id, message_length)?;
2812            }
2813            rndisprot::MESSAGE_TYPE_RESET_MSG => {
2814                return Err(WorkerError::RndisMessageTypeNotImplemented);
2815            }
2816            rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG => {
2817                return Err(WorkerError::RndisMessageTypeNotImplemented);
2818            }
2819            rndisprot::MESSAGE_TYPE_KEEPALIVE_MSG => {
2820                let request: rndisprot::KeepaliveRequest = reader.read_plain()?;
2821
2822                tracing::trace!(
2823                    ?request,
2824                    "handling control message MESSAGE_TYPE_KEEPALIVE_MSG"
2825                );
2826
2827                let message_length = write_rndis_message(
2828                    &mut buffer_range.writer(mem),
2829                    rndisprot::MESSAGE_TYPE_KEEPALIVE_CMPLT,
2830                    0,
2831                    &rndisprot::KeepaliveComplete {
2832                        request_id: request.request_id,
2833                        status: rndisprot::STATUS_SUCCESS,
2834                    },
2835                )?;
2836                self.send_rndis_control_message(buffers, id, message_length)?;
2837            }
2838            rndisprot::MESSAGE_TYPE_SET_EX_MSG => {
2839                return Err(WorkerError::RndisMessageTypeNotImplemented);
2840            }
2841            _ => return Err(WorkerError::UnknownRndisMessageType(message_type)),
2842        };
2843        Ok(())
2844    }
2845
2846    fn try_send_rndis_message(
2847        &mut self,
2848        transaction_id: u64,
2849        channel_type: u32,
2850        recv_buffer_id: u16,
2851        transfer_pages: &[ring::TransferPageRange],
2852    ) -> Result<Option<usize>, WorkerError> {
2853        let message = self.message(
2854            protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET,
2855            protocol::Message1SendRndisPacket {
2856                channel_type,
2857                send_buffer_section_index: 0xffffffff,
2858                send_buffer_section_size: 0,
2859            },
2860        );
2861        let pending_send_size = match self.queue.split().1.try_write(&queue::OutgoingPacket {
2862            transaction_id,
2863            packet_type: OutgoingPacketType::TransferPages(recv_buffer_id, transfer_pages),
2864            payload: &message.payload(),
2865        }) {
2866            Ok(()) => None,
2867            Err(queue::TryWriteError::Full(n)) => Some(n),
2868            Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
2869        };
2870        Ok(pending_send_size)
2871    }
2872
2873    fn send_rndis_control_message(
2874        &mut self,
2875        buffers: &ChannelBuffers,
2876        id: u32,
2877        message_length: usize,
2878    ) -> Result<(), WorkerError> {
2879        let result = self.try_send_rndis_message(
2880            id as u64,
2881            protocol::CONTROL_CHANNEL_TYPE,
2882            buffers.recv_buffer.id,
2883            std::slice::from_ref(&buffers.recv_buffer.transfer_page_range(id, message_length)),
2884        )?;
2885
2886        // Ring size is checked before control messages are processed, so failure to write is unexpected.
2887        match result {
2888            None => Ok(()),
2889            Some(len) => {
2890                tracelimit::error_ratelimited!(len, "failed to write control message completion");
2891                Err(WorkerError::OutOfSpace)
2892            }
2893        }
2894    }
2895
2896    fn indicate_status(
2897        &mut self,
2898        buffers: &ChannelBuffers,
2899        id: u32,
2900        status: u32,
2901        payload: &[u8],
2902    ) -> Result<(), WorkerError> {
2903        let buffer = &buffers.recv_buffer.range(id);
2904        let mut writer = buffer.writer(&buffers.mem);
2905        let message_length = write_rndis_message(
2906            &mut writer,
2907            rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG,
2908            payload.len(),
2909            &rndisprot::IndicateStatus {
2910                status,
2911                status_buffer_length: payload.len() as u32,
2912                status_buffer_offset: if payload.is_empty() {
2913                    0
2914                } else {
2915                    size_of::<rndisprot::IndicateStatus>() as u32
2916                },
2917            },
2918        )?;
2919        writer.write(payload)?;
2920        self.send_rndis_control_message(buffers, id, message_length)?;
2921        Ok(())
2922    }
2923
2924    /// Processes pending control messages until all are processed or there are
2925    /// no available suballocations.
2926    fn process_control_messages(
2927        &mut self,
2928        buffers: &ChannelBuffers,
2929        state: &mut ActiveState,
2930    ) -> Result<(), WorkerError> {
2931        let Some(primary) = &mut state.primary else {
2932            return Ok(());
2933        };
2934
2935        while !primary.control_messages.is_empty()
2936            || (primary.pending_offload_change && primary.rndis_state == RndisState::Operational)
2937        {
2938            // Ensure the ring buffer has enough room to successfully complete control message handling.
2939            if !self.queue.split().1.can_write(MIN_CONTROL_RING_SIZE)? {
2940                self.pending_send_size = MIN_CONTROL_RING_SIZE;
2941                break;
2942            }
2943            let Some(id) = primary.free_control_buffers.pop() else {
2944                break;
2945            };
2946
2947            // Mark the receive buffer in use to allow the guest to release it.
2948            assert!(state.rx_bufs.is_free(id.0));
2949            state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
2950
2951            if let Some(message) = primary.control_messages.pop_front() {
2952                primary.control_messages_len -= message.data.len();
2953                self.handle_rndis_control_message(
2954                    primary,
2955                    buffers,
2956                    message.message_type,
2957                    message.data.as_ref(),
2958                    id.0,
2959                )?;
2960            } else if primary.pending_offload_change
2961                && primary.rndis_state == RndisState::Operational
2962            {
2963                let ndis_offload = primary.offload_config.ndis_offload();
2964                self.indicate_status(
2965                    buffers,
2966                    id.0,
2967                    rndisprot::STATUS_TASK_OFFLOAD_CURRENT_CONFIG,
2968                    &ndis_offload.as_bytes()[..ndis_offload.header.size.into()],
2969                )?;
2970                primary.pending_offload_change = false;
2971            } else {
2972                unreachable!();
2973            }
2974        }
2975        Ok(())
2976    }
2977}
2978
2979/// Writes an RNDIS message to `writer`.
2980fn write_rndis_message<T: IntoBytes + Immutable + KnownLayout>(
2981    writer: &mut impl MemoryWrite,
2982    message_type: u32,
2983    extra: usize,
2984    payload: &T,
2985) -> Result<usize, AccessError> {
2986    let message_length = size_of::<rndisprot::MessageHeader>() + size_of_val(payload) + extra;
2987    writer.write(
2988        rndisprot::MessageHeader {
2989            message_type,
2990            message_length: message_length as u32,
2991        }
2992        .as_bytes(),
2993    )?;
2994    writer.write(payload.as_bytes())?;
2995    Ok(message_length)
2996}
2997
2998#[derive(Debug, Error)]
2999enum OidError {
3000    #[error(transparent)]
3001    Access(#[from] AccessError),
3002    #[error("unknown oid")]
3003    UnknownOid,
3004    #[error("invalid oid input, bad field {0}")]
3005    InvalidInput(&'static str),
3006    #[error("bad ndis version")]
3007    BadVersion,
3008    #[error("feature {0} not supported")]
3009    NotSupported(&'static str),
3010}
3011
3012impl OidError {
3013    fn as_status(&self) -> u32 {
3014        match self {
3015            OidError::UnknownOid | OidError::NotSupported(_) => rndisprot::STATUS_NOT_SUPPORTED,
3016            OidError::BadVersion => rndisprot::STATUS_BAD_VERSION,
3017            OidError::InvalidInput(_) => rndisprot::STATUS_INVALID_DATA,
3018            OidError::Access(_) => rndisprot::STATUS_FAILURE,
3019        }
3020    }
3021}
3022
3023const DEFAULT_MTU: u32 = 1514;
3024const MIN_MTU: u32 = DEFAULT_MTU;
3025const MAX_MTU: u32 = 9216;
3026
3027const ETHERNET_HEADER_LEN: u32 = 14;
3028
3029impl Adapter {
3030    fn get_guest_vf_serial_number(&self, vfid: u32) -> u32 {
3031        if let Some(guest_os_id) = self.get_guest_os_id.as_ref().map(|f| f()) {
3032            // For enlightened guests (which is only Windows at the moment), send the
3033            // adapter index, which was previously set as the vport serial number.
3034            if guest_os_id
3035                .microsoft()
3036                .unwrap_or(HvGuestOsMicrosoft::from(0))
3037                .os_id()
3038                == HvGuestOsMicrosoftIds::WINDOWS_NT.0
3039            {
3040                self.adapter_index
3041            } else {
3042                vfid
3043            }
3044        } else {
3045            vfid
3046        }
3047    }
3048
3049    fn handle_oid_query(
3050        &self,
3051        buffers: &ChannelBuffers,
3052        primary: &PrimaryChannelState,
3053        oid: rndisprot::Oid,
3054        mut writer: impl MemoryWrite,
3055    ) -> Result<usize, OidError> {
3056        tracing::debug!(?oid, "oid query");
3057        let available_len = writer.len();
3058        match oid {
3059            rndisprot::Oid::OID_GEN_SUPPORTED_LIST => {
3060                let supported_oids_common = &[
3061                    rndisprot::Oid::OID_GEN_SUPPORTED_LIST,
3062                    rndisprot::Oid::OID_GEN_HARDWARE_STATUS,
3063                    rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED,
3064                    rndisprot::Oid::OID_GEN_MEDIA_IN_USE,
3065                    rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD,
3066                    rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD,
3067                    rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE,
3068                    rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE,
3069                    rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE,
3070                    rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE,
3071                    rndisprot::Oid::OID_GEN_LINK_SPEED,
3072                    rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE,
3073                    rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE,
3074                    rndisprot::Oid::OID_GEN_VENDOR_ID,
3075                    rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION,
3076                    rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION,
3077                    rndisprot::Oid::OID_GEN_DRIVER_VERSION,
3078                    rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER,
3079                    rndisprot::Oid::OID_GEN_PROTOCOL_OPTIONS,
3080                    rndisprot::Oid::OID_GEN_MAC_OPTIONS,
3081                    rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS,
3082                    rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS,
3083                    rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES,
3084                    rndisprot::Oid::OID_GEN_FRIENDLY_NAME,
3085                    // Ethernet objects operation characteristics
3086                    rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS,
3087                    rndisprot::Oid::OID_802_3_CURRENT_ADDRESS,
3088                    rndisprot::Oid::OID_802_3_MULTICAST_LIST,
3089                    rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE,
3090                    // Ethernet objects statistics
3091                    rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT,
3092                    rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION,
3093                    rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS,
3094                    // PNP operations characteristics */
3095                    // rndisprot::Oid::OID_PNP_SET_POWER,
3096                    // rndisprot::Oid::OID_PNP_QUERY_POWER,
3097                    // RNDIS OIDS
3098                    rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER,
3099                ];
3100
3101                // NDIS6 OIDs
3102                let supported_oids_6 = &[
3103                    // Link State OID
3104                    rndisprot::Oid::OID_GEN_LINK_PARAMETERS,
3105                    rndisprot::Oid::OID_GEN_LINK_STATE,
3106                    rndisprot::Oid::OID_GEN_MAX_LINK_SPEED,
3107                    // NDIS 6 statistics OID
3108                    rndisprot::Oid::OID_GEN_BYTES_RCV,
3109                    rndisprot::Oid::OID_GEN_BYTES_XMIT,
3110                    // Offload related OID
3111                    rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS,
3112                    rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION,
3113                    rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES,
3114                    rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG,
3115                    // rndisprot::Oid::OID_802_3_ADD_MULTICAST_ADDRESS,
3116                    // rndisprot::Oid::OID_802_3_DELETE_MULTICAST_ADDRESS,
3117                ];
3118
3119                let supported_oids_63 = &[
3120                    rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES,
3121                    rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS,
3122                ];
3123
3124                match buffers.ndis_version.major {
3125                    5 => {
3126                        writer.write(supported_oids_common.as_bytes())?;
3127                    }
3128                    6 => {
3129                        writer.write(supported_oids_common.as_bytes())?;
3130                        writer.write(supported_oids_6.as_bytes())?;
3131                        if buffers.ndis_version.minor >= 30 {
3132                            writer.write(supported_oids_63.as_bytes())?;
3133                        }
3134                    }
3135                    _ => return Err(OidError::BadVersion),
3136                }
3137            }
3138            rndisprot::Oid::OID_GEN_HARDWARE_STATUS => {
3139                let status: u32 = 0; // NdisHardwareStatusReady
3140                writer.write(status.as_bytes())?;
3141            }
3142            rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED | rndisprot::Oid::OID_GEN_MEDIA_IN_USE => {
3143                writer.write(rndisprot::MEDIUM_802_3.as_bytes())?;
3144            }
3145            rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD
3146            | rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD
3147            | rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE => {
3148                let len: u32 = buffers.ndis_config.mtu - ETHERNET_HEADER_LEN;
3149                writer.write(len.as_bytes())?;
3150            }
3151            rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE
3152            | rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE
3153            | rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE => {
3154                let len: u32 = buffers.ndis_config.mtu;
3155                writer.write(len.as_bytes())?;
3156            }
3157            rndisprot::Oid::OID_GEN_LINK_SPEED => {
3158                let speed: u32 = (self.link_speed / 100) as u32; // In 100bps units
3159                writer.write(speed.as_bytes())?;
3160            }
3161            rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE
3162            | rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE => {
3163                // This value is meaningless for virtual NICs. Return what vmswitch returns.
3164                writer.write((256u32 * 1024).as_bytes())?
3165            }
3166            rndisprot::Oid::OID_GEN_VENDOR_ID => {
3167                // Like vmswitch, use the first N bytes of Microsoft's MAC address
3168                // prefix as the vendor ID.
3169                writer.write(0x0000155du32.as_bytes())?;
3170            }
3171            rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION => writer.write(b"Microsoft Corporation")?,
3172            rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION
3173            | rndisprot::Oid::OID_GEN_DRIVER_VERSION => {
3174                writer.write(0x0100u16.as_bytes())? // 1.0. Vmswitch reports 19.0 for Mn.
3175            }
3176            rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => writer.write(0u32.as_bytes())?,
3177            rndisprot::Oid::OID_GEN_MAC_OPTIONS => {
3178                let options: u32 = rndisprot::MAC_OPTION_COPY_LOOKAHEAD_DATA
3179                    | rndisprot::MAC_OPTION_TRANSFERS_NOT_PEND
3180                    | rndisprot::MAC_OPTION_NO_LOOPBACK;
3181                writer.write(options.as_bytes())?;
3182            }
3183            rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS => {
3184                writer.write(rndisprot::MEDIA_STATE_CONNECTED.as_bytes())?;
3185            }
3186            rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS => writer.write(u32::MAX.as_bytes())?,
3187            rndisprot::Oid::OID_GEN_FRIENDLY_NAME => {
3188                let name16: Vec<u16> = "Network Device".encode_utf16().collect();
3189                let mut name = rndisprot::FriendlyName::new_zeroed();
3190                name.name[..name16.len()].copy_from_slice(&name16);
3191                writer.write(name.as_bytes())?
3192            }
3193            rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS
3194            | rndisprot::Oid::OID_802_3_CURRENT_ADDRESS => {
3195                writer.write(&self.mac_address.to_bytes())?
3196            }
3197            rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE => {
3198                writer.write(0u32.as_bytes())?;
3199            }
3200            rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT
3201            | rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION
3202            | rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS => writer.write(0u32.as_bytes())?,
3203
3204            // NDIS6 OIDs:
3205            rndisprot::Oid::OID_GEN_LINK_STATE => {
3206                let link_state = rndisprot::LinkState {
3207                    header: rndisprot::NdisObjectHeader {
3208                        object_type: rndisprot::NdisObjectType::DEFAULT,
3209                        revision: 1,
3210                        size: size_of::<rndisprot::LinkState>() as u16,
3211                    },
3212                    media_connect_state: 1, /* MediaConnectStateConnected */
3213                    media_duplex_state: 0,  /* MediaDuplexStateUnknown */
3214                    padding: 0,
3215                    xmit_link_speed: self.link_speed,
3216                    rcv_link_speed: self.link_speed,
3217                    pause_functions: 0, /* NdisPauseFunctionsUnsupported */
3218                    auto_negotiation_flags: 0,
3219                };
3220                writer.write(link_state.as_bytes())?;
3221            }
3222            rndisprot::Oid::OID_GEN_MAX_LINK_SPEED => {
3223                let link_speed = rndisprot::LinkSpeed {
3224                    xmit: self.link_speed,
3225                    rcv: self.link_speed,
3226                };
3227                writer.write(link_speed.as_bytes())?;
3228            }
3229            rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES => {
3230                let ndis_offload = self.offload_support.ndis_offload();
3231                writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3232            }
3233            rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG => {
3234                let ndis_offload = primary.offload_config.ndis_offload();
3235                writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3236            }
3237            rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3238                writer.write(
3239                    &rndisprot::NdisOffloadEncapsulation {
3240                        header: rndisprot::NdisObjectHeader {
3241                            object_type: rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3242                            revision: 1,
3243                            size: rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1 as u16,
3244                        },
3245                        ipv4_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3246                        ipv4_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3247                        ipv4_header_size: ETHERNET_HEADER_LEN,
3248                        ipv6_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3249                        ipv6_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3250                        ipv6_header_size: ETHERNET_HEADER_LEN,
3251                    }
3252                    .as_bytes()[..rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1],
3253                )?;
3254            }
3255            rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES => {
3256                writer.write(
3257                    &rndisprot::NdisReceiveScaleCapabilities {
3258                        header: rndisprot::NdisObjectHeader {
3259                            object_type: rndisprot::NdisObjectType::RSS_CAPABILITIES,
3260                            revision: 2,
3261                            size: rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2
3262                                as u16,
3263                        },
3264                        capabilities_flags: rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV4
3265                            | rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV6
3266                            | rndisprot::NDIS_HASH_FUNCTION_TOEPLITZ,
3267                        number_of_interrupt_messages: 1,
3268                        number_of_receive_queues: self.max_queues.into(),
3269                        number_of_indirection_table_entries: if self.indirection_table_size != 0 {
3270                            self.indirection_table_size
3271                        } else {
3272                            // DPDK gets confused if the table size is zero,
3273                            // even if there is only one queue.
3274                            128
3275                        },
3276                        padding: 0,
3277                    }
3278                    .as_bytes()[..rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2],
3279                )?;
3280            }
3281            _ => {
3282                tracelimit::warn_ratelimited!(?oid, "query for unknown OID");
3283                return Err(OidError::UnknownOid);
3284            }
3285        };
3286        Ok(available_len - writer.len())
3287    }
3288
3289    fn handle_oid_set(
3290        &self,
3291        primary: &mut PrimaryChannelState,
3292        oid: rndisprot::Oid,
3293        reader: impl MemoryRead + Clone,
3294    ) -> Result<bool, OidError> {
3295        tracing::debug!(?oid, "oid set");
3296
3297        let mut restart_endpoint = false;
3298        match oid {
3299            rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => {
3300                // TODO
3301            }
3302            rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS => {
3303                self.oid_set_offload_parameters(reader, primary)?;
3304            }
3305            rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3306                self.oid_set_offload_encapsulation(reader)?;
3307            }
3308            rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER => {
3309                self.oid_set_rndis_config_parameter(reader, primary)?;
3310            }
3311            rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES => {
3312                // TODO
3313            }
3314            rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS => {
3315                self.oid_set_rss_parameters(reader, primary)?;
3316
3317                // Endpoints cannot currently change RSS parameters without
3318                // being restarted. This was a limitation driven by some DPDK
3319                // PMDs, and should be fixed.
3320                restart_endpoint = true;
3321            }
3322            _ => {
3323                tracelimit::warn_ratelimited!(?oid, "set of unknown OID");
3324                return Err(OidError::UnknownOid);
3325            }
3326        }
3327        Ok(restart_endpoint)
3328    }
3329
3330    fn oid_set_rss_parameters(
3331        &self,
3332        mut reader: impl MemoryRead + Clone,
3333        primary: &mut PrimaryChannelState,
3334    ) -> Result<(), OidError> {
3335        // Vmswitch doesn't validate the NDIS header on this object, so read it manually.
3336        let mut params = rndisprot::NdisReceiveScaleParameters::new_zeroed();
3337        let len = reader.len().min(size_of_val(&params));
3338        reader.clone().read(&mut params.as_mut_bytes()[..len])?;
3339
3340        if ((params.flags & NDIS_RSS_PARAM_FLAG_DISABLE_RSS) != 0)
3341            || ((params.hash_information & NDIS_HASH_FUNCTION_MASK) == 0)
3342        {
3343            primary.rss_state = None;
3344            return Ok(());
3345        }
3346
3347        if params.hash_secret_key_size != 40 {
3348            return Err(OidError::InvalidInput("hash_secret_key_size"));
3349        }
3350        if params.indirection_table_size % 4 != 0 {
3351            return Err(OidError::InvalidInput("indirection_table_size"));
3352        }
3353        let indirection_table_size =
3354            (params.indirection_table_size / 4).min(self.indirection_table_size) as usize;
3355        let mut key = [0; 40];
3356        let mut indirection_table = vec![0u32; self.indirection_table_size as usize];
3357        reader
3358            .clone()
3359            .skip(params.hash_secret_key_offset as usize)?
3360            .read(&mut key)?;
3361        reader
3362            .skip(params.indirection_table_offset as usize)?
3363            .read(indirection_table[..indirection_table_size].as_mut_bytes())?;
3364        tracelimit::info_ratelimited!(?indirection_table, "OID_GEN_RECEIVE_SCALE_PARAMETERS");
3365        if indirection_table
3366            .iter()
3367            .any(|&x| x >= self.max_queues as u32)
3368        {
3369            return Err(OidError::InvalidInput("indirection_table"));
3370        }
3371        let (indir_init, indir_uninit) = indirection_table.split_at_mut(indirection_table_size);
3372        for (src, dest) in std::iter::repeat_with(|| indir_init.iter().copied())
3373            .flatten()
3374            .zip(indir_uninit)
3375        {
3376            *dest = src;
3377        }
3378        primary.rss_state = Some(RssState {
3379            key,
3380            indirection_table: indirection_table.iter().map(|&x| x as u16).collect(),
3381        });
3382        Ok(())
3383    }
3384
3385    fn oid_set_offload_parameters(
3386        &self,
3387        reader: impl MemoryRead + Clone,
3388        primary: &mut PrimaryChannelState,
3389    ) -> Result<(), OidError> {
3390        let offload: rndisprot::NdisOffloadParameters = read_ndis_object(
3391            reader,
3392            rndisprot::NdisObjectType::DEFAULT,
3393            1,
3394            rndisprot::NDIS_SIZEOF_OFFLOAD_PARAMETERS_REVISION_1,
3395        )?;
3396
3397        tracing::debug!(?offload, "offload parameters");
3398        let rndisprot::NdisOffloadParameters {
3399            header: _,
3400            ipv4_checksum,
3401            tcp4_checksum,
3402            udp4_checksum,
3403            tcp6_checksum,
3404            udp6_checksum,
3405            lsov1,
3406            ipsec_v1: _,
3407            lsov2_ipv4,
3408            lsov2_ipv6,
3409            tcp_connection_ipv4: _,
3410            tcp_connection_ipv6: _,
3411            reserved: _,
3412            flags: _,
3413        } = offload;
3414
3415        if lsov1 == rndisprot::OffloadParametersSimple::ENABLED {
3416            return Err(OidError::NotSupported("lsov1"));
3417        }
3418        if let Some((tx, rx)) = ipv4_checksum.tx_rx() {
3419            primary.offload_config.checksum_tx.ipv4_header = tx;
3420            primary.offload_config.checksum_rx.ipv4_header = rx;
3421        }
3422        if let Some((tx, rx)) = tcp4_checksum.tx_rx() {
3423            primary.offload_config.checksum_tx.tcp4 = tx;
3424            primary.offload_config.checksum_rx.tcp4 = rx;
3425        }
3426        if let Some((tx, rx)) = tcp6_checksum.tx_rx() {
3427            primary.offload_config.checksum_tx.tcp6 = tx;
3428            primary.offload_config.checksum_rx.tcp6 = rx;
3429        }
3430        if let Some((tx, rx)) = udp4_checksum.tx_rx() {
3431            primary.offload_config.checksum_tx.udp4 = tx;
3432            primary.offload_config.checksum_rx.udp4 = rx;
3433        }
3434        if let Some((tx, rx)) = udp6_checksum.tx_rx() {
3435            primary.offload_config.checksum_tx.udp6 = tx;
3436            primary.offload_config.checksum_rx.udp6 = rx;
3437        }
3438        if let Some(enable) = lsov2_ipv4.enable() {
3439            primary.offload_config.lso4 = enable && self.offload_support.lso4;
3440        }
3441        if let Some(enable) = lsov2_ipv6.enable() {
3442            primary.offload_config.lso6 = enable && self.offload_support.lso6;
3443        }
3444        primary.pending_offload_change = true;
3445        Ok(())
3446    }
3447
3448    fn oid_set_offload_encapsulation(
3449        &self,
3450        reader: impl MemoryRead + Clone,
3451    ) -> Result<(), OidError> {
3452        let encap: rndisprot::NdisOffloadEncapsulation = read_ndis_object(
3453            reader,
3454            rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3455            1,
3456            rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1,
3457        )?;
3458        if encap.ipv4_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3459            && (encap.ipv4_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3460                || encap.ipv4_header_size != ETHERNET_HEADER_LEN)
3461        {
3462            return Err(OidError::NotSupported("ipv4 encap"));
3463        }
3464        if encap.ipv6_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3465            && (encap.ipv6_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3466                || encap.ipv6_header_size != ETHERNET_HEADER_LEN)
3467        {
3468            return Err(OidError::NotSupported("ipv6 encap"));
3469        }
3470        Ok(())
3471    }
3472
3473    fn oid_set_rndis_config_parameter(
3474        &self,
3475        reader: impl MemoryRead + Clone,
3476        primary: &mut PrimaryChannelState,
3477    ) -> Result<(), OidError> {
3478        let info: rndisprot::RndisConfigParameterInfo = reader.clone().read_plain()?;
3479        if info.name_length > 255 {
3480            return Err(OidError::InvalidInput("name_length"));
3481        }
3482        if info.value_length > 255 {
3483            return Err(OidError::InvalidInput("value_length"));
3484        }
3485        let name = reader
3486            .clone()
3487            .skip(info.name_offset as usize)?
3488            .read_n::<u16>(info.name_length as usize / 2)?;
3489        let name = String::from_utf16(&name).map_err(|_| OidError::InvalidInput("name"))?;
3490        let mut value = reader;
3491        value.skip(info.value_offset as usize)?;
3492        let mut value = value.limit(info.value_length as usize);
3493        match info.parameter_type {
3494            rndisprot::NdisParameterType::STRING => {
3495                let value = value.read_n::<u16>(info.value_length as usize / 2)?;
3496                let value =
3497                    String::from_utf16(&value).map_err(|_| OidError::InvalidInput("value"))?;
3498                let as_num = value.as_bytes().first().map_or(0, |c| c - b'0');
3499                let tx = as_num & 1 != 0;
3500                let rx = as_num & 2 != 0;
3501
3502                tracing::debug!(name, value, "rndis config");
3503                match name.as_str() {
3504                    "*IPChecksumOffloadIPv4" => {
3505                        primary.offload_config.checksum_tx.ipv4_header = tx;
3506                        primary.offload_config.checksum_rx.ipv4_header = rx;
3507                    }
3508                    "*LsoV2IPv4" => {
3509                        primary.offload_config.lso4 = as_num != 0 && self.offload_support.lso4;
3510                    }
3511                    "*LsoV2IPv6" => {
3512                        primary.offload_config.lso6 = as_num != 0 && self.offload_support.lso6;
3513                    }
3514                    "*TCPChecksumOffloadIPv4" => {
3515                        primary.offload_config.checksum_tx.tcp4 = tx;
3516                        primary.offload_config.checksum_rx.tcp4 = rx;
3517                    }
3518                    "*TCPChecksumOffloadIPv6" => {
3519                        primary.offload_config.checksum_tx.tcp6 = tx;
3520                        primary.offload_config.checksum_rx.tcp6 = rx;
3521                    }
3522                    "*UDPChecksumOffloadIPv4" => {
3523                        primary.offload_config.checksum_tx.udp4 = tx;
3524                        primary.offload_config.checksum_rx.udp4 = rx;
3525                    }
3526                    "*UDPChecksumOffloadIPv6" => {
3527                        primary.offload_config.checksum_tx.udp6 = tx;
3528                        primary.offload_config.checksum_rx.udp6 = rx;
3529                    }
3530                    _ => {}
3531                }
3532            }
3533            rndisprot::NdisParameterType::INTEGER => {
3534                let value: u32 = value.read_plain()?;
3535                tracing::debug!(name, value, "rndis config");
3536            }
3537            parameter_type => tracelimit::warn_ratelimited!(
3538                name,
3539                ?parameter_type,
3540                "unhandled rndis config parameter type"
3541            ),
3542        }
3543        Ok(())
3544    }
3545}
3546
3547fn read_ndis_object<T: IntoBytes + FromBytes + Debug + Immutable + KnownLayout>(
3548    mut reader: impl MemoryRead,
3549    object_type: rndisprot::NdisObjectType,
3550    min_revision: u8,
3551    min_size: usize,
3552) -> Result<T, OidError> {
3553    let mut buffer = T::new_zeroed();
3554    let sent_size = reader.len();
3555    let len = sent_size.min(size_of_val(&buffer));
3556    reader.read(&mut buffer.as_mut_bytes()[..len])?;
3557    validate_ndis_object_header(
3558        &rndisprot::NdisObjectHeader::read_from_prefix(buffer.as_bytes())
3559            .unwrap()
3560            .0, // TODO: zerocopy: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
3561        sent_size,
3562        object_type,
3563        min_revision,
3564        min_size,
3565    )?;
3566    Ok(buffer)
3567}
3568
3569fn validate_ndis_object_header(
3570    header: &rndisprot::NdisObjectHeader,
3571    sent_size: usize,
3572    object_type: rndisprot::NdisObjectType,
3573    min_revision: u8,
3574    min_size: usize,
3575) -> Result<(), OidError> {
3576    if header.object_type != object_type {
3577        return Err(OidError::InvalidInput("header.object_type"));
3578    }
3579    if sent_size < header.size as usize {
3580        return Err(OidError::InvalidInput("header.size"));
3581    }
3582    if header.revision < min_revision {
3583        return Err(OidError::InvalidInput("header.revision"));
3584    }
3585    if (header.size as usize) < min_size {
3586        return Err(OidError::InvalidInput("header.size"));
3587    }
3588    Ok(())
3589}
3590
3591struct Coordinator {
3592    recv: mpsc::Receiver<CoordinatorMessage>,
3593    channel_control: ChannelControl,
3594    restart: bool,
3595    workers: Vec<TaskControl<NetQueue, Worker<GpadlRingMem>>>,
3596    buffers: Option<Arc<ChannelBuffers>>,
3597    num_queues: u16,
3598}
3599
3600/// Removing the VF may result in the guest sending messages to switch the data
3601/// path, so these operations need to happen asynchronously with message
3602/// processing.
3603enum CoordinatorStatePendingVfState {
3604    /// No pending updates.
3605    Ready,
3606    /// Delay before adding VF.
3607    Delay {
3608        timer: PolledTimer,
3609        delay_until: Instant,
3610    },
3611    /// A VF update is pending.
3612    Pending,
3613}
3614
3615struct CoordinatorState {
3616    endpoint: Box<dyn Endpoint>,
3617    adapter: Arc<Adapter>,
3618    virtual_function: Option<Box<dyn VirtualFunction>>,
3619    pending_vf_state: CoordinatorStatePendingVfState,
3620}
3621
3622impl InspectTaskMut<Coordinator> for CoordinatorState {
3623    fn inspect_mut(
3624        &mut self,
3625        req: inspect::Request<'_>,
3626        mut coordinator: Option<&mut Coordinator>,
3627    ) {
3628        let mut resp = req.respond();
3629
3630        let adapter = self.adapter.as_ref();
3631        resp.field("mac_address", adapter.mac_address)
3632            .field("max_queues", adapter.max_queues)
3633            .sensitivity_field(
3634                "offload_support",
3635                SensitivityLevel::Safe,
3636                &adapter.offload_support,
3637            )
3638            .field_mut_with("ring_size_limit", |v| -> anyhow::Result<_> {
3639                if let Some(v) = v {
3640                    let v: usize = v.parse()?;
3641                    adapter.ring_size_limit.store(v, Ordering::Relaxed);
3642                    // Bounce each task so that it sees the new value.
3643                    if let Some(this) = &mut coordinator {
3644                        for worker in &mut this.workers {
3645                            worker.update_with(|_, _| ());
3646                        }
3647                    }
3648                }
3649                Ok(adapter.ring_size_limit.load(Ordering::Relaxed))
3650            });
3651
3652        resp.field("endpoint_type", self.endpoint.endpoint_type())
3653            .field(
3654                "endpoint_max_queues",
3655                self.endpoint.multiqueue_support().max_queues,
3656            )
3657            .sensitivity_field_mut("endpoint", SensitivityLevel::Safe, self.endpoint.as_mut());
3658
3659        if let Some(coordinator) = coordinator {
3660            resp.sensitivity_child("queues", SensitivityLevel::Safe, |req| {
3661                let mut resp = req.respond();
3662                for (i, q) in coordinator.workers[..coordinator.num_queues as usize]
3663                    .iter_mut()
3664                    .enumerate()
3665                {
3666                    resp.field_mut(&i.to_string(), q);
3667                }
3668            });
3669
3670            // Get the shared channel state from the primary channel.
3671            resp.merge(inspect::adhoc_mut(|req| {
3672                let deferred = req.defer();
3673                coordinator.workers[0].update_with(|_, worker| {
3674                    if let Some(state) = worker.and_then(|worker| worker.state.ready()) {
3675                        deferred.respond(|resp| {
3676                            resp.merge(&state.buffers);
3677                            resp.sensitivity_field(
3678                                "primary_channel_state",
3679                                SensitivityLevel::Safe,
3680                                &state.state.primary,
3681                            );
3682                        });
3683                    }
3684                })
3685            }));
3686        }
3687    }
3688}
3689
3690impl AsyncRun<Coordinator> for CoordinatorState {
3691    async fn run(
3692        &mut self,
3693        stop: &mut StopTask<'_>,
3694        coordinator: &mut Coordinator,
3695    ) -> Result<(), task_control::Cancelled> {
3696        coordinator.process(stop, self).await
3697    }
3698}
3699
3700impl Coordinator {
3701    async fn process(
3702        &mut self,
3703        stop: &mut StopTask<'_>,
3704        state: &mut CoordinatorState,
3705    ) -> Result<(), task_control::Cancelled> {
3706        let mut sleep_duration: Option<Instant> = None;
3707        loop {
3708            if self.restart {
3709                stop.until_stopped(self.stop_workers()).await?;
3710                // The queue restart operation is not restartable, so do not
3711                // poll on `stop` here.
3712                if let Err(err) = self
3713                    .restart_queues(state)
3714                    .instrument(tracing::info_span!("netvsp_restart_queues"))
3715                    .await
3716                {
3717                    tracing::error!(
3718                        error = &err as &dyn std::error::Error,
3719                        "failed to restart queues"
3720                    );
3721                }
3722                if let Some(primary) = self.primary_mut() {
3723                    primary.is_data_path_switched =
3724                        state.endpoint.get_data_path_to_guest_vf().await.ok();
3725                    tracing::info!(
3726                        is_data_path_switched = primary.is_data_path_switched,
3727                        "Query data path state"
3728                    );
3729                }
3730                self.restore_guest_vf_state(state).await;
3731                self.restart = false;
3732            }
3733
3734            // Ensure that all workers except the primary are started. The
3735            // primary is started below if there are no outstanding messages.
3736            for worker in &mut self.workers[1..] {
3737                worker.start();
3738            }
3739
3740            enum Message {
3741                Internal(CoordinatorMessage),
3742                ChannelDisconnected,
3743                UpdateFromEndpoint(EndpointAction),
3744                UpdateFromVf(Rpc<(), ()>),
3745                OfferVfDevice,
3746                PendingVfStateComplete,
3747                TimerExpired,
3748            }
3749            let timer_sleep = async {
3750                if let Some(sleep_duration) = sleep_duration {
3751                    let mut timer = PolledTimer::new(&state.adapter.driver);
3752                    timer.sleep_until(sleep_duration).await;
3753                } else {
3754                    pending::<()>().await;
3755                }
3756                Message::TimerExpired
3757            };
3758            let message = {
3759                let wait_for_message = async {
3760                    let internal_msg = self
3761                        .recv
3762                        .next()
3763                        .map(|x| x.map_or(Message::ChannelDisconnected, Message::Internal));
3764                    let endpoint_restart = state
3765                        .endpoint
3766                        .wait_for_endpoint_action()
3767                        .map(Message::UpdateFromEndpoint);
3768                    if let Some(vf) = state.virtual_function.as_mut() {
3769                        match state.pending_vf_state {
3770                            CoordinatorStatePendingVfState::Ready
3771                            | CoordinatorStatePendingVfState::Delay { .. } => {
3772                                let offer_device = async {
3773                                    if let CoordinatorStatePendingVfState::Delay {
3774                                        timer,
3775                                        delay_until,
3776                                    } = &mut state.pending_vf_state
3777                                    {
3778                                        timer.sleep_until(*delay_until).await;
3779                                    } else {
3780                                        pending::<()>().await;
3781                                    }
3782                                    Message::OfferVfDevice
3783                                };
3784                                (
3785                                    internal_msg,
3786                                    offer_device,
3787                                    endpoint_restart,
3788                                    vf.wait_for_state_change().map(Message::UpdateFromVf),
3789                                    timer_sleep,
3790                                )
3791                                    .race()
3792                                    .await
3793                            }
3794                            CoordinatorStatePendingVfState::Pending => {
3795                                // Allow the network workers to continue while
3796                                // waiting for the Vf add/remove call to
3797                                // complete, but block any other notifications
3798                                // while it is running. This is necessary to
3799                                // support Vf removal, which may trigger the
3800                                // guest to send a switch data path request and
3801                                // wait for a completion message as part of
3802                                // its eject handling. The switch data path
3803                                // request won't send a message here because
3804                                // the Vf is not available -- it will be a
3805                                // no-op.
3806                                vf.guest_ready_for_device().await;
3807                                Message::PendingVfStateComplete
3808                            }
3809                        }
3810                    } else {
3811                        (internal_msg, endpoint_restart, timer_sleep).race().await
3812                    }
3813                };
3814
3815                let mut wait_for_message = std::pin::pin!(wait_for_message);
3816                match (&mut wait_for_message).now_or_never() {
3817                    Some(message) => message,
3818                    None => {
3819                        self.workers[0].start();
3820                        stop.until_stopped(wait_for_message).await?
3821                    }
3822                }
3823            };
3824            match message {
3825                Message::UpdateFromVf(rpc) => {
3826                    rpc.handle(async |_| {
3827                        self.update_guest_vf_state(state).await;
3828                    })
3829                    .await;
3830                }
3831                Message::OfferVfDevice => {
3832                    self.workers[0].stop().await;
3833                    if let Some(primary) = self.primary_mut() {
3834                        if matches!(
3835                            primary.guest_vf_state,
3836                            PrimaryChannelGuestVfState::AvailableAdvertised
3837                        ) {
3838                            primary.guest_vf_state = PrimaryChannelGuestVfState::Ready;
3839                        }
3840                    }
3841
3842                    state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
3843                }
3844                Message::PendingVfStateComplete => {
3845                    state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
3846                }
3847                Message::TimerExpired => {
3848                    // Kick the worker as requested.
3849                    if self.workers[0].is_running() {
3850                        self.workers[0].stop().await;
3851                        if let Some(primary) = self.primary_mut() {
3852                            if let PendingLinkAction::Delay(up) = primary.pending_link_action {
3853                                primary.pending_link_action = PendingLinkAction::Active(up);
3854                            }
3855                        }
3856                    }
3857                    sleep_duration = None;
3858                }
3859                Message::Internal(CoordinatorMessage::UpdateGuestVfState) => {
3860                    self.update_guest_vf_state(state).await;
3861                }
3862                Message::UpdateFromEndpoint(EndpointAction::RestartRequired) => self.restart = true,
3863                Message::UpdateFromEndpoint(EndpointAction::LinkStatusNotify(connect)) => {
3864                    self.workers[0].stop().await;
3865
3866                    // These are the only link state transitions that are tracked.
3867                    // 1. up -> down or down -> up
3868                    // 2. up -> down -> up or down -> up -> down.
3869                    // All other state transitions are coalesced into one of the above cases.
3870                    // For example, up -> down -> up -> down is treated as up -> down.
3871                    // N.B - Always queue up the incoming state to minimize the effects of loss
3872                    //       of any notifications (for example, during vtl2 servicing).
3873                    if let Some(primary) = self.primary_mut() {
3874                        primary.pending_link_action = PendingLinkAction::Active(connect);
3875                    }
3876
3877                    // If there is any existing sleep timer running, cancel it out.
3878                    sleep_duration = None;
3879                }
3880                Message::Internal(CoordinatorMessage::Restart) => self.restart = true,
3881                Message::Internal(CoordinatorMessage::StartTimer(duration)) => {
3882                    sleep_duration = Some(duration);
3883                    // Restart primary task.
3884                    self.workers[0].stop().await;
3885                }
3886                Message::ChannelDisconnected => {
3887                    break;
3888                }
3889            };
3890        }
3891        Ok(())
3892    }
3893
3894    async fn stop_workers(&mut self) {
3895        for worker in &mut self.workers {
3896            worker.stop().await;
3897        }
3898    }
3899
3900    async fn restore_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
3901        let primary = match self.primary_mut() {
3902            Some(primary) => primary,
3903            None => return,
3904        };
3905
3906        // Update guest VF state based on current endpoint properties.
3907        let virtual_function = c_state.virtual_function.as_mut();
3908        let guest_vf_id = match &virtual_function {
3909            Some(vf) => vf.id().await,
3910            None => None,
3911        };
3912        if let Some(guest_vf_id) = guest_vf_id {
3913            // Ensure guest VF is in proper state.
3914            match primary.guest_vf_state {
3915                PrimaryChannelGuestVfState::AvailableAdvertised
3916                | PrimaryChannelGuestVfState::Restoring(
3917                    saved_state::GuestVfState::AvailableAdvertised,
3918                ) => {
3919                    if !primary.is_data_path_switched.unwrap_or(false) {
3920                        let timer = PolledTimer::new(&c_state.adapter.driver);
3921                        c_state.pending_vf_state = CoordinatorStatePendingVfState::Delay {
3922                            timer,
3923                            delay_until: Instant::now() + VF_DEVICE_DELAY,
3924                        };
3925                    }
3926                }
3927                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
3928                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
3929                | PrimaryChannelGuestVfState::Ready
3930                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready)
3931                | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
3932                | PrimaryChannelGuestVfState::Restoring(
3933                    saved_state::GuestVfState::DataPathSwitchPending { .. },
3934                )
3935                | PrimaryChannelGuestVfState::DataPathSwitched
3936                | PrimaryChannelGuestVfState::Restoring(
3937                    saved_state::GuestVfState::DataPathSwitched,
3938                )
3939                | PrimaryChannelGuestVfState::DataPathSynthetic => {
3940                    c_state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
3941                }
3942                _ => (),
3943            };
3944            // ensure data path is switched as expected
3945            if let PrimaryChannelGuestVfState::Restoring(
3946                saved_state::GuestVfState::DataPathSwitchPending {
3947                    to_guest,
3948                    id,
3949                    result,
3950                },
3951            ) = primary.guest_vf_state
3952            {
3953                // If the save was after the data path switch already occurred, don't do it again.
3954                if result.is_some() {
3955                    primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
3956                        to_guest,
3957                        id,
3958                        result,
3959                    };
3960                    return;
3961                }
3962            }
3963            primary.guest_vf_state = match primary.guest_vf_state {
3964                PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
3965                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
3966                | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
3967                | PrimaryChannelGuestVfState::Restoring(
3968                    saved_state::GuestVfState::DataPathSwitchPending { .. },
3969                )
3970                | PrimaryChannelGuestVfState::DataPathSynthetic => {
3971                    let (to_guest, id) = match primary.guest_vf_state {
3972                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
3973                            to_guest,
3974                            id,
3975                        }
3976                        | PrimaryChannelGuestVfState::DataPathSwitchPending {
3977                            to_guest, id, ..
3978                        }
3979                        | PrimaryChannelGuestVfState::Restoring(
3980                            saved_state::GuestVfState::DataPathSwitchPending {
3981                                to_guest, id, ..
3982                            },
3983                        ) => (to_guest, id),
3984                        _ => (true, None),
3985                    };
3986                    // Cancel any outstanding delay timers for VF offers if the data path is
3987                    // getting switched. Those timers are essentially no-op at this point.
3988                    if matches!(
3989                        c_state.pending_vf_state,
3990                        CoordinatorStatePendingVfState::Delay { .. }
3991                    ) {
3992                        c_state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
3993                    }
3994                    let result = c_state.endpoint.set_data_path_to_guest_vf(to_guest).await;
3995                    let result = if let Err(err) = result {
3996                        tracing::error!(
3997                            err = err.as_ref() as &dyn std::error::Error,
3998                            to_guest,
3999                            "Failed to switch guest VF data path"
4000                        );
4001                        false
4002                    } else {
4003                        primary.is_data_path_switched = Some(to_guest);
4004                        true
4005                    };
4006                    match primary.guest_vf_state {
4007                        PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4008                            ..
4009                        }
4010                        | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4011                        | PrimaryChannelGuestVfState::Restoring(
4012                            saved_state::GuestVfState::DataPathSwitchPending { .. },
4013                        ) => PrimaryChannelGuestVfState::DataPathSwitchPending {
4014                            to_guest,
4015                            id,
4016                            result: Some(result),
4017                        },
4018                        _ if result => PrimaryChannelGuestVfState::DataPathSwitched,
4019                        _ => PrimaryChannelGuestVfState::DataPathSynthetic,
4020                    }
4021                }
4022                PrimaryChannelGuestVfState::Initializing
4023                | PrimaryChannelGuestVfState::Unavailable
4024                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
4025                    PrimaryChannelGuestVfState::Available { vfid: guest_vf_id }
4026                }
4027                PrimaryChannelGuestVfState::AvailableAdvertised
4028                | PrimaryChannelGuestVfState::Restoring(
4029                    saved_state::GuestVfState::AvailableAdvertised,
4030                ) => {
4031                    if !primary.is_data_path_switched.unwrap_or(false) {
4032                        PrimaryChannelGuestVfState::AvailableAdvertised
4033                    } else {
4034                        // A previous instantiation already switched the data
4035                        // path.
4036                        PrimaryChannelGuestVfState::DataPathSwitched
4037                    }
4038                }
4039                PrimaryChannelGuestVfState::DataPathSwitched
4040                | PrimaryChannelGuestVfState::Restoring(
4041                    saved_state::GuestVfState::DataPathSwitched,
4042                ) => PrimaryChannelGuestVfState::DataPathSwitched,
4043                PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4044                    PrimaryChannelGuestVfState::Ready
4045                }
4046                _ => primary.guest_vf_state,
4047            };
4048        } else {
4049            // If the device was just removed, make sure the the data path is synthetic.
4050            match primary.guest_vf_state {
4051                PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, .. }
4052                | PrimaryChannelGuestVfState::Restoring(
4053                    saved_state::GuestVfState::DataPathSwitchPending { to_guest, .. },
4054                ) => {
4055                    if !to_guest {
4056                        if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4057                            tracing::warn!(
4058                                err = err.as_ref() as &dyn std::error::Error,
4059                                "Failed setting data path back to synthetic after guest VF was removed."
4060                            );
4061                        }
4062                        primary.is_data_path_switched = Some(false);
4063                    }
4064                }
4065                PrimaryChannelGuestVfState::DataPathSwitched
4066                | PrimaryChannelGuestVfState::Restoring(
4067                    saved_state::GuestVfState::DataPathSwitched,
4068                ) => {
4069                    if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4070                        tracing::warn!(
4071                            err = err.as_ref() as &dyn std::error::Error,
4072                            "Failed setting data path back to synthetic after guest VF was removed."
4073                        );
4074                    }
4075                    primary.is_data_path_switched = Some(false);
4076                }
4077                _ => (),
4078            }
4079            if let PrimaryChannelGuestVfState::AvailableAdvertised = primary.guest_vf_state {
4080                c_state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
4081            }
4082            // Notify guest if VF is no longer available
4083            primary.guest_vf_state = match primary.guest_vf_state {
4084                PrimaryChannelGuestVfState::Initializing
4085                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState)
4086                | PrimaryChannelGuestVfState::Available { .. } => {
4087                    PrimaryChannelGuestVfState::Unavailable
4088                }
4089                PrimaryChannelGuestVfState::AvailableAdvertised
4090                | PrimaryChannelGuestVfState::Restoring(
4091                    saved_state::GuestVfState::AvailableAdvertised,
4092                )
4093                | PrimaryChannelGuestVfState::Ready
4094                | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4095                    PrimaryChannelGuestVfState::UnavailableFromAvailable
4096                }
4097                PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, id, .. }
4098                | PrimaryChannelGuestVfState::Restoring(
4099                    saved_state::GuestVfState::DataPathSwitchPending { to_guest, id, .. },
4100                ) => PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4101                    to_guest,
4102                    id,
4103                },
4104                PrimaryChannelGuestVfState::DataPathSwitched
4105                | PrimaryChannelGuestVfState::Restoring(
4106                    saved_state::GuestVfState::DataPathSwitched,
4107                )
4108                | PrimaryChannelGuestVfState::DataPathSynthetic => {
4109                    PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4110                }
4111                _ => primary.guest_vf_state,
4112            }
4113        }
4114    }
4115
4116    async fn restart_queues(&mut self, c_state: &mut CoordinatorState) -> Result<(), WorkerError> {
4117        // Drop all the queues and stop the endpoint. Collect the worker drivers to pass to the queues.
4118        let drivers = self
4119            .workers
4120            .iter_mut()
4121            .map(|worker| {
4122                let task = worker.task_mut();
4123                task.queue_state = None;
4124                task.driver.clone()
4125            })
4126            .collect::<Vec<_>>();
4127
4128        c_state.endpoint.stop().await;
4129
4130        let (primary_worker, subworkers) = if let [primary, sub @ ..] = self.workers.as_mut_slice()
4131        {
4132            (primary, sub)
4133        } else {
4134            unreachable!()
4135        };
4136
4137        let state = primary_worker
4138            .state_mut()
4139            .and_then(|worker| worker.state.ready_mut());
4140
4141        let state = if let Some(state) = state {
4142            state
4143        } else {
4144            return Ok(());
4145        };
4146
4147        // Save the channel buffers for use in the subchannel workers.
4148        self.buffers = Some(state.buffers.clone());
4149
4150        let num_queues = state.state.primary.as_ref().unwrap().requested_num_queues;
4151        let mut active_queues = Vec::new();
4152        let active_queue_count =
4153            if let Some(rss_state) = state.state.primary.as_ref().unwrap().rss_state.as_ref() {
4154                // Active queue count is computed as the number of unique entries in the indirection table
4155                active_queues.clone_from(&rss_state.indirection_table);
4156                active_queues.sort();
4157                active_queues.dedup();
4158                active_queues = active_queues
4159                    .into_iter()
4160                    .filter(|&index| index < num_queues)
4161                    .collect::<Vec<_>>();
4162                if !active_queues.is_empty() {
4163                    active_queues.len() as u16
4164                } else {
4165                    tracelimit::warn_ratelimited!("Invalid RSS indirection table");
4166                    num_queues
4167                }
4168            } else {
4169                num_queues
4170            };
4171
4172        // Distribute the rx buffers to only the active queues.
4173        let (ranges, mut remote_buffer_id_recvs) =
4174            RxBufferRanges::new(state.buffers.recv_buffer.count, active_queue_count.into());
4175        let ranges = Arc::new(ranges);
4176
4177        let mut queues = Vec::new();
4178        let mut rx_buffers = Vec::new();
4179        {
4180            let buffers = &state.buffers;
4181            let guest_buffers = Arc::new(
4182                GuestBuffers::new(
4183                    buffers.mem.clone(),
4184                    buffers.recv_buffer.gpadl.clone(),
4185                    buffers.recv_buffer.sub_allocation_size,
4186                    buffers.ndis_config.mtu,
4187                )
4188                .map_err(WorkerError::GpadlError)?,
4189            );
4190
4191            // Get the list of free rx buffers from each task, then partition
4192            // the list per-active-queue, and produce the queue configuration.
4193            let mut queue_config = Vec::new();
4194            let initial_rx;
4195            {
4196                let states = std::iter::once(Some(&*state)).chain(
4197                    subworkers
4198                        .iter()
4199                        .map(|worker| worker.state().and_then(|worker| worker.state.ready())),
4200                );
4201
4202                initial_rx = (RX_RESERVED_CONTROL_BUFFERS..state.buffers.recv_buffer.count)
4203                    .filter(|&n| {
4204                        states
4205                            .clone()
4206                            .flatten()
4207                            .all(|s| (s.state.rx_bufs.is_free(n)))
4208                    })
4209                    .map(RxId)
4210                    .collect::<Vec<_>>();
4211
4212                let mut initial_rx = initial_rx.as_slice();
4213                let mut range_start = 0;
4214                let primary_queue_excluded = !active_queues.is_empty() && active_queues[0] != 0;
4215                let first_queue = if !primary_queue_excluded {
4216                    0
4217                } else {
4218                    // If the primary queue is excluded from the guest supplied
4219                    // indirection table, it is assigned just the reserved
4220                    // buffers.
4221                    queue_config.push(QueueConfig {
4222                        pool: Box::new(BufferPool::new(guest_buffers.clone())),
4223                        initial_rx: &[],
4224                        driver: Box::new(drivers[0].clone()),
4225                    });
4226                    rx_buffers.push(RxBufferRange::new(
4227                        ranges.clone(),
4228                        0..RX_RESERVED_CONTROL_BUFFERS,
4229                        None,
4230                    ));
4231                    range_start = RX_RESERVED_CONTROL_BUFFERS;
4232                    1
4233                };
4234                for queue_index in first_queue..num_queues {
4235                    let queue_active = active_queues.is_empty()
4236                        || active_queues.binary_search(&queue_index).is_ok();
4237                    let (range_end, end, buffer_id_recv) = if queue_active {
4238                        let range_end = if rx_buffers.len() as u16 == active_queue_count - 1 {
4239                            // The last queue gets all the remaining buffers.
4240                            state.buffers.recv_buffer.count
4241                        } else if queue_index == 0 {
4242                            // Queue zero always includes the reserved buffers.
4243                            RX_RESERVED_CONTROL_BUFFERS + ranges.buffers_per_queue
4244                        } else {
4245                            range_start + ranges.buffers_per_queue
4246                        };
4247                        (
4248                            range_end,
4249                            initial_rx.partition_point(|id| id.0 < range_end),
4250                            Some(remote_buffer_id_recvs.remove(0)),
4251                        )
4252                    } else {
4253                        (range_start, 0, None)
4254                    };
4255
4256                    let (this, rest) = initial_rx.split_at(end);
4257                    queue_config.push(QueueConfig {
4258                        pool: Box::new(BufferPool::new(guest_buffers.clone())),
4259                        initial_rx: this,
4260                        driver: Box::new(drivers[queue_index as usize].clone()),
4261                    });
4262                    initial_rx = rest;
4263                    rx_buffers.push(RxBufferRange::new(
4264                        ranges.clone(),
4265                        range_start..range_end,
4266                        buffer_id_recv,
4267                    ));
4268
4269                    range_start = range_end;
4270                }
4271            }
4272
4273            let primary = state.state.primary.as_mut().unwrap();
4274            tracing::debug!(num_queues, "enabling endpoint");
4275
4276            let rss = primary
4277                .rss_state
4278                .as_ref()
4279                .map(|rss| net_backend::RssConfig {
4280                    key: &rss.key,
4281                    indirection_table: &rss.indirection_table,
4282                    flags: 0,
4283                });
4284
4285            c_state
4286                .endpoint
4287                .get_queues(queue_config, rss.as_ref(), &mut queues)
4288                .instrument(tracing::info_span!("netvsp_get_queues"))
4289                .await
4290                .map_err(WorkerError::Endpoint)?;
4291
4292            assert_eq!(queues.len(), num_queues as usize);
4293
4294            // Set the subchannel count.
4295            self.channel_control
4296                .enable_subchannels(num_queues - 1)
4297                .expect("already validated");
4298
4299            self.num_queues = num_queues;
4300        }
4301
4302        // Provide the queue and receive buffer ranges for each worker.
4303        for ((worker, queue), rx_buffer) in self.workers.iter_mut().zip(queues).zip(rx_buffers) {
4304            worker.task_mut().queue_state = Some(QueueState {
4305                queue,
4306                target_vp_set: false,
4307                rx_buffer_range: rx_buffer,
4308            });
4309        }
4310
4311        Ok(())
4312    }
4313
4314    fn primary_mut(&mut self) -> Option<&mut PrimaryChannelState> {
4315        self.workers[0]
4316            .state_mut()
4317            .unwrap()
4318            .state
4319            .ready_mut()?
4320            .state
4321            .primary
4322            .as_mut()
4323    }
4324
4325    async fn update_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
4326        self.workers[0].stop().await;
4327        self.restore_guest_vf_state(c_state).await;
4328    }
4329}
4330
4331impl<T: RingMem + 'static + Sync> AsyncRun<Worker<T>> for NetQueue {
4332    async fn run(
4333        &mut self,
4334        stop: &mut StopTask<'_>,
4335        worker: &mut Worker<T>,
4336    ) -> Result<(), task_control::Cancelled> {
4337        match worker.process(stop, self).await {
4338            Ok(()) | Err(WorkerError::BufferRevoked) => {}
4339            Err(WorkerError::Cancelled(cancelled)) => return Err(cancelled),
4340            Err(err) => {
4341                tracing::error!(
4342                    channel_idx = worker.channel_idx,
4343                    error = &err as &dyn std::error::Error,
4344                    "netvsp error"
4345                );
4346            }
4347        }
4348        Ok(())
4349    }
4350}
4351
4352impl<T: RingMem + 'static> Worker<T> {
4353    async fn process(
4354        &mut self,
4355        stop: &mut StopTask<'_>,
4356        queue: &mut NetQueue,
4357    ) -> Result<(), WorkerError> {
4358        // Be careful not to wait on actions with unbounded blocking time (e.g.
4359        // guest actions, or waiting for network packets to arrive) without
4360        // wrapping the wait on `stop.until_stopped`.
4361        loop {
4362            match &mut self.state {
4363                WorkerState::Init(initializing) => {
4364                    if self.channel_idx != 0 {
4365                        // Still waiting for the coordinator to provide receive
4366                        // buffers. The task will be restarted when they are available.
4367                        stop.until_stopped(pending()).await?
4368                    }
4369
4370                    tracelimit::info_ratelimited!("network accepted");
4371
4372                    let (buffers, state) = stop
4373                        .until_stopped(self.channel.initialize(initializing, self.mem.clone()))
4374                        .await??;
4375
4376                    let state = ReadyState {
4377                        buffers: Arc::new(buffers),
4378                        state,
4379                        data: ProcessingData::new(),
4380                    };
4381
4382                    // Wake up the coordinator task to start the queues.
4383                    let _ = self.coordinator_send.try_send(CoordinatorMessage::Restart);
4384
4385                    tracelimit::info_ratelimited!("network initialized");
4386                    self.state = WorkerState::Ready(state);
4387                }
4388                WorkerState::Ready(state) => {
4389                    let queue_state = if let Some(queue_state) = &mut queue.queue_state {
4390                        if !queue_state.target_vp_set
4391                            && self.target_vp != vmbus_core::protocol::VP_INDEX_DISABLE_INTERRUPT
4392                        {
4393                            tracing::debug!(
4394                                channel_idx = self.channel_idx,
4395                                target_vp = self.target_vp,
4396                                "updating target VP"
4397                            );
4398                            queue_state.queue.update_target_vp(self.target_vp).await;
4399                            queue_state.target_vp_set = true;
4400                        }
4401
4402                        queue_state
4403                    } else {
4404                        // This task will be restarted when the queues are ready.
4405                        stop.until_stopped(pending()).await?
4406                    };
4407
4408                    let result = self.channel.main_loop(stop, state, queue_state).await;
4409                    match result {
4410                        Ok(restart) => {
4411                            assert_eq!(self.channel_idx, 0);
4412                            let _ = self.coordinator_send.try_send(restart);
4413                        }
4414                        Err(WorkerError::EndpointRequiresQueueRestart(err)) => {
4415                            tracelimit::warn_ratelimited!(
4416                                err = err.as_ref() as &dyn std::error::Error,
4417                                "Endpoint requires queues to restart",
4418                            );
4419                            if let Err(try_send_err) =
4420                                self.coordinator_send.try_send(CoordinatorMessage::Restart)
4421                            {
4422                                tracing::error!(
4423                                    try_send_err = &try_send_err as &dyn std::error::Error,
4424                                    "failed to restart queues"
4425                                );
4426                                return Err(WorkerError::Endpoint(err));
4427                            }
4428                        }
4429                        Err(err) => return Err(err),
4430                    }
4431
4432                    stop.until_stopped(pending()).await?
4433                }
4434            }
4435        }
4436    }
4437}
4438
4439impl<T: 'static + RingMem> NetChannel<T> {
4440    fn try_next_packet<'a>(
4441        &mut self,
4442        send_buffer: Option<&'a SendBuffer>,
4443        version: Option<Version>,
4444    ) -> Result<Option<Packet<'a>>, WorkerError> {
4445        let (mut read, _) = self.queue.split();
4446        let packet = match read.try_read() {
4447            Ok(packet) => {
4448                parse_packet(&packet, send_buffer, version).map_err(WorkerError::Packet)?
4449            }
4450            Err(queue::TryReadError::Empty) => return Ok(None),
4451            Err(queue::TryReadError::Queue(err)) => return Err(err.into()),
4452        };
4453
4454        tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
4455        Ok(Some(packet))
4456    }
4457
4458    async fn next_packet<'a>(
4459        &mut self,
4460        send_buffer: Option<&'a SendBuffer>,
4461        version: Option<Version>,
4462    ) -> Result<Packet<'a>, WorkerError> {
4463        let (mut read, _) = self.queue.split();
4464        let mut packet_ref = read.read().await?;
4465        let packet =
4466            parse_packet(&packet_ref, send_buffer, version).map_err(WorkerError::Packet)?;
4467        if matches!(packet.data, PacketData::RndisPacket(_)) {
4468            // In WorkerState::Init if an rndis packet is received, assume it is MESSAGE_TYPE_INITIALIZE_MSG
4469            tracing::trace!(target: "netvsp/vmbus", "detected rndis initialization message");
4470            packet_ref.revert();
4471        }
4472        tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
4473        Ok(packet)
4474    }
4475
4476    fn is_ready_to_initialize(initializing: &InitState, allow_missing_send_buffer: bool) -> bool {
4477        (initializing.ndis_config.is_some() || initializing.version < Version::V2)
4478            && initializing.ndis_version.is_some()
4479            && (initializing.send_buffer.is_some() || allow_missing_send_buffer)
4480            && initializing.recv_buffer.is_some()
4481    }
4482
4483    async fn initialize(
4484        &mut self,
4485        initializing: &mut Option<InitState>,
4486        mem: GuestMemory,
4487    ) -> Result<(ChannelBuffers, ActiveState), WorkerError> {
4488        let mut has_init_packet_arrived = false;
4489        loop {
4490            if let Some(initializing) = &mut *initializing {
4491                if Self::is_ready_to_initialize(initializing, false) || has_init_packet_arrived {
4492                    let recv_buffer = initializing.recv_buffer.take().unwrap();
4493                    let send_buffer = initializing.send_buffer.take();
4494                    let state = ActiveState::new(
4495                        Some(PrimaryChannelState::new(
4496                            self.adapter.offload_support.clone(),
4497                        )),
4498                        recv_buffer.count,
4499                    );
4500                    let buffers = ChannelBuffers {
4501                        version: initializing.version,
4502                        mem,
4503                        recv_buffer,
4504                        send_buffer,
4505                        ndis_version: initializing.ndis_version.take().unwrap(),
4506                        ndis_config: initializing.ndis_config.take().unwrap_or(NdisConfig {
4507                            mtu: DEFAULT_MTU,
4508                            capabilities: protocol::NdisConfigCapabilities::new(),
4509                        }),
4510                    };
4511
4512                    break Ok((buffers, state));
4513                }
4514            }
4515
4516            // Wait for enough room in the ring to avoid needing to track
4517            // completion packets.
4518            self.queue
4519                .split()
4520                .1
4521                .wait_ready(ring::PacketSize::completion(protocol::PACKET_SIZE_V61))
4522                .await?;
4523
4524            let packet = self
4525                .next_packet(None, initializing.as_ref().map(|x| x.version))
4526                .await?;
4527
4528            if let Some(initializing) = &mut *initializing {
4529                match packet.data {
4530                    PacketData::SendNdisConfig(config) => {
4531                        if initializing.ndis_config.is_some() {
4532                            return Err(WorkerError::UnexpectedPacketOrder(
4533                                PacketOrderError::SendNdisConfigExists,
4534                            ));
4535                        }
4536
4537                        // As in the vmswitch, if the MTU is invalid then use the default.
4538                        let mtu = if config.mtu >= MIN_MTU && config.mtu <= MAX_MTU {
4539                            config.mtu
4540                        } else {
4541                            DEFAULT_MTU
4542                        };
4543
4544                        // The UEFI client expects a completion packet, which can be empty.
4545                        self.send_completion(packet.transaction_id, &[])?;
4546                        initializing.ndis_config = Some(NdisConfig {
4547                            mtu,
4548                            capabilities: config.capabilities,
4549                        });
4550                    }
4551                    PacketData::SendNdisVersion(version) => {
4552                        if initializing.ndis_version.is_some() {
4553                            return Err(WorkerError::UnexpectedPacketOrder(
4554                                PacketOrderError::SendNdisVersionExists,
4555                            ));
4556                        }
4557
4558                        // The UEFI client expects a completion packet, which can be empty.
4559                        self.send_completion(packet.transaction_id, &[])?;
4560                        initializing.ndis_version = Some(NdisVersion {
4561                            major: version.ndis_major_version,
4562                            minor: version.ndis_minor_version,
4563                        });
4564                    }
4565                    PacketData::SendReceiveBuffer(message) => {
4566                        if initializing.recv_buffer.is_some() {
4567                            return Err(WorkerError::UnexpectedPacketOrder(
4568                                PacketOrderError::SendReceiveBufferExists,
4569                            ));
4570                        }
4571
4572                        let mtu = if let Some(cfg) = &initializing.ndis_config {
4573                            cfg.mtu
4574                        } else if initializing.version < Version::V2 {
4575                            DEFAULT_MTU
4576                        } else {
4577                            return Err(WorkerError::UnexpectedPacketOrder(
4578                                PacketOrderError::SendReceiveBufferMissingMTU,
4579                            ));
4580                        };
4581
4582                        let sub_allocation_size = sub_allocation_size_for_mtu(mtu);
4583
4584                        let recv_buffer = ReceiveBuffer::new(
4585                            &self.gpadl_map,
4586                            message.gpadl_handle,
4587                            message.id,
4588                            sub_allocation_size,
4589                        )?;
4590
4591                        self.send_completion(
4592                            packet.transaction_id,
4593                            &self
4594                                .message(
4595                                    protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER_COMPLETE,
4596                                    protocol::Message1SendReceiveBufferComplete {
4597                                        status: protocol::Status::SUCCESS,
4598                                        num_sections: 1,
4599                                        sections: [protocol::ReceiveBufferSection {
4600                                            offset: 0,
4601                                            sub_allocation_size: recv_buffer.sub_allocation_size,
4602                                            num_sub_allocations: recv_buffer.count,
4603                                            end_offset: recv_buffer.sub_allocation_size
4604                                                * recv_buffer.count,
4605                                        }],
4606                                    },
4607                                )
4608                                .payload(),
4609                        )?;
4610                        initializing.recv_buffer = Some(recv_buffer);
4611                    }
4612                    PacketData::SendSendBuffer(message) => {
4613                        if initializing.send_buffer.is_some() {
4614                            return Err(WorkerError::UnexpectedPacketOrder(
4615                                PacketOrderError::SendSendBufferExists,
4616                            ));
4617                        }
4618
4619                        let send_buffer = SendBuffer::new(&self.gpadl_map, message.gpadl_handle)?;
4620                        self.send_completion(
4621                            packet.transaction_id,
4622                            &self
4623                                .message(
4624                                    protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER_COMPLETE,
4625                                    protocol::Message1SendSendBufferComplete {
4626                                        status: protocol::Status::SUCCESS,
4627                                        section_size: 6144,
4628                                    },
4629                                )
4630                                .payload(),
4631                        )?;
4632
4633                        initializing.send_buffer = Some(send_buffer);
4634                    }
4635                    PacketData::RndisPacket(rndis_packet) => {
4636                        if !Self::is_ready_to_initialize(initializing, true) {
4637                            return Err(WorkerError::UnexpectedPacketOrder(
4638                                PacketOrderError::UnexpectedRndisPacket,
4639                            ));
4640                        }
4641                        tracing::debug!(
4642                            channel_type = rndis_packet.channel_type,
4643                            "RndisPacket received during initialization, assuming MESSAGE_TYPE_INITIALIZE_MSG"
4644                        );
4645                        has_init_packet_arrived = true;
4646                    }
4647                    _ => {
4648                        return Err(WorkerError::UnexpectedPacketOrder(
4649                            PacketOrderError::InvalidPacketData,
4650                        ));
4651                    }
4652                }
4653            } else {
4654                match packet.data {
4655                    PacketData::Init(init) => {
4656                        let requested_version = init.protocol_version;
4657                        let version = check_version(requested_version);
4658                        let mut message = self.message(
4659                            protocol::MESSAGE_TYPE_INIT_COMPLETE,
4660                            protocol::MessageInitComplete {
4661                                deprecated: protocol::INVALID_PROTOCOL_VERSION,
4662                                maximum_mdl_chain_length: 34,
4663                                status: protocol::Status::NONE,
4664                            },
4665                        );
4666                        if let Some(version) = version {
4667                            if version == Version::V1 {
4668                                message.data.deprecated = Version::V1 as u32;
4669                            }
4670                            message.data.status = protocol::Status::SUCCESS;
4671                        } else {
4672                            tracing::debug!(requested_version, "unrecognized version");
4673                        }
4674                        self.send_completion(packet.transaction_id, &message.payload())?;
4675
4676                        if let Some(version) = version {
4677                            tracelimit::info_ratelimited!(?version, "network negotiated");
4678
4679                            if version >= Version::V61 {
4680                                // Update the packet size so that the appropriate padding is
4681                                // appended for picky Windows guests.
4682                                self.packet_size = protocol::PACKET_SIZE_V61;
4683                            }
4684                            *initializing = Some(InitState {
4685                                version,
4686                                ndis_config: None,
4687                                ndis_version: None,
4688                                recv_buffer: None,
4689                                send_buffer: None,
4690                            });
4691                        }
4692                    }
4693                    _ => unreachable!(),
4694                }
4695            }
4696        }
4697    }
4698
4699    async fn main_loop(
4700        &mut self,
4701        stop: &mut StopTask<'_>,
4702        ready_state: &mut ReadyState,
4703        queue_state: &mut QueueState,
4704    ) -> Result<CoordinatorMessage, WorkerError> {
4705        let buffers = &ready_state.buffers;
4706        let state = &mut ready_state.state;
4707        let data = &mut ready_state.data;
4708
4709        let ring_spare_capacity = {
4710            let (_, send) = self.queue.split();
4711            let mut limit = if self.can_use_ring_size_opt {
4712                self.adapter.ring_size_limit.load(Ordering::Relaxed)
4713            } else {
4714                0
4715            };
4716            if limit == 0 {
4717                limit = send.capacity() - 2048;
4718            }
4719            send.capacity() - limit
4720        };
4721
4722        // Handle any guest state changes since last run.
4723        if let Some(primary) = state.primary.as_mut() {
4724            if primary.requested_num_queues > 1 && !primary.tx_spread_sent {
4725                let num_channels_opened =
4726                    self.adapter.num_sub_channels_opened.load(Ordering::Relaxed);
4727                if num_channels_opened == primary.requested_num_queues as usize {
4728                    let (_, mut send) = self.queue.split();
4729                    stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
4730                        .await??;
4731                    self.guest_send_indirection_table(buffers.version, num_channels_opened as u32);
4732                    primary.tx_spread_sent = true;
4733                }
4734            }
4735            if let PendingLinkAction::Active(up) = primary.pending_link_action {
4736                let (_, mut send) = self.queue.split();
4737                stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
4738                    .await??;
4739                if let Some(id) = primary.free_control_buffers.pop() {
4740                    let connect = if primary.guest_link_up != up {
4741                        primary.pending_link_action = PendingLinkAction::Default;
4742                        up
4743                    } else {
4744                        // For the up -> down -> up OR down -> up -> down case, the first transition
4745                        // is sent immediately and the second transition is queued with a delay.
4746                        primary.pending_link_action =
4747                            PendingLinkAction::Delay(primary.guest_link_up);
4748                        !primary.guest_link_up
4749                    };
4750                    // Mark the receive buffer in use to allow the guest to release it.
4751                    assert!(state.rx_bufs.is_free(id.0));
4752                    state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
4753                    let state_to_send = if connect {
4754                        rndisprot::STATUS_MEDIA_CONNECT
4755                    } else {
4756                        rndisprot::STATUS_MEDIA_DISCONNECT
4757                    };
4758                    tracing::info!(
4759                        connect,
4760                        mac_address = %self.adapter.mac_address,
4761                        "sending link status"
4762                    );
4763
4764                    self.indicate_status(buffers, id.0, state_to_send, &[])?;
4765                    primary.guest_link_up = connect;
4766                } else {
4767                    primary.pending_link_action = PendingLinkAction::Delay(up);
4768                }
4769
4770                match primary.pending_link_action {
4771                    PendingLinkAction::Delay(_) => {
4772                        return Ok(CoordinatorMessage::StartTimer(
4773                            Instant::now() + LINK_DELAY_DURATION,
4774                        ));
4775                    }
4776                    PendingLinkAction::Active(_) => panic!("State should not be Active"),
4777                    _ => {}
4778                }
4779            }
4780            match primary.guest_vf_state {
4781                PrimaryChannelGuestVfState::Available { .. }
4782                | PrimaryChannelGuestVfState::UnavailableFromAvailable
4783                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4784                | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4785                | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4786                | PrimaryChannelGuestVfState::DataPathSynthetic => {
4787                    let (_, mut send) = self.queue.split();
4788                    stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
4789                        .await??;
4790                    if let Some(message) = self.handle_state_change(primary, buffers).await? {
4791                        return Ok(message);
4792                    }
4793                }
4794                _ => (),
4795            }
4796        }
4797
4798        loop {
4799            // If the ring is almost full, then do not poll the endpoint,
4800            // since this will cause us to spend time dropping packets and
4801            // reprogramming receive buffers. It's more efficient to let the
4802            // backend drop packets.
4803            let ring_full = {
4804                let (_, mut send) = self.queue.split();
4805                !send.can_write(ring_spare_capacity)?
4806            };
4807
4808            let did_some_work = (!ring_full
4809                && self.process_endpoint_rx(buffers, state, data, queue_state.queue.as_mut())?)
4810                | self.process_ring_buffer(buffers, state, data, queue_state)?
4811                | (!ring_full
4812                    && self.process_endpoint_tx(state, data, queue_state.queue.as_mut())?)
4813                | self.transmit_pending_segments(state, data, queue_state)?
4814                | self.send_pending_packets(state)?;
4815
4816            if !did_some_work {
4817                state.stats.spurious_wakes.increment();
4818            }
4819
4820            // Process any outstanding control messages before sleeping in case
4821            // there are now suballocations available for use.
4822            self.process_control_messages(buffers, state)?;
4823
4824            // This should be the only await point waiting on network traffic or
4825            // guest actions. Wrap it in `stop.until_stopped` to allow
4826            // cancellation.
4827            let restart = stop
4828                .until_stopped(std::future::poll_fn(
4829                    |cx| -> Poll<Option<CoordinatorMessage>> {
4830                        // If the ring is almost full, then don't wait for endpoint
4831                        // interrupts. This allows the interrupt rate to fall when the
4832                        // guest cannot keep up with the load.
4833                        if !ring_full {
4834                            // Check the network endpoint for tx completion or rx.
4835                            if queue_state.queue.poll_ready(cx).is_ready() {
4836                                tracing::trace!("endpoint ready");
4837                                return Poll::Ready(None);
4838                            }
4839                        }
4840
4841                        // Check the incoming ring for tx, but only if there are enough
4842                        // free tx packets.
4843                        let (mut recv, mut send) = self.queue.split();
4844                        if state.free_tx_packets.len() >= self.adapter.free_tx_packet_threshold
4845                            && data.tx_segments.is_empty()
4846                            && recv.poll_ready(cx).is_ready()
4847                        {
4848                            tracing::trace!("incoming ring ready");
4849                            return Poll::Ready(None);
4850                        }
4851
4852                        // Check the outgoing ring for space to send rx completions or
4853                        // control message sends, if any are pending.
4854                        //
4855                        // Also, if endpoint processing has been suspended due to the
4856                        // ring being nearly full, then ask the guest to wake us up when
4857                        // there is space again.
4858                        let mut pending_send_size = self.pending_send_size;
4859                        if ring_full {
4860                            pending_send_size = ring_spare_capacity;
4861                        }
4862                        if pending_send_size != 0
4863                            && send.poll_ready(cx, pending_send_size).is_ready()
4864                        {
4865                            tracing::trace!("outgoing ring ready");
4866                            return Poll::Ready(None);
4867                        }
4868
4869                        // Collect any of this queue's receive buffers that were
4870                        // completed by a remote channel. This only happens when the
4871                        // subchannel count changes, so that receive buffer ownership
4872                        // moves between queues while some receive buffers are still in
4873                        // use.
4874                        if let Some(remote_buffer_id_recv) =
4875                            &mut queue_state.rx_buffer_range.remote_buffer_id_recv
4876                        {
4877                            while let Poll::Ready(Some(id)) =
4878                                remote_buffer_id_recv.poll_next_unpin(cx)
4879                            {
4880                                if id >= RX_RESERVED_CONTROL_BUFFERS {
4881                                    queue_state.queue.rx_avail(&[RxId(id)]);
4882                                } else {
4883                                    state
4884                                        .primary
4885                                        .as_mut()
4886                                        .unwrap()
4887                                        .free_control_buffers
4888                                        .push(ControlMessageId(id));
4889                                }
4890                            }
4891                        }
4892
4893                        if let Some(restart) = self.restart.take() {
4894                            return Poll::Ready(Some(restart));
4895                        }
4896
4897                        tracing::trace!("network waiting");
4898                        Poll::Pending
4899                    },
4900                ))
4901                .await?;
4902
4903            if let Some(restart) = restart {
4904                break Ok(restart);
4905            }
4906        }
4907    }
4908
4909    fn process_endpoint_rx(
4910        &mut self,
4911        buffers: &ChannelBuffers,
4912        state: &mut ActiveState,
4913        data: &mut ProcessingData,
4914        epqueue: &mut dyn net_backend::Queue,
4915    ) -> Result<bool, WorkerError> {
4916        let n = epqueue
4917            .rx_poll(&mut data.rx_ready)
4918            .map_err(WorkerError::Endpoint)?;
4919        if n == 0 {
4920            return Ok(false);
4921        }
4922
4923        let transaction_id = data.rx_ready[0].0.into();
4924        let ready_ids = data.rx_ready[..n].iter().map(|&RxId(id)| id);
4925
4926        state.rx_bufs.allocate(ready_ids.clone()).unwrap();
4927
4928        // Always use the full suballocation size to avoid tracking the
4929        // message length. See RxBuf::header() for details.
4930        let len = buffers.recv_buffer.sub_allocation_size as usize;
4931        data.transfer_pages.clear();
4932        data.transfer_pages
4933            .extend(ready_ids.map(|id| buffers.recv_buffer.transfer_page_range(id, len)));
4934
4935        match self.try_send_rndis_message(
4936            transaction_id,
4937            protocol::DATA_CHANNEL_TYPE,
4938            buffers.recv_buffer.id,
4939            &data.transfer_pages,
4940        )? {
4941            None => {
4942                // packet was sent
4943                state.stats.rx_packets.add(n as u64);
4944            }
4945            Some(_) => {
4946                // Ring buffer is full. Drop the packets and free the rx
4947                // buffers.
4948                state.stats.rx_dropped_ring_full.add(n as u64);
4949
4950                state.rx_bufs.free(data.rx_ready[0].0);
4951                epqueue.rx_avail(&data.rx_ready[..n]);
4952            }
4953        }
4954
4955        state.stats.rx_packets_per_wake.add_sample(n as u64);
4956        Ok(true)
4957    }
4958
4959    fn process_endpoint_tx(
4960        &mut self,
4961        state: &mut ActiveState,
4962        data: &mut ProcessingData,
4963        epqueue: &mut dyn net_backend::Queue,
4964    ) -> Result<bool, WorkerError> {
4965        // Drain completed transmits.
4966        let result = epqueue.tx_poll(&mut data.tx_done);
4967
4968        match result {
4969            Ok(n) => {
4970                if n == 0 {
4971                    return Ok(false);
4972                }
4973
4974                for &id in &data.tx_done[..n] {
4975                    let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
4976                    assert!(tx_packet.pending_packet_count > 0);
4977                    tx_packet.pending_packet_count -= 1;
4978                    if tx_packet.pending_packet_count == 0 {
4979                        self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
4980                    }
4981                }
4982
4983                Ok(true)
4984            }
4985            Err(TxError::TryRestart(err)) => {
4986                // Complete any pending tx prior to restarting queues.
4987                let pending_tx = state
4988                    .pending_tx_packets
4989                    .iter_mut()
4990                    .enumerate()
4991                    .filter_map(|(id, inflight)| {
4992                        if inflight.pending_packet_count > 0 {
4993                            inflight.pending_packet_count = 0;
4994                            Some(PendingTxCompletion {
4995                                transaction_id: inflight.transaction_id,
4996                                tx_id: Some(TxId(id as u32)),
4997                                status: protocol::Status::SUCCESS,
4998                            })
4999                        } else {
5000                            None
5001                        }
5002                    })
5003                    .collect::<Vec<_>>();
5004                state.pending_tx_completions.extend(pending_tx);
5005                Err(WorkerError::EndpointRequiresQueueRestart(err))
5006            }
5007            Err(TxError::Fatal(err)) => Err(WorkerError::Endpoint(err)),
5008        }
5009    }
5010
5011    fn switch_data_path(
5012        &mut self,
5013        state: &mut ActiveState,
5014        use_guest_vf: bool,
5015        transaction_id: Option<u64>,
5016    ) -> Result<(), WorkerError> {
5017        let primary = state.primary.as_mut().unwrap();
5018        let mut queue_switch_operation = false;
5019        match primary.guest_vf_state {
5020            PrimaryChannelGuestVfState::AvailableAdvertised | PrimaryChannelGuestVfState::Ready => {
5021                // Allow the guest to switch to VTL0, or if the current state
5022                // of the data path is unknown, allow a switch away from VTL0.
5023                // The latter case handles the scenario where the data path has
5024                // been switched but the synthetic device has been restarted.
5025                // The device is queried for the current state of the data path
5026                // but if it is unknown (upgraded from older version that
5027                // doesn't track this data, or failure during query) then the
5028                // safest option is to pass the request through.
5029                if use_guest_vf || primary.is_data_path_switched.is_none() {
5030                    primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5031                        to_guest: use_guest_vf,
5032                        id: transaction_id,
5033                        result: None,
5034                    };
5035                    queue_switch_operation = true;
5036                }
5037            }
5038            PrimaryChannelGuestVfState::DataPathSwitched => {
5039                if !use_guest_vf {
5040                    primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5041                        to_guest: false,
5042                        id: transaction_id,
5043                        result: None,
5044                    };
5045                    queue_switch_operation = true;
5046                }
5047            }
5048            _ if use_guest_vf => {
5049                tracing::warn!(
5050                    state = %primary.guest_vf_state,
5051                    use_guest_vf,
5052                    "Data path switch requested while device is in wrong state"
5053                );
5054            }
5055            _ => (),
5056        };
5057        if queue_switch_operation {
5058            // A restart will also try to switch the data path based on primary.guest_vf_state.
5059            if self.restart.is_none() {
5060                self.restart = Some(CoordinatorMessage::UpdateGuestVfState)
5061            };
5062        } else {
5063            self.send_completion(transaction_id, &[])?;
5064        }
5065        Ok(())
5066    }
5067
5068    fn process_ring_buffer(
5069        &mut self,
5070        buffers: &ChannelBuffers,
5071        state: &mut ActiveState,
5072        data: &mut ProcessingData,
5073        queue_state: &mut QueueState,
5074    ) -> Result<bool, WorkerError> {
5075        let mut total_packets = 0;
5076        let mut did_some_work = false;
5077        loop {
5078            if state.free_tx_packets.is_empty() || !data.tx_segments.is_empty() {
5079                break;
5080            }
5081            let packet = if let Some(packet) =
5082                self.try_next_packet(buffers.send_buffer.as_ref(), Some(buffers.version))?
5083            {
5084                packet
5085            } else {
5086                break;
5087            };
5088
5089            did_some_work = true;
5090            match packet.data {
5091                PacketData::RndisPacket(_) => {
5092                    assert!(data.tx_segments.is_empty());
5093                    let id = state.free_tx_packets.pop().unwrap();
5094                    let result: Result<usize, WorkerError> =
5095                        self.handle_rndis(buffers, id, state, &packet, &mut data.tx_segments);
5096                    let num_packets = match result {
5097                        Ok(num_packets) => num_packets,
5098                        Err(err) => {
5099                            tracelimit::error_ratelimited!(
5100                                err = &err as &dyn std::error::Error,
5101                                "failed to handle RNDIS packet"
5102                            );
5103                            self.complete_tx_packet(state, id, protocol::Status::FAILURE)?;
5104                            continue;
5105                        }
5106                    };
5107                    total_packets += num_packets as u64;
5108                    state.pending_tx_packets[id.0 as usize].pending_packet_count += num_packets;
5109
5110                    if num_packets != 0 {
5111                        if self.transmit_segments(state, data, queue_state, id, num_packets)?
5112                            < num_packets
5113                        {
5114                            state.stats.tx_stalled.increment();
5115                        }
5116                    } else {
5117                        self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5118                    }
5119                }
5120                PacketData::RndisPacketComplete(_completion) => {
5121                    data.rx_done.clear();
5122                    state
5123                        .release_recv_buffers(
5124                            packet
5125                                .transaction_id
5126                                .expect("completion packets have transaction id by construction"),
5127                            &queue_state.rx_buffer_range,
5128                            &mut data.rx_done,
5129                        )
5130                        .ok_or(WorkerError::InvalidRndisPacketCompletion)?;
5131                    queue_state.queue.rx_avail(&data.rx_done);
5132                }
5133                PacketData::SubChannelRequest(request) if state.primary.is_some() => {
5134                    let mut subchannel_count = 0;
5135                    let status = if request.operation == protocol::SubchannelOperation::ALLOCATE
5136                        && request.num_sub_channels <= self.adapter.max_queues.into()
5137                    {
5138                        subchannel_count = request.num_sub_channels;
5139                        protocol::Status::SUCCESS
5140                    } else {
5141                        protocol::Status::FAILURE
5142                    };
5143
5144                    tracing::debug!(?status, subchannel_count, "subchannel request");
5145                    self.send_completion(
5146                        packet.transaction_id,
5147                        &self
5148                            .message(
5149                                protocol::MESSAGE5_TYPE_SUB_CHANNEL,
5150                                protocol::Message5SubchannelComplete {
5151                                    status,
5152                                    num_sub_channels: subchannel_count,
5153                                },
5154                            )
5155                            .payload(),
5156                    )?;
5157
5158                    if subchannel_count > 0 {
5159                        let primary = state.primary.as_mut().unwrap();
5160                        primary.requested_num_queues = subchannel_count as u16 + 1;
5161                        primary.tx_spread_sent = false;
5162                        self.restart = Some(CoordinatorMessage::Restart);
5163                    }
5164                }
5165                PacketData::RevokeReceiveBuffer(Message1RevokeReceiveBuffer { id })
5166                | PacketData::RevokeSendBuffer(Message1RevokeSendBuffer { id })
5167                    if state.primary.is_some() =>
5168                {
5169                    tracing::debug!(
5170                        id,
5171                        "receive/send buffer revoked, terminating channel processing"
5172                    );
5173                    return Err(WorkerError::BufferRevoked);
5174                }
5175                PacketData::SendVfAssociationCompletion if state.primary.is_some() => (),
5176                PacketData::SwitchDataPath(switch_data_path) if state.primary.is_some() => {
5177                    self.switch_data_path(
5178                        state,
5179                        switch_data_path.active_data_path == protocol::DataPath::VF.0,
5180                        packet.transaction_id,
5181                    )?;
5182                }
5183                PacketData::SwitchDataPathCompletion if state.primary.is_some() => (),
5184                PacketData::OidQueryEx(oid_query) => {
5185                    tracing::warn!(oid = ?oid_query.oid, "unimplemented OID");
5186                    self.send_completion(
5187                        packet.transaction_id,
5188                        &self
5189                            .message(
5190                                protocol::MESSAGE5_TYPE_OID_QUERY_EX_COMPLETE,
5191                                protocol::Message5OidQueryExComplete {
5192                                    status: rndisprot::STATUS_NOT_SUPPORTED,
5193                                    bytes: 0,
5194                                },
5195                            )
5196                            .payload(),
5197                    )?;
5198                }
5199                p => {
5200                    tracing::warn!(request = ?p, "unexpected packet");
5201                    return Err(WorkerError::UnexpectedPacketOrder(
5202                        PacketOrderError::SwitchDataPathCompletionPrimaryChannelState,
5203                    ));
5204                }
5205            }
5206        }
5207        state.stats.tx_packets_per_wake.add_sample(total_packets);
5208        Ok(did_some_work)
5209    }
5210
5211    fn transmit_pending_segments(
5212        &mut self,
5213        state: &mut ActiveState,
5214        data: &mut ProcessingData,
5215        queue_state: &mut QueueState,
5216    ) -> Result<bool, WorkerError> {
5217        if data.tx_segments.is_empty() {
5218            return Ok(false);
5219        }
5220        let net_backend::TxSegmentType::Head(metadata) = &data.tx_segments[0].ty else {
5221            unreachable!()
5222        };
5223        let id = metadata.id;
5224        let num_packets = state.pending_tx_packets[id.0 as usize].pending_packet_count;
5225        let packets_sent = self.transmit_segments(state, data, queue_state, id, num_packets)?;
5226        Ok(num_packets == packets_sent)
5227    }
5228
5229    fn transmit_segments(
5230        &mut self,
5231        state: &mut ActiveState,
5232        data: &mut ProcessingData,
5233        queue_state: &mut QueueState,
5234        id: TxId,
5235        num_packets: usize,
5236    ) -> Result<usize, WorkerError> {
5237        let (sync, segments_sent) = queue_state
5238            .queue
5239            .tx_avail(&data.tx_segments)
5240            .map_err(WorkerError::Endpoint)?;
5241
5242        assert!(segments_sent <= data.tx_segments.len());
5243
5244        let packets_sent = if segments_sent == data.tx_segments.len() {
5245            num_packets
5246        } else {
5247            net_backend::packet_count(&data.tx_segments[..segments_sent])
5248        };
5249
5250        data.tx_segments.drain(..segments_sent);
5251
5252        if sync {
5253            state.pending_tx_packets[id.0 as usize].pending_packet_count -= packets_sent;
5254        }
5255
5256        if state.pending_tx_packets[id.0 as usize].pending_packet_count == 0 {
5257            self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5258        }
5259
5260        Ok(packets_sent)
5261    }
5262
5263    fn handle_rndis(
5264        &mut self,
5265        buffers: &ChannelBuffers,
5266        id: TxId,
5267        state: &mut ActiveState,
5268        packet: &Packet<'_>,
5269        segments: &mut Vec<TxSegment>,
5270    ) -> Result<usize, WorkerError> {
5271        let mut num_packets = 0;
5272        let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5273        assert!(tx_packet.pending_packet_count == 0);
5274        tx_packet.transaction_id = packet
5275            .transaction_id
5276            .ok_or(WorkerError::MissingTransactionId)?;
5277
5278        // Probe the data to catch accesses that are out of bounds. This
5279        // simplifies error handling for backends that use
5280        // [`GuestMemory::iova`].
5281        packet
5282            .external_data
5283            .iter()
5284            .try_for_each(|range| buffers.mem.probe_gpns(range.gpns()))
5285            .map_err(WorkerError::GpaDirectError)?;
5286
5287        let mut reader = packet.rndis_reader(&buffers.mem);
5288        // There may be multiple RNDIS packets in a single
5289        // message, concatenated with each other. Consume
5290        // them until there is no more data in the RNDIS
5291        // message.
5292        while reader.len() > 0 {
5293            let mut this_reader = reader.clone();
5294            let header: rndisprot::MessageHeader = this_reader.read_plain()?;
5295            if self.handle_rndis_message(
5296                buffers,
5297                state,
5298                id,
5299                header.message_type,
5300                this_reader,
5301                segments,
5302            )? {
5303                num_packets += 1;
5304            }
5305            reader.skip(header.message_length as usize)?;
5306        }
5307
5308        Ok(num_packets)
5309    }
5310
5311    fn try_send_tx_packet(
5312        &mut self,
5313        transaction_id: u64,
5314        status: protocol::Status,
5315    ) -> Result<bool, WorkerError> {
5316        let message = self.message(
5317            protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE,
5318            protocol::Message1SendRndisPacketComplete { status },
5319        );
5320        let result = self.queue.split().1.try_write(&queue::OutgoingPacket {
5321            transaction_id,
5322            packet_type: OutgoingPacketType::Completion,
5323            payload: &message.payload(),
5324        });
5325        let sent = match result {
5326            Ok(()) => true,
5327            Err(queue::TryWriteError::Full(n)) => {
5328                self.pending_send_size = n;
5329                false
5330            }
5331            Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
5332        };
5333        Ok(sent)
5334    }
5335
5336    fn send_pending_packets(&mut self, state: &mut ActiveState) -> Result<bool, WorkerError> {
5337        let mut did_some_work = false;
5338        while let Some(pending) = state.pending_tx_completions.front() {
5339            if !self.try_send_tx_packet(pending.transaction_id, pending.status)? {
5340                return Ok(did_some_work);
5341            }
5342            did_some_work = true;
5343            if let Some(id) = pending.tx_id {
5344                state.free_tx_packets.push(id);
5345            }
5346            tracing::trace!(?pending, "sent tx completion");
5347            state.pending_tx_completions.pop_front();
5348        }
5349
5350        self.pending_send_size = 0;
5351        Ok(did_some_work)
5352    }
5353
5354    fn complete_tx_packet(
5355        &mut self,
5356        state: &mut ActiveState,
5357        id: TxId,
5358        status: protocol::Status,
5359    ) -> Result<(), WorkerError> {
5360        let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5361        assert_eq!(tx_packet.pending_packet_count, 0);
5362        if self.pending_send_size == 0
5363            && self.try_send_tx_packet(tx_packet.transaction_id, status)?
5364        {
5365            tracing::trace!(id = id.0, "sent tx completion");
5366            state.free_tx_packets.push(id);
5367        } else {
5368            tracing::trace!(id = id.0, "pended tx completion");
5369            state.pending_tx_completions.push_back(PendingTxCompletion {
5370                transaction_id: tx_packet.transaction_id,
5371                tx_id: Some(id),
5372                status,
5373            });
5374        }
5375        Ok(())
5376    }
5377}
5378
5379impl ActiveState {
5380    fn release_recv_buffers(
5381        &mut self,
5382        transaction_id: u64,
5383        rx_buffer_range: &RxBufferRange,
5384        done: &mut Vec<RxId>,
5385    ) -> Option<()> {
5386        // The transaction ID specifies the first rx buffer ID.
5387        let first_id: u32 = transaction_id.try_into().ok()?;
5388        let ids = self.rx_bufs.free(first_id)?;
5389        for id in ids {
5390            if !rx_buffer_range.send_if_remote(id) {
5391                if id >= RX_RESERVED_CONTROL_BUFFERS {
5392                    done.push(RxId(id));
5393                } else {
5394                    self.primary
5395                        .as_mut()?
5396                        .free_control_buffers
5397                        .push(ControlMessageId(id));
5398                }
5399            }
5400        }
5401        Some(())
5402    }
5403}