tensor4all_treetn/site_index_network.rs
1//! Site Index Network (inspired by ITensorNetworks.jl's IndsNetwork)
2//!
3//! Provides a structure combining:
4//! - **NodeNameNetwork**: Graph topology (node connections)
5//! - **Site space map**: Physical indices at each node (`HashMap<NodeName, HashSet<I>>`)
6//!
7//! This design separates the index structure from tensor data,
8//! enabling topology and site space comparison independent of tensor values.
9
10use crate::node_name_network::{CanonicalizeEdges, NodeNameNetwork};
11use petgraph::stable_graph::{EdgeIndex, NodeIndex, StableGraph};
12use petgraph::Undirected;
13use std::collections::{HashMap, HashSet};
14use std::fmt::Debug;
15use std::hash::Hash;
16use tensor4all_core::IndexLike;
17
18// Re-export CanonicalizeEdges for convenience
19pub use crate::node_name_network::CanonicalizeEdges as CanonicalizeEdgesType;
20
21/// Site Index Network (inspired by ITensorNetworks.jl's IndsNetwork)
22///
23/// Represents the index structure of a tensor network:
24/// - **Topology**: Graph structure via `NodeNameNetwork`
25/// - **Site space**: Physical indices at each node via `HashMap`
26///
27/// This structure enables:
28/// - Comparing topologies and site spaces independently of tensor data
29/// - Extracting index information without accessing tensor values
30/// - Validating network structure consistency
31///
32/// # Type Parameters
33/// - `NodeName`: Node name type (must be Clone, Hash, Eq, Send, Sync, Debug)
34/// - `I`: Index type (must implement `IndexLike`)
35///
36/// # Examples
37///
38/// ```
39/// use std::collections::HashSet;
40/// use tensor4all_core::index::{DynId, Index, TagSet};
41/// use tensor4all_treetn::SiteIndexNetwork;
42///
43/// let mut net = SiteIndexNetwork::<String, Index<DynId, TagSet>>::new();
44/// let idx_a = Index::new_dyn(2);
45/// let idx_b = Index::new_dyn(3);
46///
47/// net.add_node("A".to_string(), HashSet::from([idx_a])).unwrap();
48/// net.add_node("B".to_string(), HashSet::from([idx_b])).unwrap();
49/// net.add_edge(&"A".to_string(), &"B".to_string()).unwrap();
50///
51/// assert_eq!(net.node_count(), 2);
52/// assert_eq!(net.edge_count(), 1);
53/// ```
54#[derive(Debug, Clone)]
55pub struct SiteIndexNetwork<NodeName, I>
56where
57 NodeName: Clone + Hash + Eq + Send + Sync + Debug,
58 I: IndexLike,
59{
60 /// Graph topology (node names and connections only).
61 topology: NodeNameNetwork<NodeName>,
62 /// Site space (physical indices) for each node.
63 site_spaces: HashMap<NodeName, HashSet<I>>,
64 /// Reverse lookup: index ID → node name containing this index.
65 index_to_node: HashMap<I::Id, NodeName>,
66}
67
68impl<NodeName, I> SiteIndexNetwork<NodeName, I>
69where
70 NodeName: Clone + Hash + Eq + Send + Sync + Debug,
71 I: IndexLike,
72{
73 /// Create a new empty SiteIndexNetwork.
74 pub fn new() -> Self {
75 Self {
76 topology: NodeNameNetwork::new(),
77 site_spaces: HashMap::new(),
78 index_to_node: HashMap::new(),
79 }
80 }
81
82 /// Create a new SiteIndexNetwork with initial capacity.
83 pub fn with_capacity(nodes: usize, edges: usize) -> Self {
84 Self {
85 topology: NodeNameNetwork::with_capacity(nodes, edges),
86 site_spaces: HashMap::with_capacity(nodes),
87 index_to_node: HashMap::new(),
88 }
89 }
90
91 /// Add a node with site space (physical indices).
92 ///
93 /// # Arguments
94 /// * `node_name` - The name of the node
95 /// * `site_space` - The physical indices at this node (order doesn't matter)
96 ///
97 /// Returns an error if the node already exists.
98 pub fn add_node(
99 &mut self,
100 node_name: NodeName,
101 site_space: impl Into<HashSet<I>>,
102 ) -> Result<NodeIndex, String> {
103 let node_idx = self.topology.add_node(node_name.clone())?;
104 let site_space_set = site_space.into();
105 // Update reverse lookup for all indices
106 for idx in &site_space_set {
107 self.index_to_node
108 .insert(idx.id().clone(), node_name.clone());
109 }
110 self.site_spaces.insert(node_name, site_space_set);
111 Ok(node_idx)
112 }
113
114 /// Check if a node exists.
115 pub fn has_node(&self, node_name: &NodeName) -> bool {
116 self.topology.has_node(node_name)
117 }
118
119 /// Rename an existing node and preserve its site-space metadata.
120 pub fn rename_node(&mut self, old_name: &NodeName, new_name: NodeName) -> Result<(), String> {
121 if old_name == &new_name {
122 return Ok(());
123 }
124
125 let site_space = self
126 .site_spaces
127 .remove(old_name)
128 .ok_or_else(|| format!("Node {:?} not found", old_name))?;
129 self.topology.rename_node(old_name, new_name.clone())?;
130 for index in &site_space {
131 self.index_to_node
132 .insert(index.id().clone(), new_name.clone());
133 }
134 self.site_spaces.insert(new_name, site_space);
135 Ok(())
136 }
137
138 /// Get the site space (physical indices) for a node.
139 pub fn site_space(&self, node_name: &NodeName) -> Option<&HashSet<I>> {
140 self.site_spaces.get(node_name)
141 }
142
143 /// Get a mutable reference to the site space for a node.
144 ///
145 /// **Warning**: Direct modification of site space via this method does NOT
146 /// update the reverse lookup (`index_to_node`). Use `add_site_index()`,
147 /// `remove_site_index()`, or `replace_site_index()` for modifications
148 /// that maintain consistency.
149 pub fn site_space_mut(&mut self, node_name: &NodeName) -> Option<&mut HashSet<I>> {
150 self.site_spaces.get_mut(node_name)
151 }
152
153 /// Find the node containing a given site index.
154 ///
155 /// # Arguments
156 /// * `index` - The index to look up
157 ///
158 /// # Returns
159 /// The node name containing this index, or None if not found.
160 pub fn find_node_by_index(&self, index: &I) -> Option<&NodeName> {
161 self.index_to_node.get(index.id())
162 }
163
164 /// Find the node containing an index by ID.
165 pub fn find_node_by_index_id(&self, id: &I::Id) -> Option<&NodeName> {
166 self.index_to_node.get(id)
167 }
168
169 /// Check if a site index is registered.
170 pub fn contains_index(&self, index: &I) -> bool {
171 self.index_to_node.contains_key(index.id())
172 }
173
174 /// Add a site index to a node's site space.
175 ///
176 /// Updates both the site space and the reverse lookup.
177 pub fn add_site_index(&mut self, node_name: &NodeName, index: I) -> Result<(), String> {
178 let site_space = self
179 .site_spaces
180 .get_mut(node_name)
181 .ok_or_else(|| format!("Node {:?} not found", node_name))?;
182 site_space.insert(index.clone());
183 self.index_to_node
184 .insert(index.id().clone(), node_name.clone());
185 Ok(())
186 }
187
188 /// Remove a site index from a node's site space.
189 ///
190 /// Updates both the site space and the reverse lookup.
191 pub fn remove_site_index(&mut self, node_name: &NodeName, index: &I) -> Result<bool, String> {
192 let site_space = self
193 .site_spaces
194 .get_mut(node_name)
195 .ok_or_else(|| format!("Node {:?} not found", node_name))?;
196 let removed = site_space.remove(index);
197 if removed {
198 self.index_to_node.remove(index.id());
199 }
200 Ok(removed)
201 }
202
203 /// Replace a site index in a node's site space.
204 ///
205 /// Updates both the site space and the reverse lookup.
206 pub fn replace_site_index(
207 &mut self,
208 node_name: &NodeName,
209 old_index: &I,
210 new_index: I,
211 ) -> Result<(), String> {
212 let site_space = self
213 .site_spaces
214 .get_mut(node_name)
215 .ok_or_else(|| format!("Node {:?} not found", node_name))?;
216 if !site_space.remove(old_index) {
217 return Err(format!(
218 "Index {:?} not found in node {:?}",
219 old_index.id(),
220 node_name
221 ));
222 }
223 self.index_to_node.remove(old_index.id());
224 site_space.insert(new_index.clone());
225 self.index_to_node
226 .insert(new_index.id().clone(), node_name.clone());
227 Ok(())
228 }
229
230 /// Replace all site indices for a node with a new set.
231 ///
232 /// Updates both the site space and the reverse lookup.
233 /// This is an atomic operation that removes all old indices and adds all new ones.
234 pub fn set_site_space(
235 &mut self,
236 node_name: &NodeName,
237 new_indices: HashSet<I>,
238 ) -> Result<(), String> {
239 let site_space = self
240 .site_spaces
241 .get_mut(node_name)
242 .ok_or_else(|| format!("Node {:?} not found", node_name))?;
243
244 // Remove old indices from index_to_node
245 for old_idx in site_space.iter() {
246 if self.index_to_node.get(old_idx.id()) == Some(node_name) {
247 self.index_to_node.remove(old_idx.id());
248 }
249 }
250
251 // Add new indices to index_to_node
252 for new_idx in &new_indices {
253 self.index_to_node
254 .insert(new_idx.id().clone(), node_name.clone());
255 }
256
257 // Replace site space
258 *site_space = new_indices;
259
260 Ok(())
261 }
262
263 /// Get the site space by NodeIndex.
264 pub fn site_space_by_index(&self, node: NodeIndex) -> Option<&HashSet<I>> {
265 let name = self.topology.node_name(node)?;
266 self.site_spaces.get(name)
267 }
268
269 /// Add an edge between two nodes.
270 ///
271 /// Returns an error if either node doesn't exist.
272 pub fn add_edge(&mut self, n1: &NodeName, n2: &NodeName) -> Result<EdgeIndex, String> {
273 self.topology.add_edge(n1, n2)
274 }
275
276 /// Get the NodeIndex for a node name.
277 pub fn node_index(&self, node_name: &NodeName) -> Option<NodeIndex> {
278 self.topology.node_index(node_name)
279 }
280
281 /// Get the node name for a NodeIndex.
282 pub fn node_name(&self, node: NodeIndex) -> Option<&NodeName> {
283 self.topology.node_name(node)
284 }
285
286 /// Get all node names.
287 pub fn node_names(&self) -> Vec<&NodeName> {
288 self.topology.node_names()
289 }
290
291 /// Get the number of nodes.
292 pub fn node_count(&self) -> usize {
293 self.topology.node_count()
294 }
295
296 /// Get the number of edges.
297 pub fn edge_count(&self) -> usize {
298 self.topology.edge_count()
299 }
300
301 /// Get a reference to the underlying topology (NodeNameNetwork).
302 pub fn topology(&self) -> &NodeNameNetwork<NodeName> {
303 &self.topology
304 }
305
306 /// Get all edges as pairs of node names.
307 ///
308 /// Returns an iterator of `(NodeName, NodeName)` pairs.
309 pub fn edges(&self) -> impl Iterator<Item = (NodeName, NodeName)> + '_ {
310 let graph = self.topology.graph();
311 graph.edge_indices().filter_map(move |edge| {
312 let (a, b) = graph.edge_endpoints(edge)?;
313 let name_a = self.topology.node_name(a)?.clone();
314 let name_b = self.topology.node_name(b)?.clone();
315 Some((name_a, name_b))
316 })
317 }
318
319 /// Get all neighbors of a node.
320 ///
321 /// Returns an iterator of neighbor node names.
322 pub fn neighbors(&self, node_name: &NodeName) -> impl Iterator<Item = NodeName> + '_ {
323 let node_idx = self.topology.node_index(node_name);
324 let graph = self.topology.graph();
325 let topology = &self.topology;
326
327 node_idx
328 .into_iter()
329 .flat_map(move |idx| graph.neighbors(idx))
330 .filter_map(move |n| topology.node_name(n).cloned())
331 }
332
333 /// Get a reference to the internal graph.
334 pub fn graph(&self) -> &StableGraph<(), (), Undirected> {
335 self.topology.graph()
336 }
337
338 /// Get a mutable reference to the internal graph.
339 ///
340 /// **Warning**: Directly modifying the internal graph can break consistency.
341 pub fn graph_mut(&mut self) -> &mut StableGraph<(), (), Undirected> {
342 self.topology.graph_mut()
343 }
344
345 /// Check if two SiteIndexNetworks share equivalent site index structure.
346 ///
347 /// Two networks are equivalent if:
348 /// - Same topology (nodes and edges)
349 /// - Same site space for each node
350 ///
351 /// This is used to verify that two TreeTNs can be added or contracted.
352 pub fn share_equivalent_site_index_network(&self, other: &Self) -> bool {
353 // Check topology
354 if !self.topology.same_topology(&other.topology) {
355 return false;
356 }
357
358 // Check site spaces
359 for name in self.node_names() {
360 match (self.site_space(name), other.site_space(name)) {
361 (Some(self_indices), Some(other_indices)) => {
362 if self_indices != other_indices {
363 return false;
364 }
365 }
366 (None, None) => continue,
367 _ => return false,
368 }
369 }
370
371 true
372 }
373
374 // =========================================================================
375 // Delegated graph algorithms (from NodeNameNetwork)
376 // =========================================================================
377
378 /// Perform a post-order DFS traversal starting from the given root node.
379 pub fn post_order_dfs(&self, root: &NodeName) -> Option<Vec<NodeName>> {
380 self.topology.post_order_dfs(root)
381 }
382
383 /// Perform a post-order DFS traversal starting from the given root NodeIndex.
384 pub fn post_order_dfs_by_index(&self, root: NodeIndex) -> Vec<NodeIndex> {
385 self.topology.post_order_dfs_by_index(root)
386 }
387
388 /// Find the shortest path between two nodes.
389 pub fn path_between(&self, from: NodeIndex, to: NodeIndex) -> Option<Vec<NodeIndex>> {
390 self.topology.path_between(from, to)
391 }
392
393 /// Compute the Steiner tree nodes spanning a set of terminal nodes.
394 ///
395 /// Delegates to [`NodeNameNetwork::steiner_tree_nodes`].
396 ///
397 /// # Examples
398 ///
399 /// ```
400 /// use std::collections::HashSet;
401 /// use tensor4all_core::DynIndex;
402 /// use tensor4all_treetn::SiteIndexNetwork;
403 ///
404 /// let mut net: SiteIndexNetwork<String, DynIndex> = SiteIndexNetwork::new();
405 /// let a = net.add_node("A".to_string(), HashSet::<DynIndex>::new()).unwrap();
406 /// let b = net.add_node("B".to_string(), HashSet::<DynIndex>::new()).unwrap();
407 /// let c = net.add_node("C".to_string(), HashSet::<DynIndex>::new()).unwrap();
408 /// net.add_edge(&"A".to_string(), &"B".to_string()).unwrap();
409 /// net.add_edge(&"B".to_string(), &"C".to_string()).unwrap();
410 ///
411 /// let steiner = net.steiner_tree_nodes(&[a, c].into_iter().collect::<HashSet<_>>());
412 /// assert_eq!(steiner, [a, b, c].into_iter().collect());
413 /// ```
414 pub fn steiner_tree_nodes(&self, terminals: &HashSet<NodeIndex>) -> HashSet<NodeIndex> {
415 self.topology.steiner_tree_nodes(terminals)
416 }
417
418 /// Check if a subset of nodes forms a connected subgraph.
419 pub fn is_connected_subset(&self, nodes: &HashSet<NodeIndex>) -> bool {
420 self.topology.is_connected_subset(nodes)
421 }
422
423 /// Compute edges to canonicalize from current state to target.
424 pub fn edges_to_canonicalize(
425 &self,
426 current_region: Option<&HashSet<NodeIndex>>,
427 target: NodeIndex,
428 ) -> CanonicalizeEdges {
429 self.topology.edges_to_canonicalize(current_region, target)
430 }
431
432 /// Compute edges to canonicalize from leaves to target, returning node names.
433 ///
434 /// This is similar to `edges_to_canonicalize(None, target)` but returns
435 /// `(from_name, to_name)` pairs instead of `(NodeIndex, NodeIndex)`.
436 ///
437 /// See [`NodeNameNetwork::edges_to_canonicalize_by_names`] for details.
438 pub fn edges_to_canonicalize_by_names(
439 &self,
440 target: &NodeName,
441 ) -> Option<Vec<(NodeName, NodeName)>> {
442 self.topology.edges_to_canonicalize_by_names(target)
443 }
444
445 /// Compute edges to canonicalize from leaves towards a connected region (multiple centers).
446 ///
447 /// See [`NodeNameNetwork::edges_to_canonicalize_to_region`] for details.
448 pub fn edges_to_canonicalize_to_region(
449 &self,
450 target_region: &HashSet<NodeIndex>,
451 ) -> CanonicalizeEdges {
452 self.topology.edges_to_canonicalize_to_region(target_region)
453 }
454
455 /// Compute edges to canonicalize towards a region, returning node names.
456 ///
457 /// See [`NodeNameNetwork::edges_to_canonicalize_to_region_by_names`] for details.
458 pub fn edges_to_canonicalize_to_region_by_names(
459 &self,
460 target_region: &HashSet<NodeName>,
461 ) -> Option<Vec<(NodeName, NodeName)>> {
462 self.topology
463 .edges_to_canonicalize_to_region_by_names(target_region)
464 }
465
466 // =========================================================================
467 // Operator/state compatibility checking
468 // =========================================================================
469
470 /// Check if an operator can act on this state (as a ket).
471 ///
472 /// Returns `Ok(result_network)` if the operator can act on self,
473 /// where `result_network` is the SiteIndexNetwork of the output state.
474 ///
475 /// For an operator to act on a state:
476 /// - They must have the same topology (same nodes and edges)
477 /// - The operator must have indices that can contract with the state's site indices
478 ///
479 /// # Arguments
480 /// * `operator` - The operator's SiteIndexNetwork
481 ///
482 /// # Returns
483 /// - `Ok(SiteIndexNetwork)` - The resulting state's site index network after the operator acts
484 /// - `Err(String)` - Error message if the operator cannot act on this state
485 ///
486 /// # Note
487 /// This is a simplified version that assumes the operator's output indices
488 /// have the same structure as the input (i.e., the result has the same
489 /// site index structure as the original state). For more complex operators
490 /// with different input/output dimensions, a more sophisticated approach
491 /// would be needed.
492 pub fn apply_operator_topology(&self, operator: &Self) -> Result<Self, String> {
493 // Check topology match
494 if !self.topology.same_topology(&operator.topology) {
495 return Err(format!(
496 "Operator and state have different topologies. State nodes: {:?}, Operator nodes: {:?}",
497 self.node_names(),
498 operator.node_names()
499 ));
500 }
501
502 // For now, assume the operator preserves the site index structure
503 // (output has same site indices as input). This is the common case
504 // for Hamiltonians and other operators where H|ψ⟩ has the same
505 // index structure as |ψ⟩.
506 //
507 // A more complete implementation would:
508 // 1. Check that operator has compatible input indices
509 // 2. Return the actual output index structure
510 Ok(self.clone())
511 }
512
513 /// Check if this network has compatible site dimensions with another.
514 ///
515 /// Two networks have compatible site dimensions if:
516 /// - Same topology (nodes and edges)
517 /// - Each node has the same number of site indices
518 /// - Site index dimensions match (after sorting, order doesn't matter)
519 ///
520 /// This is useful for checking if two states can be added or if
521 /// a state matches the expected output of an operator.
522 pub fn compatible_site_dimensions(&self, other: &Self) -> bool {
523 // Check topology
524 if !self.topology.same_topology(&other.topology) {
525 return false;
526 }
527
528 // Check site dimensions for each node
529 for name in self.node_names() {
530 match (self.site_space(name), other.site_space(name)) {
531 (Some(self_indices), Some(other_indices)) => {
532 // Check same number of indices
533 if self_indices.len() != other_indices.len() {
534 return false;
535 }
536
537 // Get dimensions and sort for comparison
538 // Use IndexLike::dim() to get the dimension
539 let mut self_dims: Vec<_> = self_indices.iter().map(|idx| idx.dim()).collect();
540 let mut other_dims: Vec<_> =
541 other_indices.iter().map(|idx| idx.dim()).collect();
542 self_dims.sort();
543 other_dims.sort();
544
545 if self_dims != other_dims {
546 return false;
547 }
548 }
549 (None, None) => continue,
550 _ => return false,
551 }
552 }
553
554 true
555 }
556}
557
558impl<NodeName, I> Default for SiteIndexNetwork<NodeName, I>
559where
560 NodeName: Clone + Hash + Eq + Send + Sync + Debug,
561 I: IndexLike,
562{
563 fn default() -> Self {
564 Self::new()
565 }
566}
567
568// ============================================================================
569// Type alias for backwards compatibility
570// ============================================================================
571
572use tensor4all_core::index::{DynId, Index};
573use tensor4all_core::DefaultTagSet;
574
575/// Type alias for the default SiteIndexNetwork using DynId indices.
576///
577/// This preserves backwards compatibility with existing code that uses
578/// `SiteIndexNetwork<NodeName, Id, Symm, Tags>`.
579pub type DefaultSiteIndexNetwork<NodeName> =
580 SiteIndexNetwork<NodeName, Index<DynId, DefaultTagSet>>;
581
582#[cfg(test)]
583mod tests;