1#![expect(missing_docs)]
7#![forbid(unsafe_code)]
8
9mod buffers;
10mod protocol;
11pub mod resolver;
12mod rndisprot;
13mod rx_bufs;
14mod saved_state;
15mod test;
16
17use crate::buffers::GuestBuffers;
18use crate::protocol::VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES;
19use crate::protocol::Version;
20use crate::rndisprot::NDIS_HASH_FUNCTION_MASK;
21use crate::rndisprot::NDIS_RSS_PARAM_FLAG_DISABLE_RSS;
22use async_trait::async_trait;
23pub use buffers::BufferPool;
24use buffers::sub_allocation_size_for_mtu;
25use futures::FutureExt;
26use futures::StreamExt;
27use futures::channel::mpsc;
28use futures::channel::mpsc::TrySendError;
29use futures_concurrency::future::Race;
30use guestmem::AccessError;
31use guestmem::GuestMemory;
32use guestmem::GuestMemoryError;
33use guestmem::MemoryRead;
34use guestmem::MemoryWrite;
35use guestmem::ranges::PagedRange;
36use guestmem::ranges::PagedRanges;
37use guestmem::ranges::PagedRangesReader;
38use guid::Guid;
39use hvdef::hypercall::HvGuestOsId;
40use hvdef::hypercall::HvGuestOsMicrosoft;
41use hvdef::hypercall::HvGuestOsMicrosoftIds;
42use hvdef::hypercall::HvGuestOsOpenSourceType;
43use inspect::Inspect;
44use inspect::InspectMut;
45use inspect::SensitivityLevel;
46use inspect_counters::Counter;
47use inspect_counters::Histogram;
48use mesh::rpc::Rpc;
49use net_backend::Endpoint;
50use net_backend::EndpointAction;
51use net_backend::QueueConfig;
52use net_backend::RxId;
53use net_backend::TxError;
54use net_backend::TxId;
55use net_backend::TxSegment;
56use net_backend_resources::mac_address::MacAddress;
57use pal_async::timer::Instant;
58use pal_async::timer::PolledTimer;
59use ring::gparange::MultiPagedRangeIter;
60use rx_bufs::RxBuffers;
61use rx_bufs::SubAllocationInUse;
62use std::collections::VecDeque;
63use std::fmt::Debug;
64use std::future::pending;
65use std::mem::offset_of;
66use std::ops::Range;
67use std::sync::Arc;
68use std::sync::atomic::AtomicUsize;
69use std::sync::atomic::Ordering;
70use std::task::Poll;
71use std::time::Duration;
72use task_control::AsyncRun;
73use task_control::InspectTaskMut;
74use task_control::StopTask;
75use task_control::TaskControl;
76use thiserror::Error;
77use tracing::Instrument;
78use vmbus_async::queue;
79use vmbus_async::queue::ExternalDataError;
80use vmbus_async::queue::IncomingPacket;
81use vmbus_async::queue::Queue;
82use vmbus_channel::bus::OfferParams;
83use vmbus_channel::bus::OpenRequest;
84use vmbus_channel::channel::ChannelControl;
85use vmbus_channel::channel::ChannelOpenError;
86use vmbus_channel::channel::ChannelRestoreError;
87use vmbus_channel::channel::DeviceResources;
88use vmbus_channel::channel::RestoreControl;
89use vmbus_channel::channel::SaveRestoreVmbusDevice;
90use vmbus_channel::channel::VmbusDevice;
91use vmbus_channel::gpadl::GpadlId;
92use vmbus_channel::gpadl::GpadlMapView;
93use vmbus_channel::gpadl::GpadlView;
94use vmbus_channel::gpadl::UnknownGpadlId;
95use vmbus_channel::gpadl_ring::GpadlRingMem;
96use vmbus_channel::gpadl_ring::gpadl_channel;
97use vmbus_ring as ring;
98use vmbus_ring::OutgoingPacketType;
99use vmbus_ring::RingMem;
100use vmbus_ring::gparange::MultiPagedRangeBuf;
101use vmcore::save_restore::RestoreError;
102use vmcore::save_restore::SaveError;
103use vmcore::save_restore::SavedStateBlob;
104use vmcore::vm_task::VmTaskDriver;
105use vmcore::vm_task::VmTaskDriverSource;
106use zerocopy::FromBytes;
107use zerocopy::FromZeros;
108use zerocopy::Immutable;
109use zerocopy::IntoBytes;
110use zerocopy::KnownLayout;
111
112const MIN_CONTROL_RING_SIZE: usize = 144;
115
116const MIN_STATE_CHANGE_RING_SIZE: usize = 196;
119
120const VF_ASSOCIATION_TRANSACTION_ID: u64 = 0x8000000000000000;
122const SWITCH_DATA_PATH_TRANSACTION_ID: u64 = 0x8000000000000001;
124
125const NETVSP_MAX_SUBCHANNELS_PER_VNIC: u16 = 64;
126
127#[cfg(not(test))]
135const VF_DEVICE_DELAY: Duration = Duration::from_secs(1);
136#[cfg(test)]
137const VF_DEVICE_DELAY: Duration = Duration::from_millis(100);
138
139#[cfg(not(test))]
142const LINK_DELAY_DURATION: Duration = Duration::from_secs(5);
143#[cfg(test)]
144const LINK_DELAY_DURATION: Duration = Duration::from_millis(333);
145
146#[derive(Default, PartialEq)]
147struct CoordinatorMessageUpdateType {
148 guest_vf_state: bool,
151 filter_state: bool,
153}
154
155#[derive(PartialEq)]
156enum CoordinatorMessage {
157 Update(CoordinatorMessageUpdateType),
159 Restart,
162 StartTimer(Instant),
164}
165
166struct Worker<T: RingMem> {
167 channel_idx: u16,
168 target_vp: Option<u32>,
169 mem: GuestMemory,
170 channel: NetChannel<T>,
171 state: WorkerState,
172 coordinator_send: mpsc::Sender<CoordinatorMessage>,
173}
174
175struct NetQueue {
176 driver: VmTaskDriver,
177 queue_state: Option<QueueState>,
178}
179
180impl<T: RingMem + 'static + Sync> InspectTaskMut<Worker<T>> for NetQueue {
181 fn inspect_mut(&mut self, req: inspect::Request<'_>, worker: Option<&mut Worker<T>>) {
182 if worker.is_none() && self.queue_state.is_none() {
183 req.ignore();
184 return;
185 }
186
187 let mut resp = req.respond();
188 resp.field("driver", &self.driver);
189 if let Some(worker) = worker {
190 resp.field(
191 "protocol_state",
192 match &worker.state {
193 WorkerState::Init(None) => "version",
194 WorkerState::Init(Some(_)) => "init",
195 WorkerState::Ready(_) => "ready",
196 WorkerState::WaitingForCoordinator(_) => "waiting for coordinator",
197 },
198 )
199 .field("ring", &worker.channel.queue)
200 .field(
201 "can_use_ring_size_optimization",
202 worker.channel.can_use_ring_size_opt,
203 );
204
205 if let WorkerState::Ready(state) = &worker.state {
206 resp.field(
207 "outstanding_tx_packets",
208 state.state.pending_tx_packets.len() - state.state.free_tx_packets.len(),
209 )
210 .field("pending_rx_packets", state.state.pending_rx_packets.len())
211 .field(
212 "pending_tx_completions",
213 state.state.pending_tx_completions.len(),
214 )
215 .field("free_tx_packets", state.state.free_tx_packets.len())
216 .merge(&state.state.stats);
217 }
218
219 resp.field("packet_filter", worker.channel.packet_filter);
220 }
221
222 if let Some(queue_state) = &mut self.queue_state {
223 resp.field_mut("queue", &mut queue_state.queue)
224 .field("rx_buffers", queue_state.rx_buffer_range.id_range.len())
225 .field(
226 "rx_buffers_start",
227 queue_state.rx_buffer_range.id_range.start,
228 );
229 }
230 }
231}
232
233enum WorkerState {
234 Init(Option<InitState>),
235 Ready(ReadyState),
236 WaitingForCoordinator(Option<ReadyState>),
237}
238
239impl WorkerState {
240 fn ready(&self) -> Option<&ReadyState> {
241 if let Self::Ready(state) = self {
242 Some(state)
243 } else {
244 None
245 }
246 }
247
248 fn ready_mut(&mut self) -> Option<&mut ReadyState> {
249 if let Self::Ready(state) = self {
250 Some(state)
251 } else {
252 None
253 }
254 }
255}
256
257struct InitState {
258 version: Version,
259 ndis_config: Option<NdisConfig>,
260 ndis_version: Option<NdisVersion>,
261 recv_buffer: Option<ReceiveBuffer>,
262 send_buffer: Option<SendBuffer>,
263}
264
265#[derive(Copy, Clone, Debug, Inspect)]
266struct NdisVersion {
267 #[inspect(hex)]
268 major: u32,
269 #[inspect(hex)]
270 minor: u32,
271}
272
273#[derive(Copy, Clone, Debug, Inspect)]
274struct NdisConfig {
275 #[inspect(safe)]
276 mtu: u32,
277 #[inspect(safe)]
278 capabilities: protocol::NdisConfigCapabilities,
279}
280
281struct ReadyState {
282 buffers: Arc<ChannelBuffers>,
283 state: ActiveState,
284 data: ProcessingData,
285}
286
287impl ReadyState {
288 fn reset_tx_after_endpoint_stop(&mut self) {
297 let state = &mut self.state;
298
299 let pending_tx = state
302 .pending_tx_packets
303 .iter_mut()
304 .enumerate()
305 .filter_map(|(id, inflight)| {
306 if inflight.pending_packet_count > 0 {
307 inflight.pending_packet_count = 0;
308 Some(PendingTxCompletion {
309 transaction_id: inflight.transaction_id,
310 tx_id: Some(TxId(id as u32)),
311 status: protocol::Status::SUCCESS,
312 })
313 } else {
314 None
315 }
316 })
317 .collect::<Vec<_>>();
318 state.pending_tx_completions.extend(pending_tx);
319
320 self.data.tx_segments.clear();
323 self.data.tx_segments_sent = 0;
324 }
325}
326
327#[async_trait]
330pub trait VirtualFunction: Sync + Send {
331 async fn id(&self) -> Option<u32>;
335 async fn guest_ready_for_device(&mut self);
337 async fn wait_for_state_change(&mut self) -> Rpc<(), ()>;
340}
341
342struct Adapter {
343 driver: VmTaskDriver,
344 mac_address: MacAddress,
345 max_queues: u16,
346 indirection_table_size: u16,
347 offload_support: OffloadConfig,
348 ring_size_limit: AtomicUsize,
349 free_tx_packet_threshold: usize,
350 tx_fast_completions: bool,
351 adapter_index: u32,
352 get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
353 num_sub_channels_opened: AtomicUsize,
354 link_speed: u64,
355}
356
357struct QueueState {
358 queue: Box<dyn net_backend::Queue>,
359 pool: BufferPool,
360 rx_buffer_range: RxBufferRange,
361 target_vp_set: bool,
362}
363
364struct RxBufferRange {
365 id_range: Range<u32>,
366 remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
367 remote_ranges: Arc<RxBufferRanges>,
368}
369
370impl RxBufferRange {
371 fn new(
372 ranges: Arc<RxBufferRanges>,
373 id_range: Range<u32>,
374 remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
375 ) -> Self {
376 Self {
377 id_range,
378 remote_buffer_id_recv,
379 remote_ranges: ranges,
380 }
381 }
382
383 fn send_if_remote(&self, id: u32) -> bool {
384 if id < RX_RESERVED_CONTROL_BUFFERS || self.id_range.contains(&id) {
387 false
388 } else {
389 let i = (id - RX_RESERVED_CONTROL_BUFFERS) / self.remote_ranges.buffers_per_queue;
390 let i = (i as usize).min(self.remote_ranges.buffer_id_send.len() - 1);
394 let _ = self.remote_ranges.buffer_id_send[i].unbounded_send(id);
395 true
396 }
397 }
398}
399
400struct RxBufferRanges {
401 buffers_per_queue: u32,
402 buffer_id_send: Vec<mpsc::UnboundedSender<u32>>,
403}
404
405impl RxBufferRanges {
406 fn new(buffer_count: u32, queue_count: u32) -> (Self, Vec<mpsc::UnboundedReceiver<u32>>) {
407 let buffers_per_queue = (buffer_count - RX_RESERVED_CONTROL_BUFFERS) / queue_count;
408 #[expect(clippy::disallowed_methods)] let (send, recv): (Vec<_>, Vec<_>) = (0..queue_count).map(|_| mpsc::unbounded()).unzip();
410 (
411 Self {
412 buffers_per_queue,
413 buffer_id_send: send,
414 },
415 recv,
416 )
417 }
418}
419
420struct RssState {
421 key: [u8; 40],
422 indirection_table: Vec<u16>,
423}
424
425struct NetChannel<T: RingMem> {
427 adapter: Arc<Adapter>,
428 queue: Queue<T>,
429 gpadl_map: GpadlMapView,
430 packet_size: PacketSize,
431 pending_send_size: usize,
432 restart: Option<CoordinatorMessage>,
433 can_use_ring_size_opt: bool,
434 packet_filter: u32,
435}
436
437#[derive(Debug, Copy, Clone)]
439enum PacketSize {
440 V1,
442 V61,
444}
445
446struct ProcessingData {
448 tx_segments: Vec<TxSegment>,
449 tx_segments_sent: usize,
450 tx_done: Box<[TxId]>,
451 rx_ready: Box<[RxId]>,
452 rx_done: Vec<RxId>,
453 transfer_pages: Vec<ring::TransferPageRange>,
454 external_data: MultiPagedRangeBuf,
455}
456
457impl ProcessingData {
458 fn new() -> Self {
459 Self {
460 tx_segments: Vec::new(),
461 tx_segments_sent: 0,
462 tx_done: vec![TxId(0); 8192].into(),
463 rx_ready: vec![RxId(0); RX_BATCH_SIZE].into(),
464 rx_done: Vec::with_capacity(RX_BATCH_SIZE),
465 transfer_pages: Vec::with_capacity(RX_BATCH_SIZE),
466 external_data: MultiPagedRangeBuf::new(),
467 }
468 }
469}
470
471#[derive(Debug, Inspect)]
474struct ChannelBuffers {
475 version: Version,
476 #[inspect(skip)]
477 mem: GuestMemory,
478 #[inspect(skip)]
479 recv_buffer: ReceiveBuffer,
480 #[inspect(skip)]
481 send_buffer: Option<SendBuffer>,
482 ndis_version: NdisVersion,
483 #[inspect(safe)]
484 ndis_config: NdisConfig,
485}
486
487#[derive(Copy, Clone, Debug)]
489struct ControlMessageId(u32);
490
491struct ActiveState {
493 primary: Option<PrimaryChannelState>,
494
495 pending_tx_packets: Vec<PendingTxPacket>,
496 free_tx_packets: Vec<TxId>,
497 pending_tx_completions: VecDeque<PendingTxCompletion>,
498 pending_rx_packets: VecDeque<RxId>,
499
500 rx_bufs: RxBuffers,
501
502 stats: QueueStats,
503}
504
505#[derive(Inspect, Default)]
506struct QueueStats {
507 tx_stalled: Counter,
508 rx_dropped_ring_full: Counter,
509 rx_dropped_filtered: Counter,
510 spurious_wakes: Counter,
511 rx_packets: Counter,
512 tx_packets: Counter,
513 tx_lso_packets: Counter,
514 tx_checksum_packets: Counter,
515 tx_invalid_lso_packets: Counter,
516 tx_packets_per_wake: Histogram<10>,
517 rx_packets_per_wake: Histogram<10>,
518}
519
520#[derive(Debug)]
521struct PendingTxCompletion {
522 transaction_id: u64,
523 tx_id: Option<TxId>,
524 status: protocol::Status,
525}
526
527#[derive(Clone, Copy)]
528enum PrimaryChannelGuestVfState {
529 Initializing,
531 Restoring(saved_state::GuestVfState),
533 Unavailable,
535 UnavailableFromAvailable,
537 UnavailableFromDataPathSwitchPending { to_guest: bool, id: Option<u64> },
539 UnavailableFromDataPathSwitched,
541 Available { vfid: u32 },
543 AvailableAdvertised,
545 Ready,
547 DataPathSwitchPending {
549 to_guest: bool,
550 id: Option<u64>,
551 result: Option<bool>,
552 },
553 DataPathSwitched,
555 DataPathSynthetic,
558}
559
560impl std::fmt::Display for PrimaryChannelGuestVfState {
561 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
562 match self {
563 PrimaryChannelGuestVfState::Initializing => write!(f, "initializing"),
564 PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
565 write!(f, "restoring")
566 }
567 PrimaryChannelGuestVfState::Restoring(
568 saved_state::GuestVfState::AvailableAdvertised,
569 ) => write!(f, "restoring from guest notified of vfid"),
570 PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
571 write!(f, "restoring from vf present")
572 }
573 PrimaryChannelGuestVfState::Restoring(
574 saved_state::GuestVfState::DataPathSwitchPending {
575 to_guest, result, ..
576 },
577 ) => {
578 write!(
579 f,
580 "restoring from client requested data path switch: to {} {}",
581 if *to_guest { "guest" } else { "synthetic" },
582 if let Some(result) = result {
583 if *result { "succeeded\"" } else { "failed\"" }
584 } else {
585 "in progress\""
586 }
587 )
588 }
589 PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::DataPathSwitched) => {
590 write!(f, "restoring from data path in guest")
591 }
592 PrimaryChannelGuestVfState::Unavailable => write!(f, "unavailable"),
593 PrimaryChannelGuestVfState::UnavailableFromAvailable => {
594 write!(f, "\"unavailable (previously available)\"")
595 }
596 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. } => {
597 write!(f, "unavailable (previously switching data path)")
598 }
599 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
600 write!(f, "\"unavailable (previously using guest VF)\"")
601 }
602 PrimaryChannelGuestVfState::Available { vfid } => write!(f, "available vfid: {}", vfid),
603 PrimaryChannelGuestVfState::AvailableAdvertised => {
604 write!(f, "\"available, guest notified\"")
605 }
606 PrimaryChannelGuestVfState::Ready => write!(f, "\"available and present in guest\""),
607 PrimaryChannelGuestVfState::DataPathSwitchPending {
608 to_guest, result, ..
609 } => {
610 write!(
611 f,
612 "\"switching to {} {}",
613 if *to_guest { "guest" } else { "synthetic" },
614 if let Some(result) = result {
615 if *result { "succeeded\"" } else { "failed\"" }
616 } else {
617 "in progress\""
618 }
619 )
620 }
621 PrimaryChannelGuestVfState::DataPathSwitched => {
622 write!(f, "\"available and data path switched\"")
623 }
624 PrimaryChannelGuestVfState::DataPathSynthetic => write!(
625 f,
626 "\"available but data path switched back to synthetic due to external state change\""
627 ),
628 }
629 }
630}
631
632impl Inspect for PrimaryChannelGuestVfState {
633 fn inspect(&self, req: inspect::Request<'_>) {
634 req.value(self.to_string());
635 }
636}
637
638#[derive(Debug)]
639enum PendingLinkAction {
640 Default,
641 Active(bool),
642 Delay(bool),
643}
644
645struct PrimaryChannelState {
646 guest_vf_state: PrimaryChannelGuestVfState,
647 is_data_path_switched: Option<bool>,
648 control_messages: VecDeque<ControlMessage>,
649 control_messages_len: usize,
650 free_control_buffers: Vec<ControlMessageId>,
651 rss_state: Option<RssState>,
652 requested_num_queues: u16,
653 rndis_state: RndisState,
654 offload_config: OffloadConfig,
655 pending_offload_change: bool,
656 tx_spread_sent: bool,
657 guest_link_up: bool,
658 pending_link_action: PendingLinkAction,
659}
660
661impl Inspect for PrimaryChannelState {
662 fn inspect(&self, req: inspect::Request<'_>) {
663 req.respond()
664 .sensitivity_field(
665 "guest_vf_state",
666 SensitivityLevel::Safe,
667 self.guest_vf_state,
668 )
669 .sensitivity_field(
670 "data_path_switched",
671 SensitivityLevel::Safe,
672 self.is_data_path_switched,
673 )
674 .sensitivity_field(
675 "pending_control_messages",
676 SensitivityLevel::Safe,
677 self.control_messages.len(),
678 )
679 .sensitivity_field(
680 "free_control_message_buffers",
681 SensitivityLevel::Safe,
682 self.free_control_buffers.len(),
683 )
684 .sensitivity_field(
685 "pending_offload_change",
686 SensitivityLevel::Safe,
687 self.pending_offload_change,
688 )
689 .sensitivity_field("rndis_state", SensitivityLevel::Safe, self.rndis_state)
690 .sensitivity_field(
691 "offload_config",
692 SensitivityLevel::Safe,
693 &self.offload_config,
694 )
695 .sensitivity_field(
696 "tx_spread_sent",
697 SensitivityLevel::Safe,
698 self.tx_spread_sent,
699 )
700 .sensitivity_field("guest_link_up", SensitivityLevel::Safe, self.guest_link_up)
701 .sensitivity_field(
702 "pending_link_action",
703 SensitivityLevel::Safe,
704 match &self.pending_link_action {
705 PendingLinkAction::Active(up) => format!("Active({:x?})", up),
706 PendingLinkAction::Delay(up) => format!("Delay({:x?})", up),
707 PendingLinkAction::Default => "None".to_string(),
708 },
709 );
710 }
711}
712
713#[derive(Debug, Inspect, Clone)]
714struct OffloadConfig {
715 #[inspect(safe)]
716 checksum_tx: ChecksumOffloadConfig,
717 #[inspect(safe)]
718 checksum_rx: ChecksumOffloadConfig,
719 #[inspect(safe)]
720 lso4: bool,
721 #[inspect(safe)]
722 lso6: bool,
723}
724
725impl OffloadConfig {
726 fn mask_to_supported(&mut self, supported: &OffloadConfig) {
727 self.checksum_tx.mask_to_supported(&supported.checksum_tx);
728 self.checksum_rx.mask_to_supported(&supported.checksum_rx);
729 self.lso4 &= supported.lso4;
730 self.lso6 &= supported.lso6;
731 }
732}
733
734#[derive(Debug, Inspect, Clone)]
735struct ChecksumOffloadConfig {
736 #[inspect(safe)]
737 ipv4_header: bool,
738 #[inspect(safe)]
739 tcp4: bool,
740 #[inspect(safe)]
741 udp4: bool,
742 #[inspect(safe)]
743 tcp6: bool,
744 #[inspect(safe)]
745 udp6: bool,
746}
747
748impl ChecksumOffloadConfig {
749 fn mask_to_supported(&mut self, supported: &ChecksumOffloadConfig) {
750 self.ipv4_header &= supported.ipv4_header;
751 self.tcp4 &= supported.tcp4;
752 self.udp4 &= supported.udp4;
753 self.tcp6 &= supported.tcp6;
754 self.udp6 &= supported.udp6;
755 }
756
757 fn flags(
758 &self,
759 ) -> (
760 rndisprot::Ipv4ChecksumOffload,
761 rndisprot::Ipv6ChecksumOffload,
762 ) {
763 let on = rndisprot::NDIS_OFFLOAD_SUPPORTED;
764 let mut v4 = rndisprot::Ipv4ChecksumOffload::new();
765 let mut v6 = rndisprot::Ipv6ChecksumOffload::new();
766 if self.ipv4_header {
767 v4.set_ip_options_supported(on);
768 v4.set_ip_checksum(on);
769 }
770 if self.tcp4 {
771 v4.set_ip_options_supported(on);
772 v4.set_tcp_options_supported(on);
773 v4.set_tcp_checksum(on);
774 }
775 if self.tcp6 {
776 v6.set_ip_extension_headers_supported(on);
777 v6.set_tcp_options_supported(on);
778 v6.set_tcp_checksum(on);
779 }
780 if self.udp4 {
781 v4.set_ip_options_supported(on);
782 v4.set_udp_checksum(on);
783 }
784 if self.udp6 {
785 v6.set_ip_extension_headers_supported(on);
786 v6.set_udp_checksum(on);
787 }
788 (v4, v6)
789 }
790}
791
792impl OffloadConfig {
793 fn ndis_offload(&self) -> rndisprot::NdisOffload {
794 let checksum = {
795 let (ipv4_tx_flags, ipv6_tx_flags) = self.checksum_tx.flags();
796 let (ipv4_rx_flags, ipv6_rx_flags) = self.checksum_rx.flags();
797 rndisprot::TcpIpChecksumOffload {
798 ipv4_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
799 ipv4_tx_flags,
800 ipv4_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
801 ipv4_rx_flags,
802 ipv6_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
803 ipv6_tx_flags,
804 ipv6_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
805 ipv6_rx_flags,
806 }
807 };
808
809 let lso_v2 = {
810 let mut lso = rndisprot::TcpLargeSendOffloadV2::new_zeroed();
811 if self.lso4 {
812 lso.ipv4_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
813 lso.ipv4_max_offload_size = rndisprot::LSO_MAX_OFFLOAD_SIZE;
814 lso.ipv4_min_segment_count = rndisprot::LSO_MIN_SEGMENT_COUNT;
815 }
816 if self.lso6 {
817 lso.ipv6_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
818 lso.ipv6_max_offload_size = rndisprot::LSO_MAX_OFFLOAD_SIZE;
819 lso.ipv6_min_segment_count = rndisprot::LSO_MIN_SEGMENT_COUNT;
820 lso.ipv6_flags = rndisprot::Ipv6LsoFlags::new()
821 .with_ip_extension_headers_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED)
822 .with_tcp_options_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED);
823 }
824 lso
825 };
826
827 rndisprot::NdisOffload {
828 header: rndisprot::NdisObjectHeader {
829 object_type: rndisprot::NdisObjectType::OFFLOAD,
830 revision: 3,
831 size: rndisprot::NDIS_SIZEOF_NDIS_OFFLOAD_REVISION_3 as u16,
832 },
833 checksum,
834 lso_v2,
835 ..FromZeros::new_zeroed()
836 }
837 }
838}
839
840#[derive(Debug, Inspect, PartialEq, Eq, Copy, Clone)]
841pub enum RndisState {
842 Initializing,
843 Operational,
844 Halted,
845}
846
847impl PrimaryChannelState {
848 fn new(offload_config: OffloadConfig) -> Self {
849 Self {
850 guest_vf_state: PrimaryChannelGuestVfState::Initializing,
851 is_data_path_switched: None,
852 control_messages: VecDeque::new(),
853 control_messages_len: 0,
854 free_control_buffers: (0..RX_RESERVED_CONTROL_BUFFERS)
855 .map(ControlMessageId)
856 .collect(),
857 rss_state: None,
858 requested_num_queues: 1,
859 rndis_state: RndisState::Initializing,
860 pending_offload_change: false,
861 offload_config,
862 tx_spread_sent: false,
863 guest_link_up: true,
864 pending_link_action: PendingLinkAction::Default,
865 }
866 }
867
868 fn restore(
869 guest_vf_state: &saved_state::GuestVfState,
870 rndis_state: &saved_state::RndisState,
871 offload_config: &saved_state::OffloadConfig,
872 pending_offload_change: bool,
873 num_queues: u16,
874 indirection_table_size: u16,
875 rx_bufs: &RxBuffers,
876 control_messages: Vec<saved_state::IncomingControlMessage>,
877 rss_state: Option<saved_state::RssState>,
878 tx_spread_sent: bool,
879 guest_link_down: bool,
880 pending_link_action: Option<bool>,
881 ) -> Result<Self, NetRestoreError> {
882 let control_messages_len = control_messages.iter().map(|msg| msg.data.len()).sum();
884
885 let control_messages = control_messages
886 .into_iter()
887 .map(|msg| ControlMessage {
888 message_type: msg.message_type,
889 data: msg.data.into(),
890 })
891 .collect();
892
893 let free_control_buffers = (0..RX_RESERVED_CONTROL_BUFFERS)
895 .filter_map(|id| rx_bufs.is_free(id).then_some(ControlMessageId(id)))
896 .collect();
897
898 let rss_state = rss_state
899 .map(|mut rss| {
900 if rss.indirection_table.len() > indirection_table_size as usize {
901 return Err(NetRestoreError::ReducedIndirectionTableSize);
904 }
905 if rss.indirection_table.len() < indirection_table_size as usize {
906 tracing::warn!(
907 saved_indirection_table_size = rss.indirection_table.len(),
908 adapter_indirection_table_size = indirection_table_size,
909 "increasing indirection table size",
910 );
911 let table_clone = rss.indirection_table.clone();
914 let num_to_add = indirection_table_size as usize - rss.indirection_table.len();
915 rss.indirection_table
916 .extend(table_clone.iter().cycle().take(num_to_add));
917 }
918 Ok(RssState {
919 key: rss
920 .key
921 .try_into()
922 .map_err(|_| NetRestoreError::InvalidRssKeySize)?,
923 indirection_table: rss.indirection_table,
924 })
925 })
926 .transpose()?;
927
928 let rndis_state = match rndis_state {
929 saved_state::RndisState::Initializing => RndisState::Initializing,
930 saved_state::RndisState::Operational => RndisState::Operational,
931 saved_state::RndisState::Halted => RndisState::Halted,
932 };
933
934 let guest_vf_state = PrimaryChannelGuestVfState::Restoring(*guest_vf_state);
935 let offload_config = OffloadConfig {
936 checksum_tx: ChecksumOffloadConfig {
937 ipv4_header: offload_config.checksum_tx.ipv4_header,
938 tcp4: offload_config.checksum_tx.tcp4,
939 udp4: offload_config.checksum_tx.udp4,
940 tcp6: offload_config.checksum_tx.tcp6,
941 udp6: offload_config.checksum_tx.udp6,
942 },
943 checksum_rx: ChecksumOffloadConfig {
944 ipv4_header: offload_config.checksum_rx.ipv4_header,
945 tcp4: offload_config.checksum_rx.tcp4,
946 udp4: offload_config.checksum_rx.udp4,
947 tcp6: offload_config.checksum_rx.tcp6,
948 udp6: offload_config.checksum_rx.udp6,
949 },
950 lso4: offload_config.lso4,
951 lso6: offload_config.lso6,
952 };
953
954 let pending_link_action = if let Some(pending) = pending_link_action {
955 PendingLinkAction::Active(pending)
956 } else {
957 PendingLinkAction::Default
958 };
959
960 Ok(Self {
961 guest_vf_state,
962 is_data_path_switched: None,
963 control_messages,
964 control_messages_len,
965 free_control_buffers,
966 rss_state,
967 requested_num_queues: num_queues,
968 rndis_state,
969 pending_offload_change,
970 offload_config,
971 tx_spread_sent,
972 guest_link_up: !guest_link_down,
973 pending_link_action,
974 })
975 }
976}
977
978struct ControlMessage {
979 message_type: u32,
980 data: Box<[u8]>,
981}
982
983const TX_PACKET_QUOTA: usize = 1024;
984
985impl ActiveState {
986 fn new(primary: Option<PrimaryChannelState>, recv_buffer_count: u32) -> Self {
987 Self {
988 primary,
989 pending_tx_packets: vec![Default::default(); TX_PACKET_QUOTA],
990 free_tx_packets: (0..TX_PACKET_QUOTA as u32).rev().map(TxId).collect(),
991 pending_tx_completions: VecDeque::new(),
992 pending_rx_packets: VecDeque::new(),
993 rx_bufs: RxBuffers::new(recv_buffer_count),
994 stats: Default::default(),
995 }
996 }
997
998 fn restore(
999 channel: &saved_state::Channel,
1000 recv_buffer_count: u32,
1001 ) -> Result<Self, NetRestoreError> {
1002 let mut active = Self::new(None, recv_buffer_count);
1003 let saved_state::Channel {
1004 pending_tx_completions,
1005 in_use_rx,
1006 } = channel;
1007 for rx in in_use_rx {
1008 active
1009 .rx_bufs
1010 .allocate(rx.buffers.as_slice().iter().copied())?;
1011 }
1012 for &transaction_id in pending_tx_completions {
1013 let tx_id = active.free_tx_packets.pop();
1017 if let Some(id) = tx_id {
1018 active.pending_tx_packets[id.0 as usize].transaction_id = transaction_id;
1020 }
1021 active
1024 .pending_tx_completions
1025 .push_back(PendingTxCompletion {
1026 transaction_id,
1027 tx_id,
1028 status: protocol::Status::SUCCESS,
1029 });
1030 }
1031 Ok(active)
1032 }
1033}
1034
1035#[derive(Default, Clone)]
1038struct PendingTxPacket {
1039 pending_packet_count: usize,
1040 transaction_id: u64,
1041}
1042
1043const RX_BATCH_SIZE: usize = 375;
1048
1049const RX_RESERVED_CONTROL_BUFFERS: u32 = 16;
1051
1052pub struct Nic {
1054 instance_id: Guid,
1055 resources: DeviceResources,
1056 coordinator: TaskControl<CoordinatorState, Coordinator>,
1057 coordinator_send: Option<mpsc::Sender<CoordinatorMessage>>,
1058 adapter: Arc<Adapter>,
1059 driver_source: VmTaskDriverSource,
1060}
1061
1062pub struct NicBuilder {
1063 virtual_function: Option<Box<dyn VirtualFunction>>,
1064 limit_ring_buffer: bool,
1065 max_queues: u16,
1066 get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
1067}
1068
1069impl NicBuilder {
1070 pub fn limit_ring_buffer(mut self, limit: bool) -> Self {
1071 self.limit_ring_buffer = limit;
1072 self
1073 }
1074
1075 pub fn max_queues(mut self, max_queues: u16) -> Self {
1076 self.max_queues = max_queues;
1077 self
1078 }
1079
1080 pub fn virtual_function(mut self, virtual_function: Box<dyn VirtualFunction>) -> Self {
1081 self.virtual_function = Some(virtual_function);
1082 self
1083 }
1084
1085 pub fn get_guest_os_id(mut self, os_type: Box<dyn Fn() -> HvGuestOsId + Send + Sync>) -> Self {
1086 self.get_guest_os_id = Some(os_type);
1087 self
1088 }
1089
1090 pub fn build(
1092 self,
1093 driver_source: &VmTaskDriverSource,
1094 instance_id: Guid,
1095 endpoint: Box<dyn Endpoint>,
1096 mac_address: MacAddress,
1097 adapter_index: u32,
1098 ) -> Nic {
1099 let multiqueue = endpoint.multiqueue_support();
1100
1101 let max_queues = self.max_queues.clamp(
1102 1,
1103 multiqueue.max_queues.min(NETVSP_MAX_SUBCHANNELS_PER_VNIC),
1104 );
1105
1106 let ring_size_limit = if self.limit_ring_buffer { 1024 } else { 0 };
1111
1112 let free_tx_packet_threshold = if endpoint.tx_fast_completions() {
1117 TX_PACKET_QUOTA
1118 } else {
1119 TX_PACKET_QUOTA / 4
1122 };
1123
1124 let tx_offloads = endpoint.tx_offload_support();
1125
1126 let offload_support = OffloadConfig {
1129 checksum_rx: ChecksumOffloadConfig {
1130 ipv4_header: true,
1131 tcp4: true,
1132 udp4: true,
1133 tcp6: true,
1134 udp6: true,
1135 },
1136 checksum_tx: ChecksumOffloadConfig {
1137 ipv4_header: tx_offloads.ipv4_header,
1138 tcp4: tx_offloads.tcp,
1139 tcp6: tx_offloads.tcp,
1140 udp4: tx_offloads.udp,
1141 udp6: tx_offloads.udp,
1142 },
1143 lso4: tx_offloads.tso && tx_offloads.ipv4_header,
1147 lso6: tx_offloads.tso,
1148 };
1149
1150 let driver = driver_source.simple();
1151 let adapter = Arc::new(Adapter {
1152 driver,
1153 mac_address,
1154 max_queues,
1155 indirection_table_size: multiqueue.indirection_table_size,
1156 offload_support,
1157 free_tx_packet_threshold,
1158 ring_size_limit: ring_size_limit.into(),
1159 tx_fast_completions: endpoint.tx_fast_completions(),
1160 adapter_index,
1161 get_guest_os_id: self.get_guest_os_id,
1162 num_sub_channels_opened: AtomicUsize::new(0),
1163 link_speed: endpoint.link_speed(),
1164 });
1165
1166 let coordinator = TaskControl::new(CoordinatorState {
1167 endpoint,
1168 adapter: adapter.clone(),
1169 virtual_function: self.virtual_function,
1170 pending_vf_state: CoordinatorStatePendingVfState::Ready,
1171 });
1172
1173 Nic {
1174 instance_id,
1175 resources: Default::default(),
1176 coordinator,
1177 coordinator_send: None,
1178 adapter,
1179 driver_source: driver_source.clone(),
1180 }
1181 }
1182}
1183
1184fn can_use_ring_opt<T: RingMem>(queue: &mut Queue<T>, guest_os_id: Option<HvGuestOsId>) -> bool {
1185 let Some(guest_os_id) = guest_os_id else {
1186 return false;
1188 };
1189
1190 if !queue.split().0.supports_pending_send_size() {
1191 return false;
1193 }
1194
1195 let Some(open_source_os) = guest_os_id.open_source() else {
1196 return true;
1198 };
1199
1200 match HvGuestOsOpenSourceType(open_source_os.os_type()) {
1201 HvGuestOsOpenSourceType::FREEBSD => open_source_os.version() >= 1400097,
1204 HvGuestOsOpenSourceType::LINUX => {
1208 open_source_os.version() >= 199424
1211 }
1212 _ => true,
1213 }
1214}
1215
1216impl Nic {
1217 pub fn builder() -> NicBuilder {
1218 NicBuilder {
1219 virtual_function: None,
1220 limit_ring_buffer: false,
1221 max_queues: !0,
1222 get_guest_os_id: None,
1223 }
1224 }
1225
1226 pub fn shutdown(self) -> (Box<dyn Endpoint>, MacAddress) {
1227 let (state, _) = self.coordinator.into_inner();
1228 (state.endpoint, self.adapter.mac_address)
1229 }
1230}
1231
1232impl InspectMut for Nic {
1233 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
1234 self.coordinator.inspect_mut(req);
1235 }
1236}
1237
1238#[async_trait]
1239impl VmbusDevice for Nic {
1240 fn offer(&self) -> OfferParams {
1241 OfferParams {
1242 interface_name: "net".to_owned(),
1243 instance_id: self.instance_id,
1244 interface_id: Guid {
1245 data1: 0xf8615163,
1246 data2: 0xdf3e,
1247 data3: 0x46c5,
1248 data4: [0x91, 0x3f, 0xf2, 0xd2, 0xf9, 0x65, 0xed, 0xe],
1249 },
1250 subchannel_index: 0,
1251 mnf_interrupt_latency: Some(Duration::from_micros(100)),
1252 ..Default::default()
1253 }
1254 }
1255
1256 fn max_subchannels(&self) -> u16 {
1257 self.adapter.max_queues
1258 }
1259
1260 fn install(&mut self, resources: DeviceResources) {
1261 self.resources = resources;
1262 }
1263
1264 async fn open(
1265 &mut self,
1266 channel_idx: u16,
1267 open_request: &OpenRequest,
1268 ) -> Result<(), ChannelOpenError> {
1269 let state = if channel_idx == 0 {
1271 self.insert_coordinator(1, None);
1272 WorkerState::Init(None)
1273 } else {
1274 self.coordinator.stop().await;
1275 let buffers = self.coordinator.state().unwrap().buffers.clone().unwrap();
1277 WorkerState::Ready(ReadyState {
1278 state: ActiveState::new(None, buffers.recv_buffer.count),
1279 buffers,
1280 data: ProcessingData::new(),
1281 })
1282 };
1283
1284 let num_opened = self
1285 .adapter
1286 .num_sub_channels_opened
1287 .fetch_add(1, Ordering::SeqCst);
1288 let r = self.insert_worker(channel_idx, open_request, state, true);
1289 if channel_idx != 0
1290 && num_opened + 1 == self.coordinator.state_mut().unwrap().num_queues as usize
1291 {
1292 let coordinator = &mut self.coordinator.state_mut().unwrap();
1293 coordinator.workers[0].stop().await;
1294 }
1295
1296 if r.is_err() && channel_idx == 0 {
1297 self.coordinator.remove();
1298 } else {
1299 self.coordinator.start();
1301 }
1302 r?;
1303 Ok(())
1304 }
1305
1306 async fn close(&mut self, channel_idx: u16) {
1307 if !self.coordinator.has_state() {
1308 tracing::error!(
1309 channel_idx,
1310 instance_id = %self.instance_id,
1311 "Close called while vmbus channel is already closed"
1312 );
1313 return;
1314 }
1315
1316 let restart = self.coordinator.stop().await;
1318
1319 {
1321 let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1322 worker.stop().await;
1323 if worker.has_state() {
1324 worker.remove();
1325 }
1326 }
1327
1328 self.adapter
1329 .num_sub_channels_opened
1330 .fetch_sub(1, Ordering::SeqCst);
1331 if channel_idx == 0 {
1333 for worker in &mut self.coordinator.state_mut().unwrap().workers {
1334 worker.task_mut().queue_state = None;
1335 }
1336
1337 self.coordinator.task_mut().endpoint.stop().await;
1339
1340 self.coordinator.remove();
1345 } else {
1346 if restart {
1348 self.coordinator.start();
1349 }
1350 }
1351 }
1352
1353 async fn retarget_vp(&mut self, channel_idx: u16, target_vp: u32) {
1354 if !self.coordinator.has_state() {
1355 return;
1356 }
1357
1358 let coordinator_running = self.coordinator.stop().await;
1360 let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1361 worker.stop().await;
1362 let (net_queue, worker_state) = worker.get_mut();
1363
1364 net_queue.driver.retarget_vp(target_vp);
1366
1367 if let Some(worker_state) = worker_state {
1368 worker_state.target_vp = Some(target_vp);
1370 if let Some(queue_state) = &mut net_queue.queue_state {
1371 queue_state.target_vp_set = false;
1373 }
1374 }
1375
1376 if coordinator_running {
1378 self.coordinator.start();
1379 }
1380 }
1381
1382 fn start(&mut self) {
1383 if !self.coordinator.is_running() {
1384 self.coordinator.start();
1385 }
1386 }
1387
1388 async fn stop(&mut self) {
1389 self.coordinator.stop().await;
1390 if let Some(coordinator) = self.coordinator.state_mut() {
1391 coordinator.stop_workers().await;
1392 }
1393 }
1394
1395 fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
1396 Some(self)
1397 }
1398}
1399
1400#[async_trait]
1401impl SaveRestoreVmbusDevice for Nic {
1402 async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
1403 let state = self.saved_state();
1404 Ok(SavedStateBlob::new(state))
1405 }
1406
1407 async fn restore(
1408 &mut self,
1409 control: RestoreControl<'_>,
1410 state: SavedStateBlob,
1411 ) -> Result<(), RestoreError> {
1412 let state: saved_state::SavedState = state.parse()?;
1413 if let Err(err) = self.restore_state(control, state).await {
1414 tracing::error!(
1415 error = &err as &dyn std::error::Error,
1416 instance_id = %self.instance_id,
1417 "Failed restoring network vmbus state"
1418 );
1419 Err(err.into())
1420 } else {
1421 Ok(())
1422 }
1423 }
1424}
1425
1426impl Nic {
1427 fn insert_worker(
1431 &mut self,
1432 channel_idx: u16,
1433 open_request: &OpenRequest,
1434 state: WorkerState,
1435 start: bool,
1436 ) -> Result<(), OpenError> {
1437 let coordinator = self.coordinator.state_mut().unwrap();
1438
1439 let driver = coordinator.workers[channel_idx as usize]
1443 .task()
1444 .driver
1445 .clone();
1446 driver.retarget_vp(open_request.open_data.target_vp.unwrap_or_default());
1447
1448 let raw = gpadl_channel(&driver, &self.resources, open_request, channel_idx)
1449 .map_err(OpenError::Ring)?;
1450 let mut queue = Queue::new(raw).map_err(OpenError::Queue)?;
1451 let guest_os_id = self.adapter.get_guest_os_id.as_ref().map(|f| f());
1452 let can_use_ring_size_opt = can_use_ring_opt(&mut queue, guest_os_id);
1453 let worker = Worker {
1454 channel_idx,
1455 target_vp: open_request.open_data.target_vp,
1456 mem: self
1457 .resources
1458 .offer_resources
1459 .guest_memory(open_request)
1460 .clone(),
1461 channel: NetChannel {
1462 adapter: self.adapter.clone(),
1463 queue,
1464 gpadl_map: self.resources.gpadl_map.clone(),
1465 packet_size: PacketSize::V1,
1466 pending_send_size: 0,
1467 restart: None,
1468 can_use_ring_size_opt,
1469 packet_filter: coordinator.active_packet_filter,
1470 },
1471 state,
1472 coordinator_send: self.coordinator_send.clone().unwrap(),
1473 };
1474 let instance_id = self.instance_id;
1475 let worker_task = &mut coordinator.workers[channel_idx as usize];
1476 worker_task.insert(
1477 driver,
1478 format!("netvsp-{}-{}", instance_id, channel_idx),
1479 worker,
1480 );
1481 if start {
1482 worker_task.start();
1483 }
1484 Ok(())
1485 }
1486}
1487
1488struct RestoreCoordinatorState {
1489 active_packet_filter: u32,
1490}
1491
1492impl Nic {
1493 fn insert_coordinator(&mut self, num_queues: u16, restoring: Option<RestoreCoordinatorState>) {
1495 let mut driver_builder = self.driver_source.builder();
1496 driver_builder.target_vp(0);
1499 driver_builder.run_on_target(!self.adapter.tx_fast_completions);
1504
1505 #[expect(clippy::disallowed_methods)] let (send, recv) = mpsc::channel(1);
1507 self.coordinator_send = Some(send);
1508 self.coordinator.insert(
1509 &self.adapter.driver,
1510 format!("netvsp-{}-coordinator", self.instance_id),
1511 Coordinator {
1512 recv,
1513 channel_control: self.resources.channel_control.clone(),
1514 restart: restoring.is_some(),
1515 workers: (0..self.adapter.max_queues)
1516 .map(|i| {
1517 TaskControl::new(NetQueue {
1518 queue_state: None,
1519 driver: driver_builder
1520 .build(format!("netvsp-{}-{}", self.instance_id, i)),
1521 })
1522 })
1523 .collect(),
1524 buffers: None,
1525 num_queues,
1526 active_packet_filter: restoring
1527 .map(|r| r.active_packet_filter)
1528 .unwrap_or(rndisprot::NDIS_PACKET_TYPE_NONE),
1529 sleep_deadline: None,
1530 },
1531 );
1532 }
1533}
1534
1535#[derive(Debug, Error)]
1536enum NetRestoreError {
1537 #[error("unsupported protocol version {0:#x}")]
1538 UnsupportedVersion(u32),
1539 #[error("send/receive buffer invalid gpadl ID")]
1540 UnknownGpadlId(#[from] UnknownGpadlId),
1541 #[error("failed to restore channels")]
1542 Channel(#[source] ChannelRestoreError),
1543 #[error(transparent)]
1544 ReceiveBuffer(#[from] BufferError),
1545 #[error(transparent)]
1546 SuballocationMisconfigured(#[from] SubAllocationInUse),
1547 #[error(transparent)]
1548 Open(#[from] OpenError),
1549 #[error("invalid rss key size")]
1550 InvalidRssKeySize,
1551 #[error("reduced indirection table size")]
1552 ReducedIndirectionTableSize,
1553}
1554
1555impl From<NetRestoreError> for RestoreError {
1556 fn from(err: NetRestoreError) -> Self {
1557 RestoreError::InvalidSavedState(anyhow::Error::new(err))
1558 }
1559}
1560
1561impl Nic {
1562 async fn restore_state(
1563 &mut self,
1564 mut control: RestoreControl<'_>,
1565 state: saved_state::SavedState,
1566 ) -> Result<(), NetRestoreError> {
1567 let mut saved_packet_filter = 0u32;
1568 if let Some(state) = state.open {
1569 let open = match &state.primary {
1570 saved_state::Primary::Version => vec![true],
1571 saved_state::Primary::Init(_) => vec![true],
1572 saved_state::Primary::Ready(ready) => {
1573 ready.channels.iter().map(|x| x.is_some()).collect()
1574 }
1575 };
1576
1577 let mut states: Vec<_> = open.iter().map(|_| None).collect();
1578
1579 let requests = control
1585 .restore(&open)
1586 .await
1587 .map_err(NetRestoreError::Channel)?;
1588
1589 match state.primary {
1590 saved_state::Primary::Version => {
1591 states[0] = Some(WorkerState::Init(None));
1592 }
1593 saved_state::Primary::Init(init) => {
1594 let version = check_version(init.version)
1595 .ok_or(NetRestoreError::UnsupportedVersion(init.version))?;
1596
1597 let recv_buffer = init
1598 .receive_buffer
1599 .map(|recv_buffer| {
1600 ReceiveBuffer::new(
1601 &self.resources.gpadl_map,
1602 recv_buffer.gpadl_id,
1603 recv_buffer.id,
1604 recv_buffer.sub_allocation_size,
1605 )
1606 })
1607 .transpose()?;
1608
1609 let send_buffer = init
1610 .send_buffer
1611 .map(|send_buffer| {
1612 SendBuffer::new(&self.resources.gpadl_map, send_buffer.gpadl_id)
1613 })
1614 .transpose()?;
1615
1616 let state = InitState {
1617 version,
1618 ndis_config: init.ndis_config.map(
1619 |saved_state::NdisConfig { mtu, capabilities }| NdisConfig {
1620 mtu,
1621 capabilities: capabilities.into(),
1622 },
1623 ),
1624 ndis_version: init.ndis_version.map(
1625 |saved_state::NdisVersion { major, minor }| NdisVersion {
1626 major,
1627 minor,
1628 },
1629 ),
1630 recv_buffer,
1631 send_buffer,
1632 };
1633 states[0] = Some(WorkerState::Init(Some(state)));
1634 }
1635 saved_state::Primary::Ready(ready) => {
1636 let saved_state::ReadyPrimary {
1637 version,
1638 receive_buffer,
1639 send_buffer,
1640 mut control_messages,
1641 mut rss_state,
1642 channels,
1643 ndis_version,
1644 ndis_config,
1645 rndis_state,
1646 guest_vf_state,
1647 offload_config,
1648 pending_offload_change,
1649 tx_spread_sent,
1650 guest_link_down,
1651 pending_link_action,
1652 packet_filter,
1653 } = ready;
1654
1655 saved_packet_filter = packet_filter.unwrap_or(rndisprot::NPROTO_PACKET_FILTER);
1657
1658 let version = check_version(version)
1659 .ok_or(NetRestoreError::UnsupportedVersion(version))?;
1660
1661 let request = requests[0].as_ref().unwrap();
1662 let buffers = Arc::new(ChannelBuffers {
1663 version,
1664 mem: self.resources.offer_resources.guest_memory(request).clone(),
1665 recv_buffer: ReceiveBuffer::new(
1666 &self.resources.gpadl_map,
1667 receive_buffer.gpadl_id,
1668 receive_buffer.id,
1669 receive_buffer.sub_allocation_size,
1670 )?,
1671 send_buffer: {
1672 if let Some(send_buffer) = send_buffer {
1673 Some(SendBuffer::new(
1674 &self.resources.gpadl_map,
1675 send_buffer.gpadl_id,
1676 )?)
1677 } else {
1678 None
1679 }
1680 },
1681 ndis_version: {
1682 let saved_state::NdisVersion { major, minor } = ndis_version;
1683 NdisVersion { major, minor }
1684 },
1685 ndis_config: {
1686 let saved_state::NdisConfig { mtu, capabilities } = ndis_config;
1687 NdisConfig {
1688 mtu,
1689 capabilities: capabilities.into(),
1690 }
1691 },
1692 });
1693
1694 for (channel_idx, channel) in channels.iter().enumerate() {
1695 let channel = if let Some(channel) = channel {
1696 channel
1697 } else {
1698 continue;
1699 };
1700
1701 let mut active = ActiveState::restore(channel, buffers.recv_buffer.count)?;
1702
1703 if channel_idx == 0 {
1705 let primary = PrimaryChannelState::restore(
1706 &guest_vf_state,
1707 &rndis_state,
1708 &offload_config,
1709 pending_offload_change,
1710 channels.len() as u16,
1711 self.adapter.indirection_table_size,
1712 &active.rx_bufs,
1713 std::mem::take(&mut control_messages),
1714 rss_state.take(),
1715 tx_spread_sent,
1716 guest_link_down,
1717 pending_link_action,
1718 )?;
1719 active.primary = Some(primary);
1720 }
1721
1722 states[channel_idx] = Some(WorkerState::Ready(ReadyState {
1723 buffers: buffers.clone(),
1724 state: active,
1725 data: ProcessingData::new(),
1726 }));
1727 }
1728 }
1729 }
1730
1731 self.insert_coordinator(
1734 states.len() as u16,
1735 Some(RestoreCoordinatorState {
1736 active_packet_filter: saved_packet_filter,
1737 }),
1738 );
1739
1740 for (channel_idx, (state, request)) in states.into_iter().zip(requests).enumerate() {
1741 if let Some(state) = state {
1742 self.insert_worker(channel_idx as u16, &request.unwrap(), state, false)?;
1743 }
1744 }
1745 } else {
1746 control
1747 .restore(&[false])
1748 .await
1749 .map_err(NetRestoreError::Channel)?;
1750 }
1751 Ok(())
1752 }
1753
1754 fn saved_state(&self) -> saved_state::SavedState {
1755 let open = if let Some(coordinator) = self.coordinator.state() {
1756 let primary = coordinator.workers[0].state().unwrap();
1757 let primary = match &primary.state {
1758 WorkerState::Init(None) => saved_state::Primary::Version,
1759 WorkerState::Init(Some(init)) => {
1760 saved_state::Primary::Init(saved_state::InitPrimary {
1761 version: init.version as u32,
1762 ndis_config: init.ndis_config.map(|NdisConfig { mtu, capabilities }| {
1763 saved_state::NdisConfig {
1764 mtu,
1765 capabilities: capabilities.into(),
1766 }
1767 }),
1768 ndis_version: init.ndis_version.map(|NdisVersion { major, minor }| {
1769 saved_state::NdisVersion { major, minor }
1770 }),
1771 receive_buffer: init.recv_buffer.as_ref().map(|x| x.saved_state()),
1772 send_buffer: init.send_buffer.as_ref().map(|x| saved_state::SendBuffer {
1773 gpadl_id: x.gpadl.id(),
1774 }),
1775 })
1776 }
1777 WorkerState::WaitingForCoordinator(Some(ready)) | WorkerState::Ready(ready) => {
1778 let primary = ready.state.primary.as_ref().unwrap();
1779
1780 let rndis_state = match primary.rndis_state {
1781 RndisState::Initializing => saved_state::RndisState::Initializing,
1782 RndisState::Operational => saved_state::RndisState::Operational,
1783 RndisState::Halted => saved_state::RndisState::Halted,
1784 };
1785
1786 let offload_config = saved_state::OffloadConfig {
1787 checksum_tx: saved_state::ChecksumOffloadConfig {
1788 ipv4_header: primary.offload_config.checksum_tx.ipv4_header,
1789 tcp4: primary.offload_config.checksum_tx.tcp4,
1790 udp4: primary.offload_config.checksum_tx.udp4,
1791 tcp6: primary.offload_config.checksum_tx.tcp6,
1792 udp6: primary.offload_config.checksum_tx.udp6,
1793 },
1794 checksum_rx: saved_state::ChecksumOffloadConfig {
1795 ipv4_header: primary.offload_config.checksum_rx.ipv4_header,
1796 tcp4: primary.offload_config.checksum_rx.tcp4,
1797 udp4: primary.offload_config.checksum_rx.udp4,
1798 tcp6: primary.offload_config.checksum_rx.tcp6,
1799 udp6: primary.offload_config.checksum_rx.udp6,
1800 },
1801 lso4: primary.offload_config.lso4,
1802 lso6: primary.offload_config.lso6,
1803 };
1804
1805 let control_messages = primary
1806 .control_messages
1807 .iter()
1808 .map(|message| saved_state::IncomingControlMessage {
1809 message_type: message.message_type,
1810 data: message.data.to_vec(),
1811 })
1812 .collect();
1813
1814 let rss_state = primary.rss_state.as_ref().map(|rss| saved_state::RssState {
1815 key: rss.key.into(),
1816 indirection_table: rss.indirection_table.clone(),
1817 });
1818
1819 let pending_link_action = match primary.pending_link_action {
1820 PendingLinkAction::Default => None,
1821 PendingLinkAction::Active(action) | PendingLinkAction::Delay(action) => {
1822 Some(action)
1823 }
1824 };
1825
1826 let channels = coordinator.workers[..coordinator.num_queues as usize]
1827 .iter()
1828 .map(|worker| {
1829 worker.state().map(|worker| {
1830 if let Some(ready) = worker.state.ready() {
1831 let pending_tx_completions = ready
1834 .state
1835 .pending_tx_completions
1836 .iter()
1837 .map(|pending| pending.transaction_id)
1838 .chain(ready.state.pending_tx_packets.iter().filter_map(
1839 |inflight| {
1840 (inflight.pending_packet_count > 0)
1841 .then_some(inflight.transaction_id)
1842 },
1843 ))
1844 .collect();
1845
1846 saved_state::Channel {
1847 pending_tx_completions,
1848 in_use_rx: {
1849 ready
1850 .state
1851 .rx_bufs
1852 .allocated()
1853 .map(|id| saved_state::Rx {
1854 buffers: id.collect(),
1855 })
1856 .collect()
1857 },
1858 }
1859 } else {
1860 saved_state::Channel {
1861 pending_tx_completions: Vec::new(),
1862 in_use_rx: Vec::new(),
1863 }
1864 }
1865 })
1866 })
1867 .collect();
1868
1869 let guest_vf_state = match primary.guest_vf_state {
1870 PrimaryChannelGuestVfState::Initializing
1871 | PrimaryChannelGuestVfState::Unavailable
1872 | PrimaryChannelGuestVfState::Available { .. } => {
1873 saved_state::GuestVfState::NoState
1874 }
1875 PrimaryChannelGuestVfState::UnavailableFromAvailable
1876 | PrimaryChannelGuestVfState::AvailableAdvertised => {
1877 saved_state::GuestVfState::AvailableAdvertised
1878 }
1879 PrimaryChannelGuestVfState::Ready => saved_state::GuestVfState::Ready,
1880 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
1881 to_guest,
1882 id,
1883 } => saved_state::GuestVfState::DataPathSwitchPending {
1884 to_guest,
1885 id,
1886 result: None,
1887 },
1888 PrimaryChannelGuestVfState::DataPathSwitchPending {
1889 to_guest,
1890 id,
1891 result,
1892 } => saved_state::GuestVfState::DataPathSwitchPending {
1893 to_guest,
1894 id,
1895 result,
1896 },
1897 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
1898 | PrimaryChannelGuestVfState::DataPathSwitched
1899 | PrimaryChannelGuestVfState::DataPathSynthetic => {
1900 saved_state::GuestVfState::DataPathSwitched
1901 }
1902 PrimaryChannelGuestVfState::Restoring(saved_state) => saved_state,
1903 };
1904
1905 let worker_0_packet_filter = coordinator.workers[0]
1906 .state()
1907 .unwrap()
1908 .channel
1909 .packet_filter;
1910 saved_state::Primary::Ready(saved_state::ReadyPrimary {
1911 version: ready.buffers.version as u32,
1912 receive_buffer: ready.buffers.recv_buffer.saved_state(),
1913 send_buffer: ready.buffers.send_buffer.as_ref().map(|sb| {
1914 saved_state::SendBuffer {
1915 gpadl_id: sb.gpadl.id(),
1916 }
1917 }),
1918 rndis_state,
1919 guest_vf_state,
1920 offload_config,
1921 pending_offload_change: primary.pending_offload_change,
1922 control_messages,
1923 rss_state,
1924 channels,
1925 ndis_config: {
1926 let NdisConfig { mtu, capabilities } = ready.buffers.ndis_config;
1927 saved_state::NdisConfig {
1928 mtu,
1929 capabilities: capabilities.into(),
1930 }
1931 },
1932 ndis_version: {
1933 let NdisVersion { major, minor } = ready.buffers.ndis_version;
1934 saved_state::NdisVersion { major, minor }
1935 },
1936 tx_spread_sent: primary.tx_spread_sent,
1937 guest_link_down: !primary.guest_link_up,
1938 pending_link_action,
1939 packet_filter: Some(worker_0_packet_filter),
1940 })
1941 }
1942 WorkerState::WaitingForCoordinator(None) => {
1943 unreachable!("valid ready state")
1944 }
1945 };
1946
1947 let state = saved_state::OpenState { primary };
1948 Some(state)
1949 } else {
1950 None
1951 };
1952
1953 saved_state::SavedState { open }
1954 }
1955}
1956
1957#[derive(Debug, Error)]
1958enum WorkerError {
1959 #[error("packet error")]
1960 Packet(#[source] PacketError),
1961 #[error("unexpected packet order: {0}")]
1962 UnexpectedPacketOrder(#[source] PacketOrderError),
1963 #[error("unknown rndis message type: {0}")]
1964 UnknownRndisMessageType(u32),
1965 #[error("junk after rndis packet message: {0:#x}")]
1966 NonRndisPacketAfterPacket(u32),
1967 #[error("memory access error")]
1968 Access(#[from] AccessError),
1969 #[error("rndis message too small")]
1970 RndisMessageTooSmall,
1971 #[error("unsupported rndis behavior")]
1972 UnsupportedRndisBehavior,
1973 #[error("vmbus queue error")]
1974 Queue(#[from] queue::Error),
1975 #[error("too many control messages")]
1976 TooManyControlMessages,
1977 #[error("invalid rndis packet completion")]
1978 InvalidRndisPacketCompletion,
1979 #[error("missing transaction id")]
1980 MissingTransactionId,
1981 #[error("invalid gpadl")]
1982 InvalidGpadl(#[source] guestmem::InvalidGpn),
1983 #[error("gpadl error")]
1984 GpadlError(#[source] GuestMemoryError),
1985 #[error("gpa direct error")]
1986 GpaDirectError(#[source] GuestMemoryError),
1987 #[error("endpoint")]
1988 Endpoint(#[source] anyhow::Error),
1989 #[error("message not supported on sub channel: {0}")]
1990 NotSupportedOnSubChannel(u32),
1991 #[error("the ring buffer ran out of space, which should not be possible")]
1992 OutOfSpace,
1993 #[error("send/receive buffer error")]
1994 Buffer(#[from] BufferError),
1995 #[error("invalid rndis state")]
1996 InvalidRndisState,
1997 #[error("rndis message type not implemented")]
1998 RndisMessageTypeNotImplemented,
1999 #[error("invalid TCP header offset")]
2000 InvalidTcpHeaderOffset,
2001 #[error("cancelled")]
2002 Cancelled(task_control::Cancelled),
2003 #[error("tearing down because send/receive buffer is revoked")]
2004 BufferRevoked,
2005 #[error("endpoint requires queue restart: {0}")]
2006 EndpointRequiresQueueRestart(#[source] anyhow::Error),
2007 #[error("Failed to send message to coordinator")]
2008 CoordinatorMessageSendFailed(#[source] TrySendError<CoordinatorMessage>),
2009}
2010
2011impl From<task_control::Cancelled> for WorkerError {
2012 fn from(value: task_control::Cancelled) -> Self {
2013 Self::Cancelled(value)
2014 }
2015}
2016
2017#[derive(Debug, Error)]
2018enum OpenError {
2019 #[error("error establishing ring buffer")]
2020 Ring(#[source] vmbus_channel::gpadl_ring::Error),
2021 #[error("error establishing vmbus queue")]
2022 Queue(#[source] queue::Error),
2023}
2024
2025#[derive(Debug, Error)]
2026enum PacketError {
2027 #[error("UnknownType {0}")]
2028 UnknownType(u32),
2029 #[error("Access")]
2030 Access(#[source] AccessError),
2031 #[error("ExternalData")]
2032 ExternalData(#[source] ExternalDataError),
2033 #[error("InvalidSendBufferIndex")]
2034 InvalidSendBufferIndex,
2035}
2036
2037#[derive(Debug, Error)]
2038enum PacketOrderError {
2039 #[error("Invalid PacketData")]
2040 InvalidPacketData,
2041 #[error("Unexpected RndisPacket")]
2042 UnexpectedRndisPacket,
2043 #[error("SendNdisVersion already exists")]
2044 SendNdisVersionExists,
2045 #[error("SendNdisConfig already exists")]
2046 SendNdisConfigExists,
2047 #[error("SendReceiveBuffer already exists")]
2048 SendReceiveBufferExists,
2049 #[error("SendReceiveBuffer missing MTU")]
2050 SendReceiveBufferMissingMTU,
2051 #[error("SendSendBuffer already exists")]
2052 SendSendBufferExists,
2053 #[error("SwitchDataPathCompletion during PrimaryChannelState")]
2054 SwitchDataPathCompletionPrimaryChannelState,
2055}
2056
2057#[derive(Debug)]
2058enum PacketData {
2059 Init(protocol::MessageInit),
2060 SendNdisVersion(protocol::Message1SendNdisVersion),
2061 SendReceiveBuffer(protocol::Message1SendReceiveBuffer),
2062 SendSendBuffer(protocol::Message1SendSendBuffer),
2063 RevokeReceiveBuffer(protocol::Message1RevokeReceiveBuffer),
2064 RevokeSendBuffer(protocol::Message1RevokeSendBuffer),
2065 RndisPacket(protocol::Message1SendRndisPacket),
2066 RndisPacketComplete(protocol::Message1SendRndisPacketComplete),
2067 SendNdisConfig(protocol::Message2SendNdisConfig),
2068 SwitchDataPath(protocol::Message4SwitchDataPath),
2069 OidQueryEx(protocol::Message5OidQueryEx),
2070 SubChannelRequest(protocol::Message5SubchannelRequest),
2071 SendVfAssociationCompletion,
2072 SwitchDataPathCompletion,
2073}
2074
2075#[derive(Debug)]
2076struct Packet<'a> {
2077 data: PacketData,
2078 transaction_id: Option<u64>,
2079 external_data: &'a MultiPagedRangeBuf,
2080}
2081
2082type PacketReader<'a> = PagedRangesReader<'a, MultiPagedRangeIter<'a>>;
2083
2084impl Packet<'_> {
2085 fn rndis_reader<'a>(&'a self, mem: &'a GuestMemory) -> PacketReader<'a> {
2086 PagedRanges::new(self.external_data.iter()).reader(mem)
2087 }
2088}
2089
2090fn read_packet_data<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
2091 reader: &mut impl MemoryRead,
2092) -> Result<T, PacketError> {
2093 reader.read_plain().map_err(PacketError::Access)
2094}
2095
2096fn parse_packet<'a, T: RingMem>(
2097 packet_ref: &queue::PacketRef<'_, T>,
2098 send_buffer: Option<&SendBuffer>,
2099 version: Option<Version>,
2100 external_data: &'a mut MultiPagedRangeBuf,
2101) -> Result<Packet<'a>, PacketError> {
2102 external_data.clear();
2103 let packet = match packet_ref.as_ref() {
2104 IncomingPacket::Data(data) => data,
2105 IncomingPacket::Completion(completion) => {
2106 let data = if completion.transaction_id() == VF_ASSOCIATION_TRANSACTION_ID {
2107 PacketData::SendVfAssociationCompletion
2108 } else if completion.transaction_id() == SWITCH_DATA_PATH_TRANSACTION_ID {
2109 PacketData::SwitchDataPathCompletion
2110 } else {
2111 let mut reader = completion.reader();
2112 let header: protocol::MessageHeader =
2113 reader.read_plain().map_err(PacketError::Access)?;
2114 match header.message_type {
2115 protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE => {
2116 PacketData::RndisPacketComplete(read_packet_data(&mut reader)?)
2117 }
2118 typ => return Err(PacketError::UnknownType(typ)),
2119 }
2120 };
2121 return Ok(Packet {
2122 data,
2123 transaction_id: Some(completion.transaction_id()),
2124 external_data,
2125 });
2126 }
2127 };
2128
2129 let mut reader = packet.reader();
2130 let header: protocol::MessageHeader = reader.read_plain().map_err(PacketError::Access)?;
2131 let data = match header.message_type {
2132 protocol::MESSAGE_TYPE_INIT => PacketData::Init(read_packet_data(&mut reader)?),
2133 protocol::MESSAGE1_TYPE_SEND_NDIS_VERSION if version >= Some(Version::V1) => {
2134 PacketData::SendNdisVersion(read_packet_data(&mut reader)?)
2135 }
2136 protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2137 PacketData::SendReceiveBuffer(read_packet_data(&mut reader)?)
2138 }
2139 protocol::MESSAGE1_TYPE_REVOKE_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2140 PacketData::RevokeReceiveBuffer(read_packet_data(&mut reader)?)
2141 }
2142 protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER if version >= Some(Version::V1) => {
2143 PacketData::SendSendBuffer(read_packet_data(&mut reader)?)
2144 }
2145 protocol::MESSAGE1_TYPE_REVOKE_SEND_BUFFER if version >= Some(Version::V1) => {
2146 PacketData::RevokeSendBuffer(read_packet_data(&mut reader)?)
2147 }
2148 protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET if version >= Some(Version::V1) => {
2149 let message: protocol::Message1SendRndisPacket = read_packet_data(&mut reader)?;
2150 if message.send_buffer_section_index != 0xffffffff {
2151 let send_buffer_suballocation = send_buffer
2152 .ok_or(PacketError::InvalidSendBufferIndex)?
2153 .gpadl
2154 .first()
2155 .unwrap()
2156 .try_subrange(
2157 message.send_buffer_section_index as usize * 6144,
2158 message.send_buffer_section_size as usize,
2159 )
2160 .ok_or(PacketError::InvalidSendBufferIndex)?;
2161
2162 external_data.push_range(send_buffer_suballocation);
2163 }
2164 PacketData::RndisPacket(message)
2165 }
2166 protocol::MESSAGE2_TYPE_SEND_NDIS_CONFIG if version >= Some(Version::V2) => {
2167 PacketData::SendNdisConfig(read_packet_data(&mut reader)?)
2168 }
2169 protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH if version >= Some(Version::V4) => {
2170 PacketData::SwitchDataPath(read_packet_data(&mut reader)?)
2171 }
2172 protocol::MESSAGE5_TYPE_OID_QUERY_EX if version >= Some(Version::V5) => {
2173 PacketData::OidQueryEx(read_packet_data(&mut reader)?)
2174 }
2175 protocol::MESSAGE5_TYPE_SUB_CHANNEL if version >= Some(Version::V5) => {
2176 PacketData::SubChannelRequest(read_packet_data(&mut reader)?)
2177 }
2178 typ => return Err(PacketError::UnknownType(typ)),
2179 };
2180 packet
2181 .read_external_ranges(external_data)
2182 .map_err(PacketError::ExternalData)?;
2183 Ok(Packet {
2184 data,
2185 transaction_id: packet.transaction_id(),
2186 external_data,
2187 })
2188}
2189
2190#[derive(Debug, Copy, Clone)]
2191struct NvspMessage {
2192 buf: [u64; protocol::PACKET_SIZE_V61 / 8],
2193 size: PacketSize,
2194}
2195
2196impl NvspMessage {
2197 fn new<P: IntoBytes + Immutable + KnownLayout>(
2198 size: PacketSize,
2199 message_type: u32,
2200 data: P,
2201 ) -> Self {
2202 const {
2209 assert!(
2210 size_of::<P>() <= protocol::PACKET_SIZE_V1 - size_of::<protocol::MessageHeader>(),
2211 "packet might not fit in message"
2212 )
2213 };
2214 let mut message = NvspMessage {
2215 buf: [0; protocol::PACKET_SIZE_V61 / 8],
2216 size,
2217 };
2218 let header = protocol::MessageHeader { message_type };
2219 header.write_to_prefix(message.buf.as_mut_bytes()).unwrap();
2220 data.write_to_prefix(
2221 &mut message.buf.as_mut_bytes()[size_of::<protocol::MessageHeader>()..],
2222 )
2223 .unwrap();
2224 message
2225 }
2226
2227 fn aligned_payload(&self) -> &[u64] {
2228 let len = match self.size {
2231 PacketSize::V1 => const { protocol::PACKET_SIZE_V1.div_ceil(8) },
2232 PacketSize::V61 => const { protocol::PACKET_SIZE_V61.div_ceil(8) },
2233 };
2234 &self.buf[..len]
2235 }
2236}
2237
2238impl<T: RingMem> NetChannel<T> {
2239 fn message<P: IntoBytes + Immutable + KnownLayout>(
2240 &self,
2241 message_type: u32,
2242 data: P,
2243 ) -> NvspMessage {
2244 NvspMessage::new(self.packet_size, message_type, data)
2245 }
2246
2247 fn send_completion(
2248 &mut self,
2249 transaction_id: Option<u64>,
2250 message: Option<&NvspMessage>,
2251 ) -> Result<(), WorkerError> {
2252 match transaction_id {
2253 None => Ok(()),
2254 Some(transaction_id) => Ok(self
2255 .queue
2256 .split()
2257 .1
2258 .batched()
2259 .try_write_aligned(
2260 transaction_id,
2261 OutgoingPacketType::Completion,
2262 message.map_or(&[], |m| m.aligned_payload()),
2263 )
2264 .map_err(|err| match err {
2265 queue::TryWriteError::Full(_) => WorkerError::OutOfSpace,
2266 queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2267 })?),
2268 }
2269 }
2270}
2271
2272static SUPPORTED_VERSIONS: &[Version] = &[
2273 Version::V1,
2274 Version::V2,
2275 Version::V4,
2276 Version::V5,
2277 Version::V6,
2278 Version::V61,
2279];
2280
2281fn check_version(requested_version: u32) -> Option<Version> {
2282 SUPPORTED_VERSIONS
2283 .iter()
2284 .find(|version| **version as u32 == requested_version)
2285 .copied()
2286}
2287
2288#[derive(Debug)]
2289struct ReceiveBuffer {
2290 gpadl: GpadlView,
2291 id: u16,
2292 count: u32,
2293 sub_allocation_size: u32,
2294}
2295
2296#[derive(Debug, Error)]
2297enum BufferError {
2298 #[error("unsupported suballocation size {0}")]
2299 UnsupportedSuballocationSize(u32),
2300 #[error("unaligned gpadl")]
2301 UnalignedGpadl,
2302 #[error("unknown gpadl ID")]
2303 UnknownGpadlId(#[from] UnknownGpadlId),
2304}
2305
2306impl ReceiveBuffer {
2307 fn new(
2308 gpadl_map: &GpadlMapView,
2309 gpadl_id: GpadlId,
2310 id: u16,
2311 sub_allocation_size: u32,
2312 ) -> Result<Self, BufferError> {
2313 if sub_allocation_size < sub_allocation_size_for_mtu(DEFAULT_MTU) {
2314 return Err(BufferError::UnsupportedSuballocationSize(
2315 sub_allocation_size,
2316 ));
2317 }
2318 let gpadl = gpadl_map.map(gpadl_id)?;
2319 let range = gpadl
2320 .contiguous_aligned()
2321 .ok_or(BufferError::UnalignedGpadl)?;
2322 let num_sub_allocations = range.len() as u32 / sub_allocation_size;
2323 if num_sub_allocations == 0 {
2324 return Err(BufferError::UnsupportedSuballocationSize(
2325 sub_allocation_size,
2326 ));
2327 }
2328 let recv_buffer = Self {
2329 gpadl,
2330 id,
2331 count: num_sub_allocations,
2332 sub_allocation_size,
2333 };
2334 Ok(recv_buffer)
2335 }
2336
2337 fn range(&self, index: u32) -> PagedRange<'_> {
2338 self.gpadl.first().unwrap().subrange(
2339 (index * self.sub_allocation_size) as usize,
2340 self.sub_allocation_size as usize,
2341 )
2342 }
2343
2344 fn transfer_page_range(&self, index: u32, len: usize) -> ring::TransferPageRange {
2345 assert!(len <= self.sub_allocation_size as usize);
2346 ring::TransferPageRange {
2347 byte_offset: index * self.sub_allocation_size,
2348 byte_count: len as u32,
2349 }
2350 }
2351
2352 fn saved_state(&self) -> saved_state::ReceiveBuffer {
2353 saved_state::ReceiveBuffer {
2354 gpadl_id: self.gpadl.id(),
2355 id: self.id,
2356 sub_allocation_size: self.sub_allocation_size,
2357 }
2358 }
2359}
2360
2361#[derive(Debug)]
2362struct SendBuffer {
2363 gpadl: GpadlView,
2364}
2365
2366impl SendBuffer {
2367 fn new(gpadl_map: &GpadlMapView, gpadl_id: GpadlId) -> Result<Self, BufferError> {
2368 let gpadl = gpadl_map.map(gpadl_id)?;
2369 gpadl
2370 .contiguous_aligned()
2371 .ok_or(BufferError::UnalignedGpadl)?;
2372 Ok(Self { gpadl })
2373 }
2374}
2375
2376impl<T: RingMem> NetChannel<T> {
2377 fn handle_rndis_message(
2379 &mut self,
2380 state: &mut ActiveState,
2381 message_type: u32,
2382 mut reader: PacketReader<'_>,
2383 ) -> Result<(), WorkerError> {
2384 assert_ne!(
2385 message_type,
2386 rndisprot::MESSAGE_TYPE_PACKET_MSG,
2387 "handled elsewhere"
2388 );
2389 let control = state
2390 .primary
2391 .as_mut()
2392 .ok_or(WorkerError::NotSupportedOnSubChannel(message_type))?;
2393
2394 if message_type == rndisprot::MESSAGE_TYPE_HALT_MSG {
2395 return Ok(());
2397 }
2398
2399 const CONTROL_MESSAGE_MAX_QUEUED_BYTES: usize = 100 * 1024;
2404 if reader.len() == 0 {
2405 return Err(WorkerError::RndisMessageTooSmall);
2406 }
2407 if CONTROL_MESSAGE_MAX_QUEUED_BYTES - control.control_messages_len < reader.len() {
2410 return Err(WorkerError::TooManyControlMessages);
2411 }
2412
2413 control.control_messages_len += reader.len();
2414 control.control_messages.push_back(ControlMessage {
2415 message_type,
2416 data: reader.read_all()?.into(),
2417 });
2418
2419 Ok(())
2422 }
2423
2424 fn handle_rndis_packet_messages(
2430 &mut self,
2431 buffers: &ChannelBuffers,
2432 state: &mut ActiveState,
2433 id: TxId,
2434 mut message_len: usize,
2435 mut reader: PacketReader<'_>,
2436 segments: &mut Vec<TxSegment>,
2437 ) -> Result<usize, WorkerError> {
2438 let mut num_packets = 0;
2442 loop {
2443 let next_message_offset = message_len
2444 .checked_sub(size_of::<rndisprot::MessageHeader>())
2445 .ok_or(WorkerError::RndisMessageTooSmall)?;
2446
2447 self.handle_rndis_packet_message(
2448 id,
2449 reader.clone(),
2450 &buffers.mem,
2451 segments,
2452 &mut state.stats,
2453 )?;
2454 num_packets += 1;
2455
2456 reader.skip(next_message_offset)?;
2457 if reader.len() == 0 {
2458 break;
2459 }
2460 let header: rndisprot::MessageHeader = reader.read_plain()?;
2461 if header.message_type != rndisprot::MESSAGE_TYPE_PACKET_MSG {
2462 return Err(WorkerError::NonRndisPacketAfterPacket(header.message_type));
2463 }
2464 message_len = header.message_length as usize;
2465 }
2466 Ok(num_packets)
2467 }
2468
2469 fn handle_rndis_packet_message(
2471 &mut self,
2472 id: TxId,
2473 reader: PacketReader<'_>,
2474 mem: &GuestMemory,
2475 segments: &mut Vec<TxSegment>,
2476 stats: &mut QueueStats,
2477 ) -> Result<(), WorkerError> {
2478 let headers = reader
2480 .clone()
2481 .into_inner()
2482 .paged_ranges()
2483 .next()
2484 .ok_or(WorkerError::RndisMessageTooSmall)?;
2485 let mut data = reader.into_inner();
2486 let request: rndisprot::Packet = headers.reader(mem).read_plain()?;
2487 if request.num_oob_data_elements != 0
2488 || request.oob_data_length != 0
2489 || request.oob_data_offset != 0
2490 || request.vc_handle != 0
2491 {
2492 return Err(WorkerError::UnsupportedRndisBehavior);
2493 }
2494
2495 if data.len() < request.data_offset as usize
2496 || (data.len() - request.data_offset as usize) < request.data_length as usize
2497 || request.data_length == 0
2498 {
2499 return Err(WorkerError::RndisMessageTooSmall);
2500 }
2501
2502 data.skip(request.data_offset as usize);
2503 data.truncate(request.data_length as usize);
2504
2505 let mut metadata = net_backend::TxMetadata {
2506 id,
2507 len: request.data_length,
2508 ..Default::default()
2509 };
2510
2511 if request.per_packet_info_length != 0 {
2512 let mut ppi = headers
2513 .try_subrange(
2514 request.per_packet_info_offset as usize,
2515 request.per_packet_info_length as usize,
2516 )
2517 .ok_or(WorkerError::RndisMessageTooSmall)?;
2518 while !ppi.is_empty() {
2519 let h: rndisprot::PerPacketInfo = ppi.reader(mem).read_plain()?;
2520 if h.size == 0 {
2521 return Err(WorkerError::RndisMessageTooSmall);
2522 }
2523 let (this, rest) = ppi
2524 .try_split(h.size as usize)
2525 .ok_or(WorkerError::RndisMessageTooSmall)?;
2526 let (_, d) = this
2527 .try_split(h.per_packet_information_offset as usize)
2528 .ok_or(WorkerError::RndisMessageTooSmall)?;
2529 match h.typ {
2530 rndisprot::PPI_TCP_IP_CHECKSUM => {
2531 let n: rndisprot::TxTcpIpChecksumInfo = d.reader(mem).read_plain()?;
2532
2533 metadata.flags.set_offload_tcp_checksum(
2534 (n.is_ipv4() || n.is_ipv6()) && n.tcp_checksum(),
2535 );
2536 metadata.flags.set_offload_udp_checksum(
2537 (n.is_ipv4() || n.is_ipv6()) && !n.tcp_checksum() && n.udp_checksum(),
2538 );
2539 metadata
2540 .flags
2541 .set_offload_ip_header_checksum(n.is_ipv4() && n.ip_header_checksum());
2542 metadata.flags.set_is_ipv4(n.is_ipv4());
2543 metadata.flags.set_is_ipv6(n.is_ipv6() && !n.is_ipv4());
2544 metadata.l2_len = ETHERNET_HEADER_LEN as u8;
2545 if metadata.flags.offload_tcp_checksum()
2546 || metadata.flags.offload_udp_checksum()
2547 {
2548 metadata.l3_len = if n.tcp_header_offset() >= metadata.l2_len as u16 {
2549 n.tcp_header_offset() - metadata.l2_len as u16
2550 } else if n.is_ipv4() {
2551 let mut reader = data.clone().reader(mem);
2552 reader.skip(metadata.l2_len as usize)?;
2553 let mut b = 0;
2554 reader.read(std::slice::from_mut(&mut b))?;
2555 (b as u16 >> 4) * 4
2556 } else {
2557 40
2559 };
2560 }
2561 }
2562 rndisprot::PPI_LSO => {
2563 let n: rndisprot::TcpLsoInfo = d.reader(mem).read_plain()?;
2564
2565 metadata.flags.set_offload_tcp_segmentation(true);
2566 metadata.flags.set_offload_tcp_checksum(true);
2567 metadata.flags.set_offload_ip_header_checksum(n.is_ipv4());
2568 metadata.flags.set_is_ipv4(n.is_ipv4());
2569 metadata.flags.set_is_ipv6(n.is_ipv6() && !n.is_ipv4());
2570 metadata.l2_len = ETHERNET_HEADER_LEN as u8;
2571 if n.tcp_header_offset() < metadata.l2_len as u16 {
2572 return Err(WorkerError::InvalidTcpHeaderOffset);
2573 }
2574 metadata.l3_len = n.tcp_header_offset() - metadata.l2_len as u16;
2575 metadata.l4_len = {
2576 let mut reader = data.clone().reader(mem);
2577 reader
2578 .skip(metadata.l2_len as usize + metadata.l3_len as usize + 12)?;
2579 let mut b = 0;
2580 reader.read(std::slice::from_mut(&mut b))?;
2581 (b >> 4) * 4
2582 };
2583 metadata.max_tcp_segment_size = n.mss() as u16;
2584
2585 if request.data_length >= rndisprot::LSO_MAX_OFFLOAD_SIZE {
2586 stats.tx_invalid_lso_packets.increment();
2588 }
2589 }
2590 _ => {}
2591 }
2592 ppi = rest;
2593 }
2594 }
2595
2596 let start = segments.len();
2597 for range in data.paged_ranges().flat_map(|r| r.ranges()) {
2598 let range = range.map_err(WorkerError::InvalidGpadl)?;
2599 segments.push(TxSegment {
2600 ty: net_backend::TxSegmentType::Tail,
2601 gpa: range.start,
2602 len: range.len() as u32,
2603 });
2604 }
2605
2606 metadata.segment_count = (segments.len() - start) as u8;
2607
2608 stats.tx_packets.increment();
2609 if metadata.flags.offload_tcp_checksum() || metadata.flags.offload_udp_checksum() {
2610 stats.tx_checksum_packets.increment();
2611 }
2612 if metadata.flags.offload_tcp_segmentation() {
2613 stats.tx_lso_packets.increment();
2614 }
2615
2616 segments[start].ty = net_backend::TxSegmentType::Head(metadata);
2617
2618 Ok(())
2619 }
2620
2621 fn guest_vf_is_available(
2624 &mut self,
2625 guest_vf_id: Option<u32>,
2626 version: Version,
2627 config: NdisConfig,
2628 ) -> Result<bool, WorkerError> {
2629 let serial_number = guest_vf_id.map(|vfid| self.adapter.get_guest_vf_serial_number(vfid));
2630 if version >= Version::V4 && config.capabilities.sriov() {
2631 tracing::info!(
2632 available = serial_number.is_some(),
2633 serial_number,
2634 "sending VF association message"
2635 );
2636 let message = {
2638 self.message(
2639 protocol::MESSAGE4_TYPE_SEND_VF_ASSOCIATION,
2640 protocol::Message4SendVfAssociation {
2641 vf_allocated: if serial_number.is_some() { 1 } else { 0 },
2642 serial_number: serial_number.unwrap_or(0),
2643 },
2644 )
2645 };
2646 self.queue
2647 .split()
2648 .1
2649 .batched()
2650 .try_write_aligned(
2651 VF_ASSOCIATION_TRANSACTION_ID,
2652 OutgoingPacketType::InBandWithCompletion,
2653 message.aligned_payload(),
2654 )
2655 .map_err(|err| match err {
2656 queue::TryWriteError::Full(len) => {
2657 tracing::error!(len, "failed to write vf association message");
2658 WorkerError::OutOfSpace
2659 }
2660 queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2661 })?;
2662 Ok(true)
2663 } else {
2664 tracing::info!(
2665 available = serial_number.is_some(),
2666 serial_number,
2667 major = version.major(),
2668 minor = version.minor(),
2669 sriov_capable = config.capabilities.sriov(),
2670 "Skipping NvspMessage4TypeSendVFAssociation message"
2671 );
2672 Ok(false)
2673 }
2674 }
2675
2676 fn guest_send_indirection_table(&mut self, version: Version, num_channels_opened: u32) {
2678 if version < Version::V5 {
2680 return;
2681 }
2682
2683 #[repr(C)]
2684 #[derive(IntoBytes, Immutable, KnownLayout)]
2685 struct SendIndirectionMsg {
2686 pub message: protocol::Message5SendIndirectionTable,
2687 pub send_indirection_table:
2688 [u32; VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES as usize],
2689 }
2690
2691 let send_indirection_table_offset = offset_of!(SendIndirectionMsg, send_indirection_table)
2693 + size_of::<protocol::MessageHeader>();
2694 let mut data = SendIndirectionMsg {
2695 message: protocol::Message5SendIndirectionTable {
2696 table_entry_count: VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES,
2697 table_offset: send_indirection_table_offset as u32,
2698 },
2699 send_indirection_table: Default::default(),
2700 };
2701
2702 for i in 0..data.send_indirection_table.len() {
2703 data.send_indirection_table[i] = i as u32 % num_channels_opened;
2704 }
2705
2706 let header = protocol::MessageHeader {
2707 message_type: protocol::MESSAGE5_TYPE_SEND_INDIRECTION_TABLE,
2708 };
2709 let result = self
2710 .queue
2711 .split()
2712 .1
2713 .try_write(&queue::OutgoingPacket {
2714 transaction_id: 0,
2715 packet_type: OutgoingPacketType::InBandNoCompletion,
2716 payload: &[header.as_bytes(), data.as_bytes()],
2717 })
2718 .map_err(|err| match err {
2719 queue::TryWriteError::Full(len) => {
2720 tracing::error!(len, "failed to write send indirection table message");
2721 WorkerError::OutOfSpace
2722 }
2723 queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2724 });
2725 if let Err(err) = result {
2726 tracing::error!(
2727 error = &err as &dyn std::error::Error,
2728 "Failed to notify guest about the send indirection table"
2729 );
2730 }
2731 }
2732
2733 fn guest_vf_data_path_switched_to_synthetic(&mut self) {
2736 let header = protocol::MessageHeader {
2737 message_type: protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH,
2738 };
2739 let data = protocol::Message4SwitchDataPath {
2740 active_data_path: protocol::DataPath::SYNTHETIC.0,
2741 };
2742 let result = self
2743 .queue
2744 .split()
2745 .1
2746 .try_write(&queue::OutgoingPacket {
2747 transaction_id: SWITCH_DATA_PATH_TRANSACTION_ID,
2748 packet_type: OutgoingPacketType::InBandWithCompletion,
2749 payload: &[header.as_bytes(), data.as_bytes()],
2750 })
2751 .map_err(|err| match err {
2752 queue::TryWriteError::Full(len) => {
2753 tracing::error!(len, "failed to write switch data path message");
2754 WorkerError::OutOfSpace
2755 }
2756 queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2757 });
2758 if let Err(err) = result {
2759 tracing::error!(
2760 error = &err as &dyn std::error::Error,
2761 "Failed to notify guest that data path is now synthetic"
2762 );
2763 }
2764 }
2765
2766 async fn handle_state_change(
2768 &mut self,
2769 primary: &mut PrimaryChannelState,
2770 buffers: &ChannelBuffers,
2771 ) -> Result<Option<CoordinatorMessage>, WorkerError> {
2772 if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
2778 if primary.rndis_state == RndisState::Operational {
2780 if self.guest_vf_is_available(Some(vfid), buffers.version, buffers.ndis_config)? {
2781 primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
2782 return Ok(Some(CoordinatorMessage::Update(
2783 CoordinatorMessageUpdateType {
2784 guest_vf_state: true,
2785 ..Default::default()
2786 },
2787 )));
2788 } else if let Some(true) = primary.is_data_path_switched {
2789 tracing::error!(
2790 "Data path switched, but current guest negotiation does not support VTL0 VF"
2791 );
2792 }
2793 }
2794 return Ok(None);
2795 }
2796 loop {
2797 primary.guest_vf_state = match primary.guest_vf_state {
2798 PrimaryChannelGuestVfState::UnavailableFromAvailable => {
2799 if primary.rndis_state == RndisState::Operational {
2801 self.guest_vf_is_available(None, buffers.version, buffers.ndis_config)?;
2802 }
2803 PrimaryChannelGuestVfState::Unavailable
2804 }
2805 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
2806 to_guest,
2807 id,
2808 } => {
2809 self.send_completion(id, None)?;
2811 if to_guest {
2812 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
2813 } else {
2814 PrimaryChannelGuestVfState::UnavailableFromAvailable
2815 }
2816 }
2817 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
2818 self.guest_vf_data_path_switched_to_synthetic();
2820 PrimaryChannelGuestVfState::UnavailableFromAvailable
2821 }
2822 PrimaryChannelGuestVfState::DataPathSynthetic => {
2823 self.guest_vf_data_path_switched_to_synthetic();
2825 PrimaryChannelGuestVfState::Ready
2826 }
2827 PrimaryChannelGuestVfState::DataPathSwitchPending {
2828 to_guest,
2829 id,
2830 result,
2831 } => {
2832 let result = result.expect("DataPathSwitchPending should have been processed");
2833 self.send_completion(id, None)?;
2835
2836 match (to_guest, result) {
2837 (true, true) => PrimaryChannelGuestVfState::DataPathSwitched,
2839 (true, false) => {
2841 tracing::error!(
2842 "Failure switching to guest VF, remaining on synthetic"
2843 );
2844 PrimaryChannelGuestVfState::DataPathSynthetic
2845 }
2846 (false, true) => PrimaryChannelGuestVfState::Ready,
2848 (false, false) => {
2850 tracing::error!(
2851 "Failure when guest requested switch back to synthetic"
2852 );
2853 PrimaryChannelGuestVfState::DataPathSwitched
2854 }
2855 }
2856 }
2857 PrimaryChannelGuestVfState::Initializing
2858 | PrimaryChannelGuestVfState::Restoring(_) => {
2859 panic!("Invalid guest VF state: {}", primary.guest_vf_state)
2860 }
2861 _ => break,
2862 };
2863 }
2864 Ok(None)
2865 }
2866
2867 fn handle_rndis_control_message(
2870 &mut self,
2871 primary: &mut PrimaryChannelState,
2872 buffers: &ChannelBuffers,
2873 message_type: u32,
2874 mut reader: impl MemoryRead + Clone,
2875 id: u32,
2876 ) -> Result<(), WorkerError> {
2877 let mem = &buffers.mem;
2878 let buffer_range = &buffers.recv_buffer.range(id);
2879 match message_type {
2880 rndisprot::MESSAGE_TYPE_INITIALIZE_MSG => {
2881 if primary.rndis_state != RndisState::Initializing {
2882 return Err(WorkerError::InvalidRndisState);
2883 }
2884
2885 let request: rndisprot::InitializeRequest = reader.read_plain()?;
2886
2887 tracing::trace!(
2888 ?request,
2889 "handling control message MESSAGE_TYPE_INITIALIZE_MSG"
2890 );
2891
2892 primary.rndis_state = RndisState::Operational;
2893
2894 let mut writer = buffer_range.writer(mem);
2895 let message_length = write_rndis_message(
2896 &mut writer,
2897 rndisprot::MESSAGE_TYPE_INITIALIZE_CMPLT,
2898 0,
2899 &rndisprot::InitializeComplete {
2900 request_id: request.request_id,
2901 status: rndisprot::STATUS_SUCCESS,
2902 major_version: rndisprot::MAJOR_VERSION,
2903 minor_version: rndisprot::MINOR_VERSION,
2904 device_flags: rndisprot::DF_CONNECTIONLESS,
2905 medium: rndisprot::MEDIUM_802_3,
2906 max_packets_per_message: 8,
2907 max_transfer_size: 0xEFFFFFFF,
2908 packet_alignment_factor: 3,
2909 af_list_offset: 0,
2910 af_list_size: 0,
2911 },
2912 )?;
2913 self.send_rndis_control_message(buffers, id, message_length)?;
2914 if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
2915 if self.guest_vf_is_available(
2916 Some(vfid),
2917 buffers.version,
2918 buffers.ndis_config,
2919 )? {
2920 primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
2932 self.send_coordinator_update_vf();
2933 } else if let Some(true) = primary.is_data_path_switched {
2934 tracing::error!(
2935 "Data path switched, but current guest negotiation does not support VTL0 VF"
2936 );
2937 }
2938 }
2939 }
2940 rndisprot::MESSAGE_TYPE_QUERY_MSG => {
2941 let request: rndisprot::QueryRequest = reader.read_plain()?;
2942
2943 tracing::trace!(?request, "handling control message MESSAGE_TYPE_QUERY_MSG");
2944
2945 let (header, body) = buffer_range
2946 .try_split(
2947 size_of::<rndisprot::MessageHeader>()
2948 + size_of::<rndisprot::QueryComplete>(),
2949 )
2950 .ok_or(WorkerError::RndisMessageTooSmall)?;
2951 let (status, tx) = match self.adapter.handle_oid_query(
2952 buffers,
2953 primary,
2954 request.oid,
2955 body.writer(mem),
2956 ) {
2957 Ok(tx) => (rndisprot::STATUS_SUCCESS, tx),
2958 Err(err) => (err.as_status(), 0),
2959 };
2960
2961 let message_length = write_rndis_message(
2962 &mut header.writer(mem),
2963 rndisprot::MESSAGE_TYPE_QUERY_CMPLT,
2964 tx,
2965 &rndisprot::QueryComplete {
2966 request_id: request.request_id,
2967 status,
2968 information_buffer_offset: size_of::<rndisprot::QueryComplete>() as u32,
2969 information_buffer_length: tx as u32,
2970 },
2971 )?;
2972 self.send_rndis_control_message(buffers, id, message_length)?;
2973 }
2974 rndisprot::MESSAGE_TYPE_SET_MSG => {
2975 let request: rndisprot::SetRequest = reader.read_plain()?;
2976
2977 tracing::trace!(?request, "handling control message MESSAGE_TYPE_SET_MSG");
2978
2979 let status = match self.adapter.handle_oid_set(primary, request.oid, reader) {
2980 Ok((restart_endpoint, packet_filter)) => {
2981 if restart_endpoint {
2984 self.restart = Some(CoordinatorMessage::Restart);
2985 }
2986 if let Some(filter) = packet_filter {
2987 if self.packet_filter != filter {
2988 self.packet_filter = filter;
2989 self.send_coordinator_update_filter();
2990 }
2991 }
2992 rndisprot::STATUS_SUCCESS
2993 }
2994 Err(err) => {
2995 tracelimit::warn_ratelimited!(
2996 error = &err as &dyn std::error::Error,
2997 oid = ?request.oid,
2998 "oid set failure"
2999 );
3000 err.as_status()
3001 }
3002 };
3003
3004 let message_length = write_rndis_message(
3005 &mut buffer_range.writer(mem),
3006 rndisprot::MESSAGE_TYPE_SET_CMPLT,
3007 0,
3008 &rndisprot::SetComplete {
3009 request_id: request.request_id,
3010 status,
3011 },
3012 )?;
3013 self.send_rndis_control_message(buffers, id, message_length)?;
3014 }
3015 rndisprot::MESSAGE_TYPE_RESET_MSG => {
3016 return Err(WorkerError::RndisMessageTypeNotImplemented);
3017 }
3018 rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG => {
3019 return Err(WorkerError::RndisMessageTypeNotImplemented);
3020 }
3021 rndisprot::MESSAGE_TYPE_KEEPALIVE_MSG => {
3022 let request: rndisprot::KeepaliveRequest = reader.read_plain()?;
3023
3024 tracing::trace!(
3025 ?request,
3026 "handling control message MESSAGE_TYPE_KEEPALIVE_MSG"
3027 );
3028
3029 let message_length = write_rndis_message(
3030 &mut buffer_range.writer(mem),
3031 rndisprot::MESSAGE_TYPE_KEEPALIVE_CMPLT,
3032 0,
3033 &rndisprot::KeepaliveComplete {
3034 request_id: request.request_id,
3035 status: rndisprot::STATUS_SUCCESS,
3036 },
3037 )?;
3038 self.send_rndis_control_message(buffers, id, message_length)?;
3039 }
3040 rndisprot::MESSAGE_TYPE_SET_EX_MSG => {
3041 return Err(WorkerError::RndisMessageTypeNotImplemented);
3042 }
3043 _ => return Err(WorkerError::UnknownRndisMessageType(message_type)),
3044 };
3045 Ok(())
3046 }
3047
3048 fn try_send_rndis_message(
3049 &mut self,
3050 transaction_id: u64,
3051 channel_type: u32,
3052 recv_buffer_id: u16,
3053 transfer_pages: &[ring::TransferPageRange],
3054 ) -> Result<Option<usize>, WorkerError> {
3055 let message = self.message(
3056 protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET,
3057 protocol::Message1SendRndisPacket {
3058 channel_type,
3059 send_buffer_section_index: 0xffffffff,
3060 send_buffer_section_size: 0,
3061 },
3062 );
3063 let pending_send_size = match self.queue.split().1.batched().try_write_aligned(
3064 transaction_id,
3065 OutgoingPacketType::TransferPages(recv_buffer_id, transfer_pages),
3066 message.aligned_payload(),
3067 ) {
3068 Ok(()) => None,
3069 Err(queue::TryWriteError::Full(n)) => Some(n),
3070 Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
3071 };
3072 Ok(pending_send_size)
3073 }
3074
3075 fn send_rndis_control_message(
3076 &mut self,
3077 buffers: &ChannelBuffers,
3078 id: u32,
3079 message_length: usize,
3080 ) -> Result<(), WorkerError> {
3081 let result = self.try_send_rndis_message(
3082 id as u64,
3083 protocol::CONTROL_CHANNEL_TYPE,
3084 buffers.recv_buffer.id,
3085 std::slice::from_ref(&buffers.recv_buffer.transfer_page_range(id, message_length)),
3086 )?;
3087
3088 match result {
3090 None => Ok(()),
3091 Some(len) => {
3092 tracelimit::error_ratelimited!(len, "failed to write control message completion");
3093 Err(WorkerError::OutOfSpace)
3094 }
3095 }
3096 }
3097
3098 fn indicate_status(
3099 &mut self,
3100 buffers: &ChannelBuffers,
3101 id: u32,
3102 status: u32,
3103 payload: &[u8],
3104 ) -> Result<(), WorkerError> {
3105 let buffer = &buffers.recv_buffer.range(id);
3106 let mut writer = buffer.writer(&buffers.mem);
3107 let message_length = write_rndis_message(
3108 &mut writer,
3109 rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG,
3110 payload.len(),
3111 &rndisprot::IndicateStatus {
3112 status,
3113 status_buffer_length: payload.len() as u32,
3114 status_buffer_offset: if payload.is_empty() {
3115 0
3116 } else {
3117 size_of::<rndisprot::IndicateStatus>() as u32
3118 },
3119 },
3120 )?;
3121 writer.write(payload)?;
3122 self.send_rndis_control_message(buffers, id, message_length)?;
3123 Ok(())
3124 }
3125
3126 fn process_control_messages(
3129 &mut self,
3130 buffers: &ChannelBuffers,
3131 state: &mut ActiveState,
3132 ) -> Result<(), WorkerError> {
3133 let Some(primary) = &mut state.primary else {
3134 return Ok(());
3135 };
3136
3137 while !primary.control_messages.is_empty()
3138 || (primary.pending_offload_change && primary.rndis_state == RndisState::Operational)
3139 {
3140 if !self.queue.split().1.can_write(MIN_CONTROL_RING_SIZE)? {
3142 self.pending_send_size = MIN_CONTROL_RING_SIZE;
3143 break;
3144 }
3145 let Some(id) = primary.free_control_buffers.pop() else {
3146 break;
3147 };
3148
3149 assert!(state.rx_bufs.is_free(id.0));
3151 state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
3152
3153 if let Some(message) = primary.control_messages.pop_front() {
3154 primary.control_messages_len -= message.data.len();
3155 self.handle_rndis_control_message(
3156 primary,
3157 buffers,
3158 message.message_type,
3159 message.data.as_ref(),
3160 id.0,
3161 )?;
3162 } else if primary.pending_offload_change
3163 && primary.rndis_state == RndisState::Operational
3164 {
3165 let ndis_offload = primary.offload_config.ndis_offload();
3166 self.indicate_status(
3167 buffers,
3168 id.0,
3169 rndisprot::STATUS_TASK_OFFLOAD_CURRENT_CONFIG,
3170 &ndis_offload.as_bytes()[..ndis_offload.header.size.into()],
3171 )?;
3172 primary.pending_offload_change = false;
3173 } else {
3174 unreachable!();
3175 }
3176 }
3177 Ok(())
3178 }
3179
3180 fn send_coordinator_update_message(&mut self, guest_vf: bool, packet_filter: bool) {
3181 if self.restart.is_none() {
3182 self.restart = Some(CoordinatorMessage::Update(CoordinatorMessageUpdateType {
3183 guest_vf_state: guest_vf,
3184 filter_state: packet_filter,
3185 }));
3186 } else if let Some(CoordinatorMessage::Restart) = self.restart {
3187 } else if let Some(CoordinatorMessage::Update(ref mut update)) = self.restart {
3191 update.guest_vf_state |= guest_vf;
3193 update.filter_state |= packet_filter;
3194 }
3195 }
3196
3197 fn send_coordinator_update_vf(&mut self) {
3198 self.send_coordinator_update_message(true, false);
3199 }
3200
3201 fn send_coordinator_update_filter(&mut self) {
3202 self.send_coordinator_update_message(false, true);
3203 }
3204}
3205
3206fn write_rndis_message<T: IntoBytes + Immutable + KnownLayout>(
3208 writer: &mut impl MemoryWrite,
3209 message_type: u32,
3210 extra: usize,
3211 payload: &T,
3212) -> Result<usize, AccessError> {
3213 let message_length = size_of::<rndisprot::MessageHeader>() + size_of_val(payload) + extra;
3214 writer.write(
3215 rndisprot::MessageHeader {
3216 message_type,
3217 message_length: message_length as u32,
3218 }
3219 .as_bytes(),
3220 )?;
3221 writer.write(payload.as_bytes())?;
3222 Ok(message_length)
3223}
3224
3225#[derive(Debug, Error)]
3226enum OidError {
3227 #[error(transparent)]
3228 Access(#[from] AccessError),
3229 #[error("unknown oid")]
3230 UnknownOid,
3231 #[error("invalid oid input, bad field {0}")]
3232 InvalidInput(&'static str),
3233 #[error("bad ndis version")]
3234 BadVersion,
3235 #[error("feature {0} not supported")]
3236 NotSupported(&'static str),
3237}
3238
3239impl OidError {
3240 fn as_status(&self) -> u32 {
3241 match self {
3242 OidError::UnknownOid | OidError::NotSupported(_) => rndisprot::STATUS_NOT_SUPPORTED,
3243 OidError::BadVersion => rndisprot::STATUS_BAD_VERSION,
3244 OidError::InvalidInput(_) => rndisprot::STATUS_INVALID_DATA,
3245 OidError::Access(_) => rndisprot::STATUS_FAILURE,
3246 }
3247 }
3248}
3249
3250const DEFAULT_MTU: u32 = 1514;
3251const MIN_MTU: u32 = DEFAULT_MTU;
3252const MAX_MTU: u32 = 9216;
3253
3254const ETHERNET_HEADER_LEN: u32 = 14;
3255
3256impl Adapter {
3257 fn get_guest_vf_serial_number(&self, vfid: u32) -> u32 {
3258 if let Some(guest_os_id) = self.get_guest_os_id.as_ref().map(|f| f()) {
3259 if guest_os_id
3262 .microsoft()
3263 .unwrap_or(HvGuestOsMicrosoft::from(0))
3264 .os_id()
3265 == HvGuestOsMicrosoftIds::WINDOWS_NT.0
3266 {
3267 self.adapter_index
3268 } else {
3269 vfid
3270 }
3271 } else {
3272 vfid
3273 }
3274 }
3275
3276 fn handle_oid_query(
3277 &self,
3278 buffers: &ChannelBuffers,
3279 primary: &PrimaryChannelState,
3280 oid: rndisprot::Oid,
3281 mut writer: impl MemoryWrite,
3282 ) -> Result<usize, OidError> {
3283 tracing::debug!(?oid, "oid query");
3284 let available_len = writer.len();
3285 match oid {
3286 rndisprot::Oid::OID_GEN_SUPPORTED_LIST => {
3287 let supported_oids_common = &[
3288 rndisprot::Oid::OID_GEN_SUPPORTED_LIST,
3289 rndisprot::Oid::OID_GEN_HARDWARE_STATUS,
3290 rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED,
3291 rndisprot::Oid::OID_GEN_MEDIA_IN_USE,
3292 rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD,
3293 rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD,
3294 rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE,
3295 rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE,
3296 rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE,
3297 rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE,
3298 rndisprot::Oid::OID_GEN_LINK_SPEED,
3299 rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE,
3300 rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE,
3301 rndisprot::Oid::OID_GEN_VENDOR_ID,
3302 rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION,
3303 rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION,
3304 rndisprot::Oid::OID_GEN_DRIVER_VERSION,
3305 rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER,
3306 rndisprot::Oid::OID_GEN_PROTOCOL_OPTIONS,
3307 rndisprot::Oid::OID_GEN_MAC_OPTIONS,
3308 rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS,
3309 rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS,
3310 rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES,
3311 rndisprot::Oid::OID_GEN_FRIENDLY_NAME,
3312 rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS,
3314 rndisprot::Oid::OID_802_3_CURRENT_ADDRESS,
3315 rndisprot::Oid::OID_802_3_MULTICAST_LIST,
3316 rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE,
3317 rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT,
3319 rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION,
3320 rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS,
3321 rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER,
3326 ];
3327
3328 let supported_oids_6 = &[
3330 rndisprot::Oid::OID_GEN_LINK_PARAMETERS,
3332 rndisprot::Oid::OID_GEN_LINK_STATE,
3333 rndisprot::Oid::OID_GEN_MAX_LINK_SPEED,
3334 rndisprot::Oid::OID_GEN_BYTES_RCV,
3336 rndisprot::Oid::OID_GEN_BYTES_XMIT,
3337 rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS,
3339 rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION,
3340 rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES,
3341 rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG,
3342 ];
3345
3346 let supported_oids_63 = &[
3347 rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES,
3348 rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS,
3349 ];
3350
3351 match buffers.ndis_version.major {
3352 5 => {
3353 writer.write(supported_oids_common.as_bytes())?;
3354 }
3355 6 => {
3356 writer.write(supported_oids_common.as_bytes())?;
3357 writer.write(supported_oids_6.as_bytes())?;
3358 if buffers.ndis_version.minor >= 30 {
3359 writer.write(supported_oids_63.as_bytes())?;
3360 }
3361 }
3362 _ => return Err(OidError::BadVersion),
3363 }
3364 }
3365 rndisprot::Oid::OID_GEN_HARDWARE_STATUS => {
3366 let status: u32 = 0; writer.write(status.as_bytes())?;
3368 }
3369 rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED | rndisprot::Oid::OID_GEN_MEDIA_IN_USE => {
3370 writer.write(rndisprot::MEDIUM_802_3.as_bytes())?;
3371 }
3372 rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD
3373 | rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD
3374 | rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE => {
3375 let len: u32 = buffers.ndis_config.mtu - ETHERNET_HEADER_LEN;
3376 writer.write(len.as_bytes())?;
3377 }
3378 rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE
3379 | rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE
3380 | rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE => {
3381 let len: u32 = buffers.ndis_config.mtu;
3382 writer.write(len.as_bytes())?;
3383 }
3384 rndisprot::Oid::OID_GEN_LINK_SPEED => {
3385 let speed: u32 = (self.link_speed / 100) as u32; writer.write(speed.as_bytes())?;
3387 }
3388 rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE
3389 | rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE => {
3390 writer.write((256u32 * 1024).as_bytes())?
3392 }
3393 rndisprot::Oid::OID_GEN_VENDOR_ID => {
3394 writer.write(0x0000155du32.as_bytes())?;
3397 }
3398 rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION => writer.write(b"Microsoft Corporation")?,
3399 rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION
3400 | rndisprot::Oid::OID_GEN_DRIVER_VERSION => {
3401 writer.write(0x0100u16.as_bytes())? }
3403 rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => writer.write(0u32.as_bytes())?,
3404 rndisprot::Oid::OID_GEN_MAC_OPTIONS => {
3405 let options: u32 = rndisprot::MAC_OPTION_COPY_LOOKAHEAD_DATA
3406 | rndisprot::MAC_OPTION_TRANSFERS_NOT_PEND
3407 | rndisprot::MAC_OPTION_NO_LOOPBACK;
3408 writer.write(options.as_bytes())?;
3409 }
3410 rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS => {
3411 writer.write(rndisprot::MEDIA_STATE_CONNECTED.as_bytes())?;
3412 }
3413 rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS => writer.write(u32::MAX.as_bytes())?,
3414 rndisprot::Oid::OID_GEN_FRIENDLY_NAME => {
3415 let name16: Vec<u16> = "Network Device".encode_utf16().collect();
3416 let mut name = rndisprot::FriendlyName::new_zeroed();
3417 name.name[..name16.len()].copy_from_slice(&name16);
3418 writer.write(name.as_bytes())?
3419 }
3420 rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS
3421 | rndisprot::Oid::OID_802_3_CURRENT_ADDRESS => {
3422 writer.write(&self.mac_address.to_bytes())?
3423 }
3424 rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE => {
3425 writer.write(0u32.as_bytes())?;
3426 }
3427 rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT
3428 | rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION
3429 | rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS => writer.write(0u32.as_bytes())?,
3430
3431 rndisprot::Oid::OID_GEN_LINK_STATE => {
3433 let link_state = rndisprot::LinkState {
3434 header: rndisprot::NdisObjectHeader {
3435 object_type: rndisprot::NdisObjectType::DEFAULT,
3436 revision: 1,
3437 size: size_of::<rndisprot::LinkState>() as u16,
3438 },
3439 media_connect_state: 1, media_duplex_state: 0, padding: 0,
3442 xmit_link_speed: self.link_speed,
3443 rcv_link_speed: self.link_speed,
3444 pause_functions: 0, auto_negotiation_flags: 0,
3446 };
3447 writer.write(link_state.as_bytes())?;
3448 }
3449 rndisprot::Oid::OID_GEN_MAX_LINK_SPEED => {
3450 let link_speed = rndisprot::LinkSpeed {
3451 xmit: self.link_speed,
3452 rcv: self.link_speed,
3453 };
3454 writer.write(link_speed.as_bytes())?;
3455 }
3456 rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES => {
3457 let ndis_offload = self.offload_support.ndis_offload();
3458 writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3459 }
3460 rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG => {
3461 let ndis_offload = primary.offload_config.ndis_offload();
3462 writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3463 }
3464 rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3465 writer.write(
3466 &rndisprot::NdisOffloadEncapsulation {
3467 header: rndisprot::NdisObjectHeader {
3468 object_type: rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3469 revision: 1,
3470 size: rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1 as u16,
3471 },
3472 ipv4_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3473 ipv4_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3474 ipv4_header_size: ETHERNET_HEADER_LEN,
3475 ipv6_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3476 ipv6_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3477 ipv6_header_size: ETHERNET_HEADER_LEN,
3478 }
3479 .as_bytes()[..rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1],
3480 )?;
3481 }
3482 rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES => {
3483 writer.write(
3484 &rndisprot::NdisReceiveScaleCapabilities {
3485 header: rndisprot::NdisObjectHeader {
3486 object_type: rndisprot::NdisObjectType::RSS_CAPABILITIES,
3487 revision: 2,
3488 size: rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2
3489 as u16,
3490 },
3491 capabilities_flags: rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV4
3492 | rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV6
3493 | rndisprot::NDIS_HASH_FUNCTION_TOEPLITZ,
3494 number_of_interrupt_messages: 1,
3495 number_of_receive_queues: self.max_queues.into(),
3496 number_of_indirection_table_entries: if self.indirection_table_size != 0 {
3497 self.indirection_table_size
3498 } else {
3499 128
3502 },
3503 padding: 0,
3504 }
3505 .as_bytes()[..rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2],
3506 )?;
3507 }
3508 _ => {
3509 tracelimit::warn_ratelimited!(?oid, "query for unknown OID");
3510 return Err(OidError::UnknownOid);
3511 }
3512 };
3513 Ok(available_len - writer.len())
3514 }
3515
3516 fn handle_oid_set(
3517 &self,
3518 primary: &mut PrimaryChannelState,
3519 oid: rndisprot::Oid,
3520 reader: impl MemoryRead + Clone,
3521 ) -> Result<(bool, Option<u32>), OidError> {
3522 tracing::debug!(?oid, "oid set");
3523
3524 let mut restart_endpoint = false;
3525 let mut packet_filter = None;
3526 match oid {
3527 rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => {
3528 packet_filter = self.oid_set_packet_filter(reader)?;
3529 }
3530 rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS => {
3531 self.oid_set_offload_parameters(reader, primary)?;
3532 }
3533 rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3534 self.oid_set_offload_encapsulation(reader)?;
3535 }
3536 rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER => {
3537 self.oid_set_rndis_config_parameter(reader, primary)?;
3538 }
3539 rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES => {
3540 }
3542 rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS => {
3543 let rss_was_enabled = self.oid_set_rss_parameters(reader, primary)?;
3544
3545 if rss_was_enabled || primary.rss_state.is_some() {
3552 restart_endpoint = true;
3553 }
3554 }
3555 _ => {
3556 tracelimit::warn_ratelimited!(?oid, "set of unknown OID");
3557 return Err(OidError::UnknownOid);
3558 }
3559 }
3560 Ok((restart_endpoint, packet_filter))
3561 }
3562
3563 fn oid_set_rss_parameters(
3564 &self,
3565 mut reader: impl MemoryRead + Clone,
3566 primary: &mut PrimaryChannelState,
3567 ) -> Result<bool, OidError> {
3568 let mut params = rndisprot::NdisReceiveScaleParameters::new_zeroed();
3570 let len = reader.len().min(size_of_val(¶ms));
3571 reader.clone().read(&mut params.as_mut_bytes()[..len])?;
3572
3573 let rss_was_enabled = primary.rss_state.is_some();
3574
3575 if ((params.flags & NDIS_RSS_PARAM_FLAG_DISABLE_RSS) != 0)
3576 || ((params.hash_information & NDIS_HASH_FUNCTION_MASK) == 0)
3577 {
3578 primary.rss_state = None;
3579 return Ok(rss_was_enabled);
3580 }
3581
3582 if params.hash_secret_key_size != 40 {
3583 return Err(OidError::InvalidInput("hash_secret_key_size"));
3584 }
3585 if params.indirection_table_size % 4 != 0 {
3586 return Err(OidError::InvalidInput("indirection_table_size"));
3587 }
3588 let indirection_table_size =
3589 (params.indirection_table_size / 4).min(self.indirection_table_size) as usize;
3590 let mut key = [0; 40];
3591 let mut indirection_table = vec![0u32; self.indirection_table_size as usize];
3592 reader
3593 .clone()
3594 .skip(params.hash_secret_key_offset as usize)?
3595 .read(&mut key)?;
3596 reader
3597 .skip(params.indirection_table_offset as usize)?
3598 .read(indirection_table[..indirection_table_size].as_mut_bytes())?;
3599 tracelimit::info_ratelimited!(?indirection_table, "OID_GEN_RECEIVE_SCALE_PARAMETERS");
3600 if indirection_table
3601 .iter()
3602 .any(|&x| x >= self.max_queues as u32)
3603 {
3604 return Err(OidError::InvalidInput("indirection_table"));
3605 }
3606 let (indir_init, indir_uninit) = indirection_table.split_at_mut(indirection_table_size);
3607 for (src, dest) in std::iter::repeat_with(|| indir_init.iter().copied())
3608 .flatten()
3609 .zip(indir_uninit)
3610 {
3611 *dest = src;
3612 }
3613 primary.rss_state = Some(RssState {
3614 key,
3615 indirection_table: indirection_table.iter().map(|&x| x as u16).collect(),
3616 });
3617 Ok(rss_was_enabled)
3618 }
3619
3620 fn oid_set_packet_filter(
3621 &self,
3622 reader: impl MemoryRead + Clone,
3623 ) -> Result<Option<u32>, OidError> {
3624 let filter: rndisprot::RndisPacketFilterOidValue = reader.clone().read_plain()?;
3625 tracing::debug!(filter, "set packet filter");
3626 Ok(Some(filter))
3627 }
3628
3629 fn oid_set_offload_parameters(
3630 &self,
3631 reader: impl MemoryRead + Clone,
3632 primary: &mut PrimaryChannelState,
3633 ) -> Result<(), OidError> {
3634 let offload: rndisprot::NdisOffloadParameters = read_ndis_object(
3635 reader,
3636 rndisprot::NdisObjectType::DEFAULT,
3637 1,
3638 rndisprot::NDIS_SIZEOF_OFFLOAD_PARAMETERS_REVISION_1,
3639 )?;
3640
3641 tracing::debug!(?offload, "offload parameters");
3642 let rndisprot::NdisOffloadParameters {
3643 header: _,
3644 ipv4_checksum,
3645 tcp4_checksum,
3646 udp4_checksum,
3647 tcp6_checksum,
3648 udp6_checksum,
3649 lsov1,
3650 ipsec_v1: _,
3651 lsov2_ipv4,
3652 lsov2_ipv6,
3653 tcp_connection_ipv4: _,
3654 tcp_connection_ipv6: _,
3655 reserved: _,
3656 flags: _,
3657 } = offload;
3658
3659 if lsov1 == rndisprot::OffloadParametersSimple::ENABLED {
3660 return Err(OidError::NotSupported("lsov1"));
3661 }
3662 if let Some((tx, rx)) = ipv4_checksum.tx_rx() {
3663 primary.offload_config.checksum_tx.ipv4_header = tx;
3664 primary.offload_config.checksum_rx.ipv4_header = rx;
3665 }
3666 if let Some((tx, rx)) = tcp4_checksum.tx_rx() {
3667 primary.offload_config.checksum_tx.tcp4 = tx;
3668 primary.offload_config.checksum_rx.tcp4 = rx;
3669 }
3670 if let Some((tx, rx)) = tcp6_checksum.tx_rx() {
3671 primary.offload_config.checksum_tx.tcp6 = tx;
3672 primary.offload_config.checksum_rx.tcp6 = rx;
3673 }
3674 if let Some((tx, rx)) = udp4_checksum.tx_rx() {
3675 primary.offload_config.checksum_tx.udp4 = tx;
3676 primary.offload_config.checksum_rx.udp4 = rx;
3677 }
3678 if let Some((tx, rx)) = udp6_checksum.tx_rx() {
3679 primary.offload_config.checksum_tx.udp6 = tx;
3680 primary.offload_config.checksum_rx.udp6 = rx;
3681 }
3682 if let Some(enable) = lsov2_ipv4.enable() {
3683 primary.offload_config.lso4 = enable;
3684 }
3685 if let Some(enable) = lsov2_ipv6.enable() {
3686 primary.offload_config.lso6 = enable;
3687 }
3688 primary
3689 .offload_config
3690 .mask_to_supported(&self.offload_support);
3691 primary.pending_offload_change = true;
3692 Ok(())
3693 }
3694
3695 fn oid_set_offload_encapsulation(
3696 &self,
3697 reader: impl MemoryRead + Clone,
3698 ) -> Result<(), OidError> {
3699 let encap: rndisprot::NdisOffloadEncapsulation = read_ndis_object(
3700 reader,
3701 rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3702 1,
3703 rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1,
3704 )?;
3705 if encap.ipv4_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3706 && (encap.ipv4_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3707 || encap.ipv4_header_size != ETHERNET_HEADER_LEN)
3708 {
3709 return Err(OidError::NotSupported("ipv4 encap"));
3710 }
3711 if encap.ipv6_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3712 && (encap.ipv6_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3713 || encap.ipv6_header_size != ETHERNET_HEADER_LEN)
3714 {
3715 return Err(OidError::NotSupported("ipv6 encap"));
3716 }
3717 Ok(())
3718 }
3719
3720 fn oid_set_rndis_config_parameter(
3721 &self,
3722 reader: impl MemoryRead + Clone,
3723 primary: &mut PrimaryChannelState,
3724 ) -> Result<(), OidError> {
3725 let info: rndisprot::RndisConfigParameterInfo = reader.clone().read_plain()?;
3726 if info.name_length > 255 {
3727 return Err(OidError::InvalidInput("name_length"));
3728 }
3729 if info.value_length > 255 {
3730 return Err(OidError::InvalidInput("value_length"));
3731 }
3732 let name = reader
3733 .clone()
3734 .skip(info.name_offset as usize)?
3735 .read_n::<u16>(info.name_length as usize / 2)?;
3736 let name = String::from_utf16(&name).map_err(|_| OidError::InvalidInput("name"))?;
3737 let mut value = reader;
3738 value.skip(info.value_offset as usize)?;
3739 let mut value = value.limit(info.value_length as usize);
3740 match info.parameter_type {
3741 rndisprot::NdisParameterType::STRING => {
3742 let value = value.read_n::<u16>(info.value_length as usize / 2)?;
3743 let value =
3744 String::from_utf16(&value).map_err(|_| OidError::InvalidInput("value"))?;
3745 let as_num = value.as_bytes().first().map_or(0, |c| c - b'0');
3746 let tx = as_num & 1 != 0;
3747 let rx = as_num & 2 != 0;
3748
3749 tracing::debug!(name, value, "rndis config");
3750 match name.as_str() {
3751 "*IPChecksumOffloadIPv4" => {
3752 primary.offload_config.checksum_tx.ipv4_header = tx;
3753 primary.offload_config.checksum_rx.ipv4_header = rx;
3754 }
3755 "*LsoV2IPv4" => {
3756 primary.offload_config.lso4 = as_num != 0;
3757 }
3758 "*LsoV2IPv6" => {
3759 primary.offload_config.lso6 = as_num != 0;
3760 }
3761 "*TCPChecksumOffloadIPv4" => {
3762 primary.offload_config.checksum_tx.tcp4 = tx;
3763 primary.offload_config.checksum_rx.tcp4 = rx;
3764 }
3765 "*TCPChecksumOffloadIPv6" => {
3766 primary.offload_config.checksum_tx.tcp6 = tx;
3767 primary.offload_config.checksum_rx.tcp6 = rx;
3768 }
3769 "*UDPChecksumOffloadIPv4" => {
3770 primary.offload_config.checksum_tx.udp4 = tx;
3771 primary.offload_config.checksum_rx.udp4 = rx;
3772 }
3773 "*UDPChecksumOffloadIPv6" => {
3774 primary.offload_config.checksum_tx.udp6 = tx;
3775 primary.offload_config.checksum_rx.udp6 = rx;
3776 }
3777 _ => {}
3778 }
3779 primary
3780 .offload_config
3781 .mask_to_supported(&self.offload_support);
3782 }
3783 rndisprot::NdisParameterType::INTEGER => {
3784 let value: u32 = value.read_plain()?;
3785 tracing::debug!(name, value, "rndis config");
3786 }
3787 parameter_type => tracelimit::warn_ratelimited!(
3788 name,
3789 ?parameter_type,
3790 "unhandled rndis config parameter type"
3791 ),
3792 }
3793 Ok(())
3794 }
3795}
3796
3797fn read_ndis_object<T: IntoBytes + FromBytes + Debug + Immutable + KnownLayout>(
3798 mut reader: impl MemoryRead,
3799 object_type: rndisprot::NdisObjectType,
3800 min_revision: u8,
3801 min_size: usize,
3802) -> Result<T, OidError> {
3803 let mut buffer = T::new_zeroed();
3804 let sent_size = reader.len();
3805 let len = sent_size.min(size_of_val(&buffer));
3806 reader.read(&mut buffer.as_mut_bytes()[..len])?;
3807 validate_ndis_object_header(
3808 &rndisprot::NdisObjectHeader::read_from_prefix(buffer.as_bytes())
3809 .unwrap()
3810 .0, sent_size,
3812 object_type,
3813 min_revision,
3814 min_size,
3815 )?;
3816 Ok(buffer)
3817}
3818
3819fn validate_ndis_object_header(
3820 header: &rndisprot::NdisObjectHeader,
3821 sent_size: usize,
3822 object_type: rndisprot::NdisObjectType,
3823 min_revision: u8,
3824 min_size: usize,
3825) -> Result<(), OidError> {
3826 if header.object_type != object_type {
3827 return Err(OidError::InvalidInput("header.object_type"));
3828 }
3829 if sent_size < header.size as usize {
3830 return Err(OidError::InvalidInput("header.size"));
3831 }
3832 if header.revision < min_revision {
3833 return Err(OidError::InvalidInput("header.revision"));
3834 }
3835 if (header.size as usize) < min_size {
3836 return Err(OidError::InvalidInput("header.size"));
3837 }
3838 Ok(())
3839}
3840
3841struct Coordinator {
3842 recv: mpsc::Receiver<CoordinatorMessage>,
3843 channel_control: ChannelControl,
3844 restart: bool,
3845 workers: Vec<TaskControl<NetQueue, Worker<GpadlRingMem>>>,
3846 buffers: Option<Arc<ChannelBuffers>>,
3847 num_queues: u16,
3848 active_packet_filter: u32,
3849 sleep_deadline: Option<Instant>,
3850}
3851
3852enum CoordinatorStatePendingVfState {
3856 Ready,
3858 Delay {
3860 timer: PolledTimer,
3861 delay_until: Instant,
3862 },
3863 Pending,
3865}
3866
3867struct CoordinatorState {
3868 endpoint: Box<dyn Endpoint>,
3869 adapter: Arc<Adapter>,
3870 virtual_function: Option<Box<dyn VirtualFunction>>,
3871 pending_vf_state: CoordinatorStatePendingVfState,
3872}
3873
3874impl InspectTaskMut<Coordinator> for CoordinatorState {
3875 fn inspect_mut(
3876 &mut self,
3877 req: inspect::Request<'_>,
3878 mut coordinator: Option<&mut Coordinator>,
3879 ) {
3880 let mut resp = req.respond();
3881
3882 let adapter = self.adapter.as_ref();
3883 resp.field("mac_address", adapter.mac_address)
3884 .field("max_queues", adapter.max_queues)
3885 .sensitivity_field(
3886 "offload_support",
3887 SensitivityLevel::Safe,
3888 &adapter.offload_support,
3889 )
3890 .field_mut_with("ring_size_limit", |v| -> anyhow::Result<_> {
3891 if let Some(v) = v {
3892 let v: usize = v.parse()?;
3893 adapter.ring_size_limit.store(v, Ordering::Relaxed);
3894 if let Some(this) = &mut coordinator {
3896 for worker in &mut this.workers {
3897 worker.update_with(|_, _| ());
3898 }
3899 }
3900 }
3901 Ok(adapter.ring_size_limit.load(Ordering::Relaxed))
3902 });
3903
3904 resp.field("endpoint_type", self.endpoint.endpoint_type())
3905 .field(
3906 "endpoint_max_queues",
3907 self.endpoint.multiqueue_support().max_queues,
3908 )
3909 .sensitivity_field_mut("endpoint", SensitivityLevel::Safe, self.endpoint.as_mut());
3910
3911 if let Some(coordinator) = coordinator {
3912 resp.sensitivity_child("queues", SensitivityLevel::Safe, |req| {
3913 let mut resp = req.respond();
3914 for (i, q) in coordinator.workers[..coordinator.num_queues as usize]
3915 .iter_mut()
3916 .enumerate()
3917 {
3918 resp.field_mut(&i.to_string(), q);
3919 }
3920 });
3921
3922 resp.merge(inspect::adhoc_mut(|req| {
3924 let deferred = req.defer();
3925 coordinator.workers[0].update_with(|_, worker| {
3926 let Some(worker) = worker.as_deref() else {
3927 return;
3928 };
3929 if let Some(state) = worker.state.ready() {
3930 deferred.respond(|resp| {
3931 resp.merge(&state.buffers);
3932 resp.sensitivity_field(
3933 "primary_channel_state",
3934 SensitivityLevel::Safe,
3935 &state.state.primary,
3936 )
3937 .sensitivity_field(
3938 "packet_filter",
3939 SensitivityLevel::Safe,
3940 inspect::AsHex(worker.channel.packet_filter),
3941 );
3942 });
3943 }
3944 })
3945 }));
3946 }
3947 }
3948}
3949
3950impl AsyncRun<Coordinator> for CoordinatorState {
3951 async fn run(
3952 &mut self,
3953 stop: &mut StopTask<'_>,
3954 coordinator: &mut Coordinator,
3955 ) -> Result<(), task_control::Cancelled> {
3956 coordinator.process(stop, self).await
3957 }
3958}
3959
3960impl Coordinator {
3961 async fn process(
3962 &mut self,
3963 stop: &mut StopTask<'_>,
3964 state: &mut CoordinatorState,
3965 ) -> Result<(), task_control::Cancelled> {
3966 loop {
3967 if self.restart {
3968 stop.until_stopped(self.stop_workers()).await?;
3969 if let Err(err) = self
3972 .restart_queues(state)
3973 .instrument(tracing::info_span!("netvsp_restart_queues"))
3974 .await
3975 {
3976 tracing::error!(
3977 error = &err as &dyn std::error::Error,
3978 "failed to restart queues"
3979 );
3980 }
3981 if let Some(primary) = self.primary_mut() {
3982 primary.is_data_path_switched =
3983 state.endpoint.get_data_path_to_guest_vf().await.ok();
3984 tracing::info!(
3985 is_data_path_switched = primary.is_data_path_switched,
3986 "Query data path state"
3987 );
3988 }
3989 self.restore_guest_vf_state(state).await;
3990 self.restart = false;
3991 }
3992
3993 for worker in &mut self.workers[1..] {
3996 worker.start();
3997 }
3998 if !self.workers[0].is_running()
3999 && self.workers[0].state().is_none_or(|worker| {
4000 !matches!(worker.state, WorkerState::WaitingForCoordinator(_))
4001 })
4002 {
4003 self.workers[0].start();
4004 }
4005
4006 enum Message {
4007 Internal(CoordinatorMessage),
4008 ChannelDisconnected,
4009 UpdateFromEndpoint(EndpointAction),
4010 UpdateFromVf(Rpc<(), ()>),
4011 OfferVfDevice,
4012 PendingVfStateComplete,
4013 TimerExpired,
4014 }
4015 let message = if matches!(
4016 state.pending_vf_state,
4017 CoordinatorStatePendingVfState::Pending
4018 ) {
4019 state
4022 .virtual_function
4023 .as_mut()
4024 .expect("Pending requires a VF")
4025 .guest_ready_for_device()
4026 .await;
4027 Message::PendingVfStateComplete
4028 } else {
4029 let timer_sleep = async {
4030 if let Some(deadline) = self.sleep_deadline {
4031 let mut timer = PolledTimer::new(&state.adapter.driver);
4032 timer.sleep_until(deadline).await;
4033 } else {
4034 pending::<()>().await;
4035 }
4036 Message::TimerExpired
4037 };
4038 let wait_for_message = async {
4039 let internal_msg = self
4040 .recv
4041 .next()
4042 .map(|x| x.map_or(Message::ChannelDisconnected, Message::Internal));
4043 let endpoint_restart = state
4044 .endpoint
4045 .wait_for_endpoint_action()
4046 .map(Message::UpdateFromEndpoint);
4047 if let Some(vf) = state.virtual_function.as_mut() {
4048 match state.pending_vf_state {
4049 CoordinatorStatePendingVfState::Ready
4050 | CoordinatorStatePendingVfState::Delay { .. } => {
4051 let offer_device = async {
4052 if let CoordinatorStatePendingVfState::Delay {
4053 timer,
4054 delay_until,
4055 } = &mut state.pending_vf_state
4056 {
4057 timer.sleep_until(*delay_until).await;
4058 } else {
4059 pending::<()>().await;
4060 }
4061 Message::OfferVfDevice
4062 };
4063 (
4064 internal_msg,
4065 offer_device,
4066 endpoint_restart,
4067 vf.wait_for_state_change().map(Message::UpdateFromVf),
4068 timer_sleep,
4069 )
4070 .race()
4071 .await
4072 }
4073 CoordinatorStatePendingVfState::Pending => unreachable!(),
4074 }
4075 } else {
4076 (internal_msg, endpoint_restart, timer_sleep).race().await
4077 }
4078 };
4079
4080 stop.until_stopped(wait_for_message).await?
4081 };
4082 match message {
4083 Message::Internal(msg) => {
4084 self.handle_coordinator_message(msg, state).await;
4085 }
4086 Message::UpdateFromVf(rpc) => {
4087 rpc.handle(async |_| {
4088 self.update_guest_vf_state(state).await;
4089 })
4090 .await;
4091 }
4092 Message::OfferVfDevice => {
4093 self.workers[0].stop().await;
4094 if let Some(primary) = self.primary_mut() {
4095 if matches!(
4096 primary.guest_vf_state,
4097 PrimaryChannelGuestVfState::AvailableAdvertised
4098 ) {
4099 primary.guest_vf_state = PrimaryChannelGuestVfState::Ready;
4100 }
4101 }
4102
4103 state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4104 }
4105 Message::PendingVfStateComplete => {
4106 state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
4107 }
4108 Message::TimerExpired => {
4109 self.workers[0].stop().await;
4111 if let Some(primary) = self.primary_mut() {
4112 if let PendingLinkAction::Delay(up) = primary.pending_link_action {
4113 primary.pending_link_action = PendingLinkAction::Active(up);
4114 }
4115 }
4116 self.sleep_deadline = None;
4117 }
4118 Message::UpdateFromEndpoint(EndpointAction::RestartRequired) => self.restart = true,
4119 Message::UpdateFromEndpoint(EndpointAction::LinkStatusNotify(connect)) => {
4120 self.workers[0].stop().await;
4121
4122 if let Some(primary) = self.primary_mut() {
4130 primary.pending_link_action = PendingLinkAction::Active(connect);
4131 }
4132
4133 self.sleep_deadline = None;
4135 }
4136 Message::ChannelDisconnected => {
4137 break;
4138 }
4139 };
4140 }
4141 Ok(())
4142 }
4143
4144 async fn handle_coordinator_message(
4145 &mut self,
4146 msg: CoordinatorMessage,
4147 state: &mut CoordinatorState,
4148 ) {
4149 self.workers[0].stop().await;
4150 if let Some(worker) = self.workers[0].state_mut() {
4151 if matches!(worker.state, WorkerState::WaitingForCoordinator(_)) {
4152 let WorkerState::WaitingForCoordinator(Some(ready)) =
4153 std::mem::replace(&mut worker.state, WorkerState::WaitingForCoordinator(None))
4154 else {
4155 unreachable!("valid ready state")
4156 };
4157 let _ = std::mem::replace(&mut worker.state, WorkerState::Ready(ready));
4158 }
4159 }
4160 match msg {
4161 CoordinatorMessage::Update(update_type) => {
4162 if update_type.filter_state {
4163 self.stop_workers().await;
4164 self.active_packet_filter =
4165 self.workers[0].state().unwrap().channel.packet_filter;
4166 self.workers.iter_mut().skip(1).for_each(|worker| {
4167 if let Some(state) = worker.state_mut() {
4168 state.channel.packet_filter = self.active_packet_filter;
4169 tracing::debug!(
4170 packet_filter = ?self.active_packet_filter,
4171 channel_idx = state.channel_idx,
4172 "update packet filter"
4173 );
4174 }
4175 });
4176 }
4177
4178 if update_type.guest_vf_state {
4179 self.update_guest_vf_state(state).await;
4180 }
4181 }
4182 CoordinatorMessage::StartTimer(deadline) => {
4183 self.sleep_deadline = Some(deadline);
4184 }
4185 CoordinatorMessage::Restart => self.restart = true,
4186 }
4187 }
4188
4189 async fn stop_workers(&mut self) {
4190 for worker in &mut self.workers {
4191 worker.stop().await;
4192 }
4193 }
4194
4195 async fn restore_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
4196 let primary = match self.primary_mut() {
4197 Some(primary) => primary,
4198 None => return,
4199 };
4200
4201 let virtual_function = c_state.virtual_function.as_mut();
4203 let guest_vf_id = match &virtual_function {
4204 Some(vf) => vf.id().await,
4205 None => None,
4206 };
4207 if let Some(guest_vf_id) = guest_vf_id {
4208 match primary.guest_vf_state {
4210 PrimaryChannelGuestVfState::AvailableAdvertised
4211 | PrimaryChannelGuestVfState::Restoring(
4212 saved_state::GuestVfState::AvailableAdvertised,
4213 ) => {
4214 if !primary.is_data_path_switched.unwrap_or(false) {
4215 let timer = PolledTimer::new(&c_state.adapter.driver);
4216 c_state.pending_vf_state = CoordinatorStatePendingVfState::Delay {
4217 timer,
4218 delay_until: Instant::now() + VF_DEVICE_DELAY,
4219 };
4220 }
4221 }
4222 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4223 | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4224 | PrimaryChannelGuestVfState::Ready
4225 | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready)
4226 | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4227 | PrimaryChannelGuestVfState::Restoring(
4228 saved_state::GuestVfState::DataPathSwitchPending { .. },
4229 )
4230 | PrimaryChannelGuestVfState::DataPathSwitched
4231 | PrimaryChannelGuestVfState::Restoring(
4232 saved_state::GuestVfState::DataPathSwitched,
4233 )
4234 | PrimaryChannelGuestVfState::DataPathSynthetic => {
4235 c_state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4236 }
4237 _ => (),
4238 };
4239 if let PrimaryChannelGuestVfState::Restoring(
4241 saved_state::GuestVfState::DataPathSwitchPending {
4242 to_guest,
4243 id,
4244 result,
4245 },
4246 ) = primary.guest_vf_state
4247 {
4248 if result.is_some() {
4250 primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
4251 to_guest,
4252 id,
4253 result,
4254 };
4255 return;
4256 }
4257 }
4258 primary.guest_vf_state = match primary.guest_vf_state {
4259 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4260 | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4261 | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4262 | PrimaryChannelGuestVfState::Restoring(
4263 saved_state::GuestVfState::DataPathSwitchPending { .. },
4264 )
4265 | PrimaryChannelGuestVfState::DataPathSynthetic => {
4266 let (to_guest, id) = match primary.guest_vf_state {
4267 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4268 to_guest,
4269 id,
4270 }
4271 | PrimaryChannelGuestVfState::DataPathSwitchPending {
4272 to_guest, id, ..
4273 }
4274 | PrimaryChannelGuestVfState::Restoring(
4275 saved_state::GuestVfState::DataPathSwitchPending {
4276 to_guest, id, ..
4277 },
4278 ) => (to_guest, id),
4279 _ => (true, None),
4280 };
4281 if matches!(
4285 c_state.pending_vf_state,
4286 CoordinatorStatePendingVfState::Delay { .. }
4287 ) {
4288 c_state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4289 }
4290 let result = c_state.endpoint.set_data_path_to_guest_vf(to_guest).await;
4291 let result = if let Err(err) = result {
4292 tracing::error!(
4293 err = err.as_ref() as &dyn std::error::Error,
4294 to_guest,
4295 "Failed to switch guest VF data path"
4296 );
4297 false
4298 } else {
4299 primary.is_data_path_switched = Some(to_guest);
4300 true
4301 };
4302 match primary.guest_vf_state {
4303 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4304 ..
4305 }
4306 | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4307 | PrimaryChannelGuestVfState::Restoring(
4308 saved_state::GuestVfState::DataPathSwitchPending { .. },
4309 ) => PrimaryChannelGuestVfState::DataPathSwitchPending {
4310 to_guest,
4311 id,
4312 result: Some(result),
4313 },
4314 _ if result => PrimaryChannelGuestVfState::DataPathSwitched,
4315 _ => PrimaryChannelGuestVfState::DataPathSynthetic,
4316 }
4317 }
4318 PrimaryChannelGuestVfState::Initializing
4319 | PrimaryChannelGuestVfState::Unavailable
4320 | PrimaryChannelGuestVfState::UnavailableFromAvailable
4321 | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
4322 PrimaryChannelGuestVfState::Available { vfid: guest_vf_id }
4323 }
4324 PrimaryChannelGuestVfState::AvailableAdvertised
4325 | PrimaryChannelGuestVfState::Restoring(
4326 saved_state::GuestVfState::AvailableAdvertised,
4327 ) => {
4328 if !primary.is_data_path_switched.unwrap_or(false) {
4329 PrimaryChannelGuestVfState::AvailableAdvertised
4330 } else {
4331 PrimaryChannelGuestVfState::DataPathSwitched
4334 }
4335 }
4336 PrimaryChannelGuestVfState::DataPathSwitched
4337 | PrimaryChannelGuestVfState::Restoring(
4338 saved_state::GuestVfState::DataPathSwitched,
4339 ) => PrimaryChannelGuestVfState::DataPathSwitched,
4340 PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4341 PrimaryChannelGuestVfState::Ready
4342 }
4343 _ => primary.guest_vf_state,
4344 };
4345 } else {
4346 match primary.guest_vf_state {
4348 PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, .. }
4349 | PrimaryChannelGuestVfState::Restoring(
4350 saved_state::GuestVfState::DataPathSwitchPending { to_guest, .. },
4351 ) => {
4352 if !to_guest {
4353 if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4354 tracing::warn!(
4355 err = err.as_ref() as &dyn std::error::Error,
4356 "Failed setting data path back to synthetic after guest VF was removed."
4357 );
4358 }
4359 primary.is_data_path_switched = Some(false);
4360 }
4361 }
4362 PrimaryChannelGuestVfState::DataPathSwitched
4363 | PrimaryChannelGuestVfState::Restoring(
4364 saved_state::GuestVfState::DataPathSwitched,
4365 ) => {
4366 if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4367 tracing::warn!(
4368 err = err.as_ref() as &dyn std::error::Error,
4369 "Failed setting data path back to synthetic after guest VF was removed."
4370 );
4371 }
4372 primary.is_data_path_switched = Some(false);
4373 }
4374 _ => (),
4375 }
4376 if let PrimaryChannelGuestVfState::AvailableAdvertised = primary.guest_vf_state {
4377 c_state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
4378 }
4379 primary.guest_vf_state = match primary.guest_vf_state {
4381 PrimaryChannelGuestVfState::Initializing
4382 | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState)
4383 | PrimaryChannelGuestVfState::Available { .. } => {
4384 PrimaryChannelGuestVfState::Unavailable
4385 }
4386 PrimaryChannelGuestVfState::AvailableAdvertised
4387 | PrimaryChannelGuestVfState::Restoring(
4388 saved_state::GuestVfState::AvailableAdvertised,
4389 )
4390 | PrimaryChannelGuestVfState::Ready
4391 | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4392 PrimaryChannelGuestVfState::UnavailableFromAvailable
4393 }
4394 PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, id, .. }
4395 | PrimaryChannelGuestVfState::Restoring(
4396 saved_state::GuestVfState::DataPathSwitchPending { to_guest, id, .. },
4397 ) => PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4398 to_guest,
4399 id,
4400 },
4401 PrimaryChannelGuestVfState::DataPathSwitched
4402 | PrimaryChannelGuestVfState::Restoring(
4403 saved_state::GuestVfState::DataPathSwitched,
4404 )
4405 | PrimaryChannelGuestVfState::DataPathSynthetic => {
4406 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4407 }
4408 _ => primary.guest_vf_state,
4409 }
4410 }
4411 }
4412
4413 async fn restart_queues(&mut self, c_state: &mut CoordinatorState) -> Result<(), WorkerError> {
4414 let drivers = self
4416 .workers
4417 .iter_mut()
4418 .map(|worker| {
4419 let task = worker.task_mut();
4420 task.queue_state = None;
4421 task.driver.clone()
4422 })
4423 .collect::<Vec<_>>();
4424
4425 c_state.endpoint.stop().await;
4426
4427 let (primary_worker, subworkers) = if let [primary, sub @ ..] = self.workers.as_mut_slice()
4428 {
4429 (primary, sub)
4430 } else {
4431 unreachable!()
4432 };
4433
4434 let state = primary_worker
4435 .state_mut()
4436 .and_then(|worker| worker.state.ready_mut());
4437
4438 let state = if let Some(state) = state {
4439 state
4440 } else {
4441 return Ok(());
4442 };
4443
4444 self.buffers = Some(state.buffers.clone());
4446
4447 let num_queues = state.state.primary.as_ref().unwrap().requested_num_queues;
4448 let mut active_queues = Vec::new();
4449 let active_queue_count =
4450 if let Some(rss_state) = state.state.primary.as_ref().unwrap().rss_state.as_ref() {
4451 active_queues.clone_from(&rss_state.indirection_table);
4453 active_queues.sort();
4454 active_queues.dedup();
4455 active_queues = active_queues
4456 .into_iter()
4457 .filter(|&index| index < num_queues)
4458 .collect::<Vec<_>>();
4459 if !active_queues.is_empty() {
4460 active_queues.len() as u16
4461 } else {
4462 tracelimit::warn_ratelimited!("Invalid RSS indirection table");
4463 num_queues
4464 }
4465 } else {
4466 num_queues
4467 };
4468
4469 let (ranges, mut remote_buffer_id_recvs) =
4471 RxBufferRanges::new(state.buffers.recv_buffer.count, active_queue_count.into());
4472 let ranges = Arc::new(ranges);
4473
4474 let mut queues = Vec::new();
4475 let mut rx_buffers = Vec::new();
4476 let mut per_queue_rx: Vec<Vec<RxId>> = Vec::new();
4477 let guest_buffers;
4478 {
4479 let buffers = &state.buffers;
4480 guest_buffers = Arc::new(
4481 GuestBuffers::new(
4482 buffers.mem.clone(),
4483 buffers.recv_buffer.gpadl.clone(),
4484 buffers.recv_buffer.sub_allocation_size,
4485 buffers.ndis_config.mtu,
4486 )
4487 .map_err(WorkerError::GpadlError)?,
4488 );
4489
4490 let mut queue_config = Vec::new();
4493 let initial_rx;
4494 {
4495 let states = std::iter::once(Some(&*state)).chain(
4496 subworkers
4497 .iter()
4498 .map(|worker| worker.state().and_then(|worker| worker.state.ready())),
4499 );
4500
4501 initial_rx = (RX_RESERVED_CONTROL_BUFFERS..state.buffers.recv_buffer.count)
4502 .filter(|&n| states.clone().flatten().all(|s| s.state.rx_bufs.is_free(n)))
4503 .map(RxId)
4504 .collect::<Vec<_>>();
4505
4506 let mut initial_rx = initial_rx.as_slice();
4507 let mut range_start = 0;
4508 let primary_queue_excluded = !active_queues.is_empty() && active_queues[0] != 0;
4509 let first_queue = if !primary_queue_excluded {
4510 0
4511 } else {
4512 queue_config.push(QueueConfig {
4516 driver: Box::new(drivers[0].clone()),
4517 });
4518 per_queue_rx.push(Vec::new());
4519 rx_buffers.push(RxBufferRange::new(
4520 ranges.clone(),
4521 0..RX_RESERVED_CONTROL_BUFFERS,
4522 None,
4523 ));
4524 range_start = RX_RESERVED_CONTROL_BUFFERS;
4525 1
4526 };
4527 for queue_index in first_queue..num_queues {
4528 let queue_active = active_queues.is_empty()
4529 || active_queues.binary_search(&queue_index).is_ok();
4530 let (range_end, end, buffer_id_recv) = if queue_active {
4531 let range_end = if rx_buffers.len() as u16 == active_queue_count - 1 {
4532 state.buffers.recv_buffer.count
4534 } else if queue_index == 0 {
4535 RX_RESERVED_CONTROL_BUFFERS + ranges.buffers_per_queue
4537 } else {
4538 range_start + ranges.buffers_per_queue
4539 };
4540 (
4541 range_end,
4542 initial_rx.partition_point(|id| id.0 < range_end),
4543 Some(remote_buffer_id_recvs.remove(0)),
4544 )
4545 } else {
4546 (range_start, 0, None)
4547 };
4548
4549 let (this, rest) = initial_rx.split_at(end);
4550 queue_config.push(QueueConfig {
4551 driver: Box::new(drivers[queue_index as usize].clone()),
4552 });
4553 per_queue_rx.push(this.to_vec());
4554 initial_rx = rest;
4555 rx_buffers.push(RxBufferRange::new(
4556 ranges.clone(),
4557 range_start..range_end,
4558 buffer_id_recv,
4559 ));
4560
4561 range_start = range_end;
4562 }
4563 }
4564
4565 let primary = state.state.primary.as_mut().unwrap();
4566 tracing::debug!(num_queues, "enabling endpoint");
4567
4568 let rss = primary
4569 .rss_state
4570 .as_ref()
4571 .map(|rss| net_backend::RssConfig {
4572 key: &rss.key,
4573 indirection_table: &rss.indirection_table,
4574 flags: 0,
4575 });
4576
4577 c_state
4578 .endpoint
4579 .get_queues(queue_config, rss.as_ref(), &mut queues)
4580 .instrument(tracing::info_span!("netvsp_get_queues"))
4581 .await
4582 .map_err(WorkerError::Endpoint)?;
4583
4584 assert_eq!(queues.len(), num_queues as usize);
4585
4586 self.channel_control
4588 .enable_subchannels(num_queues - 1)
4589 .expect("already validated");
4590
4591 self.num_queues = num_queues;
4592 }
4593
4594 self.active_packet_filter = self.workers[0].state().unwrap().channel.packet_filter;
4595 for (((worker, mut queue), rx_buffer), initial) in self
4597 .workers
4598 .iter_mut()
4599 .zip(queues)
4600 .zip(rx_buffers)
4601 .zip(per_queue_rx)
4602 {
4603 let mut pool = BufferPool::new(guest_buffers.clone());
4604 if !initial.is_empty() {
4605 queue.rx_avail(&mut pool, &initial);
4606 }
4607 worker.task_mut().queue_state = Some(QueueState {
4608 queue,
4609 pool,
4610 target_vp_set: false,
4611 rx_buffer_range: rx_buffer,
4612 });
4613 if let Some(worker) = worker.state_mut() {
4615 worker.channel.packet_filter = self.active_packet_filter;
4616 if let Some(ready_state) = worker.state.ready_mut() {
4619 ready_state.state.pending_rx_packets.clear();
4620 ready_state.reset_tx_after_endpoint_stop();
4621 }
4622 }
4623 }
4624
4625 Ok(())
4626 }
4627
4628 fn primary_mut(&mut self) -> Option<&mut PrimaryChannelState> {
4629 self.workers[0]
4630 .state_mut()
4631 .unwrap()
4632 .state
4633 .ready_mut()?
4634 .state
4635 .primary
4636 .as_mut()
4637 }
4638
4639 async fn update_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
4640 self.workers[0].stop().await;
4641 self.restore_guest_vf_state(c_state).await;
4642 }
4643}
4644
4645impl<T: RingMem + 'static + Sync> AsyncRun<Worker<T>> for NetQueue {
4646 async fn run(
4647 &mut self,
4648 stop: &mut StopTask<'_>,
4649 worker: &mut Worker<T>,
4650 ) -> Result<(), task_control::Cancelled> {
4651 match worker.process(stop, self).await {
4652 Ok(()) | Err(WorkerError::BufferRevoked) => {}
4653 Err(WorkerError::Cancelled(cancelled)) => return Err(cancelled),
4654 Err(err) => {
4655 tracing::error!(
4656 error = &err as &dyn std::error::Error,
4657 channel_idx = worker.channel_idx,
4658 "netvsp error"
4659 );
4660 }
4661 }
4662 Ok(())
4663 }
4664}
4665
4666impl<T: RingMem + 'static> Worker<T> {
4667 async fn process(
4668 &mut self,
4669 stop: &mut StopTask<'_>,
4670 queue: &mut NetQueue,
4671 ) -> Result<(), WorkerError> {
4672 loop {
4676 match &mut self.state {
4677 WorkerState::Init(initializing) => {
4678 assert_eq!(self.channel_idx, 0);
4679
4680 tracelimit::info_ratelimited!("network accepted");
4681
4682 let (buffers, state) = stop
4683 .until_stopped(self.channel.initialize(initializing, self.mem.clone()))
4684 .await??;
4685
4686 let state = ReadyState {
4687 buffers: Arc::new(buffers),
4688 state,
4689 data: ProcessingData::new(),
4690 };
4691
4692 let _ = self.coordinator_send.try_send(CoordinatorMessage::Restart);
4694
4695 tracelimit::info_ratelimited!("network initialized");
4696 self.state = WorkerState::WaitingForCoordinator(Some(state));
4697 }
4698 WorkerState::WaitingForCoordinator(_) => {
4699 assert_eq!(self.channel_idx, 0);
4700 stop.until_stopped(pending()).await?
4703 }
4704 WorkerState::Ready(state) => {
4705 let queue_state = if let Some(queue_state) = &mut queue.queue_state {
4706 if !queue_state.target_vp_set {
4707 if let Some(target_vp) = self.target_vp {
4708 tracing::debug!(
4709 channel_idx = self.channel_idx,
4710 target_vp,
4711 "updating target VP"
4712 );
4713 queue_state.queue.update_target_vp(target_vp).await;
4714 queue_state.target_vp_set = true;
4715 }
4716 }
4717
4718 queue_state
4719 } else {
4720 stop.until_stopped(pending()).await?
4722 };
4723
4724 let result = self.channel.main_loop(stop, state, queue_state).await;
4725 let msg = match result {
4726 Ok(restart) => {
4727 assert_eq!(self.channel_idx, 0);
4728 restart
4729 }
4730 Err(WorkerError::EndpointRequiresQueueRestart(err)) => {
4731 tracelimit::warn_ratelimited!(
4732 err = err.as_ref() as &dyn std::error::Error,
4733 "Endpoint requires queues to restart",
4734 );
4735 CoordinatorMessage::Restart
4736 }
4737 Err(err) => return Err(err),
4738 };
4739
4740 let WorkerState::Ready(ready) = std::mem::replace(
4741 &mut self.state,
4742 WorkerState::WaitingForCoordinator(None),
4743 ) else {
4744 unreachable!("must be running in ready state")
4745 };
4746 let _ = std::mem::replace(
4747 &mut self.state,
4748 WorkerState::WaitingForCoordinator(Some(ready)),
4749 );
4750 self.coordinator_send
4751 .try_send(msg)
4752 .map_err(WorkerError::CoordinatorMessageSendFailed)?;
4753 stop.until_stopped(pending()).await?
4754 }
4755 }
4756 }
4757 }
4758}
4759
4760impl<T: 'static + RingMem> NetChannel<T> {
4761 fn try_next_packet<'a>(
4762 &mut self,
4763 send_buffer: Option<&SendBuffer>,
4764 version: Option<Version>,
4765 external_data: &'a mut MultiPagedRangeBuf,
4766 ) -> Result<Option<Packet<'a>>, WorkerError> {
4767 let (mut read, _) = self.queue.split();
4768 let packet = match read.try_read() {
4769 Ok(packet) => parse_packet(&packet, send_buffer, version, external_data)
4770 .map_err(WorkerError::Packet)?,
4771 Err(queue::TryReadError::Empty) => return Ok(None),
4772 Err(queue::TryReadError::Queue(err)) => return Err(err.into()),
4773 };
4774
4775 tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
4776 Ok(Some(packet))
4777 }
4778
4779 async fn next_packet<'a>(
4780 &mut self,
4781 send_buffer: Option<&'a SendBuffer>,
4782 version: Option<Version>,
4783 external_data: &'a mut MultiPagedRangeBuf,
4784 ) -> Result<Packet<'a>, WorkerError> {
4785 let (mut read, _) = self.queue.split();
4786 let mut packet_ref = read.read().await?;
4787 let packet = parse_packet(&packet_ref, send_buffer, version, external_data)
4788 .map_err(WorkerError::Packet)?;
4789 if matches!(packet.data, PacketData::RndisPacket(_)) {
4790 tracing::trace!(target: "netvsp/vmbus", "detected rndis initialization message");
4792 packet_ref.revert();
4793 }
4794 tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
4795 Ok(packet)
4796 }
4797
4798 fn is_ready_to_initialize(initializing: &InitState, allow_missing_send_buffer: bool) -> bool {
4799 (initializing.ndis_config.is_some() || initializing.version < Version::V2)
4800 && initializing.ndis_version.is_some()
4801 && (initializing.send_buffer.is_some() || allow_missing_send_buffer)
4802 && initializing.recv_buffer.is_some()
4803 }
4804
4805 async fn initialize(
4806 &mut self,
4807 initializing: &mut Option<InitState>,
4808 mem: GuestMemory,
4809 ) -> Result<(ChannelBuffers, ActiveState), WorkerError> {
4810 let mut has_init_packet_arrived = false;
4811 loop {
4812 if let Some(initializing) = &mut *initializing {
4813 if Self::is_ready_to_initialize(initializing, false) || has_init_packet_arrived {
4814 let recv_buffer = initializing.recv_buffer.take().unwrap();
4815 let send_buffer = initializing.send_buffer.take();
4816 let state = ActiveState::new(
4817 Some(PrimaryChannelState::new(
4818 self.adapter.offload_support.clone(),
4819 )),
4820 recv_buffer.count,
4821 );
4822 let buffers = ChannelBuffers {
4823 version: initializing.version,
4824 mem,
4825 recv_buffer,
4826 send_buffer,
4827 ndis_version: initializing.ndis_version.take().unwrap(),
4828 ndis_config: initializing.ndis_config.take().unwrap_or(NdisConfig {
4829 mtu: DEFAULT_MTU,
4830 capabilities: protocol::NdisConfigCapabilities::new(),
4831 }),
4832 };
4833
4834 break Ok((buffers, state));
4835 }
4836 }
4837
4838 self.queue
4841 .split()
4842 .1
4843 .wait_ready(ring::PacketSize::completion(protocol::PACKET_SIZE_V61))
4844 .await?;
4845
4846 let mut external_data = MultiPagedRangeBuf::new();
4847 let packet = self
4848 .next_packet(
4849 None,
4850 initializing.as_ref().map(|x| x.version),
4851 &mut external_data,
4852 )
4853 .await?;
4854
4855 if let Some(initializing) = &mut *initializing {
4856 match packet.data {
4857 PacketData::SendNdisConfig(config) => {
4858 if initializing.ndis_config.is_some() {
4859 return Err(WorkerError::UnexpectedPacketOrder(
4860 PacketOrderError::SendNdisConfigExists,
4861 ));
4862 }
4863
4864 let mtu = if config.mtu >= MIN_MTU && config.mtu <= MAX_MTU {
4866 config.mtu
4867 } else {
4868 DEFAULT_MTU
4869 };
4870
4871 self.send_completion(packet.transaction_id, None)?;
4873 initializing.ndis_config = Some(NdisConfig {
4874 mtu,
4875 capabilities: config.capabilities,
4876 });
4877 }
4878 PacketData::SendNdisVersion(version) => {
4879 if initializing.ndis_version.is_some() {
4880 return Err(WorkerError::UnexpectedPacketOrder(
4881 PacketOrderError::SendNdisVersionExists,
4882 ));
4883 }
4884
4885 self.send_completion(packet.transaction_id, None)?;
4887 initializing.ndis_version = Some(NdisVersion {
4888 major: version.ndis_major_version,
4889 minor: version.ndis_minor_version,
4890 });
4891 }
4892 PacketData::SendReceiveBuffer(message) => {
4893 if initializing.recv_buffer.is_some() {
4894 return Err(WorkerError::UnexpectedPacketOrder(
4895 PacketOrderError::SendReceiveBufferExists,
4896 ));
4897 }
4898
4899 let mtu = if let Some(cfg) = &initializing.ndis_config {
4900 cfg.mtu
4901 } else if initializing.version < Version::V2 {
4902 DEFAULT_MTU
4903 } else {
4904 return Err(WorkerError::UnexpectedPacketOrder(
4905 PacketOrderError::SendReceiveBufferMissingMTU,
4906 ));
4907 };
4908
4909 let sub_allocation_size = sub_allocation_size_for_mtu(mtu);
4910
4911 let recv_buffer = ReceiveBuffer::new(
4912 &self.gpadl_map,
4913 message.gpadl_handle,
4914 message.id,
4915 sub_allocation_size,
4916 )?;
4917
4918 self.send_completion(
4919 packet.transaction_id,
4920 Some(&self.message(
4921 protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER_COMPLETE,
4922 protocol::Message1SendReceiveBufferComplete {
4923 status: protocol::Status::SUCCESS,
4924 num_sections: 1,
4925 sections: [protocol::ReceiveBufferSection {
4926 offset: 0,
4927 sub_allocation_size: recv_buffer.sub_allocation_size,
4928 num_sub_allocations: recv_buffer.count,
4929 end_offset: recv_buffer.sub_allocation_size
4930 * recv_buffer.count,
4931 }],
4932 },
4933 )),
4934 )?;
4935 initializing.recv_buffer = Some(recv_buffer);
4936 }
4937 PacketData::SendSendBuffer(message) => {
4938 if initializing.send_buffer.is_some() {
4939 return Err(WorkerError::UnexpectedPacketOrder(
4940 PacketOrderError::SendSendBufferExists,
4941 ));
4942 }
4943
4944 let send_buffer = SendBuffer::new(&self.gpadl_map, message.gpadl_handle)?;
4945 self.send_completion(
4946 packet.transaction_id,
4947 Some(&self.message(
4948 protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER_COMPLETE,
4949 protocol::Message1SendSendBufferComplete {
4950 status: protocol::Status::SUCCESS,
4951 section_size: 6144,
4952 },
4953 )),
4954 )?;
4955
4956 initializing.send_buffer = Some(send_buffer);
4957 }
4958 PacketData::RndisPacket(rndis_packet) => {
4959 if !Self::is_ready_to_initialize(initializing, true) {
4960 return Err(WorkerError::UnexpectedPacketOrder(
4961 PacketOrderError::UnexpectedRndisPacket,
4962 ));
4963 }
4964 tracing::debug!(
4965 channel_type = rndis_packet.channel_type,
4966 "RndisPacket received during initialization, assuming MESSAGE_TYPE_INITIALIZE_MSG"
4967 );
4968 has_init_packet_arrived = true;
4969 }
4970 _ => {
4971 return Err(WorkerError::UnexpectedPacketOrder(
4972 PacketOrderError::InvalidPacketData,
4973 ));
4974 }
4975 }
4976 } else {
4977 match packet.data {
4978 PacketData::Init(init) => {
4979 let requested_version = init.protocol_version;
4980 let version = check_version(requested_version);
4981 let mut data = protocol::MessageInitComplete {
4982 deprecated: protocol::INVALID_PROTOCOL_VERSION,
4983 maximum_mdl_chain_length: 34,
4984 status: protocol::Status::NONE,
4985 };
4986 if let Some(version) = version {
4987 if version == Version::V1 {
4988 data.deprecated = Version::V1 as u32;
4989 }
4990 data.status = protocol::Status::SUCCESS;
4991 } else {
4992 tracing::debug!(requested_version, "unrecognized version");
4993 }
4994 let message = self.message(protocol::MESSAGE_TYPE_INIT_COMPLETE, data);
4995 self.send_completion(packet.transaction_id, Some(&message))?;
4996
4997 if let Some(version) = version {
4998 tracelimit::info_ratelimited!(?version, "network negotiated");
4999
5000 if version >= Version::V61 {
5001 self.packet_size = PacketSize::V61;
5004 }
5005 *initializing = Some(InitState {
5006 version,
5007 ndis_config: None,
5008 ndis_version: None,
5009 recv_buffer: None,
5010 send_buffer: None,
5011 });
5012 }
5013 }
5014 _ => unreachable!(),
5015 }
5016 }
5017 }
5018 }
5019
5020 async fn main_loop(
5021 &mut self,
5022 stop: &mut StopTask<'_>,
5023 ready_state: &mut ReadyState,
5024 queue_state: &mut QueueState,
5025 ) -> Result<CoordinatorMessage, WorkerError> {
5026 let buffers = &ready_state.buffers;
5027 let state = &mut ready_state.state;
5028 let data = &mut ready_state.data;
5029
5030 let ring_spare_capacity = {
5031 let (_, send) = self.queue.split();
5032 let mut limit = if self.can_use_ring_size_opt {
5033 self.adapter.ring_size_limit.load(Ordering::Relaxed)
5034 } else {
5035 0
5036 };
5037 if limit == 0 {
5038 limit = send.capacity() - 2048;
5039 }
5040 send.capacity() - limit
5041 };
5042
5043 if !state.pending_rx_packets.is_empty()
5045 && self.packet_filter != rndisprot::NDIS_PACKET_TYPE_NONE
5046 {
5047 let (front, back) = state.pending_rx_packets.as_slices();
5048 queue_state.queue.rx_avail(&mut queue_state.pool, front);
5049 queue_state.queue.rx_avail(&mut queue_state.pool, back);
5050 state.pending_rx_packets.clear();
5051 }
5052
5053 if let Some(primary) = state.primary.as_mut() {
5055 if primary.requested_num_queues > 1 && !primary.tx_spread_sent {
5056 let num_channels_opened =
5057 self.adapter.num_sub_channels_opened.load(Ordering::Relaxed);
5058 if num_channels_opened == primary.requested_num_queues as usize {
5059 let (_, mut send) = self.queue.split();
5060 stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
5061 .await??;
5062 self.guest_send_indirection_table(buffers.version, num_channels_opened as u32);
5063 primary.tx_spread_sent = true;
5064 }
5065 }
5066 if let PendingLinkAction::Active(up) = primary.pending_link_action {
5067 let (_, mut send) = self.queue.split();
5068 stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
5069 .await??;
5070 if let Some(id) = primary.free_control_buffers.pop() {
5071 let connect = if primary.guest_link_up != up {
5072 primary.pending_link_action = PendingLinkAction::Default;
5073 up
5074 } else {
5075 primary.pending_link_action =
5078 PendingLinkAction::Delay(primary.guest_link_up);
5079 !primary.guest_link_up
5080 };
5081 assert!(state.rx_bufs.is_free(id.0));
5083 state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
5084 let state_to_send = if connect {
5085 rndisprot::STATUS_MEDIA_CONNECT
5086 } else {
5087 rndisprot::STATUS_MEDIA_DISCONNECT
5088 };
5089 tracing::info!(
5090 connect,
5091 mac_address = %self.adapter.mac_address,
5092 "sending link status"
5093 );
5094
5095 self.indicate_status(buffers, id.0, state_to_send, &[])?;
5096 primary.guest_link_up = connect;
5097 } else {
5098 primary.pending_link_action = PendingLinkAction::Delay(up);
5099 }
5100
5101 match primary.pending_link_action {
5102 PendingLinkAction::Delay(_) => {
5103 return Ok(CoordinatorMessage::StartTimer(
5104 Instant::now() + LINK_DELAY_DURATION,
5105 ));
5106 }
5107 PendingLinkAction::Active(_) => panic!("State should not be Active"),
5108 _ => {}
5109 }
5110 }
5111 match primary.guest_vf_state {
5112 PrimaryChannelGuestVfState::Available { .. }
5113 | PrimaryChannelGuestVfState::UnavailableFromAvailable
5114 | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
5115 | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
5116 | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
5117 | PrimaryChannelGuestVfState::DataPathSynthetic => {
5118 let (_, mut send) = self.queue.split();
5119 stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
5120 .await??;
5121 if let Some(message) = self.handle_state_change(primary, buffers).await? {
5122 return Ok(message);
5123 }
5124 }
5125 _ => (),
5126 }
5127 }
5128
5129 loop {
5130 let ring_full = {
5135 let (_, mut send) = self.queue.split();
5136 !send.can_write(ring_spare_capacity)?
5137 };
5138
5139 let did_some_work = (!ring_full
5140 && self.process_endpoint_rx(
5141 buffers,
5142 state,
5143 data,
5144 queue_state.queue.as_mut(),
5145 &mut queue_state.pool,
5146 )?)
5147 | self.process_ring_buffer(buffers, state, data, queue_state)?
5148 | (!ring_full
5149 && self.process_endpoint_tx(
5150 state,
5151 data,
5152 queue_state.queue.as_mut(),
5153 &mut queue_state.pool,
5154 )?)
5155 | self.transmit_pending_segments(state, data, queue_state)?
5156 | self.send_pending_packets(state)?;
5157
5158 if !did_some_work {
5159 state.stats.spurious_wakes.increment();
5160 }
5161
5162 self.process_control_messages(buffers, state)?;
5165
5166 let restart = stop
5170 .until_stopped(std::future::poll_fn(
5171 |cx| -> Poll<Option<CoordinatorMessage>> {
5172 if !ring_full {
5176 if queue_state
5178 .queue
5179 .poll_ready(cx, &mut queue_state.pool)
5180 .is_ready()
5181 {
5182 tracing::trace!("endpoint ready");
5183 return Poll::Ready(None);
5184 }
5185 }
5186
5187 let (mut recv, mut send) = self.queue.split();
5190 if state.free_tx_packets.len() >= self.adapter.free_tx_packet_threshold
5191 && data.tx_segments.is_empty()
5192 && recv.poll_ready(cx).is_ready()
5193 {
5194 tracing::trace!("incoming ring ready");
5195 return Poll::Ready(None);
5196 }
5197
5198 let mut pending_send_size = self.pending_send_size;
5205 if ring_full {
5206 pending_send_size = ring_spare_capacity;
5207 }
5208 if pending_send_size != 0
5209 && send.poll_ready(cx, pending_send_size).is_ready()
5210 {
5211 tracing::trace!("outgoing ring ready");
5212 return Poll::Ready(None);
5213 }
5214
5215 if let Some(remote_buffer_id_recv) =
5221 &mut queue_state.rx_buffer_range.remote_buffer_id_recv
5222 {
5223 while let Poll::Ready(Some(id)) =
5224 remote_buffer_id_recv.poll_next_unpin(cx)
5225 {
5226 if id >= RX_RESERVED_CONTROL_BUFFERS {
5227 queue_state
5228 .queue
5229 .rx_avail(&mut queue_state.pool, &[RxId(id)]);
5230 } else {
5231 state
5232 .primary
5233 .as_mut()
5234 .unwrap()
5235 .free_control_buffers
5236 .push(ControlMessageId(id));
5237 }
5238 }
5239 }
5240
5241 if let Some(restart) = self.restart.take() {
5242 return Poll::Ready(Some(restart));
5243 }
5244
5245 tracing::trace!("network waiting");
5246 Poll::Pending
5247 },
5248 ))
5249 .await?;
5250
5251 if let Some(restart) = restart {
5252 break Ok(restart);
5253 }
5254 }
5255 }
5256
5257 fn process_endpoint_rx(
5258 &mut self,
5259 buffers: &ChannelBuffers,
5260 state: &mut ActiveState,
5261 data: &mut ProcessingData,
5262 epqueue: &mut dyn net_backend::Queue,
5263 pool: &mut BufferPool,
5264 ) -> Result<bool, WorkerError> {
5265 let n = epqueue
5266 .rx_poll(pool, &mut data.rx_ready)
5267 .map_err(WorkerError::Endpoint)?;
5268 if n == 0 {
5269 return Ok(false);
5270 }
5271
5272 state.stats.rx_packets_per_wake.add_sample(n as u64);
5273
5274 if self.packet_filter == rndisprot::NDIS_PACKET_TYPE_NONE {
5275 tracing::trace!(
5276 packet_filter = self.packet_filter,
5277 "rx packets dropped due to packet filter"
5278 );
5279 state.pending_rx_packets.extend(&data.rx_ready[..n]);
5284 state.stats.rx_dropped_filtered.add(n as u64);
5285 return Ok(false);
5286 }
5287
5288 let transaction_id = data.rx_ready[0].0.into();
5289 let ready_ids = data.rx_ready[..n].iter().map(|&RxId(id)| id);
5290
5291 state.rx_bufs.allocate(ready_ids.clone()).unwrap();
5292
5293 let len = buffers.recv_buffer.sub_allocation_size as usize;
5296 data.transfer_pages.clear();
5297 data.transfer_pages
5298 .extend(ready_ids.map(|id| buffers.recv_buffer.transfer_page_range(id, len)));
5299
5300 match self.try_send_rndis_message(
5301 transaction_id,
5302 protocol::DATA_CHANNEL_TYPE,
5303 buffers.recv_buffer.id,
5304 &data.transfer_pages,
5305 )? {
5306 None => {
5307 state.stats.rx_packets.add(n as u64);
5309 }
5310 Some(_) => {
5311 state.stats.rx_dropped_ring_full.add(n as u64);
5315
5316 state.rx_bufs.free(data.rx_ready[0].0);
5317 epqueue.rx_avail(pool, &data.rx_ready[..n]);
5318 }
5319 }
5320
5321 Ok(true)
5322 }
5323
5324 fn process_endpoint_tx(
5325 &mut self,
5326 state: &mut ActiveState,
5327 data: &mut ProcessingData,
5328 epqueue: &mut dyn net_backend::Queue,
5329 pool: &mut BufferPool,
5330 ) -> Result<bool, WorkerError> {
5331 let result = epqueue.tx_poll(pool, &mut data.tx_done);
5333
5334 match result {
5335 Ok(n) => {
5336 if n == 0 {
5337 return Ok(false);
5338 }
5339
5340 for &id in &data.tx_done[..n] {
5341 let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5342 assert!(tx_packet.pending_packet_count > 0);
5343 tx_packet.pending_packet_count -= 1;
5344 if tx_packet.pending_packet_count == 0 {
5345 self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5346 }
5347 }
5348
5349 Ok(true)
5350 }
5351 Err(TxError::TryRestart(err)) => {
5352 Err(WorkerError::EndpointRequiresQueueRestart(err))
5355 }
5356 Err(TxError::Fatal(err)) => Err(WorkerError::Endpoint(err)),
5357 }
5358 }
5359
5360 fn switch_data_path(
5361 &mut self,
5362 state: &mut ActiveState,
5363 use_guest_vf: bool,
5364 transaction_id: Option<u64>,
5365 ) -> Result<(), WorkerError> {
5366 let primary = state.primary.as_mut().unwrap();
5367 let mut queue_switch_operation = false;
5368 match primary.guest_vf_state {
5369 PrimaryChannelGuestVfState::AvailableAdvertised | PrimaryChannelGuestVfState::Ready => {
5370 if use_guest_vf || primary.is_data_path_switched.is_none() {
5379 primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5380 to_guest: use_guest_vf,
5381 id: transaction_id,
5382 result: None,
5383 };
5384 queue_switch_operation = true;
5385 }
5386 }
5387 PrimaryChannelGuestVfState::DataPathSwitched => {
5388 if !use_guest_vf {
5389 primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5390 to_guest: false,
5391 id: transaction_id,
5392 result: None,
5393 };
5394 queue_switch_operation = true;
5395 }
5396 }
5397 _ if use_guest_vf => {
5398 tracing::warn!(
5399 state = %primary.guest_vf_state,
5400 use_guest_vf,
5401 "Data path switch requested while device is in wrong state"
5402 );
5403 }
5404 _ => (),
5405 };
5406 if queue_switch_operation {
5407 self.send_coordinator_update_vf();
5408 } else {
5409 self.send_completion(transaction_id, None)?;
5410 }
5411 Ok(())
5412 }
5413
5414 fn process_ring_buffer(
5415 &mut self,
5416 buffers: &ChannelBuffers,
5417 state: &mut ActiveState,
5418 data: &mut ProcessingData,
5419 queue_state: &mut QueueState,
5420 ) -> Result<bool, WorkerError> {
5421 if !data.tx_segments.is_empty() {
5422 return Ok(false);
5426 }
5427 let mut total_packets = 0;
5428 let mut did_some_work = false;
5429 loop {
5430 if state.free_tx_packets.is_empty() {
5431 break;
5432 }
5433 let packet = if let Some(packet) = self.try_next_packet(
5434 buffers.send_buffer.as_ref(),
5435 Some(buffers.version),
5436 &mut data.external_data,
5437 )? {
5438 packet
5439 } else {
5440 break;
5441 };
5442
5443 did_some_work = true;
5444 match packet.data {
5445 PacketData::RndisPacket(_) => {
5446 let id = state.free_tx_packets.pop().unwrap();
5447 let result: Result<usize, WorkerError> =
5448 self.handle_rndis(buffers, id, state, &packet, &mut data.tx_segments);
5449 match result {
5450 Ok(num_packets) => {
5451 total_packets += num_packets as u64;
5452 if num_packets == 0 {
5453 self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5454 }
5455 }
5456 Err(err) => {
5457 tracelimit::error_ratelimited!(
5458 error = &err as &dyn std::error::Error,
5459 "failed to handle RNDIS packet"
5460 );
5461 self.complete_tx_packet(state, id, protocol::Status::FAILURE)?;
5462 }
5463 };
5464 }
5465 PacketData::RndisPacketComplete(_completion) => {
5466 data.rx_done.clear();
5467 state
5468 .release_recv_buffers(
5469 packet
5470 .transaction_id
5471 .expect("completion packets have transaction id by construction"),
5472 &queue_state.rx_buffer_range,
5473 &mut data.rx_done,
5474 )
5475 .ok_or(WorkerError::InvalidRndisPacketCompletion)?;
5476 queue_state
5477 .queue
5478 .rx_avail(&mut queue_state.pool, &data.rx_done);
5479 }
5480 PacketData::SubChannelRequest(request) if state.primary.is_some() => {
5481 let mut subchannel_count = 0;
5482 let status = if request.operation == protocol::SubchannelOperation::ALLOCATE
5488 && request.num_sub_channels < self.adapter.max_queues.into()
5489 {
5490 subchannel_count = request.num_sub_channels;
5491 protocol::Status::SUCCESS
5492 } else {
5493 tracelimit::warn_ratelimited!(
5494 operation = ?request.operation,
5495 request_sub_channels = request.num_sub_channels,
5496 max_supported_sub_channels = self.adapter.max_queues - 1,
5497 "Subchannel request failed: either operation is not supported or requested more subchannels than supported"
5498 );
5499 protocol::Status::FAILURE
5500 };
5501
5502 tracing::debug!(?status, subchannel_count, "subchannel request");
5503 self.send_completion(
5504 packet.transaction_id,
5505 Some(&self.message(
5506 protocol::MESSAGE5_TYPE_SUB_CHANNEL,
5507 protocol::Message5SubchannelComplete {
5508 status,
5509 num_sub_channels: subchannel_count,
5510 },
5511 )),
5512 )?;
5513
5514 if subchannel_count > 0 {
5515 let primary = state.primary.as_mut().unwrap();
5516 primary.requested_num_queues = subchannel_count as u16 + 1;
5517 primary.tx_spread_sent = false;
5518 self.restart = Some(CoordinatorMessage::Restart);
5519 }
5520 }
5521 PacketData::RevokeReceiveBuffer(protocol::Message1RevokeReceiveBuffer { id })
5522 | PacketData::RevokeSendBuffer(protocol::Message1RevokeSendBuffer { id })
5523 if state.primary.is_some() =>
5524 {
5525 tracing::debug!(
5526 id,
5527 "receive/send buffer revoked, terminating channel processing"
5528 );
5529 return Err(WorkerError::BufferRevoked);
5530 }
5531 PacketData::SendVfAssociationCompletion if state.primary.is_some() => (),
5533 PacketData::SwitchDataPath(switch_data_path) if state.primary.is_some() => {
5534 self.switch_data_path(
5535 state,
5536 switch_data_path.active_data_path == protocol::DataPath::VF.0,
5537 packet.transaction_id,
5538 )?;
5539 }
5540 PacketData::SwitchDataPathCompletion if state.primary.is_some() => (),
5541 PacketData::OidQueryEx(oid_query) => {
5542 tracing::warn!(oid = ?oid_query.oid, "unimplemented OID");
5543 self.send_completion(
5544 packet.transaction_id,
5545 Some(&self.message(
5546 protocol::MESSAGE5_TYPE_OID_QUERY_EX_COMPLETE,
5547 protocol::Message5OidQueryExComplete {
5548 status: rndisprot::STATUS_NOT_SUPPORTED,
5549 bytes: 0,
5550 },
5551 )),
5552 )?;
5553 }
5554 p => {
5555 tracing::warn!(request = ?p, "unexpected packet");
5556 return Err(WorkerError::UnexpectedPacketOrder(
5557 PacketOrderError::SwitchDataPathCompletionPrimaryChannelState,
5558 ));
5559 }
5560 }
5561 }
5562 if total_packets > 0 && !self.transmit_segments(state, data, queue_state)? {
5563 state.stats.tx_stalled.increment();
5564 }
5565 state.stats.tx_packets_per_wake.add_sample(total_packets);
5566 Ok(did_some_work)
5567 }
5568
5569 fn transmit_pending_segments(
5572 &mut self,
5573 state: &mut ActiveState,
5574 data: &mut ProcessingData,
5575 queue_state: &mut QueueState,
5576 ) -> Result<bool, WorkerError> {
5577 if data.tx_segments.is_empty() {
5578 return Ok(false);
5579 }
5580 let sent = data.tx_segments_sent;
5581 let did_work =
5582 self.transmit_segments(state, data, queue_state)? || data.tx_segments_sent > sent;
5583 Ok(did_work)
5584 }
5585
5586 fn transmit_segments(
5588 &mut self,
5589 state: &mut ActiveState,
5590 data: &mut ProcessingData,
5591 queue_state: &mut QueueState,
5592 ) -> Result<bool, WorkerError> {
5593 let segments = &data.tx_segments[data.tx_segments_sent..];
5594 let (sync, segments_sent) = queue_state
5595 .queue
5596 .tx_avail(&mut queue_state.pool, segments)
5597 .map_err(WorkerError::Endpoint)?;
5598
5599 let mut segments = &segments[..segments_sent];
5600 data.tx_segments_sent += segments_sent;
5601
5602 if sync {
5603 while let Some(head) = segments.first() {
5605 let net_backend::TxSegmentType::Head(metadata) = &head.ty else {
5606 unreachable!()
5607 };
5608 let id = metadata.id;
5609 let pending_tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5610 pending_tx_packet.pending_packet_count -= 1;
5611 if pending_tx_packet.pending_packet_count == 0 {
5612 self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5613 }
5614 segments = &segments[metadata.segment_count as usize..];
5615 }
5616 }
5617
5618 let all_sent = data.tx_segments_sent == data.tx_segments.len();
5619 if all_sent {
5620 data.tx_segments.clear();
5621 data.tx_segments_sent = 0;
5622 }
5623 Ok(all_sent)
5624 }
5625
5626 fn handle_rndis(
5627 &mut self,
5628 buffers: &ChannelBuffers,
5629 id: TxId,
5630 state: &mut ActiveState,
5631 packet: &Packet<'_>,
5632 segments: &mut Vec<TxSegment>,
5633 ) -> Result<usize, WorkerError> {
5634 let mut total_packets = 0;
5635 let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5636 assert_eq!(tx_packet.pending_packet_count, 0);
5637 tx_packet.transaction_id = packet
5638 .transaction_id
5639 .ok_or(WorkerError::MissingTransactionId)?;
5640
5641 packet
5645 .external_data
5646 .iter()
5647 .try_for_each(|range| buffers.mem.probe_gpns(range.gpns()))
5648 .map_err(WorkerError::GpaDirectError)?;
5649
5650 let mut reader = packet.rndis_reader(&buffers.mem);
5651 let header: rndisprot::MessageHeader = reader.read_plain()?;
5652 if header.message_type == rndisprot::MESSAGE_TYPE_PACKET_MSG {
5653 let start = segments.len();
5654 match self.handle_rndis_packet_messages(
5655 buffers,
5656 state,
5657 id,
5658 header.message_length as usize,
5659 reader,
5660 segments,
5661 ) {
5662 Ok(n) => {
5663 state.pending_tx_packets[id.0 as usize].pending_packet_count += n;
5664 total_packets += n;
5665 }
5666 Err(err) => {
5667 segments.truncate(start);
5669 return Err(err);
5670 }
5671 }
5672 } else {
5673 self.handle_rndis_message(state, header.message_type, reader)?;
5674 }
5675
5676 Ok(total_packets)
5677 }
5678
5679 fn try_send_tx_packet(
5680 &mut self,
5681 transaction_id: u64,
5682 status: protocol::Status,
5683 ) -> Result<bool, WorkerError> {
5684 let message = self.message(
5685 protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE,
5686 protocol::Message1SendRndisPacketComplete { status },
5687 );
5688 let result = self.queue.split().1.batched().try_write_aligned(
5689 transaction_id,
5690 OutgoingPacketType::Completion,
5691 message.aligned_payload(),
5692 );
5693 let sent = match result {
5694 Ok(()) => true,
5695 Err(queue::TryWriteError::Full(n)) => {
5696 self.pending_send_size = n;
5697 false
5698 }
5699 Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
5700 };
5701 Ok(sent)
5702 }
5703
5704 fn send_pending_packets(&mut self, state: &mut ActiveState) -> Result<bool, WorkerError> {
5705 let mut did_some_work = false;
5706 while let Some(pending) = state.pending_tx_completions.front() {
5707 if !self.try_send_tx_packet(pending.transaction_id, pending.status)? {
5708 return Ok(did_some_work);
5709 }
5710 did_some_work = true;
5711 if let Some(id) = pending.tx_id {
5712 state.free_tx_packets.push(id);
5713 }
5714 tracing::trace!(?pending, "sent tx completion");
5715 state.pending_tx_completions.pop_front();
5716 }
5717
5718 self.pending_send_size = 0;
5719 Ok(did_some_work)
5720 }
5721
5722 fn complete_tx_packet(
5723 &mut self,
5724 state: &mut ActiveState,
5725 id: TxId,
5726 status: protocol::Status,
5727 ) -> Result<(), WorkerError> {
5728 let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5729 assert_eq!(tx_packet.pending_packet_count, 0);
5730 if self.pending_send_size == 0
5731 && self.try_send_tx_packet(tx_packet.transaction_id, status)?
5732 {
5733 tracing::trace!(id = id.0, "sent tx completion");
5734 state.free_tx_packets.push(id);
5735 } else {
5736 tracing::trace!(id = id.0, "pended tx completion");
5737 state.pending_tx_completions.push_back(PendingTxCompletion {
5738 transaction_id: tx_packet.transaction_id,
5739 tx_id: Some(id),
5740 status,
5741 });
5742 }
5743 Ok(())
5744 }
5745}
5746
5747impl ActiveState {
5748 fn release_recv_buffers(
5749 &mut self,
5750 transaction_id: u64,
5751 rx_buffer_range: &RxBufferRange,
5752 done: &mut Vec<RxId>,
5753 ) -> Option<()> {
5754 let first_id: u32 = transaction_id.try_into().ok()?;
5756 let ids = self.rx_bufs.free(first_id)?;
5757 for id in ids {
5758 if !rx_buffer_range.send_if_remote(id) {
5759 if id >= RX_RESERVED_CONTROL_BUFFERS {
5760 done.push(RxId(id));
5761 } else {
5762 self.primary
5763 .as_mut()?
5764 .free_control_buffers
5765 .push(ControlMessageId(id));
5766 }
5767 }
5768 }
5769 Some(())
5770 }
5771}