tenferro_tensor/cpu/
affinity.rs1use std::num::NonZeroUsize;
2
3pub fn available_parallelism() -> usize {
15 process_cpu_affinity_count()
16 .or_else(standard_available_parallelism)
17 .unwrap_or(1)
18}
19
20pub fn process_cpu_affinity_count() -> Option<usize> {
34 platform_process_cpu_affinity_count()
35}
36
37pub(crate) fn standard_available_parallelism() -> Option<usize> {
38 std::thread::available_parallelism()
39 .ok()
40 .map(NonZeroUsize::get)
41}
42
43#[cfg(any(target_os = "linux", target_os = "android"))]
44fn platform_process_cpu_affinity_count() -> Option<usize> {
45 unsafe extern "C" {
46 fn sched_getaffinity(pid: i32, cpusetsize: usize, mask: *mut core::ffi::c_void) -> i32;
47 }
48
49 const INITIAL_MASK_BYTES: usize = 128;
50 const EINVAL: i32 = 22;
51
52 let mut mask_bytes = INITIAL_MASK_BYTES;
53 loop {
54 let mut mask = vec![0u8; mask_bytes];
55 let rc = unsafe {
56 sched_getaffinity(0, mask_bytes, mask.as_mut_ptr().cast::<core::ffi::c_void>())
57 };
58 if rc == 0 {
59 let count = mask.iter().map(|byte| byte.count_ones() as usize).sum();
60 return (count > 0).then_some(count);
61 }
62
63 match std::io::Error::last_os_error().raw_os_error() {
64 Some(errno) if errno == EINVAL => {
65 mask_bytes = mask_bytes.checked_mul(2)?;
66 }
67 _ => return None,
68 }
69 }
70}
71
72#[cfg(target_os = "windows")]
73fn platform_process_cpu_affinity_count() -> Option<usize> {
74 type Handle = *mut core::ffi::c_void;
75 type DwordPtr = usize;
76 type Word = u16;
77
78 unsafe extern "system" {
79 fn GetCurrentProcess() -> Handle;
80 fn GetProcessAffinityMask(
81 process: Handle,
82 process_affinity_mask: *mut DwordPtr,
83 system_affinity_mask: *mut DwordPtr,
84 ) -> i32;
85 fn GetActiveProcessorGroupCount() -> Word;
86 fn GetActiveProcessorCount(group_number: Word) -> u32;
87 fn GetProcessGroupAffinity(
88 process: Handle,
89 group_count: *mut Word,
90 group_array: *mut Word,
91 ) -> i32;
92 }
93
94 let process = unsafe { GetCurrentProcess() };
95 let system_group_count = unsafe { GetActiveProcessorGroupCount() };
96
97 if system_group_count <= 1 {
98 let mut process_mask = 0usize;
99 let mut system_mask = 0usize;
100 let ok = unsafe {
101 GetProcessAffinityMask(
102 process,
103 std::ptr::addr_of_mut!(process_mask),
104 std::ptr::addr_of_mut!(system_mask),
105 )
106 };
107 if ok != 0 {
108 let count = process_mask.count_ones() as usize;
109 return (count > 0).then_some(count);
110 }
111 let count = unsafe { GetActiveProcessorCount(0) } as usize;
112 return (count > 0).then_some(count);
113 }
114
115 let mut group_count: Word = 0;
116 let ok = unsafe {
117 GetProcessGroupAffinity(
118 process,
119 std::ptr::addr_of_mut!(group_count),
120 std::ptr::null_mut(),
121 )
122 };
123 if ok != 0 || group_count == 0 {
124 let count = unsafe { GetActiveProcessorCount(u16::MAX) } as usize;
125 return (count > 0).then_some(count);
126 }
127
128 let mut groups = vec![0u16; group_count as usize];
129 let ok = unsafe {
130 GetProcessGroupAffinity(
131 process,
132 std::ptr::addr_of_mut!(group_count),
133 groups.as_mut_ptr(),
134 )
135 };
136 if ok == 0 || group_count == 0 {
137 let count = unsafe { GetActiveProcessorCount(u16::MAX) } as usize;
138 return (count > 0).then_some(count);
139 }
140
141 if group_count == 1 {
142 let mut process_mask = 0usize;
143 let mut system_mask = 0usize;
144 let ok = unsafe {
145 GetProcessAffinityMask(
146 process,
147 std::ptr::addr_of_mut!(process_mask),
148 std::ptr::addr_of_mut!(system_mask),
149 )
150 };
151 if ok != 0 {
152 let count = process_mask.count_ones() as usize;
153 return (count > 0).then_some(count);
154 }
155 }
156
157 let count = groups
158 .into_iter()
159 .map(|group| unsafe { GetActiveProcessorCount(group) } as usize)
160 .sum();
161 (count > 0).then_some(count)
162}
163
164#[cfg(not(any(target_os = "linux", target_os = "android", target_os = "windows")))]
165fn platform_process_cpu_affinity_count() -> Option<usize> {
166 None
167}