consomme/
udp.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use super::Access;
5use super::Client;
6use super::DropReason;
7use super::dhcp::DHCP_SERVER;
8use super::dhcpv6::DHCPV6_ALL_AGENTS_MULTICAST;
9use super::dhcpv6::DHCPV6_SERVER;
10use crate::ChecksumState;
11use crate::ConsommeState;
12use crate::IpAddresses;
13use crate::Ipv4Addresses;
14use crate::Ipv6Addresses;
15use crate::dns_resolver::DnsFlow;
16use crate::dns_resolver::DnsRequest;
17use crate::dns_resolver::DnsResponse;
18use inspect::Inspect;
19use inspect::InspectMut;
20use inspect_counters::Counter;
21use pal_async::interest::InterestSlot;
22use pal_async::interest::PollEvents;
23use pal_async::socket::PolledSocket;
24use smoltcp::phy::ChecksumCapabilities;
25use smoltcp::wire::ETHERNET_HEADER_LEN;
26use smoltcp::wire::EthernetAddress;
27use smoltcp::wire::EthernetFrame;
28use smoltcp::wire::EthernetProtocol;
29use smoltcp::wire::EthernetRepr;
30use smoltcp::wire::IPV4_HEADER_LEN;
31use smoltcp::wire::IPV6_HEADER_LEN;
32use smoltcp::wire::IpAddress;
33use smoltcp::wire::IpProtocol;
34use smoltcp::wire::Ipv4Packet;
35use smoltcp::wire::Ipv4Repr;
36use smoltcp::wire::Ipv6Packet;
37use smoltcp::wire::Ipv6Repr;
38use smoltcp::wire::UDP_HEADER_LEN;
39use smoltcp::wire::UdpPacket;
40use smoltcp::wire::UdpRepr;
41use std::collections::HashMap;
42use std::collections::hash_map;
43use std::io::ErrorKind;
44use std::net::IpAddr;
45use std::net::Ipv4Addr;
46use std::net::Ipv6Addr;
47use std::net::SocketAddr;
48use std::net::SocketAddrV4;
49use std::net::SocketAddrV6;
50use std::net::UdpSocket;
51use std::task::Context;
52use std::task::Poll;
53use std::time::Duration;
54use std::time::Instant;
55
56use crate::DNS_PORT;
57
58pub(crate) struct Udp {
59    connections: HashMap<SocketAddr, UdpConnection>,
60    timeout: Duration,
61}
62
63impl Udp {
64    pub fn new(timeout: Duration) -> Self {
65        Self {
66            connections: HashMap::new(),
67            timeout,
68        }
69    }
70}
71
72impl InspectMut for Udp {
73    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
74        let mut resp = req.respond();
75        for (addr, conn) in &mut self.connections {
76            let key = addr.to_string();
77            resp.field_mut(&key, conn);
78        }
79    }
80}
81
82#[derive(InspectMut)]
83struct UdpConnection {
84    #[inspect(skip)]
85    socket: Option<PolledSocket<UdpSocket>>,
86    #[inspect(display)]
87    guest_mac: EthernetAddress,
88    stats: Stats,
89    #[inspect(mut)]
90    recycle: bool,
91    #[inspect(debug)]
92    last_activity: Instant,
93}
94
95#[derive(Inspect, Default)]
96struct Stats {
97    tx_packets: Counter,
98    tx_dropped: Counter,
99    tx_errors: Counter,
100    rx_packets: Counter,
101}
102
103impl UdpConnection {
104    fn poll_conn(
105        &mut self,
106        cx: &mut Context<'_>,
107        dst_addr: &SocketAddr,
108        state: &mut ConsommeState,
109        client: &mut impl Client,
110    ) -> bool {
111        if self.recycle {
112            return false;
113        }
114
115        let mut eth = EthernetFrame::new_unchecked(&mut state.buffer);
116        loop {
117            // Receive UDP packets while there are receive buffers available. This
118            // means we won't drop UDP packets at this level--instead, we only drop
119            // UDP packets if the kernel socket's receive buffer fills up. If this
120            // results in latency problems, then we could try sizing this buffer
121            // more carefully.
122            if client.rx_mtu() == 0 {
123                break true;
124            }
125
126            let header_offset = match dst_addr {
127                SocketAddr::V4(_) => IPV4_HEADER_LEN + UDP_HEADER_LEN,
128                SocketAddr::V6(_) => IPV6_HEADER_LEN + UDP_HEADER_LEN,
129            };
130
131            match self.socket.as_mut().unwrap().poll_io(
132                cx,
133                InterestSlot::Read,
134                PollEvents::IN,
135                |socket| {
136                    socket
137                        .get()
138                        .recv_from(&mut eth.payload_mut()[header_offset..])
139                },
140            ) {
141                Poll::Ready(Ok((n, src_addr))) => {
142                    let (packet_len, checksum_state) = match (dst_addr, src_addr.ip()) {
143                        (SocketAddr::V4(dst), IpAddr::V4(src_ip)) => {
144                            let len = build_udp_packet(
145                                &mut eth,
146                                src_ip.into(),
147                                (*dst.ip()).into(),
148                                src_addr.port(),
149                                dst.port(),
150                                n,
151                                state.params.gateway_mac,
152                                self.guest_mac,
153                            );
154                            (len, ChecksumState::UDP4)
155                        }
156                        (SocketAddr::V6(dst), IpAddr::V6(src_ip)) => {
157                            let len = build_udp_packet(
158                                &mut eth,
159                                src_ip.into(),
160                                (*dst.ip()).into(),
161                                src_addr.port(),
162                                dst.port(),
163                                n,
164                                state.params.gateway_mac,
165                                self.guest_mac,
166                            );
167                            (len, ChecksumState::NONE)
168                        }
169                        _ => unreachable!("mismatched address families"),
170                    };
171
172                    client.recv(&eth.as_ref()[..packet_len], &checksum_state);
173                    self.stats.rx_packets.increment();
174                    self.last_activity = Instant::now();
175                }
176                Poll::Ready(Err(err)) => {
177                    tracelimit::error_ratelimited!(
178                        error = &err as &dyn std::error::Error,
179                        "recv error"
180                    );
181                    break false;
182                }
183                Poll::Pending => break true,
184            }
185        }
186    }
187}
188
189impl<T: Client> Access<'_, T> {
190    pub(crate) fn poll_udp(&mut self, cx: &mut Context<'_>) {
191        let timeout = self.inner.udp.timeout;
192        let now = Instant::now();
193
194        self.inner.udp.connections.retain(|dst_addr, conn| {
195            // Check if connection has timed out
196            if now.duration_since(conn.last_activity) > timeout {
197                tracing::debug!(
198                    addr = %dst_addr,
199                    "UDP connection timed out"
200                );
201                return false;
202            }
203
204            conn.poll_conn(cx, dst_addr, &mut self.inner.state, self.client)
205        });
206        while let Some(response) =
207            self.inner
208                .dns
209                .as_mut()
210                .and_then(|dns| match dns.poll_udp_response(cx) {
211                    Poll::Ready(resp) => resp,
212                    Poll::Pending => None,
213                })
214        {
215            if let Err(e) = self.send_dns_response(&response) {
216                tracelimit::error_ratelimited!(error = ?e, "Failed to send DNS response");
217            }
218        }
219    }
220
221    pub(crate) fn refresh_udp_driver(&mut self) {
222        self.inner.udp.connections.retain(|_, conn| {
223            let socket = conn.socket.take().unwrap().into_inner();
224            match PolledSocket::new(self.client.driver(), socket) {
225                Ok(socket) => {
226                    conn.socket = Some(socket);
227                    true
228                }
229                Err(err) => {
230                    tracing::warn!(
231                        error = &err as &dyn std::error::Error,
232                        "failed to update driver for udp connection"
233                    );
234                    false
235                }
236            }
237        });
238    }
239
240    pub(crate) fn handle_udp(
241        &mut self,
242        frame: &EthernetRepr,
243        addresses: &IpAddresses,
244        payload: &[u8],
245        checksum: &ChecksumState,
246    ) -> Result<(), DropReason> {
247        let udp_packet = UdpPacket::new_checked(payload)?;
248
249        // Parse UDP header and check gateway handling
250        let (guest_addr, dst_sock_addr) = match addresses {
251            IpAddresses::V4(addrs) => {
252                let udp = UdpRepr::parse(
253                    &udp_packet,
254                    &addrs.src_addr.into(),
255                    &addrs.dst_addr.into(),
256                    &checksum.caps(),
257                )?;
258
259                // Check for gateway-destined packets
260                if addrs.dst_addr == self.inner.state.params.gateway_ip
261                    || addrs.dst_addr.is_broadcast()
262                {
263                    if self.handle_gateway_udp(frame, addrs, &udp_packet)? {
264                        return Ok(());
265                    }
266                }
267
268                let guest_addr = SocketAddr::V4(SocketAddrV4::new(addrs.src_addr, udp.src_port));
269
270                let dst_sock_addr = SocketAddr::V4(SocketAddrV4::new(addrs.dst_addr, udp.dst_port));
271
272                (guest_addr, dst_sock_addr)
273            }
274            IpAddresses::V6(addrs) => {
275                let udp = UdpRepr::parse(
276                    &udp_packet,
277                    &addrs.src_addr.into(),
278                    &addrs.dst_addr.into(),
279                    &checksum.caps(),
280                )?;
281
282                // Check for gateway-destined packets (IPv6 uses multicast instead of broadcast)
283                if addrs.dst_addr == self.inner.state.params.gateway_link_local_ipv6
284                    || addrs.dst_addr == DHCPV6_ALL_AGENTS_MULTICAST
285                {
286                    if self.handle_gateway_udp_v6(frame, addrs, &udp_packet)? {
287                        return Ok(());
288                    }
289                }
290
291                let guest_addr =
292                    SocketAddr::V6(SocketAddrV6::new(addrs.src_addr, udp.src_port, 0, 0));
293
294                let dst_sock_addr =
295                    SocketAddr::V6(SocketAddrV6::new(addrs.dst_addr, udp.dst_port, 0, 0));
296
297                (guest_addr, dst_sock_addr)
298            }
299        };
300
301        let conn = self.get_or_insert(guest_addr, Some(frame.src_addr))?;
302        match conn
303            .socket
304            .as_mut()
305            .unwrap()
306            .get()
307            .send_to(udp_packet.payload(), dst_sock_addr)
308        {
309            Ok(_) => {
310                conn.stats.tx_packets.increment();
311                conn.last_activity = Instant::now();
312                Ok(())
313            }
314            Err(err) if err.kind() == ErrorKind::WouldBlock => {
315                conn.stats.tx_dropped.increment();
316                Err(DropReason::SendBufferFull)
317            }
318            Err(err) => {
319                conn.stats.tx_errors.increment();
320                Err(DropReason::Io(err))
321            }
322        }
323    }
324
325    fn get_or_insert(
326        &mut self,
327        guest_addr: SocketAddr,
328        guest_mac: Option<EthernetAddress>,
329    ) -> Result<&mut UdpConnection, DropReason> {
330        let entry = self.inner.udp.connections.entry(guest_addr);
331        match entry {
332            hash_map::Entry::Occupied(conn) => Ok(conn.into_mut()),
333            hash_map::Entry::Vacant(e) => {
334                let bind_addr: SocketAddr = match guest_addr {
335                    SocketAddr::V4(_) => {
336                        SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
337                    }
338                    SocketAddr::V6(_) => {
339                        SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0))
340                    }
341                };
342
343                let socket = UdpSocket::bind(bind_addr).map_err(DropReason::Io)?;
344                let socket =
345                    PolledSocket::new(self.client.driver(), socket).map_err(DropReason::Io)?;
346                let conn = UdpConnection {
347                    socket: Some(socket),
348                    guest_mac: guest_mac.unwrap_or(self.inner.state.params.client_mac),
349                    stats: Default::default(),
350                    recycle: false,
351                    last_activity: Instant::now(),
352                };
353                Ok(e.insert(conn))
354            }
355        }
356    }
357
358    fn handle_gateway_udp(
359        &mut self,
360        frame: &EthernetRepr,
361        addresses: &Ipv4Addresses,
362        udp: &UdpPacket<&[u8]>,
363    ) -> Result<bool, DropReason> {
364        match udp.dst_port() {
365            DHCP_SERVER => {
366                self.handle_dhcp(udp.payload())?;
367                Ok(true)
368            }
369            DNS_PORT => self.handle_dns(
370                frame,
371                addresses.src_addr.into(),
372                addresses.dst_addr.into(),
373                udp,
374            ),
375            _ => Ok(false),
376        }
377    }
378
379    fn handle_gateway_udp_v6(
380        &mut self,
381        frame: &EthernetRepr,
382        addresses: &Ipv6Addresses,
383        udp: &UdpPacket<&[u8]>,
384    ) -> Result<bool, DropReason> {
385        let payload = udp.payload();
386        match udp.dst_port() {
387            DHCPV6_SERVER => {
388                self.handle_dhcpv6(payload, Some(addresses.src_addr))?;
389                Ok(true)
390            }
391            DNS_PORT => self.handle_dns(
392                frame,
393                addresses.src_addr.into(),
394                addresses.dst_addr.into(),
395                udp,
396            ),
397            _ => Ok(false),
398        }
399    }
400
401    /// Binds to the specified host IP and port for forwarding inbound UDP
402    /// packets to the guest.
403    pub fn bind_udp_port(&mut self, ip_addr: Option<IpAddr>, port: u16) -> Result<(), DropReason> {
404        let guest_addr = match ip_addr {
405            Some(IpAddr::V4(ip)) => SocketAddr::V4(SocketAddrV4::new(ip, port)),
406            Some(IpAddr::V6(ip)) => SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)),
407            None => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port)),
408        };
409        let _ = self.get_or_insert(guest_addr, None)?;
410        Ok(())
411    }
412
413    /// Unbinds from the specified host port for both IPv4 and IPv6.
414    pub fn unbind_udp_port(&mut self, port: u16) -> Result<(), DropReason> {
415        // Try to remove both IPv4 and IPv6 bindings
416        let v4_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port));
417        let v6_addr = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, port, 0, 0));
418
419        let v4_removed = self.inner.udp.connections.remove(&v4_addr).is_some();
420        let v6_removed = self.inner.udp.connections.remove(&v6_addr).is_some();
421
422        if v4_removed || v6_removed {
423            Ok(())
424        } else {
425            Err(DropReason::PortNotBound)
426        }
427    }
428
429    fn handle_dns(
430        &mut self,
431        frame: &EthernetRepr,
432        src_addr: IpAddress,
433        dst_addr: IpAddress,
434        udp: &UdpPacket<&[u8]>,
435    ) -> Result<bool, DropReason> {
436        let Some(dns) = self.inner.dns.as_mut() else {
437            return Ok(false);
438        };
439
440        let request = DnsRequest {
441            flow: DnsFlow {
442                src_addr,
443                dst_addr,
444                src_port: udp.src_port(),
445                dst_port: udp.dst_port(),
446                gateway_mac: self.inner.state.params.gateway_mac,
447                client_mac: frame.src_addr,
448                transport: crate::dns_resolver::DnsTransport::Udp,
449            },
450            dns_query: udp.payload(),
451        };
452
453        // Submit the DNS query with addressing information
454        // The response will be queued and sent later in poll_udp
455        dns.submit_udp_query(&request).map_err(|e| {
456            tracelimit::error_ratelimited!(error = ?e, "Failed to start DNS query");
457            DropReason::Packet(smoltcp::wire::Error)
458        })?;
459
460        Ok(true)
461    }
462
463    fn send_dns_response(&mut self, response: &DnsResponse) -> Result<(), DropReason> {
464        tracing::debug!(
465            response_len = response.response_data.len(),
466            src = %response.flow.src_addr,
467            dst = %response.flow.dst_addr,
468            src_port = response.flow.src_port,
469            dst_port = response.flow.dst_port,
470            "Sending UDP DNS response"
471        );
472
473        let buffer = &mut self.inner.state.buffer;
474
475        // Determine header length based on IP version
476        let (ip_header_len, checksum_state) = match response.flow.src_addr {
477            IpAddress::Ipv4(_) => (IPV4_HEADER_LEN, ChecksumState::UDP4),
478            IpAddress::Ipv6(_) => (IPV6_HEADER_LEN, ChecksumState::NONE),
479        };
480
481        let payload_offset = ETHERNET_HEADER_LEN + ip_header_len + UDP_HEADER_LEN;
482        let required_size = payload_offset + response.response_data.len();
483
484        if required_size > buffer.len() {
485            return Err(DropReason::SendBufferFull);
486        }
487
488        buffer[payload_offset..required_size].copy_from_slice(&response.response_data);
489
490        let mut eth_frame = EthernetFrame::new_unchecked(&mut buffer[..]);
491        let frame_len = build_udp_packet(
492            &mut eth_frame,
493            response.flow.dst_addr,
494            response.flow.src_addr,
495            response.flow.dst_port,
496            response.flow.src_port,
497            response.response_data.len(),
498            response.flow.gateway_mac,
499            response.flow.client_mac,
500        );
501
502        self.client.recv(&buffer[..frame_len], &checksum_state);
503
504        Ok(())
505    }
506
507    #[cfg(test)]
508    /// Returns the current number of active UDP connections.
509    pub fn udp_connection_count(&self) -> usize {
510        self.inner.udp.connections.len()
511    }
512}
513
514/// Helper function to build a complete UDP packet in an Ethernet frame.
515///
516/// This function constructs the Ethernet, IP (v4 or v6), and UDP headers, and assumes
517/// the UDP payload is already present in the buffer at the correct offset.
518///
519/// Returns the total length of the constructed frame.
520fn build_udp_packet<T: AsRef<[u8]> + AsMut<[u8]> + ?Sized>(
521    eth_frame: &mut EthernetFrame<&mut T>,
522    src_ip: IpAddress,
523    dst_ip: IpAddress,
524    src_port: u16,
525    dst_port: u16,
526    payload_len: usize,
527    src_mac: EthernetAddress,
528    dst_mac: EthernetAddress,
529) -> usize {
530    // Build Ethernet header
531    eth_frame.set_src_addr(src_mac);
532    eth_frame.set_dst_addr(dst_mac);
533
534    match (src_ip, dst_ip) {
535        (IpAddress::Ipv4(src_ip), IpAddress::Ipv4(dst_ip)) => {
536            eth_frame.set_ethertype(EthernetProtocol::Ipv4);
537
538            // Build IPv4 header
539            let mut ipv4_packet = Ipv4Packet::new_unchecked(eth_frame.payload_mut());
540            let ipv4_repr = Ipv4Repr {
541                src_addr: src_ip,
542                dst_addr: dst_ip,
543                next_header: IpProtocol::Udp,
544                payload_len: UDP_HEADER_LEN + payload_len,
545                hop_limit: 64,
546            };
547            ipv4_repr.emit(&mut ipv4_packet, &ChecksumCapabilities::default());
548
549            // Build UDP header (payload is already in place)
550            let mut udp_packet = UdpPacket::new_unchecked(ipv4_packet.payload_mut());
551            udp_packet.set_src_port(src_port);
552            udp_packet.set_dst_port(dst_port);
553            udp_packet.set_len((UDP_HEADER_LEN + payload_len) as u16);
554            udp_packet.fill_checksum(&src_ip.into(), &dst_ip.into());
555
556            // Return total frame length
557            ETHERNET_HEADER_LEN + ipv4_packet.total_len() as usize
558        }
559        (IpAddress::Ipv6(src_ip), IpAddress::Ipv6(dst_ip)) => {
560            eth_frame.set_ethertype(EthernetProtocol::Ipv6);
561
562            // Build IPv6 header
563            let mut ipv6_packet = Ipv6Packet::new_unchecked(eth_frame.payload_mut());
564            let ipv6_repr = Ipv6Repr {
565                src_addr: src_ip,
566                dst_addr: dst_ip,
567                next_header: IpProtocol::Udp,
568                payload_len: UDP_HEADER_LEN + payload_len,
569                hop_limit: 64,
570            };
571            ipv6_repr.emit(&mut ipv6_packet);
572
573            // Build UDP header (payload is already in place)
574            let mut udp_packet = UdpPacket::new_unchecked(ipv6_packet.payload_mut());
575            udp_packet.set_src_port(src_port);
576            udp_packet.set_dst_port(dst_port);
577            udp_packet.set_len((UDP_HEADER_LEN + payload_len) as u16);
578            udp_packet.fill_checksum(&src_ip.into(), &dst_ip.into());
579
580            // Return total frame length
581            ETHERNET_HEADER_LEN + ipv6_packet.total_len()
582        }
583        _ => panic!("mismatched IP address families"),
584    }
585}
586
587#[cfg(all(unix, test))]
588mod tests {
589    use super::*;
590    use crate::Consomme;
591    use crate::ConsommeParams;
592    use pal_async::DefaultDriver;
593    use parking_lot::Mutex;
594    use smoltcp::wire::Ipv4Address;
595    use std::sync::Arc;
596
597    /// Mock test client that captures received packets
598    struct TestClient {
599        driver: Arc<DefaultDriver>,
600        received_packets: Arc<Mutex<Vec<Vec<u8>>>>,
601        rx_mtu: usize,
602    }
603
604    impl TestClient {
605        fn new(driver: Arc<DefaultDriver>) -> Self {
606            Self {
607                driver,
608                received_packets: Arc::new(Mutex::new(Vec::new())),
609                rx_mtu: 1514, // Standard Ethernet MTU
610            }
611        }
612    }
613
614    impl Client for TestClient {
615        fn driver(&self) -> &dyn pal_async::driver::Driver {
616            &*self.driver
617        }
618
619        fn recv(&mut self, data: &[u8], _checksum: &ChecksumState) {
620            self.received_packets.lock().push(data.to_vec());
621        }
622
623        fn rx_mtu(&mut self) -> usize {
624            self.rx_mtu
625        }
626    }
627
628    fn create_consomme_with_timeout(timeout: Duration) -> Consomme {
629        let mut params = ConsommeParams::new().expect("Failed to create params");
630        params.udp_timeout = timeout;
631        Consomme::new(params)
632    }
633
634    #[pal_async::async_test]
635    async fn test_udp_connection_timeout(driver: DefaultDriver) {
636        let driver = Arc::new(driver);
637        let mut consomme = create_consomme_with_timeout(Duration::from_millis(100));
638        let mut client = TestClient::new(driver);
639
640        let guest_mac = consomme.params_mut().client_mac;
641        let gateway_mac = consomme.params_mut().gateway_mac;
642        let guest_ip: Ipv4Address = consomme.params_mut().client_ip;
643        let target_ip: Ipv4Address = Ipv4Addr::LOCALHOST;
644
645        // Create a buffer and place the payload at the correct offset
646        let payload = b"test";
647        let mut buffer =
648            vec![0u8; ETHERNET_HEADER_LEN + IPV4_HEADER_LEN + UDP_HEADER_LEN + payload.len()];
649        buffer[ETHERNET_HEADER_LEN + IPV4_HEADER_LEN + UDP_HEADER_LEN..].copy_from_slice(payload);
650
651        let mut eth_frame = EthernetFrame::new_unchecked(&mut buffer[..]);
652        let packet_len = build_udp_packet(
653            &mut eth_frame,
654            IpAddress::Ipv4(guest_ip),
655            IpAddress::Ipv4(target_ip),
656            12345,
657            54321,
658            payload.len(),
659            guest_mac,
660            gateway_mac,
661        );
662
663        let mut access = consomme.access(&mut client);
664        let _ = access.send(&buffer[..packet_len], &ChecksumState::NONE);
665
666        let mut cx = Context::from_waker(std::task::Waker::noop());
667        access.poll(&mut cx);
668
669        assert_eq!(
670            access.udp_connection_count(),
671            1,
672            "Connection should be created"
673        );
674
675        // Manually update the last_activity to simulate timeout
676        for conn in access.inner.udp.connections.values_mut() {
677            conn.last_activity = Instant::now() - Duration::from_millis(150);
678        }
679
680        // Poll should remove timed out connections
681        access.poll(&mut cx);
682
683        assert_eq!(
684            access.udp_connection_count(),
685            0,
686            "Connection should be removed after timeout"
687        );
688    }
689}