pal/unix/
affinity.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Thread affinity support for Linux.
5
6#![cfg(target_os = "linux")]
7
8use std::io;
9use std::sync::OnceLock;
10use thiserror::Error;
11
12/// A [`libc::cpu_set_t`] sized appropriately to the number of processors on
13/// this machine.
14///
15/// This is needed to support more than 1024 processors, since the statically
16/// sized `cpu_set_t` only has room for that many processors.
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct CpuSet(Box<[u64]>);
19
20impl Default for CpuSet {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26impl CpuSet {
27    /// Allocates a new empty CPU set.
28    pub fn new() -> Self {
29        // Size the buffer according to the maximum number of processors.
30        let size = (max_procs() + 63) as usize / 64;
31        Self(vec![0; size].into())
32    }
33
34    /// Gets the length of the buffer in bytes, for use with syscalls.
35    pub fn buffer_len(&self) -> usize {
36        self.0.len() * 8
37    }
38
39    /// Gets a pointer for use with syscalls.
40    pub fn as_ptr(&self) -> *const libc::cpu_set_t {
41        self.0.as_ptr().cast()
42    }
43
44    /// Gets a mutable pointer for use with syscalls.
45    pub fn as_mut_ptr(&mut self) -> *mut libc::cpu_set_t {
46        self.0.as_mut_ptr().cast()
47    }
48
49    /// Sets processor `index` in the CPU set.
50    ///
51    /// Panics if `index` is greater than or equal to [`max_procs`].
52    pub fn set(&mut self, index: u32) -> &mut Self {
53        assert!(index < max_procs());
54        // Can't use libc::CPU_SET because it assumes a statically sized
55        // cpu_set_t (which raises the question of why they bother to expose
56        // CPU_ALLOC_SIZE...).
57        self.0[index as usize / 64] |= 1 << (index % 64);
58        self
59    }
60
61    /// Sets all the CPUs in the linear bitmask `mask`, which is an ASCII
62    /// hexadecimal string.
63    ///
64    /// This is useful for parsing the output of `/sys/devices/system/cpu/topology`.
65    pub fn set_mask_hex_string(&mut self, string_mask: &[u8]) -> Result<(), InvalidHexString> {
66        let err = || InvalidHexString(String::from_utf8_lossy(string_mask).into_owned());
67        if string_mask.len() % 2 != 0 {
68            return Err(err());
69        }
70        let mask = string_mask
71            .chunks_exact(2)
72            .map(|s| u8::from_str_radix(std::str::from_utf8(s).ok()?, 16).ok());
73        for (i, byte) in mask.enumerate() {
74            let byte = byte.ok_or_else(err)?;
75            if byte != 0 {
76                *self.0.get_mut(i / 8).ok_or_else(err)? |= (byte as u64) << (i % 8);
77            }
78        }
79        Ok(())
80    }
81
82    /// Sets all the CPUs in the list `list`, which is a comma-separated list of
83    /// ranges and single CPUs.
84    ///
85    /// This is useful for parsing the output of `/sys/devices/system/cpu/online`.
86    pub fn set_mask_list(&mut self, list: &str) -> Result<(), InvalidCpuList> {
87        let err = || InvalidCpuList(list.to_owned());
88        for range in list.trim_end().split(',') {
89            let range = match range.split_once('-') {
90                Some((start, end)) => {
91                    start.parse().map_err(|_| err())?..=end.parse().map_err(|_| err())?
92                }
93                None => {
94                    let cpu = range.parse().map_err(|_| err())?;
95                    cpu..=cpu
96                }
97            };
98            for cpu in range {
99                self.set(cpu);
100            }
101        }
102        Ok(())
103    }
104
105    /// Returns whether processor `index` is set.
106    ///
107    /// Panics if `index` is greater than or equal to [`max_procs`].
108    pub fn is_set(&self, index: u32) -> bool {
109        assert!(index < max_procs());
110        self.0[index as usize / 64] & (1 << (index % 64)) != 0
111    }
112}
113
114#[derive(Debug, Error)]
115#[error("invalid hex string for bitmask: {0}")]
116pub struct InvalidHexString(String);
117
118#[derive(Debug, Error)]
119#[error("invalid CPU list: {0}")]
120pub struct InvalidCpuList(String);
121
122/// Sets the current thread's affinity.
123pub fn set_current_thread_affinity(cpu_set: &CpuSet) -> io::Result<()> {
124    // SAFETY: calling as documented, with an appropriately-sized buffer.
125    let r = unsafe { libc::sched_setaffinity(0, cpu_set.buffer_len(), cpu_set.as_ptr()) };
126    if r < 0 {
127        return Err(io::Error::last_os_error());
128    }
129    Ok(())
130}
131
132/// Gets the current thread's affinity.
133pub fn get_current_thread_affinity(cpu_set: &mut CpuSet) -> io::Result<()> {
134    // SAFETY: calling as documented, with an appropriately-sized buffer.
135    let r = unsafe { libc::sched_getaffinity(0, cpu_set.buffer_len(), cpu_set.as_mut_ptr()) };
136    if r < 0 {
137        return Err(io::Error::last_os_error());
138    }
139    Ok(())
140}
141
142/// Returns the number of the processor the current thread was running on during the call to this function.
143pub fn get_cpu_number() -> u32 {
144    // SAFETY: Calling external code.
145    unsafe { libc::sched_getcpu() as u32 }
146}
147
148/// Returns the total number of online processors.
149pub fn num_procs() -> u32 {
150    static NUM_PROCS: OnceLock<u32> = OnceLock::new();
151    *NUM_PROCS.get_or_init(|| {
152        // Get the number of bits in the current affinity set. This isn't
153        // perfect--what if we have set affinity to something else--but it's
154        // what `sysconf(_SC_NPROCESSORS_ONLN)` does, except that it works with
155        // more than 1024 processors.
156        //
157        // FUTURE: find callers of this and choose a different strategy
158        // accordingly.
159        let mut set = CpuSet::new();
160        get_current_thread_affinity(&mut set).unwrap();
161        set.0.iter().map(|x| x.count_ones()).sum()
162    })
163}
164
165/// Returns the maximum CPU number of any present (but not necessarily online)
166/// processor.
167pub fn max_present_cpu() -> io::Result<u32> {
168    let mut max_cpu = 0;
169    for entry in fs_err::read_dir("/sys/devices/system/cpu")? {
170        let entry = entry?;
171        let name = entry.file_name();
172        let Some(cpu) = name
173            .to_str()
174            .and_then(|s| s.strip_prefix("cpu"))
175            .and_then(|s| s.parse::<u32>().ok())
176        else {
177            continue;
178        };
179        max_cpu = cpu.max(max_cpu);
180    }
181    Ok(max_cpu)
182}
183
184/// Returns the kernel compiled-in maximum number of processors.
185pub fn max_procs() -> u32 {
186    static MAX_PROCS: OnceLock<u32> = OnceLock::new();
187    *MAX_PROCS.get_or_init(|| {
188        let max_cpu_index: u32 = std::fs::read_to_string("/sys/devices/system/cpu/kernel_max")
189            .expect("failed to read kernel_max")
190            .trim_end()
191            .parse()
192            .expect("failed to parse kernel_max");
193
194        max_cpu_index + 1
195    })
196}
197
198#[cfg(test)]
199mod tests {
200    use super::max_procs;
201
202    #[test]
203    fn test_max_procs() {
204        let p = max_procs();
205        assert!(p > 0 && p < 32768);
206    }
207
208    #[test]
209    fn test_cpu_list() {
210        let mut set = super::CpuSet::new();
211        set.set_mask_list("0-3,5").unwrap();
212        assert_eq!(set.is_set(0), true);
213        assert_eq!(set.is_set(1), true);
214        assert_eq!(set.is_set(2), true);
215        assert_eq!(set.is_set(3), true);
216        assert_eq!(set.is_set(4), false);
217        assert_eq!(set.is_set(5), true);
218        assert_eq!(set.is_set(6), false);
219    }
220}