1#![expect(missing_docs)]
7#![forbid(unsafe_code)]
8
9mod buffers;
10mod protocol;
11pub mod resolver;
12mod rndisprot;
13mod rx_bufs;
14mod saved_state;
15mod test;
16
17use crate::buffers::GuestBuffers;
18use crate::protocol::Message1RevokeReceiveBuffer;
19use crate::protocol::Message1RevokeSendBuffer;
20use crate::protocol::VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES;
21use crate::protocol::Version;
22use crate::rndisprot::NDIS_HASH_FUNCTION_MASK;
23use crate::rndisprot::NDIS_RSS_PARAM_FLAG_DISABLE_RSS;
24use async_trait::async_trait;
25pub use buffers::BufferPool;
26use buffers::sub_allocation_size_for_mtu;
27use futures::FutureExt;
28use futures::StreamExt;
29use futures::channel::mpsc;
30use futures_concurrency::future::Race;
31use guestmem::AccessError;
32use guestmem::GuestMemory;
33use guestmem::GuestMemoryError;
34use guestmem::MemoryRead;
35use guestmem::MemoryWrite;
36use guestmem::ranges::PagedRange;
37use guestmem::ranges::PagedRanges;
38use guestmem::ranges::PagedRangesReader;
39use guid::Guid;
40use hvdef::hypercall::HvGuestOsId;
41use hvdef::hypercall::HvGuestOsMicrosoft;
42use hvdef::hypercall::HvGuestOsMicrosoftIds;
43use hvdef::hypercall::HvGuestOsOpenSourceType;
44use inspect::Inspect;
45use inspect::InspectMut;
46use inspect::SensitivityLevel;
47use inspect_counters::Counter;
48use inspect_counters::Histogram;
49use mesh::rpc::Rpc;
50use net_backend::Endpoint;
51use net_backend::EndpointAction;
52use net_backend::L3Protocol;
53use net_backend::QueueConfig;
54use net_backend::RxId;
55use net_backend::TxError;
56use net_backend::TxId;
57use net_backend::TxSegment;
58use net_backend_resources::mac_address::MacAddress;
59use pal_async::timer::Instant;
60use pal_async::timer::PolledTimer;
61use ring::gparange::MultiPagedRangeIter;
62use rx_bufs::RxBuffers;
63use rx_bufs::SubAllocationInUse;
64use std::cmp;
65use std::collections::VecDeque;
66use std::fmt::Debug;
67use std::future::pending;
68use std::mem::offset_of;
69use std::ops::Range;
70use std::sync::Arc;
71use std::sync::atomic::AtomicUsize;
72use std::sync::atomic::Ordering;
73use std::task::Poll;
74use std::time::Duration;
75use task_control::AsyncRun;
76use task_control::InspectTaskMut;
77use task_control::StopTask;
78use task_control::TaskControl;
79use thiserror::Error;
80use tracing::Instrument;
81use vmbus_async::queue;
82use vmbus_async::queue::ExternalDataError;
83use vmbus_async::queue::IncomingPacket;
84use vmbus_async::queue::Queue;
85use vmbus_channel::bus::OfferParams;
86use vmbus_channel::bus::OpenRequest;
87use vmbus_channel::channel::ChannelControl;
88use vmbus_channel::channel::ChannelOpenError;
89use vmbus_channel::channel::ChannelRestoreError;
90use vmbus_channel::channel::DeviceResources;
91use vmbus_channel::channel::RestoreControl;
92use vmbus_channel::channel::SaveRestoreVmbusDevice;
93use vmbus_channel::channel::VmbusDevice;
94use vmbus_channel::gpadl::GpadlId;
95use vmbus_channel::gpadl::GpadlMapView;
96use vmbus_channel::gpadl::GpadlView;
97use vmbus_channel::gpadl::UnknownGpadlId;
98use vmbus_channel::gpadl_ring::GpadlRingMem;
99use vmbus_channel::gpadl_ring::gpadl_channel;
100use vmbus_ring as ring;
101use vmbus_ring::OutgoingPacketType;
102use vmbus_ring::RingMem;
103use vmbus_ring::gparange::GpnList;
104use vmbus_ring::gparange::MultiPagedRangeBuf;
105use vmcore::save_restore::RestoreError;
106use vmcore::save_restore::SaveError;
107use vmcore::save_restore::SavedStateBlob;
108use vmcore::vm_task::VmTaskDriver;
109use vmcore::vm_task::VmTaskDriverSource;
110use zerocopy::FromBytes;
111use zerocopy::FromZeros;
112use zerocopy::Immutable;
113use zerocopy::IntoBytes;
114use zerocopy::KnownLayout;
115
116const MIN_CONTROL_RING_SIZE: usize = 144;
119
120const MIN_STATE_CHANGE_RING_SIZE: usize = 196;
123
124const VF_ASSOCIATION_TRANSACTION_ID: u64 = 0x8000000000000000;
126const SWITCH_DATA_PATH_TRANSACTION_ID: u64 = 0x8000000000000001;
128
129const NETVSP_MAX_SUBCHANNELS_PER_VNIC: u16 = 64;
130
131#[cfg(not(test))]
139const VF_DEVICE_DELAY: Duration = Duration::from_secs(1);
140#[cfg(test)]
141const VF_DEVICE_DELAY: Duration = Duration::from_millis(100);
142
143#[cfg(not(test))]
146const LINK_DELAY_DURATION: Duration = Duration::from_secs(5);
147#[cfg(test)]
148const LINK_DELAY_DURATION: Duration = Duration::from_millis(333);
149
150#[derive(PartialEq)]
151enum CoordinatorMessage {
152 UpdateGuestVfState,
155 Restart,
158 StartTimer(Instant),
160}
161
162struct Worker<T: RingMem> {
163 channel_idx: u16,
164 target_vp: u32,
165 mem: GuestMemory,
166 channel: NetChannel<T>,
167 state: WorkerState,
168 coordinator_send: mpsc::Sender<CoordinatorMessage>,
169}
170
171struct NetQueue {
172 driver: VmTaskDriver,
173 queue_state: Option<QueueState>,
174}
175
176impl<T: RingMem + 'static + Sync> InspectTaskMut<Worker<T>> for NetQueue {
177 fn inspect_mut(&mut self, req: inspect::Request<'_>, worker: Option<&mut Worker<T>>) {
178 if worker.is_none() && self.queue_state.is_none() {
179 req.ignore();
180 return;
181 }
182
183 let mut resp = req.respond();
184 resp.field("driver", &self.driver);
185 if let Some(worker) = worker {
186 resp.field(
187 "protocol_state",
188 match &worker.state {
189 WorkerState::Init(None) => "version",
190 WorkerState::Init(Some(_)) => "init",
191 WorkerState::Ready(_) => "ready",
192 },
193 )
194 .field("ring", &worker.channel.queue)
195 .field(
196 "can_use_ring_size_optimization",
197 worker.channel.can_use_ring_size_opt,
198 );
199
200 if let WorkerState::Ready(state) = &worker.state {
201 resp.field(
202 "outstanding_tx_packets",
203 state.state.pending_tx_packets.len() - state.state.free_tx_packets.len(),
204 )
205 .field(
206 "pending_tx_completions",
207 state.state.pending_tx_completions.len(),
208 )
209 .field("free_tx_packets", state.state.free_tx_packets.len())
210 .merge(&state.state.stats);
211 }
212 }
213
214 if let Some(queue_state) = &mut self.queue_state {
215 resp.field_mut("queue", &mut queue_state.queue)
216 .field("rx_buffers", queue_state.rx_buffer_range.id_range.len())
217 .field(
218 "rx_buffers_start",
219 queue_state.rx_buffer_range.id_range.start,
220 );
221 }
222 }
223}
224
225#[expect(clippy::large_enum_variant)]
226enum WorkerState {
227 Init(Option<InitState>),
228 Ready(ReadyState),
229}
230
231impl WorkerState {
232 fn ready(&self) -> Option<&ReadyState> {
233 if let Self::Ready(state) = self {
234 Some(state)
235 } else {
236 None
237 }
238 }
239
240 fn ready_mut(&mut self) -> Option<&mut ReadyState> {
241 if let Self::Ready(state) = self {
242 Some(state)
243 } else {
244 None
245 }
246 }
247}
248
249struct InitState {
250 version: Version,
251 ndis_config: Option<NdisConfig>,
252 ndis_version: Option<NdisVersion>,
253 recv_buffer: Option<ReceiveBuffer>,
254 send_buffer: Option<SendBuffer>,
255}
256
257#[derive(Copy, Clone, Debug, Inspect)]
258struct NdisVersion {
259 #[inspect(hex)]
260 major: u32,
261 #[inspect(hex)]
262 minor: u32,
263}
264
265#[derive(Copy, Clone, Debug, Inspect)]
266struct NdisConfig {
267 #[inspect(safe)]
268 mtu: u32,
269 #[inspect(safe)]
270 capabilities: protocol::NdisConfigCapabilities,
271}
272
273struct ReadyState {
274 buffers: Arc<ChannelBuffers>,
275 state: ActiveState,
276 data: ProcessingData,
277}
278
279#[async_trait]
282pub trait VirtualFunction: Sync + Send {
283 async fn id(&self) -> Option<u32>;
287 async fn guest_ready_for_device(&mut self);
289 async fn wait_for_state_change(&mut self) -> Rpc<(), ()>;
292}
293
294struct Adapter {
295 driver: VmTaskDriver,
296 mac_address: MacAddress,
297 max_queues: u16,
298 indirection_table_size: u16,
299 offload_support: OffloadConfig,
300 ring_size_limit: AtomicUsize,
301 free_tx_packet_threshold: usize,
302 tx_fast_completions: bool,
303 adapter_index: u32,
304 get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
305 num_sub_channels_opened: AtomicUsize,
306 link_speed: u64,
307}
308
309struct QueueState {
310 queue: Box<dyn net_backend::Queue>,
311 rx_buffer_range: RxBufferRange,
312 target_vp_set: bool,
313}
314
315struct RxBufferRange {
316 id_range: Range<u32>,
317 remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
318 remote_ranges: Arc<RxBufferRanges>,
319}
320
321impl RxBufferRange {
322 fn new(
323 ranges: Arc<RxBufferRanges>,
324 id_range: Range<u32>,
325 remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
326 ) -> Self {
327 Self {
328 id_range,
329 remote_buffer_id_recv,
330 remote_ranges: ranges,
331 }
332 }
333
334 fn send_if_remote(&self, id: u32) -> bool {
335 if id < RX_RESERVED_CONTROL_BUFFERS || self.id_range.contains(&id) {
338 false
339 } else {
340 let i = (id - RX_RESERVED_CONTROL_BUFFERS) / self.remote_ranges.buffers_per_queue;
341 let i = (i as usize).min(self.remote_ranges.buffer_id_send.len() - 1);
345 let _ = self.remote_ranges.buffer_id_send[i].unbounded_send(id);
346 true
347 }
348 }
349}
350
351struct RxBufferRanges {
352 buffers_per_queue: u32,
353 buffer_id_send: Vec<mpsc::UnboundedSender<u32>>,
354}
355
356impl RxBufferRanges {
357 fn new(buffer_count: u32, queue_count: u32) -> (Self, Vec<mpsc::UnboundedReceiver<u32>>) {
358 let buffers_per_queue = (buffer_count - RX_RESERVED_CONTROL_BUFFERS) / queue_count;
359 #[expect(clippy::disallowed_methods)] let (send, recv): (Vec<_>, Vec<_>) = (0..queue_count).map(|_| mpsc::unbounded()).unzip();
361 (
362 Self {
363 buffers_per_queue,
364 buffer_id_send: send,
365 },
366 recv,
367 )
368 }
369}
370
371struct RssState {
372 key: [u8; 40],
373 indirection_table: Vec<u16>,
374}
375
376struct NetChannel<T: RingMem> {
378 adapter: Arc<Adapter>,
379 queue: Queue<T>,
380 gpadl_map: GpadlMapView,
381 packet_size: usize,
382 pending_send_size: usize,
383 restart: Option<CoordinatorMessage>,
384 can_use_ring_size_opt: bool,
385}
386
387struct ProcessingData {
389 tx_segments: Vec<TxSegment>,
390 tx_done: Box<[TxId]>,
391 rx_ready: Box<[RxId]>,
392 rx_done: Vec<RxId>,
393 transfer_pages: Vec<ring::TransferPageRange>,
394}
395
396impl ProcessingData {
397 fn new() -> Self {
398 Self {
399 tx_segments: Vec::new(),
400 tx_done: vec![TxId(0); 8192].into(),
401 rx_ready: vec![RxId(0); RX_BATCH_SIZE].into(),
402 rx_done: Vec::with_capacity(RX_BATCH_SIZE),
403 transfer_pages: Vec::with_capacity(RX_BATCH_SIZE),
404 }
405 }
406}
407
408#[derive(Debug, Inspect)]
411struct ChannelBuffers {
412 version: Version,
413 #[inspect(skip)]
414 mem: GuestMemory,
415 #[inspect(skip)]
416 recv_buffer: ReceiveBuffer,
417 #[inspect(skip)]
418 send_buffer: Option<SendBuffer>,
419 ndis_version: NdisVersion,
420 #[inspect(safe)]
421 ndis_config: NdisConfig,
422}
423
424#[derive(Copy, Clone, Debug)]
426struct ControlMessageId(u32);
427
428struct ActiveState {
430 primary: Option<PrimaryChannelState>,
431
432 pending_tx_packets: Vec<PendingTxPacket>,
433 free_tx_packets: Vec<TxId>,
434 pending_tx_completions: VecDeque<PendingTxCompletion>,
435
436 rx_bufs: RxBuffers,
437
438 stats: QueueStats,
439}
440
441#[derive(Inspect, Default)]
442struct QueueStats {
443 tx_stalled: Counter,
444 rx_dropped_ring_full: Counter,
445 spurious_wakes: Counter,
446 rx_packets: Counter,
447 tx_packets: Counter,
448 tx_lso_packets: Counter,
449 tx_checksum_packets: Counter,
450 tx_packets_per_wake: Histogram<10>,
451 rx_packets_per_wake: Histogram<10>,
452}
453
454#[derive(Debug)]
455struct PendingTxCompletion {
456 transaction_id: u64,
457 tx_id: Option<TxId>,
458 status: protocol::Status,
459}
460
461#[derive(Clone, Copy)]
462enum PrimaryChannelGuestVfState {
463 Initializing,
465 Restoring(saved_state::GuestVfState),
467 Unavailable,
469 UnavailableFromAvailable,
471 UnavailableFromDataPathSwitchPending { to_guest: bool, id: Option<u64> },
473 UnavailableFromDataPathSwitched,
475 Available { vfid: u32 },
477 AvailableAdvertised,
479 Ready,
481 DataPathSwitchPending {
483 to_guest: bool,
484 id: Option<u64>,
485 result: Option<bool>,
486 },
487 DataPathSwitched,
489 DataPathSynthetic,
492}
493
494impl std::fmt::Display for PrimaryChannelGuestVfState {
495 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
496 match self {
497 PrimaryChannelGuestVfState::Initializing => write!(f, "initializing"),
498 PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
499 write!(f, "restoring")
500 }
501 PrimaryChannelGuestVfState::Restoring(
502 saved_state::GuestVfState::AvailableAdvertised,
503 ) => write!(f, "restoring from guest notified of vfid"),
504 PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
505 write!(f, "restoring from vf present")
506 }
507 PrimaryChannelGuestVfState::Restoring(
508 saved_state::GuestVfState::DataPathSwitchPending {
509 to_guest, result, ..
510 },
511 ) => {
512 write!(
513 f,
514 "restoring from client requested data path switch: to {} {}",
515 if *to_guest { "guest" } else { "synthetic" },
516 if let Some(result) = result {
517 if *result { "succeeded\"" } else { "failed\"" }
518 } else {
519 "in progress\""
520 }
521 )
522 }
523 PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::DataPathSwitched) => {
524 write!(f, "restoring from data path in guest")
525 }
526 PrimaryChannelGuestVfState::Unavailable => write!(f, "unavailable"),
527 PrimaryChannelGuestVfState::UnavailableFromAvailable => {
528 write!(f, "\"unavailable (previously available)\"")
529 }
530 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. } => {
531 write!(f, "unavailable (previously switching data path)")
532 }
533 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
534 write!(f, "\"unavailable (previously using guest VF)\"")
535 }
536 PrimaryChannelGuestVfState::Available { vfid } => write!(f, "available vfid: {}", vfid),
537 PrimaryChannelGuestVfState::AvailableAdvertised => {
538 write!(f, "\"available, guest notified\"")
539 }
540 PrimaryChannelGuestVfState::Ready => write!(f, "\"available and present in guest\""),
541 PrimaryChannelGuestVfState::DataPathSwitchPending {
542 to_guest, result, ..
543 } => {
544 write!(
545 f,
546 "\"switching to {} {}",
547 if *to_guest { "guest" } else { "synthetic" },
548 if let Some(result) = result {
549 if *result { "succeeded\"" } else { "failed\"" }
550 } else {
551 "in progress\""
552 }
553 )
554 }
555 PrimaryChannelGuestVfState::DataPathSwitched => {
556 write!(f, "\"available and data path switched\"")
557 }
558 PrimaryChannelGuestVfState::DataPathSynthetic => write!(
559 f,
560 "\"available but data path switched back to synthetic due to external state change\""
561 ),
562 }
563 }
564}
565
566impl Inspect for PrimaryChannelGuestVfState {
567 fn inspect(&self, req: inspect::Request<'_>) {
568 req.value(self.to_string());
569 }
570}
571
572#[derive(Debug)]
573enum PendingLinkAction {
574 Default,
575 Active(bool),
576 Delay(bool),
577}
578
579struct PrimaryChannelState {
580 guest_vf_state: PrimaryChannelGuestVfState,
581 is_data_path_switched: Option<bool>,
582 control_messages: VecDeque<ControlMessage>,
583 control_messages_len: usize,
584 free_control_buffers: Vec<ControlMessageId>,
585 rss_state: Option<RssState>,
586 requested_num_queues: u16,
587 rndis_state: RndisState,
588 offload_config: OffloadConfig,
589 pending_offload_change: bool,
590 tx_spread_sent: bool,
591 guest_link_up: bool,
592 pending_link_action: PendingLinkAction,
593}
594
595impl Inspect for PrimaryChannelState {
596 fn inspect(&self, req: inspect::Request<'_>) {
597 req.respond()
598 .sensitivity_field(
599 "guest_vf_state",
600 SensitivityLevel::Safe,
601 self.guest_vf_state,
602 )
603 .sensitivity_field(
604 "data_path_switched",
605 SensitivityLevel::Safe,
606 self.is_data_path_switched,
607 )
608 .sensitivity_field(
609 "pending_control_messages",
610 SensitivityLevel::Safe,
611 self.control_messages.len(),
612 )
613 .sensitivity_field(
614 "free_control_message_buffers",
615 SensitivityLevel::Safe,
616 self.free_control_buffers.len(),
617 )
618 .sensitivity_field(
619 "pending_offload_change",
620 SensitivityLevel::Safe,
621 self.pending_offload_change,
622 )
623 .sensitivity_field("rndis_state", SensitivityLevel::Safe, self.rndis_state)
624 .sensitivity_field(
625 "offload_config",
626 SensitivityLevel::Safe,
627 &self.offload_config,
628 )
629 .sensitivity_field(
630 "tx_spread_sent",
631 SensitivityLevel::Safe,
632 self.tx_spread_sent,
633 )
634 .sensitivity_field("guest_link_up", SensitivityLevel::Safe, self.guest_link_up)
635 .sensitivity_field(
636 "pending_link_action",
637 SensitivityLevel::Safe,
638 match &self.pending_link_action {
639 PendingLinkAction::Active(up) => format!("Active({:x?})", up),
640 PendingLinkAction::Delay(up) => format!("Delay({:x?})", up),
641 PendingLinkAction::Default => "None".to_string(),
642 },
643 );
644 }
645}
646
647#[derive(Debug, Inspect, Clone)]
648struct OffloadConfig {
649 #[inspect(safe)]
650 checksum_tx: ChecksumOffloadConfig,
651 #[inspect(safe)]
652 checksum_rx: ChecksumOffloadConfig,
653 #[inspect(safe)]
654 lso4: bool,
655 #[inspect(safe)]
656 lso6: bool,
657}
658
659#[derive(Debug, Inspect, Clone)]
660struct ChecksumOffloadConfig {
661 #[inspect(safe)]
662 ipv4_header: bool,
663 #[inspect(safe)]
664 tcp4: bool,
665 #[inspect(safe)]
666 udp4: bool,
667 #[inspect(safe)]
668 tcp6: bool,
669 #[inspect(safe)]
670 udp6: bool,
671}
672
673impl ChecksumOffloadConfig {
674 fn flags(
675 &self,
676 ) -> (
677 rndisprot::Ipv4ChecksumOffload,
678 rndisprot::Ipv6ChecksumOffload,
679 ) {
680 let on = rndisprot::NDIS_OFFLOAD_SUPPORTED;
681 let mut v4 = rndisprot::Ipv4ChecksumOffload::new();
682 let mut v6 = rndisprot::Ipv6ChecksumOffload::new();
683 if self.ipv4_header {
684 v4.set_ip_options_supported(on);
685 v4.set_ip_checksum(on);
686 }
687 if self.tcp4 {
688 v4.set_ip_options_supported(on);
689 v4.set_tcp_options_supported(on);
690 v4.set_tcp_checksum(on);
691 }
692 if self.tcp6 {
693 v6.set_ip_extension_headers_supported(on);
694 v6.set_tcp_options_supported(on);
695 v6.set_tcp_checksum(on);
696 }
697 if self.udp4 {
698 v4.set_ip_options_supported(on);
699 v4.set_udp_checksum(on);
700 }
701 if self.udp6 {
702 v6.set_ip_extension_headers_supported(on);
703 v6.set_udp_checksum(on);
704 }
705 (v4, v6)
706 }
707}
708
709impl OffloadConfig {
710 fn ndis_offload(&self) -> rndisprot::NdisOffload {
711 let checksum = {
712 let (ipv4_tx_flags, ipv6_tx_flags) = self.checksum_tx.flags();
713 let (ipv4_rx_flags, ipv6_rx_flags) = self.checksum_rx.flags();
714 rndisprot::TcpIpChecksumOffload {
715 ipv4_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
716 ipv4_tx_flags,
717 ipv4_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
718 ipv4_rx_flags,
719 ipv6_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
720 ipv6_tx_flags,
721 ipv6_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
722 ipv6_rx_flags,
723 }
724 };
725
726 let lso_v2 = {
727 let mut lso = rndisprot::TcpLargeSendOffloadV2::new_zeroed();
728 const MAX_OFFLOAD_SIZE: u32 = 0xF53C;
730 if self.lso4 {
731 lso.ipv4_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
732 lso.ipv4_max_offload_size = MAX_OFFLOAD_SIZE;
733 lso.ipv4_min_segment_count = 2;
734 }
735 if self.lso6 {
736 lso.ipv6_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
737 lso.ipv6_max_offload_size = MAX_OFFLOAD_SIZE;
738 lso.ipv6_min_segment_count = 2;
739 lso.ipv6_flags = rndisprot::Ipv6LsoFlags::new()
740 .with_ip_extension_headers_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED)
741 .with_tcp_options_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED);
742 }
743 lso
744 };
745
746 rndisprot::NdisOffload {
747 header: rndisprot::NdisObjectHeader {
748 object_type: rndisprot::NdisObjectType::OFFLOAD,
749 revision: 3,
750 size: rndisprot::NDIS_SIZEOF_NDIS_OFFLOAD_REVISION_3 as u16,
751 },
752 checksum,
753 lso_v2,
754 ..FromZeros::new_zeroed()
755 }
756 }
757}
758
759#[derive(Debug, Inspect, PartialEq, Eq, Copy, Clone)]
760pub enum RndisState {
761 Initializing,
762 Operational,
763 Halted,
764}
765
766impl PrimaryChannelState {
767 fn new(offload_config: OffloadConfig) -> Self {
768 Self {
769 guest_vf_state: PrimaryChannelGuestVfState::Initializing,
770 is_data_path_switched: None,
771 control_messages: VecDeque::new(),
772 control_messages_len: 0,
773 free_control_buffers: (0..RX_RESERVED_CONTROL_BUFFERS)
774 .map(ControlMessageId)
775 .collect(),
776 rss_state: None,
777 requested_num_queues: 1,
778 rndis_state: RndisState::Initializing,
779 pending_offload_change: false,
780 offload_config,
781 tx_spread_sent: false,
782 guest_link_up: true,
783 pending_link_action: PendingLinkAction::Default,
784 }
785 }
786
787 fn restore(
788 guest_vf_state: &saved_state::GuestVfState,
789 rndis_state: &saved_state::RndisState,
790 offload_config: &saved_state::OffloadConfig,
791 pending_offload_change: bool,
792 num_queues: u16,
793 indirection_table_size: u16,
794 rx_bufs: &RxBuffers,
795 control_messages: Vec<saved_state::IncomingControlMessage>,
796 rss_state: Option<saved_state::RssState>,
797 tx_spread_sent: bool,
798 guest_link_down: bool,
799 pending_link_action: Option<bool>,
800 ) -> Result<Self, NetRestoreError> {
801 let control_messages_len = control_messages.iter().map(|msg| msg.data.len()).sum();
803
804 let control_messages = control_messages
805 .into_iter()
806 .map(|msg| ControlMessage {
807 message_type: msg.message_type,
808 data: msg.data.into(),
809 })
810 .collect();
811
812 let free_control_buffers = (0..RX_RESERVED_CONTROL_BUFFERS)
814 .filter_map(|id| rx_bufs.is_free(id).then_some(ControlMessageId(id)))
815 .collect();
816
817 let rss_state = rss_state
818 .map(|mut rss| {
819 if rss.indirection_table.len() > indirection_table_size as usize {
820 return Err(NetRestoreError::ReducedIndirectionTableSize);
823 }
824 if rss.indirection_table.len() < indirection_table_size as usize {
825 tracing::warn!(
826 saved_indirection_table_size = rss.indirection_table.len(),
827 adapter_indirection_table_size = indirection_table_size,
828 "increasing indirection table size",
829 );
830 let table_clone = rss.indirection_table.clone();
833 let num_to_add = indirection_table_size as usize - rss.indirection_table.len();
834 rss.indirection_table
835 .extend(table_clone.iter().cycle().take(num_to_add));
836 }
837 Ok(RssState {
838 key: rss
839 .key
840 .try_into()
841 .map_err(|_| NetRestoreError::InvalidRssKeySize)?,
842 indirection_table: rss.indirection_table,
843 })
844 })
845 .transpose()?;
846
847 let rndis_state = match rndis_state {
848 saved_state::RndisState::Initializing => RndisState::Initializing,
849 saved_state::RndisState::Operational => RndisState::Operational,
850 saved_state::RndisState::Halted => RndisState::Halted,
851 };
852
853 let guest_vf_state = PrimaryChannelGuestVfState::Restoring(*guest_vf_state);
854 let offload_config = OffloadConfig {
855 checksum_tx: ChecksumOffloadConfig {
856 ipv4_header: offload_config.checksum_tx.ipv4_header,
857 tcp4: offload_config.checksum_tx.tcp4,
858 udp4: offload_config.checksum_tx.udp4,
859 tcp6: offload_config.checksum_tx.tcp6,
860 udp6: offload_config.checksum_tx.udp6,
861 },
862 checksum_rx: ChecksumOffloadConfig {
863 ipv4_header: offload_config.checksum_rx.ipv4_header,
864 tcp4: offload_config.checksum_rx.tcp4,
865 udp4: offload_config.checksum_rx.udp4,
866 tcp6: offload_config.checksum_rx.tcp6,
867 udp6: offload_config.checksum_rx.udp6,
868 },
869 lso4: offload_config.lso4,
870 lso6: offload_config.lso6,
871 };
872
873 let pending_link_action = if let Some(pending) = pending_link_action {
874 PendingLinkAction::Active(pending)
875 } else {
876 PendingLinkAction::Default
877 };
878
879 Ok(Self {
880 guest_vf_state,
881 is_data_path_switched: None,
882 control_messages,
883 control_messages_len,
884 free_control_buffers,
885 rss_state,
886 requested_num_queues: num_queues,
887 rndis_state,
888 pending_offload_change,
889 offload_config,
890 tx_spread_sent,
891 guest_link_up: !guest_link_down,
892 pending_link_action,
893 })
894 }
895}
896
897struct ControlMessage {
898 message_type: u32,
899 data: Box<[u8]>,
900}
901
902const TX_PACKET_QUOTA: usize = 1024;
903
904impl ActiveState {
905 fn new(primary: Option<PrimaryChannelState>, recv_buffer_count: u32) -> Self {
906 Self {
907 primary,
908 pending_tx_packets: vec![Default::default(); TX_PACKET_QUOTA],
909 free_tx_packets: (0..TX_PACKET_QUOTA as u32).rev().map(TxId).collect(),
910 pending_tx_completions: VecDeque::new(),
911 rx_bufs: RxBuffers::new(recv_buffer_count),
912 stats: Default::default(),
913 }
914 }
915
916 fn restore(
917 channel: &saved_state::Channel,
918 recv_buffer_count: u32,
919 ) -> Result<Self, NetRestoreError> {
920 let mut active = Self::new(None, recv_buffer_count);
921 let saved_state::Channel {
922 pending_tx_completions,
923 in_use_rx,
924 } = channel;
925 for rx in in_use_rx {
926 active
927 .rx_bufs
928 .allocate(rx.buffers.as_slice().iter().copied())?;
929 }
930 for &transaction_id in pending_tx_completions {
931 let tx_id = active.free_tx_packets.pop();
935 if let Some(id) = tx_id {
936 active.pending_tx_packets[id.0 as usize].transaction_id = transaction_id;
938 }
939 active
942 .pending_tx_completions
943 .push_back(PendingTxCompletion {
944 transaction_id,
945 tx_id,
946 status: protocol::Status::SUCCESS,
947 });
948 }
949 Ok(active)
950 }
951}
952
953#[derive(Default, Clone)]
956struct PendingTxPacket {
957 pending_packet_count: usize,
958 transaction_id: u64,
959}
960
961const RX_BATCH_SIZE: usize = 375;
966
967const RX_RESERVED_CONTROL_BUFFERS: u32 = 16;
969
970pub struct Nic {
972 instance_id: Guid,
973 resources: DeviceResources,
974 coordinator: TaskControl<CoordinatorState, Coordinator>,
975 coordinator_send: Option<mpsc::Sender<CoordinatorMessage>>,
976 adapter: Arc<Adapter>,
977 driver_source: VmTaskDriverSource,
978}
979
980pub struct NicBuilder {
981 virtual_function: Option<Box<dyn VirtualFunction>>,
982 limit_ring_buffer: bool,
983 max_queues: u16,
984 get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
985}
986
987impl NicBuilder {
988 pub fn limit_ring_buffer(mut self, limit: bool) -> Self {
989 self.limit_ring_buffer = limit;
990 self
991 }
992
993 pub fn max_queues(mut self, max_queues: u16) -> Self {
994 self.max_queues = max_queues;
995 self
996 }
997
998 pub fn virtual_function(mut self, virtual_function: Box<dyn VirtualFunction>) -> Self {
999 self.virtual_function = Some(virtual_function);
1000 self
1001 }
1002
1003 pub fn get_guest_os_id(mut self, os_type: Box<dyn Fn() -> HvGuestOsId + Send + Sync>) -> Self {
1004 self.get_guest_os_id = Some(os_type);
1005 self
1006 }
1007
1008 pub fn build(
1010 self,
1011 driver_source: &VmTaskDriverSource,
1012 instance_id: Guid,
1013 endpoint: Box<dyn Endpoint>,
1014 mac_address: MacAddress,
1015 adapter_index: u32,
1016 ) -> Nic {
1017 let multiqueue = endpoint.multiqueue_support();
1018
1019 let max_queues = self.max_queues.clamp(
1020 1,
1021 multiqueue.max_queues.min(NETVSP_MAX_SUBCHANNELS_PER_VNIC),
1022 );
1023
1024 let ring_size_limit = if self.limit_ring_buffer { 1024 } else { 0 };
1029
1030 let free_tx_packet_threshold = if endpoint.tx_fast_completions() {
1035 TX_PACKET_QUOTA
1036 } else {
1037 TX_PACKET_QUOTA / 4
1040 };
1041
1042 let tx_offloads = endpoint.tx_offload_support();
1043
1044 let offload_support = OffloadConfig {
1047 checksum_rx: ChecksumOffloadConfig {
1048 ipv4_header: true,
1049 tcp4: true,
1050 udp4: true,
1051 tcp6: true,
1052 udp6: true,
1053 },
1054 checksum_tx: ChecksumOffloadConfig {
1055 ipv4_header: tx_offloads.ipv4_header,
1056 tcp4: tx_offloads.tcp,
1057 tcp6: tx_offloads.tcp,
1058 udp4: tx_offloads.udp,
1059 udp6: tx_offloads.udp,
1060 },
1061 lso4: tx_offloads.tso,
1062 lso6: tx_offloads.tso,
1063 };
1064
1065 let driver = driver_source.simple();
1066 let adapter = Arc::new(Adapter {
1067 driver,
1068 mac_address,
1069 max_queues,
1070 indirection_table_size: multiqueue.indirection_table_size,
1071 offload_support,
1072 free_tx_packet_threshold,
1073 ring_size_limit: ring_size_limit.into(),
1074 tx_fast_completions: endpoint.tx_fast_completions(),
1075 adapter_index,
1076 get_guest_os_id: self.get_guest_os_id,
1077 num_sub_channels_opened: AtomicUsize::new(0),
1078 link_speed: endpoint.link_speed(),
1079 });
1080
1081 let coordinator = TaskControl::new(CoordinatorState {
1082 endpoint,
1083 adapter: adapter.clone(),
1084 virtual_function: self.virtual_function,
1085 pending_vf_state: CoordinatorStatePendingVfState::Ready,
1086 });
1087
1088 Nic {
1089 instance_id,
1090 resources: Default::default(),
1091 coordinator,
1092 coordinator_send: None,
1093 adapter,
1094 driver_source: driver_source.clone(),
1095 }
1096 }
1097}
1098
1099fn can_use_ring_opt<T: RingMem>(queue: &mut Queue<T>, guest_os_id: Option<HvGuestOsId>) -> bool {
1100 let Some(guest_os_id) = guest_os_id else {
1101 return false;
1103 };
1104
1105 if !queue.split().0.supports_pending_send_size() {
1106 return false;
1108 }
1109
1110 let Some(open_source_os) = guest_os_id.open_source() else {
1111 return true;
1113 };
1114
1115 match HvGuestOsOpenSourceType(open_source_os.os_type()) {
1116 HvGuestOsOpenSourceType::FREEBSD => open_source_os.version() >= 1400097,
1120 _ => true,
1121 }
1122}
1123
1124impl Nic {
1125 pub fn builder() -> NicBuilder {
1126 NicBuilder {
1127 virtual_function: None,
1128 limit_ring_buffer: false,
1129 max_queues: !0,
1130 get_guest_os_id: None,
1131 }
1132 }
1133
1134 pub fn shutdown(self) -> Box<dyn Endpoint> {
1135 let (state, _) = self.coordinator.into_inner();
1136 state.endpoint
1137 }
1138}
1139
1140impl InspectMut for Nic {
1141 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
1142 self.coordinator.inspect_mut(req);
1143 }
1144}
1145
1146#[async_trait]
1147impl VmbusDevice for Nic {
1148 fn offer(&self) -> OfferParams {
1149 OfferParams {
1150 interface_name: "net".to_owned(),
1151 instance_id: self.instance_id,
1152 interface_id: Guid {
1153 data1: 0xf8615163,
1154 data2: 0xdf3e,
1155 data3: 0x46c5,
1156 data4: [0x91, 0x3f, 0xf2, 0xd2, 0xf9, 0x65, 0xed, 0xe],
1157 },
1158 subchannel_index: 0,
1159 mnf_interrupt_latency: Some(Duration::from_micros(100)),
1160 ..Default::default()
1161 }
1162 }
1163
1164 fn max_subchannels(&self) -> u16 {
1165 self.adapter.max_queues
1166 }
1167
1168 fn install(&mut self, resources: DeviceResources) {
1169 self.resources = resources;
1170 }
1171
1172 async fn open(
1173 &mut self,
1174 channel_idx: u16,
1175 open_request: &OpenRequest,
1176 ) -> Result<(), ChannelOpenError> {
1177 let state = if channel_idx == 0 {
1179 self.insert_coordinator(1, false);
1180 WorkerState::Init(None)
1181 } else {
1182 self.coordinator.stop().await;
1183 let buffers = self.coordinator.state().unwrap().buffers.clone().unwrap();
1185 WorkerState::Ready(ReadyState {
1186 state: ActiveState::new(None, buffers.recv_buffer.count),
1187 buffers,
1188 data: ProcessingData::new(),
1189 })
1190 };
1191
1192 let num_opened = self
1193 .adapter
1194 .num_sub_channels_opened
1195 .fetch_add(1, Ordering::SeqCst);
1196 let r = self.insert_worker(channel_idx, open_request, state, true);
1197 if channel_idx != 0
1198 && num_opened + 1 == self.coordinator.state_mut().unwrap().num_queues as usize
1199 {
1200 let coordinator = &mut self.coordinator.state_mut().unwrap();
1201 coordinator.workers[0].stop().await;
1202 }
1203
1204 if r.is_err() && channel_idx == 0 {
1205 self.coordinator.remove();
1206 } else {
1207 self.coordinator.start();
1209 }
1210 r?;
1211 Ok(())
1212 }
1213
1214 async fn close(&mut self, channel_idx: u16) {
1215 if !self.coordinator.has_state() {
1216 tracing::error!("Close called while vmbus channel is already closed");
1217 return;
1218 }
1219
1220 let restart = self.coordinator.stop().await;
1222
1223 {
1225 let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1226 worker.stop().await;
1227 if worker.has_state() {
1228 worker.remove();
1229 }
1230 }
1231
1232 self.adapter
1233 .num_sub_channels_opened
1234 .fetch_sub(1, Ordering::SeqCst);
1235 if channel_idx == 0 {
1237 for worker in &mut self.coordinator.state_mut().unwrap().workers {
1238 worker.task_mut().queue_state = None;
1239 }
1240
1241 self.coordinator.task_mut().endpoint.stop().await;
1243
1244 self.coordinator.remove();
1249 } else {
1250 if restart {
1252 self.coordinator.start();
1253 }
1254 }
1255 }
1256
1257 async fn retarget_vp(&mut self, channel_idx: u16, target_vp: u32) {
1258 if !self.coordinator.has_state() {
1259 return;
1260 }
1261
1262 let coordinator_running = self.coordinator.stop().await;
1264 let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1265 worker.stop().await;
1266 let (net_queue, worker_state) = worker.get_mut();
1267
1268 net_queue.driver.retarget_vp(target_vp);
1270
1271 if let Some(worker_state) = worker_state {
1272 worker_state.target_vp = target_vp;
1274 if let Some(queue_state) = &mut net_queue.queue_state {
1275 queue_state.target_vp_set = false;
1277 }
1278 }
1279
1280 if coordinator_running {
1282 self.coordinator.start();
1283 }
1284 }
1285
1286 fn start(&mut self) {
1287 if !self.coordinator.is_running() {
1288 self.coordinator.start();
1289 }
1290 }
1291
1292 async fn stop(&mut self) {
1293 self.coordinator.stop().await;
1294 if let Some(coordinator) = self.coordinator.state_mut() {
1295 coordinator.stop_workers().await;
1296 }
1297 }
1298
1299 fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
1300 Some(self)
1301 }
1302}
1303
1304#[async_trait]
1305impl SaveRestoreVmbusDevice for Nic {
1306 async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
1307 let state = self.saved_state();
1308 Ok(SavedStateBlob::new(state))
1309 }
1310
1311 async fn restore(
1312 &mut self,
1313 control: RestoreControl<'_>,
1314 state: SavedStateBlob,
1315 ) -> Result<(), RestoreError> {
1316 let state: saved_state::SavedState = state.parse()?;
1317 if let Err(err) = self.restore_state(control, state).await {
1318 tracing::error!(err = &err as &dyn std::error::Error, instance_id = %self.instance_id, "Failed restoring network vmbus state");
1319 Err(err.into())
1320 } else {
1321 Ok(())
1322 }
1323 }
1324}
1325
1326impl Nic {
1327 fn insert_worker(
1331 &mut self,
1332 channel_idx: u16,
1333 open_request: &OpenRequest,
1334 state: WorkerState,
1335 start: bool,
1336 ) -> Result<(), OpenError> {
1337 let coordinator = self.coordinator.state_mut().unwrap();
1338
1339 let driver = coordinator.workers[channel_idx as usize]
1341 .task()
1342 .driver
1343 .clone();
1344 driver.retarget_vp(open_request.open_data.target_vp);
1345
1346 let raw = gpadl_channel(&driver, &self.resources, open_request, channel_idx)
1347 .map_err(OpenError::Ring)?;
1348 let mut queue = Queue::new(raw).map_err(OpenError::Queue)?;
1349 let guest_os_id = self.adapter.get_guest_os_id.as_ref().map(|f| f());
1350 let can_use_ring_size_opt = can_use_ring_opt(&mut queue, guest_os_id);
1351 let worker = Worker {
1352 channel_idx,
1353 target_vp: open_request.open_data.target_vp,
1354 mem: self
1355 .resources
1356 .offer_resources
1357 .guest_memory(open_request)
1358 .clone(),
1359 channel: NetChannel {
1360 adapter: self.adapter.clone(),
1361 queue,
1362 gpadl_map: self.resources.gpadl_map.clone(),
1363 packet_size: protocol::PACKET_SIZE_V1,
1364 pending_send_size: 0,
1365 restart: None,
1366 can_use_ring_size_opt,
1367 },
1368 state,
1369 coordinator_send: self.coordinator_send.clone().unwrap(),
1370 };
1371 let instance_id = self.instance_id;
1372 let worker_task = &mut coordinator.workers[channel_idx as usize];
1373 worker_task.insert(
1374 driver,
1375 format!("netvsp-{}-{}", instance_id, channel_idx),
1376 worker,
1377 );
1378 if start {
1379 worker_task.start();
1380 }
1381 Ok(())
1382 }
1383}
1384
1385impl Nic {
1386 fn insert_coordinator(&mut self, num_queues: u16, restoring: bool) {
1388 let mut driver_builder = self.driver_source.builder();
1389 driver_builder.target_vp(0);
1392 driver_builder.run_on_target(!self.adapter.tx_fast_completions);
1397
1398 #[expect(clippy::disallowed_methods)] let (send, recv) = mpsc::channel(1);
1400 self.coordinator_send = Some(send);
1401 self.coordinator.insert(
1402 &self.adapter.driver,
1403 format!("netvsp-{}-coordinator", self.instance_id),
1404 Coordinator {
1405 recv,
1406 channel_control: self.resources.channel_control.clone(),
1407 restart: restoring,
1408 workers: (0..self.adapter.max_queues)
1409 .map(|i| {
1410 TaskControl::new(NetQueue {
1411 queue_state: None,
1412 driver: driver_builder
1413 .build(format!("netvsp-{}-{}", self.instance_id, i)),
1414 })
1415 })
1416 .collect(),
1417 buffers: None,
1418 num_queues,
1419 },
1420 );
1421 }
1422}
1423
1424#[derive(Debug, Error)]
1425enum NetRestoreError {
1426 #[error("unsupported protocol version {0:#x}")]
1427 UnsupportedVersion(u32),
1428 #[error("send/receive buffer invalid gpadl ID")]
1429 UnknownGpadlId(#[from] UnknownGpadlId),
1430 #[error("failed to restore channels")]
1431 Channel(#[source] ChannelRestoreError),
1432 #[error(transparent)]
1433 ReceiveBuffer(#[from] BufferError),
1434 #[error(transparent)]
1435 SuballocationMisconfigured(#[from] SubAllocationInUse),
1436 #[error(transparent)]
1437 Open(#[from] OpenError),
1438 #[error("invalid rss key size")]
1439 InvalidRssKeySize,
1440 #[error("reduced indirection table size")]
1441 ReducedIndirectionTableSize,
1442}
1443
1444impl From<NetRestoreError> for RestoreError {
1445 fn from(err: NetRestoreError) -> Self {
1446 RestoreError::InvalidSavedState(anyhow::Error::new(err))
1447 }
1448}
1449
1450impl Nic {
1451 async fn restore_state(
1452 &mut self,
1453 mut control: RestoreControl<'_>,
1454 state: saved_state::SavedState,
1455 ) -> Result<(), NetRestoreError> {
1456 if let Some(state) = state.open {
1457 let open = match &state.primary {
1458 saved_state::Primary::Version => vec![true],
1459 saved_state::Primary::Init(_) => vec![true],
1460 saved_state::Primary::Ready(ready) => {
1461 ready.channels.iter().map(|x| x.is_some()).collect()
1462 }
1463 };
1464
1465 let mut states: Vec<_> = open.iter().map(|_| None).collect();
1466
1467 let requests = control
1473 .restore(&open)
1474 .await
1475 .map_err(NetRestoreError::Channel)?;
1476
1477 match state.primary {
1478 saved_state::Primary::Version => {
1479 states[0] = Some(WorkerState::Init(None));
1480 }
1481 saved_state::Primary::Init(init) => {
1482 let version = check_version(init.version)
1483 .ok_or(NetRestoreError::UnsupportedVersion(init.version))?;
1484
1485 let recv_buffer = init
1486 .receive_buffer
1487 .map(|recv_buffer| {
1488 ReceiveBuffer::new(
1489 &self.resources.gpadl_map,
1490 recv_buffer.gpadl_id,
1491 recv_buffer.id,
1492 recv_buffer.sub_allocation_size,
1493 )
1494 })
1495 .transpose()?;
1496
1497 let send_buffer = init
1498 .send_buffer
1499 .map(|send_buffer| {
1500 SendBuffer::new(&self.resources.gpadl_map, send_buffer.gpadl_id)
1501 })
1502 .transpose()?;
1503
1504 let state = InitState {
1505 version,
1506 ndis_config: init.ndis_config.map(
1507 |saved_state::NdisConfig { mtu, capabilities }| NdisConfig {
1508 mtu,
1509 capabilities: capabilities.into(),
1510 },
1511 ),
1512 ndis_version: init.ndis_version.map(
1513 |saved_state::NdisVersion { major, minor }| NdisVersion {
1514 major,
1515 minor,
1516 },
1517 ),
1518 recv_buffer,
1519 send_buffer,
1520 };
1521 states[0] = Some(WorkerState::Init(Some(state)));
1522 }
1523 saved_state::Primary::Ready(ready) => {
1524 let saved_state::ReadyPrimary {
1525 version,
1526 receive_buffer,
1527 send_buffer,
1528 mut control_messages,
1529 mut rss_state,
1530 channels,
1531 ndis_version,
1532 ndis_config,
1533 rndis_state,
1534 guest_vf_state,
1535 offload_config,
1536 pending_offload_change,
1537 tx_spread_sent,
1538 guest_link_down,
1539 pending_link_action,
1540 } = ready;
1541
1542 let version = check_version(version)
1543 .ok_or(NetRestoreError::UnsupportedVersion(version))?;
1544
1545 let request = requests[0].as_ref().unwrap();
1546 let buffers = Arc::new(ChannelBuffers {
1547 version,
1548 mem: self.resources.offer_resources.guest_memory(request).clone(),
1549 recv_buffer: ReceiveBuffer::new(
1550 &self.resources.gpadl_map,
1551 receive_buffer.gpadl_id,
1552 receive_buffer.id,
1553 receive_buffer.sub_allocation_size,
1554 )?,
1555 send_buffer: {
1556 if let Some(send_buffer) = send_buffer {
1557 Some(SendBuffer::new(
1558 &self.resources.gpadl_map,
1559 send_buffer.gpadl_id,
1560 )?)
1561 } else {
1562 None
1563 }
1564 },
1565 ndis_version: {
1566 let saved_state::NdisVersion { major, minor } = ndis_version;
1567 NdisVersion { major, minor }
1568 },
1569 ndis_config: {
1570 let saved_state::NdisConfig { mtu, capabilities } = ndis_config;
1571 NdisConfig {
1572 mtu,
1573 capabilities: capabilities.into(),
1574 }
1575 },
1576 });
1577
1578 for (channel_idx, channel) in channels.iter().enumerate() {
1579 let channel = if let Some(channel) = channel {
1580 channel
1581 } else {
1582 continue;
1583 };
1584
1585 let mut active = ActiveState::restore(channel, buffers.recv_buffer.count)?;
1586
1587 if channel_idx == 0 {
1589 let primary = PrimaryChannelState::restore(
1590 &guest_vf_state,
1591 &rndis_state,
1592 &offload_config,
1593 pending_offload_change,
1594 channels.len() as u16,
1595 self.adapter.indirection_table_size,
1596 &active.rx_bufs,
1597 std::mem::take(&mut control_messages),
1598 rss_state.take(),
1599 tx_spread_sent,
1600 guest_link_down,
1601 pending_link_action,
1602 )?;
1603 active.primary = Some(primary);
1604 }
1605
1606 states[channel_idx] = Some(WorkerState::Ready(ReadyState {
1607 buffers: buffers.clone(),
1608 state: active,
1609 data: ProcessingData::new(),
1610 }));
1611 }
1612 }
1613 }
1614
1615 self.insert_coordinator(states.len() as u16, true);
1618
1619 for (channel_idx, (state, request)) in states.into_iter().zip(requests).enumerate() {
1620 if let Some(state) = state {
1621 self.insert_worker(channel_idx as u16, &request.unwrap(), state, false)?;
1622 }
1623 }
1624 } else {
1625 control
1626 .restore(&[false])
1627 .await
1628 .map_err(NetRestoreError::Channel)?;
1629 }
1630 Ok(())
1631 }
1632
1633 fn saved_state(&self) -> saved_state::SavedState {
1634 let open = if let Some(coordinator) = self.coordinator.state() {
1635 let primary = coordinator.workers[0].state().unwrap();
1636 let primary = match &primary.state {
1637 WorkerState::Init(None) => saved_state::Primary::Version,
1638 WorkerState::Init(Some(init)) => {
1639 saved_state::Primary::Init(saved_state::InitPrimary {
1640 version: init.version as u32,
1641 ndis_config: init.ndis_config.map(|NdisConfig { mtu, capabilities }| {
1642 saved_state::NdisConfig {
1643 mtu,
1644 capabilities: capabilities.into(),
1645 }
1646 }),
1647 ndis_version: init.ndis_version.map(|NdisVersion { major, minor }| {
1648 saved_state::NdisVersion { major, minor }
1649 }),
1650 receive_buffer: init.recv_buffer.as_ref().map(|x| x.saved_state()),
1651 send_buffer: init.send_buffer.as_ref().map(|x| saved_state::SendBuffer {
1652 gpadl_id: x.gpadl.id(),
1653 }),
1654 })
1655 }
1656 WorkerState::Ready(ready) => {
1657 let primary = ready.state.primary.as_ref().unwrap();
1658
1659 let rndis_state = match primary.rndis_state {
1660 RndisState::Initializing => saved_state::RndisState::Initializing,
1661 RndisState::Operational => saved_state::RndisState::Operational,
1662 RndisState::Halted => saved_state::RndisState::Halted,
1663 };
1664
1665 let offload_config = saved_state::OffloadConfig {
1666 checksum_tx: saved_state::ChecksumOffloadConfig {
1667 ipv4_header: primary.offload_config.checksum_tx.ipv4_header,
1668 tcp4: primary.offload_config.checksum_tx.tcp4,
1669 udp4: primary.offload_config.checksum_tx.udp4,
1670 tcp6: primary.offload_config.checksum_tx.tcp6,
1671 udp6: primary.offload_config.checksum_tx.udp6,
1672 },
1673 checksum_rx: saved_state::ChecksumOffloadConfig {
1674 ipv4_header: primary.offload_config.checksum_rx.ipv4_header,
1675 tcp4: primary.offload_config.checksum_rx.tcp4,
1676 udp4: primary.offload_config.checksum_rx.udp4,
1677 tcp6: primary.offload_config.checksum_rx.tcp6,
1678 udp6: primary.offload_config.checksum_rx.udp6,
1679 },
1680 lso4: primary.offload_config.lso4,
1681 lso6: primary.offload_config.lso6,
1682 };
1683
1684 let control_messages = primary
1685 .control_messages
1686 .iter()
1687 .map(|message| saved_state::IncomingControlMessage {
1688 message_type: message.message_type,
1689 data: message.data.to_vec(),
1690 })
1691 .collect();
1692
1693 let rss_state = primary.rss_state.as_ref().map(|rss| saved_state::RssState {
1694 key: rss.key.into(),
1695 indirection_table: rss.indirection_table.clone(),
1696 });
1697
1698 let pending_link_action = match primary.pending_link_action {
1699 PendingLinkAction::Default => None,
1700 PendingLinkAction::Active(action) | PendingLinkAction::Delay(action) => {
1701 Some(action)
1702 }
1703 };
1704
1705 let channels = coordinator.workers[..coordinator.num_queues as usize]
1706 .iter()
1707 .map(|worker| {
1708 worker.state().map(|worker| {
1709 let channel = if let Some(ready) = worker.state.ready() {
1710 let pending_tx_completions = ready
1713 .state
1714 .pending_tx_completions
1715 .iter()
1716 .map(|pending| pending.transaction_id)
1717 .chain(ready.state.pending_tx_packets.iter().filter_map(
1718 |inflight| {
1719 (inflight.pending_packet_count > 0)
1720 .then_some(inflight.transaction_id)
1721 },
1722 ))
1723 .collect();
1724
1725 saved_state::Channel {
1726 pending_tx_completions,
1727 in_use_rx: {
1728 ready
1729 .state
1730 .rx_bufs
1731 .allocated()
1732 .map(|id| saved_state::Rx {
1733 buffers: id.collect(),
1734 })
1735 .collect()
1736 },
1737 }
1738 } else {
1739 saved_state::Channel {
1740 pending_tx_completions: Vec::new(),
1741 in_use_rx: Vec::new(),
1742 }
1743 };
1744 channel
1745 })
1746 })
1747 .collect();
1748
1749 let guest_vf_state = match primary.guest_vf_state {
1750 PrimaryChannelGuestVfState::Initializing
1751 | PrimaryChannelGuestVfState::Unavailable
1752 | PrimaryChannelGuestVfState::Available { .. } => {
1753 saved_state::GuestVfState::NoState
1754 }
1755 PrimaryChannelGuestVfState::UnavailableFromAvailable
1756 | PrimaryChannelGuestVfState::AvailableAdvertised => {
1757 saved_state::GuestVfState::AvailableAdvertised
1758 }
1759 PrimaryChannelGuestVfState::Ready => saved_state::GuestVfState::Ready,
1760 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
1761 to_guest,
1762 id,
1763 } => saved_state::GuestVfState::DataPathSwitchPending {
1764 to_guest,
1765 id,
1766 result: None,
1767 },
1768 PrimaryChannelGuestVfState::DataPathSwitchPending {
1769 to_guest,
1770 id,
1771 result,
1772 } => saved_state::GuestVfState::DataPathSwitchPending {
1773 to_guest,
1774 id,
1775 result,
1776 },
1777 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
1778 | PrimaryChannelGuestVfState::DataPathSwitched
1779 | PrimaryChannelGuestVfState::DataPathSynthetic => {
1780 saved_state::GuestVfState::DataPathSwitched
1781 }
1782 PrimaryChannelGuestVfState::Restoring(saved_state) => saved_state,
1783 };
1784
1785 saved_state::Primary::Ready(saved_state::ReadyPrimary {
1786 version: ready.buffers.version as u32,
1787 receive_buffer: ready.buffers.recv_buffer.saved_state(),
1788 send_buffer: ready.buffers.send_buffer.as_ref().map(|sb| {
1789 saved_state::SendBuffer {
1790 gpadl_id: sb.gpadl.id(),
1791 }
1792 }),
1793 rndis_state,
1794 guest_vf_state,
1795 offload_config,
1796 pending_offload_change: primary.pending_offload_change,
1797 control_messages,
1798 rss_state,
1799 channels,
1800 ndis_config: {
1801 let NdisConfig { mtu, capabilities } = ready.buffers.ndis_config;
1802 saved_state::NdisConfig {
1803 mtu,
1804 capabilities: capabilities.into(),
1805 }
1806 },
1807 ndis_version: {
1808 let NdisVersion { major, minor } = ready.buffers.ndis_version;
1809 saved_state::NdisVersion { major, minor }
1810 },
1811 tx_spread_sent: primary.tx_spread_sent,
1812 guest_link_down: !primary.guest_link_up,
1813 pending_link_action,
1814 })
1815 }
1816 };
1817
1818 let state = saved_state::OpenState { primary };
1819 Some(state)
1820 } else {
1821 None
1822 };
1823
1824 saved_state::SavedState { open }
1825 }
1826}
1827
1828#[derive(Debug, Error)]
1829enum WorkerError {
1830 #[error("packet error")]
1831 Packet(#[source] PacketError),
1832 #[error("unexpected packet order: {0}")]
1833 UnexpectedPacketOrder(#[source] PacketOrderError),
1834 #[error("unknown rndis message type: {0}")]
1835 UnknownRndisMessageType(u32),
1836 #[error("memory access error")]
1837 Access(#[from] AccessError),
1838 #[error("rndis message too small")]
1839 RndisMessageTooSmall,
1840 #[error("unsupported rndis behavior")]
1841 UnsupportedRndisBehavior,
1842 #[error("vmbus queue error")]
1843 Queue(#[from] queue::Error),
1844 #[error("too many control messages")]
1845 TooManyControlMessages,
1846 #[error("invalid rndis packet completion")]
1847 InvalidRndisPacketCompletion,
1848 #[error("missing transaction id")]
1849 MissingTransactionId,
1850 #[error("invalid gpadl")]
1851 InvalidGpadl(#[source] guestmem::InvalidGpn),
1852 #[error("gpadl error")]
1853 GpadlError(#[source] GuestMemoryError),
1854 #[error("gpa direct error")]
1855 GpaDirectError(#[source] GuestMemoryError),
1856 #[error("endpoint")]
1857 Endpoint(#[source] anyhow::Error),
1858 #[error("message not supported on sub channel: {0}")]
1859 NotSupportedOnSubChannel(u32),
1860 #[error("the ring buffer ran out of space, which should not be possible")]
1861 OutOfSpace,
1862 #[error("send/receive buffer error")]
1863 Buffer(#[from] BufferError),
1864 #[error("invalid rndis state")]
1865 InvalidRndisState,
1866 #[error("rndis message type not implemented")]
1867 RndisMessageTypeNotImplemented,
1868 #[error("invalid TCP header offset")]
1869 InvalidTcpHeaderOffset,
1870 #[error("cancelled")]
1871 Cancelled(task_control::Cancelled),
1872 #[error("tearing down because send/receive buffer is revoked")]
1873 BufferRevoked,
1874 #[error("endpoint requires queue restart: {0}")]
1875 EndpointRequiresQueueRestart(#[source] anyhow::Error),
1876}
1877
1878impl From<task_control::Cancelled> for WorkerError {
1879 fn from(value: task_control::Cancelled) -> Self {
1880 Self::Cancelled(value)
1881 }
1882}
1883
1884#[derive(Debug, Error)]
1885enum OpenError {
1886 #[error("error establishing ring buffer")]
1887 Ring(#[source] vmbus_channel::gpadl_ring::Error),
1888 #[error("error establishing vmbus queue")]
1889 Queue(#[source] queue::Error),
1890}
1891
1892#[derive(Debug, Error)]
1893enum PacketError {
1894 #[error("UnknownType {0}")]
1895 UnknownType(u32),
1896 #[error("Access")]
1897 Access(#[source] AccessError),
1898 #[error("ExternalData")]
1899 ExternalData(#[source] ExternalDataError),
1900 #[error("InvalidSendBufferIndex")]
1901 InvalidSendBufferIndex,
1902}
1903
1904#[derive(Debug, Error)]
1905enum PacketOrderError {
1906 #[error("Invalid PacketData")]
1907 InvalidPacketData,
1908 #[error("Unexpected RndisPacket")]
1909 UnexpectedRndisPacket,
1910 #[error("SendNdisVersion already exists")]
1911 SendNdisVersionExists,
1912 #[error("SendNdisConfig already exists")]
1913 SendNdisConfigExists,
1914 #[error("SendReceiveBuffer already exists")]
1915 SendReceiveBufferExists,
1916 #[error("SendReceiveBuffer missing MTU")]
1917 SendReceiveBufferMissingMTU,
1918 #[error("SendSendBuffer already exists")]
1919 SendSendBufferExists,
1920 #[error("SwitchDataPathCompletion during PrimaryChannelState")]
1921 SwitchDataPathCompletionPrimaryChannelState,
1922}
1923
1924#[derive(Debug)]
1925enum PacketData {
1926 Init(protocol::MessageInit),
1927 SendNdisVersion(protocol::Message1SendNdisVersion),
1928 SendReceiveBuffer(protocol::Message1SendReceiveBuffer),
1929 SendSendBuffer(protocol::Message1SendSendBuffer),
1930 RevokeReceiveBuffer(Message1RevokeReceiveBuffer),
1931 RevokeSendBuffer(Message1RevokeSendBuffer),
1932 RndisPacket(protocol::Message1SendRndisPacket),
1933 RndisPacketComplete(protocol::Message1SendRndisPacketComplete),
1934 SendNdisConfig(protocol::Message2SendNdisConfig),
1935 SwitchDataPath(protocol::Message4SwitchDataPath),
1936 OidQueryEx(protocol::Message5OidQueryEx),
1937 SubChannelRequest(protocol::Message5SubchannelRequest),
1938 SendVfAssociationCompletion,
1939 SwitchDataPathCompletion,
1940}
1941
1942#[derive(Debug)]
1943struct Packet<'a> {
1944 data: PacketData,
1945 transaction_id: Option<u64>,
1946 external_data: MultiPagedRangeBuf<GpnList>,
1947 send_buffer_suballocation: PagedRange<'a>,
1948}
1949
1950type PacketReader<'a> = PagedRangesReader<
1951 'a,
1952 std::iter::Chain<std::iter::Once<PagedRange<'a>>, MultiPagedRangeIter<'a>>,
1953>;
1954
1955impl Packet<'_> {
1956 fn rndis_reader<'a>(&'a self, mem: &'a GuestMemory) -> PacketReader<'a> {
1957 PagedRanges::new(
1958 std::iter::once(self.send_buffer_suballocation).chain(self.external_data.iter()),
1959 )
1960 .reader(mem)
1961 }
1962}
1963
1964fn read_packet_data<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
1965 reader: &mut impl MemoryRead,
1966) -> Result<T, PacketError> {
1967 reader.read_plain().map_err(PacketError::Access)
1968}
1969
1970fn parse_packet<'a, T: RingMem>(
1971 packet_ref: &queue::PacketRef<'_, T>,
1972 send_buffer: Option<&'a SendBuffer>,
1973 version: Option<Version>,
1974) -> Result<Packet<'a>, PacketError> {
1975 let packet = match packet_ref.as_ref() {
1976 IncomingPacket::Data(data) => data,
1977 IncomingPacket::Completion(completion) => {
1978 let data = if completion.transaction_id() == VF_ASSOCIATION_TRANSACTION_ID {
1979 PacketData::SendVfAssociationCompletion
1980 } else if completion.transaction_id() == SWITCH_DATA_PATH_TRANSACTION_ID {
1981 PacketData::SwitchDataPathCompletion
1982 } else {
1983 let mut reader = completion.reader();
1984 let header: protocol::MessageHeader =
1985 reader.read_plain().map_err(PacketError::Access)?;
1986 match header.message_type {
1987 protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE => {
1988 PacketData::RndisPacketComplete(read_packet_data(&mut reader)?)
1989 }
1990 typ => return Err(PacketError::UnknownType(typ)),
1991 }
1992 };
1993 return Ok(Packet {
1994 data,
1995 transaction_id: Some(completion.transaction_id()),
1996 external_data: MultiPagedRangeBuf::empty(),
1997 send_buffer_suballocation: PagedRange::empty(),
1998 });
1999 }
2000 };
2001
2002 let mut reader = packet.reader();
2003 let header: protocol::MessageHeader = reader.read_plain().map_err(PacketError::Access)?;
2004 let mut send_buffer_suballocation = PagedRange::empty();
2005 let data = match header.message_type {
2006 protocol::MESSAGE_TYPE_INIT => PacketData::Init(read_packet_data(&mut reader)?),
2007 protocol::MESSAGE1_TYPE_SEND_NDIS_VERSION if version >= Some(Version::V1) => {
2008 PacketData::SendNdisVersion(read_packet_data(&mut reader)?)
2009 }
2010 protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2011 PacketData::SendReceiveBuffer(read_packet_data(&mut reader)?)
2012 }
2013 protocol::MESSAGE1_TYPE_REVOKE_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2014 PacketData::RevokeReceiveBuffer(read_packet_data(&mut reader)?)
2015 }
2016 protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER if version >= Some(Version::V1) => {
2017 PacketData::SendSendBuffer(read_packet_data(&mut reader)?)
2018 }
2019 protocol::MESSAGE1_TYPE_REVOKE_SEND_BUFFER if version >= Some(Version::V1) => {
2020 PacketData::RevokeSendBuffer(read_packet_data(&mut reader)?)
2021 }
2022 protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET if version >= Some(Version::V1) => {
2023 let message: protocol::Message1SendRndisPacket = read_packet_data(&mut reader)?;
2024 if message.send_buffer_section_index != 0xffffffff {
2025 send_buffer_suballocation = send_buffer
2026 .ok_or(PacketError::InvalidSendBufferIndex)?
2027 .gpadl
2028 .first()
2029 .unwrap()
2030 .try_subrange(
2031 message.send_buffer_section_index as usize * 6144,
2032 message.send_buffer_section_size as usize,
2033 )
2034 .ok_or(PacketError::InvalidSendBufferIndex)?;
2035 }
2036 PacketData::RndisPacket(message)
2037 }
2038 protocol::MESSAGE2_TYPE_SEND_NDIS_CONFIG if version >= Some(Version::V2) => {
2039 PacketData::SendNdisConfig(read_packet_data(&mut reader)?)
2040 }
2041 protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH if version >= Some(Version::V4) => {
2042 PacketData::SwitchDataPath(read_packet_data(&mut reader)?)
2043 }
2044 protocol::MESSAGE5_TYPE_OID_QUERY_EX if version >= Some(Version::V5) => {
2045 PacketData::OidQueryEx(read_packet_data(&mut reader)?)
2046 }
2047 protocol::MESSAGE5_TYPE_SUB_CHANNEL if version >= Some(Version::V5) => {
2048 PacketData::SubChannelRequest(read_packet_data(&mut reader)?)
2049 }
2050 typ => return Err(PacketError::UnknownType(typ)),
2051 };
2052 Ok(Packet {
2053 data,
2054 transaction_id: packet.transaction_id(),
2055 external_data: packet
2056 .read_external_ranges()
2057 .map_err(PacketError::ExternalData)?,
2058 send_buffer_suballocation,
2059 })
2060}
2061
2062#[derive(Debug, Copy, Clone)]
2063struct NvspMessage<T> {
2064 header: protocol::MessageHeader,
2065 data: T,
2066 padding: &'static [u8],
2067}
2068
2069impl<T: IntoBytes + Immutable + KnownLayout> NvspMessage<T> {
2070 fn payload(&self) -> [&[u8]; 3] {
2071 [self.header.as_bytes(), self.data.as_bytes(), self.padding]
2072 }
2073}
2074
2075impl<T: RingMem> NetChannel<T> {
2076 fn message<P: IntoBytes + Immutable + KnownLayout>(
2077 &self,
2078 message_type: u32,
2079 data: P,
2080 ) -> NvspMessage<P> {
2081 let padding = self.padding(&data);
2082 NvspMessage {
2083 header: protocol::MessageHeader { message_type },
2084 data,
2085 padding,
2086 }
2087 }
2088
2089 fn padding<P: IntoBytes + Immutable + KnownLayout>(&self, data: &P) -> &'static [u8] {
2092 static PADDING: &[u8] = &[0; protocol::PACKET_SIZE_V61];
2093 let padding_len = self.packet_size
2094 - cmp::min(
2095 self.packet_size,
2096 size_of::<protocol::MessageHeader>() + data.as_bytes().len(),
2097 );
2098 &PADDING[..padding_len]
2099 }
2100
2101 fn send_completion(
2102 &mut self,
2103 transaction_id: Option<u64>,
2104 payload: &[&[u8]],
2105 ) -> Result<(), WorkerError> {
2106 match transaction_id {
2107 None => Ok(()),
2108 Some(transaction_id) => Ok(self
2109 .queue
2110 .split()
2111 .1
2112 .try_write(&queue::OutgoingPacket {
2113 transaction_id,
2114 packet_type: OutgoingPacketType::Completion,
2115 payload,
2116 })
2117 .map_err(|err| match err {
2118 queue::TryWriteError::Full(_) => WorkerError::OutOfSpace,
2119 queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2120 })?),
2121 }
2122 }
2123}
2124
2125static SUPPORTED_VERSIONS: &[Version] = &[
2126 Version::V1,
2127 Version::V2,
2128 Version::V4,
2129 Version::V5,
2130 Version::V6,
2131 Version::V61,
2132];
2133
2134fn check_version(requested_version: u32) -> Option<Version> {
2135 SUPPORTED_VERSIONS
2136 .iter()
2137 .find(|version| **version as u32 == requested_version)
2138 .copied()
2139}
2140
2141#[derive(Debug)]
2142struct ReceiveBuffer {
2143 gpadl: GpadlView,
2144 id: u16,
2145 count: u32,
2146 sub_allocation_size: u32,
2147}
2148
2149#[derive(Debug, Error)]
2150enum BufferError {
2151 #[error("unsupported suballocation size {0}")]
2152 UnsupportedSuballocationSize(u32),
2153 #[error("unaligned gpadl")]
2154 UnalignedGpadl,
2155 #[error("unknown gpadl ID")]
2156 UnknownGpadlId(#[from] UnknownGpadlId),
2157}
2158
2159impl ReceiveBuffer {
2160 fn new(
2161 gpadl_map: &GpadlMapView,
2162 gpadl_id: GpadlId,
2163 id: u16,
2164 sub_allocation_size: u32,
2165 ) -> Result<Self, BufferError> {
2166 if sub_allocation_size < sub_allocation_size_for_mtu(DEFAULT_MTU) {
2167 return Err(BufferError::UnsupportedSuballocationSize(
2168 sub_allocation_size,
2169 ));
2170 }
2171 let gpadl = gpadl_map.map(gpadl_id)?;
2172 let range = gpadl
2173 .contiguous_aligned()
2174 .ok_or(BufferError::UnalignedGpadl)?;
2175 let num_sub_allocations = range.len() as u32 / sub_allocation_size;
2176 if num_sub_allocations == 0 {
2177 return Err(BufferError::UnsupportedSuballocationSize(
2178 sub_allocation_size,
2179 ));
2180 }
2181 let recv_buffer = Self {
2182 gpadl,
2183 id,
2184 count: num_sub_allocations,
2185 sub_allocation_size,
2186 };
2187 Ok(recv_buffer)
2188 }
2189
2190 fn range(&self, index: u32) -> PagedRange<'_> {
2191 self.gpadl.first().unwrap().subrange(
2192 (index * self.sub_allocation_size) as usize,
2193 self.sub_allocation_size as usize,
2194 )
2195 }
2196
2197 fn transfer_page_range(&self, index: u32, len: usize) -> ring::TransferPageRange {
2198 assert!(len <= self.sub_allocation_size as usize);
2199 ring::TransferPageRange {
2200 byte_offset: index * self.sub_allocation_size,
2201 byte_count: len as u32,
2202 }
2203 }
2204
2205 fn saved_state(&self) -> saved_state::ReceiveBuffer {
2206 saved_state::ReceiveBuffer {
2207 gpadl_id: self.gpadl.id(),
2208 id: self.id,
2209 sub_allocation_size: self.sub_allocation_size,
2210 }
2211 }
2212}
2213
2214#[derive(Debug)]
2215struct SendBuffer {
2216 gpadl: GpadlView,
2217}
2218
2219impl SendBuffer {
2220 fn new(gpadl_map: &GpadlMapView, gpadl_id: GpadlId) -> Result<Self, BufferError> {
2221 let gpadl = gpadl_map.map(gpadl_id)?;
2222 gpadl
2223 .contiguous_aligned()
2224 .ok_or(BufferError::UnalignedGpadl)?;
2225 Ok(Self { gpadl })
2226 }
2227}
2228
2229impl<T: RingMem> NetChannel<T> {
2230 fn handle_rndis_message(
2232 &mut self,
2233 buffers: &ChannelBuffers,
2234 state: &mut ActiveState,
2235 id: TxId,
2236 message_type: u32,
2237 mut reader: PacketReader<'_>,
2238 segments: &mut Vec<TxSegment>,
2239 ) -> Result<bool, WorkerError> {
2240 let is_packet = match message_type {
2241 rndisprot::MESSAGE_TYPE_PACKET_MSG => {
2242 self.handle_rndis_packet_message(
2243 id,
2244 reader,
2245 &buffers.mem,
2246 segments,
2247 &mut state.stats,
2248 )?;
2249 true
2250 }
2251 rndisprot::MESSAGE_TYPE_HALT_MSG => false,
2252 n => {
2253 let control = state
2254 .primary
2255 .as_mut()
2256 .ok_or(WorkerError::NotSupportedOnSubChannel(n))?;
2257
2258 const CONTROL_MESSAGE_MAX_QUEUED_BYTES: usize = 100 * 1024;
2263 if reader.len() == 0 {
2264 return Err(WorkerError::RndisMessageTooSmall);
2265 }
2266 if CONTROL_MESSAGE_MAX_QUEUED_BYTES - control.control_messages_len < reader.len() {
2269 return Err(WorkerError::TooManyControlMessages);
2270 }
2271 control.control_messages_len += reader.len();
2272 control.control_messages.push_back(ControlMessage {
2273 message_type,
2274 data: reader.read_all()?.into(),
2275 });
2276
2277 false
2278 }
2280 };
2281 Ok(is_packet)
2282 }
2283
2284 fn handle_rndis_packet_message(
2286 &mut self,
2287 id: TxId,
2288 reader: PacketReader<'_>,
2289 mem: &GuestMemory,
2290 segments: &mut Vec<TxSegment>,
2291 stats: &mut QueueStats,
2292 ) -> Result<(), WorkerError> {
2293 let headers = reader
2295 .clone()
2296 .into_inner()
2297 .paged_ranges()
2298 .find(|r| !r.is_empty())
2299 .ok_or(WorkerError::RndisMessageTooSmall)?;
2300 let mut data = reader.into_inner();
2301 let request: rndisprot::Packet = headers.reader(mem).read_plain()?;
2302 if request.num_oob_data_elements != 0
2303 || request.oob_data_length != 0
2304 || request.oob_data_offset != 0
2305 || request.vc_handle != 0
2306 {
2307 return Err(WorkerError::UnsupportedRndisBehavior);
2308 }
2309
2310 if data.len() < request.data_offset as usize
2311 || (data.len() - request.data_offset as usize) < request.data_length as usize
2312 || request.data_length == 0
2313 {
2314 return Err(WorkerError::RndisMessageTooSmall);
2315 }
2316
2317 data.skip(request.data_offset as usize);
2318 data.truncate(request.data_length as usize);
2319
2320 let mut metadata = net_backend::TxMetadata {
2321 id,
2322 len: request.data_length as usize,
2323 ..Default::default()
2324 };
2325
2326 if request.per_packet_info_length != 0 {
2327 let mut ppi = headers
2328 .try_subrange(
2329 request.per_packet_info_offset as usize,
2330 request.per_packet_info_length as usize,
2331 )
2332 .ok_or(WorkerError::RndisMessageTooSmall)?;
2333 while !ppi.is_empty() {
2334 let h: rndisprot::PerPacketInfo = ppi.reader(mem).read_plain()?;
2335 if h.size == 0 {
2336 return Err(WorkerError::RndisMessageTooSmall);
2337 }
2338 let (this, rest) = ppi
2339 .try_split(h.size as usize)
2340 .ok_or(WorkerError::RndisMessageTooSmall)?;
2341 let (_, d) = this
2342 .try_split(h.per_packet_information_offset as usize)
2343 .ok_or(WorkerError::RndisMessageTooSmall)?;
2344 match h.typ {
2345 rndisprot::PPI_TCP_IP_CHECKSUM => {
2346 let n: rndisprot::TxTcpIpChecksumInfo = d.reader(mem).read_plain()?;
2347
2348 metadata.offload_tcp_checksum =
2349 (n.is_ipv4() || n.is_ipv6()) && n.tcp_checksum();
2350 metadata.offload_udp_checksum =
2351 (n.is_ipv4() || n.is_ipv6()) && !n.tcp_checksum() && n.udp_checksum();
2352 metadata.offload_ip_header_checksum = n.is_ipv4() && n.ip_header_checksum();
2353 metadata.l3_protocol = if n.is_ipv4() {
2354 L3Protocol::Ipv4
2355 } else if n.is_ipv6() {
2356 L3Protocol::Ipv6
2357 } else {
2358 L3Protocol::Unknown
2359 };
2360 metadata.l2_len = ETHERNET_HEADER_LEN as u8;
2361 if metadata.offload_tcp_checksum || metadata.offload_udp_checksum {
2362 metadata.l3_len = if n.tcp_header_offset() >= metadata.l2_len as u16 {
2363 n.tcp_header_offset() - metadata.l2_len as u16
2364 } else if n.is_ipv4() {
2365 let mut reader = data.clone().reader(mem);
2366 reader.skip(metadata.l2_len as usize)?;
2367 let mut b = 0;
2368 reader.read(std::slice::from_mut(&mut b))?;
2369 (b as u16 >> 4) * 4
2370 } else {
2371 40
2373 };
2374 }
2375 }
2376 rndisprot::PPI_LSO => {
2377 let n: rndisprot::TcpLsoInfo = d.reader(mem).read_plain()?;
2378
2379 metadata.offload_tcp_segmentation = true;
2380 metadata.offload_tcp_checksum = true;
2381 metadata.offload_ip_header_checksum = n.is_ipv4();
2382 metadata.l3_protocol = if n.is_ipv4() {
2383 L3Protocol::Ipv4
2384 } else {
2385 L3Protocol::Ipv6
2386 };
2387 metadata.l2_len = ETHERNET_HEADER_LEN as u8;
2388 if n.tcp_header_offset() < metadata.l2_len as u16 {
2389 return Err(WorkerError::InvalidTcpHeaderOffset);
2390 }
2391 metadata.l3_len = n.tcp_header_offset() - metadata.l2_len as u16;
2392 metadata.l4_len = {
2393 let mut reader = data.clone().reader(mem);
2394 reader
2395 .skip(metadata.l2_len as usize + metadata.l3_len as usize + 12)?;
2396 let mut b = 0;
2397 reader.read(std::slice::from_mut(&mut b))?;
2398 (b >> 4) * 4
2399 };
2400 metadata.max_tcp_segment_size = n.mss() as u16;
2401 }
2402 _ => {}
2403 }
2404 ppi = rest;
2405 }
2406 }
2407
2408 let start = segments.len();
2409 for range in data.paged_ranges().flat_map(|r| r.ranges()) {
2410 let range = range.map_err(WorkerError::InvalidGpadl)?;
2411 segments.push(TxSegment {
2412 ty: net_backend::TxSegmentType::Tail,
2413 gpa: range.start,
2414 len: range.len() as u32,
2415 });
2416 }
2417
2418 metadata.segment_count = segments.len() - start;
2419
2420 stats.tx_packets.increment();
2421 if metadata.offload_tcp_checksum || metadata.offload_udp_checksum {
2422 stats.tx_checksum_packets.increment();
2423 }
2424 if metadata.offload_tcp_segmentation {
2425 stats.tx_lso_packets.increment();
2426 }
2427
2428 segments[start].ty = net_backend::TxSegmentType::Head(metadata);
2429
2430 Ok(())
2431 }
2432
2433 fn guest_vf_is_available(
2436 &mut self,
2437 guest_vf_id: Option<u32>,
2438 version: Version,
2439 config: NdisConfig,
2440 ) -> Result<bool, WorkerError> {
2441 let serial_number = guest_vf_id.map(|vfid| self.adapter.get_guest_vf_serial_number(vfid));
2442 if version >= Version::V4 && config.capabilities.sriov() {
2443 tracing::info!(
2444 available = serial_number.is_some(),
2445 serial_number,
2446 "sending VF association message."
2447 );
2448 let message = {
2450 self.message(
2451 protocol::MESSAGE4_TYPE_SEND_VF_ASSOCIATION,
2452 protocol::Message4SendVfAssociation {
2453 vf_allocated: if serial_number.is_some() { 1 } else { 0 },
2454 serial_number: serial_number.unwrap_or(0),
2455 },
2456 )
2457 };
2458 self.queue
2459 .split()
2460 .1
2461 .try_write(&queue::OutgoingPacket {
2462 transaction_id: VF_ASSOCIATION_TRANSACTION_ID,
2463 packet_type: OutgoingPacketType::InBandWithCompletion,
2464 payload: &message.payload(),
2465 })
2466 .map_err(|err| match err {
2467 queue::TryWriteError::Full(len) => {
2468 tracing::error!(len, "failed to write vf association message");
2469 WorkerError::OutOfSpace
2470 }
2471 queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2472 })?;
2473 Ok(true)
2474 } else {
2475 tracing::info!(
2476 available = serial_number.is_some(),
2477 serial_number,
2478 major = version.major(),
2479 minor = version.minor(),
2480 sriov_capable = config.capabilities.sriov(),
2481 "Skipping NvspMessage4TypeSendVFAssociation message"
2482 );
2483 Ok(false)
2484 }
2485 }
2486
2487 fn guest_send_indirection_table(&mut self, version: Version, num_channels_opened: u32) {
2489 if version < Version::V5 {
2491 return;
2492 }
2493
2494 #[repr(C)]
2495 #[derive(IntoBytes, Immutable, KnownLayout)]
2496 struct SendIndirectionMsg {
2497 pub message: protocol::Message5SendIndirectionTable,
2498 pub send_indirection_table:
2499 [u32; VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES as usize],
2500 }
2501
2502 let send_indirection_table_offset = offset_of!(SendIndirectionMsg, send_indirection_table)
2504 + size_of::<protocol::MessageHeader>();
2505 let mut data = SendIndirectionMsg {
2506 message: protocol::Message5SendIndirectionTable {
2507 table_entry_count: VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES,
2508 table_offset: send_indirection_table_offset as u32,
2509 },
2510 send_indirection_table: Default::default(),
2511 };
2512
2513 for i in 0..data.send_indirection_table.len() {
2514 data.send_indirection_table[i] = i as u32 % num_channels_opened;
2515 }
2516
2517 let message = NvspMessage {
2518 header: protocol::MessageHeader {
2519 message_type: protocol::MESSAGE5_TYPE_SEND_INDIRECTION_TABLE,
2520 },
2521 data,
2522 padding: &[],
2523 };
2524 let result = self
2525 .queue
2526 .split()
2527 .1
2528 .try_write(&queue::OutgoingPacket {
2529 transaction_id: 0,
2530 packet_type: OutgoingPacketType::InBandNoCompletion,
2531 payload: &message.payload(),
2532 })
2533 .map_err(|err| match err {
2534 queue::TryWriteError::Full(len) => {
2535 tracing::error!(len, "failed to write send indirection table message");
2536 WorkerError::OutOfSpace
2537 }
2538 queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2539 });
2540 if let Err(err) = result {
2541 tracing::error!(err = %err, "Failed to notify guest about the send indirection table");
2542 }
2543 }
2544
2545 fn guest_vf_data_path_switched_to_synthetic(&mut self) {
2548 let message = NvspMessage {
2549 header: protocol::MessageHeader {
2550 message_type: protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH,
2551 },
2552 data: protocol::Message4SwitchDataPath {
2553 active_data_path: protocol::DataPath::SYNTHETIC.0,
2554 },
2555 padding: &[],
2556 };
2557 let result = self
2558 .queue
2559 .split()
2560 .1
2561 .try_write(&queue::OutgoingPacket {
2562 transaction_id: SWITCH_DATA_PATH_TRANSACTION_ID,
2563 packet_type: OutgoingPacketType::InBandWithCompletion,
2564 payload: &message.payload(),
2565 })
2566 .map_err(|err| match err {
2567 queue::TryWriteError::Full(len) => {
2568 tracing::error!(
2569 len,
2570 "failed to write MESSAGE4_TYPE_SWITCH_DATA_PATH message"
2571 );
2572 WorkerError::OutOfSpace
2573 }
2574 queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2575 });
2576 if let Err(err) = result {
2577 tracing::error!(err = %err, "Failed to notify guest that data path is now synthetic");
2578 }
2579 }
2580
2581 async fn handle_state_change(
2583 &mut self,
2584 primary: &mut PrimaryChannelState,
2585 buffers: &ChannelBuffers,
2586 ) -> Result<Option<CoordinatorMessage>, WorkerError> {
2587 if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
2593 if primary.rndis_state == RndisState::Operational {
2595 if self.guest_vf_is_available(Some(vfid), buffers.version, buffers.ndis_config)? {
2596 primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
2597 return Ok(Some(CoordinatorMessage::UpdateGuestVfState));
2598 } else if let Some(true) = primary.is_data_path_switched {
2599 tracing::error!(
2600 "Data path switched, but current guest negotiation does not support VTL0 VF"
2601 );
2602 }
2603 }
2604 return Ok(None);
2605 }
2606 loop {
2607 primary.guest_vf_state = match primary.guest_vf_state {
2608 PrimaryChannelGuestVfState::UnavailableFromAvailable => {
2609 if primary.rndis_state == RndisState::Operational {
2611 self.guest_vf_is_available(None, buffers.version, buffers.ndis_config)?;
2612 }
2613 PrimaryChannelGuestVfState::Unavailable
2614 }
2615 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
2616 to_guest,
2617 id,
2618 } => {
2619 self.send_completion(id, &[])?;
2621 if to_guest {
2622 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
2623 } else {
2624 PrimaryChannelGuestVfState::UnavailableFromAvailable
2625 }
2626 }
2627 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
2628 self.guest_vf_data_path_switched_to_synthetic();
2630 PrimaryChannelGuestVfState::UnavailableFromAvailable
2631 }
2632 PrimaryChannelGuestVfState::DataPathSynthetic => {
2633 self.guest_vf_data_path_switched_to_synthetic();
2635 PrimaryChannelGuestVfState::Ready
2636 }
2637 PrimaryChannelGuestVfState::DataPathSwitchPending {
2638 to_guest,
2639 id,
2640 result,
2641 } => {
2642 let result = result.expect("DataPathSwitchPending should have been processed");
2643 self.send_completion(id, &[])?;
2645 if result {
2646 if to_guest {
2647 PrimaryChannelGuestVfState::DataPathSwitched
2648 } else {
2649 PrimaryChannelGuestVfState::Ready
2650 }
2651 } else {
2652 if to_guest {
2653 PrimaryChannelGuestVfState::DataPathSynthetic
2654 } else {
2655 tracing::error!(
2656 "Failure when guest requested switch back to synthetic"
2657 );
2658 PrimaryChannelGuestVfState::DataPathSwitched
2659 }
2660 }
2661 }
2662 PrimaryChannelGuestVfState::Initializing
2663 | PrimaryChannelGuestVfState::Restoring(_) => {
2664 panic!("Invalid guest VF state: {}", primary.guest_vf_state)
2665 }
2666 _ => break,
2667 };
2668 }
2669 Ok(None)
2670 }
2671
2672 fn handle_rndis_control_message(
2675 &mut self,
2676 primary: &mut PrimaryChannelState,
2677 buffers: &ChannelBuffers,
2678 message_type: u32,
2679 mut reader: impl MemoryRead + Clone,
2680 id: u32,
2681 ) -> Result<(), WorkerError> {
2682 let mem = &buffers.mem;
2683 let buffer_range = &buffers.recv_buffer.range(id);
2684 match message_type {
2685 rndisprot::MESSAGE_TYPE_INITIALIZE_MSG => {
2686 if primary.rndis_state != RndisState::Initializing {
2687 return Err(WorkerError::InvalidRndisState);
2688 }
2689
2690 let request: rndisprot::InitializeRequest = reader.read_plain()?;
2691
2692 tracing::trace!(
2693 ?request,
2694 "handling control message MESSAGE_TYPE_INITIALIZE_MSG"
2695 );
2696
2697 primary.rndis_state = RndisState::Operational;
2698
2699 let mut writer = buffer_range.writer(mem);
2700 let message_length = write_rndis_message(
2701 &mut writer,
2702 rndisprot::MESSAGE_TYPE_INITIALIZE_CMPLT,
2703 0,
2704 &rndisprot::InitializeComplete {
2705 request_id: request.request_id,
2706 status: rndisprot::STATUS_SUCCESS,
2707 major_version: rndisprot::MAJOR_VERSION,
2708 minor_version: rndisprot::MINOR_VERSION,
2709 device_flags: rndisprot::DF_CONNECTIONLESS,
2710 medium: rndisprot::MEDIUM_802_3,
2711 max_packets_per_message: 8,
2712 max_transfer_size: 0xEFFFFFFF,
2713 packet_alignment_factor: 3,
2714 af_list_offset: 0,
2715 af_list_size: 0,
2716 },
2717 )?;
2718 self.send_rndis_control_message(buffers, id, message_length)?;
2719 if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
2720 if self.guest_vf_is_available(
2721 Some(vfid),
2722 buffers.version,
2723 buffers.ndis_config,
2724 )? {
2725 primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
2737 if self.restart.is_none() {
2739 self.restart = Some(CoordinatorMessage::UpdateGuestVfState);
2740 }
2741 } else if let Some(true) = primary.is_data_path_switched {
2742 tracing::error!(
2743 "Data path switched, but current guest negotiation does not support VTL0 VF"
2744 );
2745 }
2746 }
2747 }
2748 rndisprot::MESSAGE_TYPE_QUERY_MSG => {
2749 let request: rndisprot::QueryRequest = reader.read_plain()?;
2750
2751 tracing::trace!(?request, "handling control message MESSAGE_TYPE_QUERY_MSG");
2752
2753 let (header, body) = buffer_range
2754 .try_split(
2755 size_of::<rndisprot::MessageHeader>()
2756 + size_of::<rndisprot::QueryComplete>(),
2757 )
2758 .ok_or(WorkerError::RndisMessageTooSmall)?;
2759 let (status, tx) = match self.adapter.handle_oid_query(
2760 buffers,
2761 primary,
2762 request.oid,
2763 body.writer(mem),
2764 ) {
2765 Ok(tx) => (rndisprot::STATUS_SUCCESS, tx),
2766 Err(err) => (err.as_status(), 0),
2767 };
2768
2769 let message_length = write_rndis_message(
2770 &mut header.writer(mem),
2771 rndisprot::MESSAGE_TYPE_QUERY_CMPLT,
2772 tx,
2773 &rndisprot::QueryComplete {
2774 request_id: request.request_id,
2775 status,
2776 information_buffer_offset: size_of::<rndisprot::QueryComplete>() as u32,
2777 information_buffer_length: tx as u32,
2778 },
2779 )?;
2780 self.send_rndis_control_message(buffers, id, message_length)?;
2781 }
2782 rndisprot::MESSAGE_TYPE_SET_MSG => {
2783 let request: rndisprot::SetRequest = reader.read_plain()?;
2784
2785 tracing::trace!(?request, "handling control message MESSAGE_TYPE_SET_MSG");
2786
2787 let status = match self.adapter.handle_oid_set(primary, request.oid, reader) {
2788 Ok(restart_endpoint) => {
2789 if restart_endpoint {
2792 self.restart = Some(CoordinatorMessage::Restart);
2793 }
2794 rndisprot::STATUS_SUCCESS
2795 }
2796 Err(err) => {
2797 tracelimit::warn_ratelimited!(oid = ?request.oid, error = &err as &dyn std::error::Error, "oid failure");
2798 err.as_status()
2799 }
2800 };
2801
2802 let message_length = write_rndis_message(
2803 &mut buffer_range.writer(mem),
2804 rndisprot::MESSAGE_TYPE_SET_CMPLT,
2805 0,
2806 &rndisprot::SetComplete {
2807 request_id: request.request_id,
2808 status,
2809 },
2810 )?;
2811 self.send_rndis_control_message(buffers, id, message_length)?;
2812 }
2813 rndisprot::MESSAGE_TYPE_RESET_MSG => {
2814 return Err(WorkerError::RndisMessageTypeNotImplemented);
2815 }
2816 rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG => {
2817 return Err(WorkerError::RndisMessageTypeNotImplemented);
2818 }
2819 rndisprot::MESSAGE_TYPE_KEEPALIVE_MSG => {
2820 let request: rndisprot::KeepaliveRequest = reader.read_plain()?;
2821
2822 tracing::trace!(
2823 ?request,
2824 "handling control message MESSAGE_TYPE_KEEPALIVE_MSG"
2825 );
2826
2827 let message_length = write_rndis_message(
2828 &mut buffer_range.writer(mem),
2829 rndisprot::MESSAGE_TYPE_KEEPALIVE_CMPLT,
2830 0,
2831 &rndisprot::KeepaliveComplete {
2832 request_id: request.request_id,
2833 status: rndisprot::STATUS_SUCCESS,
2834 },
2835 )?;
2836 self.send_rndis_control_message(buffers, id, message_length)?;
2837 }
2838 rndisprot::MESSAGE_TYPE_SET_EX_MSG => {
2839 return Err(WorkerError::RndisMessageTypeNotImplemented);
2840 }
2841 _ => return Err(WorkerError::UnknownRndisMessageType(message_type)),
2842 };
2843 Ok(())
2844 }
2845
2846 fn try_send_rndis_message(
2847 &mut self,
2848 transaction_id: u64,
2849 channel_type: u32,
2850 recv_buffer_id: u16,
2851 transfer_pages: &[ring::TransferPageRange],
2852 ) -> Result<Option<usize>, WorkerError> {
2853 let message = self.message(
2854 protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET,
2855 protocol::Message1SendRndisPacket {
2856 channel_type,
2857 send_buffer_section_index: 0xffffffff,
2858 send_buffer_section_size: 0,
2859 },
2860 );
2861 let pending_send_size = match self.queue.split().1.try_write(&queue::OutgoingPacket {
2862 transaction_id,
2863 packet_type: OutgoingPacketType::TransferPages(recv_buffer_id, transfer_pages),
2864 payload: &message.payload(),
2865 }) {
2866 Ok(()) => None,
2867 Err(queue::TryWriteError::Full(n)) => Some(n),
2868 Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
2869 };
2870 Ok(pending_send_size)
2871 }
2872
2873 fn send_rndis_control_message(
2874 &mut self,
2875 buffers: &ChannelBuffers,
2876 id: u32,
2877 message_length: usize,
2878 ) -> Result<(), WorkerError> {
2879 let result = self.try_send_rndis_message(
2880 id as u64,
2881 protocol::CONTROL_CHANNEL_TYPE,
2882 buffers.recv_buffer.id,
2883 std::slice::from_ref(&buffers.recv_buffer.transfer_page_range(id, message_length)),
2884 )?;
2885
2886 match result {
2888 None => Ok(()),
2889 Some(len) => {
2890 tracelimit::error_ratelimited!(len, "failed to write control message completion");
2891 Err(WorkerError::OutOfSpace)
2892 }
2893 }
2894 }
2895
2896 fn indicate_status(
2897 &mut self,
2898 buffers: &ChannelBuffers,
2899 id: u32,
2900 status: u32,
2901 payload: &[u8],
2902 ) -> Result<(), WorkerError> {
2903 let buffer = &buffers.recv_buffer.range(id);
2904 let mut writer = buffer.writer(&buffers.mem);
2905 let message_length = write_rndis_message(
2906 &mut writer,
2907 rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG,
2908 payload.len(),
2909 &rndisprot::IndicateStatus {
2910 status,
2911 status_buffer_length: payload.len() as u32,
2912 status_buffer_offset: if payload.is_empty() {
2913 0
2914 } else {
2915 size_of::<rndisprot::IndicateStatus>() as u32
2916 },
2917 },
2918 )?;
2919 writer.write(payload)?;
2920 self.send_rndis_control_message(buffers, id, message_length)?;
2921 Ok(())
2922 }
2923
2924 fn process_control_messages(
2927 &mut self,
2928 buffers: &ChannelBuffers,
2929 state: &mut ActiveState,
2930 ) -> Result<(), WorkerError> {
2931 let Some(primary) = &mut state.primary else {
2932 return Ok(());
2933 };
2934
2935 while !primary.control_messages.is_empty()
2936 || (primary.pending_offload_change && primary.rndis_state == RndisState::Operational)
2937 {
2938 if !self.queue.split().1.can_write(MIN_CONTROL_RING_SIZE)? {
2940 self.pending_send_size = MIN_CONTROL_RING_SIZE;
2941 break;
2942 }
2943 let Some(id) = primary.free_control_buffers.pop() else {
2944 break;
2945 };
2946
2947 assert!(state.rx_bufs.is_free(id.0));
2949 state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
2950
2951 if let Some(message) = primary.control_messages.pop_front() {
2952 primary.control_messages_len -= message.data.len();
2953 self.handle_rndis_control_message(
2954 primary,
2955 buffers,
2956 message.message_type,
2957 message.data.as_ref(),
2958 id.0,
2959 )?;
2960 } else if primary.pending_offload_change
2961 && primary.rndis_state == RndisState::Operational
2962 {
2963 let ndis_offload = primary.offload_config.ndis_offload();
2964 self.indicate_status(
2965 buffers,
2966 id.0,
2967 rndisprot::STATUS_TASK_OFFLOAD_CURRENT_CONFIG,
2968 &ndis_offload.as_bytes()[..ndis_offload.header.size.into()],
2969 )?;
2970 primary.pending_offload_change = false;
2971 } else {
2972 unreachable!();
2973 }
2974 }
2975 Ok(())
2976 }
2977}
2978
2979fn write_rndis_message<T: IntoBytes + Immutable + KnownLayout>(
2981 writer: &mut impl MemoryWrite,
2982 message_type: u32,
2983 extra: usize,
2984 payload: &T,
2985) -> Result<usize, AccessError> {
2986 let message_length = size_of::<rndisprot::MessageHeader>() + size_of_val(payload) + extra;
2987 writer.write(
2988 rndisprot::MessageHeader {
2989 message_type,
2990 message_length: message_length as u32,
2991 }
2992 .as_bytes(),
2993 )?;
2994 writer.write(payload.as_bytes())?;
2995 Ok(message_length)
2996}
2997
2998#[derive(Debug, Error)]
2999enum OidError {
3000 #[error(transparent)]
3001 Access(#[from] AccessError),
3002 #[error("unknown oid")]
3003 UnknownOid,
3004 #[error("invalid oid input, bad field {0}")]
3005 InvalidInput(&'static str),
3006 #[error("bad ndis version")]
3007 BadVersion,
3008 #[error("feature {0} not supported")]
3009 NotSupported(&'static str),
3010}
3011
3012impl OidError {
3013 fn as_status(&self) -> u32 {
3014 match self {
3015 OidError::UnknownOid | OidError::NotSupported(_) => rndisprot::STATUS_NOT_SUPPORTED,
3016 OidError::BadVersion => rndisprot::STATUS_BAD_VERSION,
3017 OidError::InvalidInput(_) => rndisprot::STATUS_INVALID_DATA,
3018 OidError::Access(_) => rndisprot::STATUS_FAILURE,
3019 }
3020 }
3021}
3022
3023const DEFAULT_MTU: u32 = 1514;
3024const MIN_MTU: u32 = DEFAULT_MTU;
3025const MAX_MTU: u32 = 9216;
3026
3027const ETHERNET_HEADER_LEN: u32 = 14;
3028
3029impl Adapter {
3030 fn get_guest_vf_serial_number(&self, vfid: u32) -> u32 {
3031 if let Some(guest_os_id) = self.get_guest_os_id.as_ref().map(|f| f()) {
3032 if guest_os_id
3035 .microsoft()
3036 .unwrap_or(HvGuestOsMicrosoft::from(0))
3037 .os_id()
3038 == HvGuestOsMicrosoftIds::WINDOWS_NT.0
3039 {
3040 self.adapter_index
3041 } else {
3042 vfid
3043 }
3044 } else {
3045 vfid
3046 }
3047 }
3048
3049 fn handle_oid_query(
3050 &self,
3051 buffers: &ChannelBuffers,
3052 primary: &PrimaryChannelState,
3053 oid: rndisprot::Oid,
3054 mut writer: impl MemoryWrite,
3055 ) -> Result<usize, OidError> {
3056 tracing::debug!(?oid, "oid query");
3057 let available_len = writer.len();
3058 match oid {
3059 rndisprot::Oid::OID_GEN_SUPPORTED_LIST => {
3060 let supported_oids_common = &[
3061 rndisprot::Oid::OID_GEN_SUPPORTED_LIST,
3062 rndisprot::Oid::OID_GEN_HARDWARE_STATUS,
3063 rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED,
3064 rndisprot::Oid::OID_GEN_MEDIA_IN_USE,
3065 rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD,
3066 rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD,
3067 rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE,
3068 rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE,
3069 rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE,
3070 rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE,
3071 rndisprot::Oid::OID_GEN_LINK_SPEED,
3072 rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE,
3073 rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE,
3074 rndisprot::Oid::OID_GEN_VENDOR_ID,
3075 rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION,
3076 rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION,
3077 rndisprot::Oid::OID_GEN_DRIVER_VERSION,
3078 rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER,
3079 rndisprot::Oid::OID_GEN_PROTOCOL_OPTIONS,
3080 rndisprot::Oid::OID_GEN_MAC_OPTIONS,
3081 rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS,
3082 rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS,
3083 rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES,
3084 rndisprot::Oid::OID_GEN_FRIENDLY_NAME,
3085 rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS,
3087 rndisprot::Oid::OID_802_3_CURRENT_ADDRESS,
3088 rndisprot::Oid::OID_802_3_MULTICAST_LIST,
3089 rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE,
3090 rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT,
3092 rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION,
3093 rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS,
3094 rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER,
3099 ];
3100
3101 let supported_oids_6 = &[
3103 rndisprot::Oid::OID_GEN_LINK_PARAMETERS,
3105 rndisprot::Oid::OID_GEN_LINK_STATE,
3106 rndisprot::Oid::OID_GEN_MAX_LINK_SPEED,
3107 rndisprot::Oid::OID_GEN_BYTES_RCV,
3109 rndisprot::Oid::OID_GEN_BYTES_XMIT,
3110 rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS,
3112 rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION,
3113 rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES,
3114 rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG,
3115 ];
3118
3119 let supported_oids_63 = &[
3120 rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES,
3121 rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS,
3122 ];
3123
3124 match buffers.ndis_version.major {
3125 5 => {
3126 writer.write(supported_oids_common.as_bytes())?;
3127 }
3128 6 => {
3129 writer.write(supported_oids_common.as_bytes())?;
3130 writer.write(supported_oids_6.as_bytes())?;
3131 if buffers.ndis_version.minor >= 30 {
3132 writer.write(supported_oids_63.as_bytes())?;
3133 }
3134 }
3135 _ => return Err(OidError::BadVersion),
3136 }
3137 }
3138 rndisprot::Oid::OID_GEN_HARDWARE_STATUS => {
3139 let status: u32 = 0; writer.write(status.as_bytes())?;
3141 }
3142 rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED | rndisprot::Oid::OID_GEN_MEDIA_IN_USE => {
3143 writer.write(rndisprot::MEDIUM_802_3.as_bytes())?;
3144 }
3145 rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD
3146 | rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD
3147 | rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE => {
3148 let len: u32 = buffers.ndis_config.mtu - ETHERNET_HEADER_LEN;
3149 writer.write(len.as_bytes())?;
3150 }
3151 rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE
3152 | rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE
3153 | rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE => {
3154 let len: u32 = buffers.ndis_config.mtu;
3155 writer.write(len.as_bytes())?;
3156 }
3157 rndisprot::Oid::OID_GEN_LINK_SPEED => {
3158 let speed: u32 = (self.link_speed / 100) as u32; writer.write(speed.as_bytes())?;
3160 }
3161 rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE
3162 | rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE => {
3163 writer.write((256u32 * 1024).as_bytes())?
3165 }
3166 rndisprot::Oid::OID_GEN_VENDOR_ID => {
3167 writer.write(0x0000155du32.as_bytes())?;
3170 }
3171 rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION => writer.write(b"Microsoft Corporation")?,
3172 rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION
3173 | rndisprot::Oid::OID_GEN_DRIVER_VERSION => {
3174 writer.write(0x0100u16.as_bytes())? }
3176 rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => writer.write(0u32.as_bytes())?,
3177 rndisprot::Oid::OID_GEN_MAC_OPTIONS => {
3178 let options: u32 = rndisprot::MAC_OPTION_COPY_LOOKAHEAD_DATA
3179 | rndisprot::MAC_OPTION_TRANSFERS_NOT_PEND
3180 | rndisprot::MAC_OPTION_NO_LOOPBACK;
3181 writer.write(options.as_bytes())?;
3182 }
3183 rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS => {
3184 writer.write(rndisprot::MEDIA_STATE_CONNECTED.as_bytes())?;
3185 }
3186 rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS => writer.write(u32::MAX.as_bytes())?,
3187 rndisprot::Oid::OID_GEN_FRIENDLY_NAME => {
3188 let name16: Vec<u16> = "Network Device".encode_utf16().collect();
3189 let mut name = rndisprot::FriendlyName::new_zeroed();
3190 name.name[..name16.len()].copy_from_slice(&name16);
3191 writer.write(name.as_bytes())?
3192 }
3193 rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS
3194 | rndisprot::Oid::OID_802_3_CURRENT_ADDRESS => {
3195 writer.write(&self.mac_address.to_bytes())?
3196 }
3197 rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE => {
3198 writer.write(0u32.as_bytes())?;
3199 }
3200 rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT
3201 | rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION
3202 | rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS => writer.write(0u32.as_bytes())?,
3203
3204 rndisprot::Oid::OID_GEN_LINK_STATE => {
3206 let link_state = rndisprot::LinkState {
3207 header: rndisprot::NdisObjectHeader {
3208 object_type: rndisprot::NdisObjectType::DEFAULT,
3209 revision: 1,
3210 size: size_of::<rndisprot::LinkState>() as u16,
3211 },
3212 media_connect_state: 1, media_duplex_state: 0, padding: 0,
3215 xmit_link_speed: self.link_speed,
3216 rcv_link_speed: self.link_speed,
3217 pause_functions: 0, auto_negotiation_flags: 0,
3219 };
3220 writer.write(link_state.as_bytes())?;
3221 }
3222 rndisprot::Oid::OID_GEN_MAX_LINK_SPEED => {
3223 let link_speed = rndisprot::LinkSpeed {
3224 xmit: self.link_speed,
3225 rcv: self.link_speed,
3226 };
3227 writer.write(link_speed.as_bytes())?;
3228 }
3229 rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES => {
3230 let ndis_offload = self.offload_support.ndis_offload();
3231 writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3232 }
3233 rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG => {
3234 let ndis_offload = primary.offload_config.ndis_offload();
3235 writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3236 }
3237 rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3238 writer.write(
3239 &rndisprot::NdisOffloadEncapsulation {
3240 header: rndisprot::NdisObjectHeader {
3241 object_type: rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3242 revision: 1,
3243 size: rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1 as u16,
3244 },
3245 ipv4_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3246 ipv4_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3247 ipv4_header_size: ETHERNET_HEADER_LEN,
3248 ipv6_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3249 ipv6_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3250 ipv6_header_size: ETHERNET_HEADER_LEN,
3251 }
3252 .as_bytes()[..rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1],
3253 )?;
3254 }
3255 rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES => {
3256 writer.write(
3257 &rndisprot::NdisReceiveScaleCapabilities {
3258 header: rndisprot::NdisObjectHeader {
3259 object_type: rndisprot::NdisObjectType::RSS_CAPABILITIES,
3260 revision: 2,
3261 size: rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2
3262 as u16,
3263 },
3264 capabilities_flags: rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV4
3265 | rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV6
3266 | rndisprot::NDIS_HASH_FUNCTION_TOEPLITZ,
3267 number_of_interrupt_messages: 1,
3268 number_of_receive_queues: self.max_queues.into(),
3269 number_of_indirection_table_entries: if self.indirection_table_size != 0 {
3270 self.indirection_table_size
3271 } else {
3272 128
3275 },
3276 padding: 0,
3277 }
3278 .as_bytes()[..rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2],
3279 )?;
3280 }
3281 _ => {
3282 tracelimit::warn_ratelimited!(?oid, "query for unknown OID");
3283 return Err(OidError::UnknownOid);
3284 }
3285 };
3286 Ok(available_len - writer.len())
3287 }
3288
3289 fn handle_oid_set(
3290 &self,
3291 primary: &mut PrimaryChannelState,
3292 oid: rndisprot::Oid,
3293 reader: impl MemoryRead + Clone,
3294 ) -> Result<bool, OidError> {
3295 tracing::debug!(?oid, "oid set");
3296
3297 let mut restart_endpoint = false;
3298 match oid {
3299 rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => {
3300 }
3302 rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS => {
3303 self.oid_set_offload_parameters(reader, primary)?;
3304 }
3305 rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3306 self.oid_set_offload_encapsulation(reader)?;
3307 }
3308 rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER => {
3309 self.oid_set_rndis_config_parameter(reader, primary)?;
3310 }
3311 rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES => {
3312 }
3314 rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS => {
3315 self.oid_set_rss_parameters(reader, primary)?;
3316
3317 restart_endpoint = true;
3321 }
3322 _ => {
3323 tracelimit::warn_ratelimited!(?oid, "set of unknown OID");
3324 return Err(OidError::UnknownOid);
3325 }
3326 }
3327 Ok(restart_endpoint)
3328 }
3329
3330 fn oid_set_rss_parameters(
3331 &self,
3332 mut reader: impl MemoryRead + Clone,
3333 primary: &mut PrimaryChannelState,
3334 ) -> Result<(), OidError> {
3335 let mut params = rndisprot::NdisReceiveScaleParameters::new_zeroed();
3337 let len = reader.len().min(size_of_val(¶ms));
3338 reader.clone().read(&mut params.as_mut_bytes()[..len])?;
3339
3340 if ((params.flags & NDIS_RSS_PARAM_FLAG_DISABLE_RSS) != 0)
3341 || ((params.hash_information & NDIS_HASH_FUNCTION_MASK) == 0)
3342 {
3343 primary.rss_state = None;
3344 return Ok(());
3345 }
3346
3347 if params.hash_secret_key_size != 40 {
3348 return Err(OidError::InvalidInput("hash_secret_key_size"));
3349 }
3350 if params.indirection_table_size % 4 != 0 {
3351 return Err(OidError::InvalidInput("indirection_table_size"));
3352 }
3353 let indirection_table_size =
3354 (params.indirection_table_size / 4).min(self.indirection_table_size) as usize;
3355 let mut key = [0; 40];
3356 let mut indirection_table = vec![0u32; self.indirection_table_size as usize];
3357 reader
3358 .clone()
3359 .skip(params.hash_secret_key_offset as usize)?
3360 .read(&mut key)?;
3361 reader
3362 .skip(params.indirection_table_offset as usize)?
3363 .read(indirection_table[..indirection_table_size].as_mut_bytes())?;
3364 tracelimit::info_ratelimited!(?indirection_table, "OID_GEN_RECEIVE_SCALE_PARAMETERS");
3365 if indirection_table
3366 .iter()
3367 .any(|&x| x >= self.max_queues as u32)
3368 {
3369 return Err(OidError::InvalidInput("indirection_table"));
3370 }
3371 let (indir_init, indir_uninit) = indirection_table.split_at_mut(indirection_table_size);
3372 for (src, dest) in std::iter::repeat_with(|| indir_init.iter().copied())
3373 .flatten()
3374 .zip(indir_uninit)
3375 {
3376 *dest = src;
3377 }
3378 primary.rss_state = Some(RssState {
3379 key,
3380 indirection_table: indirection_table.iter().map(|&x| x as u16).collect(),
3381 });
3382 Ok(())
3383 }
3384
3385 fn oid_set_offload_parameters(
3386 &self,
3387 reader: impl MemoryRead + Clone,
3388 primary: &mut PrimaryChannelState,
3389 ) -> Result<(), OidError> {
3390 let offload: rndisprot::NdisOffloadParameters = read_ndis_object(
3391 reader,
3392 rndisprot::NdisObjectType::DEFAULT,
3393 1,
3394 rndisprot::NDIS_SIZEOF_OFFLOAD_PARAMETERS_REVISION_1,
3395 )?;
3396
3397 tracing::debug!(?offload, "offload parameters");
3398 let rndisprot::NdisOffloadParameters {
3399 header: _,
3400 ipv4_checksum,
3401 tcp4_checksum,
3402 udp4_checksum,
3403 tcp6_checksum,
3404 udp6_checksum,
3405 lsov1,
3406 ipsec_v1: _,
3407 lsov2_ipv4,
3408 lsov2_ipv6,
3409 tcp_connection_ipv4: _,
3410 tcp_connection_ipv6: _,
3411 reserved: _,
3412 flags: _,
3413 } = offload;
3414
3415 if lsov1 == rndisprot::OffloadParametersSimple::ENABLED {
3416 return Err(OidError::NotSupported("lsov1"));
3417 }
3418 if let Some((tx, rx)) = ipv4_checksum.tx_rx() {
3419 primary.offload_config.checksum_tx.ipv4_header = tx;
3420 primary.offload_config.checksum_rx.ipv4_header = rx;
3421 }
3422 if let Some((tx, rx)) = tcp4_checksum.tx_rx() {
3423 primary.offload_config.checksum_tx.tcp4 = tx;
3424 primary.offload_config.checksum_rx.tcp4 = rx;
3425 }
3426 if let Some((tx, rx)) = tcp6_checksum.tx_rx() {
3427 primary.offload_config.checksum_tx.tcp6 = tx;
3428 primary.offload_config.checksum_rx.tcp6 = rx;
3429 }
3430 if let Some((tx, rx)) = udp4_checksum.tx_rx() {
3431 primary.offload_config.checksum_tx.udp4 = tx;
3432 primary.offload_config.checksum_rx.udp4 = rx;
3433 }
3434 if let Some((tx, rx)) = udp6_checksum.tx_rx() {
3435 primary.offload_config.checksum_tx.udp6 = tx;
3436 primary.offload_config.checksum_rx.udp6 = rx;
3437 }
3438 if let Some(enable) = lsov2_ipv4.enable() {
3439 primary.offload_config.lso4 = enable && self.offload_support.lso4;
3440 }
3441 if let Some(enable) = lsov2_ipv6.enable() {
3442 primary.offload_config.lso6 = enable && self.offload_support.lso6;
3443 }
3444 primary.pending_offload_change = true;
3445 Ok(())
3446 }
3447
3448 fn oid_set_offload_encapsulation(
3449 &self,
3450 reader: impl MemoryRead + Clone,
3451 ) -> Result<(), OidError> {
3452 let encap: rndisprot::NdisOffloadEncapsulation = read_ndis_object(
3453 reader,
3454 rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3455 1,
3456 rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1,
3457 )?;
3458 if encap.ipv4_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3459 && (encap.ipv4_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3460 || encap.ipv4_header_size != ETHERNET_HEADER_LEN)
3461 {
3462 return Err(OidError::NotSupported("ipv4 encap"));
3463 }
3464 if encap.ipv6_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3465 && (encap.ipv6_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3466 || encap.ipv6_header_size != ETHERNET_HEADER_LEN)
3467 {
3468 return Err(OidError::NotSupported("ipv6 encap"));
3469 }
3470 Ok(())
3471 }
3472
3473 fn oid_set_rndis_config_parameter(
3474 &self,
3475 reader: impl MemoryRead + Clone,
3476 primary: &mut PrimaryChannelState,
3477 ) -> Result<(), OidError> {
3478 let info: rndisprot::RndisConfigParameterInfo = reader.clone().read_plain()?;
3479 if info.name_length > 255 {
3480 return Err(OidError::InvalidInput("name_length"));
3481 }
3482 if info.value_length > 255 {
3483 return Err(OidError::InvalidInput("value_length"));
3484 }
3485 let name = reader
3486 .clone()
3487 .skip(info.name_offset as usize)?
3488 .read_n::<u16>(info.name_length as usize / 2)?;
3489 let name = String::from_utf16(&name).map_err(|_| OidError::InvalidInput("name"))?;
3490 let mut value = reader;
3491 value.skip(info.value_offset as usize)?;
3492 let mut value = value.limit(info.value_length as usize);
3493 match info.parameter_type {
3494 rndisprot::NdisParameterType::STRING => {
3495 let value = value.read_n::<u16>(info.value_length as usize / 2)?;
3496 let value =
3497 String::from_utf16(&value).map_err(|_| OidError::InvalidInput("value"))?;
3498 let as_num = value.as_bytes().first().map_or(0, |c| c - b'0');
3499 let tx = as_num & 1 != 0;
3500 let rx = as_num & 2 != 0;
3501
3502 tracing::debug!(name, value, "rndis config");
3503 match name.as_str() {
3504 "*IPChecksumOffloadIPv4" => {
3505 primary.offload_config.checksum_tx.ipv4_header = tx;
3506 primary.offload_config.checksum_rx.ipv4_header = rx;
3507 }
3508 "*LsoV2IPv4" => {
3509 primary.offload_config.lso4 = as_num != 0 && self.offload_support.lso4;
3510 }
3511 "*LsoV2IPv6" => {
3512 primary.offload_config.lso6 = as_num != 0 && self.offload_support.lso6;
3513 }
3514 "*TCPChecksumOffloadIPv4" => {
3515 primary.offload_config.checksum_tx.tcp4 = tx;
3516 primary.offload_config.checksum_rx.tcp4 = rx;
3517 }
3518 "*TCPChecksumOffloadIPv6" => {
3519 primary.offload_config.checksum_tx.tcp6 = tx;
3520 primary.offload_config.checksum_rx.tcp6 = rx;
3521 }
3522 "*UDPChecksumOffloadIPv4" => {
3523 primary.offload_config.checksum_tx.udp4 = tx;
3524 primary.offload_config.checksum_rx.udp4 = rx;
3525 }
3526 "*UDPChecksumOffloadIPv6" => {
3527 primary.offload_config.checksum_tx.udp6 = tx;
3528 primary.offload_config.checksum_rx.udp6 = rx;
3529 }
3530 _ => {}
3531 }
3532 }
3533 rndisprot::NdisParameterType::INTEGER => {
3534 let value: u32 = value.read_plain()?;
3535 tracing::debug!(name, value, "rndis config");
3536 }
3537 parameter_type => tracelimit::warn_ratelimited!(
3538 name,
3539 ?parameter_type,
3540 "unhandled rndis config parameter type"
3541 ),
3542 }
3543 Ok(())
3544 }
3545}
3546
3547fn read_ndis_object<T: IntoBytes + FromBytes + Debug + Immutable + KnownLayout>(
3548 mut reader: impl MemoryRead,
3549 object_type: rndisprot::NdisObjectType,
3550 min_revision: u8,
3551 min_size: usize,
3552) -> Result<T, OidError> {
3553 let mut buffer = T::new_zeroed();
3554 let sent_size = reader.len();
3555 let len = sent_size.min(size_of_val(&buffer));
3556 reader.read(&mut buffer.as_mut_bytes()[..len])?;
3557 validate_ndis_object_header(
3558 &rndisprot::NdisObjectHeader::read_from_prefix(buffer.as_bytes())
3559 .unwrap()
3560 .0, sent_size,
3562 object_type,
3563 min_revision,
3564 min_size,
3565 )?;
3566 Ok(buffer)
3567}
3568
3569fn validate_ndis_object_header(
3570 header: &rndisprot::NdisObjectHeader,
3571 sent_size: usize,
3572 object_type: rndisprot::NdisObjectType,
3573 min_revision: u8,
3574 min_size: usize,
3575) -> Result<(), OidError> {
3576 if header.object_type != object_type {
3577 return Err(OidError::InvalidInput("header.object_type"));
3578 }
3579 if sent_size < header.size as usize {
3580 return Err(OidError::InvalidInput("header.size"));
3581 }
3582 if header.revision < min_revision {
3583 return Err(OidError::InvalidInput("header.revision"));
3584 }
3585 if (header.size as usize) < min_size {
3586 return Err(OidError::InvalidInput("header.size"));
3587 }
3588 Ok(())
3589}
3590
3591struct Coordinator {
3592 recv: mpsc::Receiver<CoordinatorMessage>,
3593 channel_control: ChannelControl,
3594 restart: bool,
3595 workers: Vec<TaskControl<NetQueue, Worker<GpadlRingMem>>>,
3596 buffers: Option<Arc<ChannelBuffers>>,
3597 num_queues: u16,
3598}
3599
3600enum CoordinatorStatePendingVfState {
3604 Ready,
3606 Delay {
3608 timer: PolledTimer,
3609 delay_until: Instant,
3610 },
3611 Pending,
3613}
3614
3615struct CoordinatorState {
3616 endpoint: Box<dyn Endpoint>,
3617 adapter: Arc<Adapter>,
3618 virtual_function: Option<Box<dyn VirtualFunction>>,
3619 pending_vf_state: CoordinatorStatePendingVfState,
3620}
3621
3622impl InspectTaskMut<Coordinator> for CoordinatorState {
3623 fn inspect_mut(
3624 &mut self,
3625 req: inspect::Request<'_>,
3626 mut coordinator: Option<&mut Coordinator>,
3627 ) {
3628 let mut resp = req.respond();
3629
3630 let adapter = self.adapter.as_ref();
3631 resp.field("mac_address", adapter.mac_address)
3632 .field("max_queues", adapter.max_queues)
3633 .sensitivity_field(
3634 "offload_support",
3635 SensitivityLevel::Safe,
3636 &adapter.offload_support,
3637 )
3638 .field_mut_with("ring_size_limit", |v| -> anyhow::Result<_> {
3639 if let Some(v) = v {
3640 let v: usize = v.parse()?;
3641 adapter.ring_size_limit.store(v, Ordering::Relaxed);
3642 if let Some(this) = &mut coordinator {
3644 for worker in &mut this.workers {
3645 worker.update_with(|_, _| ());
3646 }
3647 }
3648 }
3649 Ok(adapter.ring_size_limit.load(Ordering::Relaxed))
3650 });
3651
3652 resp.field("endpoint_type", self.endpoint.endpoint_type())
3653 .field(
3654 "endpoint_max_queues",
3655 self.endpoint.multiqueue_support().max_queues,
3656 )
3657 .sensitivity_field_mut("endpoint", SensitivityLevel::Safe, self.endpoint.as_mut());
3658
3659 if let Some(coordinator) = coordinator {
3660 resp.sensitivity_child("queues", SensitivityLevel::Safe, |req| {
3661 let mut resp = req.respond();
3662 for (i, q) in coordinator.workers[..coordinator.num_queues as usize]
3663 .iter_mut()
3664 .enumerate()
3665 {
3666 resp.field_mut(&i.to_string(), q);
3667 }
3668 });
3669
3670 resp.merge(inspect::adhoc_mut(|req| {
3672 let deferred = req.defer();
3673 coordinator.workers[0].update_with(|_, worker| {
3674 if let Some(state) = worker.and_then(|worker| worker.state.ready()) {
3675 deferred.respond(|resp| {
3676 resp.merge(&state.buffers);
3677 resp.sensitivity_field(
3678 "primary_channel_state",
3679 SensitivityLevel::Safe,
3680 &state.state.primary,
3681 );
3682 });
3683 }
3684 })
3685 }));
3686 }
3687 }
3688}
3689
3690impl AsyncRun<Coordinator> for CoordinatorState {
3691 async fn run(
3692 &mut self,
3693 stop: &mut StopTask<'_>,
3694 coordinator: &mut Coordinator,
3695 ) -> Result<(), task_control::Cancelled> {
3696 coordinator.process(stop, self).await
3697 }
3698}
3699
3700impl Coordinator {
3701 async fn process(
3702 &mut self,
3703 stop: &mut StopTask<'_>,
3704 state: &mut CoordinatorState,
3705 ) -> Result<(), task_control::Cancelled> {
3706 let mut sleep_duration: Option<Instant> = None;
3707 loop {
3708 if self.restart {
3709 stop.until_stopped(self.stop_workers()).await?;
3710 if let Err(err) = self
3713 .restart_queues(state)
3714 .instrument(tracing::info_span!("netvsp_restart_queues"))
3715 .await
3716 {
3717 tracing::error!(
3718 error = &err as &dyn std::error::Error,
3719 "failed to restart queues"
3720 );
3721 }
3722 if let Some(primary) = self.primary_mut() {
3723 primary.is_data_path_switched =
3724 state.endpoint.get_data_path_to_guest_vf().await.ok();
3725 tracing::info!(
3726 is_data_path_switched = primary.is_data_path_switched,
3727 "Query data path state"
3728 );
3729 }
3730 self.restore_guest_vf_state(state).await;
3731 self.restart = false;
3732 }
3733
3734 for worker in &mut self.workers[1..] {
3737 worker.start();
3738 }
3739
3740 enum Message {
3741 Internal(CoordinatorMessage),
3742 ChannelDisconnected,
3743 UpdateFromEndpoint(EndpointAction),
3744 UpdateFromVf(Rpc<(), ()>),
3745 OfferVfDevice,
3746 PendingVfStateComplete,
3747 TimerExpired,
3748 }
3749 let timer_sleep = async {
3750 if let Some(sleep_duration) = sleep_duration {
3751 let mut timer = PolledTimer::new(&state.adapter.driver);
3752 timer.sleep_until(sleep_duration).await;
3753 } else {
3754 pending::<()>().await;
3755 }
3756 Message::TimerExpired
3757 };
3758 let message = {
3759 let wait_for_message = async {
3760 let internal_msg = self
3761 .recv
3762 .next()
3763 .map(|x| x.map_or(Message::ChannelDisconnected, Message::Internal));
3764 let endpoint_restart = state
3765 .endpoint
3766 .wait_for_endpoint_action()
3767 .map(Message::UpdateFromEndpoint);
3768 if let Some(vf) = state.virtual_function.as_mut() {
3769 match state.pending_vf_state {
3770 CoordinatorStatePendingVfState::Ready
3771 | CoordinatorStatePendingVfState::Delay { .. } => {
3772 let offer_device = async {
3773 if let CoordinatorStatePendingVfState::Delay {
3774 timer,
3775 delay_until,
3776 } = &mut state.pending_vf_state
3777 {
3778 timer.sleep_until(*delay_until).await;
3779 } else {
3780 pending::<()>().await;
3781 }
3782 Message::OfferVfDevice
3783 };
3784 (
3785 internal_msg,
3786 offer_device,
3787 endpoint_restart,
3788 vf.wait_for_state_change().map(Message::UpdateFromVf),
3789 timer_sleep,
3790 )
3791 .race()
3792 .await
3793 }
3794 CoordinatorStatePendingVfState::Pending => {
3795 vf.guest_ready_for_device().await;
3807 Message::PendingVfStateComplete
3808 }
3809 }
3810 } else {
3811 (internal_msg, endpoint_restart, timer_sleep).race().await
3812 }
3813 };
3814
3815 let mut wait_for_message = std::pin::pin!(wait_for_message);
3816 match (&mut wait_for_message).now_or_never() {
3817 Some(message) => message,
3818 None => {
3819 self.workers[0].start();
3820 stop.until_stopped(wait_for_message).await?
3821 }
3822 }
3823 };
3824 match message {
3825 Message::UpdateFromVf(rpc) => {
3826 rpc.handle(async |_| {
3827 self.update_guest_vf_state(state).await;
3828 })
3829 .await;
3830 }
3831 Message::OfferVfDevice => {
3832 self.workers[0].stop().await;
3833 if let Some(primary) = self.primary_mut() {
3834 if matches!(
3835 primary.guest_vf_state,
3836 PrimaryChannelGuestVfState::AvailableAdvertised
3837 ) {
3838 primary.guest_vf_state = PrimaryChannelGuestVfState::Ready;
3839 }
3840 }
3841
3842 state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
3843 }
3844 Message::PendingVfStateComplete => {
3845 state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
3846 }
3847 Message::TimerExpired => {
3848 if self.workers[0].is_running() {
3850 self.workers[0].stop().await;
3851 if let Some(primary) = self.primary_mut() {
3852 if let PendingLinkAction::Delay(up) = primary.pending_link_action {
3853 primary.pending_link_action = PendingLinkAction::Active(up);
3854 }
3855 }
3856 }
3857 sleep_duration = None;
3858 }
3859 Message::Internal(CoordinatorMessage::UpdateGuestVfState) => {
3860 self.update_guest_vf_state(state).await;
3861 }
3862 Message::UpdateFromEndpoint(EndpointAction::RestartRequired) => self.restart = true,
3863 Message::UpdateFromEndpoint(EndpointAction::LinkStatusNotify(connect)) => {
3864 self.workers[0].stop().await;
3865
3866 if let Some(primary) = self.primary_mut() {
3874 primary.pending_link_action = PendingLinkAction::Active(connect);
3875 }
3876
3877 sleep_duration = None;
3879 }
3880 Message::Internal(CoordinatorMessage::Restart) => self.restart = true,
3881 Message::Internal(CoordinatorMessage::StartTimer(duration)) => {
3882 sleep_duration = Some(duration);
3883 self.workers[0].stop().await;
3885 }
3886 Message::ChannelDisconnected => {
3887 break;
3888 }
3889 };
3890 }
3891 Ok(())
3892 }
3893
3894 async fn stop_workers(&mut self) {
3895 for worker in &mut self.workers {
3896 worker.stop().await;
3897 }
3898 }
3899
3900 async fn restore_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
3901 let primary = match self.primary_mut() {
3902 Some(primary) => primary,
3903 None => return,
3904 };
3905
3906 let virtual_function = c_state.virtual_function.as_mut();
3908 let guest_vf_id = match &virtual_function {
3909 Some(vf) => vf.id().await,
3910 None => None,
3911 };
3912 if let Some(guest_vf_id) = guest_vf_id {
3913 match primary.guest_vf_state {
3915 PrimaryChannelGuestVfState::AvailableAdvertised
3916 | PrimaryChannelGuestVfState::Restoring(
3917 saved_state::GuestVfState::AvailableAdvertised,
3918 ) => {
3919 if !primary.is_data_path_switched.unwrap_or(false) {
3920 let timer = PolledTimer::new(&c_state.adapter.driver);
3921 c_state.pending_vf_state = CoordinatorStatePendingVfState::Delay {
3922 timer,
3923 delay_until: Instant::now() + VF_DEVICE_DELAY,
3924 };
3925 }
3926 }
3927 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
3928 | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
3929 | PrimaryChannelGuestVfState::Ready
3930 | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready)
3931 | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
3932 | PrimaryChannelGuestVfState::Restoring(
3933 saved_state::GuestVfState::DataPathSwitchPending { .. },
3934 )
3935 | PrimaryChannelGuestVfState::DataPathSwitched
3936 | PrimaryChannelGuestVfState::Restoring(
3937 saved_state::GuestVfState::DataPathSwitched,
3938 )
3939 | PrimaryChannelGuestVfState::DataPathSynthetic => {
3940 c_state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
3941 }
3942 _ => (),
3943 };
3944 if let PrimaryChannelGuestVfState::Restoring(
3946 saved_state::GuestVfState::DataPathSwitchPending {
3947 to_guest,
3948 id,
3949 result,
3950 },
3951 ) = primary.guest_vf_state
3952 {
3953 if result.is_some() {
3955 primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
3956 to_guest,
3957 id,
3958 result,
3959 };
3960 return;
3961 }
3962 }
3963 primary.guest_vf_state = match primary.guest_vf_state {
3964 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
3965 | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
3966 | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
3967 | PrimaryChannelGuestVfState::Restoring(
3968 saved_state::GuestVfState::DataPathSwitchPending { .. },
3969 )
3970 | PrimaryChannelGuestVfState::DataPathSynthetic => {
3971 let (to_guest, id) = match primary.guest_vf_state {
3972 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
3973 to_guest,
3974 id,
3975 }
3976 | PrimaryChannelGuestVfState::DataPathSwitchPending {
3977 to_guest, id, ..
3978 }
3979 | PrimaryChannelGuestVfState::Restoring(
3980 saved_state::GuestVfState::DataPathSwitchPending {
3981 to_guest, id, ..
3982 },
3983 ) => (to_guest, id),
3984 _ => (true, None),
3985 };
3986 if matches!(
3989 c_state.pending_vf_state,
3990 CoordinatorStatePendingVfState::Delay { .. }
3991 ) {
3992 c_state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
3993 }
3994 let result = c_state.endpoint.set_data_path_to_guest_vf(to_guest).await;
3995 let result = if let Err(err) = result {
3996 tracing::error!(
3997 err = err.as_ref() as &dyn std::error::Error,
3998 to_guest,
3999 "Failed to switch guest VF data path"
4000 );
4001 false
4002 } else {
4003 primary.is_data_path_switched = Some(to_guest);
4004 true
4005 };
4006 match primary.guest_vf_state {
4007 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4008 ..
4009 }
4010 | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4011 | PrimaryChannelGuestVfState::Restoring(
4012 saved_state::GuestVfState::DataPathSwitchPending { .. },
4013 ) => PrimaryChannelGuestVfState::DataPathSwitchPending {
4014 to_guest,
4015 id,
4016 result: Some(result),
4017 },
4018 _ if result => PrimaryChannelGuestVfState::DataPathSwitched,
4019 _ => PrimaryChannelGuestVfState::DataPathSynthetic,
4020 }
4021 }
4022 PrimaryChannelGuestVfState::Initializing
4023 | PrimaryChannelGuestVfState::Unavailable
4024 | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
4025 PrimaryChannelGuestVfState::Available { vfid: guest_vf_id }
4026 }
4027 PrimaryChannelGuestVfState::AvailableAdvertised
4028 | PrimaryChannelGuestVfState::Restoring(
4029 saved_state::GuestVfState::AvailableAdvertised,
4030 ) => {
4031 if !primary.is_data_path_switched.unwrap_or(false) {
4032 PrimaryChannelGuestVfState::AvailableAdvertised
4033 } else {
4034 PrimaryChannelGuestVfState::DataPathSwitched
4037 }
4038 }
4039 PrimaryChannelGuestVfState::DataPathSwitched
4040 | PrimaryChannelGuestVfState::Restoring(
4041 saved_state::GuestVfState::DataPathSwitched,
4042 ) => PrimaryChannelGuestVfState::DataPathSwitched,
4043 PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4044 PrimaryChannelGuestVfState::Ready
4045 }
4046 _ => primary.guest_vf_state,
4047 };
4048 } else {
4049 match primary.guest_vf_state {
4051 PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, .. }
4052 | PrimaryChannelGuestVfState::Restoring(
4053 saved_state::GuestVfState::DataPathSwitchPending { to_guest, .. },
4054 ) => {
4055 if !to_guest {
4056 if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4057 tracing::warn!(
4058 err = err.as_ref() as &dyn std::error::Error,
4059 "Failed setting data path back to synthetic after guest VF was removed."
4060 );
4061 }
4062 primary.is_data_path_switched = Some(false);
4063 }
4064 }
4065 PrimaryChannelGuestVfState::DataPathSwitched
4066 | PrimaryChannelGuestVfState::Restoring(
4067 saved_state::GuestVfState::DataPathSwitched,
4068 ) => {
4069 if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4070 tracing::warn!(
4071 err = err.as_ref() as &dyn std::error::Error,
4072 "Failed setting data path back to synthetic after guest VF was removed."
4073 );
4074 }
4075 primary.is_data_path_switched = Some(false);
4076 }
4077 _ => (),
4078 }
4079 if let PrimaryChannelGuestVfState::AvailableAdvertised = primary.guest_vf_state {
4080 c_state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
4081 }
4082 primary.guest_vf_state = match primary.guest_vf_state {
4084 PrimaryChannelGuestVfState::Initializing
4085 | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState)
4086 | PrimaryChannelGuestVfState::Available { .. } => {
4087 PrimaryChannelGuestVfState::Unavailable
4088 }
4089 PrimaryChannelGuestVfState::AvailableAdvertised
4090 | PrimaryChannelGuestVfState::Restoring(
4091 saved_state::GuestVfState::AvailableAdvertised,
4092 )
4093 | PrimaryChannelGuestVfState::Ready
4094 | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4095 PrimaryChannelGuestVfState::UnavailableFromAvailable
4096 }
4097 PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, id, .. }
4098 | PrimaryChannelGuestVfState::Restoring(
4099 saved_state::GuestVfState::DataPathSwitchPending { to_guest, id, .. },
4100 ) => PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4101 to_guest,
4102 id,
4103 },
4104 PrimaryChannelGuestVfState::DataPathSwitched
4105 | PrimaryChannelGuestVfState::Restoring(
4106 saved_state::GuestVfState::DataPathSwitched,
4107 )
4108 | PrimaryChannelGuestVfState::DataPathSynthetic => {
4109 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4110 }
4111 _ => primary.guest_vf_state,
4112 }
4113 }
4114 }
4115
4116 async fn restart_queues(&mut self, c_state: &mut CoordinatorState) -> Result<(), WorkerError> {
4117 let drivers = self
4119 .workers
4120 .iter_mut()
4121 .map(|worker| {
4122 let task = worker.task_mut();
4123 task.queue_state = None;
4124 task.driver.clone()
4125 })
4126 .collect::<Vec<_>>();
4127
4128 c_state.endpoint.stop().await;
4129
4130 let (primary_worker, subworkers) = if let [primary, sub @ ..] = self.workers.as_mut_slice()
4131 {
4132 (primary, sub)
4133 } else {
4134 unreachable!()
4135 };
4136
4137 let state = primary_worker
4138 .state_mut()
4139 .and_then(|worker| worker.state.ready_mut());
4140
4141 let state = if let Some(state) = state {
4142 state
4143 } else {
4144 return Ok(());
4145 };
4146
4147 self.buffers = Some(state.buffers.clone());
4149
4150 let num_queues = state.state.primary.as_ref().unwrap().requested_num_queues;
4151 let mut active_queues = Vec::new();
4152 let active_queue_count =
4153 if let Some(rss_state) = state.state.primary.as_ref().unwrap().rss_state.as_ref() {
4154 active_queues.clone_from(&rss_state.indirection_table);
4156 active_queues.sort();
4157 active_queues.dedup();
4158 active_queues = active_queues
4159 .into_iter()
4160 .filter(|&index| index < num_queues)
4161 .collect::<Vec<_>>();
4162 if !active_queues.is_empty() {
4163 active_queues.len() as u16
4164 } else {
4165 tracelimit::warn_ratelimited!("Invalid RSS indirection table");
4166 num_queues
4167 }
4168 } else {
4169 num_queues
4170 };
4171
4172 let (ranges, mut remote_buffer_id_recvs) =
4174 RxBufferRanges::new(state.buffers.recv_buffer.count, active_queue_count.into());
4175 let ranges = Arc::new(ranges);
4176
4177 let mut queues = Vec::new();
4178 let mut rx_buffers = Vec::new();
4179 {
4180 let buffers = &state.buffers;
4181 let guest_buffers = Arc::new(
4182 GuestBuffers::new(
4183 buffers.mem.clone(),
4184 buffers.recv_buffer.gpadl.clone(),
4185 buffers.recv_buffer.sub_allocation_size,
4186 buffers.ndis_config.mtu,
4187 )
4188 .map_err(WorkerError::GpadlError)?,
4189 );
4190
4191 let mut queue_config = Vec::new();
4194 let initial_rx;
4195 {
4196 let states = std::iter::once(Some(&*state)).chain(
4197 subworkers
4198 .iter()
4199 .map(|worker| worker.state().and_then(|worker| worker.state.ready())),
4200 );
4201
4202 initial_rx = (RX_RESERVED_CONTROL_BUFFERS..state.buffers.recv_buffer.count)
4203 .filter(|&n| {
4204 states
4205 .clone()
4206 .flatten()
4207 .all(|s| (s.state.rx_bufs.is_free(n)))
4208 })
4209 .map(RxId)
4210 .collect::<Vec<_>>();
4211
4212 let mut initial_rx = initial_rx.as_slice();
4213 let mut range_start = 0;
4214 let primary_queue_excluded = !active_queues.is_empty() && active_queues[0] != 0;
4215 let first_queue = if !primary_queue_excluded {
4216 0
4217 } else {
4218 queue_config.push(QueueConfig {
4222 pool: Box::new(BufferPool::new(guest_buffers.clone())),
4223 initial_rx: &[],
4224 driver: Box::new(drivers[0].clone()),
4225 });
4226 rx_buffers.push(RxBufferRange::new(
4227 ranges.clone(),
4228 0..RX_RESERVED_CONTROL_BUFFERS,
4229 None,
4230 ));
4231 range_start = RX_RESERVED_CONTROL_BUFFERS;
4232 1
4233 };
4234 for queue_index in first_queue..num_queues {
4235 let queue_active = active_queues.is_empty()
4236 || active_queues.binary_search(&queue_index).is_ok();
4237 let (range_end, end, buffer_id_recv) = if queue_active {
4238 let range_end = if rx_buffers.len() as u16 == active_queue_count - 1 {
4239 state.buffers.recv_buffer.count
4241 } else if queue_index == 0 {
4242 RX_RESERVED_CONTROL_BUFFERS + ranges.buffers_per_queue
4244 } else {
4245 range_start + ranges.buffers_per_queue
4246 };
4247 (
4248 range_end,
4249 initial_rx.partition_point(|id| id.0 < range_end),
4250 Some(remote_buffer_id_recvs.remove(0)),
4251 )
4252 } else {
4253 (range_start, 0, None)
4254 };
4255
4256 let (this, rest) = initial_rx.split_at(end);
4257 queue_config.push(QueueConfig {
4258 pool: Box::new(BufferPool::new(guest_buffers.clone())),
4259 initial_rx: this,
4260 driver: Box::new(drivers[queue_index as usize].clone()),
4261 });
4262 initial_rx = rest;
4263 rx_buffers.push(RxBufferRange::new(
4264 ranges.clone(),
4265 range_start..range_end,
4266 buffer_id_recv,
4267 ));
4268
4269 range_start = range_end;
4270 }
4271 }
4272
4273 let primary = state.state.primary.as_mut().unwrap();
4274 tracing::debug!(num_queues, "enabling endpoint");
4275
4276 let rss = primary
4277 .rss_state
4278 .as_ref()
4279 .map(|rss| net_backend::RssConfig {
4280 key: &rss.key,
4281 indirection_table: &rss.indirection_table,
4282 flags: 0,
4283 });
4284
4285 c_state
4286 .endpoint
4287 .get_queues(queue_config, rss.as_ref(), &mut queues)
4288 .instrument(tracing::info_span!("netvsp_get_queues"))
4289 .await
4290 .map_err(WorkerError::Endpoint)?;
4291
4292 assert_eq!(queues.len(), num_queues as usize);
4293
4294 self.channel_control
4296 .enable_subchannels(num_queues - 1)
4297 .expect("already validated");
4298
4299 self.num_queues = num_queues;
4300 }
4301
4302 for ((worker, queue), rx_buffer) in self.workers.iter_mut().zip(queues).zip(rx_buffers) {
4304 worker.task_mut().queue_state = Some(QueueState {
4305 queue,
4306 target_vp_set: false,
4307 rx_buffer_range: rx_buffer,
4308 });
4309 }
4310
4311 Ok(())
4312 }
4313
4314 fn primary_mut(&mut self) -> Option<&mut PrimaryChannelState> {
4315 self.workers[0]
4316 .state_mut()
4317 .unwrap()
4318 .state
4319 .ready_mut()?
4320 .state
4321 .primary
4322 .as_mut()
4323 }
4324
4325 async fn update_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
4326 self.workers[0].stop().await;
4327 self.restore_guest_vf_state(c_state).await;
4328 }
4329}
4330
4331impl<T: RingMem + 'static + Sync> AsyncRun<Worker<T>> for NetQueue {
4332 async fn run(
4333 &mut self,
4334 stop: &mut StopTask<'_>,
4335 worker: &mut Worker<T>,
4336 ) -> Result<(), task_control::Cancelled> {
4337 match worker.process(stop, self).await {
4338 Ok(()) | Err(WorkerError::BufferRevoked) => {}
4339 Err(WorkerError::Cancelled(cancelled)) => return Err(cancelled),
4340 Err(err) => {
4341 tracing::error!(
4342 channel_idx = worker.channel_idx,
4343 error = &err as &dyn std::error::Error,
4344 "netvsp error"
4345 );
4346 }
4347 }
4348 Ok(())
4349 }
4350}
4351
4352impl<T: RingMem + 'static> Worker<T> {
4353 async fn process(
4354 &mut self,
4355 stop: &mut StopTask<'_>,
4356 queue: &mut NetQueue,
4357 ) -> Result<(), WorkerError> {
4358 loop {
4362 match &mut self.state {
4363 WorkerState::Init(initializing) => {
4364 if self.channel_idx != 0 {
4365 stop.until_stopped(pending()).await?
4368 }
4369
4370 tracelimit::info_ratelimited!("network accepted");
4371
4372 let (buffers, state) = stop
4373 .until_stopped(self.channel.initialize(initializing, self.mem.clone()))
4374 .await??;
4375
4376 let state = ReadyState {
4377 buffers: Arc::new(buffers),
4378 state,
4379 data: ProcessingData::new(),
4380 };
4381
4382 let _ = self.coordinator_send.try_send(CoordinatorMessage::Restart);
4384
4385 tracelimit::info_ratelimited!("network initialized");
4386 self.state = WorkerState::Ready(state);
4387 }
4388 WorkerState::Ready(state) => {
4389 let queue_state = if let Some(queue_state) = &mut queue.queue_state {
4390 if !queue_state.target_vp_set
4391 && self.target_vp != vmbus_core::protocol::VP_INDEX_DISABLE_INTERRUPT
4392 {
4393 tracing::debug!(
4394 channel_idx = self.channel_idx,
4395 target_vp = self.target_vp,
4396 "updating target VP"
4397 );
4398 queue_state.queue.update_target_vp(self.target_vp).await;
4399 queue_state.target_vp_set = true;
4400 }
4401
4402 queue_state
4403 } else {
4404 stop.until_stopped(pending()).await?
4406 };
4407
4408 let result = self.channel.main_loop(stop, state, queue_state).await;
4409 match result {
4410 Ok(restart) => {
4411 assert_eq!(self.channel_idx, 0);
4412 let _ = self.coordinator_send.try_send(restart);
4413 }
4414 Err(WorkerError::EndpointRequiresQueueRestart(err)) => {
4415 tracelimit::warn_ratelimited!(
4416 err = err.as_ref() as &dyn std::error::Error,
4417 "Endpoint requires queues to restart",
4418 );
4419 if let Err(try_send_err) =
4420 self.coordinator_send.try_send(CoordinatorMessage::Restart)
4421 {
4422 tracing::error!(
4423 try_send_err = &try_send_err as &dyn std::error::Error,
4424 "failed to restart queues"
4425 );
4426 return Err(WorkerError::Endpoint(err));
4427 }
4428 }
4429 Err(err) => return Err(err),
4430 }
4431
4432 stop.until_stopped(pending()).await?
4433 }
4434 }
4435 }
4436 }
4437}
4438
4439impl<T: 'static + RingMem> NetChannel<T> {
4440 fn try_next_packet<'a>(
4441 &mut self,
4442 send_buffer: Option<&'a SendBuffer>,
4443 version: Option<Version>,
4444 ) -> Result<Option<Packet<'a>>, WorkerError> {
4445 let (mut read, _) = self.queue.split();
4446 let packet = match read.try_read() {
4447 Ok(packet) => {
4448 parse_packet(&packet, send_buffer, version).map_err(WorkerError::Packet)?
4449 }
4450 Err(queue::TryReadError::Empty) => return Ok(None),
4451 Err(queue::TryReadError::Queue(err)) => return Err(err.into()),
4452 };
4453
4454 tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
4455 Ok(Some(packet))
4456 }
4457
4458 async fn next_packet<'a>(
4459 &mut self,
4460 send_buffer: Option<&'a SendBuffer>,
4461 version: Option<Version>,
4462 ) -> Result<Packet<'a>, WorkerError> {
4463 let (mut read, _) = self.queue.split();
4464 let mut packet_ref = read.read().await?;
4465 let packet =
4466 parse_packet(&packet_ref, send_buffer, version).map_err(WorkerError::Packet)?;
4467 if matches!(packet.data, PacketData::RndisPacket(_)) {
4468 tracing::trace!(target: "netvsp/vmbus", "detected rndis initialization message");
4470 packet_ref.revert();
4471 }
4472 tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
4473 Ok(packet)
4474 }
4475
4476 fn is_ready_to_initialize(initializing: &InitState, allow_missing_send_buffer: bool) -> bool {
4477 (initializing.ndis_config.is_some() || initializing.version < Version::V2)
4478 && initializing.ndis_version.is_some()
4479 && (initializing.send_buffer.is_some() || allow_missing_send_buffer)
4480 && initializing.recv_buffer.is_some()
4481 }
4482
4483 async fn initialize(
4484 &mut self,
4485 initializing: &mut Option<InitState>,
4486 mem: GuestMemory,
4487 ) -> Result<(ChannelBuffers, ActiveState), WorkerError> {
4488 let mut has_init_packet_arrived = false;
4489 loop {
4490 if let Some(initializing) = &mut *initializing {
4491 if Self::is_ready_to_initialize(initializing, false) || has_init_packet_arrived {
4492 let recv_buffer = initializing.recv_buffer.take().unwrap();
4493 let send_buffer = initializing.send_buffer.take();
4494 let state = ActiveState::new(
4495 Some(PrimaryChannelState::new(
4496 self.adapter.offload_support.clone(),
4497 )),
4498 recv_buffer.count,
4499 );
4500 let buffers = ChannelBuffers {
4501 version: initializing.version,
4502 mem,
4503 recv_buffer,
4504 send_buffer,
4505 ndis_version: initializing.ndis_version.take().unwrap(),
4506 ndis_config: initializing.ndis_config.take().unwrap_or(NdisConfig {
4507 mtu: DEFAULT_MTU,
4508 capabilities: protocol::NdisConfigCapabilities::new(),
4509 }),
4510 };
4511
4512 break Ok((buffers, state));
4513 }
4514 }
4515
4516 self.queue
4519 .split()
4520 .1
4521 .wait_ready(ring::PacketSize::completion(protocol::PACKET_SIZE_V61))
4522 .await?;
4523
4524 let packet = self
4525 .next_packet(None, initializing.as_ref().map(|x| x.version))
4526 .await?;
4527
4528 if let Some(initializing) = &mut *initializing {
4529 match packet.data {
4530 PacketData::SendNdisConfig(config) => {
4531 if initializing.ndis_config.is_some() {
4532 return Err(WorkerError::UnexpectedPacketOrder(
4533 PacketOrderError::SendNdisConfigExists,
4534 ));
4535 }
4536
4537 let mtu = if config.mtu >= MIN_MTU && config.mtu <= MAX_MTU {
4539 config.mtu
4540 } else {
4541 DEFAULT_MTU
4542 };
4543
4544 self.send_completion(packet.transaction_id, &[])?;
4546 initializing.ndis_config = Some(NdisConfig {
4547 mtu,
4548 capabilities: config.capabilities,
4549 });
4550 }
4551 PacketData::SendNdisVersion(version) => {
4552 if initializing.ndis_version.is_some() {
4553 return Err(WorkerError::UnexpectedPacketOrder(
4554 PacketOrderError::SendNdisVersionExists,
4555 ));
4556 }
4557
4558 self.send_completion(packet.transaction_id, &[])?;
4560 initializing.ndis_version = Some(NdisVersion {
4561 major: version.ndis_major_version,
4562 minor: version.ndis_minor_version,
4563 });
4564 }
4565 PacketData::SendReceiveBuffer(message) => {
4566 if initializing.recv_buffer.is_some() {
4567 return Err(WorkerError::UnexpectedPacketOrder(
4568 PacketOrderError::SendReceiveBufferExists,
4569 ));
4570 }
4571
4572 let mtu = if let Some(cfg) = &initializing.ndis_config {
4573 cfg.mtu
4574 } else if initializing.version < Version::V2 {
4575 DEFAULT_MTU
4576 } else {
4577 return Err(WorkerError::UnexpectedPacketOrder(
4578 PacketOrderError::SendReceiveBufferMissingMTU,
4579 ));
4580 };
4581
4582 let sub_allocation_size = sub_allocation_size_for_mtu(mtu);
4583
4584 let recv_buffer = ReceiveBuffer::new(
4585 &self.gpadl_map,
4586 message.gpadl_handle,
4587 message.id,
4588 sub_allocation_size,
4589 )?;
4590
4591 self.send_completion(
4592 packet.transaction_id,
4593 &self
4594 .message(
4595 protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER_COMPLETE,
4596 protocol::Message1SendReceiveBufferComplete {
4597 status: protocol::Status::SUCCESS,
4598 num_sections: 1,
4599 sections: [protocol::ReceiveBufferSection {
4600 offset: 0,
4601 sub_allocation_size: recv_buffer.sub_allocation_size,
4602 num_sub_allocations: recv_buffer.count,
4603 end_offset: recv_buffer.sub_allocation_size
4604 * recv_buffer.count,
4605 }],
4606 },
4607 )
4608 .payload(),
4609 )?;
4610 initializing.recv_buffer = Some(recv_buffer);
4611 }
4612 PacketData::SendSendBuffer(message) => {
4613 if initializing.send_buffer.is_some() {
4614 return Err(WorkerError::UnexpectedPacketOrder(
4615 PacketOrderError::SendSendBufferExists,
4616 ));
4617 }
4618
4619 let send_buffer = SendBuffer::new(&self.gpadl_map, message.gpadl_handle)?;
4620 self.send_completion(
4621 packet.transaction_id,
4622 &self
4623 .message(
4624 protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER_COMPLETE,
4625 protocol::Message1SendSendBufferComplete {
4626 status: protocol::Status::SUCCESS,
4627 section_size: 6144,
4628 },
4629 )
4630 .payload(),
4631 )?;
4632
4633 initializing.send_buffer = Some(send_buffer);
4634 }
4635 PacketData::RndisPacket(rndis_packet) => {
4636 if !Self::is_ready_to_initialize(initializing, true) {
4637 return Err(WorkerError::UnexpectedPacketOrder(
4638 PacketOrderError::UnexpectedRndisPacket,
4639 ));
4640 }
4641 tracing::debug!(
4642 channel_type = rndis_packet.channel_type,
4643 "RndisPacket received during initialization, assuming MESSAGE_TYPE_INITIALIZE_MSG"
4644 );
4645 has_init_packet_arrived = true;
4646 }
4647 _ => {
4648 return Err(WorkerError::UnexpectedPacketOrder(
4649 PacketOrderError::InvalidPacketData,
4650 ));
4651 }
4652 }
4653 } else {
4654 match packet.data {
4655 PacketData::Init(init) => {
4656 let requested_version = init.protocol_version;
4657 let version = check_version(requested_version);
4658 let mut message = self.message(
4659 protocol::MESSAGE_TYPE_INIT_COMPLETE,
4660 protocol::MessageInitComplete {
4661 deprecated: protocol::INVALID_PROTOCOL_VERSION,
4662 maximum_mdl_chain_length: 34,
4663 status: protocol::Status::NONE,
4664 },
4665 );
4666 if let Some(version) = version {
4667 if version == Version::V1 {
4668 message.data.deprecated = Version::V1 as u32;
4669 }
4670 message.data.status = protocol::Status::SUCCESS;
4671 } else {
4672 tracing::debug!(requested_version, "unrecognized version");
4673 }
4674 self.send_completion(packet.transaction_id, &message.payload())?;
4675
4676 if let Some(version) = version {
4677 tracelimit::info_ratelimited!(?version, "network negotiated");
4678
4679 if version >= Version::V61 {
4680 self.packet_size = protocol::PACKET_SIZE_V61;
4683 }
4684 *initializing = Some(InitState {
4685 version,
4686 ndis_config: None,
4687 ndis_version: None,
4688 recv_buffer: None,
4689 send_buffer: None,
4690 });
4691 }
4692 }
4693 _ => unreachable!(),
4694 }
4695 }
4696 }
4697 }
4698
4699 async fn main_loop(
4700 &mut self,
4701 stop: &mut StopTask<'_>,
4702 ready_state: &mut ReadyState,
4703 queue_state: &mut QueueState,
4704 ) -> Result<CoordinatorMessage, WorkerError> {
4705 let buffers = &ready_state.buffers;
4706 let state = &mut ready_state.state;
4707 let data = &mut ready_state.data;
4708
4709 let ring_spare_capacity = {
4710 let (_, send) = self.queue.split();
4711 let mut limit = if self.can_use_ring_size_opt {
4712 self.adapter.ring_size_limit.load(Ordering::Relaxed)
4713 } else {
4714 0
4715 };
4716 if limit == 0 {
4717 limit = send.capacity() - 2048;
4718 }
4719 send.capacity() - limit
4720 };
4721
4722 if let Some(primary) = state.primary.as_mut() {
4724 if primary.requested_num_queues > 1 && !primary.tx_spread_sent {
4725 let num_channels_opened =
4726 self.adapter.num_sub_channels_opened.load(Ordering::Relaxed);
4727 if num_channels_opened == primary.requested_num_queues as usize {
4728 let (_, mut send) = self.queue.split();
4729 stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
4730 .await??;
4731 self.guest_send_indirection_table(buffers.version, num_channels_opened as u32);
4732 primary.tx_spread_sent = true;
4733 }
4734 }
4735 if let PendingLinkAction::Active(up) = primary.pending_link_action {
4736 let (_, mut send) = self.queue.split();
4737 stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
4738 .await??;
4739 if let Some(id) = primary.free_control_buffers.pop() {
4740 let connect = if primary.guest_link_up != up {
4741 primary.pending_link_action = PendingLinkAction::Default;
4742 up
4743 } else {
4744 primary.pending_link_action =
4747 PendingLinkAction::Delay(primary.guest_link_up);
4748 !primary.guest_link_up
4749 };
4750 assert!(state.rx_bufs.is_free(id.0));
4752 state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
4753 let state_to_send = if connect {
4754 rndisprot::STATUS_MEDIA_CONNECT
4755 } else {
4756 rndisprot::STATUS_MEDIA_DISCONNECT
4757 };
4758 tracing::info!(
4759 connect,
4760 mac_address = %self.adapter.mac_address,
4761 "sending link status"
4762 );
4763
4764 self.indicate_status(buffers, id.0, state_to_send, &[])?;
4765 primary.guest_link_up = connect;
4766 } else {
4767 primary.pending_link_action = PendingLinkAction::Delay(up);
4768 }
4769
4770 match primary.pending_link_action {
4771 PendingLinkAction::Delay(_) => {
4772 return Ok(CoordinatorMessage::StartTimer(
4773 Instant::now() + LINK_DELAY_DURATION,
4774 ));
4775 }
4776 PendingLinkAction::Active(_) => panic!("State should not be Active"),
4777 _ => {}
4778 }
4779 }
4780 match primary.guest_vf_state {
4781 PrimaryChannelGuestVfState::Available { .. }
4782 | PrimaryChannelGuestVfState::UnavailableFromAvailable
4783 | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4784 | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4785 | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4786 | PrimaryChannelGuestVfState::DataPathSynthetic => {
4787 let (_, mut send) = self.queue.split();
4788 stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
4789 .await??;
4790 if let Some(message) = self.handle_state_change(primary, buffers).await? {
4791 return Ok(message);
4792 }
4793 }
4794 _ => (),
4795 }
4796 }
4797
4798 loop {
4799 let ring_full = {
4804 let (_, mut send) = self.queue.split();
4805 !send.can_write(ring_spare_capacity)?
4806 };
4807
4808 let did_some_work = (!ring_full
4809 && self.process_endpoint_rx(buffers, state, data, queue_state.queue.as_mut())?)
4810 | self.process_ring_buffer(buffers, state, data, queue_state)?
4811 | (!ring_full
4812 && self.process_endpoint_tx(state, data, queue_state.queue.as_mut())?)
4813 | self.transmit_pending_segments(state, data, queue_state)?
4814 | self.send_pending_packets(state)?;
4815
4816 if !did_some_work {
4817 state.stats.spurious_wakes.increment();
4818 }
4819
4820 self.process_control_messages(buffers, state)?;
4823
4824 let restart = stop
4828 .until_stopped(std::future::poll_fn(
4829 |cx| -> Poll<Option<CoordinatorMessage>> {
4830 if !ring_full {
4834 if queue_state.queue.poll_ready(cx).is_ready() {
4836 tracing::trace!("endpoint ready");
4837 return Poll::Ready(None);
4838 }
4839 }
4840
4841 let (mut recv, mut send) = self.queue.split();
4844 if state.free_tx_packets.len() >= self.adapter.free_tx_packet_threshold
4845 && data.tx_segments.is_empty()
4846 && recv.poll_ready(cx).is_ready()
4847 {
4848 tracing::trace!("incoming ring ready");
4849 return Poll::Ready(None);
4850 }
4851
4852 let mut pending_send_size = self.pending_send_size;
4859 if ring_full {
4860 pending_send_size = ring_spare_capacity;
4861 }
4862 if pending_send_size != 0
4863 && send.poll_ready(cx, pending_send_size).is_ready()
4864 {
4865 tracing::trace!("outgoing ring ready");
4866 return Poll::Ready(None);
4867 }
4868
4869 if let Some(remote_buffer_id_recv) =
4875 &mut queue_state.rx_buffer_range.remote_buffer_id_recv
4876 {
4877 while let Poll::Ready(Some(id)) =
4878 remote_buffer_id_recv.poll_next_unpin(cx)
4879 {
4880 if id >= RX_RESERVED_CONTROL_BUFFERS {
4881 queue_state.queue.rx_avail(&[RxId(id)]);
4882 } else {
4883 state
4884 .primary
4885 .as_mut()
4886 .unwrap()
4887 .free_control_buffers
4888 .push(ControlMessageId(id));
4889 }
4890 }
4891 }
4892
4893 if let Some(restart) = self.restart.take() {
4894 return Poll::Ready(Some(restart));
4895 }
4896
4897 tracing::trace!("network waiting");
4898 Poll::Pending
4899 },
4900 ))
4901 .await?;
4902
4903 if let Some(restart) = restart {
4904 break Ok(restart);
4905 }
4906 }
4907 }
4908
4909 fn process_endpoint_rx(
4910 &mut self,
4911 buffers: &ChannelBuffers,
4912 state: &mut ActiveState,
4913 data: &mut ProcessingData,
4914 epqueue: &mut dyn net_backend::Queue,
4915 ) -> Result<bool, WorkerError> {
4916 let n = epqueue
4917 .rx_poll(&mut data.rx_ready)
4918 .map_err(WorkerError::Endpoint)?;
4919 if n == 0 {
4920 return Ok(false);
4921 }
4922
4923 let transaction_id = data.rx_ready[0].0.into();
4924 let ready_ids = data.rx_ready[..n].iter().map(|&RxId(id)| id);
4925
4926 state.rx_bufs.allocate(ready_ids.clone()).unwrap();
4927
4928 let len = buffers.recv_buffer.sub_allocation_size as usize;
4931 data.transfer_pages.clear();
4932 data.transfer_pages
4933 .extend(ready_ids.map(|id| buffers.recv_buffer.transfer_page_range(id, len)));
4934
4935 match self.try_send_rndis_message(
4936 transaction_id,
4937 protocol::DATA_CHANNEL_TYPE,
4938 buffers.recv_buffer.id,
4939 &data.transfer_pages,
4940 )? {
4941 None => {
4942 state.stats.rx_packets.add(n as u64);
4944 }
4945 Some(_) => {
4946 state.stats.rx_dropped_ring_full.add(n as u64);
4949
4950 state.rx_bufs.free(data.rx_ready[0].0);
4951 epqueue.rx_avail(&data.rx_ready[..n]);
4952 }
4953 }
4954
4955 state.stats.rx_packets_per_wake.add_sample(n as u64);
4956 Ok(true)
4957 }
4958
4959 fn process_endpoint_tx(
4960 &mut self,
4961 state: &mut ActiveState,
4962 data: &mut ProcessingData,
4963 epqueue: &mut dyn net_backend::Queue,
4964 ) -> Result<bool, WorkerError> {
4965 let result = epqueue.tx_poll(&mut data.tx_done);
4967
4968 match result {
4969 Ok(n) => {
4970 if n == 0 {
4971 return Ok(false);
4972 }
4973
4974 for &id in &data.tx_done[..n] {
4975 let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
4976 assert!(tx_packet.pending_packet_count > 0);
4977 tx_packet.pending_packet_count -= 1;
4978 if tx_packet.pending_packet_count == 0 {
4979 self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
4980 }
4981 }
4982
4983 Ok(true)
4984 }
4985 Err(TxError::TryRestart(err)) => {
4986 let pending_tx = state
4988 .pending_tx_packets
4989 .iter_mut()
4990 .enumerate()
4991 .filter_map(|(id, inflight)| {
4992 if inflight.pending_packet_count > 0 {
4993 inflight.pending_packet_count = 0;
4994 Some(PendingTxCompletion {
4995 transaction_id: inflight.transaction_id,
4996 tx_id: Some(TxId(id as u32)),
4997 status: protocol::Status::SUCCESS,
4998 })
4999 } else {
5000 None
5001 }
5002 })
5003 .collect::<Vec<_>>();
5004 state.pending_tx_completions.extend(pending_tx);
5005 Err(WorkerError::EndpointRequiresQueueRestart(err))
5006 }
5007 Err(TxError::Fatal(err)) => Err(WorkerError::Endpoint(err)),
5008 }
5009 }
5010
5011 fn switch_data_path(
5012 &mut self,
5013 state: &mut ActiveState,
5014 use_guest_vf: bool,
5015 transaction_id: Option<u64>,
5016 ) -> Result<(), WorkerError> {
5017 let primary = state.primary.as_mut().unwrap();
5018 let mut queue_switch_operation = false;
5019 match primary.guest_vf_state {
5020 PrimaryChannelGuestVfState::AvailableAdvertised | PrimaryChannelGuestVfState::Ready => {
5021 if use_guest_vf || primary.is_data_path_switched.is_none() {
5030 primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5031 to_guest: use_guest_vf,
5032 id: transaction_id,
5033 result: None,
5034 };
5035 queue_switch_operation = true;
5036 }
5037 }
5038 PrimaryChannelGuestVfState::DataPathSwitched => {
5039 if !use_guest_vf {
5040 primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5041 to_guest: false,
5042 id: transaction_id,
5043 result: None,
5044 };
5045 queue_switch_operation = true;
5046 }
5047 }
5048 _ if use_guest_vf => {
5049 tracing::warn!(
5050 state = %primary.guest_vf_state,
5051 use_guest_vf,
5052 "Data path switch requested while device is in wrong state"
5053 );
5054 }
5055 _ => (),
5056 };
5057 if queue_switch_operation {
5058 if self.restart.is_none() {
5060 self.restart = Some(CoordinatorMessage::UpdateGuestVfState)
5061 };
5062 } else {
5063 self.send_completion(transaction_id, &[])?;
5064 }
5065 Ok(())
5066 }
5067
5068 fn process_ring_buffer(
5069 &mut self,
5070 buffers: &ChannelBuffers,
5071 state: &mut ActiveState,
5072 data: &mut ProcessingData,
5073 queue_state: &mut QueueState,
5074 ) -> Result<bool, WorkerError> {
5075 let mut total_packets = 0;
5076 let mut did_some_work = false;
5077 loop {
5078 if state.free_tx_packets.is_empty() || !data.tx_segments.is_empty() {
5079 break;
5080 }
5081 let packet = if let Some(packet) =
5082 self.try_next_packet(buffers.send_buffer.as_ref(), Some(buffers.version))?
5083 {
5084 packet
5085 } else {
5086 break;
5087 };
5088
5089 did_some_work = true;
5090 match packet.data {
5091 PacketData::RndisPacket(_) => {
5092 assert!(data.tx_segments.is_empty());
5093 let id = state.free_tx_packets.pop().unwrap();
5094 let result: Result<usize, WorkerError> =
5095 self.handle_rndis(buffers, id, state, &packet, &mut data.tx_segments);
5096 let num_packets = match result {
5097 Ok(num_packets) => num_packets,
5098 Err(err) => {
5099 tracelimit::error_ratelimited!(
5100 err = &err as &dyn std::error::Error,
5101 "failed to handle RNDIS packet"
5102 );
5103 self.complete_tx_packet(state, id, protocol::Status::FAILURE)?;
5104 continue;
5105 }
5106 };
5107 total_packets += num_packets as u64;
5108 state.pending_tx_packets[id.0 as usize].pending_packet_count += num_packets;
5109
5110 if num_packets != 0 {
5111 if self.transmit_segments(state, data, queue_state, id, num_packets)?
5112 < num_packets
5113 {
5114 state.stats.tx_stalled.increment();
5115 }
5116 } else {
5117 self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5118 }
5119 }
5120 PacketData::RndisPacketComplete(_completion) => {
5121 data.rx_done.clear();
5122 state
5123 .release_recv_buffers(
5124 packet
5125 .transaction_id
5126 .expect("completion packets have transaction id by construction"),
5127 &queue_state.rx_buffer_range,
5128 &mut data.rx_done,
5129 )
5130 .ok_or(WorkerError::InvalidRndisPacketCompletion)?;
5131 queue_state.queue.rx_avail(&data.rx_done);
5132 }
5133 PacketData::SubChannelRequest(request) if state.primary.is_some() => {
5134 let mut subchannel_count = 0;
5135 let status = if request.operation == protocol::SubchannelOperation::ALLOCATE
5136 && request.num_sub_channels <= self.adapter.max_queues.into()
5137 {
5138 subchannel_count = request.num_sub_channels;
5139 protocol::Status::SUCCESS
5140 } else {
5141 protocol::Status::FAILURE
5142 };
5143
5144 tracing::debug!(?status, subchannel_count, "subchannel request");
5145 self.send_completion(
5146 packet.transaction_id,
5147 &self
5148 .message(
5149 protocol::MESSAGE5_TYPE_SUB_CHANNEL,
5150 protocol::Message5SubchannelComplete {
5151 status,
5152 num_sub_channels: subchannel_count,
5153 },
5154 )
5155 .payload(),
5156 )?;
5157
5158 if subchannel_count > 0 {
5159 let primary = state.primary.as_mut().unwrap();
5160 primary.requested_num_queues = subchannel_count as u16 + 1;
5161 primary.tx_spread_sent = false;
5162 self.restart = Some(CoordinatorMessage::Restart);
5163 }
5164 }
5165 PacketData::RevokeReceiveBuffer(Message1RevokeReceiveBuffer { id })
5166 | PacketData::RevokeSendBuffer(Message1RevokeSendBuffer { id })
5167 if state.primary.is_some() =>
5168 {
5169 tracing::debug!(
5170 id,
5171 "receive/send buffer revoked, terminating channel processing"
5172 );
5173 return Err(WorkerError::BufferRevoked);
5174 }
5175 PacketData::SendVfAssociationCompletion if state.primary.is_some() => (),
5176 PacketData::SwitchDataPath(switch_data_path) if state.primary.is_some() => {
5177 self.switch_data_path(
5178 state,
5179 switch_data_path.active_data_path == protocol::DataPath::VF.0,
5180 packet.transaction_id,
5181 )?;
5182 }
5183 PacketData::SwitchDataPathCompletion if state.primary.is_some() => (),
5184 PacketData::OidQueryEx(oid_query) => {
5185 tracing::warn!(oid = ?oid_query.oid, "unimplemented OID");
5186 self.send_completion(
5187 packet.transaction_id,
5188 &self
5189 .message(
5190 protocol::MESSAGE5_TYPE_OID_QUERY_EX_COMPLETE,
5191 protocol::Message5OidQueryExComplete {
5192 status: rndisprot::STATUS_NOT_SUPPORTED,
5193 bytes: 0,
5194 },
5195 )
5196 .payload(),
5197 )?;
5198 }
5199 p => {
5200 tracing::warn!(request = ?p, "unexpected packet");
5201 return Err(WorkerError::UnexpectedPacketOrder(
5202 PacketOrderError::SwitchDataPathCompletionPrimaryChannelState,
5203 ));
5204 }
5205 }
5206 }
5207 state.stats.tx_packets_per_wake.add_sample(total_packets);
5208 Ok(did_some_work)
5209 }
5210
5211 fn transmit_pending_segments(
5212 &mut self,
5213 state: &mut ActiveState,
5214 data: &mut ProcessingData,
5215 queue_state: &mut QueueState,
5216 ) -> Result<bool, WorkerError> {
5217 if data.tx_segments.is_empty() {
5218 return Ok(false);
5219 }
5220 let net_backend::TxSegmentType::Head(metadata) = &data.tx_segments[0].ty else {
5221 unreachable!()
5222 };
5223 let id = metadata.id;
5224 let num_packets = state.pending_tx_packets[id.0 as usize].pending_packet_count;
5225 let packets_sent = self.transmit_segments(state, data, queue_state, id, num_packets)?;
5226 Ok(num_packets == packets_sent)
5227 }
5228
5229 fn transmit_segments(
5230 &mut self,
5231 state: &mut ActiveState,
5232 data: &mut ProcessingData,
5233 queue_state: &mut QueueState,
5234 id: TxId,
5235 num_packets: usize,
5236 ) -> Result<usize, WorkerError> {
5237 let (sync, segments_sent) = queue_state
5238 .queue
5239 .tx_avail(&data.tx_segments)
5240 .map_err(WorkerError::Endpoint)?;
5241
5242 assert!(segments_sent <= data.tx_segments.len());
5243
5244 let packets_sent = if segments_sent == data.tx_segments.len() {
5245 num_packets
5246 } else {
5247 net_backend::packet_count(&data.tx_segments[..segments_sent])
5248 };
5249
5250 data.tx_segments.drain(..segments_sent);
5251
5252 if sync {
5253 state.pending_tx_packets[id.0 as usize].pending_packet_count -= packets_sent;
5254 }
5255
5256 if state.pending_tx_packets[id.0 as usize].pending_packet_count == 0 {
5257 self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5258 }
5259
5260 Ok(packets_sent)
5261 }
5262
5263 fn handle_rndis(
5264 &mut self,
5265 buffers: &ChannelBuffers,
5266 id: TxId,
5267 state: &mut ActiveState,
5268 packet: &Packet<'_>,
5269 segments: &mut Vec<TxSegment>,
5270 ) -> Result<usize, WorkerError> {
5271 let mut num_packets = 0;
5272 let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5273 assert!(tx_packet.pending_packet_count == 0);
5274 tx_packet.transaction_id = packet
5275 .transaction_id
5276 .ok_or(WorkerError::MissingTransactionId)?;
5277
5278 packet
5282 .external_data
5283 .iter()
5284 .try_for_each(|range| buffers.mem.probe_gpns(range.gpns()))
5285 .map_err(WorkerError::GpaDirectError)?;
5286
5287 let mut reader = packet.rndis_reader(&buffers.mem);
5288 while reader.len() > 0 {
5293 let mut this_reader = reader.clone();
5294 let header: rndisprot::MessageHeader = this_reader.read_plain()?;
5295 if self.handle_rndis_message(
5296 buffers,
5297 state,
5298 id,
5299 header.message_type,
5300 this_reader,
5301 segments,
5302 )? {
5303 num_packets += 1;
5304 }
5305 reader.skip(header.message_length as usize)?;
5306 }
5307
5308 Ok(num_packets)
5309 }
5310
5311 fn try_send_tx_packet(
5312 &mut self,
5313 transaction_id: u64,
5314 status: protocol::Status,
5315 ) -> Result<bool, WorkerError> {
5316 let message = self.message(
5317 protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE,
5318 protocol::Message1SendRndisPacketComplete { status },
5319 );
5320 let result = self.queue.split().1.try_write(&queue::OutgoingPacket {
5321 transaction_id,
5322 packet_type: OutgoingPacketType::Completion,
5323 payload: &message.payload(),
5324 });
5325 let sent = match result {
5326 Ok(()) => true,
5327 Err(queue::TryWriteError::Full(n)) => {
5328 self.pending_send_size = n;
5329 false
5330 }
5331 Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
5332 };
5333 Ok(sent)
5334 }
5335
5336 fn send_pending_packets(&mut self, state: &mut ActiveState) -> Result<bool, WorkerError> {
5337 let mut did_some_work = false;
5338 while let Some(pending) = state.pending_tx_completions.front() {
5339 if !self.try_send_tx_packet(pending.transaction_id, pending.status)? {
5340 return Ok(did_some_work);
5341 }
5342 did_some_work = true;
5343 if let Some(id) = pending.tx_id {
5344 state.free_tx_packets.push(id);
5345 }
5346 tracing::trace!(?pending, "sent tx completion");
5347 state.pending_tx_completions.pop_front();
5348 }
5349
5350 self.pending_send_size = 0;
5351 Ok(did_some_work)
5352 }
5353
5354 fn complete_tx_packet(
5355 &mut self,
5356 state: &mut ActiveState,
5357 id: TxId,
5358 status: protocol::Status,
5359 ) -> Result<(), WorkerError> {
5360 let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5361 assert_eq!(tx_packet.pending_packet_count, 0);
5362 if self.pending_send_size == 0
5363 && self.try_send_tx_packet(tx_packet.transaction_id, status)?
5364 {
5365 tracing::trace!(id = id.0, "sent tx completion");
5366 state.free_tx_packets.push(id);
5367 } else {
5368 tracing::trace!(id = id.0, "pended tx completion");
5369 state.pending_tx_completions.push_back(PendingTxCompletion {
5370 transaction_id: tx_packet.transaction_id,
5371 tx_id: Some(id),
5372 status,
5373 });
5374 }
5375 Ok(())
5376 }
5377}
5378
5379impl ActiveState {
5380 fn release_recv_buffers(
5381 &mut self,
5382 transaction_id: u64,
5383 rx_buffer_range: &RxBufferRange,
5384 done: &mut Vec<RxId>,
5385 ) -> Option<()> {
5386 let first_id: u32 = transaction_id.try_into().ok()?;
5388 let ids = self.rx_bufs.free(first_id)?;
5389 for id in ids {
5390 if !rx_buffer_range.send_if_remote(id) {
5391 if id >= RX_RESERVED_CONTROL_BUFFERS {
5392 done.push(RxId(id));
5393 } else {
5394 self.primary
5395 .as_mut()?
5396 .free_control_buffers
5397 .push(ControlMessageId(id));
5398 }
5399 }
5400 }
5401 Some(())
5402 }
5403}