Skip to main content

tensor4all_treetn/treetn/
decompose.rs

1//! TreeTN decomposition from dense tensor.
2//!
3//! This module provides functions to decompose a dense tensor into a TreeTN
4//! using factorization algorithms.
5
6use std::collections::{HashMap, HashSet, VecDeque};
7use std::hash::Hash;
8
9use anyhow::Result;
10
11use tensor4all_core::{Canonical, FactorizeOptions, IndexLike, TensorLike};
12
13use super::TreeTN;
14
15// ============================================================================
16// TreeTopology specification
17// ============================================================================
18
19/// Specification for tree topology: defines nodes and index ID assignments.
20///
21/// `I` is the index ID type (e.g., `DynId`). Each node maps to the IDs of its
22/// physical indices in the input tensor. This ensures correct index lookup
23/// regardless of tensor index ordering, which can change during factorization.
24#[derive(Debug, Clone)]
25pub struct TreeTopology<V, I> {
26    /// Nodes in the tree (node name -> list of index IDs belonging to this node)
27    pub nodes: HashMap<V, Vec<I>>,
28    /// Edges in the tree: (node_a, node_b)
29    pub edges: Vec<(V, V)>,
30}
31
32impl<V: Clone + Hash + Eq, I: Clone + Eq> TreeTopology<V, I> {
33    /// Create a new tree topology with the given nodes and edges.
34    ///
35    /// # Arguments
36    /// * `nodes` - Map from node name to the index IDs belonging to that node
37    /// * `edges` - List of edges as (node_a, node_b) pairs
38    pub fn new(nodes: HashMap<V, Vec<I>>, edges: Vec<(V, V)>) -> Self {
39        Self { nodes, edges }
40    }
41
42    /// Validate that this topology describes a tree.
43    pub fn validate(&self) -> Result<()> {
44        let n = self.nodes.len();
45        if n == 0 {
46            return Err(anyhow::anyhow!("Tree topology must have at least one node"));
47        }
48        if n > 1 && self.edges.len() != n - 1 {
49            return Err(anyhow::anyhow!(
50                "Tree must have exactly n-1 edges: got {} nodes and {} edges",
51                n,
52                self.edges.len()
53            ));
54        }
55        // Check all edge endpoints are valid nodes
56        for (a, b) in &self.edges {
57            if !self.nodes.contains_key(a) {
58                return Err(anyhow::anyhow!("Edge refers to unknown node"));
59            }
60            if !self.nodes.contains_key(b) {
61                return Err(anyhow::anyhow!("Edge refers to unknown node"));
62            }
63        }
64        Ok(())
65    }
66}
67
68// ============================================================================
69// Decomposition functions
70// ============================================================================
71
72/// Decompose a dense tensor into a TreeTN using QR-based factorization.
73///
74/// This function takes a dense tensor and a tree topology specification, then
75/// recursively decomposes the tensor using QR factorization to create a TreeTN.
76///
77/// # Algorithm
78///
79/// 1. Start from a leaf node, factorize to separate that node's physical indices
80/// 2. Contract the right factor with remaining tensor, repeat for next edge
81/// 3. Continue until all edges are processed
82///
83/// # Arguments
84/// * `tensor` - The dense tensor to decompose
85/// * `topology` - Tree topology specifying nodes, edges, and physical index assignments
86///
87/// # Returns
88/// A TreeTN representing the decomposed tensor.
89///
90/// # Errors
91/// Returns an error if:
92/// - The topology is invalid
93/// - Physical index positions don't match the tensor
94/// - Factorization fails
95pub fn factorize_tensor_to_treetn<T, V>(
96    tensor: &T,
97    topology: &TreeTopology<V, <T::Index as IndexLike>::Id>,
98    root: &V,
99) -> Result<TreeTN<T, V>>
100where
101    T: TensorLike,
102    <T::Index as IndexLike>::Id: Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
103    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug + Ord,
104{
105    factorize_tensor_to_treetn_with(tensor, topology, FactorizeOptions::qr(), root)
106}
107
108/// Factorize a dense tensor into a TreeTN using specified factorization options.
109///
110/// This function takes a dense tensor and a tree topology specification, then
111/// recursively decomposes the tensor using the specified algorithm to create a TreeTN.
112///
113/// # Algorithm
114///
115/// 1. Start from a leaf node, factorize to separate that node's physical indices
116/// 2. Contract the right factor with remaining tensor, repeat for next edge
117/// 3. Continue until all edges are processed
118///
119/// # Arguments
120/// * `tensor` - The dense tensor to decompose
121/// * `topology` - Tree topology specifying nodes, edges, and physical index assignments
122/// * `options` - Factorization options (algorithm, max_rank, rtol, etc.)
123///
124/// # Returns
125/// A TreeTN representing the decomposed tensor.
126///
127/// # Errors
128/// Returns an error if:
129/// - The topology is invalid
130/// - Physical index positions don't match the tensor
131/// - Factorization fails
132pub fn factorize_tensor_to_treetn_with<T, V>(
133    tensor: &T,
134    topology: &TreeTopology<V, <T::Index as IndexLike>::Id>,
135    options: FactorizeOptions,
136    root: &V,
137) -> Result<TreeTN<T, V>>
138where
139    T: TensorLike,
140    <T::Index as IndexLike>::Id: Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
141    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug + Ord,
142{
143    factorize_tensor_to_treetn_with_root_impl(tensor, topology, options, root)
144}
145
146fn factorize_tensor_to_treetn_with_root_impl<T, V>(
147    tensor: &T,
148    topology: &TreeTopology<V, <T::Index as IndexLike>::Id>,
149    options: FactorizeOptions,
150    root: &V,
151) -> Result<TreeTN<T, V>>
152where
153    T: TensorLike,
154    <T::Index as IndexLike>::Id: Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
155    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug + Ord,
156{
157    topology.validate()?;
158
159    let tensor_indices = tensor.external_indices();
160
161    if topology.nodes.len() == 1 {
162        // Single node - just wrap the tensor
163        let node_name = topology.nodes.keys().next().unwrap().clone();
164        if &node_name != root {
165            return Err(anyhow::anyhow!("Requested root node not found in topology"));
166        }
167        let mut tn = TreeTN::<T, V>::new();
168        tn.add_tensor(node_name.clone(), tensor.clone())?;
169        tn.set_canonical_region([node_name])?;
170        return Ok(tn);
171    }
172
173    // Validate that all index IDs exist in the tensor
174    let tensor_ids: HashSet<_> = tensor_indices.iter().map(|idx| idx.id().clone()).collect();
175    for (node, ids) in &topology.nodes {
176        for id in ids {
177            if !tensor_ids.contains(id) {
178                return Err(anyhow::anyhow!(
179                    "Index ID {:?} for node {:?} not found in tensor (tensor has {} indices)",
180                    id,
181                    node,
182                    tensor_indices.len()
183                ));
184            }
185        }
186    }
187
188    // Validate that each physical index ID is assigned to at most one node.
189    // Duplicate assignment is almost always a topology specification bug and will
190    // lead to ambiguous/missing node tensors during decomposition.
191    let mut assigned_ids: HashSet<<T::Index as IndexLike>::Id> = HashSet::new();
192    for (node, ids) in &topology.nodes {
193        for id in ids {
194            if !assigned_ids.insert(id.clone()) {
195                return Err(anyhow::anyhow!(
196                    "Index ID {:?} is assigned to multiple nodes (at least {:?})",
197                    id,
198                    node
199                ));
200            }
201        }
202    }
203
204    // Build adjacency list for the tree
205    let mut adj: HashMap<V, Vec<V>> = HashMap::new();
206    for node in topology.nodes.keys() {
207        adj.insert(node.clone(), Vec::new());
208    }
209    for (a, b) in &topology.edges {
210        adj.get_mut(a).unwrap().push(b.clone());
211        adj.get_mut(b).unwrap().push(a.clone());
212    }
213    // Sort each adjacency list to ensure deterministic traversal order
214    for neighbors in adj.values_mut() {
215        neighbors.sort();
216    }
217
218    // Root is required. This ensures the norm-carrying tensor (the final `S*Vh` in
219    // a left-canonical decomposition) ends up on a caller-chosen node.
220    if !adj.contains_key(root) {
221        return Err(anyhow::anyhow!("Requested root node not found in topology"));
222    }
223
224    // Build traversal order using BFS from root
225    let mut traversal_order: Vec<(V, Option<V>)> = Vec::new(); // (node, parent)
226    let mut visited: HashSet<V> = HashSet::new();
227    let mut queue = VecDeque::new();
228    queue.push_back((root.clone(), None::<V>));
229
230    while let Some((node, parent)) = queue.pop_front() {
231        if visited.contains(&node) {
232            continue;
233        }
234        visited.insert(node.clone());
235        traversal_order.push((node.clone(), parent));
236
237        for neighbor in adj.get(&node).unwrap() {
238            if !visited.contains(neighbor) {
239                queue.push_back((neighbor.clone(), Some(node.clone())));
240            }
241        }
242    }
243
244    let mut children_by_parent: HashMap<V, Vec<V>> = HashMap::new();
245    for (node, parent) in &traversal_order {
246        if let Some(parent) = parent {
247            children_by_parent
248                .entry(parent.clone())
249                .or_default()
250                .push(node.clone());
251        }
252    }
253    for children in children_by_parent.values_mut() {
254        children.sort();
255    }
256
257    // Reverse traversal order to process leaves first (post-order)
258    traversal_order.reverse();
259
260    // Store intermediate tensors as we decompose
261    let mut current_tensor = tensor.clone();
262
263    // Store the resulting node tensors
264    let mut node_tensors: HashMap<V, T> = HashMap::new();
265    // Store the bond each processed child uses to connect to its parent.
266    let mut child_bonds: HashMap<V, T::Index> = HashMap::new();
267
268    // Use provided factorization options with Left canonical direction
269    let factorize_options = FactorizeOptions {
270        canonical: Canonical::Left,
271        ..options
272    };
273
274    // Process nodes in post-order (leaves first)
275    #[allow(clippy::needless_range_loop)]
276    for i in 0..traversal_order.len() - 1 {
277        let (node, _parent) = &traversal_order[i];
278        // Get the index IDs for this node
279        let node_ids = topology.nodes.get(node).unwrap();
280
281        // Keep this node's physical indices and the bonds to already-factorized
282        // children on the left side. This preserves the requested tree topology
283        // instead of collapsing all processed children directly into the root.
284        let current_indices = current_tensor.external_indices();
285        let mut desired_ids: HashSet<<T::Index as IndexLike>::Id> =
286            node_ids.iter().cloned().collect();
287        if let Some(children) = children_by_parent.get(node) {
288            for child in children {
289                let bond = child_bonds.get(child).ok_or_else(|| {
290                    anyhow::anyhow!(
291                        "Missing child bond for node {:?} while processing parent {:?}",
292                        child,
293                        node
294                    )
295                })?;
296                desired_ids.insert(bond.id().clone());
297            }
298        }
299        let left_inds: Vec<_> = current_indices
300            .iter()
301            .filter(|idx| desired_ids.contains(idx.id()))
302            .cloned()
303            .collect();
304
305        if left_inds.is_empty() && current_indices.len() > 1 {
306            // This indicates an inconsistent topology specification for the current tensor.
307            // Previously we "skipped" such nodes, but that can lead to missing tensors and
308            // panics later when building the TreeTN.
309            return Err(anyhow::anyhow!(
310                "No physical indices found for node {:?} (requested ids={:?}) in current tensor indices={:?}",
311                node,
312                node_ids,
313                current_indices
314                    .iter()
315                    .map(|idx| idx.id().clone())
316                    .collect::<Vec<_>>()
317            ));
318        }
319
320        // Perform factorization using TensorLike::factorize
321        // left will have the node's physical indices + bond index
322        // right will have bond index + remaining indices
323        let factorize_result = current_tensor
324            .factorize(&left_inds, &factorize_options)
325            .map_err(|e| anyhow::anyhow!("Factorization failed: {:?}", e))?;
326
327        let left_indices = factorize_result.left.external_indices();
328        let right_indices = factorize_result.right.external_indices();
329        let shared_bonds =
330            tensor4all_core::index_ops::common_inds::<T::Index>(&left_indices, &right_indices);
331        if shared_bonds.len() != 1 {
332            return Err(anyhow::anyhow!(
333                "Expected exactly one parent bond for node {:?}, found {}",
334                node,
335                shared_bonds.len()
336            ));
337        }
338        child_bonds.insert(node.clone(), shared_bonds[0].clone());
339
340        // Store left as the node's tensor (with physical indices + bond to parent)
341        node_tensors.insert(node.clone(), factorize_result.left);
342
343        // right becomes the current tensor for the next iteration
344        current_tensor = factorize_result.right;
345    }
346
347    // The last node (root) gets the remaining tensor
348    let (root_node, _) = &traversal_order.last().unwrap();
349    node_tensors.insert(root_node.clone(), current_tensor);
350
351    // Build the TreeTN using from_tensors (auto-connection by matching index IDs)
352    // Since factorize() returns shared bond_index, tensors already have matching index IDs
353    // IMPORTANT: Sort node_names to ensure deterministic ordering (HashMap iteration is non-deterministic)
354    let mut node_names: Vec<V> = topology.nodes.keys().cloned().collect();
355    node_names.sort();
356    let tensors: Vec<T> = node_names
357        .iter()
358        .map(|name| node_tensors.get(name).cloned().unwrap())
359        .collect();
360
361    let mut tn = TreeTN::from_tensors(tensors, node_names)?;
362    tn.set_canonical_region([root.clone()])?;
363    Ok(tn)
364}
365
366#[cfg(test)]
367mod tests;