tenferro_tensor/cpu/context.rs
1use std::collections::HashMap;
2use std::env;
3use std::sync::{Arc, Mutex, OnceLock};
4
5use crate::{Error, Result};
6
7/// Reusable CPU execution context backed by an owned rayon thread pool.
8///
9/// # Examples
10///
11/// ```ignore
12/// use tenferro_tensor::cpu::CpuContext;
13///
14/// let ctx = CpuContext::with_threads(1);
15/// let seen = ctx.install(|| rayon::current_num_threads());
16/// assert_eq!(seen, 1);
17/// ```
18#[derive(Clone)]
19pub struct CpuContext {
20 pool: Arc<rayon::ThreadPool>,
21}
22
23fn shared_pools() -> &'static Mutex<HashMap<usize, Arc<rayon::ThreadPool>>> {
24 static POOLS: OnceLock<Mutex<HashMap<usize, Arc<rayon::ThreadPool>>>> = OnceLock::new();
25 POOLS.get_or_init(|| Mutex::new(HashMap::new()))
26}
27
28pub(crate) fn get_or_create_pool(num_threads: usize) -> Arc<rayon::ThreadPool> {
29 let mut pools = shared_pools()
30 .lock()
31 .unwrap_or_else(|poisoned| poisoned.into_inner());
32 if let Some(pool) = pools.get(&num_threads) {
33 return Arc::clone(pool);
34 }
35
36 let pool = Arc::new(
37 rayon::ThreadPoolBuilder::new()
38 .num_threads(num_threads)
39 .build()
40 .unwrap_or_else(|e| panic!("failed to create rayon thread pool: {e}")),
41 );
42 pools.insert(num_threads, Arc::clone(&pool));
43 pool
44}
45
46impl CpuContext {
47 /// Create a CPU context from `RAYON_NUM_THREADS`, or fall back to the
48 /// process-visible CPU count.
49 ///
50 /// # Examples
51 ///
52 /// ```ignore
53 /// use tenferro_tensor::cpu::CpuContext;
54 ///
55 /// let ctx = CpuContext::from_env();
56 /// let _ = ctx.num_threads();
57 /// ```
58 pub fn from_env() -> Self {
59 Self::try_from_env()
60 .unwrap_or_else(|_| Self::with_threads(super::affinity::available_parallelism()))
61 }
62
63 /// Try to create a CPU context from `RAYON_NUM_THREADS`.
64 ///
65 /// # Examples
66 ///
67 /// ```ignore
68 /// use tenferro_tensor::cpu::CpuContext;
69 ///
70 /// let ctx = CpuContext::try_from_env().unwrap();
71 /// let _ = ctx.num_threads();
72 /// ```
73 pub fn try_from_env() -> Result<Self> {
74 match env::var("RAYON_NUM_THREADS") {
75 Ok(value) => {
76 let num_threads = value.parse::<usize>().map_err(|err| Error::InvalidConfig {
77 op: "CpuContext::try_from_env",
78 message: format!("invalid RAYON_NUM_THREADS value {value:?}: {err}"),
79 })?;
80 if num_threads == 0 {
81 return Err(Error::InvalidConfig {
82 op: "CpuContext::try_from_env",
83 message: "RAYON_NUM_THREADS must be at least 1".to_string(),
84 });
85 }
86 Ok(Self::with_threads(num_threads))
87 }
88 Err(env::VarError::NotPresent) => {
89 Ok(Self::with_threads(super::affinity::available_parallelism()))
90 }
91 Err(err) => Err(Error::InvalidConfig {
92 op: "CpuContext::try_from_env",
93 message: format!("failed to read RAYON_NUM_THREADS: {err}"),
94 }),
95 }
96 }
97
98 /// Create a CPU context with a fixed rayon thread-pool size.
99 ///
100 /// # Examples
101 ///
102 /// ```ignore
103 /// use tenferro_tensor::cpu::CpuContext;
104 ///
105 /// let ctx = CpuContext::with_threads(2);
106 /// assert_eq!(ctx.num_threads(), 2);
107 /// ```
108 pub fn with_threads(num_threads: usize) -> Self {
109 assert!(num_threads >= 1, "thread count must be >= 1");
110 Self {
111 pool: get_or_create_pool(num_threads),
112 }
113 }
114
115 /// Return the number of threads in this context's owned rayon pool.
116 ///
117 /// # Examples
118 ///
119 /// ```ignore
120 /// use tenferro_tensor::cpu::CpuContext;
121 ///
122 /// let ctx = CpuContext::with_threads(2);
123 /// assert_eq!(ctx.num_threads(), 2);
124 /// ```
125 pub fn num_threads(&self) -> usize {
126 self.pool.current_num_threads()
127 }
128
129 /// Run a closure inside this context's owned rayon thread pool.
130 ///
131 /// # Examples
132 ///
133 /// ```ignore
134 /// use tenferro_tensor::cpu::CpuContext;
135 ///
136 /// let ctx = CpuContext::with_threads(1);
137 /// let value = ctx.install(|| 1 + 1);
138 /// assert_eq!(value, 2);
139 /// ```
140 pub fn install<R>(&self, op: impl FnOnce() -> R + Send) -> R
141 where
142 R: Send,
143 {
144 self.pool.install(op)
145 }
146
147 /// Return the faer parallelism policy for this context.
148 #[cfg(feature = "cpu-faer")]
149 pub(crate) fn faer_par(&self) -> faer::Par {
150 if self.num_threads() == 1 {
151 faer::Par::Seq
152 } else {
153 faer::Par::rayon(0)
154 }
155 }
156}