cache_topology/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Provides ways to describe a machine's cache topology and to query it from
5//! the current running machine.
6
7// UNSAFETY: needed to call Win32 functions to query cache topology
8#![cfg_attr(windows, expect(unsafe_code))]
9
10use thiserror::Error;
11
12/// A machine's cache topology.
13#[derive(Debug)]
14pub struct CacheTopology {
15    /// A list of caches.
16    pub caches: Vec<Cache>,
17}
18
19/// A memory cache.
20#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
21pub struct Cache {
22    /// The cache level, 1 being closest to the CPU.
23    pub level: u8,
24    /// The cache type.
25    pub cache_type: CacheType,
26    /// The CPUs that share this cache.
27    pub cpus: Vec<u32>,
28    /// The cache size in bytes.
29    pub size: u32,
30    /// The cache associativity. /// If `None`, this cache is fully associative.
31    pub associativity: Option<u32>,
32    /// The cache line size in bytes.
33    pub line_size: u32,
34}
35
36/// A cache type.
37#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
38pub enum CacheType {
39    /// A data cache.
40    Data,
41    /// An instruction cache.
42    Instruction,
43    /// A unified cache.
44    Unified,
45}
46
47/// An error returned by [`CacheTopology::from_host`].
48#[derive(Debug, Error)]
49pub enum HostTopologyError {
50    /// An error occurred while retrieving the cache topology.
51    #[error("os error retrieving cache topology")]
52    Os(#[source] std::io::Error),
53}
54
55impl CacheTopology {
56    /// Returns the cache topology of the current machine.
57    pub fn from_host() -> Result<Self, HostTopologyError> {
58        let mut caches = Self::host_caches().map_err(HostTopologyError::Os)?;
59        caches.sort();
60        caches.dedup();
61        Ok(Self { caches })
62    }
63}
64
65#[cfg(windows)]
66mod windows {
67    use super::CacheTopology;
68    use crate::Cache;
69    use crate::CacheType;
70    use windows_sys::Win32::Foundation::ERROR_INSUFFICIENT_BUFFER;
71    use windows_sys::Win32::System::SystemInformation;
72
73    impl CacheTopology {
74        pub(crate) fn host_caches() -> std::io::Result<Vec<Cache>> {
75            let mut len = 0;
76            // SAFETY: passing a zero-length buffer as allowed by this routine.
77            let r = unsafe {
78                SystemInformation::GetLogicalProcessorInformationEx(
79                    SystemInformation::RelationCache,
80                    std::ptr::null_mut(),
81                    &mut len,
82                )
83            };
84            assert_eq!(r, 0);
85            let err = std::io::Error::last_os_error();
86            if err.raw_os_error() != Some(ERROR_INSUFFICIENT_BUFFER as i32) {
87                return Err(err);
88            }
89            let mut buf = vec![0u8; len as usize];
90            // SAFETY: passing a buffer of the correct size as returned by the
91            // previous call.
92            let r = unsafe {
93                SystemInformation::GetLogicalProcessorInformationEx(
94                    SystemInformation::RelationCache,
95                    buf.as_mut_ptr().cast(),
96                    &mut len,
97                )
98            };
99            if r == 0 {
100                return Err(std::io::Error::last_os_error());
101            }
102
103            let mut caches = Vec::new();
104
105            let mut buf = buf.as_slice();
106            while !buf.is_empty() {
107                // SAFETY: the remaining buffer is guaranteed to be large enough to hold
108                // the structure.
109                let info = unsafe {
110                    &*buf
111                        .as_ptr()
112                        .cast::<SystemInformation::SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>()
113                };
114
115                assert_eq!(info.Relationship, SystemInformation::RelationCache);
116                buf = &buf[info.Size as usize..];
117
118                // SAFETY: this is a cache entry, as guaranteed by the previous
119                // assertion.
120                let cache = unsafe { &info.Anonymous.Cache };
121
122                // SAFETY: the buffer is guaranteed by Win32 to be large enough
123                // to hold the group masks.
124                let groups = unsafe {
125                    std::slice::from_raw_parts(
126                        cache.Anonymous.GroupMasks.as_ptr(),
127                        cache.GroupCount as usize,
128                    )
129                };
130
131                let mut cpus = Vec::new();
132                for group in groups {
133                    for i in 0..usize::BITS {
134                        if group.Mask & (1 << i) != 0 {
135                            cpus.push(group.Group as u32 * usize::BITS + i);
136                        }
137                    }
138                }
139
140                caches.push(Cache {
141                    cpus,
142                    level: cache.Level,
143                    cache_type: match cache.Type {
144                        SystemInformation::CacheUnified => CacheType::Unified,
145                        SystemInformation::CacheInstruction => CacheType::Instruction,
146                        SystemInformation::CacheData => CacheType::Data,
147                        _ => continue,
148                    },
149                    size: cache.CacheSize,
150                    associativity: if cache.Associativity == !0 {
151                        None
152                    } else {
153                        Some(cache.Associativity.into())
154                    },
155                    line_size: cache.LineSize.into(),
156                });
157            }
158
159            Ok(caches)
160        }
161    }
162}
163
164#[cfg(target_os = "linux")]
165mod linux {
166    use super::Cache;
167    use super::CacheTopology;
168
169    impl CacheTopology {
170        pub(crate) fn host_caches() -> std::io::Result<Vec<Cache>> {
171            let mut caches = Vec::new();
172            for cpu_entry in fs_err::read_dir("/sys/devices/system/cpu")? {
173                let cpu_path = cpu_entry?.path();
174                if cpu_path
175                    .file_name()
176                    .unwrap()
177                    .to_str()
178                    .unwrap()
179                    .strip_prefix("cpu")
180                    .and_then(|s| s.parse::<u32>().ok())
181                    .is_none()
182                {
183                    continue;
184                }
185                for entry in fs_err::read_dir(cpu_path.join("cache"))? {
186                    let entry = entry?;
187                    let path = entry.path();
188                    if !path
189                        .file_name()
190                        .unwrap()
191                        .to_str()
192                        .is_some_and(|s| s.starts_with("index"))
193                    {
194                        continue;
195                    }
196
197                    let associativity = fs_err::read_to_string(path.join("ways_of_associativity"))?
198                        .trim_end()
199                        .parse()
200                        .unwrap();
201
202                    let mut cpus = Vec::new();
203                    for range in fs_err::read_to_string(path.join("shared_cpu_list"))?
204                        .trim_end()
205                        .split(',')
206                    {
207                        if let Some((start, end)) = range.split_once('-') {
208                            cpus.extend(
209                                start.parse::<u32>().unwrap()..=end.parse::<u32>().unwrap(),
210                            );
211                        } else {
212                            cpus.push(range.parse().unwrap());
213                        }
214                    }
215
216                    let line_size_result = fs_err::read_to_string(path.join("coherency_line_size"));
217                    let line_size = match line_size_result {
218                        Ok(s) => s.trim_end().parse::<u32>().unwrap(),
219                        Err(e) => match e.kind() {
220                            std::io::ErrorKind::NotFound => 64,
221                            _ => return std::io::Result::Err(e),
222                        },
223                    };
224                    caches.push(Cache {
225                        cpus,
226                        level: fs_err::read_to_string(path.join("level"))?
227                            .trim_end()
228                            .parse()
229                            .unwrap(),
230                        cache_type: match fs_err::read_to_string(path.join("type"))?.trim_end() {
231                            "Data" => super::CacheType::Data,
232                            "Instruction" => super::CacheType::Instruction,
233                            "Unified" => super::CacheType::Unified,
234                            _ => continue,
235                        },
236                        size: fs_err::read_to_string(path.join("size"))?
237                            .strip_suffix("K\n")
238                            .unwrap()
239                            .parse::<u32>()
240                            .unwrap()
241                            * 1024,
242                        associativity: if associativity == 0 {
243                            None
244                        } else {
245                            Some(associativity)
246                        },
247                        line_size,
248                    });
249                }
250            }
251            Ok(caches)
252        }
253    }
254}
255
256#[cfg(target_os = "macos")]
257mod macos {
258    use super::Cache;
259    use super::CacheTopology;
260
261    impl CacheTopology {
262        pub(crate) fn host_caches() -> std::io::Result<Vec<Cache>> {
263            // TODO
264            Ok(Vec::new())
265        }
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    #[test]
272    fn test_host_cache_topology() {
273        let topology = super::CacheTopology::from_host().unwrap();
274        assert!(!topology.caches.is_empty());
275        println!("{topology:?}");
276    }
277}