1use std::any::{type_name, Any, TypeId};
2use std::cell::RefCell;
3use std::collections::HashMap;
4use std::marker::PhantomData;
5
6use crate::{Error, Result};
7
8thread_local! {
9 static GLOBAL_CONTEXTS: RefCell<HashMap<TypeId, Box<dyn Any>>> = RefCell::new(HashMap::new());
10}
11
12pub struct GlobalContextGuard<C: 'static> {
26 previous: Option<Box<dyn Any>>,
27 _marker: PhantomData<C>,
28}
29
30impl<C: 'static> Drop for GlobalContextGuard<C> {
31 fn drop(&mut self) {
32 GLOBAL_CONTEXTS.with(|contexts| {
33 let mut contexts = contexts.borrow_mut();
34 let key = TypeId::of::<C>();
35 contexts.remove(&key);
36 if let Some(previous) = self.previous.take() {
37 contexts.insert(key, previous);
38 }
39 });
40 }
41}
42
43pub fn set_global_context<C: 'static>(ctx: C) -> GlobalContextGuard<C> {
58 let previous = GLOBAL_CONTEXTS.with(|contexts| {
59 contexts
60 .borrow_mut()
61 .insert(TypeId::of::<C>(), Box::new(ctx) as Box<dyn Any>)
62 });
63
64 GlobalContextGuard {
65 previous,
66 _marker: PhantomData,
67 }
68}
69
70pub fn with_global_context<C: 'static, R>(f: impl FnOnce(&mut C) -> Result<R>) -> Result<R> {
88 GLOBAL_CONTEXTS.with(|contexts| {
89 let mut contexts = contexts.borrow_mut();
90 let erased =
91 contexts
92 .get_mut(&TypeId::of::<C>())
93 .ok_or_else(|| Error::MissingGlobalContext {
94 type_name: type_name::<C>(),
95 })?;
96 let typed = erased
97 .downcast_mut::<C>()
98 .ok_or_else(|| Error::ContextTypeMismatch {
99 expected: type_name::<C>(),
100 })?;
101 f(typed)
102 })
103}
104
105pub fn try_with_global_context<C: 'static, R>(
116 f: impl FnOnce(&mut C) -> Result<R>,
117) -> Result<Option<R>> {
118 GLOBAL_CONTEXTS.with(|contexts| {
119 let mut contexts = contexts.borrow_mut();
120 let Some(erased) = contexts.get_mut(&TypeId::of::<C>()) else {
121 return Ok(None);
122 };
123 let typed = erased
124 .downcast_mut::<C>()
125 .ok_or_else(|| Error::ContextTypeMismatch {
126 expected: type_name::<C>(),
127 })?;
128 f(typed).map(Some)
129 })
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135
136 #[test]
137 fn set_and_restore_context() {
138 let guard0 = set_global_context::<u32>(7);
139 let value = with_global_context::<u32, _>(|ctx| Ok(*ctx)).unwrap();
140 assert_eq!(value, 7);
141
142 let guard1 = set_global_context::<u32>(11);
143 let value = with_global_context::<u32, _>(|ctx| Ok(*ctx)).unwrap();
144 assert_eq!(value, 11);
145
146 drop(guard1);
147 let value = with_global_context::<u32, _>(|ctx| Ok(*ctx)).unwrap();
148 assert_eq!(value, 7);
149
150 drop(guard0);
151 let missing = with_global_context::<u32, _>(|ctx| Ok(*ctx));
152 assert!(matches!(missing, Err(Error::MissingGlobalContext { .. })));
153 }
154
155 #[test]
156 fn try_with_global_context_when_missing() {
157 let value = try_with_global_context::<usize, _>(|ctx| Ok(*ctx)).unwrap();
158 assert_eq!(value, None);
159 }
160}