1#![cfg(target_os = "linux")]
7
8use std::io;
9use std::sync::OnceLock;
10use thiserror::Error;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct CpuSet(Box<[u64]>);
19
20impl Default for CpuSet {
21 fn default() -> Self {
22 Self::new()
23 }
24}
25
26impl CpuSet {
27 pub fn new() -> Self {
29 let size = (max_procs() + 63) as usize / 64;
31 Self(vec![0; size].into())
32 }
33
34 pub fn buffer_len(&self) -> usize {
36 self.0.len() * 8
37 }
38
39 pub fn as_ptr(&self) -> *const libc::cpu_set_t {
41 self.0.as_ptr().cast()
42 }
43
44 pub fn as_mut_ptr(&mut self) -> *mut libc::cpu_set_t {
46 self.0.as_mut_ptr().cast()
47 }
48
49 pub fn set(&mut self, index: u32) -> &mut Self {
53 assert!(index < max_procs());
54 self.0[index as usize / 64] |= 1 << (index % 64);
58 self
59 }
60
61 pub fn set_mask_hex_string(&mut self, string_mask: &[u8]) -> Result<(), InvalidHexString> {
66 let err = || InvalidHexString(String::from_utf8_lossy(string_mask).into_owned());
67 if string_mask.len() % 2 != 0 {
68 return Err(err());
69 }
70 let mask = string_mask
71 .chunks_exact(2)
72 .map(|s| u8::from_str_radix(std::str::from_utf8(s).ok()?, 16).ok());
73 for (i, byte) in mask.enumerate() {
74 let byte = byte.ok_or_else(err)?;
75 if byte != 0 {
76 *self.0.get_mut(i / 8).ok_or_else(err)? |= (byte as u64) << (i % 8);
77 }
78 }
79 Ok(())
80 }
81
82 pub fn set_mask_list(&mut self, list: &str) -> Result<(), InvalidCpuList> {
87 let err = || InvalidCpuList(list.to_owned());
88 for range in list.trim_end().split(',') {
89 let range = match range.split_once('-') {
90 Some((start, end)) => {
91 start.parse().map_err(|_| err())?..=end.parse().map_err(|_| err())?
92 }
93 None => {
94 let cpu = range.parse().map_err(|_| err())?;
95 cpu..=cpu
96 }
97 };
98 for cpu in range {
99 self.set(cpu);
100 }
101 }
102 Ok(())
103 }
104
105 pub fn is_set(&self, index: u32) -> bool {
109 assert!(index < max_procs());
110 self.0[index as usize / 64] & (1 << (index % 64)) != 0
111 }
112}
113
114#[derive(Debug, Error)]
115#[error("invalid hex string for bitmask: {0}")]
116pub struct InvalidHexString(String);
117
118#[derive(Debug, Error)]
119#[error("invalid CPU list: {0}")]
120pub struct InvalidCpuList(String);
121
122pub fn set_current_thread_affinity(cpu_set: &CpuSet) -> io::Result<()> {
124 let r = unsafe { libc::sched_setaffinity(0, cpu_set.buffer_len(), cpu_set.as_ptr()) };
126 if r < 0 {
127 return Err(io::Error::last_os_error());
128 }
129 Ok(())
130}
131
132pub fn get_current_thread_affinity(cpu_set: &mut CpuSet) -> io::Result<()> {
134 let r = unsafe { libc::sched_getaffinity(0, cpu_set.buffer_len(), cpu_set.as_mut_ptr()) };
136 if r < 0 {
137 return Err(io::Error::last_os_error());
138 }
139 Ok(())
140}
141
142pub fn get_cpu_number() -> u32 {
144 unsafe { libc::sched_getcpu() as u32 }
146}
147
148pub fn num_procs() -> u32 {
150 static NUM_PROCS: OnceLock<u32> = OnceLock::new();
151 *NUM_PROCS.get_or_init(|| {
152 let mut set = CpuSet::new();
160 get_current_thread_affinity(&mut set).unwrap();
161 set.0.iter().map(|x| x.count_ones()).sum()
162 })
163}
164
165pub fn max_present_cpu() -> io::Result<u32> {
168 let mut max_cpu = 0;
169 for entry in fs_err::read_dir("/sys/devices/system/cpu")? {
170 let entry = entry?;
171 let name = entry.file_name();
172 let Some(cpu) = name
173 .to_str()
174 .and_then(|s| s.strip_prefix("cpu"))
175 .and_then(|s| s.parse::<u32>().ok())
176 else {
177 continue;
178 };
179 max_cpu = cpu.max(max_cpu);
180 }
181 Ok(max_cpu)
182}
183
184pub fn max_procs() -> u32 {
186 static MAX_PROCS: OnceLock<u32> = OnceLock::new();
187 *MAX_PROCS.get_or_init(|| {
188 let max_cpu_index: u32 = std::fs::read_to_string("/sys/devices/system/cpu/kernel_max")
189 .expect("failed to read kernel_max")
190 .trim_end()
191 .parse()
192 .expect("failed to parse kernel_max");
193
194 max_cpu_index + 1
195 })
196}
197
198#[cfg(test)]
199mod tests {
200 use super::max_procs;
201
202 #[test]
203 fn test_max_procs() {
204 let p = max_procs();
205 assert!(p > 0 && p < 32768);
206 }
207
208 #[test]
209 fn test_cpu_list() {
210 let mut set = super::CpuSet::new();
211 set.set_mask_list("0-3,5").unwrap();
212 assert_eq!(set.is_set(0), true);
213 assert_eq!(set.is_set(1), true);
214 assert_eq!(set.is_set(2), true);
215 assert_eq!(set.is_set(3), true);
216 assert_eq!(set.is_set(4), false);
217 assert_eq!(set.is_set(5), true);
218 assert_eq!(set.is_set(6), false);
219 }
220}