tenferro_prims/cpu/
context.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex, OnceLock};
3
4use tenferro_algebra::{Conjugate, Scalar};
5use tenferro_device::{Error, Result};
6use tenferro_tensor::Tensor;
7
8use crate::{infra::plan_cache::PlanCache, TensorTempPoolContext};
9
10#[cfg(feature = "gemm-blas")]
11use super::scratch::{ScratchBuf, ScratchPool};
12use super::temp_pool::TempPool;
13use crate::cpu::common;
14
15/// CPU execution context.
16///
17/// Encapsulates CPU-side execution resources, analogous to cuTENSOR's
18/// `cutensorHandle_t`. Holds a rayon thread pool, a [`PlanCache`] for plan
19/// reuse, and reusable temporary buffers for host-side execution helpers.
20///
21/// # Examples
22///
23/// ```
24/// use tenferro_prims::CpuContext;
25///
26/// # fn demo() -> tenferro_device::Result<()> {
27/// let mut ctx = CpuContext::try_new(4)?; // 4-thread pool
28/// assert_eq!(ctx.num_threads(), 4);
29/// # Ok(())
30/// # }
31/// ```
32pub struct CpuContext {
33    pub(super) pool: Arc<rayon::ThreadPool>,
34    pub(super) plan_cache: PlanCache,
35    #[allow(dead_code)]
36    temp_pool: TempPool,
37    #[cfg(feature = "gemm-blas")]
38    scratch: ScratchPool,
39}
40
41fn shared_thread_pools() -> &'static Mutex<HashMap<usize, Arc<rayon::ThreadPool>>> {
42    static SHARED_THREAD_POOLS: OnceLock<Mutex<HashMap<usize, Arc<rayon::ThreadPool>>>> =
43        OnceLock::new();
44    SHARED_THREAD_POOLS.get_or_init(|| Mutex::new(HashMap::new()))
45}
46
47fn shared_thread_pool(num_threads: usize) -> Result<Arc<rayon::ThreadPool>> {
48    let mut pools = shared_thread_pools()
49        .lock()
50        .unwrap_or_else(|poisoned| poisoned.into_inner());
51    if let Some(pool) = pools.get(&num_threads) {
52        return Ok(Arc::clone(pool));
53    }
54
55    let pool = Arc::new(
56        rayon::ThreadPoolBuilder::new()
57            .num_threads(num_threads)
58            .build()
59            .map_err(|e| Error::DeviceError(format!("failed to build rayon thread pool: {e}")))?,
60    );
61    pools.insert(num_threads, Arc::clone(&pool));
62    Ok(pool)
63}
64
65#[cfg(target_os = "linux")]
66fn affinity_thread_count() -> Option<usize> {
67    let mut set = std::mem::MaybeUninit::<libc::cpu_set_t>::zeroed();
68    let rc = unsafe {
69        libc::sched_getaffinity(0, std::mem::size_of::<libc::cpu_set_t>(), set.as_mut_ptr())
70    };
71    if rc != 0 {
72        return None;
73    }
74    let set = unsafe { set.assume_init() };
75    let count = unsafe { libc::CPU_COUNT(&set) as usize };
76    (count > 0).then_some(count)
77}
78
79#[cfg(not(target_os = "linux"))]
80fn affinity_thread_count() -> Option<usize> {
81    None
82}
83
84impl CpuContext {
85    /// Return the backend-defined default CPU thread count.
86    ///
87    /// On Linux this prefers the current process CPU affinity mask. Other
88    /// platforms fall back to [`std::thread::available_parallelism`].
89    pub fn default_num_threads() -> usize {
90        affinity_thread_count()
91            .or_else(|| std::thread::available_parallelism().ok().map(usize::from))
92            .unwrap_or(1)
93            .max(1)
94    }
95
96    /// Create a new CPU context with the given number of threads.
97    ///
98    /// # Errors
99    ///
100    /// Returns [`tenferro_device::Error::InvalidArgument`] when
101    /// `num_threads == 0`, or [`tenferro_device::Error::DeviceError`] when the
102    /// underlying Rayon thread-pool construction fails.
103    ///
104    /// # Examples
105    ///
106    /// ```
107    /// # fn demo() -> tenferro_device::Result<()> {
108    /// use tenferro_prims::CpuContext;
109    ///
110    /// let ctx = CpuContext::try_new(2)?;
111    /// assert_eq!(ctx.num_threads(), 2);
112    /// # Ok(())
113    /// # }
114    /// ```
115    pub fn try_new(num_threads: usize) -> Result<Self> {
116        if num_threads == 0 {
117            return Err(Error::InvalidArgument(
118                "CpuContext::try_new requires num_threads >= 1".into(),
119            ));
120        }
121        Ok(Self {
122            pool: shared_thread_pool(num_threads)?,
123            plan_cache: PlanCache::new(),
124            temp_pool: TempPool::default(),
125            #[cfg(feature = "gemm-blas")]
126            scratch: ScratchPool::default(),
127        })
128    }
129
130    /// Create a new CPU context using the backend-defined default thread count.
131    pub fn try_new_default() -> Result<Self> {
132        Self::try_new(Self::default_num_threads())
133    }
134
135    /// Create a new CPU context with the given number of threads.
136    ///
137    /// This is a convenience wrapper around [`CpuContext::try_new`]. Production
138    /// code should generally prefer the fallible constructor so context setup
139    /// errors stay in the normal `Result` flow.
140    ///
141    /// # Panics
142    ///
143    /// Panics if [`CpuContext::try_new`] returns an error.
144    ///
145    /// # Examples
146    ///
147    /// ```
148    /// use tenferro_prims::CpuContext;
149    ///
150    /// let ctx = CpuContext::new(1);
151    /// assert_eq!(ctx.num_threads(), 1);
152    /// ```
153    pub fn new(num_threads: usize) -> Self {
154        Self::try_new(num_threads)
155            .unwrap_or_else(|e| panic!("failed to initialize CpuContext: {e}"))
156    }
157
158    /// Create a new CPU context using the backend-defined default thread count.
159    pub fn new_default() -> Self {
160        Self::try_new_default()
161            .unwrap_or_else(|e| panic!("failed to initialize CpuContext with defaults: {e}"))
162    }
163
164    /// Returns the number of threads in the pool.
165    pub fn num_threads(&self) -> usize {
166        self.pool.current_num_threads()
167    }
168
169    /// Returns a reference to the underlying rayon thread pool.
170    pub fn thread_pool(&self) -> &rayon::ThreadPool {
171        self.pool.as_ref()
172    }
173
174    /// Run a closure inside the owned rayon thread pool.
175    pub fn install<R>(&self, op: impl FnOnce() -> R + Send) -> R
176    where
177        R: Send,
178    {
179        self.pool.install(op)
180    }
181
182    /// Returns a mutable reference to the plan cache.
183    pub fn plan_cache_mut(&mut self) -> &mut PlanCache {
184        &mut self.plan_cache
185    }
186
187    #[allow(dead_code)]
188    pub(crate) fn temp_pool_mut(&mut self) -> &mut TempPool {
189        &mut self.temp_pool
190    }
191
192    #[cfg(feature = "gemm-blas")]
193    pub(super) fn take_scratch<T>(&mut self, len: usize) -> Result<ScratchBuf<T>> {
194        self.scratch.take(len)
195    }
196
197    #[cfg(feature = "gemm-blas")]
198    pub(super) fn put_scratch<T>(&mut self, buf: ScratchBuf<T>) {
199        self.scratch.put(buf);
200    }
201}
202
203impl Default for CpuContext {
204    fn default() -> Self {
205        Self::new_default()
206    }
207}
208
209impl TensorTempPoolContext for CpuContext {
210    fn take_temp_vec<T: Send + 'static>(&mut self, len: usize) -> Vec<T> {
211        self.temp_pool_mut().take_vec::<T>(len)
212    }
213
214    fn put_temp_vec<T: Send + 'static>(&mut self, vec: Vec<T>) {
215        self.temp_pool_mut().put_vec(vec);
216    }
217}
218
219/// CPU backend using strided-kernel and GEMM.
220///
221/// Dispatched automatically when tensors reside on
222/// [`LogicalMemorySpace::MainMemory`](tenferro_device::LogicalMemorySpace::MainMemory).
223/// Implements the semiring core and semiring fast-path families for
224/// [`Standard<T>`](tenferro_algebra::Standard).
225///
226/// # Examples
227///
228/// ```ignore
229/// use tenferro_algebra::Standard;
230/// use tenferro_device::LogicalMemorySpace;
231/// use tenferro_prims::{CpuBackend, CpuContext, SemiringCoreDescriptor, TensorSemiringCore};
232/// use tenferro_tensor::{MemoryOrder, Tensor};
233///
234/// let mut ctx = CpuContext::try_new(4).unwrap();
235/// let col = MemoryOrder::ColumnMajor;
236/// let mem = LogicalMemorySpace::MainMemory;
237/// let a_base = Tensor::<f64>::zeros(&[3, 4], mem, col).unwrap();
238/// let a = a_base.permute(&[1, 0]).unwrap();
239/// let mut b = Tensor::<f64>::zeros(&[4, 3], mem, col).unwrap();
240/// let plan = <CpuBackend as TensorSemiringCore<Standard<f64>>>::plan(
241///     &mut ctx,
242///     &SemiringCoreDescriptor::MakeContiguous,
243///     &[&[4, 3], &[4, 3]],
244/// )
245/// .unwrap();
246/// <CpuBackend as TensorSemiringCore<Standard<f64>>>::execute(
247///     &mut ctx,
248///     &plan,
249///     1.0,
250///     &[&a],
251///     0.0,
252///     &mut b,
253/// )
254/// .unwrap();
255/// ```
256pub struct CpuBackend;
257
258impl CpuBackend {
259    pub(super) fn supports_batched_gemm_type<T: Scalar>() -> bool {
260        common::is_supported_scalar_type::<T>()
261    }
262
263    /// Materialize a lazily-conjugated tensor.
264    ///
265    /// If `src.is_conjugated()` is `false`, returns a shallow clone.
266    /// If `true`, routes through the tensor-layer logical combine substrate so
267    /// the result is resolved (`conjugated = false`) without reimplementing the
268    /// copy logic here.
269    ///
270    /// # Examples
271    ///
272    /// ```ignore
273    /// use tenferro_prims::{CpuBackend, CpuContext};
274    ///
275    /// let mut ctx = CpuContext::try_new(1).unwrap();
276    /// let a_conj = a.into_conj(); // lazy
277    /// let a_resolved = CpuBackend::resolve_conj(&mut ctx, &a_conj);
278    /// assert!(!a_resolved.is_conjugated());
279    /// ```
280    pub fn resolve_conj<T: Scalar + Conjugate>(
281        _ctx: &mut CpuContext,
282        src: &Tensor<T>,
283    ) -> Tensor<T> {
284        if !src.is_conjugated() {
285            return src.clone();
286        }
287
288        Tensor::stack(&[src], 0)
289            .and_then(|tensor| tensor.squeeze_dim(0))
290            .unwrap_or_else(|_| src.clone())
291    }
292}