1mod arp;
18mod dhcp;
19mod dhcpv6;
20#[cfg_attr(unix, path = "dns_unix.rs")]
21#[cfg_attr(windows, path = "dns_windows.rs")]
22mod dns;
23mod dns_resolver;
24mod icmp;
25mod ndp;
26mod tcp;
27mod udp;
28
29mod unix;
30mod windows;
31
32const DNS_PORT: u16 = 53;
34
35use inspect::Inspect;
36use inspect::InspectMut;
37use pal_async::driver::Driver;
38use smoltcp::phy::Checksum;
39use smoltcp::phy::ChecksumCapabilities;
40use smoltcp::wire::DhcpMessageType;
41use smoltcp::wire::EthernetAddress;
42use smoltcp::wire::EthernetFrame;
43use smoltcp::wire::EthernetProtocol;
44use smoltcp::wire::EthernetRepr;
45use smoltcp::wire::IPV4_HEADER_LEN;
46use smoltcp::wire::Icmpv6Packet;
47use smoltcp::wire::IpAddress;
48use smoltcp::wire::IpProtocol;
49use smoltcp::wire::Ipv4Address;
50use smoltcp::wire::Ipv4Packet;
51use smoltcp::wire::Ipv6Address;
52use smoltcp::wire::Ipv6Packet;
53use std::task::Context;
54use std::time::Duration;
55use thiserror::Error;
56
57#[derive(InspectMut)]
59pub struct Consomme {
60 state: ConsommeState,
61 #[inspect(mut)]
62 tcp: tcp::Tcp,
63 #[inspect(mut)]
64 udp: udp::Udp,
65 icmp: icmp::Icmp,
66 dns: Option<dns_resolver::DnsResolver>,
67 host_has_ipv6: bool,
68}
69
70#[derive(Inspect)]
71struct ConsommeState {
72 params: ConsommeParams,
73 #[inspect(skip)]
74 buffer: Box<[u8]>,
75}
76
77#[derive(Inspect, Clone)]
79pub struct ConsommeParams {
80 #[inspect(display)]
82 pub net_mask: Ipv4Address,
83 #[inspect(display)]
85 pub gateway_ip: Ipv4Address,
86 #[inspect(display)]
88 pub gateway_mac: EthernetAddress,
89 #[inspect(display)]
91 pub client_ip: Ipv4Address,
92 #[inspect(display)]
94 pub client_mac: EthernetAddress,
95 #[inspect(with = "|x| inspect::iter_by_index(x).map_value(inspect::AsDisplay)")]
97 pub nameservers: Vec<IpAddress>,
98 #[inspect(display)]
100 pub prefix_len_ipv6: u8,
101 #[inspect(display)]
103 pub gateway_mac_ipv6: EthernetAddress,
104 #[inspect(display)]
109 pub gateway_link_local_ipv6: Ipv6Address,
110 #[inspect(with = "Option::is_some")]
116 pub client_ip_ipv6: Option<Ipv6Address>,
117 #[inspect(debug)]
119 pub udp_timeout: Duration,
120 #[inspect(display)]
123 pub skip_ipv6_checks: bool,
124}
125
126#[derive(Debug, Error)]
128#[error("invalid CIDR")]
129pub struct InvalidCidr;
130
131impl ConsommeParams {
132 pub fn new() -> Result<Self, Error> {
139 let nameservers = dns::nameservers()?;
140 let gateway_mac_ipv6 = EthernetAddress([0x52, 0x55, 0x0A, 0x00, 0x01, 0x02]);
141
142 Ok(Self {
143 gateway_ip: Ipv4Address::new(10, 0, 0, 1),
144 gateway_mac: EthernetAddress([0x52, 0x55, 10, 0, 0, 1]),
145 client_ip: Ipv4Address::new(10, 0, 0, 2),
146 client_mac: EthernetAddress([0x0, 0x0, 0x0, 0x0, 0x1, 0x0]),
147 net_mask: Ipv4Address::new(255, 255, 255, 0),
148 nameservers,
149 prefix_len_ipv6: 64,
150 gateway_mac_ipv6,
151 gateway_link_local_ipv6: Self::compute_link_local_address(gateway_mac_ipv6),
152 client_ip_ipv6: None,
153 udp_timeout: Duration::from_secs(300),
155 skip_ipv6_checks: false,
156 })
157 }
158
159 pub fn set_cidr(&mut self, cidr: &str) -> Result<(), InvalidCidr> {
164 let cidr: smoltcp::wire::Ipv4Cidr = cidr.parse().map_err(|()| InvalidCidr)?;
165 let base_address = cidr.network().address();
166 let mut gateway_octets = base_address.octets();
167 gateway_octets[3] += 1;
168 self.gateway_ip = Ipv4Address::from(gateway_octets);
169 let mut client_octets = base_address.octets();
170 client_octets[3] += 2;
171 self.client_ip = Ipv4Address::from(client_octets);
172 self.net_mask = cidr.netmask();
173 Ok(())
174 }
175
176 pub fn compute_link_local_address(mac: EthernetAddress) -> Ipv6Address {
186 const LINK_LOCAL_PREFIX: [u8; 8] = [0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
187
188 let mut addr = [0u8; 16];
189
190 addr[0..8].copy_from_slice(&LINK_LOCAL_PREFIX);
192
193 addr[8] = mac.0[0] ^ 0x02; addr[9] = mac.0[1];
198 addr[10] = mac.0[2];
199 addr[11] = 0xFF;
200 addr[12] = 0xFE;
201 addr[13] = mac.0[3];
202 addr[14] = mac.0[4];
203 addr[15] = mac.0[5];
204
205 Ipv6Address::from_octets(addr)
206 }
207
208 pub fn filtered_ipv6_nameservers(&self) -> Vec<Ipv6Address> {
215 self.nameservers
216 .iter()
217 .filter_map(|ip| match ip {
218 IpAddress::Ipv6(addr) => Some(*addr),
219 _ => None,
220 })
221 .filter(|addr| {
222 let octets = addr.octets();
223 !(addr.is_unspecified()
224 || addr.is_loopback()
225 || addr.is_multicast()
226 || matches!(octets[0], 0xfc | 0xfd) || octets.starts_with(&[0xfe, 0xc0])) })
229 .collect()
230 }
231
232 fn internal_nameservers(&self, host_has_ipv6: bool) -> Vec<IpAddress> {
236 let mut ns = vec![self.gateway_ip.into()];
237 if host_has_ipv6 {
238 ns.push(self.gateway_link_local_ipv6.into());
239 }
240 ns
241 }
242}
243
244pub struct Access<'a, T> {
246 inner: &'a mut Consomme,
247 client: &'a mut T,
248}
249
250pub trait Client {
252 fn driver(&self) -> &dyn Driver;
257
258 fn recv(&mut self, data: &[u8], checksum: &ChecksumState);
270
271 fn rx_mtu(&mut self) -> usize;
279}
280
281#[derive(Debug, Copy, Clone)]
283pub struct ChecksumState {
284 pub ipv4: bool,
287 pub tcp: bool,
290 pub udp: bool,
293 pub tso: Option<u16>,
298}
299
300impl ChecksumState {
301 const NONE: Self = Self {
302 ipv4: false,
303 tcp: false,
304 udp: false,
305 tso: None,
306 };
307 const IPV4_ONLY: Self = Self {
308 ipv4: true,
309 tcp: false,
310 udp: false,
311 tso: None,
312 };
313 const TCP4: Self = Self {
314 ipv4: true,
315 tcp: true,
316 udp: false,
317 tso: None,
318 };
319 const UDP4: Self = Self {
320 ipv4: true,
321 tcp: false,
322 udp: true,
323 tso: None,
324 };
325 const TCP6: Self = Self {
326 ipv4: false,
327 tcp: true,
328 udp: false,
329 tso: None,
330 };
331
332 fn caps(&self) -> ChecksumCapabilities {
333 let mut caps = ChecksumCapabilities::default();
334 if self.ipv4 {
335 caps.ipv4 = Checksum::None;
336 }
337 if self.tcp {
338 caps.tcp = Checksum::None;
339 }
340 if self.udp {
341 caps.udp = Checksum::None;
342 }
343 caps
344 }
345}
346
347pub const MIN_MTU: usize = 1514;
350
351#[derive(Debug, Error)]
353pub enum DropReason {
354 #[error("packet parsing error")]
356 Packet(#[from] smoltcp::wire::Error),
357 #[error("unsupported ethertype {0}")]
359 UnsupportedEthertype(EthernetProtocol),
360 #[error("unsupported ip protocol {0}")]
362 UnsupportedIpProtocol(IpProtocol),
363 #[error("unsupported dhcp message type {0:?}")]
365 UnsupportedDhcp(DhcpMessageType),
366 #[error("unsupported arp type")]
368 UnsupportedArp,
369 #[error("ipv4 checksum failure")]
371 Ipv4Checksum,
372 #[error("send buffer full")]
374 SendBufferFull,
375 #[error("io error")]
377 Io(#[source] std::io::Error),
378 #[error("bad tcp state")]
380 BadTcpState(#[from] tcp::TcpError),
381 #[error("port is not bound")]
383 PortNotBound,
384 #[error("unsupported dhcpv6 message type {0:?}")]
386 UnsupportedDhcpv6(dhcpv6::MessageType),
387 #[error("unsupported ndp message type {0:?}")]
389 UnsupportedNdp(ndp::NdpMessageType),
390 #[error("packet is malformed")]
393 MalformedPacket,
394 #[error("packet fragmentation is not supported")]
397 FragmentedPacket,
398}
399
400#[derive(Debug, Error)]
402pub enum Error {
403 #[error("failed to initialize nameservers")]
405 Dns(#[from] dns::Error),
406}
407
408#[derive(Debug)]
409struct Ipv4Addresses {
410 src_addr: Ipv4Address,
411 dst_addr: Ipv4Address,
412}
413
414#[derive(Debug)]
415struct Ipv6Addresses {
416 src_addr: Ipv6Address,
417 dst_addr: Ipv6Address,
418}
419
420#[derive(Debug)]
421enum IpAddresses {
422 V4(Ipv4Addresses),
423 V6(Ipv6Addresses),
424}
425
426impl IpAddresses {
427 fn src_addr(&self) -> IpAddress {
428 match self {
429 IpAddresses::V4(addrs) => IpAddress::Ipv4(addrs.src_addr),
430 IpAddresses::V6(addrs) => IpAddress::Ipv6(addrs.src_addr),
431 }
432 }
433
434 fn dst_addr(&self) -> IpAddress {
435 match self {
436 IpAddresses::V4(addrs) => IpAddress::Ipv4(addrs.dst_addr),
437 IpAddresses::V6(addrs) => IpAddress::Ipv6(addrs.dst_addr),
438 }
439 }
440}
441
442fn is_routable_ipv6(addr: &std::net::Ipv6Addr) -> bool {
445 !addr.is_loopback() && !addr.is_unspecified() && !addr.is_unicast_link_local()
446}
447
448impl Consomme {
449 pub fn new(mut params: ConsommeParams) -> Self {
451 let host_has_ipv6 = if params.skip_ipv6_checks {
452 true
453 } else {
454 #[cfg(windows)]
455 let host_has_ipv6_result = windows::host_has_ipv6_address().map_err(|e| e.to_string());
456 #[cfg(unix)]
457 let host_has_ipv6_result = unix::host_has_ipv6_address().map_err(|e| e.to_string());
458
459 match host_has_ipv6_result {
460 Ok(has_ipv6) => has_ipv6,
461 Err(e) => {
462 tracelimit::warn_ratelimited!(
463 "failed to check for host IPv6 address, assuming no IPv6 support: {e}"
464 );
465 false
466 }
467 }
468 };
469 let dns =
470 match dns_resolver::DnsResolver::new(dns_resolver::DEFAULT_MAX_PENDING_DNS_REQUESTS) {
471 Ok(dns) => {
472 params.nameservers = params.internal_nameservers(host_has_ipv6);
474 Some(dns)
475 }
476 Err(_) => {
477 tracelimit::warn_ratelimited!(
478 "failed to initialize DNS resolver, falling back to using host DNS settings"
479 );
480 None
481 }
482 };
483 let timeout = params.udp_timeout;
484 Self {
485 state: ConsommeState {
486 params,
487 buffer: Box::new([0; 65536]),
488 },
489 tcp: tcp::Tcp::new(),
490 udp: udp::Udp::new(timeout),
491 icmp: icmp::Icmp::new(),
492 dns,
493 host_has_ipv6,
494 }
495 }
496
497 pub fn params_mut(&mut self) -> &mut ConsommeParams {
502 &mut self.state.params
503 }
504
505 pub fn access<'a, T: Client>(&'a mut self, client: &'a mut T) -> Access<'a, T> {
507 Access {
508 inner: self,
509 client,
510 }
511 }
512}
513
514impl<T: Client> Access<'_, T> {
515 pub fn get(&self) -> &Consomme {
517 self.inner
518 }
519
520 pub fn get_mut(&mut self) -> &mut Consomme {
522 self.inner
523 }
524
525 pub fn poll(&mut self, cx: &mut Context<'_>) {
527 self.poll_udp(cx);
528 self.poll_tcp(cx);
529 self.poll_icmp(cx);
530 }
531
532 pub fn refresh_driver(&mut self) {
536 self.refresh_tcp_driver();
537 self.refresh_udp_driver();
538 }
539
540 pub fn send(&mut self, data: &[u8], checksum: &ChecksumState) -> Result<(), DropReason> {
570 let frame_packet = EthernetFrame::new_unchecked(data);
571 let frame = EthernetRepr::parse(&frame_packet)?;
572 match frame.ethertype {
573 EthernetProtocol::Ipv4 => self.handle_ipv4(&frame, frame_packet.payload(), checksum)?,
574 EthernetProtocol::Ipv6 => {
575 if self.inner.host_has_ipv6 {
576 self.handle_ipv6(&frame, frame_packet.payload(), checksum)?
577 }
578 }
579 EthernetProtocol::Arp => self.handle_arp(&frame, frame_packet.payload())?,
580 _ => return Err(DropReason::UnsupportedEthertype(frame.ethertype)),
581 }
582 Ok(())
583 }
584
585 fn handle_ipv4(
586 &mut self,
587 frame: &EthernetRepr,
588 payload: &[u8],
589 checksum: &ChecksumState,
590 ) -> Result<(), DropReason> {
591 let ipv4 = Ipv4Packet::new_unchecked(payload);
592 if payload.len() < IPV4_HEADER_LEN
593 || ipv4.version() != 4
594 || payload.len() < ipv4.header_len().into()
595 || payload.len() < ipv4.total_len().into()
596 {
597 return Err(DropReason::MalformedPacket);
598 }
599
600 let total_len = if checksum.tso.is_some() {
601 payload.len()
602 } else {
603 ipv4.total_len().into()
604 };
605 if total_len < ipv4.header_len().into() {
606 return Err(DropReason::MalformedPacket);
607 }
608
609 if ipv4.more_frags() || ipv4.frag_offset() != 0 {
610 return Err(DropReason::FragmentedPacket);
611 }
612
613 if !checksum.ipv4 && !ipv4.verify_checksum() {
614 return Err(DropReason::Ipv4Checksum);
615 }
616
617 let addresses = Ipv4Addresses {
618 src_addr: ipv4.src_addr(),
619 dst_addr: ipv4.dst_addr(),
620 };
621
622 let inner = &payload[ipv4.header_len().into()..total_len];
623
624 match ipv4.next_header() {
625 IpProtocol::Tcp => self.handle_tcp(&IpAddresses::V4(addresses), inner, checksum)?,
626 IpProtocol::Udp => {
627 self.handle_udp(frame, &IpAddresses::V4(addresses), inner, checksum)?
628 }
629 IpProtocol::Icmp => {
630 self.handle_icmp(frame, &addresses, inner, checksum, ipv4.hop_limit())?
631 }
632 p => return Err(DropReason::UnsupportedIpProtocol(p)),
633 };
634 Ok(())
635 }
636
637 fn handle_ipv6(
638 &mut self,
639 frame: &EthernetRepr,
640 payload: &[u8],
641 checksum: &ChecksumState,
642 ) -> Result<(), DropReason> {
643 let ipv6 = Ipv6Packet::new_unchecked(payload);
644 if payload.len() < smoltcp::wire::IPV6_HEADER_LEN || ipv6.version() != 6 {
645 return Err(DropReason::MalformedPacket);
646 }
647
648 let required_len = smoltcp::wire::IPV6_HEADER_LEN + ipv6.payload_len() as usize;
649 if payload.len() < required_len {
650 return Err(DropReason::MalformedPacket);
651 }
652
653 let next_header = ipv6.next_header();
654 let inner = &payload[smoltcp::wire::IPV6_HEADER_LEN..];
655 let addresses = Ipv6Addresses {
656 src_addr: ipv6.src_addr(),
657 dst_addr: ipv6.dst_addr(),
658 };
659
660 match next_header {
661 IpProtocol::Udp => {
662 self.handle_udp(frame, &IpAddresses::V6(addresses), inner, checksum)?
663 }
664 IpProtocol::Tcp => self.handle_tcp(&IpAddresses::V6(addresses), inner, checksum)?,
665 IpProtocol::Icmpv6 => {
666 let icmpv6_packet = Icmpv6Packet::new_unchecked(inner);
668 let msg_type = icmpv6_packet.msg_type();
669
670 if msg_type == smoltcp::wire::Icmpv6Message::NeighborSolicit
671 || msg_type == smoltcp::wire::Icmpv6Message::NeighborAdvert
672 || msg_type == smoltcp::wire::Icmpv6Message::RouterSolicit
673 || msg_type == smoltcp::wire::Icmpv6Message::RouterAdvert
674 {
675 self.handle_ndp(frame, inner, ipv6.src_addr())?;
676 } else {
677 return Err(DropReason::UnsupportedIpProtocol(next_header));
678 }
679 }
680
681 p => return Err(DropReason::UnsupportedIpProtocol(p)),
682 };
683 Ok(())
684 }
685
686 pub fn update_dns_nameservers(&mut self) {
688 if self.inner.dns.is_some() {
689 self.inner.state.params.nameservers = self
690 .inner
691 .state
692 .params
693 .internal_nameservers(self.inner.host_has_ipv6);
694 }
695 }
696}