1#![expect(missing_docs)]
7#![forbid(unsafe_code)]
8
9mod buffers;
10mod protocol;
11pub mod resolver;
12mod rndisprot;
13mod rx_bufs;
14mod saved_state;
15mod test;
16
17use crate::buffers::GuestBuffers;
18use crate::protocol::Message1RevokeReceiveBuffer;
19use crate::protocol::Message1RevokeSendBuffer;
20use crate::protocol::VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES;
21use crate::protocol::Version;
22use crate::rndisprot::NDIS_HASH_FUNCTION_MASK;
23use crate::rndisprot::NDIS_RSS_PARAM_FLAG_DISABLE_RSS;
24use async_trait::async_trait;
25pub use buffers::BufferPool;
26use buffers::sub_allocation_size_for_mtu;
27use futures::FutureExt;
28use futures::StreamExt;
29use futures::channel::mpsc;
30use futures::channel::mpsc::TrySendError;
31use futures_concurrency::future::Race;
32use guestmem::AccessError;
33use guestmem::GuestMemory;
34use guestmem::GuestMemoryError;
35use guestmem::MemoryRead;
36use guestmem::MemoryWrite;
37use guestmem::ranges::PagedRange;
38use guestmem::ranges::PagedRanges;
39use guestmem::ranges::PagedRangesReader;
40use guid::Guid;
41use hvdef::hypercall::HvGuestOsId;
42use hvdef::hypercall::HvGuestOsMicrosoft;
43use hvdef::hypercall::HvGuestOsMicrosoftIds;
44use hvdef::hypercall::HvGuestOsOpenSourceType;
45use inspect::Inspect;
46use inspect::InspectMut;
47use inspect::SensitivityLevel;
48use inspect_counters::Counter;
49use inspect_counters::Histogram;
50use mesh::rpc::Rpc;
51use net_backend::Endpoint;
52use net_backend::EndpointAction;
53use net_backend::QueueConfig;
54use net_backend::RxId;
55use net_backend::TxError;
56use net_backend::TxId;
57use net_backend::TxSegment;
58use net_backend_resources::mac_address::MacAddress;
59use pal_async::timer::Instant;
60use pal_async::timer::PolledTimer;
61use ring::gparange::MultiPagedRangeIter;
62use rx_bufs::RxBuffers;
63use rx_bufs::SubAllocationInUse;
64use std::collections::VecDeque;
65use std::fmt::Debug;
66use std::future::pending;
67use std::mem::offset_of;
68use std::ops::Range;
69use std::sync::Arc;
70use std::sync::atomic::AtomicUsize;
71use std::sync::atomic::Ordering;
72use std::task::Poll;
73use std::time::Duration;
74use task_control::AsyncRun;
75use task_control::InspectTaskMut;
76use task_control::StopTask;
77use task_control::TaskControl;
78use thiserror::Error;
79use tracing::Instrument;
80use vmbus_async::queue;
81use vmbus_async::queue::ExternalDataError;
82use vmbus_async::queue::IncomingPacket;
83use vmbus_async::queue::Queue;
84use vmbus_channel::bus::OfferParams;
85use vmbus_channel::bus::OpenRequest;
86use vmbus_channel::channel::ChannelControl;
87use vmbus_channel::channel::ChannelOpenError;
88use vmbus_channel::channel::ChannelRestoreError;
89use vmbus_channel::channel::DeviceResources;
90use vmbus_channel::channel::RestoreControl;
91use vmbus_channel::channel::SaveRestoreVmbusDevice;
92use vmbus_channel::channel::VmbusDevice;
93use vmbus_channel::gpadl::GpadlId;
94use vmbus_channel::gpadl::GpadlMapView;
95use vmbus_channel::gpadl::GpadlView;
96use vmbus_channel::gpadl::UnknownGpadlId;
97use vmbus_channel::gpadl_ring::GpadlRingMem;
98use vmbus_channel::gpadl_ring::gpadl_channel;
99use vmbus_ring as ring;
100use vmbus_ring::OutgoingPacketType;
101use vmbus_ring::RingMem;
102use vmbus_ring::gparange::MultiPagedRangeBuf;
103use vmcore::save_restore::RestoreError;
104use vmcore::save_restore::SaveError;
105use vmcore::save_restore::SavedStateBlob;
106use vmcore::vm_task::VmTaskDriver;
107use vmcore::vm_task::VmTaskDriverSource;
108use zerocopy::FromBytes;
109use zerocopy::FromZeros;
110use zerocopy::Immutable;
111use zerocopy::IntoBytes;
112use zerocopy::KnownLayout;
113
114const MIN_CONTROL_RING_SIZE: usize = 144;
117
118const MIN_STATE_CHANGE_RING_SIZE: usize = 196;
121
122const VF_ASSOCIATION_TRANSACTION_ID: u64 = 0x8000000000000000;
124const SWITCH_DATA_PATH_TRANSACTION_ID: u64 = 0x8000000000000001;
126
127const NETVSP_MAX_SUBCHANNELS_PER_VNIC: u16 = 64;
128
129#[cfg(not(test))]
137const VF_DEVICE_DELAY: Duration = Duration::from_secs(1);
138#[cfg(test)]
139const VF_DEVICE_DELAY: Duration = Duration::from_millis(100);
140
141#[cfg(not(test))]
144const LINK_DELAY_DURATION: Duration = Duration::from_secs(5);
145#[cfg(test)]
146const LINK_DELAY_DURATION: Duration = Duration::from_millis(333);
147
148#[derive(Default, PartialEq)]
149struct CoordinatorMessageUpdateType {
150 guest_vf_state: bool,
153 filter_state: bool,
155}
156
157#[derive(PartialEq)]
158enum CoordinatorMessage {
159 Update(CoordinatorMessageUpdateType),
161 Restart,
164 StartTimer(Instant),
166}
167
168struct Worker<T: RingMem> {
169 channel_idx: u16,
170 target_vp: u32,
171 mem: GuestMemory,
172 channel: NetChannel<T>,
173 state: WorkerState,
174 coordinator_send: mpsc::Sender<CoordinatorMessage>,
175}
176
177struct NetQueue {
178 driver: VmTaskDriver,
179 queue_state: Option<QueueState>,
180}
181
182impl<T: RingMem + 'static + Sync> InspectTaskMut<Worker<T>> for NetQueue {
183 fn inspect_mut(&mut self, req: inspect::Request<'_>, worker: Option<&mut Worker<T>>) {
184 if worker.is_none() && self.queue_state.is_none() {
185 req.ignore();
186 return;
187 }
188
189 let mut resp = req.respond();
190 resp.field("driver", &self.driver);
191 if let Some(worker) = worker {
192 resp.field(
193 "protocol_state",
194 match &worker.state {
195 WorkerState::Init(None) => "version",
196 WorkerState::Init(Some(_)) => "init",
197 WorkerState::Ready(_) => "ready",
198 WorkerState::WaitingForCoordinator(_) => "waiting for coordinator",
199 },
200 )
201 .field("ring", &worker.channel.queue)
202 .field(
203 "can_use_ring_size_optimization",
204 worker.channel.can_use_ring_size_opt,
205 );
206
207 if let WorkerState::Ready(state) = &worker.state {
208 resp.field(
209 "outstanding_tx_packets",
210 state.state.pending_tx_packets.len() - state.state.free_tx_packets.len(),
211 )
212 .field("pending_rx_packets", state.state.pending_rx_packets.len())
213 .field(
214 "pending_tx_completions",
215 state.state.pending_tx_completions.len(),
216 )
217 .field("free_tx_packets", state.state.free_tx_packets.len())
218 .merge(&state.state.stats);
219 }
220
221 resp.field("packet_filter", worker.channel.packet_filter);
222 }
223
224 if let Some(queue_state) = &mut self.queue_state {
225 resp.field_mut("queue", &mut queue_state.queue)
226 .field("rx_buffers", queue_state.rx_buffer_range.id_range.len())
227 .field(
228 "rx_buffers_start",
229 queue_state.rx_buffer_range.id_range.start,
230 );
231 }
232 }
233}
234
235enum WorkerState {
236 Init(Option<InitState>),
237 Ready(ReadyState),
238 WaitingForCoordinator(Option<ReadyState>),
239}
240
241impl WorkerState {
242 fn ready(&self) -> Option<&ReadyState> {
243 if let Self::Ready(state) = self {
244 Some(state)
245 } else {
246 None
247 }
248 }
249
250 fn ready_mut(&mut self) -> Option<&mut ReadyState> {
251 if let Self::Ready(state) = self {
252 Some(state)
253 } else {
254 None
255 }
256 }
257}
258
259struct InitState {
260 version: Version,
261 ndis_config: Option<NdisConfig>,
262 ndis_version: Option<NdisVersion>,
263 recv_buffer: Option<ReceiveBuffer>,
264 send_buffer: Option<SendBuffer>,
265}
266
267#[derive(Copy, Clone, Debug, Inspect)]
268struct NdisVersion {
269 #[inspect(hex)]
270 major: u32,
271 #[inspect(hex)]
272 minor: u32,
273}
274
275#[derive(Copy, Clone, Debug, Inspect)]
276struct NdisConfig {
277 #[inspect(safe)]
278 mtu: u32,
279 #[inspect(safe)]
280 capabilities: protocol::NdisConfigCapabilities,
281}
282
283struct ReadyState {
284 buffers: Arc<ChannelBuffers>,
285 state: ActiveState,
286 data: ProcessingData,
287}
288
289#[async_trait]
292pub trait VirtualFunction: Sync + Send {
293 async fn id(&self) -> Option<u32>;
297 async fn guest_ready_for_device(&mut self);
299 async fn wait_for_state_change(&mut self) -> Rpc<(), ()>;
302}
303
304struct Adapter {
305 driver: VmTaskDriver,
306 mac_address: MacAddress,
307 max_queues: u16,
308 indirection_table_size: u16,
309 offload_support: OffloadConfig,
310 ring_size_limit: AtomicUsize,
311 free_tx_packet_threshold: usize,
312 tx_fast_completions: bool,
313 adapter_index: u32,
314 get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
315 num_sub_channels_opened: AtomicUsize,
316 link_speed: u64,
317}
318
319struct QueueState {
320 queue: Box<dyn net_backend::Queue>,
321 rx_buffer_range: RxBufferRange,
322 target_vp_set: bool,
323}
324
325struct RxBufferRange {
326 id_range: Range<u32>,
327 remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
328 remote_ranges: Arc<RxBufferRanges>,
329}
330
331impl RxBufferRange {
332 fn new(
333 ranges: Arc<RxBufferRanges>,
334 id_range: Range<u32>,
335 remote_buffer_id_recv: Option<mpsc::UnboundedReceiver<u32>>,
336 ) -> Self {
337 Self {
338 id_range,
339 remote_buffer_id_recv,
340 remote_ranges: ranges,
341 }
342 }
343
344 fn send_if_remote(&self, id: u32) -> bool {
345 if id < RX_RESERVED_CONTROL_BUFFERS || self.id_range.contains(&id) {
348 false
349 } else {
350 let i = (id - RX_RESERVED_CONTROL_BUFFERS) / self.remote_ranges.buffers_per_queue;
351 let i = (i as usize).min(self.remote_ranges.buffer_id_send.len() - 1);
355 let _ = self.remote_ranges.buffer_id_send[i].unbounded_send(id);
356 true
357 }
358 }
359}
360
361struct RxBufferRanges {
362 buffers_per_queue: u32,
363 buffer_id_send: Vec<mpsc::UnboundedSender<u32>>,
364}
365
366impl RxBufferRanges {
367 fn new(buffer_count: u32, queue_count: u32) -> (Self, Vec<mpsc::UnboundedReceiver<u32>>) {
368 let buffers_per_queue = (buffer_count - RX_RESERVED_CONTROL_BUFFERS) / queue_count;
369 #[expect(clippy::disallowed_methods)] let (send, recv): (Vec<_>, Vec<_>) = (0..queue_count).map(|_| mpsc::unbounded()).unzip();
371 (
372 Self {
373 buffers_per_queue,
374 buffer_id_send: send,
375 },
376 recv,
377 )
378 }
379}
380
381struct RssState {
382 key: [u8; 40],
383 indirection_table: Vec<u16>,
384}
385
386struct NetChannel<T: RingMem> {
388 adapter: Arc<Adapter>,
389 queue: Queue<T>,
390 gpadl_map: GpadlMapView,
391 packet_size: PacketSize,
392 pending_send_size: usize,
393 restart: Option<CoordinatorMessage>,
394 can_use_ring_size_opt: bool,
395 packet_filter: u32,
396}
397
398#[derive(Debug, Copy, Clone)]
400enum PacketSize {
401 V1,
403 V61,
405}
406
407struct ProcessingData {
409 tx_segments: Vec<TxSegment>,
410 tx_segments_sent: usize,
411 tx_done: Box<[TxId]>,
412 rx_ready: Box<[RxId]>,
413 rx_done: Vec<RxId>,
414 transfer_pages: Vec<ring::TransferPageRange>,
415 external_data: MultiPagedRangeBuf,
416}
417
418impl ProcessingData {
419 fn new() -> Self {
420 Self {
421 tx_segments: Vec::new(),
422 tx_segments_sent: 0,
423 tx_done: vec![TxId(0); 8192].into(),
424 rx_ready: vec![RxId(0); RX_BATCH_SIZE].into(),
425 rx_done: Vec::with_capacity(RX_BATCH_SIZE),
426 transfer_pages: Vec::with_capacity(RX_BATCH_SIZE),
427 external_data: MultiPagedRangeBuf::new(),
428 }
429 }
430}
431
432#[derive(Debug, Inspect)]
435struct ChannelBuffers {
436 version: Version,
437 #[inspect(skip)]
438 mem: GuestMemory,
439 #[inspect(skip)]
440 recv_buffer: ReceiveBuffer,
441 #[inspect(skip)]
442 send_buffer: Option<SendBuffer>,
443 ndis_version: NdisVersion,
444 #[inspect(safe)]
445 ndis_config: NdisConfig,
446}
447
448#[derive(Copy, Clone, Debug)]
450struct ControlMessageId(u32);
451
452struct ActiveState {
454 primary: Option<PrimaryChannelState>,
455
456 pending_tx_packets: Vec<PendingTxPacket>,
457 free_tx_packets: Vec<TxId>,
458 pending_tx_completions: VecDeque<PendingTxCompletion>,
459 pending_rx_packets: VecDeque<RxId>,
460
461 rx_bufs: RxBuffers,
462
463 stats: QueueStats,
464}
465
466#[derive(Inspect, Default)]
467struct QueueStats {
468 tx_stalled: Counter,
469 rx_dropped_ring_full: Counter,
470 rx_dropped_filtered: Counter,
471 spurious_wakes: Counter,
472 rx_packets: Counter,
473 tx_packets: Counter,
474 tx_lso_packets: Counter,
475 tx_checksum_packets: Counter,
476 tx_invalid_lso_packets: Counter,
477 tx_packets_per_wake: Histogram<10>,
478 rx_packets_per_wake: Histogram<10>,
479}
480
481#[derive(Debug)]
482struct PendingTxCompletion {
483 transaction_id: u64,
484 tx_id: Option<TxId>,
485 status: protocol::Status,
486}
487
488#[derive(Clone, Copy)]
489enum PrimaryChannelGuestVfState {
490 Initializing,
492 Restoring(saved_state::GuestVfState),
494 Unavailable,
496 UnavailableFromAvailable,
498 UnavailableFromDataPathSwitchPending { to_guest: bool, id: Option<u64> },
500 UnavailableFromDataPathSwitched,
502 Available { vfid: u32 },
504 AvailableAdvertised,
506 Ready,
508 DataPathSwitchPending {
510 to_guest: bool,
511 id: Option<u64>,
512 result: Option<bool>,
513 },
514 DataPathSwitched,
516 DataPathSynthetic,
519}
520
521impl std::fmt::Display for PrimaryChannelGuestVfState {
522 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
523 match self {
524 PrimaryChannelGuestVfState::Initializing => write!(f, "initializing"),
525 PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
526 write!(f, "restoring")
527 }
528 PrimaryChannelGuestVfState::Restoring(
529 saved_state::GuestVfState::AvailableAdvertised,
530 ) => write!(f, "restoring from guest notified of vfid"),
531 PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
532 write!(f, "restoring from vf present")
533 }
534 PrimaryChannelGuestVfState::Restoring(
535 saved_state::GuestVfState::DataPathSwitchPending {
536 to_guest, result, ..
537 },
538 ) => {
539 write!(
540 f,
541 "restoring from client requested data path switch: to {} {}",
542 if *to_guest { "guest" } else { "synthetic" },
543 if let Some(result) = result {
544 if *result { "succeeded\"" } else { "failed\"" }
545 } else {
546 "in progress\""
547 }
548 )
549 }
550 PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::DataPathSwitched) => {
551 write!(f, "restoring from data path in guest")
552 }
553 PrimaryChannelGuestVfState::Unavailable => write!(f, "unavailable"),
554 PrimaryChannelGuestVfState::UnavailableFromAvailable => {
555 write!(f, "\"unavailable (previously available)\"")
556 }
557 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. } => {
558 write!(f, "unavailable (previously switching data path)")
559 }
560 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
561 write!(f, "\"unavailable (previously using guest VF)\"")
562 }
563 PrimaryChannelGuestVfState::Available { vfid } => write!(f, "available vfid: {}", vfid),
564 PrimaryChannelGuestVfState::AvailableAdvertised => {
565 write!(f, "\"available, guest notified\"")
566 }
567 PrimaryChannelGuestVfState::Ready => write!(f, "\"available and present in guest\""),
568 PrimaryChannelGuestVfState::DataPathSwitchPending {
569 to_guest, result, ..
570 } => {
571 write!(
572 f,
573 "\"switching to {} {}",
574 if *to_guest { "guest" } else { "synthetic" },
575 if let Some(result) = result {
576 if *result { "succeeded\"" } else { "failed\"" }
577 } else {
578 "in progress\""
579 }
580 )
581 }
582 PrimaryChannelGuestVfState::DataPathSwitched => {
583 write!(f, "\"available and data path switched\"")
584 }
585 PrimaryChannelGuestVfState::DataPathSynthetic => write!(
586 f,
587 "\"available but data path switched back to synthetic due to external state change\""
588 ),
589 }
590 }
591}
592
593impl Inspect for PrimaryChannelGuestVfState {
594 fn inspect(&self, req: inspect::Request<'_>) {
595 req.value(self.to_string());
596 }
597}
598
599#[derive(Debug)]
600enum PendingLinkAction {
601 Default,
602 Active(bool),
603 Delay(bool),
604}
605
606struct PrimaryChannelState {
607 guest_vf_state: PrimaryChannelGuestVfState,
608 is_data_path_switched: Option<bool>,
609 control_messages: VecDeque<ControlMessage>,
610 control_messages_len: usize,
611 free_control_buffers: Vec<ControlMessageId>,
612 rss_state: Option<RssState>,
613 requested_num_queues: u16,
614 rndis_state: RndisState,
615 offload_config: OffloadConfig,
616 pending_offload_change: bool,
617 tx_spread_sent: bool,
618 guest_link_up: bool,
619 pending_link_action: PendingLinkAction,
620}
621
622impl Inspect for PrimaryChannelState {
623 fn inspect(&self, req: inspect::Request<'_>) {
624 req.respond()
625 .sensitivity_field(
626 "guest_vf_state",
627 SensitivityLevel::Safe,
628 self.guest_vf_state,
629 )
630 .sensitivity_field(
631 "data_path_switched",
632 SensitivityLevel::Safe,
633 self.is_data_path_switched,
634 )
635 .sensitivity_field(
636 "pending_control_messages",
637 SensitivityLevel::Safe,
638 self.control_messages.len(),
639 )
640 .sensitivity_field(
641 "free_control_message_buffers",
642 SensitivityLevel::Safe,
643 self.free_control_buffers.len(),
644 )
645 .sensitivity_field(
646 "pending_offload_change",
647 SensitivityLevel::Safe,
648 self.pending_offload_change,
649 )
650 .sensitivity_field("rndis_state", SensitivityLevel::Safe, self.rndis_state)
651 .sensitivity_field(
652 "offload_config",
653 SensitivityLevel::Safe,
654 &self.offload_config,
655 )
656 .sensitivity_field(
657 "tx_spread_sent",
658 SensitivityLevel::Safe,
659 self.tx_spread_sent,
660 )
661 .sensitivity_field("guest_link_up", SensitivityLevel::Safe, self.guest_link_up)
662 .sensitivity_field(
663 "pending_link_action",
664 SensitivityLevel::Safe,
665 match &self.pending_link_action {
666 PendingLinkAction::Active(up) => format!("Active({:x?})", up),
667 PendingLinkAction::Delay(up) => format!("Delay({:x?})", up),
668 PendingLinkAction::Default => "None".to_string(),
669 },
670 );
671 }
672}
673
674#[derive(Debug, Inspect, Clone)]
675struct OffloadConfig {
676 #[inspect(safe)]
677 checksum_tx: ChecksumOffloadConfig,
678 #[inspect(safe)]
679 checksum_rx: ChecksumOffloadConfig,
680 #[inspect(safe)]
681 lso4: bool,
682 #[inspect(safe)]
683 lso6: bool,
684}
685
686impl OffloadConfig {
687 fn mask_to_supported(&mut self, supported: &OffloadConfig) {
688 self.checksum_tx.mask_to_supported(&supported.checksum_tx);
689 self.checksum_rx.mask_to_supported(&supported.checksum_rx);
690 self.lso4 &= supported.lso4;
691 self.lso6 &= supported.lso6;
692 }
693}
694
695#[derive(Debug, Inspect, Clone)]
696struct ChecksumOffloadConfig {
697 #[inspect(safe)]
698 ipv4_header: bool,
699 #[inspect(safe)]
700 tcp4: bool,
701 #[inspect(safe)]
702 udp4: bool,
703 #[inspect(safe)]
704 tcp6: bool,
705 #[inspect(safe)]
706 udp6: bool,
707}
708
709impl ChecksumOffloadConfig {
710 fn mask_to_supported(&mut self, supported: &ChecksumOffloadConfig) {
711 self.ipv4_header &= supported.ipv4_header;
712 self.tcp4 &= supported.tcp4;
713 self.udp4 &= supported.udp4;
714 self.tcp6 &= supported.tcp6;
715 self.udp6 &= supported.udp6;
716 }
717
718 fn flags(
719 &self,
720 ) -> (
721 rndisprot::Ipv4ChecksumOffload,
722 rndisprot::Ipv6ChecksumOffload,
723 ) {
724 let on = rndisprot::NDIS_OFFLOAD_SUPPORTED;
725 let mut v4 = rndisprot::Ipv4ChecksumOffload::new();
726 let mut v6 = rndisprot::Ipv6ChecksumOffload::new();
727 if self.ipv4_header {
728 v4.set_ip_options_supported(on);
729 v4.set_ip_checksum(on);
730 }
731 if self.tcp4 {
732 v4.set_ip_options_supported(on);
733 v4.set_tcp_options_supported(on);
734 v4.set_tcp_checksum(on);
735 }
736 if self.tcp6 {
737 v6.set_ip_extension_headers_supported(on);
738 v6.set_tcp_options_supported(on);
739 v6.set_tcp_checksum(on);
740 }
741 if self.udp4 {
742 v4.set_ip_options_supported(on);
743 v4.set_udp_checksum(on);
744 }
745 if self.udp6 {
746 v6.set_ip_extension_headers_supported(on);
747 v6.set_udp_checksum(on);
748 }
749 (v4, v6)
750 }
751}
752
753impl OffloadConfig {
754 fn ndis_offload(&self) -> rndisprot::NdisOffload {
755 let checksum = {
756 let (ipv4_tx_flags, ipv6_tx_flags) = self.checksum_tx.flags();
757 let (ipv4_rx_flags, ipv6_rx_flags) = self.checksum_rx.flags();
758 rndisprot::TcpIpChecksumOffload {
759 ipv4_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
760 ipv4_tx_flags,
761 ipv4_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
762 ipv4_rx_flags,
763 ipv6_tx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
764 ipv6_tx_flags,
765 ipv6_rx_encapsulation: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
766 ipv6_rx_flags,
767 }
768 };
769
770 let lso_v2 = {
771 let mut lso = rndisprot::TcpLargeSendOffloadV2::new_zeroed();
772 if self.lso4 {
773 lso.ipv4_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
774 lso.ipv4_max_offload_size = rndisprot::LSO_MAX_OFFLOAD_SIZE;
775 lso.ipv4_min_segment_count = rndisprot::LSO_MIN_SEGMENT_COUNT;
776 }
777 if self.lso6 {
778 lso.ipv6_encapsulation = rndisprot::NDIS_ENCAPSULATION_IEEE_802_3;
779 lso.ipv6_max_offload_size = rndisprot::LSO_MAX_OFFLOAD_SIZE;
780 lso.ipv6_min_segment_count = rndisprot::LSO_MIN_SEGMENT_COUNT;
781 lso.ipv6_flags = rndisprot::Ipv6LsoFlags::new()
782 .with_ip_extension_headers_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED)
783 .with_tcp_options_supported(rndisprot::NDIS_OFFLOAD_SUPPORTED);
784 }
785 lso
786 };
787
788 rndisprot::NdisOffload {
789 header: rndisprot::NdisObjectHeader {
790 object_type: rndisprot::NdisObjectType::OFFLOAD,
791 revision: 3,
792 size: rndisprot::NDIS_SIZEOF_NDIS_OFFLOAD_REVISION_3 as u16,
793 },
794 checksum,
795 lso_v2,
796 ..FromZeros::new_zeroed()
797 }
798 }
799}
800
801#[derive(Debug, Inspect, PartialEq, Eq, Copy, Clone)]
802pub enum RndisState {
803 Initializing,
804 Operational,
805 Halted,
806}
807
808impl PrimaryChannelState {
809 fn new(offload_config: OffloadConfig) -> Self {
810 Self {
811 guest_vf_state: PrimaryChannelGuestVfState::Initializing,
812 is_data_path_switched: None,
813 control_messages: VecDeque::new(),
814 control_messages_len: 0,
815 free_control_buffers: (0..RX_RESERVED_CONTROL_BUFFERS)
816 .map(ControlMessageId)
817 .collect(),
818 rss_state: None,
819 requested_num_queues: 1,
820 rndis_state: RndisState::Initializing,
821 pending_offload_change: false,
822 offload_config,
823 tx_spread_sent: false,
824 guest_link_up: true,
825 pending_link_action: PendingLinkAction::Default,
826 }
827 }
828
829 fn restore(
830 guest_vf_state: &saved_state::GuestVfState,
831 rndis_state: &saved_state::RndisState,
832 offload_config: &saved_state::OffloadConfig,
833 pending_offload_change: bool,
834 num_queues: u16,
835 indirection_table_size: u16,
836 rx_bufs: &RxBuffers,
837 control_messages: Vec<saved_state::IncomingControlMessage>,
838 rss_state: Option<saved_state::RssState>,
839 tx_spread_sent: bool,
840 guest_link_down: bool,
841 pending_link_action: Option<bool>,
842 ) -> Result<Self, NetRestoreError> {
843 let control_messages_len = control_messages.iter().map(|msg| msg.data.len()).sum();
845
846 let control_messages = control_messages
847 .into_iter()
848 .map(|msg| ControlMessage {
849 message_type: msg.message_type,
850 data: msg.data.into(),
851 })
852 .collect();
853
854 let free_control_buffers = (0..RX_RESERVED_CONTROL_BUFFERS)
856 .filter_map(|id| rx_bufs.is_free(id).then_some(ControlMessageId(id)))
857 .collect();
858
859 let rss_state = rss_state
860 .map(|mut rss| {
861 if rss.indirection_table.len() > indirection_table_size as usize {
862 return Err(NetRestoreError::ReducedIndirectionTableSize);
865 }
866 if rss.indirection_table.len() < indirection_table_size as usize {
867 tracing::warn!(
868 saved_indirection_table_size = rss.indirection_table.len(),
869 adapter_indirection_table_size = indirection_table_size,
870 "increasing indirection table size",
871 );
872 let table_clone = rss.indirection_table.clone();
875 let num_to_add = indirection_table_size as usize - rss.indirection_table.len();
876 rss.indirection_table
877 .extend(table_clone.iter().cycle().take(num_to_add));
878 }
879 Ok(RssState {
880 key: rss
881 .key
882 .try_into()
883 .map_err(|_| NetRestoreError::InvalidRssKeySize)?,
884 indirection_table: rss.indirection_table,
885 })
886 })
887 .transpose()?;
888
889 let rndis_state = match rndis_state {
890 saved_state::RndisState::Initializing => RndisState::Initializing,
891 saved_state::RndisState::Operational => RndisState::Operational,
892 saved_state::RndisState::Halted => RndisState::Halted,
893 };
894
895 let guest_vf_state = PrimaryChannelGuestVfState::Restoring(*guest_vf_state);
896 let offload_config = OffloadConfig {
897 checksum_tx: ChecksumOffloadConfig {
898 ipv4_header: offload_config.checksum_tx.ipv4_header,
899 tcp4: offload_config.checksum_tx.tcp4,
900 udp4: offload_config.checksum_tx.udp4,
901 tcp6: offload_config.checksum_tx.tcp6,
902 udp6: offload_config.checksum_tx.udp6,
903 },
904 checksum_rx: ChecksumOffloadConfig {
905 ipv4_header: offload_config.checksum_rx.ipv4_header,
906 tcp4: offload_config.checksum_rx.tcp4,
907 udp4: offload_config.checksum_rx.udp4,
908 tcp6: offload_config.checksum_rx.tcp6,
909 udp6: offload_config.checksum_rx.udp6,
910 },
911 lso4: offload_config.lso4,
912 lso6: offload_config.lso6,
913 };
914
915 let pending_link_action = if let Some(pending) = pending_link_action {
916 PendingLinkAction::Active(pending)
917 } else {
918 PendingLinkAction::Default
919 };
920
921 Ok(Self {
922 guest_vf_state,
923 is_data_path_switched: None,
924 control_messages,
925 control_messages_len,
926 free_control_buffers,
927 rss_state,
928 requested_num_queues: num_queues,
929 rndis_state,
930 pending_offload_change,
931 offload_config,
932 tx_spread_sent,
933 guest_link_up: !guest_link_down,
934 pending_link_action,
935 })
936 }
937}
938
939struct ControlMessage {
940 message_type: u32,
941 data: Box<[u8]>,
942}
943
944const TX_PACKET_QUOTA: usize = 1024;
945
946impl ActiveState {
947 fn new(primary: Option<PrimaryChannelState>, recv_buffer_count: u32) -> Self {
948 Self {
949 primary,
950 pending_tx_packets: vec![Default::default(); TX_PACKET_QUOTA],
951 free_tx_packets: (0..TX_PACKET_QUOTA as u32).rev().map(TxId).collect(),
952 pending_tx_completions: VecDeque::new(),
953 pending_rx_packets: VecDeque::new(),
954 rx_bufs: RxBuffers::new(recv_buffer_count),
955 stats: Default::default(),
956 }
957 }
958
959 fn restore(
960 channel: &saved_state::Channel,
961 recv_buffer_count: u32,
962 ) -> Result<Self, NetRestoreError> {
963 let mut active = Self::new(None, recv_buffer_count);
964 let saved_state::Channel {
965 pending_tx_completions,
966 in_use_rx,
967 } = channel;
968 for rx in in_use_rx {
969 active
970 .rx_bufs
971 .allocate(rx.buffers.as_slice().iter().copied())?;
972 }
973 for &transaction_id in pending_tx_completions {
974 let tx_id = active.free_tx_packets.pop();
978 if let Some(id) = tx_id {
979 active.pending_tx_packets[id.0 as usize].transaction_id = transaction_id;
981 }
982 active
985 .pending_tx_completions
986 .push_back(PendingTxCompletion {
987 transaction_id,
988 tx_id,
989 status: protocol::Status::SUCCESS,
990 });
991 }
992 Ok(active)
993 }
994}
995
996#[derive(Default, Clone)]
999struct PendingTxPacket {
1000 pending_packet_count: usize,
1001 transaction_id: u64,
1002}
1003
1004const RX_BATCH_SIZE: usize = 375;
1009
1010const RX_RESERVED_CONTROL_BUFFERS: u32 = 16;
1012
1013pub struct Nic {
1015 instance_id: Guid,
1016 resources: DeviceResources,
1017 coordinator: TaskControl<CoordinatorState, Coordinator>,
1018 coordinator_send: Option<mpsc::Sender<CoordinatorMessage>>,
1019 adapter: Arc<Adapter>,
1020 driver_source: VmTaskDriverSource,
1021}
1022
1023pub struct NicBuilder {
1024 virtual_function: Option<Box<dyn VirtualFunction>>,
1025 limit_ring_buffer: bool,
1026 max_queues: u16,
1027 get_guest_os_id: Option<Box<dyn Fn() -> HvGuestOsId + Send + Sync>>,
1028}
1029
1030impl NicBuilder {
1031 pub fn limit_ring_buffer(mut self, limit: bool) -> Self {
1032 self.limit_ring_buffer = limit;
1033 self
1034 }
1035
1036 pub fn max_queues(mut self, max_queues: u16) -> Self {
1037 self.max_queues = max_queues;
1038 self
1039 }
1040
1041 pub fn virtual_function(mut self, virtual_function: Box<dyn VirtualFunction>) -> Self {
1042 self.virtual_function = Some(virtual_function);
1043 self
1044 }
1045
1046 pub fn get_guest_os_id(mut self, os_type: Box<dyn Fn() -> HvGuestOsId + Send + Sync>) -> Self {
1047 self.get_guest_os_id = Some(os_type);
1048 self
1049 }
1050
1051 pub fn build(
1053 self,
1054 driver_source: &VmTaskDriverSource,
1055 instance_id: Guid,
1056 endpoint: Box<dyn Endpoint>,
1057 mac_address: MacAddress,
1058 adapter_index: u32,
1059 ) -> Nic {
1060 let multiqueue = endpoint.multiqueue_support();
1061
1062 let max_queues = self.max_queues.clamp(
1063 1,
1064 multiqueue.max_queues.min(NETVSP_MAX_SUBCHANNELS_PER_VNIC),
1065 );
1066
1067 let ring_size_limit = if self.limit_ring_buffer { 1024 } else { 0 };
1072
1073 let free_tx_packet_threshold = if endpoint.tx_fast_completions() {
1078 TX_PACKET_QUOTA
1079 } else {
1080 TX_PACKET_QUOTA / 4
1083 };
1084
1085 let tx_offloads = endpoint.tx_offload_support();
1086
1087 let offload_support = OffloadConfig {
1090 checksum_rx: ChecksumOffloadConfig {
1091 ipv4_header: true,
1092 tcp4: true,
1093 udp4: true,
1094 tcp6: true,
1095 udp6: true,
1096 },
1097 checksum_tx: ChecksumOffloadConfig {
1098 ipv4_header: tx_offloads.ipv4_header,
1099 tcp4: tx_offloads.tcp,
1100 tcp6: tx_offloads.tcp,
1101 udp4: tx_offloads.udp,
1102 udp6: tx_offloads.udp,
1103 },
1104 lso4: tx_offloads.tso,
1105 lso6: tx_offloads.tso,
1106 };
1107
1108 let driver = driver_source.simple();
1109 let adapter = Arc::new(Adapter {
1110 driver,
1111 mac_address,
1112 max_queues,
1113 indirection_table_size: multiqueue.indirection_table_size,
1114 offload_support,
1115 free_tx_packet_threshold,
1116 ring_size_limit: ring_size_limit.into(),
1117 tx_fast_completions: endpoint.tx_fast_completions(),
1118 adapter_index,
1119 get_guest_os_id: self.get_guest_os_id,
1120 num_sub_channels_opened: AtomicUsize::new(0),
1121 link_speed: endpoint.link_speed(),
1122 });
1123
1124 let coordinator = TaskControl::new(CoordinatorState {
1125 endpoint,
1126 adapter: adapter.clone(),
1127 virtual_function: self.virtual_function,
1128 pending_vf_state: CoordinatorStatePendingVfState::Ready,
1129 });
1130
1131 Nic {
1132 instance_id,
1133 resources: Default::default(),
1134 coordinator,
1135 coordinator_send: None,
1136 adapter,
1137 driver_source: driver_source.clone(),
1138 }
1139 }
1140}
1141
1142fn can_use_ring_opt<T: RingMem>(queue: &mut Queue<T>, guest_os_id: Option<HvGuestOsId>) -> bool {
1143 let Some(guest_os_id) = guest_os_id else {
1144 return false;
1146 };
1147
1148 if !queue.split().0.supports_pending_send_size() {
1149 return false;
1151 }
1152
1153 let Some(open_source_os) = guest_os_id.open_source() else {
1154 return true;
1156 };
1157
1158 match HvGuestOsOpenSourceType(open_source_os.os_type()) {
1159 HvGuestOsOpenSourceType::FREEBSD => open_source_os.version() >= 1400097,
1162 HvGuestOsOpenSourceType::LINUX => {
1166 open_source_os.version() >= 199424
1169 }
1170 _ => true,
1171 }
1172}
1173
1174impl Nic {
1175 pub fn builder() -> NicBuilder {
1176 NicBuilder {
1177 virtual_function: None,
1178 limit_ring_buffer: false,
1179 max_queues: !0,
1180 get_guest_os_id: None,
1181 }
1182 }
1183
1184 pub fn shutdown(self) -> Box<dyn Endpoint> {
1185 let (state, _) = self.coordinator.into_inner();
1186 state.endpoint
1187 }
1188}
1189
1190impl InspectMut for Nic {
1191 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
1192 self.coordinator.inspect_mut(req);
1193 }
1194}
1195
1196#[async_trait]
1197impl VmbusDevice for Nic {
1198 fn offer(&self) -> OfferParams {
1199 OfferParams {
1200 interface_name: "net".to_owned(),
1201 instance_id: self.instance_id,
1202 interface_id: Guid {
1203 data1: 0xf8615163,
1204 data2: 0xdf3e,
1205 data3: 0x46c5,
1206 data4: [0x91, 0x3f, 0xf2, 0xd2, 0xf9, 0x65, 0xed, 0xe],
1207 },
1208 subchannel_index: 0,
1209 mnf_interrupt_latency: Some(Duration::from_micros(100)),
1210 ..Default::default()
1211 }
1212 }
1213
1214 fn max_subchannels(&self) -> u16 {
1215 self.adapter.max_queues
1216 }
1217
1218 fn install(&mut self, resources: DeviceResources) {
1219 self.resources = resources;
1220 }
1221
1222 async fn open(
1223 &mut self,
1224 channel_idx: u16,
1225 open_request: &OpenRequest,
1226 ) -> Result<(), ChannelOpenError> {
1227 let state = if channel_idx == 0 {
1229 self.insert_coordinator(1, None);
1230 WorkerState::Init(None)
1231 } else {
1232 self.coordinator.stop().await;
1233 let buffers = self.coordinator.state().unwrap().buffers.clone().unwrap();
1235 WorkerState::Ready(ReadyState {
1236 state: ActiveState::new(None, buffers.recv_buffer.count),
1237 buffers,
1238 data: ProcessingData::new(),
1239 })
1240 };
1241
1242 let num_opened = self
1243 .adapter
1244 .num_sub_channels_opened
1245 .fetch_add(1, Ordering::SeqCst);
1246 let r = self.insert_worker(channel_idx, open_request, state, true);
1247 if channel_idx != 0
1248 && num_opened + 1 == self.coordinator.state_mut().unwrap().num_queues as usize
1249 {
1250 let coordinator = &mut self.coordinator.state_mut().unwrap();
1251 coordinator.workers[0].stop().await;
1252 }
1253
1254 if r.is_err() && channel_idx == 0 {
1255 self.coordinator.remove();
1256 } else {
1257 self.coordinator.start();
1259 }
1260 r?;
1261 Ok(())
1262 }
1263
1264 async fn close(&mut self, channel_idx: u16) {
1265 if !self.coordinator.has_state() {
1266 tracing::error!("Close called while vmbus channel is already closed");
1267 return;
1268 }
1269
1270 let restart = self.coordinator.stop().await;
1272
1273 {
1275 let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1276 worker.stop().await;
1277 if worker.has_state() {
1278 worker.remove();
1279 }
1280 }
1281
1282 self.adapter
1283 .num_sub_channels_opened
1284 .fetch_sub(1, Ordering::SeqCst);
1285 if channel_idx == 0 {
1287 for worker in &mut self.coordinator.state_mut().unwrap().workers {
1288 worker.task_mut().queue_state = None;
1289 }
1290
1291 self.coordinator.task_mut().endpoint.stop().await;
1293
1294 self.coordinator.remove();
1299 } else {
1300 if restart {
1302 self.coordinator.start();
1303 }
1304 }
1305 }
1306
1307 async fn retarget_vp(&mut self, channel_idx: u16, target_vp: u32) {
1308 if !self.coordinator.has_state() {
1309 return;
1310 }
1311
1312 let coordinator_running = self.coordinator.stop().await;
1314 let worker = &mut self.coordinator.state_mut().unwrap().workers[channel_idx as usize];
1315 worker.stop().await;
1316 let (net_queue, worker_state) = worker.get_mut();
1317
1318 net_queue.driver.retarget_vp(target_vp);
1320
1321 if let Some(worker_state) = worker_state {
1322 worker_state.target_vp = target_vp;
1324 if let Some(queue_state) = &mut net_queue.queue_state {
1325 queue_state.target_vp_set = false;
1327 }
1328 }
1329
1330 if coordinator_running {
1332 self.coordinator.start();
1333 }
1334 }
1335
1336 fn start(&mut self) {
1337 if !self.coordinator.is_running() {
1338 self.coordinator.start();
1339 }
1340 }
1341
1342 async fn stop(&mut self) {
1343 self.coordinator.stop().await;
1344 if let Some(coordinator) = self.coordinator.state_mut() {
1345 coordinator.stop_workers().await;
1346 }
1347 }
1348
1349 fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
1350 Some(self)
1351 }
1352}
1353
1354#[async_trait]
1355impl SaveRestoreVmbusDevice for Nic {
1356 async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
1357 let state = self.saved_state();
1358 Ok(SavedStateBlob::new(state))
1359 }
1360
1361 async fn restore(
1362 &mut self,
1363 control: RestoreControl<'_>,
1364 state: SavedStateBlob,
1365 ) -> Result<(), RestoreError> {
1366 let state: saved_state::SavedState = state.parse()?;
1367 if let Err(err) = self.restore_state(control, state).await {
1368 tracing::error!(err = &err as &dyn std::error::Error, instance_id = %self.instance_id, "Failed restoring network vmbus state");
1369 Err(err.into())
1370 } else {
1371 Ok(())
1372 }
1373 }
1374}
1375
1376impl Nic {
1377 fn insert_worker(
1381 &mut self,
1382 channel_idx: u16,
1383 open_request: &OpenRequest,
1384 state: WorkerState,
1385 start: bool,
1386 ) -> Result<(), OpenError> {
1387 let coordinator = self.coordinator.state_mut().unwrap();
1388
1389 let driver = coordinator.workers[channel_idx as usize]
1391 .task()
1392 .driver
1393 .clone();
1394 driver.retarget_vp(open_request.open_data.target_vp);
1395
1396 let raw = gpadl_channel(&driver, &self.resources, open_request, channel_idx)
1397 .map_err(OpenError::Ring)?;
1398 let mut queue = Queue::new(raw).map_err(OpenError::Queue)?;
1399 let guest_os_id = self.adapter.get_guest_os_id.as_ref().map(|f| f());
1400 let can_use_ring_size_opt = can_use_ring_opt(&mut queue, guest_os_id);
1401 let worker = Worker {
1402 channel_idx,
1403 target_vp: open_request.open_data.target_vp,
1404 mem: self
1405 .resources
1406 .offer_resources
1407 .guest_memory(open_request)
1408 .clone(),
1409 channel: NetChannel {
1410 adapter: self.adapter.clone(),
1411 queue,
1412 gpadl_map: self.resources.gpadl_map.clone(),
1413 packet_size: PacketSize::V1,
1414 pending_send_size: 0,
1415 restart: None,
1416 can_use_ring_size_opt,
1417 packet_filter: coordinator.active_packet_filter,
1418 },
1419 state,
1420 coordinator_send: self.coordinator_send.clone().unwrap(),
1421 };
1422 let instance_id = self.instance_id;
1423 let worker_task = &mut coordinator.workers[channel_idx as usize];
1424 worker_task.insert(
1425 driver,
1426 format!("netvsp-{}-{}", instance_id, channel_idx),
1427 worker,
1428 );
1429 if start {
1430 worker_task.start();
1431 }
1432 Ok(())
1433 }
1434}
1435
1436struct RestoreCoordinatorState {
1437 active_packet_filter: u32,
1438}
1439
1440impl Nic {
1441 fn insert_coordinator(&mut self, num_queues: u16, restoring: Option<RestoreCoordinatorState>) {
1443 let mut driver_builder = self.driver_source.builder();
1444 driver_builder.target_vp(0);
1447 driver_builder.run_on_target(!self.adapter.tx_fast_completions);
1452
1453 #[expect(clippy::disallowed_methods)] let (send, recv) = mpsc::channel(1);
1455 self.coordinator_send = Some(send);
1456 self.coordinator.insert(
1457 &self.adapter.driver,
1458 format!("netvsp-{}-coordinator", self.instance_id),
1459 Coordinator {
1460 recv,
1461 channel_control: self.resources.channel_control.clone(),
1462 restart: restoring.is_some(),
1463 workers: (0..self.adapter.max_queues)
1464 .map(|i| {
1465 TaskControl::new(NetQueue {
1466 queue_state: None,
1467 driver: driver_builder
1468 .build(format!("netvsp-{}-{}", self.instance_id, i)),
1469 })
1470 })
1471 .collect(),
1472 buffers: None,
1473 num_queues,
1474 active_packet_filter: restoring
1475 .map(|r| r.active_packet_filter)
1476 .unwrap_or(rndisprot::NDIS_PACKET_TYPE_NONE),
1477 sleep_deadline: None,
1478 },
1479 );
1480 }
1481}
1482
1483#[derive(Debug, Error)]
1484enum NetRestoreError {
1485 #[error("unsupported protocol version {0:#x}")]
1486 UnsupportedVersion(u32),
1487 #[error("send/receive buffer invalid gpadl ID")]
1488 UnknownGpadlId(#[from] UnknownGpadlId),
1489 #[error("failed to restore channels")]
1490 Channel(#[source] ChannelRestoreError),
1491 #[error(transparent)]
1492 ReceiveBuffer(#[from] BufferError),
1493 #[error(transparent)]
1494 SuballocationMisconfigured(#[from] SubAllocationInUse),
1495 #[error(transparent)]
1496 Open(#[from] OpenError),
1497 #[error("invalid rss key size")]
1498 InvalidRssKeySize,
1499 #[error("reduced indirection table size")]
1500 ReducedIndirectionTableSize,
1501}
1502
1503impl From<NetRestoreError> for RestoreError {
1504 fn from(err: NetRestoreError) -> Self {
1505 RestoreError::InvalidSavedState(anyhow::Error::new(err))
1506 }
1507}
1508
1509impl Nic {
1510 async fn restore_state(
1511 &mut self,
1512 mut control: RestoreControl<'_>,
1513 state: saved_state::SavedState,
1514 ) -> Result<(), NetRestoreError> {
1515 let mut saved_packet_filter = 0u32;
1516 if let Some(state) = state.open {
1517 let open = match &state.primary {
1518 saved_state::Primary::Version => vec![true],
1519 saved_state::Primary::Init(_) => vec![true],
1520 saved_state::Primary::Ready(ready) => {
1521 ready.channels.iter().map(|x| x.is_some()).collect()
1522 }
1523 };
1524
1525 let mut states: Vec<_> = open.iter().map(|_| None).collect();
1526
1527 let requests = control
1533 .restore(&open)
1534 .await
1535 .map_err(NetRestoreError::Channel)?;
1536
1537 match state.primary {
1538 saved_state::Primary::Version => {
1539 states[0] = Some(WorkerState::Init(None));
1540 }
1541 saved_state::Primary::Init(init) => {
1542 let version = check_version(init.version)
1543 .ok_or(NetRestoreError::UnsupportedVersion(init.version))?;
1544
1545 let recv_buffer = init
1546 .receive_buffer
1547 .map(|recv_buffer| {
1548 ReceiveBuffer::new(
1549 &self.resources.gpadl_map,
1550 recv_buffer.gpadl_id,
1551 recv_buffer.id,
1552 recv_buffer.sub_allocation_size,
1553 )
1554 })
1555 .transpose()?;
1556
1557 let send_buffer = init
1558 .send_buffer
1559 .map(|send_buffer| {
1560 SendBuffer::new(&self.resources.gpadl_map, send_buffer.gpadl_id)
1561 })
1562 .transpose()?;
1563
1564 let state = InitState {
1565 version,
1566 ndis_config: init.ndis_config.map(
1567 |saved_state::NdisConfig { mtu, capabilities }| NdisConfig {
1568 mtu,
1569 capabilities: capabilities.into(),
1570 },
1571 ),
1572 ndis_version: init.ndis_version.map(
1573 |saved_state::NdisVersion { major, minor }| NdisVersion {
1574 major,
1575 minor,
1576 },
1577 ),
1578 recv_buffer,
1579 send_buffer,
1580 };
1581 states[0] = Some(WorkerState::Init(Some(state)));
1582 }
1583 saved_state::Primary::Ready(ready) => {
1584 let saved_state::ReadyPrimary {
1585 version,
1586 receive_buffer,
1587 send_buffer,
1588 mut control_messages,
1589 mut rss_state,
1590 channels,
1591 ndis_version,
1592 ndis_config,
1593 rndis_state,
1594 guest_vf_state,
1595 offload_config,
1596 pending_offload_change,
1597 tx_spread_sent,
1598 guest_link_down,
1599 pending_link_action,
1600 packet_filter,
1601 } = ready;
1602
1603 saved_packet_filter = packet_filter.unwrap_or(rndisprot::NPROTO_PACKET_FILTER);
1605
1606 let version = check_version(version)
1607 .ok_or(NetRestoreError::UnsupportedVersion(version))?;
1608
1609 let request = requests[0].as_ref().unwrap();
1610 let buffers = Arc::new(ChannelBuffers {
1611 version,
1612 mem: self.resources.offer_resources.guest_memory(request).clone(),
1613 recv_buffer: ReceiveBuffer::new(
1614 &self.resources.gpadl_map,
1615 receive_buffer.gpadl_id,
1616 receive_buffer.id,
1617 receive_buffer.sub_allocation_size,
1618 )?,
1619 send_buffer: {
1620 if let Some(send_buffer) = send_buffer {
1621 Some(SendBuffer::new(
1622 &self.resources.gpadl_map,
1623 send_buffer.gpadl_id,
1624 )?)
1625 } else {
1626 None
1627 }
1628 },
1629 ndis_version: {
1630 let saved_state::NdisVersion { major, minor } = ndis_version;
1631 NdisVersion { major, minor }
1632 },
1633 ndis_config: {
1634 let saved_state::NdisConfig { mtu, capabilities } = ndis_config;
1635 NdisConfig {
1636 mtu,
1637 capabilities: capabilities.into(),
1638 }
1639 },
1640 });
1641
1642 for (channel_idx, channel) in channels.iter().enumerate() {
1643 let channel = if let Some(channel) = channel {
1644 channel
1645 } else {
1646 continue;
1647 };
1648
1649 let mut active = ActiveState::restore(channel, buffers.recv_buffer.count)?;
1650
1651 if channel_idx == 0 {
1653 let primary = PrimaryChannelState::restore(
1654 &guest_vf_state,
1655 &rndis_state,
1656 &offload_config,
1657 pending_offload_change,
1658 channels.len() as u16,
1659 self.adapter.indirection_table_size,
1660 &active.rx_bufs,
1661 std::mem::take(&mut control_messages),
1662 rss_state.take(),
1663 tx_spread_sent,
1664 guest_link_down,
1665 pending_link_action,
1666 )?;
1667 active.primary = Some(primary);
1668 }
1669
1670 states[channel_idx] = Some(WorkerState::Ready(ReadyState {
1671 buffers: buffers.clone(),
1672 state: active,
1673 data: ProcessingData::new(),
1674 }));
1675 }
1676 }
1677 }
1678
1679 self.insert_coordinator(
1682 states.len() as u16,
1683 Some(RestoreCoordinatorState {
1684 active_packet_filter: saved_packet_filter,
1685 }),
1686 );
1687
1688 for (channel_idx, (state, request)) in states.into_iter().zip(requests).enumerate() {
1689 if let Some(state) = state {
1690 self.insert_worker(channel_idx as u16, &request.unwrap(), state, false)?;
1691 }
1692 }
1693 } else {
1694 control
1695 .restore(&[false])
1696 .await
1697 .map_err(NetRestoreError::Channel)?;
1698 }
1699 Ok(())
1700 }
1701
1702 fn saved_state(&self) -> saved_state::SavedState {
1703 let open = if let Some(coordinator) = self.coordinator.state() {
1704 let primary = coordinator.workers[0].state().unwrap();
1705 let primary = match &primary.state {
1706 WorkerState::Init(None) => saved_state::Primary::Version,
1707 WorkerState::Init(Some(init)) => {
1708 saved_state::Primary::Init(saved_state::InitPrimary {
1709 version: init.version as u32,
1710 ndis_config: init.ndis_config.map(|NdisConfig { mtu, capabilities }| {
1711 saved_state::NdisConfig {
1712 mtu,
1713 capabilities: capabilities.into(),
1714 }
1715 }),
1716 ndis_version: init.ndis_version.map(|NdisVersion { major, minor }| {
1717 saved_state::NdisVersion { major, minor }
1718 }),
1719 receive_buffer: init.recv_buffer.as_ref().map(|x| x.saved_state()),
1720 send_buffer: init.send_buffer.as_ref().map(|x| saved_state::SendBuffer {
1721 gpadl_id: x.gpadl.id(),
1722 }),
1723 })
1724 }
1725 WorkerState::WaitingForCoordinator(Some(ready)) | WorkerState::Ready(ready) => {
1726 let primary = ready.state.primary.as_ref().unwrap();
1727
1728 let rndis_state = match primary.rndis_state {
1729 RndisState::Initializing => saved_state::RndisState::Initializing,
1730 RndisState::Operational => saved_state::RndisState::Operational,
1731 RndisState::Halted => saved_state::RndisState::Halted,
1732 };
1733
1734 let offload_config = saved_state::OffloadConfig {
1735 checksum_tx: saved_state::ChecksumOffloadConfig {
1736 ipv4_header: primary.offload_config.checksum_tx.ipv4_header,
1737 tcp4: primary.offload_config.checksum_tx.tcp4,
1738 udp4: primary.offload_config.checksum_tx.udp4,
1739 tcp6: primary.offload_config.checksum_tx.tcp6,
1740 udp6: primary.offload_config.checksum_tx.udp6,
1741 },
1742 checksum_rx: saved_state::ChecksumOffloadConfig {
1743 ipv4_header: primary.offload_config.checksum_rx.ipv4_header,
1744 tcp4: primary.offload_config.checksum_rx.tcp4,
1745 udp4: primary.offload_config.checksum_rx.udp4,
1746 tcp6: primary.offload_config.checksum_rx.tcp6,
1747 udp6: primary.offload_config.checksum_rx.udp6,
1748 },
1749 lso4: primary.offload_config.lso4,
1750 lso6: primary.offload_config.lso6,
1751 };
1752
1753 let control_messages = primary
1754 .control_messages
1755 .iter()
1756 .map(|message| saved_state::IncomingControlMessage {
1757 message_type: message.message_type,
1758 data: message.data.to_vec(),
1759 })
1760 .collect();
1761
1762 let rss_state = primary.rss_state.as_ref().map(|rss| saved_state::RssState {
1763 key: rss.key.into(),
1764 indirection_table: rss.indirection_table.clone(),
1765 });
1766
1767 let pending_link_action = match primary.pending_link_action {
1768 PendingLinkAction::Default => None,
1769 PendingLinkAction::Active(action) | PendingLinkAction::Delay(action) => {
1770 Some(action)
1771 }
1772 };
1773
1774 let channels = coordinator.workers[..coordinator.num_queues as usize]
1775 .iter()
1776 .map(|worker| {
1777 worker.state().map(|worker| {
1778 if let Some(ready) = worker.state.ready() {
1779 let pending_tx_completions = ready
1782 .state
1783 .pending_tx_completions
1784 .iter()
1785 .map(|pending| pending.transaction_id)
1786 .chain(ready.state.pending_tx_packets.iter().filter_map(
1787 |inflight| {
1788 (inflight.pending_packet_count > 0)
1789 .then_some(inflight.transaction_id)
1790 },
1791 ))
1792 .collect();
1793
1794 saved_state::Channel {
1795 pending_tx_completions,
1796 in_use_rx: {
1797 ready
1798 .state
1799 .rx_bufs
1800 .allocated()
1801 .map(|id| saved_state::Rx {
1802 buffers: id.collect(),
1803 })
1804 .collect()
1805 },
1806 }
1807 } else {
1808 saved_state::Channel {
1809 pending_tx_completions: Vec::new(),
1810 in_use_rx: Vec::new(),
1811 }
1812 }
1813 })
1814 })
1815 .collect();
1816
1817 let guest_vf_state = match primary.guest_vf_state {
1818 PrimaryChannelGuestVfState::Initializing
1819 | PrimaryChannelGuestVfState::Unavailable
1820 | PrimaryChannelGuestVfState::Available { .. } => {
1821 saved_state::GuestVfState::NoState
1822 }
1823 PrimaryChannelGuestVfState::UnavailableFromAvailable
1824 | PrimaryChannelGuestVfState::AvailableAdvertised => {
1825 saved_state::GuestVfState::AvailableAdvertised
1826 }
1827 PrimaryChannelGuestVfState::Ready => saved_state::GuestVfState::Ready,
1828 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
1829 to_guest,
1830 id,
1831 } => saved_state::GuestVfState::DataPathSwitchPending {
1832 to_guest,
1833 id,
1834 result: None,
1835 },
1836 PrimaryChannelGuestVfState::DataPathSwitchPending {
1837 to_guest,
1838 id,
1839 result,
1840 } => saved_state::GuestVfState::DataPathSwitchPending {
1841 to_guest,
1842 id,
1843 result,
1844 },
1845 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
1846 | PrimaryChannelGuestVfState::DataPathSwitched
1847 | PrimaryChannelGuestVfState::DataPathSynthetic => {
1848 saved_state::GuestVfState::DataPathSwitched
1849 }
1850 PrimaryChannelGuestVfState::Restoring(saved_state) => saved_state,
1851 };
1852
1853 let worker_0_packet_filter = coordinator.workers[0]
1854 .state()
1855 .unwrap()
1856 .channel
1857 .packet_filter;
1858 saved_state::Primary::Ready(saved_state::ReadyPrimary {
1859 version: ready.buffers.version as u32,
1860 receive_buffer: ready.buffers.recv_buffer.saved_state(),
1861 send_buffer: ready.buffers.send_buffer.as_ref().map(|sb| {
1862 saved_state::SendBuffer {
1863 gpadl_id: sb.gpadl.id(),
1864 }
1865 }),
1866 rndis_state,
1867 guest_vf_state,
1868 offload_config,
1869 pending_offload_change: primary.pending_offload_change,
1870 control_messages,
1871 rss_state,
1872 channels,
1873 ndis_config: {
1874 let NdisConfig { mtu, capabilities } = ready.buffers.ndis_config;
1875 saved_state::NdisConfig {
1876 mtu,
1877 capabilities: capabilities.into(),
1878 }
1879 },
1880 ndis_version: {
1881 let NdisVersion { major, minor } = ready.buffers.ndis_version;
1882 saved_state::NdisVersion { major, minor }
1883 },
1884 tx_spread_sent: primary.tx_spread_sent,
1885 guest_link_down: !primary.guest_link_up,
1886 pending_link_action,
1887 packet_filter: Some(worker_0_packet_filter),
1888 })
1889 }
1890 WorkerState::WaitingForCoordinator(None) => {
1891 unreachable!("valid ready state")
1892 }
1893 };
1894
1895 let state = saved_state::OpenState { primary };
1896 Some(state)
1897 } else {
1898 None
1899 };
1900
1901 saved_state::SavedState { open }
1902 }
1903}
1904
1905#[derive(Debug, Error)]
1906enum WorkerError {
1907 #[error("packet error")]
1908 Packet(#[source] PacketError),
1909 #[error("unexpected packet order: {0}")]
1910 UnexpectedPacketOrder(#[source] PacketOrderError),
1911 #[error("unknown rndis message type: {0}")]
1912 UnknownRndisMessageType(u32),
1913 #[error("junk after rndis packet message: {0:#x}")]
1914 NonRndisPacketAfterPacket(u32),
1915 #[error("memory access error")]
1916 Access(#[from] AccessError),
1917 #[error("rndis message too small")]
1918 RndisMessageTooSmall,
1919 #[error("unsupported rndis behavior")]
1920 UnsupportedRndisBehavior,
1921 #[error("vmbus queue error")]
1922 Queue(#[from] queue::Error),
1923 #[error("too many control messages")]
1924 TooManyControlMessages,
1925 #[error("invalid rndis packet completion")]
1926 InvalidRndisPacketCompletion,
1927 #[error("missing transaction id")]
1928 MissingTransactionId,
1929 #[error("invalid gpadl")]
1930 InvalidGpadl(#[source] guestmem::InvalidGpn),
1931 #[error("gpadl error")]
1932 GpadlError(#[source] GuestMemoryError),
1933 #[error("gpa direct error")]
1934 GpaDirectError(#[source] GuestMemoryError),
1935 #[error("endpoint")]
1936 Endpoint(#[source] anyhow::Error),
1937 #[error("message not supported on sub channel: {0}")]
1938 NotSupportedOnSubChannel(u32),
1939 #[error("the ring buffer ran out of space, which should not be possible")]
1940 OutOfSpace,
1941 #[error("send/receive buffer error")]
1942 Buffer(#[from] BufferError),
1943 #[error("invalid rndis state")]
1944 InvalidRndisState,
1945 #[error("rndis message type not implemented")]
1946 RndisMessageTypeNotImplemented,
1947 #[error("invalid TCP header offset")]
1948 InvalidTcpHeaderOffset,
1949 #[error("cancelled")]
1950 Cancelled(task_control::Cancelled),
1951 #[error("tearing down because send/receive buffer is revoked")]
1952 BufferRevoked,
1953 #[error("endpoint requires queue restart: {0}")]
1954 EndpointRequiresQueueRestart(#[source] anyhow::Error),
1955 #[error("Failed to send message to coordinator")]
1956 CoordinatorMessageSendFailed(#[source] TrySendError<CoordinatorMessage>),
1957}
1958
1959impl From<task_control::Cancelled> for WorkerError {
1960 fn from(value: task_control::Cancelled) -> Self {
1961 Self::Cancelled(value)
1962 }
1963}
1964
1965#[derive(Debug, Error)]
1966enum OpenError {
1967 #[error("error establishing ring buffer")]
1968 Ring(#[source] vmbus_channel::gpadl_ring::Error),
1969 #[error("error establishing vmbus queue")]
1970 Queue(#[source] queue::Error),
1971}
1972
1973#[derive(Debug, Error)]
1974enum PacketError {
1975 #[error("UnknownType {0}")]
1976 UnknownType(u32),
1977 #[error("Access")]
1978 Access(#[source] AccessError),
1979 #[error("ExternalData")]
1980 ExternalData(#[source] ExternalDataError),
1981 #[error("InvalidSendBufferIndex")]
1982 InvalidSendBufferIndex,
1983}
1984
1985#[derive(Debug, Error)]
1986enum PacketOrderError {
1987 #[error("Invalid PacketData")]
1988 InvalidPacketData,
1989 #[error("Unexpected RndisPacket")]
1990 UnexpectedRndisPacket,
1991 #[error("SendNdisVersion already exists")]
1992 SendNdisVersionExists,
1993 #[error("SendNdisConfig already exists")]
1994 SendNdisConfigExists,
1995 #[error("SendReceiveBuffer already exists")]
1996 SendReceiveBufferExists,
1997 #[error("SendReceiveBuffer missing MTU")]
1998 SendReceiveBufferMissingMTU,
1999 #[error("SendSendBuffer already exists")]
2000 SendSendBufferExists,
2001 #[error("SwitchDataPathCompletion during PrimaryChannelState")]
2002 SwitchDataPathCompletionPrimaryChannelState,
2003}
2004
2005#[derive(Debug)]
2006enum PacketData {
2007 Init(protocol::MessageInit),
2008 SendNdisVersion(protocol::Message1SendNdisVersion),
2009 SendReceiveBuffer(protocol::Message1SendReceiveBuffer),
2010 SendSendBuffer(protocol::Message1SendSendBuffer),
2011 RevokeReceiveBuffer(Message1RevokeReceiveBuffer),
2012 RevokeSendBuffer(Message1RevokeSendBuffer),
2013 RndisPacket(protocol::Message1SendRndisPacket),
2014 RndisPacketComplete(protocol::Message1SendRndisPacketComplete),
2015 SendNdisConfig(protocol::Message2SendNdisConfig),
2016 SwitchDataPath(protocol::Message4SwitchDataPath),
2017 OidQueryEx(protocol::Message5OidQueryEx),
2018 SubChannelRequest(protocol::Message5SubchannelRequest),
2019 SendVfAssociationCompletion,
2020 SwitchDataPathCompletion,
2021}
2022
2023#[derive(Debug)]
2024struct Packet<'a> {
2025 data: PacketData,
2026 transaction_id: Option<u64>,
2027 external_data: &'a MultiPagedRangeBuf,
2028}
2029
2030type PacketReader<'a> = PagedRangesReader<'a, MultiPagedRangeIter<'a>>;
2031
2032impl Packet<'_> {
2033 fn rndis_reader<'a>(&'a self, mem: &'a GuestMemory) -> PacketReader<'a> {
2034 PagedRanges::new(self.external_data.iter()).reader(mem)
2035 }
2036}
2037
2038fn read_packet_data<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
2039 reader: &mut impl MemoryRead,
2040) -> Result<T, PacketError> {
2041 reader.read_plain().map_err(PacketError::Access)
2042}
2043
2044fn parse_packet<'a, T: RingMem>(
2045 packet_ref: &queue::PacketRef<'_, T>,
2046 send_buffer: Option<&SendBuffer>,
2047 version: Option<Version>,
2048 external_data: &'a mut MultiPagedRangeBuf,
2049) -> Result<Packet<'a>, PacketError> {
2050 external_data.clear();
2051 let packet = match packet_ref.as_ref() {
2052 IncomingPacket::Data(data) => data,
2053 IncomingPacket::Completion(completion) => {
2054 let data = if completion.transaction_id() == VF_ASSOCIATION_TRANSACTION_ID {
2055 PacketData::SendVfAssociationCompletion
2056 } else if completion.transaction_id() == SWITCH_DATA_PATH_TRANSACTION_ID {
2057 PacketData::SwitchDataPathCompletion
2058 } else {
2059 let mut reader = completion.reader();
2060 let header: protocol::MessageHeader =
2061 reader.read_plain().map_err(PacketError::Access)?;
2062 match header.message_type {
2063 protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE => {
2064 PacketData::RndisPacketComplete(read_packet_data(&mut reader)?)
2065 }
2066 typ => return Err(PacketError::UnknownType(typ)),
2067 }
2068 };
2069 return Ok(Packet {
2070 data,
2071 transaction_id: Some(completion.transaction_id()),
2072 external_data,
2073 });
2074 }
2075 };
2076
2077 let mut reader = packet.reader();
2078 let header: protocol::MessageHeader = reader.read_plain().map_err(PacketError::Access)?;
2079 let data = match header.message_type {
2080 protocol::MESSAGE_TYPE_INIT => PacketData::Init(read_packet_data(&mut reader)?),
2081 protocol::MESSAGE1_TYPE_SEND_NDIS_VERSION if version >= Some(Version::V1) => {
2082 PacketData::SendNdisVersion(read_packet_data(&mut reader)?)
2083 }
2084 protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2085 PacketData::SendReceiveBuffer(read_packet_data(&mut reader)?)
2086 }
2087 protocol::MESSAGE1_TYPE_REVOKE_RECEIVE_BUFFER if version >= Some(Version::V1) => {
2088 PacketData::RevokeReceiveBuffer(read_packet_data(&mut reader)?)
2089 }
2090 protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER if version >= Some(Version::V1) => {
2091 PacketData::SendSendBuffer(read_packet_data(&mut reader)?)
2092 }
2093 protocol::MESSAGE1_TYPE_REVOKE_SEND_BUFFER if version >= Some(Version::V1) => {
2094 PacketData::RevokeSendBuffer(read_packet_data(&mut reader)?)
2095 }
2096 protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET if version >= Some(Version::V1) => {
2097 let message: protocol::Message1SendRndisPacket = read_packet_data(&mut reader)?;
2098 if message.send_buffer_section_index != 0xffffffff {
2099 let send_buffer_suballocation = send_buffer
2100 .ok_or(PacketError::InvalidSendBufferIndex)?
2101 .gpadl
2102 .first()
2103 .unwrap()
2104 .try_subrange(
2105 message.send_buffer_section_index as usize * 6144,
2106 message.send_buffer_section_size as usize,
2107 )
2108 .ok_or(PacketError::InvalidSendBufferIndex)?;
2109
2110 external_data.push_range(send_buffer_suballocation);
2111 }
2112 PacketData::RndisPacket(message)
2113 }
2114 protocol::MESSAGE2_TYPE_SEND_NDIS_CONFIG if version >= Some(Version::V2) => {
2115 PacketData::SendNdisConfig(read_packet_data(&mut reader)?)
2116 }
2117 protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH if version >= Some(Version::V4) => {
2118 PacketData::SwitchDataPath(read_packet_data(&mut reader)?)
2119 }
2120 protocol::MESSAGE5_TYPE_OID_QUERY_EX if version >= Some(Version::V5) => {
2121 PacketData::OidQueryEx(read_packet_data(&mut reader)?)
2122 }
2123 protocol::MESSAGE5_TYPE_SUB_CHANNEL if version >= Some(Version::V5) => {
2124 PacketData::SubChannelRequest(read_packet_data(&mut reader)?)
2125 }
2126 typ => return Err(PacketError::UnknownType(typ)),
2127 };
2128 packet
2129 .read_external_ranges(external_data)
2130 .map_err(PacketError::ExternalData)?;
2131 Ok(Packet {
2132 data,
2133 transaction_id: packet.transaction_id(),
2134 external_data,
2135 })
2136}
2137
2138#[derive(Debug, Copy, Clone)]
2139struct NvspMessage {
2140 buf: [u64; protocol::PACKET_SIZE_V61 / 8],
2141 size: PacketSize,
2142}
2143
2144impl NvspMessage {
2145 fn new<P: IntoBytes + Immutable + KnownLayout>(
2146 size: PacketSize,
2147 message_type: u32,
2148 data: P,
2149 ) -> Self {
2150 const {
2157 assert!(
2158 size_of::<P>() <= protocol::PACKET_SIZE_V1 - size_of::<protocol::MessageHeader>(),
2159 "packet might not fit in message"
2160 )
2161 };
2162 let mut message = NvspMessage {
2163 buf: [0; protocol::PACKET_SIZE_V61 / 8],
2164 size,
2165 };
2166 let header = protocol::MessageHeader { message_type };
2167 header.write_to_prefix(message.buf.as_mut_bytes()).unwrap();
2168 data.write_to_prefix(
2169 &mut message.buf.as_mut_bytes()[size_of::<protocol::MessageHeader>()..],
2170 )
2171 .unwrap();
2172 message
2173 }
2174
2175 fn aligned_payload(&self) -> &[u64] {
2176 let len = match self.size {
2179 PacketSize::V1 => const { protocol::PACKET_SIZE_V1.next_multiple_of(8) / 8 },
2180 PacketSize::V61 => const { protocol::PACKET_SIZE_V61.next_multiple_of(8) / 8 },
2181 };
2182 &self.buf[..len]
2183 }
2184}
2185
2186impl<T: RingMem> NetChannel<T> {
2187 fn message<P: IntoBytes + Immutable + KnownLayout>(
2188 &self,
2189 message_type: u32,
2190 data: P,
2191 ) -> NvspMessage {
2192 NvspMessage::new(self.packet_size, message_type, data)
2193 }
2194
2195 fn send_completion(
2196 &mut self,
2197 transaction_id: Option<u64>,
2198 message: Option<&NvspMessage>,
2199 ) -> Result<(), WorkerError> {
2200 match transaction_id {
2201 None => Ok(()),
2202 Some(transaction_id) => Ok(self
2203 .queue
2204 .split()
2205 .1
2206 .batched()
2207 .try_write_aligned(
2208 transaction_id,
2209 OutgoingPacketType::Completion,
2210 message.map_or(&[], |m| m.aligned_payload()),
2211 )
2212 .map_err(|err| match err {
2213 queue::TryWriteError::Full(_) => WorkerError::OutOfSpace,
2214 queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2215 })?),
2216 }
2217 }
2218}
2219
2220static SUPPORTED_VERSIONS: &[Version] = &[
2221 Version::V1,
2222 Version::V2,
2223 Version::V4,
2224 Version::V5,
2225 Version::V6,
2226 Version::V61,
2227];
2228
2229fn check_version(requested_version: u32) -> Option<Version> {
2230 SUPPORTED_VERSIONS
2231 .iter()
2232 .find(|version| **version as u32 == requested_version)
2233 .copied()
2234}
2235
2236#[derive(Debug)]
2237struct ReceiveBuffer {
2238 gpadl: GpadlView,
2239 id: u16,
2240 count: u32,
2241 sub_allocation_size: u32,
2242}
2243
2244#[derive(Debug, Error)]
2245enum BufferError {
2246 #[error("unsupported suballocation size {0}")]
2247 UnsupportedSuballocationSize(u32),
2248 #[error("unaligned gpadl")]
2249 UnalignedGpadl,
2250 #[error("unknown gpadl ID")]
2251 UnknownGpadlId(#[from] UnknownGpadlId),
2252}
2253
2254impl ReceiveBuffer {
2255 fn new(
2256 gpadl_map: &GpadlMapView,
2257 gpadl_id: GpadlId,
2258 id: u16,
2259 sub_allocation_size: u32,
2260 ) -> Result<Self, BufferError> {
2261 if sub_allocation_size < sub_allocation_size_for_mtu(DEFAULT_MTU) {
2262 return Err(BufferError::UnsupportedSuballocationSize(
2263 sub_allocation_size,
2264 ));
2265 }
2266 let gpadl = gpadl_map.map(gpadl_id)?;
2267 let range = gpadl
2268 .contiguous_aligned()
2269 .ok_or(BufferError::UnalignedGpadl)?;
2270 let num_sub_allocations = range.len() as u32 / sub_allocation_size;
2271 if num_sub_allocations == 0 {
2272 return Err(BufferError::UnsupportedSuballocationSize(
2273 sub_allocation_size,
2274 ));
2275 }
2276 let recv_buffer = Self {
2277 gpadl,
2278 id,
2279 count: num_sub_allocations,
2280 sub_allocation_size,
2281 };
2282 Ok(recv_buffer)
2283 }
2284
2285 fn range(&self, index: u32) -> PagedRange<'_> {
2286 self.gpadl.first().unwrap().subrange(
2287 (index * self.sub_allocation_size) as usize,
2288 self.sub_allocation_size as usize,
2289 )
2290 }
2291
2292 fn transfer_page_range(&self, index: u32, len: usize) -> ring::TransferPageRange {
2293 assert!(len <= self.sub_allocation_size as usize);
2294 ring::TransferPageRange {
2295 byte_offset: index * self.sub_allocation_size,
2296 byte_count: len as u32,
2297 }
2298 }
2299
2300 fn saved_state(&self) -> saved_state::ReceiveBuffer {
2301 saved_state::ReceiveBuffer {
2302 gpadl_id: self.gpadl.id(),
2303 id: self.id,
2304 sub_allocation_size: self.sub_allocation_size,
2305 }
2306 }
2307}
2308
2309#[derive(Debug)]
2310struct SendBuffer {
2311 gpadl: GpadlView,
2312}
2313
2314impl SendBuffer {
2315 fn new(gpadl_map: &GpadlMapView, gpadl_id: GpadlId) -> Result<Self, BufferError> {
2316 let gpadl = gpadl_map.map(gpadl_id)?;
2317 gpadl
2318 .contiguous_aligned()
2319 .ok_or(BufferError::UnalignedGpadl)?;
2320 Ok(Self { gpadl })
2321 }
2322}
2323
2324impl<T: RingMem> NetChannel<T> {
2325 fn handle_rndis_message(
2327 &mut self,
2328 state: &mut ActiveState,
2329 message_type: u32,
2330 mut reader: PacketReader<'_>,
2331 ) -> Result<(), WorkerError> {
2332 assert_ne!(
2333 message_type,
2334 rndisprot::MESSAGE_TYPE_PACKET_MSG,
2335 "handled elsewhere"
2336 );
2337 let control = state
2338 .primary
2339 .as_mut()
2340 .ok_or(WorkerError::NotSupportedOnSubChannel(message_type))?;
2341
2342 if message_type == rndisprot::MESSAGE_TYPE_HALT_MSG {
2343 return Ok(());
2345 }
2346
2347 const CONTROL_MESSAGE_MAX_QUEUED_BYTES: usize = 100 * 1024;
2352 if reader.len() == 0 {
2353 return Err(WorkerError::RndisMessageTooSmall);
2354 }
2355 if CONTROL_MESSAGE_MAX_QUEUED_BYTES - control.control_messages_len < reader.len() {
2358 return Err(WorkerError::TooManyControlMessages);
2359 }
2360
2361 control.control_messages_len += reader.len();
2362 control.control_messages.push_back(ControlMessage {
2363 message_type,
2364 data: reader.read_all()?.into(),
2365 });
2366
2367 Ok(())
2370 }
2371
2372 fn handle_rndis_packet_messages(
2378 &mut self,
2379 buffers: &ChannelBuffers,
2380 state: &mut ActiveState,
2381 id: TxId,
2382 mut message_len: usize,
2383 mut reader: PacketReader<'_>,
2384 segments: &mut Vec<TxSegment>,
2385 ) -> Result<usize, WorkerError> {
2386 let mut num_packets = 0;
2390 loop {
2391 let next_message_offset = message_len
2392 .checked_sub(size_of::<rndisprot::MessageHeader>())
2393 .ok_or(WorkerError::RndisMessageTooSmall)?;
2394
2395 self.handle_rndis_packet_message(
2396 id,
2397 reader.clone(),
2398 &buffers.mem,
2399 segments,
2400 &mut state.stats,
2401 )?;
2402 num_packets += 1;
2403
2404 reader.skip(next_message_offset)?;
2405 if reader.len() == 0 {
2406 break;
2407 }
2408 let header: rndisprot::MessageHeader = reader.read_plain()?;
2409 if header.message_type != rndisprot::MESSAGE_TYPE_PACKET_MSG {
2410 return Err(WorkerError::NonRndisPacketAfterPacket(header.message_type));
2411 }
2412 message_len = header.message_length as usize;
2413 }
2414 Ok(num_packets)
2415 }
2416
2417 fn handle_rndis_packet_message(
2419 &mut self,
2420 id: TxId,
2421 reader: PacketReader<'_>,
2422 mem: &GuestMemory,
2423 segments: &mut Vec<TxSegment>,
2424 stats: &mut QueueStats,
2425 ) -> Result<(), WorkerError> {
2426 let headers = reader
2428 .clone()
2429 .into_inner()
2430 .paged_ranges()
2431 .next()
2432 .ok_or(WorkerError::RndisMessageTooSmall)?;
2433 let mut data = reader.into_inner();
2434 let request: rndisprot::Packet = headers.reader(mem).read_plain()?;
2435 if request.num_oob_data_elements != 0
2436 || request.oob_data_length != 0
2437 || request.oob_data_offset != 0
2438 || request.vc_handle != 0
2439 {
2440 return Err(WorkerError::UnsupportedRndisBehavior);
2441 }
2442
2443 if data.len() < request.data_offset as usize
2444 || (data.len() - request.data_offset as usize) < request.data_length as usize
2445 || request.data_length == 0
2446 {
2447 return Err(WorkerError::RndisMessageTooSmall);
2448 }
2449
2450 data.skip(request.data_offset as usize);
2451 data.truncate(request.data_length as usize);
2452
2453 let mut metadata = net_backend::TxMetadata {
2454 id,
2455 len: request.data_length,
2456 ..Default::default()
2457 };
2458
2459 if request.per_packet_info_length != 0 {
2460 let mut ppi = headers
2461 .try_subrange(
2462 request.per_packet_info_offset as usize,
2463 request.per_packet_info_length as usize,
2464 )
2465 .ok_or(WorkerError::RndisMessageTooSmall)?;
2466 while !ppi.is_empty() {
2467 let h: rndisprot::PerPacketInfo = ppi.reader(mem).read_plain()?;
2468 if h.size == 0 {
2469 return Err(WorkerError::RndisMessageTooSmall);
2470 }
2471 let (this, rest) = ppi
2472 .try_split(h.size as usize)
2473 .ok_or(WorkerError::RndisMessageTooSmall)?;
2474 let (_, d) = this
2475 .try_split(h.per_packet_information_offset as usize)
2476 .ok_or(WorkerError::RndisMessageTooSmall)?;
2477 match h.typ {
2478 rndisprot::PPI_TCP_IP_CHECKSUM => {
2479 let n: rndisprot::TxTcpIpChecksumInfo = d.reader(mem).read_plain()?;
2480
2481 metadata.flags.set_offload_tcp_checksum(
2482 (n.is_ipv4() || n.is_ipv6()) && n.tcp_checksum(),
2483 );
2484 metadata.flags.set_offload_udp_checksum(
2485 (n.is_ipv4() || n.is_ipv6()) && !n.tcp_checksum() && n.udp_checksum(),
2486 );
2487 metadata
2488 .flags
2489 .set_offload_ip_header_checksum(n.is_ipv4() && n.ip_header_checksum());
2490 metadata.flags.set_is_ipv4(n.is_ipv4());
2491 metadata.flags.set_is_ipv6(n.is_ipv6() && !n.is_ipv4());
2492 metadata.l2_len = ETHERNET_HEADER_LEN as u8;
2493 if metadata.flags.offload_tcp_checksum()
2494 || metadata.flags.offload_udp_checksum()
2495 {
2496 metadata.l3_len = if n.tcp_header_offset() >= metadata.l2_len as u16 {
2497 n.tcp_header_offset() - metadata.l2_len as u16
2498 } else if n.is_ipv4() {
2499 let mut reader = data.clone().reader(mem);
2500 reader.skip(metadata.l2_len as usize)?;
2501 let mut b = 0;
2502 reader.read(std::slice::from_mut(&mut b))?;
2503 (b as u16 >> 4) * 4
2504 } else {
2505 40
2507 };
2508 }
2509 }
2510 rndisprot::PPI_LSO => {
2511 let n: rndisprot::TcpLsoInfo = d.reader(mem).read_plain()?;
2512
2513 metadata.flags.set_offload_tcp_segmentation(true);
2514 metadata.flags.set_offload_tcp_checksum(true);
2515 metadata.flags.set_offload_ip_header_checksum(n.is_ipv4());
2516 metadata.flags.set_is_ipv4(n.is_ipv4());
2517 metadata.flags.set_is_ipv6(n.is_ipv6() && !n.is_ipv4());
2518 metadata.l2_len = ETHERNET_HEADER_LEN as u8;
2519 if n.tcp_header_offset() < metadata.l2_len as u16 {
2520 return Err(WorkerError::InvalidTcpHeaderOffset);
2521 }
2522 metadata.l3_len = n.tcp_header_offset() - metadata.l2_len as u16;
2523 metadata.l4_len = {
2524 let mut reader = data.clone().reader(mem);
2525 reader
2526 .skip(metadata.l2_len as usize + metadata.l3_len as usize + 12)?;
2527 let mut b = 0;
2528 reader.read(std::slice::from_mut(&mut b))?;
2529 (b >> 4) * 4
2530 };
2531 metadata.max_tcp_segment_size = n.mss() as u16;
2532
2533 if request.data_length >= rndisprot::LSO_MAX_OFFLOAD_SIZE {
2534 stats.tx_invalid_lso_packets.increment();
2536 }
2537 }
2538 _ => {}
2539 }
2540 ppi = rest;
2541 }
2542 }
2543
2544 let start = segments.len();
2545 for range in data.paged_ranges().flat_map(|r| r.ranges()) {
2546 let range = range.map_err(WorkerError::InvalidGpadl)?;
2547 segments.push(TxSegment {
2548 ty: net_backend::TxSegmentType::Tail,
2549 gpa: range.start,
2550 len: range.len() as u32,
2551 });
2552 }
2553
2554 metadata.segment_count = (segments.len() - start) as u8;
2555
2556 stats.tx_packets.increment();
2557 if metadata.flags.offload_tcp_checksum() || metadata.flags.offload_udp_checksum() {
2558 stats.tx_checksum_packets.increment();
2559 }
2560 if metadata.flags.offload_tcp_segmentation() {
2561 stats.tx_lso_packets.increment();
2562 }
2563
2564 segments[start].ty = net_backend::TxSegmentType::Head(metadata);
2565
2566 Ok(())
2567 }
2568
2569 fn guest_vf_is_available(
2572 &mut self,
2573 guest_vf_id: Option<u32>,
2574 version: Version,
2575 config: NdisConfig,
2576 ) -> Result<bool, WorkerError> {
2577 let serial_number = guest_vf_id.map(|vfid| self.adapter.get_guest_vf_serial_number(vfid));
2578 if version >= Version::V4 && config.capabilities.sriov() {
2579 tracing::info!(
2580 available = serial_number.is_some(),
2581 serial_number,
2582 "sending VF association message."
2583 );
2584 let message = {
2586 self.message(
2587 protocol::MESSAGE4_TYPE_SEND_VF_ASSOCIATION,
2588 protocol::Message4SendVfAssociation {
2589 vf_allocated: if serial_number.is_some() { 1 } else { 0 },
2590 serial_number: serial_number.unwrap_or(0),
2591 },
2592 )
2593 };
2594 self.queue
2595 .split()
2596 .1
2597 .batched()
2598 .try_write_aligned(
2599 VF_ASSOCIATION_TRANSACTION_ID,
2600 OutgoingPacketType::InBandWithCompletion,
2601 message.aligned_payload(),
2602 )
2603 .map_err(|err| match err {
2604 queue::TryWriteError::Full(len) => {
2605 tracing::error!(len, "failed to write vf association message");
2606 WorkerError::OutOfSpace
2607 }
2608 queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2609 })?;
2610 Ok(true)
2611 } else {
2612 tracing::info!(
2613 available = serial_number.is_some(),
2614 serial_number,
2615 major = version.major(),
2616 minor = version.minor(),
2617 sriov_capable = config.capabilities.sriov(),
2618 "Skipping NvspMessage4TypeSendVFAssociation message"
2619 );
2620 Ok(false)
2621 }
2622 }
2623
2624 fn guest_send_indirection_table(&mut self, version: Version, num_channels_opened: u32) {
2626 if version < Version::V5 {
2628 return;
2629 }
2630
2631 #[repr(C)]
2632 #[derive(IntoBytes, Immutable, KnownLayout)]
2633 struct SendIndirectionMsg {
2634 pub message: protocol::Message5SendIndirectionTable,
2635 pub send_indirection_table:
2636 [u32; VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES as usize],
2637 }
2638
2639 let send_indirection_table_offset = offset_of!(SendIndirectionMsg, send_indirection_table)
2641 + size_of::<protocol::MessageHeader>();
2642 let mut data = SendIndirectionMsg {
2643 message: protocol::Message5SendIndirectionTable {
2644 table_entry_count: VMS_SWITCH_RSS_MAX_SEND_INDIRECTION_TABLE_ENTRIES,
2645 table_offset: send_indirection_table_offset as u32,
2646 },
2647 send_indirection_table: Default::default(),
2648 };
2649
2650 for i in 0..data.send_indirection_table.len() {
2651 data.send_indirection_table[i] = i as u32 % num_channels_opened;
2652 }
2653
2654 let header = protocol::MessageHeader {
2655 message_type: protocol::MESSAGE5_TYPE_SEND_INDIRECTION_TABLE,
2656 };
2657 let result = self
2658 .queue
2659 .split()
2660 .1
2661 .try_write(&queue::OutgoingPacket {
2662 transaction_id: 0,
2663 packet_type: OutgoingPacketType::InBandNoCompletion,
2664 payload: &[header.as_bytes(), data.as_bytes()],
2665 })
2666 .map_err(|err| match err {
2667 queue::TryWriteError::Full(len) => {
2668 tracing::error!(len, "failed to write send indirection table message");
2669 WorkerError::OutOfSpace
2670 }
2671 queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2672 });
2673 if let Err(err) = result {
2674 tracing::error!(err = %err, "Failed to notify guest about the send indirection table");
2675 }
2676 }
2677
2678 fn guest_vf_data_path_switched_to_synthetic(&mut self) {
2681 let header = protocol::MessageHeader {
2682 message_type: protocol::MESSAGE4_TYPE_SWITCH_DATA_PATH,
2683 };
2684 let data = protocol::Message4SwitchDataPath {
2685 active_data_path: protocol::DataPath::SYNTHETIC.0,
2686 };
2687 let result = self
2688 .queue
2689 .split()
2690 .1
2691 .try_write(&queue::OutgoingPacket {
2692 transaction_id: SWITCH_DATA_PATH_TRANSACTION_ID,
2693 packet_type: OutgoingPacketType::InBandWithCompletion,
2694 payload: &[header.as_bytes(), data.as_bytes()],
2695 })
2696 .map_err(|err| match err {
2697 queue::TryWriteError::Full(len) => {
2698 tracing::error!(
2699 len,
2700 "failed to write MESSAGE4_TYPE_SWITCH_DATA_PATH message"
2701 );
2702 WorkerError::OutOfSpace
2703 }
2704 queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
2705 });
2706 if let Err(err) = result {
2707 tracing::error!(err = %err, "Failed to notify guest that data path is now synthetic");
2708 }
2709 }
2710
2711 async fn handle_state_change(
2713 &mut self,
2714 primary: &mut PrimaryChannelState,
2715 buffers: &ChannelBuffers,
2716 ) -> Result<Option<CoordinatorMessage>, WorkerError> {
2717 if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
2723 if primary.rndis_state == RndisState::Operational {
2725 if self.guest_vf_is_available(Some(vfid), buffers.version, buffers.ndis_config)? {
2726 primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
2727 return Ok(Some(CoordinatorMessage::Update(
2728 CoordinatorMessageUpdateType {
2729 guest_vf_state: true,
2730 ..Default::default()
2731 },
2732 )));
2733 } else if let Some(true) = primary.is_data_path_switched {
2734 tracing::error!(
2735 "Data path switched, but current guest negotiation does not support VTL0 VF"
2736 );
2737 }
2738 }
2739 return Ok(None);
2740 }
2741 loop {
2742 primary.guest_vf_state = match primary.guest_vf_state {
2743 PrimaryChannelGuestVfState::UnavailableFromAvailable => {
2744 if primary.rndis_state == RndisState::Operational {
2746 self.guest_vf_is_available(None, buffers.version, buffers.ndis_config)?;
2747 }
2748 PrimaryChannelGuestVfState::Unavailable
2749 }
2750 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
2751 to_guest,
2752 id,
2753 } => {
2754 self.send_completion(id, None)?;
2756 if to_guest {
2757 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
2758 } else {
2759 PrimaryChannelGuestVfState::UnavailableFromAvailable
2760 }
2761 }
2762 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched => {
2763 self.guest_vf_data_path_switched_to_synthetic();
2765 PrimaryChannelGuestVfState::UnavailableFromAvailable
2766 }
2767 PrimaryChannelGuestVfState::DataPathSynthetic => {
2768 self.guest_vf_data_path_switched_to_synthetic();
2770 PrimaryChannelGuestVfState::Ready
2771 }
2772 PrimaryChannelGuestVfState::DataPathSwitchPending {
2773 to_guest,
2774 id,
2775 result,
2776 } => {
2777 let result = result.expect("DataPathSwitchPending should have been processed");
2778 self.send_completion(id, None)?;
2780 if result {
2781 if to_guest {
2782 PrimaryChannelGuestVfState::DataPathSwitched
2783 } else {
2784 PrimaryChannelGuestVfState::Ready
2785 }
2786 } else {
2787 if to_guest {
2788 PrimaryChannelGuestVfState::DataPathSynthetic
2789 } else {
2790 tracing::error!(
2791 "Failure when guest requested switch back to synthetic"
2792 );
2793 PrimaryChannelGuestVfState::DataPathSwitched
2794 }
2795 }
2796 }
2797 PrimaryChannelGuestVfState::Initializing
2798 | PrimaryChannelGuestVfState::Restoring(_) => {
2799 panic!("Invalid guest VF state: {}", primary.guest_vf_state)
2800 }
2801 _ => break,
2802 };
2803 }
2804 Ok(None)
2805 }
2806
2807 fn handle_rndis_control_message(
2810 &mut self,
2811 primary: &mut PrimaryChannelState,
2812 buffers: &ChannelBuffers,
2813 message_type: u32,
2814 mut reader: impl MemoryRead + Clone,
2815 id: u32,
2816 ) -> Result<(), WorkerError> {
2817 let mem = &buffers.mem;
2818 let buffer_range = &buffers.recv_buffer.range(id);
2819 match message_type {
2820 rndisprot::MESSAGE_TYPE_INITIALIZE_MSG => {
2821 if primary.rndis_state != RndisState::Initializing {
2822 return Err(WorkerError::InvalidRndisState);
2823 }
2824
2825 let request: rndisprot::InitializeRequest = reader.read_plain()?;
2826
2827 tracing::trace!(
2828 ?request,
2829 "handling control message MESSAGE_TYPE_INITIALIZE_MSG"
2830 );
2831
2832 primary.rndis_state = RndisState::Operational;
2833
2834 let mut writer = buffer_range.writer(mem);
2835 let message_length = write_rndis_message(
2836 &mut writer,
2837 rndisprot::MESSAGE_TYPE_INITIALIZE_CMPLT,
2838 0,
2839 &rndisprot::InitializeComplete {
2840 request_id: request.request_id,
2841 status: rndisprot::STATUS_SUCCESS,
2842 major_version: rndisprot::MAJOR_VERSION,
2843 minor_version: rndisprot::MINOR_VERSION,
2844 device_flags: rndisprot::DF_CONNECTIONLESS,
2845 medium: rndisprot::MEDIUM_802_3,
2846 max_packets_per_message: 8,
2847 max_transfer_size: 0xEFFFFFFF,
2848 packet_alignment_factor: 3,
2849 af_list_offset: 0,
2850 af_list_size: 0,
2851 },
2852 )?;
2853 self.send_rndis_control_message(buffers, id, message_length)?;
2854 if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
2855 if self.guest_vf_is_available(
2856 Some(vfid),
2857 buffers.version,
2858 buffers.ndis_config,
2859 )? {
2860 primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised;
2872 self.send_coordinator_update_vf();
2873 } else if let Some(true) = primary.is_data_path_switched {
2874 tracing::error!(
2875 "Data path switched, but current guest negotiation does not support VTL0 VF"
2876 );
2877 }
2878 }
2879 }
2880 rndisprot::MESSAGE_TYPE_QUERY_MSG => {
2881 let request: rndisprot::QueryRequest = reader.read_plain()?;
2882
2883 tracing::trace!(?request, "handling control message MESSAGE_TYPE_QUERY_MSG");
2884
2885 let (header, body) = buffer_range
2886 .try_split(
2887 size_of::<rndisprot::MessageHeader>()
2888 + size_of::<rndisprot::QueryComplete>(),
2889 )
2890 .ok_or(WorkerError::RndisMessageTooSmall)?;
2891 let (status, tx) = match self.adapter.handle_oid_query(
2892 buffers,
2893 primary,
2894 request.oid,
2895 body.writer(mem),
2896 ) {
2897 Ok(tx) => (rndisprot::STATUS_SUCCESS, tx),
2898 Err(err) => (err.as_status(), 0),
2899 };
2900
2901 let message_length = write_rndis_message(
2902 &mut header.writer(mem),
2903 rndisprot::MESSAGE_TYPE_QUERY_CMPLT,
2904 tx,
2905 &rndisprot::QueryComplete {
2906 request_id: request.request_id,
2907 status,
2908 information_buffer_offset: size_of::<rndisprot::QueryComplete>() as u32,
2909 information_buffer_length: tx as u32,
2910 },
2911 )?;
2912 self.send_rndis_control_message(buffers, id, message_length)?;
2913 }
2914 rndisprot::MESSAGE_TYPE_SET_MSG => {
2915 let request: rndisprot::SetRequest = reader.read_plain()?;
2916
2917 tracing::trace!(?request, "handling control message MESSAGE_TYPE_SET_MSG");
2918
2919 let status = match self.adapter.handle_oid_set(primary, request.oid, reader) {
2920 Ok((restart_endpoint, packet_filter)) => {
2921 if restart_endpoint {
2924 self.restart = Some(CoordinatorMessage::Restart);
2925 }
2926 if let Some(filter) = packet_filter {
2927 if self.packet_filter != filter {
2928 self.packet_filter = filter;
2929 self.send_coordinator_update_filter();
2930 }
2931 }
2932 rndisprot::STATUS_SUCCESS
2933 }
2934 Err(err) => {
2935 tracelimit::warn_ratelimited!(oid = ?request.oid, error = &err as &dyn std::error::Error, "oid failure");
2936 err.as_status()
2937 }
2938 };
2939
2940 let message_length = write_rndis_message(
2941 &mut buffer_range.writer(mem),
2942 rndisprot::MESSAGE_TYPE_SET_CMPLT,
2943 0,
2944 &rndisprot::SetComplete {
2945 request_id: request.request_id,
2946 status,
2947 },
2948 )?;
2949 self.send_rndis_control_message(buffers, id, message_length)?;
2950 }
2951 rndisprot::MESSAGE_TYPE_RESET_MSG => {
2952 return Err(WorkerError::RndisMessageTypeNotImplemented);
2953 }
2954 rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG => {
2955 return Err(WorkerError::RndisMessageTypeNotImplemented);
2956 }
2957 rndisprot::MESSAGE_TYPE_KEEPALIVE_MSG => {
2958 let request: rndisprot::KeepaliveRequest = reader.read_plain()?;
2959
2960 tracing::trace!(
2961 ?request,
2962 "handling control message MESSAGE_TYPE_KEEPALIVE_MSG"
2963 );
2964
2965 let message_length = write_rndis_message(
2966 &mut buffer_range.writer(mem),
2967 rndisprot::MESSAGE_TYPE_KEEPALIVE_CMPLT,
2968 0,
2969 &rndisprot::KeepaliveComplete {
2970 request_id: request.request_id,
2971 status: rndisprot::STATUS_SUCCESS,
2972 },
2973 )?;
2974 self.send_rndis_control_message(buffers, id, message_length)?;
2975 }
2976 rndisprot::MESSAGE_TYPE_SET_EX_MSG => {
2977 return Err(WorkerError::RndisMessageTypeNotImplemented);
2978 }
2979 _ => return Err(WorkerError::UnknownRndisMessageType(message_type)),
2980 };
2981 Ok(())
2982 }
2983
2984 fn try_send_rndis_message(
2985 &mut self,
2986 transaction_id: u64,
2987 channel_type: u32,
2988 recv_buffer_id: u16,
2989 transfer_pages: &[ring::TransferPageRange],
2990 ) -> Result<Option<usize>, WorkerError> {
2991 let message = self.message(
2992 protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET,
2993 protocol::Message1SendRndisPacket {
2994 channel_type,
2995 send_buffer_section_index: 0xffffffff,
2996 send_buffer_section_size: 0,
2997 },
2998 );
2999 let pending_send_size = match self.queue.split().1.batched().try_write_aligned(
3000 transaction_id,
3001 OutgoingPacketType::TransferPages(recv_buffer_id, transfer_pages),
3002 message.aligned_payload(),
3003 ) {
3004 Ok(()) => None,
3005 Err(queue::TryWriteError::Full(n)) => Some(n),
3006 Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
3007 };
3008 Ok(pending_send_size)
3009 }
3010
3011 fn send_rndis_control_message(
3012 &mut self,
3013 buffers: &ChannelBuffers,
3014 id: u32,
3015 message_length: usize,
3016 ) -> Result<(), WorkerError> {
3017 let result = self.try_send_rndis_message(
3018 id as u64,
3019 protocol::CONTROL_CHANNEL_TYPE,
3020 buffers.recv_buffer.id,
3021 std::slice::from_ref(&buffers.recv_buffer.transfer_page_range(id, message_length)),
3022 )?;
3023
3024 match result {
3026 None => Ok(()),
3027 Some(len) => {
3028 tracelimit::error_ratelimited!(len, "failed to write control message completion");
3029 Err(WorkerError::OutOfSpace)
3030 }
3031 }
3032 }
3033
3034 fn indicate_status(
3035 &mut self,
3036 buffers: &ChannelBuffers,
3037 id: u32,
3038 status: u32,
3039 payload: &[u8],
3040 ) -> Result<(), WorkerError> {
3041 let buffer = &buffers.recv_buffer.range(id);
3042 let mut writer = buffer.writer(&buffers.mem);
3043 let message_length = write_rndis_message(
3044 &mut writer,
3045 rndisprot::MESSAGE_TYPE_INDICATE_STATUS_MSG,
3046 payload.len(),
3047 &rndisprot::IndicateStatus {
3048 status,
3049 status_buffer_length: payload.len() as u32,
3050 status_buffer_offset: if payload.is_empty() {
3051 0
3052 } else {
3053 size_of::<rndisprot::IndicateStatus>() as u32
3054 },
3055 },
3056 )?;
3057 writer.write(payload)?;
3058 self.send_rndis_control_message(buffers, id, message_length)?;
3059 Ok(())
3060 }
3061
3062 fn process_control_messages(
3065 &mut self,
3066 buffers: &ChannelBuffers,
3067 state: &mut ActiveState,
3068 ) -> Result<(), WorkerError> {
3069 let Some(primary) = &mut state.primary else {
3070 return Ok(());
3071 };
3072
3073 while !primary.control_messages.is_empty()
3074 || (primary.pending_offload_change && primary.rndis_state == RndisState::Operational)
3075 {
3076 if !self.queue.split().1.can_write(MIN_CONTROL_RING_SIZE)? {
3078 self.pending_send_size = MIN_CONTROL_RING_SIZE;
3079 break;
3080 }
3081 let Some(id) = primary.free_control_buffers.pop() else {
3082 break;
3083 };
3084
3085 assert!(state.rx_bufs.is_free(id.0));
3087 state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
3088
3089 if let Some(message) = primary.control_messages.pop_front() {
3090 primary.control_messages_len -= message.data.len();
3091 self.handle_rndis_control_message(
3092 primary,
3093 buffers,
3094 message.message_type,
3095 message.data.as_ref(),
3096 id.0,
3097 )?;
3098 } else if primary.pending_offload_change
3099 && primary.rndis_state == RndisState::Operational
3100 {
3101 let ndis_offload = primary.offload_config.ndis_offload();
3102 self.indicate_status(
3103 buffers,
3104 id.0,
3105 rndisprot::STATUS_TASK_OFFLOAD_CURRENT_CONFIG,
3106 &ndis_offload.as_bytes()[..ndis_offload.header.size.into()],
3107 )?;
3108 primary.pending_offload_change = false;
3109 } else {
3110 unreachable!();
3111 }
3112 }
3113 Ok(())
3114 }
3115
3116 fn send_coordinator_update_message(&mut self, guest_vf: bool, packet_filter: bool) {
3117 if self.restart.is_none() {
3118 self.restart = Some(CoordinatorMessage::Update(CoordinatorMessageUpdateType {
3119 guest_vf_state: guest_vf,
3120 filter_state: packet_filter,
3121 }));
3122 } else if let Some(CoordinatorMessage::Restart) = self.restart {
3123 } else if let Some(CoordinatorMessage::Update(ref mut update)) = self.restart {
3127 update.guest_vf_state |= guest_vf;
3129 update.filter_state |= packet_filter;
3130 }
3131 }
3132
3133 fn send_coordinator_update_vf(&mut self) {
3134 self.send_coordinator_update_message(true, false);
3135 }
3136
3137 fn send_coordinator_update_filter(&mut self) {
3138 self.send_coordinator_update_message(false, true);
3139 }
3140}
3141
3142fn write_rndis_message<T: IntoBytes + Immutable + KnownLayout>(
3144 writer: &mut impl MemoryWrite,
3145 message_type: u32,
3146 extra: usize,
3147 payload: &T,
3148) -> Result<usize, AccessError> {
3149 let message_length = size_of::<rndisprot::MessageHeader>() + size_of_val(payload) + extra;
3150 writer.write(
3151 rndisprot::MessageHeader {
3152 message_type,
3153 message_length: message_length as u32,
3154 }
3155 .as_bytes(),
3156 )?;
3157 writer.write(payload.as_bytes())?;
3158 Ok(message_length)
3159}
3160
3161#[derive(Debug, Error)]
3162enum OidError {
3163 #[error(transparent)]
3164 Access(#[from] AccessError),
3165 #[error("unknown oid")]
3166 UnknownOid,
3167 #[error("invalid oid input, bad field {0}")]
3168 InvalidInput(&'static str),
3169 #[error("bad ndis version")]
3170 BadVersion,
3171 #[error("feature {0} not supported")]
3172 NotSupported(&'static str),
3173}
3174
3175impl OidError {
3176 fn as_status(&self) -> u32 {
3177 match self {
3178 OidError::UnknownOid | OidError::NotSupported(_) => rndisprot::STATUS_NOT_SUPPORTED,
3179 OidError::BadVersion => rndisprot::STATUS_BAD_VERSION,
3180 OidError::InvalidInput(_) => rndisprot::STATUS_INVALID_DATA,
3181 OidError::Access(_) => rndisprot::STATUS_FAILURE,
3182 }
3183 }
3184}
3185
3186const DEFAULT_MTU: u32 = 1514;
3187const MIN_MTU: u32 = DEFAULT_MTU;
3188const MAX_MTU: u32 = 9216;
3189
3190const ETHERNET_HEADER_LEN: u32 = 14;
3191
3192impl Adapter {
3193 fn get_guest_vf_serial_number(&self, vfid: u32) -> u32 {
3194 if let Some(guest_os_id) = self.get_guest_os_id.as_ref().map(|f| f()) {
3195 if guest_os_id
3198 .microsoft()
3199 .unwrap_or(HvGuestOsMicrosoft::from(0))
3200 .os_id()
3201 == HvGuestOsMicrosoftIds::WINDOWS_NT.0
3202 {
3203 self.adapter_index
3204 } else {
3205 vfid
3206 }
3207 } else {
3208 vfid
3209 }
3210 }
3211
3212 fn handle_oid_query(
3213 &self,
3214 buffers: &ChannelBuffers,
3215 primary: &PrimaryChannelState,
3216 oid: rndisprot::Oid,
3217 mut writer: impl MemoryWrite,
3218 ) -> Result<usize, OidError> {
3219 tracing::debug!(?oid, "oid query");
3220 let available_len = writer.len();
3221 match oid {
3222 rndisprot::Oid::OID_GEN_SUPPORTED_LIST => {
3223 let supported_oids_common = &[
3224 rndisprot::Oid::OID_GEN_SUPPORTED_LIST,
3225 rndisprot::Oid::OID_GEN_HARDWARE_STATUS,
3226 rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED,
3227 rndisprot::Oid::OID_GEN_MEDIA_IN_USE,
3228 rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD,
3229 rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD,
3230 rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE,
3231 rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE,
3232 rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE,
3233 rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE,
3234 rndisprot::Oid::OID_GEN_LINK_SPEED,
3235 rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE,
3236 rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE,
3237 rndisprot::Oid::OID_GEN_VENDOR_ID,
3238 rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION,
3239 rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION,
3240 rndisprot::Oid::OID_GEN_DRIVER_VERSION,
3241 rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER,
3242 rndisprot::Oid::OID_GEN_PROTOCOL_OPTIONS,
3243 rndisprot::Oid::OID_GEN_MAC_OPTIONS,
3244 rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS,
3245 rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS,
3246 rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES,
3247 rndisprot::Oid::OID_GEN_FRIENDLY_NAME,
3248 rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS,
3250 rndisprot::Oid::OID_802_3_CURRENT_ADDRESS,
3251 rndisprot::Oid::OID_802_3_MULTICAST_LIST,
3252 rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE,
3253 rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT,
3255 rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION,
3256 rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS,
3257 rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER,
3262 ];
3263
3264 let supported_oids_6 = &[
3266 rndisprot::Oid::OID_GEN_LINK_PARAMETERS,
3268 rndisprot::Oid::OID_GEN_LINK_STATE,
3269 rndisprot::Oid::OID_GEN_MAX_LINK_SPEED,
3270 rndisprot::Oid::OID_GEN_BYTES_RCV,
3272 rndisprot::Oid::OID_GEN_BYTES_XMIT,
3273 rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS,
3275 rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION,
3276 rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES,
3277 rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG,
3278 ];
3281
3282 let supported_oids_63 = &[
3283 rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES,
3284 rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS,
3285 ];
3286
3287 match buffers.ndis_version.major {
3288 5 => {
3289 writer.write(supported_oids_common.as_bytes())?;
3290 }
3291 6 => {
3292 writer.write(supported_oids_common.as_bytes())?;
3293 writer.write(supported_oids_6.as_bytes())?;
3294 if buffers.ndis_version.minor >= 30 {
3295 writer.write(supported_oids_63.as_bytes())?;
3296 }
3297 }
3298 _ => return Err(OidError::BadVersion),
3299 }
3300 }
3301 rndisprot::Oid::OID_GEN_HARDWARE_STATUS => {
3302 let status: u32 = 0; writer.write(status.as_bytes())?;
3304 }
3305 rndisprot::Oid::OID_GEN_MEDIA_SUPPORTED | rndisprot::Oid::OID_GEN_MEDIA_IN_USE => {
3306 writer.write(rndisprot::MEDIUM_802_3.as_bytes())?;
3307 }
3308 rndisprot::Oid::OID_GEN_MAXIMUM_LOOKAHEAD
3309 | rndisprot::Oid::OID_GEN_CURRENT_LOOKAHEAD
3310 | rndisprot::Oid::OID_GEN_MAXIMUM_FRAME_SIZE => {
3311 let len: u32 = buffers.ndis_config.mtu - ETHERNET_HEADER_LEN;
3312 writer.write(len.as_bytes())?;
3313 }
3314 rndisprot::Oid::OID_GEN_MAXIMUM_TOTAL_SIZE
3315 | rndisprot::Oid::OID_GEN_TRANSMIT_BLOCK_SIZE
3316 | rndisprot::Oid::OID_GEN_RECEIVE_BLOCK_SIZE => {
3317 let len: u32 = buffers.ndis_config.mtu;
3318 writer.write(len.as_bytes())?;
3319 }
3320 rndisprot::Oid::OID_GEN_LINK_SPEED => {
3321 let speed: u32 = (self.link_speed / 100) as u32; writer.write(speed.as_bytes())?;
3323 }
3324 rndisprot::Oid::OID_GEN_TRANSMIT_BUFFER_SPACE
3325 | rndisprot::Oid::OID_GEN_RECEIVE_BUFFER_SPACE => {
3326 writer.write((256u32 * 1024).as_bytes())?
3328 }
3329 rndisprot::Oid::OID_GEN_VENDOR_ID => {
3330 writer.write(0x0000155du32.as_bytes())?;
3333 }
3334 rndisprot::Oid::OID_GEN_VENDOR_DESCRIPTION => writer.write(b"Microsoft Corporation")?,
3335 rndisprot::Oid::OID_GEN_VENDOR_DRIVER_VERSION
3336 | rndisprot::Oid::OID_GEN_DRIVER_VERSION => {
3337 writer.write(0x0100u16.as_bytes())? }
3339 rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => writer.write(0u32.as_bytes())?,
3340 rndisprot::Oid::OID_GEN_MAC_OPTIONS => {
3341 let options: u32 = rndisprot::MAC_OPTION_COPY_LOOKAHEAD_DATA
3342 | rndisprot::MAC_OPTION_TRANSFERS_NOT_PEND
3343 | rndisprot::MAC_OPTION_NO_LOOPBACK;
3344 writer.write(options.as_bytes())?;
3345 }
3346 rndisprot::Oid::OID_GEN_MEDIA_CONNECT_STATUS => {
3347 writer.write(rndisprot::MEDIA_STATE_CONNECTED.as_bytes())?;
3348 }
3349 rndisprot::Oid::OID_GEN_MAXIMUM_SEND_PACKETS => writer.write(u32::MAX.as_bytes())?,
3350 rndisprot::Oid::OID_GEN_FRIENDLY_NAME => {
3351 let name16: Vec<u16> = "Network Device".encode_utf16().collect();
3352 let mut name = rndisprot::FriendlyName::new_zeroed();
3353 name.name[..name16.len()].copy_from_slice(&name16);
3354 writer.write(name.as_bytes())?
3355 }
3356 rndisprot::Oid::OID_802_3_PERMANENT_ADDRESS
3357 | rndisprot::Oid::OID_802_3_CURRENT_ADDRESS => {
3358 writer.write(&self.mac_address.to_bytes())?
3359 }
3360 rndisprot::Oid::OID_802_3_MAXIMUM_LIST_SIZE => {
3361 writer.write(0u32.as_bytes())?;
3362 }
3363 rndisprot::Oid::OID_802_3_RCV_ERROR_ALIGNMENT
3364 | rndisprot::Oid::OID_802_3_XMIT_ONE_COLLISION
3365 | rndisprot::Oid::OID_802_3_XMIT_MORE_COLLISIONS => writer.write(0u32.as_bytes())?,
3366
3367 rndisprot::Oid::OID_GEN_LINK_STATE => {
3369 let link_state = rndisprot::LinkState {
3370 header: rndisprot::NdisObjectHeader {
3371 object_type: rndisprot::NdisObjectType::DEFAULT,
3372 revision: 1,
3373 size: size_of::<rndisprot::LinkState>() as u16,
3374 },
3375 media_connect_state: 1, media_duplex_state: 0, padding: 0,
3378 xmit_link_speed: self.link_speed,
3379 rcv_link_speed: self.link_speed,
3380 pause_functions: 0, auto_negotiation_flags: 0,
3382 };
3383 writer.write(link_state.as_bytes())?;
3384 }
3385 rndisprot::Oid::OID_GEN_MAX_LINK_SPEED => {
3386 let link_speed = rndisprot::LinkSpeed {
3387 xmit: self.link_speed,
3388 rcv: self.link_speed,
3389 };
3390 writer.write(link_speed.as_bytes())?;
3391 }
3392 rndisprot::Oid::OID_TCP_OFFLOAD_HARDWARE_CAPABILITIES => {
3393 let ndis_offload = self.offload_support.ndis_offload();
3394 writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3395 }
3396 rndisprot::Oid::OID_TCP_OFFLOAD_CURRENT_CONFIG => {
3397 let ndis_offload = primary.offload_config.ndis_offload();
3398 writer.write(&ndis_offload.as_bytes()[..ndis_offload.header.size.into()])?;
3399 }
3400 rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3401 writer.write(
3402 &rndisprot::NdisOffloadEncapsulation {
3403 header: rndisprot::NdisObjectHeader {
3404 object_type: rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3405 revision: 1,
3406 size: rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1 as u16,
3407 },
3408 ipv4_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3409 ipv4_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3410 ipv4_header_size: ETHERNET_HEADER_LEN,
3411 ipv6_enabled: rndisprot::NDIS_OFFLOAD_SUPPORTED,
3412 ipv6_encapsulation_type: rndisprot::NDIS_ENCAPSULATION_IEEE_802_3,
3413 ipv6_header_size: ETHERNET_HEADER_LEN,
3414 }
3415 .as_bytes()[..rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1],
3416 )?;
3417 }
3418 rndisprot::Oid::OID_GEN_RECEIVE_SCALE_CAPABILITIES => {
3419 writer.write(
3420 &rndisprot::NdisReceiveScaleCapabilities {
3421 header: rndisprot::NdisObjectHeader {
3422 object_type: rndisprot::NdisObjectType::RSS_CAPABILITIES,
3423 revision: 2,
3424 size: rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2
3425 as u16,
3426 },
3427 capabilities_flags: rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV4
3428 | rndisprot::NDIS_RSS_CAPS_HASH_TYPE_TCP_IPV6
3429 | rndisprot::NDIS_HASH_FUNCTION_TOEPLITZ,
3430 number_of_interrupt_messages: 1,
3431 number_of_receive_queues: self.max_queues.into(),
3432 number_of_indirection_table_entries: if self.indirection_table_size != 0 {
3433 self.indirection_table_size
3434 } else {
3435 128
3438 },
3439 padding: 0,
3440 }
3441 .as_bytes()[..rndisprot::NDIS_SIZEOF_RECEIVE_SCALE_CAPABILITIES_REVISION_2],
3442 )?;
3443 }
3444 _ => {
3445 tracelimit::warn_ratelimited!(?oid, "query for unknown OID");
3446 return Err(OidError::UnknownOid);
3447 }
3448 };
3449 Ok(available_len - writer.len())
3450 }
3451
3452 fn handle_oid_set(
3453 &self,
3454 primary: &mut PrimaryChannelState,
3455 oid: rndisprot::Oid,
3456 reader: impl MemoryRead + Clone,
3457 ) -> Result<(bool, Option<u32>), OidError> {
3458 tracing::debug!(?oid, "oid set");
3459
3460 let mut restart_endpoint = false;
3461 let mut packet_filter = None;
3462 match oid {
3463 rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => {
3464 packet_filter = self.oid_set_packet_filter(reader)?;
3465 }
3466 rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS => {
3467 self.oid_set_offload_parameters(reader, primary)?;
3468 }
3469 rndisprot::Oid::OID_OFFLOAD_ENCAPSULATION => {
3470 self.oid_set_offload_encapsulation(reader)?;
3471 }
3472 rndisprot::Oid::OID_GEN_RNDIS_CONFIG_PARAMETER => {
3473 self.oid_set_rndis_config_parameter(reader, primary)?;
3474 }
3475 rndisprot::Oid::OID_GEN_NETWORK_LAYER_ADDRESSES => {
3476 }
3478 rndisprot::Oid::OID_GEN_RECEIVE_SCALE_PARAMETERS => {
3479 self.oid_set_rss_parameters(reader, primary)?;
3480
3481 restart_endpoint = true;
3485 }
3486 _ => {
3487 tracelimit::warn_ratelimited!(?oid, "set of unknown OID");
3488 return Err(OidError::UnknownOid);
3489 }
3490 }
3491 Ok((restart_endpoint, packet_filter))
3492 }
3493
3494 fn oid_set_rss_parameters(
3495 &self,
3496 mut reader: impl MemoryRead + Clone,
3497 primary: &mut PrimaryChannelState,
3498 ) -> Result<(), OidError> {
3499 let mut params = rndisprot::NdisReceiveScaleParameters::new_zeroed();
3501 let len = reader.len().min(size_of_val(¶ms));
3502 reader.clone().read(&mut params.as_mut_bytes()[..len])?;
3503
3504 if ((params.flags & NDIS_RSS_PARAM_FLAG_DISABLE_RSS) != 0)
3505 || ((params.hash_information & NDIS_HASH_FUNCTION_MASK) == 0)
3506 {
3507 primary.rss_state = None;
3508 return Ok(());
3509 }
3510
3511 if params.hash_secret_key_size != 40 {
3512 return Err(OidError::InvalidInput("hash_secret_key_size"));
3513 }
3514 if params.indirection_table_size % 4 != 0 {
3515 return Err(OidError::InvalidInput("indirection_table_size"));
3516 }
3517 let indirection_table_size =
3518 (params.indirection_table_size / 4).min(self.indirection_table_size) as usize;
3519 let mut key = [0; 40];
3520 let mut indirection_table = vec![0u32; self.indirection_table_size as usize];
3521 reader
3522 .clone()
3523 .skip(params.hash_secret_key_offset as usize)?
3524 .read(&mut key)?;
3525 reader
3526 .skip(params.indirection_table_offset as usize)?
3527 .read(indirection_table[..indirection_table_size].as_mut_bytes())?;
3528 tracelimit::info_ratelimited!(?indirection_table, "OID_GEN_RECEIVE_SCALE_PARAMETERS");
3529 if indirection_table
3530 .iter()
3531 .any(|&x| x >= self.max_queues as u32)
3532 {
3533 return Err(OidError::InvalidInput("indirection_table"));
3534 }
3535 let (indir_init, indir_uninit) = indirection_table.split_at_mut(indirection_table_size);
3536 for (src, dest) in std::iter::repeat_with(|| indir_init.iter().copied())
3537 .flatten()
3538 .zip(indir_uninit)
3539 {
3540 *dest = src;
3541 }
3542 primary.rss_state = Some(RssState {
3543 key,
3544 indirection_table: indirection_table.iter().map(|&x| x as u16).collect(),
3545 });
3546 Ok(())
3547 }
3548
3549 fn oid_set_packet_filter(
3550 &self,
3551 reader: impl MemoryRead + Clone,
3552 ) -> Result<Option<u32>, OidError> {
3553 let filter: rndisprot::RndisPacketFilterOidValue = reader.clone().read_plain()?;
3554 tracing::debug!(filter, "set packet filter");
3555 Ok(Some(filter))
3556 }
3557
3558 fn oid_set_offload_parameters(
3559 &self,
3560 reader: impl MemoryRead + Clone,
3561 primary: &mut PrimaryChannelState,
3562 ) -> Result<(), OidError> {
3563 let offload: rndisprot::NdisOffloadParameters = read_ndis_object(
3564 reader,
3565 rndisprot::NdisObjectType::DEFAULT,
3566 1,
3567 rndisprot::NDIS_SIZEOF_OFFLOAD_PARAMETERS_REVISION_1,
3568 )?;
3569
3570 tracing::debug!(?offload, "offload parameters");
3571 let rndisprot::NdisOffloadParameters {
3572 header: _,
3573 ipv4_checksum,
3574 tcp4_checksum,
3575 udp4_checksum,
3576 tcp6_checksum,
3577 udp6_checksum,
3578 lsov1,
3579 ipsec_v1: _,
3580 lsov2_ipv4,
3581 lsov2_ipv6,
3582 tcp_connection_ipv4: _,
3583 tcp_connection_ipv6: _,
3584 reserved: _,
3585 flags: _,
3586 } = offload;
3587
3588 if lsov1 == rndisprot::OffloadParametersSimple::ENABLED {
3589 return Err(OidError::NotSupported("lsov1"));
3590 }
3591 if let Some((tx, rx)) = ipv4_checksum.tx_rx() {
3592 primary.offload_config.checksum_tx.ipv4_header = tx;
3593 primary.offload_config.checksum_rx.ipv4_header = rx;
3594 }
3595 if let Some((tx, rx)) = tcp4_checksum.tx_rx() {
3596 primary.offload_config.checksum_tx.tcp4 = tx;
3597 primary.offload_config.checksum_rx.tcp4 = rx;
3598 }
3599 if let Some((tx, rx)) = tcp6_checksum.tx_rx() {
3600 primary.offload_config.checksum_tx.tcp6 = tx;
3601 primary.offload_config.checksum_rx.tcp6 = rx;
3602 }
3603 if let Some((tx, rx)) = udp4_checksum.tx_rx() {
3604 primary.offload_config.checksum_tx.udp4 = tx;
3605 primary.offload_config.checksum_rx.udp4 = rx;
3606 }
3607 if let Some((tx, rx)) = udp6_checksum.tx_rx() {
3608 primary.offload_config.checksum_tx.udp6 = tx;
3609 primary.offload_config.checksum_rx.udp6 = rx;
3610 }
3611 if let Some(enable) = lsov2_ipv4.enable() {
3612 primary.offload_config.lso4 = enable;
3613 }
3614 if let Some(enable) = lsov2_ipv6.enable() {
3615 primary.offload_config.lso6 = enable;
3616 }
3617 primary
3618 .offload_config
3619 .mask_to_supported(&self.offload_support);
3620 primary.pending_offload_change = true;
3621 Ok(())
3622 }
3623
3624 fn oid_set_offload_encapsulation(
3625 &self,
3626 reader: impl MemoryRead + Clone,
3627 ) -> Result<(), OidError> {
3628 let encap: rndisprot::NdisOffloadEncapsulation = read_ndis_object(
3629 reader,
3630 rndisprot::NdisObjectType::OFFLOAD_ENCAPSULATION,
3631 1,
3632 rndisprot::NDIS_SIZEOF_OFFLOAD_ENCAPSULATION_REVISION_1,
3633 )?;
3634 if encap.ipv4_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3635 && (encap.ipv4_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3636 || encap.ipv4_header_size != ETHERNET_HEADER_LEN)
3637 {
3638 return Err(OidError::NotSupported("ipv4 encap"));
3639 }
3640 if encap.ipv6_enabled == rndisprot::NDIS_OFFLOAD_SET_ON
3641 && (encap.ipv6_encapsulation_type != rndisprot::NDIS_ENCAPSULATION_IEEE_802_3
3642 || encap.ipv6_header_size != ETHERNET_HEADER_LEN)
3643 {
3644 return Err(OidError::NotSupported("ipv6 encap"));
3645 }
3646 Ok(())
3647 }
3648
3649 fn oid_set_rndis_config_parameter(
3650 &self,
3651 reader: impl MemoryRead + Clone,
3652 primary: &mut PrimaryChannelState,
3653 ) -> Result<(), OidError> {
3654 let info: rndisprot::RndisConfigParameterInfo = reader.clone().read_plain()?;
3655 if info.name_length > 255 {
3656 return Err(OidError::InvalidInput("name_length"));
3657 }
3658 if info.value_length > 255 {
3659 return Err(OidError::InvalidInput("value_length"));
3660 }
3661 let name = reader
3662 .clone()
3663 .skip(info.name_offset as usize)?
3664 .read_n::<u16>(info.name_length as usize / 2)?;
3665 let name = String::from_utf16(&name).map_err(|_| OidError::InvalidInput("name"))?;
3666 let mut value = reader;
3667 value.skip(info.value_offset as usize)?;
3668 let mut value = value.limit(info.value_length as usize);
3669 match info.parameter_type {
3670 rndisprot::NdisParameterType::STRING => {
3671 let value = value.read_n::<u16>(info.value_length as usize / 2)?;
3672 let value =
3673 String::from_utf16(&value).map_err(|_| OidError::InvalidInput("value"))?;
3674 let as_num = value.as_bytes().first().map_or(0, |c| c - b'0');
3675 let tx = as_num & 1 != 0;
3676 let rx = as_num & 2 != 0;
3677
3678 tracing::debug!(name, value, "rndis config");
3679 match name.as_str() {
3680 "*IPChecksumOffloadIPv4" => {
3681 primary.offload_config.checksum_tx.ipv4_header = tx;
3682 primary.offload_config.checksum_rx.ipv4_header = rx;
3683 }
3684 "*LsoV2IPv4" => {
3685 primary.offload_config.lso4 = as_num != 0;
3686 }
3687 "*LsoV2IPv6" => {
3688 primary.offload_config.lso6 = as_num != 0;
3689 }
3690 "*TCPChecksumOffloadIPv4" => {
3691 primary.offload_config.checksum_tx.tcp4 = tx;
3692 primary.offload_config.checksum_rx.tcp4 = rx;
3693 }
3694 "*TCPChecksumOffloadIPv6" => {
3695 primary.offload_config.checksum_tx.tcp6 = tx;
3696 primary.offload_config.checksum_rx.tcp6 = rx;
3697 }
3698 "*UDPChecksumOffloadIPv4" => {
3699 primary.offload_config.checksum_tx.udp4 = tx;
3700 primary.offload_config.checksum_rx.udp4 = rx;
3701 }
3702 "*UDPChecksumOffloadIPv6" => {
3703 primary.offload_config.checksum_tx.udp6 = tx;
3704 primary.offload_config.checksum_rx.udp6 = rx;
3705 }
3706 _ => {}
3707 }
3708 primary
3709 .offload_config
3710 .mask_to_supported(&self.offload_support);
3711 }
3712 rndisprot::NdisParameterType::INTEGER => {
3713 let value: u32 = value.read_plain()?;
3714 tracing::debug!(name, value, "rndis config");
3715 }
3716 parameter_type => tracelimit::warn_ratelimited!(
3717 name,
3718 ?parameter_type,
3719 "unhandled rndis config parameter type"
3720 ),
3721 }
3722 Ok(())
3723 }
3724}
3725
3726fn read_ndis_object<T: IntoBytes + FromBytes + Debug + Immutable + KnownLayout>(
3727 mut reader: impl MemoryRead,
3728 object_type: rndisprot::NdisObjectType,
3729 min_revision: u8,
3730 min_size: usize,
3731) -> Result<T, OidError> {
3732 let mut buffer = T::new_zeroed();
3733 let sent_size = reader.len();
3734 let len = sent_size.min(size_of_val(&buffer));
3735 reader.read(&mut buffer.as_mut_bytes()[..len])?;
3736 validate_ndis_object_header(
3737 &rndisprot::NdisObjectHeader::read_from_prefix(buffer.as_bytes())
3738 .unwrap()
3739 .0, sent_size,
3741 object_type,
3742 min_revision,
3743 min_size,
3744 )?;
3745 Ok(buffer)
3746}
3747
3748fn validate_ndis_object_header(
3749 header: &rndisprot::NdisObjectHeader,
3750 sent_size: usize,
3751 object_type: rndisprot::NdisObjectType,
3752 min_revision: u8,
3753 min_size: usize,
3754) -> Result<(), OidError> {
3755 if header.object_type != object_type {
3756 return Err(OidError::InvalidInput("header.object_type"));
3757 }
3758 if sent_size < header.size as usize {
3759 return Err(OidError::InvalidInput("header.size"));
3760 }
3761 if header.revision < min_revision {
3762 return Err(OidError::InvalidInput("header.revision"));
3763 }
3764 if (header.size as usize) < min_size {
3765 return Err(OidError::InvalidInput("header.size"));
3766 }
3767 Ok(())
3768}
3769
3770struct Coordinator {
3771 recv: mpsc::Receiver<CoordinatorMessage>,
3772 channel_control: ChannelControl,
3773 restart: bool,
3774 workers: Vec<TaskControl<NetQueue, Worker<GpadlRingMem>>>,
3775 buffers: Option<Arc<ChannelBuffers>>,
3776 num_queues: u16,
3777 active_packet_filter: u32,
3778 sleep_deadline: Option<Instant>,
3779}
3780
3781enum CoordinatorStatePendingVfState {
3785 Ready,
3787 Delay {
3789 timer: PolledTimer,
3790 delay_until: Instant,
3791 },
3792 Pending,
3794}
3795
3796struct CoordinatorState {
3797 endpoint: Box<dyn Endpoint>,
3798 adapter: Arc<Adapter>,
3799 virtual_function: Option<Box<dyn VirtualFunction>>,
3800 pending_vf_state: CoordinatorStatePendingVfState,
3801}
3802
3803impl InspectTaskMut<Coordinator> for CoordinatorState {
3804 fn inspect_mut(
3805 &mut self,
3806 req: inspect::Request<'_>,
3807 mut coordinator: Option<&mut Coordinator>,
3808 ) {
3809 let mut resp = req.respond();
3810
3811 let adapter = self.adapter.as_ref();
3812 resp.field("mac_address", adapter.mac_address)
3813 .field("max_queues", adapter.max_queues)
3814 .sensitivity_field(
3815 "offload_support",
3816 SensitivityLevel::Safe,
3817 &adapter.offload_support,
3818 )
3819 .field_mut_with("ring_size_limit", |v| -> anyhow::Result<_> {
3820 if let Some(v) = v {
3821 let v: usize = v.parse()?;
3822 adapter.ring_size_limit.store(v, Ordering::Relaxed);
3823 if let Some(this) = &mut coordinator {
3825 for worker in &mut this.workers {
3826 worker.update_with(|_, _| ());
3827 }
3828 }
3829 }
3830 Ok(adapter.ring_size_limit.load(Ordering::Relaxed))
3831 });
3832
3833 resp.field("endpoint_type", self.endpoint.endpoint_type())
3834 .field(
3835 "endpoint_max_queues",
3836 self.endpoint.multiqueue_support().max_queues,
3837 )
3838 .sensitivity_field_mut("endpoint", SensitivityLevel::Safe, self.endpoint.as_mut());
3839
3840 if let Some(coordinator) = coordinator {
3841 resp.sensitivity_child("queues", SensitivityLevel::Safe, |req| {
3842 let mut resp = req.respond();
3843 for (i, q) in coordinator.workers[..coordinator.num_queues as usize]
3844 .iter_mut()
3845 .enumerate()
3846 {
3847 resp.field_mut(&i.to_string(), q);
3848 }
3849 });
3850
3851 resp.merge(inspect::adhoc_mut(|req| {
3853 let deferred = req.defer();
3854 coordinator.workers[0].update_with(|_, worker| {
3855 let Some(worker) = worker.as_deref() else {
3856 return;
3857 };
3858 if let Some(state) = worker.state.ready() {
3859 deferred.respond(|resp| {
3860 resp.merge(&state.buffers);
3861 resp.sensitivity_field(
3862 "primary_channel_state",
3863 SensitivityLevel::Safe,
3864 &state.state.primary,
3865 )
3866 .sensitivity_field(
3867 "packet_filter",
3868 SensitivityLevel::Safe,
3869 inspect::AsHex(worker.channel.packet_filter),
3870 );
3871 });
3872 }
3873 })
3874 }));
3875 }
3876 }
3877}
3878
3879impl AsyncRun<Coordinator> for CoordinatorState {
3880 async fn run(
3881 &mut self,
3882 stop: &mut StopTask<'_>,
3883 coordinator: &mut Coordinator,
3884 ) -> Result<(), task_control::Cancelled> {
3885 coordinator.process(stop, self).await
3886 }
3887}
3888
3889impl Coordinator {
3890 async fn process(
3891 &mut self,
3892 stop: &mut StopTask<'_>,
3893 state: &mut CoordinatorState,
3894 ) -> Result<(), task_control::Cancelled> {
3895 loop {
3896 if self.restart {
3897 stop.until_stopped(self.stop_workers()).await?;
3898 if let Err(err) = self
3901 .restart_queues(state)
3902 .instrument(tracing::info_span!("netvsp_restart_queues"))
3903 .await
3904 {
3905 tracing::error!(
3906 error = &err as &dyn std::error::Error,
3907 "failed to restart queues"
3908 );
3909 }
3910 if let Some(primary) = self.primary_mut() {
3911 primary.is_data_path_switched =
3912 state.endpoint.get_data_path_to_guest_vf().await.ok();
3913 tracing::info!(
3914 is_data_path_switched = primary.is_data_path_switched,
3915 "Query data path state"
3916 );
3917 }
3918 self.restore_guest_vf_state(state).await;
3919 self.restart = false;
3920 }
3921
3922 for worker in &mut self.workers[1..] {
3925 worker.start();
3926 }
3927 if !self.workers[0].is_running()
3928 && self.workers[0].state().is_none_or(|worker| {
3929 !matches!(worker.state, WorkerState::WaitingForCoordinator(_))
3930 })
3931 {
3932 self.workers[0].start();
3933 }
3934
3935 enum Message {
3936 Internal(CoordinatorMessage),
3937 ChannelDisconnected,
3938 UpdateFromEndpoint(EndpointAction),
3939 UpdateFromVf(Rpc<(), ()>),
3940 OfferVfDevice,
3941 PendingVfStateComplete,
3942 TimerExpired,
3943 }
3944 let message = if matches!(
3945 state.pending_vf_state,
3946 CoordinatorStatePendingVfState::Pending
3947 ) {
3948 state
3951 .virtual_function
3952 .as_mut()
3953 .expect("Pending requires a VF")
3954 .guest_ready_for_device()
3955 .await;
3956 Message::PendingVfStateComplete
3957 } else {
3958 let timer_sleep = async {
3959 if let Some(deadline) = self.sleep_deadline {
3960 let mut timer = PolledTimer::new(&state.adapter.driver);
3961 timer.sleep_until(deadline).await;
3962 } else {
3963 pending::<()>().await;
3964 }
3965 Message::TimerExpired
3966 };
3967 let wait_for_message = async {
3968 let internal_msg = self
3969 .recv
3970 .next()
3971 .map(|x| x.map_or(Message::ChannelDisconnected, Message::Internal));
3972 let endpoint_restart = state
3973 .endpoint
3974 .wait_for_endpoint_action()
3975 .map(Message::UpdateFromEndpoint);
3976 if let Some(vf) = state.virtual_function.as_mut() {
3977 match state.pending_vf_state {
3978 CoordinatorStatePendingVfState::Ready
3979 | CoordinatorStatePendingVfState::Delay { .. } => {
3980 let offer_device = async {
3981 if let CoordinatorStatePendingVfState::Delay {
3982 timer,
3983 delay_until,
3984 } = &mut state.pending_vf_state
3985 {
3986 timer.sleep_until(*delay_until).await;
3987 } else {
3988 pending::<()>().await;
3989 }
3990 Message::OfferVfDevice
3991 };
3992 (
3993 internal_msg,
3994 offer_device,
3995 endpoint_restart,
3996 vf.wait_for_state_change().map(Message::UpdateFromVf),
3997 timer_sleep,
3998 )
3999 .race()
4000 .await
4001 }
4002 CoordinatorStatePendingVfState::Pending => unreachable!(),
4003 }
4004 } else {
4005 (internal_msg, endpoint_restart, timer_sleep).race().await
4006 }
4007 };
4008
4009 stop.until_stopped(wait_for_message).await?
4010 };
4011 match message {
4012 Message::Internal(msg) => {
4013 self.handle_coordinator_message(msg, state).await;
4014 }
4015 Message::UpdateFromVf(rpc) => {
4016 rpc.handle(async |_| {
4017 self.update_guest_vf_state(state).await;
4018 })
4019 .await;
4020 }
4021 Message::OfferVfDevice => {
4022 self.workers[0].stop().await;
4023 if let Some(primary) = self.primary_mut() {
4024 if matches!(
4025 primary.guest_vf_state,
4026 PrimaryChannelGuestVfState::AvailableAdvertised
4027 ) {
4028 primary.guest_vf_state = PrimaryChannelGuestVfState::Ready;
4029 }
4030 }
4031
4032 state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4033 }
4034 Message::PendingVfStateComplete => {
4035 state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
4036 }
4037 Message::TimerExpired => {
4038 self.workers[0].stop().await;
4040 if let Some(primary) = self.primary_mut() {
4041 if let PendingLinkAction::Delay(up) = primary.pending_link_action {
4042 primary.pending_link_action = PendingLinkAction::Active(up);
4043 }
4044 }
4045 self.sleep_deadline = None;
4046 }
4047 Message::UpdateFromEndpoint(EndpointAction::RestartRequired) => self.restart = true,
4048 Message::UpdateFromEndpoint(EndpointAction::LinkStatusNotify(connect)) => {
4049 self.workers[0].stop().await;
4050
4051 if let Some(primary) = self.primary_mut() {
4059 primary.pending_link_action = PendingLinkAction::Active(connect);
4060 }
4061
4062 self.sleep_deadline = None;
4064 }
4065 Message::ChannelDisconnected => {
4066 break;
4067 }
4068 };
4069 }
4070 Ok(())
4071 }
4072
4073 async fn handle_coordinator_message(
4074 &mut self,
4075 msg: CoordinatorMessage,
4076 state: &mut CoordinatorState,
4077 ) {
4078 self.workers[0].stop().await;
4079 if let Some(worker) = self.workers[0].state_mut() {
4080 if matches!(worker.state, WorkerState::WaitingForCoordinator(_)) {
4081 let WorkerState::WaitingForCoordinator(Some(ready)) =
4082 std::mem::replace(&mut worker.state, WorkerState::WaitingForCoordinator(None))
4083 else {
4084 unreachable!("valid ready state")
4085 };
4086 let _ = std::mem::replace(&mut worker.state, WorkerState::Ready(ready));
4087 }
4088 }
4089 match msg {
4090 CoordinatorMessage::Update(update_type) => {
4091 if update_type.filter_state {
4092 self.stop_workers().await;
4093 self.active_packet_filter =
4094 self.workers[0].state().unwrap().channel.packet_filter;
4095 self.workers.iter_mut().skip(1).for_each(|worker| {
4096 if let Some(state) = worker.state_mut() {
4097 state.channel.packet_filter = self.active_packet_filter;
4098 tracing::debug!(
4099 packet_filter = ?self.active_packet_filter,
4100 channel_idx = state.channel_idx,
4101 "update packet filter"
4102 );
4103 }
4104 });
4105 }
4106
4107 if update_type.guest_vf_state {
4108 self.update_guest_vf_state(state).await;
4109 }
4110 }
4111 CoordinatorMessage::StartTimer(deadline) => {
4112 self.sleep_deadline = Some(deadline);
4113 }
4114 CoordinatorMessage::Restart => self.restart = true,
4115 }
4116 }
4117
4118 async fn stop_workers(&mut self) {
4119 for worker in &mut self.workers {
4120 worker.stop().await;
4121 }
4122 }
4123
4124 async fn restore_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
4125 let primary = match self.primary_mut() {
4126 Some(primary) => primary,
4127 None => return,
4128 };
4129
4130 let virtual_function = c_state.virtual_function.as_mut();
4132 let guest_vf_id = match &virtual_function {
4133 Some(vf) => vf.id().await,
4134 None => None,
4135 };
4136 if let Some(guest_vf_id) = guest_vf_id {
4137 match primary.guest_vf_state {
4139 PrimaryChannelGuestVfState::AvailableAdvertised
4140 | PrimaryChannelGuestVfState::Restoring(
4141 saved_state::GuestVfState::AvailableAdvertised,
4142 ) => {
4143 if !primary.is_data_path_switched.unwrap_or(false) {
4144 let timer = PolledTimer::new(&c_state.adapter.driver);
4145 c_state.pending_vf_state = CoordinatorStatePendingVfState::Delay {
4146 timer,
4147 delay_until: Instant::now() + VF_DEVICE_DELAY,
4148 };
4149 }
4150 }
4151 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4152 | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4153 | PrimaryChannelGuestVfState::Ready
4154 | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready)
4155 | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4156 | PrimaryChannelGuestVfState::Restoring(
4157 saved_state::GuestVfState::DataPathSwitchPending { .. },
4158 )
4159 | PrimaryChannelGuestVfState::DataPathSwitched
4160 | PrimaryChannelGuestVfState::Restoring(
4161 saved_state::GuestVfState::DataPathSwitched,
4162 )
4163 | PrimaryChannelGuestVfState::DataPathSynthetic => {
4164 c_state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4165 }
4166 _ => (),
4167 };
4168 if let PrimaryChannelGuestVfState::Restoring(
4170 saved_state::GuestVfState::DataPathSwitchPending {
4171 to_guest,
4172 id,
4173 result,
4174 },
4175 ) = primary.guest_vf_state
4176 {
4177 if result.is_some() {
4179 primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
4180 to_guest,
4181 id,
4182 result,
4183 };
4184 return;
4185 }
4186 }
4187 primary.guest_vf_state = match primary.guest_vf_state {
4188 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
4189 | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4190 | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4191 | PrimaryChannelGuestVfState::Restoring(
4192 saved_state::GuestVfState::DataPathSwitchPending { .. },
4193 )
4194 | PrimaryChannelGuestVfState::DataPathSynthetic => {
4195 let (to_guest, id) = match primary.guest_vf_state {
4196 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4197 to_guest,
4198 id,
4199 }
4200 | PrimaryChannelGuestVfState::DataPathSwitchPending {
4201 to_guest, id, ..
4202 }
4203 | PrimaryChannelGuestVfState::Restoring(
4204 saved_state::GuestVfState::DataPathSwitchPending {
4205 to_guest, id, ..
4206 },
4207 ) => (to_guest, id),
4208 _ => (true, None),
4209 };
4210 if matches!(
4214 c_state.pending_vf_state,
4215 CoordinatorStatePendingVfState::Delay { .. }
4216 ) {
4217 c_state.pending_vf_state = CoordinatorStatePendingVfState::Pending;
4218 }
4219 let result = c_state.endpoint.set_data_path_to_guest_vf(to_guest).await;
4220 let result = if let Err(err) = result {
4221 tracing::error!(
4222 err = err.as_ref() as &dyn std::error::Error,
4223 to_guest,
4224 "Failed to switch guest VF data path"
4225 );
4226 false
4227 } else {
4228 primary.is_data_path_switched = Some(to_guest);
4229 true
4230 };
4231 match primary.guest_vf_state {
4232 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4233 ..
4234 }
4235 | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
4236 | PrimaryChannelGuestVfState::Restoring(
4237 saved_state::GuestVfState::DataPathSwitchPending { .. },
4238 ) => PrimaryChannelGuestVfState::DataPathSwitchPending {
4239 to_guest,
4240 id,
4241 result: Some(result),
4242 },
4243 _ if result => PrimaryChannelGuestVfState::DataPathSwitched,
4244 _ => PrimaryChannelGuestVfState::DataPathSynthetic,
4245 }
4246 }
4247 PrimaryChannelGuestVfState::Initializing
4248 | PrimaryChannelGuestVfState::Unavailable
4249 | PrimaryChannelGuestVfState::UnavailableFromAvailable
4250 | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState) => {
4251 PrimaryChannelGuestVfState::Available { vfid: guest_vf_id }
4252 }
4253 PrimaryChannelGuestVfState::AvailableAdvertised
4254 | PrimaryChannelGuestVfState::Restoring(
4255 saved_state::GuestVfState::AvailableAdvertised,
4256 ) => {
4257 if !primary.is_data_path_switched.unwrap_or(false) {
4258 PrimaryChannelGuestVfState::AvailableAdvertised
4259 } else {
4260 PrimaryChannelGuestVfState::DataPathSwitched
4263 }
4264 }
4265 PrimaryChannelGuestVfState::DataPathSwitched
4266 | PrimaryChannelGuestVfState::Restoring(
4267 saved_state::GuestVfState::DataPathSwitched,
4268 ) => PrimaryChannelGuestVfState::DataPathSwitched,
4269 PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4270 PrimaryChannelGuestVfState::Ready
4271 }
4272 _ => primary.guest_vf_state,
4273 };
4274 } else {
4275 match primary.guest_vf_state {
4277 PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, .. }
4278 | PrimaryChannelGuestVfState::Restoring(
4279 saved_state::GuestVfState::DataPathSwitchPending { to_guest, .. },
4280 ) => {
4281 if !to_guest {
4282 if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4283 tracing::warn!(
4284 err = err.as_ref() as &dyn std::error::Error,
4285 "Failed setting data path back to synthetic after guest VF was removed."
4286 );
4287 }
4288 primary.is_data_path_switched = Some(false);
4289 }
4290 }
4291 PrimaryChannelGuestVfState::DataPathSwitched
4292 | PrimaryChannelGuestVfState::Restoring(
4293 saved_state::GuestVfState::DataPathSwitched,
4294 ) => {
4295 if let Err(err) = c_state.endpoint.set_data_path_to_guest_vf(false).await {
4296 tracing::warn!(
4297 err = err.as_ref() as &dyn std::error::Error,
4298 "Failed setting data path back to synthetic after guest VF was removed."
4299 );
4300 }
4301 primary.is_data_path_switched = Some(false);
4302 }
4303 _ => (),
4304 }
4305 if let PrimaryChannelGuestVfState::AvailableAdvertised = primary.guest_vf_state {
4306 c_state.pending_vf_state = CoordinatorStatePendingVfState::Ready;
4307 }
4308 primary.guest_vf_state = match primary.guest_vf_state {
4310 PrimaryChannelGuestVfState::Initializing
4311 | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::NoState)
4312 | PrimaryChannelGuestVfState::Available { .. } => {
4313 PrimaryChannelGuestVfState::Unavailable
4314 }
4315 PrimaryChannelGuestVfState::AvailableAdvertised
4316 | PrimaryChannelGuestVfState::Restoring(
4317 saved_state::GuestVfState::AvailableAdvertised,
4318 )
4319 | PrimaryChannelGuestVfState::Ready
4320 | PrimaryChannelGuestVfState::Restoring(saved_state::GuestVfState::Ready) => {
4321 PrimaryChannelGuestVfState::UnavailableFromAvailable
4322 }
4323 PrimaryChannelGuestVfState::DataPathSwitchPending { to_guest, id, .. }
4324 | PrimaryChannelGuestVfState::Restoring(
4325 saved_state::GuestVfState::DataPathSwitchPending { to_guest, id, .. },
4326 ) => PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending {
4327 to_guest,
4328 id,
4329 },
4330 PrimaryChannelGuestVfState::DataPathSwitched
4331 | PrimaryChannelGuestVfState::Restoring(
4332 saved_state::GuestVfState::DataPathSwitched,
4333 )
4334 | PrimaryChannelGuestVfState::DataPathSynthetic => {
4335 PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
4336 }
4337 _ => primary.guest_vf_state,
4338 }
4339 }
4340 }
4341
4342 async fn restart_queues(&mut self, c_state: &mut CoordinatorState) -> Result<(), WorkerError> {
4343 let drivers = self
4345 .workers
4346 .iter_mut()
4347 .map(|worker| {
4348 let task = worker.task_mut();
4349 task.queue_state = None;
4350 task.driver.clone()
4351 })
4352 .collect::<Vec<_>>();
4353
4354 c_state.endpoint.stop().await;
4355
4356 let (primary_worker, subworkers) = if let [primary, sub @ ..] = self.workers.as_mut_slice()
4357 {
4358 (primary, sub)
4359 } else {
4360 unreachable!()
4361 };
4362
4363 let state = primary_worker
4364 .state_mut()
4365 .and_then(|worker| worker.state.ready_mut());
4366
4367 let state = if let Some(state) = state {
4368 state
4369 } else {
4370 return Ok(());
4371 };
4372
4373 self.buffers = Some(state.buffers.clone());
4375
4376 let num_queues = state.state.primary.as_ref().unwrap().requested_num_queues;
4377 let mut active_queues = Vec::new();
4378 let active_queue_count =
4379 if let Some(rss_state) = state.state.primary.as_ref().unwrap().rss_state.as_ref() {
4380 active_queues.clone_from(&rss_state.indirection_table);
4382 active_queues.sort();
4383 active_queues.dedup();
4384 active_queues = active_queues
4385 .into_iter()
4386 .filter(|&index| index < num_queues)
4387 .collect::<Vec<_>>();
4388 if !active_queues.is_empty() {
4389 active_queues.len() as u16
4390 } else {
4391 tracelimit::warn_ratelimited!("Invalid RSS indirection table");
4392 num_queues
4393 }
4394 } else {
4395 num_queues
4396 };
4397
4398 let (ranges, mut remote_buffer_id_recvs) =
4400 RxBufferRanges::new(state.buffers.recv_buffer.count, active_queue_count.into());
4401 let ranges = Arc::new(ranges);
4402
4403 let mut queues = Vec::new();
4404 let mut rx_buffers = Vec::new();
4405 {
4406 let buffers = &state.buffers;
4407 let guest_buffers = Arc::new(
4408 GuestBuffers::new(
4409 buffers.mem.clone(),
4410 buffers.recv_buffer.gpadl.clone(),
4411 buffers.recv_buffer.sub_allocation_size,
4412 buffers.ndis_config.mtu,
4413 )
4414 .map_err(WorkerError::GpadlError)?,
4415 );
4416
4417 let mut queue_config = Vec::new();
4420 let initial_rx;
4421 {
4422 let states = std::iter::once(Some(&*state)).chain(
4423 subworkers
4424 .iter()
4425 .map(|worker| worker.state().and_then(|worker| worker.state.ready())),
4426 );
4427
4428 initial_rx = (RX_RESERVED_CONTROL_BUFFERS..state.buffers.recv_buffer.count)
4429 .filter(|&n| states.clone().flatten().all(|s| s.state.rx_bufs.is_free(n)))
4430 .map(RxId)
4431 .collect::<Vec<_>>();
4432
4433 let mut initial_rx = initial_rx.as_slice();
4434 let mut range_start = 0;
4435 let primary_queue_excluded = !active_queues.is_empty() && active_queues[0] != 0;
4436 let first_queue = if !primary_queue_excluded {
4437 0
4438 } else {
4439 queue_config.push(QueueConfig {
4443 pool: Box::new(BufferPool::new(guest_buffers.clone())),
4444 initial_rx: &[],
4445 driver: Box::new(drivers[0].clone()),
4446 });
4447 rx_buffers.push(RxBufferRange::new(
4448 ranges.clone(),
4449 0..RX_RESERVED_CONTROL_BUFFERS,
4450 None,
4451 ));
4452 range_start = RX_RESERVED_CONTROL_BUFFERS;
4453 1
4454 };
4455 for queue_index in first_queue..num_queues {
4456 let queue_active = active_queues.is_empty()
4457 || active_queues.binary_search(&queue_index).is_ok();
4458 let (range_end, end, buffer_id_recv) = if queue_active {
4459 let range_end = if rx_buffers.len() as u16 == active_queue_count - 1 {
4460 state.buffers.recv_buffer.count
4462 } else if queue_index == 0 {
4463 RX_RESERVED_CONTROL_BUFFERS + ranges.buffers_per_queue
4465 } else {
4466 range_start + ranges.buffers_per_queue
4467 };
4468 (
4469 range_end,
4470 initial_rx.partition_point(|id| id.0 < range_end),
4471 Some(remote_buffer_id_recvs.remove(0)),
4472 )
4473 } else {
4474 (range_start, 0, None)
4475 };
4476
4477 let (this, rest) = initial_rx.split_at(end);
4478 queue_config.push(QueueConfig {
4479 pool: Box::new(BufferPool::new(guest_buffers.clone())),
4480 initial_rx: this,
4481 driver: Box::new(drivers[queue_index as usize].clone()),
4482 });
4483 initial_rx = rest;
4484 rx_buffers.push(RxBufferRange::new(
4485 ranges.clone(),
4486 range_start..range_end,
4487 buffer_id_recv,
4488 ));
4489
4490 range_start = range_end;
4491 }
4492 }
4493
4494 let primary = state.state.primary.as_mut().unwrap();
4495 tracing::debug!(num_queues, "enabling endpoint");
4496
4497 let rss = primary
4498 .rss_state
4499 .as_ref()
4500 .map(|rss| net_backend::RssConfig {
4501 key: &rss.key,
4502 indirection_table: &rss.indirection_table,
4503 flags: 0,
4504 });
4505
4506 c_state
4507 .endpoint
4508 .get_queues(queue_config, rss.as_ref(), &mut queues)
4509 .instrument(tracing::info_span!("netvsp_get_queues"))
4510 .await
4511 .map_err(WorkerError::Endpoint)?;
4512
4513 assert_eq!(queues.len(), num_queues as usize);
4514
4515 self.channel_control
4517 .enable_subchannels(num_queues - 1)
4518 .expect("already validated");
4519
4520 self.num_queues = num_queues;
4521 }
4522
4523 self.active_packet_filter = self.workers[0].state().unwrap().channel.packet_filter;
4524 for ((worker, queue), rx_buffer) in self.workers.iter_mut().zip(queues).zip(rx_buffers) {
4526 worker.task_mut().queue_state = Some(QueueState {
4527 queue,
4528 target_vp_set: false,
4529 rx_buffer_range: rx_buffer,
4530 });
4531 if let Some(worker) = worker.state_mut() {
4533 worker.channel.packet_filter = self.active_packet_filter;
4534 if let Some(ready_state) = worker.state.ready_mut() {
4536 ready_state.state.pending_rx_packets.clear();
4537 }
4538 }
4539 }
4540
4541 Ok(())
4542 }
4543
4544 fn primary_mut(&mut self) -> Option<&mut PrimaryChannelState> {
4545 self.workers[0]
4546 .state_mut()
4547 .unwrap()
4548 .state
4549 .ready_mut()?
4550 .state
4551 .primary
4552 .as_mut()
4553 }
4554
4555 async fn update_guest_vf_state(&mut self, c_state: &mut CoordinatorState) {
4556 self.workers[0].stop().await;
4557 self.restore_guest_vf_state(c_state).await;
4558 }
4559}
4560
4561impl<T: RingMem + 'static + Sync> AsyncRun<Worker<T>> for NetQueue {
4562 async fn run(
4563 &mut self,
4564 stop: &mut StopTask<'_>,
4565 worker: &mut Worker<T>,
4566 ) -> Result<(), task_control::Cancelled> {
4567 match worker.process(stop, self).await {
4568 Ok(()) | Err(WorkerError::BufferRevoked) => {}
4569 Err(WorkerError::Cancelled(cancelled)) => return Err(cancelled),
4570 Err(err) => {
4571 tracing::error!(
4572 channel_idx = worker.channel_idx,
4573 error = &err as &dyn std::error::Error,
4574 "netvsp error"
4575 );
4576 }
4577 }
4578 Ok(())
4579 }
4580}
4581
4582impl<T: RingMem + 'static> Worker<T> {
4583 async fn process(
4584 &mut self,
4585 stop: &mut StopTask<'_>,
4586 queue: &mut NetQueue,
4587 ) -> Result<(), WorkerError> {
4588 loop {
4592 match &mut self.state {
4593 WorkerState::Init(initializing) => {
4594 assert_eq!(self.channel_idx, 0);
4595
4596 tracelimit::info_ratelimited!("network accepted");
4597
4598 let (buffers, state) = stop
4599 .until_stopped(self.channel.initialize(initializing, self.mem.clone()))
4600 .await??;
4601
4602 let state = ReadyState {
4603 buffers: Arc::new(buffers),
4604 state,
4605 data: ProcessingData::new(),
4606 };
4607
4608 let _ = self.coordinator_send.try_send(CoordinatorMessage::Restart);
4610
4611 tracelimit::info_ratelimited!("network initialized");
4612 self.state = WorkerState::WaitingForCoordinator(Some(state));
4613 }
4614 WorkerState::WaitingForCoordinator(_) => {
4615 assert_eq!(self.channel_idx, 0);
4616 stop.until_stopped(pending()).await?
4619 }
4620 WorkerState::Ready(state) => {
4621 let queue_state = if let Some(queue_state) = &mut queue.queue_state {
4622 if !queue_state.target_vp_set
4623 && self.target_vp != vmbus_core::protocol::VP_INDEX_DISABLE_INTERRUPT
4624 {
4625 tracing::debug!(
4626 channel_idx = self.channel_idx,
4627 target_vp = self.target_vp,
4628 "updating target VP"
4629 );
4630 queue_state.queue.update_target_vp(self.target_vp).await;
4631 queue_state.target_vp_set = true;
4632 }
4633
4634 queue_state
4635 } else {
4636 stop.until_stopped(pending()).await?
4638 };
4639
4640 let result = self.channel.main_loop(stop, state, queue_state).await;
4641 let msg = match result {
4642 Ok(restart) => {
4643 assert_eq!(self.channel_idx, 0);
4644 restart
4645 }
4646 Err(WorkerError::EndpointRequiresQueueRestart(err)) => {
4647 tracelimit::warn_ratelimited!(
4648 err = err.as_ref() as &dyn std::error::Error,
4649 "Endpoint requires queues to restart",
4650 );
4651 CoordinatorMessage::Restart
4652 }
4653 Err(err) => return Err(err),
4654 };
4655
4656 let WorkerState::Ready(ready) = std::mem::replace(
4657 &mut self.state,
4658 WorkerState::WaitingForCoordinator(None),
4659 ) else {
4660 unreachable!("must be running in ready state")
4661 };
4662 let _ = std::mem::replace(
4663 &mut self.state,
4664 WorkerState::WaitingForCoordinator(Some(ready)),
4665 );
4666 self.coordinator_send
4667 .try_send(msg)
4668 .map_err(WorkerError::CoordinatorMessageSendFailed)?;
4669 stop.until_stopped(pending()).await?
4670 }
4671 }
4672 }
4673 }
4674}
4675
4676impl<T: 'static + RingMem> NetChannel<T> {
4677 fn try_next_packet<'a>(
4678 &mut self,
4679 send_buffer: Option<&SendBuffer>,
4680 version: Option<Version>,
4681 external_data: &'a mut MultiPagedRangeBuf,
4682 ) -> Result<Option<Packet<'a>>, WorkerError> {
4683 let (mut read, _) = self.queue.split();
4684 let packet = match read.try_read() {
4685 Ok(packet) => parse_packet(&packet, send_buffer, version, external_data)
4686 .map_err(WorkerError::Packet)?,
4687 Err(queue::TryReadError::Empty) => return Ok(None),
4688 Err(queue::TryReadError::Queue(err)) => return Err(err.into()),
4689 };
4690
4691 tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
4692 Ok(Some(packet))
4693 }
4694
4695 async fn next_packet<'a>(
4696 &mut self,
4697 send_buffer: Option<&'a SendBuffer>,
4698 version: Option<Version>,
4699 external_data: &'a mut MultiPagedRangeBuf,
4700 ) -> Result<Packet<'a>, WorkerError> {
4701 let (mut read, _) = self.queue.split();
4702 let mut packet_ref = read.read().await?;
4703 let packet = parse_packet(&packet_ref, send_buffer, version, external_data)
4704 .map_err(WorkerError::Packet)?;
4705 if matches!(packet.data, PacketData::RndisPacket(_)) {
4706 tracing::trace!(target: "netvsp/vmbus", "detected rndis initialization message");
4708 packet_ref.revert();
4709 }
4710 tracing::trace!(target: "netvsp/vmbus", data = ?packet.data, "incoming vmbus packet");
4711 Ok(packet)
4712 }
4713
4714 fn is_ready_to_initialize(initializing: &InitState, allow_missing_send_buffer: bool) -> bool {
4715 (initializing.ndis_config.is_some() || initializing.version < Version::V2)
4716 && initializing.ndis_version.is_some()
4717 && (initializing.send_buffer.is_some() || allow_missing_send_buffer)
4718 && initializing.recv_buffer.is_some()
4719 }
4720
4721 async fn initialize(
4722 &mut self,
4723 initializing: &mut Option<InitState>,
4724 mem: GuestMemory,
4725 ) -> Result<(ChannelBuffers, ActiveState), WorkerError> {
4726 let mut has_init_packet_arrived = false;
4727 loop {
4728 if let Some(initializing) = &mut *initializing {
4729 if Self::is_ready_to_initialize(initializing, false) || has_init_packet_arrived {
4730 let recv_buffer = initializing.recv_buffer.take().unwrap();
4731 let send_buffer = initializing.send_buffer.take();
4732 let state = ActiveState::new(
4733 Some(PrimaryChannelState::new(
4734 self.adapter.offload_support.clone(),
4735 )),
4736 recv_buffer.count,
4737 );
4738 let buffers = ChannelBuffers {
4739 version: initializing.version,
4740 mem,
4741 recv_buffer,
4742 send_buffer,
4743 ndis_version: initializing.ndis_version.take().unwrap(),
4744 ndis_config: initializing.ndis_config.take().unwrap_or(NdisConfig {
4745 mtu: DEFAULT_MTU,
4746 capabilities: protocol::NdisConfigCapabilities::new(),
4747 }),
4748 };
4749
4750 break Ok((buffers, state));
4751 }
4752 }
4753
4754 self.queue
4757 .split()
4758 .1
4759 .wait_ready(ring::PacketSize::completion(protocol::PACKET_SIZE_V61))
4760 .await?;
4761
4762 let mut external_data = MultiPagedRangeBuf::new();
4763 let packet = self
4764 .next_packet(
4765 None,
4766 initializing.as_ref().map(|x| x.version),
4767 &mut external_data,
4768 )
4769 .await?;
4770
4771 if let Some(initializing) = &mut *initializing {
4772 match packet.data {
4773 PacketData::SendNdisConfig(config) => {
4774 if initializing.ndis_config.is_some() {
4775 return Err(WorkerError::UnexpectedPacketOrder(
4776 PacketOrderError::SendNdisConfigExists,
4777 ));
4778 }
4779
4780 let mtu = if config.mtu >= MIN_MTU && config.mtu <= MAX_MTU {
4782 config.mtu
4783 } else {
4784 DEFAULT_MTU
4785 };
4786
4787 self.send_completion(packet.transaction_id, None)?;
4789 initializing.ndis_config = Some(NdisConfig {
4790 mtu,
4791 capabilities: config.capabilities,
4792 });
4793 }
4794 PacketData::SendNdisVersion(version) => {
4795 if initializing.ndis_version.is_some() {
4796 return Err(WorkerError::UnexpectedPacketOrder(
4797 PacketOrderError::SendNdisVersionExists,
4798 ));
4799 }
4800
4801 self.send_completion(packet.transaction_id, None)?;
4803 initializing.ndis_version = Some(NdisVersion {
4804 major: version.ndis_major_version,
4805 minor: version.ndis_minor_version,
4806 });
4807 }
4808 PacketData::SendReceiveBuffer(message) => {
4809 if initializing.recv_buffer.is_some() {
4810 return Err(WorkerError::UnexpectedPacketOrder(
4811 PacketOrderError::SendReceiveBufferExists,
4812 ));
4813 }
4814
4815 let mtu = if let Some(cfg) = &initializing.ndis_config {
4816 cfg.mtu
4817 } else if initializing.version < Version::V2 {
4818 DEFAULT_MTU
4819 } else {
4820 return Err(WorkerError::UnexpectedPacketOrder(
4821 PacketOrderError::SendReceiveBufferMissingMTU,
4822 ));
4823 };
4824
4825 let sub_allocation_size = sub_allocation_size_for_mtu(mtu);
4826
4827 let recv_buffer = ReceiveBuffer::new(
4828 &self.gpadl_map,
4829 message.gpadl_handle,
4830 message.id,
4831 sub_allocation_size,
4832 )?;
4833
4834 self.send_completion(
4835 packet.transaction_id,
4836 Some(&self.message(
4837 protocol::MESSAGE1_TYPE_SEND_RECEIVE_BUFFER_COMPLETE,
4838 protocol::Message1SendReceiveBufferComplete {
4839 status: protocol::Status::SUCCESS,
4840 num_sections: 1,
4841 sections: [protocol::ReceiveBufferSection {
4842 offset: 0,
4843 sub_allocation_size: recv_buffer.sub_allocation_size,
4844 num_sub_allocations: recv_buffer.count,
4845 end_offset: recv_buffer.sub_allocation_size
4846 * recv_buffer.count,
4847 }],
4848 },
4849 )),
4850 )?;
4851 initializing.recv_buffer = Some(recv_buffer);
4852 }
4853 PacketData::SendSendBuffer(message) => {
4854 if initializing.send_buffer.is_some() {
4855 return Err(WorkerError::UnexpectedPacketOrder(
4856 PacketOrderError::SendSendBufferExists,
4857 ));
4858 }
4859
4860 let send_buffer = SendBuffer::new(&self.gpadl_map, message.gpadl_handle)?;
4861 self.send_completion(
4862 packet.transaction_id,
4863 Some(&self.message(
4864 protocol::MESSAGE1_TYPE_SEND_SEND_BUFFER_COMPLETE,
4865 protocol::Message1SendSendBufferComplete {
4866 status: protocol::Status::SUCCESS,
4867 section_size: 6144,
4868 },
4869 )),
4870 )?;
4871
4872 initializing.send_buffer = Some(send_buffer);
4873 }
4874 PacketData::RndisPacket(rndis_packet) => {
4875 if !Self::is_ready_to_initialize(initializing, true) {
4876 return Err(WorkerError::UnexpectedPacketOrder(
4877 PacketOrderError::UnexpectedRndisPacket,
4878 ));
4879 }
4880 tracing::debug!(
4881 channel_type = rndis_packet.channel_type,
4882 "RndisPacket received during initialization, assuming MESSAGE_TYPE_INITIALIZE_MSG"
4883 );
4884 has_init_packet_arrived = true;
4885 }
4886 _ => {
4887 return Err(WorkerError::UnexpectedPacketOrder(
4888 PacketOrderError::InvalidPacketData,
4889 ));
4890 }
4891 }
4892 } else {
4893 match packet.data {
4894 PacketData::Init(init) => {
4895 let requested_version = init.protocol_version;
4896 let version = check_version(requested_version);
4897 let mut data = protocol::MessageInitComplete {
4898 deprecated: protocol::INVALID_PROTOCOL_VERSION,
4899 maximum_mdl_chain_length: 34,
4900 status: protocol::Status::NONE,
4901 };
4902 if let Some(version) = version {
4903 if version == Version::V1 {
4904 data.deprecated = Version::V1 as u32;
4905 }
4906 data.status = protocol::Status::SUCCESS;
4907 } else {
4908 tracing::debug!(requested_version, "unrecognized version");
4909 }
4910 let message = self.message(protocol::MESSAGE_TYPE_INIT_COMPLETE, data);
4911 self.send_completion(packet.transaction_id, Some(&message))?;
4912
4913 if let Some(version) = version {
4914 tracelimit::info_ratelimited!(?version, "network negotiated");
4915
4916 if version >= Version::V61 {
4917 self.packet_size = PacketSize::V61;
4920 }
4921 *initializing = Some(InitState {
4922 version,
4923 ndis_config: None,
4924 ndis_version: None,
4925 recv_buffer: None,
4926 send_buffer: None,
4927 });
4928 }
4929 }
4930 _ => unreachable!(),
4931 }
4932 }
4933 }
4934 }
4935
4936 async fn main_loop(
4937 &mut self,
4938 stop: &mut StopTask<'_>,
4939 ready_state: &mut ReadyState,
4940 queue_state: &mut QueueState,
4941 ) -> Result<CoordinatorMessage, WorkerError> {
4942 let buffers = &ready_state.buffers;
4943 let state = &mut ready_state.state;
4944 let data = &mut ready_state.data;
4945
4946 let ring_spare_capacity = {
4947 let (_, send) = self.queue.split();
4948 let mut limit = if self.can_use_ring_size_opt {
4949 self.adapter.ring_size_limit.load(Ordering::Relaxed)
4950 } else {
4951 0
4952 };
4953 if limit == 0 {
4954 limit = send.capacity() - 2048;
4955 }
4956 send.capacity() - limit
4957 };
4958
4959 if !state.pending_rx_packets.is_empty()
4961 && self.packet_filter != rndisprot::NDIS_PACKET_TYPE_NONE
4962 {
4963 let epqueue = queue_state.queue.as_mut();
4964 let (front, back) = state.pending_rx_packets.as_slices();
4965 epqueue.rx_avail(front);
4966 epqueue.rx_avail(back);
4967 state.pending_rx_packets.clear();
4968 }
4969
4970 if let Some(primary) = state.primary.as_mut() {
4972 if primary.requested_num_queues > 1 && !primary.tx_spread_sent {
4973 let num_channels_opened =
4974 self.adapter.num_sub_channels_opened.load(Ordering::Relaxed);
4975 if num_channels_opened == primary.requested_num_queues as usize {
4976 let (_, mut send) = self.queue.split();
4977 stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
4978 .await??;
4979 self.guest_send_indirection_table(buffers.version, num_channels_opened as u32);
4980 primary.tx_spread_sent = true;
4981 }
4982 }
4983 if let PendingLinkAction::Active(up) = primary.pending_link_action {
4984 let (_, mut send) = self.queue.split();
4985 stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
4986 .await??;
4987 if let Some(id) = primary.free_control_buffers.pop() {
4988 let connect = if primary.guest_link_up != up {
4989 primary.pending_link_action = PendingLinkAction::Default;
4990 up
4991 } else {
4992 primary.pending_link_action =
4995 PendingLinkAction::Delay(primary.guest_link_up);
4996 !primary.guest_link_up
4997 };
4998 assert!(state.rx_bufs.is_free(id.0));
5000 state.rx_bufs.allocate(std::iter::once(id.0)).unwrap();
5001 let state_to_send = if connect {
5002 rndisprot::STATUS_MEDIA_CONNECT
5003 } else {
5004 rndisprot::STATUS_MEDIA_DISCONNECT
5005 };
5006 tracing::info!(
5007 connect,
5008 mac_address = %self.adapter.mac_address,
5009 "sending link status"
5010 );
5011
5012 self.indicate_status(buffers, id.0, state_to_send, &[])?;
5013 primary.guest_link_up = connect;
5014 } else {
5015 primary.pending_link_action = PendingLinkAction::Delay(up);
5016 }
5017
5018 match primary.pending_link_action {
5019 PendingLinkAction::Delay(_) => {
5020 return Ok(CoordinatorMessage::StartTimer(
5021 Instant::now() + LINK_DELAY_DURATION,
5022 ));
5023 }
5024 PendingLinkAction::Active(_) => panic!("State should not be Active"),
5025 _ => {}
5026 }
5027 }
5028 match primary.guest_vf_state {
5029 PrimaryChannelGuestVfState::Available { .. }
5030 | PrimaryChannelGuestVfState::UnavailableFromAvailable
5031 | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitchPending { .. }
5032 | PrimaryChannelGuestVfState::UnavailableFromDataPathSwitched
5033 | PrimaryChannelGuestVfState::DataPathSwitchPending { .. }
5034 | PrimaryChannelGuestVfState::DataPathSynthetic => {
5035 let (_, mut send) = self.queue.split();
5036 stop.until_stopped(send.wait_ready(MIN_STATE_CHANGE_RING_SIZE))
5037 .await??;
5038 if let Some(message) = self.handle_state_change(primary, buffers).await? {
5039 return Ok(message);
5040 }
5041 }
5042 _ => (),
5043 }
5044 }
5045
5046 loop {
5047 let ring_full = {
5052 let (_, mut send) = self.queue.split();
5053 !send.can_write(ring_spare_capacity)?
5054 };
5055
5056 let did_some_work = (!ring_full
5057 && self.process_endpoint_rx(buffers, state, data, queue_state.queue.as_mut())?)
5058 | self.process_ring_buffer(buffers, state, data, queue_state)?
5059 | (!ring_full
5060 && self.process_endpoint_tx(state, data, queue_state.queue.as_mut())?)
5061 | self.transmit_pending_segments(state, data, queue_state)?
5062 | self.send_pending_packets(state)?;
5063
5064 if !did_some_work {
5065 state.stats.spurious_wakes.increment();
5066 }
5067
5068 self.process_control_messages(buffers, state)?;
5071
5072 let restart = stop
5076 .until_stopped(std::future::poll_fn(
5077 |cx| -> Poll<Option<CoordinatorMessage>> {
5078 if !ring_full {
5082 if queue_state.queue.poll_ready(cx).is_ready() {
5084 tracing::trace!("endpoint ready");
5085 return Poll::Ready(None);
5086 }
5087 }
5088
5089 let (mut recv, mut send) = self.queue.split();
5092 if state.free_tx_packets.len() >= self.adapter.free_tx_packet_threshold
5093 && data.tx_segments.is_empty()
5094 && recv.poll_ready(cx).is_ready()
5095 {
5096 tracing::trace!("incoming ring ready");
5097 return Poll::Ready(None);
5098 }
5099
5100 let mut pending_send_size = self.pending_send_size;
5107 if ring_full {
5108 pending_send_size = ring_spare_capacity;
5109 }
5110 if pending_send_size != 0
5111 && send.poll_ready(cx, pending_send_size).is_ready()
5112 {
5113 tracing::trace!("outgoing ring ready");
5114 return Poll::Ready(None);
5115 }
5116
5117 if let Some(remote_buffer_id_recv) =
5123 &mut queue_state.rx_buffer_range.remote_buffer_id_recv
5124 {
5125 while let Poll::Ready(Some(id)) =
5126 remote_buffer_id_recv.poll_next_unpin(cx)
5127 {
5128 if id >= RX_RESERVED_CONTROL_BUFFERS {
5129 queue_state.queue.rx_avail(&[RxId(id)]);
5130 } else {
5131 state
5132 .primary
5133 .as_mut()
5134 .unwrap()
5135 .free_control_buffers
5136 .push(ControlMessageId(id));
5137 }
5138 }
5139 }
5140
5141 if let Some(restart) = self.restart.take() {
5142 return Poll::Ready(Some(restart));
5143 }
5144
5145 tracing::trace!("network waiting");
5146 Poll::Pending
5147 },
5148 ))
5149 .await?;
5150
5151 if let Some(restart) = restart {
5152 break Ok(restart);
5153 }
5154 }
5155 }
5156
5157 fn process_endpoint_rx(
5158 &mut self,
5159 buffers: &ChannelBuffers,
5160 state: &mut ActiveState,
5161 data: &mut ProcessingData,
5162 epqueue: &mut dyn net_backend::Queue,
5163 ) -> Result<bool, WorkerError> {
5164 let n = epqueue
5165 .rx_poll(&mut data.rx_ready)
5166 .map_err(WorkerError::Endpoint)?;
5167 if n == 0 {
5168 return Ok(false);
5169 }
5170
5171 state.stats.rx_packets_per_wake.add_sample(n as u64);
5172
5173 if self.packet_filter == rndisprot::NDIS_PACKET_TYPE_NONE {
5174 tracing::trace!(
5175 packet_filter = self.packet_filter,
5176 "rx packets dropped due to packet filter"
5177 );
5178 state.pending_rx_packets.extend(&data.rx_ready[..n]);
5183 state.stats.rx_dropped_filtered.add(n as u64);
5184 return Ok(false);
5185 }
5186
5187 let transaction_id = data.rx_ready[0].0.into();
5188 let ready_ids = data.rx_ready[..n].iter().map(|&RxId(id)| id);
5189
5190 state.rx_bufs.allocate(ready_ids.clone()).unwrap();
5191
5192 let len = buffers.recv_buffer.sub_allocation_size as usize;
5195 data.transfer_pages.clear();
5196 data.transfer_pages
5197 .extend(ready_ids.map(|id| buffers.recv_buffer.transfer_page_range(id, len)));
5198
5199 match self.try_send_rndis_message(
5200 transaction_id,
5201 protocol::DATA_CHANNEL_TYPE,
5202 buffers.recv_buffer.id,
5203 &data.transfer_pages,
5204 )? {
5205 None => {
5206 state.stats.rx_packets.add(n as u64);
5208 }
5209 Some(_) => {
5210 state.stats.rx_dropped_ring_full.add(n as u64);
5214
5215 state.rx_bufs.free(data.rx_ready[0].0);
5216 epqueue.rx_avail(&data.rx_ready[..n]);
5217 }
5218 }
5219
5220 Ok(true)
5221 }
5222
5223 fn process_endpoint_tx(
5224 &mut self,
5225 state: &mut ActiveState,
5226 data: &mut ProcessingData,
5227 epqueue: &mut dyn net_backend::Queue,
5228 ) -> Result<bool, WorkerError> {
5229 let result = epqueue.tx_poll(&mut data.tx_done);
5231
5232 match result {
5233 Ok(n) => {
5234 if n == 0 {
5235 return Ok(false);
5236 }
5237
5238 for &id in &data.tx_done[..n] {
5239 let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5240 assert!(tx_packet.pending_packet_count > 0);
5241 tx_packet.pending_packet_count -= 1;
5242 if tx_packet.pending_packet_count == 0 {
5243 self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5244 }
5245 }
5246
5247 Ok(true)
5248 }
5249 Err(TxError::TryRestart(err)) => {
5250 let pending_tx = state
5252 .pending_tx_packets
5253 .iter_mut()
5254 .enumerate()
5255 .filter_map(|(id, inflight)| {
5256 if inflight.pending_packet_count > 0 {
5257 inflight.pending_packet_count = 0;
5258 Some(PendingTxCompletion {
5259 transaction_id: inflight.transaction_id,
5260 tx_id: Some(TxId(id as u32)),
5261 status: protocol::Status::SUCCESS,
5262 })
5263 } else {
5264 None
5265 }
5266 })
5267 .collect::<Vec<_>>();
5268 state.pending_tx_completions.extend(pending_tx);
5269 Err(WorkerError::EndpointRequiresQueueRestart(err))
5270 }
5271 Err(TxError::Fatal(err)) => Err(WorkerError::Endpoint(err)),
5272 }
5273 }
5274
5275 fn switch_data_path(
5276 &mut self,
5277 state: &mut ActiveState,
5278 use_guest_vf: bool,
5279 transaction_id: Option<u64>,
5280 ) -> Result<(), WorkerError> {
5281 let primary = state.primary.as_mut().unwrap();
5282 let mut queue_switch_operation = false;
5283 match primary.guest_vf_state {
5284 PrimaryChannelGuestVfState::AvailableAdvertised | PrimaryChannelGuestVfState::Ready => {
5285 if use_guest_vf || primary.is_data_path_switched.is_none() {
5294 primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5295 to_guest: use_guest_vf,
5296 id: transaction_id,
5297 result: None,
5298 };
5299 queue_switch_operation = true;
5300 }
5301 }
5302 PrimaryChannelGuestVfState::DataPathSwitched => {
5303 if !use_guest_vf {
5304 primary.guest_vf_state = PrimaryChannelGuestVfState::DataPathSwitchPending {
5305 to_guest: false,
5306 id: transaction_id,
5307 result: None,
5308 };
5309 queue_switch_operation = true;
5310 }
5311 }
5312 _ if use_guest_vf => {
5313 tracing::warn!(
5314 state = %primary.guest_vf_state,
5315 use_guest_vf,
5316 "Data path switch requested while device is in wrong state"
5317 );
5318 }
5319 _ => (),
5320 };
5321 if queue_switch_operation {
5322 self.send_coordinator_update_vf();
5323 } else {
5324 self.send_completion(transaction_id, None)?;
5325 }
5326 Ok(())
5327 }
5328
5329 fn process_ring_buffer(
5330 &mut self,
5331 buffers: &ChannelBuffers,
5332 state: &mut ActiveState,
5333 data: &mut ProcessingData,
5334 queue_state: &mut QueueState,
5335 ) -> Result<bool, WorkerError> {
5336 if !data.tx_segments.is_empty() {
5337 return Ok(false);
5341 }
5342 let mut total_packets = 0;
5343 let mut did_some_work = false;
5344 loop {
5345 if state.free_tx_packets.is_empty() {
5346 break;
5347 }
5348 let packet = if let Some(packet) = self.try_next_packet(
5349 buffers.send_buffer.as_ref(),
5350 Some(buffers.version),
5351 &mut data.external_data,
5352 )? {
5353 packet
5354 } else {
5355 break;
5356 };
5357
5358 did_some_work = true;
5359 match packet.data {
5360 PacketData::RndisPacket(_) => {
5361 let id = state.free_tx_packets.pop().unwrap();
5362 let result: Result<usize, WorkerError> =
5363 self.handle_rndis(buffers, id, state, &packet, &mut data.tx_segments);
5364 match result {
5365 Ok(num_packets) => {
5366 total_packets += num_packets as u64;
5367 if num_packets == 0 {
5368 self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5369 }
5370 }
5371 Err(err) => {
5372 tracelimit::error_ratelimited!(
5373 err = &err as &dyn std::error::Error,
5374 "failed to handle RNDIS packet"
5375 );
5376 self.complete_tx_packet(state, id, protocol::Status::FAILURE)?;
5377 }
5378 };
5379 }
5380 PacketData::RndisPacketComplete(_completion) => {
5381 data.rx_done.clear();
5382 state
5383 .release_recv_buffers(
5384 packet
5385 .transaction_id
5386 .expect("completion packets have transaction id by construction"),
5387 &queue_state.rx_buffer_range,
5388 &mut data.rx_done,
5389 )
5390 .ok_or(WorkerError::InvalidRndisPacketCompletion)?;
5391 queue_state.queue.rx_avail(&data.rx_done);
5392 }
5393 PacketData::SubChannelRequest(request) if state.primary.is_some() => {
5394 let mut subchannel_count = 0;
5395 let status = if request.operation == protocol::SubchannelOperation::ALLOCATE
5401 && request.num_sub_channels < self.adapter.max_queues.into()
5402 {
5403 subchannel_count = request.num_sub_channels;
5404 protocol::Status::SUCCESS
5405 } else {
5406 tracelimit::warn_ratelimited!(
5407 operation = ?request.operation,
5408 request_sub_channels = request.num_sub_channels,
5409 max_supported_sub_channels = self.adapter.max_queues - 1,
5410 "Subchannel request failed: either operation is not supported or requested more subchannels than supported"
5411 );
5412 protocol::Status::FAILURE
5413 };
5414
5415 tracing::debug!(?status, subchannel_count, "subchannel request");
5416 self.send_completion(
5417 packet.transaction_id,
5418 Some(&self.message(
5419 protocol::MESSAGE5_TYPE_SUB_CHANNEL,
5420 protocol::Message5SubchannelComplete {
5421 status,
5422 num_sub_channels: subchannel_count,
5423 },
5424 )),
5425 )?;
5426
5427 if subchannel_count > 0 {
5428 let primary = state.primary.as_mut().unwrap();
5429 primary.requested_num_queues = subchannel_count as u16 + 1;
5430 primary.tx_spread_sent = false;
5431 self.restart = Some(CoordinatorMessage::Restart);
5432 }
5433 }
5434 PacketData::RevokeReceiveBuffer(Message1RevokeReceiveBuffer { id })
5435 | PacketData::RevokeSendBuffer(Message1RevokeSendBuffer { id })
5436 if state.primary.is_some() =>
5437 {
5438 tracing::debug!(
5439 id,
5440 "receive/send buffer revoked, terminating channel processing"
5441 );
5442 return Err(WorkerError::BufferRevoked);
5443 }
5444 PacketData::SendVfAssociationCompletion if state.primary.is_some() => (),
5446 PacketData::SwitchDataPath(switch_data_path) if state.primary.is_some() => {
5447 self.switch_data_path(
5448 state,
5449 switch_data_path.active_data_path == protocol::DataPath::VF.0,
5450 packet.transaction_id,
5451 )?;
5452 }
5453 PacketData::SwitchDataPathCompletion if state.primary.is_some() => (),
5454 PacketData::OidQueryEx(oid_query) => {
5455 tracing::warn!(oid = ?oid_query.oid, "unimplemented OID");
5456 self.send_completion(
5457 packet.transaction_id,
5458 Some(&self.message(
5459 protocol::MESSAGE5_TYPE_OID_QUERY_EX_COMPLETE,
5460 protocol::Message5OidQueryExComplete {
5461 status: rndisprot::STATUS_NOT_SUPPORTED,
5462 bytes: 0,
5463 },
5464 )),
5465 )?;
5466 }
5467 p => {
5468 tracing::warn!(request = ?p, "unexpected packet");
5469 return Err(WorkerError::UnexpectedPacketOrder(
5470 PacketOrderError::SwitchDataPathCompletionPrimaryChannelState,
5471 ));
5472 }
5473 }
5474 }
5475 if total_packets > 0 && !self.transmit_segments(state, data, queue_state)? {
5476 state.stats.tx_stalled.increment();
5477 }
5478 state.stats.tx_packets_per_wake.add_sample(total_packets);
5479 Ok(did_some_work)
5480 }
5481
5482 fn transmit_pending_segments(
5485 &mut self,
5486 state: &mut ActiveState,
5487 data: &mut ProcessingData,
5488 queue_state: &mut QueueState,
5489 ) -> Result<bool, WorkerError> {
5490 if data.tx_segments.is_empty() {
5491 return Ok(false);
5492 }
5493 let sent = data.tx_segments_sent;
5494 let did_work =
5495 self.transmit_segments(state, data, queue_state)? || data.tx_segments_sent > sent;
5496 Ok(did_work)
5497 }
5498
5499 fn transmit_segments(
5501 &mut self,
5502 state: &mut ActiveState,
5503 data: &mut ProcessingData,
5504 queue_state: &mut QueueState,
5505 ) -> Result<bool, WorkerError> {
5506 let segments = &data.tx_segments[data.tx_segments_sent..];
5507 let (sync, segments_sent) = queue_state
5508 .queue
5509 .tx_avail(segments)
5510 .map_err(WorkerError::Endpoint)?;
5511
5512 let mut segments = &segments[..segments_sent];
5513 data.tx_segments_sent += segments_sent;
5514
5515 if sync {
5516 while let Some(head) = segments.first() {
5518 let net_backend::TxSegmentType::Head(metadata) = &head.ty else {
5519 unreachable!()
5520 };
5521 let id = metadata.id;
5522 let pending_tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5523 pending_tx_packet.pending_packet_count -= 1;
5524 if pending_tx_packet.pending_packet_count == 0 {
5525 self.complete_tx_packet(state, id, protocol::Status::SUCCESS)?;
5526 }
5527 segments = &segments[metadata.segment_count as usize..];
5528 }
5529 }
5530
5531 let all_sent = data.tx_segments_sent == data.tx_segments.len();
5532 if all_sent {
5533 data.tx_segments.clear();
5534 data.tx_segments_sent = 0;
5535 }
5536 Ok(all_sent)
5537 }
5538
5539 fn handle_rndis(
5540 &mut self,
5541 buffers: &ChannelBuffers,
5542 id: TxId,
5543 state: &mut ActiveState,
5544 packet: &Packet<'_>,
5545 segments: &mut Vec<TxSegment>,
5546 ) -> Result<usize, WorkerError> {
5547 let mut total_packets = 0;
5548 let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5549 assert_eq!(tx_packet.pending_packet_count, 0);
5550 tx_packet.transaction_id = packet
5551 .transaction_id
5552 .ok_or(WorkerError::MissingTransactionId)?;
5553
5554 packet
5558 .external_data
5559 .iter()
5560 .try_for_each(|range| buffers.mem.probe_gpns(range.gpns()))
5561 .map_err(WorkerError::GpaDirectError)?;
5562
5563 let mut reader = packet.rndis_reader(&buffers.mem);
5564 let header: rndisprot::MessageHeader = reader.read_plain()?;
5565 if header.message_type == rndisprot::MESSAGE_TYPE_PACKET_MSG {
5566 let start = segments.len();
5567 match self.handle_rndis_packet_messages(
5568 buffers,
5569 state,
5570 id,
5571 header.message_length as usize,
5572 reader,
5573 segments,
5574 ) {
5575 Ok(n) => {
5576 state.pending_tx_packets[id.0 as usize].pending_packet_count += n;
5577 total_packets += n;
5578 }
5579 Err(err) => {
5580 segments.truncate(start);
5582 return Err(err);
5583 }
5584 }
5585 } else {
5586 self.handle_rndis_message(state, header.message_type, reader)?;
5587 }
5588
5589 Ok(total_packets)
5590 }
5591
5592 fn try_send_tx_packet(
5593 &mut self,
5594 transaction_id: u64,
5595 status: protocol::Status,
5596 ) -> Result<bool, WorkerError> {
5597 let message = self.message(
5598 protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE,
5599 protocol::Message1SendRndisPacketComplete { status },
5600 );
5601 let result = self.queue.split().1.batched().try_write_aligned(
5602 transaction_id,
5603 OutgoingPacketType::Completion,
5604 message.aligned_payload(),
5605 );
5606 let sent = match result {
5607 Ok(()) => true,
5608 Err(queue::TryWriteError::Full(n)) => {
5609 self.pending_send_size = n;
5610 false
5611 }
5612 Err(queue::TryWriteError::Queue(err)) => return Err(err.into()),
5613 };
5614 Ok(sent)
5615 }
5616
5617 fn send_pending_packets(&mut self, state: &mut ActiveState) -> Result<bool, WorkerError> {
5618 let mut did_some_work = false;
5619 while let Some(pending) = state.pending_tx_completions.front() {
5620 if !self.try_send_tx_packet(pending.transaction_id, pending.status)? {
5621 return Ok(did_some_work);
5622 }
5623 did_some_work = true;
5624 if let Some(id) = pending.tx_id {
5625 state.free_tx_packets.push(id);
5626 }
5627 tracing::trace!(?pending, "sent tx completion");
5628 state.pending_tx_completions.pop_front();
5629 }
5630
5631 self.pending_send_size = 0;
5632 Ok(did_some_work)
5633 }
5634
5635 fn complete_tx_packet(
5636 &mut self,
5637 state: &mut ActiveState,
5638 id: TxId,
5639 status: protocol::Status,
5640 ) -> Result<(), WorkerError> {
5641 let tx_packet = &mut state.pending_tx_packets[id.0 as usize];
5642 assert_eq!(tx_packet.pending_packet_count, 0);
5643 if self.pending_send_size == 0
5644 && self.try_send_tx_packet(tx_packet.transaction_id, status)?
5645 {
5646 tracing::trace!(id = id.0, "sent tx completion");
5647 state.free_tx_packets.push(id);
5648 } else {
5649 tracing::trace!(id = id.0, "pended tx completion");
5650 state.pending_tx_completions.push_back(PendingTxCompletion {
5651 transaction_id: tx_packet.transaction_id,
5652 tx_id: Some(id),
5653 status,
5654 });
5655 }
5656 Ok(())
5657 }
5658}
5659
5660impl ActiveState {
5661 fn release_recv_buffers(
5662 &mut self,
5663 transaction_id: u64,
5664 rx_buffer_range: &RxBufferRange,
5665 done: &mut Vec<RxId>,
5666 ) -> Option<()> {
5667 let first_id: u32 = transaction_id.try_into().ok()?;
5669 let ids = self.rx_bufs.free(first_id)?;
5670 for id in ids {
5671 if !rx_buffer_range.send_if_remote(id) {
5672 if id >= RX_RESERVED_CONTROL_BUFFERS {
5673 done.push(RxId(id));
5674 } else {
5675 self.primary
5676 .as_mut()?
5677 .free_control_buffers
5678 .push(ControlMessageId(id));
5679 }
5680 }
5681 }
5682 Some(())
5683 }
5684}