#![expect(unsafe_code)]
use parking_lot::Mutex;
use slab::Slab;
use std::cell::Cell;
use std::fmt::Debug;
use std::fmt::Display;
use std::future::Future;
use std::panic::Location;
use std::pin::Pin;
use std::ptr::null;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::AtomicU32;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::sync::Weak;
pub type Task<T> = async_task::Task<T, TaskMetadata>;
pub type Runnable = async_task::Runnable<TaskMetadata>;
#[derive(Debug)]
pub struct TaskMetadata {
name: Arc<str>,
location: &'static Location<'static>,
state: AtomicU32,
dropped: AtomicBool,
scheduler: Weak<dyn Schedule>,
id: AtomicUsize,
_no_pin: std::marker::PhantomPinned,
}
impl TaskMetadata {
const NO_ID: usize = !0;
#[track_caller]
fn new(name: Arc<str>) -> Self {
Self {
name,
location: Location::caller(),
state: AtomicU32::new(TASK_STATE_READY),
dropped: AtomicBool::new(false),
scheduler: Weak::<Scheduler>::new(),
id: AtomicUsize::new(Self::NO_ID),
_no_pin: std::marker::PhantomPinned,
}
}
fn register(self: Pin<&Self>) {
assert_eq!(self.id.load(Ordering::Relaxed), Self::NO_ID);
let id = TASK_LIST
.slab
.lock()
.insert(TaskMetadataPtr(self.get_ref()));
self.id.store(id, Ordering::Relaxed);
}
fn pend(&self, old_task: *const Self) {
self.state.store(TASK_STATE_WAITING, Ordering::Relaxed);
CURRENT_TASK.with(|task| {
let this_task = task.replace(old_task);
assert_eq!(this_task, std::ptr::from_ref(self));
})
}
fn done(&self) {
self.state.store(TASK_STATE_DONE, Ordering::Relaxed);
}
fn run(&self) -> *const Self {
let old_task = CURRENT_TASK.with(|task| task.replace(std::ptr::from_ref(self)));
self.state.store(TASK_STATE_RUNNING, Ordering::Relaxed);
old_task
}
pub fn name(&self) -> &Arc<str> {
&self.name
}
pub fn location(&self) -> &'static Location<'static> {
self.location
}
fn state(&self) -> TaskState {
let state = self.state.load(Ordering::Relaxed);
if self.dropped.load(Ordering::Relaxed) {
if state == TASK_STATE_DONE {
TaskState::Complete
} else {
TaskState::Cancelled
}
} else {
match self.state.load(Ordering::Relaxed) {
TASK_STATE_READY => TaskState::Ready,
TASK_STATE_WAITING => TaskState::Waiting,
TASK_STATE_RUNNING => TaskState::Running,
TASK_STATE_DONE => TaskState::Complete,
_ => unreachable!(),
}
}
}
}
impl Drop for TaskMetadata {
fn drop(&mut self) {
let id = self.id.load(Ordering::Relaxed);
if id != Self::NO_ID {
let _task = TASK_LIST.slab.lock().remove(id);
}
}
}
#[derive(Debug, Copy, Clone)]
struct TaskMetadataPtr(*const TaskMetadata);
const _: () = {
const fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<TaskMetadata>();
};
unsafe impl Send for TaskMetadataPtr {}
unsafe impl Sync for TaskMetadataPtr {}
#[derive(Debug)]
pub struct TaskQueue {
tasks: async_channel::Receiver<Runnable>,
}
#[derive(Debug)]
pub struct Scheduler {
send: async_channel::Sender<Runnable>,
name: Mutex<Arc<str>>,
}
impl Scheduler {
pub fn set_name(&self, name: impl Into<Arc<str>>) {
*self.name.lock() = name.into();
}
}
impl Schedule for Scheduler {
fn schedule(&self, runnable: Runnable) {
let _ = self.send.try_send(runnable);
}
fn name(&self) -> Arc<str> {
self.name.lock().clone()
}
}
pub fn task_queue(name: impl Into<Arc<str>>) -> (TaskQueue, Scheduler) {
let (send, recv) = async_channel::unbounded();
(
TaskQueue { tasks: recv },
Scheduler {
send,
name: Mutex::new(name.into()),
},
)
}
impl TaskQueue {
pub async fn run(&mut self) {
while let Ok(task) = self.tasks.recv().await {
task.run();
}
}
}
pub trait Schedule: Send + Sync {
fn schedule(&self, runnable: Runnable);
fn name(&self) -> Arc<str>;
}
struct TaskFuture<'a, Fut> {
metadata: &'a TaskMetadata,
_scheduler: Arc<dyn Schedule>, future: Fut,
}
impl<'a, Fut: Future> TaskFuture<'a, Fut> {
fn new(metadata: Pin<&'a TaskMetadata>, scheduler: Arc<dyn Schedule>, future: Fut) -> Self {
metadata.register();
Self {
metadata: metadata.get_ref(),
_scheduler: scheduler,
future,
}
}
unsafe fn new_for_async_task(
metadata: &'a TaskMetadata,
scheduler: Arc<dyn Schedule>,
future: Fut,
) -> TaskFuture<'static, Fut> {
let metadata = unsafe { Pin::new_unchecked(metadata) };
let this = Self::new(metadata, scheduler, future);
unsafe { std::mem::transmute::<TaskFuture<'a, Fut>, TaskFuture<'static, Fut>>(this) }
}
}
impl<Fut: Future> Future for TaskFuture<'_, Fut> {
type Output = Fut::Output;
fn poll(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
let old_task = this.metadata.run();
let future = unsafe { Pin::new_unchecked(&mut this.future) };
let r = future.poll(cx);
if r.is_pending() {
this.metadata.pend(old_task);
} else {
this.metadata.done();
}
r
}
}
impl<Fut> Drop for TaskFuture<'_, Fut> {
fn drop(&mut self) {
self.metadata.dropped.store(true, Ordering::Relaxed);
}
}
fn schedule(runnable: Runnable) {
let metadata = runnable.metadata();
metadata.state.store(TASK_STATE_READY, Ordering::Relaxed);
if let Some(scheduler) = metadata.scheduler.upgrade() {
scheduler.schedule(runnable);
}
}
pub trait Spawn: Send + Sync {
fn scheduler(&self, metadata: &TaskMetadata) -> Arc<dyn Schedule>;
#[track_caller]
fn spawn<T: 'static + Send>(
&self,
name: impl Into<Arc<str>>,
fut: impl Future<Output = T> + Send + 'static,
) -> Task<T>
where
Self: Sized,
{
let mut metadata = TaskMetadata::new(name.into());
let scheduler = self.scheduler(&metadata);
metadata.scheduler = Arc::downgrade(&scheduler);
let (runnable, task) = async_task::Builder::new().metadata(metadata).spawn(
|metadata| {
unsafe { TaskFuture::new_for_async_task(metadata, scheduler, fut) }
},
schedule,
);
runnable.schedule();
task
}
}
pub trait SpawnLocal {
fn scheduler_local(&self, metadata: &TaskMetadata) -> Arc<dyn Schedule>;
#[track_caller]
fn spawn_local<T: 'static>(
&self,
name: impl Into<Arc<str>>,
fut: impl Future<Output = T> + 'static,
) -> Task<T>
where
Self: Sized,
{
let mut metadata = TaskMetadata::new(name.into());
let scheduler = self.scheduler_local(&metadata);
metadata.scheduler = Arc::downgrade(&scheduler);
let (runnable, task) = async_task::Builder::new().metadata(metadata).spawn_local(
|metadata| {
unsafe { TaskFuture::new_for_async_task(metadata, scheduler, fut) }
},
schedule,
);
runnable.schedule();
task
}
}
thread_local! {
static CURRENT_TASK: Cell<*const TaskMetadata> = const { Cell::new(null()) };
}
pub fn with_current_task_metadata<F: FnOnce(Option<&TaskMetadata>) -> R, R>(f: F) -> R {
CURRENT_TASK.with(|task| {
let task = task.get();
let metadata = if !task.is_null() {
Some(unsafe { &*task })
} else {
None
};
f(metadata)
})
}
impl<T: ?Sized + Spawn> Spawn for &'_ T {
fn scheduler(&self, metadata: &TaskMetadata) -> Arc<dyn Schedule> {
(*self).scheduler(metadata)
}
}
impl<T: ?Sized + Spawn> Spawn for Box<T> {
fn scheduler(&self, metadata: &TaskMetadata) -> Arc<dyn Schedule> {
self.as_ref().scheduler(metadata)
}
}
impl<T: ?Sized + Spawn> Spawn for Arc<T> {
fn scheduler(&self, metadata: &TaskMetadata) -> Arc<dyn Schedule> {
self.as_ref().scheduler(metadata)
}
}
impl<T: ?Sized + SpawnLocal> SpawnLocal for &'_ T {
fn scheduler_local(&self, metadata: &TaskMetadata) -> Arc<dyn Schedule> {
(*self).scheduler_local(metadata)
}
}
impl<T: ?Sized + SpawnLocal> SpawnLocal for Box<T> {
fn scheduler_local(&self, metadata: &TaskMetadata) -> Arc<dyn Schedule> {
self.as_ref().scheduler_local(metadata)
}
}
impl<T: ?Sized + SpawnLocal> SpawnLocal for Arc<T> {
fn scheduler_local(&self, metadata: &TaskMetadata) -> Arc<dyn Schedule> {
self.as_ref().scheduler_local(metadata)
}
}
const TASK_STATE_READY: u32 = 0;
const TASK_STATE_WAITING: u32 = 1;
const TASK_STATE_RUNNING: u32 = 2;
const TASK_STATE_DONE: u32 = 3;
#[derive(Debug, Copy, Clone)]
#[repr(u64)]
pub enum TaskState {
Ready,
Waiting,
Running,
Complete,
Cancelled,
}
impl Display for TaskState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
TaskState::Ready => "ready",
TaskState::Waiting => "waiting",
TaskState::Running => "running",
TaskState::Complete => "complete",
TaskState::Cancelled => "cancelled",
};
f.pad(s)
}
}
pub struct TaskList {
slab: Mutex<Slab<TaskMetadataPtr>>,
}
static TASK_LIST: TaskList = TaskList::new();
impl TaskList {
const fn new() -> Self {
Self {
slab: Mutex::new(Slab::new()),
}
}
pub fn global() -> &'static Self {
&TASK_LIST
}
pub fn tasks(&self) -> Vec<TaskData> {
let tasks = self.slab.lock();
tasks
.iter()
.map(|(id, task)| {
let task = unsafe { &*task.0 };
let scheduler = task.scheduler.upgrade().map(|s| s.name());
TaskData {
id,
name: task.name.clone(),
location: task.location,
state: task.state(),
executor: scheduler,
}
})
.collect()
}
}
#[derive(Debug)]
pub struct TaskData {
id: usize,
name: Arc<str>,
location: &'static Location<'static>,
state: TaskState,
executor: Option<Arc<str>>,
}
impl TaskData {
pub fn id(&self) -> usize {
self.id
}
pub fn executor(&self) -> Option<&str> {
self.executor.as_deref()
}
pub fn name(&self) -> &str {
&self.name
}
pub fn location(&self) -> &'static Location<'static> {
self.location
}
pub fn state(&self) -> TaskState {
self.state
}
}