pal_async/
io_pool.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Single-threaded task pools backed by platform-specific IO backends.
5
6use crate::task::Schedule;
7use crate::task::Scheduler;
8use crate::task::Spawn;
9use crate::task::TaskMetadata;
10use crate::task::TaskQueue;
11use crate::task::task_queue;
12use std::future::Future;
13use std::future::poll_fn;
14use std::pin::pin;
15use std::sync::Arc;
16use std::task::Poll;
17
18/// An single-threaded task pool backed by IO backend `T`.
19#[derive(Debug)]
20pub struct IoPool<T> {
21    driver: IoDriver<T>,
22    tasks: TaskQueue,
23}
24
25/// A driver to spawn tasks and IO objects on [`IoPool`].
26#[derive(Debug)]
27pub struct IoDriver<T> {
28    pub(crate) inner: Arc<T>,
29    scheduler: Arc<Scheduler>,
30}
31
32impl<T> Clone for IoDriver<T> {
33    fn clone(&self) -> Self {
34        Self {
35            inner: self.inner.clone(),
36            scheduler: self.scheduler.clone(),
37        }
38    }
39}
40
41/// Trait implemented by IO backends.
42pub trait IoBackend: Send + Sync {
43    /// The name of the backend.
44    fn name() -> &'static str;
45    /// Run the
46    fn run<Fut: Future>(self: &Arc<Self>, fut: Fut) -> Fut::Output;
47}
48
49impl<T: IoBackend + Default> IoPool<T> {
50    /// Creates a new task pool.
51    pub fn new() -> Self {
52        Self::named(T::name().to_owned())
53    }
54
55    fn named(name: impl Into<Arc<str>>) -> Self {
56        let (tasks, scheduler) = task_queue(name);
57        Self {
58            driver: IoDriver {
59                inner: Arc::new(T::default()),
60                scheduler: Arc::new(scheduler),
61            },
62            tasks,
63        }
64    }
65
66    /// Creates and runs a task pool, seeding it with an initial future
67    /// `f(driver)`, until all tasks have completed.
68    pub fn run_with<F, R>(f: F) -> R
69    where
70        F: AsyncFnOnce(IoDriver<T>) -> R,
71    {
72        let mut pool = Self::named(std::thread::current().name().unwrap_or_else(|| T::name()));
73        let fut = f(pool.driver.clone());
74        drop(pool.driver.scheduler);
75        pool.driver
76            .inner
77            .run(async { futures::future::join(fut, pool.tasks.run()).await.0 })
78    }
79
80    /// Creates a new pool and runs it on a newly spawned thread with the given
81    /// name. Returns the thread handle and the pool's driver.
82    pub fn spawn_on_thread(name: impl Into<String>) -> (std::thread::JoinHandle<()>, IoDriver<T>)
83    where
84        T: 'static,
85    {
86        let pool = Self::new();
87        let driver = pool.driver.clone();
88        let thread = std::thread::Builder::new()
89            .name(name.into())
90            .spawn(move || pool.run())
91            .unwrap();
92        (thread, driver)
93    }
94}
95
96impl<T: IoBackend> IoPool<T> {
97    /// Returns the IO driver.
98    pub fn driver(&self) -> IoDriver<T> {
99        self.driver.clone()
100    }
101
102    /// Runs `f` and the task pool until `f` completes.
103    pub fn run_until<Fut: Future>(&mut self, f: Fut) -> Fut::Output {
104        let mut tasks = pin!(self.tasks.run());
105        let mut f = pin!(f);
106        self.driver.inner.run(poll_fn(|cx| {
107            if let Poll::Ready(r) = f.as_mut().poll(cx) {
108                Poll::Ready(r)
109            } else {
110                assert!(tasks.as_mut().poll(cx).is_pending());
111                Poll::Pending
112            }
113        }))
114    }
115
116    /// Runs the task pool until all tasks are completed.
117    pub fn run(mut self) {
118        // Update the executor name with the current thread's name.
119        if let Some(name) = std::thread::current().name() {
120            self.driver.scheduler.set_name(name);
121        }
122        drop(self.driver.scheduler);
123        self.driver.inner.run(self.tasks.run())
124    }
125}
126
127impl<T: IoBackend> Spawn for IoDriver<T> {
128    fn scheduler(&self, _metadata: &TaskMetadata) -> Arc<dyn Schedule> {
129        self.scheduler.clone()
130    }
131}