1#![expect(missing_docs)]
7#![forbid(unsafe_code)]
8
9mod buffers;
10mod protocol;
11pub mod resolver;
12mod rndisprot;
13mod rx_bufs;
14mod saved_state;
15mod test;
16
17use crate::buffers::GuestBuffers;
18use crate::protocol::VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES;
19use crate::protocol::Version;
20use crate::rndisprot::NDIS_HASH_FUNCTION_MASK;
21use crate::rndisprot::NDIS_RSS_PARAM_FLAG_DISABLE_RSS;
22use async_trait::async_trait;
23pub use buffers::BufferPool;
24use buffers::sub_allocation_size_for_mtu;
25use futures::FutureExt;
26use futures::StreamExt;
27use futures::channel::mpsc;
28use futures::channel::mpsc::TrySendError;
29use futures_concurrency::future::Race;
30use guestmem::AccessError;
31use guestmem::GuestMemory;
32use guestmem::GuestMemoryError;
33use guestmem::MemoryRead;
34use guestmem::MemoryWrite;
35use guestmem::ranges::PagedRange;
36use guestmem::ranges::PagedRanges;
37use guestmem::ranges::PagedRangesReader;
38use guid::Guid;
39use hvdef::hypercall::HvGuestOsId;
40use hvdef::hypercall::HvGuestOsMicrosoft;
41use hvdef::hypercall::HvGuestOsMicrosoftIds;
42use hvdef::hypercall::HvGuestOsOpenSourceType;
43use inspect::Inspect;
44use inspect::InspectMut;
45use inspect::SensitivityLevel;
46use inspect_counters::Counter;
47use inspect_counters::Histogram;
48use mesh::rpc::Rpc;
49use net_backend::Endpoint;
50use net_backend::EndpointAction;
51use net_backend::QueueConfig;
52use net_backend::RxId;
53use net_backend::TxError;
54use net_backend::TxId;
55use net_backend::TxSegment;
56use net_backend_resources::mac_address::MacAddress;
57use pal_async::timer::Instant;
58use pal_async::timer::PolledTimer;
59use ring::gparange::MultiPagedRangeIter;
60use rx_bufs::RxBuffers;
61use rx_bufs::SubAllocationInUse;
62use std::collections::VecDeque;
63use std::fmt::Debug;
64use std::future::pending;
65use std::mem::offset_of;
66use std::ops::Range;
67use std::sync::Arc;
68use std::sync::atomic::AtomicUsize;
69use std::sync::atomic::Ordering;
70use std::task::Poll;
71use std::time::Duration;
72use task_control::AsyncRun;
73use task_control::InspectTaskMut;
74use task_control::StopTask;
75use task_control::TaskControl;
76use thiserror::Error;
77use tracing::Instrument;
78use vmbus_async::queue;
79use vmbus_async::queue::ExternalDataError;
80use vmbus_async::queue::IncomingPacket;
81use vmbus_async::queue::Queue;
82use vmbus_channel::bus::OfferParams;
83use vmbus_channel::bus::OpenRequest;
84use vmbus_channel::channel::ChannelControl;
85use vmbus_channel::channel::ChannelOpenError;
86use vmbus_channel::channel::ChannelRestoreError;
87use vmbus_channel::channel::DeviceResources;
88use vmbus_channel::channel::RestoreControl;
89use vmbus_channel::channel::SaveRestoreVmbusDevice;
90use vmbus_channel::channel::VmbusDevice;
91use vmbus_channel::gpadl::GpadlId;
92use vmbus_channel::gpadl::GpadlMapView;
93use vmbus_channel::gpadl::GpadlView;
94use vmbus_channel::gpadl::UnknownGpadlId;
95use vmbus_channel::gpadl_ring::GpadlRingMem;
96use vmbus_channel::gpadl_ring::gpadl_channel;
97use vmbus_ring as ring;
98use vmbus_ring::OutgoingPacketType;
99use vmbus_ring::RingMem;
100use vmbus_ring::gparange::MultiPagedRangeBuf;
101use vmcore::save_restore::RestoreError;
102use vmcore::save_restore::SaveError;
103use vmcore::save_restore::SavedStateBlob;
104use vmcore::vm_task::VmTaskDriver;
105use vmcore::vm_task::VmTaskDriverSource;
106use zerocopy::FromBytes;
107use zerocopy::FromZeros;
108use zerocopy::Immutable;
109use zerocopy::IntoBytes;
110use zerocopy::KnownLayout;
111
112const MIN_CONTROL_RING_SIZE: usize = 144;
115
116const MIN_STATE_CHANGE_RING_SIZE: usize = 196;
119
120const VF_ASSOCIATION_TRANSACTION_ID: u64 = 0x8000000000000000;
122const SWITCH_DATA_PATH_TRANSACTION_ID: u64 = 0x8000000000000001;
124
125const NETVSP_MAX_SUBCHANNELS_PER_VNIC: u16 = 64;
126
127#[cfg(not(test))]
135const VF_DEVICE_DELAY: Duration = Duration::from_secs(1);
136#[cfg(test)]
137const VF_DEVICE_DELAY: Duration = Duration::from_millis(100);
138
139#[cfg(not(test))]
142const LINK_DELAY_DURATION: Duration = Duration::from_secs(5);
143#[cfg(test)]
144const LINK_DELAY_DURATION: Duration = Duration::from_millis(333);
145
146#[derive(Default, PartialEq)]
147struct CoordinatorMessageUpdateType {
148 guest_vf_state: bool,
151 filter_state: bool,
153}
154
155#[derive(PartialEq)]
156enum CoordinatorMessage {
157 Update(CoordinatorMessageUpdateType),
159 Restart,
162 StartTimer(Instant),
164}
165
166struct Worker<T: RingMem> {
167 channel_idx: u16,
168 target_vp: Option<u32>,
169 mem: GuestMemory,
170 channel: NetChannel<T>,
171 state: WorkerState,
172 coordinator_send: mpsc::Sender<CoordinatorMessage>,
173}
174
175struct NetQueue {
176 driver: VmTaskDriver,
177 queue_state: Option<QueueState>,
178}
179
180impl<T: RingMem + 'static + Sync> InspectTaskMut<Worker<T>> for NetQueue {
181 fn inspect_mut(&mut self, req: inspect::Request<'_>, worker: Option<&mut Worker<T>>) {
182 if worker.is_none() && self.queue_state.is_none() {
183 req.ignore();
184 return;
185 }
186
187 let mut resp = req.respond();
188 resp.field("driver", &self.driver);
189 if let Some(worker) = worker {
190 resp.field(
191 "protocol_state",
192 match &worker.state {
193 WorkerState::Init(None) => "version",
194 WorkerState::Init(Some(_)) => "init",
195 WorkerState::Ready(_) => "ready",
196 WorkerState::WaitingForCoordinator(_) => "waiting for coordinator",
197 },
198 )
199 .field("ring", &worker.channel.queue)
200 .field(
201 "can_use_ring_size_optimization",
202 worker.channel.can_use_ring_size_opt,
203 );
204
205 if let WorkerState::Ready(state) = &worker.state {
206 resp.field(
207 "outstanding_tx_packets",
208 state.state.pending_tx_packets.len() - state.state.free_tx_packets.len(),
209 )
210 .field("pending_rx_packets", state.state.pending_rx_packets.len())
211 .field(
212 "pending_tx_completions",
213 state.state.pending_tx_completions.len(),
214 )
215 .field("free_tx_packets", state.state.free_tx_packets.len())
216 .merge(&state.state.stats);
217 }
218
219 resp.field("packet_filter", worker.channel.packet_filter);
220 }
221
222 if let Some(queue_state) = &mut self.queue_state {
223 resp.field_mut("queue", &mut queue_state.queue)
224 .field("rx_buffers", queue_state.rx_buffer_range.id_range.len())
225 .field(
226 "rx_buffers_start",
227 queue_state.rx_buffer_range.id_range.start,
228 );
229 }
230 }
231}
232
233enum WorkerState {
234 Init(Option<InitState>),
235 Ready(ReadyState),
236 WaitingForCoordinator(Option<ReadyState>),
237}
238
239impl WorkerState {
240 fn ready(&self) -> Option<&ReadyState> {
241 if let Self::Ready(state) = self {
242 Some(state)
243 } else {
244 None
245 }
246 }
247
248 fn ready_mut(&mut self) -> Option<&mut ReadyState> {
249 if let Self::Ready(state) = self {
250 Some(state)
251 } else {
252 None
253 }
254 }
255}
256
257struct InitState {
258 version: Version,
259 ndis_config: Option<NdisConfig>,
260 ndis_version: Option<NdisVersion>,
261 recv_buffer: Option<ReceiveBuffer>,
262 send_buffer: Option<SendBuffer>,
263}
264
265#[derive(Copy, Clone, Debug, Inspect)]
266struct NdisVersion {
267 #[inspect(hex)]
268 major: u32,
269 #[inspect(hex)]
270 minor: u32,
271}
272
273#[derive(Copy, Clone, Debug, Inspect)]
274struct NdisConfig {
275 #[inspect(safe)]
276 mtu: u32,
277 #[inspect(safe)]
278 capabilities: protocol::NdisConfigCapabilities,
279}
280
281struct ReadyState {
282 buffers: Arc<ChannelBuffers>,
283 state: ActiveState,
284 data: ProcessingData,
285}
286
287impl ReadyState {
288 fn reset_tx_after_endpoint_stop(&mut self) {
297 let state = &mut self.state;
298
299 let pending_tx = state
302 .pending_tx_packets
303 .iter_mut()
304 .enumerate()
305 .filter_map(|(id, inflight)| {
306 if inflight.pending_packet_count > 0 {
307 inflight.pending_packet_count = 0;
308 Some(PendingTxCompletion {
309 transaction_id: inflight.transaction_id,
310 tx_id: Some(TxId(id as u32)),
311 status: protocol::Status::SUCCESS,
312 })
313 } else {
314 None
315 }
316 })
317 .collect::<Vec<_>>();
318 state.pending_tx_completions.extend(pending_tx);
319
320 self.data.tx_segments.clear();
323 self.data.tx_segments_sent = 0;
324 }
325}
326
327#[async_trait]
330pub trait VirtualFunction: Sync + Send {
331 async fn id(&self) -> Option<u32>;
335 async fn guest_ready_for_device(&mut self);
337 async fn wait_for_state_change(&mut self) -> Rpc<(), ()>;
340}
341
342struct Adapter {
343 driver: VmTaskDriver,
344 mac_address: MacAddress,
345 max_queues: u16,
346 indirection_table_size: u16,
347 offload_support: OffloadConfig,
348 ring_size_limit: AtomicUsize,
349 free_tx_packet_threshold: usize,
350 tx_fast_completions: bool,
351 adapter_index: u32,
352 get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
353 num_sub_channels_opened: AtomicUsize,
354 link_speed: u64,
355}
356
357struct QueueState {
358 queue: Box<dyn net_backend::Queue>,
359 rx_buffer_range: RxBufferRange,
360 target_vp_set: bool,
361}
362
363struct RxBufferRange {
364 id_range: Range<u32>,
365 remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
366 remote_ranges: Arc<RxBufferRanges>,
367}
368
369impl RxBufferRange {
370 fn new(
371 ranges: Arc<RxBufferRanges>,
372 id_range: Range<u32>,
373 remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
374 ) -> Self {
375 Self {
376 id_range,
377 remote_buffer_id_recv,
378 remote_ranges: ranges,
379 }
380 }
381
382 fn send_if_remote(&self, id: u32) -> bool {
383 if id < RX_RESERVED_CONTROL_BUFFERS || self.id_range.contains(&id) {
386 false
387 } else {
388 let i = (id - RX_RESERVED_CONTROL_BUFFERS) / self.remote_ranges.buffers_per_queue;
389 let i = (i as usize).min(self.remote_ranges.buffer_id_send.len() - 1);
393 let _ = self.remote_ranges.buffer_id_send[i].unbounded_send(id);
394 true
395 }
396 }
397}
398
399struct RxBufferRanges {
400 buffers_per_queue: u32,
401 buffer_id_send: Vec<mpsc::UnboundedSender<u32>>,
402}
403
404impl RxBufferRanges {
405 fn new(buffer_count: u32, queue_count: u32) -> (Self, Vec<mpsc::UnboundedReceiver<u32>>) {
406 let buffers_per_queue = (buffer_count - RX_RESERVED_CONTROL_BUFFERS) / queue_count;
407 #[expect(clippy::disallowed_methods)] let (send, recv): (Vec<_>, Vec<_>) = (0..queue_count).map(|_| mpsc::unbounded()).unzip();
409 (
410 Self {
411 buffers_per_queue,
412 buffer_id_send: send,
413 },
414 recv,
415 )
416 }
417}
418
419struct RssState {
420 key: [u8; 40],
421 indirection_table: Vec<u16>,
422}
423
424struct NetChannel<T: RingMem> {
426 adapter: Arc<Adapter>,
427 queue: Queue<T>,
428 gpadl_map: GpadlMapView,
429 packet_size: PacketSize,
430 pending_send_size: usize,
431 restart: Option<CoordinatorMessage>,
432 can_use_ring_size_opt: bool,
433 packet_filter: u32,
434}
435
436#[derive(Debug, Copy, Clone)]
438enum PacketSize {
439 V1,
441 V61,
443}
444
445struct ProcessingData {
447 tx_segments: Vec<TxSegment>,
448 tx_segments_sent: usize,
449 tx_done: Box<[TxId]>,
450 rx_ready: Box<[RxId]>,
451 rx_done: Vec<RxId>,
452 transfer_pages: Vec<ring::TransferPageRange>,
453 external_data: MultiPagedRangeBuf,
454}
455
456impl ProcessingData {
457 fn new() -> Self {
458 Self {
459 tx_segments: Vec::new(),
460 tx_segments_sent: 0,
461 tx_done: vec![TxId(0); 8192].into(),
462 rx_ready: vec![RxId(0); RX_BATCH_SIZE].into(),
463 rx_done: Vec::with_capacity(RX_BATCH_SIZE),
464 transfer_pages: Vec::with_capacity(RX_BATCH_SIZE),
465 external_data: MultiPagedRangeBuf::new(),
466 }
467 }
468}
469
470#[derive(Debug, Inspect)]
473struct ChannelBuffers {
474 version: Version,
475 #[inspect(skip)]
476 mem: GuestMemory,
477 #[inspect(skip)]
478 recv_buffer: ReceiveBuffer,
479 #[inspect(skip)]
480 send_buffer: Option<SendBuffer>,
481 ndis_version: NdisVersion,
482 #[inspect(safe)]
483 ndis_config: NdisConfig,
484}
485
486#[derive(Copy, Clone, Debug)]
488struct ControlMessageId(u32);
489
490struct ActiveState {
492 primary: Option<PrimaryChannelState>,
493
494 pending_tx_packets: Vec<PendingTxPacket>,
495 free_tx_packets: Vec<TxId>,
496 pending_tx_completions: VecDeque<PendingTxCompletion>,
497 pending_rx_packets: VecDeque<RxId>,
498
499 rx_bufs: RxBuffers,
500
501 stats: QueueStats,
502}
503
504#[derive(Inspect, Default)]
505struct QueueStats {
506 tx_stalled: Counter,
507 rx_dropped_ring_full: Counter,
508 rx_dropped_filtered: Counter,
509 spurious_wakes: Counter,
510 rx_packets: Counter,
511 tx_packets: Counter,
512 tx_lso_packets: Counter,
513 tx_checksum_packets: Counter,
514 tx_invalid_lso_packets: Counter,
515 tx_packets_per_wake: Histogram<10>,
516 rx_packets_per_wake: Histogram<10>,
517}
518
519#[derive(Debug)]
520struct PendingTxCompletion {
521 transaction_id: u64,
522 tx_id: Option<TxId>,
523 status: protocol::Status,
524}
525
526#[derive(Clone, Copy)]
527enum PrimaryChannelGuestVfState {
528 Initializing,
530 Restoring(saved_state::GuestVfState),
532 Unavailable,
534 UnavailableFromAvailable,
536 UnavailableFromDataPathSwitchPending { to_guest: bool, id: Option<u64> },
538 UnavailableFromDataPathSwitched,
540 Available { vfid: u32 },
542 AvailableAdvertised,
544 Ready,
546 DataPathSwitchPending {
548 to_guest: bool,
549 id: Option<u64>,
550 result: Option<bool>,
551 },
552 DataPathSwitched,
554 DataPathSynthetic,
557}
558
559impl std::fmt::Display for PrimaryChannelGuestVfState {
560 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
561 match self {
562 PrimaryChannelGuestVfState::Initializing => write!(f, "initializing"),
563 PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
564 write!(f, "restoring")
565 }
566 PrimaryChannelGuestVfState::Restoring(
567 saved_state::GuestVfState::AvailableAdvertised,
568 ) => write!(f, "restoring from guest notified of vfid"),
569 PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
570 write!(f, "restoring from vf present")
571 }
572 PrimaryChannelGuestVfState::Restoring(
573 saved_state::GuestVfState::DataPathSwitchPending {
574 to_guest, result, ..
575 },
576 ) => {
577 write!(
578 f,
579 "restoring from client requested data path switch: to {} {}",
580 if *to_guest { "guest" } else { "synthetic" },
581 if let Some(result) = result {
582 if *result { "succeeded\"" } else { "failed\"" }
583 } else {
584 "in progress\""
585 }
586 )
587 }
588 PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::DataPathSwitched) => {
589 write!(f, "restoring from data path in guest")
590 }
591 PrimaryChannelGuestVfState::Unavailable => write!(f, "unavailable"),
592 PrimaryChannelGuestVfState::UnavailableFromAvailable => {
593 write!(f, "\"unavailable (previously available)\"")
594 }
595 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. } => {
596 write!(f, "unavailable (previously switching data path)")
597 }
598 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
599 write!(f, "\"unavailable (previously using guest VF)\"")
600 }
601 PrimaryChannelGuestVfState::Available { vfid } => write!(f, "available vfid: {}", vfid),
602 PrimaryChannelGuestVfState::AvailableAdvertised => {
603 write!(f, "\"available, guest notified\"")
604 }
605 PrimaryChannelGuestVfState::Ready => write!(f, "\"available and present in guest\""),
606 PrimaryChannelGuestVfState::DataPathSwitchPending {
607 to_guest, result, ..
608 } => {
609 write!(
610 f,
611 "\"switching to {} {}",
612 if *to_guest { "guest" } else { "synthetic" },
613 if let Some(result) = result {
614 if *result { "succeeded\"" } else { "failed\"" }
615 } else {
616 "in progress\""
617 }
618 )
619 }
620 PrimaryChannelGuestVfState::DataPathSwitched => {
621 write!(f, "\"available and data path switched\"")
622 }
623 PrimaryChannelGuestVfState::DataPathSynthetic => write!(
624 f,
625 "\"available but data path switched back to synthetic due to external state change\""
626 ),
627 }
628 }
629}
630
631impl Inspect for PrimaryChannelGuestVfState {
632 fn inspect(&self, req: inspect::Request<'_>) {
633 req.value(self.to_string());
634 }
635}
636
637#[derive(Debug)]
638enum PendingLinkAction {
639 Default,
640 Active(bool),
641 Delay(bool),
642}
643
644struct PrimaryChannelState {
645 guest_vf_state: PrimaryChannelGuestVfState,
646 is_data_path_switched: Option<bool>,
647 control_messages: VecDeque<ControlMessage>,
648 control_messages_len: usize,
649 free_control_buffers: Vec<ControlMessageId>,
650 rss_state: Option<RssState>,
651 requested_num_queues: u16,
652 rndis_state: RndisState,
653 offload_config: OffloadConfig,
654 pending_offload_change: bool,
655 tx_spread_sent: bool,
656 guest_link_up: bool,
657 pending_link_action: PendingLinkAction,
658}
659
660impl Inspect for PrimaryChannelState {
661 fn inspect(&self, req: inspect::Request<'_>) {
662 req.respond()
663 .sensitivity_field(
664 "guest_vf_state",
665 SensitivityLevel::Safe,
666 self.guest_vf_state,
667 )
668 .sensitivity_field(
669 "data_path_switched",
670 SensitivityLevel::Safe,
671 self.is_data_path_switched,
672 )
673 .sensitivity_field(
674 "pending_control_messages",
675 SensitivityLevel::Safe,
676 self.control_messages.len(),
677 )
678 .sensitivity_field(
679 "free_control_message_buffers",
680 SensitivityLevel::Safe,
681 self.free_control_buffers.len(),
682 )
683 .sensitivity_field(
684 "pending_offload_change",
685 SensitivityLevel::Safe,
686 self.pending_offload_change,
687 )
688 .sensitivity_field("rndis_state", SensitivityLevel::Safe, self.rndis_state)
689 .sensitivity_field(
690 "offload_config",
691 SensitivityLevel::Safe,
692 &self.offload_config,
693 )
694 .sensitivity_field(
695 "tx_spread_sent",
696 SensitivityLevel::Safe,
697 self.tx_spread_sent,
698 )
699 .sensitivity_field("guest_link_up", SensitivityLevel::Safe, self.guest_link_up)
700 .sensitivity_field(
701 "pending_link_action",
702 SensitivityLevel::Safe,
703 match &self.pending_link_action {
704 PendingLinkAction::Active(up) => format!("Active({:x?})", up),
705 PendingLinkAction::Delay(up) => format!("Delay({:x?})", up),
706 PendingLinkAction::Default => "None".to_string(),
707 },
708 );
709 }
710}
711
712#[derive(Debug, Inspect, Clone)]
713struct OffloadConfig {
714 #[inspect(safe)]
715 checksum_tx: ChecksumOffloadConfig,
716 #[inspect(safe)]
717 checksum_rx: ChecksumOffloadConfig,
718 #[inspect(safe)]
719 lso4: bool,
720 #[inspect(safe)]
721 lso6: bool,
722}
723
724impl OffloadConfig {
725 fn mask_to_supported(&mut self, supported: &OffloadConfig) {
726 self.checksum_tx.mask_to_supported(&supported.checksum_tx);
727 self.checksum_rx.mask_to_supported(&supported.checksum_rx);
728 self.lso4 &= supported.lso4;
729 self.lso6 &= supported.lso6;
730 }
731}
732
733#[derive(Debug, Inspect, Clone)]
734struct ChecksumOffloadConfig {
735 #[inspect(safe)]
736 ipv4_header: bool,
737 #[inspect(safe)]
738 tcp4: bool,
739 #[inspect(safe)]
740 udp4: bool,
741 #[inspect(safe)]
742 tcp6: bool,
743 #[inspect(safe)]
744 udp6: bool,
745}
746
747impl ChecksumOffloadConfig {
748 fn mask_to_supported(&mut self, supported: &ChecksumOffloadConfig) {
749 self.ipv4_header &= supported.ipv4_header;
750 self.tcp4 &= supported.tcp4;
751 self.udp4 &= supported.udp4;
752 self.tcp6 &= supported.tcp6;
753 self.udp6 &= supported.udp6;
754 }
755
756 fn flags(
757 &self,
758 ) -> (
759 rndisprot::Ipv4ChecksumOffload,
760 rndisprot::Ipv6ChecksumOffload,
761 ) {
762 let on = rndisprot::NDIS_OFFLOAD_SUPPORTED;
763 let mut v4 = rndisprot::Ipv4ChecksumOffload::new();
764 let mut v6 = rndisprot::Ipv6ChecksumOffload::new();
765 if self.ipv4_header {
766 v4.set_ip_options_supported(on);
767 v4.set_ip_checksum(on);
768 }
769 if self.tcp4 {
770 v4.set_ip_options_supported(on);
771 v4.set_tcp_options_supported(on);
772 v4.set_tcp_checksum(on);
773 }
774 if self.tcp6 {
775 v6.set_ip_extension_headers_supported(on);
776 v6.set_tcp_options_supported(on);
777 v6.set_tcp_checksum(on);
778 }
779 if self.udp4 {
780 v4.set_ip_options_supported(on);
781 v4.set_udp_checksum(on);
782 }
783 if self.udp6 {
784 v6.set_ip_extension_headers_supported(on);
785 v6.set_udp_checksum(on);
786 }
787 (v4, v6)
788 }
789}
790
791impl OffloadConfig {
792 fn ndis_offload(&self) -> rndisprot::NdisOffload {
793 let checksum = {
794 let (ipv4_tx_flags, ipv6_tx_flags) = self.checksum_tx.flags();
795 let (ipv4_rx_flags, ipv6_rx_flags) = self.checksum_rx.flags();
796 rndisprot::TcpIpChecksumOffload {
797 ipv4_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
798 ipv4_tx_flags,
799 ipv4_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
800 ipv4_rx_flags,
801 ipv6_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
802 ipv6_tx_flags,
803 ipv6_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
804 ipv6_rx_flags,
805 }
806 };
807
808 let lso_v2 = {
809 let mut lso = rndisprot::TcpLargeSendOffloadV2::new_zeroed();
810 if self.lso4 {
811 lso.ipv4_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
812 lso.ipv4_max_offload_size = rndisprot::LSO_MAX_OFFLOAD_SIZE;
813 lso.ipv4_min_segment_count = rndisprot::LSO_MIN_SEGMENT_COUNT;
814 }
815 if self.lso6 {
816 lso.ipv6_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
817 lso.ipv6_max_offload_size = rndisprot::LSO_MAX_OFFLOAD_SIZE;
818 lso.ipv6_min_segment_count = rndisprot::LSO_MIN_SEGMENT_COUNT;
819 lso.ipv6_flags = rndisprot::Ipv6LsoFlags::new()
820 .with_ip_extension_headers_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED)
821 .with_tcp_options_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED);
822 }
823 lso
824 };
825
826 rndisprot::NdisOffload {
827 header: rndisprot::NdisObjectHeader {
828 object_type: rndisprot::NdisObjectType::OFFLOAD,
829 revision: 3,
830 size: rndisprot::NDIS_SIZEOF_NDIS_OFFLOAD_REVISION_3 as u16,
831 },
832 checksum,
833 lso_v2,
834 ..FromZeros::new_zeroed()
835 }
836 }
837}
838
839#[derive(Debug, Inspect, PartialEq, Eq, Copy, Clone)]
840pub enum RndisState {
841 Initializing,
842 Operational,
843 Halted,
844}
845
846impl PrimaryChannelState {
847 fn new(offload_config: OffloadConfig) -> Self {
848 Self {
849 guest_vf_state: PrimaryChannelGuestVfState::Initializing,
850 is_data_path_switched: None,
851 control_messages: VecDeque::new(),
852 control_messages_len: 0,
853 free_control_buffers: (0..RX_RESERVED_CONTROL_BUFFERS)
854 .map(ControlMessageId)
855 .collect(),
856 rss_state: None,
857 requested_num_queues: 1,
858 rndis_state: RndisState::Initializing,
859 pending_offload_change: false,
860 offload_config,
861 tx_spread_sent: false,
862 guest_link_up: true,
863 pending_link_action: PendingLinkAction::Default,
864 }
865 }
866
867 fn restore(
868 guest_vf_state: &saved_state::GuestVfState,
869 rndis_state: &saved_state::RndisState,
870 offload_config: &saved_state::OffloadConfig,
871 pending_offload_change: bool,
872 num_queues: u16,
873 indirection_table_size: u16,
874 rx_bufs: &RxBuffers,
875 control_messages: Vec<saved_state::IncomingControlMessage>,
876 rss_state: Option<saved_state::RssState>,
877 tx_spread_sent: bool,
878 guest_link_down: bool,
879 pending_link_action: Option<bool>,
880 ) -> Result<Self, NetRestoreError> {
881 let control_messages_len = control_messages.iter().map(|msg| msg.data.len()).sum();
883
884 let control_messages = control_messages
885 .into_iter()
886 .map(|msg| ControlMessage {
887 message_type: msg.message_type,
888 data: msg.data.into(),
889 })
890 .collect();
891
892 let free_control_buffers = (0..RX_RESERVED_CONTROL_BUFFERS)
894 .filter_map(|id| rx_bufs.is_free(id).then_some(ControlMessageId(id)))
895 .collect();
896
897 let rss_state = rss_state
898 .map(|mut rss| {
899 if rss.indirection_table.len() > indirection_table_size as usize {
900 return Err(NetRestoreError::ReducedIndirectionTableSize);
903 }
904 if rss.indirection_table.len() < indirection_table_size as usize {
905 tracing::warn!(
906 saved_indirection_table_size = rss.indirection_table.len(),
907 adapter_indirection_table_size = indirection_table_size,
908 "increasing indirection table size",
909 );
910 let table_clone = rss.indirection_table.clone();
913 let num_to_add = indirection_table_size as usize - rss.indirection_table.len();
914 rss.indirection_table
915 .extend(table_clone.iter().cycle().take(num_to_add));
916 }
917 Ok(RssState {
918 key: rss
919 .key
920 .try_into()
921 .map_err(|_| NetRestoreError::InvalidRssKeySize)?,
922 indirection_table: rss.indirection_table,
923 })
924 })
925 .transpose()?;
926
927 let rndis_state = match rndis_state {
928 saved_state::RndisState::Initializing => RndisState::Initializing,
929 saved_state::RndisState::Operational => RndisState::Operational,
930 saved_state::RndisState::Halted => RndisState::Halted,
931 };
932
933 let guest_vf_state = PrimaryChannelGuestVfState::Restoring(*guest_vf_state);
934 let offload_config = OffloadConfig {
935 checksum_tx: ChecksumOffloadConfig {
936 ipv4_header: offload_config.checksum_tx.ipv4_header,
937 tcp4: offload_config.checksum_tx.tcp4,
938 udp4: offload_config.checksum_tx.udp4,
939 tcp6: offload_config.checksum_tx.tcp6,
940 udp6: offload_config.checksum_tx.udp6,
941 },
942 checksum_rx: ChecksumOffloadConfig {
943 ipv4_header: offload_config.checksum_rx.ipv4_header,
944 tcp4: offload_config.checksum_rx.tcp4,
945 udp4: offload_config.checksum_rx.udp4,
946 tcp6: offload_config.checksum_rx.tcp6,
947 udp6: offload_config.checksum_rx.udp6,
948 },
949 lso4: offload_config.lso4,
950 lso6: offload_config.lso6,
951 };
952
953 let pending_link_action = if let Some(pending) = pending_link_action {
954 PendingLinkAction::Active(pending)
955 } else {
956 PendingLinkAction::Default
957 };
958
959 Ok(Self {
960 guest_vf_state,
961 is_data_path_switched: None,
962 control_messages,
963 control_messages_len,
964 free_control_buffers,
965 rss_state,
966 requested_num_queues: num_queues,
967 rndis_state,
968 pending_offload_change,
969 offload_config,
970 tx_spread_sent,
971 guest_link_up: !guest_link_down,
972 pending_link_action,
973 })
974 }
975}
976
977struct ControlMessage {
978 message_type: u32,
979 data: Box<[u8]>,
980}
981
982const TX_PACKET_QUOTA: usize = 1024;
983
984impl ActiveState {
985 fn new(primary: Option<PrimaryChannelState>, recv_buffer_count: u32) -> Self {
986 Self {
987 primary,
988 pending_tx_packets: vec![Default::default(); TX_PACKET_QUOTA],
989 free_tx_packets: (0..TX_PACKET_QUOTA as u32).rev().map(TxId).collect(),
990 pending_tx_completions: VecDeque::new(),
991 pending_rx_packets: VecDeque::new(),
992 rx_bufs: RxBuffers::new(recv_buffer_count),
993 stats: Default::default(),
994 }
995 }
996
997 fn restore(
998 channel: &saved_state::Channel,
999 recv_buffer_count: u32,
1000 ) -> Result<Self, NetRestoreError> {
1001 let mut active = Self::new(None, recv_buffer_count);
1002 let saved_state::Channel {
1003 pending_tx_completions,
1004 in_use_rx,
1005 } = channel;
1006 for rx in in_use_rx {
1007 active
1008 .rx_bufs
1009 .allocate(rx.buffers.as_slice().iter().copied())?;
1010 }
1011 for &transaction_id in pending_tx_completions {
1012 let tx_id = active.free_tx_packets.pop();
1016 if let Some(id) = tx_id {
1017 active.pending_tx_packets[id.0 as usize].transaction_id = transaction_id;
1019 }
1020 active
1023 .pending_tx_completions
1024 .push_back(PendingTxCompletion {
1025 transaction_id,
1026 tx_id,
1027 status: protocol::Status::SUCCESS,
1028 });
1029 }
1030 Ok(active)
1031 }
1032}
1033
1034#[derive(Default, Clone)]
1037struct PendingTxPacket {
1038 pending_packet_count: usize,
1039 transaction_id: u64,
1040}
1041
1042const RX_BATCH_SIZE: usize = 375;
1047
1048const RX_RESERVED_CONTROL_BUFFERS: u32 = 16;
1050
1051pub struct Nic {
1053 instance_id: Guid,
1054 resources: DeviceResources,
1055 coordinator: TaskControl<CoordinatorState, Coordinator>,
1056 coordinator_send: Option<mpsc::Sender<CoordinatorMessage>>,
1057 adapter: Arc<Adapter>,
1058 driver_source: VmTaskDriverSource,
1059}
1060
1061pub struct NicBuilder {
1062 virtual_function: Option<Box<dyn VirtualFunction>>,
1063 limit_ring_buffer: bool,
1064 max_queues: u16,
1065 get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
1066}
1067
1068impl NicBuilder {
1069 pub fn limit_ring_buffer(mut self, limit: bool) -> Self {
1070 self.limit_ring_buffer = limit;
1071 self
1072 }
1073
1074 pub fn max_queues(mut self, max_queues: u16) -> Self {
1075 self.max_queues = max_queues;
1076 self
1077 }
1078
1079 pub fn virtual_function(mut self, virtual_function: Box<dyn VirtualFunction>) -> Self {
1080 self.virtual_function = Some(virtual_function);
1081 self
1082 }
1083
1084 pub fn get_guest_os_id(mut self, os_type: Box<dyn Fn() -> HvGuestOsId + Send + Sync>) -> Self {
1085 self.get_guest_os_id = Some(os_type);
1086 self
1087 }
1088
1089 pub fn build(
1091 self,
1092 driver_source: &VmTaskDriverSource,
1093 instance_id: Guid,
1094 endpoint: Box<dyn Endpoint>,
1095 mac_address: MacAddress,
1096 adapter_index: u32,
1097 ) -> Nic {
1098 let multiqueue = endpoint.multiqueue_support();
1099
1100 let max_queues = self.max_queues.clamp(
1101 1,
1102 multiqueue.max_queues.min(NETVSP_MAX_SUBCHANNELS_PER_VNIC),
1103 );
1104
1105 let ring_size_limit = if self.limit_ring_buffer { 1024 } else { 0 };
1110
1111 let free_tx_packet_threshold = if endpoint.tx_fast_completions() {
1116 TX_PACKET_QUOTA
1117 } else {
1118 TX_PACKET_QUOTA / 4
1121 };
1122
1123 let tx_offloads = endpoint.tx_offload_support();
1124
1125 let offload_support = OffloadConfig {
1128 checksum_rx: ChecksumOffloadConfig {
1129 ipv4_header: true,
1130 tcp4: true,
1131 udp4: true,
1132 tcp6: true,
1133 udp6: true,
1134 },
1135 checksum_tx: ChecksumOffloadConfig {
1136 ipv4_header: tx_offloads.ipv4_header,
1137 tcp4: tx_offloads.tcp,
1138 tcp6: tx_offloads.tcp,
1139 udp4: tx_offloads.udp,
1140 udp6: tx_offloads.udp,
1141 },
1142 lso4: tx_offloads.tso,
1143 lso6: tx_offloads.tso,
1144 };
1145
1146 let driver = driver_source.simple();
1147 let adapter = Arc::new(Adapter {
1148 driver,
1149 mac_address,
1150 max_queues,
1151 indirection_table_size: multiqueue.indirection_table_size,
1152 offload_support,
1153 free_tx_packet_threshold,
1154 ring_size_limit: ring_size_limit.into(),
1155 tx_fast_completions: endpoint.tx_fast_completions(),
1156 adapter_index,
1157 get_guest_os_id: self.get_guest_os_id,
1158 num_sub_channels_opened: AtomicUsize::new(0),
1159 link_speed: endpoint.link_speed(),
1160 });
1161
1162 let coordinator = TaskControl::new(CoordinatorState {
1163 endpoint,
1164 adapter: adapter.clone(),
1165 virtual_function: self.virtual_function,
1166 pending_vf_state: CoordinatorStatePendingVfState::Ready,
1167 });
1168
1169 Nic {
1170 instance_id,
1171 resources: Default::default(),
1172 coordinator,
1173 coordinator_send: None,
1174 adapter,
1175 driver_source: driver_source.clone(),
1176 }
1177 }
1178}
1179
1180fn can_use_ring_opt<T: RingMem>(queue: &mut Queue<T>, guest_os_id: Option<HvGuestOsId>) -> bool {
1181 let Some(guest_os_id) = guest_os_id else {
1182 return false;
1184 };
1185
1186 if !queue.split().0.supports_pending_send_size() {
1187 return false;
1189 }
1190
1191 let Some(open_source_os) = guest_os_id.open_source() else {
1192 return true;
1194 };
1195
1196 match HvGuestOsOpenSourceType(open_source_os.os_type()) {
1197 HvGuestOsOpenSourceType::FREEBSD => open_source_os.version() >= 1400097,
1200 HvGuestOsOpenSourceType::LINUX => {
1204 open_source_os.version() >= 199424
1207 }
1208 _ => true,
1209 }
1210}
1211
1212impl Nic {
1213 pub fn builder() -> NicBuilder {
1214 NicBuilder {
1215 virtual_function: None,
1216 limit_ring_buffer: false,
1217 max_queues: !0,
1218 get_guest_os_id: None,
1219 }
1220 }
1221
1222 pub fn shutdown(self) -> Box<dyn Endpoint> {
1223 let (state, _) = self.coordinator.into_inner();
1224 state.endpoint
1225 }
1226}
1227
1228impl InspectMut for Nic {
1229 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
1230 self.coordinator.inspect_mut(req);
1231 }
1232}
1233
1234#[async_trait]
1235impl VmbusDevice for Nic {
1236 fn offer(&self) -> OfferParams {
1237 OfferParams {
1238 interface_name: "net".to_owned(),
1239 instance_id: self.instance_id,
1240 interface_id: Guid {
1241 data1: 0xf8615163,
1242 data2: 0xdf3e,
1243 data3: 0x46c5,
1244 data4: [0x91, 0x3f, 0xf2, 0xd2, 0xf9, 0x65, 0xed, 0xe],
1245 },
1246 subchannel_index: 0,
1247 mnf_interrupt_latency: Some(Duration::from_micros(100)),
1248 ..Default::default()
1249 }
1250 }
1251
1252 fn max_subchannels(&self) -> u16 {
1253 self.adapter.max_queues
1254 }
1255
1256 fn install(&mut self, resources: DeviceResources) {
1257 self.resources = resources;
1258 }
1259
1260 async fn open(
1261 &mut self,
1262 channel_idx: u16,
1263 open_request: &OpenRequest,
1264 ) -> Result<(), ChannelOpenError> {
1265 let state = if channel_idx == 0 {
1267 self.insert_coordinator(1, None);
1268 WorkerState::Init(None)
1269 } else {
1270 self.coordinator.stop().await;
1271 let buffers = self.coordinator.state().unwrap().buffers.clone().unwrap();
1273 WorkerState::Ready(ReadyState {
1274 state: ActiveState::new(None, buffers.recv_buffer.count),
1275 buffers,
1276 data: ProcessingData::new(),
1277 })
1278 };
1279
1280 let num_opened = self
1281 .adapter
1282 .num_sub_channels_opened
1283 .fetch_add(1, Ordering::SeqCst);
1284 let r = self.insert_worker(channel_idx, open_request, state, true);
1285 if channel_idx != 0
1286 && num_opened + 1 == self.coordinator.state_mut().unwrap().num_queues as usize
1287 {
1288 let coordinator = &mut self.coordinator.state_mut().unwrap();
1289 coordinator.workers[0].stop().await;
1290 }
1291
1292 if r.is_err() && channel_idx == 0 {
1293 self.coordinator.remove();
1294 } else {
1295 self.coordinator.start();
1297 }
1298 r?;
1299 Ok(())
1300 }
1301
1302 async fn close(&mut self, channel_idx: u16) {
1303 if !self.coordinator.has_state() {
1304 tracing::error!(
1305 channel_idx,
1306 instance_id = %self.instance_id,
1307 "Close called while vmbus channel is already closed"
1308 );
1309 return;
1310 }
1311
1312 let restart = self.coordinator.stop().await;
1314
1315 {
1317 let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1318 worker.stop().await;
1319 if worker.has_state() {
1320 worker.remove();
1321 }
1322 }
1323
1324 self.adapter
1325 .num_sub_channels_opened
1326 .fetch_sub(1, Ordering::SeqCst);
1327 if channel_idx == 0 {
1329 for worker in &mut self.coordinator.state_mut().unwrap().workers {
1330 worker.task_mut().queue_state = None;
1331 }
1332
1333 self.coordinator.task_mut().endpoint.stop().await;
1335
1336 self.coordinator.remove();
1341 } else {
1342 if restart {
1344 self.coordinator.start();
1345 }
1346 }
1347 }
1348
1349 async fn retarget_vp(&mut self, channel_idx: u16, target_vp: u32) {
1350 if !self.coordinator.has_state() {
1351 return;
1352 }
1353
1354 let coordinator_running = self.coordinator.stop().await;
1356 let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1357 worker.stop().await;
1358 let (net_queue, worker_state) = worker.get_mut();
1359
1360 net_queue.driver.retarget_vp(target_vp);
1362
1363 if let Some(worker_state) = worker_state {
1364 worker_state.target_vp = Some(target_vp);
1366 if let Some(queue_state) = &mut net_queue.queue_state {
1367 queue_state.target_vp_set = false;
1369 }
1370 }
1371
1372 if coordinator_running {
1374 self.coordinator.start();
1375 }
1376 }
1377
1378 fn start(&mut self) {
1379 if !self.coordinator.is_running() {
1380 self.coordinator.start();
1381 }
1382 }
1383
1384 async fn stop(&mut self) {
1385 self.coordinator.stop().await;
1386 if let Some(coordinator) = self.coordinator.state_mut() {
1387 coordinator.stop_workers().await;
1388 }
1389 }
1390
1391 fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
1392 Some(self)
1393 }
1394}
1395
1396#[async_trait]
1397impl SaveRestoreVmbusDevice for Nic {
1398 async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
1399 let state = self.saved_state();
1400 Ok(SavedStateBlob::new(state))
1401 }
1402
1403 async fn restore(
1404 &mut self,
1405 control: RestoreControl<'_>,
1406 state: SavedStateBlob,
1407 ) -> Result<(), RestoreError> {
1408 let state: saved_state::SavedState = state.parse()?;
1409 if let Err(err) = self.restore_state(control, state).await {
1410 tracing::error!(
1411 error = &err as &dyn std::error::Error,
1412 instance_id = %self.instance_id,
1413 "Failed restoring network vmbus state"
1414 );
1415 Err(err.into())
1416 } else {
1417 Ok(())
1418 }
1419 }
1420}
1421
1422impl Nic {
1423 fn insert_worker(
1427 &mut self,
1428 channel_idx: u16,
1429 open_request: &OpenRequest,
1430 state: WorkerState,
1431 start: bool,
1432 ) -> Result<(), OpenError> {
1433 let coordinator = self.coordinator.state_mut().unwrap();
1434
1435 let driver = coordinator.workers[channel_idx as usize]
1439 .task()
1440 .driver
1441 .clone();
1442 driver.retarget_vp(open_request.open_data.target_vp.unwrap_or_default());
1443
1444 let raw = gpadl_channel(&driver, &self.resources, open_request, channel_idx)
1445 .map_err(OpenError::Ring)?;
1446 let mut queue = Queue::new(raw).map_err(OpenError::Queue)?;
1447 let guest_os_id = self.adapter.get_guest_os_id.as_ref().map(|f| f());
1448 let can_use_ring_size_opt = can_use_ring_opt(&mut queue, guest_os_id);
1449 let worker = Worker {
1450 channel_idx,
1451 target_vp: open_request.open_data.target_vp,
1452 mem: self
1453 .resources
1454 .offer_resources
1455 .guest_memory(open_request)
1456 .clone(),
1457 channel: NetChannel {
1458 adapter: self.adapter.clone(),
1459 queue,
1460 gpadl_map: self.resources.gpadl_map.clone(),
1461 packet_size: PacketSize::V1,
1462 pending_send_size: 0,
1463 restart: None,
1464 can_use_ring_size_opt,
1465 packet_filter: coordinator.active_packet_filter,
1466 },
1467 state,
1468 coordinator_send: self.coordinator_send.clone().unwrap(),
1469 };
1470 let instance_id = self.instance_id;
1471 let worker_task = &mut coordinator.workers[channel_idx as usize];
1472 worker_task.insert(
1473 driver,
1474 format!("netvsp-{}-{}", instance_id, channel_idx),
1475 worker,
1476 );
1477 if start {
1478 worker_task.start();
1479 }
1480 Ok(())
1481 }
1482}
1483
1484struct RestoreCoordinatorState {
1485 active_packet_filter: u32,
1486}
1487
1488impl Nic {
1489 fn insert_coordinator(&mut self, num_queues: u16, restoring: Option<RestoreCoordinatorState>) {
1491 let mut driver_builder = self.driver_source.builder();
1492 driver_builder.target_vp(0);
1495 driver_builder.run_on_target(!self.adapter.tx_fast_completions);
1500
1501 #[expect(clippy::disallowed_methods)] let (send, recv) = mpsc::channel(1);
1503 self.coordinator_send = Some(send);
1504 self.coordinator.insert(
1505 &self.adapter.driver,
1506 format!("netvsp-{}-coordinator", self.instance_id),
1507 Coordinator {
1508 recv,
1509 channel_control: self.resources.channel_control.clone(),
1510 restart: restoring.is_some(),
1511 workers: (0..self.adapter.max_queues)
1512 .map(|i| {
1513 TaskControl::new(NetQueue {
1514 queue_state: None,
1515 driver: driver_builder
1516 .build(format!("netvsp-{}-{}", self.instance_id, i)),
1517 })
1518 })
1519 .collect(),
1520 buffers: None,
1521 num_queues,
1522 active_packet_filter: restoring
1523 .map(|r| r.active_packet_filter)
1524 .unwrap_or(rndisprot::NDIS_PACKET_TYPE_NONE),
1525 sleep_deadline: None,
1526 },
1527 );
1528 }
1529}
1530
1531#[derive(Debug, Error)]
1532enum NetRestoreError {
1533 #[error("unsupported protocol version {0:#x}")]
1534 UnsupportedVersion(u32),
1535 #[error("send/receive buffer invalid gpadl ID")]
1536 UnknownGpadlId(#[from] UnknownGpadlId),
1537 #[error("failed to restore channels")]
1538 Channel(#[source] ChannelRestoreError),
1539 #[error(transparent)]
1540 ReceiveBuffer(#[from] BufferError),
1541 #[error(transparent)]
1542 SuballocationMisconfigured(#[from] SubAllocationInUse),
1543 #[error(transparent)]
1544 Open(#[from] OpenError),
1545 #[error("invalid rss key size")]
1546 InvalidRssKeySize,
1547 #[error("reduced indirection table size")]
1548 ReducedIndirectionTableSize,
1549}
1550
1551impl From<NetRestoreError> for RestoreError {
1552 fn from(err: NetRestoreError) -> Self {
1553 RestoreError::InvalidSavedState(anyhow::Error::new(err))
1554 }
1555}
1556
1557impl Nic {
1558 async fn restore_state(
1559 &mut self,
1560 mut control: RestoreControl<'_>,
1561 state: saved_state::SavedState,
1562 ) -> Result<(), NetRestoreError> {
1563 let mut saved_packet_filter = 0u32;
1564 if let Some(state) = state.open {
1565 let open = match &state.primary {
1566 saved_state::Primary::Version => vec![true],
1567 saved_state::Primary::Init(_) => vec![true],
1568 saved_state::Primary::Ready(ready) => {
1569 ready.channels.iter().map(|x| x.is_some()).collect()
1570 }
1571 };
1572
1573 let mut states: Vec<_> = open.iter().map(|_| None).collect();
1574
1575 let requests = control
1581 .restore(&open)
1582 .await
1583 .map_err(NetRestoreError::Channel)?;
1584
1585 match state.primary {
1586 saved_state::Primary::Version => {
1587 states[0] = Some(WorkerState::Init(None));
1588 }
1589 saved_state::Primary::Init(init) => {
1590 let version = check_version(init.version)
1591 .ok_or(NetRestoreError::UnsupportedVersion(init.version))?;
1592
1593 let recv_buffer = init
1594 .receive_buffer
1595 .map(|recv_buffer| {
1596 ReceiveBuffer::new(
1597 &self.resources.gpadl_map,
1598 recv_buffer.gpadl_id,
1599 recv_buffer.id,
1600 recv_buffer.sub_allocation_size,
1601 )
1602 })
1603 .transpose()?;
1604
1605 let send_buffer = init
1606 .send_buffer
1607 .map(|send_buffer| {
1608 SendBuffer::new(&self.resources.gpadl_map, send_buffer.gpadl_id)
1609 })
1610 .transpose()?;
1611
1612 let state = InitState {
1613 version,
1614 ndis_config: init.ndis_config.map(
1615 |saved_state::NdisConfig { mtu, capabilities }| NdisConfig {
1616 mtu,
1617 capabilities: capabilities.into(),
1618 },
1619 ),
1620 ndis_version: init.ndis_version.map(
1621 |saved_state::NdisVersion { major, minor }| NdisVersion {
1622 major,
1623 minor,
1624 },
1625 ),
1626 recv_buffer,
1627 send_buffer,
1628 };
1629 states[0] = Some(WorkerState::Init(Some(state)));
1630 }
1631 saved_state::Primary::Ready(ready) => {
1632 let saved_state::ReadyPrimary {
1633 version,
1634 receive_buffer,
1635 send_buffer,
1636 mut control_messages,
1637 mut rss_state,
1638 channels,
1639 ndis_version,
1640 ndis_config,
1641 rndis_state,
1642 guest_vf_state,
1643 offload_config,
1644 pending_offload_change,
1645 tx_spread_sent,
1646 guest_link_down,
1647 pending_link_action,
1648 packet_filter,
1649 } = ready;
1650
1651 saved_packet_filter = packet_filter.unwrap_or(rndisprot::NPROTO_PACKET_FILTER);
1653
1654 let version = check_version(version)
1655 .ok_or(NetRestoreError::UnsupportedVersion(version))?;
1656
1657 let request = requests[0].as_ref().unwrap();
1658 let buffers = Arc::new(ChannelBuffers {
1659 version,
1660 mem: self.resources.offer_resources.guest_memory(request).clone(),
1661 recv_buffer: ReceiveBuffer::new(
1662 &self.resources.gpadl_map,
1663 receive_buffer.gpadl_id,
1664 receive_buffer.id,
1665 receive_buffer.sub_allocation_size,
1666 )?,
1667 send_buffer: {
1668 if let Some(send_buffer) = send_buffer {
1669 Some(SendBuffer::new(
1670 &self.resources.gpadl_map,
1671 send_buffer.gpadl_id,
1672 )?)
1673 } else {
1674 None
1675 }
1676 },
1677 ndis_version: {
1678 let saved_state::NdisVersion { major, minor } = ndis_version;
1679 NdisVersion { major, minor }
1680 },
1681 ndis_config: {
1682 let saved_state::NdisConfig { mtu, capabilities } = ndis_config;
1683 NdisConfig {
1684 mtu,
1685 capabilities: capabilities.into(),
1686 }
1687 },
1688 });
1689
1690 for (channel_idx, channel) in channels.iter().enumerate() {
1691 let channel = if let Some(channel) = channel {
1692 channel
1693 } else {
1694 continue;
1695 };
1696
1697 let mut active = ActiveState::restore(channel, buffers.recv_buffer.count)?;
1698
1699 if channel_idx == 0 {
1701 let primary = PrimaryChannelState::restore(
1702 &guest_vf_state,
1703 &rndis_state,
1704 &offload_config,
1705 pending_offload_change,
1706 channels.len() as u16,
1707 self.adapter.indirection_table_size,
1708 &active.rx_bufs,
1709 std::mem::take(&mut control_messages),
1710 rss_state.take(),
1711 tx_spread_sent,
1712 guest_link_down,
1713 pending_link_action,
1714 )?;
1715 active.primary = Some(primary);
1716 }
1717
1718 states[channel_idx] = Some(WorkerState::Ready(ReadyState {
1719 buffers: buffers.clone(),
1720 state: active,
1721 data: ProcessingData::new(),
1722 }));
1723 }
1724 }
1725 }
1726
1727 self.insert_coordinator(
1730 states.len() as u16,
1731 Some(RestoreCoordinatorState {
1732 active_packet_filter: saved_packet_filter,
1733 }),
1734 );
1735
1736 for (channel_idx, (state, request)) in states.into_iter().zip(requests).enumerate() {
1737 if let Some(state) = state {
1738 self.insert_worker(channel_idx as u16, &request.unwrap(), state, false)?;
1739 }
1740 }
1741 } else {
1742 control
1743 .restore(&[false])
1744 .await
1745 .map_err(NetRestoreError::Channel)?;
1746 }
1747 Ok(())
1748 }
1749
1750 fn saved_state(&self) -> saved_state::SavedState {
1751 let open = if let Some(coordinator) = self.coordinator.state() {
1752 let primary = coordinator.workers[0].state().unwrap();
1753 let primary = match &primary.state {
1754 WorkerState::Init(None) => saved_state::Primary::Version,
1755 WorkerState::Init(Some(init)) => {
1756 saved_state::Primary::Init(saved_state::InitPrimary {
1757 version: init.version as u32,
1758 ndis_config: init.ndis_config.map(|NdisConfig { mtu, capabilities }| {
1759 saved_state::NdisConfig {
1760 mtu,
1761 capabilities: capabilities.into(),
1762 }
1763 }),
1764 ndis_version: init.ndis_version.map(|NdisVersion { major, minor }| {
1765 saved_state::NdisVersion { major, minor }
1766 }),
1767 receive_buffer: init.recv_buffer.as_ref().map(|x| x.saved_state()),
1768 send_buffer: init.send_buffer.as_ref().map(|x| saved_state::SendBuffer {
1769 gpadl_id: x.gpadl.id(),
1770 }),
1771 })
1772 }
1773 WorkerState::WaitingForCoordinator(Some(ready)) | WorkerState::Ready(ready) => {
1774 let primary = ready.state.primary.as_ref().unwrap();
1775
1776 let rndis_state = match primary.rndis_state {
1777 RndisState::Initializing => saved_state::RndisState::Initializing,
1778 RndisState::Operational => saved_state::RndisState::Operational,
1779 RndisState::Halted => saved_state::RndisState::Halted,
1780 };
1781
1782 let offload_config = saved_state::OffloadConfig {
1783 checksum_tx: saved_state::ChecksumOffloadConfig {
1784 ipv4_header: primary.offload_config.checksum_tx.ipv4_header,
1785 tcp4: primary.offload_config.checksum_tx.tcp4,
1786 udp4: primary.offload_config.checksum_tx.udp4,
1787 tcp6: primary.offload_config.checksum_tx.tcp6,
1788 udp6: primary.offload_config.checksum_tx.udp6,
1789 },
1790 checksum_rx: saved_state::ChecksumOffloadConfig {
1791 ipv4_header: primary.offload_config.checksum_rx.ipv4_header,
1792 tcp4: primary.offload_config.checksum_rx.tcp4,
1793 udp4: primary.offload_config.checksum_rx.udp4,
1794 tcp6: primary.offload_config.checksum_rx.tcp6,
1795 udp6: primary.offload_config.checksum_rx.udp6,
1796 },
1797 lso4: primary.offload_config.lso4,
1798 lso6: primary.offload_config.lso6,
1799 };
1800
1801 let control_messages = primary
1802 .control_messages
1803 .iter()
1804 .map(|message| saved_state::IncomingControlMessage {
1805 message_type: message.message_type,
1806 data: message.data.to_vec(),
1807 })
1808 .collect();
1809
1810 let rss_state = primary.rss_state.as_ref().map(|rss| saved_state::RssState {
1811 key: rss.key.into(),
1812 indirection_table: rss.indirection_table.clone(),
1813 });
1814
1815 let pending_link_action = match primary.pending_link_action {
1816 PendingLinkAction::Default => None,
1817 PendingLinkAction::Active(action) | PendingLinkAction::Delay(action) => {
1818 Some(action)
1819 }
1820 };
1821
1822 let channels = coordinator.workers[..coordinator.num_queues as usize]
1823 .iter()
1824 .map(|worker| {
1825 worker.state().map(|worker| {
1826 if let Some(ready) = worker.state.ready() {
1827 let pending_tx_completions = ready
1830 .state
1831 .pending_tx_completions
1832 .iter()
1833 .map(|pending| pending.transaction_id)
1834 .chain(ready.state.pending_tx_packets.iter().filter_map(
1835 |inflight| {
1836 (inflight.pending_packet_count > 0)
1837 .then_some(inflight.transaction_id)
1838 },
1839 ))
1840 .collect();
1841
1842 saved_state::Channel {
1843 pending_tx_completions,
1844 in_use_rx: {
1845 ready
1846 .state
1847 .rx_bufs
1848 .allocated()
1849 .map(|id| saved_state::Rx {
1850 buffers: id.collect(),
1851 })
1852 .collect()
1853 },
1854 }
1855 } else {
1856 saved_state::Channel {
1857 pending_tx_completions: Vec::new(),
1858 in_use_rx: Vec::new(),
1859 }
1860 }
1861 })
1862 })
1863 .collect();
1864
1865 let guest_vf_state = match primary.guest_vf_state {
1866 PrimaryChannelGuestVfState::Initializing
1867 | PrimaryChannelGuestVfState::Unavailable
1868 | PrimaryChannelGuestVfState::Available { .. } => {
1869 saved_state::GuestVfState::NoState
1870 }
1871 PrimaryChannelGuestVfState::UnavailableFromAvailable
1872 | PrimaryChannelGuestVfState::AvailableAdvertised => {
1873 saved_state::GuestVfState::AvailableAdvertised
1874 }
1875 PrimaryChannelGuestVfState::Ready => saved_state::GuestVfState::Ready,
1876 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
1877 to_guest,
1878 id,
1879 } => saved_state::GuestVfState::DataPathSwitchPending {
1880 to_guest,
1881 id,
1882 result: None,
1883 },
1884 PrimaryChannelGuestVfState::DataPathSwitchPending {
1885 to_guest,
1886 id,
1887 result,
1888 } => saved_state::GuestVfState::DataPathSwitchPending {
1889 to_guest,
1890 id,
1891 result,
1892 },
1893 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
1894 | PrimaryChannelGuestVfState::DataPathSwitched
1895 | PrimaryChannelGuestVfState::DataPathSynthetic => {
1896 saved_state::GuestVfState::DataPathSwitched
1897 }
1898 PrimaryChannelGuestVfState::Restoring(saved_state) => saved_state,
1899 };
1900
1901 let worker_0_packet_filter = coordinator.workers[0]
1902 .state()
1903 .unwrap()
1904 .channel
1905 .packet_filter;
1906 saved_state::Primary::Ready(saved_state::ReadyPrimary {
1907 version: ready.buffers.version as u32,
1908 receive_buffer: ready.buffers.recv_buffer.saved_state(),
1909 send_buffer: ready.buffers.send_buffer.as_ref().map(|sb| {
1910 saved_state::SendBuffer {
1911 gpadl_id: sb.gpadl.id(),
1912 }
1913 }),
1914 rndis_state,
1915 guest_vf_state,
1916 offload_config,
1917 pending_offload_change: primary.pending_offload_change,
1918 control_messages,
1919 rss_state,
1920 channels,
1921 ndis_config: {
1922 let NdisConfig { mtu, capabilities } = ready.buffers.ndis_config;
1923 saved_state::NdisConfig {
1924 mtu,
1925 capabilities: capabilities.into(),
1926 }
1927 },
1928 ndis_version: {
1929 let NdisVersion { major, minor } = ready.buffers.ndis_version;
1930 saved_state::NdisVersion { major, minor }
1931 },
1932 tx_spread_sent: primary.tx_spread_sent,
1933 guest_link_down: !primary.guest_link_up,
1934 pending_link_action,
1935 packet_filter: Some(worker_0_packet_filter),
1936 })
1937 }
1938 WorkerState::WaitingForCoordinator(None) => {
1939 unreachable!("valid ready state")
1940 }
1941 };
1942
1943 let state = saved_state::OpenState { primary };
1944 Some(state)
1945 } else {
1946 None
1947 };
1948
1949 saved_state::SavedState { open }
1950 }
1951}
1952
1953#[derive(Debug, Error)]
1954enum WorkerError {
1955 #[error("packet error")]
1956 Packet(#[source] PacketError),
1957 #[error("unexpected packet order: {0}")]
1958 UnexpectedPacketOrder(#[source] PacketOrderError),
1959 #[error("unknown rndis message type: {0}")]
1960 UnknownRndisMessageType(u32),
1961 #[error("junk after rndis packet message: {0:#x}")]
1962 NonRndisPacketAfterPacket(u32),
1963 #[error("memory access error")]
1964 Access(#[from] AccessError),
1965 #[error("rndis message too small")]
1966 RndisMessageTooSmall,
1967 #[error("unsupported rndis behavior")]
1968 UnsupportedRndisBehavior,
1969 #[error("vmbus queue error")]
1970 Queue(#[from] queue::Error),
1971 #[error("too many control messages")]
1972 TooManyControlMessages,
1973 #[error("invalid rndis packet completion")]
1974 InvalidRndisPacketCompletion,
1975 #[error("missing transaction id")]
1976 MissingTransactionId,
1977 #[error("invalid gpadl")]
1978 InvalidGpadl(#[source] guestmem::InvalidGpn),
1979 #[error("gpadl error")]
1980 GpadlError(#[source] GuestMemoryError),
1981 #[error("gpa direct error")]
1982 GpaDirectError(#[source] GuestMemoryError),
1983 #[error("endpoint")]
1984 Endpoint(#[source] anyhow::Error),
1985 #[error("message not supported on sub channel: {0}")]
1986 NotSupportedOnSubChannel(u32),
1987 #[error("the ring buffer ran out of space, which should not be possible")]
1988 OutOfSpace,
1989 #[error("send/receive buffer error")]
1990 Buffer(#[from] BufferError),
1991 #[error("invalid rndis state")]
1992 InvalidRndisState,
1993 #[error("rndis message type not implemented")]
1994 RndisMessageTypeNotImplemented,
1995 #[error("invalid TCP header offset")]
1996 InvalidTcpHeaderOffset,
1997 #[error("cancelled")]
1998 Cancelled(task_control::Cancelled),
1999 #[error("tearing down because send/receive buffer is revoked")]
2000 BufferRevoked,
2001 #[error("endpoint requires queue restart: {0}")]
2002 EndpointRequiresQueueRestart(#[source] anyhow::Error),
2003 #[error("Failed to send message to coordinator")]
2004 CoordinatorMessageSendFailed(#[source] TrySendError<CoordinatorMessage>),
2005}
2006
2007impl From<task_control::Cancelled> for WorkerError {
2008 fn from(value: task_control::Cancelled) -> Self {
2009 Self::Cancelled(value)
2010 }
2011}
2012
2013#[derive(Debug, Error)]
2014enum OpenError {
2015 #[error("error establishing ring buffer")]
2016 Ring(#[source] vmbus_channel::gpadl_ring::Error),
2017 #[error("error establishing vmbus queue")]
2018 Queue(#[source] queue::Error),
2019}
2020
2021#[derive(Debug, Error)]
2022enum PacketError {
2023 #[error("UnknownType {0}")]
2024 UnknownType(u32),
2025 #[error("Access")]
2026 Access(#[source] AccessError),
2027 #[error("ExternalData")]
2028 ExternalData(#[source] ExternalDataError),
2029 #[error("InvalidSendBufferIndex")]
2030 InvalidSendBufferIndex,
2031}
2032
2033#[derive(Debug, Error)]
2034enum PacketOrderError {
2035 #[error("Invalid PacketData")]
2036 InvalidPacketData,
2037 #[error("Unexpected RndisPacket")]
2038 UnexpectedRndisPacket,
2039 #[error("SendNdisVersion already exists")]
2040 SendNdisVersionExists,
2041 #[error("SendNdisConfig already exists")]
2042 SendNdisConfigExists,
2043 #[error("SendReceiveBuffer already exists")]
2044 SendReceiveBufferExists,
2045 #[error("SendReceiveBuffer missing MTU")]
2046 SendReceiveBufferMissingMTU,
2047 #[error("SendSendBuffer already exists")]
2048 SendSendBufferExists,
2049 #[error("SwitchDataPathCompletion during PrimaryChannelState")]
2050 SwitchDataPathCompletionPrimaryChannelState,
2051}
2052
2053#[derive(Debug)]
2054enum PacketData {
2055 Init(protocol::MessageInit),
2056 SendNdisVersion(protocol::Message1SendNdisVersion),
2057 SendReceiveBuffer(protocol::Message1SendReceiveBuffer),
2058 SendSendBuffer(protocol::Message1SendSendBuffer),
2059 RevokeReceiveBuffer(protocol::Message1RevokeReceiveBuffer),
2060 RevokeSendBuffer(protocol::Message1RevokeSendBuffer),
2061 RndisPacket(protocol::Message1SendRndisPacket),
2062 RndisPacketComplete(protocol::Message1SendRndisPacketComplete),
2063 SendNdisConfig(protocol::Message2SendNdisConfig),
2064 SwitchDataPath(protocol::Message4SwitchDataPath),
2065 OidQueryEx(protocol::Message5OidQueryEx),
2066 SubChannelRequest(protocol::Message5SubchannelRequest),
2067 SendVfAssociationCompletion,
2068 SwitchDataPathCompletion,
2069}
2070
2071#[derive(Debug)]
2072struct Packet<'a> {
2073 data: PacketData,
2074 transaction_id: Option<u64>,
2075 external_data: &'a MultiPagedRangeBuf,
2076}
2077
2078type PacketReader<'a> = PagedRangesReader<'a, MultiPagedRangeIter<'a>>;
2079
2080impl Packet<'_> {
2081 fn rndis_reader<'a>(&'a self, mem: &'a GuestMemory) -> PacketReader<'a> {
2082 PagedRanges::new(self.external_data.iter()).reader(mem)
2083 }
2084}
2085
2086fn read_packet_data<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
2087 reader: &mut impl MemoryRead,
2088) -> Result<T, PacketError> {
2089 reader.read_plain().map_err(PacketError::Access)
2090}
2091
2092fn parse_packet<'a, T: RingMem>(
2093 packet_ref: &queue::PacketRef<'_, T>,
2094 send_buffer: Option<&SendBuffer>,
2095 version: Option<Version>,
2096 external_data: &'a mut MultiPagedRangeBuf,
2097) -> Result<Packet<'a>, PacketError> {
2098 external_data.clear();
2099 let packet = match packet_ref.as_ref() {
2100 IncomingPacket::Data(data) => data,
2101 IncomingPacket::Completion(completion) => {
2102 let data = if completion.transaction_id() == VF_ASSOCIATION_TRANSACTION_ID {
2103 PacketData::SendVfAssociationCompletion
2104 } else if completion.transaction_id() == SWITCH_DATA_PATH_TRANSACTION_ID {
2105 PacketData::SwitchDataPathCompletion
2106 } else {
2107 let mut reader = completion.reader();
2108 let header: protocol::MessageHeader =
2109 reader.read_plain().map_err(PacketError::Access)?;
2110 match header.message_type {
2111 protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE => {
2112 PacketData::RndisPacketComplete(read_packet_data(&mut reader)?)
2113 }
2114 typ => return Err(PacketError::UnknownType(typ)),
2115 }
2116 };
2117 return Ok(Packet {
2118 data,
2119 transaction_id: Some(completion.transaction_id()),
2120 external_data,
2121 });
2122 }
2123 };
2124
2125 let mut reader = packet.reader();
2126 let header: protocol::MessageHeader = reader.read_plain().map_err(PacketError::Access)?;
2127 let data = match header.message_type {
2128 protocol::MESSAGE_TYPE_INIT => PacketData::Init(read_packet_data(&mut reader)?),
2129 protocol::MESSAGE1_TYPE_SEND_NDIS_VERSION if version >= Some(Version::V1) => {
2130 PacketData::SendNdisVersion(read_packet_data(&mut reader)?)
2131 }
2132 protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2133 PacketData::SendReceiveBuffer(read_packet_data(&mut reader)?)
2134 }
2135 protocol::MESSAGE1_TYPE_REVOKE_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2136 PacketData::RevokeReceiveBuffer(read_packet_data(&mut reader)?)
2137 }
2138 protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER if version >= Some(Version::V1) => {
2139 PacketData::SendSendBuffer(read_packet_data(&mut reader)?)
2140 }
2141 protocol::MESSAGE1_TYPE_REVOKE_SEND_BUFFER if version >= Some(Version::V1) => {
2142 PacketData::RevokeSendBuffer(read_packet_data(&mut reader)?)
2143 }
2144 protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET if version >= Some(Version::V1) => {
2145 let message: protocol::Message1SendRndisPacket = read_packet_data(&mut reader)?;
2146 if message.send_buffer_section_index != 0xffffffff {
2147 let send_buffer_suballocation = send_buffer
2148 .ok_or(PacketError::InvalidSendBufferIndex)?
2149 .gpadl
2150 .first()
2151 .unwrap()
2152 .try_subrange(
2153 message.send_buffer_section_index as usize * 6144,
2154 message.send_buffer_section_size as usize,
2155 )
2156 .ok_or(PacketError::InvalidSendBufferIndex)?;
2157
2158 external_data.push_range(send_buffer_suballocation);
2159 }
2160 PacketData::RndisPacket(message)
2161 }
2162 protocol::MESSAGE2_TYPE_SEND_NDIS_CONFIG if version >= Some(Version::V2) => {
2163 PacketData::SendNdisConfig(read_packet_data(&mut reader)?)
2164 }
2165 protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH if version >= Some(Version::V4) => {
2166 PacketData::SwitchDataPath(read_packet_data(&mut reader)?)
2167 }
2168 protocol::MESSAGE5_TYPE_OID_QUERY_EX if version >= Some(Version::V5) => {
2169 PacketData::OidQueryEx(read_packet_data(&mut reader)?)
2170 }
2171 protocol::MESSAGE5_TYPE_SUB_CHANNEL if version >= Some(Version::V5) => {
2172 PacketData::SubChannelRequest(read_packet_data(&mut reader)?)
2173 }
2174 typ => return Err(PacketError::UnknownType(typ)),
2175 };
2176 packet
2177 .read_external_ranges(external_data)
2178 .map_err(PacketError::ExternalData)?;
2179 Ok(Packet {
2180 data,
2181 transaction_id: packet.transaction_id(),
2182 external_data,
2183 })
2184}
2185
2186#[derive(Debug, Copy, Clone)]
2187struct NvspMessage {
2188 buf: [u64; protocol::PACKET_SIZE_V61 / 8],
2189 size: PacketSize,
2190}
2191
2192impl NvspMessage {
2193 fn new<P: IntoBytes + Immutable + KnownLayout>(
2194 size: PacketSize,
2195 message_type: u32,
2196 data: P,
2197 ) -> Self {
2198 const {
2205 assert!(
2206 size_of::<P>() <= protocol::PACKET_SIZE_V1 - size_of::<protocol::MessageHeader>(),
2207 "packet might not fit in message"
2208 )
2209 };
2210 let mut message = NvspMessage {
2211 buf: [0; protocol::PACKET_SIZE_V61 / 8],
2212 size,
2213 };
2214 let header = protocol::MessageHeader { message_type };
2215 header.write_to_prefix(message.buf.as_mut_bytes()).unwrap();
2216 data.write_to_prefix(
2217 &mut message.buf.as_mut_bytes()[size_of::<protocol::MessageHeader>()..],
2218 )
2219 .unwrap();
2220 message
2221 }
2222
2223 fn aligned_payload(&self) -> &[u64] {
2224 let len = match self.size {
2227 PacketSize::V1 => const { protocol::PACKET_SIZE_V1.div_ceil(8) },
2228 PacketSize::V61 => const { protocol::PACKET_SIZE_V61.div_ceil(8) },
2229 };
2230 &self.buf[..len]
2231 }
2232}
2233
2234impl<T: RingMem> NetChannel<T> {
2235 fn message<P: IntoBytes + Immutable + KnownLayout>(
2236 &self,
2237 message_type: u32,
2238 data: P,
2239 ) -> NvspMessage {
2240 NvspMessage::new(self.packet_size, message_type, data)
2241 }
2242
2243 fn send_completion(
2244 &mut self,
2245 transaction_id: Option<u64>,
2246 message: Option<&NvspMessage>,
2247 ) -> Result<(), WorkerError> {
2248 match transaction_id {
2249 None => Ok(()),
2250 Some(transaction_id) => Ok(self
2251 .queue
2252 .split()
2253 .1
2254 .batched()
2255 .try_write_aligned(
2256 transaction_id,
2257 OutgoingPacketType::Completion,
2258 message.map_or(&[], |m| m.aligned_payload()),
2259 )
2260 .map_err(|err| match err {
2261 queue::TryWriteError::Full(_) => WorkerError::OutOfSpace,
2262 queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2263 })?),
2264 }
2265 }
2266}
2267
2268static SUPPORTED_VERSIONS: &[Version] = &[
2269 Version::V1,
2270 Version::V2,
2271 Version::V4,
2272 Version::V5,
2273 Version::V6,
2274 Version::V61,
2275];
2276
2277fn check_version(requested_version: u32) -> Option<Version> {
2278 SUPPORTED_VERSIONS
2279 .iter()
2280 .find(|version| **version as u32 == requested_version)
2281 .copied()
2282}
2283
2284#[derive(Debug)]
2285struct ReceiveBuffer {
2286 gpadl: GpadlView,
2287 id: u16,
2288 count: u32,
2289 sub_allocation_size: u32,
2290}
2291
2292#[derive(Debug, Error)]
2293enum BufferError {
2294 #[error("unsupported suballocation size {0}")]
2295 UnsupportedSuballocationSize(u32),
2296 #[error("unaligned gpadl")]
2297 UnalignedGpadl,
2298 #[error("unknown gpadl ID")]
2299 UnknownGpadlId(#[from] UnknownGpadlId),
2300}
2301
2302impl ReceiveBuffer {
2303 fn new(
2304 gpadl_map: &GpadlMapView,
2305 gpadl_id: GpadlId,
2306 id: u16,
2307 sub_allocation_size: u32,
2308 ) -> Result<Self, BufferError> {
2309 if sub_allocation_size < sub_allocation_size_for_mtu(DEFAULT_MTU) {
2310 return Err(BufferError::UnsupportedSuballocationSize(
2311 sub_allocation_size,
2312 ));
2313 }
2314 let gpadl = gpadl_map.map(gpadl_id)?;
2315 let range = gpadl
2316 .contiguous_aligned()
2317 .ok_or(BufferError::UnalignedGpadl)?;
2318 let num_sub_allocations = range.len() as u32 / sub_allocation_size;
2319 if num_sub_allocations == 0 {
2320 return Err(BufferError::UnsupportedSuballocationSize(
2321 sub_allocation_size,
2322 ));
2323 }
2324 let recv_buffer = Self {
2325 gpadl,
2326 id,
2327 count: num_sub_allocations,
2328 sub_allocation_size,
2329 };
2330 Ok(recv_buffer)
2331 }
2332
2333 fn range(&self, index: u32) -> PagedRange<'_> {
2334 self.gpadl.first().unwrap().subrange(
2335 (index * self.sub_allocation_size) as usize,
2336 self.sub_allocation_size as usize,
2337 )
2338 }
2339
2340 fn transfer_page_range(&self, index: u32, len: usize) -> ring::TransferPageRange {
2341 assert!(len <= self.sub_allocation_size as usize);
2342 ring::TransferPageRange {
2343 byte_offset: index * self.sub_allocation_size,
2344 byte_count: len as u32,
2345 }
2346 }
2347
2348 fn saved_state(&self) -> saved_state::ReceiveBuffer {
2349 saved_state::ReceiveBuffer {
2350 gpadl_id: self.gpadl.id(),
2351 id: self.id,
2352 sub_allocation_size: self.sub_allocation_size,
2353 }
2354 }
2355}
2356
2357#[derive(Debug)]
2358struct SendBuffer {
2359 gpadl: GpadlView,
2360}
2361
2362impl SendBuffer {
2363 fn new(gpadl_map: &GpadlMapView, gpadl_id: GpadlId) -> Result<Self, BufferError> {
2364 let gpadl = gpadl_map.map(gpadl_id)?;
2365 gpadl
2366 .contiguous_aligned()
2367 .ok_or(BufferError::UnalignedGpadl)?;
2368 Ok(Self { gpadl })
2369 }
2370}
2371
2372impl<T: RingMem> NetChannel<T> {
2373 fn handle_rndis_message(
2375 &mut self,
2376 state: &mut ActiveState,
2377 message_type: u32,
2378 mut reader: PacketReader<'_>,
2379 ) -> Result<(), WorkerError> {
2380 assert_ne!(
2381 message_type,
2382 rndisprot::MESSAGE_TYPE_PACKET_MSG,
2383 "handled elsewhere"
2384 );
2385 let control = state
2386 .primary
2387 .as_mut()
2388 .ok_or(WorkerError::NotSupportedOnSubChannel(message_type))?;
2389
2390 if message_type == rndisprot::MESSAGE_TYPE_HALT_MSG {
2391 return Ok(());
2393 }
2394
2395 const CONTROL_MESSAGE_MAX_QUEUED_BYTES: usize = 100 * 1024;
2400 if reader.len() == 0 {
2401 return Err(WorkerError::RndisMessageTooSmall);
2402 }
2403 if CONTROL_MESSAGE_MAX_QUEUED_BYTES - control.control_messages_len < reader.len() {
2406 return Err(WorkerError::TooManyControlMessages);
2407 }
2408
2409 control.control_messages_len += reader.len();
2410 control.control_messages.push_back(ControlMessage {
2411 message_type,
2412 data: reader.read_all()?.into(),
2413 });
2414
2415 Ok(())
2418 }
2419
2420 fn handle_rndis_packet_messages(
2426 &mut self,
2427 buffers: &ChannelBuffers,
2428 state: &mut ActiveState,
2429 id: TxId,
2430 mut message_len: usize,
2431 mut reader: PacketReader<'_>,
2432 segments: &mut Vec<TxSegment>,
2433 ) -> Result<usize, WorkerError> {
2434 let mut num_packets = 0;
2438 loop {
2439 let next_message_offset = message_len
2440 .checked_sub(size_of::<rndisprot::MessageHeader>())
2441 .ok_or(WorkerError::RndisMessageTooSmall)?;
2442
2443 self.handle_rndis_packet_message(
2444 id,
2445 reader.clone(),
2446 &buffers.mem,
2447 segments,
2448 &mut state.stats,
2449 )?;
2450 num_packets += 1;
2451
2452 reader.skip(next_message_offset)?;
2453 if reader.len() == 0 {
2454 break;
2455 }
2456 let header: rndisprot::MessageHeader = reader.read_plain()?;
2457 if header.message_type != rndisprot::MESSAGE_TYPE_PACKET_MSG {
2458 return Err(WorkerError::NonRndisPacketAfterPacket(header.message_type));
2459 }
2460 message_len = header.message_length as usize;
2461 }
2462 Ok(num_packets)
2463 }
2464
2465 fn handle_rndis_packet_message(
2467 &mut self,
2468 id: TxId,
2469 reader: PacketReader<'_>,
2470 mem: &GuestMemory,
2471 segments: &mut Vec<TxSegment>,
2472 stats: &mut QueueStats,
2473 ) -> Result<(), WorkerError> {
2474 let headers = reader
2476 .clone()
2477 .into_inner()
2478 .paged_ranges()
2479 .next()
2480 .ok_or(WorkerError::RndisMessageTooSmall)?;
2481 let mut data = reader.into_inner();
2482 let request: rndisprot::Packet = headers.reader(mem).read_plain()?;
2483 if request.num_oob_data_elements != 0
2484 || request.oob_data_length != 0
2485 || request.oob_data_offset != 0
2486 || request.vc_handle != 0
2487 {
2488 return Err(WorkerError::UnsupportedRndisBehavior);
2489 }
2490
2491 if data.len() < request.data_offset as usize
2492 || (data.len() - request.data_offset as usize) < request.data_length as usize
2493 || request.data_length == 0
2494 {
2495 return Err(WorkerError::RndisMessageTooSmall);
2496 }
2497
2498 data.skip(request.data_offset as usize);
2499 data.truncate(request.data_length as usize);
2500
2501 let mut metadata = net_backend::TxMetadata {
2502 id,
2503 len: request.data_length,
2504 ..Default::default()
2505 };
2506
2507 if request.per_packet_info_length != 0 {
2508 let mut ppi = headers
2509 .try_subrange(
2510 request.per_packet_info_offset as usize,
2511 request.per_packet_info_length as usize,
2512 )
2513 .ok_or(WorkerError::RndisMessageTooSmall)?;
2514 while !ppi.is_empty() {
2515 let h: rndisprot::PerPacketInfo = ppi.reader(mem).read_plain()?;
2516 if h.size == 0 {
2517 return Err(WorkerError::RndisMessageTooSmall);
2518 }
2519 let (this, rest) = ppi
2520 .try_split(h.size as usize)
2521 .ok_or(WorkerError::RndisMessageTooSmall)?;
2522 let (_, d) = this
2523 .try_split(h.per_packet_information_offset as usize)
2524 .ok_or(WorkerError::RndisMessageTooSmall)?;
2525 match h.typ {
2526 rndisprot::PPI_TCP_IP_CHECKSUM => {
2527 let n: rndisprot::TxTcpIpChecksumInfo = d.reader(mem).read_plain()?;
2528
2529 metadata.flags.set_offload_tcp_checksum(
2530 (n.is_ipv4() || n.is_ipv6()) && n.tcp_checksum(),
2531 );
2532 metadata.flags.set_offload_udp_checksum(
2533 (n.is_ipv4() || n.is_ipv6()) && !n.tcp_checksum() && n.udp_checksum(),
2534 );
2535 metadata
2536 .flags
2537 .set_offload_ip_header_checksum(n.is_ipv4() && n.ip_header_checksum());
2538 metadata.flags.set_is_ipv4(n.is_ipv4());
2539 metadata.flags.set_is_ipv6(n.is_ipv6() && !n.is_ipv4());
2540 metadata.l2_len = ETHERNET_HEADER_LEN as u8;
2541 if metadata.flags.offload_tcp_checksum()
2542 || metadata.flags.offload_udp_checksum()
2543 {
2544 metadata.l3_len = if n.tcp_header_offset() >= metadata.l2_len as u16 {
2545 n.tcp_header_offset() - metadata.l2_len as u16
2546 } else if n.is_ipv4() {
2547 let mut reader = data.clone().reader(mem);
2548 reader.skip(metadata.l2_len as usize)?;
2549 let mut b = 0;
2550 reader.read(std::slice::from_mut(&mut b))?;
2551 (b as u16 >> 4) * 4
2552 } else {
2553 40
2555 };
2556 }
2557 }
2558 rndisprot::PPI_LSO => {
2559 let n: rndisprot::TcpLsoInfo = d.reader(mem).read_plain()?;
2560
2561 metadata.flags.set_offload_tcp_segmentation(true);
2562 metadata.flags.set_offload_tcp_checksum(true);
2563 metadata.flags.set_offload_ip_header_checksum(n.is_ipv4());
2564 metadata.flags.set_is_ipv4(n.is_ipv4());
2565 metadata.flags.set_is_ipv6(n.is_ipv6() && !n.is_ipv4());
2566 metadata.l2_len = ETHERNET_HEADER_LEN as u8;
2567 if n.tcp_header_offset() < metadata.l2_len as u16 {
2568 return Err(WorkerError::InvalidTcpHeaderOffset);
2569 }
2570 metadata.l3_len = n.tcp_header_offset() - metadata.l2_len as u16;
2571 metadata.l4_len = {
2572 let mut reader = data.clone().reader(mem);
2573 reader
2574 .skip(metadata.l2_len as usize + metadata.l3_len as usize + 12)?;
2575 let mut b = 0;
2576 reader.read(std::slice::from_mut(&mut b))?;
2577 (b >> 4) * 4
2578 };
2579 metadata.max_tcp_segment_size = n.mss() as u16;
2580
2581 if request.data_length >= rndisprot::LSO_MAX_OFFLOAD_SIZE {
2582 stats.tx_invalid_lso_packets.increment();
2584 }
2585 }
2586 _ => {}
2587 }
2588 ppi = rest;
2589 }
2590 }
2591
2592 let start = segments.len();
2593 for range in data.paged_ranges().flat_map(|r| r.ranges()) {
2594 let range = range.map_err(WorkerError::InvalidGpadl)?;
2595 segments.push(TxSegment {
2596 ty: net_backend::TxSegmentType::Tail,
2597 gpa: range.start,
2598 len: range.len() as u32,
2599 });
2600 }
2601
2602 metadata.segment_count = (segments.len() - start) as u8;
2603
2604 stats.tx_packets.increment();
2605 if metadata.flags.offload_tcp_checksum() || metadata.flags.offload_udp_checksum() {
2606 stats.tx_checksum_packets.increment();
2607 }
2608 if metadata.flags.offload_tcp_segmentation() {
2609 stats.tx_lso_packets.increment();
2610 }
2611
2612 segments[start].ty = net_backend::TxSegmentType::Head(metadata);
2613
2614 Ok(())
2615 }
2616
2617 fn guest_vf_is_available(
2620 &mut self,
2621 guest_vf_id: Option<u32>,
2622 version: Version,
2623 config: NdisConfig,
2624 ) -> Result<bool, WorkerError> {
2625 let serial_number = guest_vf_id.map(|vfid| self.adapter.get_guest_vf_serial_number(vfid));
2626 if version >= Version::V4 && config.capabilities.sriov() {
2627 tracing::info!(
2628 available = serial_number.is_some(),
2629 serial_number,
2630 "sending VF association message"
2631 );
2632 let message = {
2634 self.message(
2635 protocol::MESSAGE4_TYPE_SEND_VF_ASSOCIATION,
2636 protocol::Message4SendVfAssociation {
2637 vf_allocated: if serial_number.is_some() { 1 } else { 0 },
2638 serial_number: serial_number.unwrap_or(0),
2639 },
2640 )
2641 };
2642 self.queue
2643 .split()
2644 .1
2645 .batched()
2646 .try_write_aligned(
2647 VF_ASSOCIATION_TRANSACTION_ID,
2648 OutgoingPacketType::InBandWithCompletion,
2649 message.aligned_payload(),
2650 )
2651 .map_err(|err| match err {
2652 queue::TryWriteError::Full(len) => {
2653 tracing::error!(len, "failed to write vf association message");
2654 WorkerError::OutOfSpace
2655 }
2656 queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2657 })?;
2658 Ok(true)
2659 } else {
2660 tracing::info!(
2661 available = serial_number.is_some(),
2662 serial_number,
2663 major = version.major(),
2664 minor = version.minor(),
2665 sriov_capable = config.capabilities.sriov(),
2666 "Skipping NvspMessage4TypeSendVFAssociation message"
2667 );
2668 Ok(false)
2669 }
2670 }
2671
2672 fn guest_send_indirection_table(&mut self, version: Version, num_channels_opened: u32) {
2674 if version < Version::V5 {
2676 return;
2677 }
2678
2679 #[repr(C)]
2680 #[derive(IntoBytes, Immutable, KnownLayout)]
2681 struct SendIndirectionMsg {
2682 pub message: protocol::Message5SendIndirectionTable,
2683 pub send_indirection_table:
2684 [u32; VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES as usize],
2685 }
2686
2687 let send_indirection_table_offset = offset_of!(SendIndirectionMsg, send_indirection_table)
2689 + size_of::<protocol::MessageHeader>();
2690 let mut data = SendIndirectionMsg {
2691 message: protocol::Message5SendIndirectionTable {
2692 table_entry_count: VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES,
2693 table_offset: send_indirection_table_offset as u32,
2694 },
2695 send_indirection_table: Default::default(),
2696 };
2697
2698 for i in 0..data.send_indirection_table.len() {
2699 data.send_indirection_table[i] = i as u32 % num_channels_opened;
2700 }
2701
2702 let header = protocol::MessageHeader {
2703 message_type: protocol::MESSAGE5_TYPE_SEND_INDIRECTION_TABLE,
2704 };
2705 let result = self
2706 .queue
2707 .split()
2708 .1
2709 .try_write(&queue::OutgoingPacket {
2710 transaction_id: 0,
2711 packet_type: OutgoingPacketType::InBandNoCompletion,
2712 payload: &[header.as_bytes(), data.as_bytes()],
2713 })
2714 .map_err(|err| match err {
2715 queue::TryWriteError::Full(len) => {
2716 tracing::error!(len, "failed to write send indirection table message");
2717 WorkerError::OutOfSpace
2718 }
2719 queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2720 });
2721 if let Err(err) = result {
2722 tracing::error!(
2723 error = &err as &dyn std::error::Error,
2724 "Failed to notify guest about the send indirection table"
2725 );
2726 }
2727 }
2728
2729 fn guest_vf_data_path_switched_to_synthetic(&mut self) {
2732 let header = protocol::MessageHeader {
2733 message_type: protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH,
2734 };
2735 let data = protocol::Message4SwitchDataPath {
2736 active_data_path: protocol::DataPath::SYNTHETIC.0,
2737 };
2738 let result = self
2739 .queue
2740 .split()
2741 .1
2742 .try_write(&queue::OutgoingPacket {
2743 transaction_id: SWITCH_DATA_PATH_TRANSACTION_ID,
2744 packet_type: OutgoingPacketType::InBandWithCompletion,
2745 payload: &[header.as_bytes(), data.as_bytes()],
2746 })
2747 .map_err(|err| match err {
2748 queue::TryWriteError::Full(len) => {
2749 tracing::error!(len, "failed to write switch data path message");
2750 WorkerError::OutOfSpace
2751 }
2752 queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2753 });
2754 if let Err(err) = result {
2755 tracing::error!(
2756 error = &err as &dyn std::error::Error,
2757 "Failed to notify guest that data path is now synthetic"
2758 );
2759 }
2760 }
2761
2762 async fn handle_state_change(
2764 &mut self,
2765 primary: &mut PrimaryChannelState,
2766 buffers: &ChannelBuffers,
2767 ) -> Result<Option<CoordinatorMessage>, WorkerError> {
2768 if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
2774 if primary.rndis_state == RndisState::Operational {
2776 if self.guest_vf_is_available(Some(vfid), buffers.version, buffers.ndis_config)? {
2777 primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
2778 return Ok(Some(CoordinatorMessage::Update(
2779 CoordinatorMessageUpdateType {
2780 guest_vf_state: true,
2781 ..Default::default()
2782 },
2783 )));
2784 } else if let Some(true) = primary.is_data_path_switched {
2785 tracing::error!(
2786 "Data path switched, but current guest negotiation does not support VTL0 VF"
2787 );
2788 }
2789 }
2790 return Ok(None);
2791 }
2792 loop {
2793 primary.guest_vf_state = match primary.guest_vf_state {
2794 PrimaryChannelGuestVfState::UnavailableFromAvailable => {
2795 if primary.rndis_state == RndisState::Operational {
2797 self.guest_vf_is_available(None, buffers.version, buffers.ndis_config)?;
2798 }
2799 PrimaryChannelGuestVfState::Unavailable
2800 }
2801 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
2802 to_guest,
2803 id,
2804 } => {
2805 self.send_completion(id, None)?;
2807 if to_guest {
2808 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
2809 } else {
2810 PrimaryChannelGuestVfState::UnavailableFromAvailable
2811 }
2812 }
2813 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
2814 self.guest_vf_data_path_switched_to_synthetic();
2816 PrimaryChannelGuestVfState::UnavailableFromAvailable
2817 }
2818 PrimaryChannelGuestVfState::DataPathSynthetic => {
2819 self.guest_vf_data_path_switched_to_synthetic();
2821 PrimaryChannelGuestVfState::Ready
2822 }
2823 PrimaryChannelGuestVfState::DataPathSwitchPending {
2824 to_guest,
2825 id,
2826 result,
2827 } => {
2828 let result = result.expect("DataPathSwitchPending should have been processed");
2829 self.send_completion(id, None)?;
2831
2832 match (to_guest, result) {
2833 (true, true) => PrimaryChannelGuestVfState::DataPathSwitched,
2835 (true, false) => {
2837 tracing::error!(
2838 "Failure switching to guest VF, remaining on synthetic"
2839 );
2840 PrimaryChannelGuestVfState::DataPathSynthetic
2841 }
2842 (false, true) => PrimaryChannelGuestVfState::Ready,
2844 (false, false) => {
2846 tracing::error!(
2847 "Failure when guest requested switch back to synthetic"
2848 );
2849 PrimaryChannelGuestVfState::DataPathSwitched
2850 }
2851 }
2852 }
2853 PrimaryChannelGuestVfState::Initializing
2854 | PrimaryChannelGuestVfState::Restoring(_) => {
2855 panic!("Invalid guest VF state: {}", primary.guest_vf_state)
2856 }
2857 _ => break,
2858 };
2859 }
2860 Ok(None)
2861 }
2862
2863 fn handle_rndis_control_message(
2866 &mut self,
2867 primary: &mut PrimaryChannelState,
2868 buffers: &ChannelBuffers,
2869 message_type: u32,
2870 mut reader: impl MemoryRead + Clone,
2871 id: u32,
2872 ) -> Result<(), WorkerError> {
2873 let mem = &buffers.mem;
2874 let buffer_range = &buffers.recv_buffer.range(id);
2875 match message_type {
2876 rndisprot::MESSAGE_TYPE_INITIALIZE_MSG => {
2877 if primary.rndis_state != RndisState::Initializing {
2878 return Err(WorkerError::InvalidRndisState);
2879 }
2880
2881 let request: rndisprot::InitializeRequest = reader.read_plain()?;
2882
2883 tracing::trace!(
2884 ?request,
2885 "handling control message MESSAGE_TYPE_INITIALIZE_MSG"
2886 );
2887
2888 primary.rndis_state = RndisState::Operational;
2889
2890 let mut writer = buffer_range.writer(mem);
2891 let message_length = write_rndis_message(
2892 &mut writer,
2893 rndisprot::MESSAGE_TYPE_INITIALIZE_CMPLT,
2894 0,
2895 &rndisprot::InitializeComplete {
2896 request_id: request.request_id,
2897 status: rndisprot::STATUS_SUCCESS,
2898 major_version: rndisprot::MAJOR_VERSION,
2899 minor_version: rndisprot::MINOR_VERSION,
2900 device_flags: rndisprot::DF_CONNECTIONLESS,
2901 medium: rndisprot::MEDIUM_802_3,
2902 max_packets_per_message: 8,
2903 max_transfer_size: 0xEFFFFFFF,
2904 packet_alignment_factor: 3,
2905 af_list_offset: 0,
2906 af_list_size: 0,
2907 },
2908 )?;
2909 self.send_rndis_control_message(buffers, id, message_length)?;
2910 if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
2911 if self.guest_vf_is_available(
2912 Some(vfid),
2913 buffers.version,
2914 buffers.ndis_config,
2915 )? {
2916 primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
2928 self.send_coordinator_update_vf();
2929 } else if let Some(true) = primary.is_data_path_switched {
2930 tracing::error!(
2931 "Data path switched, but current guest negotiation does not support VTL0 VF"
2932 );
2933 }
2934 }
2935 }
2936 rndisprot::MESSAGE_TYPE_QUERY_MSG => {
2937 let request: rndisprot::QueryRequest = reader.read_plain()?;
2938
2939 tracing::trace!(?request, "handling control message MESSAGE_TYPE_QUERY_MSG");
2940
2941 let (header, body) = buffer_range
2942 .try_split(
2943 size_of::<rndisprot::MessageHeader>()
2944 + size_of::<rndisprot::QueryComplete>(),
2945 )
2946 .ok_or(WorkerError::RndisMessageTooSmall)?;
2947 let (status, tx) = match self.adapter.handle_oid_query(
2948 buffers,
2949 primary,
2950 request.oid,
2951 body.writer(mem),
2952 ) {
2953 Ok(tx) => (rndisprot::STATUS_SUCCESS, tx),
2954 Err(err) => (err.as_status(), 0),
2955 };
2956
2957 let message_length = write_rndis_message(
2958 &mut header.writer(mem),
2959 rndisprot::MESSAGE_TYPE_QUERY_CMPLT,
2960 tx,
2961 &rndisprot::QueryComplete {
2962 request_id: request.request_id,
2963 status,
2964 information_buffer_offset: size_of::<rndisprot::QueryComplete>() as u32,
2965 information_buffer_length: tx as u32,
2966 },
2967 )?;
2968 self.send_rndis_control_message(buffers, id, message_length)?;
2969 }
2970 rndisprot::MESSAGE_TYPE_SET_MSG => {
2971 let request: rndisprot::SetRequest = reader.read_plain()?;
2972
2973 tracing::trace!(?request, "handling control message MESSAGE_TYPE_SET_MSG");
2974
2975 let status = match self.adapter.handle_oid_set(primary, request.oid, reader) {
2976 Ok((restart_endpoint, packet_filter)) => {
2977 if restart_endpoint {
2980 self.restart = Some(CoordinatorMessage::Restart);
2981 }
2982 if let Some(filter) = packet_filter {
2983 if self.packet_filter != filter {
2984 self.packet_filter = filter;
2985 self.send_coordinator_update_filter();
2986 }
2987 }
2988 rndisprot::STATUS_SUCCESS
2989 }
2990 Err(err) => {
2991 tracelimit::warn_ratelimited!(
2992 error = &err as &dyn std::error::Error,
2993 oid = ?request.oid,
2994 "oid set failure"
2995 );
2996 err.as_status()
2997 }
2998 };
2999
3000 let message_length = write_rndis_message(
3001 &mut buffer_range.writer(mem),
3002 rndisprot::MESSAGE_TYPE_SET_CMPLT,
3003 0,
3004 &rndisprot::SetComplete {
3005 request_id: request.request_id,
3006 status,
3007 },
3008 )?;
3009 self.send_rndis_control_message(buffers, id, message_length)?;
3010 }
3011 rndisprot::MESSAGE_TYPE_RESET_MSG => {
3012 return Err(WorkerError::RndisMessageTypeNotImplemented);
3013 }
3014 rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG => {
3015 return Err(WorkerError::RndisMessageTypeNotImplemented);
3016 }
3017 rndisprot::MESSAGE_TYPE_KEEPALIVE_MSG => {
3018 let request: rndisprot::KeepaliveRequest = reader.read_plain()?;
3019
3020 tracing::trace!(
3021 ?request,
3022 "handling control message MESSAGE_TYPE_KEEPALIVE_MSG"
3023 );
3024
3025 let message_length = write_rndis_message(
3026 &mut buffer_range.writer(mem),
3027 rndisprot::MESSAGE_TYPE_KEEPALIVE_CMPLT,
3028 0,
3029 &rndisprot::KeepaliveComplete {
3030 request_id: request.request_id,
3031 status: rndisprot::STATUS_SUCCESS,
3032 },
3033 )?;
3034 self.send_rndis_control_message(buffers, id, message_length)?;
3035 }
3036 rndisprot::MESSAGE_TYPE_SET_EX_MSG => {
3037 return Err(WorkerError::RndisMessageTypeNotImplemented);
3038 }
3039 _ => return Err(WorkerError::UnknownRndisMessageType(message_type)),
3040 };
3041 Ok(())
3042 }
3043
3044 fn try_send_rndis_message(
3045 &mut self,
3046 transaction_id: u64,
3047 channel_type: u32,
3048 recv_buffer_id: u16,
3049 transfer_pages: &[ring::TransferPageRange],
3050 ) -> Result<Option<usize>, WorkerError> {
3051 let message = self.message(
3052 protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET,
3053 protocol::Message1SendRndisPacket {
3054 channel_type,
3055 send_buffer_section_index: 0xffffffff,
3056 send_buffer_section_size: 0,
3057 },
3058 );
3059 let pending_send_size = match self.queue.split().1.batched().try_write_aligned(
3060 transaction_id,
3061 OutgoingPacketType::TransferPages(recv_buffer_id, transfer_pages),
3062 message.aligned_payload(),
3063 ) {
3064 Ok(()) => None,
3065 Err(queue::TryWriteError::Full(n)) => Some(n),
3066 Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
3067 };
3068 Ok(pending_send_size)
3069 }
3070
3071 fn send_rndis_control_message(
3072 &mut self,
3073 buffers: &ChannelBuffers,
3074 id: u32,
3075 message_length: usize,
3076 ) -> Result<(), WorkerError> {
3077 let result = self.try_send_rndis_message(
3078 id as u64,
3079 protocol::CONTROL_CHANNEL_TYPE,
3080 buffers.recv_buffer.id,
3081 std::slice::from_ref(&buffers.recv_buffer.transfer_page_range(id, message_length)),
3082 )?;
3083
3084 match result {
3086 None => Ok(()),
3087 Some(len) => {
3088 tracelimit::error_ratelimited!(len, "failed to write control message completion");
3089 Err(WorkerError::OutOfSpace)
3090 }
3091 }
3092 }
3093
3094 fn indicate_status(
3095 &mut self,
3096 buffers: &ChannelBuffers,
3097 id: u32,
3098 status: u32,
3099 payload: &[u8],
3100 ) -> Result<(), WorkerError> {
3101 let buffer = &buffers.recv_buffer.range(id);
3102 let mut writer = buffer.writer(&buffers.mem);
3103 let message_length = write_rndis_message(
3104 &mut writer,
3105 rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG,
3106 payload.len(),
3107 &rndisprot::IndicateStatus {
3108 status,
3109 status_buffer_length: payload.len() as u32,
3110 status_buffer_offset: if payload.is_empty() {
3111 0
3112 } else {
3113 size_of::<rndisprot::IndicateStatus>() as u32
3114 },
3115 },
3116 )?;
3117 writer.write(payload)?;
3118 self.send_rndis_control_message(buffers, id, message_length)?;
3119 Ok(())
3120 }
3121
3122 fn process_control_messages(
3125 &mut self,
3126 buffers: &ChannelBuffers,
3127 state: &mut ActiveState,
3128 ) -> Result<(), WorkerError> {
3129 let Some(primary) = &mut state.primary else {
3130 return Ok(());
3131 };
3132
3133 while !primary.control_messages.is_empty()
3134 || (primary.pending_offload_change && primary.rndis_state == RndisState::Operational)
3135 {
3136 if !self.queue.split().1.can_write(MIN_CONTROL_RING_SIZE)? {
3138 self.pending_send_size = MIN_CONTROL_RING_SIZE;
3139 break;
3140 }
3141 let Some(id) = primary.free_control_buffers.pop() else {
3142 break;
3143 };
3144
3145 assert!(state.rx_bufs.is_free(id.0));
3147 state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
3148
3149 if let Some(message) = primary.control_messages.pop_front() {
3150 primary.control_messages_len -= message.data.len();
3151 self.handle_rndis_control_message(
3152 primary,
3153 buffers,
3154 message.message_type,
3155 message.data.as_ref(),
3156 id.0,
3157 )?;
3158 } else if primary.pending_offload_change
3159 && primary.rndis_state == RndisState::Operational
3160 {
3161 let ndis_offload = primary.offload_config.ndis_offload();
3162 self.indicate_status(
3163 buffers,
3164 id.0,
3165 rndisprot::STATUS_TASK_OFFLOAD_CURRENT_CONFIG,
3166 &ndis_offload.as_bytes()[..ndis_offload.header.size.into()],
3167 )?;
3168 primary.pending_offload_change = false;
3169 } else {
3170 unreachable!();
3171 }
3172 }
3173 Ok(())
3174 }
3175
3176 fn send_coordinator_update_message(&mut self, guest_vf: bool, packet_filter: bool) {
3177 if self.restart.is_none() {
3178 self.restart = Some(CoordinatorMessage::Update(CoordinatorMessageUpdateType {
3179 guest_vf_state: guest_vf,
3180 filter_state: packet_filter,
3181 }));
3182 } else if let Some(CoordinatorMessage::Restart) = self.restart {
3183 } else if let Some(CoordinatorMessage::Update(ref mut update)) = self.restart {
3187 update.guest_vf_state |= guest_vf;
3189 update.filter_state |= packet_filter;
3190 }
3191 }
3192
3193 fn send_coordinator_update_vf(&mut self) {
3194 self.send_coordinator_update_message(true, false);
3195 }
3196
3197 fn send_coordinator_update_filter(&mut self) {
3198 self.send_coordinator_update_message(false, true);
3199 }
3200}
3201
3202fn write_rndis_message<T: IntoBytes + Immutable + KnownLayout>(
3204 writer: &mut impl MemoryWrite,
3205 message_type: u32,
3206 extra: usize,
3207 payload: &T,
3208) -> Result<usize, AccessError> {
3209 let message_length = size_of::<rndisprot::MessageHeader>() + size_of_val(payload) + extra;
3210 writer.write(
3211 rndisprot::MessageHeader {
3212 message_type,
3213 message_length: message_length as u32,
3214 }
3215 .as_bytes(),
3216 )?;
3217 writer.write(payload.as_bytes())?;
3218 Ok(message_length)
3219}
3220
3221#[derive(Debug, Error)]
3222enum OidError {
3223 #[error(transparent)]
3224 Access(#[from] AccessError),
3225 #[error("unknown oid")]
3226 UnknownOid,
3227 #[error("invalid oid input, bad field {0}")]
3228 InvalidInput(&'static str),
3229 #[error("bad ndis version")]
3230 BadVersion,
3231 #[error("feature {0} not supported")]
3232 NotSupported(&'static str),
3233}
3234
3235impl OidError {
3236 fn as_status(&self) -> u32 {
3237 match self {
3238 OidError::UnknownOid | OidError::NotSupported(_) => rndisprot::STATUS_NOT_SUPPORTED,
3239 OidError::BadVersion => rndisprot::STATUS_BAD_VERSION,
3240 OidError::InvalidInput(_) => rndisprot::STATUS_INVALID_DATA,
3241 OidError::Access(_) => rndisprot::STATUS_FAILURE,
3242 }
3243 }
3244}
3245
3246const DEFAULT_MTU: u32 = 1514;
3247const MIN_MTU: u32 = DEFAULT_MTU;
3248const MAX_MTU: u32 = 9216;
3249
3250const ETHERNET_HEADER_LEN: u32 = 14;
3251
3252impl Adapter {
3253 fn get_guest_vf_serial_number(&self, vfid: u32) -> u32 {
3254 if let Some(guest_os_id) = self.get_guest_os_id.as_ref().map(|f| f()) {
3255 if guest_os_id
3258 .microsoft()
3259 .unwrap_or(HvGuestOsMicrosoft::from(0))
3260 .os_id()
3261 == HvGuestOsMicrosoftIds::WINDOWS_NT.0
3262 {
3263 self.adapter_index
3264 } else {
3265 vfid
3266 }
3267 } else {
3268 vfid
3269 }
3270 }
3271
3272 fn handle_oid_query(
3273 &self,
3274 buffers: &ChannelBuffers,
3275 primary: &PrimaryChannelState,
3276 oid: rndisprot::Oid,
3277 mut writer: impl MemoryWrite,
3278 ) -> Result<usize, OidError> {
3279 tracing::debug!(?oid, "oid query");
3280 let available_len = writer.len();
3281 match oid {
3282 rndisprot::Oid::OID_GEN_SUPPORTED_LIST => {
3283 let supported_oids_common = &[
3284 rndisprot::Oid::OID_GEN_SUPPORTED_LIST,
3285 rndisprot::Oid::OID_GEN_HARDWARE_STATUS,
3286 rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED,
3287 rndisprot::Oid::OID_GEN_MEDIA_IN_USE,
3288 rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD,
3289 rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD,
3290 rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE,
3291 rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE,
3292 rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE,
3293 rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE,
3294 rndisprot::Oid::OID_GEN_LINK_SPEED,
3295 rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE,
3296 rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE,
3297 rndisprot::Oid::OID_GEN_VENDOR_ID,
3298 rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION,
3299 rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION,
3300 rndisprot::Oid::OID_GEN_DRIVER_VERSION,
3301 rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER,
3302 rndisprot::Oid::OID_GEN_PROTOCOL_OPTIONS,
3303 rndisprot::Oid::OID_GEN_MAC_OPTIONS,
3304 rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS,
3305 rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS,
3306 rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES,
3307 rndisprot::Oid::OID_GEN_FRIENDLY_NAME,
3308 rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS,
3310 rndisprot::Oid::OID_802_3_CURRENT_ADDRESS,
3311 rndisprot::Oid::OID_802_3_MULTICAST_LIST,
3312 rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE,
3313 rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT,
3315 rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION,
3316 rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS,
3317 rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER,
3322 ];
3323
3324 let supported_oids_6 = &[
3326 rndisprot::Oid::OID_GEN_LINK_PARAMETERS,
3328 rndisprot::Oid::OID_GEN_LINK_STATE,
3329 rndisprot::Oid::OID_GEN_MAX_LINK_SPEED,
3330 rndisprot::Oid::OID_GEN_BYTES_RCV,
3332 rndisprot::Oid::OID_GEN_BYTES_XMIT,
3333 rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS,
3335 rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION,
3336 rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES,
3337 rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG,
3338 ];
3341
3342 let supported_oids_63 = &[
3343 rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES,
3344 rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS,
3345 ];
3346
3347 match buffers.ndis_version.major {
3348 5 => {
3349 writer.write(supported_oids_common.as_bytes())?;
3350 }
3351 6 => {
3352 writer.write(supported_oids_common.as_bytes())?;
3353 writer.write(supported_oids_6.as_bytes())?;
3354 if buffers.ndis_version.minor >= 30 {
3355 writer.write(supported_oids_63.as_bytes())?;
3356 }
3357 }
3358 _ => return Err(OidError::BadVersion),
3359 }
3360 }
3361 rndisprot::Oid::OID_GEN_HARDWARE_STATUS => {
3362 let status: u32 = 0; writer.write(status.as_bytes())?;
3364 }
3365 rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED | rndisprot::Oid::OID_GEN_MEDIA_IN_USE => {
3366 writer.write(rndisprot::MEDIUM_802_3.as_bytes())?;
3367 }
3368 rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD
3369 | rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD
3370 | rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE => {
3371 let len: u32 = buffers.ndis_config.mtu - ETHERNET_HEADER_LEN;
3372 writer.write(len.as_bytes())?;
3373 }
3374 rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE
3375 | rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE
3376 | rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE => {
3377 let len: u32 = buffers.ndis_config.mtu;
3378 writer.write(len.as_bytes())?;
3379 }
3380 rndisprot::Oid::OID_GEN_LINK_SPEED => {
3381 let speed: u32 = (self.link_speed / 100) as u32; writer.write(speed.as_bytes())?;
3383 }
3384 rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE
3385 | rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE => {
3386 writer.write((256u32 * 1024).as_bytes())?
3388 }
3389 rndisprot::Oid::OID_GEN_VENDOR_ID => {
3390 writer.write(0x0000155du32.as_bytes())?;
3393 }
3394 rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION => writer.write(b"Microsoft Corporation")?,
3395 rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION
3396 | rndisprot::Oid::OID_GEN_DRIVER_VERSION => {
3397 writer.write(0x0100u16.as_bytes())? }
3399 rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => writer.write(0u32.as_bytes())?,
3400 rndisprot::Oid::OID_GEN_MAC_OPTIONS => {
3401 let options: u32 = rndisprot::MAC_OPTION_COPY_LOOKAHEAD_DATA
3402 | rndisprot::MAC_OPTION_TRANSFERS_NOT_PEND
3403 | rndisprot::MAC_OPTION_NO_LOOPBACK;
3404 writer.write(options.as_bytes())?;
3405 }
3406 rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS => {
3407 writer.write(rndisprot::MEDIA_STATE_CONNECTED.as_bytes())?;
3408 }
3409 rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS => writer.write(u32::MAX.as_bytes())?,
3410 rndisprot::Oid::OID_GEN_FRIENDLY_NAME => {
3411 let name16: Vec<u16> = "Network Device".encode_utf16().collect();
3412 let mut name = rndisprot::FriendlyName::new_zeroed();
3413 name.name[..name16.len()].copy_from_slice(&name16);
3414 writer.write(name.as_bytes())?
3415 }
3416 rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS
3417 | rndisprot::Oid::OID_802_3_CURRENT_ADDRESS => {
3418 writer.write(&self.mac_address.to_bytes())?
3419 }
3420 rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE => {
3421 writer.write(0u32.as_bytes())?;
3422 }
3423 rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT
3424 | rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION
3425 | rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS => writer.write(0u32.as_bytes())?,
3426
3427 rndisprot::Oid::OID_GEN_LINK_STATE => {
3429 let link_state = rndisprot::LinkState {
3430 header: rndisprot::NdisObjectHeader {
3431 object_type: rndisprot::NdisObjectType::DEFAULT,
3432 revision: 1,
3433 size: size_of::<rndisprot::LinkState>() as u16,
3434 },
3435 media_connect_state: 1, media_duplex_state: 0, padding: 0,
3438 xmit_link_speed: self.link_speed,
3439 rcv_link_speed: self.link_speed,
3440 pause_functions: 0, auto_negotiation_flags: 0,
3442 };
3443 writer.write(link_state.as_bytes())?;
3444 }
3445 rndisprot::Oid::OID_GEN_MAX_LINK_SPEED => {
3446 let link_speed = rndisprot::LinkSpeed {
3447 xmit: self.link_speed,
3448 rcv: self.link_speed,
3449 };
3450 writer.write(link_speed.as_bytes())?;
3451 }
3452 rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES => {
3453 let ndis_offload = self.offload_support.ndis_offload();
3454 writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3455 }
3456 rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG => {
3457 let ndis_offload = primary.offload_config.ndis_offload();
3458 writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3459 }
3460 rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3461 writer.write(
3462 &rndisprot::NdisOffloadEncapsulation {
3463 header: rndisprot::NdisObjectHeader {
3464 object_type: rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3465 revision: 1,
3466 size: rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1 as u16,
3467 },
3468 ipv4_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3469 ipv4_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3470 ipv4_header_size: ETHERNET_HEADER_LEN,
3471 ipv6_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3472 ipv6_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3473 ipv6_header_size: ETHERNET_HEADER_LEN,
3474 }
3475 .as_bytes()[..rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1],
3476 )?;
3477 }
3478 rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES => {
3479 writer.write(
3480 &rndisprot::NdisReceiveScaleCapabilities {
3481 header: rndisprot::NdisObjectHeader {
3482 object_type: rndisprot::NdisObjectType::RSS_CAPABILITIES,
3483 revision: 2,
3484 size: rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2
3485 as u16,
3486 },
3487 capabilities_flags: rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV4
3488 | rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV6
3489 | rndisprot::NDIS_HASH_FUNCTION_TOEPLITZ,
3490 number_of_interrupt_messages: 1,
3491 number_of_receive_queues: self.max_queues.into(),
3492 number_of_indirection_table_entries: if self.indirection_table_size != 0 {
3493 self.indirection_table_size
3494 } else {
3495 128
3498 },
3499 padding: 0,
3500 }
3501 .as_bytes()[..rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2],
3502 )?;
3503 }
3504 _ => {
3505 tracelimit::warn_ratelimited!(?oid, "query for unknown OID");
3506 return Err(OidError::UnknownOid);
3507 }
3508 };
3509 Ok(available_len - writer.len())
3510 }
3511
3512 fn handle_oid_set(
3513 &self,
3514 primary: &mut PrimaryChannelState,
3515 oid: rndisprot::Oid,
3516 reader: impl MemoryRead + Clone,
3517 ) -> Result<(bool, Option<u32>), OidError> {
3518 tracing::debug!(?oid, "oid set");
3519
3520 let mut restart_endpoint = false;
3521 let mut packet_filter = None;
3522 match oid {
3523 rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => {
3524 packet_filter = self.oid_set_packet_filter(reader)?;
3525 }
3526 rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS => {
3527 self.oid_set_offload_parameters(reader, primary)?;
3528 }
3529 rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3530 self.oid_set_offload_encapsulation(reader)?;
3531 }
3532 rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER => {
3533 self.oid_set_rndis_config_parameter(reader, primary)?;
3534 }
3535 rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES => {
3536 }
3538 rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS => {
3539 let rss_was_enabled = self.oid_set_rss_parameters(reader, primary)?;
3540
3541 if rss_was_enabled || primary.rss_state.is_some() {
3548 restart_endpoint = true;
3549 }
3550 }
3551 _ => {
3552 tracelimit::warn_ratelimited!(?oid, "set of unknown OID");
3553 return Err(OidError::UnknownOid);
3554 }
3555 }
3556 Ok((restart_endpoint, packet_filter))
3557 }
3558
3559 fn oid_set_rss_parameters(
3560 &self,
3561 mut reader: impl MemoryRead + Clone,
3562 primary: &mut PrimaryChannelState,
3563 ) -> Result<bool, OidError> {
3564 let mut params = rndisprot::NdisReceiveScaleParameters::new_zeroed();
3566 let len = reader.len().min(size_of_val(¶ms));
3567 reader.clone().read(&mut params.as_mut_bytes()[..len])?;
3568
3569 let rss_was_enabled = primary.rss_state.is_some();
3570
3571 if ((params.flags & NDIS_RSS_PARAM_FLAG_DISABLE_RSS) != 0)
3572 || ((params.hash_information & NDIS_HASH_FUNCTION_MASK) == 0)
3573 {
3574 primary.rss_state = None;
3575 return Ok(rss_was_enabled);
3576 }
3577
3578 if params.hash_secret_key_size != 40 {
3579 return Err(OidError::InvalidInput("hash_secret_key_size"));
3580 }
3581 if params.indirection_table_size % 4 != 0 {
3582 return Err(OidError::InvalidInput("indirection_table_size"));
3583 }
3584 let indirection_table_size =
3585 (params.indirection_table_size / 4).min(self.indirection_table_size) as usize;
3586 let mut key = [0; 40];
3587 let mut indirection_table = vec![0u32; self.indirection_table_size as usize];
3588 reader
3589 .clone()
3590 .skip(params.hash_secret_key_offset as usize)?
3591 .read(&mut key)?;
3592 reader
3593 .skip(params.indirection_table_offset as usize)?
3594 .read(indirection_table[..indirection_table_size].as_mut_bytes())?;
3595 tracelimit::info_ratelimited!(?indirection_table, "OID_GEN_RECEIVE_SCALE_PARAMETERS");
3596 if indirection_table
3597 .iter()
3598 .any(|&x| x >= self.max_queues as u32)
3599 {
3600 return Err(OidError::InvalidInput("indirection_table"));
3601 }
3602 let (indir_init, indir_uninit) = indirection_table.split_at_mut(indirection_table_size);
3603 for (src, dest) in std::iter::repeat_with(|| indir_init.iter().copied())
3604 .flatten()
3605 .zip(indir_uninit)
3606 {
3607 *dest = src;
3608 }
3609 primary.rss_state = Some(RssState {
3610 key,
3611 indirection_table: indirection_table.iter().map(|&x| x as u16).collect(),
3612 });
3613 Ok(rss_was_enabled)
3614 }
3615
3616 fn oid_set_packet_filter(
3617 &self,
3618 reader: impl MemoryRead + Clone,
3619 ) -> Result<Option<u32>, OidError> {
3620 let filter: rndisprot::RndisPacketFilterOidValue = reader.clone().read_plain()?;
3621 tracing::debug!(filter, "set packet filter");
3622 Ok(Some(filter))
3623 }
3624
3625 fn oid_set_offload_parameters(
3626 &self,
3627 reader: impl MemoryRead + Clone,
3628 primary: &mut PrimaryChannelState,
3629 ) -> Result<(), OidError> {
3630 let offload: rndisprot::NdisOffloadParameters = read_ndis_object(
3631 reader,
3632 rndisprot::NdisObjectType::DEFAULT,
3633 1,
3634 rndisprot::NDIS_SIZEOF_OFFLOAD_PARAMETERS_REVISION_1,
3635 )?;
3636
3637 tracing::debug!(?offload, "offload parameters");
3638 let rndisprot::NdisOffloadParameters {
3639 header: _,
3640 ipv4_checksum,
3641 tcp4_checksum,
3642 udp4_checksum,
3643 tcp6_checksum,
3644 udp6_checksum,
3645 lsov1,
3646 ipsec_v1: _,
3647 lsov2_ipv4,
3648 lsov2_ipv6,
3649 tcp_connection_ipv4: _,
3650 tcp_connection_ipv6: _,
3651 reserved: _,
3652 flags: _,
3653 } = offload;
3654
3655 if lsov1 == rndisprot::OffloadParametersSimple::ENABLED {
3656 return Err(OidError::NotSupported("lsov1"));
3657 }
3658 if let Some((tx, rx)) = ipv4_checksum.tx_rx() {
3659 primary.offload_config.checksum_tx.ipv4_header = tx;
3660 primary.offload_config.checksum_rx.ipv4_header = rx;
3661 }
3662 if let Some((tx, rx)) = tcp4_checksum.tx_rx() {
3663 primary.offload_config.checksum_tx.tcp4 = tx;
3664 primary.offload_config.checksum_rx.tcp4 = rx;
3665 }
3666 if let Some((tx, rx)) = tcp6_checksum.tx_rx() {
3667 primary.offload_config.checksum_tx.tcp6 = tx;
3668 primary.offload_config.checksum_rx.tcp6 = rx;
3669 }
3670 if let Some((tx, rx)) = udp4_checksum.tx_rx() {
3671 primary.offload_config.checksum_tx.udp4 = tx;
3672 primary.offload_config.checksum_rx.udp4 = rx;
3673 }
3674 if let Some((tx, rx)) = udp6_checksum.tx_rx() {
3675 primary.offload_config.checksum_tx.udp6 = tx;
3676 primary.offload_config.checksum_rx.udp6 = rx;
3677 }
3678 if let Some(enable) = lsov2_ipv4.enable() {
3679 primary.offload_config.lso4 = enable;
3680 }
3681 if let Some(enable) = lsov2_ipv6.enable() {
3682 primary.offload_config.lso6 = enable;
3683 }
3684 primary
3685 .offload_config
3686 .mask_to_supported(&self.offload_support);
3687 primary.pending_offload_change = true;
3688 Ok(())
3689 }
3690
3691 fn oid_set_offload_encapsulation(
3692 &self,
3693 reader: impl MemoryRead + Clone,
3694 ) -> Result<(), OidError> {
3695 let encap: rndisprot::NdisOffloadEncapsulation = read_ndis_object(
3696 reader,
3697 rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3698 1,
3699 rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1,
3700 )?;
3701 if encap.ipv4_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3702 && (encap.ipv4_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3703 || encap.ipv4_header_size != ETHERNET_HEADER_LEN)
3704 {
3705 return Err(OidError::NotSupported("ipv4 encap"));
3706 }
3707 if encap.ipv6_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3708 && (encap.ipv6_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3709 || encap.ipv6_header_size != ETHERNET_HEADER_LEN)
3710 {
3711 return Err(OidError::NotSupported("ipv6 encap"));
3712 }
3713 Ok(())
3714 }
3715
3716 fn oid_set_rndis_config_parameter(
3717 &self,
3718 reader: impl MemoryRead + Clone,
3719 primary: &mut PrimaryChannelState,
3720 ) -> Result<(), OidError> {
3721 let info: rndisprot::RndisConfigParameterInfo = reader.clone().read_plain()?;
3722 if info.name_length > 255 {
3723 return Err(OidError::InvalidInput("name_length"));
3724 }
3725 if info.value_length > 255 {
3726 return Err(OidError::InvalidInput("value_length"));
3727 }
3728 let name = reader
3729 .clone()
3730 .skip(info.name_offset as usize)?
3731 .read_n::<u16>(info.name_length as usize / 2)?;
3732 let name = String::from_utf16(&name).map_err(|_| OidError::InvalidInput("name"))?;
3733 let mut value = reader;
3734 value.skip(info.value_offset as usize)?;
3735 let mut value = value.limit(info.value_length as usize);
3736 match info.parameter_type {
3737 rndisprot::NdisParameterType::STRING => {
3738 let value = value.read_n::<u16>(info.value_length as usize / 2)?;
3739 let value =
3740 String::from_utf16(&value).map_err(|_| OidError::InvalidInput("value"))?;
3741 let as_num = value.as_bytes().first().map_or(0, |c| c - b'0');
3742 let tx = as_num & 1 != 0;
3743 let rx = as_num & 2 != 0;
3744
3745 tracing::debug!(name, value, "rndis config");
3746 match name.as_str() {
3747 "*IPChecksumOffloadIPv4" => {
3748 primary.offload_config.checksum_tx.ipv4_header = tx;
3749 primary.offload_config.checksum_rx.ipv4_header = rx;
3750 }
3751 "*LsoV2IPv4" => {
3752 primary.offload_config.lso4 = as_num != 0;
3753 }
3754 "*LsoV2IPv6" => {
3755 primary.offload_config.lso6 = as_num != 0;
3756 }
3757 "*TCPChecksumOffloadIPv4" => {
3758 primary.offload_config.checksum_tx.tcp4 = tx;
3759 primary.offload_config.checksum_rx.tcp4 = rx;
3760 }
3761 "*TCPChecksumOffloadIPv6" => {
3762 primary.offload_config.checksum_tx.tcp6 = tx;
3763 primary.offload_config.checksum_rx.tcp6 = rx;
3764 }
3765 "*UDPChecksumOffloadIPv4" => {
3766 primary.offload_config.checksum_tx.udp4 = tx;
3767 primary.offload_config.checksum_rx.udp4 = rx;
3768 }
3769 "*UDPChecksumOffloadIPv6" => {
3770 primary.offload_config.checksum_tx.udp6 = tx;
3771 primary.offload_config.checksum_rx.udp6 = rx;
3772 }
3773 _ => {}
3774 }
3775 primary
3776 .offload_config
3777 .mask_to_supported(&self.offload_support);
3778 }
3779 rndisprot::NdisParameterType::INTEGER => {
3780 let value: u32 = value.read_plain()?;
3781 tracing::debug!(name, value, "rndis config");
3782 }
3783 parameter_type => tracelimit::warn_ratelimited!(
3784 name,
3785 ?parameter_type,
3786 "unhandled rndis config parameter type"
3787 ),
3788 }
3789 Ok(())
3790 }
3791}
3792
3793fn read_ndis_object<T: IntoBytes + FromBytes + Debug + Immutable + KnownLayout>(
3794 mut reader: impl MemoryRead,
3795 object_type: rndisprot::NdisObjectType,
3796 min_revision: u8,
3797 min_size: usize,
3798) -> Result<T, OidError> {
3799 let mut buffer = T::new_zeroed();
3800 let sent_size = reader.len();
3801 let len = sent_size.min(size_of_val(&buffer));
3802 reader.read(&mut buffer.as_mut_bytes()[..len])?;
3803 validate_ndis_object_header(
3804 &rndisprot::NdisObjectHeader::read_from_prefix(buffer.as_bytes())
3805 .unwrap()
3806 .0, sent_size,
3808 object_type,
3809 min_revision,
3810 min_size,
3811 )?;
3812 Ok(buffer)
3813}
3814
3815fn validate_ndis_object_header(
3816 header: &rndisprot::NdisObjectHeader,
3817 sent_size: usize,
3818 object_type: rndisprot::NdisObjectType,
3819 min_revision: u8,
3820 min_size: usize,
3821) -> Result<(), OidError> {
3822 if header.object_type != object_type {
3823 return Err(OidError::InvalidInput("header.object_type"));
3824 }
3825 if sent_size < header.size as usize {
3826 return Err(OidError::InvalidInput("header.size"));
3827 }
3828 if header.revision < min_revision {
3829 return Err(OidError::InvalidInput("header.revision"));
3830 }
3831 if (header.size as usize) < min_size {
3832 return Err(OidError::InvalidInput("header.size"));
3833 }
3834 Ok(())
3835}
3836
3837struct Coordinator {
3838 recv: mpsc::Receiver<CoordinatorMessage>,
3839 channel_control: ChannelControl,
3840 restart: bool,
3841 workers: Vec<TaskControl<NetQueue, Worker<GpadlRingMem>>>,
3842 buffers: Option<Arc<ChannelBuffers>>,
3843 num_queues: u16,
3844 active_packet_filter: u32,
3845 sleep_deadline: Option<Instant>,
3846}
3847
3848enum CoordinatorStatePendingVfState {
3852 Ready,
3854 Delay {
3856 timer: PolledTimer,
3857 delay_until: Instant,
3858 },
3859 Pending,
3861}
3862
3863struct CoordinatorState {
3864 endpoint: Box<dyn Endpoint>,
3865 adapter: Arc<Adapter>,
3866 virtual_function: Option<Box<dyn VirtualFunction>>,
3867 pending_vf_state: CoordinatorStatePendingVfState,
3868}
3869
3870impl InspectTaskMut<Coordinator> for CoordinatorState {
3871 fn inspect_mut(
3872 &mut self,
3873 req: inspect::Request<'_>,
3874 mut coordinator: Option<&mut Coordinator>,
3875 ) {
3876 let mut resp = req.respond();
3877
3878 let adapter = self.adapter.as_ref();
3879 resp.field("mac_address", adapter.mac_address)
3880 .field("max_queues", adapter.max_queues)
3881 .sensitivity_field(
3882 "offload_support",
3883 SensitivityLevel::Safe,
3884 &adapter.offload_support,
3885 )
3886 .field_mut_with("ring_size_limit", |v| -> anyhow::Result<_> {
3887 if let Some(v) = v {
3888 let v: usize = v.parse()?;
3889 adapter.ring_size_limit.store(v, Ordering::Relaxed);
3890 if let Some(this) = &mut coordinator {
3892 for worker in &mut this.workers {
3893 worker.update_with(|_, _| ());
3894 }
3895 }
3896 }
3897 Ok(adapter.ring_size_limit.load(Ordering::Relaxed))
3898 });
3899
3900 resp.field("endpoint_type", self.endpoint.endpoint_type())
3901 .field(
3902 "endpoint_max_queues",
3903 self.endpoint.multiqueue_support().max_queues,
3904 )
3905 .sensitivity_field_mut("endpoint", SensitivityLevel::Safe, self.endpoint.as_mut());
3906
3907 if let Some(coordinator) = coordinator {
3908 resp.sensitivity_child("queues", SensitivityLevel::Safe, |req| {
3909 let mut resp = req.respond();
3910 for (i, q) in coordinator.workers[..coordinator.num_queues as usize]
3911 .iter_mut()
3912 .enumerate()
3913 {
3914 resp.field_mut(&i.to_string(), q);
3915 }
3916 });
3917
3918 resp.merge(inspect::adhoc_mut(|req| {
3920 let deferred = req.defer();
3921 coordinator.workers[0].update_with(|_, worker| {
3922 let Some(worker) = worker.as_deref() else {
3923 return;
3924 };
3925 if let Some(state) = worker.state.ready() {
3926 deferred.respond(|resp| {
3927 resp.merge(&state.buffers);
3928 resp.sensitivity_field(
3929 "primary_channel_state",
3930 SensitivityLevel::Safe,
3931 &state.state.primary,
3932 )
3933 .sensitivity_field(
3934 "packet_filter",
3935 SensitivityLevel::Safe,
3936 inspect::AsHex(worker.channel.packet_filter),
3937 );
3938 });
3939 }
3940 })
3941 }));
3942 }
3943 }
3944}
3945
3946impl AsyncRun<Coordinator> for CoordinatorState {
3947 async fn run(
3948 &mut self,
3949 stop: &mut StopTask<'_>,
3950 coordinator: &mut Coordinator,
3951 ) -> Result<(), task_control::Cancelled> {
3952 coordinator.process(stop, self).await
3953 }
3954}
3955
3956impl Coordinator {
3957 async fn process(
3958 &mut self,
3959 stop: &mut StopTask<'_>,
3960 state: &mut CoordinatorState,
3961 ) -> Result<(), task_control::Cancelled> {
3962 loop {
3963 if self.restart {
3964 stop.until_stopped(self.stop_workers()).await?;
3965 if let Err(err) = self
3968 .restart_queues(state)
3969 .instrument(tracing::info_span!("netvsp_restart_queues"))
3970 .await
3971 {
3972 tracing::error!(
3973 error = &err as &dyn std::error::Error,
3974 "failed to restart queues"
3975 );
3976 }
3977 if let Some(primary) = self.primary_mut() {
3978 primary.is_data_path_switched =
3979 state.endpoint.get_data_path_to_guest_vf().await.ok();
3980 tracing::info!(
3981 is_data_path_switched = primary.is_data_path_switched,
3982 "Query data path state"
3983 );
3984 }
3985 self.restore_guest_vf_state(state).await;
3986 self.restart = false;
3987 }
3988
3989 for worker in &mut self.workers[1..] {
3992 worker.start();
3993 }
3994 if !self.workers[0].is_running()
3995 && self.workers[0].state().is_none_or(|worker| {
3996 !matches!(worker.state, WorkerState::WaitingForCoordinator(_))
3997 })
3998 {
3999 self.workers[0].start();
4000 }
4001
4002 enum Message {
4003 Internal(CoordinatorMessage),
4004 ChannelDisconnected,
4005 UpdateFromEndpoint(EndpointAction),
4006 UpdateFromVf(Rpc<(), ()>),
4007 OfferVfDevice,
4008 PendingVfStateComplete,
4009 TimerExpired,
4010 }
4011 let message = if matches!(
4012 state.pending_vf_state,
4013 CoordinatorStatePendingVfState::Pending
4014 ) {
4015 state
4018 .virtual_function
4019 .as_mut()
4020 .expect("Pending requires a VF")
4021 .guest_ready_for_device()
4022 .await;
4023 Message::PendingVfStateComplete
4024 } else {
4025 let timer_sleep = async {
4026 if let Some(deadline) = self.sleep_deadline {
4027 let mut timer = PolledTimer::new(&state.adapter.driver);
4028 timer.sleep_until(deadline).await;
4029 } else {
4030 pending::<()>().await;
4031 }
4032 Message::TimerExpired
4033 };
4034 let wait_for_message = async {
4035 let internal_msg = self
4036 .recv
4037 .next()
4038 .map(|x| x.map_or(Message::ChannelDisconnected, Message::Internal));
4039 let endpoint_restart = state
4040 .endpoint
4041 .wait_for_endpoint_action()
4042 .map(Message::UpdateFromEndpoint);
4043 if let Some(vf) = state.virtual_function.as_mut() {
4044 match state.pending_vf_state {
4045 CoordinatorStatePendingVfState::Ready
4046 | CoordinatorStatePendingVfState::Delay { .. } => {
4047 let offer_device = async {
4048 if let CoordinatorStatePendingVfState::Delay {
4049 timer,
4050 delay_until,
4051 } = &mut state.pending_vf_state
4052 {
4053 timer.sleep_until(*delay_until).await;
4054 } else {
4055 pending::<()>().await;
4056 }
4057 Message::OfferVfDevice
4058 };
4059 (
4060 internal_msg,
4061 offer_device,
4062 endpoint_restart,
4063 vf.wait_for_state_change().map(Message::UpdateFromVf),
4064 timer_sleep,
4065 )
4066 .race()
4067 .await
4068 }
4069 CoordinatorStatePendingVfState::Pending => unreachable!(),
4070 }
4071 } else {
4072 (internal_msg, endpoint_restart, timer_sleep).race().await
4073 }
4074 };
4075
4076 stop.until_stopped(wait_for_message).await?
4077 };
4078 match message {
4079 Message::Internal(msg) => {
4080 self.handle_coordinator_message(msg, state).await;
4081 }
4082 Message::UpdateFromVf(rpc) => {
4083 rpc.handle(async |_| {
4084 self.update_guest_vf_state(state).await;
4085 })
4086 .await;
4087 }
4088 Message::OfferVfDevice => {
4089 self.workers[0].stop().await;
4090 if let Some(primary) = self.primary_mut() {
4091 if matches!(
4092 primary.guest_vf_state,
4093 PrimaryChannelGuestVfState::AvailableAdvertised
4094 ) {
4095 primary.guest_vf_state = PrimaryChannelGuestVfState::Ready;
4096 }
4097 }
4098
4099 state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4100 }
4101 Message::PendingVfStateComplete => {
4102 state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
4103 }
4104 Message::TimerExpired => {
4105 self.workers[0].stop().await;
4107 if let Some(primary) = self.primary_mut() {
4108 if let PendingLinkAction::Delay(up) = primary.pending_link_action {
4109 primary.pending_link_action = PendingLinkAction::Active(up);
4110 }
4111 }
4112 self.sleep_deadline = None;
4113 }
4114 Message::UpdateFromEndpoint(EndpointAction::RestartRequired) => self.restart = true,
4115 Message::UpdateFromEndpoint(EndpointAction::LinkStatusNotify(connect)) => {
4116 self.workers[0].stop().await;
4117
4118 if let Some(primary) = self.primary_mut() {
4126 primary.pending_link_action = PendingLinkAction::Active(connect);
4127 }
4128
4129 self.sleep_deadline = None;
4131 }
4132 Message::ChannelDisconnected => {
4133 break;
4134 }
4135 };
4136 }
4137 Ok(())
4138 }
4139
4140 async fn handle_coordinator_message(
4141 &mut self,
4142 msg: CoordinatorMessage,
4143 state: &mut CoordinatorState,
4144 ) {
4145 self.workers[0].stop().await;
4146 if let Some(worker) = self.workers[0].state_mut() {
4147 if matches!(worker.state, WorkerState::WaitingForCoordinator(_)) {
4148 let WorkerState::WaitingForCoordinator(Some(ready)) =
4149 std::mem::replace(&mut worker.state, WorkerState::WaitingForCoordinator(None))
4150 else {
4151 unreachable!("valid ready state")
4152 };
4153 let _ = std::mem::replace(&mut worker.state, WorkerState::Ready(ready));
4154 }
4155 }
4156 match msg {
4157 CoordinatorMessage::Update(update_type) => {
4158 if update_type.filter_state {
4159 self.stop_workers().await;
4160 self.active_packet_filter =
4161 self.workers[0].state().unwrap().channel.packet_filter;
4162 self.workers.iter_mut().skip(1).for_each(|worker| {
4163 if let Some(state) = worker.state_mut() {
4164 state.channel.packet_filter = self.active_packet_filter;
4165 tracing::debug!(
4166 packet_filter = ?self.active_packet_filter,
4167 channel_idx = state.channel_idx,
4168 "update packet filter"
4169 );
4170 }
4171 });
4172 }
4173
4174 if update_type.guest_vf_state {
4175 self.update_guest_vf_state(state).await;
4176 }
4177 }
4178 CoordinatorMessage::StartTimer(deadline) => {
4179 self.sleep_deadline = Some(deadline);
4180 }
4181 CoordinatorMessage::Restart => self.restart = true,
4182 }
4183 }
4184
4185 async fn stop_workers(&mut self) {
4186 for worker in &mut self.workers {
4187 worker.stop().await;
4188 }
4189 }
4190
4191 async fn restore_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
4192 let primary = match self.primary_mut() {
4193 Some(primary) => primary,
4194 None => return,
4195 };
4196
4197 let virtual_function = c_state.virtual_function.as_mut();
4199 let guest_vf_id = match &virtual_function {
4200 Some(vf) => vf.id().await,
4201 None => None,
4202 };
4203 if let Some(guest_vf_id) = guest_vf_id {
4204 match primary.guest_vf_state {
4206 PrimaryChannelGuestVfState::AvailableAdvertised
4207 | PrimaryChannelGuestVfState::Restoring(
4208 saved_state::GuestVfState::AvailableAdvertised,
4209 ) => {
4210 if !primary.is_data_path_switched.unwrap_or(false) {
4211 let timer = PolledTimer::new(&c_state.adapter.driver);
4212 c_state.pending_vf_state = CoordinatorStatePendingVfState::Delay {
4213 timer,
4214 delay_until: Instant::now() + VF_DEVICE_DELAY,
4215 };
4216 }
4217 }
4218 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4219 | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4220 | PrimaryChannelGuestVfState::Ready
4221 | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready)
4222 | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4223 | PrimaryChannelGuestVfState::Restoring(
4224 saved_state::GuestVfState::DataPathSwitchPending { .. },
4225 )
4226 | PrimaryChannelGuestVfState::DataPathSwitched
4227 | PrimaryChannelGuestVfState::Restoring(
4228 saved_state::GuestVfState::DataPathSwitched,
4229 )
4230 | PrimaryChannelGuestVfState::DataPathSynthetic => {
4231 c_state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4232 }
4233 _ => (),
4234 };
4235 if let PrimaryChannelGuestVfState::Restoring(
4237 saved_state::GuestVfState::DataPathSwitchPending {
4238 to_guest,
4239 id,
4240 result,
4241 },
4242 ) = primary.guest_vf_state
4243 {
4244 if result.is_some() {
4246 primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
4247 to_guest,
4248 id,
4249 result,
4250 };
4251 return;
4252 }
4253 }
4254 primary.guest_vf_state = match primary.guest_vf_state {
4255 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4256 | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4257 | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4258 | PrimaryChannelGuestVfState::Restoring(
4259 saved_state::GuestVfState::DataPathSwitchPending { .. },
4260 )
4261 | PrimaryChannelGuestVfState::DataPathSynthetic => {
4262 let (to_guest, id) = match primary.guest_vf_state {
4263 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4264 to_guest,
4265 id,
4266 }
4267 | PrimaryChannelGuestVfState::DataPathSwitchPending {
4268 to_guest, id, ..
4269 }
4270 | PrimaryChannelGuestVfState::Restoring(
4271 saved_state::GuestVfState::DataPathSwitchPending {
4272 to_guest, id, ..
4273 },
4274 ) => (to_guest, id),
4275 _ => (true, None),
4276 };
4277 if matches!(
4281 c_state.pending_vf_state,
4282 CoordinatorStatePendingVfState::Delay { .. }
4283 ) {
4284 c_state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4285 }
4286 let result = c_state.endpoint.set_data_path_to_guest_vf(to_guest).await;
4287 let result = if let Err(err) = result {
4288 tracing::error!(
4289 err = err.as_ref() as &dyn std::error::Error,
4290 to_guest,
4291 "Failed to switch guest VF data path"
4292 );
4293 false
4294 } else {
4295 primary.is_data_path_switched = Some(to_guest);
4296 true
4297 };
4298 match primary.guest_vf_state {
4299 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4300 ..
4301 }
4302 | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4303 | PrimaryChannelGuestVfState::Restoring(
4304 saved_state::GuestVfState::DataPathSwitchPending { .. },
4305 ) => PrimaryChannelGuestVfState::DataPathSwitchPending {
4306 to_guest,
4307 id,
4308 result: Some(result),
4309 },
4310 _ if result => PrimaryChannelGuestVfState::DataPathSwitched,
4311 _ => PrimaryChannelGuestVfState::DataPathSynthetic,
4312 }
4313 }
4314 PrimaryChannelGuestVfState::Initializing
4315 | PrimaryChannelGuestVfState::Unavailable
4316 | PrimaryChannelGuestVfState::UnavailableFromAvailable
4317 | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
4318 PrimaryChannelGuestVfState::Available { vfid: guest_vf_id }
4319 }
4320 PrimaryChannelGuestVfState::AvailableAdvertised
4321 | PrimaryChannelGuestVfState::Restoring(
4322 saved_state::GuestVfState::AvailableAdvertised,
4323 ) => {
4324 if !primary.is_data_path_switched.unwrap_or(false) {
4325 PrimaryChannelGuestVfState::AvailableAdvertised
4326 } else {
4327 PrimaryChannelGuestVfState::DataPathSwitched
4330 }
4331 }
4332 PrimaryChannelGuestVfState::DataPathSwitched
4333 | PrimaryChannelGuestVfState::Restoring(
4334 saved_state::GuestVfState::DataPathSwitched,
4335 ) => PrimaryChannelGuestVfState::DataPathSwitched,
4336 PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4337 PrimaryChannelGuestVfState::Ready
4338 }
4339 _ => primary.guest_vf_state,
4340 };
4341 } else {
4342 match primary.guest_vf_state {
4344 PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, .. }
4345 | PrimaryChannelGuestVfState::Restoring(
4346 saved_state::GuestVfState::DataPathSwitchPending { to_guest, .. },
4347 ) => {
4348 if !to_guest {
4349 if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4350 tracing::warn!(
4351 err = err.as_ref() as &dyn std::error::Error,
4352 "Failed setting data path back to synthetic after guest VF was removed."
4353 );
4354 }
4355 primary.is_data_path_switched = Some(false);
4356 }
4357 }
4358 PrimaryChannelGuestVfState::DataPathSwitched
4359 | PrimaryChannelGuestVfState::Restoring(
4360 saved_state::GuestVfState::DataPathSwitched,
4361 ) => {
4362 if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4363 tracing::warn!(
4364 err = err.as_ref() as &dyn std::error::Error,
4365 "Failed setting data path back to synthetic after guest VF was removed."
4366 );
4367 }
4368 primary.is_data_path_switched = Some(false);
4369 }
4370 _ => (),
4371 }
4372 if let PrimaryChannelGuestVfState::AvailableAdvertised = primary.guest_vf_state {
4373 c_state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
4374 }
4375 primary.guest_vf_state = match primary.guest_vf_state {
4377 PrimaryChannelGuestVfState::Initializing
4378 | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState)
4379 | PrimaryChannelGuestVfState::Available { .. } => {
4380 PrimaryChannelGuestVfState::Unavailable
4381 }
4382 PrimaryChannelGuestVfState::AvailableAdvertised
4383 | PrimaryChannelGuestVfState::Restoring(
4384 saved_state::GuestVfState::AvailableAdvertised,
4385 )
4386 | PrimaryChannelGuestVfState::Ready
4387 | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4388 PrimaryChannelGuestVfState::UnavailableFromAvailable
4389 }
4390 PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, id, .. }
4391 | PrimaryChannelGuestVfState::Restoring(
4392 saved_state::GuestVfState::DataPathSwitchPending { to_guest, id, .. },
4393 ) => PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4394 to_guest,
4395 id,
4396 },
4397 PrimaryChannelGuestVfState::DataPathSwitched
4398 | PrimaryChannelGuestVfState::Restoring(
4399 saved_state::GuestVfState::DataPathSwitched,
4400 )
4401 | PrimaryChannelGuestVfState::DataPathSynthetic => {
4402 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4403 }
4404 _ => primary.guest_vf_state,
4405 }
4406 }
4407 }
4408
4409 async fn restart_queues(&mut self, c_state: &mut CoordinatorState) -> Result<(), WorkerError> {
4410 let drivers = self
4412 .workers
4413 .iter_mut()
4414 .map(|worker| {
4415 let task = worker.task_mut();
4416 task.queue_state = None;
4417 task.driver.clone()
4418 })
4419 .collect::<Vec<_>>();
4420
4421 c_state.endpoint.stop().await;
4422
4423 let (primary_worker, subworkers) = if let [primary, sub @ ..] = self.workers.as_mut_slice()
4424 {
4425 (primary, sub)
4426 } else {
4427 unreachable!()
4428 };
4429
4430 let state = primary_worker
4431 .state_mut()
4432 .and_then(|worker| worker.state.ready_mut());
4433
4434 let state = if let Some(state) = state {
4435 state
4436 } else {
4437 return Ok(());
4438 };
4439
4440 self.buffers = Some(state.buffers.clone());
4442
4443 let num_queues = state.state.primary.as_ref().unwrap().requested_num_queues;
4444 let mut active_queues = Vec::new();
4445 let active_queue_count =
4446 if let Some(rss_state) = state.state.primary.as_ref().unwrap().rss_state.as_ref() {
4447 active_queues.clone_from(&rss_state.indirection_table);
4449 active_queues.sort();
4450 active_queues.dedup();
4451 active_queues = active_queues
4452 .into_iter()
4453 .filter(|&index| index < num_queues)
4454 .collect::<Vec<_>>();
4455 if !active_queues.is_empty() {
4456 active_queues.len() as u16
4457 } else {
4458 tracelimit::warn_ratelimited!("Invalid RSS indirection table");
4459 num_queues
4460 }
4461 } else {
4462 num_queues
4463 };
4464
4465 let (ranges, mut remote_buffer_id_recvs) =
4467 RxBufferRanges::new(state.buffers.recv_buffer.count, active_queue_count.into());
4468 let ranges = Arc::new(ranges);
4469
4470 let mut queues = Vec::new();
4471 let mut rx_buffers = Vec::new();
4472 {
4473 let buffers = &state.buffers;
4474 let guest_buffers = Arc::new(
4475 GuestBuffers::new(
4476 buffers.mem.clone(),
4477 buffers.recv_buffer.gpadl.clone(),
4478 buffers.recv_buffer.sub_allocation_size,
4479 buffers.ndis_config.mtu,
4480 )
4481 .map_err(WorkerError::GpadlError)?,
4482 );
4483
4484 let mut queue_config = Vec::new();
4487 let initial_rx;
4488 {
4489 let states = std::iter::once(Some(&*state)).chain(
4490 subworkers
4491 .iter()
4492 .map(|worker| worker.state().and_then(|worker| worker.state.ready())),
4493 );
4494
4495 initial_rx = (RX_RESERVED_CONTROL_BUFFERS..state.buffers.recv_buffer.count)
4496 .filter(|&n| states.clone().flatten().all(|s| s.state.rx_bufs.is_free(n)))
4497 .map(RxId)
4498 .collect::<Vec<_>>();
4499
4500 let mut initial_rx = initial_rx.as_slice();
4501 let mut range_start = 0;
4502 let primary_queue_excluded = !active_queues.is_empty() && active_queues[0] != 0;
4503 let first_queue = if !primary_queue_excluded {
4504 0
4505 } else {
4506 queue_config.push(QueueConfig {
4510 pool: Box::new(BufferPool::new(guest_buffers.clone())),
4511 initial_rx: &[],
4512 driver: Box::new(drivers[0].clone()),
4513 });
4514 rx_buffers.push(RxBufferRange::new(
4515 ranges.clone(),
4516 0..RX_RESERVED_CONTROL_BUFFERS,
4517 None,
4518 ));
4519 range_start = RX_RESERVED_CONTROL_BUFFERS;
4520 1
4521 };
4522 for queue_index in first_queue..num_queues {
4523 let queue_active = active_queues.is_empty()
4524 || active_queues.binary_search(&queue_index).is_ok();
4525 let (range_end, end, buffer_id_recv) = if queue_active {
4526 let range_end = if rx_buffers.len() as u16 == active_queue_count - 1 {
4527 state.buffers.recv_buffer.count
4529 } else if queue_index == 0 {
4530 RX_RESERVED_CONTROL_BUFFERS + ranges.buffers_per_queue
4532 } else {
4533 range_start + ranges.buffers_per_queue
4534 };
4535 (
4536 range_end,
4537 initial_rx.partition_point(|id| id.0 < range_end),
4538 Some(remote_buffer_id_recvs.remove(0)),
4539 )
4540 } else {
4541 (range_start, 0, None)
4542 };
4543
4544 let (this, rest) = initial_rx.split_at(end);
4545 queue_config.push(QueueConfig {
4546 pool: Box::new(BufferPool::new(guest_buffers.clone())),
4547 initial_rx: this,
4548 driver: Box::new(drivers[queue_index as usize].clone()),
4549 });
4550 initial_rx = rest;
4551 rx_buffers.push(RxBufferRange::new(
4552 ranges.clone(),
4553 range_start..range_end,
4554 buffer_id_recv,
4555 ));
4556
4557 range_start = range_end;
4558 }
4559 }
4560
4561 let primary = state.state.primary.as_mut().unwrap();
4562 tracing::debug!(num_queues, "enabling endpoint");
4563
4564 let rss = primary
4565 .rss_state
4566 .as_ref()
4567 .map(|rss| net_backend::RssConfig {
4568 key: &rss.key,
4569 indirection_table: &rss.indirection_table,
4570 flags: 0,
4571 });
4572
4573 c_state
4574 .endpoint
4575 .get_queues(queue_config, rss.as_ref(), &mut queues)
4576 .instrument(tracing::info_span!("netvsp_get_queues"))
4577 .await
4578 .map_err(WorkerError::Endpoint)?;
4579
4580 assert_eq!(queues.len(), num_queues as usize);
4581
4582 self.channel_control
4584 .enable_subchannels(num_queues - 1)
4585 .expect("already validated");
4586
4587 self.num_queues = num_queues;
4588 }
4589
4590 self.active_packet_filter = self.workers[0].state().unwrap().channel.packet_filter;
4591 for ((worker, queue), rx_buffer) in self.workers.iter_mut().zip(queues).zip(rx_buffers) {
4593 worker.task_mut().queue_state = Some(QueueState {
4594 queue,
4595 target_vp_set: false,
4596 rx_buffer_range: rx_buffer,
4597 });
4598 if let Some(worker) = worker.state_mut() {
4600 worker.channel.packet_filter = self.active_packet_filter;
4601 if let Some(ready_state) = worker.state.ready_mut() {
4604 ready_state.state.pending_rx_packets.clear();
4605 ready_state.reset_tx_after_endpoint_stop();
4606 }
4607 }
4608 }
4609
4610 Ok(())
4611 }
4612
4613 fn primary_mut(&mut self) -> Option<&mut PrimaryChannelState> {
4614 self.workers[0]
4615 .state_mut()
4616 .unwrap()
4617 .state
4618 .ready_mut()?
4619 .state
4620 .primary
4621 .as_mut()
4622 }
4623
4624 async fn update_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
4625 self.workers[0].stop().await;
4626 self.restore_guest_vf_state(c_state).await;
4627 }
4628}
4629
4630impl<T: RingMem + 'static + Sync> AsyncRun<Worker<T>> for NetQueue {
4631 async fn run(
4632 &mut self,
4633 stop: &mut StopTask<'_>,
4634 worker: &mut Worker<T>,
4635 ) -> Result<(), task_control::Cancelled> {
4636 match worker.process(stop, self).await {
4637 Ok(()) | Err(WorkerError::BufferRevoked) => {}
4638 Err(WorkerError::Cancelled(cancelled)) => return Err(cancelled),
4639 Err(err) => {
4640 tracing::error!(
4641 error = &err as &dyn std::error::Error,
4642 channel_idx = worker.channel_idx,
4643 "netvsp error"
4644 );
4645 }
4646 }
4647 Ok(())
4648 }
4649}
4650
4651impl<T: RingMem + 'static> Worker<T> {
4652 async fn process(
4653 &mut self,
4654 stop: &mut StopTask<'_>,
4655 queue: &mut NetQueue,
4656 ) -> Result<(), WorkerError> {
4657 loop {
4661 match &mut self.state {
4662 WorkerState::Init(initializing) => {
4663 assert_eq!(self.channel_idx, 0);
4664
4665 tracelimit::info_ratelimited!("network accepted");
4666
4667 let (buffers, state) = stop
4668 .until_stopped(self.channel.initialize(initializing, self.mem.clone()))
4669 .await??;
4670
4671 let state = ReadyState {
4672 buffers: Arc::new(buffers),
4673 state,
4674 data: ProcessingData::new(),
4675 };
4676
4677 let _ = self.coordinator_send.try_send(CoordinatorMessage::Restart);
4679
4680 tracelimit::info_ratelimited!("network initialized");
4681 self.state = WorkerState::WaitingForCoordinator(Some(state));
4682 }
4683 WorkerState::WaitingForCoordinator(_) => {
4684 assert_eq!(self.channel_idx, 0);
4685 stop.until_stopped(pending()).await?
4688 }
4689 WorkerState::Ready(state) => {
4690 let queue_state = if let Some(queue_state) = &mut queue.queue_state {
4691 if !queue_state.target_vp_set {
4692 if let Some(target_vp) = self.target_vp {
4693 tracing::debug!(
4694 channel_idx = self.channel_idx,
4695 target_vp,
4696 "updating target VP"
4697 );
4698 queue_state.queue.update_target_vp(target_vp).await;
4699 queue_state.target_vp_set = true;
4700 }
4701 }
4702
4703 queue_state
4704 } else {
4705 stop.until_stopped(pending()).await?
4707 };
4708
4709 let result = self.channel.main_loop(stop, state, queue_state).await;
4710 let msg = match result {
4711 Ok(restart) => {
4712 assert_eq!(self.channel_idx, 0);
4713 restart
4714 }
4715 Err(WorkerError::EndpointRequiresQueueRestart(err)) => {
4716 tracelimit::warn_ratelimited!(
4717 err = err.as_ref() as &dyn std::error::Error,
4718 "Endpoint requires queues to restart",
4719 );
4720 CoordinatorMessage::Restart
4721 }
4722 Err(err) => return Err(err),
4723 };
4724
4725 let WorkerState::Ready(ready) = std::mem::replace(
4726 &mut self.state,
4727 WorkerState::WaitingForCoordinator(None),
4728 ) else {
4729 unreachable!("must be running in ready state")
4730 };
4731 let _ = std::mem::replace(
4732 &mut self.state,
4733 WorkerState::WaitingForCoordinator(Some(ready)),
4734 );
4735 self.coordinator_send
4736 .try_send(msg)
4737 .map_err(WorkerError::CoordinatorMessageSendFailed)?;
4738 stop.until_stopped(pending()).await?
4739 }
4740 }
4741 }
4742 }
4743}
4744
4745impl<T: 'static + RingMem> NetChannel<T> {
4746 fn try_next_packet<'a>(
4747 &mut self,
4748 send_buffer: Option<&SendBuffer>,
4749 version: Option<Version>,
4750 external_data: &'a mut MultiPagedRangeBuf,
4751 ) -> Result<Option<Packet<'a>>, WorkerError> {
4752 let (mut read, _) = self.queue.split();
4753 let packet = match read.try_read() {
4754 Ok(packet) => parse_packet(&packet, send_buffer, version, external_data)
4755 .map_err(WorkerError::Packet)?,
4756 Err(queue::TryReadError::Empty) => return Ok(None),
4757 Err(queue::TryReadError::Queue(err)) => return Err(err.into()),
4758 };
4759
4760 tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
4761 Ok(Some(packet))
4762 }
4763
4764 async fn next_packet<'a>(
4765 &mut self,
4766 send_buffer: Option<&'a SendBuffer>,
4767 version: Option<Version>,
4768 external_data: &'a mut MultiPagedRangeBuf,
4769 ) -> Result<Packet<'a>, WorkerError> {
4770 let (mut read, _) = self.queue.split();
4771 let mut packet_ref = read.read().await?;
4772 let packet = parse_packet(&packet_ref, send_buffer, version, external_data)
4773 .map_err(WorkerError::Packet)?;
4774 if matches!(packet.data, PacketData::RndisPacket(_)) {
4775 tracing::trace!(target: "netvsp/vmbus", "detected rndis initialization message");
4777 packet_ref.revert();
4778 }
4779 tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
4780 Ok(packet)
4781 }
4782
4783 fn is_ready_to_initialize(initializing: &InitState, allow_missing_send_buffer: bool) -> bool {
4784 (initializing.ndis_config.is_some() || initializing.version < Version::V2)
4785 && initializing.ndis_version.is_some()
4786 && (initializing.send_buffer.is_some() || allow_missing_send_buffer)
4787 && initializing.recv_buffer.is_some()
4788 }
4789
4790 async fn initialize(
4791 &mut self,
4792 initializing: &mut Option<InitState>,
4793 mem: GuestMemory,
4794 ) -> Result<(ChannelBuffers, ActiveState), WorkerError> {
4795 let mut has_init_packet_arrived = false;
4796 loop {
4797 if let Some(initializing) = &mut *initializing {
4798 if Self::is_ready_to_initialize(initializing, false) || has_init_packet_arrived {
4799 let recv_buffer = initializing.recv_buffer.take().unwrap();
4800 let send_buffer = initializing.send_buffer.take();
4801 let state = ActiveState::new(
4802 Some(PrimaryChannelState::new(
4803 self.adapter.offload_support.clone(),
4804 )),
4805 recv_buffer.count,
4806 );
4807 let buffers = ChannelBuffers {
4808 version: initializing.version,
4809 mem,
4810 recv_buffer,
4811 send_buffer,
4812 ndis_version: initializing.ndis_version.take().unwrap(),
4813 ndis_config: initializing.ndis_config.take().unwrap_or(NdisConfig {
4814 mtu: DEFAULT_MTU,
4815 capabilities: protocol::NdisConfigCapabilities::new(),
4816 }),
4817 };
4818
4819 break Ok((buffers, state));
4820 }
4821 }
4822
4823 self.queue
4826 .split()
4827 .1
4828 .wait_ready(ring::PacketSize::completion(protocol::PACKET_SIZE_V61))
4829 .await?;
4830
4831 let mut external_data = MultiPagedRangeBuf::new();
4832 let packet = self
4833 .next_packet(
4834 None,
4835 initializing.as_ref().map(|x| x.version),
4836 &mut external_data,
4837 )
4838 .await?;
4839
4840 if let Some(initializing) = &mut *initializing {
4841 match packet.data {
4842 PacketData::SendNdisConfig(config) => {
4843 if initializing.ndis_config.is_some() {
4844 return Err(WorkerError::UnexpectedPacketOrder(
4845 PacketOrderError::SendNdisConfigExists,
4846 ));
4847 }
4848
4849 let mtu = if config.mtu >= MIN_MTU && config.mtu <= MAX_MTU {
4851 config.mtu
4852 } else {
4853 DEFAULT_MTU
4854 };
4855
4856 self.send_completion(packet.transaction_id, None)?;
4858 initializing.ndis_config = Some(NdisConfig {
4859 mtu,
4860 capabilities: config.capabilities,
4861 });
4862 }
4863 PacketData::SendNdisVersion(version) => {
4864 if initializing.ndis_version.is_some() {
4865 return Err(WorkerError::UnexpectedPacketOrder(
4866 PacketOrderError::SendNdisVersionExists,
4867 ));
4868 }
4869
4870 self.send_completion(packet.transaction_id, None)?;
4872 initializing.ndis_version = Some(NdisVersion {
4873 major: version.ndis_major_version,
4874 minor: version.ndis_minor_version,
4875 });
4876 }
4877 PacketData::SendReceiveBuffer(message) => {
4878 if initializing.recv_buffer.is_some() {
4879 return Err(WorkerError::UnexpectedPacketOrder(
4880 PacketOrderError::SendReceiveBufferExists,
4881 ));
4882 }
4883
4884 let mtu = if let Some(cfg) = &initializing.ndis_config {
4885 cfg.mtu
4886 } else if initializing.version < Version::V2 {
4887 DEFAULT_MTU
4888 } else {
4889 return Err(WorkerError::UnexpectedPacketOrder(
4890 PacketOrderError::SendReceiveBufferMissingMTU,
4891 ));
4892 };
4893
4894 let sub_allocation_size = sub_allocation_size_for_mtu(mtu);
4895
4896 let recv_buffer = ReceiveBuffer::new(
4897 &self.gpadl_map,
4898 message.gpadl_handle,
4899 message.id,
4900 sub_allocation_size,
4901 )?;
4902
4903 self.send_completion(
4904 packet.transaction_id,
4905 Some(&self.message(
4906 protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER_COMPLETE,
4907 protocol::Message1SendReceiveBufferComplete {
4908 status: protocol::Status::SUCCESS,
4909 num_sections: 1,
4910 sections: [protocol::ReceiveBufferSection {
4911 offset: 0,
4912 sub_allocation_size: recv_buffer.sub_allocation_size,
4913 num_sub_allocations: recv_buffer.count,
4914 end_offset: recv_buffer.sub_allocation_size
4915 * recv_buffer.count,
4916 }],
4917 },
4918 )),
4919 )?;
4920 initializing.recv_buffer = Some(recv_buffer);
4921 }
4922 PacketData::SendSendBuffer(message) => {
4923 if initializing.send_buffer.is_some() {
4924 return Err(WorkerError::UnexpectedPacketOrder(
4925 PacketOrderError::SendSendBufferExists,
4926 ));
4927 }
4928
4929 let send_buffer = SendBuffer::new(&self.gpadl_map, message.gpadl_handle)?;
4930 self.send_completion(
4931 packet.transaction_id,
4932 Some(&self.message(
4933 protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER_COMPLETE,
4934 protocol::Message1SendSendBufferComplete {
4935 status: protocol::Status::SUCCESS,
4936 section_size: 6144,
4937 },
4938 )),
4939 )?;
4940
4941 initializing.send_buffer = Some(send_buffer);
4942 }
4943 PacketData::RndisPacket(rndis_packet) => {
4944 if !Self::is_ready_to_initialize(initializing, true) {
4945 return Err(WorkerError::UnexpectedPacketOrder(
4946 PacketOrderError::UnexpectedRndisPacket,
4947 ));
4948 }
4949 tracing::debug!(
4950 channel_type = rndis_packet.channel_type,
4951 "RndisPacket received during initialization, assuming MESSAGE_TYPE_INITIALIZE_MSG"
4952 );
4953 has_init_packet_arrived = true;
4954 }
4955 _ => {
4956 return Err(WorkerError::UnexpectedPacketOrder(
4957 PacketOrderError::InvalidPacketData,
4958 ));
4959 }
4960 }
4961 } else {
4962 match packet.data {
4963 PacketData::Init(init) => {
4964 let requested_version = init.protocol_version;
4965 let version = check_version(requested_version);
4966 let mut data = protocol::MessageInitComplete {
4967 deprecated: protocol::INVALID_PROTOCOL_VERSION,
4968 maximum_mdl_chain_length: 34,
4969 status: protocol::Status::NONE,
4970 };
4971 if let Some(version) = version {
4972 if version == Version::V1 {
4973 data.deprecated = Version::V1 as u32;
4974 }
4975 data.status = protocol::Status::SUCCESS;
4976 } else {
4977 tracing::debug!(requested_version, "unrecognized version");
4978 }
4979 let message = self.message(protocol::MESSAGE_TYPE_INIT_COMPLETE, data);
4980 self.send_completion(packet.transaction_id, Some(&message))?;
4981
4982 if let Some(version) = version {
4983 tracelimit::info_ratelimited!(?version, "network negotiated");
4984
4985 if version >= Version::V61 {
4986 self.packet_size = PacketSize::V61;
4989 }
4990 *initializing = Some(InitState {
4991 version,
4992 ndis_config: None,
4993 ndis_version: None,
4994 recv_buffer: None,
4995 send_buffer: None,
4996 });
4997 }
4998 }
4999 _ => unreachable!(),
5000 }
5001 }
5002 }
5003 }
5004
5005 async fn main_loop(
5006 &mut self,
5007 stop: &mut StopTask<'_>,
5008 ready_state: &mut ReadyState,
5009 queue_state: &mut QueueState,
5010 ) -> Result<CoordinatorMessage, WorkerError> {
5011 let buffers = &ready_state.buffers;
5012 let state = &mut ready_state.state;
5013 let data = &mut ready_state.data;
5014
5015 let ring_spare_capacity = {
5016 let (_, send) = self.queue.split();
5017 let mut limit = if self.can_use_ring_size_opt {
5018 self.adapter.ring_size_limit.load(Ordering::Relaxed)
5019 } else {
5020 0
5021 };
5022 if limit == 0 {
5023 limit = send.capacity() - 2048;
5024 }
5025 send.capacity() - limit
5026 };
5027
5028 if !state.pending_rx_packets.is_empty()
5030 && self.packet_filter != rndisprot::NDIS_PACKET_TYPE_NONE
5031 {
5032 let epqueue = queue_state.queue.as_mut();
5033 let (front, back) = state.pending_rx_packets.as_slices();
5034 epqueue.rx_avail(front);
5035 epqueue.rx_avail(back);
5036 state.pending_rx_packets.clear();
5037 }
5038
5039 if let Some(primary) = state.primary.as_mut() {
5041 if primary.requested_num_queues > 1 && !primary.tx_spread_sent {
5042 let num_channels_opened =
5043 self.adapter.num_sub_channels_opened.load(Ordering::Relaxed);
5044 if num_channels_opened == primary.requested_num_queues as usize {
5045 let (_, mut send) = self.queue.split();
5046 stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
5047 .await??;
5048 self.guest_send_indirection_table(buffers.version, num_channels_opened as u32);
5049 primary.tx_spread_sent = true;
5050 }
5051 }
5052 if let PendingLinkAction::Active(up) = primary.pending_link_action {
5053 let (_, mut send) = self.queue.split();
5054 stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
5055 .await??;
5056 if let Some(id) = primary.free_control_buffers.pop() {
5057 let connect = if primary.guest_link_up != up {
5058 primary.pending_link_action = PendingLinkAction::Default;
5059 up
5060 } else {
5061 primary.pending_link_action =
5064 PendingLinkAction::Delay(primary.guest_link_up);
5065 !primary.guest_link_up
5066 };
5067 assert!(state.rx_bufs.is_free(id.0));
5069 state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
5070 let state_to_send = if connect {
5071 rndisprot::STATUS_MEDIA_CONNECT
5072 } else {
5073 rndisprot::STATUS_MEDIA_DISCONNECT
5074 };
5075 tracing::info!(
5076 connect,
5077 mac_address = %self.adapter.mac_address,
5078 "sending link status"
5079 );
5080
5081 self.indicate_status(buffers, id.0, state_to_send, &[])?;
5082 primary.guest_link_up = connect;
5083 } else {
5084 primary.pending_link_action = PendingLinkAction::Delay(up);
5085 }
5086
5087 match primary.pending_link_action {
5088 PendingLinkAction::Delay(_) => {
5089 return Ok(CoordinatorMessage::StartTimer(
5090 Instant::now() + LINK_DELAY_DURATION,
5091 ));
5092 }
5093 PendingLinkAction::Active(_) => panic!("State should not be Active"),
5094 _ => {}
5095 }
5096 }
5097 match primary.guest_vf_state {
5098 PrimaryChannelGuestVfState::Available { .. }
5099 | PrimaryChannelGuestVfState::UnavailableFromAvailable
5100 | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
5101 | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
5102 | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
5103 | PrimaryChannelGuestVfState::DataPathSynthetic => {
5104 let (_, mut send) = self.queue.split();
5105 stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
5106 .await??;
5107 if let Some(message) = self.handle_state_change(primary, buffers).await? {
5108 return Ok(message);
5109 }
5110 }
5111 _ => (),
5112 }
5113 }
5114
5115 loop {
5116 let ring_full = {
5121 let (_, mut send) = self.queue.split();
5122 !send.can_write(ring_spare_capacity)?
5123 };
5124
5125 let did_some_work = (!ring_full
5126 && self.process_endpoint_rx(buffers, state, data, queue_state.queue.as_mut())?)
5127 | self.process_ring_buffer(buffers, state, data, queue_state)?
5128 | (!ring_full
5129 && self.process_endpoint_tx(state, data, queue_state.queue.as_mut())?)
5130 | self.transmit_pending_segments(state, data, queue_state)?
5131 | self.send_pending_packets(state)?;
5132
5133 if !did_some_work {
5134 state.stats.spurious_wakes.increment();
5135 }
5136
5137 self.process_control_messages(buffers, state)?;
5140
5141 let restart = stop
5145 .until_stopped(std::future::poll_fn(
5146 |cx| -> Poll<Option<CoordinatorMessage>> {
5147 if !ring_full {
5151 if queue_state.queue.poll_ready(cx).is_ready() {
5153 tracing::trace!("endpoint ready");
5154 return Poll::Ready(None);
5155 }
5156 }
5157
5158 let (mut recv, mut send) = self.queue.split();
5161 if state.free_tx_packets.len() >= self.adapter.free_tx_packet_threshold
5162 && data.tx_segments.is_empty()
5163 && recv.poll_ready(cx).is_ready()
5164 {
5165 tracing::trace!("incoming ring ready");
5166 return Poll::Ready(None);
5167 }
5168
5169 let mut pending_send_size = self.pending_send_size;
5176 if ring_full {
5177 pending_send_size = ring_spare_capacity;
5178 }
5179 if pending_send_size != 0
5180 && send.poll_ready(cx, pending_send_size).is_ready()
5181 {
5182 tracing::trace!("outgoing ring ready");
5183 return Poll::Ready(None);
5184 }
5185
5186 if let Some(remote_buffer_id_recv) =
5192 &mut queue_state.rx_buffer_range.remote_buffer_id_recv
5193 {
5194 while let Poll::Ready(Some(id)) =
5195 remote_buffer_id_recv.poll_next_unpin(cx)
5196 {
5197 if id >= RX_RESERVED_CONTROL_BUFFERS {
5198 queue_state.queue.rx_avail(&[RxId(id)]);
5199 } else {
5200 state
5201 .primary
5202 .as_mut()
5203 .unwrap()
5204 .free_control_buffers
5205 .push(ControlMessageId(id));
5206 }
5207 }
5208 }
5209
5210 if let Some(restart) = self.restart.take() {
5211 return Poll::Ready(Some(restart));
5212 }
5213
5214 tracing::trace!("network waiting");
5215 Poll::Pending
5216 },
5217 ))
5218 .await?;
5219
5220 if let Some(restart) = restart {
5221 break Ok(restart);
5222 }
5223 }
5224 }
5225
5226 fn process_endpoint_rx(
5227 &mut self,
5228 buffers: &ChannelBuffers,
5229 state: &mut ActiveState,
5230 data: &mut ProcessingData,
5231 epqueue: &mut dyn net_backend::Queue,
5232 ) -> Result<bool, WorkerError> {
5233 let n = epqueue
5234 .rx_poll(&mut data.rx_ready)
5235 .map_err(WorkerError::Endpoint)?;
5236 if n == 0 {
5237 return Ok(false);
5238 }
5239
5240 state.stats.rx_packets_per_wake.add_sample(n as u64);
5241
5242 if self.packet_filter == rndisprot::NDIS_PACKET_TYPE_NONE {
5243 tracing::trace!(
5244 packet_filter = self.packet_filter,
5245 "rx packets dropped due to packet filter"
5246 );
5247 state.pending_rx_packets.extend(&data.rx_ready[..n]);
5252 state.stats.rx_dropped_filtered.add(n as u64);
5253 return Ok(false);
5254 }
5255
5256 let transaction_id = data.rx_ready[0].0.into();
5257 let ready_ids = data.rx_ready[..n].iter().map(|&RxId(id)| id);
5258
5259 state.rx_bufs.allocate(ready_ids.clone()).unwrap();
5260
5261 let len = buffers.recv_buffer.sub_allocation_size as usize;
5264 data.transfer_pages.clear();
5265 data.transfer_pages
5266 .extend(ready_ids.map(|id| buffers.recv_buffer.transfer_page_range(id, len)));
5267
5268 match self.try_send_rndis_message(
5269 transaction_id,
5270 protocol::DATA_CHANNEL_TYPE,
5271 buffers.recv_buffer.id,
5272 &data.transfer_pages,
5273 )? {
5274 None => {
5275 state.stats.rx_packets.add(n as u64);
5277 }
5278 Some(_) => {
5279 state.stats.rx_dropped_ring_full.add(n as u64);
5283
5284 state.rx_bufs.free(data.rx_ready[0].0);
5285 epqueue.rx_avail(&data.rx_ready[..n]);
5286 }
5287 }
5288
5289 Ok(true)
5290 }
5291
5292 fn process_endpoint_tx(
5293 &mut self,
5294 state: &mut ActiveState,
5295 data: &mut ProcessingData,
5296 epqueue: &mut dyn net_backend::Queue,
5297 ) -> Result<bool, WorkerError> {
5298 let result = epqueue.tx_poll(&mut data.tx_done);
5300
5301 match result {
5302 Ok(n) => {
5303 if n == 0 {
5304 return Ok(false);
5305 }
5306
5307 for &id in &data.tx_done[..n] {
5308 let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5309 assert!(tx_packet.pending_packet_count > 0);
5310 tx_packet.pending_packet_count -= 1;
5311 if tx_packet.pending_packet_count == 0 {
5312 self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5313 }
5314 }
5315
5316 Ok(true)
5317 }
5318 Err(TxError::TryRestart(err)) => {
5319 Err(WorkerError::EndpointRequiresQueueRestart(err))
5322 }
5323 Err(TxError::Fatal(err)) => Err(WorkerError::Endpoint(err)),
5324 }
5325 }
5326
5327 fn switch_data_path(
5328 &mut self,
5329 state: &mut ActiveState,
5330 use_guest_vf: bool,
5331 transaction_id: Option<u64>,
5332 ) -> Result<(), WorkerError> {
5333 let primary = state.primary.as_mut().unwrap();
5334 let mut queue_switch_operation = false;
5335 match primary.guest_vf_state {
5336 PrimaryChannelGuestVfState::AvailableAdvertised | PrimaryChannelGuestVfState::Ready => {
5337 if use_guest_vf || primary.is_data_path_switched.is_none() {
5346 primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5347 to_guest: use_guest_vf,
5348 id: transaction_id,
5349 result: None,
5350 };
5351 queue_switch_operation = true;
5352 }
5353 }
5354 PrimaryChannelGuestVfState::DataPathSwitched => {
5355 if !use_guest_vf {
5356 primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5357 to_guest: false,
5358 id: transaction_id,
5359 result: None,
5360 };
5361 queue_switch_operation = true;
5362 }
5363 }
5364 _ if use_guest_vf => {
5365 tracing::warn!(
5366 state = %primary.guest_vf_state,
5367 use_guest_vf,
5368 "Data path switch requested while device is in wrong state"
5369 );
5370 }
5371 _ => (),
5372 };
5373 if queue_switch_operation {
5374 self.send_coordinator_update_vf();
5375 } else {
5376 self.send_completion(transaction_id, None)?;
5377 }
5378 Ok(())
5379 }
5380
5381 fn process_ring_buffer(
5382 &mut self,
5383 buffers: &ChannelBuffers,
5384 state: &mut ActiveState,
5385 data: &mut ProcessingData,
5386 queue_state: &mut QueueState,
5387 ) -> Result<bool, WorkerError> {
5388 if !data.tx_segments.is_empty() {
5389 return Ok(false);
5393 }
5394 let mut total_packets = 0;
5395 let mut did_some_work = false;
5396 loop {
5397 if state.free_tx_packets.is_empty() {
5398 break;
5399 }
5400 let packet = if let Some(packet) = self.try_next_packet(
5401 buffers.send_buffer.as_ref(),
5402 Some(buffers.version),
5403 &mut data.external_data,
5404 )? {
5405 packet
5406 } else {
5407 break;
5408 };
5409
5410 did_some_work = true;
5411 match packet.data {
5412 PacketData::RndisPacket(_) => {
5413 let id = state.free_tx_packets.pop().unwrap();
5414 let result: Result<usize, WorkerError> =
5415 self.handle_rndis(buffers, id, state, &packet, &mut data.tx_segments);
5416 match result {
5417 Ok(num_packets) => {
5418 total_packets += num_packets as u64;
5419 if num_packets == 0 {
5420 self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5421 }
5422 }
5423 Err(err) => {
5424 tracelimit::error_ratelimited!(
5425 error = &err as &dyn std::error::Error,
5426 "failed to handle RNDIS packet"
5427 );
5428 self.complete_tx_packet(state, id, protocol::Status::FAILURE)?;
5429 }
5430 };
5431 }
5432 PacketData::RndisPacketComplete(_completion) => {
5433 data.rx_done.clear();
5434 state
5435 .release_recv_buffers(
5436 packet
5437 .transaction_id
5438 .expect("completion packets have transaction id by construction"),
5439 &queue_state.rx_buffer_range,
5440 &mut data.rx_done,
5441 )
5442 .ok_or(WorkerError::InvalidRndisPacketCompletion)?;
5443 queue_state.queue.rx_avail(&data.rx_done);
5444 }
5445 PacketData::SubChannelRequest(request) if state.primary.is_some() => {
5446 let mut subchannel_count = 0;
5447 let status = if request.operation == protocol::SubchannelOperation::ALLOCATE
5453 && request.num_sub_channels < self.adapter.max_queues.into()
5454 {
5455 subchannel_count = request.num_sub_channels;
5456 protocol::Status::SUCCESS
5457 } else {
5458 tracelimit::warn_ratelimited!(
5459 operation = ?request.operation,
5460 request_sub_channels = request.num_sub_channels,
5461 max_supported_sub_channels = self.adapter.max_queues - 1,
5462 "Subchannel request failed: either operation is not supported or requested more subchannels than supported"
5463 );
5464 protocol::Status::FAILURE
5465 };
5466
5467 tracing::debug!(?status, subchannel_count, "subchannel request");
5468 self.send_completion(
5469 packet.transaction_id,
5470 Some(&self.message(
5471 protocol::MESSAGE5_TYPE_SUB_CHANNEL,
5472 protocol::Message5SubchannelComplete {
5473 status,
5474 num_sub_channels: subchannel_count,
5475 },
5476 )),
5477 )?;
5478
5479 if subchannel_count > 0 {
5480 let primary = state.primary.as_mut().unwrap();
5481 primary.requested_num_queues = subchannel_count as u16 + 1;
5482 primary.tx_spread_sent = false;
5483 self.restart = Some(CoordinatorMessage::Restart);
5484 }
5485 }
5486 PacketData::RevokeReceiveBuffer(protocol::Message1RevokeReceiveBuffer { id })
5487 | PacketData::RevokeSendBuffer(protocol::Message1RevokeSendBuffer { id })
5488 if state.primary.is_some() =>
5489 {
5490 tracing::debug!(
5491 id,
5492 "receive/send buffer revoked, terminating channel processing"
5493 );
5494 return Err(WorkerError::BufferRevoked);
5495 }
5496 PacketData::SendVfAssociationCompletion if state.primary.is_some() => (),
5498 PacketData::SwitchDataPath(switch_data_path) if state.primary.is_some() => {
5499 self.switch_data_path(
5500 state,
5501 switch_data_path.active_data_path == protocol::DataPath::VF.0,
5502 packet.transaction_id,
5503 )?;
5504 }
5505 PacketData::SwitchDataPathCompletion if state.primary.is_some() => (),
5506 PacketData::OidQueryEx(oid_query) => {
5507 tracing::warn!(oid = ?oid_query.oid, "unimplemented OID");
5508 self.send_completion(
5509 packet.transaction_id,
5510 Some(&self.message(
5511 protocol::MESSAGE5_TYPE_OID_QUERY_EX_COMPLETE,
5512 protocol::Message5OidQueryExComplete {
5513 status: rndisprot::STATUS_NOT_SUPPORTED,
5514 bytes: 0,
5515 },
5516 )),
5517 )?;
5518 }
5519 p => {
5520 tracing::warn!(request = ?p, "unexpected packet");
5521 return Err(WorkerError::UnexpectedPacketOrder(
5522 PacketOrderError::SwitchDataPathCompletionPrimaryChannelState,
5523 ));
5524 }
5525 }
5526 }
5527 if total_packets > 0 && !self.transmit_segments(state, data, queue_state)? {
5528 state.stats.tx_stalled.increment();
5529 }
5530 state.stats.tx_packets_per_wake.add_sample(total_packets);
5531 Ok(did_some_work)
5532 }
5533
5534 fn transmit_pending_segments(
5537 &mut self,
5538 state: &mut ActiveState,
5539 data: &mut ProcessingData,
5540 queue_state: &mut QueueState,
5541 ) -> Result<bool, WorkerError> {
5542 if data.tx_segments.is_empty() {
5543 return Ok(false);
5544 }
5545 let sent = data.tx_segments_sent;
5546 let did_work =
5547 self.transmit_segments(state, data, queue_state)? || data.tx_segments_sent > sent;
5548 Ok(did_work)
5549 }
5550
5551 fn transmit_segments(
5553 &mut self,
5554 state: &mut ActiveState,
5555 data: &mut ProcessingData,
5556 queue_state: &mut QueueState,
5557 ) -> Result<bool, WorkerError> {
5558 let segments = &data.tx_segments[data.tx_segments_sent..];
5559 let (sync, segments_sent) = queue_state
5560 .queue
5561 .tx_avail(segments)
5562 .map_err(WorkerError::Endpoint)?;
5563
5564 let mut segments = &segments[..segments_sent];
5565 data.tx_segments_sent += segments_sent;
5566
5567 if sync {
5568 while let Some(head) = segments.first() {
5570 let net_backend::TxSegmentType::Head(metadata) = &head.ty else {
5571 unreachable!()
5572 };
5573 let id = metadata.id;
5574 let pending_tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5575 pending_tx_packet.pending_packet_count -= 1;
5576 if pending_tx_packet.pending_packet_count == 0 {
5577 self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5578 }
5579 segments = &segments[metadata.segment_count as usize..];
5580 }
5581 }
5582
5583 let all_sent = data.tx_segments_sent == data.tx_segments.len();
5584 if all_sent {
5585 data.tx_segments.clear();
5586 data.tx_segments_sent = 0;
5587 }
5588 Ok(all_sent)
5589 }
5590
5591 fn handle_rndis(
5592 &mut self,
5593 buffers: &ChannelBuffers,
5594 id: TxId,
5595 state: &mut ActiveState,
5596 packet: &Packet<'_>,
5597 segments: &mut Vec<TxSegment>,
5598 ) -> Result<usize, WorkerError> {
5599 let mut total_packets = 0;
5600 let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5601 assert_eq!(tx_packet.pending_packet_count, 0);
5602 tx_packet.transaction_id = packet
5603 .transaction_id
5604 .ok_or(WorkerError::MissingTransactionId)?;
5605
5606 packet
5610 .external_data
5611 .iter()
5612 .try_for_each(|range| buffers.mem.probe_gpns(range.gpns()))
5613 .map_err(WorkerError::GpaDirectError)?;
5614
5615 let mut reader = packet.rndis_reader(&buffers.mem);
5616 let header: rndisprot::MessageHeader = reader.read_plain()?;
5617 if header.message_type == rndisprot::MESSAGE_TYPE_PACKET_MSG {
5618 let start = segments.len();
5619 match self.handle_rndis_packet_messages(
5620 buffers,
5621 state,
5622 id,
5623 header.message_length as usize,
5624 reader,
5625 segments,
5626 ) {
5627 Ok(n) => {
5628 state.pending_tx_packets[id.0 as usize].pending_packet_count += n;
5629 total_packets += n;
5630 }
5631 Err(err) => {
5632 segments.truncate(start);
5634 return Err(err);
5635 }
5636 }
5637 } else {
5638 self.handle_rndis_message(state, header.message_type, reader)?;
5639 }
5640
5641 Ok(total_packets)
5642 }
5643
5644 fn try_send_tx_packet(
5645 &mut self,
5646 transaction_id: u64,
5647 status: protocol::Status,
5648 ) -> Result<bool, WorkerError> {
5649 let message = self.message(
5650 protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE,
5651 protocol::Message1SendRndisPacketComplete { status },
5652 );
5653 let result = self.queue.split().1.batched().try_write_aligned(
5654 transaction_id,
5655 OutgoingPacketType::Completion,
5656 message.aligned_payload(),
5657 );
5658 let sent = match result {
5659 Ok(()) => true,
5660 Err(queue::TryWriteError::Full(n)) => {
5661 self.pending_send_size = n;
5662 false
5663 }
5664 Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
5665 };
5666 Ok(sent)
5667 }
5668
5669 fn send_pending_packets(&mut self, state: &mut ActiveState) -> Result<bool, WorkerError> {
5670 let mut did_some_work = false;
5671 while let Some(pending) = state.pending_tx_completions.front() {
5672 if !self.try_send_tx_packet(pending.transaction_id, pending.status)? {
5673 return Ok(did_some_work);
5674 }
5675 did_some_work = true;
5676 if let Some(id) = pending.tx_id {
5677 state.free_tx_packets.push(id);
5678 }
5679 tracing::trace!(?pending, "sent tx completion");
5680 state.pending_tx_completions.pop_front();
5681 }
5682
5683 self.pending_send_size = 0;
5684 Ok(did_some_work)
5685 }
5686
5687 fn complete_tx_packet(
5688 &mut self,
5689 state: &mut ActiveState,
5690 id: TxId,
5691 status: protocol::Status,
5692 ) -> Result<(), WorkerError> {
5693 let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5694 assert_eq!(tx_packet.pending_packet_count, 0);
5695 if self.pending_send_size == 0
5696 && self.try_send_tx_packet(tx_packet.transaction_id, status)?
5697 {
5698 tracing::trace!(id = id.0, "sent tx completion");
5699 state.free_tx_packets.push(id);
5700 } else {
5701 tracing::trace!(id = id.0, "pended tx completion");
5702 state.pending_tx_completions.push_back(PendingTxCompletion {
5703 transaction_id: tx_packet.transaction_id,
5704 tx_id: Some(id),
5705 status,
5706 });
5707 }
5708 Ok(())
5709 }
5710}
5711
5712impl ActiveState {
5713 fn release_recv_buffers(
5714 &mut self,
5715 transaction_id: u64,
5716 rx_buffer_range: &RxBufferRange,
5717 done: &mut Vec<RxId>,
5718 ) -> Option<()> {
5719 let first_id: u32 = transaction_id.try_into().ok()?;
5721 let ids = self.rx_bufs.free(first_id)?;
5722 for id in ids {
5723 if !rx_buffer_range.send_if_remote(id) {
5724 if id >= RX_RESERVED_CONTROL_BUFFERS {
5725 done.push(RxId(id));
5726 } else {
5727 self.primary
5728 .as_mut()?
5729 .free_control_buffers
5730 .push(ControlMessageId(id));
5731 }
5732 }
5733 }
5734 Some(())
5735 }
5736}