net_packet_capture/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! `pcapng` compatible packet capture endpoint implementation.
5
6#![expect(missing_docs)]
7#![forbid(unsafe_code)]
8
9use async_trait::async_trait;
10use futures::FutureExt;
11use futures::StreamExt;
12use futures::lock::Mutex;
13use futures_concurrency::future::Race;
14use inspect::InspectMut;
15use mesh::error::RemoteError;
16use mesh::rpc::FailableRpc;
17use mesh::rpc::RpcSend;
18use net_backend::BufferAccess;
19use net_backend::Endpoint;
20use net_backend::EndpointAction;
21use net_backend::MultiQueueSupport;
22use net_backend::Queue;
23use net_backend::QueueConfig;
24use net_backend::RssConfig;
25use net_backend::RxId;
26use net_backend::TxError;
27use net_backend::TxId;
28use net_backend::TxOffloadSupport;
29use net_backend::TxSegment;
30use net_backend::next_packet;
31use pcap_file::DataLink;
32use pcap_file::PcapError;
33use pcap_file::PcapResult;
34use pcap_file::pcapng::PcapNgWriter;
35use pcap_file::pcapng::blocks::enhanced_packet::EnhancedPacketBlock;
36use pcap_file::pcapng::blocks::interface_description::InterfaceDescriptionBlock;
37use std::borrow::Cow;
38use std::io::Write;
39use std::sync::Arc;
40use std::sync::atomic::AtomicBool;
41use std::sync::atomic::AtomicUsize;
42use std::sync::atomic::Ordering;
43use std::task::Context;
44use std::task::Poll;
45use std::time::Duration;
46use std::time::SystemTime;
47use std::time::UNIX_EPOCH;
48
49/// Defines packet capture operations.
50#[derive(Debug, PartialEq, mesh::MeshPayload)]
51pub enum PacketCaptureOperation {
52    /// Query details.
53    Query,
54    /// Start packet capture.
55    Start,
56    /// Stop packet capture.
57    Stop,
58}
59
60/// Defines start operation data.
61#[derive(Debug, mesh::MeshPayload)]
62pub struct StartData<W: Write> {
63    pub snaplen: u32,
64    pub writers: Vec<W>,
65}
66
67/// Defines operational data.
68#[derive(Debug, mesh::MeshPayload)]
69pub enum OperationData<W: Write> {
70    OpQueryData(u32),
71    OpStartData(StartData<W>),
72}
73
74/// Additional parameters provided as part of a network packet capture trace.
75#[derive(Debug, mesh::MeshPayload)]
76pub struct PacketCaptureParams<W: Write> {
77    /// Indicates the network capture operation.
78    pub operation: PacketCaptureOperation,
79    /// Operational data that is specific to the given operation.
80    pub op_data: Option<OperationData<W>>,
81}
82
83trait PcapWriter: Send + Sync {
84    /// Writes a EnhancedPacketBlocke
85    fn write_pcapng_block_eb(&mut self, block: EnhancedPacketBlock<'_>) -> PcapResult<usize>;
86
87    /// Writes a InterfaceDescriptionBlock
88    fn write_pcapng_block_id(&mut self, block: InterfaceDescriptionBlock<'_>) -> PcapResult<usize>;
89}
90
91struct LocalPcapWriter<W: Write> {
92    inner: PcapNgWriter<W>,
93}
94
95impl<W: Write + Send + Sync> PcapWriter for LocalPcapWriter<W> {
96    fn write_pcapng_block_eb(&mut self, block: EnhancedPacketBlock<'_>) -> PcapResult<usize> {
97        self.inner.write_pcapng_block(block)
98    }
99
100    fn write_pcapng_block_id(&mut self, block: InterfaceDescriptionBlock<'_>) -> PcapResult<usize> {
101        self.inner.write_pcapng_block(block)
102    }
103}
104
105struct PacketCaptureOptions {
106    operation: PacketCaptureOperation,
107    snaplen: usize,
108    writer: Option<Box<dyn PcapWriter>>,
109}
110
111impl PacketCaptureOptions {
112    fn new_with_start<W: Write + Send + Sync + 'static>(snaplen: u32, writer: W) -> Self {
113        //TODO: Native endianness?
114        let pcap_ng_writer =
115            PcapNgWriter::with_endianness(writer, pcap_file::Endianness::Big).unwrap();
116
117        let local_writer = LocalPcapWriter {
118            inner: pcap_ng_writer,
119        };
120
121        Self {
122            operation: PacketCaptureOperation::Start,
123            snaplen: snaplen as usize,
124            writer: Some(Box::new(local_writer)),
125        }
126    }
127
128    fn new_with_stop() -> Self {
129        Self {
130            operation: PacketCaptureOperation::Stop,
131            snaplen: 0,
132            writer: None,
133        }
134    }
135}
136
137enum PacketCaptureEndpointCommand {
138    PacketCapture(FailableRpc<PacketCaptureOptions, ()>),
139}
140
141pub struct PacketCaptureEndpointControl {
142    control_tx: mesh::Sender<PacketCaptureEndpointCommand>,
143}
144
145impl PacketCaptureEndpointControl {
146    pub async fn packet_capture<W: Write + Send + Sync + 'static>(
147        &self,
148        params: PacketCaptureParams<W>,
149    ) -> anyhow::Result<PacketCaptureParams<W>> {
150        let mut params = params;
151        let options = match params.operation {
152            PacketCaptureOperation::Query | PacketCaptureOperation::Start => {
153                let Some(op_data) = &mut params.op_data else {
154                    anyhow::bail!(
155                        "Invalid input parameter. Expecting operational data, but none provided"
156                    );
157                };
158
159                match op_data {
160                    OperationData::OpQueryData(num_streams) => {
161                        return Ok(PacketCaptureParams {
162                            operation: params.operation,
163                            op_data: Some(OperationData::OpQueryData(*num_streams + 1)),
164                        });
165                    }
166                    OperationData::OpStartData(data) => {
167                        if data.writers.is_empty() {
168                            anyhow::bail!("Insufficient streams");
169                        }
170                        let socket = data.writers.remove(0);
171                        PacketCaptureOptions::new_with_start(data.snaplen, socket)
172                    }
173                }
174            }
175            PacketCaptureOperation::Stop => PacketCaptureOptions::new_with_stop(),
176        };
177
178        self.control_tx
179            .call_failable(PacketCaptureEndpointCommand::PacketCapture, options)
180            .await?;
181
182        Ok(params)
183    }
184}
185
186pub struct PacketCaptureEndpoint {
187    /// Some identifier that this endpoint can identify itself using for things
188    /// like tracing, filtering etc..
189    id: String,
190    endpoint: Box<dyn Endpoint>,
191    control_rx: Arc<Mutex<mesh::Receiver<PacketCaptureEndpointCommand>>>,
192    pcap: Arc<Pcap>,
193}
194
195impl InspectMut for PacketCaptureEndpoint {
196    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
197        self.current_mut().inspect_mut(req)
198    }
199}
200
201impl PacketCaptureEndpoint {
202    pub fn new(endpoint: Box<dyn Endpoint>, id: String) -> (Self, PacketCaptureEndpointControl) {
203        let (control_tx, control_rx) = mesh::channel();
204        let control = PacketCaptureEndpointControl {
205            control_tx: control_tx.clone(),
206        };
207        let pcap = Arc::new(Pcap::new(control_tx.clone()));
208        (
209            Self {
210                id,
211                endpoint,
212                control_rx: Arc::new(Mutex::new(control_rx)),
213                pcap,
214            },
215            control,
216        )
217    }
218
219    fn current(&self) -> &dyn Endpoint {
220        self.endpoint.as_ref()
221    }
222
223    fn current_mut(&mut self) -> &mut dyn Endpoint {
224        self.endpoint.as_mut()
225    }
226}
227
228#[async_trait]
229impl Endpoint for PacketCaptureEndpoint {
230    fn endpoint_type(&self) -> &'static str {
231        self.current().endpoint_type()
232    }
233
234    async fn get_queues(
235        &mut self,
236        config: Vec<QueueConfig>,
237        rss: Option<&RssConfig<'_>>,
238        queues: &mut Vec<Box<dyn Queue>>,
239    ) -> anyhow::Result<()> {
240        if self.pcap.enabled.load(Ordering::Relaxed) {
241            tracing::trace!("using packet capture queues");
242            let mut queues_inner: Vec<Box<dyn Queue>> = Vec::new();
243            self.current_mut()
244                .get_queues(config, rss, &mut queues_inner)
245                .await?;
246            while let Some(inner) = queues_inner.pop() {
247                queues.push(Box::new(PacketCaptureQueue {
248                    queue: inner,
249                    pcap: self.pcap.clone(),
250                    scratch_segments: Vec::new(),
251                }));
252            }
253        } else {
254            tracing::trace!("using inner queues");
255            self.current_mut().get_queues(config, rss, queues).await?;
256        }
257        Ok(())
258    }
259
260    async fn stop(&mut self) {
261        self.current_mut().stop().await
262    }
263
264    fn is_ordered(&self) -> bool {
265        self.current().is_ordered()
266    }
267
268    fn tx_offload_support(&self) -> TxOffloadSupport {
269        self.current().tx_offload_support()
270    }
271
272    fn multiqueue_support(&self) -> MultiQueueSupport {
273        self.current().multiqueue_support()
274    }
275
276    fn tx_fast_completions(&self) -> bool {
277        self.current().tx_fast_completions()
278    }
279
280    async fn set_data_path_to_guest_vf(&self, use_vf: bool) -> anyhow::Result<()> {
281        self.current().set_data_path_to_guest_vf(use_vf).await
282    }
283
284    async fn get_data_path_to_guest_vf(&self) -> anyhow::Result<bool> {
285        self.current().get_data_path_to_guest_vf().await
286    }
287
288    async fn wait_for_endpoint_action(&mut self) -> EndpointAction {
289        enum Message {
290            PacketCaptureEndpointCommand(PacketCaptureEndpointCommand),
291            UpdateFromEndpoint(EndpointAction),
292        }
293        loop {
294            let receiver = self.control_rx.clone();
295            let mut receive_update = receiver.lock().await;
296            let update = async {
297                match receive_update.next().await {
298                    Some(m) => Message::PacketCaptureEndpointCommand(m),
299                    None => {
300                        std::future::pending::<()>().await;
301                        unreachable!()
302                    }
303                }
304            };
305            let ep_update = self
306                .current_mut()
307                .wait_for_endpoint_action()
308                .map(Message::UpdateFromEndpoint);
309            let m = (update, ep_update).race().await;
310            match m {
311                Message::PacketCaptureEndpointCommand(
312                    PacketCaptureEndpointCommand::PacketCapture(rpc),
313                ) => {
314                    let (options, response) = rpc.split();
315                    let result = async {
316                        let id = &self.id;
317                        let start = match options.operation {
318                            PacketCaptureOperation::Start => {
319                                tracing::info!(id, "starting trace");
320                                true
321                            }
322                            PacketCaptureOperation::Stop => {
323                                tracing::info!(id, "stopping trace");
324                                false
325                            }
326                            _ => Err(anyhow::anyhow!("Unexpected packet capture option {id}"))?,
327                        };
328
329                        // Keep the lock until all values are being set to make the update atomic.
330                        let mut pcap_writer = self.pcap.pcap_writer.lock();
331                        let restart_required = start != self.pcap.enabled.load(Ordering::Relaxed);
332                        self.pcap.snaplen.store(options.snaplen, Ordering::Relaxed);
333                        self.pcap
334                            .interface_descriptor_written
335                            .store(false, Ordering::Relaxed);
336                        self.pcap.enabled.store(start, Ordering::Relaxed);
337                        *pcap_writer = options.writer;
338                        anyhow::Ok(restart_required)
339                    }
340                    .await;
341                    let (result, restart_required) = match result {
342                        Err(e) => (Err(e), false),
343                        Ok(value) => (Ok(()), value),
344                    };
345                    response.complete(result.map_err(RemoteError::new));
346                    if restart_required {
347                        break EndpointAction::RestartRequired;
348                    }
349                }
350                Message::UpdateFromEndpoint(update) => break update,
351            }
352        }
353    }
354
355    fn link_speed(&self) -> u64 {
356        self.current().link_speed()
357    }
358}
359
360struct Pcap {
361    // N.B Lock/update semantics: Keep the `pcap_writer` lock while updating
362    //  the other fields.
363    pcap_writer: parking_lot::Mutex<Option<Box<dyn PcapWriter>>>,
364    interface_descriptor_written: AtomicBool,
365    enabled: AtomicBool,
366    snaplen: AtomicUsize,
367    endpoint_control: mesh::Sender<PacketCaptureEndpointCommand>,
368}
369
370impl Pcap {
371    fn new(endpoint_control: mesh::Sender<PacketCaptureEndpointCommand>) -> Self {
372        Self {
373            enabled: AtomicBool::new(false),
374            snaplen: AtomicUsize::new(65535),
375            pcap_writer: parking_lot::Mutex::new(None),
376            interface_descriptor_written: AtomicBool::new(false),
377            endpoint_control,
378        }
379    }
380
381    fn write_packet(
382        &self,
383        buf: &[u8],
384        original_len: u32,
385        snaplen: u32,
386        timestamp: &Duration,
387    ) -> bool {
388        let mut locked_writer = self.pcap_writer.lock();
389        let Some(pcap_writer) = &mut *locked_writer else {
390            return false;
391        };
392
393        let handle_write_result = |r: PcapResult<usize>| match r {
394            // Writer gone unexpectedly; disable packet capture.
395            Err(PcapError::IoError(_)) => {
396                // No particular benefit of using compare_exchange atomic here
397                // as the pcap writer lock is held.
398                if self.enabled.load(Ordering::Relaxed) {
399                    self.enabled.store(false, Ordering::Relaxed);
400                    let stop = PacketCaptureOptions::new_with_stop();
401                    // Best effort.
402                    drop(
403                        self.endpoint_control
404                            .call(PacketCaptureEndpointCommand::PacketCapture, stop),
405                    );
406                }
407                Err(())
408            }
409            _ => Ok(()),
410        };
411
412        if !self.interface_descriptor_written.load(Ordering::Relaxed) {
413            let interface = InterfaceDescriptionBlock {
414                linktype: DataLink::ETHERNET,
415                snaplen,
416                options: vec![],
417            };
418            if handle_write_result(pcap_writer.write_pcapng_block_id(interface)).is_err() {
419                *locked_writer = None;
420                return false;
421            }
422            self.interface_descriptor_written
423                .store(true, Ordering::Relaxed);
424        }
425
426        let packet = EnhancedPacketBlock {
427            interface_id: 0,
428            timestamp: *timestamp,
429            original_len,
430            data: Cow::Borrowed(buf),
431            options: vec![],
432        };
433
434        if handle_write_result(pcap_writer.write_pcapng_block_eb(packet)).is_err() {
435            *locked_writer = None;
436            return false;
437        }
438
439        true
440    }
441}
442
443struct PacketCaptureQueue {
444    queue: Box<dyn Queue>,
445    pcap: Arc<Pcap>,
446    scratch_segments: Vec<net_backend::RxBufferSegment>,
447}
448
449impl PacketCaptureQueue {
450    fn current_mut(&mut self) -> &mut dyn Queue {
451        self.queue.as_mut()
452    }
453}
454
455#[async_trait]
456impl Queue for PacketCaptureQueue {
457    async fn update_target_vp(&mut self, target_vp: u32) {
458        self.current_mut().update_target_vp(target_vp).await
459    }
460
461    fn poll_ready(&mut self, cx: &mut Context<'_>, pool: &mut dyn BufferAccess) -> Poll<()> {
462        self.current_mut().poll_ready(cx, pool)
463    }
464
465    fn rx_avail(&mut self, pool: &mut dyn BufferAccess, done: &[RxId]) {
466        self.current_mut().rx_avail(pool, done)
467    }
468
469    fn rx_poll(
470        &mut self,
471        pool: &mut dyn BufferAccess,
472        packets: &mut [RxId],
473    ) -> anyhow::Result<usize> {
474        let n = self.current_mut().rx_poll(pool, packets)?;
475        if self.pcap.enabled.load(Ordering::Relaxed) {
476            let timestamp = SystemTime::now()
477                .duration_since(UNIX_EPOCH)
478                .unwrap_or(Duration::new(0, 0));
479            let snaplen = self.pcap.snaplen.load(Ordering::Relaxed);
480            for id in &packets[..n] {
481                let mut buf = vec![0; snaplen];
482                let mut len = 0;
483                let mut pkt_len = 0;
484                self.scratch_segments.clear();
485                pool.push_guest_addresses(*id, &mut self.scratch_segments);
486                for segment in &self.scratch_segments {
487                    pkt_len += segment.len;
488                    if len == buf.len() {
489                        continue;
490                    }
491
492                    let copy_length = std::cmp::min(buf.len() - len, segment.len as usize);
493                    let _ = pool.guest_memory().read_at(segment.gpa, &mut buf[len..]);
494                    len += copy_length;
495                }
496
497                if len == 0 {
498                    continue;
499                }
500
501                if !self
502                    .pcap
503                    .write_packet(&buf[..len], pkt_len, snaplen as u32, &timestamp)
504                {
505                    break;
506                }
507            }
508        }
509        Ok(n)
510    }
511
512    fn tx_avail(
513        &mut self,
514        pool: &mut dyn BufferAccess,
515        segments: &[TxSegment],
516    ) -> anyhow::Result<(bool, usize)> {
517        if self.pcap.enabled.load(Ordering::Relaxed) {
518            let mut segments = segments;
519            let timestamp = SystemTime::now()
520                .duration_since(UNIX_EPOCH)
521                .unwrap_or(Duration::new(0, 0));
522            let snaplen = self.pcap.snaplen.load(Ordering::Relaxed);
523            while !segments.is_empty() {
524                let (metadata, this, rest) = next_packet(segments);
525                segments = rest;
526                if metadata.len == 0 {
527                    continue;
528                }
529                let mut buf = vec![0; snaplen];
530                let mut len = 0;
531                for segment in this {
532                    if len == buf.len() {
533                        break;
534                    }
535
536                    let copy_length = std::cmp::min(buf.len() - len, segment.len as usize);
537                    let _ = pool.guest_memory().read_at(segment.gpa, &mut buf[len..]);
538                    len += copy_length;
539                }
540
541                if len == 0 {
542                    continue;
543                }
544
545                if !self
546                    .pcap
547                    .write_packet(&buf[..len], metadata.len, snaplen as u32, &timestamp)
548                {
549                    break;
550                }
551            }
552        }
553        self.current_mut().tx_avail(pool, segments)
554    }
555
556    fn tx_poll(
557        &mut self,
558        pool: &mut dyn BufferAccess,
559        done: &mut [TxId],
560    ) -> Result<usize, TxError> {
561        self.current_mut().tx_poll(pool, done)
562    }
563}
564
565impl InspectMut for PacketCaptureQueue {
566    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
567        self.current_mut().inspect_mut(req)
568    }
569}