tenferro_cpu/context.rs
1use std::env;
2use std::sync::Arc;
3
4use crate::{Error, Result};
5
6/// Reusable CPU execution context carrying CPU parallelism policy.
7///
8/// `CpuContext` stores the requested thread count as a kernel-level
9/// parallelism hint and owns the Rayon pool used by multi-threaded CPU work.
10///
11/// # Examples
12///
13/// ```
14/// use tenferro_cpu::CpuContext;
15///
16/// let ctx = CpuContext::with_threads(1).unwrap();
17/// let value = ctx.install(|| 1 + 1);
18/// assert_eq!(value, 2);
19/// assert_eq!(ctx.num_threads(), 1);
20/// ```
21#[derive(Clone, Debug)]
22pub struct CpuContext {
23 num_threads: usize,
24 pool: Option<Arc<rayon::ThreadPool>>,
25}
26
27impl CpuContext {
28 /// Create a CPU context from `RAYON_NUM_THREADS`, or fall back to the
29 /// process-visible CPU count.
30 ///
31 /// # Examples
32 ///
33 /// ```
34 /// use tenferro_cpu::CpuContext;
35 ///
36 /// let ctx = CpuContext::from_env();
37 /// assert!(ctx.num_threads() >= 1);
38 /// ```
39 pub fn from_env() -> Self {
40 Self::try_from_env().unwrap_or_else(|_| Self::single_threaded())
41 }
42
43 /// Try to create a CPU context from `RAYON_NUM_THREADS`.
44 ///
45 /// # Examples
46 ///
47 /// ```
48 /// use tenferro_cpu::CpuContext;
49 ///
50 /// let ctx = CpuContext::try_from_env()
51 /// .unwrap_or_else(|_| CpuContext::with_threads(1).unwrap());
52 /// assert!(ctx.num_threads() >= 1);
53 /// ```
54 pub fn try_from_env() -> Result<Self> {
55 match env::var("RAYON_NUM_THREADS") {
56 Ok(value) => {
57 let num_threads = value.parse::<usize>().map_err(|err| Error::InvalidConfig {
58 op: "CpuContext::try_from_env",
59 message: format!("invalid RAYON_NUM_THREADS value {value:?}: {err}"),
60 })?;
61 Self::with_threads(num_threads).map_err(|err| match err {
62 Error::InvalidConfig { message, .. } => Error::InvalidConfig {
63 op: "CpuContext::try_from_env",
64 message: format!("invalid RAYON_NUM_THREADS value {value:?}: {message}"),
65 },
66 err => err,
67 })
68 }
69 Err(env::VarError::NotPresent) => {
70 Self::with_threads(super::affinity::available_parallelism())
71 }
72 Err(err) => Err(Error::InvalidConfig {
73 op: "CpuContext::try_from_env",
74 message: format!("failed to read RAYON_NUM_THREADS: {err}"),
75 }),
76 }
77 }
78
79 /// Create a CPU context with a fixed parallelism hint.
80 ///
81 /// # Examples
82 ///
83 /// ```
84 /// use tenferro_cpu::CpuContext;
85 ///
86 /// let ctx = CpuContext::with_threads(2).unwrap();
87 /// assert_eq!(ctx.num_threads(), 2);
88 /// ```
89 ///
90 /// # Errors
91 ///
92 /// Returns an error when `num_threads` is zero or Rayon rejects the pool.
93 pub fn with_threads(num_threads: usize) -> Result<Self> {
94 if num_threads == 0 {
95 return Err(Error::InvalidConfig {
96 op: "CpuContext::with_threads",
97 message: "thread count must be at least 1".into(),
98 });
99 }
100 let pool = if num_threads == 1 {
101 None
102 } else {
103 Some(Arc::new(
104 rayon::ThreadPoolBuilder::new()
105 .num_threads(num_threads)
106 .build()
107 .map_err(|err| Error::InvalidConfig {
108 op: "CpuContext::with_threads",
109 message: format!("failed to build CPU thread pool: {err}"),
110 })?,
111 ))
112 };
113 Ok(Self { num_threads, pool })
114 }
115
116 fn single_threaded() -> Self {
117 Self {
118 num_threads: 1,
119 pool: None,
120 }
121 }
122
123 /// Return this context's CPU parallelism hint.
124 ///
125 /// # Examples
126 ///
127 /// ```
128 /// use tenferro_cpu::CpuContext;
129 ///
130 /// let ctx = CpuContext::with_threads(2).unwrap();
131 /// assert_eq!(ctx.num_threads(), 2);
132 /// ```
133 pub fn num_threads(&self) -> usize {
134 self.num_threads
135 }
136
137 /// Run a closure inside this context's CPU execution scope.
138 ///
139 /// # Examples
140 ///
141 /// ```
142 /// use tenferro_cpu::CpuContext;
143 ///
144 /// let ctx = CpuContext::with_threads(1).unwrap();
145 /// let value = ctx.install(|| 1 + 1);
146 /// assert_eq!(value, 2);
147 /// ```
148 pub fn install<R: Send>(&self, op: impl FnOnce() -> R + Send) -> R {
149 match &self.pool {
150 Some(pool) => pool.install(op),
151 None => op(),
152 }
153 }
154
155 /// Return the faer parallelism policy for this context.
156 #[cfg(feature = "cpu-faer")]
157 #[doc(hidden)]
158 pub fn faer_par(&self) -> faer::Par {
159 if self.num_threads == 1 {
160 faer::Par::Seq
161 } else {
162 faer::Par::rayon(0)
163 }
164 }
165}
166
167#[cfg(test)]
168mod tests;