1mod ring;
5
6use super::Access;
7use super::Client;
8use super::DropReason;
9use super::FourTuple;
10use super::SocketAddress;
11use crate::ChecksumState;
12use crate::ConsommeState;
13use crate::Ipv4Addresses;
14use futures::AsyncRead;
15use futures::AsyncWrite;
16use inspect::Inspect;
17use pal_async::interest::PollEvents;
18use pal_async::socket::PollReady;
19use pal_async::socket::PolledSocket;
20use smoltcp::phy::ChecksumCapabilities;
21use smoltcp::wire::ETHERNET_HEADER_LEN;
22use smoltcp::wire::EthernetFrame;
23use smoltcp::wire::EthernetProtocol;
24use smoltcp::wire::IPV4_HEADER_LEN;
25use smoltcp::wire::IpProtocol;
26use smoltcp::wire::Ipv4Packet;
27use smoltcp::wire::Ipv4Repr;
28use smoltcp::wire::TcpControl;
29use smoltcp::wire::TcpPacket;
30use smoltcp::wire::TcpRepr;
31use smoltcp::wire::TcpSeqNumber;
32use socket2::Domain;
33use socket2::Protocol;
34use socket2::SockAddr;
35use socket2::Socket;
36use socket2::Type;
37use std::collections::HashMap;
38use std::collections::VecDeque;
39use std::collections::hash_map;
40use std::io;
41use std::io::ErrorKind;
42use std::io::IoSlice;
43use std::io::IoSliceMut;
44use std::net::Ipv4Addr;
45use std::net::Shutdown;
46use std::net::SocketAddrV4;
47use std::pin::Pin;
48use std::task::Context;
49use std::task::Poll;
50use thiserror::Error;
51
52pub(crate) struct Tcp {
53 connections: HashMap<FourTuple, TcpConnection>,
54 listeners: HashMap<u16, TcpListener>,
55}
56
57#[derive(Debug, Error)]
58pub enum TcpError {
59 #[error("still connecting")]
60 StillConnecting,
61 #[error("unacceptable segment number")]
62 Unacceptable,
63 #[error("received out of order packet")]
64 OutOfOrder,
65 #[error("missing ack bit")]
66 MissingAck,
67 #[error("ack newer than sequence")]
68 AckPastSequence,
69 #[error("invalid window scale")]
70 InvalidWindowScale,
71}
72
73impl Inspect for Tcp {
74 fn inspect(&self, req: inspect::Request<'_>) {
75 let mut resp = req.respond();
76 for (addr, conn) in &self.connections {
77 resp.field(
78 &format!(
79 "{}:{}-{}:{}",
80 addr.src.ip, addr.src.port, addr.dst.ip, addr.dst.port
81 ),
82 conn,
83 );
84 }
85 for port in self.listeners.keys() {
86 resp.field("listening port", port);
87 }
88 }
89}
90
91impl Tcp {
92 pub fn new() -> Self {
93 Self {
94 connections: HashMap::new(),
95 listeners: HashMap::new(),
96 }
97 }
98}
99
100#[derive(Inspect)]
101#[inspect(tag = "info")]
102enum LoopbackPortInfo {
103 None,
104 ProxyForGuestPort { sending_port: u16, guest_port: u16 },
105}
106
107#[derive(Inspect)]
108struct TcpConnection {
109 #[inspect(skip)]
110 socket: Option<PolledSocket<Socket>>,
111 loopback_port: LoopbackPortInfo,
112 state: TcpState,
113
114 #[inspect(with = "|x| x.len()")]
115 rx_buffer: VecDeque<u8>,
116 #[inspect(hex)]
117 rx_window_cap: usize,
118 rx_window_scale: u8,
119 #[inspect(with = "inspect_seq")]
120 rx_seq: TcpSeqNumber,
121 needs_ack: bool,
122 is_shutdown: bool,
123 enable_window_scaling: bool,
124
125 #[inspect(with = "|x| x.len()")]
126 tx_buffer: ring::Ring,
127 #[inspect(with = "inspect_seq")]
128 tx_acked: TcpSeqNumber,
129 #[inspect(with = "inspect_seq")]
130 tx_send: TcpSeqNumber,
131 tx_fin_buffered: bool,
132 #[inspect(hex)]
133 tx_window_len: u16,
134 tx_window_scale: u8,
135 #[inspect(with = "inspect_seq")]
136 tx_window_rx_seq: TcpSeqNumber,
137 #[inspect(with = "inspect_seq")]
138 tx_window_tx_seq: TcpSeqNumber,
139 #[inspect(hex)]
140 tx_mss: usize,
141}
142
143fn inspect_seq(seq: &TcpSeqNumber) -> inspect::AsHex<u32> {
144 inspect::AsHex(seq.0 as u32)
145}
146
147#[derive(Inspect)]
148struct TcpListener {
149 #[inspect(skip)]
150 socket: PolledSocket<Socket>,
151}
152
153#[derive(Debug, PartialEq, Eq, Inspect)]
154enum TcpState {
155 Connecting,
156 SynSent,
157 SynReceived,
158 Established,
159 FinWait1,
160 FinWait2,
161 CloseWait,
162 Closing,
163 LastAck,
164 TimeWait,
165}
166
167impl TcpState {
168 fn tx_fin(&self) -> bool {
169 match self {
170 TcpState::Connecting
171 | TcpState::SynSent
172 | TcpState::SynReceived
173 | TcpState::Established
174 | TcpState::CloseWait => false,
175
176 TcpState::FinWait1
177 | TcpState::FinWait2
178 | TcpState::Closing
179 | TcpState::TimeWait
180 | TcpState::LastAck => true,
181 }
182 }
183
184 fn rx_fin(&self) -> bool {
185 match self {
186 TcpState::Connecting
187 | TcpState::SynSent
188 | TcpState::SynReceived
189 | TcpState::Established
190 | TcpState::FinWait1
191 | TcpState::FinWait2 => false,
192
193 TcpState::CloseWait | TcpState::Closing | TcpState::LastAck | TcpState::TimeWait => {
194 true
195 }
196 }
197 }
198}
199
200impl<T: Client> Access<'_, T> {
201 pub(crate) fn poll_tcp(&mut self, cx: &mut Context<'_>) {
202 self.inner
204 .tcp
205 .listeners
206 .retain(|port, listener| match listener.poll_listener(cx) {
207 Ok(result) => {
208 if let Some((socket, mut other_addr)) = result {
209 if other_addr.ip.is_loopback() {
212 for (other_ft, connection) in self.inner.tcp.connections.iter() {
213 if connection.state == TcpState::Connecting && other_ft.dst.port == *port {
214 if let LoopbackPortInfo::ProxyForGuestPort{sending_port, guest_port} = connection.loopback_port {
215 if sending_port == other_addr.port {
216 other_addr.port = guest_port;
217 break;
218 }
219 }
220 }
221 }
222 }
223
224 let ft = FourTuple { dst: other_addr, src: SocketAddress {
225 ip: self.inner.state.params.client_ip,
226 port: *port,
227 } };
228
229 match self.inner.tcp.connections.entry(ft) {
230 hash_map::Entry::Vacant(e) => {
231 let mut sender = Sender {
232 ft: &ft,
233 client: self.client,
234 state: &mut self.inner.state,
235 };
236
237 let conn = match TcpConnection::new_from_accept(
238 &mut sender,
239 socket,
240 ) {
241 Ok(conn) => conn,
242 Err(err) => {
243 tracing::warn!(err = %err, "Failed to create connection from newly accepted socket");
244 return true;
245 }
246 };
247 e.insert(conn);
248 }
249 hash_map::Entry::Occupied(_) => {
250 tracing::warn!(
251 address = ?ft.dst,
252 "New client request ignored because it was already connected"
253 );
254 }
255 }
256 }
257 true
258 }
259 Err(_) => false,
260 });
261 self.inner.tcp.connections.retain(|ft, conn| {
263 conn.poll_conn(
264 cx,
265 &mut Sender {
266 ft,
267 state: &mut self.inner.state,
268 client: self.client,
269 },
270 )
271 })
272 }
273
274 pub(crate) fn refresh_tcp_driver(&mut self) {
275 self.inner.tcp.connections.retain(|_, conn| {
276 let Some(socket) = conn.socket.take() else {
277 return true;
278 };
279 let socket = socket.into_inner();
280 match PolledSocket::new(self.client.driver(), socket) {
281 Ok(socket) => {
282 conn.socket = Some(socket);
283 true
284 }
285 Err(err) => {
286 tracing::warn!(
287 error = &err as &dyn std::error::Error,
288 "failed to update driver for tcp connection"
289 );
290 false
291 }
292 }
293 })
294 }
295
296 pub(crate) fn handle_tcp(
297 &mut self,
298 addresses: &Ipv4Addresses,
299 payload: &[u8],
300 checksum: &ChecksumState,
301 ) -> Result<(), DropReason> {
302 let tcp_packet = TcpPacket::new_checked(payload)?;
303 let tcp = TcpRepr::parse(
304 &tcp_packet,
305 &addresses.src_addr.into(),
306 &addresses.dst_addr.into(),
307 &checksum.caps(),
308 )?;
309
310 tracing::trace!(?tcp, "tcp packet");
311
312 let ft = FourTuple {
313 dst: SocketAddress {
314 ip: addresses.dst_addr,
315 port: tcp.dst_port,
316 },
317 src: SocketAddress {
318 ip: addresses.src_addr,
319 port: tcp.src_port,
320 },
321 };
322
323 let mut sender = Sender {
324 ft: &ft,
325 client: self.client,
326 state: &mut self.inner.state,
327 };
328
329 match self.inner.tcp.connections.entry(ft) {
330 hash_map::Entry::Occupied(mut e) => {
331 let conn = e.get_mut();
332 if !conn.handle_packet(&mut sender, &tcp)? {
333 e.remove();
334 }
335 }
336 hash_map::Entry::Vacant(e) => {
337 if tcp.control == TcpControl::Rst {
338 } else if let Some(ack) = tcp.ack_number {
340 sender.rst(ack, None);
342 } else if tcp.control == TcpControl::Syn {
343 let conn = TcpConnection::new(&mut sender, &tcp)?;
344 e.insert(conn);
345 } else {
346 }
348 }
349 }
350 Ok(())
351 }
352
353 pub fn bind_tcp_port(
356 &mut self,
357 ip_addr: Option<Ipv4Addr>,
358 port: u16,
359 ) -> Result<(), DropReason> {
360 match self.inner.tcp.listeners.entry(port) {
361 hash_map::Entry::Occupied(_) => {
362 tracing::warn!(port, "Duplicate TCP bind for port");
363 }
364 hash_map::Entry::Vacant(e) => {
365 let ft = FourTuple {
366 dst: SocketAddress {
367 ip: Ipv4Addr::UNSPECIFIED.into(),
368 port: 0,
369 },
370 src: SocketAddress {
371 ip: ip_addr.unwrap_or(Ipv4Addr::UNSPECIFIED).into(),
372 port,
373 },
374 };
375 let mut sender = Sender {
376 ft: &ft,
377 client: self.client,
378 state: &mut self.inner.state,
379 };
380
381 let listener = TcpListener::new(&mut sender)?;
382 e.insert(listener);
383 }
384 }
385 Ok(())
386 }
387
388 pub fn unbind_tcp_port(&mut self, port: u16) -> Result<(), DropReason> {
390 match self.inner.tcp.listeners.entry(port) {
391 hash_map::Entry::Occupied(e) => {
392 e.remove();
393 Ok(())
394 }
395 hash_map::Entry::Vacant(_) => Err(DropReason::PortNotBound),
396 }
397 }
398}
399
400struct Sender<'a, T> {
401 ft: &'a FourTuple,
402 client: &'a mut T,
403 state: &'a mut ConsommeState,
404}
405
406impl<T: Client> Sender<'_, T> {
407 fn send_packet(&mut self, tcp: &TcpRepr<'_>, payload: Option<ring::View<'_>>) {
408 let buffer = &mut self.state.buffer;
409 let mut eth_packet = EthernetFrame::new_unchecked(&mut buffer[..]);
410 eth_packet.set_ethertype(EthernetProtocol::Ipv4);
411 eth_packet.set_dst_addr(self.state.params.client_mac);
412 eth_packet.set_src_addr(self.state.params.gateway_mac);
413 let mut ipv4_packet = Ipv4Packet::new_unchecked(eth_packet.payload_mut());
414 let ipv4 = Ipv4Repr {
415 src_addr: self.ft.dst.ip,
416 dst_addr: self.ft.src.ip,
417 protocol: IpProtocol::Tcp,
418 payload_len: tcp.header_len() + payload.as_ref().map_or(0, |p| p.len()),
419 hop_limit: 64,
420 };
421 ipv4.emit(&mut ipv4_packet, &ChecksumCapabilities::default());
422 let mut tcp_packet = TcpPacket::new_unchecked(ipv4_packet.payload_mut());
423 tcp.emit(
424 &mut tcp_packet,
425 &self.ft.dst.ip.into(),
426 &self.ft.src.ip.into(),
427 &ChecksumCapabilities::default(),
428 );
429 if let Some(payload) = payload {
430 for (b, c) in tcp_packet.payload_mut().iter_mut().zip(payload.iter()) {
431 *b = *c;
432 }
433 }
434 tcp_packet.fill_checksum(&self.ft.dst.ip.into(), &self.ft.src.ip.into());
435 let n = ETHERNET_HEADER_LEN + ipv4_packet.total_len() as usize;
436 self.client.recv(&buffer[..n], &ChecksumState::TCP4);
437 }
438
439 fn rst(&mut self, seq: TcpSeqNumber, ack: Option<TcpSeqNumber>) {
440 let tcp = TcpRepr {
441 src_port: self.ft.dst.port,
442 dst_port: self.ft.src.port,
443 control: TcpControl::Rst,
444 seq_number: seq,
445 ack_number: ack,
446 window_len: 0,
447 window_scale: None,
448 max_seg_size: None,
449 sack_permitted: false,
450 sack_ranges: [None, None, None],
451 payload: &[],
452 };
453
454 tracing::trace!(?tcp, "tcp rst xmit");
455
456 self.send_packet(&tcp, None);
457 }
458}
459
460impl Default for TcpConnection {
461 fn default() -> Self {
462 let mut rx_tx_seq = [0; 8];
463 getrandom::fill(&mut rx_tx_seq[..]).expect("prng failure");
464 let rx_seq = TcpSeqNumber(i32::from_ne_bytes(
465 rx_tx_seq[0..4].try_into().expect("invalid length"),
466 ));
467 let tx_seq = TcpSeqNumber(i32::from_ne_bytes(
468 rx_tx_seq[4..8].try_into().expect("invalid length"),
469 ));
470
471 let rx_buffer_size: usize = 16384;
472 let rx_window_scale =
473 (usize::BITS - rx_buffer_size.leading_zeros()).saturating_sub(16) as u8;
474
475 let tx_buffer_size = 16384;
476
477 Self {
478 socket: None,
479 loopback_port: LoopbackPortInfo::None,
480 state: TcpState::Connecting,
481 rx_buffer: VecDeque::with_capacity(rx_buffer_size),
482 rx_window_cap: 0,
483 rx_window_scale,
484 rx_seq,
485 needs_ack: false,
486 is_shutdown: false,
487 enable_window_scaling: false,
488 tx_buffer: ring::Ring::new(tx_buffer_size),
489 tx_acked: tx_seq,
490 tx_send: tx_seq,
491 tx_window_len: 1,
492 tx_window_scale: 0,
493 tx_window_rx_seq: rx_seq,
494 tx_window_tx_seq: tx_seq,
495 tx_mss: 536,
498 tx_fin_buffered: false,
499 }
500 }
501}
502
503impl TcpConnection {
504 fn new(sender: &mut Sender<'_, impl Client>, tcp: &TcpRepr<'_>) -> Result<Self, DropReason> {
505 let mut this = Self::default();
506 this.initialize_from_first_client_packet(tcp)?;
507
508 let socket =
509 Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP)).map_err(DropReason::Io)?;
510
511 #[cfg(windows)]
515 if sender.ft.dst.ip.is_loopback() {
516 if let Err(err) = crate::windows::disable_connection_retries(&socket) {
517 tracing::trace!(err, "Failed to disable loopback retries");
518 }
519 }
520
521 let socket = PolledSocket::new(sender.client.driver(), socket).map_err(DropReason::Io)?;
522 match socket
523 .get()
524 .connect(&SockAddr::from(SocketAddrV4::from(sender.ft.dst)))
525 {
526 Ok(_) => unreachable!(),
527 Err(err) if is_connect_incomplete_error(&err) => (),
528 Err(err) => {
529 tracing::warn!(
530 error = &err as &dyn std::error::Error,
531 "socket connect error"
532 );
533 sender.rst(TcpSeqNumber(0), Some(tcp.seq_number + tcp.segment_len()));
534 return Err(DropReason::Io(err));
535 }
536 }
537 if let Ok(addr) = socket.get().local_addr() {
538 if let Some(addr) = addr.as_socket_ipv4() {
539 if addr.ip().is_loopback() {
540 this.loopback_port = LoopbackPortInfo::ProxyForGuestPort {
541 sending_port: addr.port(),
542 guest_port: sender.ft.src.port,
543 };
544 }
545 }
546 }
547 this.socket = Some(socket);
548 Ok(this)
549 }
550
551 fn new_from_accept(
552 sender: &mut Sender<'_, impl Client>,
553 socket: Socket,
554 ) -> Result<Self, DropReason> {
555 let mut this = Self {
556 socket: Some(
557 PolledSocket::new(sender.client.driver(), socket).map_err(DropReason::Io)?,
558 ),
559 state: TcpState::SynSent,
560 ..Default::default()
561 };
562 this.send_syn(sender, None);
563 Ok(this)
564 }
565
566 fn initialize_from_first_client_packet(&mut self, tcp: &TcpRepr<'_>) -> Result<(), DropReason> {
567 let tx_mss = tcp.max_seg_size.map_or(536, |x| x.into());
570
571 if let Some(tx_window_scale) = tcp.window_scale {
572 if tx_window_scale > 14 {
573 return Err(TcpError::InvalidWindowScale.into());
574 }
575 }
576
577 let max_rx_buffer_size = if tcp.window_scale.is_some() {
578 u32::MAX as usize
579 } else {
580 u16::MAX as usize
581 };
582 let rx_buffer_size = 16384.min(max_rx_buffer_size);
583 let rx_window_scale =
584 (usize::BITS - rx_buffer_size.leading_zeros()).saturating_sub(16) as u8;
585
586 assert!(tcp.window_scale.is_some() || rx_window_scale == 0);
587 if self.rx_buffer.capacity() < rx_buffer_size {
588 self.rx_buffer.reserve_exact(rx_buffer_size);
589 }
590
591 self.rx_window_scale = rx_window_scale;
592 self.rx_seq = tcp.seq_number + 1;
593 self.tx_window_rx_seq = tcp.seq_number + 1;
594 self.enable_window_scaling = tcp.window_scale.is_some();
595 self.tx_window_scale = tcp.window_scale.unwrap_or(0);
596 self.tx_mss = tx_mss;
597 Ok(())
598 }
599
600 fn poll_conn(&mut self, cx: &mut Context<'_>, sender: &mut Sender<'_, impl Client>) -> bool {
601 if self.state == TcpState::Connecting {
602 match self
603 .socket
604 .as_mut()
605 .unwrap()
606 .poll_ready(cx, PollEvents::OUT)
607 {
608 Poll::Ready(r) => {
609 if r.has_err() {
610 let err = take_socket_error(self.socket.as_mut().unwrap());
611 let reset = match err.kind() {
612 ErrorKind::TimedOut => {
613 tracing::debug!(
617 error = &err as &dyn std::error::Error,
618 "connect timed out"
619 );
620 false
621 }
622 ErrorKind::ConnectionRefused => {
623 tracing::debug!(
626 error = &err as &dyn std::error::Error,
627 "connection refused"
628 );
629 true
630 }
631 _ => {
632 tracing::warn!(
639 error = &err as &dyn std::error::Error,
640 "unhandled connect failure"
641 );
642 true
643 }
644 };
645 if reset {
646 sender.rst(self.tx_send, Some(self.rx_seq));
647 }
648 return false;
649 }
650
651 tracing::debug!("connection established");
652 self.state = TcpState::SynReceived;
653 self.rx_window_cap = self.rx_buffer.capacity();
654 }
655 Poll::Pending => return true,
656 }
657 } else if self.state == TcpState::SynSent {
658 return true;
660 }
661
662 if self.socket.is_some() {
664 if self.state.tx_fin() {
665 if let Poll::Ready(events) = self
666 .socket
667 .as_mut()
668 .unwrap()
669 .poll_ready(cx, PollEvents::EMPTY)
670 {
671 if events.has_err() {
672 let err = take_socket_error(self.socket.as_ref().unwrap());
673 match err.kind() {
674 ErrorKind::BrokenPipe | ErrorKind::ConnectionReset => {}
675 _ => tracing::warn!(
676 error = &err as &dyn std::error::Error,
677 "socket failure after fin"
678 ),
679 }
680 sender.rst(self.tx_send, Some(self.rx_seq));
681 return false;
682 }
683
684 self.socket = None;
686 }
687 } else {
688 while !self.tx_buffer.is_full() {
689 let (a, b) = self.tx_buffer.unwritten_slices_mut();
690 let mut bufs = [IoSliceMut::new(a), IoSliceMut::new(b)];
691 match Pin::new(&mut *self.socket.as_mut().unwrap())
692 .poll_read_vectored(cx, &mut bufs)
693 {
694 Poll::Ready(Ok(n)) => {
695 if n == 0 {
696 self.close();
697 break;
698 }
699 self.tx_buffer.extend_by(n);
700 }
701 Poll::Ready(Err(err)) => {
702 match err.kind() {
703 ErrorKind::ConnectionReset => tracing::trace!(
704 error = &err as &dyn std::error::Error,
705 "socket read error"
706 ),
707 _ => tracing::warn!(
708 error = &err as &dyn std::error::Error,
709 "socket read error"
710 ),
711 }
712 sender.rst(self.tx_send, Some(self.rx_seq));
713 return false;
714 }
715 Poll::Pending => break,
716 }
717 }
718 }
719 }
720
721 if self.socket.is_some() {
723 while !self.rx_buffer.is_empty() {
724 let (a, b) = self.rx_buffer.as_slices();
725 let bufs = [IoSlice::new(a), IoSlice::new(b)];
726 match Pin::new(&mut *self.socket.as_mut().unwrap()).poll_write_vectored(cx, &bufs) {
727 Poll::Ready(Ok(n)) => {
728 self.rx_buffer.drain(..n);
729 }
730 Poll::Ready(Err(err)) => {
731 match err.kind() {
732 ErrorKind::BrokenPipe | ErrorKind::ConnectionReset => {}
733 _ => {
734 tracing::warn!(
735 error = &err as &dyn std::error::Error,
736 "socket write error"
737 );
738 }
739 }
740 sender.rst(self.tx_send, Some(self.rx_seq));
741 return false;
742 }
743 Poll::Pending => break,
744 }
745 }
746 if self.rx_buffer.is_empty() && self.state.rx_fin() && !self.is_shutdown {
747 if let Err(err) = self
748 .socket
749 .as_ref()
750 .unwrap()
751 .get()
752 .shutdown(Shutdown::Write)
753 {
754 tracing::warn!(error = &err as &dyn std::error::Error, "shutdown error");
755 sender.rst(self.tx_send, Some(self.rx_seq));
756 return false;
757 }
758 self.is_shutdown = true;
759 }
760 }
761
762 self.send_next(sender);
764 true
765 }
766
767 fn rx_window_len(&self) -> u16 {
768 ((self.rx_window_cap - self.rx_buffer.len()) >> self.rx_window_scale) as u16
769 }
770
771 fn send_next(&mut self, sender: &mut Sender<'_, impl Client>) {
772 match self.state {
773 TcpState::Connecting => {}
774 TcpState::SynReceived => self.send_syn(sender, Some(self.rx_seq)),
775 _ => self.send_data(sender),
776 }
777 }
778
779 fn send_syn(&mut self, sender: &mut Sender<'_, impl Client>, ack_number: Option<TcpSeqNumber>) {
780 if self.tx_send != self.tx_acked || sender.client.rx_mtu() == 0 {
781 return;
782 }
783
784 let window_scale = self.enable_window_scaling.then_some(self.rx_window_scale);
787
788 let max_seg_size = u16::MAX;
791 let tcp = TcpRepr {
792 src_port: sender.ft.dst.port,
793 dst_port: sender.ft.src.port,
794 control: TcpControl::Syn,
795 seq_number: self.tx_send,
796 ack_number,
797 window_len: self.rx_window_len(),
798 window_scale,
799 max_seg_size: Some(max_seg_size),
800 sack_permitted: false,
801 sack_ranges: [None, None, None],
802 payload: &[],
803 };
804
805 sender.send_packet(&tcp, None);
806 self.tx_send += 1;
807 }
808
809 fn send_data(&mut self, sender: &mut Sender<'_, impl Client>) {
810 let tx_payload_end = self.tx_acked + self.tx_buffer.len();
812 let tx_end = tx_payload_end + self.tx_fin_buffered as usize;
813 let tx_window_end = self.tx_acked + ((self.tx_window_len as usize) << self.tx_window_scale);
814 let tx_done = seq_min([tx_end, tx_window_end]);
815
816 while self.needs_ack || self.tx_send < tx_done {
817 let rx_mtu = sender.client.rx_mtu();
818 if rx_mtu == 0 {
819 break;
821 }
822
823 let mut tcp = TcpRepr {
824 src_port: sender.ft.dst.port,
825 dst_port: sender.ft.src.port,
826 control: TcpControl::None,
827 seq_number: self.tx_send,
828 ack_number: Some(self.rx_seq),
829 window_len: self.rx_window_len(),
830 window_scale: None,
831 max_seg_size: None,
832 sack_permitted: false,
833 sack_ranges: [None, None, None],
834 payload: &[],
835 };
836
837 let mut tx_next = self.tx_send;
838
839 let tx_segment_end = {
846 let header_len = ETHERNET_HEADER_LEN + IPV4_HEADER_LEN + tcp.header_len();
847 let mtu = rx_mtu.min(sender.state.buffer.len());
848 seq_min([
849 tx_payload_end,
850 tx_window_end,
851 tx_next + self.tx_mss,
852 tx_next + (mtu - header_len),
853 ])
854 };
855
856 let (payload_start, payload_len) = if tx_next < tx_segment_end {
857 (tx_next - self.tx_acked, tx_segment_end - tx_next)
858 } else {
859 (0, 0)
860 };
861
862 tx_next += payload_len;
863
864 if self.tx_fin_buffered
866 && tcp.control == TcpControl::None
867 && tx_next == tx_payload_end
868 && tx_next < tx_window_end
869 {
870 tcp.control = TcpControl::Fin;
871 tx_next += 1;
872 }
873
874 assert!(tx_next <= tx_end);
875 assert!(self.needs_ack || tx_next > self.tx_send);
876
877 tracing::trace!(?tcp, %tx_next, "tcp xmit");
878
879 let payload = self
880 .tx_buffer
881 .view(payload_start..payload_start + payload_len);
882
883 sender.send_packet(&tcp, Some(payload));
884 self.tx_send = tx_next;
885 self.needs_ack = false;
886 }
887
888 assert!(self.tx_send <= tx_end);
889 }
890
891 fn close(&mut self) {
892 tracing::trace!("fin");
893 match self.state {
894 TcpState::SynSent | TcpState::SynReceived | TcpState::Established => {
895 self.state = TcpState::FinWait1;
896 }
897 TcpState::CloseWait => {
898 self.state = TcpState::LastAck;
899 }
900 TcpState::Connecting
901 | TcpState::FinWait1
902 | TcpState::FinWait2
903 | TcpState::Closing
904 | TcpState::TimeWait
905 | TcpState::LastAck => unreachable!("fin in {:?}", self.state),
906 }
907 self.tx_fin_buffered = true;
908 }
909
910 fn ack(&self, sender: &mut Sender<'_, impl Client>) {
917 let tcp = TcpRepr {
918 src_port: sender.ft.dst.port,
919 dst_port: sender.ft.src.port,
920 control: TcpControl::None,
921 seq_number: self.tx_send,
922 ack_number: Some(self.rx_seq),
923 window_len: self.rx_window_len(),
924 window_scale: None,
925 max_seg_size: None,
926 sack_permitted: false,
927 sack_ranges: [None, None, None],
928 payload: &[],
929 };
930
931 tracing::trace!(?tcp, "tcp ack xmit");
932
933 sender.send_packet(&tcp, None);
934 }
935
936 fn handle_listen_syn(
937 &mut self,
938 sender: &mut Sender<'_, impl Client>,
939 tcp: &TcpRepr<'_>,
940 ) -> Result<bool, DropReason> {
941 if tcp.control != TcpControl::Syn || tcp.segment_len() != 1 {
942 tracing::error!(?tcp.control, "invalid packet waiting for syn, drop connection");
943 return Ok(false);
944 }
945
946 let ack_number = tcp.ack_number.ok_or(TcpError::MissingAck)?;
947 if ack_number <= self.tx_acked || ack_number > self.tx_send {
948 sender.rst(ack_number, None);
949 return Ok(false);
950 }
951 self.tx_acked = ack_number;
952
953 self.initialize_from_first_client_packet(tcp)?;
954 self.tx_window_tx_seq = ack_number;
955 self.rx_window_cap = self.rx_buffer.capacity();
956 self.tx_window_len = tcp.window_len;
957
958 self.ack(sender);
960
961 self.state = TcpState::Established;
962 Ok(true)
963 }
964
965 fn handle_packet(
966 &mut self,
967 sender: &mut Sender<'_, impl Client>,
968 tcp: &TcpRepr<'_>,
969 ) -> Result<bool, DropReason> {
970 if self.state == TcpState::Connecting {
971 return Err(TcpError::StillConnecting.into());
975 } else if self.state == TcpState::SynSent {
976 return self.handle_listen_syn(sender, tcp);
977 }
978
979 let rx_window_len = self.rx_window_cap - self.rx_buffer.len();
980 let rx_window_end = self.rx_seq + rx_window_len;
981 let segment_end = tcp.seq_number + tcp.segment_len();
982
983 let seq_acceptable = if rx_window_len != 0 {
985 (tcp.seq_number >= self.rx_seq && tcp.seq_number < rx_window_end)
986 || (tcp.segment_len() > 0
987 && segment_end > self.rx_seq
988 && segment_end <= rx_window_end)
989 } else {
990 tcp.segment_len() == 0 && tcp.seq_number == self.rx_seq
991 };
992
993 if tcp.control == TcpControl::Rst {
994 if !seq_acceptable {
995 return Err(TcpError::Unacceptable.into());
998 }
999
1000 if tcp.seq_number != self.rx_seq {
1002 self.ack(sender);
1004 return Ok(true);
1005 }
1006
1007 tracing::debug!("connection reset");
1009 return Ok(false);
1010 }
1011
1012 if !seq_acceptable {
1014 self.ack(sender);
1015 return Err(TcpError::Unacceptable.into());
1016 }
1017
1018 if tcp.seq_number > self.rx_seq && tcp.segment_len() > 0 {
1021 self.ack(sender);
1022 return Err(TcpError::OutOfOrder.into());
1023 }
1024
1025 if tcp.control == TcpControl::Syn {
1027 if self.state == TcpState::SynReceived {
1028 tracing::debug!("invalid syn, drop connection");
1029 return Ok(false);
1030 }
1031 self.ack(sender);
1033 return Ok(true);
1034 }
1035
1036 let ack_number = tcp.ack_number.ok_or(TcpError::MissingAck)?;
1038
1039 if self.state == TcpState::SynReceived {
1043 if ack_number <= self.tx_acked || ack_number > self.tx_send {
1044 sender.rst(ack_number, None);
1045 return Ok(false);
1046 }
1047 self.tx_window_len = tcp.window_len;
1048 self.tx_window_rx_seq = tcp.seq_number;
1049 self.tx_window_tx_seq = ack_number;
1050 self.tx_acked += 1;
1051 self.state = TcpState::Established;
1052 }
1053
1054 if ack_number > self.tx_send {
1056 self.ack(sender);
1057 return Err(TcpError::AckPastSequence.into());
1058 }
1059
1060 if ack_number > self.tx_acked {
1062 let mut consumed = ack_number - self.tx_acked;
1063 if self.tx_fin_buffered && ack_number == self.tx_acked + self.tx_buffer.len() + 1 {
1064 self.tx_fin_buffered = false;
1065 consumed -= 1;
1066 match self.state {
1067 TcpState::FinWait1 => self.state = TcpState::FinWait2,
1068 TcpState::Closing => self.state = TcpState::TimeWait,
1069 TcpState::LastAck => return Ok(false),
1070 _ => unreachable!(),
1071 }
1072 }
1073 self.tx_buffer.consume(consumed);
1074 self.tx_acked = ack_number;
1075 }
1076
1077 if ack_number >= self.tx_acked
1079 && (tcp.seq_number > self.tx_window_rx_seq
1080 || (tcp.seq_number == self.tx_window_rx_seq && ack_number >= self.tx_window_tx_seq))
1081 {
1082 self.tx_window_len = tcp.window_len;
1083 self.tx_window_rx_seq = tcp.seq_number;
1084 self.tx_window_tx_seq = ack_number;
1085 }
1086
1087 let mut fin = tcp.control == TcpControl::Fin;
1089 let segment_skip = if tcp.seq_number < self.rx_seq {
1090 self.rx_seq - tcp.seq_number
1091 } else {
1092 0
1093 };
1094 let segment_end = if segment_end > rx_window_end {
1095 fin = false;
1096 rx_window_end
1097 } else {
1098 segment_end
1099 };
1100 let payload = &tcp.payload[segment_skip..segment_end - tcp.seq_number - fin as usize];
1101
1102 match self.state {
1104 TcpState::Connecting | TcpState::SynReceived | TcpState::SynSent => unreachable!(),
1105 TcpState::Established | TcpState::FinWait1 | TcpState::FinWait2 => {
1106 self.rx_buffer.extend(payload);
1107 self.rx_seq = segment_end;
1108 if tcp.segment_len() > 0 {
1109 self.needs_ack = true;
1110 }
1111 }
1112 TcpState::CloseWait | TcpState::Closing | TcpState::LastAck => {}
1113 TcpState::TimeWait => {
1114 self.ack(sender);
1115 }
1117 }
1118
1119 if fin {
1121 match self.state {
1122 TcpState::Connecting | TcpState::SynReceived | TcpState::SynSent => unreachable!(),
1123 TcpState::Established => {
1124 self.state = TcpState::CloseWait;
1125 }
1126 TcpState::FinWait1 => {
1127 self.state = TcpState::Closing;
1128 }
1129 TcpState::FinWait2 => {
1130 self.state = TcpState::TimeWait;
1131 }
1133 TcpState::CloseWait
1134 | TcpState::Closing
1135 | TcpState::LastAck
1136 | TcpState::TimeWait => {}
1137 }
1138 }
1139
1140 Ok(true)
1141 }
1142}
1143
1144impl TcpListener {
1145 pub fn new(sender: &mut Sender<'_, impl Client>) -> Result<Self, DropReason> {
1146 let socket =
1147 Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP)).map_err(DropReason::Io)?;
1148
1149 let socket = PolledSocket::new(sender.client.driver(), socket).map_err(DropReason::Io)?;
1150 if let Err(err) = socket.get().bind(&sender.ft.src.into()) {
1151 tracing::warn!(
1152 address = ?sender.ft.src,
1153 error = &err as &dyn std::error::Error,
1154 "socket bind error"
1155 );
1156 return Err(DropReason::Io(err));
1157 }
1158 if let Err(err) = socket.listen(10) {
1159 tracing::warn!(
1160 error = &err as &dyn std::error::Error,
1161 "socket listen error"
1162 );
1163 return Err(DropReason::Io(err));
1164 }
1165 Ok(Self { socket })
1166 }
1167
1168 fn poll_listener(
1169 &mut self,
1170 cx: &mut Context<'_>,
1171 ) -> Result<Option<(Socket, SocketAddress)>, DropReason> {
1172 match self.socket.poll_accept(cx) {
1173 Poll::Ready(r) => match r {
1174 Ok((socket, address)) => match address.as_socket() {
1175 Some(addr) => match address.as_socket_ipv4() {
1176 Some(src_address) => Ok(Some((
1177 socket,
1178 SocketAddress {
1179 ip: (*src_address.ip()).into(),
1180 port: addr.port(),
1181 },
1182 ))),
1183 None => {
1184 tracing::warn!(?address, "Not an IPv4 address from accept");
1185 Ok(None)
1186 }
1187 },
1188 None => {
1189 tracing::warn!(?address, "Unknown address from accept");
1190 Ok(None)
1191 }
1192 },
1193 Err(_) => {
1194 let err = take_socket_error(&self.socket);
1195 tracing::warn!(error = &err as &dyn std::error::Error, "listen failure");
1196 Err(DropReason::Io(err))
1197 }
1198 },
1199 Poll::Pending => Ok(None),
1200 }
1201 }
1202}
1203
1204fn take_socket_error(socket: &PolledSocket<Socket>) -> io::Error {
1205 match socket.get().take_error() {
1206 Ok(Some(err)) => err,
1207 Ok(_) => io::Error::other("missing error"),
1208 Err(err) => err,
1209 }
1210}
1211
1212fn is_connect_incomplete_error(err: &io::Error) -> bool {
1213 if err.kind() == ErrorKind::WouldBlock {
1214 return true;
1215 }
1216 #[cfg(unix)]
1218 if err.raw_os_error() == Some(libc::EINPROGRESS) {
1219 return true;
1220 }
1221 false
1222}
1223
1224fn seq_min<const N: usize>(seqs: [TcpSeqNumber; N]) -> TcpSeqNumber {
1231 let mut min = seqs[0];
1232 for &seq in &seqs[1..] {
1233 if min > seq {
1234 min = seq;
1235 }
1236 }
1237 min
1238}