Skip to main content

underhill_threadpool/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![cfg_attr(not(target_os = "linux"), expect(missing_docs))]
5#![cfg(target_os = "linux")]
6
7//! The Underhill per-CPU thread pool used to run async tasks and IO.
8//!
9//! This is built on top of [`pal_uring`] and [`pal_async`].
10
11// UNSAFETY: Implementing the `IoUringSubmit` trait, which has an `unsafe fn`
12// method signature. The implementation delegates to `IoInitiator` which
13// handles the actual unsafe io-uring interaction.
14#![expect(unsafe_code)]
15
16use cvm_tracing::CVM_ALLOWED;
17use inspect::Inspect;
18use loan_cell::LoanCell;
19use pal::unix::affinity::CpuSet;
20use pal_async::fd::FdReadyDriver;
21use pal_async::io_uring::IoUringSubmit;
22use pal_async::task::Runnable;
23use pal_async::task::Schedule;
24use pal_async::task::Spawn;
25use pal_async::task::SpawnLocal;
26use pal_async::task::TaskMetadata;
27use pal_async::timer::TimerDriver;
28use pal_async::wait::WaitDriver;
29use pal_uring::FdReady;
30use pal_uring::FdWait;
31use pal_uring::IdleControl;
32use pal_uring::Initiate;
33use pal_uring::IoInitiator;
34use pal_uring::IoUringPool;
35use pal_uring::PoolClient;
36use pal_uring::Timer;
37use parking_lot::Mutex;
38use std::future::poll_fn;
39use std::io;
40use std::marker::PhantomData;
41use std::os::fd::RawFd;
42use std::sync::Arc;
43use std::sync::OnceLock;
44use std::sync::atomic::AtomicBool;
45use std::sync::atomic::AtomicU32;
46use std::sync::atomic::Ordering::Relaxed;
47use std::task::Poll;
48use std::task::Waker;
49use thiserror::Error;
50
51/// Represents the internal state of an `AffinitizedThreadpool`.
52#[derive(Debug, Inspect)]
53struct AffinitizedThreadpoolState {
54    #[inspect(iter_by_index)]
55    drivers: Vec<ThreadpoolDriver>,
56}
57
58/// A pool of affinitized worker threads.
59#[derive(Clone, Debug, Inspect)]
60#[inspect(transparent)]
61pub struct AffinitizedThreadpool {
62    state: Arc<AffinitizedThreadpoolState>,
63}
64
65/// A builder for [`AffinitizedThreadpool`].
66#[derive(Debug, Clone)]
67pub struct ThreadpoolBuilder {
68    max_bounded_workers: Option<u32>,
69    max_unbounded_workers: Option<u32>,
70    ring_size: u32,
71}
72
73impl ThreadpoolBuilder {
74    /// Returns a new builder.
75    pub fn new() -> Self {
76        Self {
77            max_bounded_workers: None,
78            max_unbounded_workers: None,
79            ring_size: 256,
80        }
81    }
82
83    /// Sets the maximum number of bounded kernel workers for each worker ring,
84    /// per NUMA node.
85    ///
86    /// This defaults in the kernel to `min(io_ring_size, cpu_count * 4)`.
87    pub fn max_bounded_workers(&mut self, n: u32) -> &mut Self {
88        self.max_bounded_workers = Some(n);
89        self
90    }
91
92    /// Sets the maximum number of unbounded kernel workers for each worker
93    /// ring, per NUMA node.
94    ///
95    /// This defaults to the process's `RLIMIT_NPROC` limit at time of
96    /// threadpool creation.
97    pub fn max_unbounded_workers(&mut self, n: u32) -> &mut Self {
98        self.max_unbounded_workers = Some(n);
99        self
100    }
101
102    /// Sets the IO ring size. Defaults to 256.
103    pub fn ring_size(&mut self, ring_size: u32) -> &mut Self {
104        assert_ne!(ring_size, 0);
105        self.ring_size = ring_size;
106        self
107    }
108
109    /// Builds the thread pool.
110    pub fn build(&self) -> io::Result<AffinitizedThreadpool> {
111        let proc_count = pal::unix::affinity::max_present_cpu()? + 1;
112
113        let builder = Arc::new(self.clone());
114        let mut drivers = Vec::with_capacity(proc_count as usize);
115        drivers.extend((0..proc_count).map(|processor| ThreadpoolDriver {
116            inner: Arc::new(ThreadpoolDriverInner {
117                once: OnceLock::new(),
118                cpu: processor,
119                builder: builder.clone(),
120                name: format!("threadpool-{}", processor).into(),
121                affinity_set: false.into(),
122                state: Mutex::new(ThreadpoolDriverState {
123                    notifier: None,
124                    affinity: AffinityState::Waiting(Vec::new()),
125                    spawned: false,
126                }),
127            }),
128        }));
129
130        let state = Arc::new(AffinitizedThreadpoolState { drivers });
131
132        Ok(AffinitizedThreadpool { state })
133    }
134
135    // Spawn a pool on the specified CPU.
136    //
137    // If the specified CPU is present but not online, spawns a thread with
138    // affinity set to all processors that are in the same package, if possible.
139    //
140    // Note that this sets affinity of the current thread and does not revert
141    // it. Call this from a temporary thread to avoid permanently changing the
142    // affinity of the current thread.
143    fn spawn_pool(&self, cpu: u32, driver: ThreadpoolDriver) -> io::Result<PoolClient> {
144        tracing::debug!(cpu, "starting threadpool thread");
145
146        let online = is_cpu_online(cpu)?;
147        let mut affinity = CpuSet::new();
148        if online {
149            affinity.set(cpu);
150        } else {
151            // The CPU is not online. Set the affinity to match the package.
152            //
153            // TODO: figure out how to do this (maybe pass in
154            // ProcessorTopology)--the sysfs topology directory does not exist
155            // for offline CPUs. For now, just allow all CPUs.
156            let online_cpus = fs_err::read_to_string("/sys/devices/system/cpu/online")?;
157            affinity
158                .set_mask_list(&online_cpus)
159                .map_err(io::Error::other)?;
160        }
161
162        // Set the current thread's affinity so that allocations for the worker
163        // thread are performed in the correct node.
164        let affinity_ok = match pal::unix::affinity::set_current_thread_affinity(&affinity) {
165            Ok(()) => true,
166            Err(err) if err.kind() == io::ErrorKind::InvalidInput && !online => {
167                // None of the CPUs in the package are online. That's not ideal,
168                // because the thread will probably get allocated with the wrong node,
169                // but it's recoverable.
170                tracing::warn!(
171                    CVM_ALLOWED,
172                    cpu,
173                    error = &err as &dyn std::error::Error,
174                    "could not set package affinity for thread pool thread"
175                );
176                false
177            }
178            Err(err) => return Err(err),
179        };
180
181        let this = self.clone();
182        let (send, recv) = std::sync::mpsc::channel();
183        let thread = std::thread::Builder::new()
184            .name("tp".to_owned())
185            .spawn(move || {
186                // Create the pool and report back the result. This must be done
187                // on the thread so that the io-uring task context gets created.
188                // If we create this back on the initiating thread, then the
189                // task context gets created and then destroyed, and subsequent
190                // calls to update the affinity fail until the task context gets
191                // recreated (next time an IO is issued).
192                //
193                // FUTURE: take advantage of the per-thread task context and
194                // pre-register the ring via IORING_REGISTER_RING_FDS.
195                let pool = match this
196                    .make_ring(driver.inner.name.clone(), affinity_ok.then_some(&affinity))
197                {
198                    Ok(pool) => pool,
199                    Err(err) => {
200                        send.send(Err(err)).ok();
201                        return;
202                    }
203                };
204
205                let driver = driver;
206                let notifier = {
207                    let mut state = driver.inner.state.lock();
208                    state.spawned = true;
209                    if online {
210                        // There cannot be any waiters yet since they can only
211                        // be registered from the current thread.
212                        driver.inner.affinity_set.store(true, Relaxed);
213                        state.affinity = AffinityState::Set;
214                    }
215                    state.notifier.take()
216                };
217
218                send.send(Ok(pool.client().clone())).ok();
219
220                // Store the current thread's driver so that spawned tasks can
221                // find it via `Thread::current()`. Do this via a loan instead
222                // of storing it directly in TLS to avoid the overhead of
223                // registering a destructor.
224                CURRENT_THREAD_DRIVER.with(|current| {
225                    current.lend(&driver, || {
226                        if let Some(notifier) = notifier {
227                            (notifier.0)(true);
228                        }
229                        pool.run()
230                    });
231                });
232            })?;
233
234        // Wait for the pool to be initialized.
235        recv.recv().unwrap().inspect_err(|_| {
236            // Wait for the child thread to exit to bound resource use.
237            thread.join().unwrap();
238        })
239    }
240
241    fn make_ring(&self, name: Arc<str>, affinity: Option<&CpuSet>) -> io::Result<IoUringPool> {
242        let pool = IoUringPool::new(name, self.ring_size)?;
243        let client = pool.client();
244        client.set_iowq_max_workers(self.max_bounded_workers, self.max_unbounded_workers)?;
245        if let Some(affinity) = affinity {
246            client.set_iowq_affinity(affinity)?
247        }
248        Ok(pool)
249    }
250}
251
252/// Returns whether the specified CPU is online.
253pub fn is_cpu_online(cpu: u32) -> io::Result<bool> {
254    // Depending at the very minimum on whether the kernel has been built with
255    // `CONFIG_HOTPLUG_CPU` or not, the individual `online` pseudo-files will be
256    // present or absent.
257    //
258    // The other factors at play are the firmware-reported system properties and
259    // the `cpu_ops` structures defined for the platform. All these lead ultimately
260    // to setting the `hotpluggable` property on the cpu device in the kernel.
261    // If that property is set, the `online` file will be present for the given CPU.
262    //
263    // If that file is absent for the CPU in question, that means it is online, and
264    // due to various factors (e.g. BSP on x86_64, missing `cpu_die` handler, etc)
265    // the CPU is not allowed to be offlined.
266    //
267    // The well-established cross-platform tools (e.g. `perf`) in the kernel repo
268    // rely on the same: if the `online` file is missing, assume the CPU is online
269    // provided the CPU "home" directory is present (although they don't have
270    // comments like this one :)).
271
272    let cpu_sysfs_home = format!("/sys/devices/system/cpu/cpu{cpu}");
273    let cpu_sysfs_home = std::path::Path::new(cpu_sysfs_home.as_str());
274    let online = cpu_sysfs_home.join("online");
275    match fs_err::read_to_string(online) {
276        Ok(s) => Ok(s.trim() == "1"),
277        Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(cpu_sysfs_home.exists()),
278        Err(err) => Err(err),
279    }
280}
281
282/// Sets the specified CPU online, if it is not already online.
283pub fn set_cpu_online(cpu: u32) -> io::Result<()> {
284    let online = format!("/sys/devices/system/cpu/cpu{cpu}/online");
285    match fs_err::read_to_string(&online) {
286        Ok(s) if s.trim() == "0" => {
287            fs_err::write(&online, "1")?;
288        }
289        Ok(_) => {
290            // Already online.
291        }
292        Err(err) if err.kind() == io::ErrorKind::NotFound => {
293            // The file doesn't exist, so the processor is always online.
294        }
295        Err(err) => return Err(err),
296    }
297    Ok(())
298}
299
300impl AffinitizedThreadpool {
301    /// Creates a new threadpool with the specified ring size.
302    pub fn new(io_ring_size: u32) -> io::Result<Self> {
303        ThreadpoolBuilder::new().ring_size(io_ring_size).build()
304    }
305
306    /// Returns an object that can be used to submit IOs or spawn tasks to the
307    /// current processor's ring.
308    ///
309    /// Spawned tasks will remain affinitized to the current thread. Spawn
310    /// directly on the threadpool object to get a task that will run on any
311    /// thread.
312    pub fn current_driver(&self) -> &ThreadpoolDriver {
313        self.driver(pal::unix::affinity::get_cpu_number())
314    }
315
316    /// Returns an object that can be used to submit IOs to the specified ring
317    /// in the pool, or to spawn tasks on the specified thread.
318    ///
319    /// Spawned tasks will remain affinitized to the specified thread. Spawn
320    /// directly on the threadpool object to get a task that will run on any
321    /// thread.
322    pub fn driver(&self, ring_id: u32) -> &ThreadpoolDriver {
323        &self.state.drivers[ring_id as usize]
324    }
325
326    /// Returns an iterator of drivers for threads that are running and have
327    /// their affinity set.
328    ///
329    /// This is useful for getting a set of drivers that can be used to
330    /// parallelize work.
331    pub fn active_drivers(&self) -> impl Iterator<Item = &ThreadpoolDriver> + Clone {
332        self.state
333            .drivers
334            .iter()
335            .filter(|driver| driver.is_affinity_set())
336    }
337}
338
339impl Schedule for AffinitizedThreadpoolState {
340    fn schedule(&self, runnable: Runnable) {
341        self.drivers[pal::unix::affinity::get_cpu_number() as usize]
342            .client(Some(runnable.metadata()))
343            .schedule(runnable);
344    }
345
346    fn name(&self) -> Arc<str> {
347        static NAME: OnceLock<Arc<str>> = OnceLock::new();
348        NAME.get_or_init(|| "tp".into()).clone()
349    }
350}
351
352impl Spawn for AffinitizedThreadpool {
353    fn scheduler(&self, _metadata: &TaskMetadata) -> Arc<dyn Schedule> {
354        self.state.clone()
355    }
356}
357
358/// Initiate IOs to the current CPU's thread.
359impl Initiate for AffinitizedThreadpool {
360    fn initiator(&self) -> &IoInitiator {
361        self.current_driver().initiator()
362    }
363}
364
365impl IoUringSubmit for AffinitizedThreadpool {
366    fn probe(&self, opcode: u8) -> bool {
367        self.current_driver().initiator().probe(opcode)
368    }
369
370    unsafe fn submit(
371        &self,
372        sqe: pal_async::io_uring::Entry,
373    ) -> impl Future<Output = io::Result<i32>> + Send + '_ {
374        // SAFETY: caller guarantees the SQE references valid memory.
375        unsafe { self.current_driver().initiator().submit(sqe) }
376    }
377}
378
379/// The state for the thread pool thread for the currently running CPU.
380#[derive(Debug, Copy, Clone)]
381pub struct Thread {
382    _not_send_sync: PhantomData<*const ()>,
383}
384
385impl Thread {
386    /// Returns an instance for the current CPU.
387    pub fn current() -> Option<Self> {
388        if !CURRENT_THREAD_DRIVER.with(|current| current.is_lent()) {
389            return None;
390        }
391        Some(Self {
392            _not_send_sync: PhantomData,
393        })
394    }
395
396    /// Calls `f` with the driver for the current thread.
397    pub fn with_driver<R>(&self, f: impl FnOnce(&ThreadpoolDriver) -> R) -> R {
398        CURRENT_THREAD_DRIVER.with(|current| current.borrow(|driver| f(driver.unwrap())))
399    }
400
401    fn with_once<R>(&self, f: impl FnOnce(&ThreadpoolDriver, &ThreadpoolDriverOnce) -> R) -> R {
402        self.with_driver(|driver| f(driver, driver.inner.once.get().unwrap()))
403    }
404
405    /// Sets the idle task to run. The task is returned by `f`, which receives
406    /// the file descriptor of the IO ring.
407    ///
408    /// The idle task is run before waiting on the IO ring. The idle task can
409    /// block synchronously by first calling [`IdleControl::pre_block`], and
410    /// then by polling on the IO ring while the task blocks.
411    pub fn set_idle_task<F>(&self, f: F)
412    where
413        F: 'static + Send + AsyncFnOnce(IdleControl),
414    {
415        self.with_once(|_, once| once.client.set_idle_task(f))
416    }
417
418    /// Tries to set the affinity to this thread's intended CPU, if it has not
419    /// already been set. Returns `Ok(false)` if the intended CPU is still
420    /// offline.
421    pub fn try_set_affinity(&self) -> Result<bool, SetAffinityError> {
422        self.with_once(|driver, once| {
423            let mut state = driver.inner.state.lock();
424            if matches!(state.affinity, AffinityState::Set) {
425                return Ok(true);
426            }
427            if !is_cpu_online(driver.inner.cpu).map_err(SetAffinityError::Online)? {
428                return Ok(false);
429            }
430
431            let mut affinity = CpuSet::new();
432            affinity.set(driver.inner.cpu);
433
434            pal::unix::affinity::set_current_thread_affinity(&affinity)
435                .map_err(SetAffinityError::Thread)?;
436            once.client
437                .set_iowq_affinity(&affinity)
438                .map_err(SetAffinityError::Ring)?;
439
440            let old_affinity_state = std::mem::replace(&mut state.affinity, AffinityState::Set);
441            driver.inner.affinity_set.store(true, Relaxed);
442            drop(state);
443
444            match old_affinity_state {
445                AffinityState::Waiting(wakers) => {
446                    for waker in wakers {
447                        waker.wake();
448                    }
449                }
450                AffinityState::Set => unreachable!(),
451            }
452            Ok(true)
453        })
454    }
455
456    /// Returns the that caused this thread to spawn.
457    ///
458    /// Returns `None` if the thread was spawned to issue IO.
459    pub fn first_task(&self) -> Option<TaskInfo> {
460        self.with_once(|_, once| once.first_task.clone())
461    }
462}
463
464/// An error that can occur when setting the affinity of a thread.
465#[derive(Debug, Error)]
466pub enum SetAffinityError {
467    /// An error occurred while checking if the CPU is online.
468    #[error("failed to check if CPU is online")]
469    Online(#[source] io::Error),
470    /// An error occurred while setting the thread affinity.
471    #[error("failed to set thread affinity")]
472    Thread(#[source] io::Error),
473    /// An error occurred while setting the IO ring affinity.
474    #[error("failed to set io-uring affinity")]
475    Ring(#[source] io::Error),
476}
477
478thread_local! {
479    static CURRENT_THREAD_DRIVER: LoanCell<ThreadpoolDriver> = const { LoanCell::new() };
480}
481
482impl SpawnLocal for Thread {
483    fn scheduler_local(&self, metadata: &TaskMetadata) -> Arc<dyn Schedule> {
484        self.with_driver(|driver| driver.scheduler(metadata).clone())
485    }
486}
487
488/// A driver for [`AffinitizedThreadpool`] that is targeted at a specific
489/// CPU.
490#[derive(Debug, Clone, Inspect)]
491#[inspect(transparent)]
492pub struct ThreadpoolDriver {
493    inner: Arc<ThreadpoolDriverInner>,
494}
495
496#[derive(Debug, Inspect)]
497struct ThreadpoolDriverInner {
498    #[inspect(flatten)]
499    once: OnceLock<ThreadpoolDriverOnce>,
500    #[inspect(skip)]
501    builder: Arc<ThreadpoolBuilder>,
502    cpu: u32,
503    name: Arc<str>,
504    affinity_set: AtomicBool,
505    #[inspect(flatten)]
506    state: Mutex<ThreadpoolDriverState>,
507}
508
509#[derive(Debug, Inspect)]
510struct ThreadpoolDriverOnce {
511    #[inspect(skip)]
512    client: PoolClient,
513    first_task: Option<TaskInfo>,
514}
515
516/// Information about a task that caused a thread to spawn.
517#[derive(Debug, Clone, Inspect)]
518pub struct TaskInfo {
519    /// The name of the task.
520    pub name: Arc<str>,
521    /// The location of the task.
522    #[inspect(display)]
523    pub location: &'static std::panic::Location<'static>,
524}
525
526#[derive(Debug, Inspect)]
527struct ThreadpoolDriverState {
528    affinity: AffinityState,
529    #[inspect(with = "|x| x.is_some()")]
530    notifier: Option<AffinityNotifier>,
531    spawned: bool,
532}
533
534#[derive(Debug, Inspect)]
535#[inspect(external_tag)]
536enum AffinityState {
537    #[inspect(transparent)]
538    Waiting(#[inspect(with = "|x| x.len()")] Vec<Waker>),
539    Set,
540}
541
542struct AffinityNotifier(Box<dyn FnOnce(bool) + Send>);
543
544impl std::fmt::Debug for AffinityNotifier {
545    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
546        f.pad("AffinityNotifier")
547    }
548}
549
550impl ThreadpoolDriver {
551    fn once(&self, metadata: Option<&TaskMetadata>) -> &ThreadpoolDriverOnce {
552        self.inner.once.get_or_init(|| {
553            let this = self.clone();
554            let client = std::thread::spawn(move || {
555                let inner = this.inner.clone();
556                inner.builder.spawn_pool(inner.cpu, this)
557            })
558            .join()
559            .unwrap()
560            .expect("failed to spawn thread pool thread");
561
562            // If no task metadata was provided (because the thread is being
563            // spawned to issue IO) use the current task's metadata as the
564            // initiating task.
565            pal_async::task::with_current_task_metadata(|current_metadata| {
566                let metadata = metadata.or(current_metadata);
567                ThreadpoolDriverOnce {
568                    client,
569                    first_task: metadata.map(|metadata| TaskInfo {
570                        name: metadata.name().clone(),
571                        location: metadata.location(),
572                    }),
573                }
574            })
575        })
576    }
577
578    fn client(&self, metadata: Option<&TaskMetadata>) -> &PoolClient {
579        &self.once(metadata).client
580    }
581
582    /// Returns the target CPU number for this thread.
583    ///
584    /// This may be different from the CPU tasks actually run on if the affinity
585    /// has not yet been set for the thread.
586    pub fn target_cpu(&self) -> u32 {
587        self.inner.cpu
588    }
589
590    /// Returns whether this thread's CPU affinity has been set to the intended
591    /// CPU.
592    pub fn is_affinity_set(&self) -> bool {
593        self.inner.affinity_set.load(Relaxed)
594    }
595
596    /// Waits for the affinity to be set to this thread's intended CPU. If the
597    /// CPU was not online when the thread was created, then this will block
598    /// until the CPU is online and someone calls `try_set_affinity`.
599    pub async fn wait_for_affinity(&self) {
600        // Ensure the thread has been spawned and that the notifier has been
601        // called. Use the calling task as the initiating task for diagnostics
602        // purposes.
603        pal_async::task::with_current_task_metadata(|metadata| self.once(metadata));
604        poll_fn(|cx| {
605            let mut state = self.inner.state.lock();
606            match &mut state.affinity {
607                AffinityState::Waiting(wakers) => {
608                    if !wakers.iter().any(|w| w.will_wake(cx.waker())) {
609                        wakers.push(cx.waker().clone());
610                    }
611                    Poll::Pending
612                }
613                AffinityState::Set => Poll::Ready(()),
614            }
615        })
616        .await
617    }
618
619    /// Sets a function to be called when the thread gets spawned. The function
620    /// accepts a single `bool` parameter that indicates that the notifier
621    /// should cancel any oustanding run or not. When called by the threadpool,
622    /// the function will recieve `true`.
623    ///
624    /// Return `Err(f)` if the thread is already spawned.
625    pub fn set_spawn_notifier<F: 'static + Send + FnOnce(bool)>(&self, f: F) -> Result<(), F> {
626        let mut state = self.inner.state.lock();
627        if !state.spawned {
628            state.notifier = Some(AffinityNotifier(Box::new(f)));
629            Ok(())
630        } else {
631            Err(f)
632        }
633    }
634}
635
636impl Initiate for ThreadpoolDriver {
637    fn initiator(&self) -> &IoInitiator {
638        self.client(None).initiator()
639    }
640}
641
642impl Spawn for ThreadpoolDriver {
643    fn scheduler(&self, metadata: &TaskMetadata) -> Arc<dyn Schedule> {
644        self.client(Some(metadata)).initiator().scheduler(metadata)
645    }
646}
647
648impl FdReadyDriver for ThreadpoolDriver {
649    type FdReady = FdReady<Self>;
650
651    fn new_fd_ready(&self, fd: RawFd) -> io::Result<Self::FdReady> {
652        Ok(FdReady::new(self.clone(), fd))
653    }
654}
655
656impl WaitDriver for ThreadpoolDriver {
657    type Wait = FdWait<Self>;
658
659    fn new_wait(&self, fd: RawFd, read_size: usize) -> io::Result<Self::Wait> {
660        Ok(FdWait::new(self.clone(), fd, read_size))
661    }
662}
663
664impl TimerDriver for ThreadpoolDriver {
665    type Timer = Timer<Self>;
666
667    fn new_timer(&self) -> Self::Timer {
668        Timer::new(self.clone())
669    }
670}
671
672impl pal_async::io_uring::IoUringDriver for ThreadpoolDriver {
673    type Submitter = IoInitiator;
674
675    fn io_uring_submitter(&self) -> Option<&IoInitiator> {
676        Some(self.client(None).initiator())
677    }
678}
679
680/// A driver for [`AffinitizedThreadpool`] that can be retargeted to different
681/// CPUs.
682#[derive(Debug, Clone)]
683pub struct RetargetableDriver {
684    inner: Arc<RetargetableDriverInner>,
685}
686
687#[derive(Debug)]
688struct RetargetableDriverInner {
689    threadpool: AffinitizedThreadpool,
690    target_cpu: AtomicU32,
691}
692
693impl RetargetableDriver {
694    /// Returns a new driver, initially targeted to `target_cpu`.
695    pub fn new(threadpool: AffinitizedThreadpool, target_cpu: u32) -> Self {
696        Self {
697            inner: Arc::new(RetargetableDriverInner {
698                threadpool,
699                target_cpu: target_cpu.into(),
700            }),
701        }
702    }
703
704    /// Retargets the driver to `target_cpu`.
705    ///
706    /// In-flight IOs will not be retargeted.
707    pub fn retarget(&self, target_cpu: u32) {
708        self.inner.target_cpu.store(target_cpu, Relaxed);
709    }
710
711    /// Returns the current target CPU.
712    pub fn current_target_cpu(&self) -> u32 {
713        self.inner.target_cpu.load(Relaxed)
714    }
715
716    /// Returns the current driver.
717    pub fn current_driver(&self) -> &ThreadpoolDriver {
718        self.inner.current_driver()
719    }
720}
721
722impl Initiate for RetargetableDriver {
723    fn initiator(&self) -> &IoInitiator {
724        self.inner.current_driver().initiator()
725    }
726}
727
728impl Spawn for RetargetableDriver {
729    fn scheduler(&self, _metadata: &TaskMetadata) -> Arc<dyn Schedule> {
730        self.inner.clone()
731    }
732}
733
734impl RetargetableDriverInner {
735    fn current_driver(&self) -> &ThreadpoolDriver {
736        self.threadpool.driver(self.target_cpu.load(Relaxed))
737    }
738}
739
740impl Schedule for RetargetableDriverInner {
741    fn schedule(&self, runnable: Runnable) {
742        self.current_driver()
743            .client(Some(runnable.metadata()))
744            .schedule(runnable)
745    }
746
747    fn name(&self) -> Arc<str> {
748        self.current_driver().inner.name.clone()
749    }
750}
751
752impl FdReadyDriver for RetargetableDriver {
753    type FdReady = FdReady<Self>;
754
755    fn new_fd_ready(&self, fd: RawFd) -> io::Result<Self::FdReady> {
756        Ok(FdReady::new(self.clone(), fd))
757    }
758}
759
760impl WaitDriver for RetargetableDriver {
761    type Wait = FdWait<Self>;
762
763    fn new_wait(&self, fd: RawFd, read_size: usize) -> io::Result<Self::Wait> {
764        Ok(FdWait::new(self.clone(), fd, read_size))
765    }
766}
767
768impl TimerDriver for RetargetableDriver {
769    type Timer = Timer<Self>;
770
771    fn new_timer(&self) -> Self::Timer {
772        Timer::new(self.clone())
773    }
774}
775
776impl pal_async::io_uring::IoUringDriver for RetargetableDriver {
777    type Submitter = IoInitiator;
778
779    fn io_uring_submitter(&self) -> Option<&IoInitiator> {
780        Some(self.inner.current_driver().initiator())
781    }
782}