Skip to main content

tensor4all_tensorbackend/
context.rs

1//! Process-global tenferro CPU execution helpers.
2//!
3//! tensor4all-rs routes tenferro CPU execution through one process-global
4//! `CpuContext`, matching tenferro's `cpu:0` default-global thread-pool model.
5//! Plain tensor operations, cached traced execution, and eager AD currently use
6//! separate `CpuBackend` values because tenferro does not yet expose a public
7//! API for borrowing the backend owned by an `EagerContext<CpuBackend>`. All
8//! backends are created from the same global CPU context, so thread-pool
9//! configuration is shared.
10
11use std::sync::{Arc, Mutex, OnceLock};
12
13use tenferro::{CpuBackend, EagerContext, Engine};
14use tenferro_tensor::buffer_pool::BufferPoolStats;
15use tenferro_tensor::cpu::CpuContext;
16
17static DEFAULT_CPU_CONTEXT: OnceLock<Arc<CpuContext>> = OnceLock::new();
18static DEFAULT_BACKEND: OnceLock<Mutex<CpuBackend>> = OnceLock::new();
19static DEFAULT_ENGINE: OnceLock<Mutex<Engine<CpuBackend>>> = OnceLock::new();
20static DEFAULT_EAGER_CTX: OnceLock<Arc<EagerContext<CpuBackend>>> = OnceLock::new();
21
22fn default_cpu_context() -> Arc<CpuContext> {
23    DEFAULT_CPU_CONTEXT
24        .get_or_init(|| Arc::new(CpuContext::from_env()))
25        .clone()
26}
27
28fn default_backend() -> &'static Mutex<CpuBackend> {
29    DEFAULT_BACKEND.get_or_init(|| Mutex::new(CpuBackend::from_context(default_cpu_context())))
30}
31
32fn default_engine() -> &'static Mutex<Engine<CpuBackend>> {
33    DEFAULT_ENGINE
34        .get_or_init(|| Mutex::new(Engine::new(CpuBackend::from_context(default_cpu_context()))))
35}
36
37fn lock_default_engine() -> std::sync::MutexGuard<'static, Engine<CpuBackend>> {
38    match default_engine().lock() {
39        Ok(guard) => guard,
40        Err(poisoned) => poisoned.into_inner(),
41    }
42}
43
44/// Run a closure against the process-global CPU backend.
45///
46/// This is the canonical entry point for typed and untyped tenferro tensor
47/// operations inside `tensor4all-tensorbackend`.
48pub fn with_default_backend<R>(f: impl FnOnce(&mut CpuBackend) -> R) -> R {
49    let mut backend = match default_backend().lock() {
50        Ok(guard) => guard,
51        Err(poisoned) => poisoned.into_inner(),
52    };
53    f(&mut backend)
54}
55
56/// Run a closure against the process-global tenferro execution engine.
57///
58/// This is used for native tensor operations that benefit from tenferro's
59/// persistent execution caches, such as N-ary einsum contraction paths.
60pub(crate) fn with_default_engine<R>(f: impl FnOnce(&mut Engine<CpuBackend>) -> R) -> R {
61    let mut engine = lock_default_engine();
62    f(&mut engine)
63}
64
65/// Return retained-buffer statistics for the process-global execution engine.
66pub(crate) fn default_engine_buffer_pool_stats() -> BufferPoolStats {
67    lock_default_engine().buffer_pool_stats()
68}
69
70/// Reset retained buffers in the process-global execution engine.
71pub(crate) fn reset_default_engine_buffer_pool() {
72    lock_default_engine().reset_buffer_pool();
73}
74
75/// Drop and recreate the process-global execution engine.
76///
77/// This releases tenferro's retained execution buffers and cached contraction
78/// paths. It is intended for diagnostics and memory-pressure recovery, not for
79/// normal hot loops where the caches are valuable.
80pub(crate) fn reset_default_engine() {
81    let mut engine = lock_default_engine();
82    *engine = Engine::new(CpuBackend::from_context(default_cpu_context()));
83}
84
85/// Return the process-global eager context used for reverse-mode AD.
86///
87/// This context owns a separate `CpuBackend` from [`with_default_backend`] and
88/// the cached execution engine, but all backends share the same process-global
89/// tenferro CPU context.
90pub fn default_eager_ctx() -> Arc<EagerContext<CpuBackend>> {
91    DEFAULT_EAGER_CTX
92        .get_or_init(|| EagerContext::with_backend(CpuBackend::from_context(default_cpu_context())))
93        .clone()
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99
100    #[test]
101    fn eager_context_is_process_global() {
102        let first = default_eager_ctx();
103        let second = default_eager_ctx();
104
105        assert!(Arc::ptr_eq(&first, &second));
106    }
107
108    #[test]
109    fn eager_context_is_shared_across_threads() {
110        let main_context = default_eager_ctx();
111        let worker_context = std::thread::spawn(default_eager_ctx)
112            .join()
113            .expect("worker thread should complete");
114
115        assert!(Arc::ptr_eq(&main_context, &worker_context));
116    }
117
118    #[test]
119    fn default_backend_is_shared_across_threads() {
120        let main_threads = with_default_backend(|backend| backend.num_threads());
121        let worker_threads =
122            std::thread::spawn(|| with_default_backend(|backend| backend.num_threads()))
123                .join()
124                .expect("worker thread should complete");
125
126        assert_eq!(main_threads, worker_threads);
127    }
128
129    #[test]
130    fn default_engine_is_shared_across_threads() {
131        let main_threads = with_default_engine(|engine| engine.backend().num_threads());
132        let worker_threads =
133            std::thread::spawn(|| with_default_engine(|engine| engine.backend().num_threads()))
134                .join()
135                .expect("worker thread should complete");
136
137        assert_eq!(main_threads, worker_threads);
138    }
139}