consomme/
tcp.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4mod assembler;
5mod ring;
6
7use super::Access;
8use super::Client;
9use super::DropReason;
10use crate::ChecksumState;
11use crate::ConsommeState;
12use crate::IpAddresses;
13use crate::dns_resolver::DnsResolver;
14use crate::dns_resolver::dns_tcp::DnsTcpHandler;
15use futures::AsyncRead;
16use futures::AsyncWrite;
17use inspect::Inspect;
18use inspect::InspectMut;
19use pal_async::interest::PollEvents;
20use pal_async::socket::PollReady;
21use pal_async::socket::PolledSocket;
22use smoltcp::phy::ChecksumCapabilities;
23use smoltcp::wire::ETHERNET_HEADER_LEN;
24use smoltcp::wire::EthernetFrame;
25use smoltcp::wire::EthernetProtocol;
26use smoltcp::wire::IPV4_HEADER_LEN;
27use smoltcp::wire::IPV6_HEADER_LEN;
28use smoltcp::wire::IpAddress;
29use smoltcp::wire::IpProtocol;
30use smoltcp::wire::IpRepr;
31use smoltcp::wire::Ipv4Packet;
32use smoltcp::wire::Ipv6Packet;
33use smoltcp::wire::TcpControl;
34use smoltcp::wire::TcpPacket;
35use smoltcp::wire::TcpRepr;
36use smoltcp::wire::TcpSeqNumber;
37use socket2::Domain;
38use socket2::Protocol;
39use socket2::SockAddr;
40use socket2::Socket;
41use socket2::Type;
42use std::collections::HashMap;
43use std::collections::hash_map;
44use std::io;
45use std::io::ErrorKind;
46use std::io::IoSlice;
47use std::io::IoSliceMut;
48use std::net::IpAddr;
49use std::net::Ipv4Addr;
50use std::net::Ipv6Addr;
51use std::net::Shutdown;
52use std::net::SocketAddr;
53use std::net::SocketAddrV4;
54use std::net::SocketAddrV6;
55use std::pin::Pin;
56use std::task::Context;
57use std::task::Poll;
58use thiserror::Error;
59
60#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
61struct FourTuple {
62    src: SocketAddr,
63    dst: SocketAddr,
64}
65
66impl core::fmt::Display for FourTuple {
67    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
68        write!(f, "{}-{}", self.src, self.dst)
69    }
70}
71
72#[derive(InspectMut)]
73pub(crate) struct Tcp {
74    #[inspect(iter_by_key)]
75    connections: HashMap<FourTuple, TcpConnection>,
76    #[inspect(iter_by_key)]
77    listeners: HashMap<u16, TcpListener>,
78    #[inspect(mut)]
79    connection_params: ConnectionParams,
80}
81
82#[derive(InspectMut)]
83struct ConnectionParams {
84    #[inspect(mut)]
85    rx_buffer_size: usize,
86    #[inspect(mut)]
87    tx_buffer_size: usize,
88}
89
90#[derive(Debug, Error)]
91pub enum TcpError {
92    #[error("still connecting")]
93    StillConnecting,
94    #[error("unacceptable segment number")]
95    Unacceptable,
96    #[error("missing ack bit")]
97    MissingAck,
98    #[error("ack newer than sequence")]
99    AckPastSequence,
100    #[error("invalid window scale")]
101    InvalidWindowScale,
102}
103
104impl Tcp {
105    pub fn new() -> Self {
106        Self {
107            connections: HashMap::new(),
108            listeners: HashMap::new(),
109            connection_params: ConnectionParams {
110                rx_buffer_size: 256 * 1024,
111                tx_buffer_size: 256 * 1024,
112            },
113        }
114    }
115}
116
117#[derive(Inspect)]
118#[inspect(tag = "info")]
119enum LoopbackPortInfo {
120    None,
121    ProxyForGuestPort { sending_port: u16, guest_port: u16 },
122}
123
124/// The I/O backend for a TCP connection.
125///
126/// A connection is either backed by a real host socket or a virtual DNS
127/// handler that resolves DNS queries without a real socket.
128enum TcpBackend {
129    /// A real host socket. The socket may be `None` while the connection is
130    /// being constructed, or after both ends have closed.
131    Socket(Option<PolledSocket<Socket>>),
132    /// A virtual DNS TCP handler (no real socket).
133    Dns(DnsTcpHandler),
134}
135
136#[derive(Inspect)]
137struct TcpConnection {
138    #[inspect(skip)]
139    backend: TcpBackend,
140    #[inspect(flatten)]
141    inner: TcpConnectionInner,
142}
143
144#[derive(Inspect)]
145struct TcpConnectionInner {
146    loopback_port: LoopbackPortInfo,
147    state: TcpState,
148
149    #[inspect(with = "|x| x.len()")]
150    rx_buffer: ring::Ring,
151    #[inspect(hex)]
152    rx_window_cap: usize,
153    rx_window_scale: u8,
154    #[inspect(with = "inspect_seq")]
155    rx_seq: TcpSeqNumber,
156    #[inspect(flatten)]
157    rx_assembler: assembler::Assembler,
158    needs_ack: bool,
159    is_shutdown: bool,
160    enable_window_scaling: bool,
161
162    #[inspect(with = "|x| x.len()")]
163    tx_buffer: ring::Ring,
164    #[inspect(with = "inspect_seq")]
165    tx_acked: TcpSeqNumber,
166    #[inspect(with = "inspect_seq")]
167    tx_send: TcpSeqNumber,
168    tx_fin_buffered: bool,
169    #[inspect(hex)]
170    tx_window_len: u16,
171    tx_window_scale: u8,
172    #[inspect(with = "inspect_seq")]
173    tx_window_rx_seq: TcpSeqNumber,
174    #[inspect(with = "inspect_seq")]
175    tx_window_tx_seq: TcpSeqNumber,
176    #[inspect(hex)]
177    tx_mss: usize,
178}
179
180fn inspect_seq(seq: &TcpSeqNumber) -> inspect::AsHex<u32> {
181    inspect::AsHex(seq.0 as u32)
182}
183
184#[derive(Inspect)]
185struct TcpListener {
186    #[inspect(skip)]
187    socket: PolledSocket<Socket>,
188}
189
190#[derive(Debug, PartialEq, Eq, Inspect)]
191enum TcpState {
192    Connecting,
193    SynSent,
194    SynReceived,
195    Established,
196    FinWait1,
197    FinWait2,
198    CloseWait,
199    Closing,
200    LastAck,
201    TimeWait,
202}
203
204impl TcpState {
205    fn tx_fin(&self) -> bool {
206        match self {
207            TcpState::Connecting
208            | TcpState::SynSent
209            | TcpState::SynReceived
210            | TcpState::Established
211            | TcpState::CloseWait => false,
212
213            TcpState::FinWait1
214            | TcpState::FinWait2
215            | TcpState::Closing
216            | TcpState::TimeWait
217            | TcpState::LastAck => true,
218        }
219    }
220
221    fn rx_fin(&self) -> bool {
222        match self {
223            TcpState::Connecting
224            | TcpState::SynSent
225            | TcpState::SynReceived
226            | TcpState::Established
227            | TcpState::FinWait1
228            | TcpState::FinWait2 => false,
229
230            TcpState::CloseWait | TcpState::Closing | TcpState::LastAck | TcpState::TimeWait => {
231                true
232            }
233        }
234    }
235}
236
237impl<T: Client> Access<'_, T> {
238    pub(crate) fn poll_tcp(&mut self, cx: &mut Context<'_>) {
239        // Check for any new incoming connections
240        self.inner
241            .tcp
242            .listeners
243            .retain(|port, listener| match listener.poll_listener(cx) {
244                Ok(result) => {
245                    if let Some((socket, mut other_addr)) = result {
246                        // Check for loopback requests and replace the dest port.
247                        // This supports a guest owning both the sending and receiving ports.
248                        if other_addr.ip().is_loopback() {
249                            for (other_ft, connection) in self.inner.tcp.connections.iter() {
250                                if connection.inner.state == TcpState::Connecting && other_ft.dst.port() == *port {
251                                    if let LoopbackPortInfo::ProxyForGuestPort{sending_port, guest_port} = connection.inner.loopback_port {
252                                        if sending_port == other_addr.port() {
253                                            other_addr.set_port(guest_port);
254                                            break;
255                                        }
256                                    }
257                                }
258                            }
259                        }
260
261                        let ft = match other_addr {
262                            SocketAddr::V4(_) => FourTuple {
263                                dst: other_addr,
264                                src: SocketAddr::V4(SocketAddrV4::new(self.inner.state.params.client_ip, *port)),
265                            },
266                            SocketAddr::V6(_) => {
267                                let client_ipv6 = match self.inner.state.params.client_ip_ipv6 {
268                                    Some(ip) => ip,
269                                    None => {
270                                        tracing::warn!("Received IPv6 connection but client IPv6 address is not known");
271                                        return true;
272                                    }
273                                };
274                                FourTuple {
275                                    dst: other_addr,
276                                    src: SocketAddr::V6(SocketAddrV6::new(client_ipv6, *port, 0, 0)),
277                                }
278                            }
279                        };
280
281                        match self.inner.tcp.connections.entry(ft) {
282                            hash_map::Entry::Vacant(e) => {
283                                let mut sender = Sender {
284                                    ft: &ft,
285                                    client: self.client,
286                                    state: &mut self.inner.state,
287                                };
288
289                                let conn = match TcpConnection::new_from_accept(
290                                    &mut sender,
291                                    socket,
292                                    &self.inner.tcp.connection_params,
293                                ) {
294                                    Ok(conn) => conn,
295                                    Err(err) => {
296                                        tracing::warn!(err = %err, "Failed to create connection from newly accepted socket");
297                                        return true;
298                                    }
299                                };
300                                e.insert(conn);
301                            }
302                            hash_map::Entry::Occupied(_) => {
303                                tracing::warn!(
304                                    address = ?ft.dst,
305                                    "New client request ignored because it was already connected"
306                                );
307                            }
308                        }
309                    }
310                    true
311                }
312                Err(_) => false,
313            });
314        // Check for any new incoming data
315        self.inner.tcp.connections.retain(|ft, conn| {
316            let mut sender = Sender {
317                ft,
318                state: &mut self.inner.state,
319                client: self.client,
320            };
321            match &mut conn.backend {
322                TcpBackend::Dns(dns_handler) => match &mut self.inner.dns {
323                    Some(dns) => conn
324                        .inner
325                        .poll_dns_backend(cx, &mut sender, dns_handler, dns),
326                    None => {
327                        tracing::warn!("DNS TCP connection without DNS resolver, dropping");
328                        false
329                    }
330                },
331                TcpBackend::Socket(opt_socket) => {
332                    conn.inner.poll_socket_backend(cx, &mut sender, opt_socket)
333                }
334            }
335        })
336    }
337
338    pub(crate) fn refresh_tcp_driver(&mut self) {
339        self.inner.tcp.connections.retain(|_, conn| {
340            let TcpBackend::Socket(opt_socket) = &mut conn.backend else {
341                // DNS connections have no real socket to refresh.
342                return true;
343            };
344            let Some(socket) = opt_socket.take() else {
345                return true;
346            };
347            let socket = socket.into_inner();
348            match PolledSocket::new(self.client.driver(), socket) {
349                Ok(socket) => {
350                    *opt_socket = Some(socket);
351                    true
352                }
353                Err(err) => {
354                    tracing::warn!(
355                        error = &err as &dyn std::error::Error,
356                        "failed to update driver for tcp connection"
357                    );
358                    false
359                }
360            }
361        });
362    }
363
364    pub(crate) fn handle_tcp(
365        &mut self,
366        addresses: &IpAddresses,
367        payload: &[u8],
368        checksum: &ChecksumState,
369    ) -> Result<(), DropReason> {
370        let tcp_packet = TcpPacket::new_checked(payload)?;
371        let tcp = TcpRepr::parse(
372            &tcp_packet,
373            &addresses.src_addr(),
374            &addresses.dst_addr(),
375            &checksum.caps(),
376        )?;
377
378        let ft = match addresses {
379            IpAddresses::V4(addresses) => FourTuple {
380                dst: SocketAddr::V4(SocketAddrV4::new(addresses.dst_addr, tcp.dst_port)),
381                src: SocketAddr::V4(SocketAddrV4::new(addresses.src_addr, tcp.src_port)),
382            },
383            IpAddresses::V6(addresses) => FourTuple {
384                dst: SocketAddr::V6(SocketAddrV6::new(addresses.dst_addr, tcp.dst_port, 0, 0)),
385                src: SocketAddr::V6(SocketAddrV6::new(addresses.src_addr, tcp.src_port, 0, 0)),
386            },
387        };
388        trace_tcp_packet(&tcp, tcp.payload.len(), "recv");
389
390        let is_dns_tcp =
391            is_gateway_dns_tcp(&ft, &self.inner.state.params, self.inner.dns.is_some());
392
393        let mut sender = Sender {
394            ft: &ft,
395            client: self.client,
396            state: &mut self.inner.state,
397        };
398
399        match self.inner.tcp.connections.entry(ft) {
400            hash_map::Entry::Occupied(mut e) => {
401                let keep = e.get_mut().inner.handle_packet(&mut sender, &tcp)?;
402                if !keep {
403                    let dns_in_flight = matches!(
404                        e.get().backend,
405                        TcpBackend::Dns(ref h) if h.is_in_flight()
406                    );
407                    e.remove();
408                    if dns_in_flight {
409                        if let Some(dns) = &mut self.inner.dns {
410                            dns.complete_tcp_query();
411                        }
412                    }
413                }
414            }
415            hash_map::Entry::Vacant(e) => {
416                if tcp.control == TcpControl::Rst {
417                    // This connection is already closed. Ignore the packet.
418                } else if let Some(ack) = tcp.ack_number {
419                    // This is for an old connection. Send reset.
420                    sender.rst(ack, None);
421                } else if tcp.control == TcpControl::Syn {
422                    let conn = if is_dns_tcp {
423                        TcpConnection::new_dns(
424                            &mut sender,
425                            &tcp,
426                            &self.inner.tcp.connection_params,
427                        )?
428                    } else {
429                        TcpConnection::new(&mut sender, &tcp, &self.inner.tcp.connection_params)?
430                    };
431                    e.insert(conn);
432                } else {
433                    // Ignore the packet.
434                }
435            }
436        }
437        Ok(())
438    }
439
440    /// Binds to the specified host IP and port for listening for incoming
441    /// connections.
442    pub fn bind_tcp_port(&mut self, ip_addr: Option<IpAddr>, port: u16) -> Result<(), DropReason> {
443        let ip_addr = match ip_addr {
444            Some(IpAddr::V4(ip)) => SocketAddr::V4(SocketAddrV4::new(ip, port)),
445            Some(IpAddr::V6(ip)) => SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)),
446            None => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port)),
447        };
448        match self.inner.tcp.listeners.entry(port) {
449            hash_map::Entry::Occupied(_) => {
450                tracing::warn!(port, "Duplicate TCP bind for port");
451            }
452            hash_map::Entry::Vacant(e) => {
453                let ft = match ip_addr {
454                    SocketAddr::V4(ip) => FourTuple {
455                        dst: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
456                        src: SocketAddr::V4(ip),
457                    },
458                    SocketAddr::V6(ip) => FourTuple {
459                        dst: SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)),
460                        src: SocketAddr::V6(ip),
461                    },
462                };
463                let mut sender = Sender {
464                    ft: &ft,
465                    client: self.client,
466                    state: &mut self.inner.state,
467                };
468
469                let listener = TcpListener::new(&mut sender)?;
470                e.insert(listener);
471            }
472        }
473        Ok(())
474    }
475
476    /// Unbinds from the specified host port.
477    pub fn unbind_tcp_port(&mut self, port: u16) -> Result<(), DropReason> {
478        match self.inner.tcp.listeners.entry(port) {
479            hash_map::Entry::Occupied(e) => {
480                e.remove();
481                Ok(())
482            }
483            hash_map::Entry::Vacant(_) => Err(DropReason::PortNotBound),
484        }
485    }
486}
487
488struct Sender<'a, T> {
489    ft: &'a FourTuple,
490    client: &'a mut T,
491    state: &'a mut ConsommeState,
492}
493
494impl<T: Client> Sender<'_, T> {
495    fn send_packet(&mut self, tcp: &TcpRepr<'_>, payload: Option<ring::View<'_>>) {
496        let buffer = &mut self.state.buffer;
497        let mut eth_packet = EthernetFrame::new_unchecked(&mut buffer[..]);
498        eth_packet.set_dst_addr(self.state.params.client_mac);
499        eth_packet.set_src_addr(self.state.params.gateway_mac);
500        let ip = IpRepr::new(
501            self.ft.dst.ip().into(),
502            self.ft.src.ip().into(),
503            IpProtocol::Tcp,
504            tcp.header_len() + payload.as_ref().map_or(0, |p| p.len()),
505            64,
506        );
507        // Set the ethernet type based on IP version
508        match ip {
509            IpRepr::Ipv4(_) => eth_packet.set_ethertype(EthernetProtocol::Ipv4),
510            IpRepr::Ipv6(_) => eth_packet.set_ethertype(EthernetProtocol::Ipv6),
511        }
512
513        // Emit IP packet and get the TCP payload buffer (works for both IPv4 and IPv6)
514        let ip_packet_buf = eth_packet.payload_mut();
515        ip.emit(&mut *ip_packet_buf, &ChecksumCapabilities::default());
516
517        let (tcp_payload_buf, ip_total_len) = match self.ft.dst {
518            SocketAddr::V4(_) => {
519                let ipv4_packet = Ipv4Packet::new_unchecked(&*ip_packet_buf);
520                let total_len = ipv4_packet.total_len() as usize;
521                let payload_offset = ipv4_packet.header_len() as usize;
522                (&mut ip_packet_buf[payload_offset..total_len], total_len)
523            }
524            SocketAddr::V6(_) => {
525                let ipv6_packet = Ipv6Packet::new_unchecked(&*ip_packet_buf);
526                let total_len = ipv6_packet.total_len();
527                let payload_offset = IPV6_HEADER_LEN;
528                (&mut ip_packet_buf[payload_offset..total_len], total_len)
529            }
530        };
531
532        let dst_ip_addr: IpAddress = self.ft.dst.ip().into();
533        let src_ip_addr: IpAddress = self.ft.src.ip().into();
534        let mut tcp_packet = TcpPacket::new_unchecked(tcp_payload_buf);
535        tcp.emit(
536            &mut tcp_packet,
537            &dst_ip_addr,
538            &src_ip_addr,
539            &ChecksumCapabilities::default(),
540        );
541
542        // Copy payload into TCP packet
543        if let Some(payload) = &payload {
544            payload.copy_to_slice(tcp_packet.payload_mut());
545        }
546        tcp_packet.fill_checksum(&self.ft.dst.ip().into(), &self.ft.src.ip().into());
547        let n = ETHERNET_HEADER_LEN + ip_total_len;
548        let checksum_state = match self.ft.dst {
549            SocketAddr::V4(_) => ChecksumState::TCP4,
550            SocketAddr::V6(_) => ChecksumState::TCP6,
551        };
552
553        self.client.recv(&buffer[..n], &checksum_state);
554    }
555
556    fn rst(&mut self, seq: TcpSeqNumber, ack: Option<TcpSeqNumber>) {
557        let tcp = TcpRepr {
558            src_port: self.ft.dst.port(),
559            dst_port: self.ft.src.port(),
560            control: TcpControl::Rst,
561            seq_number: seq,
562            ack_number: ack,
563            window_len: 0,
564            window_scale: None,
565            max_seg_size: None,
566            sack_permitted: false,
567            sack_ranges: [None, None, None],
568            timestamp: None,
569            payload: &[],
570        };
571
572        trace_tcp_packet(&tcp, 0, "rst xmit");
573
574        self.send_packet(&tcp, None);
575    }
576}
577
578impl TcpConnection {
579    fn new_base(params: &ConnectionParams) -> TcpConnectionInner {
580        let mut rx_tx_seq = [0; 8];
581        getrandom::fill(&mut rx_tx_seq[..]).expect("prng failure");
582        let rx_seq = TcpSeqNumber(i32::from_ne_bytes(
583            rx_tx_seq[0..4].try_into().expect("invalid length"),
584        ));
585        let tx_seq = TcpSeqNumber(i32::from_ne_bytes(
586            rx_tx_seq[4..8].try_into().expect("invalid length"),
587        ));
588
589        let rx_buffer_size: usize = params.rx_buffer_size.clamp(16384, 4 << 20);
590        let rx_window_scale =
591            (usize::BITS - rx_buffer_size.leading_zeros()).saturating_sub(16) as u8;
592
593        let tx_buffer_size = params
594            .tx_buffer_size
595            .clamp(16384, 4 << 20)
596            .next_power_of_two();
597
598        TcpConnectionInner {
599            loopback_port: LoopbackPortInfo::None,
600            state: TcpState::Connecting,
601            rx_buffer: ring::Ring::new(0),
602            rx_window_cap: rx_buffer_size,
603            rx_window_scale,
604            rx_seq,
605            rx_assembler: assembler::Assembler::new(),
606            needs_ack: false,
607            is_shutdown: false,
608            enable_window_scaling: false,
609            tx_buffer: ring::Ring::new(tx_buffer_size),
610            tx_acked: tx_seq,
611            tx_send: tx_seq,
612            tx_window_len: 1,
613            tx_window_scale: 0,
614            tx_window_rx_seq: rx_seq,
615            tx_window_tx_seq: tx_seq,
616            // The TCPv4 default maximum segment size is 536. This can be bigger for
617            // IPv6.
618            tx_mss: 536,
619            tx_fin_buffered: false,
620        }
621    }
622
623    fn new(
624        sender: &mut Sender<'_, impl Client>,
625        tcp: &TcpRepr<'_>,
626        params: &ConnectionParams,
627    ) -> Result<Self, DropReason> {
628        let mut inner = Self::new_base(params);
629        inner.initialize_from_first_client_packet(tcp)?;
630
631        let socket = Socket::new(
632            match sender.ft.dst {
633                SocketAddr::V4(_) => Domain::IPV4,
634                SocketAddr::V6(_) => Domain::IPV6,
635            },
636            Type::STREAM,
637            Some(Protocol::TCP),
638        )
639        .map_err(DropReason::Io)?;
640
641        // On Windows the default behavior for non-existent loopback sockets is
642        // to wait and try again. This is different than the Linux behavior of
643        // immediately failing. Default to the Linux behavior.
644        #[cfg(windows)]
645        if sender.ft.dst.ip().is_loopback() {
646            if let Err(err) = crate::windows::disable_connection_retries(&socket) {
647                tracing::trace!(err, "Failed to disable loopback retries");
648            }
649        }
650
651        let socket = PolledSocket::new(sender.client.driver(), socket).map_err(DropReason::Io)?;
652        match socket.get().connect(&SockAddr::from(sender.ft.dst)) {
653            Ok(_) => unreachable!(),
654            Err(err) if is_connect_incomplete_error(&err) => (),
655            Err(err) => {
656                log_connect_error(&err);
657                sender.rst(TcpSeqNumber(0), Some(tcp.seq_number + tcp.segment_len()));
658                return Err(DropReason::Io(err));
659            }
660        }
661        if let Ok(addr) = socket.get().local_addr() {
662            match addr.as_socket() {
663                None => {
664                    tracing::warn!("unable to get local socket address");
665                }
666                Some(addr) => {
667                    if addr.ip().is_loopback() {
668                        inner.loopback_port = LoopbackPortInfo::ProxyForGuestPort {
669                            sending_port: addr.port(),
670                            guest_port: sender.ft.src.port(),
671                        };
672                    }
673                }
674            }
675        }
676        Ok(Self {
677            backend: TcpBackend::Socket(Some(socket)),
678            inner,
679        })
680    }
681
682    fn new_from_accept(
683        sender: &mut Sender<'_, impl Client>,
684        socket: Socket,
685        params: &ConnectionParams,
686    ) -> Result<Self, DropReason> {
687        let mut inner = TcpConnectionInner {
688            state: TcpState::SynSent,
689            ..Self::new_base(params)
690        };
691        inner.send_syn(sender, None);
692        Ok(Self {
693            backend: TcpBackend::Socket(Some(
694                PolledSocket::new(sender.client.driver(), socket).map_err(DropReason::Io)?,
695            )),
696            inner,
697        })
698    }
699
700    /// Create a virtual DNS TCP connection (no real host socket).
701    /// The connection completes the TCP handshake with the guest and
702    /// routes DNS queries through the provided resolver backend.
703    fn new_dns(
704        sender: &mut Sender<'_, impl Client>,
705        tcp: &TcpRepr<'_>,
706        params: &ConnectionParams,
707    ) -> Result<Self, DropReason> {
708        let mut inner = Self::new_base(params);
709        inner.initialize_from_first_client_packet(tcp)?;
710
711        let flow = crate::dns_resolver::DnsFlow {
712            src_addr: sender.ft.src.ip().into(),
713            dst_addr: sender.ft.dst.ip().into(),
714            src_port: sender.ft.src.port(),
715            dst_port: sender.ft.dst.port(),
716            gateway_mac: sender.state.params.gateway_mac,
717            client_mac: sender.state.params.client_mac,
718            transport: crate::dns_resolver::DnsTransport::Tcp,
719        };
720
721        // Immediately transition to SynReceived so the handshake SYN-ACK is sent.
722        inner.state = TcpState::SynReceived;
723        inner.send_syn(sender, Some(inner.rx_seq));
724
725        Ok(Self {
726            backend: TcpBackend::Dns(DnsTcpHandler::new(flow)),
727            inner,
728        })
729    }
730}
731
732impl TcpConnectionInner {
733    fn initialize_from_first_client_packet(&mut self, tcp: &TcpRepr<'_>) -> Result<(), DropReason> {
734        // The TCPv4 default maximum segment size is 536. This can be bigger for
735        // IPv6.
736        let tx_mss = tcp.max_seg_size.map_or(536, |x| x.into());
737
738        if let Some(tx_window_scale) = tcp.window_scale {
739            if tx_window_scale > 14 {
740                return Err(TcpError::InvalidWindowScale.into());
741            }
742            self.enable_window_scaling = true;
743            self.tx_window_scale = tx_window_scale;
744        } else {
745            // Disable rx window scale. Cap the buffer and window to u16::MAX
746            // since without window scaling, the window field is only 16 bits.
747            self.enable_window_scaling = false;
748            self.rx_window_cap = self.rx_window_cap.min(u16::MAX as usize);
749            self.rx_window_scale = 0;
750        }
751
752        self.rx_buffer = ring::Ring::new(self.rx_window_cap.next_power_of_two());
753        self.rx_seq = tcp.seq_number + 1;
754        self.tx_window_rx_seq = tcp.seq_number + 1;
755        self.tx_mss = tx_mss;
756        Ok(())
757    }
758
759    /// Poll the DNS TCP virtual connection backend.
760    ///
761    /// There is no real socket; data flows through the [`DnsTcpHandler`].
762    fn poll_dns_backend(
763        &mut self,
764        cx: &mut Context<'_>,
765        sender: &mut Sender<'_, impl Client>,
766        dns_handler: &mut DnsTcpHandler,
767        dns: &mut DnsResolver,
768    ) -> bool {
769        // Propagate guest FIN before the tx path so that poll_read can
770        // detect EOF on the same iteration.
771        if self.state.rx_fin() && !dns_handler.guest_fin() {
772            dns_handler.set_guest_fin();
773        }
774
775        // tx path first: drain DNS responses into tx_buffer.
776        // This frees up backpressure so that ingest can make progress.
777        while !self.tx_buffer.is_full() {
778            let (a, b) = self.tx_buffer.unwritten_slices_mut();
779            let mut bufs = [IoSliceMut::new(a), IoSliceMut::new(b)];
780            match dns_handler.poll_read(cx, &mut bufs, dns) {
781                Poll::Ready(Ok(n)) => {
782                    if n == 0 {
783                        // EOF — close the connection.
784                        if !self.state.tx_fin() {
785                            self.close();
786                        }
787                        break;
788                    }
789                    self.tx_buffer.extend_by(n);
790                }
791                Poll::Ready(Err(_)) => {
792                    sender.rst(self.tx_send, Some(self.rx_seq));
793                    return false;
794                }
795                Poll::Pending => break,
796            }
797        }
798
799        // rx path: feed guest data into the DNS handler for query extraction.
800        let view = self.rx_buffer.view(0..self.rx_buffer.len());
801        let (a, b) = view.as_slices();
802        match dns_handler.ingest(&[a, b], dns) {
803            Ok(consumed) if consumed > 0 => {
804                self.rx_buffer.consume(consumed);
805            }
806            Ok(_) => {}
807            Err(_) => {
808                // Invalid DNS TCP framing; reset the connection.
809                sender.rst(self.tx_send, Some(self.rx_seq));
810                return false;
811            }
812        }
813
814        self.send_next(sender);
815        !(self.state == TcpState::TimeWait
816            || self.state == TcpState::LastAck
817            || (self.state.tx_fin() && self.state.rx_fin() && self.tx_buffer.is_empty()))
818    }
819
820    /// Poll the real-socket TCP connection backend.
821    ///
822    /// Reads data from the host socket into the tx buffer (host -> guest) and
823    /// writes guest rx data into the host socket (guest -> host).
824    fn poll_socket_backend(
825        &mut self,
826        cx: &mut Context<'_>,
827        sender: &mut Sender<'_, impl Client>,
828        opt_socket: &mut Option<PolledSocket<Socket>>,
829    ) -> bool {
830        // Wait for the outbound connection to complete.
831        if self.state == TcpState::Connecting {
832            let Some(socket) = opt_socket.as_mut() else {
833                return false;
834            };
835            match socket.poll_ready(cx, PollEvents::OUT) {
836                Poll::Ready(r) => {
837                    if r.has_err() {
838                        self.handle_connect_error(sender, socket);
839                        return false;
840                    }
841
842                    tracing::debug!("connection established");
843                    self.state = TcpState::SynReceived;
844                }
845                Poll::Pending => return true,
846            }
847        } else if self.state == TcpState::SynSent {
848            // Need to establish connection with client before sending data.
849            return true;
850        }
851
852        // Handle the tx path.
853        if let Some(socket) = opt_socket.as_mut() {
854            if self.state.tx_fin() {
855                if let Poll::Ready(events) = socket.poll_ready(cx, PollEvents::EMPTY) {
856                    if events.has_err() {
857                        let err = take_socket_error(socket);
858                        match err.kind() {
859                            ErrorKind::BrokenPipe | ErrorKind::ConnectionReset => {}
860                            _ => tracelimit::warn_ratelimited!(
861                                error = &err as &dyn std::error::Error,
862                                "socket failure after fin"
863                            ),
864                        }
865                        sender.rst(self.tx_send, Some(self.rx_seq));
866                        return false;
867                    }
868
869                    // Both ends are closed. Close the actual socket.
870                    *opt_socket = None;
871                }
872            } else {
873                while !self.tx_buffer.is_full() {
874                    let (a, b) = self.tx_buffer.unwritten_slices_mut();
875                    let mut bufs = [IoSliceMut::new(a), IoSliceMut::new(b)];
876                    match Pin::new(&mut *socket).poll_read_vectored(cx, &mut bufs) {
877                        Poll::Ready(Ok(n)) => {
878                            if n == 0 {
879                                self.close();
880                                break;
881                            }
882                            self.tx_buffer.extend_by(n);
883                        }
884                        Poll::Ready(Err(err)) => {
885                            match err.kind() {
886                                ErrorKind::ConnectionReset => tracing::trace!(
887                                    error = &err as &dyn std::error::Error,
888                                    "socket read error"
889                                ),
890                                _ => tracelimit::warn_ratelimited!(
891                                    error = &err as &dyn std::error::Error,
892                                    "socket read error"
893                                ),
894                            }
895                            sender.rst(self.tx_send, Some(self.rx_seq));
896                            return false;
897                        }
898                        Poll::Pending => break,
899                    }
900                }
901            }
902        }
903
904        // Handle the rx path.
905        if let Some(socket) = opt_socket.as_mut() {
906            while !self.rx_buffer.is_empty() {
907                let view = self.rx_buffer.view(0..self.rx_buffer.len());
908                let (a, b) = view.as_slices();
909                let bufs = [IoSlice::new(a), IoSlice::new(b)];
910                match Pin::new(&mut *socket).poll_write_vectored(cx, &bufs) {
911                    Poll::Ready(Ok(n)) => {
912                        self.rx_buffer.consume(n);
913                    }
914                    Poll::Ready(Err(err)) => {
915                        match err.kind() {
916                            ErrorKind::BrokenPipe | ErrorKind::ConnectionReset => {}
917                            _ => {
918                                tracelimit::warn_ratelimited!(
919                                    error = &err as &dyn std::error::Error,
920                                    "socket write error"
921                                );
922                            }
923                        }
924                        sender.rst(self.tx_send, Some(self.rx_seq));
925                        return false;
926                    }
927                    Poll::Pending => break,
928                }
929            }
930            if self.rx_buffer.is_empty() && self.state.rx_fin() && !self.is_shutdown {
931                if let Err(err) = socket.get().shutdown(Shutdown::Write) {
932                    tracelimit::warn_ratelimited!(
933                        error = &err as &dyn std::error::Error,
934                        "shutdown error"
935                    );
936                    sender.rst(self.tx_send, Some(self.rx_seq));
937                    return false;
938                }
939                self.is_shutdown = true;
940            }
941        }
942
943        // Send whatever needs to be sent.
944        self.send_next(sender);
945        true
946    }
947
948    fn handle_connect_error(
949        &mut self,
950        sender: &mut Sender<'_, impl Client>,
951        socket: &mut PolledSocket<Socket>,
952    ) {
953        let err = take_socket_error(socket);
954        if err.kind() == ErrorKind::TimedOut {
955            // Avoid resetting so that the guest doesn't think there is a
956            // responding TCP stack at this address. The guest will time out on
957            // its own.
958            tracing::debug!(error = &err as &dyn std::error::Error, "connect timed out");
959        } else {
960            log_connect_error(&err);
961            sender.rst(self.tx_send, Some(self.rx_seq));
962        }
963    }
964
965    fn rx_window_len(&self) -> u16 {
966        ((self.rx_window_cap - self.rx_buffer.len()) >> self.rx_window_scale) as u16
967    }
968
969    fn send_next(&mut self, sender: &mut Sender<'_, impl Client>) {
970        match self.state {
971            TcpState::Connecting => {}
972            TcpState::SynReceived => self.send_syn(sender, Some(self.rx_seq)),
973            _ => self.send_data(sender),
974        }
975    }
976
977    fn send_syn(&mut self, sender: &mut Sender<'_, impl Client>, ack_number: Option<TcpSeqNumber>) {
978        if self.tx_send != self.tx_acked || sender.client.rx_mtu() == 0 {
979            return;
980        }
981
982        // If the client side specified a window scale option, then do the same
983        // (even with no shift) to enable window scale support.
984        let window_scale = self.enable_window_scaling.then_some(self.rx_window_scale);
985
986        // Advertise the maximum possible segment size, allowing the guest
987        // to truncate this to its own MTU calculation.
988        let max_seg_size = u16::MAX;
989        let tcp = TcpRepr {
990            src_port: sender.ft.dst.port(),
991            dst_port: sender.ft.src.port(),
992            control: TcpControl::Syn,
993            seq_number: self.tx_send,
994            ack_number,
995            window_len: if ack_number.is_some() {
996                self.rx_window_len()
997            } else {
998                0
999            },
1000            window_scale,
1001            max_seg_size: Some(max_seg_size),
1002            sack_permitted: false,
1003            sack_ranges: [None, None, None],
1004            timestamp: None,
1005            payload: &[],
1006        };
1007
1008        sender.send_packet(&tcp, None);
1009        self.tx_send += 1;
1010    }
1011
1012    fn send_data(&mut self, sender: &mut Sender<'_, impl Client>) {
1013        // These computations assume syn has already been sent and acked.
1014        let tx_payload_end = self.tx_acked + self.tx_buffer.len();
1015        let tx_end = tx_payload_end + self.tx_fin_buffered as usize;
1016        let tx_window_end = self.tx_acked + ((self.tx_window_len as usize) << self.tx_window_scale);
1017        let tx_done = seq_min([tx_end, tx_window_end]);
1018
1019        while self.needs_ack || self.tx_send < tx_done {
1020            let rx_mtu = sender.client.rx_mtu();
1021            if rx_mtu == 0 {
1022                // Out of receive buffers.
1023                break;
1024            }
1025
1026            let mut tcp = TcpRepr {
1027                src_port: sender.ft.dst.port(),
1028                dst_port: sender.ft.src.port(),
1029                control: TcpControl::None,
1030                seq_number: self.tx_send,
1031                ack_number: Some(self.rx_seq),
1032                window_len: self.rx_window_len(),
1033                window_scale: None,
1034                max_seg_size: None,
1035                sack_permitted: false,
1036                sack_ranges: [None, None, None],
1037                timestamp: None,
1038                payload: &[],
1039            };
1040
1041            let mut tx_next = self.tx_send;
1042
1043            // Compute the end of the segment buffer in sequence space to avoid
1044            // exceeding:
1045            // 1. The available buffer length.
1046            // 2. The current window.
1047            // 3. The configured maximum segment size.
1048            // 4. The client MTU.
1049            let tx_segment_end = {
1050                let ip_header_len = match sender.ft.dst {
1051                    SocketAddr::V4(_) => IPV4_HEADER_LEN,
1052                    SocketAddr::V6(_) => IPV6_HEADER_LEN,
1053                };
1054                let header_len = ETHERNET_HEADER_LEN + ip_header_len + tcp.header_len();
1055                let mtu = rx_mtu.min(sender.state.buffer.len());
1056                seq_min([
1057                    tx_payload_end,
1058                    tx_window_end,
1059                    tx_next + self.tx_mss,
1060                    tx_next + (mtu - header_len),
1061                ])
1062            };
1063
1064            let (payload_start, payload_len) = if tx_next < tx_segment_end {
1065                (tx_next - self.tx_acked, tx_segment_end - tx_next)
1066            } else {
1067                (0, 0)
1068            };
1069
1070            tx_next += payload_len;
1071
1072            // Include the fin if present if there is still room.
1073            if self.tx_fin_buffered
1074                && tcp.control == TcpControl::None
1075                && tx_next == tx_payload_end
1076                && tx_next < tx_window_end
1077            {
1078                tcp.control = TcpControl::Fin;
1079                tx_next += 1;
1080            }
1081
1082            assert!(tx_next <= tx_end);
1083            assert!(self.needs_ack || tx_next > self.tx_send);
1084
1085            trace_tcp_packet(&tcp, payload_len, "xmit");
1086
1087            let payload = self
1088                .tx_buffer
1089                .view(payload_start..payload_start + payload_len);
1090
1091            sender.send_packet(&tcp, Some(payload));
1092            self.tx_send = tx_next;
1093            self.needs_ack = false;
1094        }
1095
1096        assert!(self.tx_send <= tx_end);
1097    }
1098
1099    fn close(&mut self) {
1100        tracing::trace!("fin");
1101        match self.state {
1102            TcpState::SynSent | TcpState::SynReceived | TcpState::Established => {
1103                self.state = TcpState::FinWait1;
1104            }
1105            TcpState::CloseWait => {
1106                self.state = TcpState::LastAck;
1107            }
1108            TcpState::Connecting
1109            | TcpState::FinWait1
1110            | TcpState::FinWait2
1111            | TcpState::Closing
1112            | TcpState::TimeWait
1113            | TcpState::LastAck => unreachable!("fin in {:?}", self.state),
1114        }
1115        self.tx_fin_buffered = true;
1116    }
1117
1118    /// Send an ACK using the current state of the connection.
1119    ///
1120    /// This is used when sending an ack to report a the reception of an
1121    /// unacceptable packet (duplicate, out of order, etc.). These acks
1122    /// shouldn't be combined with data so that they are interpreted correctly
1123    /// by the peer.
1124    fn ack(&self, sender: &mut Sender<'_, impl Client>) {
1125        let tcp = TcpRepr {
1126            src_port: sender.ft.dst.port(),
1127            dst_port: sender.ft.src.port(),
1128            control: TcpControl::None,
1129            seq_number: self.tx_send,
1130            ack_number: Some(self.rx_seq),
1131            window_len: self.rx_window_len(),
1132            window_scale: None,
1133            max_seg_size: None,
1134            sack_permitted: false,
1135            sack_ranges: [None, None, None],
1136            timestamp: None,
1137            payload: &[],
1138        };
1139
1140        trace_tcp_packet(&tcp, 0, "ack");
1141
1142        sender.send_packet(&tcp, None);
1143    }
1144
1145    fn handle_listen_syn(
1146        &mut self,
1147        sender: &mut Sender<'_, impl Client>,
1148        tcp: &TcpRepr<'_>,
1149    ) -> Result<bool, DropReason> {
1150        if tcp.control != TcpControl::Syn || tcp.segment_len() != 1 {
1151            tracing::error!(?tcp.control, "invalid packet waiting for syn, drop connection");
1152            return Ok(false);
1153        }
1154
1155        let ack_number = tcp.ack_number.ok_or(TcpError::MissingAck)?;
1156        if ack_number <= self.tx_acked || ack_number > self.tx_send {
1157            sender.rst(ack_number, None);
1158            return Ok(false);
1159        }
1160        self.tx_acked = ack_number;
1161
1162        self.initialize_from_first_client_packet(tcp)?;
1163        self.tx_window_tx_seq = ack_number;
1164        self.tx_window_len = tcp.window_len;
1165
1166        // Send an ACK to complete the initial SYN handshake.
1167        self.ack(sender);
1168
1169        self.state = TcpState::Established;
1170        Ok(true)
1171    }
1172
1173    fn handle_packet(
1174        &mut self,
1175        sender: &mut Sender<'_, impl Client>,
1176        tcp: &TcpRepr<'_>,
1177    ) -> Result<bool, DropReason> {
1178        if self.state == TcpState::Connecting {
1179            // We have not yet sent a syn (we are still deciding whether we are
1180            // in LISTEN or CLOSED state), so we can't send a reasonable
1181            // response to this. Just drop the packet.
1182            return Err(TcpError::StillConnecting.into());
1183        } else if self.state == TcpState::SynSent {
1184            return self.handle_listen_syn(sender, tcp);
1185        }
1186
1187        let rx_window_len = self.rx_window_cap - self.rx_buffer.len();
1188        let rx_window_end = self.rx_seq + rx_window_len;
1189        let segment_end = tcp.seq_number + tcp.segment_len();
1190
1191        // Validate the sequence number per RFC 793.
1192        let seq_acceptable = if rx_window_len != 0 {
1193            (tcp.seq_number >= self.rx_seq && tcp.seq_number < rx_window_end)
1194                || (tcp.segment_len() > 0
1195                    && segment_end > self.rx_seq
1196                    && segment_end <= rx_window_end)
1197        } else {
1198            tcp.segment_len() == 0 && tcp.seq_number == self.rx_seq
1199        };
1200
1201        if tcp.control == TcpControl::Rst {
1202            if !seq_acceptable {
1203                // Silently drop--don't send an ACK--since the peer would then
1204                // immediately respond with a valid RST.
1205                return Err(TcpError::Unacceptable.into());
1206            }
1207
1208            // RFC 5961
1209            if tcp.seq_number != self.rx_seq {
1210                // Send a challenge ACK.
1211                self.ack(sender);
1212                return Ok(true);
1213            }
1214
1215            // This is a valid RST. Drop the connection.
1216            tracing::debug!("connection reset");
1217            return Ok(false);
1218        }
1219
1220        // Send ack and drop packets with unacceptable sequence numbers.
1221        if !seq_acceptable {
1222            self.ack(sender);
1223            return Err(TcpError::Unacceptable.into());
1224        }
1225
1226        // SYN should not be set for in-window segments.
1227        if tcp.control == TcpControl::Syn {
1228            if self.state == TcpState::SynReceived {
1229                tracing::debug!("invalid syn, drop connection");
1230                return Ok(false);
1231            }
1232            // RFC 5961, send a challenge ACK.
1233            self.ack(sender);
1234            return Ok(true);
1235        }
1236
1237        // ACK should always be set at this point.
1238        let ack_number = tcp.ack_number.ok_or(TcpError::MissingAck)?;
1239
1240        // FUTURE: validate ack number per RFC 5961.
1241
1242        // Handle ACK of our SYN.
1243        if self.state == TcpState::SynReceived {
1244            if ack_number <= self.tx_acked || ack_number > self.tx_send {
1245                sender.rst(ack_number, None);
1246                return Ok(false);
1247            }
1248            self.tx_window_len = tcp.window_len;
1249            self.tx_window_rx_seq = tcp.seq_number;
1250            self.tx_window_tx_seq = ack_number;
1251            self.tx_acked += 1;
1252            self.state = TcpState::Established;
1253        }
1254
1255        // Ignore ACKs for segments that have not been sent.
1256        if ack_number > self.tx_send {
1257            self.ack(sender);
1258            return Err(TcpError::AckPastSequence.into());
1259        }
1260
1261        // Retire the ACKed segments.
1262        if ack_number > self.tx_acked {
1263            let mut consumed = ack_number - self.tx_acked;
1264            if self.tx_fin_buffered && ack_number == self.tx_acked + self.tx_buffer.len() + 1 {
1265                self.tx_fin_buffered = false;
1266                consumed -= 1;
1267                match self.state {
1268                    TcpState::FinWait1 => self.state = TcpState::FinWait2,
1269                    TcpState::Closing => self.state = TcpState::TimeWait,
1270                    TcpState::LastAck => return Ok(false),
1271                    _ => unreachable!(),
1272                }
1273            }
1274            self.tx_buffer.consume(consumed);
1275            self.tx_acked = ack_number;
1276        }
1277
1278        // Update the send window.
1279        if ack_number >= self.tx_acked
1280            && (tcp.seq_number > self.tx_window_rx_seq
1281                || (tcp.seq_number == self.tx_window_rx_seq && ack_number >= self.tx_window_tx_seq))
1282        {
1283            self.tx_window_len = tcp.window_len;
1284            self.tx_window_rx_seq = tcp.seq_number;
1285            self.tx_window_tx_seq = ack_number;
1286        }
1287
1288        // Scope the data payload and FIN to the in-window portion of the segment.
1289        let mut fin = tcp.control == TcpControl::Fin;
1290        let segment_skip = if tcp.seq_number < self.rx_seq {
1291            self.rx_seq - tcp.seq_number
1292        } else {
1293            0
1294        };
1295        let segment_end = if segment_end > rx_window_end {
1296            fin = false;
1297            rx_window_end
1298        } else {
1299            segment_end
1300        };
1301        let payload = &tcp.payload[segment_skip..segment_end - tcp.seq_number - fin as usize];
1302
1303        let mut rx_fin = false;
1304
1305        // Process the payload.
1306        match self.state {
1307            TcpState::Connecting | TcpState::SynReceived | TcpState::SynSent => unreachable!(),
1308            TcpState::Established | TcpState::FinWait1 | TcpState::FinWait2 => {
1309                if !payload.is_empty() || fin {
1310                    // Stage 1: Compute the byte offset from the contiguous
1311                    // frontier.
1312                    //
1313                    // Safety of ring_offset: the sequence acceptance check above
1314                    // bounds the segment to rx_window_end = rx_seq + (rx_window_cap
1315                    // - rx_buffer.len()), so seq_offset + payload.len() <=
1316                    // rx_window_cap <= ring capacity.
1317                    let seq_offset = if tcp.seq_number >= self.rx_seq {
1318                        tcp.seq_number - self.rx_seq
1319                    } else {
1320                        0
1321                    };
1322                    let ring_offset = self.rx_buffer.len() + seq_offset;
1323
1324                    // Stage 2: Record the range in the assembler. Do this
1325                    // *before* writing to the ring so that rejected segments
1326                    // (TooManyGaps) don't leave stale bytes in unwritten
1327                    // ring space.
1328                    let (rx_consumed, assembler_fin, accepted) =
1329                        match self
1330                            .rx_assembler
1331                            .add(seq_offset as u32, payload.len() as u32, fin)
1332                        {
1333                            Ok(result) => (result.consumed as usize, result.fin, true),
1334                            Err(assembler::TooManyGaps) => (0, false, false),
1335                        };
1336
1337                    // Stage 3: Write payload into the ring and advance the
1338                    // contiguous frontier. Only write when the assembler
1339                    // accepted the segment.
1340                    if accepted && !payload.is_empty() {
1341                        self.rx_buffer.write_at(ring_offset, payload);
1342                    }
1343                    self.rx_buffer.extend_by(rx_consumed);
1344                    self.rx_seq += rx_consumed;
1345                    rx_fin = assembler_fin;
1346                    if rx_fin {
1347                        self.rx_seq += 1;
1348                    }
1349                }
1350                if tcp.segment_len() > 0 {
1351                    self.needs_ack = true;
1352                }
1353            }
1354            TcpState::CloseWait | TcpState::Closing | TcpState::LastAck => {}
1355            TcpState::TimeWait => {
1356                self.ack(sender);
1357                // TODO: restart timer
1358            }
1359        }
1360
1361        // Process FIN.
1362        if rx_fin {
1363            match self.state {
1364                TcpState::Connecting | TcpState::SynReceived | TcpState::SynSent => unreachable!(),
1365                TcpState::Established => {
1366                    self.state = TcpState::CloseWait;
1367                }
1368                TcpState::FinWait1 => {
1369                    self.state = TcpState::Closing;
1370                }
1371                TcpState::FinWait2 => {
1372                    self.state = TcpState::TimeWait;
1373                    // TODO: start timer
1374                }
1375                TcpState::CloseWait
1376                | TcpState::Closing
1377                | TcpState::LastAck
1378                | TcpState::TimeWait => {}
1379            }
1380        }
1381
1382        Ok(true)
1383    }
1384}
1385
1386impl TcpListener {
1387    pub fn new(sender: &mut Sender<'_, impl Client>) -> Result<Self, DropReason> {
1388        let socket = match sender.ft.src {
1389            SocketAddr::V4(_) => Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP)),
1390            SocketAddr::V6(_) => Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP)),
1391        }
1392        .map_err(DropReason::Io)?;
1393
1394        let socket = PolledSocket::new(sender.client.driver(), socket).map_err(DropReason::Io)?;
1395        if let Err(err) = socket.get().bind(&sender.ft.src.into()) {
1396            tracing::warn!(
1397                address = ?sender.ft.src,
1398                error = &err as &dyn std::error::Error,
1399                "socket bind error"
1400            );
1401            return Err(DropReason::Io(err));
1402        }
1403        if let Err(err) = socket.listen(10) {
1404            tracing::warn!(
1405                error = &err as &dyn std::error::Error,
1406                "socket listen error"
1407            );
1408            return Err(DropReason::Io(err));
1409        }
1410        Ok(Self { socket })
1411    }
1412
1413    fn poll_listener(
1414        &mut self,
1415        cx: &mut Context<'_>,
1416    ) -> Result<Option<(Socket, SocketAddr)>, DropReason> {
1417        match self.socket.poll_accept(cx) {
1418            Poll::Ready(r) => match r {
1419                Ok((socket, address)) => match address.as_socket() {
1420                    Some(addr) => Ok(Some((socket, addr))),
1421                    None => {
1422                        tracing::warn!(?address, "Unknown address from accept");
1423                        Ok(None)
1424                    }
1425                },
1426                Err(_) => {
1427                    let err = take_socket_error(&self.socket);
1428                    tracing::warn!(error = &err as &dyn std::error::Error, "listen failure");
1429                    Err(DropReason::Io(err))
1430                }
1431            },
1432            Poll::Pending => Ok(None),
1433        }
1434    }
1435}
1436
1437/// Trace a TCP packet with structured key/value fields.
1438///
1439/// Logs protocol-relevant fields (flags, seq, ack, window, payload length)
1440/// as individual tracing fields instead of dumping the full `TcpRepr` Debug
1441/// output which includes raw payload bytes.
1442fn trace_tcp_packet(tcp: &TcpRepr<'_>, payload_len: usize, label: &str) {
1443    tracing::trace!(
1444        label,
1445        flags = match tcp.control {
1446            TcpControl::Syn => Some("SYN"),
1447            TcpControl::Fin => Some("FIN"),
1448            TcpControl::Rst => Some("RST"),
1449            TcpControl::Psh => Some("PSH"),
1450            TcpControl::None => None,
1451        },
1452        seq = tcp.seq_number.0 as u32,
1453        next_seq = (tcp.seq_number.0 as u32).wrapping_add((payload_len + tcp.control.len()) as u32),
1454        ack = tcp.ack_number.map(|a| a.0 as u32),
1455        window = tcp.window_len,
1456        payload_len,
1457        "tcp packet",
1458    );
1459}
1460
1461fn take_socket_error(socket: &PolledSocket<Socket>) -> io::Error {
1462    match socket.get().take_error() {
1463        Ok(Some(err)) => err,
1464        Ok(_) => io::Error::other("missing error"),
1465        Err(err) => err,
1466    }
1467}
1468
1469/// Log a TCP connect error at the appropriate level.
1470///
1471/// Connection refused and network/host unreachable are expected failures logged
1472/// at debug level. Everything else is logged at warn.
1473fn log_connect_error(err: &io::Error) {
1474    match err.kind() {
1475        ErrorKind::ConnectionRefused => {
1476            tracing::debug!(error = err as &dyn std::error::Error, "connect refused");
1477        }
1478        ErrorKind::NetworkUnreachable | ErrorKind::HostUnreachable => {
1479            // FUTURE: send ICMP unreachable to guest
1480            tracing::debug!(
1481                error = err as &dyn std::error::Error,
1482                "connect failed, unreachable"
1483            );
1484        }
1485        _ => {
1486            tracelimit::warn_ratelimited!(error = err as &dyn std::error::Error, "connect failed");
1487        }
1488    }
1489}
1490
1491fn is_connect_incomplete_error(err: &io::Error) -> bool {
1492    if err.kind() == ErrorKind::WouldBlock {
1493        return true;
1494    }
1495    // This handles the remaining cases on Linux.
1496    #[cfg(unix)]
1497    if err.raw_os_error() == Some(libc::EINPROGRESS) {
1498        return true;
1499    }
1500    false
1501}
1502
1503/// Finds the smallest sequence number in a set. To get a coherent result, all
1504/// the sequence numbers must be known to be comparable, meaning they are all
1505/// within 2^31 bytes of each other.
1506///
1507/// This isn't just `Ord::min` or `Iterator::min` because `TcpSeqNumber`
1508/// implements `PartialOrd` but not `Ord`.
1509fn seq_min<const N: usize>(seqs: [TcpSeqNumber; N]) -> TcpSeqNumber {
1510    let mut min = seqs[0];
1511    for &seq in &seqs[1..] {
1512        if min > seq {
1513            min = seq;
1514        }
1515    }
1516    min
1517}
1518
1519/// Check if a TCP connection targets the gateway's DNS port.
1520fn is_gateway_dns_tcp(ft: &FourTuple, params: &crate::ConsommeParams, dns_available: bool) -> bool {
1521    if !dns_available || ft.dst.port() != crate::DNS_PORT {
1522        return false;
1523    }
1524    match ft.dst.ip() {
1525        IpAddr::V4(ip) => params.gateway_ip == ip,
1526        IpAddr::V6(ip) => params.gateway_link_local_ipv6 == ip,
1527    }
1528}
1529
1530#[cfg(test)]
1531mod tests;