tenferro_prims/cpu/context.rs
1use std::collections::HashMap;
2use std::sync::{Arc, Mutex, OnceLock};
3
4use tenferro_algebra::{Conjugate, Scalar};
5use tenferro_device::{Error, Result};
6use tenferro_tensor::Tensor;
7
8use crate::{infra::plan_cache::PlanCache, TensorTempPoolContext};
9
10#[cfg(feature = "gemm-blas")]
11use super::scratch::{ScratchBuf, ScratchPool};
12use super::temp_pool::TempPool;
13use crate::cpu::common;
14
15/// CPU execution context.
16///
17/// Encapsulates CPU-side execution resources, analogous to cuTENSOR's
18/// `cutensorHandle_t`. Holds a rayon thread pool, a [`PlanCache`] for plan
19/// reuse, and reusable temporary buffers for host-side execution helpers.
20///
21/// # Examples
22///
23/// ```
24/// use tenferro_prims::CpuContext;
25///
26/// # fn demo() -> tenferro_device::Result<()> {
27/// let mut ctx = CpuContext::try_new(4)?; // 4-thread pool
28/// assert_eq!(ctx.num_threads(), 4);
29/// # Ok(())
30/// # }
31/// ```
32pub struct CpuContext {
33 pub(super) pool: Arc<rayon::ThreadPool>,
34 pub(super) plan_cache: PlanCache,
35 #[allow(dead_code)]
36 temp_pool: TempPool,
37 #[cfg(feature = "gemm-blas")]
38 scratch: ScratchPool,
39}
40
41fn shared_thread_pools() -> &'static Mutex<HashMap<usize, Arc<rayon::ThreadPool>>> {
42 static SHARED_THREAD_POOLS: OnceLock<Mutex<HashMap<usize, Arc<rayon::ThreadPool>>>> =
43 OnceLock::new();
44 SHARED_THREAD_POOLS.get_or_init(|| Mutex::new(HashMap::new()))
45}
46
47fn shared_thread_pool(num_threads: usize) -> Result<Arc<rayon::ThreadPool>> {
48 let mut pools = shared_thread_pools()
49 .lock()
50 .unwrap_or_else(|poisoned| poisoned.into_inner());
51 if let Some(pool) = pools.get(&num_threads) {
52 return Ok(Arc::clone(pool));
53 }
54
55 let pool = Arc::new(
56 rayon::ThreadPoolBuilder::new()
57 .num_threads(num_threads)
58 .build()
59 .map_err(|e| Error::DeviceError(format!("failed to build rayon thread pool: {e}")))?,
60 );
61 pools.insert(num_threads, Arc::clone(&pool));
62 Ok(pool)
63}
64
65#[cfg(target_os = "linux")]
66fn affinity_thread_count() -> Option<usize> {
67 let mut set = std::mem::MaybeUninit::<libc::cpu_set_t>::zeroed();
68 let rc = unsafe {
69 libc::sched_getaffinity(0, std::mem::size_of::<libc::cpu_set_t>(), set.as_mut_ptr())
70 };
71 if rc != 0 {
72 return None;
73 }
74 let set = unsafe { set.assume_init() };
75 let count = unsafe { libc::CPU_COUNT(&set) as usize };
76 (count > 0).then_some(count)
77}
78
79#[cfg(not(target_os = "linux"))]
80fn affinity_thread_count() -> Option<usize> {
81 None
82}
83
84impl CpuContext {
85 /// Return the backend-defined default CPU thread count.
86 ///
87 /// On Linux this prefers the current process CPU affinity mask. Other
88 /// platforms fall back to [`std::thread::available_parallelism`].
89 pub fn default_num_threads() -> usize {
90 affinity_thread_count()
91 .or_else(|| std::thread::available_parallelism().ok().map(usize::from))
92 .unwrap_or(1)
93 .max(1)
94 }
95
96 /// Create a new CPU context with the given number of threads.
97 ///
98 /// # Errors
99 ///
100 /// Returns [`tenferro_device::Error::InvalidArgument`] when
101 /// `num_threads == 0`, or [`tenferro_device::Error::DeviceError`] when the
102 /// underlying Rayon thread-pool construction fails.
103 ///
104 /// # Examples
105 ///
106 /// ```
107 /// # fn demo() -> tenferro_device::Result<()> {
108 /// use tenferro_prims::CpuContext;
109 ///
110 /// let ctx = CpuContext::try_new(2)?;
111 /// assert_eq!(ctx.num_threads(), 2);
112 /// # Ok(())
113 /// # }
114 /// ```
115 pub fn try_new(num_threads: usize) -> Result<Self> {
116 if num_threads == 0 {
117 return Err(Error::InvalidArgument(
118 "CpuContext::try_new requires num_threads >= 1".into(),
119 ));
120 }
121 Ok(Self {
122 pool: shared_thread_pool(num_threads)?,
123 plan_cache: PlanCache::new(),
124 temp_pool: TempPool::default(),
125 #[cfg(feature = "gemm-blas")]
126 scratch: ScratchPool::default(),
127 })
128 }
129
130 /// Create a new CPU context using the backend-defined default thread count.
131 pub fn try_new_default() -> Result<Self> {
132 Self::try_new(Self::default_num_threads())
133 }
134
135 /// Create a new CPU context with the given number of threads.
136 ///
137 /// This is a convenience wrapper around [`CpuContext::try_new`]. Production
138 /// code should generally prefer the fallible constructor so context setup
139 /// errors stay in the normal `Result` flow.
140 ///
141 /// # Panics
142 ///
143 /// Panics if [`CpuContext::try_new`] returns an error.
144 ///
145 /// # Examples
146 ///
147 /// ```
148 /// use tenferro_prims::CpuContext;
149 ///
150 /// let ctx = CpuContext::new(1);
151 /// assert_eq!(ctx.num_threads(), 1);
152 /// ```
153 pub fn new(num_threads: usize) -> Self {
154 Self::try_new(num_threads)
155 .unwrap_or_else(|e| panic!("failed to initialize CpuContext: {e}"))
156 }
157
158 /// Create a new CPU context using the backend-defined default thread count.
159 pub fn new_default() -> Self {
160 Self::try_new_default()
161 .unwrap_or_else(|e| panic!("failed to initialize CpuContext with defaults: {e}"))
162 }
163
164 /// Returns the number of threads in the pool.
165 pub fn num_threads(&self) -> usize {
166 self.pool.current_num_threads()
167 }
168
169 /// Returns a reference to the underlying rayon thread pool.
170 pub fn thread_pool(&self) -> &rayon::ThreadPool {
171 self.pool.as_ref()
172 }
173
174 /// Run a closure inside the owned rayon thread pool.
175 pub fn install<R>(&self, op: impl FnOnce() -> R + Send) -> R
176 where
177 R: Send,
178 {
179 self.pool.install(op)
180 }
181
182 /// Returns a mutable reference to the plan cache.
183 pub fn plan_cache_mut(&mut self) -> &mut PlanCache {
184 &mut self.plan_cache
185 }
186
187 #[allow(dead_code)]
188 pub(crate) fn temp_pool_mut(&mut self) -> &mut TempPool {
189 &mut self.temp_pool
190 }
191
192 #[cfg(feature = "gemm-blas")]
193 pub(super) fn take_scratch<T>(&mut self, len: usize) -> Result<ScratchBuf<T>> {
194 self.scratch.take(len)
195 }
196
197 #[cfg(feature = "gemm-blas")]
198 pub(super) fn put_scratch<T>(&mut self, buf: ScratchBuf<T>) {
199 self.scratch.put(buf);
200 }
201}
202
203impl Default for CpuContext {
204 fn default() -> Self {
205 Self::new_default()
206 }
207}
208
209impl TensorTempPoolContext for CpuContext {
210 fn take_temp_vec<T: Send + 'static>(&mut self, len: usize) -> Vec<T> {
211 self.temp_pool_mut().take_vec::<T>(len)
212 }
213
214 fn put_temp_vec<T: Send + 'static>(&mut self, vec: Vec<T>) {
215 self.temp_pool_mut().put_vec(vec);
216 }
217}
218
219/// CPU backend using strided-kernel and GEMM.
220///
221/// Dispatched automatically when tensors reside on
222/// [`LogicalMemorySpace::MainMemory`](tenferro_device::LogicalMemorySpace::MainMemory).
223/// Implements the semiring core and semiring fast-path families for
224/// [`Standard<T>`](tenferro_algebra::Standard).
225///
226/// # Examples
227///
228/// ```ignore
229/// use tenferro_algebra::Standard;
230/// use tenferro_device::LogicalMemorySpace;
231/// use tenferro_prims::{CpuBackend, CpuContext, SemiringCoreDescriptor, TensorSemiringCore};
232/// use tenferro_tensor::{MemoryOrder, Tensor};
233///
234/// let mut ctx = CpuContext::try_new(4).unwrap();
235/// let col = MemoryOrder::ColumnMajor;
236/// let mem = LogicalMemorySpace::MainMemory;
237/// let a_base = Tensor::<f64>::zeros(&[3, 4], mem, col).unwrap();
238/// let a = a_base.permute(&[1, 0]).unwrap();
239/// let mut b = Tensor::<f64>::zeros(&[4, 3], mem, col).unwrap();
240/// let plan = <CpuBackend as TensorSemiringCore<Standard<f64>>>::plan(
241/// &mut ctx,
242/// &SemiringCoreDescriptor::MakeContiguous,
243/// &[&[4, 3], &[4, 3]],
244/// )
245/// .unwrap();
246/// <CpuBackend as TensorSemiringCore<Standard<f64>>>::execute(
247/// &mut ctx,
248/// &plan,
249/// 1.0,
250/// &[&a],
251/// 0.0,
252/// &mut b,
253/// )
254/// .unwrap();
255/// ```
256pub struct CpuBackend;
257
258impl CpuBackend {
259 pub(super) fn supports_batched_gemm_type<T: Scalar>() -> bool {
260 common::is_supported_scalar_type::<T>()
261 }
262
263 /// Materialize a lazily-conjugated tensor.
264 ///
265 /// If `src.is_conjugated()` is `false`, returns a shallow clone.
266 /// If `true`, routes through the tensor-layer logical combine substrate so
267 /// the result is resolved (`conjugated = false`) without reimplementing the
268 /// copy logic here.
269 ///
270 /// # Examples
271 ///
272 /// ```ignore
273 /// use tenferro_prims::{CpuBackend, CpuContext};
274 ///
275 /// let mut ctx = CpuContext::try_new(1).unwrap();
276 /// let a_conj = a.into_conj(); // lazy
277 /// let a_resolved = CpuBackend::resolve_conj(&mut ctx, &a_conj);
278 /// assert!(!a_resolved.is_conjugated());
279 /// ```
280 pub fn resolve_conj<T: Scalar + Conjugate>(
281 _ctx: &mut CpuContext,
282 src: &Tensor<T>,
283 ) -> Tensor<T> {
284 if !src.is_conjugated() {
285 return src.clone();
286 }
287
288 Tensor::stack(&[src], 0)
289 .and_then(|tensor| tensor.squeeze_dim(0))
290 .unwrap_or_else(|_| src.clone())
291 }
292}