Skip to main content

tensor4all_treetn/
site_index_network.rs

1//! Site Index Network (inspired by ITensorNetworks.jl's IndsNetwork)
2//!
3//! Provides a structure combining:
4//! - **NodeNameNetwork**: Graph topology (node connections)
5//! - **Site space map**: Physical indices at each node (`HashMap<NodeName, HashSet<I>>`)
6//!
7//! This design separates the index structure from tensor data,
8//! enabling topology and site space comparison independent of tensor values.
9
10use crate::node_name_network::{CanonicalizeEdges, NodeNameNetwork};
11use petgraph::stable_graph::{EdgeIndex, NodeIndex, StableGraph};
12use petgraph::Undirected;
13use std::collections::{HashMap, HashSet};
14use std::fmt::Debug;
15use std::hash::Hash;
16use tensor4all_core::IndexLike;
17
18// Re-export CanonicalizeEdges for convenience
19pub use crate::node_name_network::CanonicalizeEdges as CanonicalizeEdgesType;
20
21/// Site Index Network (inspired by ITensorNetworks.jl's IndsNetwork)
22///
23/// Represents the index structure of a tensor network:
24/// - **Topology**: Graph structure via `NodeNameNetwork`
25/// - **Site space**: Physical indices at each node via `HashMap`
26///
27/// This structure enables:
28/// - Comparing topologies and site spaces independently of tensor data
29/// - Extracting index information without accessing tensor values
30/// - Validating network structure consistency
31///
32/// # Type Parameters
33/// - `NodeName`: Node name type (must be Clone, Hash, Eq, Send, Sync, Debug)
34/// - `I`: Index type (must implement `IndexLike`)
35///
36/// # Examples
37///
38/// ```
39/// use std::collections::HashSet;
40/// use tensor4all_core::index::{DynId, Index, TagSet};
41/// use tensor4all_treetn::SiteIndexNetwork;
42///
43/// let mut net = SiteIndexNetwork::<String, Index<DynId, TagSet>>::new();
44/// let idx_a = Index::new_dyn(2);
45/// let idx_b = Index::new_dyn(3);
46///
47/// net.add_node("A".to_string(), HashSet::from([idx_a])).unwrap();
48/// net.add_node("B".to_string(), HashSet::from([idx_b])).unwrap();
49/// net.add_edge(&"A".to_string(), &"B".to_string()).unwrap();
50///
51/// assert_eq!(net.node_count(), 2);
52/// assert_eq!(net.edge_count(), 1);
53/// ```
54#[derive(Debug, Clone)]
55pub struct SiteIndexNetwork<NodeName, I>
56where
57    NodeName: Clone + Hash + Eq + Send + Sync + Debug,
58    I: IndexLike,
59{
60    /// Graph topology (node names and connections only).
61    topology: NodeNameNetwork<NodeName>,
62    /// Site space (physical indices) for each node.
63    site_spaces: HashMap<NodeName, HashSet<I>>,
64    /// Reverse lookup: index ID → node name containing this index.
65    index_to_node: HashMap<I::Id, NodeName>,
66}
67
68impl<NodeName, I> SiteIndexNetwork<NodeName, I>
69where
70    NodeName: Clone + Hash + Eq + Send + Sync + Debug,
71    I: IndexLike,
72{
73    /// Create a new empty SiteIndexNetwork.
74    pub fn new() -> Self {
75        Self {
76            topology: NodeNameNetwork::new(),
77            site_spaces: HashMap::new(),
78            index_to_node: HashMap::new(),
79        }
80    }
81
82    /// Create a new SiteIndexNetwork with initial capacity.
83    pub fn with_capacity(nodes: usize, edges: usize) -> Self {
84        Self {
85            topology: NodeNameNetwork::with_capacity(nodes, edges),
86            site_spaces: HashMap::with_capacity(nodes),
87            index_to_node: HashMap::new(),
88        }
89    }
90
91    /// Add a node with site space (physical indices).
92    ///
93    /// # Arguments
94    /// * `node_name` - The name of the node
95    /// * `site_space` - The physical indices at this node (order doesn't matter)
96    ///
97    /// Returns an error if the node already exists.
98    pub fn add_node(
99        &mut self,
100        node_name: NodeName,
101        site_space: impl Into<HashSet<I>>,
102    ) -> Result<NodeIndex, String> {
103        let node_idx = self.topology.add_node(node_name.clone())?;
104        let site_space_set = site_space.into();
105        // Update reverse lookup for all indices
106        for idx in &site_space_set {
107            self.index_to_node
108                .insert(idx.id().clone(), node_name.clone());
109        }
110        self.site_spaces.insert(node_name, site_space_set);
111        Ok(node_idx)
112    }
113
114    /// Check if a node exists.
115    pub fn has_node(&self, node_name: &NodeName) -> bool {
116        self.topology.has_node(node_name)
117    }
118
119    /// Rename an existing node and preserve its site-space metadata.
120    pub fn rename_node(&mut self, old_name: &NodeName, new_name: NodeName) -> Result<(), String> {
121        if old_name == &new_name {
122            return Ok(());
123        }
124
125        let site_space = self
126            .site_spaces
127            .remove(old_name)
128            .ok_or_else(|| format!("Node {:?} not found", old_name))?;
129        self.topology.rename_node(old_name, new_name.clone())?;
130        for index in &site_space {
131            self.index_to_node
132                .insert(index.id().clone(), new_name.clone());
133        }
134        self.site_spaces.insert(new_name, site_space);
135        Ok(())
136    }
137
138    /// Get the site space (physical indices) for a node.
139    pub fn site_space(&self, node_name: &NodeName) -> Option<&HashSet<I>> {
140        self.site_spaces.get(node_name)
141    }
142
143    /// Get a mutable reference to the site space for a node.
144    ///
145    /// **Warning**: Direct modification of site space via this method does NOT
146    /// update the reverse lookup (`index_to_node`). Use `add_site_index()`,
147    /// `remove_site_index()`, or `replace_site_index()` for modifications
148    /// that maintain consistency.
149    pub fn site_space_mut(&mut self, node_name: &NodeName) -> Option<&mut HashSet<I>> {
150        self.site_spaces.get_mut(node_name)
151    }
152
153    /// Find the node containing a given site index.
154    ///
155    /// # Arguments
156    /// * `index` - The index to look up
157    ///
158    /// # Returns
159    /// The node name containing this index, or None if not found.
160    pub fn find_node_by_index(&self, index: &I) -> Option<&NodeName> {
161        self.index_to_node.get(index.id())
162    }
163
164    /// Find the node containing an index by ID.
165    pub fn find_node_by_index_id(&self, id: &I::Id) -> Option<&NodeName> {
166        self.index_to_node.get(id)
167    }
168
169    /// Check if a site index is registered.
170    pub fn contains_index(&self, index: &I) -> bool {
171        self.index_to_node.contains_key(index.id())
172    }
173
174    /// Add a site index to a node's site space.
175    ///
176    /// Updates both the site space and the reverse lookup.
177    pub fn add_site_index(&mut self, node_name: &NodeName, index: I) -> Result<(), String> {
178        let site_space = self
179            .site_spaces
180            .get_mut(node_name)
181            .ok_or_else(|| format!("Node {:?} not found", node_name))?;
182        site_space.insert(index.clone());
183        self.index_to_node
184            .insert(index.id().clone(), node_name.clone());
185        Ok(())
186    }
187
188    /// Remove a site index from a node's site space.
189    ///
190    /// Updates both the site space and the reverse lookup.
191    pub fn remove_site_index(&mut self, node_name: &NodeName, index: &I) -> Result<bool, String> {
192        let site_space = self
193            .site_spaces
194            .get_mut(node_name)
195            .ok_or_else(|| format!("Node {:?} not found", node_name))?;
196        let removed = site_space.remove(index);
197        if removed {
198            self.index_to_node.remove(index.id());
199        }
200        Ok(removed)
201    }
202
203    /// Replace a site index in a node's site space.
204    ///
205    /// Updates both the site space and the reverse lookup.
206    pub fn replace_site_index(
207        &mut self,
208        node_name: &NodeName,
209        old_index: &I,
210        new_index: I,
211    ) -> Result<(), String> {
212        let site_space = self
213            .site_spaces
214            .get_mut(node_name)
215            .ok_or_else(|| format!("Node {:?} not found", node_name))?;
216        if !site_space.remove(old_index) {
217            return Err(format!(
218                "Index {:?} not found in node {:?}",
219                old_index.id(),
220                node_name
221            ));
222        }
223        self.index_to_node.remove(old_index.id());
224        site_space.insert(new_index.clone());
225        self.index_to_node
226            .insert(new_index.id().clone(), node_name.clone());
227        Ok(())
228    }
229
230    /// Replace all site indices for a node with a new set.
231    ///
232    /// Updates both the site space and the reverse lookup.
233    /// This is an atomic operation that removes all old indices and adds all new ones.
234    pub fn set_site_space(
235        &mut self,
236        node_name: &NodeName,
237        new_indices: HashSet<I>,
238    ) -> Result<(), String> {
239        let site_space = self
240            .site_spaces
241            .get_mut(node_name)
242            .ok_or_else(|| format!("Node {:?} not found", node_name))?;
243
244        // Remove old indices from index_to_node
245        for old_idx in site_space.iter() {
246            if self.index_to_node.get(old_idx.id()) == Some(node_name) {
247                self.index_to_node.remove(old_idx.id());
248            }
249        }
250
251        // Add new indices to index_to_node
252        for new_idx in &new_indices {
253            self.index_to_node
254                .insert(new_idx.id().clone(), node_name.clone());
255        }
256
257        // Replace site space
258        *site_space = new_indices;
259
260        Ok(())
261    }
262
263    /// Get the site space by NodeIndex.
264    pub fn site_space_by_index(&self, node: NodeIndex) -> Option<&HashSet<I>> {
265        let name = self.topology.node_name(node)?;
266        self.site_spaces.get(name)
267    }
268
269    /// Add an edge between two nodes.
270    ///
271    /// Returns an error if either node doesn't exist.
272    pub fn add_edge(&mut self, n1: &NodeName, n2: &NodeName) -> Result<EdgeIndex, String> {
273        self.topology.add_edge(n1, n2)
274    }
275
276    /// Get the NodeIndex for a node name.
277    pub fn node_index(&self, node_name: &NodeName) -> Option<NodeIndex> {
278        self.topology.node_index(node_name)
279    }
280
281    /// Get the node name for a NodeIndex.
282    pub fn node_name(&self, node: NodeIndex) -> Option<&NodeName> {
283        self.topology.node_name(node)
284    }
285
286    /// Get all node names.
287    pub fn node_names(&self) -> Vec<&NodeName> {
288        self.topology.node_names()
289    }
290
291    /// Get the number of nodes.
292    pub fn node_count(&self) -> usize {
293        self.topology.node_count()
294    }
295
296    /// Get the number of edges.
297    pub fn edge_count(&self) -> usize {
298        self.topology.edge_count()
299    }
300
301    /// Get a reference to the underlying topology (NodeNameNetwork).
302    pub fn topology(&self) -> &NodeNameNetwork<NodeName> {
303        &self.topology
304    }
305
306    /// Get all edges as pairs of node names.
307    ///
308    /// Returns an iterator of `(NodeName, NodeName)` pairs.
309    pub fn edges(&self) -> impl Iterator<Item = (NodeName, NodeName)> + '_ {
310        let graph = self.topology.graph();
311        graph.edge_indices().filter_map(move |edge| {
312            let (a, b) = graph.edge_endpoints(edge)?;
313            let name_a = self.topology.node_name(a)?.clone();
314            let name_b = self.topology.node_name(b)?.clone();
315            Some((name_a, name_b))
316        })
317    }
318
319    /// Get all neighbors of a node.
320    ///
321    /// Returns an iterator of neighbor node names.
322    pub fn neighbors(&self, node_name: &NodeName) -> impl Iterator<Item = NodeName> + '_ {
323        let node_idx = self.topology.node_index(node_name);
324        let graph = self.topology.graph();
325        let topology = &self.topology;
326
327        node_idx
328            .into_iter()
329            .flat_map(move |idx| graph.neighbors(idx))
330            .filter_map(move |n| topology.node_name(n).cloned())
331    }
332
333    /// Get a reference to the internal graph.
334    pub fn graph(&self) -> &StableGraph<(), (), Undirected> {
335        self.topology.graph()
336    }
337
338    /// Get a mutable reference to the internal graph.
339    ///
340    /// **Warning**: Directly modifying the internal graph can break consistency.
341    pub fn graph_mut(&mut self) -> &mut StableGraph<(), (), Undirected> {
342        self.topology.graph_mut()
343    }
344
345    /// Check if two SiteIndexNetworks share equivalent site index structure.
346    ///
347    /// Two networks are equivalent if:
348    /// - Same topology (nodes and edges)
349    /// - Same site space for each node
350    ///
351    /// This is used to verify that two TreeTNs can be added or contracted.
352    pub fn share_equivalent_site_index_network(&self, other: &Self) -> bool {
353        // Check topology
354        if !self.topology.same_topology(&other.topology) {
355            return false;
356        }
357
358        // Check site spaces
359        for name in self.node_names() {
360            match (self.site_space(name), other.site_space(name)) {
361                (Some(self_indices), Some(other_indices)) => {
362                    if self_indices != other_indices {
363                        return false;
364                    }
365                }
366                (None, None) => continue,
367                _ => return false,
368            }
369        }
370
371        true
372    }
373
374    // =========================================================================
375    // Delegated graph algorithms (from NodeNameNetwork)
376    // =========================================================================
377
378    /// Perform a post-order DFS traversal starting from the given root node.
379    pub fn post_order_dfs(&self, root: &NodeName) -> Option<Vec<NodeName>> {
380        self.topology.post_order_dfs(root)
381    }
382
383    /// Perform a post-order DFS traversal starting from the given root NodeIndex.
384    pub fn post_order_dfs_by_index(&self, root: NodeIndex) -> Vec<NodeIndex> {
385        self.topology.post_order_dfs_by_index(root)
386    }
387
388    /// Find the shortest path between two nodes.
389    pub fn path_between(&self, from: NodeIndex, to: NodeIndex) -> Option<Vec<NodeIndex>> {
390        self.topology.path_between(from, to)
391    }
392
393    /// Compute the Steiner tree nodes spanning a set of terminal nodes.
394    ///
395    /// Delegates to [`NodeNameNetwork::steiner_tree_nodes`].
396    ///
397    /// # Examples
398    ///
399    /// ```
400    /// use std::collections::HashSet;
401    /// use tensor4all_core::DynIndex;
402    /// use tensor4all_treetn::SiteIndexNetwork;
403    ///
404    /// let mut net: SiteIndexNetwork<String, DynIndex> = SiteIndexNetwork::new();
405    /// let a = net.add_node("A".to_string(), HashSet::<DynIndex>::new()).unwrap();
406    /// let b = net.add_node("B".to_string(), HashSet::<DynIndex>::new()).unwrap();
407    /// let c = net.add_node("C".to_string(), HashSet::<DynIndex>::new()).unwrap();
408    /// net.add_edge(&"A".to_string(), &"B".to_string()).unwrap();
409    /// net.add_edge(&"B".to_string(), &"C".to_string()).unwrap();
410    ///
411    /// let steiner = net.steiner_tree_nodes(&[a, c].into_iter().collect::<HashSet<_>>());
412    /// assert_eq!(steiner, [a, b, c].into_iter().collect());
413    /// ```
414    pub fn steiner_tree_nodes(&self, terminals: &HashSet<NodeIndex>) -> HashSet<NodeIndex> {
415        self.topology.steiner_tree_nodes(terminals)
416    }
417
418    /// Check if a subset of nodes forms a connected subgraph.
419    pub fn is_connected_subset(&self, nodes: &HashSet<NodeIndex>) -> bool {
420        self.topology.is_connected_subset(nodes)
421    }
422
423    /// Compute edges to canonicalize from current state to target.
424    pub fn edges_to_canonicalize(
425        &self,
426        current_region: Option<&HashSet<NodeIndex>>,
427        target: NodeIndex,
428    ) -> CanonicalizeEdges {
429        self.topology.edges_to_canonicalize(current_region, target)
430    }
431
432    /// Compute edges to canonicalize from leaves to target, returning node names.
433    ///
434    /// This is similar to `edges_to_canonicalize(None, target)` but returns
435    /// `(from_name, to_name)` pairs instead of `(NodeIndex, NodeIndex)`.
436    ///
437    /// See [`NodeNameNetwork::edges_to_canonicalize_by_names`] for details.
438    pub fn edges_to_canonicalize_by_names(
439        &self,
440        target: &NodeName,
441    ) -> Option<Vec<(NodeName, NodeName)>> {
442        self.topology.edges_to_canonicalize_by_names(target)
443    }
444
445    /// Compute edges to canonicalize from leaves towards a connected region (multiple centers).
446    ///
447    /// See [`NodeNameNetwork::edges_to_canonicalize_to_region`] for details.
448    pub fn edges_to_canonicalize_to_region(
449        &self,
450        target_region: &HashSet<NodeIndex>,
451    ) -> CanonicalizeEdges {
452        self.topology.edges_to_canonicalize_to_region(target_region)
453    }
454
455    /// Compute edges to canonicalize towards a region, returning node names.
456    ///
457    /// See [`NodeNameNetwork::edges_to_canonicalize_to_region_by_names`] for details.
458    pub fn edges_to_canonicalize_to_region_by_names(
459        &self,
460        target_region: &HashSet<NodeName>,
461    ) -> Option<Vec<(NodeName, NodeName)>> {
462        self.topology
463            .edges_to_canonicalize_to_region_by_names(target_region)
464    }
465
466    // =========================================================================
467    // Operator/state compatibility checking
468    // =========================================================================
469
470    /// Check if an operator can act on this state (as a ket).
471    ///
472    /// Returns `Ok(result_network)` if the operator can act on self,
473    /// where `result_network` is the SiteIndexNetwork of the output state.
474    ///
475    /// For an operator to act on a state:
476    /// - They must have the same topology (same nodes and edges)
477    /// - The operator must have indices that can contract with the state's site indices
478    ///
479    /// # Arguments
480    /// * `operator` - The operator's SiteIndexNetwork
481    ///
482    /// # Returns
483    /// - `Ok(SiteIndexNetwork)` - The resulting state's site index network after the operator acts
484    /// - `Err(String)` - Error message if the operator cannot act on this state
485    ///
486    /// # Note
487    /// This is a simplified version that assumes the operator's output indices
488    /// have the same structure as the input (i.e., the result has the same
489    /// site index structure as the original state). For more complex operators
490    /// with different input/output dimensions, a more sophisticated approach
491    /// would be needed.
492    pub fn apply_operator_topology(&self, operator: &Self) -> Result<Self, String> {
493        // Check topology match
494        if !self.topology.same_topology(&operator.topology) {
495            return Err(format!(
496                "Operator and state have different topologies. State nodes: {:?}, Operator nodes: {:?}",
497                self.node_names(),
498                operator.node_names()
499            ));
500        }
501
502        // For now, assume the operator preserves the site index structure
503        // (output has same site indices as input). This is the common case
504        // for Hamiltonians and other operators where H|ψ⟩ has the same
505        // index structure as |ψ⟩.
506        //
507        // A more complete implementation would:
508        // 1. Check that operator has compatible input indices
509        // 2. Return the actual output index structure
510        Ok(self.clone())
511    }
512
513    /// Check if this network has compatible site dimensions with another.
514    ///
515    /// Two networks have compatible site dimensions if:
516    /// - Same topology (nodes and edges)
517    /// - Each node has the same number of site indices
518    /// - Site index dimensions match (after sorting, order doesn't matter)
519    ///
520    /// This is useful for checking if two states can be added or if
521    /// a state matches the expected output of an operator.
522    pub fn compatible_site_dimensions(&self, other: &Self) -> bool {
523        // Check topology
524        if !self.topology.same_topology(&other.topology) {
525            return false;
526        }
527
528        // Check site dimensions for each node
529        for name in self.node_names() {
530            match (self.site_space(name), other.site_space(name)) {
531                (Some(self_indices), Some(other_indices)) => {
532                    // Check same number of indices
533                    if self_indices.len() != other_indices.len() {
534                        return false;
535                    }
536
537                    // Get dimensions and sort for comparison
538                    // Use IndexLike::dim() to get the dimension
539                    let mut self_dims: Vec<_> = self_indices.iter().map(|idx| idx.dim()).collect();
540                    let mut other_dims: Vec<_> =
541                        other_indices.iter().map(|idx| idx.dim()).collect();
542                    self_dims.sort();
543                    other_dims.sort();
544
545                    if self_dims != other_dims {
546                        return false;
547                    }
548                }
549                (None, None) => continue,
550                _ => return false,
551            }
552        }
553
554        true
555    }
556}
557
558impl<NodeName, I> Default for SiteIndexNetwork<NodeName, I>
559where
560    NodeName: Clone + Hash + Eq + Send + Sync + Debug,
561    I: IndexLike,
562{
563    fn default() -> Self {
564        Self::new()
565    }
566}
567
568// ============================================================================
569// Type alias for backwards compatibility
570// ============================================================================
571
572use tensor4all_core::index::{DynId, Index};
573use tensor4all_core::DefaultTagSet;
574
575/// Type alias for the default SiteIndexNetwork using DynId indices.
576///
577/// This preserves backwards compatibility with existing code that uses
578/// `SiteIndexNetwork<NodeName, Id, Symm, Tags>`.
579pub type DefaultSiteIndexNetwork<NodeName> =
580    SiteIndexNetwork<NodeName, Index<DynId, DefaultTagSet>>;
581
582#[cfg(test)]
583mod tests;