1#![forbid(unsafe_code)]
8
9use fast_select::FastSelect;
10use inspect::Inspect;
11use inspect::InspectMut;
12use pal_async::task::Spawn;
13use pal_async::task::Task;
14use parking_lot::Mutex;
15use std::future::Future;
16use std::future::poll_fn;
17use std::pin::Pin;
18use std::pin::pin;
19use std::sync::Arc;
20use std::task::Context;
21use std::task::Poll;
22use std::task::Waker;
23
24pub trait AsyncRun<S>: 'static + Send {
27 fn run(
44 &mut self,
45 stop: &mut StopTask<'_>,
46 _: &mut S,
47 ) -> impl Send + Future<Output = Result<(), Cancelled>>;
48}
49
50#[derive(Debug)]
53pub struct Cancelled;
54
55pub struct StopTask<'a> {
58 inner: &'a mut (dyn 'a + Send + Future<Output = ()> + Unpin),
59 fast_select: &'a mut FastSelect,
60}
61
62struct StopTaskInner<'a, T, S> {
67 shared: &'a Mutex<Shared<T, S>>,
68}
69
70impl<T: AsyncRun<S>, S> Future for StopTaskInner<'_, T, S> {
71 type Output = ();
72
73 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
74 let mut shared = self.get_mut().shared.lock();
75 if !shared.calls.is_empty() || shared.stop {
76 return Poll::Ready(());
77 }
78 if shared
79 .inner_waker
80 .as_ref()
81 .is_none_or(|waker| !cx.waker().will_wake(waker))
82 {
83 shared.inner_waker = Some(cx.waker().clone());
84 }
85 Poll::Pending
86 }
87}
88
89impl StopTask<'_> {
90 pub async fn run_with<R>(
92 mut stop: impl Send + Future<Output = ()> + Unpin,
93 f: impl AsyncFnOnce(&mut StopTask<'_>) -> R,
94 ) -> R {
95 let mut fast_select: FastSelect = FastSelect::new();
96 let mut stop = StopTask {
97 inner: &mut stop,
98 fast_select: &mut fast_select,
99 };
100 f(&mut stop).await
101 }
102
103 pub async fn until_stopped<F: Future>(&mut self, fut: F) -> Result<F::Output, Cancelled> {
110 let mut cancel = pin!(
113 self.fast_select
114 .select((poll_fn(|cx| Pin::new(&mut self.inner).poll(cx)),))
115 );
116
117 let mut fut = pin!(fut);
118
119 poll_fn(|cx| {
121 if let Poll::Ready(r) = fut.as_mut().poll(cx) {
122 Poll::Ready(Ok(r))
123 } else if cancel.as_mut().poll(cx).is_ready() {
124 Poll::Ready(Err(Cancelled))
125 } else {
126 Poll::Pending
127 }
128 })
129 .await
130 }
131}
132
133impl Future for StopTask<'_> {
134 type Output = ();
135
136 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
137 Pin::new(&mut self.inner).poll(cx)
138 }
139}
140
141pub struct TaskControl<T, S> {
153 inner: Inner<T, S>,
154}
155
156pub trait InspectTask<S>: AsyncRun<S> {
158 fn inspect(&self, req: inspect::Request<'_>, state: Option<&S>);
163}
164
165impl<T: InspectTask<S>, S> Inspect for TaskAndState<T, S> {
166 fn inspect(&self, req: inspect::Request<'_>) {
167 self.task.inspect(req, self.state.as_ref());
168 }
169}
170
171impl<T: InspectTask<S>, S> Inspect for TaskControl<T, S> {
172 fn inspect(&self, req: inspect::Request<'_>) {
173 match &self.inner {
174 Inner::NoState(task_and_state) => task_and_state.inspect(req),
175 Inner::WithState {
176 activity, shared, ..
177 } => match activity {
178 Activity::Stopped(task_and_state) => task_and_state.inspect(req),
179 Activity::Running => {
180 let deferred = req.defer();
181 Shared::push_call(
182 shared,
183 Box::new(|task_and_state| {
184 deferred.inspect(&task_and_state);
185 }),
186 )
187 }
188 },
189 Inner::Invalid => unreachable!(),
190 }
191 }
192}
193
194pub trait InspectTaskMut<T>: AsyncRun<T> {
196 fn inspect_mut(&mut self, req: inspect::Request<'_>, state: Option<&mut T>);
201}
202
203impl<T: InspectTaskMut<S>, S> InspectMut for TaskAndState<T, S> {
204 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
205 self.task.inspect_mut(req, self.state.as_mut());
206 }
207}
208
209impl<T: InspectTaskMut<U>, U> InspectMut for TaskControl<T, U> {
210 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
211 match &mut self.inner {
212 Inner::NoState(task_and_state) => task_and_state.inspect_mut(req),
213 Inner::WithState {
214 activity, shared, ..
215 } => match activity {
216 Activity::Stopped(task_and_state) => task_and_state.inspect_mut(req),
217 Activity::Running => {
218 let deferred = req.defer();
219 Shared::push_call(
220 shared,
221 Box::new(|task_and_state| {
222 deferred.inspect(task_and_state);
223 }),
224 );
225 }
226 },
227 Inner::Invalid => unreachable!(),
228 }
229 }
230}
231
232type CallFn<T, S> = Box<dyn FnOnce(&mut TaskAndState<T, S>) + Send>;
233
234enum Inner<T, S> {
235 NoState(Box<TaskAndState<T, S>>),
236 WithState {
237 activity: Activity<T, S>,
238 _backing_task: Task<()>,
239 shared: Arc<Mutex<Shared<T, S>>>,
240 },
241 Invalid,
242}
243
244struct TaskAndState<T, S> {
245 task: T,
246 state: Option<S>,
247 done: bool,
248}
249
250struct Shared<T, S> {
251 task_and_state: Option<Box<TaskAndState<T, S>>>,
252 calls: Vec<CallFn<T, S>>,
253 stop: bool,
254 outer_waker: Option<Waker>,
255 inner_waker: Option<Waker>,
256}
257
258impl<T, S> Shared<T, S> {
259 fn push_call(this: &Mutex<Self>, f: CallFn<T, S>) {
260 let waker = {
261 let mut this = this.lock();
262 this.calls.push(f);
263 this.inner_waker.take()
264 };
265 if let Some(waker) = waker {
266 waker.wake();
267 }
268 }
269}
270
271enum Activity<T, S> {
272 Stopped(Box<TaskAndState<T, S>>),
273 Running,
274}
275
276impl<T: AsyncRun<S>, S: 'static + Send> TaskControl<T, S> {
277 pub fn new(task: T) -> Self {
280 Self {
281 inner: Inner::NoState(Box::new(TaskAndState {
282 task,
283 state: None,
284 done: false,
285 })),
286 }
287 }
288
289 pub fn has_state(&self) -> bool {
291 match &self.inner {
292 Inner::NoState(_) => false,
293 Inner::WithState { .. } => true,
294 Inner::Invalid => unreachable!(),
295 }
296 }
297
298 pub fn is_running(&self) -> bool {
300 match &self.inner {
301 Inner::NoState(_)
302 | Inner::WithState {
303 activity: Activity::Stopped { .. },
304 ..
305 } => false,
306 Inner::WithState {
307 activity: Activity::Running,
308 ..
309 } => true,
310 Inner::Invalid => unreachable!(),
311 }
312 }
313
314 #[track_caller]
318 pub fn task(&self) -> &T {
319 self.get().0
320 }
321
322 #[track_caller]
326 pub fn task_mut(&mut self) -> &mut T {
327 self.get_mut().0
328 }
329
330 #[track_caller]
334 pub fn state(&self) -> Option<&S> {
335 self.get().1
336 }
337
338 #[track_caller]
342 pub fn state_mut(&mut self) -> Option<&mut S> {
343 self.get_mut().1
344 }
345
346 #[track_caller]
350 pub fn get(&self) -> (&T, Option<&S>) {
351 let task_and_state = match &self.inner {
352 Inner::NoState(task_and_state) => task_and_state,
353 Inner::WithState {
354 activity: Activity::Stopped(task_and_state),
355 ..
356 } => task_and_state,
357 Inner::WithState {
358 activity: Activity::Running,
359 ..
360 } => panic!("attempt to access running task"),
361 Inner::Invalid => unreachable!(),
362 };
363 (&task_and_state.task, task_and_state.state.as_ref())
364 }
365
366 #[track_caller]
370 pub fn get_mut(&mut self) -> (&mut T, Option<&mut S>) {
371 let task_and_state = match &mut self.inner {
372 Inner::NoState(task_and_state) => task_and_state,
373 Inner::WithState {
374 activity: Activity::Stopped(task_and_state),
375 ..
376 } => task_and_state,
377 Inner::WithState {
378 activity: Activity::Running,
379 ..
380 } => panic!("attempt to access running task"),
381 Inner::Invalid => unreachable!(),
382 };
383 (&mut task_and_state.task, task_and_state.state.as_mut())
384 }
385
386 #[track_caller]
390 pub fn into_inner(self) -> (T, Option<S>) {
391 let task_and_state = match self.inner {
392 Inner::NoState(task_and_state) => task_and_state,
393 Inner::WithState {
394 activity: Activity::Stopped(task_and_state),
395 ..
396 } => task_and_state,
397 Inner::WithState {
398 activity: Activity::Running,
399 ..
400 } => panic!("attempt to extract running task"),
401 Inner::Invalid => unreachable!(),
402 };
403 (task_and_state.task, task_and_state.state)
404 }
405
406 pub fn update_with(&mut self, f: impl 'static + Send + FnOnce(&mut T, Option<&mut S>)) {
411 let f = |task_and_state: &mut TaskAndState<T, S>| {
412 f(&mut task_and_state.task, task_and_state.state.as_mut())
413 };
414 match &mut self.inner {
415 Inner::NoState(task_and_state) => f(task_and_state),
416 Inner::WithState {
417 activity, shared, ..
418 } => match activity {
419 Activity::Stopped(task_and_state) => f(task_and_state),
420 Activity::Running => Shared::push_call(shared, Box::new(f)),
421 },
422 Inner::Invalid => unreachable!(),
423 }
424 }
425
426 #[track_caller]
429 pub fn insert(&mut self, spawn: impl Spawn, name: impl Into<Arc<str>>, state: S) -> &mut S {
430 self.inner = match std::mem::replace(&mut self.inner, Inner::Invalid) {
431 Inner::NoState(mut task_and_state) => {
432 task_and_state.state = Some(state);
433 task_and_state.done = false;
434 let shared = Arc::new(Mutex::new(Shared {
435 task_and_state: None,
436 calls: Vec::new(),
437 stop: true,
438 outer_waker: None,
439 inner_waker: None,
440 }));
441 let backing_task = spawn.spawn(name, Self::run(shared.clone()));
442 Inner::WithState {
443 activity: Activity::Stopped(task_and_state),
444 _backing_task: backing_task,
445 shared,
446 }
447 }
448 Inner::WithState { .. } => panic!("attempt to insert already-present state"),
449 Inner::Invalid => unreachable!(),
450 };
451 self.state_mut().unwrap()
452 }
453
454 pub fn start(&mut self) -> bool {
460 match &mut self.inner {
461 Inner::WithState {
462 activity, shared, ..
463 } => match std::mem::replace(activity, Activity::Running) {
464 Activity::Stopped(task_and_state) => {
465 if task_and_state.done {
466 *activity = Activity::Stopped(task_and_state);
467 return false;
468 }
469 let waker = {
470 let mut shared = shared.lock();
471 shared.task_and_state = Some(task_and_state);
472 shared.stop = false;
473 shared.inner_waker.take()
474 };
475 if let Some(waker) = waker {
476 waker.wake();
477 }
478 true
479 }
480 Activity::Running => true,
481 },
482 Inner::NoState(_) => false,
483 Inner::Invalid => {
484 unreachable!()
485 }
486 }
487 }
488
489 async fn run(shared: Arc<Mutex<Shared<T, S>>>) {
490 StopTask::run_with(StopTaskInner { shared: &shared }, async |stop_task| {
491 let mut calls = Vec::new();
492 loop {
493 let (mut task_and_state, stop) = poll_fn(|cx| {
494 let mut shared = shared.lock();
495 let has_work = shared
496 .task_and_state
497 .as_ref()
498 .is_some_and(|ts| !shared.calls.is_empty() || (!shared.stop && !ts.done));
499 if !has_work {
500 shared.inner_waker = Some(cx.waker().clone());
501 return Poll::Pending;
502 }
503 calls.append(&mut shared.calls);
504 Poll::Ready((shared.task_and_state.take().unwrap(), shared.stop))
505 })
506 .await;
507
508 for call in calls.drain(..) {
509 call(&mut task_and_state);
510 }
511
512 if !stop && !task_and_state.done {
513 task_and_state.done = task_and_state
514 .task
515 .run(&mut *stop_task, task_and_state.state.as_mut().unwrap())
516 .await
517 .is_ok();
518 }
519
520 let waker = {
521 let mut shared = shared.lock();
522 shared.task_and_state = Some(task_and_state);
523 shared.outer_waker.take()
524 };
525 if let Some(waker) = waker {
526 waker.wake();
527 }
528 }
529 })
530 .await
531 }
532
533 pub fn poll_stop(&mut self, cx: &mut Context<'_>) -> Poll<bool> {
541 match &mut self.inner {
542 Inner::WithState {
543 activity, shared, ..
544 } => match activity {
545 Activity::Running => {
546 let mut shared = shared.lock();
547 shared.stop = true;
548 if shared.task_and_state.is_none() || !shared.calls.is_empty() {
549 shared.outer_waker = Some(cx.waker().clone());
550 let waker = shared.inner_waker.take();
551 drop(shared);
552 if let Some(waker) = waker {
553 waker.wake();
554 }
555 return Poll::Pending;
556 }
557 let task_and_state = shared.task_and_state.take().unwrap();
558 drop(shared);
559 let done = task_and_state.done;
560 *activity = Activity::Stopped(task_and_state);
561 Poll::Ready(!done)
562 }
563 _ => Poll::Ready(false),
564 },
565 Inner::NoState(_) => Poll::Ready(false),
566 Inner::Invalid => unreachable!(),
567 }
568 }
569
570 pub async fn stop(&mut self) -> bool {
575 poll_fn(|cx| self.poll_stop(cx)).await
576 }
577
578 #[track_caller]
582 pub fn remove(&mut self) -> S {
583 match std::mem::replace(&mut self.inner, Inner::Invalid) {
584 Inner::WithState {
585 activity: Activity::Stopped(mut task_and_state),
586 ..
587 } => {
588 let state = task_and_state.state.take().unwrap();
589 self.inner = Inner::NoState(task_and_state);
590 state
591 }
592 Inner::NoState(_) => panic!("attempt to remove missing state"),
593 Inner::WithState { .. } => panic!("attempt to remove state from running task"),
594 Inner::Invalid => {
595 unreachable!()
596 }
597 }
598 }
599}
600
601#[cfg(test)]
602mod tests {
603 use super::AsyncRun;
604 use crate::Cancelled;
605 use crate::StopTask;
606 use crate::TaskControl;
607 use futures::FutureExt;
608 use pal_async::DefaultDriver;
609 use pal_async::async_test;
610 use std::task::Poll;
611
612 struct Foo(u32);
613
614 impl AsyncRun<bool> for Foo {
615 async fn run(
616 &mut self,
617 stop: &mut StopTask<'_>,
618 state: &mut bool,
619 ) -> Result<(), Cancelled> {
620 stop.until_stopped(async {
621 self.0 += 1;
622 if !*state {
623 std::future::pending::<()>().await;
624 }
625 })
626 .await
627 }
628 }
629
630 async fn yield_once() {
631 let mut yielded = false;
632 std::future::poll_fn(|cx| {
633 if yielded {
634 Poll::Ready(())
635 } else {
636 yielded = true;
637 cx.waker().wake_by_ref();
638 Poll::Pending
639 }
640 })
641 .await
642 }
643
644 #[async_test]
645 async fn test(driver: DefaultDriver) {
646 let mut t = TaskControl::new(Foo(5));
647 t.insert(&driver, "test", false);
648 t.remove();
649 t.insert(&driver, "test", false);
650 assert_eq!(t.task().0, 5);
651 assert!(t.start());
652 yield_once().await;
653 assert!(t.stop().await);
654 assert_eq!(t.task().0, 6);
655 *t.state_mut().unwrap() = true;
656 assert!(t.start());
657 yield_once().await;
658 assert!(!t.stop().await);
659 assert_eq!(t.task().0, 7);
660 assert!(!t.start());
662 yield_once().await;
663 assert!(!t.stop().await);
664 assert_eq!(t.task().0, 7);
665 }
666
667 #[async_test]
668 async fn test_cancelled_stop(driver: DefaultDriver) {
669 let mut t = TaskControl::new(Foo(5));
670 t.insert(&driver, "test", false);
671 assert!(t.start());
672 yield_once().await;
673 t.update_with(|t, _| t.0 += 1);
674 assert!(t.stop().now_or_never().is_none());
675 t.update_with(|t, _| t.0 += 1);
676 assert!(t.stop().await);
677 assert_eq!(t.task_mut().0, 8);
678 }
679
680 #[async_test]
681 async fn test_poll_stop(driver: DefaultDriver) {
682 let mut t = TaskControl::new(Foo(5));
683
684 assert_eq!(
686 std::future::poll_fn(|cx| Poll::Ready(t.poll_stop(cx))).await,
687 Poll::Ready(false)
688 );
689
690 t.insert(&driver, "test", false);
691
692 assert_eq!(
694 std::future::poll_fn(|cx| Poll::Ready(t.poll_stop(cx))).await,
695 Poll::Ready(false)
696 );
697
698 assert!(t.start());
699 yield_once().await;
700
701 let result = std::future::poll_fn(|cx| t.poll_stop(cx)).await;
703 assert!(result); assert_eq!(t.task().0, 6);
705
706 assert_eq!(
708 std::future::poll_fn(|cx| Poll::Ready(t.poll_stop(cx))).await,
709 Poll::Ready(false)
710 );
711 }
712}