1mod arp;
18mod dhcp;
19#[cfg_attr(unix, path = "dns_unix.rs")]
20#[cfg_attr(windows, path = "dns_windows.rs")]
21mod dns;
22mod icmp;
23mod tcp;
24mod udp;
25mod windows;
26
27use inspect::Inspect;
28use inspect::InspectMut;
29use pal_async::driver::Driver;
30use smoltcp::phy::Checksum;
31use smoltcp::phy::ChecksumCapabilities;
32use smoltcp::wire::DhcpMessageType;
33use smoltcp::wire::EthernetAddress;
34use smoltcp::wire::EthernetFrame;
35use smoltcp::wire::EthernetProtocol;
36use smoltcp::wire::EthernetRepr;
37use smoltcp::wire::IPV4_HEADER_LEN;
38use smoltcp::wire::IpProtocol;
39use smoltcp::wire::Ipv4Address;
40use smoltcp::wire::Ipv4Packet;
41use std::net::SocketAddrV4;
42use std::task::Context;
43use thiserror::Error;
44
45#[derive(InspectMut)]
47pub struct Consomme {
48 state: ConsommeState,
49 tcp: tcp::Tcp,
50 #[inspect(mut)]
51 udp: udp::Udp,
52 icmp: icmp::Icmp,
53}
54
55#[derive(Inspect)]
56struct ConsommeState {
57 params: ConsommeParams,
58 #[inspect(skip)]
59 buffer: Box<[u8]>,
60}
61
62#[derive(Inspect)]
64pub struct ConsommeParams {
65 #[inspect(display)]
67 pub net_mask: Ipv4Address,
68 #[inspect(display)]
70 pub gateway_ip: Ipv4Address,
71 #[inspect(display)]
73 pub gateway_mac: EthernetAddress,
74 #[inspect(display)]
76 pub client_ip: Ipv4Address,
77 #[inspect(display)]
79 pub client_mac: EthernetAddress,
80 #[inspect(with = "|x| inspect::iter_by_index(x).map_value(inspect::AsDisplay)")]
82 pub nameservers: Vec<Ipv4Address>,
83}
84
85#[derive(Debug, Error)]
87#[error("invalid CIDR")]
88pub struct InvalidCidr;
89
90impl ConsommeParams {
91 pub fn new() -> Result<Self, Error> {
96 let nameservers = dns::nameservers()?;
97 Ok(Self {
98 gateway_ip: Ipv4Address::new(10, 0, 0, 1),
99 gateway_mac: EthernetAddress([0x52, 0x55, 10, 0, 0, 1]),
100 client_ip: Ipv4Address::new(10, 0, 0, 2),
101 client_mac: EthernetAddress([0x0, 0x0, 0x0, 0x0, 0x1, 0x0]),
102 net_mask: Ipv4Address::new(255, 255, 255, 0),
103 nameservers,
104 })
105 }
106
107 pub fn set_cidr(&mut self, cidr: &str) -> Result<(), InvalidCidr> {
112 let cidr: smoltcp::wire::Ipv4Cidr = cidr.parse().map_err(|()| InvalidCidr)?;
113 let base_address = cidr.network().address();
114 self.gateway_ip = base_address;
115 self.gateway_ip.0[3] += 1;
116 self.client_ip = base_address;
117 self.client_ip.0[3] += 2;
118 self.net_mask = cidr.netmask();
119 Ok(())
120 }
121}
122
123pub struct Access<'a, T> {
125 inner: &'a mut Consomme,
126 client: &'a mut T,
127}
128
129pub trait Client {
131 fn driver(&self) -> &dyn Driver;
136
137 fn recv(&mut self, data: &[u8], checksum: &ChecksumState);
149
150 fn rx_mtu(&mut self) -> usize;
158}
159
160#[derive(Debug, Copy, Clone)]
162pub struct ChecksumState {
163 pub ipv4: bool,
166 pub tcp: bool,
169 pub udp: bool,
172 pub tso: Option<u16>,
177}
178
179impl ChecksumState {
180 const NONE: Self = Self {
181 ipv4: false,
182 tcp: false,
183 udp: false,
184 tso: None,
185 };
186 const IPV4_ONLY: Self = Self {
187 ipv4: true,
188 tcp: false,
189 udp: false,
190 tso: None,
191 };
192 const TCP4: Self = Self {
193 ipv4: true,
194 tcp: true,
195 udp: false,
196 tso: None,
197 };
198 const UDP4: Self = Self {
199 ipv4: true,
200 tcp: false,
201 udp: true,
202 tso: None,
203 };
204
205 fn caps(&self) -> ChecksumCapabilities {
206 let mut caps = ChecksumCapabilities::default();
207 if self.ipv4 {
208 caps.ipv4 = Checksum::None;
209 }
210 if self.tcp {
211 caps.tcp = Checksum::None;
212 }
213 if self.udp {
214 caps.udp = Checksum::None;
215 }
216 caps
217 }
218}
219
220pub const MIN_MTU: usize = 1514;
223
224#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
225struct SocketAddress {
226 ip: Ipv4Address,
227 port: u16,
228}
229
230impl From<SocketAddress> for SocketAddrV4 {
231 fn from(addr: SocketAddress) -> Self {
232 Self::new(addr.ip.into(), addr.port)
233 }
234}
235
236impl From<SocketAddress> for socket2::SockAddr {
237 fn from(addr: SocketAddress) -> Self {
238 socket2::SockAddr::from(SocketAddrV4::from(addr))
239 }
240}
241
242#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
243struct FourTuple {
244 dst: SocketAddress,
245 src: SocketAddress,
246}
247
248#[derive(Debug, Error)]
250pub enum DropReason {
251 #[error("packet parsing error")]
253 Packet(#[from] smoltcp::Error),
254 #[error("unsupported ethertype {0}")]
256 UnsupportedEthertype(EthernetProtocol),
257 #[error("unsupported ip protocol {0}")]
259 UnsupportedIpProtocol(IpProtocol),
260 #[error("unsupported dhcp message type {0:?}")]
262 UnsupportedDhcp(DhcpMessageType),
263 #[error("unsupported arp type")]
265 UnsupportedArp,
266 #[error("ipv4 checksum failure")]
268 Ipv4Checksum,
269 #[error("send buffer full")]
271 SendBufferFull,
272 #[error("io error")]
274 Io(#[source] std::io::Error),
275 #[error("bad tcp state")]
277 BadTcpState(#[from] tcp::TcpError),
278 #[error("port is not bound")]
280 PortNotBound,
281}
282
283#[derive(Debug, Error)]
285pub enum Error {
286 #[error("failed to initialize nameservers")]
288 Dns(#[from] dns::Error),
289}
290
291#[derive(Debug)]
292struct Ipv4Addresses {
293 src_addr: Ipv4Address,
294 dst_addr: Ipv4Address,
295}
296
297impl Consomme {
298 pub fn new(params: ConsommeParams) -> Self {
300 Self {
301 state: ConsommeState {
302 params,
303 buffer: Box::new([0; 65536]),
304 },
305 tcp: tcp::Tcp::new(),
306 udp: udp::Udp::new(),
307 icmp: icmp::Icmp::new(),
308 }
309 }
310
311 pub fn params_mut(&mut self) -> &mut ConsommeParams {
316 &mut self.state.params
317 }
318
319 pub fn access<'a, T: Client>(&'a mut self, client: &'a mut T) -> Access<'a, T> {
321 Access {
322 inner: self,
323 client,
324 }
325 }
326}
327
328impl<T: Client> Access<'_, T> {
329 pub fn get(&self) -> &Consomme {
331 self.inner
332 }
333
334 pub fn get_mut(&mut self) -> &mut Consomme {
336 self.inner
337 }
338
339 pub fn poll(&mut self, cx: &mut Context<'_>) {
341 self.poll_udp(cx);
342 self.poll_tcp(cx);
343 self.poll_icmp(cx);
344 }
345
346 pub fn refresh_driver(&mut self) {
350 self.refresh_tcp_driver();
351 self.refresh_udp_driver();
352 }
353
354 pub fn send(&mut self, data: &[u8], checksum: &ChecksumState) -> Result<(), DropReason> {
384 let frame_packet = EthernetFrame::new_unchecked(data);
385 let frame = EthernetRepr::parse(&frame_packet)?;
386 match frame.ethertype {
387 EthernetProtocol::Ipv4 => self.handle_ipv4(&frame, frame_packet.payload(), checksum)?,
388 EthernetProtocol::Arp => self.handle_arp(&frame, frame_packet.payload())?,
389 _ => return Err(DropReason::UnsupportedEthertype(frame.ethertype)),
390 }
391 Ok(())
392 }
393
394 fn handle_ipv4(
395 &mut self,
396 frame: &EthernetRepr,
397 payload: &[u8],
398 checksum: &ChecksumState,
399 ) -> Result<(), DropReason> {
400 let ipv4 = Ipv4Packet::new_unchecked(payload);
401 if payload.len() < IPV4_HEADER_LEN
402 || ipv4.version() != 4
403 || payload.len() < ipv4.header_len().into()
404 || payload.len() < ipv4.total_len().into()
405 {
406 return Err(DropReason::Packet(smoltcp::Error::Malformed));
407 }
408
409 let total_len = if checksum.tso.is_some() {
410 payload.len()
411 } else {
412 ipv4.total_len().into()
413 };
414 if total_len < ipv4.header_len().into() {
415 return Err(DropReason::Packet(smoltcp::Error::Malformed));
416 }
417
418 if ipv4.more_frags() || ipv4.frag_offset() != 0 {
419 return Err(DropReason::Packet(smoltcp::Error::Fragmented));
420 }
421
422 if !checksum.ipv4 && !ipv4.verify_checksum() {
423 return Err(DropReason::Ipv4Checksum);
424 }
425
426 let addresses = Ipv4Addresses {
427 src_addr: ipv4.src_addr(),
428 dst_addr: ipv4.dst_addr(),
429 };
430
431 let inner = &payload[ipv4.header_len().into()..total_len];
432
433 match ipv4.protocol() {
434 IpProtocol::Tcp => self.handle_tcp(&addresses, inner, checksum)?,
435 IpProtocol::Udp => self.handle_udp(frame, &addresses, inner, checksum)?,
436 IpProtocol::Icmp => {
437 self.handle_icmp(frame, &addresses, inner, checksum, ipv4.hop_limit())?
438 }
439 p => return Err(DropReason::UnsupportedIpProtocol(p)),
440 };
441 Ok(())
442 }
443}