consomme/
tcp.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4mod ring;
5
6use super::Access;
7use super::Client;
8use super::DropReason;
9use super::FourTuple;
10use super::SocketAddress;
11use crate::ChecksumState;
12use crate::ConsommeState;
13use crate::Ipv4Addresses;
14use futures::AsyncRead;
15use futures::AsyncWrite;
16use inspect::Inspect;
17use pal_async::interest::PollEvents;
18use pal_async::socket::PollReady;
19use pal_async::socket::PolledSocket;
20use smoltcp::phy::ChecksumCapabilities;
21use smoltcp::wire::ETHERNET_HEADER_LEN;
22use smoltcp::wire::EthernetFrame;
23use smoltcp::wire::EthernetProtocol;
24use smoltcp::wire::IPV4_HEADER_LEN;
25use smoltcp::wire::IpProtocol;
26use smoltcp::wire::Ipv4Packet;
27use smoltcp::wire::Ipv4Repr;
28use smoltcp::wire::TcpControl;
29use smoltcp::wire::TcpPacket;
30use smoltcp::wire::TcpRepr;
31use smoltcp::wire::TcpSeqNumber;
32use socket2::Domain;
33use socket2::Protocol;
34use socket2::SockAddr;
35use socket2::Socket;
36use socket2::Type;
37use std::collections::HashMap;
38use std::collections::VecDeque;
39use std::collections::hash_map;
40use std::io;
41use std::io::ErrorKind;
42use std::io::IoSlice;
43use std::io::IoSliceMut;
44use std::net::Ipv4Addr;
45use std::net::Shutdown;
46use std::net::SocketAddrV4;
47use std::pin::Pin;
48use std::task::Context;
49use std::task::Poll;
50use thiserror::Error;
51
52pub(crate) struct Tcp {
53    connections: HashMap<FourTuple, TcpConnection>,
54    listeners: HashMap<u16, TcpListener>,
55}
56
57#[derive(Debug, Error)]
58pub enum TcpError {
59    #[error("still connecting")]
60    StillConnecting,
61    #[error("unacceptable segment number")]
62    Unacceptable,
63    #[error("received out of order packet")]
64    OutOfOrder,
65    #[error("missing ack bit")]
66    MissingAck,
67    #[error("ack newer than sequence")]
68    AckPastSequence,
69    #[error("invalid window scale")]
70    InvalidWindowScale,
71}
72
73impl Inspect for Tcp {
74    fn inspect(&self, req: inspect::Request<'_>) {
75        let mut resp = req.respond();
76        for (addr, conn) in &self.connections {
77            resp.field(
78                &format!(
79                    "{}:{}-{}:{}",
80                    addr.src.ip, addr.src.port, addr.dst.ip, addr.dst.port
81                ),
82                conn,
83            );
84        }
85        for port in self.listeners.keys() {
86            resp.field("listening port", port);
87        }
88    }
89}
90
91impl Tcp {
92    pub fn new() -> Self {
93        Self {
94            connections: HashMap::new(),
95            listeners: HashMap::new(),
96        }
97    }
98}
99
100#[derive(Inspect)]
101#[inspect(tag = "info")]
102enum LoopbackPortInfo {
103    None,
104    ProxyForGuestPort { sending_port: u16, guest_port: u16 },
105}
106
107#[derive(Inspect)]
108struct TcpConnection {
109    #[inspect(skip)]
110    socket: Option<PolledSocket<Socket>>,
111    loopback_port: LoopbackPortInfo,
112    state: TcpState,
113
114    #[inspect(with = "|x| x.len()")]
115    rx_buffer: VecDeque<u8>,
116    #[inspect(hex)]
117    rx_window_cap: usize,
118    rx_window_scale: u8,
119    #[inspect(with = "inspect_seq")]
120    rx_seq: TcpSeqNumber,
121    needs_ack: bool,
122    is_shutdown: bool,
123    enable_window_scaling: bool,
124
125    #[inspect(with = "|x| x.len()")]
126    tx_buffer: ring::Ring,
127    #[inspect(with = "inspect_seq")]
128    tx_acked: TcpSeqNumber,
129    #[inspect(with = "inspect_seq")]
130    tx_send: TcpSeqNumber,
131    tx_fin_buffered: bool,
132    #[inspect(hex)]
133    tx_window_len: u16,
134    tx_window_scale: u8,
135    #[inspect(with = "inspect_seq")]
136    tx_window_rx_seq: TcpSeqNumber,
137    #[inspect(with = "inspect_seq")]
138    tx_window_tx_seq: TcpSeqNumber,
139    #[inspect(hex)]
140    tx_mss: usize,
141}
142
143fn inspect_seq(seq: &TcpSeqNumber) -> inspect::AsHex<u32> {
144    inspect::AsHex(seq.0 as u32)
145}
146
147#[derive(Inspect)]
148struct TcpListener {
149    #[inspect(skip)]
150    socket: PolledSocket<Socket>,
151}
152
153#[derive(Debug, PartialEq, Eq, Inspect)]
154enum TcpState {
155    Connecting,
156    SynSent,
157    SynReceived,
158    Established,
159    FinWait1,
160    FinWait2,
161    CloseWait,
162    Closing,
163    LastAck,
164    TimeWait,
165}
166
167impl TcpState {
168    fn tx_fin(&self) -> bool {
169        match self {
170            TcpState::Connecting
171            | TcpState::SynSent
172            | TcpState::SynReceived
173            | TcpState::Established
174            | TcpState::CloseWait => false,
175
176            TcpState::FinWait1
177            | TcpState::FinWait2
178            | TcpState::Closing
179            | TcpState::TimeWait
180            | TcpState::LastAck => true,
181        }
182    }
183
184    fn rx_fin(&self) -> bool {
185        match self {
186            TcpState::Connecting
187            | TcpState::SynSent
188            | TcpState::SynReceived
189            | TcpState::Established
190            | TcpState::FinWait1
191            | TcpState::FinWait2 => false,
192
193            TcpState::CloseWait | TcpState::Closing | TcpState::LastAck | TcpState::TimeWait => {
194                true
195            }
196        }
197    }
198}
199
200impl<T: Client> Access<'_, T> {
201    pub(crate) fn poll_tcp(&mut self, cx: &mut Context<'_>) {
202        // Check for any new incoming connections
203        self.inner
204            .tcp
205            .listeners
206            .retain(|port, listener| match listener.poll_listener(cx) {
207                Ok(result) => {
208                    if let Some((socket, mut other_addr)) = result {
209                        // Check for loopback requests and replace the dest port.
210                        // This supports a guest owning both the sending and receiving ports.
211                        if other_addr.ip.is_loopback() {
212                            for (other_ft, connection) in self.inner.tcp.connections.iter() {
213                                if connection.state == TcpState::Connecting && other_ft.dst.port == *port {
214                                    if let LoopbackPortInfo::ProxyForGuestPort{sending_port, guest_port} = connection.loopback_port {
215                                        if sending_port == other_addr.port {
216                                            other_addr.port = guest_port;
217                                            break;
218                                        }
219                                    }
220                                }
221                            }
222                        }
223
224                        let ft = FourTuple { dst: other_addr, src: SocketAddress {
225                            ip: self.inner.state.params.client_ip,
226                            port: *port,
227                        } };
228
229                        match self.inner.tcp.connections.entry(ft) {
230                            hash_map::Entry::Vacant(e) => {
231                                let mut sender = Sender {
232                                    ft: &ft,
233                                    client: self.client,
234                                    state: &mut self.inner.state,
235                                };
236
237                                let conn = match TcpConnection::new_from_accept(
238                                    &mut sender,
239                                    socket,
240                                ) {
241                                    Ok(conn) => conn,
242                                    Err(err) => {
243                                        tracing::warn!(err = %err, "Failed to create connection from newly accepted socket");
244                                        return true;
245                                    }
246                                };
247                                e.insert(conn);
248                            }
249                            hash_map::Entry::Occupied(_) => {
250                                tracing::warn!(
251                                    address = ?ft.dst,
252                                    "New client request ignored because it was already connected"
253                                );
254                            }
255                        }
256                    }
257                    true
258                }
259                Err(_) => false,
260            });
261        // Check for any new incoming data
262        self.inner.tcp.connections.retain(|ft, conn| {
263            conn.poll_conn(
264                cx,
265                &mut Sender {
266                    ft,
267                    state: &mut self.inner.state,
268                    client: self.client,
269                },
270            )
271        })
272    }
273
274    pub(crate) fn refresh_tcp_driver(&mut self) {
275        self.inner.tcp.connections.retain(|_, conn| {
276            let Some(socket) = conn.socket.take() else {
277                return true;
278            };
279            let socket = socket.into_inner();
280            match PolledSocket::new(self.client.driver(), socket) {
281                Ok(socket) => {
282                    conn.socket = Some(socket);
283                    true
284                }
285                Err(err) => {
286                    tracing::warn!(
287                        error = &err as &dyn std::error::Error,
288                        "failed to update driver for tcp connection"
289                    );
290                    false
291                }
292            }
293        })
294    }
295
296    pub(crate) fn handle_tcp(
297        &mut self,
298        addresses: &Ipv4Addresses,
299        payload: &[u8],
300        checksum: &ChecksumState,
301    ) -> Result<(), DropReason> {
302        let tcp_packet = TcpPacket::new_checked(payload)?;
303        let tcp = TcpRepr::parse(
304            &tcp_packet,
305            &addresses.src_addr.into(),
306            &addresses.dst_addr.into(),
307            &checksum.caps(),
308        )?;
309
310        tracing::trace!(?tcp, "tcp packet");
311
312        let ft = FourTuple {
313            dst: SocketAddress {
314                ip: addresses.dst_addr,
315                port: tcp.dst_port,
316            },
317            src: SocketAddress {
318                ip: addresses.src_addr,
319                port: tcp.src_port,
320            },
321        };
322
323        let mut sender = Sender {
324            ft: &ft,
325            client: self.client,
326            state: &mut self.inner.state,
327        };
328
329        match self.inner.tcp.connections.entry(ft) {
330            hash_map::Entry::Occupied(mut e) => {
331                let conn = e.get_mut();
332                if !conn.handle_packet(&mut sender, &tcp)? {
333                    e.remove();
334                }
335            }
336            hash_map::Entry::Vacant(e) => {
337                if tcp.control == TcpControl::Rst {
338                    // This connection is already closed. Ignore the packet.
339                } else if let Some(ack) = tcp.ack_number {
340                    // This is for an old connection. Send reset.
341                    sender.rst(ack, None);
342                } else if tcp.control == TcpControl::Syn {
343                    let conn = TcpConnection::new(&mut sender, &tcp)?;
344                    e.insert(conn);
345                } else {
346                    // Ignore the packet.
347                }
348            }
349        }
350        Ok(())
351    }
352
353    /// Binds to the specified host IP and port for listening for incoming
354    /// connections.
355    pub fn bind_tcp_port(
356        &mut self,
357        ip_addr: Option<Ipv4Addr>,
358        port: u16,
359    ) -> Result<(), DropReason> {
360        match self.inner.tcp.listeners.entry(port) {
361            hash_map::Entry::Occupied(_) => {
362                tracing::warn!(port, "Duplicate TCP bind for port");
363            }
364            hash_map::Entry::Vacant(e) => {
365                let ft = FourTuple {
366                    dst: SocketAddress {
367                        ip: Ipv4Addr::UNSPECIFIED.into(),
368                        port: 0,
369                    },
370                    src: SocketAddress {
371                        ip: ip_addr.unwrap_or(Ipv4Addr::UNSPECIFIED).into(),
372                        port,
373                    },
374                };
375                let mut sender = Sender {
376                    ft: &ft,
377                    client: self.client,
378                    state: &mut self.inner.state,
379                };
380
381                let listener = TcpListener::new(&mut sender)?;
382                e.insert(listener);
383            }
384        }
385        Ok(())
386    }
387
388    /// Unbinds from the specified host port.
389    pub fn unbind_tcp_port(&mut self, port: u16) -> Result<(), DropReason> {
390        match self.inner.tcp.listeners.entry(port) {
391            hash_map::Entry::Occupied(e) => {
392                e.remove();
393                Ok(())
394            }
395            hash_map::Entry::Vacant(_) => Err(DropReason::PortNotBound),
396        }
397    }
398}
399
400struct Sender<'a, T> {
401    ft: &'a FourTuple,
402    client: &'a mut T,
403    state: &'a mut ConsommeState,
404}
405
406impl<T: Client> Sender<'_, T> {
407    fn send_packet(&mut self, tcp: &TcpRepr<'_>, payload: Option<ring::View<'_>>) {
408        let buffer = &mut self.state.buffer;
409        let mut eth_packet = EthernetFrame::new_unchecked(&mut buffer[..]);
410        eth_packet.set_ethertype(EthernetProtocol::Ipv4);
411        eth_packet.set_dst_addr(self.state.params.client_mac);
412        eth_packet.set_src_addr(self.state.params.gateway_mac);
413        let mut ipv4_packet = Ipv4Packet::new_unchecked(eth_packet.payload_mut());
414        let ipv4 = Ipv4Repr {
415            src_addr: self.ft.dst.ip,
416            dst_addr: self.ft.src.ip,
417            protocol: IpProtocol::Tcp,
418            payload_len: tcp.header_len() + payload.as_ref().map_or(0, |p| p.len()),
419            hop_limit: 64,
420        };
421        ipv4.emit(&mut ipv4_packet, &ChecksumCapabilities::default());
422        let mut tcp_packet = TcpPacket::new_unchecked(ipv4_packet.payload_mut());
423        tcp.emit(
424            &mut tcp_packet,
425            &self.ft.dst.ip.into(),
426            &self.ft.src.ip.into(),
427            &ChecksumCapabilities::default(),
428        );
429        if let Some(payload) = payload {
430            for (b, c) in tcp_packet.payload_mut().iter_mut().zip(payload.iter()) {
431                *b = *c;
432            }
433        }
434        tcp_packet.fill_checksum(&self.ft.dst.ip.into(), &self.ft.src.ip.into());
435        let n = ETHERNET_HEADER_LEN + ipv4_packet.total_len() as usize;
436        self.client.recv(&buffer[..n], &ChecksumState::TCP4);
437    }
438
439    fn rst(&mut self, seq: TcpSeqNumber, ack: Option<TcpSeqNumber>) {
440        let tcp = TcpRepr {
441            src_port: self.ft.dst.port,
442            dst_port: self.ft.src.port,
443            control: TcpControl::Rst,
444            seq_number: seq,
445            ack_number: ack,
446            window_len: 0,
447            window_scale: None,
448            max_seg_size: None,
449            sack_permitted: false,
450            sack_ranges: [None, None, None],
451            payload: &[],
452        };
453
454        tracing::trace!(?tcp, "tcp rst xmit");
455
456        self.send_packet(&tcp, None);
457    }
458}
459
460impl Default for TcpConnection {
461    fn default() -> Self {
462        let mut rx_tx_seq = [0; 8];
463        getrandom::fill(&mut rx_tx_seq[..]).expect("prng failure");
464        let rx_seq = TcpSeqNumber(i32::from_ne_bytes(
465            rx_tx_seq[0..4].try_into().expect("invalid length"),
466        ));
467        let tx_seq = TcpSeqNumber(i32::from_ne_bytes(
468            rx_tx_seq[4..8].try_into().expect("invalid length"),
469        ));
470
471        let rx_buffer_size: usize = 16384;
472        let rx_window_scale =
473            (usize::BITS - rx_buffer_size.leading_zeros()).saturating_sub(16) as u8;
474
475        let tx_buffer_size = 16384;
476
477        Self {
478            socket: None,
479            loopback_port: LoopbackPortInfo::None,
480            state: TcpState::Connecting,
481            rx_buffer: VecDeque::with_capacity(rx_buffer_size),
482            rx_window_cap: 0,
483            rx_window_scale,
484            rx_seq,
485            needs_ack: false,
486            is_shutdown: false,
487            enable_window_scaling: false,
488            tx_buffer: ring::Ring::new(tx_buffer_size),
489            tx_acked: tx_seq,
490            tx_send: tx_seq,
491            tx_window_len: 1,
492            tx_window_scale: 0,
493            tx_window_rx_seq: rx_seq,
494            tx_window_tx_seq: tx_seq,
495            // The TCPv4 default maximum segment size is 536. This can be bigger for
496            // IPv6.
497            tx_mss: 536,
498            tx_fin_buffered: false,
499        }
500    }
501}
502
503impl TcpConnection {
504    fn new(sender: &mut Sender<'_, impl Client>, tcp: &TcpRepr<'_>) -> Result<Self, DropReason> {
505        let mut this = Self::default();
506        this.initialize_from_first_client_packet(tcp)?;
507
508        let socket =
509            Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP)).map_err(DropReason::Io)?;
510
511        // On Windows the default behavior for non-existent loopback sockets is
512        // to wait and try again. This is different than the Linux behavior of
513        // immediately failing. Default to the Linux behavior.
514        #[cfg(windows)]
515        if sender.ft.dst.ip.is_loopback() {
516            if let Err(err) = crate::windows::disable_connection_retries(&socket) {
517                tracing::trace!(err, "Failed to disable loopback retries");
518            }
519        }
520
521        let socket = PolledSocket::new(sender.client.driver(), socket).map_err(DropReason::Io)?;
522        match socket
523            .get()
524            .connect(&SockAddr::from(SocketAddrV4::from(sender.ft.dst)))
525        {
526            Ok(_) => unreachable!(),
527            Err(err) if is_connect_incomplete_error(&err) => (),
528            Err(err) => {
529                tracing::warn!(
530                    error = &err as &dyn std::error::Error,
531                    "socket connect error"
532                );
533                sender.rst(TcpSeqNumber(0), Some(tcp.seq_number + tcp.segment_len()));
534                return Err(DropReason::Io(err));
535            }
536        }
537        if let Ok(addr) = socket.get().local_addr() {
538            if let Some(addr) = addr.as_socket_ipv4() {
539                if addr.ip().is_loopback() {
540                    this.loopback_port = LoopbackPortInfo::ProxyForGuestPort {
541                        sending_port: addr.port(),
542                        guest_port: sender.ft.src.port,
543                    };
544                }
545            }
546        }
547        this.socket = Some(socket);
548        Ok(this)
549    }
550
551    fn new_from_accept(
552        sender: &mut Sender<'_, impl Client>,
553        socket: Socket,
554    ) -> Result<Self, DropReason> {
555        let mut this = Self {
556            socket: Some(
557                PolledSocket::new(sender.client.driver(), socket).map_err(DropReason::Io)?,
558            ),
559            state: TcpState::SynSent,
560            ..Default::default()
561        };
562        this.send_syn(sender, None);
563        Ok(this)
564    }
565
566    fn initialize_from_first_client_packet(&mut self, tcp: &TcpRepr<'_>) -> Result<(), DropReason> {
567        // The TCPv4 default maximum segment size is 536. This can be bigger for
568        // IPv6.
569        let tx_mss = tcp.max_seg_size.map_or(536, |x| x.into());
570
571        if let Some(tx_window_scale) = tcp.window_scale {
572            if tx_window_scale > 14 {
573                return Err(TcpError::InvalidWindowScale.into());
574            }
575        }
576
577        let max_rx_buffer_size = if tcp.window_scale.is_some() {
578            u32::MAX as usize
579        } else {
580            u16::MAX as usize
581        };
582        let rx_buffer_size = 16384.min(max_rx_buffer_size);
583        let rx_window_scale =
584            (usize::BITS - rx_buffer_size.leading_zeros()).saturating_sub(16) as u8;
585
586        assert!(tcp.window_scale.is_some() || rx_window_scale == 0);
587        if self.rx_buffer.capacity() < rx_buffer_size {
588            self.rx_buffer.reserve_exact(rx_buffer_size);
589        }
590
591        self.rx_window_scale = rx_window_scale;
592        self.rx_seq = tcp.seq_number + 1;
593        self.tx_window_rx_seq = tcp.seq_number + 1;
594        self.enable_window_scaling = tcp.window_scale.is_some();
595        self.tx_window_scale = tcp.window_scale.unwrap_or(0);
596        self.tx_mss = tx_mss;
597        Ok(())
598    }
599
600    fn poll_conn(&mut self, cx: &mut Context<'_>, sender: &mut Sender<'_, impl Client>) -> bool {
601        if self.state == TcpState::Connecting {
602            match self
603                .socket
604                .as_mut()
605                .unwrap()
606                .poll_ready(cx, PollEvents::OUT)
607            {
608                Poll::Ready(r) => {
609                    if r.has_err() {
610                        let err = take_socket_error(self.socket.as_mut().unwrap());
611                        let reset = match err.kind() {
612                            ErrorKind::TimedOut => {
613                                // Avoid resetting so that the guest doesn't
614                                // think there is a responding TCP stack at this
615                                // address. The guest will time out on its own.
616                                tracing::debug!(
617                                    error = &err as &dyn std::error::Error,
618                                    "connect timed out"
619                                );
620                                false
621                            }
622                            ErrorKind::ConnectionRefused => {
623                                // Presumably the remote TCP stack send a RST.
624                                // Send a reset but don't log anything.
625                                tracing::debug!(
626                                    error = &err as &dyn std::error::Error,
627                                    "connection refused"
628                                );
629                                true
630                            }
631                            _ => {
632                                // Something unexpected happened. Log and reset.
633                                //
634                                // FUTURE: Handle more cases, especially
635                                // ENETUNREACH and similar, once we figure out
636                                // the right behavior for these. They might
637                                // require sending ICMP packets.
638                                tracing::warn!(
639                                    error = &err as &dyn std::error::Error,
640                                    "unhandled connect failure"
641                                );
642                                true
643                            }
644                        };
645                        if reset {
646                            sender.rst(self.tx_send, Some(self.rx_seq));
647                        }
648                        return false;
649                    }
650
651                    tracing::debug!("connection established");
652                    self.state = TcpState::SynReceived;
653                    self.rx_window_cap = self.rx_buffer.capacity();
654                }
655                Poll::Pending => return true,
656            }
657        } else if self.state == TcpState::SynSent {
658            // Need to establish connection with client before sending data.
659            return true;
660        }
661
662        // Handle the tx path.
663        if self.socket.is_some() {
664            if self.state.tx_fin() {
665                if let Poll::Ready(events) = self
666                    .socket
667                    .as_mut()
668                    .unwrap()
669                    .poll_ready(cx, PollEvents::EMPTY)
670                {
671                    if events.has_err() {
672                        let err = take_socket_error(self.socket.as_ref().unwrap());
673                        match err.kind() {
674                            ErrorKind::BrokenPipe | ErrorKind::ConnectionReset => {}
675                            _ => tracing::warn!(
676                                error = &err as &dyn std::error::Error,
677                                "socket failure after fin"
678                            ),
679                        }
680                        sender.rst(self.tx_send, Some(self.rx_seq));
681                        return false;
682                    }
683
684                    // Both ends are closed. Close the actual socket.
685                    self.socket = None;
686                }
687            } else {
688                while !self.tx_buffer.is_full() {
689                    let (a, b) = self.tx_buffer.unwritten_slices_mut();
690                    let mut bufs = [IoSliceMut::new(a), IoSliceMut::new(b)];
691                    match Pin::new(&mut *self.socket.as_mut().unwrap())
692                        .poll_read_vectored(cx, &mut bufs)
693                    {
694                        Poll::Ready(Ok(n)) => {
695                            if n == 0 {
696                                self.close();
697                                break;
698                            }
699                            self.tx_buffer.extend_by(n);
700                        }
701                        Poll::Ready(Err(err)) => {
702                            match err.kind() {
703                                ErrorKind::ConnectionReset => tracing::trace!(
704                                    error = &err as &dyn std::error::Error,
705                                    "socket read error"
706                                ),
707                                _ => tracing::warn!(
708                                    error = &err as &dyn std::error::Error,
709                                    "socket read error"
710                                ),
711                            }
712                            sender.rst(self.tx_send, Some(self.rx_seq));
713                            return false;
714                        }
715                        Poll::Pending => break,
716                    }
717                }
718            }
719        }
720
721        // Handle the rx path.
722        if self.socket.is_some() {
723            while !self.rx_buffer.is_empty() {
724                let (a, b) = self.rx_buffer.as_slices();
725                let bufs = [IoSlice::new(a), IoSlice::new(b)];
726                match Pin::new(&mut *self.socket.as_mut().unwrap()).poll_write_vectored(cx, &bufs) {
727                    Poll::Ready(Ok(n)) => {
728                        self.rx_buffer.drain(..n);
729                    }
730                    Poll::Ready(Err(err)) => {
731                        match err.kind() {
732                            ErrorKind::BrokenPipe | ErrorKind::ConnectionReset => {}
733                            _ => {
734                                tracing::warn!(
735                                    error = &err as &dyn std::error::Error,
736                                    "socket write error"
737                                );
738                            }
739                        }
740                        sender.rst(self.tx_send, Some(self.rx_seq));
741                        return false;
742                    }
743                    Poll::Pending => break,
744                }
745            }
746            if self.rx_buffer.is_empty() && self.state.rx_fin() && !self.is_shutdown {
747                if let Err(err) = self
748                    .socket
749                    .as_ref()
750                    .unwrap()
751                    .get()
752                    .shutdown(Shutdown::Write)
753                {
754                    tracing::warn!(error = &err as &dyn std::error::Error, "shutdown error");
755                    sender.rst(self.tx_send, Some(self.rx_seq));
756                    return false;
757                }
758                self.is_shutdown = true;
759            }
760        }
761
762        // Send whatever needs to be sent.
763        self.send_next(sender);
764        true
765    }
766
767    fn rx_window_len(&self) -> u16 {
768        ((self.rx_window_cap - self.rx_buffer.len()) >> self.rx_window_scale) as u16
769    }
770
771    fn send_next(&mut self, sender: &mut Sender<'_, impl Client>) {
772        match self.state {
773            TcpState::Connecting => {}
774            TcpState::SynReceived => self.send_syn(sender, Some(self.rx_seq)),
775            _ => self.send_data(sender),
776        }
777    }
778
779    fn send_syn(&mut self, sender: &mut Sender<'_, impl Client>, ack_number: Option<TcpSeqNumber>) {
780        if self.tx_send != self.tx_acked || sender.client.rx_mtu() == 0 {
781            return;
782        }
783
784        // If the client side specified a window scale option, then do the same
785        // (even with no shift) to enable window scale support.
786        let window_scale = self.enable_window_scaling.then_some(self.rx_window_scale);
787
788        // Advertise the maximum possible segment size, allowing the guest
789        // to truncate this to its own MTU calculation.
790        let max_seg_size = u16::MAX;
791        let tcp = TcpRepr {
792            src_port: sender.ft.dst.port,
793            dst_port: sender.ft.src.port,
794            control: TcpControl::Syn,
795            seq_number: self.tx_send,
796            ack_number,
797            window_len: self.rx_window_len(),
798            window_scale,
799            max_seg_size: Some(max_seg_size),
800            sack_permitted: false,
801            sack_ranges: [None, None, None],
802            payload: &[],
803        };
804
805        sender.send_packet(&tcp, None);
806        self.tx_send += 1;
807    }
808
809    fn send_data(&mut self, sender: &mut Sender<'_, impl Client>) {
810        // These computations assume syn has already been sent and acked.
811        let tx_payload_end = self.tx_acked + self.tx_buffer.len();
812        let tx_end = tx_payload_end + self.tx_fin_buffered as usize;
813        let tx_window_end = self.tx_acked + ((self.tx_window_len as usize) << self.tx_window_scale);
814        let tx_done = seq_min([tx_end, tx_window_end]);
815
816        while self.needs_ack || self.tx_send < tx_done {
817            let rx_mtu = sender.client.rx_mtu();
818            if rx_mtu == 0 {
819                // Out of receive buffers.
820                break;
821            }
822
823            let mut tcp = TcpRepr {
824                src_port: sender.ft.dst.port,
825                dst_port: sender.ft.src.port,
826                control: TcpControl::None,
827                seq_number: self.tx_send,
828                ack_number: Some(self.rx_seq),
829                window_len: self.rx_window_len(),
830                window_scale: None,
831                max_seg_size: None,
832                sack_permitted: false,
833                sack_ranges: [None, None, None],
834                payload: &[],
835            };
836
837            let mut tx_next = self.tx_send;
838
839            // Compute the end of the segment buffer in sequence space to avoid
840            // exceeding:
841            // 1. The available buffer length.
842            // 2. The current window.
843            // 3. The configured maximum segment size.
844            // 4. The client MTU.
845            let tx_segment_end = {
846                let header_len = ETHERNET_HEADER_LEN + IPV4_HEADER_LEN + tcp.header_len();
847                let mtu = rx_mtu.min(sender.state.buffer.len());
848                seq_min([
849                    tx_payload_end,
850                    tx_window_end,
851                    tx_next + self.tx_mss,
852                    tx_next + (mtu - header_len),
853                ])
854            };
855
856            let (payload_start, payload_len) = if tx_next < tx_segment_end {
857                (tx_next - self.tx_acked, tx_segment_end - tx_next)
858            } else {
859                (0, 0)
860            };
861
862            tx_next += payload_len;
863
864            // Include the fin if present if there is still room.
865            if self.tx_fin_buffered
866                && tcp.control == TcpControl::None
867                && tx_next == tx_payload_end
868                && tx_next < tx_window_end
869            {
870                tcp.control = TcpControl::Fin;
871                tx_next += 1;
872            }
873
874            assert!(tx_next <= tx_end);
875            assert!(self.needs_ack || tx_next > self.tx_send);
876
877            tracing::trace!(?tcp, %tx_next, "tcp xmit");
878
879            let payload = self
880                .tx_buffer
881                .view(payload_start..payload_start + payload_len);
882
883            sender.send_packet(&tcp, Some(payload));
884            self.tx_send = tx_next;
885            self.needs_ack = false;
886        }
887
888        assert!(self.tx_send <= tx_end);
889    }
890
891    fn close(&mut self) {
892        tracing::trace!("fin");
893        match self.state {
894            TcpState::SynSent | TcpState::SynReceived | TcpState::Established => {
895                self.state = TcpState::FinWait1;
896            }
897            TcpState::CloseWait => {
898                self.state = TcpState::LastAck;
899            }
900            TcpState::Connecting
901            | TcpState::FinWait1
902            | TcpState::FinWait2
903            | TcpState::Closing
904            | TcpState::TimeWait
905            | TcpState::LastAck => unreachable!("fin in {:?}", self.state),
906        }
907        self.tx_fin_buffered = true;
908    }
909
910    /// Send an ACK using the current state of the connection.
911    ///
912    /// This is used when sending an ack to report a the reception of an
913    /// unacceptable packet (duplicate, out of order, etc.). These acks
914    /// shouldn't be combined with data so that they are interpreted correctly
915    /// by the peer.
916    fn ack(&self, sender: &mut Sender<'_, impl Client>) {
917        let tcp = TcpRepr {
918            src_port: sender.ft.dst.port,
919            dst_port: sender.ft.src.port,
920            control: TcpControl::None,
921            seq_number: self.tx_send,
922            ack_number: Some(self.rx_seq),
923            window_len: self.rx_window_len(),
924            window_scale: None,
925            max_seg_size: None,
926            sack_permitted: false,
927            sack_ranges: [None, None, None],
928            payload: &[],
929        };
930
931        tracing::trace!(?tcp, "tcp ack xmit");
932
933        sender.send_packet(&tcp, None);
934    }
935
936    fn handle_listen_syn(
937        &mut self,
938        sender: &mut Sender<'_, impl Client>,
939        tcp: &TcpRepr<'_>,
940    ) -> Result<bool, DropReason> {
941        if tcp.control != TcpControl::Syn || tcp.segment_len() != 1 {
942            tracing::error!(?tcp.control, "invalid packet waiting for syn, drop connection");
943            return Ok(false);
944        }
945
946        let ack_number = tcp.ack_number.ok_or(TcpError::MissingAck)?;
947        if ack_number <= self.tx_acked || ack_number > self.tx_send {
948            sender.rst(ack_number, None);
949            return Ok(false);
950        }
951        self.tx_acked = ack_number;
952
953        self.initialize_from_first_client_packet(tcp)?;
954        self.tx_window_tx_seq = ack_number;
955        self.rx_window_cap = self.rx_buffer.capacity();
956        self.tx_window_len = tcp.window_len;
957
958        // Send an ACK to complete the initial SYN handshake.
959        self.ack(sender);
960
961        self.state = TcpState::Established;
962        Ok(true)
963    }
964
965    fn handle_packet(
966        &mut self,
967        sender: &mut Sender<'_, impl Client>,
968        tcp: &TcpRepr<'_>,
969    ) -> Result<bool, DropReason> {
970        if self.state == TcpState::Connecting {
971            // We have not yet sent a syn (we are still deciding whether we are
972            // in LISTEN or CLOSED state), so we can't send a reasonable
973            // response to this. Just drop the packet.
974            return Err(TcpError::StillConnecting.into());
975        } else if self.state == TcpState::SynSent {
976            return self.handle_listen_syn(sender, tcp);
977        }
978
979        let rx_window_len = self.rx_window_cap - self.rx_buffer.len();
980        let rx_window_end = self.rx_seq + rx_window_len;
981        let segment_end = tcp.seq_number + tcp.segment_len();
982
983        // Validate the sequence number per RFC 793.
984        let seq_acceptable = if rx_window_len != 0 {
985            (tcp.seq_number >= self.rx_seq && tcp.seq_number < rx_window_end)
986                || (tcp.segment_len() > 0
987                    && segment_end > self.rx_seq
988                    && segment_end <= rx_window_end)
989        } else {
990            tcp.segment_len() == 0 && tcp.seq_number == self.rx_seq
991        };
992
993        if tcp.control == TcpControl::Rst {
994            if !seq_acceptable {
995                // Silently drop--don't send an ACK--since the peer would then
996                // immediately respond with a valid RST.
997                return Err(TcpError::Unacceptable.into());
998            }
999
1000            // RFC 5961
1001            if tcp.seq_number != self.rx_seq {
1002                // Send a challenge ACK.
1003                self.ack(sender);
1004                return Ok(true);
1005            }
1006
1007            // This is a valid RST. Drop the connection.
1008            tracing::debug!("connection reset");
1009            return Ok(false);
1010        }
1011
1012        // Send ack and drop packets with unacceptable sequence numbers.
1013        if !seq_acceptable {
1014            self.ack(sender);
1015            return Err(TcpError::Unacceptable.into());
1016        }
1017
1018        // Also ack+drop for out-of-order non-empty segments rather than queueing
1019        // them. Our environment makes out-of-order segments unlikely.
1020        if tcp.seq_number > self.rx_seq && tcp.segment_len() > 0 {
1021            self.ack(sender);
1022            return Err(TcpError::OutOfOrder.into());
1023        }
1024
1025        // SYN should not be set for in-window segments.
1026        if tcp.control == TcpControl::Syn {
1027            if self.state == TcpState::SynReceived {
1028                tracing::debug!("invalid syn, drop connection");
1029                return Ok(false);
1030            }
1031            // RFC 5961, send a challenge ACK.
1032            self.ack(sender);
1033            return Ok(true);
1034        }
1035
1036        // ACK should always be set at this point.
1037        let ack_number = tcp.ack_number.ok_or(TcpError::MissingAck)?;
1038
1039        // FUTURE: validate ack number per RFC 5961.
1040
1041        // Handle ACK of our SYN.
1042        if self.state == TcpState::SynReceived {
1043            if ack_number <= self.tx_acked || ack_number > self.tx_send {
1044                sender.rst(ack_number, None);
1045                return Ok(false);
1046            }
1047            self.tx_window_len = tcp.window_len;
1048            self.tx_window_rx_seq = tcp.seq_number;
1049            self.tx_window_tx_seq = ack_number;
1050            self.tx_acked += 1;
1051            self.state = TcpState::Established;
1052        }
1053
1054        // Ignore ACKs for segments that have not been sent.
1055        if ack_number > self.tx_send {
1056            self.ack(sender);
1057            return Err(TcpError::AckPastSequence.into());
1058        }
1059
1060        // Retire the ACKed segments.
1061        if ack_number > self.tx_acked {
1062            let mut consumed = ack_number - self.tx_acked;
1063            if self.tx_fin_buffered && ack_number == self.tx_acked + self.tx_buffer.len() + 1 {
1064                self.tx_fin_buffered = false;
1065                consumed -= 1;
1066                match self.state {
1067                    TcpState::FinWait1 => self.state = TcpState::FinWait2,
1068                    TcpState::Closing => self.state = TcpState::TimeWait,
1069                    TcpState::LastAck => return Ok(false),
1070                    _ => unreachable!(),
1071                }
1072            }
1073            self.tx_buffer.consume(consumed);
1074            self.tx_acked = ack_number;
1075        }
1076
1077        // Update the send window.
1078        if ack_number >= self.tx_acked
1079            && (tcp.seq_number > self.tx_window_rx_seq
1080                || (tcp.seq_number == self.tx_window_rx_seq && ack_number >= self.tx_window_tx_seq))
1081        {
1082            self.tx_window_len = tcp.window_len;
1083            self.tx_window_rx_seq = tcp.seq_number;
1084            self.tx_window_tx_seq = ack_number;
1085        }
1086
1087        // Scope the data payload and FIN to the in-window portion of the segment.
1088        let mut fin = tcp.control == TcpControl::Fin;
1089        let segment_skip = if tcp.seq_number < self.rx_seq {
1090            self.rx_seq - tcp.seq_number
1091        } else {
1092            0
1093        };
1094        let segment_end = if segment_end > rx_window_end {
1095            fin = false;
1096            rx_window_end
1097        } else {
1098            segment_end
1099        };
1100        let payload = &tcp.payload[segment_skip..segment_end - tcp.seq_number - fin as usize];
1101
1102        // Process the payload.
1103        match self.state {
1104            TcpState::Connecting | TcpState::SynReceived | TcpState::SynSent => unreachable!(),
1105            TcpState::Established | TcpState::FinWait1 | TcpState::FinWait2 => {
1106                self.rx_buffer.extend(payload);
1107                self.rx_seq = segment_end;
1108                if tcp.segment_len() > 0 {
1109                    self.needs_ack = true;
1110                }
1111            }
1112            TcpState::CloseWait | TcpState::Closing | TcpState::LastAck => {}
1113            TcpState::TimeWait => {
1114                self.ack(sender);
1115                // TODO: restart timer
1116            }
1117        }
1118
1119        // Process FIN.
1120        if fin {
1121            match self.state {
1122                TcpState::Connecting | TcpState::SynReceived | TcpState::SynSent => unreachable!(),
1123                TcpState::Established => {
1124                    self.state = TcpState::CloseWait;
1125                }
1126                TcpState::FinWait1 => {
1127                    self.state = TcpState::Closing;
1128                }
1129                TcpState::FinWait2 => {
1130                    self.state = TcpState::TimeWait;
1131                    // TODO: start timer
1132                }
1133                TcpState::CloseWait
1134                | TcpState::Closing
1135                | TcpState::LastAck
1136                | TcpState::TimeWait => {}
1137            }
1138        }
1139
1140        Ok(true)
1141    }
1142}
1143
1144impl TcpListener {
1145    pub fn new(sender: &mut Sender<'_, impl Client>) -> Result<Self, DropReason> {
1146        let socket =
1147            Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP)).map_err(DropReason::Io)?;
1148
1149        let socket = PolledSocket::new(sender.client.driver(), socket).map_err(DropReason::Io)?;
1150        if let Err(err) = socket.get().bind(&sender.ft.src.into()) {
1151            tracing::warn!(
1152                address = ?sender.ft.src,
1153                error = &err as &dyn std::error::Error,
1154                "socket bind error"
1155            );
1156            return Err(DropReason::Io(err));
1157        }
1158        if let Err(err) = socket.listen(10) {
1159            tracing::warn!(
1160                error = &err as &dyn std::error::Error,
1161                "socket listen error"
1162            );
1163            return Err(DropReason::Io(err));
1164        }
1165        Ok(Self { socket })
1166    }
1167
1168    fn poll_listener(
1169        &mut self,
1170        cx: &mut Context<'_>,
1171    ) -> Result<Option<(Socket, SocketAddress)>, DropReason> {
1172        match self.socket.poll_accept(cx) {
1173            Poll::Ready(r) => match r {
1174                Ok((socket, address)) => match address.as_socket() {
1175                    Some(addr) => match address.as_socket_ipv4() {
1176                        Some(src_address) => Ok(Some((
1177                            socket,
1178                            SocketAddress {
1179                                ip: (*src_address.ip()).into(),
1180                                port: addr.port(),
1181                            },
1182                        ))),
1183                        None => {
1184                            tracing::warn!(?address, "Not an IPv4 address from accept");
1185                            Ok(None)
1186                        }
1187                    },
1188                    None => {
1189                        tracing::warn!(?address, "Unknown address from accept");
1190                        Ok(None)
1191                    }
1192                },
1193                Err(_) => {
1194                    let err = take_socket_error(&self.socket);
1195                    tracing::warn!(error = &err as &dyn std::error::Error, "listen failure");
1196                    Err(DropReason::Io(err))
1197                }
1198            },
1199            Poll::Pending => Ok(None),
1200        }
1201    }
1202}
1203
1204fn take_socket_error(socket: &PolledSocket<Socket>) -> io::Error {
1205    match socket.get().take_error() {
1206        Ok(Some(err)) => err,
1207        Ok(_) => io::Error::other("missing error"),
1208        Err(err) => err,
1209    }
1210}
1211
1212fn is_connect_incomplete_error(err: &io::Error) -> bool {
1213    if err.kind() == ErrorKind::WouldBlock {
1214        return true;
1215    }
1216    // This handles the remaining cases on Linux.
1217    #[cfg(unix)]
1218    if err.raw_os_error() == Some(libc::EINPROGRESS) {
1219        return true;
1220    }
1221    false
1222}
1223
1224/// Finds the smallest sequence number in a set. To get a coherent result, all
1225/// the sequence numbers must be known to be comparable, meaning they are all
1226/// within 2^31 bytes of each other.
1227///
1228/// This isn't just `Ord::min` or `Iterator::min` because `TcpSeqNumber`
1229/// implements `PartialOrd` but not `Ord`.
1230fn seq_min<const N: usize>(seqs: [TcpSeqNumber; N]) -> TcpSeqNumber {
1231    let mut min = seqs[0];
1232    for &seq in &seqs[1..] {
1233        if min > seq {
1234            min = seq;
1235        }
1236    }
1237    min
1238}