pal_async/
task.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Task spawning support.
5
6// UNSAFETY: Managing information stored as pointers for debugging purposes.
7#![expect(unsafe_code)]
8
9use loan_cell::LoanCell;
10use parking_lot::Mutex;
11use slab::Slab;
12use std::fmt::Debug;
13use std::fmt::Display;
14use std::future::Future;
15use std::panic::Location;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::sync::Weak;
19use std::sync::atomic::AtomicBool;
20use std::sync::atomic::AtomicU32;
21use std::sync::atomic::AtomicUsize;
22use std::sync::atomic::Ordering;
23
24/// A handle to a task.
25pub type Task<T> = async_task::Task<T, TaskMetadata>;
26
27/// A handle to a task that's ready to run.
28pub type Runnable = async_task::Runnable<TaskMetadata>;
29
30/// Metadata about a spawned task.
31///
32/// This can be accessed via [`Task::metadata()`], [`Runnable::metadata()`], or
33/// [`with_current_task_metadata()`].
34#[derive(Debug)]
35pub struct TaskMetadata {
36    name: Arc<str>,
37    location: &'static Location<'static>,
38    /// The ready/waiting/running/done state of the future itself. Tracked for
39    /// diagnostics purposes.
40    state: AtomicU32,
41    /// Whether the task has been dropped. This could be a high bit in `state`
42    /// or something, but keep this separate to make the codegen straightforward
43    /// for all state updates.
44    dropped: AtomicBool,
45    scheduler: Weak<dyn Schedule>,
46    id: AtomicUsize,
47    _no_pin: std::marker::PhantomPinned,
48}
49
50impl TaskMetadata {
51    const NO_ID: usize = !0;
52
53    #[track_caller]
54    fn new(name: Arc<str>) -> Self {
55        Self {
56            name,
57            location: Location::caller(),
58            state: AtomicU32::new(TASK_STATE_READY),
59            dropped: AtomicBool::new(false),
60            scheduler: Weak::<Scheduler>::new(),
61            id: AtomicUsize::new(Self::NO_ID),
62            _no_pin: std::marker::PhantomPinned,
63        }
64    }
65
66    fn register(self: Pin<&Self>) {
67        assert_eq!(self.id.load(Ordering::Relaxed), Self::NO_ID);
68        // Insert a pointer into the global task list. This is safe because this
69        // object is known to be pinned, so its storage will not be deallocated
70        // without calling `drop` (which will remove it from the list).
71        let id = TASK_LIST
72            .slab
73            .lock()
74            .insert(TaskMetadataPtr(self.get_ref()));
75        self.id.store(id, Ordering::Relaxed);
76    }
77
78    fn pend(&self) {
79        self.state.store(TASK_STATE_WAITING, Ordering::Relaxed);
80    }
81
82    fn done(&self) {
83        self.state.store(TASK_STATE_DONE, Ordering::Relaxed);
84    }
85
86    fn run(&self) {
87        self.state.store(TASK_STATE_RUNNING, Ordering::Relaxed);
88    }
89
90    /// The name of the spawned task.
91    pub fn name(&self) -> &Arc<str> {
92        &self.name
93    }
94
95    /// The location where the task was spawned.
96    pub fn location(&self) -> &'static Location<'static> {
97        self.location
98    }
99
100    /// The current state of the task.
101    fn state(&self) -> TaskState {
102        let state = self.state.load(Ordering::Relaxed);
103        if self.dropped.load(Ordering::Relaxed) {
104            if state == TASK_STATE_DONE {
105                TaskState::Complete
106            } else {
107                TaskState::Cancelled
108            }
109        } else {
110            match self.state.load(Ordering::Relaxed) {
111                TASK_STATE_READY => TaskState::Ready,
112                TASK_STATE_WAITING => TaskState::Waiting,
113                TASK_STATE_RUNNING => TaskState::Running,
114                TASK_STATE_DONE => TaskState::Complete,
115                _ => unreachable!(),
116            }
117        }
118    }
119}
120
121impl Drop for TaskMetadata {
122    fn drop(&mut self) {
123        let id = self.id.load(Ordering::Relaxed);
124        if id != Self::NO_ID {
125            let _task = TASK_LIST.slab.lock().remove(id);
126        }
127    }
128}
129
130#[derive(Debug, Copy, Clone)]
131struct TaskMetadataPtr(*const TaskMetadata);
132
133// Assert `TaskMetadata` is `Send` and `Sync`.
134const _: () = {
135    const fn assert_send_sync<T: Send + Sync>() {}
136    assert_send_sync::<TaskMetadata>();
137};
138
139// SAFETY: `TaskMetadata` can be safely shared between threads (asserted above).
140unsafe impl Send for TaskMetadataPtr {}
141// SAFETY: `TaskMetadata` can be safely shared between threads (asserted above).
142unsafe impl Sync for TaskMetadataPtr {}
143
144/// A queue of tasks that will be run on a single thread.
145#[derive(Debug)]
146pub struct TaskQueue {
147    tasks: async_channel::Receiver<Runnable>,
148}
149
150/// A task scheduler for a [`TaskQueue`].
151#[derive(Debug)]
152pub struct Scheduler {
153    send: async_channel::Sender<Runnable>,
154    name: Mutex<Arc<str>>,
155}
156
157impl Scheduler {
158    /// Updates the name of the scheduler.
159    pub fn set_name(&self, name: impl Into<Arc<str>>) {
160        *self.name.lock() = name.into();
161    }
162}
163
164impl Schedule for Scheduler {
165    fn schedule(&self, runnable: Runnable) {
166        let _ = self.send.try_send(runnable);
167    }
168
169    fn name(&self) -> Arc<str> {
170        self.name.lock().clone()
171    }
172}
173
174/// Creates a new task queue and associated scheduler.
175pub fn task_queue(name: impl Into<Arc<str>>) -> (TaskQueue, Scheduler) {
176    let (send, recv) = async_channel::unbounded();
177    (
178        TaskQueue { tasks: recv },
179        Scheduler {
180            send,
181            name: Mutex::new(name.into()),
182        },
183    )
184}
185
186impl TaskQueue {
187    /// Runs tasks on the queue.
188    ///
189    /// Returns when the associated scheduler has been dropped.
190    pub async fn run(&mut self) {
191        while let Ok(task) = self.tasks.recv().await {
192            task.run();
193        }
194    }
195}
196
197/// Trait for scheduling a task on an executor.
198pub trait Schedule: Send + Sync {
199    /// Schedules a task to run.
200    fn schedule(&self, runnable: Runnable);
201
202    /// Gets the executor name.
203    fn name(&self) -> Arc<str>;
204}
205
206struct TaskFuture<'a, Fut> {
207    metadata: &'a TaskMetadata,
208    _scheduler: Arc<dyn Schedule>, // Keep the scheduler alive until the future is dropped.
209    future: Fut,
210}
211
212impl<'a, Fut: Future> TaskFuture<'a, Fut> {
213    fn new(metadata: Pin<&'a TaskMetadata>, scheduler: Arc<dyn Schedule>, future: Fut) -> Self {
214        metadata.register();
215        Self {
216            metadata: metadata.get_ref(),
217            _scheduler: scheduler,
218            future,
219        }
220    }
221
222    /// Wrapper around `new` for passing to [`async_task::Builder::spawn`].
223    ///
224    /// # Safety
225    /// The caller guarantees that the incoming `metadata` pointer is pinned and
226    /// that the future will not be used beyond the lifetime of `metadata`
227    /// (despite having a static lifetime). This is guaranteed by
228    /// [`async_task::Builder::spawn`] API, which unfortunately is missing the
229    /// explicit `Pin` and the appropriate lifetime on the future.
230    ///
231    /// See <https://github.com/smol-rs/async-task/issues/76>.
232    unsafe fn new_for_async_task(
233        metadata: &'a TaskMetadata,
234        scheduler: Arc<dyn Schedule>,
235        future: Fut,
236    ) -> TaskFuture<'static, Fut> {
237        // SAFETY: `metadata` is pinned by `async_task::Builder::spawn`, and the
238        // caller guarantees this function will only be used in that context.
239        let metadata = unsafe { Pin::new_unchecked(metadata) };
240        let this = Self::new(metadata, scheduler, future);
241        // Transmute to static lifetime, as required by
242        // `async_task::Builder::spawn`.
243        //
244        // SAFETY: the caller guarantees this future will only be passed to
245        // `spawn`, which will guarantee the metadata is not dropped before the
246        // future is dropped.
247        unsafe { std::mem::transmute::<TaskFuture<'a, Fut>, TaskFuture<'static, Fut>>(this) }
248    }
249}
250
251impl<Fut: Future> Future for TaskFuture<'_, Fut> {
252    type Output = Fut::Output;
253
254    fn poll(
255        self: Pin<&mut Self>,
256        cx: &mut std::task::Context<'_>,
257    ) -> std::task::Poll<Self::Output> {
258        // SAFETY: projecting this type for pinned access to the future. The
259        // future will not be moved or dropped.
260        let this = unsafe { self.get_unchecked_mut() };
261        this.metadata.run();
262        // SAFETY: the future is pinned since `self` is pinned.
263        let future = unsafe { Pin::new_unchecked(&mut this.future) };
264        let r = CURRENT_TASK.with(|task| task.lend(this.metadata, || future.poll(cx)));
265        if r.is_pending() {
266            this.metadata.pend();
267        } else {
268            this.metadata.done();
269        }
270        r
271    }
272}
273
274impl<Fut> Drop for TaskFuture<'_, Fut> {
275    fn drop(&mut self) {
276        self.metadata.dropped.store(true, Ordering::Relaxed);
277    }
278}
279
280fn schedule(runnable: Runnable) {
281    let metadata = runnable.metadata();
282    metadata.state.store(TASK_STATE_READY, Ordering::Relaxed);
283    if let Some(scheduler) = metadata.scheduler.upgrade() {
284        scheduler.schedule(runnable);
285    }
286}
287
288/// Trait for spawning a task on an executor.
289pub trait Spawn: Send + Sync {
290    /// Gets a scheduler for a new task.
291    fn scheduler(&self, metadata: &TaskMetadata) -> Arc<dyn Schedule>;
292
293    /// Spawns a task.
294    #[track_caller]
295    fn spawn<T: 'static + Send>(
296        &self,
297        name: impl Into<Arc<str>>,
298        fut: impl Future<Output = T> + Send + 'static,
299    ) -> Task<T>
300    where
301        Self: Sized,
302    {
303        let mut metadata = TaskMetadata::new(name.into());
304        let scheduler = self.scheduler(&metadata);
305        metadata.scheduler = Arc::downgrade(&scheduler);
306        let (runnable, task) = async_task::Builder::new().metadata(metadata).spawn(
307            |metadata| {
308                // SAFETY: calling from the async_task::Builder::spawn closure, as required.
309                unsafe { TaskFuture::new_for_async_task(metadata, scheduler, fut) }
310            },
311            schedule,
312        );
313        runnable.schedule();
314        task
315    }
316}
317
318/// Trait for spawning a non-`Send` task on an executor.
319pub trait SpawnLocal {
320    /// Gets a scheduler for a new task.
321    fn scheduler_local(&self, metadata: &TaskMetadata) -> Arc<dyn Schedule>;
322
323    /// Spawns a task.
324    #[track_caller]
325    fn spawn_local<T: 'static>(
326        &self,
327        name: impl Into<Arc<str>>,
328        fut: impl Future<Output = T> + 'static,
329    ) -> Task<T>
330    where
331        Self: Sized,
332    {
333        let mut metadata = TaskMetadata::new(name.into());
334        let scheduler = self.scheduler_local(&metadata);
335        metadata.scheduler = Arc::downgrade(&scheduler);
336        let (runnable, task) = async_task::Builder::new().metadata(metadata).spawn_local(
337            |metadata| {
338                // SAFETY: calling from the async_task::Builder::spawn closure, as required.
339                unsafe { TaskFuture::new_for_async_task(metadata, scheduler, fut) }
340            },
341            schedule,
342        );
343        runnable.schedule();
344        task
345    }
346}
347
348thread_local! {
349    static CURRENT_TASK: LoanCell<TaskMetadata> = const { LoanCell::new() };
350}
351
352/// Calls `f` with the current task metadata, if there is a current task.
353pub fn with_current_task_metadata<F: FnOnce(Option<&TaskMetadata>) -> R, R>(f: F) -> R {
354    CURRENT_TASK.with(|task| task.borrow(f))
355}
356
357impl<T: ?Sized + Spawn> Spawn for &'_ T {
358    fn scheduler(&self, metadata: &TaskMetadata) -> Arc<dyn Schedule> {
359        (*self).scheduler(metadata)
360    }
361}
362
363impl<T: ?Sized + Spawn> Spawn for Box<T> {
364    fn scheduler(&self, metadata: &TaskMetadata) -> Arc<dyn Schedule> {
365        self.as_ref().scheduler(metadata)
366    }
367}
368
369impl<T: ?Sized + Spawn> Spawn for Arc<T> {
370    fn scheduler(&self, metadata: &TaskMetadata) -> Arc<dyn Schedule> {
371        self.as_ref().scheduler(metadata)
372    }
373}
374
375impl<T: ?Sized + SpawnLocal> SpawnLocal for &'_ T {
376    fn scheduler_local(&self, metadata: &TaskMetadata) -> Arc<dyn Schedule> {
377        (*self).scheduler_local(metadata)
378    }
379}
380
381impl<T: ?Sized + SpawnLocal> SpawnLocal for Box<T> {
382    fn scheduler_local(&self, metadata: &TaskMetadata) -> Arc<dyn Schedule> {
383        self.as_ref().scheduler_local(metadata)
384    }
385}
386
387impl<T: ?Sized + SpawnLocal> SpawnLocal for Arc<T> {
388    fn scheduler_local(&self, metadata: &TaskMetadata) -> Arc<dyn Schedule> {
389        self.as_ref().scheduler_local(metadata)
390    }
391}
392
393const TASK_STATE_READY: u32 = 0;
394const TASK_STATE_WAITING: u32 = 1;
395const TASK_STATE_RUNNING: u32 = 2;
396const TASK_STATE_DONE: u32 = 3;
397
398/// The state of a task.
399#[derive(Debug, Copy, Clone)]
400#[repr(u64)]
401pub enum TaskState {
402    /// The task is ready to run.
403    Ready,
404    /// The task is waiting on some condition.
405    Waiting,
406    /// The task is running on an executor.
407    Running,
408    /// The task has completed.
409    Complete,
410    /// The task was cancelled before it completed.
411    Cancelled,
412}
413
414impl Display for TaskState {
415    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416        let s = match self {
417            TaskState::Ready => "ready",
418            TaskState::Waiting => "waiting",
419            TaskState::Running => "running",
420            TaskState::Complete => "complete",
421            TaskState::Cancelled => "cancelled",
422        };
423        f.pad(s)
424    }
425}
426
427/// A list of tasks.
428pub struct TaskList {
429    slab: Mutex<Slab<TaskMetadataPtr>>,
430}
431
432static TASK_LIST: TaskList = TaskList::new();
433
434impl TaskList {
435    const fn new() -> Self {
436        Self {
437            slab: Mutex::new(Slab::new()),
438        }
439    }
440
441    /// Gets the global task list.
442    pub fn global() -> &'static Self {
443        &TASK_LIST
444    }
445
446    /// Gets a snapshot of the current tasks.
447    pub fn tasks(&self) -> Vec<TaskData> {
448        let tasks = self.slab.lock();
449        tasks
450            .iter()
451            .map(|(id, task)| {
452                // SAFETY: the pointer is guaranteed to be valid while the lock
453                // is held, since it was published via
454                // [`TaskMetadata::register`] and will be unpublished by
455                // [`TaskMetadata::drop`].
456                let task = unsafe { &*task.0 };
457                let scheduler = task.scheduler.upgrade().map(|s| s.name());
458                TaskData {
459                    id,
460                    name: task.name.clone(),
461                    location: task.location,
462                    state: task.state(),
463                    executor: scheduler,
464                }
465            })
466            .collect()
467    }
468}
469
470/// Information about a task.
471#[derive(Debug)]
472pub struct TaskData {
473    id: usize,
474    name: Arc<str>,
475    location: &'static Location<'static>,
476    state: TaskState,
477    executor: Option<Arc<str>>,
478}
479
480impl TaskData {
481    /// The task's unique ID.
482    ///
483    /// This ID may be reused.
484    pub fn id(&self) -> usize {
485        self.id
486    }
487
488    /// The executor's name.
489    pub fn executor(&self) -> Option<&str> {
490        self.executor.as_deref()
491    }
492
493    /// The task's metadata.
494    pub fn name(&self) -> &str {
495        &self.name
496    }
497
498    /// The location where the task was spawned.
499    pub fn location(&self) -> &'static Location<'static> {
500        self.location
501    }
502
503    /// The state of the task.
504    pub fn state(&self) -> TaskState {
505        self.state
506    }
507}