tenferro_internal_runtime/
context.rs

1use std::cell::RefCell;
2
3use tenferro_internal_error::{Error, Result};
4use tenferro_prims::{CpuContext, CudaContext, RocmContext};
5
6thread_local! {
7    static DEFAULT_RUNTIME: RefCell<Option<RuntimeContext>> = const { RefCell::new(None) };
8}
9
10/// Runtime execution context used by builder `.run()` entry points.
11///
12/// Current status:
13///
14/// - `Cpu`: supported by builder `.run()` paths.
15/// - `Cuda`/`Rocm`: accepted as context values, but current builder execution
16///   paths may reject runtime-specific operations elsewhere.
17///
18/// # Examples
19///
20/// ```rust
21/// use tenferro_internal_runtime::{set_default_runtime, RuntimeContext};
22/// use tenferro_prims::CpuContext;
23///
24/// let _guard = set_default_runtime(RuntimeContext::Cpu(CpuContext::new(1)));
25/// ```
26pub enum RuntimeContext {
27    /// CPU runtime context.
28    Cpu(CpuContext),
29    /// CUDA runtime context.
30    Cuda(CudaContext),
31    /// ROCm runtime context.
32    Rocm(RocmContext),
33}
34
35impl RuntimeContext {
36    /// Returns the runtime name.
37    ///
38    /// # Examples
39    ///
40    /// ```rust
41    /// use tenferro_internal_runtime::RuntimeContext;
42    /// use tenferro_prims::CpuContext;
43    ///
44    /// let rt = RuntimeContext::Cpu(CpuContext::new(1));
45    /// assert_eq!(rt.name(), "cpu");
46    /// ```
47    pub fn name(&self) -> &'static str {
48        match self {
49            Self::Cpu(_) => "cpu",
50            Self::Cuda(_) => "cuda",
51            Self::Rocm(_) => "rocm",
52        }
53    }
54}
55
56impl From<CpuContext> for RuntimeContext {
57    fn from(value: CpuContext) -> Self {
58        Self::Cpu(value)
59    }
60}
61
62impl From<CudaContext> for RuntimeContext {
63    fn from(value: CudaContext) -> Self {
64        Self::Cuda(value)
65    }
66}
67
68impl From<RocmContext> for RuntimeContext {
69    fn from(value: RocmContext) -> Self {
70        Self::Rocm(value)
71    }
72}
73
74/// Guard returned by [`set_default_runtime`].
75///
76/// When dropped, the previous runtime context is restored.
77///
78/// # Examples
79///
80/// ```rust
81/// use tenferro_internal_runtime::{set_default_runtime, with_default_runtime, RuntimeContext};
82/// use tenferro_prims::CpuContext;
83///
84/// let _guard = set_default_runtime(RuntimeContext::Cpu(CpuContext::new(1)));
85/// let name = with_default_runtime(|ctx| Ok(ctx.name())).unwrap();
86/// assert_eq!(name, "cpu");
87/// ```
88pub struct DefaultRuntimeGuard {
89    previous: Option<RuntimeContext>,
90}
91
92impl Drop for DefaultRuntimeGuard {
93    fn drop(&mut self) {
94        DEFAULT_RUNTIME.with(|slot| {
95            *slot.borrow_mut() = self.previous.take();
96        });
97    }
98}
99
100/// Sets the default runtime context for builder `.run()`.
101///
102/// # Examples
103///
104/// ```rust
105/// use tenferro_internal_runtime::{set_default_runtime, RuntimeContext};
106/// use tenferro_prims::CpuContext;
107///
108/// let _guard = set_default_runtime(RuntimeContext::Cpu(CpuContext::new(1)));
109/// ```
110pub fn set_default_runtime(ctx: RuntimeContext) -> DefaultRuntimeGuard {
111    let previous = DEFAULT_RUNTIME.with(|slot| slot.borrow_mut().replace(ctx));
112    DefaultRuntimeGuard { previous }
113}
114
115/// Runs `f` with the default runtime context.
116///
117/// Returns [`Error::RuntimeNotConfigured`] when runtime is not configured.
118///
119/// # Examples
120///
121/// ```rust
122/// use tenferro_internal_runtime::{set_default_runtime, with_default_runtime, RuntimeContext};
123/// use tenferro_prims::CpuContext;
124///
125/// let _guard = set_default_runtime(RuntimeContext::Cpu(CpuContext::new(1)));
126/// let name = with_default_runtime(|ctx| Ok(ctx.name())).unwrap();
127/// assert_eq!(name, "cpu");
128/// ```
129pub fn with_default_runtime<R>(f: impl FnOnce(&mut RuntimeContext) -> Result<R>) -> Result<R> {
130    DEFAULT_RUNTIME.with(|slot| {
131        let mut slot = slot.borrow_mut();
132        let ctx = slot.as_mut().ok_or(Error::RuntimeNotConfigured)?;
133        f(ctx)
134    })
135}
136
137/// Runs `f` with an explicitly supplied runtime installed for the duration of
138/// the closure.
139///
140/// Any previously configured default runtime is restored afterwards, even when
141/// `f` returns an error.
142///
143/// # Examples
144///
145/// ```rust
146/// use tenferro_internal_runtime::{with_default_runtime, with_runtime, RuntimeContext};
147/// use tenferro_prims::CpuContext;
148///
149/// let name = with_runtime(RuntimeContext::Cpu(CpuContext::new(1)), || {
150///     with_default_runtime(|ctx| Ok(ctx.name()))
151/// })
152/// .unwrap();
153/// assert_eq!(name, "cpu");
154/// ```
155pub fn with_runtime<R>(ctx: RuntimeContext, f: impl FnOnce() -> Result<R>) -> Result<R> {
156    let _guard = set_default_runtime(ctx);
157    f()
158}