Skip to main content

tensor4all_treetn/treetn/
contraction.rs

1//! Contraction and operations for TreeTN.
2//!
3//! This module provides methods for:
4//! - Replacing internal indices with fresh IDs (`sim_internal_inds`)
5//! - Contracting TreeTN to tensor (`contract_to_tensor`)
6//! - Zip-up contraction (`contract_zipup`)
7//! - Naive contraction (`contract_naive`)
8//! - Validation (`validate_ortho_consistency`)
9
10use petgraph::stable_graph::{EdgeIndex, NodeIndex};
11use std::collections::{HashMap, HashSet};
12use std::hash::Hash;
13
14use anyhow::{Context, Result};
15
16use crate::algorithm::CanonicalForm;
17use tensor4all_core::{
18    AllowedPairs, Canonical, FactorizeAlg, FactorizeOptions, IndexLike, SvdTruncationPolicy,
19    TensorLike,
20};
21
22use super::TreeTN;
23
24impl<T, V> TreeTN<T, V>
25where
26    T: TensorLike,
27    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
28{
29    /// Create a copy with all internal (link/bond) indices replaced by fresh IDs.
30    ///
31    /// External (site/physical) indices remain unchanged. This is useful when
32    /// contracting two TreeTNs that might have overlapping internal index IDs.
33    ///
34    /// # Returns
35    /// A new TreeTN with all bond indices replaced by `sim` indices (same dimension,
36    /// new unique ID).
37    pub fn sim_internal_inds(&self) -> Self {
38        // Clone the structure
39        let mut result = self.clone();
40
41        // For each edge, create a sim index and update both the edge and tensors
42        let edges: Vec<EdgeIndex> = result.graph.graph().edge_indices().collect();
43
44        for edge in edges {
45            // Get the current bond index
46            let old_bond_idx = match result.bond_index(edge) {
47                Some(idx) => idx.clone(),
48                None => continue,
49            };
50
51            // Create a new sim index (same dimension, new ID)
52            let new_bond_idx = old_bond_idx.sim();
53
54            // Get the endpoint nodes
55            let (node_a, node_b) = match result.graph.graph().edge_endpoints(edge) {
56                Some(endpoints) => endpoints,
57                None => continue,
58            };
59
60            // Update the edge weight
61            if let Some(edge_weight) = result.graph.graph_mut().edge_weight_mut(edge) {
62                *edge_weight = new_bond_idx.clone();
63            }
64
65            // Update tensor at node_a
66            if let Some(tensor_a) = result.graph.graph_mut().node_weight_mut(node_a) {
67                if let Ok(new_tensor) = tensor_a.replaceind(&old_bond_idx, &new_bond_idx) {
68                    *tensor_a = new_tensor;
69                }
70            }
71
72            // Update tensor at node_b
73            if let Some(tensor_b) = result.graph.graph_mut().node_weight_mut(node_b) {
74                if let Ok(new_tensor) = tensor_b.replaceind(&old_bond_idx, &new_bond_idx) {
75                    *tensor_b = new_tensor;
76                }
77            }
78        }
79
80        result
81    }
82
83    /// Contract the TreeTN to a single tensor.
84    ///
85    /// This method contracts all tensors in the network into a single tensor
86    /// containing all physical indices. The contraction is performed using
87    /// an edge-based order (post-order DFS edges towards root), processing
88    /// each edge in sequence and using Connection information to identify
89    /// which indices to contract.
90    ///
91    /// The result has only site (physical) indices; all bond indices are summed out.
92    /// See also [`to_dense`](Self::to_dense), which is an alias for this method.
93    ///
94    /// # Returns
95    /// A single tensor representing the full contraction of the network.
96    ///
97    /// # Errors
98    /// Returns an error if:
99    /// - The network is empty
100    /// - The graph is not a valid tree
101    /// - Tensor contraction fails
102    ///
103    /// # Examples
104    ///
105    /// ```
106    /// use tensor4all_treetn::TreeTN;
107    /// use tensor4all_core::{DynIndex, TensorDynLen, TensorIndex, TensorLike};
108    ///
109    /// let s0 = DynIndex::new_dyn(2);
110    /// let bond = DynIndex::new_dyn(2);
111    /// let s1 = DynIndex::new_dyn(2);
112    ///
113    /// let t0 = TensorDynLen::from_dense(
114    ///     vec![s0.clone(), bond.clone()],
115    ///     vec![1.0_f64, 0.0, 0.0, 1.0],
116    /// ).unwrap();
117    /// let t1 = TensorDynLen::from_dense(
118    ///     vec![bond, s1.clone()],
119    ///     vec![1.0_f64, 0.0, 0.0, 1.0],
120    /// ).unwrap();
121    ///
122    /// let tn = TreeTN::<_, usize>::from_tensors(vec![t0, t1], vec![0, 1]).unwrap();
123    /// let dense = tn.contract_to_tensor().unwrap();
124    ///
125    /// // Result has only site indices
126    /// assert_eq!(dense.num_external_indices(), 2);
127    /// ```
128    pub fn contract_to_tensor(&self) -> Result<T>
129    where
130        V: Ord,
131    {
132        if self.node_count() == 0 {
133            return Err(anyhow::anyhow!("Cannot contract empty TreeTN"));
134        }
135
136        if self.node_count() == 1 {
137            // Single node - just return a clone of its tensor
138            let node = self
139                .graph
140                .graph()
141                .node_indices()
142                .next()
143                .ok_or_else(|| anyhow::anyhow!("No nodes found"))?;
144            return self
145                .tensor(node)
146                .cloned()
147                .ok_or_else(|| anyhow::anyhow!("Tensor not found"));
148        }
149
150        // Validate tree structure
151        self.validate_tree()
152            .context("contract_to_tensor: graph must be a tree")?;
153
154        // Choose a deterministic root (minimum node name)
155        let root_name = self
156            .graph
157            .graph()
158            .node_indices()
159            .filter_map(|idx| self.graph.node_name(idx).cloned())
160            .min()
161            .ok_or_else(|| anyhow::anyhow!("No nodes found"))?;
162        let root = self
163            .graph
164            .node_index(&root_name)
165            .ok_or_else(|| anyhow::anyhow!("Root node not found"))?;
166
167        // Get edges to process (post-order DFS edges towards root)
168        let edges = self.site_index_network.edges_to_canonicalize(None, root);
169
170        // Initialize with original tensors
171        let mut tensors: HashMap<NodeIndex, T> = self
172            .graph
173            .graph()
174            .node_indices()
175            .filter_map(|n| self.tensor(n).cloned().map(|t| (n, t)))
176            .collect();
177
178        // Process each edge: contract tensor at `from` into tensor at `to`
179        for (from, to) in edges {
180            let from_tensor = tensors
181                .remove(&from)
182                .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", from))?;
183            let to_tensor = tensors
184                .remove(&to)
185                .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", to))?;
186
187            // Contract and store result at `to`
188            // (bond indices are auto-detected via is_contractable)
189            let contracted = T::contract(&[&to_tensor, &from_tensor], AllowedPairs::All)
190                .context("Failed to contract along edge")?;
191            tensors.insert(to, contracted);
192        }
193
194        // The root's tensor is the final result
195        let result = tensors
196            .remove(&root)
197            .ok_or_else(|| anyhow::anyhow!("Contraction produced no result"))?;
198
199        // Permute result indices to match canonical site index order:
200        // node names sorted, then site indices per node in consistent order
201        let mut expected_indices: Vec<T::Index> = Vec::new();
202        let mut node_names: Vec<V> = self.node_names();
203        node_names.sort();
204        for node_name in &node_names {
205            if let Some(site_space) = self.site_space(node_name) {
206                // Use site indices in insertion order (deterministic from site_space)
207                expected_indices.extend(site_space.iter().cloned());
208            }
209        }
210
211        // Get current index order from result tensor
212        let current_indices = result.external_indices();
213
214        // Check if permutation is needed
215        if current_indices.len() != expected_indices.len() {
216            // This shouldn't happen, but return as-is if sizes don't match
217            return Ok(result);
218        }
219
220        // Check if already in correct order
221        let already_ordered = current_indices
222            .iter()
223            .zip(expected_indices.iter())
224            .all(|(c, e)| c == e);
225
226        if already_ordered {
227            return Ok(result);
228        }
229
230        // Build permutation: for each expected index, find its position in current indices
231        // Then use replaceind to reorder (permuting indices)
232        result.permuteinds(&expected_indices)
233    }
234
235    /// Contract two TreeTNs with the same topology using the zip-up algorithm.
236    ///
237    /// The zip-up algorithm traverses from leaves towards the center, contracting
238    /// corresponding nodes from both networks and optionally truncating at each step.
239    ///
240    /// # Algorithm
241    /// 1. Replace internal (bond) indices of both networks with fresh IDs to avoid collision
242    /// 2. Traverse from leaves towards center
243    /// 3. At each edge (child → parent):
244    ///    - Contract the child tensors from both networks (along their shared site indices)
245    ///    - Factorize, keeping site indices + parent bond on left (canonical form)
246    ///    - Store left factor as child tensor in result
247    ///    - Contract right factor into parent tensor
248    /// 4. Contract the final center tensors
249    ///
250    /// # Arguments
251    /// * `other` - The other TreeTN to contract with (must have same topology)
252    /// * `center` - The center node name towards which to contract
253    /// * `svd_policy` - Optional SVD truncation policy
254    /// * `max_rank` - Optional maximum bond dimension
255    ///
256    /// # Returns
257    /// The contracted TreeTN result, or an error if topologies don't match or contraction fails.
258    pub fn contract_zipup(
259        &self,
260        other: &Self,
261        center: &V,
262        svd_policy: Option<SvdTruncationPolicy>,
263        max_rank: Option<usize>,
264    ) -> Result<Self>
265    where
266        V: Ord,
267        <T::Index as IndexLike>::Id:
268            Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
269    {
270        self.contract_zipup_with(other, center, CanonicalForm::Unitary, svd_policy, max_rank)
271    }
272
273    /// Contract two TreeTNs with the same topology using the zip-up algorithm with a specified form.
274    ///
275    /// See [`contract_zipup`](Self::contract_zipup) for details.
276    pub fn contract_zipup_with(
277        &self,
278        other: &Self,
279        center: &V,
280        form: CanonicalForm,
281        svd_policy: Option<SvdTruncationPolicy>,
282        max_rank: Option<usize>,
283    ) -> Result<Self>
284    where
285        V: Ord,
286        <T::Index as IndexLike>::Id:
287            Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
288    {
289        self.contract_zipup_tree_accumulated(other, center, form, svd_policy, max_rank)
290    }
291
292    /// Contract two TreeTNs using zip-up algorithm with accumulated intermediate tensors.
293    ///
294    /// This is an improved version of zip-up contraction that maintains intermediate tensors
295    /// (environment tensors) as it processes from leaves towards the root, similar to
296    /// ITensors.jl's MPO zip-up algorithm.
297    ///
298    /// # Algorithm
299    /// 1. Process leaves: contract `A[leaf] * B[leaf]`, factorize, store R at parent
300    /// 2. Process internal nodes: contract `[R_accumulated..., A[node], B[node]]`, factorize, store R\_new at parent
301    /// 3. Process root: contract `[R_list..., A[root], B[root]]`, store as final result
302    /// 4. Set canonical center
303    ///
304    /// # Arguments
305    /// * `other` - The other TreeTN to contract with (must have same topology)
306    /// * `center` - The center node name towards which to contract
307    /// * `form` - Canonical form (Unitary/LU/CI)
308    /// * `svd_policy` - Optional SVD truncation policy
309    /// * `max_rank` - Optional maximum bond dimension
310    ///
311    /// # Returns
312    /// The contracted TreeTN result, or an error if topologies don't match or contraction fails.
313    pub fn contract_zipup_tree_accumulated(
314        &self,
315        other: &Self,
316        center: &V,
317        form: CanonicalForm,
318        svd_policy: Option<SvdTruncationPolicy>,
319        max_rank: Option<usize>,
320    ) -> Result<Self>
321    where
322        V: Ord,
323        <T::Index as IndexLike>::Id:
324            Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
325    {
326        // 1. Verify topologies are compatible
327        if !self.same_topology(other) {
328            return Err(anyhow::anyhow!(
329                "contract_zipup_tree_accumulated: networks have incompatible topologies"
330            ));
331        }
332
333        // 2. Replace internal indices with fresh IDs to avoid collision
334        let tn_a = self.sim_internal_inds();
335        let tn_b = other.sim_internal_inds();
336
337        // 3. Get traversal edges from leaves to center (post-order DFS)
338        let edges = tn_a.edges_to_canonicalize_by_names(center).ok_or_else(|| {
339            anyhow::anyhow!(
340                "contract_zipup_tree_accumulated: center node {:?} not found",
341                center
342            )
343        })?;
344
345        // 4. Handle single node case
346        if edges.is_empty() && self.node_count() == 1 {
347            let node_idx = tn_a.graph.graph().node_indices().next().ok_or_else(|| {
348                anyhow::anyhow!("contract_zipup_tree_accumulated: no nodes found")
349            })?;
350            let t_a = tn_a.tensor(node_idx).ok_or_else(|| {
351                anyhow::anyhow!("contract_zipup_tree_accumulated: tensor not found in tn_a")
352            })?;
353            let t_b = tn_b
354                .tensor(tn_b.graph.graph().node_indices().next().ok_or_else(|| {
355                    anyhow::anyhow!("contract_zipup_tree_accumulated: tensor not found in tn_b")
356                })?)
357                .ok_or_else(|| {
358                    anyhow::anyhow!("contract_zipup_tree_accumulated: tensor not found in tn_b")
359                })?;
360
361            let contracted = T::contract(&[t_a, t_b], AllowedPairs::All)?;
362            let node_name = tn_a.graph.node_name(node_idx).ok_or_else(|| {
363                anyhow::anyhow!("contract_zipup_tree_accumulated: node name not found")
364            })?;
365
366            let mut result = TreeTN::new();
367            result.add_tensor(node_name.clone(), contracted)?;
368            result.set_canonical_region(std::iter::once(center.clone()))?;
369            return Ok(result);
370        }
371
372        // 5. Initialize intermediate tensors storage: HashMap<node_name, Vec<intermediate_tensor>>
373        let mut intermediate_tensors: HashMap<V, Vec<T>> = HashMap::new();
374
375        // 6. Initialize result tensors: HashMap<node_name, tensor>
376        let mut result_tensors: HashMap<V, T> = HashMap::new();
377
378        // 7. Determine which nodes are leaves (for processing logic)
379        let root_name = center.clone();
380
381        // Helper: Get bond index between two nodes
382        let get_bond_index = |tn: &TreeTN<T, V>, node_a: &V, node_b: &V| -> Result<T::Index> {
383            let edge = tn.edge_between(node_a, node_b).ok_or_else(|| {
384                anyhow::anyhow!("Edge not found between {:?} and {:?}", node_a, node_b)
385            })?;
386            tn.bond_index(edge)
387                .ok_or_else(|| anyhow::anyhow!("Bond index not found for edge"))
388                .cloned()
389        };
390
391        // 8. Set up factorization options based on form
392        let alg = match form {
393            CanonicalForm::Unitary => FactorizeAlg::SVD,
394            CanonicalForm::LU => FactorizeAlg::LU,
395            CanonicalForm::CI => FactorizeAlg::CI,
396        };
397
398        let mut factorize_options = match alg {
399            FactorizeAlg::SVD => FactorizeOptions::svd(),
400            FactorizeAlg::QR => FactorizeOptions::qr(),
401            FactorizeAlg::LU => FactorizeOptions::lu(),
402            FactorizeAlg::CI => FactorizeOptions::ci(),
403        }
404        .with_canonical(Canonical::Left);
405
406        if let Some(max_rank) = max_rank {
407            factorize_options = factorize_options.with_max_rank(max_rank);
408        }
409        if let Some(policy) = svd_policy {
410            factorize_options = factorize_options.with_svd_policy(policy);
411        }
412        factorize_options
413            .validate()
414            .map_err(|err| anyhow::anyhow!("invalid zipup factorization options: {err}"))?;
415
416        // 9. Process edges from leaves towards root
417        for (source_name, destination_name) in &edges {
418            // Get tensors from both networks
419            let node_a_idx = tn_a
420                .node_index(source_name)
421                .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in tn_a", source_name))?;
422            let node_b_idx = tn_b
423                .node_index(source_name)
424                .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in tn_b", source_name))?;
425
426            let tensor_a = tn_a
427                .tensor(node_a_idx)
428                .ok_or_else(|| {
429                    anyhow::anyhow!("Tensor not found for node {:?} in tn_a", source_name)
430                })?
431                .clone();
432            let tensor_b = tn_b
433                .tensor(node_b_idx)
434                .ok_or_else(|| {
435                    anyhow::anyhow!("Tensor not found for node {:?} in tn_b", source_name)
436                })?
437                .clone();
438
439            // Check if this is a leaf node (no intermediate tensors accumulated yet)
440            let is_leaf = !intermediate_tensors.contains_key(source_name)
441                || intermediate_tensors
442                    .get(source_name)
443                    .map(|v| v.is_empty())
444                    .unwrap_or(true);
445
446            let c_temp = if is_leaf {
447                // Leaf node: contract A[source] * B[source]
448                T::contract(&[&tensor_a, &tensor_b], AllowedPairs::All)
449                    .context("Failed to contract leaf tensors")?
450            } else {
451                // Internal node: contract [R_accumulated..., A[source], B[source]]
452                let mut tensor_list = Vec::new();
453                if let Some(r_list) = intermediate_tensors.remove(source_name) {
454                    tensor_list.extend(r_list);
455                }
456                tensor_list.push(tensor_a);
457                tensor_list.push(tensor_b);
458                let tensor_refs: Vec<&T> = tensor_list.iter().collect();
459                T::contract(&tensor_refs, AllowedPairs::All)
460                    .context("Failed to contract internal node tensors")?
461            };
462
463            // Factorize child tensor and pass the right factor to destination (even if destination is root)
464            let bond_to_dest_a = get_bond_index(&tn_a, source_name, destination_name)
465                .context("Failed to get bond index to destination in tn_a")?;
466            let bond_to_dest_b = get_bond_index(&tn_b, source_name, destination_name)
467                .context("Failed to get bond index to destination in tn_b")?;
468
469            // left_inds = all indices except the two parent bonds (keep site + child bonds)
470            let left_inds: Vec<_> = c_temp
471                .external_indices()
472                .into_iter()
473                .filter(|idx| {
474                    *idx.id() != *bond_to_dest_a.id() && *idx.id() != *bond_to_dest_b.id()
475                })
476                .collect();
477
478            if left_inds.is_empty() {
479                // If no left indices remain, pass the tensor directly to destination
480                intermediate_tensors
481                    .entry(destination_name.clone())
482                    .or_default()
483                    .push(c_temp);
484                continue;
485            }
486
487            let factorize_result = c_temp
488                .factorize(&left_inds, &factorize_options)
489                .context("Failed to factorize")?;
490
491            // Store left factor as result tensor for source node
492            result_tensors.insert(source_name.clone(), factorize_result.left);
493
494            // Store right factor (intermediate tensor R) at destination
495            intermediate_tensors
496                .entry(destination_name.clone())
497                .or_default()
498                .push(factorize_result.right);
499
500            // Note: bond index update will be handled when building the result TreeTN
501        }
502
503        // 9.5. Process root node (if it has intermediate tensors accumulated)
504        if let Some(r_list) = intermediate_tensors.remove(&root_name) {
505            // Get root tensors from both networks
506            let root_a_idx = tn_a
507                .node_index(&root_name)
508                .ok_or_else(|| anyhow::anyhow!("Root node {:?} not found in tn_a", root_name))?;
509            let root_b_idx = tn_b
510                .node_index(&root_name)
511                .ok_or_else(|| anyhow::anyhow!("Root node {:?} not found in tn_b", root_name))?;
512
513            let root_tensor_a = tn_a
514                .tensor(root_a_idx)
515                .ok_or_else(|| anyhow::anyhow!("Root tensor not found in tn_a"))?
516                .clone();
517            let root_tensor_b = tn_b
518                .tensor(root_b_idx)
519                .ok_or_else(|| anyhow::anyhow!("Root tensor not found in tn_b"))?
520                .clone();
521
522            // Contract [R_list..., A[root], B[root]]
523            let mut tensor_list = r_list;
524            tensor_list.push(root_tensor_a);
525            tensor_list.push(root_tensor_b);
526            let tensor_refs: Vec<&T> = tensor_list.iter().collect();
527            let root_result = T::contract(&tensor_refs, AllowedPairs::All)
528                .context("Failed to contract root node tensors")?;
529
530            // Store root result (no factorization needed)
531            result_tensors.insert(root_name.clone(), root_result);
532        } else {
533            // No intermediate tensors: root is a single node or already processed
534            // Check if root tensors need to be contracted
535            if !result_tensors.contains_key(&root_name) {
536                let root_a_idx = tn_a.node_index(&root_name).ok_or_else(|| {
537                    anyhow::anyhow!("Root node {:?} not found in tn_a", root_name)
538                })?;
539                let root_b_idx = tn_b.node_index(&root_name).ok_or_else(|| {
540                    anyhow::anyhow!("Root node {:?} not found in tn_b", root_name)
541                })?;
542
543                let root_tensor_a = tn_a
544                    .tensor(root_a_idx)
545                    .ok_or_else(|| anyhow::anyhow!("Root tensor not found in tn_a"))?;
546                let root_tensor_b = tn_b
547                    .tensor(root_b_idx)
548                    .ok_or_else(|| anyhow::anyhow!("Root tensor not found in tn_b"))?;
549
550                let root_result = T::contract(&[root_tensor_a, root_tensor_b], AllowedPairs::All)
551                    .context("Failed to contract root node tensors")?;
552
553                result_tensors.insert(root_name.clone(), root_result);
554            }
555        }
556
557        // 10. Build result TreeTN
558        let mut result = TreeTN::new();
559
560        // Add all result tensors
561        for (node_name, tensor) in result_tensors {
562            result.add_tensor(node_name, tensor)?;
563        }
564
565        // Connect nodes based on original topology
566        for (source_name, destination_name) in &edges {
567            if let (Some(node_a_idx), Some(node_b_idx)) = (
568                result.node_index(source_name),
569                result.node_index(destination_name),
570            ) {
571                let tensor_a = result.tensor(node_a_idx).unwrap();
572                let tensor_b = result.tensor(node_b_idx).unwrap();
573
574                // Find the common index (should be the bond index)
575                use tensor4all_core::index_ops::common_inds;
576                let indices_a = tensor_a.external_indices();
577                let indices_b = tensor_b.external_indices();
578                let common = common_inds::<T::Index>(&indices_a, &indices_b);
579                if let Some(bond_idx) = common.first() {
580                    result.connect_internal(node_a_idx, bond_idx, node_b_idx, bond_idx)?;
581                }
582            }
583        }
584
585        // 11. Set canonical center
586        if result.node_index(center).is_some() {
587            result.set_canonical_region(std::iter::once(center.clone()))?;
588        }
589
590        Ok(result)
591    }
592
593    /// Contract two TreeTNs using naive full contraction.
594    ///
595    /// This is a reference implementation that:
596    /// 1. Replaces internal indices with fresh IDs (sim_internal_inds)
597    /// 2. Converts both TreeTNs to full tensors
598    /// 3. Contracts along common site indices
599    ///
600    /// The result is a single tensor, not a TreeTN. This is useful for:
601    /// - Testing correctness of more sophisticated algorithms like `contract_zipup`
602    /// - Computing exact results for small networks
603    ///
604    /// # Arguments
605    /// * `other` - The other TreeTN to contract with (must have same topology)
606    ///
607    /// # Returns
608    /// A tensor representing the contracted result.
609    ///
610    /// # Note
611    /// This method is O(exp(n)) in both time and memory where n is the number of nodes.
612    /// Use `contract_zipup` for efficient contraction of large networks.
613    pub fn contract_naive(&self, other: &Self) -> Result<T>
614    where
615        V: Ord,
616        <T::Index as IndexLike>::Id:
617            Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
618    {
619        // 1. Verify topologies are compatible
620        if !self.same_topology(other) {
621            return Err(anyhow::anyhow!(
622                "contract_naive: networks have incompatible topologies"
623            ));
624        }
625
626        // 2. Replace internal indices with fresh IDs to avoid collision
627        let tn1 = self.sim_internal_inds();
628        let tn2 = other.sim_internal_inds();
629
630        // 3. Convert both networks to full tensors
631        let tensor1 = tn1
632            .contract_to_tensor()
633            .map_err(|e| anyhow::anyhow!("contract_naive: failed to contract tn1: {}", e))?;
634        let tensor2 = tn2
635            .contract_to_tensor()
636            .map_err(|e| anyhow::anyhow!("contract_naive: failed to contract tn2: {}", e))?;
637
638        // 4. Contract along common indices
639        // T::contract auto-contracts all is_contractable pairs
640        T::contract(&[&tensor1, &tensor2], AllowedPairs::All)
641    }
642
643    /// Validate that `canonical_region` and edge `ortho_towards` are consistent.
644    ///
645    /// Rules:
646    /// - If `canonical_region` is empty (not canonicalized), all indices must have `ortho_towards == None`.
647    /// - If `canonical_region` is non-empty:
648    ///   - It must form a connected subtree
649    ///   - All edges from outside the center region must have `ortho_towards` pointing towards the center
650    ///   - Edges entirely inside the center region may have `ortho_towards == None`
651    pub fn validate_ortho_consistency(&self) -> Result<()> {
652        // If not canonicalized, require no ortho_towards at all
653        if self.canonical_region.is_empty() {
654            if !self.ortho_towards.is_empty() {
655                return Err(anyhow::anyhow!(
656                    "Found {} ortho_towards entries but canonical_region is empty",
657                    self.ortho_towards.len()
658                ))
659                .context(
660                    "validate_ortho_consistency: canonical_region empty implies no ortho_towards",
661                );
662            }
663            return Ok(());
664        }
665
666        // Validate all canonical_region nodes exist and convert to NodeIndex
667        let mut center_indices = HashSet::new();
668        for c in &self.canonical_region {
669            let idx = self
670                .graph
671                .node_index(c)
672                .ok_or_else(|| anyhow::anyhow!("canonical_region node {:?} does not exist", c))?;
673            center_indices.insert(idx);
674        }
675
676        // Check canonical_region connectivity
677        if !self.site_index_network.is_connected_subset(&center_indices) {
678            return Err(anyhow::anyhow!("canonical_region is not connected")).context(
679                "validate_ortho_consistency: canonical_region must form a connected subtree",
680            );
681        }
682
683        // Get expected edges from edges_to_canonicalize_to_region
684        let expected_edges = self
685            .site_index_network
686            .edges_to_canonicalize_to_region(&center_indices);
687
688        // Build a set of expected (bond, expected_direction) pairs
689        let mut expected_directions: HashMap<T::Index, V> = HashMap::new();
690        for (src, dst) in expected_edges.iter() {
691            // Find the edge between src and dst
692            let edge = self
693                .graph
694                .graph()
695                .find_edge(*src, *dst)
696                .or_else(|| self.graph.graph().find_edge(*dst, *src))
697                .ok_or_else(|| anyhow::anyhow!("Edge not found between {:?} and {:?}", src, dst))?;
698
699            let bond = self
700                .bond_index(edge)
701                .ok_or_else(|| anyhow::anyhow!("Bond index not found for edge"))?
702                .clone();
703
704            // The expected ortho_towards direction is dst (towards center)
705            let dst_name = self
706                .graph
707                .node_name(*dst)
708                .ok_or_else(|| anyhow::anyhow!("Node name not found for {:?}", dst))?
709                .clone();
710
711            expected_directions.insert(bond, dst_name);
712        }
713
714        // Verify all expected directions are present in ortho_towards
715        for (bond, expected_dir) in &expected_directions {
716            match self.ortho_towards.get(bond) {
717                Some(actual_dir) => {
718                    if actual_dir != expected_dir {
719                        return Err(anyhow::anyhow!(
720                            "ortho_towards for bond {:?} points to {:?} but expected {:?}",
721                            bond,
722                            actual_dir,
723                            expected_dir
724                        ))
725                        .context("validate_ortho_consistency: wrong direction");
726                    }
727                }
728                None => {
729                    return Err(anyhow::anyhow!(
730                        "ortho_towards for bond {:?} is missing, expected to point to {:?}",
731                        bond,
732                        expected_dir
733                    ))
734                    .context("validate_ortho_consistency: missing ortho_towards");
735                }
736            }
737        }
738
739        // Verify no unexpected bond ortho_towards entries
740        // (site index ortho_towards are allowed even if not in expected_directions)
741        let bond_indices: HashSet<T::Index> = self
742            .graph
743            .graph()
744            .edge_indices()
745            .filter_map(|e| self.bond_index(e))
746            .cloned()
747            .collect();
748
749        for idx in self.ortho_towards.keys() {
750            if bond_indices.contains(idx) && !expected_directions.contains_key(idx) {
751                // This is a bond inside the canonical_region - should not have ortho_towards
752                return Err(anyhow::anyhow!(
753                    "Unexpected ortho_towards for bond {:?} (inside canonical_region)",
754                    idx
755                ))
756                .context(
757                    "validate_ortho_consistency: bonds inside center should not have ortho_towards",
758                );
759            }
760        }
761
762        Ok(())
763    }
764}
765
766// ============================================================================
767// Helper functions
768// ============================================================================
769
770/// Find common indices between two tensors (by ID).
771fn find_common_indices<T: TensorLike>(a: &T, b: &T) -> Vec<T::Index>
772where
773    <T::Index as IndexLike>::Id: Eq + std::hash::Hash,
774{
775    let a_ids: HashSet<_> = a
776        .external_indices()
777        .iter()
778        .map(|i| i.id().clone())
779        .collect();
780    b.external_indices()
781        .into_iter()
782        .filter(|i| a_ids.contains(i.id()))
783        .collect()
784}
785
786// ============================================================================
787// Contraction Method Dispatcher
788// ============================================================================
789
790use super::fit::FitContractionOptions;
791
792/// Contraction method for TreeTN operations.
793#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
794pub enum ContractionMethod {
795    /// Zip-up contraction (faster, one-pass).
796    #[default]
797    Zipup,
798    /// Fit/variational contraction (iterative optimization).
799    Fit,
800    /// Naive contraction: contract to full tensor, then decompose back to TreeTN.
801    /// Useful for debugging and testing, but O(exp(n)) in memory.
802    Naive,
803}
804
805/// Options for the generic contract function.
806#[derive(Debug, Clone)]
807pub struct ContractionOptions {
808    /// Contraction method to use.
809    pub method: ContractionMethod,
810    /// Maximum bond dimension (optional).
811    pub max_rank: Option<usize>,
812    /// Explicit SVD truncation policy (optional).
813    pub svd_policy: Option<SvdTruncationPolicy>,
814    /// QR-specific relative tolerance (optional).
815    pub qr_rtol: Option<f64>,
816    /// Number of full sweeps for Fit method.
817    ///
818    /// A full sweep visits each edge twice (forward and backward) using an Euler tour.
819    pub nfullsweeps: usize,
820    /// Convergence tolerance for Fit method (None = fixed sweeps).
821    pub convergence_tol: Option<f64>,
822    /// Factorization algorithm for Fit method.
823    pub factorize_alg: FactorizeAlg,
824}
825
826impl Default for ContractionOptions {
827    fn default() -> Self {
828        Self {
829            method: ContractionMethod::default(),
830            max_rank: None,
831            svd_policy: None,
832            qr_rtol: None,
833            nfullsweeps: 1,
834            convergence_tol: None,
835            factorize_alg: FactorizeAlg::default(),
836        }
837    }
838}
839
840impl ContractionOptions {
841    /// Create options with specified method.
842    pub fn new(method: ContractionMethod) -> Self {
843        Self {
844            method,
845            ..Default::default()
846        }
847    }
848
849    /// Create options for zipup contraction.
850    pub fn zipup() -> Self {
851        Self::new(ContractionMethod::Zipup)
852    }
853
854    /// Create options for fit contraction.
855    pub fn fit() -> Self {
856        Self::new(ContractionMethod::Fit)
857    }
858
859    /// Set maximum bond dimension.
860    pub fn with_max_rank(mut self, max_rank: usize) -> Self {
861        self.max_rank = Some(max_rank);
862        self
863    }
864
865    /// Set the SVD truncation policy.
866    pub fn with_svd_policy(mut self, policy: SvdTruncationPolicy) -> Self {
867        self.svd_policy = Some(policy);
868        self
869    }
870
871    /// Set the QR-specific relative tolerance.
872    pub fn with_qr_rtol(mut self, rtol: f64) -> Self {
873        self.qr_rtol = Some(rtol);
874        self
875    }
876
877    /// Set number of full sweeps for Fit method.
878    pub fn with_nfullsweeps(mut self, nfullsweeps: usize) -> Self {
879        self.nfullsweeps = nfullsweeps;
880        self
881    }
882
883    /// Set convergence tolerance for Fit method.
884    pub fn with_convergence_tol(mut self, tol: f64) -> Self {
885        self.convergence_tol = Some(tol);
886        self
887    }
888
889    /// Set factorization algorithm for Fit method.
890    pub fn with_factorize_alg(mut self, alg: FactorizeAlg) -> Self {
891        self.factorize_alg = alg;
892        self
893    }
894}
895
896/// Contract two TreeTNs using the specified method.
897///
898/// This is the main entry point for TreeTN contraction. It dispatches to the
899/// appropriate algorithm based on the options.
900pub fn contract<T, V>(
901    tn_a: &TreeTN<T, V>,
902    tn_b: &TreeTN<T, V>,
903    center: &V,
904    options: ContractionOptions,
905) -> Result<TreeTN<T, V>>
906where
907    T: TensorLike,
908    <T::Index as IndexLike>::Id: Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
909    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
910{
911    match options.method {
912        ContractionMethod::Zipup => {
913            tn_a.contract_zipup(tn_b, center, options.svd_policy, options.max_rank)
914        }
915        ContractionMethod::Fit => {
916            let fit_options = FitContractionOptions::new(options.nfullsweeps)
917                .with_factorize_alg(options.factorize_alg);
918            let fit_options = if let Some(max_rank) = options.max_rank {
919                fit_options.with_max_rank(max_rank)
920            } else {
921                fit_options
922            };
923            let fit_options = if let Some(policy) = options.svd_policy {
924                fit_options.with_svd_policy(policy)
925            } else {
926                fit_options
927            };
928            let fit_options = if let Some(qr_rtol) = options.qr_rtol {
929                fit_options.with_qr_rtol(qr_rtol)
930            } else {
931                fit_options
932            };
933            let fit_options = if let Some(tol) = options.convergence_tol {
934                fit_options.with_convergence_tol(tol)
935            } else {
936                fit_options
937            };
938            super::fit::contract_fit(tn_a, tn_b, center, fit_options)
939        }
940        ContractionMethod::Naive => contract_naive_to_treetn(
941            tn_a,
942            tn_b,
943            center,
944            options.max_rank,
945            options.svd_policy,
946            options.qr_rtol,
947        ),
948    }
949}
950
951/// Contract two TreeTNs using naive contraction, then decompose back to TreeTN.
952///
953/// This method:
954/// 1. Contracts both networks to full tensors
955/// 2. Contracts the tensors along common (site) indices
956/// 3. Decomposes the result back to a TreeTN using the original topology
957///
958/// This is O(exp(n)) in memory and is primarily useful for debugging and testing.
959#[allow(clippy::too_many_arguments)]
960pub fn contract_naive_to_treetn<T, V>(
961    tn_a: &TreeTN<T, V>,
962    tn_b: &TreeTN<T, V>,
963    center: &V,
964    _max_rank: Option<usize>,
965    _svd_policy: Option<SvdTruncationPolicy>,
966    _qr_rtol: Option<f64>,
967) -> Result<TreeTN<T, V>>
968where
969    T: TensorLike,
970    <T::Index as IndexLike>::Id: Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
971    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
972{
973    // 1. Contract to full tensor using existing contract_naive
974    let contracted_tensor = tn_a.contract_naive(tn_b)?;
975
976    // Handle rank-0 (scalar) result: wrap directly in a single-node TreeTN
977    if contracted_tensor.external_indices().is_empty() {
978        let mut tn = TreeTN::<T, V>::new();
979        tn.add_tensor(center.clone(), contracted_tensor)?;
980        tn.set_canonical_region([center.clone()])?;
981        return Ok(tn);
982    }
983
984    // 2. Build topology from tn_a's structure and decompose
985    use super::decompose::factorize_tensor_to_treetn_with;
986
987    // Build topology using index IDs (not positions).
988    // Consider site indices from BOTH tn_a and tn_b, since the contracted result
989    // may contain indices from either network (non-contracted ones remain).
990    let mut nodes: HashMap<V, Vec<<T::Index as IndexLike>::Id>> = HashMap::new();
991    let contracted_indices = contracted_tensor.external_indices();
992    let contracted_ids: HashSet<_> = contracted_indices
993        .iter()
994        .map(|ci| ci.id().clone())
995        .collect();
996
997    // Collect node names in sorted order for deterministic assignment
998    let mut node_names: Vec<_> = tn_a.node_names();
999    node_names.sort();
1000
1001    for node_name in &node_names {
1002        let mut ids: Vec<<T::Index as IndexLike>::Id> = Vec::new();
1003
1004        // Collect remaining site indices from tn_a at this node
1005        if let Some(site_space_a) = tn_a.site_index_network.site_space(node_name) {
1006            for site_idx in site_space_a {
1007                if contracted_ids.contains(site_idx.id()) {
1008                    ids.push(site_idx.id().clone());
1009                }
1010            }
1011        }
1012
1013        // Also collect remaining site indices from tn_b at the same node
1014        if let Some(site_space_b) = tn_b.site_index_network.site_space(node_name) {
1015            for site_idx in site_space_b {
1016                if contracted_ids.contains(site_idx.id()) && !ids.contains(site_idx.id()) {
1017                    ids.push(site_idx.id().clone());
1018                }
1019            }
1020        }
1021
1022        nodes.insert(node_name.clone(), ids);
1023    }
1024
1025    // Get edges from the graph
1026    let edges: Vec<(V, V)> = tn_a
1027        .graph
1028        .graph()
1029        .edge_indices()
1030        .filter_map(|e| {
1031            let (src, dst) = tn_a.graph.graph().edge_endpoints(e)?;
1032            let src_name = tn_a.graph.node_name(src)?;
1033            let dst_name = tn_a.graph.node_name(dst)?;
1034            Some((src_name.clone(), dst_name.clone()))
1035        })
1036        .collect();
1037
1038    let topology = super::decompose::TreeTopology::new(nodes, edges);
1039
1040    // 3. Decompose back to TreeTN
1041    let result = factorize_tensor_to_treetn_with(
1042        &contracted_tensor,
1043        &topology,
1044        FactorizeOptions::svd(),
1045        center,
1046    )?;
1047
1048    Ok(result)
1049}
1050
1051#[cfg(test)]
1052mod tests;