ad_tensors_rs/
context.rs

1use std::any::{type_name, Any, TypeId};
2use std::cell::RefCell;
3use std::collections::HashMap;
4use std::marker::PhantomData;
5
6use crate::{Error, Result};
7
8thread_local! {
9    static GLOBAL_CONTEXTS: RefCell<HashMap<TypeId, Box<dyn Any>>> = RefCell::new(HashMap::new());
10}
11
12/// Guard returned by [`set_global_context`].
13///
14/// When dropped, the previously installed context value (if any) is restored.
15///
16/// # Examples
17///
18/// ```rust
19/// use ad_tensors_rs::{set_global_context, with_global_context};
20///
21/// let _guard = set_global_context::<u32>(7);
22/// let v = with_global_context::<u32, _>(|ctx| Ok(*ctx)).unwrap();
23/// assert_eq!(v, 7);
24/// ```
25pub struct GlobalContextGuard<C: 'static> {
26    previous: Option<Box<dyn Any>>,
27    _marker: PhantomData<C>,
28}
29
30impl<C: 'static> Drop for GlobalContextGuard<C> {
31    fn drop(&mut self) {
32        GLOBAL_CONTEXTS.with(|contexts| {
33            let mut contexts = contexts.borrow_mut();
34            let key = TypeId::of::<C>();
35            contexts.remove(&key);
36            if let Some(previous) = self.previous.take() {
37                contexts.insert(key, previous);
38            }
39        });
40    }
41}
42
43/// Sets a thread-local global context for type `C`.
44///
45/// Returns a guard that restores the previous context on drop.
46///
47/// # Examples
48///
49/// ```rust
50/// use ad_tensors_rs::{set_global_context, with_global_context};
51///
52/// let guard = set_global_context::<usize>(123);
53/// let value = with_global_context::<usize, _>(|ctx| Ok(*ctx)).unwrap();
54/// assert_eq!(value, 123);
55/// drop(guard);
56/// ```
57pub fn set_global_context<C: 'static>(ctx: C) -> GlobalContextGuard<C> {
58    let previous = GLOBAL_CONTEXTS.with(|contexts| {
59        contexts
60            .borrow_mut()
61            .insert(TypeId::of::<C>(), Box::new(ctx) as Box<dyn Any>)
62    });
63
64    GlobalContextGuard {
65        previous,
66        _marker: PhantomData,
67    }
68}
69
70/// Runs `f` with a mutable reference to thread-local global context `C`.
71///
72/// Returns [`Error::MissingGlobalContext`] if no context is registered.
73///
74/// # Examples
75///
76/// ```rust
77/// use ad_tensors_rs::{set_global_context, with_global_context};
78///
79/// let _guard = set_global_context::<usize>(11);
80/// let value = with_global_context::<usize, _>(|ctx| {
81///     *ctx += 1;
82///     Ok(*ctx)
83/// })
84/// .unwrap();
85/// assert_eq!(value, 12);
86/// ```
87pub fn with_global_context<C: 'static, R>(f: impl FnOnce(&mut C) -> Result<R>) -> Result<R> {
88    GLOBAL_CONTEXTS.with(|contexts| {
89        let mut contexts = contexts.borrow_mut();
90        let erased =
91            contexts
92                .get_mut(&TypeId::of::<C>())
93                .ok_or_else(|| Error::MissingGlobalContext {
94                    type_name: type_name::<C>(),
95                })?;
96        let typed = erased
97            .downcast_mut::<C>()
98            .ok_or_else(|| Error::ContextTypeMismatch {
99                expected: type_name::<C>(),
100            })?;
101        f(typed)
102    })
103}
104
105/// Like [`with_global_context`] but returns `Ok(None)` when context is missing.
106///
107/// # Examples
108///
109/// ```rust
110/// use ad_tensors_rs::try_with_global_context;
111///
112/// let value = try_with_global_context::<u64, _>(|ctx| Ok(*ctx)).unwrap();
113/// assert_eq!(value, None);
114/// ```
115pub fn try_with_global_context<C: 'static, R>(
116    f: impl FnOnce(&mut C) -> Result<R>,
117) -> Result<Option<R>> {
118    GLOBAL_CONTEXTS.with(|contexts| {
119        let mut contexts = contexts.borrow_mut();
120        let Some(erased) = contexts.get_mut(&TypeId::of::<C>()) else {
121            return Ok(None);
122        };
123        let typed = erased
124            .downcast_mut::<C>()
125            .ok_or_else(|| Error::ContextTypeMismatch {
126                expected: type_name::<C>(),
127            })?;
128        f(typed).map(Some)
129    })
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135
136    #[test]
137    fn set_and_restore_context() {
138        let guard0 = set_global_context::<u32>(7);
139        let value = with_global_context::<u32, _>(|ctx| Ok(*ctx)).unwrap();
140        assert_eq!(value, 7);
141
142        let guard1 = set_global_context::<u32>(11);
143        let value = with_global_context::<u32, _>(|ctx| Ok(*ctx)).unwrap();
144        assert_eq!(value, 11);
145
146        drop(guard1);
147        let value = with_global_context::<u32, _>(|ctx| Ok(*ctx)).unwrap();
148        assert_eq!(value, 7);
149
150        drop(guard0);
151        let missing = with_global_context::<u32, _>(|ctx| Ok(*ctx));
152        assert!(matches!(missing, Err(Error::MissingGlobalContext { .. })));
153    }
154
155    #[test]
156    fn try_with_global_context_when_missing() {
157        let value = try_with_global_context::<usize, _>(|ctx| Ok(*ctx)).unwrap();
158        assert_eq!(value, None);
159    }
160}