1use super::Access;
5use super::Client;
6use super::DropReason;
7use super::dhcp::DHCP_SERVER;
8use super::dhcpv6::DHCPV6_ALL_AGENTS_MULTICAST;
9use super::dhcpv6::DHCPV6_SERVER;
10use crate::ChecksumState;
11use crate::ConsommeState;
12use crate::IpAddresses;
13use crate::Ipv4Addresses;
14use crate::Ipv6Addresses;
15use crate::dns_resolver::DnsFlow;
16use crate::dns_resolver::DnsRequest;
17use crate::dns_resolver::DnsResponse;
18use inspect::Inspect;
19use inspect::InspectMut;
20use inspect_counters::Counter;
21use pal_async::interest::InterestSlot;
22use pal_async::interest::PollEvents;
23use pal_async::socket::PolledSocket;
24use smoltcp::phy::ChecksumCapabilities;
25use smoltcp::wire::ETHERNET_HEADER_LEN;
26use smoltcp::wire::EthernetAddress;
27use smoltcp::wire::EthernetFrame;
28use smoltcp::wire::EthernetProtocol;
29use smoltcp::wire::EthernetRepr;
30use smoltcp::wire::IPV4_HEADER_LEN;
31use smoltcp::wire::IPV6_HEADER_LEN;
32use smoltcp::wire::IpAddress;
33use smoltcp::wire::IpProtocol;
34use smoltcp::wire::Ipv4Packet;
35use smoltcp::wire::Ipv4Repr;
36use smoltcp::wire::Ipv6Packet;
37use smoltcp::wire::Ipv6Repr;
38use smoltcp::wire::UDP_HEADER_LEN;
39use smoltcp::wire::UdpPacket;
40use smoltcp::wire::UdpRepr;
41use std::collections::HashMap;
42use std::collections::hash_map;
43use std::io::ErrorKind;
44use std::net::IpAddr;
45use std::net::Ipv4Addr;
46use std::net::Ipv6Addr;
47use std::net::SocketAddr;
48use std::net::SocketAddrV4;
49use std::net::SocketAddrV6;
50use std::net::UdpSocket;
51use std::task::Context;
52use std::task::Poll;
53use std::time::Duration;
54use std::time::Instant;
55
56use crate::DNS_PORT;
57
58pub(crate) struct Udp {
59 connections: HashMap<SocketAddr, UdpConnection>,
60 timeout: Duration,
61}
62
63impl Udp {
64 pub fn new(timeout: Duration) -> Self {
65 Self {
66 connections: HashMap::new(),
67 timeout,
68 }
69 }
70}
71
72impl InspectMut for Udp {
73 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
74 let mut resp = req.respond();
75 for (addr, conn) in &mut self.connections {
76 let key = addr.to_string();
77 resp.field_mut(&key, conn);
78 }
79 }
80}
81
82#[derive(InspectMut)]
83struct UdpConnection {
84 #[inspect(skip)]
85 socket: Option<PolledSocket<UdpSocket>>,
86 #[inspect(display)]
87 guest_mac: EthernetAddress,
88 stats: Stats,
89 #[inspect(mut)]
90 recycle: bool,
91 #[inspect(debug)]
92 last_activity: Instant,
93}
94
95#[derive(Inspect, Default)]
96struct Stats {
97 tx_packets: Counter,
98 tx_dropped: Counter,
99 tx_errors: Counter,
100 rx_packets: Counter,
101}
102
103impl UdpConnection {
104 fn poll_conn(
105 &mut self,
106 cx: &mut Context<'_>,
107 dst_addr: &SocketAddr,
108 state: &mut ConsommeState,
109 client: &mut impl Client,
110 ) -> bool {
111 if self.recycle {
112 return false;
113 }
114
115 let mut eth = EthernetFrame::new_unchecked(&mut state.buffer);
116 loop {
117 if client.rx_mtu() == 0 {
123 break true;
124 }
125
126 let header_offset = match dst_addr {
127 SocketAddr::V4(_) => IPV4_HEADER_LEN + UDP_HEADER_LEN,
128 SocketAddr::V6(_) => IPV6_HEADER_LEN + UDP_HEADER_LEN,
129 };
130
131 match self.socket.as_mut().unwrap().poll_io(
132 cx,
133 InterestSlot::Read,
134 PollEvents::IN,
135 |socket| {
136 socket
137 .get()
138 .recv_from(&mut eth.payload_mut()[header_offset..])
139 },
140 ) {
141 Poll::Ready(Ok((n, src_addr))) => {
142 let (packet_len, checksum_state) = match (dst_addr, src_addr.ip()) {
143 (SocketAddr::V4(dst), IpAddr::V4(src_ip)) => {
144 let len = build_udp_packet(
145 &mut eth,
146 src_ip.into(),
147 (*dst.ip()).into(),
148 src_addr.port(),
149 dst.port(),
150 n,
151 state.params.gateway_mac,
152 self.guest_mac,
153 );
154 (len, ChecksumState::UDP4)
155 }
156 (SocketAddr::V6(dst), IpAddr::V6(src_ip)) => {
157 let len = build_udp_packet(
158 &mut eth,
159 src_ip.into(),
160 (*dst.ip()).into(),
161 src_addr.port(),
162 dst.port(),
163 n,
164 state.params.gateway_mac,
165 self.guest_mac,
166 );
167 (len, ChecksumState::NONE)
168 }
169 _ => unreachable!("mismatched address families"),
170 };
171
172 client.recv(ð.as_ref()[..packet_len], &checksum_state);
173 self.stats.rx_packets.increment();
174 self.last_activity = Instant::now();
175 }
176 Poll::Ready(Err(err)) => {
177 tracelimit::error_ratelimited!(
178 error = &err as &dyn std::error::Error,
179 "recv error"
180 );
181 break false;
182 }
183 Poll::Pending => break true,
184 }
185 }
186 }
187}
188
189impl<T: Client> Access<'_, T> {
190 pub(crate) fn poll_udp(&mut self, cx: &mut Context<'_>) {
191 let timeout = self.inner.udp.timeout;
192 let now = Instant::now();
193
194 self.inner.udp.connections.retain(|dst_addr, conn| {
195 if now.duration_since(conn.last_activity) > timeout {
197 tracing::debug!(
198 addr = %dst_addr,
199 "UDP connection timed out"
200 );
201 return false;
202 }
203
204 conn.poll_conn(cx, dst_addr, &mut self.inner.state, self.client)
205 });
206 while let Some(response) =
207 self.inner
208 .dns
209 .as_mut()
210 .and_then(|dns| match dns.poll_udp_response(cx) {
211 Poll::Ready(resp) => resp,
212 Poll::Pending => None,
213 })
214 {
215 if let Err(e) = self.send_dns_response(&response) {
216 tracelimit::error_ratelimited!(error = ?e, "Failed to send DNS response");
217 }
218 }
219 }
220
221 pub(crate) fn refresh_udp_driver(&mut self) {
222 self.inner.udp.connections.retain(|_, conn| {
223 let socket = conn.socket.take().unwrap().into_inner();
224 match PolledSocket::new(self.client.driver(), socket) {
225 Ok(socket) => {
226 conn.socket = Some(socket);
227 true
228 }
229 Err(err) => {
230 tracing::warn!(
231 error = &err as &dyn std::error::Error,
232 "failed to update driver for udp connection"
233 );
234 false
235 }
236 }
237 });
238 }
239
240 pub(crate) fn handle_udp(
241 &mut self,
242 frame: &EthernetRepr,
243 addresses: &IpAddresses,
244 payload: &[u8],
245 checksum: &ChecksumState,
246 ) -> Result<(), DropReason> {
247 let udp_packet = UdpPacket::new_checked(payload)?;
248
249 let (guest_addr, dst_sock_addr) = match addresses {
251 IpAddresses::V4(addrs) => {
252 let udp = UdpRepr::parse(
253 &udp_packet,
254 &addrs.src_addr.into(),
255 &addrs.dst_addr.into(),
256 &checksum.caps(),
257 )?;
258
259 if addrs.dst_addr == self.inner.state.params.gateway_ip
261 || addrs.dst_addr.is_broadcast()
262 {
263 if self.handle_gateway_udp(frame, addrs, &udp_packet)? {
264 return Ok(());
265 }
266 }
267
268 let guest_addr = SocketAddr::V4(SocketAddrV4::new(addrs.src_addr, udp.src_port));
269
270 let dst_sock_addr = SocketAddr::V4(SocketAddrV4::new(addrs.dst_addr, udp.dst_port));
271
272 (guest_addr, dst_sock_addr)
273 }
274 IpAddresses::V6(addrs) => {
275 let udp = UdpRepr::parse(
276 &udp_packet,
277 &addrs.src_addr.into(),
278 &addrs.dst_addr.into(),
279 &checksum.caps(),
280 )?;
281
282 if addrs.dst_addr == self.inner.state.params.gateway_link_local_ipv6
284 || addrs.dst_addr == DHCPV6_ALL_AGENTS_MULTICAST
285 {
286 if self.handle_gateway_udp_v6(frame, addrs, &udp_packet)? {
287 return Ok(());
288 }
289 }
290
291 let guest_addr =
292 SocketAddr::V6(SocketAddrV6::new(addrs.src_addr, udp.src_port, 0, 0));
293
294 let dst_sock_addr =
295 SocketAddr::V6(SocketAddrV6::new(addrs.dst_addr, udp.dst_port, 0, 0));
296
297 (guest_addr, dst_sock_addr)
298 }
299 };
300
301 let conn = self.get_or_insert(guest_addr, Some(frame.src_addr))?;
302 match conn
303 .socket
304 .as_mut()
305 .unwrap()
306 .get()
307 .send_to(udp_packet.payload(), dst_sock_addr)
308 {
309 Ok(_) => {
310 conn.stats.tx_packets.increment();
311 conn.last_activity = Instant::now();
312 Ok(())
313 }
314 Err(err) if err.kind() == ErrorKind::WouldBlock => {
315 conn.stats.tx_dropped.increment();
316 Err(DropReason::SendBufferFull)
317 }
318 Err(err) => {
319 conn.stats.tx_errors.increment();
320 Err(DropReason::Io(err))
321 }
322 }
323 }
324
325 fn get_or_insert(
326 &mut self,
327 guest_addr: SocketAddr,
328 guest_mac: Option<EthernetAddress>,
329 ) -> Result<&mut UdpConnection, DropReason> {
330 let entry = self.inner.udp.connections.entry(guest_addr);
331 match entry {
332 hash_map::Entry::Occupied(conn) => Ok(conn.into_mut()),
333 hash_map::Entry::Vacant(e) => {
334 let bind_addr: SocketAddr = match guest_addr {
335 SocketAddr::V4(_) => {
336 SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
337 }
338 SocketAddr::V6(_) => {
339 SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0))
340 }
341 };
342
343 let socket = UdpSocket::bind(bind_addr).map_err(DropReason::Io)?;
344 let socket =
345 PolledSocket::new(self.client.driver(), socket).map_err(DropReason::Io)?;
346 let conn = UdpConnection {
347 socket: Some(socket),
348 guest_mac: guest_mac.unwrap_or(self.inner.state.params.client_mac),
349 stats: Default::default(),
350 recycle: false,
351 last_activity: Instant::now(),
352 };
353 Ok(e.insert(conn))
354 }
355 }
356 }
357
358 fn handle_gateway_udp(
359 &mut self,
360 frame: &EthernetRepr,
361 addresses: &Ipv4Addresses,
362 udp: &UdpPacket<&[u8]>,
363 ) -> Result<bool, DropReason> {
364 match udp.dst_port() {
365 DHCP_SERVER => {
366 self.handle_dhcp(udp.payload())?;
367 Ok(true)
368 }
369 DNS_PORT => self.handle_dns(
370 frame,
371 addresses.src_addr.into(),
372 addresses.dst_addr.into(),
373 udp,
374 ),
375 _ => Ok(false),
376 }
377 }
378
379 fn handle_gateway_udp_v6(
380 &mut self,
381 frame: &EthernetRepr,
382 addresses: &Ipv6Addresses,
383 udp: &UdpPacket<&[u8]>,
384 ) -> Result<bool, DropReason> {
385 let payload = udp.payload();
386 match udp.dst_port() {
387 DHCPV6_SERVER => {
388 self.handle_dhcpv6(payload, Some(addresses.src_addr))?;
389 Ok(true)
390 }
391 DNS_PORT => self.handle_dns(
392 frame,
393 addresses.src_addr.into(),
394 addresses.dst_addr.into(),
395 udp,
396 ),
397 _ => Ok(false),
398 }
399 }
400
401 pub fn bind_udp_port(&mut self, ip_addr: Option<IpAddr>, port: u16) -> Result<(), DropReason> {
404 let guest_addr = match ip_addr {
405 Some(IpAddr::V4(ip)) => SocketAddr::V4(SocketAddrV4::new(ip, port)),
406 Some(IpAddr::V6(ip)) => SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)),
407 None => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port)),
408 };
409 let _ = self.get_or_insert(guest_addr, None)?;
410 Ok(())
411 }
412
413 pub fn unbind_udp_port(&mut self, port: u16) -> Result<(), DropReason> {
415 let v4_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port));
417 let v6_addr = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, port, 0, 0));
418
419 let v4_removed = self.inner.udp.connections.remove(&v4_addr).is_some();
420 let v6_removed = self.inner.udp.connections.remove(&v6_addr).is_some();
421
422 if v4_removed || v6_removed {
423 Ok(())
424 } else {
425 Err(DropReason::PortNotBound)
426 }
427 }
428
429 fn handle_dns(
430 &mut self,
431 frame: &EthernetRepr,
432 src_addr: IpAddress,
433 dst_addr: IpAddress,
434 udp: &UdpPacket<&[u8]>,
435 ) -> Result<bool, DropReason> {
436 let Some(dns) = self.inner.dns.as_mut() else {
437 return Ok(false);
438 };
439
440 let request = DnsRequest {
441 flow: DnsFlow {
442 src_addr,
443 dst_addr,
444 src_port: udp.src_port(),
445 dst_port: udp.dst_port(),
446 gateway_mac: self.inner.state.params.gateway_mac,
447 client_mac: frame.src_addr,
448 transport: crate::dns_resolver::DnsTransport::Udp,
449 },
450 dns_query: udp.payload(),
451 };
452
453 dns.submit_udp_query(&request).map_err(|e| {
456 tracelimit::error_ratelimited!(error = ?e, "Failed to start DNS query");
457 DropReason::Packet(smoltcp::wire::Error)
458 })?;
459
460 Ok(true)
461 }
462
463 fn send_dns_response(&mut self, response: &DnsResponse) -> Result<(), DropReason> {
464 tracing::debug!(
465 response_len = response.response_data.len(),
466 src = %response.flow.src_addr,
467 dst = %response.flow.dst_addr,
468 src_port = response.flow.src_port,
469 dst_port = response.flow.dst_port,
470 "Sending UDP DNS response"
471 );
472
473 let buffer = &mut self.inner.state.buffer;
474
475 let (ip_header_len, checksum_state) = match response.flow.src_addr {
477 IpAddress::Ipv4(_) => (IPV4_HEADER_LEN, ChecksumState::UDP4),
478 IpAddress::Ipv6(_) => (IPV6_HEADER_LEN, ChecksumState::NONE),
479 };
480
481 let payload_offset = ETHERNET_HEADER_LEN + ip_header_len + UDP_HEADER_LEN;
482 let required_size = payload_offset + response.response_data.len();
483
484 if required_size > buffer.len() {
485 return Err(DropReason::SendBufferFull);
486 }
487
488 buffer[payload_offset..required_size].copy_from_slice(&response.response_data);
489
490 let mut eth_frame = EthernetFrame::new_unchecked(&mut buffer[..]);
491 let frame_len = build_udp_packet(
492 &mut eth_frame,
493 response.flow.dst_addr,
494 response.flow.src_addr,
495 response.flow.dst_port,
496 response.flow.src_port,
497 response.response_data.len(),
498 response.flow.gateway_mac,
499 response.flow.client_mac,
500 );
501
502 self.client.recv(&buffer[..frame_len], &checksum_state);
503
504 Ok(())
505 }
506
507 #[cfg(test)]
508 pub fn udp_connection_count(&self) -> usize {
510 self.inner.udp.connections.len()
511 }
512}
513
514fn build_udp_packet<T: AsRef<[u8]> + AsMut<[u8]> + ?Sized>(
521 eth_frame: &mut EthernetFrame<&mut T>,
522 src_ip: IpAddress,
523 dst_ip: IpAddress,
524 src_port: u16,
525 dst_port: u16,
526 payload_len: usize,
527 src_mac: EthernetAddress,
528 dst_mac: EthernetAddress,
529) -> usize {
530 eth_frame.set_src_addr(src_mac);
532 eth_frame.set_dst_addr(dst_mac);
533
534 match (src_ip, dst_ip) {
535 (IpAddress::Ipv4(src_ip), IpAddress::Ipv4(dst_ip)) => {
536 eth_frame.set_ethertype(EthernetProtocol::Ipv4);
537
538 let mut ipv4_packet = Ipv4Packet::new_unchecked(eth_frame.payload_mut());
540 let ipv4_repr = Ipv4Repr {
541 src_addr: src_ip,
542 dst_addr: dst_ip,
543 next_header: IpProtocol::Udp,
544 payload_len: UDP_HEADER_LEN + payload_len,
545 hop_limit: 64,
546 };
547 ipv4_repr.emit(&mut ipv4_packet, &ChecksumCapabilities::default());
548
549 let mut udp_packet = UdpPacket::new_unchecked(ipv4_packet.payload_mut());
551 udp_packet.set_src_port(src_port);
552 udp_packet.set_dst_port(dst_port);
553 udp_packet.set_len((UDP_HEADER_LEN + payload_len) as u16);
554 udp_packet.fill_checksum(&src_ip.into(), &dst_ip.into());
555
556 ETHERNET_HEADER_LEN + ipv4_packet.total_len() as usize
558 }
559 (IpAddress::Ipv6(src_ip), IpAddress::Ipv6(dst_ip)) => {
560 eth_frame.set_ethertype(EthernetProtocol::Ipv6);
561
562 let mut ipv6_packet = Ipv6Packet::new_unchecked(eth_frame.payload_mut());
564 let ipv6_repr = Ipv6Repr {
565 src_addr: src_ip,
566 dst_addr: dst_ip,
567 next_header: IpProtocol::Udp,
568 payload_len: UDP_HEADER_LEN + payload_len,
569 hop_limit: 64,
570 };
571 ipv6_repr.emit(&mut ipv6_packet);
572
573 let mut udp_packet = UdpPacket::new_unchecked(ipv6_packet.payload_mut());
575 udp_packet.set_src_port(src_port);
576 udp_packet.set_dst_port(dst_port);
577 udp_packet.set_len((UDP_HEADER_LEN + payload_len) as u16);
578 udp_packet.fill_checksum(&src_ip.into(), &dst_ip.into());
579
580 ETHERNET_HEADER_LEN + ipv6_packet.total_len()
582 }
583 _ => panic!("mismatched IP address families"),
584 }
585}
586
587#[cfg(all(unix, test))]
588mod tests {
589 use super::*;
590 use crate::Consomme;
591 use crate::ConsommeParams;
592 use pal_async::DefaultDriver;
593 use parking_lot::Mutex;
594 use smoltcp::wire::Ipv4Address;
595 use std::sync::Arc;
596
597 struct TestClient {
599 driver: Arc<DefaultDriver>,
600 received_packets: Arc<Mutex<Vec<Vec<u8>>>>,
601 rx_mtu: usize,
602 }
603
604 impl TestClient {
605 fn new(driver: Arc<DefaultDriver>) -> Self {
606 Self {
607 driver,
608 received_packets: Arc::new(Mutex::new(Vec::new())),
609 rx_mtu: 1514, }
611 }
612 }
613
614 impl Client for TestClient {
615 fn driver(&self) -> &dyn pal_async::driver::Driver {
616 &*self.driver
617 }
618
619 fn recv(&mut self, data: &[u8], _checksum: &ChecksumState) {
620 self.received_packets.lock().push(data.to_vec());
621 }
622
623 fn rx_mtu(&mut self) -> usize {
624 self.rx_mtu
625 }
626 }
627
628 fn create_consomme_with_timeout(timeout: Duration) -> Consomme {
629 let mut params = ConsommeParams::new().expect("Failed to create params");
630 params.udp_timeout = timeout;
631 Consomme::new(params)
632 }
633
634 #[pal_async::async_test]
635 async fn test_udp_connection_timeout(driver: DefaultDriver) {
636 let driver = Arc::new(driver);
637 let mut consomme = create_consomme_with_timeout(Duration::from_millis(100));
638 let mut client = TestClient::new(driver);
639
640 let guest_mac = consomme.params_mut().client_mac;
641 let gateway_mac = consomme.params_mut().gateway_mac;
642 let guest_ip: Ipv4Address = consomme.params_mut().client_ip;
643 let target_ip: Ipv4Address = Ipv4Addr::LOCALHOST;
644
645 let payload = b"test";
647 let mut buffer =
648 vec![0u8; ETHERNET_HEADER_LEN + IPV4_HEADER_LEN + UDP_HEADER_LEN + payload.len()];
649 buffer[ETHERNET_HEADER_LEN + IPV4_HEADER_LEN + UDP_HEADER_LEN..].copy_from_slice(payload);
650
651 let mut eth_frame = EthernetFrame::new_unchecked(&mut buffer[..]);
652 let packet_len = build_udp_packet(
653 &mut eth_frame,
654 IpAddress::Ipv4(guest_ip),
655 IpAddress::Ipv4(target_ip),
656 12345,
657 54321,
658 payload.len(),
659 guest_mac,
660 gateway_mac,
661 );
662
663 let mut access = consomme.access(&mut client);
664 let _ = access.send(&buffer[..packet_len], &ChecksumState::NONE);
665
666 let mut cx = Context::from_waker(std::task::Waker::noop());
667 access.poll(&mut cx);
668
669 assert_eq!(
670 access.udp_connection_count(),
671 1,
672 "Connection should be created"
673 );
674
675 for conn in access.inner.udp.connections.values_mut() {
677 conn.last_activity = Instant::now() - Duration::from_millis(150);
678 }
679
680 access.poll(&mut cx);
682
683 assert_eq!(
684 access.udp_connection_count(),
685 0,
686 "Connection should be removed after timeout"
687 );
688 }
689}