pal_async/
local.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! A local executor, for running a single task with IO on the current thread.
5
6use self::timer::Timer;
7use crate::sys::local as sys;
8use crate::timer::Instant;
9use crate::timer::PollTimer;
10use crate::timer::TimerDriver;
11use crate::timer::TimerQueue;
12use crate::timer::TimerResult;
13use crate::waker::WakerList;
14use futures::task::ArcWake;
15use futures::task::waker_ref;
16use parking_lot::Condvar;
17use parking_lot::MappedMutexGuard;
18use parking_lot::Mutex;
19use parking_lot::MutexGuard;
20use std::future::Future;
21use std::pin::Pin;
22use std::pin::pin;
23use std::sync::Arc;
24use std::task::Context;
25use std::task::Poll;
26
27/// Blocks the current thread until the given future completes.
28pub fn block_on<Fut>(fut: Fut) -> Fut::Output
29where
30    Fut: Future,
31{
32    block_with_io(|_| fut)
33}
34
35/// Polls a future that needs to issue IO until it completes.
36pub fn block_with_io<F, R>(f: F) -> R
37where
38    F: AsyncFnOnce(LocalDriver) -> R,
39{
40    let mut executor = LocalExecutor::new();
41    let fut = f(executor.driver());
42    executor.run_until(pin!(fut))
43}
44
45/// An executor that runs on a single thread and runs only one future.
46struct LocalExecutor {
47    inner: Arc<LocalInner>,
48}
49
50impl LocalExecutor {
51    fn new() -> Self {
52        Self {
53            inner: Arc::new(LocalInner::default()),
54        }
55    }
56
57    fn driver(&self) -> LocalDriver {
58        LocalDriver {
59            inner: self.inner.clone(),
60        }
61    }
62
63    fn run_until<F: Future>(&mut self, mut fut: Pin<&mut F>) -> F::Output {
64        let waker = waker_ref(&self.inner);
65        let mut cx = Context::from_waker(&waker);
66        loop {
67            match fut.as_mut().poll(&mut cx) {
68                Poll::Ready(r) => break r,
69                Poll::Pending => self.inner.wait(),
70            }
71        }
72    }
73}
74
75/// An IO driver for single-task use on a single thread.
76#[derive(Debug, Clone)]
77pub struct LocalDriver {
78    pub(crate) inner: Arc<LocalInner>,
79}
80
81#[derive(Default, Debug)]
82pub(crate) struct LocalInner {
83    state: Mutex<LocalState>,
84    wait_state: Mutex<sys::WaitState>,
85    condvar: Condvar,
86    wait_cancel: sys::WaitCancel,
87}
88
89#[derive(Debug, PartialEq, Eq)]
90enum OpState {
91    // The executor is running.
92    Running,
93    // The executor should poll its task again without waiting.
94    RunAgain,
95    // The executor is waiting on IO.
96    Waiting,
97    // The executor wait has been cancelled.
98    Woken,
99}
100
101impl Default for OpState {
102    fn default() -> Self {
103        Self::Running
104    }
105}
106
107#[derive(Debug, Default)]
108struct LocalState {
109    op_state: OpState,
110    state_waiters: usize,
111    sys: sys::State,
112    timers: TimerQueue,
113}
114
115impl LocalInner {
116    pub fn lock_sys_state(&self) -> MappedMutexGuard<'_, sys::State> {
117        MutexGuard::map(self.lock_state(), |x| &mut x.sys)
118    }
119
120    // Locks the state for mutation.
121    //
122    // If the executor is currently waiting, then wakes up the executor first to
123    // ensure that the executor never sees state changes between pre_wait and
124    // post_wait.
125    fn lock_state(&self) -> MutexGuard<'_, LocalState> {
126        let mut guard = self.state.lock();
127
128        match guard.op_state {
129            OpState::Running | OpState::RunAgain => return guard,
130            OpState::Waiting => {
131                guard.state_waiters += 1;
132                guard.op_state = OpState::Woken;
133                drop(guard);
134                self.wait_cancel.cancel_wait();
135                guard = self.state.lock();
136            }
137            OpState::Woken => {
138                guard.state_waiters += 1;
139            }
140        };
141        self.condvar
142            .wait_while(&mut guard, |state| state.op_state == OpState::Woken);
143        assert_ne!(guard.op_state, OpState::Waiting);
144        guard.state_waiters -= 1;
145        if guard.state_waiters == 0 {
146            // Notify the executor that it can proceed with its wait after the
147            // mutex guard is dropped.
148            self.condvar.notify_all();
149        }
150        guard
151    }
152
153    fn wait(&self) {
154        let mut state = self.state.lock();
155        // Wait until any threads that want to manipulate the state have
156        // done so.
157        self.condvar.wait_while(&mut state, |state| {
158            state.op_state == OpState::Running && state.state_waiters > 0
159        });
160        if state.op_state != OpState::Running {
161            assert_eq!(state.op_state, OpState::RunAgain);
162            state.op_state = OpState::Running;
163            return;
164        }
165
166        let mut wait_state = self
167            .wait_state
168            .try_lock()
169            .expect("wait should not be called concurrently");
170
171        state.sys.pre_wait(&mut wait_state, &self.wait_cancel);
172
173        let timeout = state.timers.next_deadline().map(|deadline| {
174            let now = Instant::now();
175            deadline.max(now) - now
176        });
177
178        {
179            state.op_state = OpState::Waiting;
180            drop(state);
181            wait_state.wait(&self.wait_cancel, timeout);
182            state = self.state.lock();
183            state.op_state = OpState::Running;
184        }
185
186        let mut wakers = WakerList::default();
187        state.sys.post_wait(&mut wait_state, &mut wakers);
188        state.timers.wake_expired(&mut wakers);
189        drop(state);
190        wakers.wake();
191        // Notify mutators that the wait has finished.
192        self.condvar.notify_all();
193    }
194}
195
196impl ArcWake for LocalInner {
197    fn wake_by_ref(arc_self: &Arc<Self>) {
198        let mut state = arc_self.state.lock();
199        match state.op_state {
200            OpState::Running => state.op_state = OpState::RunAgain,
201            OpState::RunAgain => {}
202            OpState::Waiting => {
203                state.op_state = OpState::Woken;
204                drop(state);
205                arc_self.wait_cancel.cancel_wait();
206            }
207            OpState::Woken => {}
208        }
209    }
210}
211
212// Use a separate module so that `Timer` is not visible.
213mod timer {
214    use super::LocalInner;
215    use crate::timer::TimerQueueId;
216    use std::sync::Arc;
217
218    #[derive(Debug)]
219    pub struct Timer {
220        pub(super) inner: Arc<LocalInner>,
221        pub(super) id: TimerQueueId,
222    }
223}
224
225impl TimerDriver for LocalDriver {
226    type Timer = Timer;
227
228    fn new_timer(&self) -> Self::Timer {
229        let id = self.inner.lock_state().timers.add();
230        Timer {
231            inner: self.inner.clone(),
232            id,
233        }
234    }
235}
236
237impl Drop for Timer {
238    fn drop(&mut self) {
239        let _waker = self.inner.lock_state().timers.remove(self.id);
240    }
241}
242
243impl PollTimer for Timer {
244    fn poll_timer(&mut self, cx: &mut Context<'_>, deadline: Option<Instant>) -> Poll<Instant> {
245        let mut state = self.inner.lock_state();
246        if let Some(deadline) = deadline {
247            state.timers.set_deadline(self.id, deadline);
248        }
249        match state.timers.poll_deadline(cx, self.id) {
250            TimerResult::TimedOut(now) => Poll::Ready(now),
251            TimerResult::Pending(_) => Poll::Pending,
252        }
253    }
254
255    fn set_deadline(&mut self, deadline: Instant) {
256        self.inner
257            .lock_state()
258            .timers
259            .set_deadline(self.id, deadline);
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::block_with_io;
266    use crate::executor_tests;
267
268    #[test]
269    fn waker_works() {
270        block_with_io(|_| executor_tests::waker_tests())
271    }
272
273    #[test]
274    fn sleep_works() {
275        block_with_io(executor_tests::sleep_tests)
276    }
277
278    #[test]
279    fn wait_works() {
280        block_with_io(executor_tests::wait_tests)
281    }
282
283    #[test]
284    fn socket_works() {
285        block_with_io(executor_tests::socket_tests)
286    }
287
288    #[cfg(windows)]
289    #[test]
290    fn overlapped_file_works() {
291        block_with_io(executor_tests::windows::overlapped_file_tests)
292    }
293}