1use super::Access;
5use super::Client;
6use super::DropReason;
7use super::SocketAddress;
8use super::dhcp::DHCP_SERVER;
9use crate::ChecksumState;
10use crate::ConsommeState;
11use crate::Ipv4Addresses;
12use inspect::Inspect;
13use inspect::InspectMut;
14use inspect_counters::Counter;
15use pal_async::interest::InterestSlot;
16use pal_async::interest::PollEvents;
17use pal_async::socket::PolledSocket;
18use smoltcp::phy::ChecksumCapabilities;
19use smoltcp::wire::ETHERNET_HEADER_LEN;
20use smoltcp::wire::EthernetAddress;
21use smoltcp::wire::EthernetFrame;
22use smoltcp::wire::EthernetProtocol;
23use smoltcp::wire::EthernetRepr;
24use smoltcp::wire::IPV4_HEADER_LEN;
25use smoltcp::wire::IpProtocol;
26use smoltcp::wire::Ipv4Packet;
27use smoltcp::wire::Ipv4Repr;
28use smoltcp::wire::UDP_HEADER_LEN;
29use smoltcp::wire::UdpPacket;
30use smoltcp::wire::UdpRepr;
31use std::collections::HashMap;
32use std::collections::hash_map;
33use std::io::ErrorKind;
34use std::net::IpAddr;
35use std::net::Ipv4Addr;
36use std::net::UdpSocket;
37use std::task::Context;
38use std::task::Poll;
39
40pub(crate) struct Udp {
41 connections: HashMap<SocketAddress, UdpConnection>,
42}
43
44impl Udp {
45 pub fn new() -> Self {
46 Self {
47 connections: HashMap::new(),
48 }
49 }
50}
51
52impl InspectMut for Udp {
53 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
54 let mut resp = req.respond();
55 for (addr, conn) in &mut self.connections {
56 resp.field_mut(&format!("{}:{}", addr.ip, addr.port), conn);
57 }
58 }
59}
60
61#[derive(InspectMut)]
62struct UdpConnection {
63 #[inspect(skip)]
64 socket: Option<PolledSocket<UdpSocket>>,
65 #[inspect(display)]
66 guest_mac: EthernetAddress,
67 stats: Stats,
68 #[inspect(mut)]
69 recycle: bool,
70}
71
72#[derive(Inspect, Default)]
73struct Stats {
74 tx_packets: Counter,
75 tx_dropped: Counter,
76 tx_errors: Counter,
77 rx_packets: Counter,
78}
79
80impl UdpConnection {
81 fn poll_conn(
82 &mut self,
83 cx: &mut Context<'_>,
84 dst_addr: &SocketAddress,
85 state: &mut ConsommeState,
86 client: &mut impl Client,
87 ) -> bool {
88 if self.recycle {
89 return false;
90 }
91
92 let mut eth = EthernetFrame::new_unchecked(&mut state.buffer);
93 loop {
94 if client.rx_mtu() == 0 {
100 break true;
101 }
102 match self.socket.as_mut().unwrap().poll_io(
103 cx,
104 InterestSlot::Read,
105 PollEvents::IN,
106 |socket| {
107 socket
108 .get()
109 .recv_from(&mut eth.payload_mut()[IPV4_HEADER_LEN + UDP_HEADER_LEN..])
110 },
111 ) {
112 Poll::Ready(Ok((n, src_addr))) => {
113 let src_ip = if let IpAddr::V4(ip) = src_addr.ip() {
114 ip
115 } else {
116 unreachable!()
117 };
118 eth.set_ethertype(EthernetProtocol::Ipv4);
119 eth.set_src_addr(state.params.gateway_mac);
120 eth.set_dst_addr(self.guest_mac);
121 let mut ipv4 = Ipv4Packet::new_unchecked(eth.payload_mut());
122 Ipv4Repr {
123 src_addr: src_ip.into(),
124 dst_addr: dst_addr.ip,
125 protocol: IpProtocol::Udp,
126 payload_len: UDP_HEADER_LEN + n,
127 hop_limit: 64,
128 }
129 .emit(&mut ipv4, &ChecksumCapabilities::default());
130 let mut udp = UdpPacket::new_unchecked(ipv4.payload_mut());
131 udp.set_src_port(src_addr.port());
132 udp.set_dst_port(dst_addr.port);
133 udp.set_len((UDP_HEADER_LEN + n) as u16);
134 udp.fill_checksum(&src_ip.into(), &dst_addr.ip.into());
135 let len = ETHERNET_HEADER_LEN + ipv4.total_len() as usize;
136 client.recv(ð.as_ref()[..len], &ChecksumState::UDP4);
137 self.stats.rx_packets.increment();
138 }
139 Poll::Ready(Err(err)) => {
140 tracing::error!(error = &err as &dyn std::error::Error, "recv error");
141 break false;
142 }
143 Poll::Pending => break true,
144 }
145 }
146 }
147}
148
149impl<T: Client> Access<'_, T> {
150 pub(crate) fn poll_udp(&mut self, cx: &mut Context<'_>) {
151 self.inner.udp.connections.retain(|dst_addr, conn| {
152 conn.poll_conn(cx, dst_addr, &mut self.inner.state, self.client)
153 });
154 }
155
156 pub(crate) fn refresh_udp_driver(&mut self) {
157 self.inner.udp.connections.retain(|_, conn| {
158 let socket = conn.socket.take().unwrap().into_inner();
159 match PolledSocket::new(self.client.driver(), socket) {
160 Ok(socket) => {
161 conn.socket = Some(socket);
162 true
163 }
164 Err(err) => {
165 tracing::warn!(
166 error = &err as &dyn std::error::Error,
167 "failed to update driver for udp connection"
168 );
169 false
170 }
171 }
172 });
173 }
174
175 pub(crate) fn handle_udp(
176 &mut self,
177 frame: &EthernetRepr,
178 addresses: &Ipv4Addresses,
179 payload: &[u8],
180 checksum: &ChecksumState,
181 ) -> Result<(), DropReason> {
182 let udp_packet = UdpPacket::new_checked(payload)?;
183 let udp = UdpRepr::parse(
184 &udp_packet,
185 &addresses.src_addr.into(),
186 &addresses.dst_addr.into(),
187 &checksum.caps(),
188 )?;
189
190 if addresses.dst_addr == self.inner.state.params.gateway_ip
191 || addresses.dst_addr.is_broadcast()
192 {
193 if self.handle_gateway_udp(&udp_packet)? {
194 return Ok(());
195 }
196 }
197
198 let guest_addr = SocketAddress {
199 ip: addresses.src_addr,
200 port: udp.src_port,
201 };
202
203 let conn = self.get_or_insert(guest_addr, None, Some(frame.src_addr))?;
204 match conn.socket.as_mut().unwrap().get().send_to(
205 udp_packet.payload(),
206 (Ipv4Addr::from(addresses.dst_addr), udp.dst_port),
207 ) {
208 Ok(_) => {
209 conn.stats.tx_packets.increment();
210 Ok(())
211 }
212 Err(err) if err.kind() == ErrorKind::WouldBlock => {
213 conn.stats.tx_dropped.increment();
214 Err(DropReason::SendBufferFull)
215 }
216 Err(err) => {
217 conn.stats.tx_errors.increment();
218 Err(DropReason::Io(err))
219 }
220 }
221 }
222
223 fn get_or_insert(
224 &mut self,
225 guest_addr: SocketAddress,
226 host_addr: Option<Ipv4Addr>,
227 guest_mac: Option<EthernetAddress>,
228 ) -> Result<&mut UdpConnection, DropReason> {
229 let entry = self.inner.udp.connections.entry(guest_addr);
230 match entry {
231 hash_map::Entry::Occupied(conn) => Ok(conn.into_mut()),
232 hash_map::Entry::Vacant(e) => {
233 let socket = UdpSocket::bind((host_addr.unwrap_or(Ipv4Addr::UNSPECIFIED), 0))
234 .map_err(DropReason::Io)?;
235 let socket =
236 PolledSocket::new(self.client.driver(), socket).map_err(DropReason::Io)?;
237 let conn = UdpConnection {
238 socket: Some(socket),
239 guest_mac: guest_mac.unwrap_or(self.inner.state.params.client_mac),
240 stats: Default::default(),
241 recycle: false,
242 };
243 Ok(e.insert(conn))
244 }
245 }
246 }
247
248 fn handle_gateway_udp(&mut self, udp: &UdpPacket<&[u8]>) -> Result<bool, DropReason> {
249 let payload = udp.payload();
250 match udp.dst_port() {
251 DHCP_SERVER => {
252 self.handle_dhcp(payload)?;
253 Ok(true)
254 }
255 _ => Ok(false),
256 }
257 }
258
259 pub fn bind_udp_port(
262 &mut self,
263 ip_addr: Option<Ipv4Addr>,
264 port: u16,
265 ) -> Result<(), DropReason> {
266 let guest_addr = SocketAddress {
267 ip: ip_addr.unwrap_or(Ipv4Addr::UNSPECIFIED).into(),
268 port,
269 };
270 let _ = self.get_or_insert(guest_addr, ip_addr, None)?;
271 Ok(())
272 }
273
274 pub fn unbind_udp_port(&mut self, port: u16) -> Result<(), DropReason> {
276 let guest_addr = SocketAddress {
277 ip: Ipv4Addr::UNSPECIFIED.into(),
278 port,
279 };
280 match self.inner.udp.connections.remove(&guest_addr) {
281 Some(_) => Ok(()),
282 None => Err(DropReason::PortNotBound),
283 }
284 }
285}