ad_tensors_rs/
runtime.rs

1use crate::context::{set_global_context, with_global_context, GlobalContextGuard};
2use crate::{Error, Result};
3use tenferro_prims::{CpuContext, CudaContext, RocmContext};
4
5/// Runtime execution context used by builder `.run()` entry points.
6///
7/// # Examples
8///
9/// ```rust
10/// use ad_tensors_rs::{set_default_runtime, RuntimeContext};
11/// use tenferro_prims::CpuContext;
12///
13/// let _guard = set_default_runtime(RuntimeContext::Cpu(CpuContext::new(1)));
14/// ```
15pub enum RuntimeContext {
16    /// CPU runtime context.
17    Cpu(CpuContext),
18    /// CUDA runtime context.
19    Cuda(CudaContext),
20    /// ROCm runtime context.
21    Rocm(RocmContext),
22}
23
24impl RuntimeContext {
25    /// Returns the runtime name.
26    ///
27    /// # Examples
28    ///
29    /// ```rust
30    /// use ad_tensors_rs::RuntimeContext;
31    /// use tenferro_prims::CpuContext;
32    ///
33    /// let rt = RuntimeContext::Cpu(CpuContext::new(1));
34    /// assert_eq!(rt.name(), "cpu");
35    /// ```
36    pub fn name(&self) -> &'static str {
37        match self {
38            Self::Cpu(_) => "cpu",
39            Self::Cuda(_) => "cuda",
40            Self::Rocm(_) => "rocm",
41        }
42    }
43}
44
45impl From<CpuContext> for RuntimeContext {
46    fn from(value: CpuContext) -> Self {
47        Self::Cpu(value)
48    }
49}
50
51impl From<CudaContext> for RuntimeContext {
52    fn from(value: CudaContext) -> Self {
53        Self::Cuda(value)
54    }
55}
56
57impl From<RocmContext> for RuntimeContext {
58    fn from(value: RocmContext) -> Self {
59        Self::Rocm(value)
60    }
61}
62
63/// Sets the default runtime context for builder `.run()`.
64///
65/// # Examples
66///
67/// ```rust
68/// use ad_tensors_rs::{set_default_runtime, RuntimeContext};
69/// use tenferro_prims::CpuContext;
70///
71/// let _guard = set_default_runtime(RuntimeContext::Cpu(CpuContext::new(1)));
72/// ```
73pub fn set_default_runtime(ctx: RuntimeContext) -> GlobalContextGuard<RuntimeContext> {
74    set_global_context(ctx)
75}
76
77/// Runs `f` with the default runtime context.
78///
79/// Returns [`Error::RuntimeNotConfigured`] when runtime is not configured.
80///
81/// # Examples
82///
83/// ```rust
84/// use ad_tensors_rs::{set_default_runtime, with_default_runtime, RuntimeContext};
85/// use tenferro_prims::CpuContext;
86///
87/// let _guard = set_default_runtime(RuntimeContext::Cpu(CpuContext::new(1)));
88/// let name = with_default_runtime(|rt| Ok(rt.name())).unwrap();
89/// assert_eq!(name, "cpu");
90/// ```
91pub fn with_default_runtime<R>(f: impl FnOnce(&mut RuntimeContext) -> Result<R>) -> Result<R> {
92    with_global_context::<RuntimeContext, _>(f).map_err(|err| match err {
93        Error::MissingGlobalContext { .. } => Error::RuntimeNotConfigured,
94        other => other,
95    })
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101
102    #[test]
103    fn default_runtime_roundtrip() {
104        let _guard = set_default_runtime(RuntimeContext::Cpu(CpuContext::new(1)));
105        let runtime = with_default_runtime(|ctx| Ok(ctx.name())).unwrap();
106        assert_eq!(runtime, "cpu");
107    }
108}