task_control/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! A simple asynchronous task model for execution that needs to be started,
5//! stopped, mutated, and inspected.
6
7#![forbid(unsafe_code)]
8
9use fast_select::FastSelect;
10use inspect::Inspect;
11use inspect::InspectMut;
12use pal_async::task::Spawn;
13use pal_async::task::Task;
14use parking_lot::Mutex;
15use std::future::Future;
16use std::future::poll_fn;
17use std::pin::Pin;
18use std::pin::pin;
19use std::sync::Arc;
20use std::task::Context;
21use std::task::Poll;
22use std::task::Waker;
23
24/// A method implemented by a task that can be run and stopped, storing
25/// transient state in `S`.
26pub trait AsyncRun<S>: 'static + Send {
27    /// Runs the task.
28    ///
29    /// The task should stop when `stop` becomes ready. This can be determined
30    /// either by awaiting on `stop`, or by calling [`StopTask::until_stopped`]
31    /// with a future to run.
32    ///
33    /// The function should return `Ok(())` if the task is complete, in which
34    /// case it will only run again after being removed and reinserted.
35    ///
36    /// If the function instead returns `Err(Cancelled)`, this indicates that
37    /// the task's work is not complete, and it should be restarted after
38    /// handling any incoming events.
39    fn run(
40        &mut self,
41        stop: &mut StopTask<'_>,
42        _: &mut S,
43    ) -> impl Send + Future<Output = Result<(), Cancelled>>;
44}
45
46/// The return error from [`AsyncRun::run`] indicating the task has not yet
47/// finished executing.
48#[derive(Debug)]
49pub struct Cancelled;
50
51/// A future indicating that the task should return for event processing or
52/// because the task was stopped.
53pub struct StopTask<'a> {
54    inner: &'a mut (dyn 'a + Send + Future<Output = ()> + Unpin),
55    fast_select: &'a mut FastSelect,
56}
57
58/// The inner polling implementation, which polls for an incoming request from
59/// `TaskControl`.
60///
61/// This is separate from `StopTask` so that the types can be erased.
62struct StopTaskInner<'a, T, S> {
63    shared: &'a Mutex<Shared<T, S>>,
64}
65
66impl<T: AsyncRun<S>, S> Future for StopTaskInner<'_, T, S> {
67    type Output = ();
68
69    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
70        let mut shared = self.get_mut().shared.lock();
71        if !shared.calls.is_empty() || shared.stop {
72            return Poll::Ready(());
73        }
74        if shared
75            .inner_waker
76            .as_ref()
77            .is_none_or(|waker| !cx.waker().will_wake(waker))
78        {
79            shared.inner_waker = Some(cx.waker().clone());
80        }
81        Poll::Pending
82    }
83}
84
85impl StopTask<'_> {
86    /// Runs `f`, providing access to `stop` via a `StopTask` passed to `f`.
87    pub async fn run_with<R>(
88        mut stop: impl Send + Future<Output = ()> + Unpin,
89        f: impl AsyncFnOnce(&mut StopTask<'_>) -> R,
90    ) -> R {
91        let mut fast_select: FastSelect = FastSelect::new();
92        let mut stop = StopTask {
93            inner: &mut stop,
94            fast_select: &mut fast_select,
95        };
96        f(&mut stop).await
97    }
98
99    /// Runs `fut` until the task is requested to stop.
100    ///
101    /// If `fut` completes, then `Ok(_)` is returned.
102    ///
103    /// If the task is requested to stop before `fut` completes, then `fut` is
104    /// dropped and `Err(Cancelled)` is returned.
105    pub async fn until_stopped<F: Future>(&mut self, fut: F) -> Result<F::Output, Cancelled> {
106        // Wrap the cancel task in a FastSelect to avoid taking the channel lock
107        // at each wakeup.
108        let mut cancel = pin!(
109            self.fast_select
110                .select((poll_fn(|cx| Pin::new(&mut self.inner).poll(cx)),))
111        );
112
113        let mut fut = pin!(fut);
114
115        // Since this is a common fast path, implement the select manually.
116        poll_fn(|cx| {
117            if let Poll::Ready(r) = fut.as_mut().poll(cx) {
118                Poll::Ready(Ok(r))
119            } else if cancel.as_mut().poll(cx).is_ready() {
120                Poll::Ready(Err(Cancelled))
121            } else {
122                Poll::Pending
123            }
124        })
125        .await
126    }
127}
128
129impl Future for StopTask<'_> {
130    type Output = ();
131
132    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
133        Pin::new(&mut self.inner).poll(cx)
134    }
135}
136
137/// A task wrapper that runs the task asynchronously and provides access to its
138/// state.
139pub struct TaskControl<T, S> {
140    inner: Inner<T, S>,
141}
142
143/// A trait for inspecting a task and its associated state.
144pub trait InspectTask<S>: AsyncRun<S> {
145    /// Inspects the task and its state.
146    ///
147    /// The state may be missing if it has not yet been inserted into the
148    /// [`TaskControl`].
149    fn inspect(&self, req: inspect::Request<'_>, state: Option<&S>);
150}
151
152impl<T: InspectTask<S>, S> Inspect for TaskAndState<T, S> {
153    fn inspect(&self, req: inspect::Request<'_>) {
154        self.task.inspect(req, self.state.as_ref());
155    }
156}
157
158impl<T: InspectTask<S>, S> Inspect for TaskControl<T, S> {
159    fn inspect(&self, req: inspect::Request<'_>) {
160        match &self.inner {
161            Inner::NoState(task_and_state) => task_and_state.inspect(req),
162            Inner::WithState {
163                activity, shared, ..
164            } => match activity {
165                Activity::Stopped(task_and_state) => task_and_state.inspect(req),
166                Activity::Running => {
167                    let deferred = req.defer();
168                    Shared::push_call(
169                        shared,
170                        Box::new(|task_and_state| {
171                            deferred.inspect(&task_and_state);
172                        }),
173                    )
174                }
175            },
176            Inner::Invalid => unreachable!(),
177        }
178    }
179}
180
181/// A trait for mutably inspecting a task and its associated state.
182pub trait InspectTaskMut<T>: AsyncRun<T> {
183    /// Inspects the task and its state.
184    ///
185    /// The state may be missing if it has not yet been inserted into the
186    /// [`TaskControl`].
187    fn inspect_mut(&mut self, req: inspect::Request<'_>, state: Option<&mut T>);
188}
189
190impl<T: InspectTaskMut<S>, S> InspectMut for TaskAndState<T, S> {
191    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
192        self.task.inspect_mut(req, self.state.as_mut());
193    }
194}
195
196impl<T: InspectTaskMut<U>, U> InspectMut for TaskControl<T, U> {
197    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
198        match &mut self.inner {
199            Inner::NoState(task_and_state) => task_and_state.inspect_mut(req),
200            Inner::WithState {
201                activity, shared, ..
202            } => match activity {
203                Activity::Stopped(task_and_state) => task_and_state.inspect_mut(req),
204                Activity::Running => {
205                    let deferred = req.defer();
206                    Shared::push_call(
207                        shared,
208                        Box::new(|task_and_state| {
209                            deferred.inspect(task_and_state);
210                        }),
211                    );
212                }
213            },
214            Inner::Invalid => unreachable!(),
215        }
216    }
217}
218
219type CallFn<T, S> = Box<dyn FnOnce(&mut TaskAndState<T, S>) + Send>;
220
221enum Inner<T, S> {
222    NoState(Box<TaskAndState<T, S>>),
223    WithState {
224        activity: Activity<T, S>,
225        _backing_task: Task<()>,
226        shared: Arc<Mutex<Shared<T, S>>>,
227    },
228    Invalid,
229}
230
231struct TaskAndState<T, S> {
232    task: T,
233    state: Option<S>,
234    done: bool,
235}
236
237struct Shared<T, S> {
238    task_and_state: Option<Box<TaskAndState<T, S>>>,
239    calls: Vec<CallFn<T, S>>,
240    stop: bool,
241    outer_waker: Option<Waker>,
242    inner_waker: Option<Waker>,
243}
244
245impl<T, S> Shared<T, S> {
246    fn push_call(this: &Mutex<Self>, f: CallFn<T, S>) {
247        let waker = {
248            let mut this = this.lock();
249            this.calls.push(f);
250            this.inner_waker.take()
251        };
252        if let Some(waker) = waker {
253            waker.wake();
254        }
255    }
256}
257
258enum Activity<T, S> {
259    Stopped(Box<TaskAndState<T, S>>),
260    Running,
261}
262
263impl<T: AsyncRun<S>, S: 'static + Send> TaskControl<T, S> {
264    /// Creates the task control, taking the state for the task but not yet
265    /// creating or starting it.
266    pub fn new(task: T) -> Self {
267        Self {
268            inner: Inner::NoState(Box::new(TaskAndState {
269                task,
270                state: None,
271                done: false,
272            })),
273        }
274    }
275
276    /// Returns true if a task has been inserted.
277    pub fn has_state(&self) -> bool {
278        match &self.inner {
279            Inner::NoState(_) => false,
280            Inner::WithState { .. } => true,
281            Inner::Invalid => unreachable!(),
282        }
283    }
284
285    /// Returns true if a task is running.
286    pub fn is_running(&self) -> bool {
287        match &self.inner {
288            Inner::NoState(_)
289            | Inner::WithState {
290                activity: Activity::Stopped { .. },
291                ..
292            } => false,
293            Inner::WithState {
294                activity: Activity::Running,
295                ..
296            } => true,
297            Inner::Invalid => unreachable!(),
298        }
299    }
300
301    /// Gets the task.
302    ///
303    /// Panics if the task is running.
304    #[track_caller]
305    pub fn task(&self) -> &T {
306        self.get().0
307    }
308
309    /// Gets the task.
310    ///
311    /// Panics if the task is running.
312    #[track_caller]
313    pub fn task_mut(&mut self) -> &mut T {
314        self.get_mut().0
315    }
316
317    /// Gets the transient task state.
318    ///
319    /// Panics if the task is running.
320    #[track_caller]
321    pub fn state(&self) -> Option<&S> {
322        self.get().1
323    }
324
325    /// Gets the transient task state.
326    ///
327    /// Panics if the task is running.
328    #[track_caller]
329    pub fn state_mut(&mut self) -> Option<&mut S> {
330        self.get_mut().1
331    }
332
333    /// Gets the task and its state.
334    ///
335    /// Panics if the task is running.
336    #[track_caller]
337    pub fn get(&self) -> (&T, Option<&S>) {
338        let task_and_state = match &self.inner {
339            Inner::NoState(task_and_state) => task_and_state,
340            Inner::WithState {
341                activity: Activity::Stopped(task_and_state),
342                ..
343            } => task_and_state,
344            Inner::WithState {
345                activity: Activity::Running,
346                ..
347            } => panic!("attempt to access running task"),
348            Inner::Invalid => unreachable!(),
349        };
350        (&task_and_state.task, task_and_state.state.as_ref())
351    }
352
353    /// Gets the state and the task.
354    ///
355    /// Panics if the task is running.
356    #[track_caller]
357    pub fn get_mut(&mut self) -> (&mut T, Option<&mut S>) {
358        let task_and_state = match &mut self.inner {
359            Inner::NoState(task_and_state) => task_and_state,
360            Inner::WithState {
361                activity: Activity::Stopped(task_and_state),
362                ..
363            } => task_and_state,
364            Inner::WithState {
365                activity: Activity::Running,
366                ..
367            } => panic!("attempt to access running task"),
368            Inner::Invalid => unreachable!(),
369        };
370        (&mut task_and_state.task, task_and_state.state.as_mut())
371    }
372
373    /// Retrieves the task and its state.
374    ///
375    /// Panics if the task is running.
376    #[track_caller]
377    pub fn into_inner(self) -> (T, Option<S>) {
378        let task_and_state = match self.inner {
379            Inner::NoState(task_and_state) => task_and_state,
380            Inner::WithState {
381                activity: Activity::Stopped(task_and_state),
382                ..
383            } => task_and_state,
384            Inner::WithState {
385                activity: Activity::Running,
386                ..
387            } => panic!("attempt to extract running task"),
388            Inner::Invalid => unreachable!(),
389        };
390        (task_and_state.task, task_and_state.state)
391    }
392
393    /// Calls `f` against the task and its state.
394    ///
395    /// If the task is running, then `f` will run remotely and will not
396    /// necessarily finish before this routine returns.
397    pub fn update_with(&mut self, f: impl 'static + Send + FnOnce(&mut T, Option<&mut S>)) {
398        let f = |task_and_state: &mut TaskAndState<T, S>| {
399            f(&mut task_and_state.task, task_and_state.state.as_mut())
400        };
401        match &mut self.inner {
402            Inner::NoState(task_and_state) => f(task_and_state),
403            Inner::WithState {
404                activity, shared, ..
405            } => match activity {
406                Activity::Stopped(task_and_state) => f(task_and_state),
407                Activity::Running => Shared::push_call(shared, Box::new(f)),
408            },
409            Inner::Invalid => unreachable!(),
410        }
411    }
412
413    /// Inserts the state the task object will use to run and starts the backing
414    /// task, but does not start running it.
415    #[track_caller]
416    pub fn insert(&mut self, spawn: impl Spawn, name: impl Into<Arc<str>>, state: S) -> &mut S {
417        self.inner = match std::mem::replace(&mut self.inner, Inner::Invalid) {
418            Inner::NoState(mut task_and_state) => {
419                task_and_state.state = Some(state);
420                task_and_state.done = false;
421                let shared = Arc::new(Mutex::new(Shared {
422                    task_and_state: None,
423                    calls: Vec::new(),
424                    stop: true,
425                    outer_waker: None,
426                    inner_waker: None,
427                }));
428                let backing_task = spawn.spawn(name, Self::run(shared.clone()));
429                Inner::WithState {
430                    activity: Activity::Stopped(task_and_state),
431                    _backing_task: backing_task,
432                    shared,
433                }
434            }
435            Inner::WithState { .. } => panic!("attempt to insert already-present state"),
436            Inner::Invalid => unreachable!(),
437        };
438        self.state_mut().unwrap()
439    }
440
441    /// Starts the task if it is not already running.
442    ///
443    /// Returns true if the task is now running (even if it was previously
444    /// running). Returns false if the task is not running (either because its
445    /// state has not been inserted, or because it has already completed).
446    pub fn start(&mut self) -> bool {
447        match &mut self.inner {
448            Inner::WithState {
449                activity, shared, ..
450            } => match std::mem::replace(activity, Activity::Running) {
451                Activity::Stopped(task_and_state) => {
452                    if task_and_state.done {
453                        *activity = Activity::Stopped(task_and_state);
454                        return false;
455                    }
456                    let waker = {
457                        let mut shared = shared.lock();
458                        shared.task_and_state = Some(task_and_state);
459                        shared.stop = false;
460                        shared.inner_waker.take()
461                    };
462                    if let Some(waker) = waker {
463                        waker.wake();
464                    }
465                    true
466                }
467                Activity::Running => true,
468            },
469            Inner::NoState(_) => false,
470            Inner::Invalid => {
471                unreachable!()
472            }
473        }
474    }
475
476    async fn run(shared: Arc<Mutex<Shared<T, S>>>) {
477        StopTask::run_with(StopTaskInner { shared: &shared }, async |stop_task| {
478            let mut calls = Vec::new();
479            loop {
480                let (mut task_and_state, stop) = poll_fn(|cx| {
481                    let mut shared = shared.lock();
482                    let has_work = shared
483                        .task_and_state
484                        .as_ref()
485                        .is_some_and(|ts| !shared.calls.is_empty() || (!shared.stop && !ts.done));
486                    if !has_work {
487                        shared.inner_waker = Some(cx.waker().clone());
488                        return Poll::Pending;
489                    }
490                    calls.append(&mut shared.calls);
491                    Poll::Ready((shared.task_and_state.take().unwrap(), shared.stop))
492                })
493                .await;
494
495                for call in calls.drain(..) {
496                    call(&mut task_and_state);
497                }
498
499                if !stop && !task_and_state.done {
500                    task_and_state.done = task_and_state
501                        .task
502                        .run(&mut *stop_task, task_and_state.state.as_mut().unwrap())
503                        .await
504                        .is_ok();
505                }
506
507                let waker = {
508                    let mut shared = shared.lock();
509                    shared.task_and_state = Some(task_and_state);
510                    shared.outer_waker.take()
511                };
512                if let Some(waker) = waker {
513                    waker.wake();
514                }
515            }
516        })
517        .await
518    }
519
520    /// Stops the task, waiting for it to be cancelled.
521    ///
522    /// Returns true if the task was previously running. Returns false if the
523    /// task was not running, not inserted, or had already completed.
524    pub async fn stop(&mut self) -> bool {
525        match &mut self.inner {
526            Inner::WithState {
527                activity, shared, ..
528            } => match activity {
529                Activity::Running => {
530                    let task_and_state = poll_fn(|cx| {
531                        let mut shared = shared.lock();
532                        shared.stop = true;
533                        if shared.task_and_state.is_none() || !shared.calls.is_empty() {
534                            shared.outer_waker = Some(cx.waker().clone());
535                            let waker = shared.inner_waker.take();
536                            drop(shared);
537                            if let Some(waker) = waker {
538                                waker.wake();
539                            }
540                            return Poll::Pending;
541                        }
542                        Poll::Ready(shared.task_and_state.take().unwrap())
543                    })
544                    .await;
545
546                    let done = task_and_state.done;
547                    *activity = Activity::Stopped(task_and_state);
548                    !done
549                }
550                _ => false,
551            },
552            Inner::NoState(_) => false,
553            Inner::Invalid => unreachable!(),
554        }
555    }
556
557    /// Removes the task state.
558    ///
559    /// Panics if the task is not stopped.
560    #[track_caller]
561    pub fn remove(&mut self) -> S {
562        match std::mem::replace(&mut self.inner, Inner::Invalid) {
563            Inner::WithState {
564                activity: Activity::Stopped(mut task_and_state),
565                ..
566            } => {
567                let state = task_and_state.state.take().unwrap();
568                self.inner = Inner::NoState(task_and_state);
569                state
570            }
571            Inner::NoState(_) => panic!("attempt to remove missing state"),
572            Inner::WithState { .. } => panic!("attempt to remove state from running task"),
573            Inner::Invalid => {
574                unreachable!()
575            }
576        }
577    }
578}
579
580#[cfg(test)]
581mod tests {
582    use super::AsyncRun;
583    use crate::Cancelled;
584    use crate::StopTask;
585    use crate::TaskControl;
586    use futures::FutureExt;
587    use pal_async::DefaultDriver;
588    use pal_async::async_test;
589    use std::task::Poll;
590
591    struct Foo(u32);
592
593    impl AsyncRun<bool> for Foo {
594        async fn run(
595            &mut self,
596            stop: &mut StopTask<'_>,
597            state: &mut bool,
598        ) -> Result<(), Cancelled> {
599            stop.until_stopped(async {
600                self.0 += 1;
601                if !*state {
602                    std::future::pending::<()>().await;
603                }
604            })
605            .await
606        }
607    }
608
609    async fn yield_once() {
610        let mut yielded = false;
611        std::future::poll_fn(|cx| {
612            if yielded {
613                Poll::Ready(())
614            } else {
615                yielded = true;
616                cx.waker().wake_by_ref();
617                Poll::Pending
618            }
619        })
620        .await
621    }
622
623    #[async_test]
624    async fn test(driver: DefaultDriver) {
625        let mut t = TaskControl::new(Foo(5));
626        t.insert(&driver, "test", false);
627        t.remove();
628        t.insert(&driver, "test", false);
629        assert_eq!(t.task().0, 5);
630        assert!(t.start());
631        yield_once().await;
632        assert!(t.stop().await);
633        assert_eq!(t.task().0, 6);
634        *t.state_mut().unwrap() = true;
635        assert!(t.start());
636        yield_once().await;
637        assert!(!t.stop().await);
638        assert_eq!(t.task().0, 7);
639        // The task has completed, so starting it again will not increment the counter.
640        assert!(!t.start());
641        yield_once().await;
642        assert!(!t.stop().await);
643        assert_eq!(t.task().0, 7);
644    }
645
646    #[async_test]
647    async fn test_cancelled_stop(driver: DefaultDriver) {
648        let mut t = TaskControl::new(Foo(5));
649        t.insert(&driver, "test", false);
650        assert!(t.start());
651        yield_once().await;
652        t.update_with(|t, _| t.0 += 1);
653        assert!(t.stop().now_or_never().is_none());
654        t.update_with(|t, _| t.0 += 1);
655        assert!(t.stop().await);
656        assert_eq!(t.task_mut().0, 8);
657    }
658}