ad_tensors_rs/runtime.rs
1use crate::context::{set_global_context, with_global_context, GlobalContextGuard};
2use crate::{Error, Result};
3use tenferro_prims::{CpuContext, CudaContext, RocmContext};
4
5/// Runtime execution context used by builder `.run()` entry points.
6///
7/// # Examples
8///
9/// ```rust
10/// use ad_tensors_rs::{set_default_runtime, RuntimeContext};
11/// use tenferro_prims::CpuContext;
12///
13/// let _guard = set_default_runtime(RuntimeContext::Cpu(CpuContext::new(1)));
14/// ```
15pub enum RuntimeContext {
16 /// CPU runtime context.
17 Cpu(CpuContext),
18 /// CUDA runtime context.
19 Cuda(CudaContext),
20 /// ROCm runtime context.
21 Rocm(RocmContext),
22}
23
24impl RuntimeContext {
25 /// Returns the runtime name.
26 ///
27 /// # Examples
28 ///
29 /// ```rust
30 /// use ad_tensors_rs::RuntimeContext;
31 /// use tenferro_prims::CpuContext;
32 ///
33 /// let rt = RuntimeContext::Cpu(CpuContext::new(1));
34 /// assert_eq!(rt.name(), "cpu");
35 /// ```
36 pub fn name(&self) -> &'static str {
37 match self {
38 Self::Cpu(_) => "cpu",
39 Self::Cuda(_) => "cuda",
40 Self::Rocm(_) => "rocm",
41 }
42 }
43}
44
45impl From<CpuContext> for RuntimeContext {
46 fn from(value: CpuContext) -> Self {
47 Self::Cpu(value)
48 }
49}
50
51impl From<CudaContext> for RuntimeContext {
52 fn from(value: CudaContext) -> Self {
53 Self::Cuda(value)
54 }
55}
56
57impl From<RocmContext> for RuntimeContext {
58 fn from(value: RocmContext) -> Self {
59 Self::Rocm(value)
60 }
61}
62
63/// Sets the default runtime context for builder `.run()`.
64///
65/// # Examples
66///
67/// ```rust
68/// use ad_tensors_rs::{set_default_runtime, RuntimeContext};
69/// use tenferro_prims::CpuContext;
70///
71/// let _guard = set_default_runtime(RuntimeContext::Cpu(CpuContext::new(1)));
72/// ```
73pub fn set_default_runtime(ctx: RuntimeContext) -> GlobalContextGuard<RuntimeContext> {
74 set_global_context(ctx)
75}
76
77/// Runs `f` with the default runtime context.
78///
79/// Returns [`Error::RuntimeNotConfigured`] when runtime is not configured.
80///
81/// # Examples
82///
83/// ```rust
84/// use ad_tensors_rs::{set_default_runtime, with_default_runtime, RuntimeContext};
85/// use tenferro_prims::CpuContext;
86///
87/// let _guard = set_default_runtime(RuntimeContext::Cpu(CpuContext::new(1)));
88/// let name = with_default_runtime(|rt| Ok(rt.name())).unwrap();
89/// assert_eq!(name, "cpu");
90/// ```
91pub fn with_default_runtime<R>(f: impl FnOnce(&mut RuntimeContext) -> Result<R>) -> Result<R> {
92 with_global_context::<RuntimeContext, _>(f).map_err(|err| match err {
93 Error::MissingGlobalContext { .. } => Error::RuntimeNotConfigured,
94 other => other,
95 })
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101
102 #[test]
103 fn default_runtime_roundtrip() {
104 let _guard = set_default_runtime(RuntimeContext::Cpu(CpuContext::new(1)));
105 let runtime = with_default_runtime(|ctx| Ok(ctx.name())).unwrap();
106 assert_eq!(runtime, "cpu");
107 }
108}