1#![forbid(unsafe_code)]
8
9use fast_select::FastSelect;
10use inspect::Inspect;
11use inspect::InspectMut;
12use pal_async::task::Spawn;
13use pal_async::task::Task;
14use parking_lot::Mutex;
15use std::future::Future;
16use std::future::poll_fn;
17use std::pin::Pin;
18use std::pin::pin;
19use std::sync::Arc;
20use std::task::Context;
21use std::task::Poll;
22use std::task::Waker;
23
24pub trait AsyncRun<S>: 'static + Send {
27 fn run(
40 &mut self,
41 stop: &mut StopTask<'_>,
42 _: &mut S,
43 ) -> impl Send + Future<Output = Result<(), Cancelled>>;
44}
45
46#[derive(Debug)]
49pub struct Cancelled;
50
51pub struct StopTask<'a> {
54 inner: &'a mut (dyn 'a + Send + Future<Output = ()> + Unpin),
55 fast_select: &'a mut FastSelect,
56}
57
58struct StopTaskInner<'a, T, S> {
63 shared: &'a Mutex<Shared<T, S>>,
64}
65
66impl<T: AsyncRun<S>, S> Future for StopTaskInner<'_, T, S> {
67 type Output = ();
68
69 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
70 let mut shared = self.get_mut().shared.lock();
71 if !shared.calls.is_empty() || shared.stop {
72 return Poll::Ready(());
73 }
74 if shared
75 .inner_waker
76 .as_ref()
77 .is_none_or(|waker| !cx.waker().will_wake(waker))
78 {
79 shared.inner_waker = Some(cx.waker().clone());
80 }
81 Poll::Pending
82 }
83}
84
85impl StopTask<'_> {
86 pub async fn run_with<R>(
88 mut stop: impl Send + Future<Output = ()> + Unpin,
89 f: impl AsyncFnOnce(&mut StopTask<'_>) -> R,
90 ) -> R {
91 let mut fast_select: FastSelect = FastSelect::new();
92 let mut stop = StopTask {
93 inner: &mut stop,
94 fast_select: &mut fast_select,
95 };
96 f(&mut stop).await
97 }
98
99 pub async fn until_stopped<F: Future>(&mut self, fut: F) -> Result<F::Output, Cancelled> {
106 let mut cancel = pin!(
109 self.fast_select
110 .select((poll_fn(|cx| Pin::new(&mut self.inner).poll(cx)),))
111 );
112
113 let mut fut = pin!(fut);
114
115 poll_fn(|cx| {
117 if let Poll::Ready(r) = fut.as_mut().poll(cx) {
118 Poll::Ready(Ok(r))
119 } else if cancel.as_mut().poll(cx).is_ready() {
120 Poll::Ready(Err(Cancelled))
121 } else {
122 Poll::Pending
123 }
124 })
125 .await
126 }
127}
128
129impl Future for StopTask<'_> {
130 type Output = ();
131
132 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
133 Pin::new(&mut self.inner).poll(cx)
134 }
135}
136
137pub struct TaskControl<T, S> {
140 inner: Inner<T, S>,
141}
142
143pub trait InspectTask<S>: AsyncRun<S> {
145 fn inspect(&self, req: inspect::Request<'_>, state: Option<&S>);
150}
151
152impl<T: InspectTask<S>, S> Inspect for TaskAndState<T, S> {
153 fn inspect(&self, req: inspect::Request<'_>) {
154 self.task.inspect(req, self.state.as_ref());
155 }
156}
157
158impl<T: InspectTask<S>, S> Inspect for TaskControl<T, S> {
159 fn inspect(&self, req: inspect::Request<'_>) {
160 match &self.inner {
161 Inner::NoState(task_and_state) => task_and_state.inspect(req),
162 Inner::WithState {
163 activity, shared, ..
164 } => match activity {
165 Activity::Stopped(task_and_state) => task_and_state.inspect(req),
166 Activity::Running => {
167 let deferred = req.defer();
168 Shared::push_call(
169 shared,
170 Box::new(|task_and_state| {
171 deferred.inspect(&task_and_state);
172 }),
173 )
174 }
175 },
176 Inner::Invalid => unreachable!(),
177 }
178 }
179}
180
181pub trait InspectTaskMut<T>: AsyncRun<T> {
183 fn inspect_mut(&mut self, req: inspect::Request<'_>, state: Option<&mut T>);
188}
189
190impl<T: InspectTaskMut<S>, S> InspectMut for TaskAndState<T, S> {
191 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
192 self.task.inspect_mut(req, self.state.as_mut());
193 }
194}
195
196impl<T: InspectTaskMut<U>, U> InspectMut for TaskControl<T, U> {
197 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
198 match &mut self.inner {
199 Inner::NoState(task_and_state) => task_and_state.inspect_mut(req),
200 Inner::WithState {
201 activity, shared, ..
202 } => match activity {
203 Activity::Stopped(task_and_state) => task_and_state.inspect_mut(req),
204 Activity::Running => {
205 let deferred = req.defer();
206 Shared::push_call(
207 shared,
208 Box::new(|task_and_state| {
209 deferred.inspect(task_and_state);
210 }),
211 );
212 }
213 },
214 Inner::Invalid => unreachable!(),
215 }
216 }
217}
218
219type CallFn<T, S> = Box<dyn FnOnce(&mut TaskAndState<T, S>) + Send>;
220
221enum Inner<T, S> {
222 NoState(Box<TaskAndState<T, S>>),
223 WithState {
224 activity: Activity<T, S>,
225 _backing_task: Task<()>,
226 shared: Arc<Mutex<Shared<T, S>>>,
227 },
228 Invalid,
229}
230
231struct TaskAndState<T, S> {
232 task: T,
233 state: Option<S>,
234 done: bool,
235}
236
237struct Shared<T, S> {
238 task_and_state: Option<Box<TaskAndState<T, S>>>,
239 calls: Vec<CallFn<T, S>>,
240 stop: bool,
241 outer_waker: Option<Waker>,
242 inner_waker: Option<Waker>,
243}
244
245impl<T, S> Shared<T, S> {
246 fn push_call(this: &Mutex<Self>, f: CallFn<T, S>) {
247 let waker = {
248 let mut this = this.lock();
249 this.calls.push(f);
250 this.inner_waker.take()
251 };
252 if let Some(waker) = waker {
253 waker.wake();
254 }
255 }
256}
257
258enum Activity<T, S> {
259 Stopped(Box<TaskAndState<T, S>>),
260 Running,
261}
262
263impl<T: AsyncRun<S>, S: 'static + Send> TaskControl<T, S> {
264 pub fn new(task: T) -> Self {
267 Self {
268 inner: Inner::NoState(Box::new(TaskAndState {
269 task,
270 state: None,
271 done: false,
272 })),
273 }
274 }
275
276 pub fn has_state(&self) -> bool {
278 match &self.inner {
279 Inner::NoState(_) => false,
280 Inner::WithState { .. } => true,
281 Inner::Invalid => unreachable!(),
282 }
283 }
284
285 pub fn is_running(&self) -> bool {
287 match &self.inner {
288 Inner::NoState(_)
289 | Inner::WithState {
290 activity: Activity::Stopped { .. },
291 ..
292 } => false,
293 Inner::WithState {
294 activity: Activity::Running,
295 ..
296 } => true,
297 Inner::Invalid => unreachable!(),
298 }
299 }
300
301 #[track_caller]
305 pub fn task(&self) -> &T {
306 self.get().0
307 }
308
309 #[track_caller]
313 pub fn task_mut(&mut self) -> &mut T {
314 self.get_mut().0
315 }
316
317 #[track_caller]
321 pub fn state(&self) -> Option<&S> {
322 self.get().1
323 }
324
325 #[track_caller]
329 pub fn state_mut(&mut self) -> Option<&mut S> {
330 self.get_mut().1
331 }
332
333 #[track_caller]
337 pub fn get(&self) -> (&T, Option<&S>) {
338 let task_and_state = match &self.inner {
339 Inner::NoState(task_and_state) => task_and_state,
340 Inner::WithState {
341 activity: Activity::Stopped(task_and_state),
342 ..
343 } => task_and_state,
344 Inner::WithState {
345 activity: Activity::Running,
346 ..
347 } => panic!("attempt to access running task"),
348 Inner::Invalid => unreachable!(),
349 };
350 (&task_and_state.task, task_and_state.state.as_ref())
351 }
352
353 #[track_caller]
357 pub fn get_mut(&mut self) -> (&mut T, Option<&mut S>) {
358 let task_and_state = match &mut self.inner {
359 Inner::NoState(task_and_state) => task_and_state,
360 Inner::WithState {
361 activity: Activity::Stopped(task_and_state),
362 ..
363 } => task_and_state,
364 Inner::WithState {
365 activity: Activity::Running,
366 ..
367 } => panic!("attempt to access running task"),
368 Inner::Invalid => unreachable!(),
369 };
370 (&mut task_and_state.task, task_and_state.state.as_mut())
371 }
372
373 #[track_caller]
377 pub fn into_inner(self) -> (T, Option<S>) {
378 let task_and_state = match self.inner {
379 Inner::NoState(task_and_state) => task_and_state,
380 Inner::WithState {
381 activity: Activity::Stopped(task_and_state),
382 ..
383 } => task_and_state,
384 Inner::WithState {
385 activity: Activity::Running,
386 ..
387 } => panic!("attempt to extract running task"),
388 Inner::Invalid => unreachable!(),
389 };
390 (task_and_state.task, task_and_state.state)
391 }
392
393 pub fn update_with(&mut self, f: impl 'static + Send + FnOnce(&mut T, Option<&mut S>)) {
398 let f = |task_and_state: &mut TaskAndState<T, S>| {
399 f(&mut task_and_state.task, task_and_state.state.as_mut())
400 };
401 match &mut self.inner {
402 Inner::NoState(task_and_state) => f(task_and_state),
403 Inner::WithState {
404 activity, shared, ..
405 } => match activity {
406 Activity::Stopped(task_and_state) => f(task_and_state),
407 Activity::Running => Shared::push_call(shared, Box::new(f)),
408 },
409 Inner::Invalid => unreachable!(),
410 }
411 }
412
413 #[track_caller]
416 pub fn insert(&mut self, spawn: impl Spawn, name: impl Into<Arc<str>>, state: S) -> &mut S {
417 self.inner = match std::mem::replace(&mut self.inner, Inner::Invalid) {
418 Inner::NoState(mut task_and_state) => {
419 task_and_state.state = Some(state);
420 task_and_state.done = false;
421 let shared = Arc::new(Mutex::new(Shared {
422 task_and_state: None,
423 calls: Vec::new(),
424 stop: true,
425 outer_waker: None,
426 inner_waker: None,
427 }));
428 let backing_task = spawn.spawn(name, Self::run(shared.clone()));
429 Inner::WithState {
430 activity: Activity::Stopped(task_and_state),
431 _backing_task: backing_task,
432 shared,
433 }
434 }
435 Inner::WithState { .. } => panic!("attempt to insert already-present state"),
436 Inner::Invalid => unreachable!(),
437 };
438 self.state_mut().unwrap()
439 }
440
441 pub fn start(&mut self) -> bool {
447 match &mut self.inner {
448 Inner::WithState {
449 activity, shared, ..
450 } => match std::mem::replace(activity, Activity::Running) {
451 Activity::Stopped(task_and_state) => {
452 if task_and_state.done {
453 *activity = Activity::Stopped(task_and_state);
454 return false;
455 }
456 let waker = {
457 let mut shared = shared.lock();
458 shared.task_and_state = Some(task_and_state);
459 shared.stop = false;
460 shared.inner_waker.take()
461 };
462 if let Some(waker) = waker {
463 waker.wake();
464 }
465 true
466 }
467 Activity::Running => true,
468 },
469 Inner::NoState(_) => false,
470 Inner::Invalid => {
471 unreachable!()
472 }
473 }
474 }
475
476 async fn run(shared: Arc<Mutex<Shared<T, S>>>) {
477 StopTask::run_with(StopTaskInner { shared: &shared }, async |stop_task| {
478 let mut calls = Vec::new();
479 loop {
480 let (mut task_and_state, stop) = poll_fn(|cx| {
481 let mut shared = shared.lock();
482 let has_work = shared
483 .task_and_state
484 .as_ref()
485 .is_some_and(|ts| !shared.calls.is_empty() || (!shared.stop && !ts.done));
486 if !has_work {
487 shared.inner_waker = Some(cx.waker().clone());
488 return Poll::Pending;
489 }
490 calls.append(&mut shared.calls);
491 Poll::Ready((shared.task_and_state.take().unwrap(), shared.stop))
492 })
493 .await;
494
495 for call in calls.drain(..) {
496 call(&mut task_and_state);
497 }
498
499 if !stop && !task_and_state.done {
500 task_and_state.done = task_and_state
501 .task
502 .run(&mut *stop_task, task_and_state.state.as_mut().unwrap())
503 .await
504 .is_ok();
505 }
506
507 let waker = {
508 let mut shared = shared.lock();
509 shared.task_and_state = Some(task_and_state);
510 shared.outer_waker.take()
511 };
512 if let Some(waker) = waker {
513 waker.wake();
514 }
515 }
516 })
517 .await
518 }
519
520 pub async fn stop(&mut self) -> bool {
525 match &mut self.inner {
526 Inner::WithState {
527 activity, shared, ..
528 } => match activity {
529 Activity::Running => {
530 let task_and_state = poll_fn(|cx| {
531 let mut shared = shared.lock();
532 shared.stop = true;
533 if shared.task_and_state.is_none() || !shared.calls.is_empty() {
534 shared.outer_waker = Some(cx.waker().clone());
535 let waker = shared.inner_waker.take();
536 drop(shared);
537 if let Some(waker) = waker {
538 waker.wake();
539 }
540 return Poll::Pending;
541 }
542 Poll::Ready(shared.task_and_state.take().unwrap())
543 })
544 .await;
545
546 let done = task_and_state.done;
547 *activity = Activity::Stopped(task_and_state);
548 !done
549 }
550 _ => false,
551 },
552 Inner::NoState(_) => false,
553 Inner::Invalid => unreachable!(),
554 }
555 }
556
557 #[track_caller]
561 pub fn remove(&mut self) -> S {
562 match std::mem::replace(&mut self.inner, Inner::Invalid) {
563 Inner::WithState {
564 activity: Activity::Stopped(mut task_and_state),
565 ..
566 } => {
567 let state = task_and_state.state.take().unwrap();
568 self.inner = Inner::NoState(task_and_state);
569 state
570 }
571 Inner::NoState(_) => panic!("attempt to remove missing state"),
572 Inner::WithState { .. } => panic!("attempt to remove state from running task"),
573 Inner::Invalid => {
574 unreachable!()
575 }
576 }
577 }
578}
579
580#[cfg(test)]
581mod tests {
582 use super::AsyncRun;
583 use crate::Cancelled;
584 use crate::StopTask;
585 use crate::TaskControl;
586 use futures::FutureExt;
587 use pal_async::DefaultDriver;
588 use pal_async::async_test;
589 use std::task::Poll;
590
591 struct Foo(u32);
592
593 impl AsyncRun<bool> for Foo {
594 async fn run(
595 &mut self,
596 stop: &mut StopTask<'_>,
597 state: &mut bool,
598 ) -> Result<(), Cancelled> {
599 stop.until_stopped(async {
600 self.0 += 1;
601 if !*state {
602 std::future::pending::<()>().await;
603 }
604 })
605 .await
606 }
607 }
608
609 async fn yield_once() {
610 let mut yielded = false;
611 std::future::poll_fn(|cx| {
612 if yielded {
613 Poll::Ready(())
614 } else {
615 yielded = true;
616 cx.waker().wake_by_ref();
617 Poll::Pending
618 }
619 })
620 .await
621 }
622
623 #[async_test]
624 async fn test(driver: DefaultDriver) {
625 let mut t = TaskControl::new(Foo(5));
626 t.insert(&driver, "test", false);
627 t.remove();
628 t.insert(&driver, "test", false);
629 assert_eq!(t.task().0, 5);
630 assert!(t.start());
631 yield_once().await;
632 assert!(t.stop().await);
633 assert_eq!(t.task().0, 6);
634 *t.state_mut().unwrap() = true;
635 assert!(t.start());
636 yield_once().await;
637 assert!(!t.stop().await);
638 assert_eq!(t.task().0, 7);
639 assert!(!t.start());
641 yield_once().await;
642 assert!(!t.stop().await);
643 assert_eq!(t.task().0, 7);
644 }
645
646 #[async_test]
647 async fn test_cancelled_stop(driver: DefaultDriver) {
648 let mut t = TaskControl::new(Foo(5));
649 t.insert(&driver, "test", false);
650 assert!(t.start());
651 yield_once().await;
652 t.update_with(|t, _| t.0 += 1);
653 assert!(t.stop().now_or_never().is_none());
654 t.update_with(|t, _| t.0 += 1);
655 assert!(t.stop().await);
656 assert_eq!(t.task_mut().0, 8);
657 }
658}