pal_async/
socket.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Socket-related functionality.
5
6#[cfg(unix)]
7use super::fd;
8use super::interest::InterestSlot;
9use super::interest::PollEvents;
10use crate::driver::Driver;
11use crate::driver::PollImpl;
12use futures::AsyncRead;
13use futures::AsyncWrite;
14use parking_lot::Mutex;
15use std::fmt::Debug;
16use std::future::Future;
17use std::future::poll_fn;
18use std::io;
19use std::io::Read;
20use std::io::Write;
21use std::net::Shutdown;
22#[cfg(unix)]
23use std::os::unix::prelude::*;
24#[cfg(windows)]
25use std::os::windows::prelude::*;
26use std::path::Path;
27use std::pin::Pin;
28use std::sync::Arc;
29use std::task::Context;
30use std::task::Poll;
31use unix_socket::UnixStream;
32
33/// A trait for driving socket ready polling.
34pub trait SocketReadyDriver: Unpin {
35    /// The socket ready type.
36    type SocketReady: 'static + PollSocketReady;
37
38    /// Creates a new object for polling socket readiness.
39    #[cfg(windows)]
40    fn new_socket_ready(&self, socket: RawSocket) -> io::Result<Self::SocketReady>;
41    /// Creates a new object for polling socket readiness.
42    #[cfg(unix)]
43    fn new_socket_ready(&self, socket: RawFd) -> io::Result<Self::SocketReady>;
44}
45
46#[cfg(unix)]
47impl<T: fd::FdReadyDriver> SocketReadyDriver for T {
48    type SocketReady = <Self as fd::FdReadyDriver>::FdReady;
49
50    fn new_socket_ready(&self, socket: RawFd) -> io::Result<Self::SocketReady> {
51        self.new_fd_ready(socket)
52    }
53}
54
55/// A trait for polling socket readiness.
56pub trait PollSocketReady: Unpin + Send + Sync {
57    /// Polls a socket for readiness.
58    fn poll_socket_ready(
59        &mut self,
60        cx: &mut Context<'_>,
61        slot: InterestSlot,
62        events: PollEvents,
63    ) -> Poll<PollEvents>;
64
65    /// Clears cached socket readiness so that the next call to
66    /// `poll_socket_ready` will poll the OS again.
67    fn clear_socket_ready(&mut self, slot: InterestSlot);
68}
69
70#[cfg(unix)]
71impl<T: fd::PollFdReady> PollSocketReady for T {
72    fn poll_socket_ready(
73        &mut self,
74        cx: &mut Context<'_>,
75        slot: InterestSlot,
76        events: PollEvents,
77    ) -> Poll<PollEvents> {
78        self.poll_fd_ready(cx, slot, events)
79    }
80
81    fn clear_socket_ready(&mut self, slot: InterestSlot) {
82        self.clear_fd_ready(slot)
83    }
84}
85
86/// A polled socket.
87pub struct PolledSocket<T> {
88    poll: PollImpl<dyn PollSocketReady>, // must be first--some executors require that it's dropped before socket.
89    socket: T,
90}
91
92/// Trait implemented by socket types.
93pub trait AsSockRef: Unpin {
94    /// Returns a socket reference.
95    fn as_sock_ref(&self) -> socket2::SockRef<'_>;
96}
97
98impl<T: Unpin> AsSockRef for T
99where
100    for<'a> &'a T: Into<socket2::SockRef<'a>>,
101{
102    fn as_sock_ref(&self) -> socket2::SockRef<'_> {
103        self.into()
104    }
105}
106
107impl<T: AsSockRef> PolledSocket<T> {
108    /// Creates a new polled socket.
109    pub fn new(driver: &(impl ?Sized + Driver), socket: T) -> io::Result<Self> {
110        let sock_ref = socket.as_sock_ref();
111        sock_ref.set_nonblocking(true)?;
112        #[cfg(windows)]
113        let fd = sock_ref.as_raw_socket();
114        #[cfg(unix)]
115        let fd = sock_ref.as_raw_fd();
116        Ok(Self {
117            poll: driver.new_dyn_socket_ready(fd)?,
118            socket,
119        })
120    }
121
122    /// Extracts the inner socket.
123    pub fn into_inner(self) -> T {
124        let sock_ref = self.socket.as_sock_ref();
125        sock_ref.set_nonblocking(false).unwrap();
126        self.socket
127    }
128}
129
130impl<T> PolledSocket<T> {
131    /// Gets a reference to the inner socket.
132    pub fn get(&self) -> &T {
133        &self.socket
134    }
135
136    /// Gets a mutable reference to the inner socket.
137    pub fn get_mut(&mut self) -> &mut T {
138        &mut self.socket
139    }
140
141    /// Converts the inner socket type.
142    pub fn convert<T2: From<T>>(self) -> PolledSocket<T2> {
143        PolledSocket {
144            socket: T2::from(self.socket),
145            poll: self.poll,
146        }
147    }
148}
149
150/// Trait for objects that can be polled for readiness.
151pub trait PollReady {
152    /// Polls an object for readiness.
153    fn poll_ready(&mut self, cx: &mut Context<'_>, events: PollEvents) -> Poll<PollEvents>;
154}
155
156/// Extension methods for implementations of [`PollReady`].
157pub trait PollReadyExt {
158    /// Waits for a socket or file to hang up.
159    fn wait_ready(&mut self, events: PollEvents) -> Ready<'_, Self>
160    where
161        Self: Unpin + Sized;
162}
163
164impl<T: PollReady + Unpin> PollReadyExt for T {
165    fn wait_ready(&mut self, events: PollEvents) -> Ready<'_, Self>
166    where
167        Self: Unpin + Sized,
168    {
169        Ready(self, events)
170    }
171}
172
173/// Future for [`PollReadyExt::wait_ready`].
174pub struct Ready<'a, T>(&'a mut T, PollEvents);
175
176impl<T: Unpin + PollReady> Future for Ready<'_, T> {
177    type Output = PollEvents;
178
179    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
180        let this = self.get_mut();
181        this.0.poll_ready(cx, this.1)
182    }
183}
184
185impl<T> PolledSocket<T> {
186    /// Calls nonblocking operation `f` when the socket has least one event in
187    /// `events` ready.
188    ///
189    /// Uses interest slot `slot` to allow multiple concurrent operations.
190    ///
191    /// If `f` returns `Err(err)` with `err.kind() ==
192    /// io::ErrorKind::WouldBlock`, then this re-polls the socket for readiness
193    /// and returns `Poll::Pending`.
194    pub fn poll_io<F, R>(
195        &mut self,
196        cx: &mut Context<'_>,
197        slot: InterestSlot,
198        events: PollEvents,
199        mut f: F,
200    ) -> Poll<io::Result<R>>
201    where
202        F: FnMut(&mut Self) -> io::Result<R>,
203    {
204        loop {
205            std::task::ready!(self.poll.poll_socket_ready(cx, slot, events));
206            match f(self) {
207                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
208                    self.poll.clear_socket_ready(slot);
209                }
210                r => break Poll::Ready(r),
211            }
212        }
213    }
214}
215
216impl<T: AsSockRef> PollReady for PolledSocket<T> {
217    fn poll_ready(&mut self, cx: &mut Context<'_>, events: PollEvents) -> Poll<PollEvents> {
218        self.poll.poll_socket_ready(cx, InterestSlot::Read, events)
219    }
220}
221
222impl<T> PolledSocket<T>
223where
224    T: AsSockRef + Read + Write,
225{
226    /// Splits the socket into a read and write half that can be used
227    /// concurrently.
228    ///
229    /// This is more flexible and efficient than
230    /// [`futures::io::AsyncReadExt::split`], since it avoids holding a lock
231    /// while calling into the kernel, and it provides access to the underlying
232    /// socket for more advanced operations.
233    pub fn split(self) -> (ReadHalf<T>, WriteHalf<T>) {
234        let inner = Arc::new(SplitInner {
235            poll: Mutex::new(self.poll),
236            socket: self.socket,
237        });
238        (
239            ReadHalf {
240                inner: inner.clone(),
241            },
242            WriteHalf { inner },
243        )
244    }
245}
246
247fn is_connect_incomplete_error(err: &io::Error) -> bool {
248    // This handles the Windows and AF_UNIX case.
249    if err.kind() == io::ErrorKind::WouldBlock {
250        return true;
251    }
252    // This handles the remaining cases on Linux.
253    #[cfg(unix)]
254    if err.raw_os_error() == Some(libc::EINPROGRESS) {
255        return true;
256    }
257    false
258}
259
260impl PolledSocket<socket2::Socket> {
261    /// Connects the socket to address `addr`.
262    pub async fn connect(&mut self, addr: &socket2::SockAddr) -> io::Result<()> {
263        match self.socket.connect(addr) {
264            Ok(()) => Ok(()),
265            Err(err) if is_connect_incomplete_error(&err) => {
266                self.poll.clear_socket_ready(InterestSlot::Write);
267                poll_fn(|cx| {
268                    self.poll
269                        .poll_socket_ready(cx, InterestSlot::Write, PollEvents::OUT)
270                })
271                .await;
272                if let Some(err) = self.socket.take_error()? {
273                    return Err(err);
274                }
275                Ok(())
276            }
277            Err(err) => Err(err),
278        }
279    }
280}
281
282impl PolledSocket<UnixStream> {
283    /// Creates a new connected Unix stream socket.
284    pub async fn connect_unix(
285        driver: &(impl ?Sized + Driver),
286        addr: impl AsRef<Path>,
287    ) -> io::Result<Self> {
288        let socket = socket2::Socket::new(socket2::Domain::UNIX, socket2::Type::STREAM, None)?;
289        let mut socket = PolledSocket::new(driver, socket)?;
290        socket
291            .connect(&socket2::SockAddr::unix(addr.as_ref())?)
292            .await?;
293        Ok(socket.convert())
294    }
295}
296
297impl<T: AsSockRef + Read> AsyncRead for PolledSocket<T> {
298    fn poll_read(
299        mut self: Pin<&mut Self>,
300        cx: &mut Context<'_>,
301        buf: &mut [u8],
302    ) -> Poll<io::Result<usize>> {
303        self.poll_io(cx, InterestSlot::Read, PollEvents::IN, |this| {
304            this.socket.read(buf)
305        })
306    }
307
308    fn poll_read_vectored(
309        mut self: Pin<&mut Self>,
310        cx: &mut Context<'_>,
311        bufs: &mut [io::IoSliceMut<'_>],
312    ) -> Poll<io::Result<usize>> {
313        self.poll_io(cx, InterestSlot::Read, PollEvents::IN, |this| {
314            this.socket.read_vectored(bufs)
315        })
316    }
317}
318
319impl<T: AsSockRef + Write> AsyncWrite for PolledSocket<T> {
320    fn poll_write(
321        mut self: Pin<&mut Self>,
322        cx: &mut Context<'_>,
323        buf: &[u8],
324    ) -> Poll<io::Result<usize>> {
325        self.poll_io(cx, InterestSlot::Write, PollEvents::OUT, |this| {
326            this.socket.write(buf)
327        })
328    }
329
330    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
331        self.poll_io(cx, InterestSlot::Write, PollEvents::OUT, |this| {
332            this.socket.flush()
333        })
334    }
335
336    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
337        Poll::Ready(self.socket.as_sock_ref().shutdown(Shutdown::Write))
338    }
339
340    fn poll_write_vectored(
341        mut self: Pin<&mut Self>,
342        cx: &mut Context<'_>,
343        bufs: &[io::IoSlice<'_>],
344    ) -> Poll<io::Result<usize>> {
345        self.poll_io(cx, InterestSlot::Write, PollEvents::OUT, |this| {
346            this.socket.write_vectored(bufs)
347        })
348    }
349}
350
351/// Trait for listening sockets.
352pub trait Listener: AsSockRef {
353    /// The socket type.
354    type Socket: AsSockRef + Read + Write + Into<socket2::Socket>;
355    /// The socket address type.
356    type Address: Debug;
357
358    /// Accepts an incoming socket.
359    fn accept(&self) -> io::Result<(Self::Socket, Self::Address)>;
360    /// Returns the local address of the listener.
361    fn local_addr(&self) -> io::Result<Self::Address>;
362}
363
364impl<'a, T> Listener for &'a T
365where
366    T: Listener,
367    &'a T: AsSockRef,
368{
369    type Socket = T::Socket;
370    type Address = T::Address;
371
372    fn accept(&self) -> io::Result<(Self::Socket, Self::Address)> {
373        (**self).accept()
374    }
375
376    fn local_addr(&self) -> io::Result<Self::Address> {
377        (**self).local_addr()
378    }
379}
380
381macro_rules! listener {
382    ($ty:ty, $socket:ty, $addr:ty) => {
383        impl Listener for $ty {
384            type Socket = $socket;
385            type Address = $addr;
386            fn accept(&self) -> io::Result<(Self::Socket, Self::Address)> {
387                <$ty>::accept(self)
388            }
389            fn local_addr(&self) -> io::Result<Self::Address> {
390                <$ty>::local_addr(self)
391            }
392        }
393    };
394}
395
396listener!(
397    std::net::TcpListener,
398    std::net::TcpStream,
399    std::net::SocketAddr
400);
401
402#[cfg(unix)]
403listener!(
404    unix_socket::UnixListener,
405    UnixStream,
406    std::os::unix::net::SocketAddr
407);
408
409#[cfg(windows)]
410impl Listener for unix_socket::UnixListener {
411    type Socket = UnixStream;
412    type Address = ();
413
414    fn accept(&self) -> io::Result<(Self::Socket, Self::Address)> {
415        self.accept()
416    }
417
418    fn local_addr(&self) -> io::Result<Self::Address> {
419        Ok(())
420    }
421}
422
423listener!(socket2::Socket, socket2::Socket, socket2::SockAddr);
424
425impl PolledSocket<socket2::Socket> {
426    /// Listens for incoming connections.
427    pub fn listen(&self, backlog: i32) -> io::Result<()> {
428        self.socket.listen(backlog)
429    }
430}
431
432impl<T: Listener> PolledSocket<T> {
433    /// Polls for a new connection.
434    pub fn poll_accept(
435        &mut self,
436        cx: &mut Context<'_>,
437    ) -> Poll<io::Result<(T::Socket, T::Address)>> {
438        self.poll_io(cx, InterestSlot::Read, PollEvents::IN, |this| {
439            this.socket.accept()
440        })
441    }
442
443    /// Accepts a new connection.
444    pub async fn accept(&mut self) -> io::Result<(T::Socket, T::Address)> {
445        poll_fn(|cx| self.poll_accept(cx)).await
446    }
447}
448
449struct SplitInner<T> {
450    poll: Mutex<PollImpl<dyn PollSocketReady>>, // must be first--some executors require that it's dropped before socket.
451    socket: T,
452}
453
454/// The read half of a socket, via [`PolledSocket::split`].
455pub struct ReadHalf<T> {
456    inner: Arc<SplitInner<T>>,
457}
458
459impl<T> ReadHalf<T> {
460    /// Gets a reference to the inner socket.
461    pub fn get(&self) -> &T {
462        &self.inner.socket
463    }
464
465    /// Calls nonblocking operation `f` when the socket is ready for read.
466    ///
467    /// If `f` returns `Err(err)` with `err.kind() ==
468    /// io::ErrorKind::WouldBlock`, then this re-polls the socket for readiness
469    /// and returns `Poll::Pending`.
470    pub fn poll_io<F, R>(&mut self, cx: &mut Context<'_>, mut f: F) -> Poll<io::Result<R>>
471    where
472        F: FnMut(&mut Self) -> io::Result<R>,
473    {
474        loop {
475            std::task::ready!(self.inner.poll.lock().poll_socket_ready(
476                cx,
477                InterestSlot::Read,
478                PollEvents::IN
479            ));
480            match f(self) {
481                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
482                    self.inner
483                        .poll
484                        .lock()
485                        .clear_socket_ready(InterestSlot::Read);
486                }
487                r => break Poll::Ready(r),
488            }
489        }
490    }
491}
492
493/// The write half of a socket, via [`PolledSocket::split`].
494pub struct WriteHalf<T> {
495    inner: Arc<SplitInner<T>>,
496}
497
498impl<T> WriteHalf<T> {
499    /// Gets a reference to the inner socket.
500    pub fn get(&self) -> &T {
501        &self.inner.socket
502    }
503
504    /// Calls nonblocking operation `f` when the socket is ready for write.
505    ///
506    /// If `f` returns `Err(err)` with `err.kind() ==
507    /// io::ErrorKind::WouldBlock`, then this re-polls the socket for readiness
508    /// and returns `Poll::Pending`.
509    pub fn poll_io<F, R>(&mut self, cx: &mut Context<'_>, mut f: F) -> Poll<io::Result<R>>
510    where
511        F: FnMut(&mut Self) -> io::Result<R>,
512    {
513        loop {
514            std::task::ready!(self.inner.poll.lock().poll_socket_ready(
515                cx,
516                InterestSlot::Write,
517                PollEvents::OUT
518            ));
519            match f(self) {
520                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
521                    self.inner
522                        .poll
523                        .lock()
524                        .clear_socket_ready(InterestSlot::Write);
525                }
526                r => break Poll::Ready(r),
527            }
528        }
529    }
530}
531
532impl<T: AsSockRef> PollReady for ReadHalf<T> {
533    fn poll_ready(&mut self, cx: &mut Context<'_>, events: PollEvents) -> Poll<PollEvents> {
534        self.inner
535            .poll
536            .lock()
537            .poll_socket_ready(cx, InterestSlot::Read, events)
538    }
539}
540
541impl<T: AsSockRef> AsyncRead for ReadHalf<T> {
542    fn poll_read(
543        mut self: Pin<&mut Self>,
544        cx: &mut Context<'_>,
545        buf: &mut [u8],
546    ) -> Poll<io::Result<usize>> {
547        self.poll_io(cx, |this| (&*this.inner.socket.as_sock_ref()).read(buf))
548    }
549
550    fn poll_read_vectored(
551        mut self: Pin<&mut Self>,
552        cx: &mut Context<'_>,
553        bufs: &mut [io::IoSliceMut<'_>],
554    ) -> Poll<io::Result<usize>> {
555        self.poll_io(cx, |this| {
556            (&*this.inner.socket.as_sock_ref()).read_vectored(bufs)
557        })
558    }
559}
560
561impl<T: AsSockRef> PollReady for WriteHalf<T> {
562    fn poll_ready(&mut self, cx: &mut Context<'_>, events: PollEvents) -> Poll<PollEvents> {
563        self.inner
564            .poll
565            .lock()
566            .poll_socket_ready(cx, InterestSlot::Write, events)
567    }
568}
569
570impl<T: AsSockRef> AsyncWrite for WriteHalf<T> {
571    fn poll_write(
572        mut self: Pin<&mut Self>,
573        cx: &mut Context<'_>,
574        buf: &[u8],
575    ) -> Poll<io::Result<usize>> {
576        self.poll_io(cx, |this| (&*this.inner.socket.as_sock_ref()).write(buf))
577    }
578
579    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
580        self.poll_io(cx, |this| (&*this.inner.socket.as_sock_ref()).flush())
581    }
582
583    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
584        Poll::Ready(self.inner.socket.as_sock_ref().shutdown(Shutdown::Write))
585    }
586
587    fn poll_write_vectored(
588        mut self: Pin<&mut Self>,
589        cx: &mut Context<'_>,
590        bufs: &[io::IoSlice<'_>],
591    ) -> Poll<io::Result<usize>> {
592        self.poll_io(cx, |this| {
593            (&*this.inner.socket.as_sock_ref()).write_vectored(bufs)
594        })
595    }
596}
597
598#[cfg(test)]
599mod tests {
600    use super::PolledSocket;
601    use crate::DefaultDriver;
602    use futures::AsyncReadExt;
603    use futures::AsyncWriteExt;
604    use pal_async_test::async_test;
605    use unix_socket::UnixStream;
606
607    #[async_test]
608    async fn split(driver: DefaultDriver) {
609        let (a, b) = UnixStream::pair().unwrap();
610        let a = PolledSocket::new(&driver, a).unwrap();
611        let b = PolledSocket::new(&driver, b).unwrap();
612        let (mut ar, mut aw) = a.split();
613        let (br, mut bw) = b.split();
614        let copy = async {
615            futures::io::copy(br, &mut bw).await.unwrap();
616            bw.close().await.unwrap();
617        };
618        let rest = async {
619            aw.write_all(b"abc").await.unwrap();
620            let mut v = vec![0; 3];
621            ar.read_exact(&mut v).await.unwrap();
622            aw.write_all(b"def").await.unwrap();
623            aw.close().await.unwrap();
624            ar.read_to_end(&mut v).await.unwrap();
625            assert_eq!(&v, b"abcdef");
626        };
627        futures::future::join(copy, rest).await;
628    }
629}