tensor4all_tensorbackend/
context.rs1use std::sync::{Arc, Mutex, OnceLock};
12
13use tenferro::{CpuBackend, EagerContext, Engine};
14use tenferro_tensor::buffer_pool::BufferPoolStats;
15use tenferro_tensor::cpu::CpuContext;
16
17static DEFAULT_CPU_CONTEXT: OnceLock<Arc<CpuContext>> = OnceLock::new();
18static DEFAULT_BACKEND: OnceLock<Mutex<CpuBackend>> = OnceLock::new();
19static DEFAULT_ENGINE: OnceLock<Mutex<Engine<CpuBackend>>> = OnceLock::new();
20static DEFAULT_EAGER_CTX: OnceLock<Arc<EagerContext<CpuBackend>>> = OnceLock::new();
21
22fn default_cpu_context() -> Arc<CpuContext> {
23 DEFAULT_CPU_CONTEXT
24 .get_or_init(|| Arc::new(CpuContext::from_env()))
25 .clone()
26}
27
28fn default_backend() -> &'static Mutex<CpuBackend> {
29 DEFAULT_BACKEND.get_or_init(|| Mutex::new(CpuBackend::from_context(default_cpu_context())))
30}
31
32fn default_engine() -> &'static Mutex<Engine<CpuBackend>> {
33 DEFAULT_ENGINE
34 .get_or_init(|| Mutex::new(Engine::new(CpuBackend::from_context(default_cpu_context()))))
35}
36
37fn lock_default_engine() -> std::sync::MutexGuard<'static, Engine<CpuBackend>> {
38 match default_engine().lock() {
39 Ok(guard) => guard,
40 Err(poisoned) => poisoned.into_inner(),
41 }
42}
43
44pub fn with_default_backend<R>(f: impl FnOnce(&mut CpuBackend) -> R) -> R {
49 let mut backend = match default_backend().lock() {
50 Ok(guard) => guard,
51 Err(poisoned) => poisoned.into_inner(),
52 };
53 f(&mut backend)
54}
55
56pub(crate) fn with_default_engine<R>(f: impl FnOnce(&mut Engine<CpuBackend>) -> R) -> R {
61 let mut engine = lock_default_engine();
62 f(&mut engine)
63}
64
65pub(crate) fn default_engine_buffer_pool_stats() -> BufferPoolStats {
67 lock_default_engine().buffer_pool_stats()
68}
69
70pub(crate) fn reset_default_engine_buffer_pool() {
72 lock_default_engine().reset_buffer_pool();
73}
74
75pub(crate) fn reset_default_engine() {
81 let mut engine = lock_default_engine();
82 *engine = Engine::new(CpuBackend::from_context(default_cpu_context()));
83}
84
85pub fn default_eager_ctx() -> Arc<EagerContext<CpuBackend>> {
91 DEFAULT_EAGER_CTX
92 .get_or_init(|| EagerContext::with_backend(CpuBackend::from_context(default_cpu_context())))
93 .clone()
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99
100 #[test]
101 fn eager_context_is_process_global() {
102 let first = default_eager_ctx();
103 let second = default_eager_ctx();
104
105 assert!(Arc::ptr_eq(&first, &second));
106 }
107
108 #[test]
109 fn eager_context_is_shared_across_threads() {
110 let main_context = default_eager_ctx();
111 let worker_context = std::thread::spawn(default_eager_ctx)
112 .join()
113 .expect("worker thread should complete");
114
115 assert!(Arc::ptr_eq(&main_context, &worker_context));
116 }
117
118 #[test]
119 fn default_backend_is_shared_across_threads() {
120 let main_threads = with_default_backend(|backend| backend.num_threads());
121 let worker_threads =
122 std::thread::spawn(|| with_default_backend(|backend| backend.num_threads()))
123 .join()
124 .expect("worker thread should complete");
125
126 assert_eq!(main_threads, worker_threads);
127 }
128
129 #[test]
130 fn default_engine_is_shared_across_threads() {
131 let main_threads = with_default_engine(|engine| engine.backend().num_threads());
132 let worker_threads =
133 std::thread::spawn(|| with_default_engine(|engine| engine.backend().num_threads()))
134 .join()
135 .expect("worker thread should complete");
136
137 assert_eq!(main_threads, worker_threads);
138 }
139}