tenferro_internal_runtime/context.rs
1use std::cell::RefCell;
2
3use tenferro_internal_error::{Error, Result};
4use tenferro_prims::{CpuContext, CudaContext, RocmContext};
5
6thread_local! {
7 static DEFAULT_RUNTIME: RefCell<Option<RuntimeContext>> = const { RefCell::new(None) };
8}
9
10/// Runtime execution context used by builder `.run()` entry points.
11///
12/// Current status:
13///
14/// - `Cpu`: supported by builder `.run()` paths.
15/// - `Cuda`/`Rocm`: accepted as context values, but current builder execution
16/// paths may reject runtime-specific operations elsewhere.
17///
18/// # Examples
19///
20/// ```rust
21/// use tenferro_internal_runtime::{set_default_runtime, RuntimeContext};
22/// use tenferro_prims::CpuContext;
23///
24/// let _guard = set_default_runtime(RuntimeContext::Cpu(CpuContext::new(1)));
25/// ```
26pub enum RuntimeContext {
27 /// CPU runtime context.
28 Cpu(CpuContext),
29 /// CUDA runtime context.
30 Cuda(CudaContext),
31 /// ROCm runtime context.
32 Rocm(RocmContext),
33}
34
35impl RuntimeContext {
36 /// Returns the runtime name.
37 ///
38 /// # Examples
39 ///
40 /// ```rust
41 /// use tenferro_internal_runtime::RuntimeContext;
42 /// use tenferro_prims::CpuContext;
43 ///
44 /// let rt = RuntimeContext::Cpu(CpuContext::new(1));
45 /// assert_eq!(rt.name(), "cpu");
46 /// ```
47 pub fn name(&self) -> &'static str {
48 match self {
49 Self::Cpu(_) => "cpu",
50 Self::Cuda(_) => "cuda",
51 Self::Rocm(_) => "rocm",
52 }
53 }
54}
55
56impl From<CpuContext> for RuntimeContext {
57 fn from(value: CpuContext) -> Self {
58 Self::Cpu(value)
59 }
60}
61
62impl From<CudaContext> for RuntimeContext {
63 fn from(value: CudaContext) -> Self {
64 Self::Cuda(value)
65 }
66}
67
68impl From<RocmContext> for RuntimeContext {
69 fn from(value: RocmContext) -> Self {
70 Self::Rocm(value)
71 }
72}
73
74/// Guard returned by [`set_default_runtime`].
75///
76/// When dropped, the previous runtime context is restored.
77///
78/// # Examples
79///
80/// ```rust
81/// use tenferro_internal_runtime::{set_default_runtime, with_default_runtime, RuntimeContext};
82/// use tenferro_prims::CpuContext;
83///
84/// let _guard = set_default_runtime(RuntimeContext::Cpu(CpuContext::new(1)));
85/// let name = with_default_runtime(|ctx| Ok(ctx.name())).unwrap();
86/// assert_eq!(name, "cpu");
87/// ```
88pub struct DefaultRuntimeGuard {
89 previous: Option<RuntimeContext>,
90}
91
92impl Drop for DefaultRuntimeGuard {
93 fn drop(&mut self) {
94 DEFAULT_RUNTIME.with(|slot| {
95 *slot.borrow_mut() = self.previous.take();
96 });
97 }
98}
99
100/// Sets the default runtime context for builder `.run()`.
101///
102/// # Examples
103///
104/// ```rust
105/// use tenferro_internal_runtime::{set_default_runtime, RuntimeContext};
106/// use tenferro_prims::CpuContext;
107///
108/// let _guard = set_default_runtime(RuntimeContext::Cpu(CpuContext::new(1)));
109/// ```
110pub fn set_default_runtime(ctx: RuntimeContext) -> DefaultRuntimeGuard {
111 let previous = DEFAULT_RUNTIME.with(|slot| slot.borrow_mut().replace(ctx));
112 DefaultRuntimeGuard { previous }
113}
114
115/// Runs `f` with the default runtime context.
116///
117/// Returns [`Error::RuntimeNotConfigured`] when runtime is not configured.
118///
119/// # Examples
120///
121/// ```rust
122/// use tenferro_internal_runtime::{set_default_runtime, with_default_runtime, RuntimeContext};
123/// use tenferro_prims::CpuContext;
124///
125/// let _guard = set_default_runtime(RuntimeContext::Cpu(CpuContext::new(1)));
126/// let name = with_default_runtime(|ctx| Ok(ctx.name())).unwrap();
127/// assert_eq!(name, "cpu");
128/// ```
129pub fn with_default_runtime<R>(f: impl FnOnce(&mut RuntimeContext) -> Result<R>) -> Result<R> {
130 DEFAULT_RUNTIME.with(|slot| {
131 let mut slot = slot.borrow_mut();
132 let ctx = slot.as_mut().ok_or(Error::RuntimeNotConfigured)?;
133 f(ctx)
134 })
135}
136
137/// Runs `f` with an explicitly supplied runtime installed for the duration of
138/// the closure.
139///
140/// Any previously configured default runtime is restored afterwards, even when
141/// `f` returns an error.
142///
143/// # Examples
144///
145/// ```rust
146/// use tenferro_internal_runtime::{with_default_runtime, with_runtime, RuntimeContext};
147/// use tenferro_prims::CpuContext;
148///
149/// let name = with_runtime(RuntimeContext::Cpu(CpuContext::new(1)), || {
150/// with_default_runtime(|ctx| Ok(ctx.name()))
151/// })
152/// .unwrap();
153/// assert_eq!(name, "cpu");
154/// ```
155pub fn with_runtime<R>(ctx: RuntimeContext, f: impl FnOnce() -> Result<R>) -> Result<R> {
156 let _guard = set_default_runtime(ctx);
157 f()
158}