Skip to main content

tenferro_tensor/cpu/
affinity.rs

1use std::num::NonZeroUsize;
2
3/// Return a best-effort CPU count available to the current process.
4///
5/// This first tries an OS-standard process-affinity query when supported, then
6/// falls back to `std::thread::available_parallelism()`, and finally to `1`.
7///
8/// # Examples
9///
10/// ```
11/// let available = tenferro_tensor::cpu::available_parallelism();
12/// assert!(available >= 1);
13/// ```
14pub fn available_parallelism() -> usize {
15    process_cpu_affinity_count()
16        .or_else(standard_available_parallelism)
17        .unwrap_or(1)
18}
19
20/// Return the current process affinity mask size when the platform exposes a
21/// standard affinity API.
22///
23/// Platforms without an affinity query return `None`.
24///
25/// # Examples
26///
27/// ```
28/// let count = tenferro_tensor::cpu::process_cpu_affinity_count();
29/// if let Some(count) = count {
30///     assert!(count >= 1);
31/// }
32/// ```
33pub fn process_cpu_affinity_count() -> Option<usize> {
34    platform_process_cpu_affinity_count()
35}
36
37pub(crate) fn standard_available_parallelism() -> Option<usize> {
38    std::thread::available_parallelism()
39        .ok()
40        .map(NonZeroUsize::get)
41}
42
43#[cfg(any(target_os = "linux", target_os = "android"))]
44fn platform_process_cpu_affinity_count() -> Option<usize> {
45    unsafe extern "C" {
46        fn sched_getaffinity(pid: i32, cpusetsize: usize, mask: *mut core::ffi::c_void) -> i32;
47    }
48
49    const INITIAL_MASK_BYTES: usize = 128;
50    const EINVAL: i32 = 22;
51
52    let mut mask_bytes = INITIAL_MASK_BYTES;
53    loop {
54        let mut mask = vec![0u8; mask_bytes];
55        let rc = unsafe {
56            sched_getaffinity(0, mask_bytes, mask.as_mut_ptr().cast::<core::ffi::c_void>())
57        };
58        if rc == 0 {
59            let count = mask.iter().map(|byte| byte.count_ones() as usize).sum();
60            return (count > 0).then_some(count);
61        }
62
63        match std::io::Error::last_os_error().raw_os_error() {
64            Some(errno) if errno == EINVAL => {
65                mask_bytes = mask_bytes.checked_mul(2)?;
66            }
67            _ => return None,
68        }
69    }
70}
71
72#[cfg(target_os = "windows")]
73fn platform_process_cpu_affinity_count() -> Option<usize> {
74    type Handle = *mut core::ffi::c_void;
75    type DwordPtr = usize;
76    type Word = u16;
77
78    unsafe extern "system" {
79        fn GetCurrentProcess() -> Handle;
80        fn GetProcessAffinityMask(
81            process: Handle,
82            process_affinity_mask: *mut DwordPtr,
83            system_affinity_mask: *mut DwordPtr,
84        ) -> i32;
85        fn GetActiveProcessorGroupCount() -> Word;
86        fn GetActiveProcessorCount(group_number: Word) -> u32;
87        fn GetProcessGroupAffinity(
88            process: Handle,
89            group_count: *mut Word,
90            group_array: *mut Word,
91        ) -> i32;
92    }
93
94    let process = unsafe { GetCurrentProcess() };
95    let system_group_count = unsafe { GetActiveProcessorGroupCount() };
96
97    if system_group_count <= 1 {
98        let mut process_mask = 0usize;
99        let mut system_mask = 0usize;
100        let ok = unsafe {
101            GetProcessAffinityMask(
102                process,
103                std::ptr::addr_of_mut!(process_mask),
104                std::ptr::addr_of_mut!(system_mask),
105            )
106        };
107        if ok != 0 {
108            let count = process_mask.count_ones() as usize;
109            return (count > 0).then_some(count);
110        }
111        let count = unsafe { GetActiveProcessorCount(0) } as usize;
112        return (count > 0).then_some(count);
113    }
114
115    let mut group_count: Word = 0;
116    let ok = unsafe {
117        GetProcessGroupAffinity(
118            process,
119            std::ptr::addr_of_mut!(group_count),
120            std::ptr::null_mut(),
121        )
122    };
123    if ok != 0 || group_count == 0 {
124        let count = unsafe { GetActiveProcessorCount(u16::MAX) } as usize;
125        return (count > 0).then_some(count);
126    }
127
128    let mut groups = vec![0u16; group_count as usize];
129    let ok = unsafe {
130        GetProcessGroupAffinity(
131            process,
132            std::ptr::addr_of_mut!(group_count),
133            groups.as_mut_ptr(),
134        )
135    };
136    if ok == 0 || group_count == 0 {
137        let count = unsafe { GetActiveProcessorCount(u16::MAX) } as usize;
138        return (count > 0).then_some(count);
139    }
140
141    if group_count == 1 {
142        let mut process_mask = 0usize;
143        let mut system_mask = 0usize;
144        let ok = unsafe {
145            GetProcessAffinityMask(
146                process,
147                std::ptr::addr_of_mut!(process_mask),
148                std::ptr::addr_of_mut!(system_mask),
149            )
150        };
151        if ok != 0 {
152            let count = process_mask.count_ones() as usize;
153            return (count > 0).then_some(count);
154        }
155    }
156
157    let count = groups
158        .into_iter()
159        .map(|group| unsafe { GetActiveProcessorCount(group) } as usize)
160        .sum();
161    (count > 0).then_some(count)
162}
163
164#[cfg(not(any(target_os = "linux", target_os = "android", target_os = "windows")))]
165fn platform_process_cpu_affinity_count() -> Option<usize> {
166    None
167}