tensor4all_treetn/linsolve/common/
environment.rs1use std::collections::HashMap;
7use std::fmt::Debug;
8use std::hash::Hash;
9
10use tensor4all_core::IndexLike;
11use tensor4all_core::TensorLike;
12
13use crate::SiteIndexNetwork;
14
15pub trait NetworkTopology<V> {
17 type Neighbors<'a>: Iterator<Item = V>
19 where
20 Self: 'a,
21 V: 'a;
22
23 fn neighbors(&self, node: &V) -> Self::Neighbors<'_>;
25}
26
27#[derive(Debug, Clone)]
35pub struct EnvironmentCache<T, V>
36where
37 T: TensorLike,
38 V: Clone + Hash + Eq,
39{
40 envs: HashMap<(V, V), T>,
42}
43
44impl<T, V> EnvironmentCache<T, V>
45where
46 T: TensorLike,
47 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
48{
49 pub fn new() -> Self {
51 Self {
52 envs: HashMap::new(),
53 }
54 }
55
56 pub fn get(&self, from: &V, to: &V) -> Option<&T> {
58 self.envs.get(&(from.clone(), to.clone()))
59 }
60
61 pub fn insert(&mut self, from: V, to: V, env: T) {
63 self.envs.insert((from, to), env);
64 }
65
66 pub fn contains(&self, from: &V, to: &V) -> bool {
68 self.envs.contains_key(&(from.clone(), to.clone()))
69 }
70
71 pub fn len(&self) -> usize {
73 self.envs.len()
74 }
75
76 pub fn is_empty(&self) -> bool {
78 self.envs.is_empty()
79 }
80
81 pub fn clear(&mut self) {
83 self.envs.clear();
84 }
85
86 pub fn invalidate<'a, NT: NetworkTopology<V>>(
92 &mut self,
93 region: impl IntoIterator<Item = &'a V>,
94 topology: &NT,
95 ) where
96 V: 'a,
97 {
98 for t in region {
99 let neighbors: Vec<V> = topology.neighbors(t).collect();
101
102 for neighbor in neighbors {
104 self.invalidate_recursive(t, &neighbor, topology);
105 }
106 }
107 }
108
109 fn invalidate_recursive<NT: NetworkTopology<V>>(&mut self, from: &V, to: &V, topology: &NT) {
111 if self.envs.remove(&(from.clone(), to.clone())).is_some() {
113 let neighbors: Vec<V> = topology.neighbors(to).filter(|n| n != from).collect();
115
116 for neighbor in neighbors {
117 self.invalidate_recursive(to, &neighbor, topology);
118 }
119 }
120 }
121}
122
123impl<T, V> Default for EnvironmentCache<T, V>
124where
125 T: TensorLike,
126 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
127{
128 fn default() -> Self {
129 Self::new()
130 }
131}
132
133impl<NodeName, I> NetworkTopology<NodeName> for SiteIndexNetwork<NodeName, I>
142where
143 NodeName: Clone + Hash + Eq + Send + Sync + Debug,
144 I: IndexLike,
145{
146 type Neighbors<'a>
147 = Box<dyn Iterator<Item = NodeName> + 'a>
148 where
149 Self: 'a,
150 NodeName: 'a;
151
152 fn neighbors(&self, node: &NodeName) -> Self::Neighbors<'_> {
153 Box::new(SiteIndexNetwork::neighbors(self, node))
154 }
155}
156
157#[cfg(test)]
158mod tests;