1use self::timer::Timer;
7use crate::sys::local as sys;
8use crate::timer::Instant;
9use crate::timer::PollTimer;
10use crate::timer::TimerDriver;
11use crate::timer::TimerQueue;
12use crate::timer::TimerResult;
13use crate::waker::WakerList;
14use futures::task::ArcWake;
15use futures::task::waker_ref;
16use parking_lot::Condvar;
17use parking_lot::MappedMutexGuard;
18use parking_lot::Mutex;
19use parking_lot::MutexGuard;
20use std::future::Future;
21use std::pin::Pin;
22use std::pin::pin;
23use std::sync::Arc;
24use std::task::Context;
25use std::task::Poll;
26
27pub fn block_on<Fut>(fut: Fut) -> Fut::Output
29where
30 Fut: Future,
31{
32 block_with_io(|_| fut)
33}
34
35pub fn block_with_io<F, R>(f: F) -> R
37where
38 F: AsyncFnOnce(LocalDriver) -> R,
39{
40 let mut executor = LocalExecutor::new();
41 let fut = f(executor.driver());
42 executor.run_until(pin!(fut))
43}
44
45struct LocalExecutor {
47 inner: Arc<LocalInner>,
48}
49
50impl LocalExecutor {
51 fn new() -> Self {
52 Self {
53 inner: Arc::new(LocalInner::default()),
54 }
55 }
56
57 fn driver(&self) -> LocalDriver {
58 LocalDriver {
59 inner: self.inner.clone(),
60 }
61 }
62
63 fn run_until<F: Future>(&mut self, mut fut: Pin<&mut F>) -> F::Output {
64 let waker = waker_ref(&self.inner);
65 let mut cx = Context::from_waker(&waker);
66 loop {
67 match fut.as_mut().poll(&mut cx) {
68 Poll::Ready(r) => break r,
69 Poll::Pending => self.inner.wait(),
70 }
71 }
72 }
73}
74
75#[derive(Debug, Clone)]
77pub struct LocalDriver {
78 pub(crate) inner: Arc<LocalInner>,
79}
80
81#[derive(Default, Debug)]
82pub(crate) struct LocalInner {
83 state: Mutex<LocalState>,
84 wait_state: Mutex<sys::WaitState>,
85 condvar: Condvar,
86 wait_cancel: sys::WaitCancel,
87}
88
89#[derive(Debug, PartialEq, Eq)]
90enum OpState {
91 Running,
93 RunAgain,
95 Waiting,
97 Woken,
99}
100
101impl Default for OpState {
102 fn default() -> Self {
103 Self::Running
104 }
105}
106
107#[derive(Debug, Default)]
108struct LocalState {
109 op_state: OpState,
110 state_waiters: usize,
111 sys: sys::State,
112 timers: TimerQueue,
113}
114
115impl LocalInner {
116 pub fn lock_sys_state(&self) -> MappedMutexGuard<'_, sys::State> {
117 MutexGuard::map(self.lock_state(), |x| &mut x.sys)
118 }
119
120 fn lock_state(&self) -> MutexGuard<'_, LocalState> {
126 let mut guard = self.state.lock();
127
128 match guard.op_state {
129 OpState::Running | OpState::RunAgain => return guard,
130 OpState::Waiting => {
131 guard.state_waiters += 1;
132 guard.op_state = OpState::Woken;
133 drop(guard);
134 self.wait_cancel.cancel_wait();
135 guard = self.state.lock();
136 }
137 OpState::Woken => {
138 guard.state_waiters += 1;
139 }
140 };
141 self.condvar
142 .wait_while(&mut guard, |state| state.op_state == OpState::Woken);
143 assert_ne!(guard.op_state, OpState::Waiting);
144 guard.state_waiters -= 1;
145 if guard.state_waiters == 0 {
146 self.condvar.notify_all();
149 }
150 guard
151 }
152
153 fn wait(&self) {
154 let mut state = self.state.lock();
155 self.condvar.wait_while(&mut state, |state| {
158 state.op_state == OpState::Running && state.state_waiters > 0
159 });
160 if state.op_state != OpState::Running {
161 assert_eq!(state.op_state, OpState::RunAgain);
162 state.op_state = OpState::Running;
163 return;
164 }
165
166 let mut wait_state = self
167 .wait_state
168 .try_lock()
169 .expect("wait should not be called concurrently");
170
171 state.sys.pre_wait(&mut wait_state, &self.wait_cancel);
172
173 let timeout = state.timers.next_deadline().map(|deadline| {
174 let now = Instant::now();
175 deadline.max(now) - now
176 });
177
178 {
179 state.op_state = OpState::Waiting;
180 drop(state);
181 wait_state.wait(&self.wait_cancel, timeout);
182 state = self.state.lock();
183 state.op_state = OpState::Running;
184 }
185
186 let mut wakers = WakerList::default();
187 state.sys.post_wait(&mut wait_state, &mut wakers);
188 state.timers.wake_expired(&mut wakers);
189 drop(state);
190 wakers.wake();
191 self.condvar.notify_all();
193 }
194}
195
196impl ArcWake for LocalInner {
197 fn wake_by_ref(arc_self: &Arc<Self>) {
198 let mut state = arc_self.state.lock();
199 match state.op_state {
200 OpState::Running => state.op_state = OpState::RunAgain,
201 OpState::RunAgain => {}
202 OpState::Waiting => {
203 state.op_state = OpState::Woken;
204 drop(state);
205 arc_self.wait_cancel.cancel_wait();
206 }
207 OpState::Woken => {}
208 }
209 }
210}
211
212mod timer {
214 use super::LocalInner;
215 use crate::timer::TimerQueueId;
216 use std::sync::Arc;
217
218 #[derive(Debug)]
219 pub struct Timer {
220 pub(super) inner: Arc<LocalInner>,
221 pub(super) id: TimerQueueId,
222 }
223}
224
225impl TimerDriver for LocalDriver {
226 type Timer = Timer;
227
228 fn new_timer(&self) -> Self::Timer {
229 let id = self.inner.lock_state().timers.add();
230 Timer {
231 inner: self.inner.clone(),
232 id,
233 }
234 }
235}
236
237impl Drop for Timer {
238 fn drop(&mut self) {
239 let _waker = self.inner.lock_state().timers.remove(self.id);
240 }
241}
242
243impl PollTimer for Timer {
244 fn poll_timer(&mut self, cx: &mut Context<'_>, deadline: Option<Instant>) -> Poll<Instant> {
245 let mut state = self.inner.lock_state();
246 if let Some(deadline) = deadline {
247 state.timers.set_deadline(self.id, deadline);
248 }
249 match state.timers.poll_deadline(cx, self.id) {
250 TimerResult::TimedOut(now) => Poll::Ready(now),
251 TimerResult::Pending(_) => Poll::Pending,
252 }
253 }
254
255 fn set_deadline(&mut self, deadline: Instant) {
256 self.inner
257 .lock_state()
258 .timers
259 .set_deadline(self.id, deadline);
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::block_with_io;
266 use crate::executor_tests;
267
268 #[test]
269 fn waker_works() {
270 block_with_io(|_| executor_tests::waker_tests())
271 }
272
273 #[test]
274 fn sleep_works() {
275 block_with_io(executor_tests::sleep_tests)
276 }
277
278 #[test]
279 fn wait_works() {
280 block_with_io(executor_tests::wait_tests)
281 }
282
283 #[test]
284 fn socket_works() {
285 block_with_io(executor_tests::socket_tests)
286 }
287
288 #[cfg(windows)]
289 #[test]
290 fn overlapped_file_works() {
291 block_with_io(executor_tests::windows::overlapped_file_tests)
292 }
293}