1#[cfg(unix)]
7use super::fd;
8use super::interest::InterestSlot;
9use super::interest::PollEvents;
10use crate::driver::Driver;
11use crate::driver::PollImpl;
12use futures::AsyncRead;
13use futures::AsyncWrite;
14use parking_lot::Mutex;
15use std::fmt::Debug;
16use std::future::Future;
17use std::future::poll_fn;
18use std::io;
19use std::io::Read;
20use std::io::Write;
21use std::net::Shutdown;
22#[cfg(unix)]
23use std::os::unix::prelude::*;
24#[cfg(windows)]
25use std::os::windows::prelude::*;
26use std::path::Path;
27use std::pin::Pin;
28use std::sync::Arc;
29use std::task::Context;
30use std::task::Poll;
31use unix_socket::UnixStream;
32
33pub trait SocketReadyDriver: Unpin {
35 type SocketReady: 'static + PollSocketReady;
37
38 #[cfg(windows)]
40 fn new_socket_ready(&self, socket: RawSocket) -> io::Result<Self::SocketReady>;
41 #[cfg(unix)]
43 fn new_socket_ready(&self, socket: RawFd) -> io::Result<Self::SocketReady>;
44}
45
46#[cfg(unix)]
47impl<T: fd::FdReadyDriver> SocketReadyDriver for T {
48 type SocketReady = <Self as fd::FdReadyDriver>::FdReady;
49
50 fn new_socket_ready(&self, socket: RawFd) -> io::Result<Self::SocketReady> {
51 self.new_fd_ready(socket)
52 }
53}
54
55pub trait PollSocketReady: Unpin + Send + Sync {
57 fn poll_socket_ready(
59 &mut self,
60 cx: &mut Context<'_>,
61 slot: InterestSlot,
62 events: PollEvents,
63 ) -> Poll<PollEvents>;
64
65 fn clear_socket_ready(&mut self, slot: InterestSlot);
68}
69
70#[cfg(unix)]
71impl<T: fd::PollFdReady> PollSocketReady for T {
72 fn poll_socket_ready(
73 &mut self,
74 cx: &mut Context<'_>,
75 slot: InterestSlot,
76 events: PollEvents,
77 ) -> Poll<PollEvents> {
78 self.poll_fd_ready(cx, slot, events)
79 }
80
81 fn clear_socket_ready(&mut self, slot: InterestSlot) {
82 self.clear_fd_ready(slot)
83 }
84}
85
86pub struct PolledSocket<T> {
88 poll: PollImpl<dyn PollSocketReady>, socket: T,
90}
91
92pub trait AsSockRef: Unpin {
94 fn as_sock_ref(&self) -> socket2::SockRef<'_>;
96}
97
98impl<T: Unpin> AsSockRef for T
99where
100 for<'a> &'a T: Into<socket2::SockRef<'a>>,
101{
102 fn as_sock_ref(&self) -> socket2::SockRef<'_> {
103 self.into()
104 }
105}
106
107impl<T: AsSockRef> PolledSocket<T> {
108 pub fn new(driver: &(impl ?Sized + Driver), socket: T) -> io::Result<Self> {
110 let sock_ref = socket.as_sock_ref();
111 sock_ref.set_nonblocking(true)?;
112 #[cfg(windows)]
113 let fd = sock_ref.as_raw_socket();
114 #[cfg(unix)]
115 let fd = sock_ref.as_raw_fd();
116 Ok(Self {
117 poll: driver.new_dyn_socket_ready(fd)?,
118 socket,
119 })
120 }
121
122 pub fn into_inner(self) -> T {
124 let sock_ref = self.socket.as_sock_ref();
125 sock_ref.set_nonblocking(false).unwrap();
126 self.socket
127 }
128}
129
130impl<T> PolledSocket<T> {
131 pub fn get(&self) -> &T {
133 &self.socket
134 }
135
136 pub fn get_mut(&mut self) -> &mut T {
138 &mut self.socket
139 }
140
141 pub fn convert<T2: From<T>>(self) -> PolledSocket<T2> {
143 PolledSocket {
144 socket: T2::from(self.socket),
145 poll: self.poll,
146 }
147 }
148}
149
150pub trait PollReady {
152 fn poll_ready(&mut self, cx: &mut Context<'_>, events: PollEvents) -> Poll<PollEvents>;
154}
155
156pub trait PollReadyExt {
158 fn wait_ready(&mut self, events: PollEvents) -> Ready<'_, Self>
160 where
161 Self: Unpin + Sized;
162}
163
164impl<T: PollReady + Unpin> PollReadyExt for T {
165 fn wait_ready(&mut self, events: PollEvents) -> Ready<'_, Self>
166 where
167 Self: Unpin + Sized,
168 {
169 Ready(self, events)
170 }
171}
172
173pub struct Ready<'a, T>(&'a mut T, PollEvents);
175
176impl<T: Unpin + PollReady> Future for Ready<'_, T> {
177 type Output = PollEvents;
178
179 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
180 let this = self.get_mut();
181 this.0.poll_ready(cx, this.1)
182 }
183}
184
185impl<T> PolledSocket<T> {
186 pub fn poll_io<F, R>(
195 &mut self,
196 cx: &mut Context<'_>,
197 slot: InterestSlot,
198 events: PollEvents,
199 mut f: F,
200 ) -> Poll<io::Result<R>>
201 where
202 F: FnMut(&mut Self) -> io::Result<R>,
203 {
204 loop {
205 std::task::ready!(self.poll.poll_socket_ready(cx, slot, events));
206 match f(self) {
207 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
208 self.poll.clear_socket_ready(slot);
209 }
210 r => break Poll::Ready(r),
211 }
212 }
213 }
214}
215
216impl<T: AsSockRef> PollReady for PolledSocket<T> {
217 fn poll_ready(&mut self, cx: &mut Context<'_>, events: PollEvents) -> Poll<PollEvents> {
218 self.poll.poll_socket_ready(cx, InterestSlot::Read, events)
219 }
220}
221
222impl<T> PolledSocket<T>
223where
224 T: AsSockRef + Read + Write,
225{
226 pub fn split(self) -> (ReadHalf<T>, WriteHalf<T>) {
234 let inner = Arc::new(SplitInner {
235 poll: Mutex::new(self.poll),
236 socket: self.socket,
237 });
238 (
239 ReadHalf {
240 inner: inner.clone(),
241 },
242 WriteHalf { inner },
243 )
244 }
245}
246
247fn is_connect_incomplete_error(err: &io::Error) -> bool {
248 if err.kind() == io::ErrorKind::WouldBlock {
250 return true;
251 }
252 #[cfg(unix)]
254 if err.raw_os_error() == Some(libc::EINPROGRESS) {
255 return true;
256 }
257 false
258}
259
260impl PolledSocket<socket2::Socket> {
261 pub async fn connect(&mut self, addr: &socket2::SockAddr) -> io::Result<()> {
263 match self.socket.connect(addr) {
264 Ok(()) => Ok(()),
265 Err(err) if is_connect_incomplete_error(&err) => {
266 self.poll.clear_socket_ready(InterestSlot::Write);
267 poll_fn(|cx| {
268 self.poll
269 .poll_socket_ready(cx, InterestSlot::Write, PollEvents::OUT)
270 })
271 .await;
272 if let Some(err) = self.socket.take_error()? {
273 return Err(err);
274 }
275 Ok(())
276 }
277 Err(err) => Err(err),
278 }
279 }
280}
281
282impl PolledSocket<UnixStream> {
283 pub async fn connect_unix(
285 driver: &(impl ?Sized + Driver),
286 addr: impl AsRef<Path>,
287 ) -> io::Result<Self> {
288 let socket = socket2::Socket::new(socket2::Domain::UNIX, socket2::Type::STREAM, None)?;
289 let mut socket = PolledSocket::new(driver, socket)?;
290 socket
291 .connect(&socket2::SockAddr::unix(addr.as_ref())?)
292 .await?;
293 Ok(socket.convert())
294 }
295}
296
297impl<T: AsSockRef + Read> AsyncRead for PolledSocket<T> {
298 fn poll_read(
299 mut self: Pin<&mut Self>,
300 cx: &mut Context<'_>,
301 buf: &mut [u8],
302 ) -> Poll<io::Result<usize>> {
303 self.poll_io(cx, InterestSlot::Read, PollEvents::IN, |this| {
304 this.socket.read(buf)
305 })
306 }
307
308 fn poll_read_vectored(
309 mut self: Pin<&mut Self>,
310 cx: &mut Context<'_>,
311 bufs: &mut [io::IoSliceMut<'_>],
312 ) -> Poll<io::Result<usize>> {
313 self.poll_io(cx, InterestSlot::Read, PollEvents::IN, |this| {
314 this.socket.read_vectored(bufs)
315 })
316 }
317}
318
319impl<T: AsSockRef + Write> AsyncWrite for PolledSocket<T> {
320 fn poll_write(
321 mut self: Pin<&mut Self>,
322 cx: &mut Context<'_>,
323 buf: &[u8],
324 ) -> Poll<io::Result<usize>> {
325 self.poll_io(cx, InterestSlot::Write, PollEvents::OUT, |this| {
326 this.socket.write(buf)
327 })
328 }
329
330 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
331 self.poll_io(cx, InterestSlot::Write, PollEvents::OUT, |this| {
332 this.socket.flush()
333 })
334 }
335
336 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
337 Poll::Ready(self.socket.as_sock_ref().shutdown(Shutdown::Write))
338 }
339
340 fn poll_write_vectored(
341 mut self: Pin<&mut Self>,
342 cx: &mut Context<'_>,
343 bufs: &[io::IoSlice<'_>],
344 ) -> Poll<io::Result<usize>> {
345 self.poll_io(cx, InterestSlot::Write, PollEvents::OUT, |this| {
346 this.socket.write_vectored(bufs)
347 })
348 }
349}
350
351pub trait Listener: AsSockRef {
353 type Socket: AsSockRef + Read + Write + Into<socket2::Socket>;
355 type Address: Debug;
357
358 fn accept(&self) -> io::Result<(Self::Socket, Self::Address)>;
360 fn local_addr(&self) -> io::Result<Self::Address>;
362}
363
364impl<'a, T> Listener for &'a T
365where
366 T: Listener,
367 &'a T: AsSockRef,
368{
369 type Socket = T::Socket;
370 type Address = T::Address;
371
372 fn accept(&self) -> io::Result<(Self::Socket, Self::Address)> {
373 (**self).accept()
374 }
375
376 fn local_addr(&self) -> io::Result<Self::Address> {
377 (**self).local_addr()
378 }
379}
380
381macro_rules! listener {
382 ($ty:ty, $socket:ty, $addr:ty) => {
383 impl Listener for $ty {
384 type Socket = $socket;
385 type Address = $addr;
386 fn accept(&self) -> io::Result<(Self::Socket, Self::Address)> {
387 <$ty>::accept(self)
388 }
389 fn local_addr(&self) -> io::Result<Self::Address> {
390 <$ty>::local_addr(self)
391 }
392 }
393 };
394}
395
396listener!(
397 std::net::TcpListener,
398 std::net::TcpStream,
399 std::net::SocketAddr
400);
401
402#[cfg(unix)]
403listener!(
404 unix_socket::UnixListener,
405 UnixStream,
406 std::os::unix::net::SocketAddr
407);
408
409#[cfg(windows)]
410impl Listener for unix_socket::UnixListener {
411 type Socket = UnixStream;
412 type Address = ();
413
414 fn accept(&self) -> io::Result<(Self::Socket, Self::Address)> {
415 self.accept()
416 }
417
418 fn local_addr(&self) -> io::Result<Self::Address> {
419 Ok(())
420 }
421}
422
423listener!(socket2::Socket, socket2::Socket, socket2::SockAddr);
424
425impl PolledSocket<socket2::Socket> {
426 pub fn listen(&self, backlog: i32) -> io::Result<()> {
428 self.socket.listen(backlog)
429 }
430}
431
432impl<T: Listener> PolledSocket<T> {
433 pub fn poll_accept(
435 &mut self,
436 cx: &mut Context<'_>,
437 ) -> Poll<io::Result<(T::Socket, T::Address)>> {
438 self.poll_io(cx, InterestSlot::Read, PollEvents::IN, |this| {
439 this.socket.accept()
440 })
441 }
442
443 pub async fn accept(&mut self) -> io::Result<(T::Socket, T::Address)> {
445 poll_fn(|cx| self.poll_accept(cx)).await
446 }
447}
448
449struct SplitInner<T> {
450 poll: Mutex<PollImpl<dyn PollSocketReady>>, socket: T,
452}
453
454pub struct ReadHalf<T> {
456 inner: Arc<SplitInner<T>>,
457}
458
459impl<T> ReadHalf<T> {
460 pub fn get(&self) -> &T {
462 &self.inner.socket
463 }
464
465 pub fn poll_io<F, R>(&mut self, cx: &mut Context<'_>, mut f: F) -> Poll<io::Result<R>>
471 where
472 F: FnMut(&mut Self) -> io::Result<R>,
473 {
474 loop {
475 std::task::ready!(self.inner.poll.lock().poll_socket_ready(
476 cx,
477 InterestSlot::Read,
478 PollEvents::IN
479 ));
480 match f(self) {
481 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
482 self.inner
483 .poll
484 .lock()
485 .clear_socket_ready(InterestSlot::Read);
486 }
487 r => break Poll::Ready(r),
488 }
489 }
490 }
491}
492
493pub struct WriteHalf<T> {
495 inner: Arc<SplitInner<T>>,
496}
497
498impl<T> WriteHalf<T> {
499 pub fn get(&self) -> &T {
501 &self.inner.socket
502 }
503
504 pub fn poll_io<F, R>(&mut self, cx: &mut Context<'_>, mut f: F) -> Poll<io::Result<R>>
510 where
511 F: FnMut(&mut Self) -> io::Result<R>,
512 {
513 loop {
514 std::task::ready!(self.inner.poll.lock().poll_socket_ready(
515 cx,
516 InterestSlot::Write,
517 PollEvents::OUT
518 ));
519 match f(self) {
520 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
521 self.inner
522 .poll
523 .lock()
524 .clear_socket_ready(InterestSlot::Write);
525 }
526 r => break Poll::Ready(r),
527 }
528 }
529 }
530}
531
532impl<T: AsSockRef> PollReady for ReadHalf<T> {
533 fn poll_ready(&mut self, cx: &mut Context<'_>, events: PollEvents) -> Poll<PollEvents> {
534 self.inner
535 .poll
536 .lock()
537 .poll_socket_ready(cx, InterestSlot::Read, events)
538 }
539}
540
541impl<T: AsSockRef> AsyncRead for ReadHalf<T> {
542 fn poll_read(
543 mut self: Pin<&mut Self>,
544 cx: &mut Context<'_>,
545 buf: &mut [u8],
546 ) -> Poll<io::Result<usize>> {
547 self.poll_io(cx, |this| (&*this.inner.socket.as_sock_ref()).read(buf))
548 }
549
550 fn poll_read_vectored(
551 mut self: Pin<&mut Self>,
552 cx: &mut Context<'_>,
553 bufs: &mut [io::IoSliceMut<'_>],
554 ) -> Poll<io::Result<usize>> {
555 self.poll_io(cx, |this| {
556 (&*this.inner.socket.as_sock_ref()).read_vectored(bufs)
557 })
558 }
559}
560
561impl<T: AsSockRef> PollReady for WriteHalf<T> {
562 fn poll_ready(&mut self, cx: &mut Context<'_>, events: PollEvents) -> Poll<PollEvents> {
563 self.inner
564 .poll
565 .lock()
566 .poll_socket_ready(cx, InterestSlot::Write, events)
567 }
568}
569
570impl<T: AsSockRef> AsyncWrite for WriteHalf<T> {
571 fn poll_write(
572 mut self: Pin<&mut Self>,
573 cx: &mut Context<'_>,
574 buf: &[u8],
575 ) -> Poll<io::Result<usize>> {
576 self.poll_io(cx, |this| (&*this.inner.socket.as_sock_ref()).write(buf))
577 }
578
579 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
580 self.poll_io(cx, |this| (&*this.inner.socket.as_sock_ref()).flush())
581 }
582
583 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
584 Poll::Ready(self.inner.socket.as_sock_ref().shutdown(Shutdown::Write))
585 }
586
587 fn poll_write_vectored(
588 mut self: Pin<&mut Self>,
589 cx: &mut Context<'_>,
590 bufs: &[io::IoSlice<'_>],
591 ) -> Poll<io::Result<usize>> {
592 self.poll_io(cx, |this| {
593 (&*this.inner.socket.as_sock_ref()).write_vectored(bufs)
594 })
595 }
596}
597
598#[cfg(test)]
599mod tests {
600 use super::PolledSocket;
601 use crate::DefaultDriver;
602 use futures::AsyncReadExt;
603 use futures::AsyncWriteExt;
604 use pal_async_test::async_test;
605 use unix_socket::UnixStream;
606
607 #[async_test]
608 async fn split(driver: DefaultDriver) {
609 let (a, b) = UnixStream::pair().unwrap();
610 let a = PolledSocket::new(&driver, a).unwrap();
611 let b = PolledSocket::new(&driver, b).unwrap();
612 let (mut ar, mut aw) = a.split();
613 let (br, mut bw) = b.split();
614 let copy = async {
615 futures::io::copy(br, &mut bw).await.unwrap();
616 bw.close().await.unwrap();
617 };
618 let rest = async {
619 aw.write_all(b"abc").await.unwrap();
620 let mut v = vec![0; 3];
621 ar.read_exact(&mut v).await.unwrap();
622 aw.write_all(b"def").await.unwrap();
623 aw.close().await.unwrap();
624 ar.read_to_end(&mut v).await.unwrap();
625 assert_eq!(&v, b"abcdef");
626 };
627 futures::future::join(copy, rest).await;
628 }
629}