pal_async/
local.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! A local executor, for running a single task with IO on the current thread.
5
6use self::timer::Timer;
7use crate::sys::local as sys;
8use crate::timer::Instant;
9use crate::timer::PollTimer;
10use crate::timer::TimerDriver;
11use crate::timer::TimerQueue;
12use crate::timer::TimerResult;
13use crate::waker::WakerList;
14use futures::task::ArcWake;
15use futures::task::waker_ref;
16use parking_lot::Condvar;
17use parking_lot::MappedMutexGuard;
18use parking_lot::Mutex;
19use parking_lot::MutexGuard;
20use std::future::Future;
21use std::pin::Pin;
22use std::pin::pin;
23use std::sync::Arc;
24use std::task::Context;
25use std::task::Poll;
26
27/// Blocks the current thread until the given future completes.
28pub fn block_on<Fut>(fut: Fut) -> Fut::Output
29where
30    Fut: Future,
31{
32    block_with_io(|_| fut)
33}
34
35/// Polls a future that needs to issue IO until it completes.
36pub fn block_with_io<F, R>(f: F) -> R
37where
38    F: AsyncFnOnce(LocalDriver) -> R,
39{
40    let mut executor = LocalExecutor::new();
41    let fut = f(executor.driver());
42    executor.run_until(pin!(fut))
43}
44
45/// An executor that runs on a single thread and runs only one future.
46struct LocalExecutor {
47    inner: Arc<LocalInner>,
48}
49
50impl LocalExecutor {
51    fn new() -> Self {
52        Self {
53            inner: Arc::new(LocalInner::default()),
54        }
55    }
56
57    fn driver(&self) -> LocalDriver {
58        LocalDriver {
59            inner: self.inner.clone(),
60        }
61    }
62
63    fn run_until<F: Future>(&mut self, mut fut: Pin<&mut F>) -> F::Output {
64        let waker = waker_ref(&self.inner);
65        let mut cx = Context::from_waker(&waker);
66        loop {
67            match fut.as_mut().poll(&mut cx) {
68                Poll::Ready(r) => break r,
69                Poll::Pending => self.inner.wait(),
70            }
71        }
72    }
73}
74
75/// An IO driver for single-task use on a single thread.
76#[derive(Debug, Clone)]
77pub struct LocalDriver {
78    pub(crate) inner: Arc<LocalInner>,
79}
80
81#[derive(Default, Debug)]
82pub(crate) struct LocalInner {
83    state: Mutex<LocalState>,
84    wait_state: Mutex<sys::WaitState>,
85    condvar: Condvar,
86    wait_cancel: sys::WaitCancel,
87}
88
89#[derive(Debug, PartialEq, Eq, Default)]
90enum OpState {
91    // The executor is running.
92    #[default]
93    Running,
94    // The executor should poll its task again without waiting.
95    RunAgain,
96    // The executor is waiting on IO.
97    Waiting,
98    // The executor wait has been cancelled.
99    Woken,
100}
101
102#[derive(Debug, Default)]
103struct LocalState {
104    op_state: OpState,
105    state_waiters: usize,
106    sys: sys::State,
107    timers: TimerQueue,
108}
109
110impl LocalInner {
111    pub fn lock_sys_state(&self) -> MappedMutexGuard<'_, sys::State> {
112        MutexGuard::map(self.lock_state(), |x| &mut x.sys)
113    }
114
115    // Locks the state for mutation.
116    //
117    // If the executor is currently waiting, then wakes up the executor first to
118    // ensure that the executor never sees state changes between pre_wait and
119    // post_wait.
120    fn lock_state(&self) -> MutexGuard<'_, LocalState> {
121        let mut guard = self.state.lock();
122
123        match guard.op_state {
124            OpState::Running | OpState::RunAgain => return guard,
125            OpState::Waiting => {
126                guard.state_waiters += 1;
127                guard.op_state = OpState::Woken;
128                drop(guard);
129                self.wait_cancel.cancel_wait();
130                guard = self.state.lock();
131            }
132            OpState::Woken => {
133                guard.state_waiters += 1;
134            }
135        };
136        self.condvar
137            .wait_while(&mut guard, |state| state.op_state == OpState::Woken);
138        assert_ne!(guard.op_state, OpState::Waiting);
139        guard.state_waiters -= 1;
140        if guard.state_waiters == 0 {
141            // Notify the executor that it can proceed with its wait after the
142            // mutex guard is dropped.
143            self.condvar.notify_all();
144        }
145        guard
146    }
147
148    fn wait(&self) {
149        let mut state = self.state.lock();
150        // Wait until any threads that want to manipulate the state have
151        // done so.
152        self.condvar.wait_while(&mut state, |state| {
153            state.op_state == OpState::Running && state.state_waiters > 0
154        });
155        if state.op_state != OpState::Running {
156            assert_eq!(state.op_state, OpState::RunAgain);
157            state.op_state = OpState::Running;
158            return;
159        }
160
161        let mut wait_state = self
162            .wait_state
163            .try_lock()
164            .expect("wait should not be called concurrently");
165
166        state.sys.pre_wait(&mut wait_state, &self.wait_cancel);
167
168        let timeout = state.timers.next_deadline().map(|deadline| {
169            let now = Instant::now();
170            deadline.max(now) - now
171        });
172
173        {
174            state.op_state = OpState::Waiting;
175            drop(state);
176            wait_state.wait(&self.wait_cancel, timeout);
177            state = self.state.lock();
178            state.op_state = OpState::Running;
179        }
180
181        let mut wakers = WakerList::default();
182        state.sys.post_wait(&mut wait_state, &mut wakers);
183        state.timers.wake_expired(&mut wakers);
184        drop(state);
185        wakers.wake();
186        // Notify mutators that the wait has finished.
187        self.condvar.notify_all();
188    }
189}
190
191impl ArcWake for LocalInner {
192    fn wake_by_ref(arc_self: &Arc<Self>) {
193        let mut state = arc_self.state.lock();
194        match state.op_state {
195            OpState::Running => state.op_state = OpState::RunAgain,
196            OpState::RunAgain => {}
197            OpState::Waiting => {
198                state.op_state = OpState::Woken;
199                drop(state);
200                arc_self.wait_cancel.cancel_wait();
201            }
202            OpState::Woken => {}
203        }
204    }
205}
206
207// Use a separate module so that `Timer` is not visible.
208mod timer {
209    use super::LocalInner;
210    use crate::timer::TimerQueueId;
211    use std::sync::Arc;
212
213    #[derive(Debug)]
214    pub struct Timer {
215        pub(super) inner: Arc<LocalInner>,
216        pub(super) id: TimerQueueId,
217    }
218}
219
220impl TimerDriver for LocalDriver {
221    type Timer = Timer;
222
223    fn new_timer(&self) -> Self::Timer {
224        let id = self.inner.lock_state().timers.add();
225        Timer {
226            inner: self.inner.clone(),
227            id,
228        }
229    }
230}
231
232impl Drop for Timer {
233    fn drop(&mut self) {
234        let _waker = self.inner.lock_state().timers.remove(self.id);
235    }
236}
237
238impl PollTimer for Timer {
239    fn poll_timer(&mut self, cx: &mut Context<'_>, deadline: Option<Instant>) -> Poll<Instant> {
240        let mut state = self.inner.lock_state();
241        if let Some(deadline) = deadline {
242            state.timers.set_deadline(self.id, deadline);
243        }
244        match state.timers.poll_deadline(cx, self.id) {
245            TimerResult::TimedOut(now) => Poll::Ready(now),
246            TimerResult::Pending(_) => Poll::Pending,
247        }
248    }
249
250    fn set_deadline(&mut self, deadline: Instant) {
251        self.inner
252            .lock_state()
253            .timers
254            .set_deadline(self.id, deadline);
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::block_with_io;
261    use crate::executor_tests;
262
263    #[test]
264    fn waker_works() {
265        block_with_io(|_| executor_tests::waker_tests())
266    }
267
268    #[test]
269    fn sleep_works() {
270        block_with_io(executor_tests::sleep_tests)
271    }
272
273    #[test]
274    fn wait_works() {
275        block_with_io(executor_tests::wait_tests)
276    }
277
278    #[test]
279    fn socket_works() {
280        block_with_io(executor_tests::socket_tests)
281    }
282
283    #[cfg(windows)]
284    #[test]
285    fn overlapped_file_works() {
286        block_with_io(executor_tests::windows::overlapped_file_tests)
287    }
288}