Skip to main content

tensor4all_treetn/
node_name_network.rs

1//! Node Name Network - Graph structure for node name connections.
2//!
3//! Provides a pure graph structure where:
4//! - Nodes are identified by names (generic type `NodeName`)
5//! - Edges represent connections between nodes (no data stored)
6//!
7//! This is a foundation for `SiteIndexNetwork` and can be used independently
8//! when only the graph structure (without index information) is needed.
9
10use crate::named_graph::NamedGraph;
11use petgraph::algo::astar;
12use petgraph::stable_graph::{EdgeIndex, NodeIndex, StableGraph};
13use petgraph::visit::DfsPostOrder;
14use petgraph::Undirected;
15use std::collections::{HashMap, HashSet, VecDeque};
16use std::fmt::Debug;
17use std::hash::Hash;
18
19/// Ordered sequence of directed edges for canonicalization.
20///
21/// Each edge is `(from, to)` where:
22/// - `from` is the node being orthogonalized away from
23/// - `to` is the direction towards the orthogonality center
24///
25/// # Note on ordering
26/// - For path-based canonicalization (moving ortho center), edges are connected:
27///   each edge's `to` equals the next edge's `from`.
28/// - For full canonicalization (from scratch), edges represent parent edges in
29///   post-order DFS traversal, which may not be connected as a path but
30///   guarantees correct processing order (children before parents).
31///
32/// # Example
33/// For a chain A - B - C - D (canonizing towards D):
34/// ```text
35/// edges = [(A, B), (B, C), (C, D)]
36/// ```
37///
38/// For a star with center C (canonizing towards C):
39/// ```text
40///     A
41///     |
42/// B - C - D
43///     |
44///     E
45///
46/// edges = [(A, C), (B, C), (D, C), (E, C)]  (order depends on DFS)
47/// ```
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub struct CanonicalizeEdges {
50    edges: Vec<(NodeIndex, NodeIndex)>,
51}
52
53impl CanonicalizeEdges {
54    /// Create an empty edge sequence (no-op canonicalization).
55    pub fn empty() -> Self {
56        Self { edges: Vec::new() }
57    }
58
59    /// Create from a list of edges.
60    ///
61    /// Note: For path-based canonicalization, edges should be connected (each edge's `to`
62    /// equals next edge's `from`). For full canonicalization, edges may not be connected
63    /// but must be in correct processing order.
64    pub fn from_edges(edges: Vec<(NodeIndex, NodeIndex)>) -> Self {
65        Self { edges }
66    }
67
68    /// Check if empty (already at target, no work needed).
69    pub fn is_empty(&self) -> bool {
70        self.edges.is_empty()
71    }
72
73    /// Number of edges to process.
74    pub fn len(&self) -> usize {
75        self.edges.len()
76    }
77
78    /// Iterate over edges in order.
79    pub fn iter(&self) -> impl Iterator<Item = &(NodeIndex, NodeIndex)> {
80        self.edges.iter()
81    }
82
83    /// Get the final target node (orthogonality center).
84    ///
85    /// Returns `None` if empty.
86    pub fn target(&self) -> Option<NodeIndex> {
87        self.edges.last().map(|(_, to)| *to)
88    }
89
90    /// Get the starting node (first node to be factorized).
91    ///
92    /// Returns `None` if empty.
93    pub fn start(&self) -> Option<NodeIndex> {
94        self.edges.first().map(|(from, _)| *from)
95    }
96}
97
98impl IntoIterator for CanonicalizeEdges {
99    type Item = (NodeIndex, NodeIndex);
100    type IntoIter = std::vec::IntoIter<Self::Item>;
101
102    fn into_iter(self) -> Self::IntoIter {
103        self.edges.into_iter()
104    }
105}
106
107impl<'a> IntoIterator for &'a CanonicalizeEdges {
108    type Item = &'a (NodeIndex, NodeIndex);
109    type IntoIter = std::slice::Iter<'a, (NodeIndex, NodeIndex)>;
110
111    fn into_iter(self) -> Self::IntoIter {
112        self.edges.iter()
113    }
114}
115
116/// Node Name Network - Pure graph structure for node connections.
117///
118/// Represents the topology of a network without any data attached to nodes or edges.
119/// This is useful for graph algorithms that only need connectivity information.
120///
121/// # Type Parameters
122/// - `NodeName`: Node name type (must be Clone, Hash, Eq, Send, Sync, Debug)
123#[derive(Debug, Clone)]
124pub struct NodeNameNetwork<NodeName>
125where
126    NodeName: Clone + Hash + Eq + Send + Sync + Debug,
127{
128    /// Named graph with unit node and edge data.
129    graph: NamedGraph<NodeName, (), ()>,
130}
131
132impl<NodeName> NodeNameNetwork<NodeName>
133where
134    NodeName: Clone + Hash + Eq + Send + Sync + Debug,
135{
136    /// Create a new empty NodeNameNetwork.
137    pub fn new() -> Self {
138        Self {
139            graph: NamedGraph::new(),
140        }
141    }
142
143    /// Create a new NodeNameNetwork with initial capacity.
144    pub fn with_capacity(nodes: usize, edges: usize) -> Self {
145        Self {
146            graph: NamedGraph::with_capacity(nodes, edges),
147        }
148    }
149
150    /// Add a node to the network.
151    ///
152    /// Returns an error if the node already exists.
153    pub fn add_node(&mut self, node_name: NodeName) -> Result<NodeIndex, String> {
154        self.graph.add_node(node_name, ())
155    }
156
157    /// Check if a node exists.
158    pub fn has_node(&self, node_name: &NodeName) -> bool {
159        self.graph.has_node(node_name)
160    }
161
162    /// Add an edge between two nodes.
163    ///
164    /// Returns an error if either node doesn't exist.
165    pub fn add_edge(&mut self, n1: &NodeName, n2: &NodeName) -> Result<EdgeIndex, String> {
166        self.graph.add_edge(n1, n2, ())
167    }
168
169    /// Get the NodeIndex for a node name.
170    pub fn node_index(&self, node_name: &NodeName) -> Option<NodeIndex> {
171        self.graph.node_index(node_name)
172    }
173
174    /// Get the node name for a NodeIndex.
175    pub fn node_name(&self, node: NodeIndex) -> Option<&NodeName> {
176        self.graph.node_name(node)
177    }
178
179    /// Rename an existing node.
180    pub fn rename_node(&mut self, old_name: &NodeName, new_name: NodeName) -> Result<(), String> {
181        self.graph.rename_node(old_name, new_name)
182    }
183
184    /// Get all node names.
185    pub fn node_names(&self) -> Vec<&NodeName> {
186        self.graph.node_names()
187    }
188
189    /// Get the number of nodes.
190    pub fn node_count(&self) -> usize {
191        self.graph.node_count()
192    }
193
194    /// Get the number of edges.
195    pub fn edge_count(&self) -> usize {
196        self.graph.edge_count()
197    }
198
199    /// Get a reference to the internal graph.
200    pub fn graph(&self) -> &StableGraph<(), (), Undirected> {
201        self.graph.graph()
202    }
203
204    /// Get a mutable reference to the internal graph.
205    ///
206    /// **Warning**: Directly modifying the internal graph can break the node-name-to-index mapping.
207    pub fn graph_mut(&mut self) -> &mut StableGraph<(), (), Undirected> {
208        self.graph.graph_mut()
209    }
210
211    /// Perform a post-order DFS traversal starting from the given root node.
212    ///
213    /// Returns node names in post-order (children before parents, leaves first).
214    ///
215    /// # Arguments
216    /// * `root` - The node name to start traversal from
217    ///
218    /// # Returns
219    /// `Some(Vec<NodeName>)` with nodes in post-order, or `None` if root doesn't exist.
220    pub fn post_order_dfs(&self, root: &NodeName) -> Option<Vec<NodeName>> {
221        let root_idx = self.graph.node_index(root)?;
222        let g = self.graph.graph();
223
224        let mut dfs = DfsPostOrder::new(g, root_idx);
225        let mut result = Vec::new();
226
227        while let Some(node_idx) = dfs.next(g) {
228            if let Some(name) = self.graph.node_name(node_idx) {
229                result.push(name.clone());
230            }
231        }
232
233        Some(result)
234    }
235
236    /// Perform a post-order DFS traversal starting from the given root NodeIndex.
237    ///
238    /// Returns NodeIndex in post-order (children before parents, leaves first).
239    pub fn post_order_dfs_by_index(&self, root: NodeIndex) -> Vec<NodeIndex> {
240        let g = self.graph.graph();
241        let mut dfs = DfsPostOrder::new(g, root);
242        let mut result = Vec::new();
243
244        while let Some(node_idx) = dfs.next(g) {
245            result.push(node_idx);
246        }
247
248        result
249    }
250
251    /// Perform an Euler tour traversal starting from the given root node.
252    ///
253    /// Delegates to [`NamedGraph::euler_tour_edges`].
254    pub fn euler_tour_edges(&self, root: &NodeName) -> Option<Vec<(NodeIndex, NodeIndex)>> {
255        self.graph.euler_tour_edges(root)
256    }
257
258    /// Perform an Euler tour traversal starting from the given root NodeIndex.
259    ///
260    /// Delegates to [`NamedGraph::euler_tour_edges_by_index`].
261    pub fn euler_tour_edges_by_index(&self, root: NodeIndex) -> Vec<(NodeIndex, NodeIndex)> {
262        self.graph.euler_tour_edges_by_index(root)
263    }
264
265    /// Perform an Euler tour traversal and return the vertex sequence.
266    ///
267    /// Delegates to [`NamedGraph::euler_tour_vertices`].
268    pub fn euler_tour_vertices(&self, root: &NodeName) -> Option<Vec<NodeIndex>> {
269        self.graph.euler_tour_vertices(root)
270    }
271
272    /// Perform an Euler tour traversal and return the vertex sequence by NodeIndex.
273    ///
274    /// Delegates to [`NamedGraph::euler_tour_vertices_by_index`].
275    pub fn euler_tour_vertices_by_index(&self, root: NodeIndex) -> Vec<NodeIndex> {
276        self.graph.euler_tour_vertices_by_index(root)
277    }
278
279    /// Find the shortest path between two nodes using A* algorithm.
280    ///
281    /// Since this is an unweighted graph, we use unit edge weights.
282    ///
283    /// # Returns
284    /// `Some(Vec<NodeIndex>)` containing the path from `from` to `to` (inclusive),
285    /// or `None` if no path exists.
286    pub fn path_between(&self, from: NodeIndex, to: NodeIndex) -> Option<Vec<NodeIndex>> {
287        let g = self.graph.graph();
288
289        // Check if both nodes exist
290        if g.node_weight(from).is_none() || g.node_weight(to).is_none() {
291            return None;
292        }
293
294        // Same node case
295        if from == to {
296            return Some(vec![from]);
297        }
298
299        // Use A* with trivial heuristic (unit edge cost, zero estimate)
300        astar(
301            g,
302            from,
303            |n| n == to,
304            |_| 1usize, // Unit edge cost
305            |_| 0usize, // No heuristic (behaves like Dijkstra/BFS)
306        )
307        .map(|(_, path)| path)
308    }
309
310    /// Check if a subset of nodes forms a connected subgraph.
311    ///
312    /// Uses DFS to verify that all nodes in the subset are reachable from each other
313    /// within the induced subgraph.
314    ///
315    /// # Returns
316    /// `true` if the subset is connected (or empty), `false` otherwise.
317    pub fn is_connected_subset(&self, nodes: &HashSet<NodeIndex>) -> bool {
318        if nodes.is_empty() || nodes.len() == 1 {
319            return true;
320        }
321
322        let g = self.graph.graph();
323
324        // Start DFS from any node in the subset
325        let start = *nodes.iter().next().unwrap();
326        let mut seen = HashSet::new();
327        let mut stack = vec![start];
328        seen.insert(start);
329
330        while let Some(v) = stack.pop() {
331            for nb in g.neighbors(v) {
332                // Only follow edges within the subset
333                if nodes.contains(&nb) && seen.insert(nb) {
334                    stack.push(nb);
335                }
336            }
337        }
338
339        // Connected if we reached all nodes
340        seen.len() == nodes.len()
341    }
342
343    /// Compute the Steiner tree nodes spanning a set of terminal nodes.
344    ///
345    /// For tree graphs, the Steiner tree is the union of the unique paths from
346    /// one terminal to every other terminal.
347    ///
348    /// # Arguments
349    /// * `terminals` - Terminal node indices to span
350    ///
351    /// # Returns
352    /// The set of nodes in the minimal connected subtree spanning `terminals`.
353    ///
354    /// # Examples
355    ///
356    /// ```
357    /// use std::collections::HashSet;
358    /// use tensor4all_treetn::NodeNameNetwork;
359    ///
360    /// let mut net: NodeNameNetwork<String> = NodeNameNetwork::new();
361    /// let a = net.add_node("A".to_string()).unwrap();
362    /// let b = net.add_node("B".to_string()).unwrap();
363    /// let c = net.add_node("C".to_string()).unwrap();
364    /// net.add_edge(&"A".to_string(), &"B".to_string()).unwrap();
365    /// net.add_edge(&"B".to_string(), &"C".to_string()).unwrap();
366    ///
367    /// let steiner = net.steiner_tree_nodes(&[a, c].into_iter().collect::<HashSet<_>>());
368    /// assert_eq!(steiner, [a, b, c].into_iter().collect());
369    /// ```
370    pub fn steiner_tree_nodes(&self, terminals: &HashSet<NodeIndex>) -> HashSet<NodeIndex> {
371        if terminals.len() <= 1 {
372            return terminals.clone();
373        }
374
375        let terminals_vec: Vec<NodeIndex> = terminals.iter().copied().collect();
376        let root = terminals_vec[0];
377        let mut result = HashSet::new();
378        result.insert(root);
379
380        for &terminal in &terminals_vec[1..] {
381            if let Some(path) = self.path_between(root, terminal) {
382                result.extend(path);
383            }
384        }
385
386        result
387    }
388
389    /// Convert a node sequence to an edge sequence.
390    fn nodes_to_edges(nodes: &[NodeIndex]) -> CanonicalizeEdges {
391        if nodes.len() < 2 {
392            return CanonicalizeEdges::empty();
393        }
394        let edges: Vec<_> = nodes.windows(2).map(|w| (w[0], w[1])).collect();
395        CanonicalizeEdges::from_edges(edges)
396    }
397
398    /// Compute edges to canonicalize from current state to target.
399    ///
400    /// # Arguments
401    /// * `current_region` - Current ortho region (`None` = not canonicalized)
402    /// * `target` - Target node for the orthogonality center
403    ///
404    /// # Returns
405    /// Ordered `CanonicalizeEdges` to process for canonicalization.
406    pub fn edges_to_canonicalize(
407        &self,
408        current_region: Option<&HashSet<NodeIndex>>,
409        target: NodeIndex,
410    ) -> CanonicalizeEdges {
411        match current_region {
412            None => {
413                // Not canonicalized: compute parent edges for each node in post-order.
414                let post_order = self.post_order_dfs_by_index(target);
415                self.compute_parent_edges(&post_order, target)
416            }
417            Some(current) if current.contains(&target) => {
418                // Already at target: no-op
419                CanonicalizeEdges::empty()
420            }
421            Some(current) => {
422                // Move from current to target: find path
423                if let Some(&start) = current.iter().next() {
424                    if let Some(path) = self.path_between(start, target) {
425                        Self::nodes_to_edges(&path)
426                    } else {
427                        CanonicalizeEdges::empty()
428                    }
429                } else {
430                    CanonicalizeEdges::empty()
431                }
432            }
433        }
434    }
435
436    /// Compute edges to canonicalize from leaves to target, returning node names.
437    ///
438    /// This is similar to `edges_to_canonicalize(None, target)` but returns
439    /// `(from_name, to_name)` pairs instead of `(NodeIndex, NodeIndex)`.
440    ///
441    /// Useful for operations that work with two networks that have the same
442    /// topology but different NodeIndex values (e.g., contract_zipup).
443    ///
444    /// # Arguments
445    /// * `target` - Target node name for the orthogonality center
446    ///
447    /// # Returns
448    /// `None` if target node doesn't exist, otherwise a vector of `(from, to)` pairs
449    /// where `from` is the node being processed and `to` is its parent (towards target).
450    pub fn edges_to_canonicalize_by_names(
451        &self,
452        target: &NodeName,
453    ) -> Option<Vec<(NodeName, NodeName)>> {
454        let target_idx = self.node_index(target)?;
455        let edges = self.edges_to_canonicalize(None, target_idx);
456
457        let result: Vec<_> = edges
458            .into_iter()
459            .filter_map(|(from_idx, to_idx)| {
460                let from_name = self.node_name(from_idx)?.clone();
461                let to_name = self.node_name(to_idx)?.clone();
462                Some((from_name, to_name))
463            })
464            .collect();
465
466        Some(result)
467    }
468
469    /// Compute parent edges for each node in the given order.
470    fn compute_parent_edges(&self, nodes: &[NodeIndex], root: NodeIndex) -> CanonicalizeEdges {
471        let g = self.graph.graph();
472        let mut edges = Vec::with_capacity(nodes.len().saturating_sub(1));
473
474        // Build parent map using BFS from root
475        let mut parent: HashMap<NodeIndex, NodeIndex> = HashMap::new();
476        let mut visited = HashSet::new();
477        let mut queue = VecDeque::new();
478        queue.push_back(root);
479        visited.insert(root);
480
481        while let Some(node) = queue.pop_front() {
482            for neighbor in g.neighbors(node) {
483                if visited.insert(neighbor) {
484                    parent.insert(neighbor, node);
485                    queue.push_back(neighbor);
486                }
487            }
488        }
489
490        // For each node in order, add edge to its parent
491        for &node in nodes {
492            if node != root {
493                if let Some(&p) = parent.get(&node) {
494                    edges.push((node, p));
495                }
496            }
497        }
498
499        CanonicalizeEdges::from_edges(edges)
500    }
501
502    /// Compute edges to canonicalize from leaves towards a connected region (multiple centers).
503    ///
504    /// Given a set of target nodes forming a connected region, this function returns
505    /// all edges (src, dst) where:
506    /// - `src` is a node outside the target region
507    /// - `dst` is the next node towards the target region
508    ///
509    /// The edges are ordered so that nodes farther from the target region are processed first
510    /// (children before parents), which is the correct order for canonicalization.
511    ///
512    /// # Arguments
513    /// * `target_region` - Set of NodeIndex that forms the canonical center region
514    ///   (must be non-empty and connected)
515    ///
516    /// # Returns
517    /// `CanonicalizeEdges` with all edges pointing towards the target region.
518    /// Returns empty edges if target_region is empty.
519    ///
520    /// # Panics
521    /// Does not panic, but if target_region is disconnected, behavior is undefined
522    /// (may return partial results).
523    pub fn edges_to_canonicalize_to_region(
524        &self,
525        target_region: &HashSet<NodeIndex>,
526    ) -> CanonicalizeEdges {
527        if target_region.is_empty() {
528            return CanonicalizeEdges::empty();
529        }
530
531        let g = self.graph.graph();
532
533        // Multi-source BFS from target_region to compute distances and parent pointers
534        let mut dist: HashMap<NodeIndex, usize> = HashMap::new();
535        let mut parent: HashMap<NodeIndex, NodeIndex> = HashMap::new();
536        let mut queue = VecDeque::new();
537
538        // Initialize all target region nodes at distance 0
539        for &node in target_region {
540            dist.insert(node, 0);
541            queue.push_back(node);
542        }
543
544        // BFS to find distances and parents
545        while let Some(node) = queue.pop_front() {
546            let d = dist[&node];
547            for neighbor in g.neighbors(node) {
548                let was_new = !dist.contains_key(&neighbor);
549                dist.entry(neighbor).or_insert_with(|| {
550                    parent.insert(neighbor, node);
551                    d + 1
552                });
553                if was_new {
554                    queue.push_back(neighbor);
555                }
556            }
557        }
558
559        // Collect edges from nodes outside target region towards their parent
560        // Sort by distance (descending) so farther nodes are processed first
561        let mut node_dist_pairs: Vec<(NodeIndex, usize)> = dist
562            .iter()
563            .filter(|(node, _)| !target_region.contains(node))
564            .map(|(&node, &d)| (node, d))
565            .collect();
566
567        node_dist_pairs.sort_by_key(|pair| std::cmp::Reverse(pair.1)); // Descending by distance
568
569        let edges: Vec<(NodeIndex, NodeIndex)> = node_dist_pairs
570            .iter()
571            .filter_map(|(node, _)| {
572                let p = parent.get(node)?;
573                Some((*node, *p))
574            })
575            .collect();
576
577        CanonicalizeEdges::from_edges(edges)
578    }
579
580    /// Compute edges to canonicalize towards a region, returning node names.
581    ///
582    /// This is similar to `edges_to_canonicalize_to_region` but takes and returns
583    /// node names instead of NodeIndex.
584    ///
585    /// # Arguments
586    /// * `target_region` - Set of node names that forms the canonical center region
587    ///
588    /// # Returns
589    /// `None` if any target node doesn't exist, otherwise `Some(Vec<(from, to)>)`
590    /// where edges point towards the target region.
591    pub fn edges_to_canonicalize_to_region_by_names(
592        &self,
593        target_region: &HashSet<NodeName>,
594    ) -> Option<Vec<(NodeName, NodeName)>> {
595        // Convert node names to NodeIndex
596        let target_indices: HashSet<NodeIndex> = target_region
597            .iter()
598            .map(|name| self.node_index(name))
599            .collect::<Option<HashSet<_>>>()?;
600
601        let edges = self.edges_to_canonicalize_to_region(&target_indices);
602
603        let result: Vec<_> = edges
604            .into_iter()
605            .filter_map(|(from_idx, to_idx)| {
606                let from_name = self.node_name(from_idx)?.clone();
607                let to_name = self.node_name(to_idx)?.clone();
608                Some((from_name, to_name))
609            })
610            .collect();
611
612        Some(result)
613    }
614
615    /// Check if two networks have the same topology (same nodes and edges).
616    pub fn same_topology(&self, other: &Self) -> bool {
617        if self.node_count() != other.node_count() {
618            return false;
619        }
620        if self.edge_count() != other.edge_count() {
621            return false;
622        }
623
624        // Check all nodes exist in both
625        for name in self.node_names() {
626            if !other.has_node(name) {
627                return false;
628            }
629        }
630
631        // Check edges match (by checking neighbors for each node)
632        let self_graph = self.graph.graph();
633        for name in self.node_names() {
634            let self_idx = self.node_index(name).unwrap();
635            let other_idx = match other.node_index(name) {
636                Some(idx) => idx,
637                None => return false,
638            };
639
640            let self_neighbors: HashSet<_> = self_graph
641                .neighbors(self_idx)
642                .filter_map(|n| self.node_name(n))
643                .collect();
644
645            let other_graph = other.graph.graph();
646            let other_neighbors: HashSet<_> = other_graph
647                .neighbors(other_idx)
648                .filter_map(|n| other.node_name(n))
649                .collect();
650
651            if self_neighbors != other_neighbors {
652                return false;
653            }
654        }
655
656        true
657    }
658}
659
660impl<NodeName> Default for NodeNameNetwork<NodeName>
661where
662    NodeName: Clone + Hash + Eq + Send + Sync + Debug,
663{
664    fn default() -> Self {
665        Self::new()
666    }
667}
668
669#[cfg(test)]
670mod tests;