Skip to main content

tensor4all_treetn/treetn/
transform.rs

1//! Structural transformation operations for TreeTN.
2//!
3//! This module provides methods to transform a TreeTN's structure:
4//! - [`fuse_to`](TreeTN::fuse_to): Merge adjacent nodes to match a target structure
5//! - [`split_to`](TreeTN::split_to): Split nodes to match a target structure
6
7use std::collections::{HashMap, HashSet};
8use std::hash::Hash;
9
10use anyhow::{Context, Result};
11use petgraph::stable_graph::NodeIndex;
12
13use tensor4all_core::{
14    AllowedPairs, Canonical, FactorizeAlg, FactorizeOptions, IndexLike, TensorLike,
15};
16
17use super::TreeTN;
18use crate::options::SplitOptions;
19use crate::site_index_network::SiteIndexNetwork;
20
21impl<T, V> TreeTN<T, V>
22where
23    T: TensorLike,
24    <T::Index as IndexLike>::Id: Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
25    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
26{
27    /// Fuse (merge) adjacent nodes to match the target structure.
28    ///
29    /// This operation contracts adjacent nodes that should be merged according to
30    /// the target `SiteIndexNetwork`. The target structure must be a "coarsening"
31    /// of the current structure: each target node should contain the site indices
32    /// of one or more adjacent current nodes.
33    ///
34    /// # Algorithm
35    ///
36    /// 1. Compare current structure with target structure
37    /// 2. Map each current node to its target node (by matching site indices)
38    /// 3. For each group of current nodes mapping to the same target node:
39    ///    - Contract all nodes in the group into a single node
40    /// 4. Build the new TreeTN with the fused structure
41    ///
42    /// # Arguments
43    /// * `target` - The target `SiteIndexNetwork` defining the desired structure
44    ///
45    /// # Returns
46    /// A new TreeTN with the fused structure, or an error if:
47    /// - The target structure is incompatible with the current structure
48    /// - Nodes to be fused are not connected
49    ///
50    /// # Properties
51    /// - **Bond dimension**: Unchanged (pure contraction, no truncation)
52    /// - **Commutative**: Non-overlapping groups can be merged in any order
53    ///
54    /// # Example
55    /// ```text
56    /// Before: x1_1---x2_1---x1_2---x2_2---x1_3---x2_3  (6 nodes)
57    /// After:  {x1_1,x2_1}---{x1_2,x2_2}---{x1_3,x2_3}  (3 nodes)
58    /// ```
59    pub fn fuse_to<TargetV>(
60        &self,
61        target: &SiteIndexNetwork<TargetV, T::Index>,
62    ) -> Result<TreeTN<T, TargetV>>
63    where
64        TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
65    {
66        // Step 1: Build a mapping from site index ID to current node name
67        let mut site_to_current_node: HashMap<<T::Index as IndexLike>::Id, V> = HashMap::new();
68        for current_node_name in self.node_names() {
69            if let Some(site_space) = self.site_space(&current_node_name) {
70                for site_idx in site_space {
71                    site_to_current_node.insert(site_idx.id().clone(), current_node_name.clone());
72                }
73            }
74        }
75
76        // Step 2: For each target node, find which current nodes should be merged
77        // Map: target node name -> set of current node names
78        let mut target_to_current: HashMap<TargetV, HashSet<V>> = HashMap::new();
79
80        for target_node_name in target.node_names() {
81            let target_site_space = target.site_space(target_node_name).ok_or_else(|| {
82                anyhow::anyhow!("Target node {:?} has no site space", target_node_name)
83            })?;
84
85            let mut current_nodes_for_target: HashSet<V> = HashSet::new();
86            for target_site_idx in target_site_space {
87                if let Some(current_node) = site_to_current_node.get(target_site_idx.id()) {
88                    current_nodes_for_target.insert(current_node.clone());
89                }
90            }
91
92            if current_nodes_for_target.is_empty() {
93                return Err(anyhow::anyhow!(
94                    "Target node {:?} has site indices not found in current TreeTN",
95                    target_node_name
96                ))
97                .context("fuse_to: incompatible target structure");
98            }
99
100            target_to_current.insert(target_node_name.clone(), current_nodes_for_target);
101        }
102
103        // Step 3: Validate that every current node maps to exactly one target node
104        let mut current_to_target: HashMap<V, TargetV> = HashMap::new();
105        for (target_name, current_nodes) in &target_to_current {
106            for current_node in current_nodes {
107                if let Some(existing_target) = current_to_target.get(current_node) {
108                    return Err(anyhow::anyhow!(
109                        "Current node {:?} maps to multiple target nodes: {:?} and {:?}",
110                        current_node,
111                        existing_target,
112                        target_name
113                    ))
114                    .context("fuse_to: ambiguous mapping");
115                }
116                current_to_target.insert(current_node.clone(), target_name.clone());
117            }
118        }
119
120        // Check all current nodes are accounted for
121        for current_name in self.node_names() {
122            if !current_to_target.contains_key(&current_name) {
123                return Err(anyhow::anyhow!(
124                    "Current node {:?} has no corresponding target node",
125                    current_name
126                ))
127                .context("fuse_to: missing target for current node");
128            }
129        }
130
131        // Step 4: For each target node, contract all its current nodes into one tensor
132        let mut result_tensors: HashMap<TargetV, T> = HashMap::new();
133
134        for (target_name, current_nodes) in &target_to_current {
135            let contracted = self.contract_node_group(current_nodes).with_context(|| {
136                format!(
137                    "fuse_to: failed to contract nodes for target {:?}",
138                    target_name
139                )
140            })?;
141            result_tensors.insert(target_name.clone(), contracted);
142        }
143
144        // Step 5: Build the new TreeTN
145        // Sort target node names for deterministic ordering
146        let mut target_names: Vec<TargetV> = target.node_names().into_iter().cloned().collect();
147        target_names.sort();
148
149        let tensors: Vec<T> = target_names
150            .iter()
151            .map(|name| result_tensors.remove(name).unwrap())
152            .collect();
153
154        let result = TreeTN::<T, TargetV>::from_tensors(tensors, target_names)
155            .context("fuse_to: failed to build result TreeTN")?;
156
157        Ok(result)
158    }
159
160    /// Contract a group of nodes into a single tensor.
161    ///
162    /// The nodes must form a connected subtree in the current TreeTN.
163    /// Contracts all internal bonds (bonds between nodes in the group),
164    /// keeping external bonds and site indices.
165    fn contract_node_group(&self, nodes: &HashSet<V>) -> Result<T>
166    where
167        V: Ord,
168    {
169        if nodes.is_empty() {
170            return Err(anyhow::anyhow!("Cannot contract empty node group"));
171        }
172
173        // Convert node names to NodeIndex
174        let node_indices: HashSet<NodeIndex> = nodes
175            .iter()
176            .filter_map(|name| self.graph.node_index(name))
177            .collect();
178
179        if node_indices.len() != nodes.len() {
180            return Err(anyhow::anyhow!(
181                "Some nodes not found in graph: expected {} nodes, found {}",
182                nodes.len(),
183                node_indices.len()
184            ));
185        }
186
187        // Single node case: just clone the tensor
188        if nodes.len() == 1 {
189            let node_name = nodes.iter().next().unwrap();
190            let node_idx = self.graph.node_index(node_name).unwrap();
191            return self
192                .tensor(node_idx)
193                .cloned()
194                .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", node_name));
195        }
196
197        // Validate connectivity
198        if !self.site_index_network.is_connected_subset(&node_indices) {
199            return Err(anyhow::anyhow!(
200                "Nodes to contract do not form a connected subtree"
201            ));
202        }
203
204        // Pick a root (smallest node name for determinism)
205        let root_name = nodes.iter().min().unwrap();
206        let root_idx = self.graph.node_index(root_name).unwrap();
207
208        // Get edges within the group, ordered from leaves to root
209        let edges = self
210            .site_index_network
211            .edges_to_canonicalize(None, root_idx);
212
213        // Filter to only edges within our group
214        let internal_edges: Vec<(NodeIndex, NodeIndex)> = edges
215            .iter()
216            .filter(|(from, to)| node_indices.contains(from) && node_indices.contains(to))
217            .cloned()
218            .collect();
219
220        // Initialize with cloned tensors
221        let mut tensors: HashMap<NodeIndex, T> = node_indices
222            .iter()
223            .filter_map(|&idx| self.tensor(idx).cloned().map(|t| (idx, t)))
224            .collect();
225
226        // Contract along each internal edge (from leaves to root)
227        for (from, to) in internal_edges {
228            let from_tensor = tensors
229                .remove(&from)
230                .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", from))?;
231            let to_tensor = tensors
232                .remove(&to)
233                .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", to))?;
234
235            // Contract using TensorLike::contract
236            // (bond indices are auto-detected via is_contractable)
237            let contracted = T::contract(&[&to_tensor, &from_tensor], AllowedPairs::All)
238                .map_err(|e| anyhow::anyhow!("Failed to contract tensors: {}", e))?;
239
240            tensors.insert(to, contracted);
241        }
242
243        // The root tensor is the result
244        tensors
245            .remove(&root_idx)
246            .ok_or_else(|| anyhow::anyhow!("Contraction produced no result at root"))
247    }
248
249    /// Split nodes to match the target structure.
250    ///
251    /// This operation splits nodes that contain site indices belonging to multiple
252    /// target nodes. The target structure must be a "refinement" of the current
253    /// structure: each current node's site indices should map to one or more
254    /// target nodes.
255    ///
256    /// # Algorithm (Two-Phase Approach)
257    ///
258    /// **Phase 1: Exact factorization (no truncation)**
259    /// 1. Build mapping: site index ID -> target node name
260    /// 2. For each current node, check if its site indices map to multiple target nodes
261    /// 3. If so, split the node using QR factorization
262    /// 4. Repeat until all nodes match the target structure
263    ///
264    /// **Phase 2: Truncation sweep (optional)**
265    /// If `options.final_sweep` is true, perform a truncation sweep to optimize
266    /// bond dimensions globally.
267    ///
268    /// # Arguments
269    /// * `target` - The target `SiteIndexNetwork` defining the desired structure
270    /// * `options` - Options controlling truncation and final sweep
271    ///
272    /// # Returns
273    /// A new TreeTN with the split structure, or an error if:
274    /// - The target structure is incompatible with the current structure
275    /// - Factorization fails
276    ///
277    /// # Properties
278    /// - **Bond dimension**: May increase during split, controlled by truncation
279    /// - **Exact (Phase 1)**: Without truncation, represents the same tensor
280    ///
281    /// # Example
282    /// ```text
283    /// Before: {x1_1,x2_1}---{x1_2,x2_2}---{x1_3,x2_3}  (3 nodes, fused)
284    /// After:  x1_1---x2_1---x1_2---x2_2---x1_3---x2_3  (6 nodes, interleaved)
285    /// ```
286    pub fn split_to<TargetV>(
287        &self,
288        target: &SiteIndexNetwork<TargetV, T::Index>,
289        options: &SplitOptions,
290    ) -> Result<TreeTN<T, TargetV>>
291    where
292        TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
293    {
294        // Step 1: Build mapping from site index ID to target node name
295        let mut site_to_target: HashMap<<T::Index as IndexLike>::Id, TargetV> = HashMap::new();
296        for target_node_name in target.node_names() {
297            if let Some(site_space) = target.site_space(target_node_name) {
298                for site_idx in site_space {
299                    site_to_target.insert(site_idx.id().clone(), target_node_name.clone());
300                }
301            }
302        }
303
304        // Step 2: For each current node, determine which target nodes it maps to
305        // and validate the mapping
306        let mut current_to_targets: HashMap<V, HashSet<TargetV>> = HashMap::new();
307        for current_node_name in self.node_names() {
308            if let Some(site_space) = self.site_space(&current_node_name) {
309                let mut targets_for_node: HashSet<TargetV> = HashSet::new();
310                for site_idx in site_space {
311                    if let Some(target_name) = site_to_target.get(site_idx.id()) {
312                        targets_for_node.insert(target_name.clone());
313                    } else {
314                        return Err(anyhow::anyhow!(
315                            "Site index {:?} in current node {:?} has no corresponding target node",
316                            site_idx.id(),
317                            current_node_name
318                        ))
319                        .context("split_to: incompatible target structure");
320                    }
321                }
322                current_to_targets.insert(current_node_name.clone(), targets_for_node);
323            }
324        }
325
326        // Step 3: Phase 1 - Split all nodes that need splitting
327        // Collect all resulting tensors with their target node names
328        let mut result_tensors: Vec<(TargetV, T)> = Vec::new();
329
330        for current_node_name in self.node_names() {
331            let node_idx = self
332                .node_index(&current_node_name)
333                .ok_or_else(|| anyhow::anyhow!("Node {:?} not found", current_node_name))?;
334            let tensor = self
335                .tensor(node_idx)
336                .ok_or_else(|| anyhow::anyhow!("Tensor not found for {:?}", current_node_name))?;
337
338            let targets_for_node = current_to_targets.get(&current_node_name).ok_or_else(|| {
339                anyhow::anyhow!("No target mapping for node {:?}", current_node_name)
340            })?;
341
342            if targets_for_node.len() == 1 {
343                // No split needed - just relabel
344                let target_name = targets_for_node.iter().next().unwrap().clone();
345                result_tensors.push((target_name, tensor.clone()));
346            } else {
347                // Need to split this node
348                let split_tensors = self
349                    .split_tensor_for_targets(tensor, &site_to_target)
350                    .with_context(|| {
351                        format!("split_to: failed to split node {:?}", current_node_name)
352                    })?;
353                result_tensors.extend(split_tensors);
354            }
355        }
356
357        // Step 4: Build the result TreeTN with target node names
358        // Sort by target name for deterministic ordering
359        result_tensors.sort_by(|(a, _), (b, _)| a.cmp(b));
360
361        let names: Vec<TargetV> = result_tensors
362            .iter()
363            .map(|(name, _)| name.clone())
364            .collect();
365        let tensors: Vec<T> = result_tensors.into_iter().map(|(_, t)| t).collect();
366
367        let result = TreeTN::<T, TargetV>::from_tensors(tensors, names)
368            .context("split_to: failed to build result TreeTN")?;
369
370        // Step 5: Phase 2 - Optional truncation sweep
371        if options.final_sweep {
372            // Find a center node for truncation
373            let center = result.node_names().into_iter().min().ok_or_else(|| {
374                anyhow::anyhow!("split_to: no nodes in result for truncation sweep")
375            })?;
376
377            let truncation_options = crate::TruncationOptions {
378                form: options.form,
379                truncation: options.truncation,
380            };
381
382            return result
383                .truncate([center], truncation_options)
384                .context("split_to: truncation sweep failed");
385        }
386
387        Ok(result)
388    }
389
390    /// Split a tensor into multiple tensors, one for each target node.
391    ///
392    /// This uses QR factorization to iteratively separate site indices
393    /// belonging to different target nodes.
394    ///
395    /// Returns a vector of (target_name, tensor) pairs.
396    fn split_tensor_for_targets<TargetV>(
397        &self,
398        tensor: &T,
399        site_to_target: &HashMap<<T::Index as IndexLike>::Id, TargetV>,
400    ) -> Result<Vec<(TargetV, T)>>
401    where
402        TargetV: Clone + Hash + Eq + Ord + std::fmt::Debug,
403    {
404        // Group tensor's site indices by their target node
405        let mut partition: HashMap<TargetV, HashSet<<T::Index as IndexLike>::Id>> = HashMap::new();
406        for idx in tensor.external_indices() {
407            if let Some(target_name) = site_to_target.get(idx.id()) {
408                partition
409                    .entry(target_name.clone())
410                    .or_default()
411                    .insert(idx.id().clone());
412            }
413            // Note: bond indices (not in site_to_target) are handled by factorize
414        }
415
416        // Sort target names for deterministic processing
417        let mut target_names: Vec<TargetV> = partition.keys().cloned().collect();
418        target_names.sort();
419
420        if target_names.len() <= 1 {
421            // Should not happen if called correctly, but handle gracefully
422            let target_name = target_names
423                .first()
424                .cloned()
425                .ok_or_else(|| anyhow::anyhow!("No site indices found in tensor"))?;
426            return Ok(vec![(target_name, tensor.clone())]);
427        }
428
429        // Split iteratively: separate first target's indices, then next, etc.
430        let mut remaining_tensor = tensor.clone();
431        let mut result: Vec<(TargetV, T)> = Vec::new();
432
433        // Process all but the last target (the last one gets the remaining tensor)
434        for target_name in target_names.iter().take(target_names.len() - 1) {
435            let site_ids_for_target = partition.get(target_name).unwrap();
436
437            // Find the actual Index objects for these site IDs
438            let left_inds: Vec<_> = remaining_tensor
439                .external_indices()
440                .iter()
441                .filter(|idx| site_ids_for_target.contains(idx.id()))
442                .cloned()
443                .collect();
444
445            if left_inds.is_empty() {
446                continue;
447            }
448
449            // Factorize: separate these site indices
450            let factorize_options = FactorizeOptions {
451                alg: FactorizeAlg::QR,
452                canonical: Canonical::Left,
453                max_rank: None,
454                svd_policy: None,
455                qr_rtol: None,
456            };
457
458            let factorize_result = remaining_tensor
459                .factorize(&left_inds, &factorize_options)
460                .map_err(|e| anyhow::anyhow!("Factorization failed: {:?}", e))?;
461
462            // Left tensor gets the separated indices
463            result.push((target_name.clone(), factorize_result.left));
464
465            // Right tensor becomes the remaining tensor for next iteration
466            remaining_tensor = factorize_result.right;
467        }
468
469        // The last target gets the remaining tensor
470        let last_target = target_names.last().unwrap().clone();
471        result.push((last_target, remaining_tensor));
472
473        Ok(result)
474    }
475}
476
477// Tests are disabled until random module is refactored
478#[cfg(test)]
479mod tests;