Skip to main content

tenferro_cpu/
context.rs

1use std::env;
2use std::sync::Arc;
3
4use crate::{Error, Result};
5
6/// Reusable CPU execution context carrying CPU parallelism policy.
7///
8/// `CpuContext` stores the requested thread count as a kernel-level
9/// parallelism hint and owns the Rayon pool used by multi-threaded CPU work.
10///
11/// # Examples
12///
13/// ```
14/// use tenferro_cpu::CpuContext;
15///
16/// let ctx = CpuContext::with_threads(1).unwrap();
17/// let value = ctx.install(|| 1 + 1);
18/// assert_eq!(value, 2);
19/// assert_eq!(ctx.num_threads(), 1);
20/// ```
21#[derive(Clone, Debug)]
22pub struct CpuContext {
23    num_threads: usize,
24    pool: Option<Arc<rayon::ThreadPool>>,
25}
26
27impl CpuContext {
28    /// Create a CPU context from `RAYON_NUM_THREADS`, or fall back to the
29    /// process-visible CPU count.
30    ///
31    /// # Examples
32    ///
33    /// ```
34    /// use tenferro_cpu::CpuContext;
35    ///
36    /// let ctx = CpuContext::from_env();
37    /// assert!(ctx.num_threads() >= 1);
38    /// ```
39    pub fn from_env() -> Self {
40        Self::try_from_env().unwrap_or_else(|_| Self::single_threaded())
41    }
42
43    /// Try to create a CPU context from `RAYON_NUM_THREADS`.
44    ///
45    /// # Examples
46    ///
47    /// ```
48    /// use tenferro_cpu::CpuContext;
49    ///
50    /// let ctx = CpuContext::try_from_env()
51    ///     .unwrap_or_else(|_| CpuContext::with_threads(1).unwrap());
52    /// assert!(ctx.num_threads() >= 1);
53    /// ```
54    pub fn try_from_env() -> Result<Self> {
55        match env::var("RAYON_NUM_THREADS") {
56            Ok(value) => {
57                let num_threads = value.parse::<usize>().map_err(|err| Error::InvalidConfig {
58                    op: "CpuContext::try_from_env",
59                    message: format!("invalid RAYON_NUM_THREADS value {value:?}: {err}"),
60                })?;
61                Self::with_threads(num_threads).map_err(|err| match err {
62                    Error::InvalidConfig { message, .. } => Error::InvalidConfig {
63                        op: "CpuContext::try_from_env",
64                        message: format!("invalid RAYON_NUM_THREADS value {value:?}: {message}"),
65                    },
66                    err => err,
67                })
68            }
69            Err(env::VarError::NotPresent) => {
70                Self::with_threads(super::affinity::available_parallelism())
71            }
72            Err(err) => Err(Error::InvalidConfig {
73                op: "CpuContext::try_from_env",
74                message: format!("failed to read RAYON_NUM_THREADS: {err}"),
75            }),
76        }
77    }
78
79    /// Create a CPU context with a fixed parallelism hint.
80    ///
81    /// # Examples
82    ///
83    /// ```
84    /// use tenferro_cpu::CpuContext;
85    ///
86    /// let ctx = CpuContext::with_threads(2).unwrap();
87    /// assert_eq!(ctx.num_threads(), 2);
88    /// ```
89    ///
90    /// # Errors
91    ///
92    /// Returns an error when `num_threads` is zero or Rayon rejects the pool.
93    pub fn with_threads(num_threads: usize) -> Result<Self> {
94        if num_threads == 0 {
95            return Err(Error::InvalidConfig {
96                op: "CpuContext::with_threads",
97                message: "thread count must be at least 1".into(),
98            });
99        }
100        let pool = if num_threads == 1 {
101            None
102        } else {
103            Some(Arc::new(
104                rayon::ThreadPoolBuilder::new()
105                    .num_threads(num_threads)
106                    .build()
107                    .map_err(|err| Error::InvalidConfig {
108                        op: "CpuContext::with_threads",
109                        message: format!("failed to build CPU thread pool: {err}"),
110                    })?,
111            ))
112        };
113        Ok(Self { num_threads, pool })
114    }
115
116    fn single_threaded() -> Self {
117        Self {
118            num_threads: 1,
119            pool: None,
120        }
121    }
122
123    /// Return this context's CPU parallelism hint.
124    ///
125    /// # Examples
126    ///
127    /// ```
128    /// use tenferro_cpu::CpuContext;
129    ///
130    /// let ctx = CpuContext::with_threads(2).unwrap();
131    /// assert_eq!(ctx.num_threads(), 2);
132    /// ```
133    pub fn num_threads(&self) -> usize {
134        self.num_threads
135    }
136
137    /// Run a closure inside this context's CPU execution scope.
138    ///
139    /// # Examples
140    ///
141    /// ```
142    /// use tenferro_cpu::CpuContext;
143    ///
144    /// let ctx = CpuContext::with_threads(1).unwrap();
145    /// let value = ctx.install(|| 1 + 1);
146    /// assert_eq!(value, 2);
147    /// ```
148    pub fn install<R: Send>(&self, op: impl FnOnce() -> R + Send) -> R {
149        match &self.pool {
150            Some(pool) => pool.install(op),
151            None => op(),
152        }
153    }
154
155    /// Return the faer parallelism policy for this context.
156    #[cfg(feature = "cpu-faer")]
157    #[doc(hidden)]
158    pub fn faer_par(&self) -> faer::Par {
159        if self.num_threads == 1 {
160            faer::Par::Seq
161        } else {
162            faer::Par::rayon(0)
163        }
164    }
165}
166
167#[cfg(test)]
168mod tests;