net_dio/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! An endpoint built on the vmswitch DirectIO interface.
5
6#![cfg(windows)]
7#![expect(missing_docs)]
8#![forbid(unsafe_code)]
9
10pub mod resolver;
11
12use anyhow::Context as _;
13use async_trait::async_trait;
14use inspect::InspectMut;
15use net_backend::BufferAccess;
16use net_backend::Endpoint;
17use net_backend::Queue;
18use net_backend::QueueConfig;
19use net_backend::RssConfig;
20use net_backend::RxId;
21use net_backend::RxMetadata;
22use net_backend::TxError;
23use net_backend::TxId;
24use net_backend::TxSegment;
25use net_backend::next_packet;
26use pal_async::driver::Driver;
27use parking_lot::Mutex;
28use std::io::ErrorKind;
29use std::sync::Arc;
30use std::task::Context;
31use std::task::Poll;
32use vmswitch::dio;
33
34/// An endpoint that uses vmswitch's DirectIO interface to plug in to a
35/// switch.
36pub struct DioEndpoint {
37    nic: Arc<Mutex<Option<dio::DioNic>>>,
38}
39
40impl DioEndpoint {
41    pub fn new(nic: dio::DioNic) -> Self {
42        Self {
43            nic: Arc::new(Mutex::new(Some(nic))),
44        }
45    }
46}
47
48impl InspectMut for DioEndpoint {
49    fn inspect_mut(&mut self, _req: inspect::Request<'_>) {
50        // TODO
51    }
52}
53
54#[async_trait]
55impl Endpoint for DioEndpoint {
56    fn endpoint_type(&self) -> &'static str {
57        "dio"
58    }
59
60    async fn get_queues(
61        &mut self,
62        mut config: Vec<QueueConfig<'_>>,
63        _rss: Option<&RssConfig<'_>>,
64        queues: &mut Vec<Box<dyn Queue>>,
65    ) -> anyhow::Result<()> {
66        assert_eq!(config.len(), 1);
67        let config = config.drain(..).next().unwrap();
68        queues.push(Box::new(DioQueue::new(
69            &config.driver,
70            self.nic.clone(),
71            config.pool,
72            config.initial_rx,
73        )));
74        Ok(())
75    }
76
77    async fn stop(&mut self) {
78        assert!(self.nic.lock().is_some(), "the queue has not been dropped");
79    }
80}
81
82/// A DirectIO queue.
83pub struct DioQueue {
84    slot: Arc<Mutex<Option<dio::DioNic>>>,
85    nic: Option<dio::DioQueue>,
86    free: Vec<RxId>,
87    rx_pool: Box<dyn BufferAccess>,
88}
89
90impl InspectMut for DioQueue {
91    fn inspect_mut(&mut self, _req: inspect::Request<'_>) {
92        // TODO
93    }
94}
95
96impl Drop for DioQueue {
97    fn drop(&mut self) {
98        // Return the NIC to the endpoint.
99        *self.slot.lock() = self.nic.take().map(|x| x.into_inner())
100    }
101}
102
103impl DioQueue {
104    fn new(
105        driver: &(impl ?Sized + Driver),
106        slot: Arc<Mutex<Option<dio::DioNic>>>,
107        rx_pool: Box<dyn BufferAccess>,
108        initial_rx: &[RxId],
109    ) -> Self {
110        let nic = slot.lock().take();
111        Self {
112            slot,
113            nic: nic.map(|nic| dio::DioQueue::new(driver, nic)),
114            free: initial_rx.to_vec(),
115            rx_pool,
116        }
117    }
118}
119
120impl Queue for DioQueue {
121    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
122        if let Some(nic) = &mut self.nic {
123            nic.poll_read_ready(cx)
124        } else {
125            Poll::Pending
126        }
127    }
128
129    fn rx_avail(&mut self, done: &[RxId]) {
130        self.free.extend(done);
131    }
132
133    fn rx_poll(&mut self, packets: &mut [RxId]) -> anyhow::Result<usize> {
134        let mut n_packets = 0;
135        if let Some(nic) = &mut self.nic {
136            // Transmit incoming packets to the guest until there are no more available.
137            for done_id in packets {
138                let id = if let Some(&id) = self.free.last() {
139                    id
140                } else {
141                    break;
142                };
143                let result = nic.read_with(|buf| {
144                    self.rx_pool.write_packet(
145                        id,
146                        &RxMetadata {
147                            offset: 0,
148                            len: buf.len(),
149                            ..Default::default()
150                        },
151                        buf,
152                    );
153                });
154                match result {
155                    Ok(()) => self.free.pop(),
156                    Err(e) if e.kind() == ErrorKind::WouldBlock => break,
157                    Err(e) => {
158                        // The DIO endpoint is in a bad state.
159                        //
160                        // Disconnect the NIC, but do not fail the operation
161                        // since that would indicate a guest error.
162                        tracing::error!(error = &e as &dyn std::error::Error, "dio error");
163                        self.nic = None;
164                        break;
165                    }
166                };
167                *done_id = id;
168                n_packets += 1;
169            }
170        }
171        Ok(n_packets)
172    }
173
174    fn tx_avail(&mut self, mut segments: &[TxSegment]) -> anyhow::Result<(bool, usize)> {
175        let n = segments.len();
176        if let Some(nic) = &mut self.nic {
177            let mem = self.rx_pool.guest_memory();
178            while !segments.is_empty() {
179                let (metadata, this, rest) = next_packet(segments);
180                segments = rest;
181                nic.write_with(metadata.len, |mut buf| -> anyhow::Result<_> {
182                    for segment in this {
183                        let (this, rest) = buf.split_at_mut(segment.len as usize);
184                        mem.read_at(segment.gpa, this)
185                            .context("failed to write guest memory")?;
186                        buf = rest;
187                    }
188                    Ok(())
189                })
190                .unwrap_or(Ok(()))?;
191            }
192        }
193        Ok((true, n))
194    }
195
196    fn tx_poll(&mut self, _done: &mut [TxId]) -> Result<usize, TxError> {
197        Ok(0)
198    }
199
200    fn buffer_access(&mut self) -> Option<&mut dyn BufferAccess> {
201        Some(self.rx_pool.as_mut())
202    }
203}