1mod assembler;
5mod ring;
6
7use super::Access;
8use super::Client;
9use super::DropReason;
10use crate::ChecksumState;
11use crate::ConsommeState;
12use crate::IpAddresses;
13use crate::dns_resolver::DnsResolver;
14use crate::dns_resolver::dns_tcp::DnsTcpHandler;
15use futures::AsyncRead;
16use futures::AsyncWrite;
17use inspect::Inspect;
18use inspect::InspectMut;
19use pal_async::interest::PollEvents;
20use pal_async::socket::PollReady;
21use pal_async::socket::PolledSocket;
22use smoltcp::phy::ChecksumCapabilities;
23use smoltcp::wire::ETHERNET_HEADER_LEN;
24use smoltcp::wire::EthernetFrame;
25use smoltcp::wire::EthernetProtocol;
26use smoltcp::wire::IPV4_HEADER_LEN;
27use smoltcp::wire::IPV6_HEADER_LEN;
28use smoltcp::wire::IpAddress;
29use smoltcp::wire::IpProtocol;
30use smoltcp::wire::IpRepr;
31use smoltcp::wire::Ipv4Packet;
32use smoltcp::wire::Ipv6Packet;
33use smoltcp::wire::TcpControl;
34use smoltcp::wire::TcpPacket;
35use smoltcp::wire::TcpRepr;
36use smoltcp::wire::TcpSeqNumber;
37use socket2::Domain;
38use socket2::Protocol;
39use socket2::SockAddr;
40use socket2::Socket;
41use socket2::Type;
42use std::collections::HashMap;
43use std::collections::hash_map;
44use std::io;
45use std::io::ErrorKind;
46use std::io::IoSlice;
47use std::io::IoSliceMut;
48use std::net::IpAddr;
49use std::net::Ipv4Addr;
50use std::net::Ipv6Addr;
51use std::net::Shutdown;
52use std::net::SocketAddr;
53use std::net::SocketAddrV4;
54use std::net::SocketAddrV6;
55use std::pin::Pin;
56use std::task::Context;
57use std::task::Poll;
58use thiserror::Error;
59
60#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
61struct FourTuple {
62 src: SocketAddr,
63 dst: SocketAddr,
64}
65
66impl core::fmt::Display for FourTuple {
67 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
68 write!(f, "{}-{}", self.src, self.dst)
69 }
70}
71
72#[derive(InspectMut)]
73pub(crate) struct Tcp {
74 #[inspect(iter_by_key)]
75 connections: HashMap<FourTuple, TcpConnection>,
76 #[inspect(iter_by_key)]
77 listeners: HashMap<u16, TcpListener>,
78 #[inspect(mut)]
79 connection_params: ConnectionParams,
80}
81
82#[derive(InspectMut)]
83struct ConnectionParams {
84 #[inspect(mut)]
85 rx_buffer_size: usize,
86 #[inspect(mut)]
87 tx_buffer_size: usize,
88}
89
90#[derive(Debug, Error)]
91pub enum TcpError {
92 #[error("still connecting")]
93 StillConnecting,
94 #[error("unacceptable segment number")]
95 Unacceptable,
96 #[error("missing ack bit")]
97 MissingAck,
98 #[error("ack newer than sequence")]
99 AckPastSequence,
100 #[error("invalid window scale")]
101 InvalidWindowScale,
102}
103
104impl Tcp {
105 pub fn new() -> Self {
106 Self {
107 connections: HashMap::new(),
108 listeners: HashMap::new(),
109 connection_params: ConnectionParams {
110 rx_buffer_size: 256 * 1024,
111 tx_buffer_size: 256 * 1024,
112 },
113 }
114 }
115}
116
117#[derive(Inspect)]
118#[inspect(tag = "info")]
119enum LoopbackPortInfo {
120 None,
121 ProxyForGuestPort { sending_port: u16, guest_port: u16 },
122}
123
124enum TcpBackend {
129 Socket(Option<PolledSocket<Socket>>),
132 Dns(DnsTcpHandler),
134}
135
136#[derive(Inspect)]
137struct TcpConnection {
138 #[inspect(skip)]
139 backend: TcpBackend,
140 #[inspect(flatten)]
141 inner: TcpConnectionInner,
142}
143
144#[derive(Inspect)]
145struct TcpConnectionInner {
146 loopback_port: LoopbackPortInfo,
147 state: TcpState,
148
149 #[inspect(with = "|x| x.len()")]
150 rx_buffer: ring::Ring,
151 #[inspect(hex)]
152 rx_window_cap: usize,
153 rx_window_scale: u8,
154 #[inspect(with = "inspect_seq")]
155 rx_seq: TcpSeqNumber,
156 #[inspect(flatten)]
157 rx_assembler: assembler::Assembler,
158 needs_ack: bool,
159 is_shutdown: bool,
160 enable_window_scaling: bool,
161
162 #[inspect(with = "|x| x.len()")]
163 tx_buffer: ring::Ring,
164 #[inspect(with = "inspect_seq")]
165 tx_acked: TcpSeqNumber,
166 #[inspect(with = "inspect_seq")]
167 tx_send: TcpSeqNumber,
168 tx_fin_buffered: bool,
169 #[inspect(hex)]
170 tx_window_len: u16,
171 tx_window_scale: u8,
172 #[inspect(with = "inspect_seq")]
173 tx_window_rx_seq: TcpSeqNumber,
174 #[inspect(with = "inspect_seq")]
175 tx_window_tx_seq: TcpSeqNumber,
176 #[inspect(hex)]
177 tx_mss: usize,
178}
179
180fn inspect_seq(seq: &TcpSeqNumber) -> inspect::AsHex<u32> {
181 inspect::AsHex(seq.0 as u32)
182}
183
184#[derive(Inspect)]
185struct TcpListener {
186 #[inspect(skip)]
187 socket: PolledSocket<Socket>,
188}
189
190#[derive(Debug, PartialEq, Eq, Inspect)]
191enum TcpState {
192 Connecting,
193 SynSent,
194 SynReceived,
195 Established,
196 FinWait1,
197 FinWait2,
198 CloseWait,
199 Closing,
200 LastAck,
201 TimeWait,
202}
203
204impl TcpState {
205 fn tx_fin(&self) -> bool {
206 match self {
207 TcpState::Connecting
208 | TcpState::SynSent
209 | TcpState::SynReceived
210 | TcpState::Established
211 | TcpState::CloseWait => false,
212
213 TcpState::FinWait1
214 | TcpState::FinWait2
215 | TcpState::Closing
216 | TcpState::TimeWait
217 | TcpState::LastAck => true,
218 }
219 }
220
221 fn rx_fin(&self) -> bool {
222 match self {
223 TcpState::Connecting
224 | TcpState::SynSent
225 | TcpState::SynReceived
226 | TcpState::Established
227 | TcpState::FinWait1
228 | TcpState::FinWait2 => false,
229
230 TcpState::CloseWait | TcpState::Closing | TcpState::LastAck | TcpState::TimeWait => {
231 true
232 }
233 }
234 }
235}
236
237impl<T: Client> Access<'_, T> {
238 pub(crate) fn poll_tcp(&mut self, cx: &mut Context<'_>) {
239 self.inner
241 .tcp
242 .listeners
243 .retain(|port, listener| match listener.poll_listener(cx) {
244 Ok(result) => {
245 if let Some((socket, mut other_addr)) = result {
246 if other_addr.ip().is_loopback() {
249 for (other_ft, connection) in self.inner.tcp.connections.iter() {
250 if connection.inner.state == TcpState::Connecting && other_ft.dst.port() == *port {
251 if let LoopbackPortInfo::ProxyForGuestPort{sending_port, guest_port} = connection.inner.loopback_port {
252 if sending_port == other_addr.port() {
253 other_addr.set_port(guest_port);
254 break;
255 }
256 }
257 }
258 }
259 }
260
261 let ft = match other_addr {
262 SocketAddr::V4(_) => FourTuple {
263 dst: other_addr,
264 src: SocketAddr::V4(SocketAddrV4::new(self.inner.state.params.client_ip, *port)),
265 },
266 SocketAddr::V6(_) => {
267 let client_ipv6 = match self.inner.state.params.client_ip_ipv6 {
268 Some(ip) => ip,
269 None => {
270 tracing::warn!("Received IPv6 connection but client IPv6 address is not known");
271 return true;
272 }
273 };
274 FourTuple {
275 dst: other_addr,
276 src: SocketAddr::V6(SocketAddrV6::new(client_ipv6, *port, 0, 0)),
277 }
278 }
279 };
280
281 match self.inner.tcp.connections.entry(ft) {
282 hash_map::Entry::Vacant(e) => {
283 let mut sender = Sender {
284 ft: &ft,
285 client: self.client,
286 state: &mut self.inner.state,
287 };
288
289 let conn = match TcpConnection::new_from_accept(
290 &mut sender,
291 socket,
292 &self.inner.tcp.connection_params,
293 ) {
294 Ok(conn) => conn,
295 Err(err) => {
296 tracing::warn!(err = %err, "Failed to create connection from newly accepted socket");
297 return true;
298 }
299 };
300 e.insert(conn);
301 }
302 hash_map::Entry::Occupied(_) => {
303 tracing::warn!(
304 address = ?ft.dst,
305 "New client request ignored because it was already connected"
306 );
307 }
308 }
309 }
310 true
311 }
312 Err(_) => false,
313 });
314 self.inner.tcp.connections.retain(|ft, conn| {
316 let mut sender = Sender {
317 ft,
318 state: &mut self.inner.state,
319 client: self.client,
320 };
321 match &mut conn.backend {
322 TcpBackend::Dns(dns_handler) => match &mut self.inner.dns {
323 Some(dns) => conn
324 .inner
325 .poll_dns_backend(cx, &mut sender, dns_handler, dns),
326 None => {
327 tracing::warn!("DNS TCP connection without DNS resolver, dropping");
328 false
329 }
330 },
331 TcpBackend::Socket(opt_socket) => {
332 conn.inner.poll_socket_backend(cx, &mut sender, opt_socket)
333 }
334 }
335 })
336 }
337
338 pub(crate) fn refresh_tcp_driver(&mut self) {
339 self.inner.tcp.connections.retain(|_, conn| {
340 let TcpBackend::Socket(opt_socket) = &mut conn.backend else {
341 return true;
343 };
344 let Some(socket) = opt_socket.take() else {
345 return true;
346 };
347 let socket = socket.into_inner();
348 match PolledSocket::new(self.client.driver(), socket) {
349 Ok(socket) => {
350 *opt_socket = Some(socket);
351 true
352 }
353 Err(err) => {
354 tracing::warn!(
355 error = &err as &dyn std::error::Error,
356 "failed to update driver for tcp connection"
357 );
358 false
359 }
360 }
361 });
362 }
363
364 pub(crate) fn handle_tcp(
365 &mut self,
366 addresses: &IpAddresses,
367 payload: &[u8],
368 checksum: &ChecksumState,
369 ) -> Result<(), DropReason> {
370 let tcp_packet = TcpPacket::new_checked(payload)?;
371 let tcp = TcpRepr::parse(
372 &tcp_packet,
373 &addresses.src_addr(),
374 &addresses.dst_addr(),
375 &checksum.caps(),
376 )?;
377
378 let ft = match addresses {
379 IpAddresses::V4(addresses) => FourTuple {
380 dst: SocketAddr::V4(SocketAddrV4::new(addresses.dst_addr, tcp.dst_port)),
381 src: SocketAddr::V4(SocketAddrV4::new(addresses.src_addr, tcp.src_port)),
382 },
383 IpAddresses::V6(addresses) => FourTuple {
384 dst: SocketAddr::V6(SocketAddrV6::new(addresses.dst_addr, tcp.dst_port, 0, 0)),
385 src: SocketAddr::V6(SocketAddrV6::new(addresses.src_addr, tcp.src_port, 0, 0)),
386 },
387 };
388 trace_tcp_packet(&tcp, tcp.payload.len(), "recv");
389
390 let is_dns_tcp =
391 is_gateway_dns_tcp(&ft, &self.inner.state.params, self.inner.dns.is_some());
392
393 let mut sender = Sender {
394 ft: &ft,
395 client: self.client,
396 state: &mut self.inner.state,
397 };
398
399 match self.inner.tcp.connections.entry(ft) {
400 hash_map::Entry::Occupied(mut e) => {
401 let keep = e.get_mut().inner.handle_packet(&mut sender, &tcp)?;
402 if !keep {
403 let dns_in_flight = matches!(
404 e.get().backend,
405 TcpBackend::Dns(ref h) if h.is_in_flight()
406 );
407 e.remove();
408 if dns_in_flight {
409 if let Some(dns) = &mut self.inner.dns {
410 dns.complete_tcp_query();
411 }
412 }
413 }
414 }
415 hash_map::Entry::Vacant(e) => {
416 if tcp.control == TcpControl::Rst {
417 } else if let Some(ack) = tcp.ack_number {
419 sender.rst(ack, None);
421 } else if tcp.control == TcpControl::Syn {
422 let conn = if is_dns_tcp {
423 TcpConnection::new_dns(
424 &mut sender,
425 &tcp,
426 &self.inner.tcp.connection_params,
427 )?
428 } else {
429 TcpConnection::new(&mut sender, &tcp, &self.inner.tcp.connection_params)?
430 };
431 e.insert(conn);
432 } else {
433 }
435 }
436 }
437 Ok(())
438 }
439
440 pub fn bind_tcp_port(&mut self, ip_addr: Option<IpAddr>, port: u16) -> Result<(), DropReason> {
443 let ip_addr = match ip_addr {
444 Some(IpAddr::V4(ip)) => SocketAddr::V4(SocketAddrV4::new(ip, port)),
445 Some(IpAddr::V6(ip)) => SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)),
446 None => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port)),
447 };
448 match self.inner.tcp.listeners.entry(port) {
449 hash_map::Entry::Occupied(_) => {
450 tracing::warn!(port, "Duplicate TCP bind for port");
451 }
452 hash_map::Entry::Vacant(e) => {
453 let ft = match ip_addr {
454 SocketAddr::V4(ip) => FourTuple {
455 dst: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
456 src: SocketAddr::V4(ip),
457 },
458 SocketAddr::V6(ip) => FourTuple {
459 dst: SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)),
460 src: SocketAddr::V6(ip),
461 },
462 };
463 let mut sender = Sender {
464 ft: &ft,
465 client: self.client,
466 state: &mut self.inner.state,
467 };
468
469 let listener = TcpListener::new(&mut sender)?;
470 e.insert(listener);
471 }
472 }
473 Ok(())
474 }
475
476 pub fn unbind_tcp_port(&mut self, port: u16) -> Result<(), DropReason> {
478 match self.inner.tcp.listeners.entry(port) {
479 hash_map::Entry::Occupied(e) => {
480 e.remove();
481 Ok(())
482 }
483 hash_map::Entry::Vacant(_) => Err(DropReason::PortNotBound),
484 }
485 }
486}
487
488struct Sender<'a, T> {
489 ft: &'a FourTuple,
490 client: &'a mut T,
491 state: &'a mut ConsommeState,
492}
493
494impl<T: Client> Sender<'_, T> {
495 fn send_packet(&mut self, tcp: &TcpRepr<'_>, payload: Option<ring::View<'_>>) {
496 let buffer = &mut self.state.buffer;
497 let mut eth_packet = EthernetFrame::new_unchecked(&mut buffer[..]);
498 eth_packet.set_dst_addr(self.state.params.client_mac);
499 eth_packet.set_src_addr(self.state.params.gateway_mac);
500 let ip = IpRepr::new(
501 self.ft.dst.ip().into(),
502 self.ft.src.ip().into(),
503 IpProtocol::Tcp,
504 tcp.header_len() + payload.as_ref().map_or(0, |p| p.len()),
505 64,
506 );
507 match ip {
509 IpRepr::Ipv4(_) => eth_packet.set_ethertype(EthernetProtocol::Ipv4),
510 IpRepr::Ipv6(_) => eth_packet.set_ethertype(EthernetProtocol::Ipv6),
511 }
512
513 let ip_packet_buf = eth_packet.payload_mut();
515 ip.emit(&mut *ip_packet_buf, &ChecksumCapabilities::default());
516
517 let (tcp_payload_buf, ip_total_len) = match self.ft.dst {
518 SocketAddr::V4(_) => {
519 let ipv4_packet = Ipv4Packet::new_unchecked(&*ip_packet_buf);
520 let total_len = ipv4_packet.total_len() as usize;
521 let payload_offset = ipv4_packet.header_len() as usize;
522 (&mut ip_packet_buf[payload_offset..total_len], total_len)
523 }
524 SocketAddr::V6(_) => {
525 let ipv6_packet = Ipv6Packet::new_unchecked(&*ip_packet_buf);
526 let total_len = ipv6_packet.total_len();
527 let payload_offset = IPV6_HEADER_LEN;
528 (&mut ip_packet_buf[payload_offset..total_len], total_len)
529 }
530 };
531
532 let dst_ip_addr: IpAddress = self.ft.dst.ip().into();
533 let src_ip_addr: IpAddress = self.ft.src.ip().into();
534 let mut tcp_packet = TcpPacket::new_unchecked(tcp_payload_buf);
535 tcp.emit(
536 &mut tcp_packet,
537 &dst_ip_addr,
538 &src_ip_addr,
539 &ChecksumCapabilities::default(),
540 );
541
542 if let Some(payload) = &payload {
544 payload.copy_to_slice(tcp_packet.payload_mut());
545 }
546 tcp_packet.fill_checksum(&self.ft.dst.ip().into(), &self.ft.src.ip().into());
547 let n = ETHERNET_HEADER_LEN + ip_total_len;
548 let checksum_state = match self.ft.dst {
549 SocketAddr::V4(_) => ChecksumState::TCP4,
550 SocketAddr::V6(_) => ChecksumState::TCP6,
551 };
552
553 self.client.recv(&buffer[..n], &checksum_state);
554 }
555
556 fn rst(&mut self, seq: TcpSeqNumber, ack: Option<TcpSeqNumber>) {
557 let tcp = TcpRepr {
558 src_port: self.ft.dst.port(),
559 dst_port: self.ft.src.port(),
560 control: TcpControl::Rst,
561 seq_number: seq,
562 ack_number: ack,
563 window_len: 0,
564 window_scale: None,
565 max_seg_size: None,
566 sack_permitted: false,
567 sack_ranges: [None, None, None],
568 timestamp: None,
569 payload: &[],
570 };
571
572 trace_tcp_packet(&tcp, 0, "rst xmit");
573
574 self.send_packet(&tcp, None);
575 }
576}
577
578impl TcpConnection {
579 fn new_base(params: &ConnectionParams) -> TcpConnectionInner {
580 let mut rx_tx_seq = [0; 8];
581 getrandom::fill(&mut rx_tx_seq[..]).expect("prng failure");
582 let rx_seq = TcpSeqNumber(i32::from_ne_bytes(
583 rx_tx_seq[0..4].try_into().expect("invalid length"),
584 ));
585 let tx_seq = TcpSeqNumber(i32::from_ne_bytes(
586 rx_tx_seq[4..8].try_into().expect("invalid length"),
587 ));
588
589 let rx_buffer_size: usize = params.rx_buffer_size.clamp(16384, 4 << 20);
590 let rx_window_scale =
591 (usize::BITS - rx_buffer_size.leading_zeros()).saturating_sub(16) as u8;
592
593 let tx_buffer_size = params
594 .tx_buffer_size
595 .clamp(16384, 4 << 20)
596 .next_power_of_two();
597
598 TcpConnectionInner {
599 loopback_port: LoopbackPortInfo::None,
600 state: TcpState::Connecting,
601 rx_buffer: ring::Ring::new(0),
602 rx_window_cap: rx_buffer_size,
603 rx_window_scale,
604 rx_seq,
605 rx_assembler: assembler::Assembler::new(),
606 needs_ack: false,
607 is_shutdown: false,
608 enable_window_scaling: false,
609 tx_buffer: ring::Ring::new(tx_buffer_size),
610 tx_acked: tx_seq,
611 tx_send: tx_seq,
612 tx_window_len: 1,
613 tx_window_scale: 0,
614 tx_window_rx_seq: rx_seq,
615 tx_window_tx_seq: tx_seq,
616 tx_mss: 536,
619 tx_fin_buffered: false,
620 }
621 }
622
623 fn new(
624 sender: &mut Sender<'_, impl Client>,
625 tcp: &TcpRepr<'_>,
626 params: &ConnectionParams,
627 ) -> Result<Self, DropReason> {
628 let mut inner = Self::new_base(params);
629 inner.initialize_from_first_client_packet(tcp)?;
630
631 let socket = Socket::new(
632 match sender.ft.dst {
633 SocketAddr::V4(_) => Domain::IPV4,
634 SocketAddr::V6(_) => Domain::IPV6,
635 },
636 Type::STREAM,
637 Some(Protocol::TCP),
638 )
639 .map_err(DropReason::Io)?;
640
641 #[cfg(windows)]
645 if sender.ft.dst.ip().is_loopback() {
646 if let Err(err) = crate::windows::disable_connection_retries(&socket) {
647 tracing::trace!(err, "Failed to disable loopback retries");
648 }
649 }
650
651 let socket = PolledSocket::new(sender.client.driver(), socket).map_err(DropReason::Io)?;
652 match socket.get().connect(&SockAddr::from(sender.ft.dst)) {
653 Ok(_) => unreachable!(),
654 Err(err) if is_connect_incomplete_error(&err) => (),
655 Err(err) => {
656 log_connect_error(&err);
657 sender.rst(TcpSeqNumber(0), Some(tcp.seq_number + tcp.segment_len()));
658 return Err(DropReason::Io(err));
659 }
660 }
661 if let Ok(addr) = socket.get().local_addr() {
662 match addr.as_socket() {
663 None => {
664 tracing::warn!("unable to get local socket address");
665 }
666 Some(addr) => {
667 if addr.ip().is_loopback() {
668 inner.loopback_port = LoopbackPortInfo::ProxyForGuestPort {
669 sending_port: addr.port(),
670 guest_port: sender.ft.src.port(),
671 };
672 }
673 }
674 }
675 }
676 Ok(Self {
677 backend: TcpBackend::Socket(Some(socket)),
678 inner,
679 })
680 }
681
682 fn new_from_accept(
683 sender: &mut Sender<'_, impl Client>,
684 socket: Socket,
685 params: &ConnectionParams,
686 ) -> Result<Self, DropReason> {
687 let mut inner = TcpConnectionInner {
688 state: TcpState::SynSent,
689 ..Self::new_base(params)
690 };
691 inner.send_syn(sender, None);
692 Ok(Self {
693 backend: TcpBackend::Socket(Some(
694 PolledSocket::new(sender.client.driver(), socket).map_err(DropReason::Io)?,
695 )),
696 inner,
697 })
698 }
699
700 fn new_dns(
704 sender: &mut Sender<'_, impl Client>,
705 tcp: &TcpRepr<'_>,
706 params: &ConnectionParams,
707 ) -> Result<Self, DropReason> {
708 let mut inner = Self::new_base(params);
709 inner.initialize_from_first_client_packet(tcp)?;
710
711 let flow = crate::dns_resolver::DnsFlow {
712 src_addr: sender.ft.src.ip().into(),
713 dst_addr: sender.ft.dst.ip().into(),
714 src_port: sender.ft.src.port(),
715 dst_port: sender.ft.dst.port(),
716 gateway_mac: sender.state.params.gateway_mac,
717 client_mac: sender.state.params.client_mac,
718 transport: crate::dns_resolver::DnsTransport::Tcp,
719 };
720
721 inner.state = TcpState::SynReceived;
723 inner.send_syn(sender, Some(inner.rx_seq));
724
725 Ok(Self {
726 backend: TcpBackend::Dns(DnsTcpHandler::new(flow)),
727 inner,
728 })
729 }
730}
731
732impl TcpConnectionInner {
733 fn initialize_from_first_client_packet(&mut self, tcp: &TcpRepr<'_>) -> Result<(), DropReason> {
734 let tx_mss = tcp.max_seg_size.map_or(536, |x| x.into());
737
738 if let Some(tx_window_scale) = tcp.window_scale {
739 if tx_window_scale > 14 {
740 return Err(TcpError::InvalidWindowScale.into());
741 }
742 self.enable_window_scaling = true;
743 self.tx_window_scale = tx_window_scale;
744 } else {
745 self.enable_window_scaling = false;
748 self.rx_window_cap = self.rx_window_cap.min(u16::MAX as usize);
749 self.rx_window_scale = 0;
750 }
751
752 self.rx_buffer = ring::Ring::new(self.rx_window_cap.next_power_of_two());
753 self.rx_seq = tcp.seq_number + 1;
754 self.tx_window_rx_seq = tcp.seq_number + 1;
755 self.tx_mss = tx_mss;
756 Ok(())
757 }
758
759 fn poll_dns_backend(
763 &mut self,
764 cx: &mut Context<'_>,
765 sender: &mut Sender<'_, impl Client>,
766 dns_handler: &mut DnsTcpHandler,
767 dns: &mut DnsResolver,
768 ) -> bool {
769 if self.state.rx_fin() && !dns_handler.guest_fin() {
772 dns_handler.set_guest_fin();
773 }
774
775 while !self.tx_buffer.is_full() {
778 let (a, b) = self.tx_buffer.unwritten_slices_mut();
779 let mut bufs = [IoSliceMut::new(a), IoSliceMut::new(b)];
780 match dns_handler.poll_read(cx, &mut bufs, dns) {
781 Poll::Ready(Ok(n)) => {
782 if n == 0 {
783 if !self.state.tx_fin() {
785 self.close();
786 }
787 break;
788 }
789 self.tx_buffer.extend_by(n);
790 }
791 Poll::Ready(Err(_)) => {
792 sender.rst(self.tx_send, Some(self.rx_seq));
793 return false;
794 }
795 Poll::Pending => break,
796 }
797 }
798
799 let view = self.rx_buffer.view(0..self.rx_buffer.len());
801 let (a, b) = view.as_slices();
802 match dns_handler.ingest(&[a, b], dns) {
803 Ok(consumed) if consumed > 0 => {
804 self.rx_buffer.consume(consumed);
805 }
806 Ok(_) => {}
807 Err(_) => {
808 sender.rst(self.tx_send, Some(self.rx_seq));
810 return false;
811 }
812 }
813
814 self.send_next(sender);
815 !(self.state == TcpState::TimeWait
816 || self.state == TcpState::LastAck
817 || (self.state.tx_fin() && self.state.rx_fin() && self.tx_buffer.is_empty()))
818 }
819
820 fn poll_socket_backend(
825 &mut self,
826 cx: &mut Context<'_>,
827 sender: &mut Sender<'_, impl Client>,
828 opt_socket: &mut Option<PolledSocket<Socket>>,
829 ) -> bool {
830 if self.state == TcpState::Connecting {
832 let Some(socket) = opt_socket.as_mut() else {
833 return false;
834 };
835 match socket.poll_ready(cx, PollEvents::OUT) {
836 Poll::Ready(r) => {
837 if r.has_err() {
838 self.handle_connect_error(sender, socket);
839 return false;
840 }
841
842 tracing::debug!("connection established");
843 self.state = TcpState::SynReceived;
844 }
845 Poll::Pending => return true,
846 }
847 } else if self.state == TcpState::SynSent {
848 return true;
850 }
851
852 if let Some(socket) = opt_socket.as_mut() {
854 if self.state.tx_fin() {
855 if let Poll::Ready(events) = socket.poll_ready(cx, PollEvents::EMPTY) {
856 if events.has_err() {
857 let err = take_socket_error(socket);
858 match err.kind() {
859 ErrorKind::BrokenPipe | ErrorKind::ConnectionReset => {}
860 _ => tracelimit::warn_ratelimited!(
861 error = &err as &dyn std::error::Error,
862 "socket failure after fin"
863 ),
864 }
865 sender.rst(self.tx_send, Some(self.rx_seq));
866 return false;
867 }
868
869 *opt_socket = None;
871 }
872 } else {
873 while !self.tx_buffer.is_full() {
874 let (a, b) = self.tx_buffer.unwritten_slices_mut();
875 let mut bufs = [IoSliceMut::new(a), IoSliceMut::new(b)];
876 match Pin::new(&mut *socket).poll_read_vectored(cx, &mut bufs) {
877 Poll::Ready(Ok(n)) => {
878 if n == 0 {
879 self.close();
880 break;
881 }
882 self.tx_buffer.extend_by(n);
883 }
884 Poll::Ready(Err(err)) => {
885 match err.kind() {
886 ErrorKind::ConnectionReset => tracing::trace!(
887 error = &err as &dyn std::error::Error,
888 "socket read error"
889 ),
890 _ => tracelimit::warn_ratelimited!(
891 error = &err as &dyn std::error::Error,
892 "socket read error"
893 ),
894 }
895 sender.rst(self.tx_send, Some(self.rx_seq));
896 return false;
897 }
898 Poll::Pending => break,
899 }
900 }
901 }
902 }
903
904 if let Some(socket) = opt_socket.as_mut() {
906 while !self.rx_buffer.is_empty() {
907 let view = self.rx_buffer.view(0..self.rx_buffer.len());
908 let (a, b) = view.as_slices();
909 let bufs = [IoSlice::new(a), IoSlice::new(b)];
910 match Pin::new(&mut *socket).poll_write_vectored(cx, &bufs) {
911 Poll::Ready(Ok(n)) => {
912 self.rx_buffer.consume(n);
913 }
914 Poll::Ready(Err(err)) => {
915 match err.kind() {
916 ErrorKind::BrokenPipe | ErrorKind::ConnectionReset => {}
917 _ => {
918 tracelimit::warn_ratelimited!(
919 error = &err as &dyn std::error::Error,
920 "socket write error"
921 );
922 }
923 }
924 sender.rst(self.tx_send, Some(self.rx_seq));
925 return false;
926 }
927 Poll::Pending => break,
928 }
929 }
930 if self.rx_buffer.is_empty() && self.state.rx_fin() && !self.is_shutdown {
931 if let Err(err) = socket.get().shutdown(Shutdown::Write) {
932 tracelimit::warn_ratelimited!(
933 error = &err as &dyn std::error::Error,
934 "shutdown error"
935 );
936 sender.rst(self.tx_send, Some(self.rx_seq));
937 return false;
938 }
939 self.is_shutdown = true;
940 }
941 }
942
943 self.send_next(sender);
945 true
946 }
947
948 fn handle_connect_error(
949 &mut self,
950 sender: &mut Sender<'_, impl Client>,
951 socket: &mut PolledSocket<Socket>,
952 ) {
953 let err = take_socket_error(socket);
954 if err.kind() == ErrorKind::TimedOut {
955 tracing::debug!(error = &err as &dyn std::error::Error, "connect timed out");
959 } else {
960 log_connect_error(&err);
961 sender.rst(self.tx_send, Some(self.rx_seq));
962 }
963 }
964
965 fn rx_window_len(&self) -> u16 {
966 ((self.rx_window_cap - self.rx_buffer.len()) >> self.rx_window_scale) as u16
967 }
968
969 fn send_next(&mut self, sender: &mut Sender<'_, impl Client>) {
970 match self.state {
971 TcpState::Connecting => {}
972 TcpState::SynReceived => self.send_syn(sender, Some(self.rx_seq)),
973 _ => self.send_data(sender),
974 }
975 }
976
977 fn send_syn(&mut self, sender: &mut Sender<'_, impl Client>, ack_number: Option<TcpSeqNumber>) {
978 if self.tx_send != self.tx_acked || sender.client.rx_mtu() == 0 {
979 return;
980 }
981
982 let window_scale = self.enable_window_scaling.then_some(self.rx_window_scale);
985
986 let max_seg_size = u16::MAX;
989 let tcp = TcpRepr {
990 src_port: sender.ft.dst.port(),
991 dst_port: sender.ft.src.port(),
992 control: TcpControl::Syn,
993 seq_number: self.tx_send,
994 ack_number,
995 window_len: if ack_number.is_some() {
996 self.rx_window_len()
997 } else {
998 0
999 },
1000 window_scale,
1001 max_seg_size: Some(max_seg_size),
1002 sack_permitted: false,
1003 sack_ranges: [None, None, None],
1004 timestamp: None,
1005 payload: &[],
1006 };
1007
1008 sender.send_packet(&tcp, None);
1009 self.tx_send += 1;
1010 }
1011
1012 fn send_data(&mut self, sender: &mut Sender<'_, impl Client>) {
1013 let tx_payload_end = self.tx_acked + self.tx_buffer.len();
1015 let tx_end = tx_payload_end + self.tx_fin_buffered as usize;
1016 let tx_window_end = self.tx_acked + ((self.tx_window_len as usize) << self.tx_window_scale);
1017 let tx_done = seq_min([tx_end, tx_window_end]);
1018
1019 while self.needs_ack || self.tx_send < tx_done {
1020 let rx_mtu = sender.client.rx_mtu();
1021 if rx_mtu == 0 {
1022 break;
1024 }
1025
1026 let mut tcp = TcpRepr {
1027 src_port: sender.ft.dst.port(),
1028 dst_port: sender.ft.src.port(),
1029 control: TcpControl::None,
1030 seq_number: self.tx_send,
1031 ack_number: Some(self.rx_seq),
1032 window_len: self.rx_window_len(),
1033 window_scale: None,
1034 max_seg_size: None,
1035 sack_permitted: false,
1036 sack_ranges: [None, None, None],
1037 timestamp: None,
1038 payload: &[],
1039 };
1040
1041 let mut tx_next = self.tx_send;
1042
1043 let tx_segment_end = {
1050 let ip_header_len = match sender.ft.dst {
1051 SocketAddr::V4(_) => IPV4_HEADER_LEN,
1052 SocketAddr::V6(_) => IPV6_HEADER_LEN,
1053 };
1054 let header_len = ETHERNET_HEADER_LEN + ip_header_len + tcp.header_len();
1055 let mtu = rx_mtu.min(sender.state.buffer.len());
1056 seq_min([
1057 tx_payload_end,
1058 tx_window_end,
1059 tx_next + self.tx_mss,
1060 tx_next + (mtu - header_len),
1061 ])
1062 };
1063
1064 let (payload_start, payload_len) = if tx_next < tx_segment_end {
1065 (tx_next - self.tx_acked, tx_segment_end - tx_next)
1066 } else {
1067 (0, 0)
1068 };
1069
1070 tx_next += payload_len;
1071
1072 if self.tx_fin_buffered
1074 && tcp.control == TcpControl::None
1075 && tx_next == tx_payload_end
1076 && tx_next < tx_window_end
1077 {
1078 tcp.control = TcpControl::Fin;
1079 tx_next += 1;
1080 }
1081
1082 assert!(tx_next <= tx_end);
1083 assert!(self.needs_ack || tx_next > self.tx_send);
1084
1085 trace_tcp_packet(&tcp, payload_len, "xmit");
1086
1087 let payload = self
1088 .tx_buffer
1089 .view(payload_start..payload_start + payload_len);
1090
1091 sender.send_packet(&tcp, Some(payload));
1092 self.tx_send = tx_next;
1093 self.needs_ack = false;
1094 }
1095
1096 assert!(self.tx_send <= tx_end);
1097 }
1098
1099 fn close(&mut self) {
1100 tracing::trace!("fin");
1101 match self.state {
1102 TcpState::SynSent | TcpState::SynReceived | TcpState::Established => {
1103 self.state = TcpState::FinWait1;
1104 }
1105 TcpState::CloseWait => {
1106 self.state = TcpState::LastAck;
1107 }
1108 TcpState::Connecting
1109 | TcpState::FinWait1
1110 | TcpState::FinWait2
1111 | TcpState::Closing
1112 | TcpState::TimeWait
1113 | TcpState::LastAck => unreachable!("fin in {:?}", self.state),
1114 }
1115 self.tx_fin_buffered = true;
1116 }
1117
1118 fn ack(&self, sender: &mut Sender<'_, impl Client>) {
1125 let tcp = TcpRepr {
1126 src_port: sender.ft.dst.port(),
1127 dst_port: sender.ft.src.port(),
1128 control: TcpControl::None,
1129 seq_number: self.tx_send,
1130 ack_number: Some(self.rx_seq),
1131 window_len: self.rx_window_len(),
1132 window_scale: None,
1133 max_seg_size: None,
1134 sack_permitted: false,
1135 sack_ranges: [None, None, None],
1136 timestamp: None,
1137 payload: &[],
1138 };
1139
1140 trace_tcp_packet(&tcp, 0, "ack");
1141
1142 sender.send_packet(&tcp, None);
1143 }
1144
1145 fn handle_listen_syn(
1146 &mut self,
1147 sender: &mut Sender<'_, impl Client>,
1148 tcp: &TcpRepr<'_>,
1149 ) -> Result<bool, DropReason> {
1150 if tcp.control != TcpControl::Syn || tcp.segment_len() != 1 {
1151 tracing::error!(?tcp.control, "invalid packet waiting for syn, drop connection");
1152 return Ok(false);
1153 }
1154
1155 let ack_number = tcp.ack_number.ok_or(TcpError::MissingAck)?;
1156 if ack_number <= self.tx_acked || ack_number > self.tx_send {
1157 sender.rst(ack_number, None);
1158 return Ok(false);
1159 }
1160 self.tx_acked = ack_number;
1161
1162 self.initialize_from_first_client_packet(tcp)?;
1163 self.tx_window_tx_seq = ack_number;
1164 self.tx_window_len = tcp.window_len;
1165
1166 self.ack(sender);
1168
1169 self.state = TcpState::Established;
1170 Ok(true)
1171 }
1172
1173 fn handle_packet(
1174 &mut self,
1175 sender: &mut Sender<'_, impl Client>,
1176 tcp: &TcpRepr<'_>,
1177 ) -> Result<bool, DropReason> {
1178 if self.state == TcpState::Connecting {
1179 return Err(TcpError::StillConnecting.into());
1183 } else if self.state == TcpState::SynSent {
1184 return self.handle_listen_syn(sender, tcp);
1185 }
1186
1187 let rx_window_len = self.rx_window_cap - self.rx_buffer.len();
1188 let rx_window_end = self.rx_seq + rx_window_len;
1189 let segment_end = tcp.seq_number + tcp.segment_len();
1190
1191 let seq_acceptable = if rx_window_len != 0 {
1193 (tcp.seq_number >= self.rx_seq && tcp.seq_number < rx_window_end)
1194 || (tcp.segment_len() > 0
1195 && segment_end > self.rx_seq
1196 && segment_end <= rx_window_end)
1197 } else {
1198 tcp.segment_len() == 0 && tcp.seq_number == self.rx_seq
1199 };
1200
1201 if tcp.control == TcpControl::Rst {
1202 if !seq_acceptable {
1203 return Err(TcpError::Unacceptable.into());
1206 }
1207
1208 if tcp.seq_number != self.rx_seq {
1210 self.ack(sender);
1212 return Ok(true);
1213 }
1214
1215 tracing::debug!("connection reset");
1217 return Ok(false);
1218 }
1219
1220 if !seq_acceptable {
1222 self.ack(sender);
1223 return Err(TcpError::Unacceptable.into());
1224 }
1225
1226 if tcp.control == TcpControl::Syn {
1228 if self.state == TcpState::SynReceived {
1229 tracing::debug!("invalid syn, drop connection");
1230 return Ok(false);
1231 }
1232 self.ack(sender);
1234 return Ok(true);
1235 }
1236
1237 let ack_number = tcp.ack_number.ok_or(TcpError::MissingAck)?;
1239
1240 if self.state == TcpState::SynReceived {
1244 if ack_number <= self.tx_acked || ack_number > self.tx_send {
1245 sender.rst(ack_number, None);
1246 return Ok(false);
1247 }
1248 self.tx_window_len = tcp.window_len;
1249 self.tx_window_rx_seq = tcp.seq_number;
1250 self.tx_window_tx_seq = ack_number;
1251 self.tx_acked += 1;
1252 self.state = TcpState::Established;
1253 }
1254
1255 if ack_number > self.tx_send {
1257 self.ack(sender);
1258 return Err(TcpError::AckPastSequence.into());
1259 }
1260
1261 if ack_number > self.tx_acked {
1263 let mut consumed = ack_number - self.tx_acked;
1264 if self.tx_fin_buffered && ack_number == self.tx_acked + self.tx_buffer.len() + 1 {
1265 self.tx_fin_buffered = false;
1266 consumed -= 1;
1267 match self.state {
1268 TcpState::FinWait1 => self.state = TcpState::FinWait2,
1269 TcpState::Closing => self.state = TcpState::TimeWait,
1270 TcpState::LastAck => return Ok(false),
1271 _ => unreachable!(),
1272 }
1273 }
1274 self.tx_buffer.consume(consumed);
1275 self.tx_acked = ack_number;
1276 }
1277
1278 if ack_number >= self.tx_acked
1280 && (tcp.seq_number > self.tx_window_rx_seq
1281 || (tcp.seq_number == self.tx_window_rx_seq && ack_number >= self.tx_window_tx_seq))
1282 {
1283 self.tx_window_len = tcp.window_len;
1284 self.tx_window_rx_seq = tcp.seq_number;
1285 self.tx_window_tx_seq = ack_number;
1286 }
1287
1288 let mut fin = tcp.control == TcpControl::Fin;
1290 let segment_skip = if tcp.seq_number < self.rx_seq {
1291 self.rx_seq - tcp.seq_number
1292 } else {
1293 0
1294 };
1295 let segment_end = if segment_end > rx_window_end {
1296 fin = false;
1297 rx_window_end
1298 } else {
1299 segment_end
1300 };
1301 let payload = &tcp.payload[segment_skip..segment_end - tcp.seq_number - fin as usize];
1302
1303 let mut rx_fin = false;
1304
1305 match self.state {
1307 TcpState::Connecting | TcpState::SynReceived | TcpState::SynSent => unreachable!(),
1308 TcpState::Established | TcpState::FinWait1 | TcpState::FinWait2 => {
1309 if !payload.is_empty() || fin {
1310 let seq_offset = if tcp.seq_number >= self.rx_seq {
1318 tcp.seq_number - self.rx_seq
1319 } else {
1320 0
1321 };
1322 let ring_offset = self.rx_buffer.len() + seq_offset;
1323
1324 let (rx_consumed, assembler_fin, accepted) =
1329 match self
1330 .rx_assembler
1331 .add(seq_offset as u32, payload.len() as u32, fin)
1332 {
1333 Ok(result) => (result.consumed as usize, result.fin, true),
1334 Err(assembler::TooManyGaps) => (0, false, false),
1335 };
1336
1337 if accepted && !payload.is_empty() {
1341 self.rx_buffer.write_at(ring_offset, payload);
1342 }
1343 self.rx_buffer.extend_by(rx_consumed);
1344 self.rx_seq += rx_consumed;
1345 rx_fin = assembler_fin;
1346 if rx_fin {
1347 self.rx_seq += 1;
1348 }
1349 }
1350 if tcp.segment_len() > 0 {
1351 self.needs_ack = true;
1352 }
1353 }
1354 TcpState::CloseWait | TcpState::Closing | TcpState::LastAck => {}
1355 TcpState::TimeWait => {
1356 self.ack(sender);
1357 }
1359 }
1360
1361 if rx_fin {
1363 match self.state {
1364 TcpState::Connecting | TcpState::SynReceived | TcpState::SynSent => unreachable!(),
1365 TcpState::Established => {
1366 self.state = TcpState::CloseWait;
1367 }
1368 TcpState::FinWait1 => {
1369 self.state = TcpState::Closing;
1370 }
1371 TcpState::FinWait2 => {
1372 self.state = TcpState::TimeWait;
1373 }
1375 TcpState::CloseWait
1376 | TcpState::Closing
1377 | TcpState::LastAck
1378 | TcpState::TimeWait => {}
1379 }
1380 }
1381
1382 Ok(true)
1383 }
1384}
1385
1386impl TcpListener {
1387 pub fn new(sender: &mut Sender<'_, impl Client>) -> Result<Self, DropReason> {
1388 let socket = match sender.ft.src {
1389 SocketAddr::V4(_) => Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP)),
1390 SocketAddr::V6(_) => Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP)),
1391 }
1392 .map_err(DropReason::Io)?;
1393
1394 let socket = PolledSocket::new(sender.client.driver(), socket).map_err(DropReason::Io)?;
1395 if let Err(err) = socket.get().bind(&sender.ft.src.into()) {
1396 tracing::warn!(
1397 address = ?sender.ft.src,
1398 error = &err as &dyn std::error::Error,
1399 "socket bind error"
1400 );
1401 return Err(DropReason::Io(err));
1402 }
1403 if let Err(err) = socket.listen(10) {
1404 tracing::warn!(
1405 error = &err as &dyn std::error::Error,
1406 "socket listen error"
1407 );
1408 return Err(DropReason::Io(err));
1409 }
1410 Ok(Self { socket })
1411 }
1412
1413 fn poll_listener(
1414 &mut self,
1415 cx: &mut Context<'_>,
1416 ) -> Result<Option<(Socket, SocketAddr)>, DropReason> {
1417 match self.socket.poll_accept(cx) {
1418 Poll::Ready(r) => match r {
1419 Ok((socket, address)) => match address.as_socket() {
1420 Some(addr) => Ok(Some((socket, addr))),
1421 None => {
1422 tracing::warn!(?address, "Unknown address from accept");
1423 Ok(None)
1424 }
1425 },
1426 Err(_) => {
1427 let err = take_socket_error(&self.socket);
1428 tracing::warn!(error = &err as &dyn std::error::Error, "listen failure");
1429 Err(DropReason::Io(err))
1430 }
1431 },
1432 Poll::Pending => Ok(None),
1433 }
1434 }
1435}
1436
1437fn trace_tcp_packet(tcp: &TcpRepr<'_>, payload_len: usize, label: &str) {
1443 tracing::trace!(
1444 label,
1445 flags = match tcp.control {
1446 TcpControl::Syn => Some("SYN"),
1447 TcpControl::Fin => Some("FIN"),
1448 TcpControl::Rst => Some("RST"),
1449 TcpControl::Psh => Some("PSH"),
1450 TcpControl::None => None,
1451 },
1452 seq = tcp.seq_number.0 as u32,
1453 next_seq = (tcp.seq_number.0 as u32).wrapping_add((payload_len + tcp.control.len()) as u32),
1454 ack = tcp.ack_number.map(|a| a.0 as u32),
1455 window = tcp.window_len,
1456 payload_len,
1457 "tcp packet",
1458 );
1459}
1460
1461fn take_socket_error(socket: &PolledSocket<Socket>) -> io::Error {
1462 match socket.get().take_error() {
1463 Ok(Some(err)) => err,
1464 Ok(_) => io::Error::other("missing error"),
1465 Err(err) => err,
1466 }
1467}
1468
1469fn log_connect_error(err: &io::Error) {
1474 match err.kind() {
1475 ErrorKind::ConnectionRefused => {
1476 tracing::debug!(error = err as &dyn std::error::Error, "connect refused");
1477 }
1478 ErrorKind::NetworkUnreachable | ErrorKind::HostUnreachable => {
1479 tracing::debug!(
1481 error = err as &dyn std::error::Error,
1482 "connect failed, unreachable"
1483 );
1484 }
1485 _ => {
1486 tracelimit::warn_ratelimited!(error = err as &dyn std::error::Error, "connect failed");
1487 }
1488 }
1489}
1490
1491fn is_connect_incomplete_error(err: &io::Error) -> bool {
1492 if err.kind() == ErrorKind::WouldBlock {
1493 return true;
1494 }
1495 #[cfg(unix)]
1497 if err.raw_os_error() == Some(libc::EINPROGRESS) {
1498 return true;
1499 }
1500 false
1501}
1502
1503fn seq_min<const N: usize>(seqs: [TcpSeqNumber; N]) -> TcpSeqNumber {
1510 let mut min = seqs[0];
1511 for &seq in &seqs[1..] {
1512 if min > seq {
1513 min = seq;
1514 }
1515 }
1516 min
1517}
1518
1519fn is_gateway_dns_tcp(ft: &FourTuple, params: &crate::ConsommeParams, dns_available: bool) -> bool {
1521 if !dns_available || ft.dst.port() != crate::DNS_PORT {
1522 return false;
1523 }
1524 match ft.dst.ip() {
1525 IpAddr::V4(ip) => params.gateway_ip == ip,
1526 IpAddr::V6(ip) => params.gateway_link_local_ipv6 == ip,
1527 }
1528}
1529
1530#[cfg(test)]
1531mod tests;