1mod arp;
18mod dhcp;
19#[cfg_attr(unix, path = "dns_unix.rs")]
20#[cfg_attr(windows, path = "dns_windows.rs")]
21mod dns;
22mod tcp;
23mod udp;
24mod windows;
25
26use inspect::Inspect;
27use inspect::InspectMut;
28use pal_async::driver::Driver;
29use smoltcp::phy::Checksum;
30use smoltcp::phy::ChecksumCapabilities;
31use smoltcp::wire::DhcpMessageType;
32use smoltcp::wire::EthernetAddress;
33use smoltcp::wire::EthernetFrame;
34use smoltcp::wire::EthernetProtocol;
35use smoltcp::wire::EthernetRepr;
36use smoltcp::wire::IPV4_HEADER_LEN;
37use smoltcp::wire::IpProtocol;
38use smoltcp::wire::Ipv4Address;
39use smoltcp::wire::Ipv4Packet;
40use std::net::SocketAddrV4;
41use std::task::Context;
42use thiserror::Error;
43
44#[derive(InspectMut)]
46pub struct Consomme {
47 state: ConsommeState,
48 tcp: tcp::Tcp,
49 #[inspect(mut)]
50 udp: udp::Udp,
51}
52
53#[derive(Inspect)]
54struct ConsommeState {
55 params: ConsommeParams,
56 #[inspect(skip)]
57 buffer: Box<[u8]>,
58}
59
60#[derive(Inspect)]
62pub struct ConsommeParams {
63 #[inspect(display)]
65 pub net_mask: Ipv4Address,
66 #[inspect(display)]
68 pub gateway_ip: Ipv4Address,
69 #[inspect(display)]
71 pub gateway_mac: EthernetAddress,
72 #[inspect(display)]
74 pub client_ip: Ipv4Address,
75 #[inspect(display)]
77 pub client_mac: EthernetAddress,
78 #[inspect(with = "|x| inspect::iter_by_index(x).map_value(inspect::AsDisplay)")]
80 pub nameservers: Vec<Ipv4Address>,
81}
82
83#[derive(Debug, Error)]
85#[error("invalid CIDR")]
86pub struct InvalidCidr;
87
88impl ConsommeParams {
89 pub fn new() -> Result<Self, Error> {
94 let nameservers = dns::nameservers()?;
95 Ok(Self {
96 gateway_ip: Ipv4Address::new(10, 0, 0, 1),
97 gateway_mac: EthernetAddress([0x52, 0x55, 10, 0, 0, 1]),
98 client_ip: Ipv4Address::new(10, 0, 0, 2),
99 client_mac: EthernetAddress([0x0, 0x0, 0x0, 0x0, 0x1, 0x0]),
100 net_mask: Ipv4Address::new(255, 255, 255, 0),
101 nameservers,
102 })
103 }
104
105 pub fn set_cidr(&mut self, cidr: &str) -> Result<(), InvalidCidr> {
110 let cidr: smoltcp::wire::Ipv4Cidr = cidr.parse().map_err(|()| InvalidCidr)?;
111 let base_address = cidr.network().address();
112 self.gateway_ip = base_address;
113 self.gateway_ip.0[3] += 1;
114 self.client_ip = base_address;
115 self.client_ip.0[3] += 2;
116 self.net_mask = cidr.netmask();
117 Ok(())
118 }
119}
120
121pub struct Access<'a, T> {
123 inner: &'a mut Consomme,
124 client: &'a mut T,
125}
126
127pub trait Client {
129 fn driver(&self) -> &dyn Driver;
134
135 fn recv(&mut self, data: &[u8], checksum: &ChecksumState);
147
148 fn rx_mtu(&mut self) -> usize;
156}
157
158#[derive(Debug, Copy, Clone)]
160pub struct ChecksumState {
161 pub ipv4: bool,
164 pub tcp: bool,
167 pub udp: bool,
170 pub tso: Option<u16>,
175}
176
177impl ChecksumState {
178 const NONE: Self = Self {
179 ipv4: false,
180 tcp: false,
181 udp: false,
182 tso: None,
183 };
184 const IPV4_ONLY: Self = Self {
185 ipv4: true,
186 tcp: false,
187 udp: false,
188 tso: None,
189 };
190 const TCP4: Self = Self {
191 ipv4: true,
192 tcp: true,
193 udp: false,
194 tso: None,
195 };
196 const UDP4: Self = Self {
197 ipv4: true,
198 tcp: false,
199 udp: true,
200 tso: None,
201 };
202
203 fn caps(&self) -> ChecksumCapabilities {
204 let mut caps = ChecksumCapabilities::default();
205 if self.ipv4 {
206 caps.ipv4 = Checksum::None;
207 }
208 if self.tcp {
209 caps.tcp = Checksum::None;
210 }
211 if self.udp {
212 caps.udp = Checksum::None;
213 }
214 caps
215 }
216}
217
218pub const MIN_MTU: usize = 1514;
221
222#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
223struct SocketAddress {
224 ip: Ipv4Address,
225 port: u16,
226}
227
228impl From<SocketAddress> for SocketAddrV4 {
229 fn from(addr: SocketAddress) -> Self {
230 Self::new(addr.ip.into(), addr.port)
231 }
232}
233
234impl From<SocketAddress> for socket2::SockAddr {
235 fn from(addr: SocketAddress) -> Self {
236 socket2::SockAddr::from(SocketAddrV4::from(addr))
237 }
238}
239
240#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
241struct FourTuple {
242 dst: SocketAddress,
243 src: SocketAddress,
244}
245
246#[derive(Debug, Error)]
248pub enum DropReason {
249 #[error("packet parsing error")]
251 Packet(#[from] smoltcp::Error),
252 #[error("unsupported ethertype {0}")]
254 UnsupportedEthertype(EthernetProtocol),
255 #[error("unsupported ip protocol {0}")]
257 UnsupportedIpProtocol(IpProtocol),
258 #[error("unsupported dhcp message type {0:?}")]
260 UnsupportedDhcp(DhcpMessageType),
261 #[error("unsupported arp type")]
263 UnsupportedArp,
264 #[error("ipv4 checksum failure")]
266 Ipv4Checksum,
267 #[error("send buffer full")]
269 SendBufferFull,
270 #[error("io error")]
272 Io(#[source] std::io::Error),
273 #[error("bad tcp state")]
275 BadTcpState(#[from] tcp::TcpError),
276 #[error("port is not bound")]
278 PortNotBound,
279}
280
281#[derive(Debug, Error)]
283pub enum Error {
284 #[error("failed to initialize nameservers")]
286 Dns(#[from] dns::Error),
287}
288
289#[derive(Debug)]
290struct Ipv4Addresses {
291 src_addr: Ipv4Address,
292 dst_addr: Ipv4Address,
293}
294
295impl Consomme {
296 pub fn new(params: ConsommeParams) -> Self {
298 Self {
299 state: ConsommeState {
300 params,
301 buffer: Box::new([0; 65536]),
302 },
303 tcp: tcp::Tcp::new(),
304 udp: udp::Udp::new(),
305 }
306 }
307
308 pub fn params_mut(&mut self) -> &mut ConsommeParams {
313 &mut self.state.params
314 }
315
316 pub fn access<'a, T: Client>(&'a mut self, client: &'a mut T) -> Access<'a, T> {
318 Access {
319 inner: self,
320 client,
321 }
322 }
323}
324
325impl<T: Client> Access<'_, T> {
326 pub fn get(&self) -> &Consomme {
328 self.inner
329 }
330
331 pub fn get_mut(&mut self) -> &mut Consomme {
333 self.inner
334 }
335
336 pub fn poll(&mut self, cx: &mut Context<'_>) {
338 self.poll_udp(cx);
339 self.poll_tcp(cx);
340 }
341
342 pub fn refresh_driver(&mut self) {
346 self.refresh_tcp_driver();
347 self.refresh_udp_driver();
348 }
349
350 pub fn send(&mut self, data: &[u8], checksum: &ChecksumState) -> Result<(), DropReason> {
380 let frame_packet = EthernetFrame::new_unchecked(data);
381 let frame = EthernetRepr::parse(&frame_packet)?;
382 match frame.ethertype {
383 EthernetProtocol::Ipv4 => self.handle_ipv4(&frame, frame_packet.payload(), checksum)?,
384 EthernetProtocol::Arp => self.handle_arp(&frame, frame_packet.payload())?,
385 _ => return Err(DropReason::UnsupportedEthertype(frame.ethertype)),
386 }
387 Ok(())
388 }
389
390 fn handle_ipv4(
391 &mut self,
392 frame: &EthernetRepr,
393 payload: &[u8],
394 checksum: &ChecksumState,
395 ) -> Result<(), DropReason> {
396 let ipv4 = Ipv4Packet::new_unchecked(payload);
397 if payload.len() < IPV4_HEADER_LEN
398 || ipv4.version() != 4
399 || payload.len() < ipv4.header_len().into()
400 || payload.len() < ipv4.total_len().into()
401 {
402 return Err(DropReason::Packet(smoltcp::Error::Malformed));
403 }
404
405 let total_len = if checksum.tso.is_some() {
406 payload.len()
407 } else {
408 ipv4.total_len().into()
409 };
410 if total_len < ipv4.header_len().into() {
411 return Err(DropReason::Packet(smoltcp::Error::Malformed));
412 }
413
414 if ipv4.more_frags() || ipv4.frag_offset() != 0 {
415 return Err(DropReason::Packet(smoltcp::Error::Fragmented));
416 }
417
418 if !checksum.ipv4 && !ipv4.verify_checksum() {
419 return Err(DropReason::Ipv4Checksum);
420 }
421
422 let addresses = Ipv4Addresses {
423 src_addr: ipv4.src_addr(),
424 dst_addr: ipv4.dst_addr(),
425 };
426
427 let inner = &payload[ipv4.header_len().into()..total_len];
428
429 match ipv4.protocol() {
430 IpProtocol::Tcp => self.handle_tcp(&addresses, inner, checksum)?,
431 IpProtocol::Udp => self.handle_udp(frame, &addresses, inner, checksum)?,
432 p => return Err(DropReason::UnsupportedIpProtocol(p)),
433 };
434 Ok(())
435 }
436}