consomme/
udp.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use super::Access;
5use super::Client;
6use super::DropReason;
7use super::SocketAddress;
8use super::dhcp::DHCP_SERVER;
9use crate::ChecksumState;
10use crate::ConsommeState;
11use crate::Ipv4Addresses;
12use inspect::Inspect;
13use inspect::InspectMut;
14use inspect_counters::Counter;
15use pal_async::interest::InterestSlot;
16use pal_async::interest::PollEvents;
17use pal_async::socket::PolledSocket;
18use smoltcp::phy::ChecksumCapabilities;
19use smoltcp::wire::ETHERNET_HEADER_LEN;
20use smoltcp::wire::EthernetAddress;
21use smoltcp::wire::EthernetFrame;
22use smoltcp::wire::EthernetProtocol;
23use smoltcp::wire::EthernetRepr;
24use smoltcp::wire::IPV4_HEADER_LEN;
25use smoltcp::wire::IpProtocol;
26use smoltcp::wire::Ipv4Packet;
27use smoltcp::wire::Ipv4Repr;
28use smoltcp::wire::UDP_HEADER_LEN;
29use smoltcp::wire::UdpPacket;
30use smoltcp::wire::UdpRepr;
31use std::collections::HashMap;
32use std::collections::hash_map;
33use std::io::ErrorKind;
34use std::net::IpAddr;
35use std::net::Ipv4Addr;
36use std::net::UdpSocket;
37use std::task::Context;
38use std::task::Poll;
39
40pub(crate) struct Udp {
41    connections: HashMap<SocketAddress, UdpConnection>,
42}
43
44impl Udp {
45    pub fn new() -> Self {
46        Self {
47            connections: HashMap::new(),
48        }
49    }
50}
51
52impl InspectMut for Udp {
53    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
54        let mut resp = req.respond();
55        for (addr, conn) in &mut self.connections {
56            resp.field_mut(&format!("{}:{}", addr.ip, addr.port), conn);
57        }
58    }
59}
60
61#[derive(InspectMut)]
62struct UdpConnection {
63    #[inspect(skip)]
64    socket: Option<PolledSocket<UdpSocket>>,
65    #[inspect(display)]
66    guest_mac: EthernetAddress,
67    stats: Stats,
68    #[inspect(mut)]
69    recycle: bool,
70}
71
72#[derive(Inspect, Default)]
73struct Stats {
74    tx_packets: Counter,
75    tx_dropped: Counter,
76    tx_errors: Counter,
77    rx_packets: Counter,
78}
79
80impl UdpConnection {
81    fn poll_conn(
82        &mut self,
83        cx: &mut Context<'_>,
84        dst_addr: &SocketAddress,
85        state: &mut ConsommeState,
86        client: &mut impl Client,
87    ) -> bool {
88        if self.recycle {
89            return false;
90        }
91
92        let mut eth = EthernetFrame::new_unchecked(&mut state.buffer);
93        loop {
94            // Receive UDP packets while there are receive buffers available. This
95            // means we won't drop UDP packets at this level--instead, we only drop
96            // UDP packets if the kernel socket's receive buffer fills up. If this
97            // results in latency problems, then we could try sizing this buffer
98            // more carefully.
99            if client.rx_mtu() == 0 {
100                break true;
101            }
102            match self.socket.as_mut().unwrap().poll_io(
103                cx,
104                InterestSlot::Read,
105                PollEvents::IN,
106                |socket| {
107                    socket
108                        .get()
109                        .recv_from(&mut eth.payload_mut()[IPV4_HEADER_LEN + UDP_HEADER_LEN..])
110                },
111            ) {
112                Poll::Ready(Ok((n, src_addr))) => {
113                    let src_ip = if let IpAddr::V4(ip) = src_addr.ip() {
114                        ip
115                    } else {
116                        unreachable!()
117                    };
118                    eth.set_ethertype(EthernetProtocol::Ipv4);
119                    eth.set_src_addr(state.params.gateway_mac);
120                    eth.set_dst_addr(self.guest_mac);
121                    let mut ipv4 = Ipv4Packet::new_unchecked(eth.payload_mut());
122                    Ipv4Repr {
123                        src_addr: src_ip.into(),
124                        dst_addr: dst_addr.ip,
125                        protocol: IpProtocol::Udp,
126                        payload_len: UDP_HEADER_LEN + n,
127                        hop_limit: 64,
128                    }
129                    .emit(&mut ipv4, &ChecksumCapabilities::default());
130                    let mut udp = UdpPacket::new_unchecked(ipv4.payload_mut());
131                    udp.set_src_port(src_addr.port());
132                    udp.set_dst_port(dst_addr.port);
133                    udp.set_len((UDP_HEADER_LEN + n) as u16);
134                    udp.fill_checksum(&src_ip.into(), &dst_addr.ip.into());
135                    let len = ETHERNET_HEADER_LEN + ipv4.total_len() as usize;
136                    client.recv(&eth.as_ref()[..len], &ChecksumState::UDP4);
137                    self.stats.rx_packets.increment();
138                }
139                Poll::Ready(Err(err)) => {
140                    tracing::error!(error = &err as &dyn std::error::Error, "recv error");
141                    break false;
142                }
143                Poll::Pending => break true,
144            }
145        }
146    }
147}
148
149impl<T: Client> Access<'_, T> {
150    pub(crate) fn poll_udp(&mut self, cx: &mut Context<'_>) {
151        self.inner.udp.connections.retain(|dst_addr, conn| {
152            conn.poll_conn(cx, dst_addr, &mut self.inner.state, self.client)
153        });
154    }
155
156    pub(crate) fn refresh_udp_driver(&mut self) {
157        self.inner.udp.connections.retain(|_, conn| {
158            let socket = conn.socket.take().unwrap().into_inner();
159            match PolledSocket::new(self.client.driver(), socket) {
160                Ok(socket) => {
161                    conn.socket = Some(socket);
162                    true
163                }
164                Err(err) => {
165                    tracing::warn!(
166                        error = &err as &dyn std::error::Error,
167                        "failed to update driver for udp connection"
168                    );
169                    false
170                }
171            }
172        });
173    }
174
175    pub(crate) fn handle_udp(
176        &mut self,
177        frame: &EthernetRepr,
178        addresses: &Ipv4Addresses,
179        payload: &[u8],
180        checksum: &ChecksumState,
181    ) -> Result<(), DropReason> {
182        let udp_packet = UdpPacket::new_checked(payload)?;
183        let udp = UdpRepr::parse(
184            &udp_packet,
185            &addresses.src_addr.into(),
186            &addresses.dst_addr.into(),
187            &checksum.caps(),
188        )?;
189
190        if addresses.dst_addr == self.inner.state.params.gateway_ip
191            || addresses.dst_addr.is_broadcast()
192        {
193            if self.handle_gateway_udp(&udp_packet)? {
194                return Ok(());
195            }
196        }
197
198        let guest_addr = SocketAddress {
199            ip: addresses.src_addr,
200            port: udp.src_port,
201        };
202
203        let conn = self.get_or_insert(guest_addr, None, Some(frame.src_addr))?;
204        match conn.socket.as_mut().unwrap().get().send_to(
205            udp_packet.payload(),
206            (Ipv4Addr::from(addresses.dst_addr), udp.dst_port),
207        ) {
208            Ok(_) => {
209                conn.stats.tx_packets.increment();
210                Ok(())
211            }
212            Err(err) if err.kind() == ErrorKind::WouldBlock => {
213                conn.stats.tx_dropped.increment();
214                Err(DropReason::SendBufferFull)
215            }
216            Err(err) => {
217                conn.stats.tx_errors.increment();
218                Err(DropReason::Io(err))
219            }
220        }
221    }
222
223    fn get_or_insert(
224        &mut self,
225        guest_addr: SocketAddress,
226        host_addr: Option<Ipv4Addr>,
227        guest_mac: Option<EthernetAddress>,
228    ) -> Result<&mut UdpConnection, DropReason> {
229        let entry = self.inner.udp.connections.entry(guest_addr);
230        match entry {
231            hash_map::Entry::Occupied(conn) => Ok(conn.into_mut()),
232            hash_map::Entry::Vacant(e) => {
233                let socket = UdpSocket::bind((host_addr.unwrap_or(Ipv4Addr::UNSPECIFIED), 0))
234                    .map_err(DropReason::Io)?;
235                let socket =
236                    PolledSocket::new(self.client.driver(), socket).map_err(DropReason::Io)?;
237                let conn = UdpConnection {
238                    socket: Some(socket),
239                    guest_mac: guest_mac.unwrap_or(self.inner.state.params.client_mac),
240                    stats: Default::default(),
241                    recycle: false,
242                };
243                Ok(e.insert(conn))
244            }
245        }
246    }
247
248    fn handle_gateway_udp(&mut self, udp: &UdpPacket<&[u8]>) -> Result<bool, DropReason> {
249        let payload = udp.payload();
250        match udp.dst_port() {
251            DHCP_SERVER => {
252                self.handle_dhcp(payload)?;
253                Ok(true)
254            }
255            _ => Ok(false),
256        }
257    }
258
259    /// Binds to the specified host IP and port for forwarding inbound UDP
260    /// packets to the guest.
261    pub fn bind_udp_port(
262        &mut self,
263        ip_addr: Option<Ipv4Addr>,
264        port: u16,
265    ) -> Result<(), DropReason> {
266        let guest_addr = SocketAddress {
267            ip: ip_addr.unwrap_or(Ipv4Addr::UNSPECIFIED).into(),
268            port,
269        };
270        let _ = self.get_or_insert(guest_addr, ip_addr, None)?;
271        Ok(())
272    }
273
274    /// Unbinds from the specified host port.
275    pub fn unbind_udp_port(&mut self, port: u16) -> Result<(), DropReason> {
276        let guest_addr = SocketAddress {
277            ip: Ipv4Addr::UNSPECIFIED.into(),
278            port,
279        };
280        match self.inner.udp.connections.remove(&guest_addr) {
281            Some(_) => Ok(()),
282            None => Err(DropReason::PortNotBound),
283        }
284    }
285}