Skip to main content

task_control/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! A simple asynchronous task model for execution that needs to be started,
5//! stopped, mutated, and inspected.
6
7#![forbid(unsafe_code)]
8
9use fast_select::FastSelect;
10use inspect::Inspect;
11use inspect::InspectMut;
12use pal_async::task::Spawn;
13use pal_async::task::Task;
14use parking_lot::Mutex;
15use std::future::Future;
16use std::future::poll_fn;
17use std::pin::Pin;
18use std::pin::pin;
19use std::sync::Arc;
20use std::task::Context;
21use std::task::Poll;
22use std::task::Waker;
23
24/// A method implemented by a task that can be run and stopped, storing
25/// transient state in `S`.
26pub trait AsyncRun<S>: 'static + Send {
27    /// Runs the task.
28    ///
29    /// The task should stop when `stop` becomes ready. This can be determined
30    /// either by awaiting on `stop`, or by calling [`StopTask::until_stopped`]
31    /// with a future to run.
32    ///
33    /// The function should return `Ok(())` if the task is complete, in which
34    /// case it will only run again after being removed and reinserted.
35    ///
36    /// If the function instead returns `Err(Cancelled)`, this indicates that
37    /// the task's work is not complete, and it should be restarted after
38    /// handling any incoming events. Implementations that choose to use
39    /// [`StopTask::until_stopped`] must be cancel-tolerant: because
40    /// [`StopTask::until_stopped`] drops the inner future when a stop
41    /// is signaled, implementations must ensure no in-flight work is silently
42    /// lost across a cancellation.
43    fn run(
44        &mut self,
45        stop: &mut StopTask<'_>,
46        _: &mut S,
47    ) -> impl Send + Future<Output = Result<(), Cancelled>>;
48}
49
50/// The return error from [`AsyncRun::run`] indicating the task has not yet
51/// finished executing.
52#[derive(Debug)]
53pub struct Cancelled;
54
55/// A future indicating that the task should return for event processing or
56/// because the task was stopped.
57pub struct StopTask<'a> {
58    inner: &'a mut (dyn 'a + Send + Future<Output = ()> + Unpin),
59    fast_select: &'a mut FastSelect,
60}
61
62/// The inner polling implementation, which polls for an incoming request from
63/// `TaskControl`.
64///
65/// This is separate from `StopTask` so that the types can be erased.
66struct StopTaskInner<'a, T, S> {
67    shared: &'a Mutex<Shared<T, S>>,
68}
69
70impl<T: AsyncRun<S>, S> Future for StopTaskInner<'_, T, S> {
71    type Output = ();
72
73    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
74        let mut shared = self.get_mut().shared.lock();
75        if !shared.calls.is_empty() || shared.stop {
76            return Poll::Ready(());
77        }
78        if shared
79            .inner_waker
80            .as_ref()
81            .is_none_or(|waker| !cx.waker().will_wake(waker))
82        {
83            shared.inner_waker = Some(cx.waker().clone());
84        }
85        Poll::Pending
86    }
87}
88
89impl StopTask<'_> {
90    /// Runs `f`, providing access to `stop` via a `StopTask` passed to `f`.
91    pub async fn run_with<R>(
92        mut stop: impl Send + Future<Output = ()> + Unpin,
93        f: impl AsyncFnOnce(&mut StopTask<'_>) -> R,
94    ) -> R {
95        let mut fast_select: FastSelect = FastSelect::new();
96        let mut stop = StopTask {
97            inner: &mut stop,
98            fast_select: &mut fast_select,
99        };
100        f(&mut stop).await
101    }
102
103    /// Runs `fut` until the task is requested to stop.
104    ///
105    /// If `fut` completes, then `Ok(_)` is returned.
106    ///
107    /// If the task is requested to stop before `fut` completes, then `fut` is
108    /// dropped and `Err(Cancelled)` is returned.
109    pub async fn until_stopped<F: Future>(&mut self, fut: F) -> Result<F::Output, Cancelled> {
110        // Wrap the cancel task in a FastSelect to avoid taking the channel lock
111        // at each wakeup.
112        let mut cancel = pin!(
113            self.fast_select
114                .select((poll_fn(|cx| Pin::new(&mut self.inner).poll(cx)),))
115        );
116
117        let mut fut = pin!(fut);
118
119        // Since this is a common fast path, implement the select manually.
120        poll_fn(|cx| {
121            if let Poll::Ready(r) = fut.as_mut().poll(cx) {
122                Poll::Ready(Ok(r))
123            } else if cancel.as_mut().poll(cx).is_ready() {
124                Poll::Ready(Err(Cancelled))
125            } else {
126                Poll::Pending
127            }
128        })
129        .await
130    }
131}
132
133impl Future for StopTask<'_> {
134    type Output = ();
135
136    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
137        Pin::new(&mut self.inner).poll(cx)
138    }
139}
140
141/// A task wrapper that runs a task asynchronously and provides access to its
142/// state and control over its execution (start/stop).
143///
144/// Pairs a task implementation `T: AsyncRun<S>` with transient state `S`.
145/// Execution is cancelled when [`stop`](Self::stop) is invoked.
146///
147/// Outside of an actual [`stop`](Self::stop), `TaskControl` also raises the
148/// stop signal to fulfill [`update_with`](Self::update_with) and
149/// [`Inspect`]/[`InspectMut`] calls, after which `run` is re-invoked with
150/// the (possibly updated) state. See [`AsyncRun`] for the cancel-tolerance
151/// requirements this places on implementations.
152pub struct TaskControl<T, S> {
153    inner: Inner<T, S>,
154}
155
156/// A trait for inspecting a task and its associated state.
157pub trait InspectTask<S>: AsyncRun<S> {
158    /// Inspects the task and its state.
159    ///
160    /// The state may be missing if it has not yet been inserted into the
161    /// [`TaskControl`].
162    fn inspect(&self, req: inspect::Request<'_>, state: Option<&S>);
163}
164
165impl<T: InspectTask<S>, S> Inspect for TaskAndState<T, S> {
166    fn inspect(&self, req: inspect::Request<'_>) {
167        self.task.inspect(req, self.state.as_ref());
168    }
169}
170
171impl<T: InspectTask<S>, S> Inspect for TaskControl<T, S> {
172    fn inspect(&self, req: inspect::Request<'_>) {
173        match &self.inner {
174            Inner::NoState(task_and_state) => task_and_state.inspect(req),
175            Inner::WithState {
176                activity, shared, ..
177            } => match activity {
178                Activity::Stopped(task_and_state) => task_and_state.inspect(req),
179                Activity::Running => {
180                    let deferred = req.defer();
181                    Shared::push_call(
182                        shared,
183                        Box::new(|task_and_state| {
184                            deferred.inspect(&task_and_state);
185                        }),
186                    )
187                }
188            },
189            Inner::Invalid => unreachable!(),
190        }
191    }
192}
193
194/// A trait for mutably inspecting a task and its associated state.
195pub trait InspectTaskMut<T>: AsyncRun<T> {
196    /// Inspects the task and its state.
197    ///
198    /// The state may be missing if it has not yet been inserted into the
199    /// [`TaskControl`].
200    fn inspect_mut(&mut self, req: inspect::Request<'_>, state: Option<&mut T>);
201}
202
203impl<T: InspectTaskMut<S>, S> InspectMut for TaskAndState<T, S> {
204    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
205        self.task.inspect_mut(req, self.state.as_mut());
206    }
207}
208
209impl<T: InspectTaskMut<U>, U> InspectMut for TaskControl<T, U> {
210    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
211        match &mut self.inner {
212            Inner::NoState(task_and_state) => task_and_state.inspect_mut(req),
213            Inner::WithState {
214                activity, shared, ..
215            } => match activity {
216                Activity::Stopped(task_and_state) => task_and_state.inspect_mut(req),
217                Activity::Running => {
218                    let deferred = req.defer();
219                    Shared::push_call(
220                        shared,
221                        Box::new(|task_and_state| {
222                            deferred.inspect(task_and_state);
223                        }),
224                    );
225                }
226            },
227            Inner::Invalid => unreachable!(),
228        }
229    }
230}
231
232type CallFn<T, S> = Box<dyn FnOnce(&mut TaskAndState<T, S>) + Send>;
233
234enum Inner<T, S> {
235    NoState(Box<TaskAndState<T, S>>),
236    WithState {
237        activity: Activity<T, S>,
238        _backing_task: Task<()>,
239        shared: Arc<Mutex<Shared<T, S>>>,
240    },
241    Invalid,
242}
243
244struct TaskAndState<T, S> {
245    task: T,
246    state: Option<S>,
247    done: bool,
248}
249
250struct Shared<T, S> {
251    task_and_state: Option<Box<TaskAndState<T, S>>>,
252    calls: Vec<CallFn<T, S>>,
253    stop: bool,
254    outer_waker: Option<Waker>,
255    inner_waker: Option<Waker>,
256}
257
258impl<T, S> Shared<T, S> {
259    fn push_call(this: &Mutex<Self>, f: CallFn<T, S>) {
260        let waker = {
261            let mut this = this.lock();
262            this.calls.push(f);
263            this.inner_waker.take()
264        };
265        if let Some(waker) = waker {
266            waker.wake();
267        }
268    }
269}
270
271enum Activity<T, S> {
272    Stopped(Box<TaskAndState<T, S>>),
273    Running,
274}
275
276impl<T: AsyncRun<S>, S: 'static + Send> TaskControl<T, S> {
277    /// Creates the task control, taking the state for the task but not yet
278    /// creating or starting it.
279    pub fn new(task: T) -> Self {
280        Self {
281            inner: Inner::NoState(Box::new(TaskAndState {
282                task,
283                state: None,
284                done: false,
285            })),
286        }
287    }
288
289    /// Returns true if a task has been inserted.
290    pub fn has_state(&self) -> bool {
291        match &self.inner {
292            Inner::NoState(_) => false,
293            Inner::WithState { .. } => true,
294            Inner::Invalid => unreachable!(),
295        }
296    }
297
298    /// Returns true if a task is running.
299    pub fn is_running(&self) -> bool {
300        match &self.inner {
301            Inner::NoState(_)
302            | Inner::WithState {
303                activity: Activity::Stopped { .. },
304                ..
305            } => false,
306            Inner::WithState {
307                activity: Activity::Running,
308                ..
309            } => true,
310            Inner::Invalid => unreachable!(),
311        }
312    }
313
314    /// Gets the task.
315    ///
316    /// Panics if the task is running.
317    #[track_caller]
318    pub fn task(&self) -> &T {
319        self.get().0
320    }
321
322    /// Gets the task.
323    ///
324    /// Panics if the task is running.
325    #[track_caller]
326    pub fn task_mut(&mut self) -> &mut T {
327        self.get_mut().0
328    }
329
330    /// Gets the transient task state.
331    ///
332    /// Panics if the task is running.
333    #[track_caller]
334    pub fn state(&self) -> Option<&S> {
335        self.get().1
336    }
337
338    /// Gets the transient task state.
339    ///
340    /// Panics if the task is running.
341    #[track_caller]
342    pub fn state_mut(&mut self) -> Option<&mut S> {
343        self.get_mut().1
344    }
345
346    /// Gets the task and its state.
347    ///
348    /// Panics if the task is running.
349    #[track_caller]
350    pub fn get(&self) -> (&T, Option<&S>) {
351        let task_and_state = match &self.inner {
352            Inner::NoState(task_and_state) => task_and_state,
353            Inner::WithState {
354                activity: Activity::Stopped(task_and_state),
355                ..
356            } => task_and_state,
357            Inner::WithState {
358                activity: Activity::Running,
359                ..
360            } => panic!("attempt to access running task"),
361            Inner::Invalid => unreachable!(),
362        };
363        (&task_and_state.task, task_and_state.state.as_ref())
364    }
365
366    /// Gets the state and the task.
367    ///
368    /// Panics if the task is running.
369    #[track_caller]
370    pub fn get_mut(&mut self) -> (&mut T, Option<&mut S>) {
371        let task_and_state = match &mut self.inner {
372            Inner::NoState(task_and_state) => task_and_state,
373            Inner::WithState {
374                activity: Activity::Stopped(task_and_state),
375                ..
376            } => task_and_state,
377            Inner::WithState {
378                activity: Activity::Running,
379                ..
380            } => panic!("attempt to access running task"),
381            Inner::Invalid => unreachable!(),
382        };
383        (&mut task_and_state.task, task_and_state.state.as_mut())
384    }
385
386    /// Retrieves the task and its state.
387    ///
388    /// Panics if the task is running.
389    #[track_caller]
390    pub fn into_inner(self) -> (T, Option<S>) {
391        let task_and_state = match self.inner {
392            Inner::NoState(task_and_state) => task_and_state,
393            Inner::WithState {
394                activity: Activity::Stopped(task_and_state),
395                ..
396            } => task_and_state,
397            Inner::WithState {
398                activity: Activity::Running,
399                ..
400            } => panic!("attempt to extract running task"),
401            Inner::Invalid => unreachable!(),
402        };
403        (task_and_state.task, task_and_state.state)
404    }
405
406    /// Calls `f` against the task and its state.
407    ///
408    /// If the task is running, then `f` will run remotely and will not
409    /// necessarily finish before this routine returns.
410    pub fn update_with(&mut self, f: impl 'static + Send + FnOnce(&mut T, Option<&mut S>)) {
411        let f = |task_and_state: &mut TaskAndState<T, S>| {
412            f(&mut task_and_state.task, task_and_state.state.as_mut())
413        };
414        match &mut self.inner {
415            Inner::NoState(task_and_state) => f(task_and_state),
416            Inner::WithState {
417                activity, shared, ..
418            } => match activity {
419                Activity::Stopped(task_and_state) => f(task_and_state),
420                Activity::Running => Shared::push_call(shared, Box::new(f)),
421            },
422            Inner::Invalid => unreachable!(),
423        }
424    }
425
426    /// Inserts the state the task object will use to run and starts the backing
427    /// task, but does not start running it.
428    #[track_caller]
429    pub fn insert(&mut self, spawn: impl Spawn, name: impl Into<Arc<str>>, state: S) -> &mut S {
430        self.inner = match std::mem::replace(&mut self.inner, Inner::Invalid) {
431            Inner::NoState(mut task_and_state) => {
432                task_and_state.state = Some(state);
433                task_and_state.done = false;
434                let shared = Arc::new(Mutex::new(Shared {
435                    task_and_state: None,
436                    calls: Vec::new(),
437                    stop: true,
438                    outer_waker: None,
439                    inner_waker: None,
440                }));
441                let backing_task = spawn.spawn(name, Self::run(shared.clone()));
442                Inner::WithState {
443                    activity: Activity::Stopped(task_and_state),
444                    _backing_task: backing_task,
445                    shared,
446                }
447            }
448            Inner::WithState { .. } => panic!("attempt to insert already-present state"),
449            Inner::Invalid => unreachable!(),
450        };
451        self.state_mut().unwrap()
452    }
453
454    /// Starts the task if it is not already running.
455    ///
456    /// Returns true if the task is now running (even if it was previously
457    /// running). Returns false if the task is not running (either because its
458    /// state has not been inserted, or because it has already completed).
459    pub fn start(&mut self) -> bool {
460        match &mut self.inner {
461            Inner::WithState {
462                activity, shared, ..
463            } => match std::mem::replace(activity, Activity::Running) {
464                Activity::Stopped(task_and_state) => {
465                    if task_and_state.done {
466                        *activity = Activity::Stopped(task_and_state);
467                        return false;
468                    }
469                    let waker = {
470                        let mut shared = shared.lock();
471                        shared.task_and_state = Some(task_and_state);
472                        shared.stop = false;
473                        shared.inner_waker.take()
474                    };
475                    if let Some(waker) = waker {
476                        waker.wake();
477                    }
478                    true
479                }
480                Activity::Running => true,
481            },
482            Inner::NoState(_) => false,
483            Inner::Invalid => {
484                unreachable!()
485            }
486        }
487    }
488
489    async fn run(shared: Arc<Mutex<Shared<T, S>>>) {
490        StopTask::run_with(StopTaskInner { shared: &shared }, async |stop_task| {
491            let mut calls = Vec::new();
492            loop {
493                let (mut task_and_state, stop) = poll_fn(|cx| {
494                    let mut shared = shared.lock();
495                    let has_work = shared
496                        .task_and_state
497                        .as_ref()
498                        .is_some_and(|ts| !shared.calls.is_empty() || (!shared.stop && !ts.done));
499                    if !has_work {
500                        shared.inner_waker = Some(cx.waker().clone());
501                        return Poll::Pending;
502                    }
503                    calls.append(&mut shared.calls);
504                    Poll::Ready((shared.task_and_state.take().unwrap(), shared.stop))
505                })
506                .await;
507
508                for call in calls.drain(..) {
509                    call(&mut task_and_state);
510                }
511
512                if !stop && !task_and_state.done {
513                    task_and_state.done = task_and_state
514                        .task
515                        .run(&mut *stop_task, task_and_state.state.as_mut().unwrap())
516                        .await
517                        .is_ok();
518                }
519
520                let waker = {
521                    let mut shared = shared.lock();
522                    shared.task_and_state = Some(task_and_state);
523                    shared.outer_waker.take()
524                };
525                if let Some(waker) = waker {
526                    waker.wake();
527                }
528            }
529        })
530        .await
531    }
532
533    /// Poll variant of [`stop`](Self::stop). Signals the task to stop and polls
534    /// for completion.
535    ///
536    /// Returns `Poll::Ready(true)` if the task was previously running and has
537    /// now stopped. Returns `Poll::Ready(false)` if the task was not running,
538    /// not inserted, or had already completed. Returns `Poll::Pending` if the
539    /// task has not yet stopped.
540    pub fn poll_stop(&mut self, cx: &mut Context<'_>) -> Poll<bool> {
541        match &mut self.inner {
542            Inner::WithState {
543                activity, shared, ..
544            } => match activity {
545                Activity::Running => {
546                    let mut shared = shared.lock();
547                    shared.stop = true;
548                    if shared.task_and_state.is_none() || !shared.calls.is_empty() {
549                        shared.outer_waker = Some(cx.waker().clone());
550                        let waker = shared.inner_waker.take();
551                        drop(shared);
552                        if let Some(waker) = waker {
553                            waker.wake();
554                        }
555                        return Poll::Pending;
556                    }
557                    let task_and_state = shared.task_and_state.take().unwrap();
558                    drop(shared);
559                    let done = task_and_state.done;
560                    *activity = Activity::Stopped(task_and_state);
561                    Poll::Ready(!done)
562                }
563                _ => Poll::Ready(false),
564            },
565            Inner::NoState(_) => Poll::Ready(false),
566            Inner::Invalid => unreachable!(),
567        }
568    }
569
570    /// Stops the task, waiting for it to be cancelled.
571    ///
572    /// Returns true if the task was previously running. Returns false if the
573    /// task was not running, not inserted, or had already completed.
574    pub async fn stop(&mut self) -> bool {
575        poll_fn(|cx| self.poll_stop(cx)).await
576    }
577
578    /// Removes the task state.
579    ///
580    /// Panics if the task is not stopped.
581    #[track_caller]
582    pub fn remove(&mut self) -> S {
583        match std::mem::replace(&mut self.inner, Inner::Invalid) {
584            Inner::WithState {
585                activity: Activity::Stopped(mut task_and_state),
586                ..
587            } => {
588                let state = task_and_state.state.take().unwrap();
589                self.inner = Inner::NoState(task_and_state);
590                state
591            }
592            Inner::NoState(_) => panic!("attempt to remove missing state"),
593            Inner::WithState { .. } => panic!("attempt to remove state from running task"),
594            Inner::Invalid => {
595                unreachable!()
596            }
597        }
598    }
599}
600
601#[cfg(test)]
602mod tests {
603    use super::AsyncRun;
604    use crate::Cancelled;
605    use crate::StopTask;
606    use crate::TaskControl;
607    use futures::FutureExt;
608    use pal_async::DefaultDriver;
609    use pal_async::async_test;
610    use std::task::Poll;
611
612    struct Foo(u32);
613
614    impl AsyncRun<bool> for Foo {
615        async fn run(
616            &mut self,
617            stop: &mut StopTask<'_>,
618            state: &mut bool,
619        ) -> Result<(), Cancelled> {
620            stop.until_stopped(async {
621                self.0 += 1;
622                if !*state {
623                    std::future::pending::<()>().await;
624                }
625            })
626            .await
627        }
628    }
629
630    async fn yield_once() {
631        let mut yielded = false;
632        std::future::poll_fn(|cx| {
633            if yielded {
634                Poll::Ready(())
635            } else {
636                yielded = true;
637                cx.waker().wake_by_ref();
638                Poll::Pending
639            }
640        })
641        .await
642    }
643
644    #[async_test]
645    async fn test(driver: DefaultDriver) {
646        let mut t = TaskControl::new(Foo(5));
647        t.insert(&driver, "test", false);
648        t.remove();
649        t.insert(&driver, "test", false);
650        assert_eq!(t.task().0, 5);
651        assert!(t.start());
652        yield_once().await;
653        assert!(t.stop().await);
654        assert_eq!(t.task().0, 6);
655        *t.state_mut().unwrap() = true;
656        assert!(t.start());
657        yield_once().await;
658        assert!(!t.stop().await);
659        assert_eq!(t.task().0, 7);
660        // The task has completed, so starting it again will not increment the counter.
661        assert!(!t.start());
662        yield_once().await;
663        assert!(!t.stop().await);
664        assert_eq!(t.task().0, 7);
665    }
666
667    #[async_test]
668    async fn test_cancelled_stop(driver: DefaultDriver) {
669        let mut t = TaskControl::new(Foo(5));
670        t.insert(&driver, "test", false);
671        assert!(t.start());
672        yield_once().await;
673        t.update_with(|t, _| t.0 += 1);
674        assert!(t.stop().now_or_never().is_none());
675        t.update_with(|t, _| t.0 += 1);
676        assert!(t.stop().await);
677        assert_eq!(t.task_mut().0, 8);
678    }
679
680    #[async_test]
681    async fn test_poll_stop(driver: DefaultDriver) {
682        let mut t = TaskControl::new(Foo(5));
683
684        // poll_stop on a task without state returns Ready(false).
685        assert_eq!(
686            std::future::poll_fn(|cx| Poll::Ready(t.poll_stop(cx))).await,
687            Poll::Ready(false)
688        );
689
690        t.insert(&driver, "test", false);
691
692        // poll_stop on a stopped (not started) task returns Ready(false).
693        assert_eq!(
694            std::future::poll_fn(|cx| Poll::Ready(t.poll_stop(cx))).await,
695            Poll::Ready(false)
696        );
697
698        assert!(t.start());
699        yield_once().await;
700
701        // poll_stop drives the task to stop, equivalent to stop().await.
702        let result = std::future::poll_fn(|cx| t.poll_stop(cx)).await;
703        assert!(result); // was running
704        assert_eq!(t.task().0, 6);
705
706        // poll_stop after already stopped returns Ready(false).
707        assert_eq!(
708            std::future::poll_fn(|cx| Poll::Ready(t.poll_stop(cx))).await,
709            Poll::Ready(false)
710        );
711    }
712}