Skip to main content

tenferro_tensor/cpu/
context.rs

1use std::collections::HashMap;
2use std::env;
3use std::sync::{Arc, Mutex, OnceLock};
4
5use crate::{Error, Result};
6
7/// Reusable CPU execution context backed by an owned rayon thread pool.
8///
9/// # Examples
10///
11/// ```ignore
12/// use tenferro_tensor::cpu::CpuContext;
13///
14/// let ctx = CpuContext::with_threads(1);
15/// let seen = ctx.install(|| rayon::current_num_threads());
16/// assert_eq!(seen, 1);
17/// ```
18#[derive(Clone)]
19pub struct CpuContext {
20    pool: Arc<rayon::ThreadPool>,
21}
22
23fn shared_pools() -> &'static Mutex<HashMap<usize, Arc<rayon::ThreadPool>>> {
24    static POOLS: OnceLock<Mutex<HashMap<usize, Arc<rayon::ThreadPool>>>> = OnceLock::new();
25    POOLS.get_or_init(|| Mutex::new(HashMap::new()))
26}
27
28pub(crate) fn get_or_create_pool(num_threads: usize) -> Arc<rayon::ThreadPool> {
29    let mut pools = shared_pools()
30        .lock()
31        .unwrap_or_else(|poisoned| poisoned.into_inner());
32    if let Some(pool) = pools.get(&num_threads) {
33        return Arc::clone(pool);
34    }
35
36    let pool = Arc::new(
37        rayon::ThreadPoolBuilder::new()
38            .num_threads(num_threads)
39            .build()
40            .unwrap_or_else(|e| panic!("failed to create rayon thread pool: {e}")),
41    );
42    pools.insert(num_threads, Arc::clone(&pool));
43    pool
44}
45
46impl CpuContext {
47    /// Create a CPU context from `RAYON_NUM_THREADS`, or fall back to the
48    /// process-visible CPU count.
49    ///
50    /// # Examples
51    ///
52    /// ```ignore
53    /// use tenferro_tensor::cpu::CpuContext;
54    ///
55    /// let ctx = CpuContext::from_env();
56    /// let _ = ctx.num_threads();
57    /// ```
58    pub fn from_env() -> Self {
59        Self::try_from_env()
60            .unwrap_or_else(|_| Self::with_threads(super::affinity::available_parallelism()))
61    }
62
63    /// Try to create a CPU context from `RAYON_NUM_THREADS`.
64    ///
65    /// # Examples
66    ///
67    /// ```ignore
68    /// use tenferro_tensor::cpu::CpuContext;
69    ///
70    /// let ctx = CpuContext::try_from_env().unwrap();
71    /// let _ = ctx.num_threads();
72    /// ```
73    pub fn try_from_env() -> Result<Self> {
74        match env::var("RAYON_NUM_THREADS") {
75            Ok(value) => {
76                let num_threads = value.parse::<usize>().map_err(|err| Error::InvalidConfig {
77                    op: "CpuContext::try_from_env",
78                    message: format!("invalid RAYON_NUM_THREADS value {value:?}: {err}"),
79                })?;
80                if num_threads == 0 {
81                    return Err(Error::InvalidConfig {
82                        op: "CpuContext::try_from_env",
83                        message: "RAYON_NUM_THREADS must be at least 1".to_string(),
84                    });
85                }
86                Ok(Self::with_threads(num_threads))
87            }
88            Err(env::VarError::NotPresent) => {
89                Ok(Self::with_threads(super::affinity::available_parallelism()))
90            }
91            Err(err) => Err(Error::InvalidConfig {
92                op: "CpuContext::try_from_env",
93                message: format!("failed to read RAYON_NUM_THREADS: {err}"),
94            }),
95        }
96    }
97
98    /// Create a CPU context with a fixed rayon thread-pool size.
99    ///
100    /// # Examples
101    ///
102    /// ```ignore
103    /// use tenferro_tensor::cpu::CpuContext;
104    ///
105    /// let ctx = CpuContext::with_threads(2);
106    /// assert_eq!(ctx.num_threads(), 2);
107    /// ```
108    pub fn with_threads(num_threads: usize) -> Self {
109        assert!(num_threads >= 1, "thread count must be >= 1");
110        Self {
111            pool: get_or_create_pool(num_threads),
112        }
113    }
114
115    /// Return the number of threads in this context's owned rayon pool.
116    ///
117    /// # Examples
118    ///
119    /// ```ignore
120    /// use tenferro_tensor::cpu::CpuContext;
121    ///
122    /// let ctx = CpuContext::with_threads(2);
123    /// assert_eq!(ctx.num_threads(), 2);
124    /// ```
125    pub fn num_threads(&self) -> usize {
126        self.pool.current_num_threads()
127    }
128
129    /// Run a closure inside this context's owned rayon thread pool.
130    ///
131    /// # Examples
132    ///
133    /// ```ignore
134    /// use tenferro_tensor::cpu::CpuContext;
135    ///
136    /// let ctx = CpuContext::with_threads(1);
137    /// let value = ctx.install(|| 1 + 1);
138    /// assert_eq!(value, 2);
139    /// ```
140    pub fn install<R>(&self, op: impl FnOnce() -> R + Send) -> R
141    where
142        R: Send,
143    {
144        self.pool.install(op)
145    }
146
147    /// Return the faer parallelism policy for this context.
148    #[cfg(feature = "cpu-faer")]
149    pub(crate) fn faer_par(&self) -> faer::Par {
150        if self.num_threads() == 1 {
151            faer::Par::Seq
152        } else {
153            faer::Par::rayon(0)
154        }
155    }
156}