Skip to main content

tensor4all_treetn/linsolve/common/
environment.rs

1//! Generic environment cache for tensor network computations.
2//!
3//! Provides a shared infrastructure for caching environment tensors
4//! used in various algorithms (linsolve, fit, etc.).
5
6use std::collections::HashMap;
7use std::fmt::Debug;
8use std::hash::Hash;
9
10use tensor4all_core::IndexLike;
11use tensor4all_core::TensorLike;
12
13use crate::SiteIndexNetwork;
14
15/// Trait for network topology, used for cache invalidation traversal.
16pub trait NetworkTopology<V> {
17    /// Iterator over neighbors of a node.
18    type Neighbors<'a>: Iterator<Item = V>
19    where
20        Self: 'a,
21        V: 'a;
22
23    /// Get neighbors of a node.
24    fn neighbors(&self, node: &V) -> Self::Neighbors<'_>;
25}
26
27/// Simple environment cache for tensor network computations.
28///
29/// This struct handles:
30/// - Storing computed environment tensors
31/// - Cache invalidation when tensors are updated
32///
33/// The actual contraction logic is implemented in ProjectedOperator/ProjectedState.
34#[derive(Debug, Clone)]
35pub struct EnvironmentCache<T, V>
36where
37    T: TensorLike,
38    V: Clone + Hash + Eq,
39{
40    /// Cached environment tensors: (from, to) -> tensor
41    envs: HashMap<(V, V), T>,
42}
43
44impl<T, V> EnvironmentCache<T, V>
45where
46    T: TensorLike,
47    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
48{
49    /// Create a new empty environment cache.
50    pub fn new() -> Self {
51        Self {
52            envs: HashMap::new(),
53        }
54    }
55
56    /// Get a cached environment tensor if it exists.
57    pub fn get(&self, from: &V, to: &V) -> Option<&T> {
58        self.envs.get(&(from.clone(), to.clone()))
59    }
60
61    /// Insert an environment tensor.
62    pub fn insert(&mut self, from: V, to: V, env: T) {
63        self.envs.insert((from, to), env);
64    }
65
66    /// Check if environment exists for edge (from, to).
67    pub fn contains(&self, from: &V, to: &V) -> bool {
68        self.envs.contains_key(&(from.clone(), to.clone()))
69    }
70
71    /// Get the number of cached environments.
72    pub fn len(&self) -> usize {
73        self.envs.len()
74    }
75
76    /// Check if the cache is empty.
77    pub fn is_empty(&self) -> bool {
78        self.envs.is_empty()
79    }
80
81    /// Clear all cached environments.
82    pub fn clear(&mut self) {
83        self.envs.clear();
84    }
85
86    /// Invalidate all caches affected by updates to tensors in region.
87    ///
88    /// For each `t ∈ region`:
89    /// 1. Remove all `env[(t, *)]` (0th generation)
90    /// 2. Recursively remove caches propagating towards leaves
91    pub fn invalidate<'a, NT: NetworkTopology<V>>(
92        &mut self,
93        region: impl IntoIterator<Item = &'a V>,
94        topology: &NT,
95    ) where
96        V: 'a,
97    {
98        for t in region {
99            // Get all neighbors of t
100            let neighbors: Vec<V> = topology.neighbors(t).collect();
101
102            // Remove all env[(t, *)] and propagate recursively
103            for neighbor in neighbors {
104                self.invalidate_recursive(t, &neighbor, topology);
105            }
106        }
107    }
108
109    /// Recursively invalidate caches starting from env[(from, to)] towards leaves.
110    fn invalidate_recursive<NT: NetworkTopology<V>>(&mut self, from: &V, to: &V, topology: &NT) {
111        // Remove env[(from, to)] if it exists
112        if self.envs.remove(&(from.clone(), to.clone())).is_some() {
113            // Propagate to next generation: env[(to, x)] for all neighbors x of to, x ≠ from
114            let neighbors: Vec<V> = topology.neighbors(to).filter(|n| n != from).collect();
115
116            for neighbor in neighbors {
117                self.invalidate_recursive(to, &neighbor, topology);
118            }
119        }
120    }
121}
122
123impl<T, V> Default for EnvironmentCache<T, V>
124where
125    T: TensorLike,
126    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
127{
128    fn default() -> Self {
129        Self::new()
130    }
131}
132
133// ============================================================================
134// NetworkTopology implementations for SiteIndexNetwork
135// ============================================================================
136
137/// Implement NetworkTopology for SiteIndexNetwork.
138///
139/// This enables direct use of SiteIndexNetwork for cache invalidation
140/// and environment computation without needing adapter types like StaticTopology.
141impl<NodeName, I> NetworkTopology<NodeName> for SiteIndexNetwork<NodeName, I>
142where
143    NodeName: Clone + Hash + Eq + Send + Sync + Debug,
144    I: IndexLike,
145{
146    type Neighbors<'a>
147        = Box<dyn Iterator<Item = NodeName> + 'a>
148    where
149        Self: 'a,
150        NodeName: 'a;
151
152    fn neighbors(&self, node: &NodeName) -> Self::Neighbors<'_> {
153        Box::new(SiteIndexNetwork::neighbors(self, node))
154    }
155}
156
157#[cfg(test)]
158mod tests;