1#![expect(missing_docs)]
7#![forbid(unsafe_code)]
8
9mod buffers;
10mod protocol;
11pub mod resolver;
12mod rndisprot;
13mod rx_bufs;
14mod saved_state;
15mod test;
16
17use crate::buffers::GuestBuffers;
18use crate::protocol::Message1RevokeReceiveBuffer;
19use crate::protocol::Message1RevokeSendBuffer;
20use crate::protocol::VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES;
21use crate::protocol::Version;
22use crate::rndisprot::NDIS_HASH_FUNCTION_MASK;
23use crate::rndisprot::NDIS_RSS_PARAM_FLAG_DISABLE_RSS;
24use async_trait::async_trait;
25pub use buffers::BufferPool;
26use buffers::sub_allocation_size_for_mtu;
27use futures::FutureExt;
28use futures::StreamExt;
29use futures::channel::mpsc;
30use futures_concurrency::future::Race;
31use guestmem::AccessError;
32use guestmem::GuestMemory;
33use guestmem::GuestMemoryError;
34use guestmem::MemoryRead;
35use guestmem::MemoryWrite;
36use guestmem::ranges::PagedRange;
37use guestmem::ranges::PagedRanges;
38use guestmem::ranges::PagedRangesReader;
39use guid::Guid;
40use hvdef::hypercall::HvGuestOsId;
41use hvdef::hypercall::HvGuestOsMicrosoft;
42use hvdef::hypercall::HvGuestOsMicrosoftIds;
43use hvdef::hypercall::HvGuestOsOpenSourceType;
44use inspect::Inspect;
45use inspect::InspectMut;
46use inspect::SensitivityLevel;
47use inspect_counters::Counter;
48use inspect_counters::Histogram;
49use mesh::rpc::Rpc;
50use net_backend::Endpoint;
51use net_backend::EndpointAction;
52use net_backend::L3Protocol;
53use net_backend::QueueConfig;
54use net_backend::RxId;
55use net_backend::TxError;
56use net_backend::TxId;
57use net_backend::TxSegment;
58use net_backend_resources::mac_address::MacAddress;
59use pal_async::timer::Instant;
60use pal_async::timer::PolledTimer;
61use ring::gparange::MultiPagedRangeIter;
62use rx_bufs::RxBuffers;
63use rx_bufs::SubAllocationInUse;
64use std::cmp;
65use std::collections::VecDeque;
66use std::fmt::Debug;
67use std::future::pending;
68use std::mem::offset_of;
69use std::ops::Range;
70use std::sync::Arc;
71use std::sync::atomic::AtomicUsize;
72use std::sync::atomic::Ordering;
73use std::task::Poll;
74use std::time::Duration;
75use task_control::AsyncRun;
76use task_control::InspectTaskMut;
77use task_control::StopTask;
78use task_control::TaskControl;
79use thiserror::Error;
80use tracing::Instrument;
81use vmbus_async::queue;
82use vmbus_async::queue::ExternalDataError;
83use vmbus_async::queue::IncomingPacket;
84use vmbus_async::queue::Queue;
85use vmbus_channel::bus::OfferParams;
86use vmbus_channel::bus::OpenRequest;
87use vmbus_channel::channel::ChannelControl;
88use vmbus_channel::channel::ChannelOpenError;
89use vmbus_channel::channel::ChannelRestoreError;
90use vmbus_channel::channel::DeviceResources;
91use vmbus_channel::channel::RestoreControl;
92use vmbus_channel::channel::SaveRestoreVmbusDevice;
93use vmbus_channel::channel::VmbusDevice;
94use vmbus_channel::gpadl::GpadlId;
95use vmbus_channel::gpadl::GpadlMapView;
96use vmbus_channel::gpadl::GpadlView;
97use vmbus_channel::gpadl::UnknownGpadlId;
98use vmbus_channel::gpadl_ring::GpadlRingMem;
99use vmbus_channel::gpadl_ring::gpadl_channel;
100use vmbus_ring as ring;
101use vmbus_ring::OutgoingPacketType;
102use vmbus_ring::RingMem;
103use vmbus_ring::gparange::GpnList;
104use vmbus_ring::gparange::MultiPagedRangeBuf;
105use vmcore::save_restore::RestoreError;
106use vmcore::save_restore::SaveError;
107use vmcore::save_restore::SavedStateBlob;
108use vmcore::vm_task::VmTaskDriver;
109use vmcore::vm_task::VmTaskDriverSource;
110use zerocopy::FromBytes;
111use zerocopy::FromZeros;
112use zerocopy::Immutable;
113use zerocopy::IntoBytes;
114use zerocopy::KnownLayout;
115
116const MIN_CONTROL_RING_SIZE: usize = 144;
119
120const MIN_STATE_CHANGE_RING_SIZE: usize = 196;
123
124const VF_ASSOCIATION_TRANSACTION_ID: u64 = 0x8000000000000000;
126const SWITCH_DATA_PATH_TRANSACTION_ID: u64 = 0x8000000000000001;
128
129const NETVSP_MAX_SUBCHANNELS_PER_VNIC: u16 = 64;
130
131#[cfg(not(test))]
139const VF_DEVICE_DELAY: Duration = Duration::from_secs(1);
140#[cfg(test)]
141const VF_DEVICE_DELAY: Duration = Duration::from_millis(100);
142
143#[cfg(not(test))]
146const LINK_DELAY_DURATION: Duration = Duration::from_secs(5);
147#[cfg(test)]
148const LINK_DELAY_DURATION: Duration = Duration::from_millis(333);
149
150#[derive(Default, PartialEq)]
151struct CoordinatorMessageUpdateType {
152 guest_vf_state: bool,
155 filter_state: bool,
157}
158
159#[derive(PartialEq)]
160enum CoordinatorMessage {
161 Update(CoordinatorMessageUpdateType),
163 Restart,
166 StartTimer(Instant),
168}
169
170struct Worker<T: RingMem> {
171 channel_idx: u16,
172 target_vp: u32,
173 mem: GuestMemory,
174 channel: NetChannel<T>,
175 state: WorkerState,
176 coordinator_send: mpsc::Sender<CoordinatorMessage>,
177}
178
179struct NetQueue {
180 driver: VmTaskDriver,
181 queue_state: Option<QueueState>,
182}
183
184impl<T: RingMem + 'static + Sync> InspectTaskMut<Worker<T>> for NetQueue {
185 fn inspect_mut(&mut self, req: inspect::Request<'_>, worker: Option<&mut Worker<T>>) {
186 if worker.is_none() && self.queue_state.is_none() {
187 req.ignore();
188 return;
189 }
190
191 let mut resp = req.respond();
192 resp.field("driver", &self.driver);
193 if let Some(worker) = worker {
194 resp.field(
195 "protocol_state",
196 match &worker.state {
197 WorkerState::Init(None) => "version",
198 WorkerState::Init(Some(_)) => "init",
199 WorkerState::Ready(_) => "ready",
200 },
201 )
202 .field("ring", &worker.channel.queue)
203 .field(
204 "can_use_ring_size_optimization",
205 worker.channel.can_use_ring_size_opt,
206 );
207
208 if let WorkerState::Ready(state) = &worker.state {
209 resp.field(
210 "outstanding_tx_packets",
211 state.state.pending_tx_packets.len() - state.state.free_tx_packets.len(),
212 )
213 .field(
214 "pending_tx_completions",
215 state.state.pending_tx_completions.len(),
216 )
217 .field("free_tx_packets", state.state.free_tx_packets.len())
218 .merge(&state.state.stats);
219 }
220 }
221
222 if let Some(queue_state) = &mut self.queue_state {
223 resp.field_mut("queue", &mut queue_state.queue)
224 .field("rx_buffers", queue_state.rx_buffer_range.id_range.len())
225 .field(
226 "rx_buffers_start",
227 queue_state.rx_buffer_range.id_range.start,
228 );
229 }
230 }
231}
232
233#[expect(clippy::large_enum_variant)]
234enum WorkerState {
235 Init(Option<InitState>),
236 Ready(ReadyState),
237}
238
239impl WorkerState {
240 fn ready(&self) -> Option<&ReadyState> {
241 if let Self::Ready(state) = self {
242 Some(state)
243 } else {
244 None
245 }
246 }
247
248 fn ready_mut(&mut self) -> Option<&mut ReadyState> {
249 if let Self::Ready(state) = self {
250 Some(state)
251 } else {
252 None
253 }
254 }
255}
256
257struct InitState {
258 version: Version,
259 ndis_config: Option<NdisConfig>,
260 ndis_version: Option<NdisVersion>,
261 recv_buffer: Option<ReceiveBuffer>,
262 send_buffer: Option<SendBuffer>,
263}
264
265#[derive(Copy, Clone, Debug, Inspect)]
266struct NdisVersion {
267 #[inspect(hex)]
268 major: u32,
269 #[inspect(hex)]
270 minor: u32,
271}
272
273#[derive(Copy, Clone, Debug, Inspect)]
274struct NdisConfig {
275 #[inspect(safe)]
276 mtu: u32,
277 #[inspect(safe)]
278 capabilities: protocol::NdisConfigCapabilities,
279}
280
281struct ReadyState {
282 buffers: Arc<ChannelBuffers>,
283 state: ActiveState,
284 data: ProcessingData,
285}
286
287#[async_trait]
290pub trait VirtualFunction: Sync + Send {
291 async fn id(&self) -> Option<u32>;
295 async fn guest_ready_for_device(&mut self);
297 async fn wait_for_state_change(&mut self) -> Rpc<(), ()>;
300}
301
302struct Adapter {
303 driver: VmTaskDriver,
304 mac_address: MacAddress,
305 max_queues: u16,
306 indirection_table_size: u16,
307 offload_support: OffloadConfig,
308 ring_size_limit: AtomicUsize,
309 free_tx_packet_threshold: usize,
310 tx_fast_completions: bool,
311 adapter_index: u32,
312 get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
313 num_sub_channels_opened: AtomicUsize,
314 link_speed: u64,
315}
316
317struct QueueState {
318 queue: Box<dyn net_backend::Queue>,
319 rx_buffer_range: RxBufferRange,
320 target_vp_set: bool,
321}
322
323struct RxBufferRange {
324 id_range: Range<u32>,
325 remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
326 remote_ranges: Arc<RxBufferRanges>,
327}
328
329impl RxBufferRange {
330 fn new(
331 ranges: Arc<RxBufferRanges>,
332 id_range: Range<u32>,
333 remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
334 ) -> Self {
335 Self {
336 id_range,
337 remote_buffer_id_recv,
338 remote_ranges: ranges,
339 }
340 }
341
342 fn send_if_remote(&self, id: u32) -> bool {
343 if id < RX_RESERVED_CONTROL_BUFFERS || self.id_range.contains(&id) {
346 false
347 } else {
348 let i = (id - RX_RESERVED_CONTROL_BUFFERS) / self.remote_ranges.buffers_per_queue;
349 let i = (i as usize).min(self.remote_ranges.buffer_id_send.len() - 1);
353 let _ = self.remote_ranges.buffer_id_send[i].unbounded_send(id);
354 true
355 }
356 }
357}
358
359struct RxBufferRanges {
360 buffers_per_queue: u32,
361 buffer_id_send: Vec<mpsc::UnboundedSender<u32>>,
362}
363
364impl RxBufferRanges {
365 fn new(buffer_count: u32, queue_count: u32) -> (Self, Vec<mpsc::UnboundedReceiver<u32>>) {
366 let buffers_per_queue = (buffer_count - RX_RESERVED_CONTROL_BUFFERS) / queue_count;
367 #[expect(clippy::disallowed_methods)] let (send, recv): (Vec<_>, Vec<_>) = (0..queue_count).map(|_| mpsc::unbounded()).unzip();
369 (
370 Self {
371 buffers_per_queue,
372 buffer_id_send: send,
373 },
374 recv,
375 )
376 }
377}
378
379struct RssState {
380 key: [u8; 40],
381 indirection_table: Vec<u16>,
382}
383
384struct NetChannel<T: RingMem> {
386 adapter: Arc<Adapter>,
387 queue: Queue<T>,
388 gpadl_map: GpadlMapView,
389 packet_size: usize,
390 pending_send_size: usize,
391 restart: Option<CoordinatorMessage>,
392 can_use_ring_size_opt: bool,
393 packet_filter: u32,
394}
395
396struct ProcessingData {
398 tx_segments: Vec<TxSegment>,
399 tx_done: Box<[TxId]>,
400 rx_ready: Box<[RxId]>,
401 rx_done: Vec<RxId>,
402 transfer_pages: Vec<ring::TransferPageRange>,
403}
404
405impl ProcessingData {
406 fn new() -> Self {
407 Self {
408 tx_segments: Vec::new(),
409 tx_done: vec![TxId(0); 8192].into(),
410 rx_ready: vec![RxId(0); RX_BATCH_SIZE].into(),
411 rx_done: Vec::with_capacity(RX_BATCH_SIZE),
412 transfer_pages: Vec::with_capacity(RX_BATCH_SIZE),
413 }
414 }
415}
416
417#[derive(Debug, Inspect)]
420struct ChannelBuffers {
421 version: Version,
422 #[inspect(skip)]
423 mem: GuestMemory,
424 #[inspect(skip)]
425 recv_buffer: ReceiveBuffer,
426 #[inspect(skip)]
427 send_buffer: Option<SendBuffer>,
428 ndis_version: NdisVersion,
429 #[inspect(safe)]
430 ndis_config: NdisConfig,
431}
432
433#[derive(Copy, Clone, Debug)]
435struct ControlMessageId(u32);
436
437struct ActiveState {
439 primary: Option<PrimaryChannelState>,
440
441 pending_tx_packets: Vec<PendingTxPacket>,
442 free_tx_packets: Vec<TxId>,
443 pending_tx_completions: VecDeque<PendingTxCompletion>,
444 pending_rx_packets: VecDeque<RxId>,
445
446 rx_bufs: RxBuffers,
447
448 stats: QueueStats,
449}
450
451#[derive(Inspect, Default)]
452struct QueueStats {
453 tx_stalled: Counter,
454 rx_dropped_ring_full: Counter,
455 rx_dropped_filtered: Counter,
456 spurious_wakes: Counter,
457 rx_packets: Counter,
458 tx_packets: Counter,
459 tx_lso_packets: Counter,
460 tx_checksum_packets: Counter,
461 tx_packets_per_wake: Histogram<10>,
462 rx_packets_per_wake: Histogram<10>,
463}
464
465#[derive(Debug)]
466struct PendingTxCompletion {
467 transaction_id: u64,
468 tx_id: Option<TxId>,
469 status: protocol::Status,
470}
471
472#[derive(Clone, Copy)]
473enum PrimaryChannelGuestVfState {
474 Initializing,
476 Restoring(saved_state::GuestVfState),
478 Unavailable,
480 UnavailableFromAvailable,
482 UnavailableFromDataPathSwitchPending { to_guest: bool, id: Option<u64> },
484 UnavailableFromDataPathSwitched,
486 Available { vfid: u32 },
488 AvailableAdvertised,
490 Ready,
492 DataPathSwitchPending {
494 to_guest: bool,
495 id: Option<u64>,
496 result: Option<bool>,
497 },
498 DataPathSwitched,
500 DataPathSynthetic,
503}
504
505impl std::fmt::Display for PrimaryChannelGuestVfState {
506 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
507 match self {
508 PrimaryChannelGuestVfState::Initializing => write!(f, "initializing"),
509 PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
510 write!(f, "restoring")
511 }
512 PrimaryChannelGuestVfState::Restoring(
513 saved_state::GuestVfState::AvailableAdvertised,
514 ) => write!(f, "restoring from guest notified of vfid"),
515 PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
516 write!(f, "restoring from vf present")
517 }
518 PrimaryChannelGuestVfState::Restoring(
519 saved_state::GuestVfState::DataPathSwitchPending {
520 to_guest, result, ..
521 },
522 ) => {
523 write!(
524 f,
525 "restoring from client requested data path switch: to {} {}",
526 if *to_guest { "guest" } else { "synthetic" },
527 if let Some(result) = result {
528 if *result { "succeeded\"" } else { "failed\"" }
529 } else {
530 "in progress\""
531 }
532 )
533 }
534 PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::DataPathSwitched) => {
535 write!(f, "restoring from data path in guest")
536 }
537 PrimaryChannelGuestVfState::Unavailable => write!(f, "unavailable"),
538 PrimaryChannelGuestVfState::UnavailableFromAvailable => {
539 write!(f, "\"unavailable (previously available)\"")
540 }
541 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. } => {
542 write!(f, "unavailable (previously switching data path)")
543 }
544 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
545 write!(f, "\"unavailable (previously using guest VF)\"")
546 }
547 PrimaryChannelGuestVfState::Available { vfid } => write!(f, "available vfid: {}", vfid),
548 PrimaryChannelGuestVfState::AvailableAdvertised => {
549 write!(f, "\"available, guest notified\"")
550 }
551 PrimaryChannelGuestVfState::Ready => write!(f, "\"available and present in guest\""),
552 PrimaryChannelGuestVfState::DataPathSwitchPending {
553 to_guest, result, ..
554 } => {
555 write!(
556 f,
557 "\"switching to {} {}",
558 if *to_guest { "guest" } else { "synthetic" },
559 if let Some(result) = result {
560 if *result { "succeeded\"" } else { "failed\"" }
561 } else {
562 "in progress\""
563 }
564 )
565 }
566 PrimaryChannelGuestVfState::DataPathSwitched => {
567 write!(f, "\"available and data path switched\"")
568 }
569 PrimaryChannelGuestVfState::DataPathSynthetic => write!(
570 f,
571 "\"available but data path switched back to synthetic due to external state change\""
572 ),
573 }
574 }
575}
576
577impl Inspect for PrimaryChannelGuestVfState {
578 fn inspect(&self, req: inspect::Request<'_>) {
579 req.value(self.to_string());
580 }
581}
582
583#[derive(Debug)]
584enum PendingLinkAction {
585 Default,
586 Active(bool),
587 Delay(bool),
588}
589
590struct PrimaryChannelState {
591 guest_vf_state: PrimaryChannelGuestVfState,
592 is_data_path_switched: Option<bool>,
593 control_messages: VecDeque<ControlMessage>,
594 control_messages_len: usize,
595 free_control_buffers: Vec<ControlMessageId>,
596 rss_state: Option<RssState>,
597 requested_num_queues: u16,
598 rndis_state: RndisState,
599 offload_config: OffloadConfig,
600 pending_offload_change: bool,
601 tx_spread_sent: bool,
602 guest_link_up: bool,
603 pending_link_action: PendingLinkAction,
604}
605
606impl Inspect for PrimaryChannelState {
607 fn inspect(&self, req: inspect::Request<'_>) {
608 req.respond()
609 .sensitivity_field(
610 "guest_vf_state",
611 SensitivityLevel::Safe,
612 self.guest_vf_state,
613 )
614 .sensitivity_field(
615 "data_path_switched",
616 SensitivityLevel::Safe,
617 self.is_data_path_switched,
618 )
619 .sensitivity_field(
620 "pending_control_messages",
621 SensitivityLevel::Safe,
622 self.control_messages.len(),
623 )
624 .sensitivity_field(
625 "free_control_message_buffers",
626 SensitivityLevel::Safe,
627 self.free_control_buffers.len(),
628 )
629 .sensitivity_field(
630 "pending_offload_change",
631 SensitivityLevel::Safe,
632 self.pending_offload_change,
633 )
634 .sensitivity_field("rndis_state", SensitivityLevel::Safe, self.rndis_state)
635 .sensitivity_field(
636 "offload_config",
637 SensitivityLevel::Safe,
638 &self.offload_config,
639 )
640 .sensitivity_field(
641 "tx_spread_sent",
642 SensitivityLevel::Safe,
643 self.tx_spread_sent,
644 )
645 .sensitivity_field("guest_link_up", SensitivityLevel::Safe, self.guest_link_up)
646 .sensitivity_field(
647 "pending_link_action",
648 SensitivityLevel::Safe,
649 match &self.pending_link_action {
650 PendingLinkAction::Active(up) => format!("Active({:x?})", up),
651 PendingLinkAction::Delay(up) => format!("Delay({:x?})", up),
652 PendingLinkAction::Default => "None".to_string(),
653 },
654 );
655 }
656}
657
658#[derive(Debug, Inspect, Clone)]
659struct OffloadConfig {
660 #[inspect(safe)]
661 checksum_tx: ChecksumOffloadConfig,
662 #[inspect(safe)]
663 checksum_rx: ChecksumOffloadConfig,
664 #[inspect(safe)]
665 lso4: bool,
666 #[inspect(safe)]
667 lso6: bool,
668}
669
670#[derive(Debug, Inspect, Clone)]
671struct ChecksumOffloadConfig {
672 #[inspect(safe)]
673 ipv4_header: bool,
674 #[inspect(safe)]
675 tcp4: bool,
676 #[inspect(safe)]
677 udp4: bool,
678 #[inspect(safe)]
679 tcp6: bool,
680 #[inspect(safe)]
681 udp6: bool,
682}
683
684impl ChecksumOffloadConfig {
685 fn flags(
686 &self,
687 ) -> (
688 rndisprot::Ipv4ChecksumOffload,
689 rndisprot::Ipv6ChecksumOffload,
690 ) {
691 let on = rndisprot::NDIS_OFFLOAD_SUPPORTED;
692 let mut v4 = rndisprot::Ipv4ChecksumOffload::new();
693 let mut v6 = rndisprot::Ipv6ChecksumOffload::new();
694 if self.ipv4_header {
695 v4.set_ip_options_supported(on);
696 v4.set_ip_checksum(on);
697 }
698 if self.tcp4 {
699 v4.set_ip_options_supported(on);
700 v4.set_tcp_options_supported(on);
701 v4.set_tcp_checksum(on);
702 }
703 if self.tcp6 {
704 v6.set_ip_extension_headers_supported(on);
705 v6.set_tcp_options_supported(on);
706 v6.set_tcp_checksum(on);
707 }
708 if self.udp4 {
709 v4.set_ip_options_supported(on);
710 v4.set_udp_checksum(on);
711 }
712 if self.udp6 {
713 v6.set_ip_extension_headers_supported(on);
714 v6.set_udp_checksum(on);
715 }
716 (v4, v6)
717 }
718}
719
720impl OffloadConfig {
721 fn ndis_offload(&self) -> rndisprot::NdisOffload {
722 let checksum = {
723 let (ipv4_tx_flags, ipv6_tx_flags) = self.checksum_tx.flags();
724 let (ipv4_rx_flags, ipv6_rx_flags) = self.checksum_rx.flags();
725 rndisprot::TcpIpChecksumOffload {
726 ipv4_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
727 ipv4_tx_flags,
728 ipv4_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
729 ipv4_rx_flags,
730 ipv6_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
731 ipv6_tx_flags,
732 ipv6_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
733 ipv6_rx_flags,
734 }
735 };
736
737 let lso_v2 = {
738 let mut lso = rndisprot::TcpLargeSendOffloadV2::new_zeroed();
739 const MAX_OFFLOAD_SIZE: u32 = 0xF53C;
741 if self.lso4 {
742 lso.ipv4_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
743 lso.ipv4_max_offload_size = MAX_OFFLOAD_SIZE;
744 lso.ipv4_min_segment_count = 2;
745 }
746 if self.lso6 {
747 lso.ipv6_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
748 lso.ipv6_max_offload_size = MAX_OFFLOAD_SIZE;
749 lso.ipv6_min_segment_count = 2;
750 lso.ipv6_flags = rndisprot::Ipv6LsoFlags::new()
751 .with_ip_extension_headers_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED)
752 .with_tcp_options_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED);
753 }
754 lso
755 };
756
757 rndisprot::NdisOffload {
758 header: rndisprot::NdisObjectHeader {
759 object_type: rndisprot::NdisObjectType::OFFLOAD,
760 revision: 3,
761 size: rndisprot::NDIS_SIZEOF_NDIS_OFFLOAD_REVISION_3 as u16,
762 },
763 checksum,
764 lso_v2,
765 ..FromZeros::new_zeroed()
766 }
767 }
768}
769
770#[derive(Debug, Inspect, PartialEq, Eq, Copy, Clone)]
771pub enum RndisState {
772 Initializing,
773 Operational,
774 Halted,
775}
776
777impl PrimaryChannelState {
778 fn new(offload_config: OffloadConfig) -> Self {
779 Self {
780 guest_vf_state: PrimaryChannelGuestVfState::Initializing,
781 is_data_path_switched: None,
782 control_messages: VecDeque::new(),
783 control_messages_len: 0,
784 free_control_buffers: (0..RX_RESERVED_CONTROL_BUFFERS)
785 .map(ControlMessageId)
786 .collect(),
787 rss_state: None,
788 requested_num_queues: 1,
789 rndis_state: RndisState::Initializing,
790 pending_offload_change: false,
791 offload_config,
792 tx_spread_sent: false,
793 guest_link_up: true,
794 pending_link_action: PendingLinkAction::Default,
795 }
796 }
797
798 fn restore(
799 guest_vf_state: &saved_state::GuestVfState,
800 rndis_state: &saved_state::RndisState,
801 offload_config: &saved_state::OffloadConfig,
802 pending_offload_change: bool,
803 num_queues: u16,
804 indirection_table_size: u16,
805 rx_bufs: &RxBuffers,
806 control_messages: Vec<saved_state::IncomingControlMessage>,
807 rss_state: Option<saved_state::RssState>,
808 tx_spread_sent: bool,
809 guest_link_down: bool,
810 pending_link_action: Option<bool>,
811 ) -> Result<Self, NetRestoreError> {
812 let control_messages_len = control_messages.iter().map(|msg| msg.data.len()).sum();
814
815 let control_messages = control_messages
816 .into_iter()
817 .map(|msg| ControlMessage {
818 message_type: msg.message_type,
819 data: msg.data.into(),
820 })
821 .collect();
822
823 let free_control_buffers = (0..RX_RESERVED_CONTROL_BUFFERS)
825 .filter_map(|id| rx_bufs.is_free(id).then_some(ControlMessageId(id)))
826 .collect();
827
828 let rss_state = rss_state
829 .map(|mut rss| {
830 if rss.indirection_table.len() > indirection_table_size as usize {
831 return Err(NetRestoreError::ReducedIndirectionTableSize);
834 }
835 if rss.indirection_table.len() < indirection_table_size as usize {
836 tracing::warn!(
837 saved_indirection_table_size = rss.indirection_table.len(),
838 adapter_indirection_table_size = indirection_table_size,
839 "increasing indirection table size",
840 );
841 let table_clone = rss.indirection_table.clone();
844 let num_to_add = indirection_table_size as usize - rss.indirection_table.len();
845 rss.indirection_table
846 .extend(table_clone.iter().cycle().take(num_to_add));
847 }
848 Ok(RssState {
849 key: rss
850 .key
851 .try_into()
852 .map_err(|_| NetRestoreError::InvalidRssKeySize)?,
853 indirection_table: rss.indirection_table,
854 })
855 })
856 .transpose()?;
857
858 let rndis_state = match rndis_state {
859 saved_state::RndisState::Initializing => RndisState::Initializing,
860 saved_state::RndisState::Operational => RndisState::Operational,
861 saved_state::RndisState::Halted => RndisState::Halted,
862 };
863
864 let guest_vf_state = PrimaryChannelGuestVfState::Restoring(*guest_vf_state);
865 let offload_config = OffloadConfig {
866 checksum_tx: ChecksumOffloadConfig {
867 ipv4_header: offload_config.checksum_tx.ipv4_header,
868 tcp4: offload_config.checksum_tx.tcp4,
869 udp4: offload_config.checksum_tx.udp4,
870 tcp6: offload_config.checksum_tx.tcp6,
871 udp6: offload_config.checksum_tx.udp6,
872 },
873 checksum_rx: ChecksumOffloadConfig {
874 ipv4_header: offload_config.checksum_rx.ipv4_header,
875 tcp4: offload_config.checksum_rx.tcp4,
876 udp4: offload_config.checksum_rx.udp4,
877 tcp6: offload_config.checksum_rx.tcp6,
878 udp6: offload_config.checksum_rx.udp6,
879 },
880 lso4: offload_config.lso4,
881 lso6: offload_config.lso6,
882 };
883
884 let pending_link_action = if let Some(pending) = pending_link_action {
885 PendingLinkAction::Active(pending)
886 } else {
887 PendingLinkAction::Default
888 };
889
890 Ok(Self {
891 guest_vf_state,
892 is_data_path_switched: None,
893 control_messages,
894 control_messages_len,
895 free_control_buffers,
896 rss_state,
897 requested_num_queues: num_queues,
898 rndis_state,
899 pending_offload_change,
900 offload_config,
901 tx_spread_sent,
902 guest_link_up: !guest_link_down,
903 pending_link_action,
904 })
905 }
906}
907
908struct ControlMessage {
909 message_type: u32,
910 data: Box<[u8]>,
911}
912
913const TX_PACKET_QUOTA: usize = 1024;
914
915impl ActiveState {
916 fn new(primary: Option<PrimaryChannelState>, recv_buffer_count: u32) -> Self {
917 Self {
918 primary,
919 pending_tx_packets: vec![Default::default(); TX_PACKET_QUOTA],
920 free_tx_packets: (0..TX_PACKET_QUOTA as u32).rev().map(TxId).collect(),
921 pending_tx_completions: VecDeque::new(),
922 pending_rx_packets: VecDeque::new(),
923 rx_bufs: RxBuffers::new(recv_buffer_count),
924 stats: Default::default(),
925 }
926 }
927
928 fn restore(
929 channel: &saved_state::Channel,
930 recv_buffer_count: u32,
931 ) -> Result<Self, NetRestoreError> {
932 let mut active = Self::new(None, recv_buffer_count);
933 let saved_state::Channel {
934 pending_tx_completions,
935 in_use_rx,
936 } = channel;
937 for rx in in_use_rx {
938 active
939 .rx_bufs
940 .allocate(rx.buffers.as_slice().iter().copied())?;
941 }
942 for &transaction_id in pending_tx_completions {
943 let tx_id = active.free_tx_packets.pop();
947 if let Some(id) = tx_id {
948 active.pending_tx_packets[id.0 as usize].transaction_id = transaction_id;
950 }
951 active
954 .pending_tx_completions
955 .push_back(PendingTxCompletion {
956 transaction_id,
957 tx_id,
958 status: protocol::Status::SUCCESS,
959 });
960 }
961 Ok(active)
962 }
963}
964
965#[derive(Default, Clone)]
968struct PendingTxPacket {
969 pending_packet_count: usize,
970 transaction_id: u64,
971}
972
973const RX_BATCH_SIZE: usize = 375;
978
979const RX_RESERVED_CONTROL_BUFFERS: u32 = 16;
981
982pub struct Nic {
984 instance_id: Guid,
985 resources: DeviceResources,
986 coordinator: TaskControl<CoordinatorState, Coordinator>,
987 coordinator_send: Option<mpsc::Sender<CoordinatorMessage>>,
988 adapter: Arc<Adapter>,
989 driver_source: VmTaskDriverSource,
990}
991
992pub struct NicBuilder {
993 virtual_function: Option<Box<dyn VirtualFunction>>,
994 limit_ring_buffer: bool,
995 max_queues: u16,
996 get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
997}
998
999impl NicBuilder {
1000 pub fn limit_ring_buffer(mut self, limit: bool) -> Self {
1001 self.limit_ring_buffer = limit;
1002 self
1003 }
1004
1005 pub fn max_queues(mut self, max_queues: u16) -> Self {
1006 self.max_queues = max_queues;
1007 self
1008 }
1009
1010 pub fn virtual_function(mut self, virtual_function: Box<dyn VirtualFunction>) -> Self {
1011 self.virtual_function = Some(virtual_function);
1012 self
1013 }
1014
1015 pub fn get_guest_os_id(mut self, os_type: Box<dyn Fn() -> HvGuestOsId + Send + Sync>) -> Self {
1016 self.get_guest_os_id = Some(os_type);
1017 self
1018 }
1019
1020 pub fn build(
1022 self,
1023 driver_source: &VmTaskDriverSource,
1024 instance_id: Guid,
1025 endpoint: Box<dyn Endpoint>,
1026 mac_address: MacAddress,
1027 adapter_index: u32,
1028 ) -> Nic {
1029 let multiqueue = endpoint.multiqueue_support();
1030
1031 let max_queues = self.max_queues.clamp(
1032 1,
1033 multiqueue.max_queues.min(NETVSP_MAX_SUBCHANNELS_PER_VNIC),
1034 );
1035
1036 let ring_size_limit = if self.limit_ring_buffer { 1024 } else { 0 };
1041
1042 let free_tx_packet_threshold = if endpoint.tx_fast_completions() {
1047 TX_PACKET_QUOTA
1048 } else {
1049 TX_PACKET_QUOTA / 4
1052 };
1053
1054 let tx_offloads = endpoint.tx_offload_support();
1055
1056 let offload_support = OffloadConfig {
1059 checksum_rx: ChecksumOffloadConfig {
1060 ipv4_header: true,
1061 tcp4: true,
1062 udp4: true,
1063 tcp6: true,
1064 udp6: true,
1065 },
1066 checksum_tx: ChecksumOffloadConfig {
1067 ipv4_header: tx_offloads.ipv4_header,
1068 tcp4: tx_offloads.tcp,
1069 tcp6: tx_offloads.tcp,
1070 udp4: tx_offloads.udp,
1071 udp6: tx_offloads.udp,
1072 },
1073 lso4: tx_offloads.tso,
1074 lso6: tx_offloads.tso,
1075 };
1076
1077 let driver = driver_source.simple();
1078 let adapter = Arc::new(Adapter {
1079 driver,
1080 mac_address,
1081 max_queues,
1082 indirection_table_size: multiqueue.indirection_table_size,
1083 offload_support,
1084 free_tx_packet_threshold,
1085 ring_size_limit: ring_size_limit.into(),
1086 tx_fast_completions: endpoint.tx_fast_completions(),
1087 adapter_index,
1088 get_guest_os_id: self.get_guest_os_id,
1089 num_sub_channels_opened: AtomicUsize::new(0),
1090 link_speed: endpoint.link_speed(),
1091 });
1092
1093 let coordinator = TaskControl::new(CoordinatorState {
1094 endpoint,
1095 adapter: adapter.clone(),
1096 virtual_function: self.virtual_function,
1097 pending_vf_state: CoordinatorStatePendingVfState::Ready,
1098 });
1099
1100 Nic {
1101 instance_id,
1102 resources: Default::default(),
1103 coordinator,
1104 coordinator_send: None,
1105 adapter,
1106 driver_source: driver_source.clone(),
1107 }
1108 }
1109}
1110
1111fn can_use_ring_opt<T: RingMem>(queue: &mut Queue<T>, guest_os_id: Option<HvGuestOsId>) -> bool {
1112 let Some(guest_os_id) = guest_os_id else {
1113 return false;
1115 };
1116
1117 if !queue.split().0.supports_pending_send_size() {
1118 return false;
1120 }
1121
1122 let Some(open_source_os) = guest_os_id.open_source() else {
1123 return true;
1125 };
1126
1127 match HvGuestOsOpenSourceType(open_source_os.os_type()) {
1128 HvGuestOsOpenSourceType::FREEBSD => open_source_os.version() >= 1400097,
1131 HvGuestOsOpenSourceType::LINUX => {
1135 open_source_os.version() >= 199424
1138 }
1139 _ => true,
1140 }
1141}
1142
1143impl Nic {
1144 pub fn builder() -> NicBuilder {
1145 NicBuilder {
1146 virtual_function: None,
1147 limit_ring_buffer: false,
1148 max_queues: !0,
1149 get_guest_os_id: None,
1150 }
1151 }
1152
1153 pub fn shutdown(self) -> Box<dyn Endpoint> {
1154 let (state, _) = self.coordinator.into_inner();
1155 state.endpoint
1156 }
1157}
1158
1159impl InspectMut for Nic {
1160 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
1161 self.coordinator.inspect_mut(req);
1162 }
1163}
1164
1165#[async_trait]
1166impl VmbusDevice for Nic {
1167 fn offer(&self) -> OfferParams {
1168 OfferParams {
1169 interface_name: "net".to_owned(),
1170 instance_id: self.instance_id,
1171 interface_id: Guid {
1172 data1: 0xf8615163,
1173 data2: 0xdf3e,
1174 data3: 0x46c5,
1175 data4: [0x91, 0x3f, 0xf2, 0xd2, 0xf9, 0x65, 0xed, 0xe],
1176 },
1177 subchannel_index: 0,
1178 mnf_interrupt_latency: Some(Duration::from_micros(100)),
1179 ..Default::default()
1180 }
1181 }
1182
1183 fn max_subchannels(&self) -> u16 {
1184 self.adapter.max_queues
1185 }
1186
1187 fn install(&mut self, resources: DeviceResources) {
1188 self.resources = resources;
1189 }
1190
1191 async fn open(
1192 &mut self,
1193 channel_idx: u16,
1194 open_request: &OpenRequest,
1195 ) -> Result<(), ChannelOpenError> {
1196 let state = if channel_idx == 0 {
1198 self.insert_coordinator(1, false);
1199 WorkerState::Init(None)
1200 } else {
1201 self.coordinator.stop().await;
1202 let buffers = self.coordinator.state().unwrap().buffers.clone().unwrap();
1204 WorkerState::Ready(ReadyState {
1205 state: ActiveState::new(None, buffers.recv_buffer.count),
1206 buffers,
1207 data: ProcessingData::new(),
1208 })
1209 };
1210
1211 let num_opened = self
1212 .adapter
1213 .num_sub_channels_opened
1214 .fetch_add(1, Ordering::SeqCst);
1215 let r = self.insert_worker(channel_idx, open_request, state, true);
1216 if channel_idx != 0
1217 && num_opened + 1 == self.coordinator.state_mut().unwrap().num_queues as usize
1218 {
1219 let coordinator = &mut self.coordinator.state_mut().unwrap();
1220 coordinator.workers[0].stop().await;
1221 }
1222
1223 if r.is_err() && channel_idx == 0 {
1224 self.coordinator.remove();
1225 } else {
1226 self.coordinator.start();
1228 }
1229 r?;
1230 Ok(())
1231 }
1232
1233 async fn close(&mut self, channel_idx: u16) {
1234 if !self.coordinator.has_state() {
1235 tracing::error!("Close called while vmbus channel is already closed");
1236 return;
1237 }
1238
1239 let restart = self.coordinator.stop().await;
1241
1242 {
1244 let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1245 worker.stop().await;
1246 if worker.has_state() {
1247 worker.remove();
1248 }
1249 }
1250
1251 self.adapter
1252 .num_sub_channels_opened
1253 .fetch_sub(1, Ordering::SeqCst);
1254 if channel_idx == 0 {
1256 for worker in &mut self.coordinator.state_mut().unwrap().workers {
1257 worker.task_mut().queue_state = None;
1258 }
1259
1260 self.coordinator.task_mut().endpoint.stop().await;
1262
1263 self.coordinator.remove();
1268 } else {
1269 if restart {
1271 self.coordinator.start();
1272 }
1273 }
1274 }
1275
1276 async fn retarget_vp(&mut self, channel_idx: u16, target_vp: u32) {
1277 if !self.coordinator.has_state() {
1278 return;
1279 }
1280
1281 let coordinator_running = self.coordinator.stop().await;
1283 let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1284 worker.stop().await;
1285 let (net_queue, worker_state) = worker.get_mut();
1286
1287 net_queue.driver.retarget_vp(target_vp);
1289
1290 if let Some(worker_state) = worker_state {
1291 worker_state.target_vp = target_vp;
1293 if let Some(queue_state) = &mut net_queue.queue_state {
1294 queue_state.target_vp_set = false;
1296 }
1297 }
1298
1299 if coordinator_running {
1301 self.coordinator.start();
1302 }
1303 }
1304
1305 fn start(&mut self) {
1306 if !self.coordinator.is_running() {
1307 self.coordinator.start();
1308 }
1309 }
1310
1311 async fn stop(&mut self) {
1312 self.coordinator.stop().await;
1313 if let Some(coordinator) = self.coordinator.state_mut() {
1314 coordinator.stop_workers().await;
1315 }
1316 }
1317
1318 fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
1319 Some(self)
1320 }
1321}
1322
1323#[async_trait]
1324impl SaveRestoreVmbusDevice for Nic {
1325 async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
1326 let state = self.saved_state();
1327 Ok(SavedStateBlob::new(state))
1328 }
1329
1330 async fn restore(
1331 &mut self,
1332 control: RestoreControl<'_>,
1333 state: SavedStateBlob,
1334 ) -> Result<(), RestoreError> {
1335 let state: saved_state::SavedState = state.parse()?;
1336 if let Err(err) = self.restore_state(control, state).await {
1337 tracing::error!(err = &err as &dyn std::error::Error, instance_id = %self.instance_id, "Failed restoring network vmbus state");
1338 Err(err.into())
1339 } else {
1340 Ok(())
1341 }
1342 }
1343}
1344
1345impl Nic {
1346 fn insert_worker(
1350 &mut self,
1351 channel_idx: u16,
1352 open_request: &OpenRequest,
1353 state: WorkerState,
1354 start: bool,
1355 ) -> Result<(), OpenError> {
1356 let coordinator = self.coordinator.state_mut().unwrap();
1357
1358 let driver = coordinator.workers[channel_idx as usize]
1360 .task()
1361 .driver
1362 .clone();
1363 driver.retarget_vp(open_request.open_data.target_vp);
1364
1365 let raw = gpadl_channel(&driver, &self.resources, open_request, channel_idx)
1366 .map_err(OpenError::Ring)?;
1367 let mut queue = Queue::new(raw).map_err(OpenError::Queue)?;
1368 let guest_os_id = self.adapter.get_guest_os_id.as_ref().map(|f| f());
1369 let can_use_ring_size_opt = can_use_ring_opt(&mut queue, guest_os_id);
1370 let worker = Worker {
1371 channel_idx,
1372 target_vp: open_request.open_data.target_vp,
1373 mem: self
1374 .resources
1375 .offer_resources
1376 .guest_memory(open_request)
1377 .clone(),
1378 channel: NetChannel {
1379 adapter: self.adapter.clone(),
1380 queue,
1381 gpadl_map: self.resources.gpadl_map.clone(),
1382 packet_size: protocol::PACKET_SIZE_V1,
1383 pending_send_size: 0,
1384 restart: None,
1385 can_use_ring_size_opt,
1386 packet_filter: rndisprot::NDIS_PACKET_TYPE_NONE,
1387 },
1388 state,
1389 coordinator_send: self.coordinator_send.clone().unwrap(),
1390 };
1391 let instance_id = self.instance_id;
1392 let worker_task = &mut coordinator.workers[channel_idx as usize];
1393 worker_task.insert(
1394 driver,
1395 format!("netvsp-{}-{}", instance_id, channel_idx),
1396 worker,
1397 );
1398 if start {
1399 worker_task.start();
1400 }
1401 Ok(())
1402 }
1403}
1404
1405impl Nic {
1406 fn insert_coordinator(&mut self, num_queues: u16, restoring: bool) {
1408 let mut driver_builder = self.driver_source.builder();
1409 driver_builder.target_vp(0);
1412 driver_builder.run_on_target(!self.adapter.tx_fast_completions);
1417
1418 #[expect(clippy::disallowed_methods)] let (send, recv) = mpsc::channel(1);
1420 self.coordinator_send = Some(send);
1421 self.coordinator.insert(
1422 &self.adapter.driver,
1423 format!("netvsp-{}-coordinator", self.instance_id),
1424 Coordinator {
1425 recv,
1426 channel_control: self.resources.channel_control.clone(),
1427 restart: restoring,
1428 workers: (0..self.adapter.max_queues)
1429 .map(|i| {
1430 TaskControl::new(NetQueue {
1431 queue_state: None,
1432 driver: driver_builder
1433 .build(format!("netvsp-{}-{}", self.instance_id, i)),
1434 })
1435 })
1436 .collect(),
1437 buffers: None,
1438 num_queues,
1439 },
1440 );
1441 }
1442}
1443
1444#[derive(Debug, Error)]
1445enum NetRestoreError {
1446 #[error("unsupported protocol version {0:#x}")]
1447 UnsupportedVersion(u32),
1448 #[error("send/receive buffer invalid gpadl ID")]
1449 UnknownGpadlId(#[from] UnknownGpadlId),
1450 #[error("failed to restore channels")]
1451 Channel(#[source] ChannelRestoreError),
1452 #[error(transparent)]
1453 ReceiveBuffer(#[from] BufferError),
1454 #[error(transparent)]
1455 SuballocationMisconfigured(#[from] SubAllocationInUse),
1456 #[error(transparent)]
1457 Open(#[from] OpenError),
1458 #[error("invalid rss key size")]
1459 InvalidRssKeySize,
1460 #[error("reduced indirection table size")]
1461 ReducedIndirectionTableSize,
1462}
1463
1464impl From<NetRestoreError> for RestoreError {
1465 fn from(err: NetRestoreError) -> Self {
1466 RestoreError::InvalidSavedState(anyhow::Error::new(err))
1467 }
1468}
1469
1470impl Nic {
1471 async fn restore_state(
1472 &mut self,
1473 mut control: RestoreControl<'_>,
1474 state: saved_state::SavedState,
1475 ) -> Result<(), NetRestoreError> {
1476 let mut saved_packet_filter = 0u32;
1477 if let Some(state) = state.open {
1478 let open = match &state.primary {
1479 saved_state::Primary::Version => vec![true],
1480 saved_state::Primary::Init(_) => vec![true],
1481 saved_state::Primary::Ready(ready) => {
1482 ready.channels.iter().map(|x| x.is_some()).collect()
1483 }
1484 };
1485
1486 let mut states: Vec<_> = open.iter().map(|_| None).collect();
1487
1488 let requests = control
1494 .restore(&open)
1495 .await
1496 .map_err(NetRestoreError::Channel)?;
1497
1498 match state.primary {
1499 saved_state::Primary::Version => {
1500 states[0] = Some(WorkerState::Init(None));
1501 }
1502 saved_state::Primary::Init(init) => {
1503 let version = check_version(init.version)
1504 .ok_or(NetRestoreError::UnsupportedVersion(init.version))?;
1505
1506 let recv_buffer = init
1507 .receive_buffer
1508 .map(|recv_buffer| {
1509 ReceiveBuffer::new(
1510 &self.resources.gpadl_map,
1511 recv_buffer.gpadl_id,
1512 recv_buffer.id,
1513 recv_buffer.sub_allocation_size,
1514 )
1515 })
1516 .transpose()?;
1517
1518 let send_buffer = init
1519 .send_buffer
1520 .map(|send_buffer| {
1521 SendBuffer::new(&self.resources.gpadl_map, send_buffer.gpadl_id)
1522 })
1523 .transpose()?;
1524
1525 let state = InitState {
1526 version,
1527 ndis_config: init.ndis_config.map(
1528 |saved_state::NdisConfig { mtu, capabilities }| NdisConfig {
1529 mtu,
1530 capabilities: capabilities.into(),
1531 },
1532 ),
1533 ndis_version: init.ndis_version.map(
1534 |saved_state::NdisVersion { major, minor }| NdisVersion {
1535 major,
1536 minor,
1537 },
1538 ),
1539 recv_buffer,
1540 send_buffer,
1541 };
1542 states[0] = Some(WorkerState::Init(Some(state)));
1543 }
1544 saved_state::Primary::Ready(ready) => {
1545 let saved_state::ReadyPrimary {
1546 version,
1547 receive_buffer,
1548 send_buffer,
1549 mut control_messages,
1550 mut rss_state,
1551 channels,
1552 ndis_version,
1553 ndis_config,
1554 rndis_state,
1555 guest_vf_state,
1556 offload_config,
1557 pending_offload_change,
1558 tx_spread_sent,
1559 guest_link_down,
1560 pending_link_action,
1561 packet_filter,
1562 } = ready;
1563
1564 saved_packet_filter = packet_filter.unwrap_or(rndisprot::NPROTO_PACKET_FILTER);
1566
1567 let version = check_version(version)
1568 .ok_or(NetRestoreError::UnsupportedVersion(version))?;
1569
1570 let request = requests[0].as_ref().unwrap();
1571 let buffers = Arc::new(ChannelBuffers {
1572 version,
1573 mem: self.resources.offer_resources.guest_memory(request).clone(),
1574 recv_buffer: ReceiveBuffer::new(
1575 &self.resources.gpadl_map,
1576 receive_buffer.gpadl_id,
1577 receive_buffer.id,
1578 receive_buffer.sub_allocation_size,
1579 )?,
1580 send_buffer: {
1581 if let Some(send_buffer) = send_buffer {
1582 Some(SendBuffer::new(
1583 &self.resources.gpadl_map,
1584 send_buffer.gpadl_id,
1585 )?)
1586 } else {
1587 None
1588 }
1589 },
1590 ndis_version: {
1591 let saved_state::NdisVersion { major, minor } = ndis_version;
1592 NdisVersion { major, minor }
1593 },
1594 ndis_config: {
1595 let saved_state::NdisConfig { mtu, capabilities } = ndis_config;
1596 NdisConfig {
1597 mtu,
1598 capabilities: capabilities.into(),
1599 }
1600 },
1601 });
1602
1603 for (channel_idx, channel) in channels.iter().enumerate() {
1604 let channel = if let Some(channel) = channel {
1605 channel
1606 } else {
1607 continue;
1608 };
1609
1610 let mut active = ActiveState::restore(channel, buffers.recv_buffer.count)?;
1611
1612 if channel_idx == 0 {
1614 let primary = PrimaryChannelState::restore(
1615 &guest_vf_state,
1616 &rndis_state,
1617 &offload_config,
1618 pending_offload_change,
1619 channels.len() as u16,
1620 self.adapter.indirection_table_size,
1621 &active.rx_bufs,
1622 std::mem::take(&mut control_messages),
1623 rss_state.take(),
1624 tx_spread_sent,
1625 guest_link_down,
1626 pending_link_action,
1627 )?;
1628 active.primary = Some(primary);
1629 }
1630
1631 states[channel_idx] = Some(WorkerState::Ready(ReadyState {
1632 buffers: buffers.clone(),
1633 state: active,
1634 data: ProcessingData::new(),
1635 }));
1636 }
1637 }
1638 }
1639
1640 self.insert_coordinator(states.len() as u16, true);
1643
1644 for (channel_idx, (state, request)) in states.into_iter().zip(requests).enumerate() {
1645 if let Some(state) = state {
1646 self.insert_worker(channel_idx as u16, &request.unwrap(), state, false)?;
1647 }
1648 }
1649 for worker in self.coordinator.state_mut().unwrap().workers.iter_mut() {
1650 if let Some(worker_state) = worker.state_mut() {
1651 worker_state.channel.packet_filter = saved_packet_filter;
1652 }
1653 }
1654 } else {
1655 control
1656 .restore(&[false])
1657 .await
1658 .map_err(NetRestoreError::Channel)?;
1659 }
1660 Ok(())
1661 }
1662
1663 fn saved_state(&self) -> saved_state::SavedState {
1664 let open = if let Some(coordinator) = self.coordinator.state() {
1665 let primary = coordinator.workers[0].state().unwrap();
1666 let primary = match &primary.state {
1667 WorkerState::Init(None) => saved_state::Primary::Version,
1668 WorkerState::Init(Some(init)) => {
1669 saved_state::Primary::Init(saved_state::InitPrimary {
1670 version: init.version as u32,
1671 ndis_config: init.ndis_config.map(|NdisConfig { mtu, capabilities }| {
1672 saved_state::NdisConfig {
1673 mtu,
1674 capabilities: capabilities.into(),
1675 }
1676 }),
1677 ndis_version: init.ndis_version.map(|NdisVersion { major, minor }| {
1678 saved_state::NdisVersion { major, minor }
1679 }),
1680 receive_buffer: init.recv_buffer.as_ref().map(|x| x.saved_state()),
1681 send_buffer: init.send_buffer.as_ref().map(|x| saved_state::SendBuffer {
1682 gpadl_id: x.gpadl.id(),
1683 }),
1684 })
1685 }
1686 WorkerState::Ready(ready) => {
1687 let primary = ready.state.primary.as_ref().unwrap();
1688
1689 let rndis_state = match primary.rndis_state {
1690 RndisState::Initializing => saved_state::RndisState::Initializing,
1691 RndisState::Operational => saved_state::RndisState::Operational,
1692 RndisState::Halted => saved_state::RndisState::Halted,
1693 };
1694
1695 let offload_config = saved_state::OffloadConfig {
1696 checksum_tx: saved_state::ChecksumOffloadConfig {
1697 ipv4_header: primary.offload_config.checksum_tx.ipv4_header,
1698 tcp4: primary.offload_config.checksum_tx.tcp4,
1699 udp4: primary.offload_config.checksum_tx.udp4,
1700 tcp6: primary.offload_config.checksum_tx.tcp6,
1701 udp6: primary.offload_config.checksum_tx.udp6,
1702 },
1703 checksum_rx: saved_state::ChecksumOffloadConfig {
1704 ipv4_header: primary.offload_config.checksum_rx.ipv4_header,
1705 tcp4: primary.offload_config.checksum_rx.tcp4,
1706 udp4: primary.offload_config.checksum_rx.udp4,
1707 tcp6: primary.offload_config.checksum_rx.tcp6,
1708 udp6: primary.offload_config.checksum_rx.udp6,
1709 },
1710 lso4: primary.offload_config.lso4,
1711 lso6: primary.offload_config.lso6,
1712 };
1713
1714 let control_messages = primary
1715 .control_messages
1716 .iter()
1717 .map(|message| saved_state::IncomingControlMessage {
1718 message_type: message.message_type,
1719 data: message.data.to_vec(),
1720 })
1721 .collect();
1722
1723 let rss_state = primary.rss_state.as_ref().map(|rss| saved_state::RssState {
1724 key: rss.key.into(),
1725 indirection_table: rss.indirection_table.clone(),
1726 });
1727
1728 let pending_link_action = match primary.pending_link_action {
1729 PendingLinkAction::Default => None,
1730 PendingLinkAction::Active(action) | PendingLinkAction::Delay(action) => {
1731 Some(action)
1732 }
1733 };
1734
1735 let channels = coordinator.workers[..coordinator.num_queues as usize]
1736 .iter()
1737 .map(|worker| {
1738 worker.state().map(|worker| {
1739 if let Some(ready) = worker.state.ready() {
1740 let pending_tx_completions = ready
1743 .state
1744 .pending_tx_completions
1745 .iter()
1746 .map(|pending| pending.transaction_id)
1747 .chain(ready.state.pending_tx_packets.iter().filter_map(
1748 |inflight| {
1749 (inflight.pending_packet_count > 0)
1750 .then_some(inflight.transaction_id)
1751 },
1752 ))
1753 .collect();
1754
1755 saved_state::Channel {
1756 pending_tx_completions,
1757 in_use_rx: {
1758 ready
1759 .state
1760 .rx_bufs
1761 .allocated()
1762 .map(|id| saved_state::Rx {
1763 buffers: id.collect(),
1764 })
1765 .collect()
1766 },
1767 }
1768 } else {
1769 saved_state::Channel {
1770 pending_tx_completions: Vec::new(),
1771 in_use_rx: Vec::new(),
1772 }
1773 }
1774 })
1775 })
1776 .collect();
1777
1778 let guest_vf_state = match primary.guest_vf_state {
1779 PrimaryChannelGuestVfState::Initializing
1780 | PrimaryChannelGuestVfState::Unavailable
1781 | PrimaryChannelGuestVfState::Available { .. } => {
1782 saved_state::GuestVfState::NoState
1783 }
1784 PrimaryChannelGuestVfState::UnavailableFromAvailable
1785 | PrimaryChannelGuestVfState::AvailableAdvertised => {
1786 saved_state::GuestVfState::AvailableAdvertised
1787 }
1788 PrimaryChannelGuestVfState::Ready => saved_state::GuestVfState::Ready,
1789 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
1790 to_guest,
1791 id,
1792 } => saved_state::GuestVfState::DataPathSwitchPending {
1793 to_guest,
1794 id,
1795 result: None,
1796 },
1797 PrimaryChannelGuestVfState::DataPathSwitchPending {
1798 to_guest,
1799 id,
1800 result,
1801 } => saved_state::GuestVfState::DataPathSwitchPending {
1802 to_guest,
1803 id,
1804 result,
1805 },
1806 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
1807 | PrimaryChannelGuestVfState::DataPathSwitched
1808 | PrimaryChannelGuestVfState::DataPathSynthetic => {
1809 saved_state::GuestVfState::DataPathSwitched
1810 }
1811 PrimaryChannelGuestVfState::Restoring(saved_state) => saved_state,
1812 };
1813
1814 let worker_0_packet_filter = coordinator.workers[0]
1815 .state()
1816 .unwrap()
1817 .channel
1818 .packet_filter;
1819 saved_state::Primary::Ready(saved_state::ReadyPrimary {
1820 version: ready.buffers.version as u32,
1821 receive_buffer: ready.buffers.recv_buffer.saved_state(),
1822 send_buffer: ready.buffers.send_buffer.as_ref().map(|sb| {
1823 saved_state::SendBuffer {
1824 gpadl_id: sb.gpadl.id(),
1825 }
1826 }),
1827 rndis_state,
1828 guest_vf_state,
1829 offload_config,
1830 pending_offload_change: primary.pending_offload_change,
1831 control_messages,
1832 rss_state,
1833 channels,
1834 ndis_config: {
1835 let NdisConfig { mtu, capabilities } = ready.buffers.ndis_config;
1836 saved_state::NdisConfig {
1837 mtu,
1838 capabilities: capabilities.into(),
1839 }
1840 },
1841 ndis_version: {
1842 let NdisVersion { major, minor } = ready.buffers.ndis_version;
1843 saved_state::NdisVersion { major, minor }
1844 },
1845 tx_spread_sent: primary.tx_spread_sent,
1846 guest_link_down: !primary.guest_link_up,
1847 pending_link_action,
1848 packet_filter: Some(worker_0_packet_filter),
1849 })
1850 }
1851 };
1852
1853 let state = saved_state::OpenState { primary };
1854 Some(state)
1855 } else {
1856 None
1857 };
1858
1859 saved_state::SavedState { open }
1860 }
1861}
1862
1863#[derive(Debug, Error)]
1864enum WorkerError {
1865 #[error("packet error")]
1866 Packet(#[source] PacketError),
1867 #[error("unexpected packet order: {0}")]
1868 UnexpectedPacketOrder(#[source] PacketOrderError),
1869 #[error("unknown rndis message type: {0}")]
1870 UnknownRndisMessageType(u32),
1871 #[error("memory access error")]
1872 Access(#[from] AccessError),
1873 #[error("rndis message too small")]
1874 RndisMessageTooSmall,
1875 #[error("unsupported rndis behavior")]
1876 UnsupportedRndisBehavior,
1877 #[error("vmbus queue error")]
1878 Queue(#[from] queue::Error),
1879 #[error("too many control messages")]
1880 TooManyControlMessages,
1881 #[error("invalid rndis packet completion")]
1882 InvalidRndisPacketCompletion,
1883 #[error("missing transaction id")]
1884 MissingTransactionId,
1885 #[error("invalid gpadl")]
1886 InvalidGpadl(#[source] guestmem::InvalidGpn),
1887 #[error("gpadl error")]
1888 GpadlError(#[source] GuestMemoryError),
1889 #[error("gpa direct error")]
1890 GpaDirectError(#[source] GuestMemoryError),
1891 #[error("endpoint")]
1892 Endpoint(#[source] anyhow::Error),
1893 #[error("message not supported on sub channel: {0}")]
1894 NotSupportedOnSubChannel(u32),
1895 #[error("the ring buffer ran out of space, which should not be possible")]
1896 OutOfSpace,
1897 #[error("send/receive buffer error")]
1898 Buffer(#[from] BufferError),
1899 #[error("invalid rndis state")]
1900 InvalidRndisState,
1901 #[error("rndis message type not implemented")]
1902 RndisMessageTypeNotImplemented,
1903 #[error("invalid TCP header offset")]
1904 InvalidTcpHeaderOffset,
1905 #[error("cancelled")]
1906 Cancelled(task_control::Cancelled),
1907 #[error("tearing down because send/receive buffer is revoked")]
1908 BufferRevoked,
1909 #[error("endpoint requires queue restart: {0}")]
1910 EndpointRequiresQueueRestart(#[source] anyhow::Error),
1911}
1912
1913impl From<task_control::Cancelled> for WorkerError {
1914 fn from(value: task_control::Cancelled) -> Self {
1915 Self::Cancelled(value)
1916 }
1917}
1918
1919#[derive(Debug, Error)]
1920enum OpenError {
1921 #[error("error establishing ring buffer")]
1922 Ring(#[source] vmbus_channel::gpadl_ring::Error),
1923 #[error("error establishing vmbus queue")]
1924 Queue(#[source] queue::Error),
1925}
1926
1927#[derive(Debug, Error)]
1928enum PacketError {
1929 #[error("UnknownType {0}")]
1930 UnknownType(u32),
1931 #[error("Access")]
1932 Access(#[source] AccessError),
1933 #[error("ExternalData")]
1934 ExternalData(#[source] ExternalDataError),
1935 #[error("InvalidSendBufferIndex")]
1936 InvalidSendBufferIndex,
1937}
1938
1939#[derive(Debug, Error)]
1940enum PacketOrderError {
1941 #[error("Invalid PacketData")]
1942 InvalidPacketData,
1943 #[error("Unexpected RndisPacket")]
1944 UnexpectedRndisPacket,
1945 #[error("SendNdisVersion already exists")]
1946 SendNdisVersionExists,
1947 #[error("SendNdisConfig already exists")]
1948 SendNdisConfigExists,
1949 #[error("SendReceiveBuffer already exists")]
1950 SendReceiveBufferExists,
1951 #[error("SendReceiveBuffer missing MTU")]
1952 SendReceiveBufferMissingMTU,
1953 #[error("SendSendBuffer already exists")]
1954 SendSendBufferExists,
1955 #[error("SwitchDataPathCompletion during PrimaryChannelState")]
1956 SwitchDataPathCompletionPrimaryChannelState,
1957}
1958
1959#[derive(Debug)]
1960enum PacketData {
1961 Init(protocol::MessageInit),
1962 SendNdisVersion(protocol::Message1SendNdisVersion),
1963 SendReceiveBuffer(protocol::Message1SendReceiveBuffer),
1964 SendSendBuffer(protocol::Message1SendSendBuffer),
1965 RevokeReceiveBuffer(Message1RevokeReceiveBuffer),
1966 RevokeSendBuffer(Message1RevokeSendBuffer),
1967 RndisPacket(protocol::Message1SendRndisPacket),
1968 RndisPacketComplete(protocol::Message1SendRndisPacketComplete),
1969 SendNdisConfig(protocol::Message2SendNdisConfig),
1970 SwitchDataPath(protocol::Message4SwitchDataPath),
1971 OidQueryEx(protocol::Message5OidQueryEx),
1972 SubChannelRequest(protocol::Message5SubchannelRequest),
1973 SendVfAssociationCompletion,
1974 SwitchDataPathCompletion,
1975}
1976
1977#[derive(Debug)]
1978struct Packet<'a> {
1979 data: PacketData,
1980 transaction_id: Option<u64>,
1981 external_data: MultiPagedRangeBuf<GpnList>,
1982 send_buffer_suballocation: PagedRange<'a>,
1983}
1984
1985type PacketReader<'a> = PagedRangesReader<
1986 'a,
1987 std::iter::Chain<std::iter::Once<PagedRange<'a>>, MultiPagedRangeIter<'a>>,
1988>;
1989
1990impl Packet<'_> {
1991 fn rndis_reader<'a>(&'a self, mem: &'a GuestMemory) -> PacketReader<'a> {
1992 PagedRanges::new(
1993 std::iter::once(self.send_buffer_suballocation).chain(self.external_data.iter()),
1994 )
1995 .reader(mem)
1996 }
1997}
1998
1999fn read_packet_data<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
2000 reader: &mut impl MemoryRead,
2001) -> Result<T, PacketError> {
2002 reader.read_plain().map_err(PacketError::Access)
2003}
2004
2005fn parse_packet<'a, T: RingMem>(
2006 packet_ref: &queue::PacketRef<'_, T>,
2007 send_buffer: Option<&'a SendBuffer>,
2008 version: Option<Version>,
2009) -> Result<Packet<'a>, PacketError> {
2010 let packet = match packet_ref.as_ref() {
2011 IncomingPacket::Data(data) => data,
2012 IncomingPacket::Completion(completion) => {
2013 let data = if completion.transaction_id() == VF_ASSOCIATION_TRANSACTION_ID {
2014 PacketData::SendVfAssociationCompletion
2015 } else if completion.transaction_id() == SWITCH_DATA_PATH_TRANSACTION_ID {
2016 PacketData::SwitchDataPathCompletion
2017 } else {
2018 let mut reader = completion.reader();
2019 let header: protocol::MessageHeader =
2020 reader.read_plain().map_err(PacketError::Access)?;
2021 match header.message_type {
2022 protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE => {
2023 PacketData::RndisPacketComplete(read_packet_data(&mut reader)?)
2024 }
2025 typ => return Err(PacketError::UnknownType(typ)),
2026 }
2027 };
2028 return Ok(Packet {
2029 data,
2030 transaction_id: Some(completion.transaction_id()),
2031 external_data: MultiPagedRangeBuf::empty(),
2032 send_buffer_suballocation: PagedRange::empty(),
2033 });
2034 }
2035 };
2036
2037 let mut reader = packet.reader();
2038 let header: protocol::MessageHeader = reader.read_plain().map_err(PacketError::Access)?;
2039 let mut send_buffer_suballocation = PagedRange::empty();
2040 let data = match header.message_type {
2041 protocol::MESSAGE_TYPE_INIT => PacketData::Init(read_packet_data(&mut reader)?),
2042 protocol::MESSAGE1_TYPE_SEND_NDIS_VERSION if version >= Some(Version::V1) => {
2043 PacketData::SendNdisVersion(read_packet_data(&mut reader)?)
2044 }
2045 protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2046 PacketData::SendReceiveBuffer(read_packet_data(&mut reader)?)
2047 }
2048 protocol::MESSAGE1_TYPE_REVOKE_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2049 PacketData::RevokeReceiveBuffer(read_packet_data(&mut reader)?)
2050 }
2051 protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER if version >= Some(Version::V1) => {
2052 PacketData::SendSendBuffer(read_packet_data(&mut reader)?)
2053 }
2054 protocol::MESSAGE1_TYPE_REVOKE_SEND_BUFFER if version >= Some(Version::V1) => {
2055 PacketData::RevokeSendBuffer(read_packet_data(&mut reader)?)
2056 }
2057 protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET if version >= Some(Version::V1) => {
2058 let message: protocol::Message1SendRndisPacket = read_packet_data(&mut reader)?;
2059 if message.send_buffer_section_index != 0xffffffff {
2060 send_buffer_suballocation = send_buffer
2061 .ok_or(PacketError::InvalidSendBufferIndex)?
2062 .gpadl
2063 .first()
2064 .unwrap()
2065 .try_subrange(
2066 message.send_buffer_section_index as usize * 6144,
2067 message.send_buffer_section_size as usize,
2068 )
2069 .ok_or(PacketError::InvalidSendBufferIndex)?;
2070 }
2071 PacketData::RndisPacket(message)
2072 }
2073 protocol::MESSAGE2_TYPE_SEND_NDIS_CONFIG if version >= Some(Version::V2) => {
2074 PacketData::SendNdisConfig(read_packet_data(&mut reader)?)
2075 }
2076 protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH if version >= Some(Version::V4) => {
2077 PacketData::SwitchDataPath(read_packet_data(&mut reader)?)
2078 }
2079 protocol::MESSAGE5_TYPE_OID_QUERY_EX if version >= Some(Version::V5) => {
2080 PacketData::OidQueryEx(read_packet_data(&mut reader)?)
2081 }
2082 protocol::MESSAGE5_TYPE_SUB_CHANNEL if version >= Some(Version::V5) => {
2083 PacketData::SubChannelRequest(read_packet_data(&mut reader)?)
2084 }
2085 typ => return Err(PacketError::UnknownType(typ)),
2086 };
2087 Ok(Packet {
2088 data,
2089 transaction_id: packet.transaction_id(),
2090 external_data: packet
2091 .read_external_ranges()
2092 .map_err(PacketError::ExternalData)?,
2093 send_buffer_suballocation,
2094 })
2095}
2096
2097#[derive(Debug, Copy, Clone)]
2098struct NvspMessage<T> {
2099 header: protocol::MessageHeader,
2100 data: T,
2101 padding: &'static [u8],
2102}
2103
2104impl<T: IntoBytes + Immutable + KnownLayout> NvspMessage<T> {
2105 fn payload(&self) -> [&[u8]; 3] {
2106 [self.header.as_bytes(), self.data.as_bytes(), self.padding]
2107 }
2108}
2109
2110impl<T: RingMem> NetChannel<T> {
2111 fn message<P: IntoBytes + Immutable + KnownLayout>(
2112 &self,
2113 message_type: u32,
2114 data: P,
2115 ) -> NvspMessage<P> {
2116 let padding = self.padding(&data);
2117 NvspMessage {
2118 header: protocol::MessageHeader { message_type },
2119 data,
2120 padding,
2121 }
2122 }
2123
2124 fn padding<P: IntoBytes + Immutable + KnownLayout>(&self, data: &P) -> &'static [u8] {
2127 static PADDING: &[u8] = &[0; protocol::PACKET_SIZE_V61];
2128 let padding_len = self.packet_size
2129 - cmp::min(
2130 self.packet_size,
2131 size_of::<protocol::MessageHeader>() + data.as_bytes().len(),
2132 );
2133 &PADDING[..padding_len]
2134 }
2135
2136 fn send_completion(
2137 &mut self,
2138 transaction_id: Option<u64>,
2139 payload: &[&[u8]],
2140 ) -> Result<(), WorkerError> {
2141 match transaction_id {
2142 None => Ok(()),
2143 Some(transaction_id) => Ok(self
2144 .queue
2145 .split()
2146 .1
2147 .try_write(&queue::OutgoingPacket {
2148 transaction_id,
2149 packet_type: OutgoingPacketType::Completion,
2150 payload,
2151 })
2152 .map_err(|err| match err {
2153 queue::TryWriteError::Full(_) => WorkerError::OutOfSpace,
2154 queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2155 })?),
2156 }
2157 }
2158}
2159
2160static SUPPORTED_VERSIONS: &[Version] = &[
2161 Version::V1,
2162 Version::V2,
2163 Version::V4,
2164 Version::V5,
2165 Version::V6,
2166 Version::V61,
2167];
2168
2169fn check_version(requested_version: u32) -> Option<Version> {
2170 SUPPORTED_VERSIONS
2171 .iter()
2172 .find(|version| **version as u32 == requested_version)
2173 .copied()
2174}
2175
2176#[derive(Debug)]
2177struct ReceiveBuffer {
2178 gpadl: GpadlView,
2179 id: u16,
2180 count: u32,
2181 sub_allocation_size: u32,
2182}
2183
2184#[derive(Debug, Error)]
2185enum BufferError {
2186 #[error("unsupported suballocation size {0}")]
2187 UnsupportedSuballocationSize(u32),
2188 #[error("unaligned gpadl")]
2189 UnalignedGpadl,
2190 #[error("unknown gpadl ID")]
2191 UnknownGpadlId(#[from] UnknownGpadlId),
2192}
2193
2194impl ReceiveBuffer {
2195 fn new(
2196 gpadl_map: &GpadlMapView,
2197 gpadl_id: GpadlId,
2198 id: u16,
2199 sub_allocation_size: u32,
2200 ) -> Result<Self, BufferError> {
2201 if sub_allocation_size < sub_allocation_size_for_mtu(DEFAULT_MTU) {
2202 return Err(BufferError::UnsupportedSuballocationSize(
2203 sub_allocation_size,
2204 ));
2205 }
2206 let gpadl = gpadl_map.map(gpadl_id)?;
2207 let range = gpadl
2208 .contiguous_aligned()
2209 .ok_or(BufferError::UnalignedGpadl)?;
2210 let num_sub_allocations = range.len() as u32 / sub_allocation_size;
2211 if num_sub_allocations == 0 {
2212 return Err(BufferError::UnsupportedSuballocationSize(
2213 sub_allocation_size,
2214 ));
2215 }
2216 let recv_buffer = Self {
2217 gpadl,
2218 id,
2219 count: num_sub_allocations,
2220 sub_allocation_size,
2221 };
2222 Ok(recv_buffer)
2223 }
2224
2225 fn range(&self, index: u32) -> PagedRange<'_> {
2226 self.gpadl.first().unwrap().subrange(
2227 (index * self.sub_allocation_size) as usize,
2228 self.sub_allocation_size as usize,
2229 )
2230 }
2231
2232 fn transfer_page_range(&self, index: u32, len: usize) -> ring::TransferPageRange {
2233 assert!(len <= self.sub_allocation_size as usize);
2234 ring::TransferPageRange {
2235 byte_offset: index * self.sub_allocation_size,
2236 byte_count: len as u32,
2237 }
2238 }
2239
2240 fn saved_state(&self) -> saved_state::ReceiveBuffer {
2241 saved_state::ReceiveBuffer {
2242 gpadl_id: self.gpadl.id(),
2243 id: self.id,
2244 sub_allocation_size: self.sub_allocation_size,
2245 }
2246 }
2247}
2248
2249#[derive(Debug)]
2250struct SendBuffer {
2251 gpadl: GpadlView,
2252}
2253
2254impl SendBuffer {
2255 fn new(gpadl_map: &GpadlMapView, gpadl_id: GpadlId) -> Result<Self, BufferError> {
2256 let gpadl = gpadl_map.map(gpadl_id)?;
2257 gpadl
2258 .contiguous_aligned()
2259 .ok_or(BufferError::UnalignedGpadl)?;
2260 Ok(Self { gpadl })
2261 }
2262}
2263
2264impl<T: RingMem> NetChannel<T> {
2265 fn handle_rndis_message(
2267 &mut self,
2268 buffers: &ChannelBuffers,
2269 state: &mut ActiveState,
2270 id: TxId,
2271 message_type: u32,
2272 mut reader: PacketReader<'_>,
2273 segments: &mut Vec<TxSegment>,
2274 ) -> Result<bool, WorkerError> {
2275 let is_packet = match message_type {
2276 rndisprot::MESSAGE_TYPE_PACKET_MSG => {
2277 self.handle_rndis_packet_message(
2278 id,
2279 reader,
2280 &buffers.mem,
2281 segments,
2282 &mut state.stats,
2283 )?;
2284 true
2285 }
2286 rndisprot::MESSAGE_TYPE_HALT_MSG => false,
2287 n => {
2288 let control = state
2289 .primary
2290 .as_mut()
2291 .ok_or(WorkerError::NotSupportedOnSubChannel(n))?;
2292
2293 const CONTROL_MESSAGE_MAX_QUEUED_BYTES: usize = 100 * 1024;
2298 if reader.len() == 0 {
2299 return Err(WorkerError::RndisMessageTooSmall);
2300 }
2301 if CONTROL_MESSAGE_MAX_QUEUED_BYTES - control.control_messages_len < reader.len() {
2304 return Err(WorkerError::TooManyControlMessages);
2305 }
2306 control.control_messages_len += reader.len();
2307 control.control_messages.push_back(ControlMessage {
2308 message_type,
2309 data: reader.read_all()?.into(),
2310 });
2311
2312 false
2313 }
2315 };
2316 Ok(is_packet)
2317 }
2318
2319 fn handle_rndis_packet_message(
2321 &mut self,
2322 id: TxId,
2323 reader: PacketReader<'_>,
2324 mem: &GuestMemory,
2325 segments: &mut Vec<TxSegment>,
2326 stats: &mut QueueStats,
2327 ) -> Result<(), WorkerError> {
2328 let headers = reader
2330 .clone()
2331 .into_inner()
2332 .paged_ranges()
2333 .find(|r| !r.is_empty())
2334 .ok_or(WorkerError::RndisMessageTooSmall)?;
2335 let mut data = reader.into_inner();
2336 let request: rndisprot::Packet = headers.reader(mem).read_plain()?;
2337 if request.num_oob_data_elements != 0
2338 || request.oob_data_length != 0
2339 || request.oob_data_offset != 0
2340 || request.vc_handle != 0
2341 {
2342 return Err(WorkerError::UnsupportedRndisBehavior);
2343 }
2344
2345 if data.len() < request.data_offset as usize
2346 || (data.len() - request.data_offset as usize) < request.data_length as usize
2347 || request.data_length == 0
2348 {
2349 return Err(WorkerError::RndisMessageTooSmall);
2350 }
2351
2352 data.skip(request.data_offset as usize);
2353 data.truncate(request.data_length as usize);
2354
2355 let mut metadata = net_backend::TxMetadata {
2356 id,
2357 len: request.data_length as usize,
2358 ..Default::default()
2359 };
2360
2361 if request.per_packet_info_length != 0 {
2362 let mut ppi = headers
2363 .try_subrange(
2364 request.per_packet_info_offset as usize,
2365 request.per_packet_info_length as usize,
2366 )
2367 .ok_or(WorkerError::RndisMessageTooSmall)?;
2368 while !ppi.is_empty() {
2369 let h: rndisprot::PerPacketInfo = ppi.reader(mem).read_plain()?;
2370 if h.size == 0 {
2371 return Err(WorkerError::RndisMessageTooSmall);
2372 }
2373 let (this, rest) = ppi
2374 .try_split(h.size as usize)
2375 .ok_or(WorkerError::RndisMessageTooSmall)?;
2376 let (_, d) = this
2377 .try_split(h.per_packet_information_offset as usize)
2378 .ok_or(WorkerError::RndisMessageTooSmall)?;
2379 match h.typ {
2380 rndisprot::PPI_TCP_IP_CHECKSUM => {
2381 let n: rndisprot::TxTcpIpChecksumInfo = d.reader(mem).read_plain()?;
2382
2383 metadata.offload_tcp_checksum =
2384 (n.is_ipv4() || n.is_ipv6()) && n.tcp_checksum();
2385 metadata.offload_udp_checksum =
2386 (n.is_ipv4() || n.is_ipv6()) && !n.tcp_checksum() && n.udp_checksum();
2387 metadata.offload_ip_header_checksum = n.is_ipv4() && n.ip_header_checksum();
2388 metadata.l3_protocol = if n.is_ipv4() {
2389 L3Protocol::Ipv4
2390 } else if n.is_ipv6() {
2391 L3Protocol::Ipv6
2392 } else {
2393 L3Protocol::Unknown
2394 };
2395 metadata.l2_len = ETHERNET_HEADER_LEN as u8;
2396 if metadata.offload_tcp_checksum || metadata.offload_udp_checksum {
2397 metadata.l3_len = if n.tcp_header_offset() >= metadata.l2_len as u16 {
2398 n.tcp_header_offset() - metadata.l2_len as u16
2399 } else if n.is_ipv4() {
2400 let mut reader = data.clone().reader(mem);
2401 reader.skip(metadata.l2_len as usize)?;
2402 let mut b = 0;
2403 reader.read(std::slice::from_mut(&mut b))?;
2404 (b as u16 >> 4) * 4
2405 } else {
2406 40
2408 };
2409 }
2410 }
2411 rndisprot::PPI_LSO => {
2412 let n: rndisprot::TcpLsoInfo = d.reader(mem).read_plain()?;
2413
2414 metadata.offload_tcp_segmentation = true;
2415 metadata.offload_tcp_checksum = true;
2416 metadata.offload_ip_header_checksum = n.is_ipv4();
2417 metadata.l3_protocol = if n.is_ipv4() {
2418 L3Protocol::Ipv4
2419 } else {
2420 L3Protocol::Ipv6
2421 };
2422 metadata.l2_len = ETHERNET_HEADER_LEN as u8;
2423 if n.tcp_header_offset() < metadata.l2_len as u16 {
2424 return Err(WorkerError::InvalidTcpHeaderOffset);
2425 }
2426 metadata.l3_len = n.tcp_header_offset() - metadata.l2_len as u16;
2427 metadata.l4_len = {
2428 let mut reader = data.clone().reader(mem);
2429 reader
2430 .skip(metadata.l2_len as usize + metadata.l3_len as usize + 12)?;
2431 let mut b = 0;
2432 reader.read(std::slice::from_mut(&mut b))?;
2433 (b >> 4) * 4
2434 };
2435 metadata.max_tcp_segment_size = n.mss() as u16;
2436 }
2437 _ => {}
2438 }
2439 ppi = rest;
2440 }
2441 }
2442
2443 let start = segments.len();
2444 for range in data.paged_ranges().flat_map(|r| r.ranges()) {
2445 let range = range.map_err(WorkerError::InvalidGpadl)?;
2446 segments.push(TxSegment {
2447 ty: net_backend::TxSegmentType::Tail,
2448 gpa: range.start,
2449 len: range.len() as u32,
2450 });
2451 }
2452
2453 metadata.segment_count = segments.len() - start;
2454
2455 stats.tx_packets.increment();
2456 if metadata.offload_tcp_checksum || metadata.offload_udp_checksum {
2457 stats.tx_checksum_packets.increment();
2458 }
2459 if metadata.offload_tcp_segmentation {
2460 stats.tx_lso_packets.increment();
2461 }
2462
2463 segments[start].ty = net_backend::TxSegmentType::Head(metadata);
2464
2465 Ok(())
2466 }
2467
2468 fn guest_vf_is_available(
2471 &mut self,
2472 guest_vf_id: Option<u32>,
2473 version: Version,
2474 config: NdisConfig,
2475 ) -> Result<bool, WorkerError> {
2476 let serial_number = guest_vf_id.map(|vfid| self.adapter.get_guest_vf_serial_number(vfid));
2477 if version >= Version::V4 && config.capabilities.sriov() {
2478 tracing::info!(
2479 available = serial_number.is_some(),
2480 serial_number,
2481 "sending VF association message."
2482 );
2483 let message = {
2485 self.message(
2486 protocol::MESSAGE4_TYPE_SEND_VF_ASSOCIATION,
2487 protocol::Message4SendVfAssociation {
2488 vf_allocated: if serial_number.is_some() { 1 } else { 0 },
2489 serial_number: serial_number.unwrap_or(0),
2490 },
2491 )
2492 };
2493 self.queue
2494 .split()
2495 .1
2496 .try_write(&queue::OutgoingPacket {
2497 transaction_id: VF_ASSOCIATION_TRANSACTION_ID,
2498 packet_type: OutgoingPacketType::InBandWithCompletion,
2499 payload: &message.payload(),
2500 })
2501 .map_err(|err| match err {
2502 queue::TryWriteError::Full(len) => {
2503 tracing::error!(len, "failed to write vf association message");
2504 WorkerError::OutOfSpace
2505 }
2506 queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2507 })?;
2508 Ok(true)
2509 } else {
2510 tracing::info!(
2511 available = serial_number.is_some(),
2512 serial_number,
2513 major = version.major(),
2514 minor = version.minor(),
2515 sriov_capable = config.capabilities.sriov(),
2516 "Skipping NvspMessage4TypeSendVFAssociation message"
2517 );
2518 Ok(false)
2519 }
2520 }
2521
2522 fn guest_send_indirection_table(&mut self, version: Version, num_channels_opened: u32) {
2524 if version < Version::V5 {
2526 return;
2527 }
2528
2529 #[repr(C)]
2530 #[derive(IntoBytes, Immutable, KnownLayout)]
2531 struct SendIndirectionMsg {
2532 pub message: protocol::Message5SendIndirectionTable,
2533 pub send_indirection_table:
2534 [u32; VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES as usize],
2535 }
2536
2537 let send_indirection_table_offset = offset_of!(SendIndirectionMsg, send_indirection_table)
2539 + size_of::<protocol::MessageHeader>();
2540 let mut data = SendIndirectionMsg {
2541 message: protocol::Message5SendIndirectionTable {
2542 table_entry_count: VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES,
2543 table_offset: send_indirection_table_offset as u32,
2544 },
2545 send_indirection_table: Default::default(),
2546 };
2547
2548 for i in 0..data.send_indirection_table.len() {
2549 data.send_indirection_table[i] = i as u32 % num_channels_opened;
2550 }
2551
2552 let message = NvspMessage {
2553 header: protocol::MessageHeader {
2554 message_type: protocol::MESSAGE5_TYPE_SEND_INDIRECTION_TABLE,
2555 },
2556 data,
2557 padding: &[],
2558 };
2559 let result = self
2560 .queue
2561 .split()
2562 .1
2563 .try_write(&queue::OutgoingPacket {
2564 transaction_id: 0,
2565 packet_type: OutgoingPacketType::InBandNoCompletion,
2566 payload: &message.payload(),
2567 })
2568 .map_err(|err| match err {
2569 queue::TryWriteError::Full(len) => {
2570 tracing::error!(len, "failed to write send indirection table message");
2571 WorkerError::OutOfSpace
2572 }
2573 queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2574 });
2575 if let Err(err) = result {
2576 tracing::error!(err = %err, "Failed to notify guest about the send indirection table");
2577 }
2578 }
2579
2580 fn guest_vf_data_path_switched_to_synthetic(&mut self) {
2583 let message = NvspMessage {
2584 header: protocol::MessageHeader {
2585 message_type: protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH,
2586 },
2587 data: protocol::Message4SwitchDataPath {
2588 active_data_path: protocol::DataPath::SYNTHETIC.0,
2589 },
2590 padding: &[],
2591 };
2592 let result = self
2593 .queue
2594 .split()
2595 .1
2596 .try_write(&queue::OutgoingPacket {
2597 transaction_id: SWITCH_DATA_PATH_TRANSACTION_ID,
2598 packet_type: OutgoingPacketType::InBandWithCompletion,
2599 payload: &message.payload(),
2600 })
2601 .map_err(|err| match err {
2602 queue::TryWriteError::Full(len) => {
2603 tracing::error!(
2604 len,
2605 "failed to write MESSAGE4_TYPE_SWITCH_DATA_PATH message"
2606 );
2607 WorkerError::OutOfSpace
2608 }
2609 queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2610 });
2611 if let Err(err) = result {
2612 tracing::error!(err = %err, "Failed to notify guest that data path is now synthetic");
2613 }
2614 }
2615
2616 async fn handle_state_change(
2618 &mut self,
2619 primary: &mut PrimaryChannelState,
2620 buffers: &ChannelBuffers,
2621 ) -> Result<Option<CoordinatorMessage>, WorkerError> {
2622 if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
2628 if primary.rndis_state == RndisState::Operational {
2630 if self.guest_vf_is_available(Some(vfid), buffers.version, buffers.ndis_config)? {
2631 primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
2632 return Ok(Some(CoordinatorMessage::Update(
2633 CoordinatorMessageUpdateType {
2634 guest_vf_state: true,
2635 ..Default::default()
2636 },
2637 )));
2638 } else if let Some(true) = primary.is_data_path_switched {
2639 tracing::error!(
2640 "Data path switched, but current guest negotiation does not support VTL0 VF"
2641 );
2642 }
2643 }
2644 return Ok(None);
2645 }
2646 loop {
2647 primary.guest_vf_state = match primary.guest_vf_state {
2648 PrimaryChannelGuestVfState::UnavailableFromAvailable => {
2649 if primary.rndis_state == RndisState::Operational {
2651 self.guest_vf_is_available(None, buffers.version, buffers.ndis_config)?;
2652 }
2653 PrimaryChannelGuestVfState::Unavailable
2654 }
2655 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
2656 to_guest,
2657 id,
2658 } => {
2659 self.send_completion(id, &[])?;
2661 if to_guest {
2662 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
2663 } else {
2664 PrimaryChannelGuestVfState::UnavailableFromAvailable
2665 }
2666 }
2667 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
2668 self.guest_vf_data_path_switched_to_synthetic();
2670 PrimaryChannelGuestVfState::UnavailableFromAvailable
2671 }
2672 PrimaryChannelGuestVfState::DataPathSynthetic => {
2673 self.guest_vf_data_path_switched_to_synthetic();
2675 PrimaryChannelGuestVfState::Ready
2676 }
2677 PrimaryChannelGuestVfState::DataPathSwitchPending {
2678 to_guest,
2679 id,
2680 result,
2681 } => {
2682 let result = result.expect("DataPathSwitchPending should have been processed");
2683 self.send_completion(id, &[])?;
2685 if result {
2686 if to_guest {
2687 PrimaryChannelGuestVfState::DataPathSwitched
2688 } else {
2689 PrimaryChannelGuestVfState::Ready
2690 }
2691 } else {
2692 if to_guest {
2693 PrimaryChannelGuestVfState::DataPathSynthetic
2694 } else {
2695 tracing::error!(
2696 "Failure when guest requested switch back to synthetic"
2697 );
2698 PrimaryChannelGuestVfState::DataPathSwitched
2699 }
2700 }
2701 }
2702 PrimaryChannelGuestVfState::Initializing
2703 | PrimaryChannelGuestVfState::Restoring(_) => {
2704 panic!("Invalid guest VF state: {}", primary.guest_vf_state)
2705 }
2706 _ => break,
2707 };
2708 }
2709 Ok(None)
2710 }
2711
2712 fn handle_rndis_control_message(
2715 &mut self,
2716 primary: &mut PrimaryChannelState,
2717 buffers: &ChannelBuffers,
2718 message_type: u32,
2719 mut reader: impl MemoryRead + Clone,
2720 id: u32,
2721 ) -> Result<(), WorkerError> {
2722 let mem = &buffers.mem;
2723 let buffer_range = &buffers.recv_buffer.range(id);
2724 match message_type {
2725 rndisprot::MESSAGE_TYPE_INITIALIZE_MSG => {
2726 if primary.rndis_state != RndisState::Initializing {
2727 return Err(WorkerError::InvalidRndisState);
2728 }
2729
2730 let request: rndisprot::InitializeRequest = reader.read_plain()?;
2731
2732 tracing::trace!(
2733 ?request,
2734 "handling control message MESSAGE_TYPE_INITIALIZE_MSG"
2735 );
2736
2737 primary.rndis_state = RndisState::Operational;
2738
2739 let mut writer = buffer_range.writer(mem);
2740 let message_length = write_rndis_message(
2741 &mut writer,
2742 rndisprot::MESSAGE_TYPE_INITIALIZE_CMPLT,
2743 0,
2744 &rndisprot::InitializeComplete {
2745 request_id: request.request_id,
2746 status: rndisprot::STATUS_SUCCESS,
2747 major_version: rndisprot::MAJOR_VERSION,
2748 minor_version: rndisprot::MINOR_VERSION,
2749 device_flags: rndisprot::DF_CONNECTIONLESS,
2750 medium: rndisprot::MEDIUM_802_3,
2751 max_packets_per_message: 8,
2752 max_transfer_size: 0xEFFFFFFF,
2753 packet_alignment_factor: 3,
2754 af_list_offset: 0,
2755 af_list_size: 0,
2756 },
2757 )?;
2758 self.send_rndis_control_message(buffers, id, message_length)?;
2759 if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
2760 if self.guest_vf_is_available(
2761 Some(vfid),
2762 buffers.version,
2763 buffers.ndis_config,
2764 )? {
2765 primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
2777 self.send_coordinator_update_vf();
2778 } else if let Some(true) = primary.is_data_path_switched {
2779 tracing::error!(
2780 "Data path switched, but current guest negotiation does not support VTL0 VF"
2781 );
2782 }
2783 }
2784 }
2785 rndisprot::MESSAGE_TYPE_QUERY_MSG => {
2786 let request: rndisprot::QueryRequest = reader.read_plain()?;
2787
2788 tracing::trace!(?request, "handling control message MESSAGE_TYPE_QUERY_MSG");
2789
2790 let (header, body) = buffer_range
2791 .try_split(
2792 size_of::<rndisprot::MessageHeader>()
2793 + size_of::<rndisprot::QueryComplete>(),
2794 )
2795 .ok_or(WorkerError::RndisMessageTooSmall)?;
2796 let (status, tx) = match self.adapter.handle_oid_query(
2797 buffers,
2798 primary,
2799 request.oid,
2800 body.writer(mem),
2801 ) {
2802 Ok(tx) => (rndisprot::STATUS_SUCCESS, tx),
2803 Err(err) => (err.as_status(), 0),
2804 };
2805
2806 let message_length = write_rndis_message(
2807 &mut header.writer(mem),
2808 rndisprot::MESSAGE_TYPE_QUERY_CMPLT,
2809 tx,
2810 &rndisprot::QueryComplete {
2811 request_id: request.request_id,
2812 status,
2813 information_buffer_offset: size_of::<rndisprot::QueryComplete>() as u32,
2814 information_buffer_length: tx as u32,
2815 },
2816 )?;
2817 self.send_rndis_control_message(buffers, id, message_length)?;
2818 }
2819 rndisprot::MESSAGE_TYPE_SET_MSG => {
2820 let request: rndisprot::SetRequest = reader.read_plain()?;
2821
2822 tracing::trace!(?request, "handling control message MESSAGE_TYPE_SET_MSG");
2823
2824 let status = match self.adapter.handle_oid_set(primary, request.oid, reader) {
2825 Ok((restart_endpoint, packet_filter)) => {
2826 if restart_endpoint {
2829 self.restart = Some(CoordinatorMessage::Restart);
2830 }
2831 if let Some(filter) = packet_filter {
2832 if self.packet_filter != filter {
2833 self.packet_filter = filter;
2834 self.send_coordinator_update_filter();
2835 }
2836 }
2837 rndisprot::STATUS_SUCCESS
2838 }
2839 Err(err) => {
2840 tracelimit::warn_ratelimited!(oid = ?request.oid, error = &err as &dyn std::error::Error, "oid failure");
2841 err.as_status()
2842 }
2843 };
2844
2845 let message_length = write_rndis_message(
2846 &mut buffer_range.writer(mem),
2847 rndisprot::MESSAGE_TYPE_SET_CMPLT,
2848 0,
2849 &rndisprot::SetComplete {
2850 request_id: request.request_id,
2851 status,
2852 },
2853 )?;
2854 self.send_rndis_control_message(buffers, id, message_length)?;
2855 }
2856 rndisprot::MESSAGE_TYPE_RESET_MSG => {
2857 return Err(WorkerError::RndisMessageTypeNotImplemented);
2858 }
2859 rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG => {
2860 return Err(WorkerError::RndisMessageTypeNotImplemented);
2861 }
2862 rndisprot::MESSAGE_TYPE_KEEPALIVE_MSG => {
2863 let request: rndisprot::KeepaliveRequest = reader.read_plain()?;
2864
2865 tracing::trace!(
2866 ?request,
2867 "handling control message MESSAGE_TYPE_KEEPALIVE_MSG"
2868 );
2869
2870 let message_length = write_rndis_message(
2871 &mut buffer_range.writer(mem),
2872 rndisprot::MESSAGE_TYPE_KEEPALIVE_CMPLT,
2873 0,
2874 &rndisprot::KeepaliveComplete {
2875 request_id: request.request_id,
2876 status: rndisprot::STATUS_SUCCESS,
2877 },
2878 )?;
2879 self.send_rndis_control_message(buffers, id, message_length)?;
2880 }
2881 rndisprot::MESSAGE_TYPE_SET_EX_MSG => {
2882 return Err(WorkerError::RndisMessageTypeNotImplemented);
2883 }
2884 _ => return Err(WorkerError::UnknownRndisMessageType(message_type)),
2885 };
2886 Ok(())
2887 }
2888
2889 fn try_send_rndis_message(
2890 &mut self,
2891 transaction_id: u64,
2892 channel_type: u32,
2893 recv_buffer_id: u16,
2894 transfer_pages: &[ring::TransferPageRange],
2895 ) -> Result<Option<usize>, WorkerError> {
2896 let message = self.message(
2897 protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET,
2898 protocol::Message1SendRndisPacket {
2899 channel_type,
2900 send_buffer_section_index: 0xffffffff,
2901 send_buffer_section_size: 0,
2902 },
2903 );
2904 let pending_send_size = match self.queue.split().1.try_write(&queue::OutgoingPacket {
2905 transaction_id,
2906 packet_type: OutgoingPacketType::TransferPages(recv_buffer_id, transfer_pages),
2907 payload: &message.payload(),
2908 }) {
2909 Ok(()) => None,
2910 Err(queue::TryWriteError::Full(n)) => Some(n),
2911 Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
2912 };
2913 Ok(pending_send_size)
2914 }
2915
2916 fn send_rndis_control_message(
2917 &mut self,
2918 buffers: &ChannelBuffers,
2919 id: u32,
2920 message_length: usize,
2921 ) -> Result<(), WorkerError> {
2922 let result = self.try_send_rndis_message(
2923 id as u64,
2924 protocol::CONTROL_CHANNEL_TYPE,
2925 buffers.recv_buffer.id,
2926 std::slice::from_ref(&buffers.recv_buffer.transfer_page_range(id, message_length)),
2927 )?;
2928
2929 match result {
2931 None => Ok(()),
2932 Some(len) => {
2933 tracelimit::error_ratelimited!(len, "failed to write control message completion");
2934 Err(WorkerError::OutOfSpace)
2935 }
2936 }
2937 }
2938
2939 fn indicate_status(
2940 &mut self,
2941 buffers: &ChannelBuffers,
2942 id: u32,
2943 status: u32,
2944 payload: &[u8],
2945 ) -> Result<(), WorkerError> {
2946 let buffer = &buffers.recv_buffer.range(id);
2947 let mut writer = buffer.writer(&buffers.mem);
2948 let message_length = write_rndis_message(
2949 &mut writer,
2950 rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG,
2951 payload.len(),
2952 &rndisprot::IndicateStatus {
2953 status,
2954 status_buffer_length: payload.len() as u32,
2955 status_buffer_offset: if payload.is_empty() {
2956 0
2957 } else {
2958 size_of::<rndisprot::IndicateStatus>() as u32
2959 },
2960 },
2961 )?;
2962 writer.write(payload)?;
2963 self.send_rndis_control_message(buffers, id, message_length)?;
2964 Ok(())
2965 }
2966
2967 fn process_control_messages(
2970 &mut self,
2971 buffers: &ChannelBuffers,
2972 state: &mut ActiveState,
2973 ) -> Result<(), WorkerError> {
2974 let Some(primary) = &mut state.primary else {
2975 return Ok(());
2976 };
2977
2978 while !primary.control_messages.is_empty()
2979 || (primary.pending_offload_change && primary.rndis_state == RndisState::Operational)
2980 {
2981 if !self.queue.split().1.can_write(MIN_CONTROL_RING_SIZE)? {
2983 self.pending_send_size = MIN_CONTROL_RING_SIZE;
2984 break;
2985 }
2986 let Some(id) = primary.free_control_buffers.pop() else {
2987 break;
2988 };
2989
2990 assert!(state.rx_bufs.is_free(id.0));
2992 state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
2993
2994 if let Some(message) = primary.control_messages.pop_front() {
2995 primary.control_messages_len -= message.data.len();
2996 self.handle_rndis_control_message(
2997 primary,
2998 buffers,
2999 message.message_type,
3000 message.data.as_ref(),
3001 id.0,
3002 )?;
3003 } else if primary.pending_offload_change
3004 && primary.rndis_state == RndisState::Operational
3005 {
3006 let ndis_offload = primary.offload_config.ndis_offload();
3007 self.indicate_status(
3008 buffers,
3009 id.0,
3010 rndisprot::STATUS_TASK_OFFLOAD_CURRENT_CONFIG,
3011 &ndis_offload.as_bytes()[..ndis_offload.header.size.into()],
3012 )?;
3013 primary.pending_offload_change = false;
3014 } else {
3015 unreachable!();
3016 }
3017 }
3018 Ok(())
3019 }
3020
3021 fn send_coordinator_update_message(&mut self, guest_vf: bool, packet_filter: bool) {
3022 if self.restart.is_none() {
3023 self.restart = Some(CoordinatorMessage::Update(CoordinatorMessageUpdateType {
3024 guest_vf_state: guest_vf,
3025 filter_state: packet_filter,
3026 }));
3027 } else if let Some(CoordinatorMessage::Restart) = self.restart {
3028 } else if let Some(CoordinatorMessage::Update(ref mut update)) = self.restart {
3032 update.guest_vf_state |= guest_vf;
3034 update.filter_state |= packet_filter;
3035 }
3036 }
3037
3038 fn send_coordinator_update_vf(&mut self) {
3039 self.send_coordinator_update_message(true, false);
3040 }
3041
3042 fn send_coordinator_update_filter(&mut self) {
3043 self.send_coordinator_update_message(false, true);
3044 }
3045}
3046
3047fn write_rndis_message<T: IntoBytes + Immutable + KnownLayout>(
3049 writer: &mut impl MemoryWrite,
3050 message_type: u32,
3051 extra: usize,
3052 payload: &T,
3053) -> Result<usize, AccessError> {
3054 let message_length = size_of::<rndisprot::MessageHeader>() + size_of_val(payload) + extra;
3055 writer.write(
3056 rndisprot::MessageHeader {
3057 message_type,
3058 message_length: message_length as u32,
3059 }
3060 .as_bytes(),
3061 )?;
3062 writer.write(payload.as_bytes())?;
3063 Ok(message_length)
3064}
3065
3066#[derive(Debug, Error)]
3067enum OidError {
3068 #[error(transparent)]
3069 Access(#[from] AccessError),
3070 #[error("unknown oid")]
3071 UnknownOid,
3072 #[error("invalid oid input, bad field {0}")]
3073 InvalidInput(&'static str),
3074 #[error("bad ndis version")]
3075 BadVersion,
3076 #[error("feature {0} not supported")]
3077 NotSupported(&'static str),
3078}
3079
3080impl OidError {
3081 fn as_status(&self) -> u32 {
3082 match self {
3083 OidError::UnknownOid | OidError::NotSupported(_) => rndisprot::STATUS_NOT_SUPPORTED,
3084 OidError::BadVersion => rndisprot::STATUS_BAD_VERSION,
3085 OidError::InvalidInput(_) => rndisprot::STATUS_INVALID_DATA,
3086 OidError::Access(_) => rndisprot::STATUS_FAILURE,
3087 }
3088 }
3089}
3090
3091const DEFAULT_MTU: u32 = 1514;
3092const MIN_MTU: u32 = DEFAULT_MTU;
3093const MAX_MTU: u32 = 9216;
3094
3095const ETHERNET_HEADER_LEN: u32 = 14;
3096
3097impl Adapter {
3098 fn get_guest_vf_serial_number(&self, vfid: u32) -> u32 {
3099 if let Some(guest_os_id) = self.get_guest_os_id.as_ref().map(|f| f()) {
3100 if guest_os_id
3103 .microsoft()
3104 .unwrap_or(HvGuestOsMicrosoft::from(0))
3105 .os_id()
3106 == HvGuestOsMicrosoftIds::WINDOWS_NT.0
3107 {
3108 self.adapter_index
3109 } else {
3110 vfid
3111 }
3112 } else {
3113 vfid
3114 }
3115 }
3116
3117 fn handle_oid_query(
3118 &self,
3119 buffers: &ChannelBuffers,
3120 primary: &PrimaryChannelState,
3121 oid: rndisprot::Oid,
3122 mut writer: impl MemoryWrite,
3123 ) -> Result<usize, OidError> {
3124 tracing::debug!(?oid, "oid query");
3125 let available_len = writer.len();
3126 match oid {
3127 rndisprot::Oid::OID_GEN_SUPPORTED_LIST => {
3128 let supported_oids_common = &[
3129 rndisprot::Oid::OID_GEN_SUPPORTED_LIST,
3130 rndisprot::Oid::OID_GEN_HARDWARE_STATUS,
3131 rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED,
3132 rndisprot::Oid::OID_GEN_MEDIA_IN_USE,
3133 rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD,
3134 rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD,
3135 rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE,
3136 rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE,
3137 rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE,
3138 rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE,
3139 rndisprot::Oid::OID_GEN_LINK_SPEED,
3140 rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE,
3141 rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE,
3142 rndisprot::Oid::OID_GEN_VENDOR_ID,
3143 rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION,
3144 rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION,
3145 rndisprot::Oid::OID_GEN_DRIVER_VERSION,
3146 rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER,
3147 rndisprot::Oid::OID_GEN_PROTOCOL_OPTIONS,
3148 rndisprot::Oid::OID_GEN_MAC_OPTIONS,
3149 rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS,
3150 rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS,
3151 rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES,
3152 rndisprot::Oid::OID_GEN_FRIENDLY_NAME,
3153 rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS,
3155 rndisprot::Oid::OID_802_3_CURRENT_ADDRESS,
3156 rndisprot::Oid::OID_802_3_MULTICAST_LIST,
3157 rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE,
3158 rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT,
3160 rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION,
3161 rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS,
3162 rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER,
3167 ];
3168
3169 let supported_oids_6 = &[
3171 rndisprot::Oid::OID_GEN_LINK_PARAMETERS,
3173 rndisprot::Oid::OID_GEN_LINK_STATE,
3174 rndisprot::Oid::OID_GEN_MAX_LINK_SPEED,
3175 rndisprot::Oid::OID_GEN_BYTES_RCV,
3177 rndisprot::Oid::OID_GEN_BYTES_XMIT,
3178 rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS,
3180 rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION,
3181 rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES,
3182 rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG,
3183 ];
3186
3187 let supported_oids_63 = &[
3188 rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES,
3189 rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS,
3190 ];
3191
3192 match buffers.ndis_version.major {
3193 5 => {
3194 writer.write(supported_oids_common.as_bytes())?;
3195 }
3196 6 => {
3197 writer.write(supported_oids_common.as_bytes())?;
3198 writer.write(supported_oids_6.as_bytes())?;
3199 if buffers.ndis_version.minor >= 30 {
3200 writer.write(supported_oids_63.as_bytes())?;
3201 }
3202 }
3203 _ => return Err(OidError::BadVersion),
3204 }
3205 }
3206 rndisprot::Oid::OID_GEN_HARDWARE_STATUS => {
3207 let status: u32 = 0; writer.write(status.as_bytes())?;
3209 }
3210 rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED | rndisprot::Oid::OID_GEN_MEDIA_IN_USE => {
3211 writer.write(rndisprot::MEDIUM_802_3.as_bytes())?;
3212 }
3213 rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD
3214 | rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD
3215 | rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE => {
3216 let len: u32 = buffers.ndis_config.mtu - ETHERNET_HEADER_LEN;
3217 writer.write(len.as_bytes())?;
3218 }
3219 rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE
3220 | rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE
3221 | rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE => {
3222 let len: u32 = buffers.ndis_config.mtu;
3223 writer.write(len.as_bytes())?;
3224 }
3225 rndisprot::Oid::OID_GEN_LINK_SPEED => {
3226 let speed: u32 = (self.link_speed / 100) as u32; writer.write(speed.as_bytes())?;
3228 }
3229 rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE
3230 | rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE => {
3231 writer.write((256u32 * 1024).as_bytes())?
3233 }
3234 rndisprot::Oid::OID_GEN_VENDOR_ID => {
3235 writer.write(0x0000155du32.as_bytes())?;
3238 }
3239 rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION => writer.write(b"Microsoft Corporation")?,
3240 rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION
3241 | rndisprot::Oid::OID_GEN_DRIVER_VERSION => {
3242 writer.write(0x0100u16.as_bytes())? }
3244 rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => writer.write(0u32.as_bytes())?,
3245 rndisprot::Oid::OID_GEN_MAC_OPTIONS => {
3246 let options: u32 = rndisprot::MAC_OPTION_COPY_LOOKAHEAD_DATA
3247 | rndisprot::MAC_OPTION_TRANSFERS_NOT_PEND
3248 | rndisprot::MAC_OPTION_NO_LOOPBACK;
3249 writer.write(options.as_bytes())?;
3250 }
3251 rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS => {
3252 writer.write(rndisprot::MEDIA_STATE_CONNECTED.as_bytes())?;
3253 }
3254 rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS => writer.write(u32::MAX.as_bytes())?,
3255 rndisprot::Oid::OID_GEN_FRIENDLY_NAME => {
3256 let name16: Vec<u16> = "Network Device".encode_utf16().collect();
3257 let mut name = rndisprot::FriendlyName::new_zeroed();
3258 name.name[..name16.len()].copy_from_slice(&name16);
3259 writer.write(name.as_bytes())?
3260 }
3261 rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS
3262 | rndisprot::Oid::OID_802_3_CURRENT_ADDRESS => {
3263 writer.write(&self.mac_address.to_bytes())?
3264 }
3265 rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE => {
3266 writer.write(0u32.as_bytes())?;
3267 }
3268 rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT
3269 | rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION
3270 | rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS => writer.write(0u32.as_bytes())?,
3271
3272 rndisprot::Oid::OID_GEN_LINK_STATE => {
3274 let link_state = rndisprot::LinkState {
3275 header: rndisprot::NdisObjectHeader {
3276 object_type: rndisprot::NdisObjectType::DEFAULT,
3277 revision: 1,
3278 size: size_of::<rndisprot::LinkState>() as u16,
3279 },
3280 media_connect_state: 1, media_duplex_state: 0, padding: 0,
3283 xmit_link_speed: self.link_speed,
3284 rcv_link_speed: self.link_speed,
3285 pause_functions: 0, auto_negotiation_flags: 0,
3287 };
3288 writer.write(link_state.as_bytes())?;
3289 }
3290 rndisprot::Oid::OID_GEN_MAX_LINK_SPEED => {
3291 let link_speed = rndisprot::LinkSpeed {
3292 xmit: self.link_speed,
3293 rcv: self.link_speed,
3294 };
3295 writer.write(link_speed.as_bytes())?;
3296 }
3297 rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES => {
3298 let ndis_offload = self.offload_support.ndis_offload();
3299 writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3300 }
3301 rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG => {
3302 let ndis_offload = primary.offload_config.ndis_offload();
3303 writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3304 }
3305 rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3306 writer.write(
3307 &rndisprot::NdisOffloadEncapsulation {
3308 header: rndisprot::NdisObjectHeader {
3309 object_type: rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3310 revision: 1,
3311 size: rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1 as u16,
3312 },
3313 ipv4_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3314 ipv4_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3315 ipv4_header_size: ETHERNET_HEADER_LEN,
3316 ipv6_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3317 ipv6_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3318 ipv6_header_size: ETHERNET_HEADER_LEN,
3319 }
3320 .as_bytes()[..rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1],
3321 )?;
3322 }
3323 rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES => {
3324 writer.write(
3325 &rndisprot::NdisReceiveScaleCapabilities {
3326 header: rndisprot::NdisObjectHeader {
3327 object_type: rndisprot::NdisObjectType::RSS_CAPABILITIES,
3328 revision: 2,
3329 size: rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2
3330 as u16,
3331 },
3332 capabilities_flags: rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV4
3333 | rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV6
3334 | rndisprot::NDIS_HASH_FUNCTION_TOEPLITZ,
3335 number_of_interrupt_messages: 1,
3336 number_of_receive_queues: self.max_queues.into(),
3337 number_of_indirection_table_entries: if self.indirection_table_size != 0 {
3338 self.indirection_table_size
3339 } else {
3340 128
3343 },
3344 padding: 0,
3345 }
3346 .as_bytes()[..rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2],
3347 )?;
3348 }
3349 _ => {
3350 tracelimit::warn_ratelimited!(?oid, "query for unknown OID");
3351 return Err(OidError::UnknownOid);
3352 }
3353 };
3354 Ok(available_len - writer.len())
3355 }
3356
3357 fn handle_oid_set(
3358 &self,
3359 primary: &mut PrimaryChannelState,
3360 oid: rndisprot::Oid,
3361 reader: impl MemoryRead + Clone,
3362 ) -> Result<(bool, Option<u32>), OidError> {
3363 tracing::debug!(?oid, "oid set");
3364
3365 let mut restart_endpoint = false;
3366 let mut packet_filter = None;
3367 match oid {
3368 rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => {
3369 packet_filter = self.oid_set_packet_filter(reader)?;
3370 }
3371 rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS => {
3372 self.oid_set_offload_parameters(reader, primary)?;
3373 }
3374 rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3375 self.oid_set_offload_encapsulation(reader)?;
3376 }
3377 rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER => {
3378 self.oid_set_rndis_config_parameter(reader, primary)?;
3379 }
3380 rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES => {
3381 }
3383 rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS => {
3384 self.oid_set_rss_parameters(reader, primary)?;
3385
3386 restart_endpoint = true;
3390 }
3391 _ => {
3392 tracelimit::warn_ratelimited!(?oid, "set of unknown OID");
3393 return Err(OidError::UnknownOid);
3394 }
3395 }
3396 Ok((restart_endpoint, packet_filter))
3397 }
3398
3399 fn oid_set_rss_parameters(
3400 &self,
3401 mut reader: impl MemoryRead + Clone,
3402 primary: &mut PrimaryChannelState,
3403 ) -> Result<(), OidError> {
3404 let mut params = rndisprot::NdisReceiveScaleParameters::new_zeroed();
3406 let len = reader.len().min(size_of_val(¶ms));
3407 reader.clone().read(&mut params.as_mut_bytes()[..len])?;
3408
3409 if ((params.flags & NDIS_RSS_PARAM_FLAG_DISABLE_RSS) != 0)
3410 || ((params.hash_information & NDIS_HASH_FUNCTION_MASK) == 0)
3411 {
3412 primary.rss_state = None;
3413 return Ok(());
3414 }
3415
3416 if params.hash_secret_key_size != 40 {
3417 return Err(OidError::InvalidInput("hash_secret_key_size"));
3418 }
3419 if params.indirection_table_size % 4 != 0 {
3420 return Err(OidError::InvalidInput("indirection_table_size"));
3421 }
3422 let indirection_table_size =
3423 (params.indirection_table_size / 4).min(self.indirection_table_size) as usize;
3424 let mut key = [0; 40];
3425 let mut indirection_table = vec![0u32; self.indirection_table_size as usize];
3426 reader
3427 .clone()
3428 .skip(params.hash_secret_key_offset as usize)?
3429 .read(&mut key)?;
3430 reader
3431 .skip(params.indirection_table_offset as usize)?
3432 .read(indirection_table[..indirection_table_size].as_mut_bytes())?;
3433 tracelimit::info_ratelimited!(?indirection_table, "OID_GEN_RECEIVE_SCALE_PARAMETERS");
3434 if indirection_table
3435 .iter()
3436 .any(|&x| x >= self.max_queues as u32)
3437 {
3438 return Err(OidError::InvalidInput("indirection_table"));
3439 }
3440 let (indir_init, indir_uninit) = indirection_table.split_at_mut(indirection_table_size);
3441 for (src, dest) in std::iter::repeat_with(|| indir_init.iter().copied())
3442 .flatten()
3443 .zip(indir_uninit)
3444 {
3445 *dest = src;
3446 }
3447 primary.rss_state = Some(RssState {
3448 key,
3449 indirection_table: indirection_table.iter().map(|&x| x as u16).collect(),
3450 });
3451 Ok(())
3452 }
3453
3454 fn oid_set_packet_filter(
3455 &self,
3456 reader: impl MemoryRead + Clone,
3457 ) -> Result<Option<u32>, OidError> {
3458 let filter: rndisprot::RndisPacketFilterOidValue = reader.clone().read_plain()?;
3459 tracing::debug!(filter, "set packet filter");
3460 Ok(Some(filter))
3461 }
3462
3463 fn oid_set_offload_parameters(
3464 &self,
3465 reader: impl MemoryRead + Clone,
3466 primary: &mut PrimaryChannelState,
3467 ) -> Result<(), OidError> {
3468 let offload: rndisprot::NdisOffloadParameters = read_ndis_object(
3469 reader,
3470 rndisprot::NdisObjectType::DEFAULT,
3471 1,
3472 rndisprot::NDIS_SIZEOF_OFFLOAD_PARAMETERS_REVISION_1,
3473 )?;
3474
3475 tracing::debug!(?offload, "offload parameters");
3476 let rndisprot::NdisOffloadParameters {
3477 header: _,
3478 ipv4_checksum,
3479 tcp4_checksum,
3480 udp4_checksum,
3481 tcp6_checksum,
3482 udp6_checksum,
3483 lsov1,
3484 ipsec_v1: _,
3485 lsov2_ipv4,
3486 lsov2_ipv6,
3487 tcp_connection_ipv4: _,
3488 tcp_connection_ipv6: _,
3489 reserved: _,
3490 flags: _,
3491 } = offload;
3492
3493 if lsov1 == rndisprot::OffloadParametersSimple::ENABLED {
3494 return Err(OidError::NotSupported("lsov1"));
3495 }
3496 if let Some((tx, rx)) = ipv4_checksum.tx_rx() {
3497 primary.offload_config.checksum_tx.ipv4_header = tx;
3498 primary.offload_config.checksum_rx.ipv4_header = rx;
3499 }
3500 if let Some((tx, rx)) = tcp4_checksum.tx_rx() {
3501 primary.offload_config.checksum_tx.tcp4 = tx;
3502 primary.offload_config.checksum_rx.tcp4 = rx;
3503 }
3504 if let Some((tx, rx)) = tcp6_checksum.tx_rx() {
3505 primary.offload_config.checksum_tx.tcp6 = tx;
3506 primary.offload_config.checksum_rx.tcp6 = rx;
3507 }
3508 if let Some((tx, rx)) = udp4_checksum.tx_rx() {
3509 primary.offload_config.checksum_tx.udp4 = tx;
3510 primary.offload_config.checksum_rx.udp4 = rx;
3511 }
3512 if let Some((tx, rx)) = udp6_checksum.tx_rx() {
3513 primary.offload_config.checksum_tx.udp6 = tx;
3514 primary.offload_config.checksum_rx.udp6 = rx;
3515 }
3516 if let Some(enable) = lsov2_ipv4.enable() {
3517 primary.offload_config.lso4 = enable && self.offload_support.lso4;
3518 }
3519 if let Some(enable) = lsov2_ipv6.enable() {
3520 primary.offload_config.lso6 = enable && self.offload_support.lso6;
3521 }
3522 primary.pending_offload_change = true;
3523 Ok(())
3524 }
3525
3526 fn oid_set_offload_encapsulation(
3527 &self,
3528 reader: impl MemoryRead + Clone,
3529 ) -> Result<(), OidError> {
3530 let encap: rndisprot::NdisOffloadEncapsulation = read_ndis_object(
3531 reader,
3532 rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3533 1,
3534 rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1,
3535 )?;
3536 if encap.ipv4_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3537 && (encap.ipv4_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3538 || encap.ipv4_header_size != ETHERNET_HEADER_LEN)
3539 {
3540 return Err(OidError::NotSupported("ipv4 encap"));
3541 }
3542 if encap.ipv6_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3543 && (encap.ipv6_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3544 || encap.ipv6_header_size != ETHERNET_HEADER_LEN)
3545 {
3546 return Err(OidError::NotSupported("ipv6 encap"));
3547 }
3548 Ok(())
3549 }
3550
3551 fn oid_set_rndis_config_parameter(
3552 &self,
3553 reader: impl MemoryRead + Clone,
3554 primary: &mut PrimaryChannelState,
3555 ) -> Result<(), OidError> {
3556 let info: rndisprot::RndisConfigParameterInfo = reader.clone().read_plain()?;
3557 if info.name_length > 255 {
3558 return Err(OidError::InvalidInput("name_length"));
3559 }
3560 if info.value_length > 255 {
3561 return Err(OidError::InvalidInput("value_length"));
3562 }
3563 let name = reader
3564 .clone()
3565 .skip(info.name_offset as usize)?
3566 .read_n::<u16>(info.name_length as usize / 2)?;
3567 let name = String::from_utf16(&name).map_err(|_| OidError::InvalidInput("name"))?;
3568 let mut value = reader;
3569 value.skip(info.value_offset as usize)?;
3570 let mut value = value.limit(info.value_length as usize);
3571 match info.parameter_type {
3572 rndisprot::NdisParameterType::STRING => {
3573 let value = value.read_n::<u16>(info.value_length as usize / 2)?;
3574 let value =
3575 String::from_utf16(&value).map_err(|_| OidError::InvalidInput("value"))?;
3576 let as_num = value.as_bytes().first().map_or(0, |c| c - b'0');
3577 let tx = as_num & 1 != 0;
3578 let rx = as_num & 2 != 0;
3579
3580 tracing::debug!(name, value, "rndis config");
3581 match name.as_str() {
3582 "*IPChecksumOffloadIPv4" => {
3583 primary.offload_config.checksum_tx.ipv4_header = tx;
3584 primary.offload_config.checksum_rx.ipv4_header = rx;
3585 }
3586 "*LsoV2IPv4" => {
3587 primary.offload_config.lso4 = as_num != 0 && self.offload_support.lso4;
3588 }
3589 "*LsoV2IPv6" => {
3590 primary.offload_config.lso6 = as_num != 0 && self.offload_support.lso6;
3591 }
3592 "*TCPChecksumOffloadIPv4" => {
3593 primary.offload_config.checksum_tx.tcp4 = tx;
3594 primary.offload_config.checksum_rx.tcp4 = rx;
3595 }
3596 "*TCPChecksumOffloadIPv6" => {
3597 primary.offload_config.checksum_tx.tcp6 = tx;
3598 primary.offload_config.checksum_rx.tcp6 = rx;
3599 }
3600 "*UDPChecksumOffloadIPv4" => {
3601 primary.offload_config.checksum_tx.udp4 = tx;
3602 primary.offload_config.checksum_rx.udp4 = rx;
3603 }
3604 "*UDPChecksumOffloadIPv6" => {
3605 primary.offload_config.checksum_tx.udp6 = tx;
3606 primary.offload_config.checksum_rx.udp6 = rx;
3607 }
3608 _ => {}
3609 }
3610 }
3611 rndisprot::NdisParameterType::INTEGER => {
3612 let value: u32 = value.read_plain()?;
3613 tracing::debug!(name, value, "rndis config");
3614 }
3615 parameter_type => tracelimit::warn_ratelimited!(
3616 name,
3617 ?parameter_type,
3618 "unhandled rndis config parameter type"
3619 ),
3620 }
3621 Ok(())
3622 }
3623}
3624
3625fn read_ndis_object<T: IntoBytes + FromBytes + Debug + Immutable + KnownLayout>(
3626 mut reader: impl MemoryRead,
3627 object_type: rndisprot::NdisObjectType,
3628 min_revision: u8,
3629 min_size: usize,
3630) -> Result<T, OidError> {
3631 let mut buffer = T::new_zeroed();
3632 let sent_size = reader.len();
3633 let len = sent_size.min(size_of_val(&buffer));
3634 reader.read(&mut buffer.as_mut_bytes()[..len])?;
3635 validate_ndis_object_header(
3636 &rndisprot::NdisObjectHeader::read_from_prefix(buffer.as_bytes())
3637 .unwrap()
3638 .0, sent_size,
3640 object_type,
3641 min_revision,
3642 min_size,
3643 )?;
3644 Ok(buffer)
3645}
3646
3647fn validate_ndis_object_header(
3648 header: &rndisprot::NdisObjectHeader,
3649 sent_size: usize,
3650 object_type: rndisprot::NdisObjectType,
3651 min_revision: u8,
3652 min_size: usize,
3653) -> Result<(), OidError> {
3654 if header.object_type != object_type {
3655 return Err(OidError::InvalidInput("header.object_type"));
3656 }
3657 if sent_size < header.size as usize {
3658 return Err(OidError::InvalidInput("header.size"));
3659 }
3660 if header.revision < min_revision {
3661 return Err(OidError::InvalidInput("header.revision"));
3662 }
3663 if (header.size as usize) < min_size {
3664 return Err(OidError::InvalidInput("header.size"));
3665 }
3666 Ok(())
3667}
3668
3669struct Coordinator {
3670 recv: mpsc::Receiver<CoordinatorMessage>,
3671 channel_control: ChannelControl,
3672 restart: bool,
3673 workers: Vec<TaskControl<NetQueue, Worker<GpadlRingMem>>>,
3674 buffers: Option<Arc<ChannelBuffers>>,
3675 num_queues: u16,
3676}
3677
3678enum CoordinatorStatePendingVfState {
3682 Ready,
3684 Delay {
3686 timer: PolledTimer,
3687 delay_until: Instant,
3688 },
3689 Pending,
3691}
3692
3693struct CoordinatorState {
3694 endpoint: Box<dyn Endpoint>,
3695 adapter: Arc<Adapter>,
3696 virtual_function: Option<Box<dyn VirtualFunction>>,
3697 pending_vf_state: CoordinatorStatePendingVfState,
3698}
3699
3700impl InspectTaskMut<Coordinator> for CoordinatorState {
3701 fn inspect_mut(
3702 &mut self,
3703 req: inspect::Request<'_>,
3704 mut coordinator: Option<&mut Coordinator>,
3705 ) {
3706 let mut resp = req.respond();
3707
3708 let adapter = self.adapter.as_ref();
3709 resp.field("mac_address", adapter.mac_address)
3710 .field("max_queues", adapter.max_queues)
3711 .sensitivity_field(
3712 "offload_support",
3713 SensitivityLevel::Safe,
3714 &adapter.offload_support,
3715 )
3716 .field_mut_with("ring_size_limit", |v| -> anyhow::Result<_> {
3717 if let Some(v) = v {
3718 let v: usize = v.parse()?;
3719 adapter.ring_size_limit.store(v, Ordering::Relaxed);
3720 if let Some(this) = &mut coordinator {
3722 for worker in &mut this.workers {
3723 worker.update_with(|_, _| ());
3724 }
3725 }
3726 }
3727 Ok(adapter.ring_size_limit.load(Ordering::Relaxed))
3728 });
3729
3730 resp.field("endpoint_type", self.endpoint.endpoint_type())
3731 .field(
3732 "endpoint_max_queues",
3733 self.endpoint.multiqueue_support().max_queues,
3734 )
3735 .sensitivity_field_mut("endpoint", SensitivityLevel::Safe, self.endpoint.as_mut());
3736
3737 if let Some(coordinator) = coordinator {
3738 resp.sensitivity_child("queues", SensitivityLevel::Safe, |req| {
3739 let mut resp = req.respond();
3740 for (i, q) in coordinator.workers[..coordinator.num_queues as usize]
3741 .iter_mut()
3742 .enumerate()
3743 {
3744 resp.field_mut(&i.to_string(), q);
3745 }
3746 });
3747
3748 resp.merge(inspect::adhoc_mut(|req| {
3750 let deferred = req.defer();
3751 coordinator.workers[0].update_with(|_, worker| {
3752 let Some(worker) = worker.as_deref() else {
3753 return;
3754 };
3755 if let Some(state) = worker.state.ready() {
3756 deferred.respond(|resp| {
3757 resp.merge(&state.buffers);
3758 resp.sensitivity_field(
3759 "primary_channel_state",
3760 SensitivityLevel::Safe,
3761 &state.state.primary,
3762 )
3763 .sensitivity_field(
3764 "packet_filter",
3765 SensitivityLevel::Safe,
3766 inspect::AsHex(worker.channel.packet_filter),
3767 );
3768 });
3769 }
3770 })
3771 }));
3772 }
3773 }
3774}
3775
3776impl AsyncRun<Coordinator> for CoordinatorState {
3777 async fn run(
3778 &mut self,
3779 stop: &mut StopTask<'_>,
3780 coordinator: &mut Coordinator,
3781 ) -> Result<(), task_control::Cancelled> {
3782 coordinator.process(stop, self).await
3783 }
3784}
3785
3786impl Coordinator {
3787 async fn process(
3788 &mut self,
3789 stop: &mut StopTask<'_>,
3790 state: &mut CoordinatorState,
3791 ) -> Result<(), task_control::Cancelled> {
3792 let mut sleep_duration: Option<Instant> = None;
3793 loop {
3794 if self.restart {
3795 stop.until_stopped(self.stop_workers()).await?;
3796 if let Err(err) = self
3799 .restart_queues(state)
3800 .instrument(tracing::info_span!("netvsp_restart_queues"))
3801 .await
3802 {
3803 tracing::error!(
3804 error = &err as &dyn std::error::Error,
3805 "failed to restart queues"
3806 );
3807 }
3808 if let Some(primary) = self.primary_mut() {
3809 primary.is_data_path_switched =
3810 state.endpoint.get_data_path_to_guest_vf().await.ok();
3811 tracing::info!(
3812 is_data_path_switched = primary.is_data_path_switched,
3813 "Query data path state"
3814 );
3815 }
3816 self.restore_guest_vf_state(state).await;
3817 self.restart = false;
3818 }
3819
3820 for worker in &mut self.workers[1..] {
3823 worker.start();
3824 }
3825
3826 enum Message {
3827 Internal(CoordinatorMessage),
3828 ChannelDisconnected,
3829 UpdateFromEndpoint(EndpointAction),
3830 UpdateFromVf(Rpc<(), ()>),
3831 OfferVfDevice,
3832 PendingVfStateComplete,
3833 TimerExpired,
3834 }
3835 let message = if matches!(
3836 state.pending_vf_state,
3837 CoordinatorStatePendingVfState::Pending
3838 ) {
3839 if !self.workers[0].is_running()
3846 && self.primary_mut().is_none_or(|primary| {
3847 !matches!(
3848 primary.guest_vf_state,
3849 PrimaryChannelGuestVfState::DataPathSwitchPending { result: None, .. }
3850 )
3851 })
3852 {
3853 self.workers[0].start();
3854 }
3855
3856 state
3859 .virtual_function
3860 .as_mut()
3861 .expect("Pending requires a VF")
3862 .guest_ready_for_device()
3863 .await;
3864 Message::PendingVfStateComplete
3865 } else {
3866 let timer_sleep = async {
3867 if let Some(sleep_duration) = sleep_duration {
3868 let mut timer = PolledTimer::new(&state.adapter.driver);
3869 timer.sleep_until(sleep_duration).await;
3870 } else {
3871 pending::<()>().await;
3872 }
3873 Message::TimerExpired
3874 };
3875 let wait_for_message = async {
3876 let internal_msg = self
3877 .recv
3878 .next()
3879 .map(|x| x.map_or(Message::ChannelDisconnected, Message::Internal));
3880 let endpoint_restart = state
3881 .endpoint
3882 .wait_for_endpoint_action()
3883 .map(Message::UpdateFromEndpoint);
3884 if let Some(vf) = state.virtual_function.as_mut() {
3885 match state.pending_vf_state {
3886 CoordinatorStatePendingVfState::Ready
3887 | CoordinatorStatePendingVfState::Delay { .. } => {
3888 let offer_device = async {
3889 if let CoordinatorStatePendingVfState::Delay {
3890 timer,
3891 delay_until,
3892 } = &mut state.pending_vf_state
3893 {
3894 timer.sleep_until(*delay_until).await;
3895 } else {
3896 pending::<()>().await;
3897 }
3898 Message::OfferVfDevice
3899 };
3900 (
3901 internal_msg,
3902 offer_device,
3903 endpoint_restart,
3904 vf.wait_for_state_change().map(Message::UpdateFromVf),
3905 timer_sleep,
3906 )
3907 .race()
3908 .await
3909 }
3910 CoordinatorStatePendingVfState::Pending => unreachable!(),
3911 }
3912 } else {
3913 (internal_msg, endpoint_restart, timer_sleep).race().await
3914 }
3915 };
3916
3917 let mut wait_for_message = std::pin::pin!(wait_for_message);
3918 match (&mut wait_for_message).now_or_never() {
3919 Some(message) => message,
3920 None => {
3921 self.workers[0].start();
3922 stop.until_stopped(wait_for_message).await?
3923 }
3924 }
3925 };
3926 match message {
3927 Message::UpdateFromVf(rpc) => {
3928 rpc.handle(async |_| {
3929 self.update_guest_vf_state(state).await;
3930 })
3931 .await;
3932 }
3933 Message::OfferVfDevice => {
3934 self.workers[0].stop().await;
3935 if let Some(primary) = self.primary_mut() {
3936 if matches!(
3937 primary.guest_vf_state,
3938 PrimaryChannelGuestVfState::AvailableAdvertised
3939 ) {
3940 primary.guest_vf_state = PrimaryChannelGuestVfState::Ready;
3941 }
3942 }
3943
3944 state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
3945 }
3946 Message::PendingVfStateComplete => {
3947 state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
3948 }
3949 Message::TimerExpired => {
3950 if self.workers[0].is_running() {
3952 self.workers[0].stop().await;
3953 if let Some(primary) = self.primary_mut() {
3954 if let PendingLinkAction::Delay(up) = primary.pending_link_action {
3955 primary.pending_link_action = PendingLinkAction::Active(up);
3956 }
3957 }
3958 }
3959 sleep_duration = None;
3960 }
3961 Message::Internal(CoordinatorMessage::Update(update_type)) => {
3962 if update_type.filter_state {
3963 self.stop_workers().await;
3964 let worker_0_packet_filter =
3965 self.workers[0].state().unwrap().channel.packet_filter;
3966 self.workers.iter_mut().skip(1).for_each(|worker| {
3967 if let Some(state) = worker.state_mut() {
3968 state.channel.packet_filter = worker_0_packet_filter;
3969 tracing::debug!(
3970 packet_filter = ?worker_0_packet_filter,
3971 channel_idx = state.channel_idx,
3972 "update packet filter"
3973 );
3974 }
3975 });
3976 }
3977
3978 if update_type.guest_vf_state {
3979 self.update_guest_vf_state(state).await;
3980 }
3981 }
3982 Message::UpdateFromEndpoint(EndpointAction::RestartRequired) => self.restart = true,
3983 Message::UpdateFromEndpoint(EndpointAction::LinkStatusNotify(connect)) => {
3984 self.workers[0].stop().await;
3985
3986 if let Some(primary) = self.primary_mut() {
3994 primary.pending_link_action = PendingLinkAction::Active(connect);
3995 }
3996
3997 sleep_duration = None;
3999 }
4000 Message::Internal(CoordinatorMessage::Restart) => self.restart = true,
4001 Message::Internal(CoordinatorMessage::StartTimer(duration)) => {
4002 sleep_duration = Some(duration);
4003 self.workers[0].stop().await;
4005 }
4006 Message::ChannelDisconnected => {
4007 break;
4008 }
4009 };
4010 }
4011 Ok(())
4012 }
4013
4014 async fn stop_workers(&mut self) {
4015 for worker in &mut self.workers {
4016 worker.stop().await;
4017 }
4018 }
4019
4020 async fn restore_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
4021 let primary = match self.primary_mut() {
4022 Some(primary) => primary,
4023 None => return,
4024 };
4025
4026 let virtual_function = c_state.virtual_function.as_mut();
4028 let guest_vf_id = match &virtual_function {
4029 Some(vf) => vf.id().await,
4030 None => None,
4031 };
4032 if let Some(guest_vf_id) = guest_vf_id {
4033 match primary.guest_vf_state {
4035 PrimaryChannelGuestVfState::AvailableAdvertised
4036 | PrimaryChannelGuestVfState::Restoring(
4037 saved_state::GuestVfState::AvailableAdvertised,
4038 ) => {
4039 if !primary.is_data_path_switched.unwrap_or(false) {
4040 let timer = PolledTimer::new(&c_state.adapter.driver);
4041 c_state.pending_vf_state = CoordinatorStatePendingVfState::Delay {
4042 timer,
4043 delay_until: Instant::now() + VF_DEVICE_DELAY,
4044 };
4045 }
4046 }
4047 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4048 | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4049 | PrimaryChannelGuestVfState::Ready
4050 | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready)
4051 | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4052 | PrimaryChannelGuestVfState::Restoring(
4053 saved_state::GuestVfState::DataPathSwitchPending { .. },
4054 )
4055 | PrimaryChannelGuestVfState::DataPathSwitched
4056 | PrimaryChannelGuestVfState::Restoring(
4057 saved_state::GuestVfState::DataPathSwitched,
4058 )
4059 | PrimaryChannelGuestVfState::DataPathSynthetic => {
4060 c_state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4061 }
4062 _ => (),
4063 };
4064 if let PrimaryChannelGuestVfState::Restoring(
4066 saved_state::GuestVfState::DataPathSwitchPending {
4067 to_guest,
4068 id,
4069 result,
4070 },
4071 ) = primary.guest_vf_state
4072 {
4073 if result.is_some() {
4075 primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
4076 to_guest,
4077 id,
4078 result,
4079 };
4080 return;
4081 }
4082 }
4083 primary.guest_vf_state = match primary.guest_vf_state {
4084 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4085 | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4086 | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4087 | PrimaryChannelGuestVfState::Restoring(
4088 saved_state::GuestVfState::DataPathSwitchPending { .. },
4089 )
4090 | PrimaryChannelGuestVfState::DataPathSynthetic => {
4091 let (to_guest, id) = match primary.guest_vf_state {
4092 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4093 to_guest,
4094 id,
4095 }
4096 | PrimaryChannelGuestVfState::DataPathSwitchPending {
4097 to_guest, id, ..
4098 }
4099 | PrimaryChannelGuestVfState::Restoring(
4100 saved_state::GuestVfState::DataPathSwitchPending {
4101 to_guest, id, ..
4102 },
4103 ) => (to_guest, id),
4104 _ => (true, None),
4105 };
4106 if matches!(
4110 c_state.pending_vf_state,
4111 CoordinatorStatePendingVfState::Delay { .. }
4112 ) {
4113 c_state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4114 }
4115 let result = c_state.endpoint.set_data_path_to_guest_vf(to_guest).await;
4116 let result = if let Err(err) = result {
4117 tracing::error!(
4118 err = err.as_ref() as &dyn std::error::Error,
4119 to_guest,
4120 "Failed to switch guest VF data path"
4121 );
4122 false
4123 } else {
4124 primary.is_data_path_switched = Some(to_guest);
4125 true
4126 };
4127 match primary.guest_vf_state {
4128 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4129 ..
4130 }
4131 | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4132 | PrimaryChannelGuestVfState::Restoring(
4133 saved_state::GuestVfState::DataPathSwitchPending { .. },
4134 ) => PrimaryChannelGuestVfState::DataPathSwitchPending {
4135 to_guest,
4136 id,
4137 result: Some(result),
4138 },
4139 _ if result => PrimaryChannelGuestVfState::DataPathSwitched,
4140 _ => PrimaryChannelGuestVfState::DataPathSynthetic,
4141 }
4142 }
4143 PrimaryChannelGuestVfState::Initializing
4144 | PrimaryChannelGuestVfState::Unavailable
4145 | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
4146 PrimaryChannelGuestVfState::Available { vfid: guest_vf_id }
4147 }
4148 PrimaryChannelGuestVfState::AvailableAdvertised
4149 | PrimaryChannelGuestVfState::Restoring(
4150 saved_state::GuestVfState::AvailableAdvertised,
4151 ) => {
4152 if !primary.is_data_path_switched.unwrap_or(false) {
4153 PrimaryChannelGuestVfState::AvailableAdvertised
4154 } else {
4155 PrimaryChannelGuestVfState::DataPathSwitched
4158 }
4159 }
4160 PrimaryChannelGuestVfState::DataPathSwitched
4161 | PrimaryChannelGuestVfState::Restoring(
4162 saved_state::GuestVfState::DataPathSwitched,
4163 ) => PrimaryChannelGuestVfState::DataPathSwitched,
4164 PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4165 PrimaryChannelGuestVfState::Ready
4166 }
4167 _ => primary.guest_vf_state,
4168 };
4169 } else {
4170 match primary.guest_vf_state {
4172 PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, .. }
4173 | PrimaryChannelGuestVfState::Restoring(
4174 saved_state::GuestVfState::DataPathSwitchPending { to_guest, .. },
4175 ) => {
4176 if !to_guest {
4177 if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4178 tracing::warn!(
4179 err = err.as_ref() as &dyn std::error::Error,
4180 "Failed setting data path back to synthetic after guest VF was removed."
4181 );
4182 }
4183 primary.is_data_path_switched = Some(false);
4184 }
4185 }
4186 PrimaryChannelGuestVfState::DataPathSwitched
4187 | PrimaryChannelGuestVfState::Restoring(
4188 saved_state::GuestVfState::DataPathSwitched,
4189 ) => {
4190 if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4191 tracing::warn!(
4192 err = err.as_ref() as &dyn std::error::Error,
4193 "Failed setting data path back to synthetic after guest VF was removed."
4194 );
4195 }
4196 primary.is_data_path_switched = Some(false);
4197 }
4198 _ => (),
4199 }
4200 if let PrimaryChannelGuestVfState::AvailableAdvertised = primary.guest_vf_state {
4201 c_state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
4202 }
4203 primary.guest_vf_state = match primary.guest_vf_state {
4205 PrimaryChannelGuestVfState::Initializing
4206 | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState)
4207 | PrimaryChannelGuestVfState::Available { .. } => {
4208 PrimaryChannelGuestVfState::Unavailable
4209 }
4210 PrimaryChannelGuestVfState::AvailableAdvertised
4211 | PrimaryChannelGuestVfState::Restoring(
4212 saved_state::GuestVfState::AvailableAdvertised,
4213 )
4214 | PrimaryChannelGuestVfState::Ready
4215 | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4216 PrimaryChannelGuestVfState::UnavailableFromAvailable
4217 }
4218 PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, id, .. }
4219 | PrimaryChannelGuestVfState::Restoring(
4220 saved_state::GuestVfState::DataPathSwitchPending { to_guest, id, .. },
4221 ) => PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4222 to_guest,
4223 id,
4224 },
4225 PrimaryChannelGuestVfState::DataPathSwitched
4226 | PrimaryChannelGuestVfState::Restoring(
4227 saved_state::GuestVfState::DataPathSwitched,
4228 )
4229 | PrimaryChannelGuestVfState::DataPathSynthetic => {
4230 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4231 }
4232 _ => primary.guest_vf_state,
4233 }
4234 }
4235 }
4236
4237 async fn restart_queues(&mut self, c_state: &mut CoordinatorState) -> Result<(), WorkerError> {
4238 let drivers = self
4240 .workers
4241 .iter_mut()
4242 .map(|worker| {
4243 let task = worker.task_mut();
4244 task.queue_state = None;
4245 task.driver.clone()
4246 })
4247 .collect::<Vec<_>>();
4248
4249 c_state.endpoint.stop().await;
4250
4251 let (primary_worker, subworkers) = if let [primary, sub @ ..] = self.workers.as_mut_slice()
4252 {
4253 (primary, sub)
4254 } else {
4255 unreachable!()
4256 };
4257
4258 let state = primary_worker
4259 .state_mut()
4260 .and_then(|worker| worker.state.ready_mut());
4261
4262 let state = if let Some(state) = state {
4263 state
4264 } else {
4265 return Ok(());
4266 };
4267
4268 self.buffers = Some(state.buffers.clone());
4270
4271 let num_queues = state.state.primary.as_ref().unwrap().requested_num_queues;
4272 let mut active_queues = Vec::new();
4273 let active_queue_count =
4274 if let Some(rss_state) = state.state.primary.as_ref().unwrap().rss_state.as_ref() {
4275 active_queues.clone_from(&rss_state.indirection_table);
4277 active_queues.sort();
4278 active_queues.dedup();
4279 active_queues = active_queues
4280 .into_iter()
4281 .filter(|&index| index < num_queues)
4282 .collect::<Vec<_>>();
4283 if !active_queues.is_empty() {
4284 active_queues.len() as u16
4285 } else {
4286 tracelimit::warn_ratelimited!("Invalid RSS indirection table");
4287 num_queues
4288 }
4289 } else {
4290 num_queues
4291 };
4292
4293 let (ranges, mut remote_buffer_id_recvs) =
4295 RxBufferRanges::new(state.buffers.recv_buffer.count, active_queue_count.into());
4296 let ranges = Arc::new(ranges);
4297
4298 let mut queues = Vec::new();
4299 let mut rx_buffers = Vec::new();
4300 {
4301 let buffers = &state.buffers;
4302 let guest_buffers = Arc::new(
4303 GuestBuffers::new(
4304 buffers.mem.clone(),
4305 buffers.recv_buffer.gpadl.clone(),
4306 buffers.recv_buffer.sub_allocation_size,
4307 buffers.ndis_config.mtu,
4308 )
4309 .map_err(WorkerError::GpadlError)?,
4310 );
4311
4312 let mut queue_config = Vec::new();
4315 let initial_rx;
4316 {
4317 let states = std::iter::once(Some(&*state)).chain(
4318 subworkers
4319 .iter()
4320 .map(|worker| worker.state().and_then(|worker| worker.state.ready())),
4321 );
4322
4323 initial_rx = (RX_RESERVED_CONTROL_BUFFERS..state.buffers.recv_buffer.count)
4324 .filter(|&n| {
4325 states
4326 .clone()
4327 .flatten()
4328 .all(|s| (s.state.rx_bufs.is_free(n)))
4329 })
4330 .map(RxId)
4331 .collect::<Vec<_>>();
4332
4333 let mut initial_rx = initial_rx.as_slice();
4334 let mut range_start = 0;
4335 let primary_queue_excluded = !active_queues.is_empty() && active_queues[0] != 0;
4336 let first_queue = if !primary_queue_excluded {
4337 0
4338 } else {
4339 queue_config.push(QueueConfig {
4343 pool: Box::new(BufferPool::new(guest_buffers.clone())),
4344 initial_rx: &[],
4345 driver: Box::new(drivers[0].clone()),
4346 });
4347 rx_buffers.push(RxBufferRange::new(
4348 ranges.clone(),
4349 0..RX_RESERVED_CONTROL_BUFFERS,
4350 None,
4351 ));
4352 range_start = RX_RESERVED_CONTROL_BUFFERS;
4353 1
4354 };
4355 for queue_index in first_queue..num_queues {
4356 let queue_active = active_queues.is_empty()
4357 || active_queues.binary_search(&queue_index).is_ok();
4358 let (range_end, end, buffer_id_recv) = if queue_active {
4359 let range_end = if rx_buffers.len() as u16 == active_queue_count - 1 {
4360 state.buffers.recv_buffer.count
4362 } else if queue_index == 0 {
4363 RX_RESERVED_CONTROL_BUFFERS + ranges.buffers_per_queue
4365 } else {
4366 range_start + ranges.buffers_per_queue
4367 };
4368 (
4369 range_end,
4370 initial_rx.partition_point(|id| id.0 < range_end),
4371 Some(remote_buffer_id_recvs.remove(0)),
4372 )
4373 } else {
4374 (range_start, 0, None)
4375 };
4376
4377 let (this, rest) = initial_rx.split_at(end);
4378 queue_config.push(QueueConfig {
4379 pool: Box::new(BufferPool::new(guest_buffers.clone())),
4380 initial_rx: this,
4381 driver: Box::new(drivers[queue_index as usize].clone()),
4382 });
4383 initial_rx = rest;
4384 rx_buffers.push(RxBufferRange::new(
4385 ranges.clone(),
4386 range_start..range_end,
4387 buffer_id_recv,
4388 ));
4389
4390 range_start = range_end;
4391 }
4392 }
4393
4394 let primary = state.state.primary.as_mut().unwrap();
4395 tracing::debug!(num_queues, "enabling endpoint");
4396
4397 let rss = primary
4398 .rss_state
4399 .as_ref()
4400 .map(|rss| net_backend::RssConfig {
4401 key: &rss.key,
4402 indirection_table: &rss.indirection_table,
4403 flags: 0,
4404 });
4405
4406 c_state
4407 .endpoint
4408 .get_queues(queue_config, rss.as_ref(), &mut queues)
4409 .instrument(tracing::info_span!("netvsp_get_queues"))
4410 .await
4411 .map_err(WorkerError::Endpoint)?;
4412
4413 assert_eq!(queues.len(), num_queues as usize);
4414
4415 self.channel_control
4417 .enable_subchannels(num_queues - 1)
4418 .expect("already validated");
4419
4420 self.num_queues = num_queues;
4421 }
4422
4423 let worker_0_packet_filter = self.workers[0].state().unwrap().channel.packet_filter;
4424 for ((worker, queue), rx_buffer) in self.workers.iter_mut().zip(queues).zip(rx_buffers) {
4426 worker.task_mut().queue_state = Some(QueueState {
4427 queue,
4428 target_vp_set: false,
4429 rx_buffer_range: rx_buffer,
4430 });
4431 if let Some(worker) = worker.state_mut() {
4433 worker.channel.packet_filter = worker_0_packet_filter;
4434 if let Some(ready_state) = worker.state.ready_mut() {
4436 ready_state.state.pending_rx_packets.clear();
4437 }
4438 }
4439 }
4440
4441 Ok(())
4442 }
4443
4444 fn primary_mut(&mut self) -> Option<&mut PrimaryChannelState> {
4445 self.workers[0]
4446 .state_mut()
4447 .unwrap()
4448 .state
4449 .ready_mut()?
4450 .state
4451 .primary
4452 .as_mut()
4453 }
4454
4455 async fn update_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
4456 self.workers[0].stop().await;
4457 self.restore_guest_vf_state(c_state).await;
4458 }
4459}
4460
4461impl<T: RingMem + 'static + Sync> AsyncRun<Worker<T>> for NetQueue {
4462 async fn run(
4463 &mut self,
4464 stop: &mut StopTask<'_>,
4465 worker: &mut Worker<T>,
4466 ) -> Result<(), task_control::Cancelled> {
4467 match worker.process(stop, self).await {
4468 Ok(()) | Err(WorkerError::BufferRevoked) => {}
4469 Err(WorkerError::Cancelled(cancelled)) => return Err(cancelled),
4470 Err(err) => {
4471 tracing::error!(
4472 channel_idx = worker.channel_idx,
4473 error = &err as &dyn std::error::Error,
4474 "netvsp error"
4475 );
4476 }
4477 }
4478 Ok(())
4479 }
4480}
4481
4482impl<T: RingMem + 'static> Worker<T> {
4483 async fn process(
4484 &mut self,
4485 stop: &mut StopTask<'_>,
4486 queue: &mut NetQueue,
4487 ) -> Result<(), WorkerError> {
4488 loop {
4492 match &mut self.state {
4493 WorkerState::Init(initializing) => {
4494 if self.channel_idx != 0 {
4495 stop.until_stopped(pending()).await?
4498 }
4499
4500 tracelimit::info_ratelimited!("network accepted");
4501
4502 let (buffers, state) = stop
4503 .until_stopped(self.channel.initialize(initializing, self.mem.clone()))
4504 .await??;
4505
4506 let state = ReadyState {
4507 buffers: Arc::new(buffers),
4508 state,
4509 data: ProcessingData::new(),
4510 };
4511
4512 let _ = self.coordinator_send.try_send(CoordinatorMessage::Restart);
4514
4515 tracelimit::info_ratelimited!("network initialized");
4516 self.state = WorkerState::Ready(state);
4517 }
4518 WorkerState::Ready(state) => {
4519 let queue_state = if let Some(queue_state) = &mut queue.queue_state {
4520 if !queue_state.target_vp_set
4521 && self.target_vp != vmbus_core::protocol::VP_INDEX_DISABLE_INTERRUPT
4522 {
4523 tracing::debug!(
4524 channel_idx = self.channel_idx,
4525 target_vp = self.target_vp,
4526 "updating target VP"
4527 );
4528 queue_state.queue.update_target_vp(self.target_vp).await;
4529 queue_state.target_vp_set = true;
4530 }
4531
4532 queue_state
4533 } else {
4534 stop.until_stopped(pending()).await?
4536 };
4537
4538 let result = self.channel.main_loop(stop, state, queue_state).await;
4539 match result {
4540 Ok(restart) => {
4541 assert_eq!(self.channel_idx, 0);
4542 let _ = self.coordinator_send.try_send(restart);
4543 }
4544 Err(WorkerError::EndpointRequiresQueueRestart(err)) => {
4545 tracelimit::warn_ratelimited!(
4546 err = err.as_ref() as &dyn std::error::Error,
4547 "Endpoint requires queues to restart",
4548 );
4549 if let Err(try_send_err) =
4550 self.coordinator_send.try_send(CoordinatorMessage::Restart)
4551 {
4552 tracing::error!(
4553 try_send_err = &try_send_err as &dyn std::error::Error,
4554 "failed to restart queues"
4555 );
4556 return Err(WorkerError::Endpoint(err));
4557 }
4558 }
4559 Err(err) => return Err(err),
4560 }
4561
4562 stop.until_stopped(pending()).await?
4563 }
4564 }
4565 }
4566 }
4567}
4568
4569impl<T: 'static + RingMem> NetChannel<T> {
4570 fn try_next_packet<'a>(
4571 &mut self,
4572 send_buffer: Option<&'a SendBuffer>,
4573 version: Option<Version>,
4574 ) -> Result<Option<Packet<'a>>, WorkerError> {
4575 let (mut read, _) = self.queue.split();
4576 let packet = match read.try_read() {
4577 Ok(packet) => {
4578 parse_packet(&packet, send_buffer, version).map_err(WorkerError::Packet)?
4579 }
4580 Err(queue::TryReadError::Empty) => return Ok(None),
4581 Err(queue::TryReadError::Queue(err)) => return Err(err.into()),
4582 };
4583
4584 tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
4585 Ok(Some(packet))
4586 }
4587
4588 async fn next_packet<'a>(
4589 &mut self,
4590 send_buffer: Option<&'a SendBuffer>,
4591 version: Option<Version>,
4592 ) -> Result<Packet<'a>, WorkerError> {
4593 let (mut read, _) = self.queue.split();
4594 let mut packet_ref = read.read().await?;
4595 let packet =
4596 parse_packet(&packet_ref, send_buffer, version).map_err(WorkerError::Packet)?;
4597 if matches!(packet.data, PacketData::RndisPacket(_)) {
4598 tracing::trace!(target: "netvsp/vmbus", "detected rndis initialization message");
4600 packet_ref.revert();
4601 }
4602 tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
4603 Ok(packet)
4604 }
4605
4606 fn is_ready_to_initialize(initializing: &InitState, allow_missing_send_buffer: bool) -> bool {
4607 (initializing.ndis_config.is_some() || initializing.version < Version::V2)
4608 && initializing.ndis_version.is_some()
4609 && (initializing.send_buffer.is_some() || allow_missing_send_buffer)
4610 && initializing.recv_buffer.is_some()
4611 }
4612
4613 async fn initialize(
4614 &mut self,
4615 initializing: &mut Option<InitState>,
4616 mem: GuestMemory,
4617 ) -> Result<(ChannelBuffers, ActiveState), WorkerError> {
4618 let mut has_init_packet_arrived = false;
4619 loop {
4620 if let Some(initializing) = &mut *initializing {
4621 if Self::is_ready_to_initialize(initializing, false) || has_init_packet_arrived {
4622 let recv_buffer = initializing.recv_buffer.take().unwrap();
4623 let send_buffer = initializing.send_buffer.take();
4624 let state = ActiveState::new(
4625 Some(PrimaryChannelState::new(
4626 self.adapter.offload_support.clone(),
4627 )),
4628 recv_buffer.count,
4629 );
4630 let buffers = ChannelBuffers {
4631 version: initializing.version,
4632 mem,
4633 recv_buffer,
4634 send_buffer,
4635 ndis_version: initializing.ndis_version.take().unwrap(),
4636 ndis_config: initializing.ndis_config.take().unwrap_or(NdisConfig {
4637 mtu: DEFAULT_MTU,
4638 capabilities: protocol::NdisConfigCapabilities::new(),
4639 }),
4640 };
4641
4642 break Ok((buffers, state));
4643 }
4644 }
4645
4646 self.queue
4649 .split()
4650 .1
4651 .wait_ready(ring::PacketSize::completion(protocol::PACKET_SIZE_V61))
4652 .await?;
4653
4654 let packet = self
4655 .next_packet(None, initializing.as_ref().map(|x| x.version))
4656 .await?;
4657
4658 if let Some(initializing) = &mut *initializing {
4659 match packet.data {
4660 PacketData::SendNdisConfig(config) => {
4661 if initializing.ndis_config.is_some() {
4662 return Err(WorkerError::UnexpectedPacketOrder(
4663 PacketOrderError::SendNdisConfigExists,
4664 ));
4665 }
4666
4667 let mtu = if config.mtu >= MIN_MTU && config.mtu <= MAX_MTU {
4669 config.mtu
4670 } else {
4671 DEFAULT_MTU
4672 };
4673
4674 self.send_completion(packet.transaction_id, &[])?;
4676 initializing.ndis_config = Some(NdisConfig {
4677 mtu,
4678 capabilities: config.capabilities,
4679 });
4680 }
4681 PacketData::SendNdisVersion(version) => {
4682 if initializing.ndis_version.is_some() {
4683 return Err(WorkerError::UnexpectedPacketOrder(
4684 PacketOrderError::SendNdisVersionExists,
4685 ));
4686 }
4687
4688 self.send_completion(packet.transaction_id, &[])?;
4690 initializing.ndis_version = Some(NdisVersion {
4691 major: version.ndis_major_version,
4692 minor: version.ndis_minor_version,
4693 });
4694 }
4695 PacketData::SendReceiveBuffer(message) => {
4696 if initializing.recv_buffer.is_some() {
4697 return Err(WorkerError::UnexpectedPacketOrder(
4698 PacketOrderError::SendReceiveBufferExists,
4699 ));
4700 }
4701
4702 let mtu = if let Some(cfg) = &initializing.ndis_config {
4703 cfg.mtu
4704 } else if initializing.version < Version::V2 {
4705 DEFAULT_MTU
4706 } else {
4707 return Err(WorkerError::UnexpectedPacketOrder(
4708 PacketOrderError::SendReceiveBufferMissingMTU,
4709 ));
4710 };
4711
4712 let sub_allocation_size = sub_allocation_size_for_mtu(mtu);
4713
4714 let recv_buffer = ReceiveBuffer::new(
4715 &self.gpadl_map,
4716 message.gpadl_handle,
4717 message.id,
4718 sub_allocation_size,
4719 )?;
4720
4721 self.send_completion(
4722 packet.transaction_id,
4723 &self
4724 .message(
4725 protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER_COMPLETE,
4726 protocol::Message1SendReceiveBufferComplete {
4727 status: protocol::Status::SUCCESS,
4728 num_sections: 1,
4729 sections: [protocol::ReceiveBufferSection {
4730 offset: 0,
4731 sub_allocation_size: recv_buffer.sub_allocation_size,
4732 num_sub_allocations: recv_buffer.count,
4733 end_offset: recv_buffer.sub_allocation_size
4734 * recv_buffer.count,
4735 }],
4736 },
4737 )
4738 .payload(),
4739 )?;
4740 initializing.recv_buffer = Some(recv_buffer);
4741 }
4742 PacketData::SendSendBuffer(message) => {
4743 if initializing.send_buffer.is_some() {
4744 return Err(WorkerError::UnexpectedPacketOrder(
4745 PacketOrderError::SendSendBufferExists,
4746 ));
4747 }
4748
4749 let send_buffer = SendBuffer::new(&self.gpadl_map, message.gpadl_handle)?;
4750 self.send_completion(
4751 packet.transaction_id,
4752 &self
4753 .message(
4754 protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER_COMPLETE,
4755 protocol::Message1SendSendBufferComplete {
4756 status: protocol::Status::SUCCESS,
4757 section_size: 6144,
4758 },
4759 )
4760 .payload(),
4761 )?;
4762
4763 initializing.send_buffer = Some(send_buffer);
4764 }
4765 PacketData::RndisPacket(rndis_packet) => {
4766 if !Self::is_ready_to_initialize(initializing, true) {
4767 return Err(WorkerError::UnexpectedPacketOrder(
4768 PacketOrderError::UnexpectedRndisPacket,
4769 ));
4770 }
4771 tracing::debug!(
4772 channel_type = rndis_packet.channel_type,
4773 "RndisPacket received during initialization, assuming MESSAGE_TYPE_INITIALIZE_MSG"
4774 );
4775 has_init_packet_arrived = true;
4776 }
4777 _ => {
4778 return Err(WorkerError::UnexpectedPacketOrder(
4779 PacketOrderError::InvalidPacketData,
4780 ));
4781 }
4782 }
4783 } else {
4784 match packet.data {
4785 PacketData::Init(init) => {
4786 let requested_version = init.protocol_version;
4787 let version = check_version(requested_version);
4788 let mut message = self.message(
4789 protocol::MESSAGE_TYPE_INIT_COMPLETE,
4790 protocol::MessageInitComplete {
4791 deprecated: protocol::INVALID_PROTOCOL_VERSION,
4792 maximum_mdl_chain_length: 34,
4793 status: protocol::Status::NONE,
4794 },
4795 );
4796 if let Some(version) = version {
4797 if version == Version::V1 {
4798 message.data.deprecated = Version::V1 as u32;
4799 }
4800 message.data.status = protocol::Status::SUCCESS;
4801 } else {
4802 tracing::debug!(requested_version, "unrecognized version");
4803 }
4804 self.send_completion(packet.transaction_id, &message.payload())?;
4805
4806 if let Some(version) = version {
4807 tracelimit::info_ratelimited!(?version, "network negotiated");
4808
4809 if version >= Version::V61 {
4810 self.packet_size = protocol::PACKET_SIZE_V61;
4813 }
4814 *initializing = Some(InitState {
4815 version,
4816 ndis_config: None,
4817 ndis_version: None,
4818 recv_buffer: None,
4819 send_buffer: None,
4820 });
4821 }
4822 }
4823 _ => unreachable!(),
4824 }
4825 }
4826 }
4827 }
4828
4829 async fn main_loop(
4830 &mut self,
4831 stop: &mut StopTask<'_>,
4832 ready_state: &mut ReadyState,
4833 queue_state: &mut QueueState,
4834 ) -> Result<CoordinatorMessage, WorkerError> {
4835 let buffers = &ready_state.buffers;
4836 let state = &mut ready_state.state;
4837 let data = &mut ready_state.data;
4838
4839 let ring_spare_capacity = {
4840 let (_, send) = self.queue.split();
4841 let mut limit = if self.can_use_ring_size_opt {
4842 self.adapter.ring_size_limit.load(Ordering::Relaxed)
4843 } else {
4844 0
4845 };
4846 if limit == 0 {
4847 limit = send.capacity() - 2048;
4848 }
4849 send.capacity() - limit
4850 };
4851
4852 if !state.pending_rx_packets.is_empty()
4854 && self.packet_filter != rndisprot::NDIS_PACKET_TYPE_NONE
4855 {
4856 let epqueue = queue_state.queue.as_mut();
4857 let (front, back) = state.pending_rx_packets.as_slices();
4858 epqueue.rx_avail(front);
4859 epqueue.rx_avail(back);
4860 state.pending_rx_packets.clear();
4861 }
4862
4863 if let Some(primary) = state.primary.as_mut() {
4865 if primary.requested_num_queues > 1 && !primary.tx_spread_sent {
4866 let num_channels_opened =
4867 self.adapter.num_sub_channels_opened.load(Ordering::Relaxed);
4868 if num_channels_opened == primary.requested_num_queues as usize {
4869 let (_, mut send) = self.queue.split();
4870 stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
4871 .await??;
4872 self.guest_send_indirection_table(buffers.version, num_channels_opened as u32);
4873 primary.tx_spread_sent = true;
4874 }
4875 }
4876 if let PendingLinkAction::Active(up) = primary.pending_link_action {
4877 let (_, mut send) = self.queue.split();
4878 stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
4879 .await??;
4880 if let Some(id) = primary.free_control_buffers.pop() {
4881 let connect = if primary.guest_link_up != up {
4882 primary.pending_link_action = PendingLinkAction::Default;
4883 up
4884 } else {
4885 primary.pending_link_action =
4888 PendingLinkAction::Delay(primary.guest_link_up);
4889 !primary.guest_link_up
4890 };
4891 assert!(state.rx_bufs.is_free(id.0));
4893 state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
4894 let state_to_send = if connect {
4895 rndisprot::STATUS_MEDIA_CONNECT
4896 } else {
4897 rndisprot::STATUS_MEDIA_DISCONNECT
4898 };
4899 tracing::info!(
4900 connect,
4901 mac_address = %self.adapter.mac_address,
4902 "sending link status"
4903 );
4904
4905 self.indicate_status(buffers, id.0, state_to_send, &[])?;
4906 primary.guest_link_up = connect;
4907 } else {
4908 primary.pending_link_action = PendingLinkAction::Delay(up);
4909 }
4910
4911 match primary.pending_link_action {
4912 PendingLinkAction::Delay(_) => {
4913 return Ok(CoordinatorMessage::StartTimer(
4914 Instant::now() + LINK_DELAY_DURATION,
4915 ));
4916 }
4917 PendingLinkAction::Active(_) => panic!("State should not be Active"),
4918 _ => {}
4919 }
4920 }
4921 match primary.guest_vf_state {
4922 PrimaryChannelGuestVfState::Available { .. }
4923 | PrimaryChannelGuestVfState::UnavailableFromAvailable
4924 | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4925 | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4926 | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4927 | PrimaryChannelGuestVfState::DataPathSynthetic => {
4928 let (_, mut send) = self.queue.split();
4929 stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
4930 .await??;
4931 if let Some(message) = self.handle_state_change(primary, buffers).await? {
4932 return Ok(message);
4933 }
4934 }
4935 _ => (),
4936 }
4937 }
4938
4939 loop {
4940 let ring_full = {
4945 let (_, mut send) = self.queue.split();
4946 !send.can_write(ring_spare_capacity)?
4947 };
4948
4949 let did_some_work = (!ring_full
4950 && self.process_endpoint_rx(buffers, state, data, queue_state.queue.as_mut())?)
4951 | self.process_ring_buffer(buffers, state, data, queue_state)?
4952 | (!ring_full
4953 && self.process_endpoint_tx(state, data, queue_state.queue.as_mut())?)
4954 | self.transmit_pending_segments(state, data, queue_state)?
4955 | self.send_pending_packets(state)?;
4956
4957 if !did_some_work {
4958 state.stats.spurious_wakes.increment();
4959 }
4960
4961 self.process_control_messages(buffers, state)?;
4964
4965 let restart = stop
4969 .until_stopped(std::future::poll_fn(
4970 |cx| -> Poll<Option<CoordinatorMessage>> {
4971 if !ring_full {
4975 if queue_state.queue.poll_ready(cx).is_ready() {
4977 tracing::trace!("endpoint ready");
4978 return Poll::Ready(None);
4979 }
4980 }
4981
4982 let (mut recv, mut send) = self.queue.split();
4985 if state.free_tx_packets.len() >= self.adapter.free_tx_packet_threshold
4986 && data.tx_segments.is_empty()
4987 && recv.poll_ready(cx).is_ready()
4988 {
4989 tracing::trace!("incoming ring ready");
4990 return Poll::Ready(None);
4991 }
4992
4993 let mut pending_send_size = self.pending_send_size;
5000 if ring_full {
5001 pending_send_size = ring_spare_capacity;
5002 }
5003 if pending_send_size != 0
5004 && send.poll_ready(cx, pending_send_size).is_ready()
5005 {
5006 tracing::trace!("outgoing ring ready");
5007 return Poll::Ready(None);
5008 }
5009
5010 if let Some(remote_buffer_id_recv) =
5016 &mut queue_state.rx_buffer_range.remote_buffer_id_recv
5017 {
5018 while let Poll::Ready(Some(id)) =
5019 remote_buffer_id_recv.poll_next_unpin(cx)
5020 {
5021 if id >= RX_RESERVED_CONTROL_BUFFERS {
5022 queue_state.queue.rx_avail(&[RxId(id)]);
5023 } else {
5024 state
5025 .primary
5026 .as_mut()
5027 .unwrap()
5028 .free_control_buffers
5029 .push(ControlMessageId(id));
5030 }
5031 }
5032 }
5033
5034 if let Some(restart) = self.restart.take() {
5035 return Poll::Ready(Some(restart));
5036 }
5037
5038 tracing::trace!("network waiting");
5039 Poll::Pending
5040 },
5041 ))
5042 .await?;
5043
5044 if let Some(restart) = restart {
5045 break Ok(restart);
5046 }
5047 }
5048 }
5049
5050 fn process_endpoint_rx(
5051 &mut self,
5052 buffers: &ChannelBuffers,
5053 state: &mut ActiveState,
5054 data: &mut ProcessingData,
5055 epqueue: &mut dyn net_backend::Queue,
5056 ) -> Result<bool, WorkerError> {
5057 let n = epqueue
5058 .rx_poll(&mut data.rx_ready)
5059 .map_err(WorkerError::Endpoint)?;
5060 if n == 0 {
5061 return Ok(false);
5062 }
5063
5064 state.stats.rx_packets_per_wake.add_sample(n as u64);
5065
5066 if self.packet_filter == rndisprot::NDIS_PACKET_TYPE_NONE {
5067 tracing::trace!(
5068 packet_filter = self.packet_filter,
5069 "rx packets dropped due to packet filter"
5070 );
5071 state.pending_rx_packets.extend(&data.rx_ready[..n]);
5076 state.stats.rx_dropped_filtered.add(n as u64);
5077 return Ok(false);
5078 }
5079
5080 let transaction_id = data.rx_ready[0].0.into();
5081 let ready_ids = data.rx_ready[..n].iter().map(|&RxId(id)| id);
5082
5083 state.rx_bufs.allocate(ready_ids.clone()).unwrap();
5084
5085 let len = buffers.recv_buffer.sub_allocation_size as usize;
5088 data.transfer_pages.clear();
5089 data.transfer_pages
5090 .extend(ready_ids.map(|id| buffers.recv_buffer.transfer_page_range(id, len)));
5091
5092 match self.try_send_rndis_message(
5093 transaction_id,
5094 protocol::DATA_CHANNEL_TYPE,
5095 buffers.recv_buffer.id,
5096 &data.transfer_pages,
5097 )? {
5098 None => {
5099 state.stats.rx_packets.add(n as u64);
5101 }
5102 Some(_) => {
5103 state.stats.rx_dropped_ring_full.add(n as u64);
5107
5108 state.rx_bufs.free(data.rx_ready[0].0);
5109 epqueue.rx_avail(&data.rx_ready[..n]);
5110 }
5111 }
5112
5113 Ok(true)
5114 }
5115
5116 fn process_endpoint_tx(
5117 &mut self,
5118 state: &mut ActiveState,
5119 data: &mut ProcessingData,
5120 epqueue: &mut dyn net_backend::Queue,
5121 ) -> Result<bool, WorkerError> {
5122 let result = epqueue.tx_poll(&mut data.tx_done);
5124
5125 match result {
5126 Ok(n) => {
5127 if n == 0 {
5128 return Ok(false);
5129 }
5130
5131 for &id in &data.tx_done[..n] {
5132 let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5133 assert!(tx_packet.pending_packet_count > 0);
5134 tx_packet.pending_packet_count -= 1;
5135 if tx_packet.pending_packet_count == 0 {
5136 self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5137 }
5138 }
5139
5140 Ok(true)
5141 }
5142 Err(TxError::TryRestart(err)) => {
5143 let pending_tx = state
5145 .pending_tx_packets
5146 .iter_mut()
5147 .enumerate()
5148 .filter_map(|(id, inflight)| {
5149 if inflight.pending_packet_count > 0 {
5150 inflight.pending_packet_count = 0;
5151 Some(PendingTxCompletion {
5152 transaction_id: inflight.transaction_id,
5153 tx_id: Some(TxId(id as u32)),
5154 status: protocol::Status::SUCCESS,
5155 })
5156 } else {
5157 None
5158 }
5159 })
5160 .collect::<Vec<_>>();
5161 state.pending_tx_completions.extend(pending_tx);
5162 Err(WorkerError::EndpointRequiresQueueRestart(err))
5163 }
5164 Err(TxError::Fatal(err)) => Err(WorkerError::Endpoint(err)),
5165 }
5166 }
5167
5168 fn switch_data_path(
5169 &mut self,
5170 state: &mut ActiveState,
5171 use_guest_vf: bool,
5172 transaction_id: Option<u64>,
5173 ) -> Result<(), WorkerError> {
5174 let primary = state.primary.as_mut().unwrap();
5175 let mut queue_switch_operation = false;
5176 match primary.guest_vf_state {
5177 PrimaryChannelGuestVfState::AvailableAdvertised | PrimaryChannelGuestVfState::Ready => {
5178 if use_guest_vf || primary.is_data_path_switched.is_none() {
5187 primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5188 to_guest: use_guest_vf,
5189 id: transaction_id,
5190 result: None,
5191 };
5192 queue_switch_operation = true;
5193 }
5194 }
5195 PrimaryChannelGuestVfState::DataPathSwitched => {
5196 if !use_guest_vf {
5197 primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5198 to_guest: false,
5199 id: transaction_id,
5200 result: None,
5201 };
5202 queue_switch_operation = true;
5203 }
5204 }
5205 _ if use_guest_vf => {
5206 tracing::warn!(
5207 state = %primary.guest_vf_state,
5208 use_guest_vf,
5209 "Data path switch requested while device is in wrong state"
5210 );
5211 }
5212 _ => (),
5213 };
5214 if queue_switch_operation {
5215 self.send_coordinator_update_vf();
5216 } else {
5217 self.send_completion(transaction_id, &[])?;
5218 }
5219 Ok(())
5220 }
5221
5222 fn process_ring_buffer(
5223 &mut self,
5224 buffers: &ChannelBuffers,
5225 state: &mut ActiveState,
5226 data: &mut ProcessingData,
5227 queue_state: &mut QueueState,
5228 ) -> Result<bool, WorkerError> {
5229 let mut total_packets = 0;
5230 let mut did_some_work = false;
5231 loop {
5232 if state.free_tx_packets.is_empty() || !data.tx_segments.is_empty() {
5233 break;
5234 }
5235 let packet = if let Some(packet) =
5236 self.try_next_packet(buffers.send_buffer.as_ref(), Some(buffers.version))?
5237 {
5238 packet
5239 } else {
5240 break;
5241 };
5242
5243 did_some_work = true;
5244 match packet.data {
5245 PacketData::RndisPacket(_) => {
5246 assert!(data.tx_segments.is_empty());
5247 let id = state.free_tx_packets.pop().unwrap();
5248 let result: Result<usize, WorkerError> =
5249 self.handle_rndis(buffers, id, state, &packet, &mut data.tx_segments);
5250 let num_packets = match result {
5251 Ok(num_packets) => num_packets,
5252 Err(err) => {
5253 tracelimit::error_ratelimited!(
5254 err = &err as &dyn std::error::Error,
5255 "failed to handle RNDIS packet"
5256 );
5257 self.complete_tx_packet(state, id, protocol::Status::FAILURE)?;
5258 continue;
5259 }
5260 };
5261 total_packets += num_packets as u64;
5262 state.pending_tx_packets[id.0 as usize].pending_packet_count += num_packets;
5263
5264 if num_packets != 0 {
5265 if self.transmit_segments(state, data, queue_state, id, num_packets)?
5266 < num_packets
5267 {
5268 state.stats.tx_stalled.increment();
5269 }
5270 } else {
5271 self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5272 }
5273 }
5274 PacketData::RndisPacketComplete(_completion) => {
5275 data.rx_done.clear();
5276 state
5277 .release_recv_buffers(
5278 packet
5279 .transaction_id
5280 .expect("completion packets have transaction id by construction"),
5281 &queue_state.rx_buffer_range,
5282 &mut data.rx_done,
5283 )
5284 .ok_or(WorkerError::InvalidRndisPacketCompletion)?;
5285 queue_state.queue.rx_avail(&data.rx_done);
5286 }
5287 PacketData::SubChannelRequest(request) if state.primary.is_some() => {
5288 let mut subchannel_count = 0;
5289 let status = if request.operation == protocol::SubchannelOperation::ALLOCATE
5290 && request.num_sub_channels <= self.adapter.max_queues.into()
5291 {
5292 subchannel_count = request.num_sub_channels;
5293 protocol::Status::SUCCESS
5294 } else {
5295 protocol::Status::FAILURE
5296 };
5297
5298 tracing::debug!(?status, subchannel_count, "subchannel request");
5299 self.send_completion(
5300 packet.transaction_id,
5301 &self
5302 .message(
5303 protocol::MESSAGE5_TYPE_SUB_CHANNEL,
5304 protocol::Message5SubchannelComplete {
5305 status,
5306 num_sub_channels: subchannel_count,
5307 },
5308 )
5309 .payload(),
5310 )?;
5311
5312 if subchannel_count > 0 {
5313 let primary = state.primary.as_mut().unwrap();
5314 primary.requested_num_queues = subchannel_count as u16 + 1;
5315 primary.tx_spread_sent = false;
5316 self.restart = Some(CoordinatorMessage::Restart);
5317 }
5318 }
5319 PacketData::RevokeReceiveBuffer(Message1RevokeReceiveBuffer { id })
5320 | PacketData::RevokeSendBuffer(Message1RevokeSendBuffer { id })
5321 if state.primary.is_some() =>
5322 {
5323 tracing::debug!(
5324 id,
5325 "receive/send buffer revoked, terminating channel processing"
5326 );
5327 return Err(WorkerError::BufferRevoked);
5328 }
5329 PacketData::SendVfAssociationCompletion if state.primary.is_some() => (),
5330 PacketData::SwitchDataPath(switch_data_path) if state.primary.is_some() => {
5331 self.switch_data_path(
5332 state,
5333 switch_data_path.active_data_path == protocol::DataPath::VF.0,
5334 packet.transaction_id,
5335 )?;
5336 }
5337 PacketData::SwitchDataPathCompletion if state.primary.is_some() => (),
5338 PacketData::OidQueryEx(oid_query) => {
5339 tracing::warn!(oid = ?oid_query.oid, "unimplemented OID");
5340 self.send_completion(
5341 packet.transaction_id,
5342 &self
5343 .message(
5344 protocol::MESSAGE5_TYPE_OID_QUERY_EX_COMPLETE,
5345 protocol::Message5OidQueryExComplete {
5346 status: rndisprot::STATUS_NOT_SUPPORTED,
5347 bytes: 0,
5348 },
5349 )
5350 .payload(),
5351 )?;
5352 }
5353 p => {
5354 tracing::warn!(request = ?p, "unexpected packet");
5355 return Err(WorkerError::UnexpectedPacketOrder(
5356 PacketOrderError::SwitchDataPathCompletionPrimaryChannelState,
5357 ));
5358 }
5359 }
5360 }
5361 state.stats.tx_packets_per_wake.add_sample(total_packets);
5362 Ok(did_some_work)
5363 }
5364
5365 fn transmit_pending_segments(
5366 &mut self,
5367 state: &mut ActiveState,
5368 data: &mut ProcessingData,
5369 queue_state: &mut QueueState,
5370 ) -> Result<bool, WorkerError> {
5371 if data.tx_segments.is_empty() {
5372 return Ok(false);
5373 }
5374 let net_backend::TxSegmentType::Head(metadata) = &data.tx_segments[0].ty else {
5375 unreachable!()
5376 };
5377 let id = metadata.id;
5378 let num_packets = state.pending_tx_packets[id.0 as usize].pending_packet_count;
5379 let packets_sent = self.transmit_segments(state, data, queue_state, id, num_packets)?;
5380 Ok(num_packets == packets_sent)
5381 }
5382
5383 fn transmit_segments(
5384 &mut self,
5385 state: &mut ActiveState,
5386 data: &mut ProcessingData,
5387 queue_state: &mut QueueState,
5388 id: TxId,
5389 num_packets: usize,
5390 ) -> Result<usize, WorkerError> {
5391 let (sync, segments_sent) = queue_state
5392 .queue
5393 .tx_avail(&data.tx_segments)
5394 .map_err(WorkerError::Endpoint)?;
5395
5396 assert!(segments_sent <= data.tx_segments.len());
5397
5398 let packets_sent = if segments_sent == data.tx_segments.len() {
5399 num_packets
5400 } else {
5401 net_backend::packet_count(&data.tx_segments[..segments_sent])
5402 };
5403
5404 data.tx_segments.drain(..segments_sent);
5405
5406 if sync {
5407 state.pending_tx_packets[id.0 as usize].pending_packet_count -= packets_sent;
5408 }
5409
5410 if state.pending_tx_packets[id.0 as usize].pending_packet_count == 0 {
5411 self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5412 }
5413
5414 Ok(packets_sent)
5415 }
5416
5417 fn handle_rndis(
5418 &mut self,
5419 buffers: &ChannelBuffers,
5420 id: TxId,
5421 state: &mut ActiveState,
5422 packet: &Packet<'_>,
5423 segments: &mut Vec<TxSegment>,
5424 ) -> Result<usize, WorkerError> {
5425 let mut num_packets = 0;
5426 let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5427 assert!(tx_packet.pending_packet_count == 0);
5428 tx_packet.transaction_id = packet
5429 .transaction_id
5430 .ok_or(WorkerError::MissingTransactionId)?;
5431
5432 packet
5436 .external_data
5437 .iter()
5438 .try_for_each(|range| buffers.mem.probe_gpns(range.gpns()))
5439 .map_err(WorkerError::GpaDirectError)?;
5440
5441 let mut reader = packet.rndis_reader(&buffers.mem);
5442 while reader.len() > 0 {
5447 let mut this_reader = reader.clone();
5448 let header: rndisprot::MessageHeader = this_reader.read_plain()?;
5449 if self.handle_rndis_message(
5450 buffers,
5451 state,
5452 id,
5453 header.message_type,
5454 this_reader,
5455 segments,
5456 )? {
5457 num_packets += 1;
5458 }
5459 reader.skip(header.message_length as usize)?;
5460 }
5461
5462 Ok(num_packets)
5463 }
5464
5465 fn try_send_tx_packet(
5466 &mut self,
5467 transaction_id: u64,
5468 status: protocol::Status,
5469 ) -> Result<bool, WorkerError> {
5470 let message = self.message(
5471 protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE,
5472 protocol::Message1SendRndisPacketComplete { status },
5473 );
5474 let result = self.queue.split().1.try_write(&queue::OutgoingPacket {
5475 transaction_id,
5476 packet_type: OutgoingPacketType::Completion,
5477 payload: &message.payload(),
5478 });
5479 let sent = match result {
5480 Ok(()) => true,
5481 Err(queue::TryWriteError::Full(n)) => {
5482 self.pending_send_size = n;
5483 false
5484 }
5485 Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
5486 };
5487 Ok(sent)
5488 }
5489
5490 fn send_pending_packets(&mut self, state: &mut ActiveState) -> Result<bool, WorkerError> {
5491 let mut did_some_work = false;
5492 while let Some(pending) = state.pending_tx_completions.front() {
5493 if !self.try_send_tx_packet(pending.transaction_id, pending.status)? {
5494 return Ok(did_some_work);
5495 }
5496 did_some_work = true;
5497 if let Some(id) = pending.tx_id {
5498 state.free_tx_packets.push(id);
5499 }
5500 tracing::trace!(?pending, "sent tx completion");
5501 state.pending_tx_completions.pop_front();
5502 }
5503
5504 self.pending_send_size = 0;
5505 Ok(did_some_work)
5506 }
5507
5508 fn complete_tx_packet(
5509 &mut self,
5510 state: &mut ActiveState,
5511 id: TxId,
5512 status: protocol::Status,
5513 ) -> Result<(), WorkerError> {
5514 let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5515 assert_eq!(tx_packet.pending_packet_count, 0);
5516 if self.pending_send_size == 0
5517 && self.try_send_tx_packet(tx_packet.transaction_id, status)?
5518 {
5519 tracing::trace!(id = id.0, "sent tx completion");
5520 state.free_tx_packets.push(id);
5521 } else {
5522 tracing::trace!(id = id.0, "pended tx completion");
5523 state.pending_tx_completions.push_back(PendingTxCompletion {
5524 transaction_id: tx_packet.transaction_id,
5525 tx_id: Some(id),
5526 status,
5527 });
5528 }
5529 Ok(())
5530 }
5531}
5532
5533impl ActiveState {
5534 fn release_recv_buffers(
5535 &mut self,
5536 transaction_id: u64,
5537 rx_buffer_range: &RxBufferRange,
5538 done: &mut Vec<RxId>,
5539 ) -> Option<()> {
5540 let first_id: u32 = transaction_id.try_into().ok()?;
5542 let ids = self.rx_bufs.free(first_id)?;
5543 for id in ids {
5544 if !rx_buffer_range.send_if_remote(id) {
5545 if id >= RX_RESERVED_CONTROL_BUFFERS {
5546 done.push(RxId(id));
5547 } else {
5548 self.primary
5549 .as_mut()?
5550 .free_control_buffers
5551 .push(ControlMessageId(id));
5552 }
5553 }
5554 }
5555 Some(())
5556 }
5557}