pal_async/unix/
local.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! A thread-local executor based on poll(2).
5
6use super::wait::FdWait;
7use crate::fd::FdReadyDriver;
8use crate::fd::PollFdReady;
9use crate::interest::InterestSlot;
10use crate::interest::PollEvents;
11use crate::interest::PollInterestSet;
12use crate::local::LocalDriver;
13use crate::local::LocalInner;
14use crate::sparsevec::SparseVec;
15use crate::wait::WaitDriver;
16use crate::waker::WakerList;
17use pal::unix::SyscallResult;
18use pal::unix::while_eintr;
19use pal_event::Event;
20use std::io;
21use std::os::unix::prelude::*;
22use std::sync::Arc;
23use std::sync::OnceLock;
24use std::task::Context;
25use std::task::Poll;
26use std::time::Duration;
27
28#[derive(Debug, Default)]
29pub(crate) struct WaitState {
30    pollfds: Vec<libc::pollfd>,
31}
32
33#[derive(Debug, Default)]
34pub(crate) struct WaitCancel {
35    event: OnceLock<Event>,
36}
37
38impl WaitCancel {
39    pub fn cancel_wait(&self) {
40        self.event.get().unwrap().signal();
41    }
42}
43
44#[derive(Debug, Default)]
45pub(crate) struct State {
46    entries: SparseVec<FdEntry>,
47}
48
49#[derive(Debug)]
50struct FdEntry {
51    fd: RawFd,
52    interests: PollInterestSet,
53}
54
55impl State {
56    fn add_fd(&mut self, fd: RawFd) -> usize {
57        self.entries.add(FdEntry {
58            fd,
59            interests: Default::default(),
60        })
61    }
62
63    fn remove_fd(&mut self, index: usize) {
64        self.entries.remove(index);
65    }
66
67    fn poll_fd(
68        &mut self,
69        cx: &mut Context<'_>,
70        index: usize,
71        slot: InterestSlot,
72        events: PollEvents,
73    ) -> Poll<PollEvents> {
74        let entry = &mut self.entries[index];
75        entry.interests.poll_ready(cx, slot, events)
76    }
77
78    fn clear_fd_ready(&mut self, index: usize, slot: InterestSlot) {
79        let entry = &mut self.entries[index];
80        entry.interests.clear_ready(slot)
81    }
82
83    pub fn pre_wait(&mut self, wait_state: &mut WaitState, wait_cancel: &WaitCancel) {
84        wait_state.pollfds.clear();
85        wait_state
86            .pollfds
87            .extend(self.entries.iter().map(|(_, entry)| {
88                let events = entry.interests.events_to_poll();
89                if !events.is_empty() {
90                    libc::pollfd {
91                        fd: entry.fd,
92                        events: events.to_poll_events(),
93                        revents: 0,
94                    }
95                } else {
96                    libc::pollfd {
97                        fd: -1,
98                        events: 0,
99                        revents: 0,
100                    }
101                }
102            }));
103
104        let event = wait_cancel.event.get_or_init(Event::new);
105        wait_state.pollfds.push(libc::pollfd {
106            fd: event.as_fd().as_raw_fd(),
107            events: libc::POLLIN,
108            revents: 0,
109        });
110    }
111
112    pub fn post_wait(&mut self, wait_state: &mut WaitState, wakers: &mut WakerList) {
113        for ((_, entry), pollfd) in self.entries.iter_mut().zip(wait_state.pollfds.iter_mut()) {
114            let revents = PollEvents::from_poll_events(pollfd.revents);
115            if !revents.is_empty() {
116                entry.interests.wake_ready(revents, wakers);
117            }
118        }
119    }
120}
121
122#[cfg(target_os = "linux")]
123fn poll(pollfds: &mut [libc::pollfd], timeout: Option<&Duration>) -> i32 {
124    let timeout = timeout.map(|timeout| libc::timespec {
125        tv_sec: timeout.as_secs().try_into().unwrap(),
126        tv_nsec: timeout.subsec_nanos().into(),
127    });
128
129    // SAFETY: calling as documented.
130    unsafe {
131        libc::ppoll(
132            pollfds.as_mut_ptr(),
133            pollfds.len().try_into().unwrap(),
134            timeout.as_ref().map_or(std::ptr::null(), |t| t),
135            std::ptr::null(),
136        )
137    }
138}
139
140#[cfg(not(target_os = "linux"))]
141fn poll(pollfds: &mut [libc::pollfd], timeout: Option<&Duration>) -> i32 {
142    // SAFETY: calling as documented.
143    unsafe {
144        libc::poll(
145            pollfds.as_mut_ptr(),
146            pollfds.len().try_into().unwrap(),
147            timeout.map_or(-1, |t| t.as_millis().min(i32::MAX as u128) as i32),
148        )
149    }
150}
151
152impl WaitState {
153    pub fn wait(&mut self, wait_cancel: &WaitCancel, timeout: Option<Duration>) {
154        while_eintr(|| poll(&mut self.pollfds, timeout.as_ref()).syscall_result())
155            .expect("ppoll unexpectedly failed");
156        if self.pollfds.last().unwrap().revents != 0 {
157            // Consume the wake event.
158            assert!(wait_cancel.event.get().unwrap().try_wait());
159        }
160    }
161}
162
163impl FdReadyDriver for LocalDriver {
164    type FdReady = FdReady;
165
166    fn new_fd_ready(&self, socket: RawFd) -> io::Result<Self::FdReady> {
167        let index = self.inner.lock_sys_state().add_fd(socket);
168        Ok(FdReady {
169            inner: self.inner.clone(),
170            index,
171        })
172    }
173}
174
175#[derive(Debug)]
176pub struct FdReady {
177    inner: Arc<LocalInner>,
178    index: usize,
179}
180
181impl Drop for FdReady {
182    fn drop(&mut self) {
183        self.inner.lock_sys_state().remove_fd(self.index);
184    }
185}
186
187impl PollFdReady for FdReady {
188    fn poll_fd_ready(
189        &mut self,
190        cx: &mut Context<'_>,
191        slot: InterestSlot,
192        events: PollEvents,
193    ) -> Poll<PollEvents> {
194        self.inner
195            .lock_sys_state()
196            .poll_fd(cx, self.index, slot, events)
197    }
198
199    fn clear_fd_ready(&mut self, slot: InterestSlot) {
200        self.inner.lock_sys_state().clear_fd_ready(self.index, slot)
201    }
202}
203
204impl WaitDriver for LocalDriver {
205    type Wait = FdWait<FdReady>;
206
207    fn new_wait(&self, fd: RawFd, read_size: usize) -> io::Result<Self::Wait> {
208        Ok(FdWait::new(fd, self.new_fd_ready(fd)?, read_size))
209    }
210}