1use petgraph::stable_graph::{EdgeIndex, NodeIndex};
11use std::collections::{HashMap, HashSet};
12use std::hash::Hash;
13
14use anyhow::{Context, Result};
15
16use crate::algorithm::CanonicalForm;
17use tensor4all_core::{
18 AllowedPairs, Canonical, FactorizeAlg, FactorizeOptions, IndexLike, SvdTruncationPolicy,
19 TensorLike,
20};
21
22use super::TreeTN;
23
24impl<T, V> TreeTN<T, V>
25where
26 T: TensorLike,
27 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
28{
29 pub fn sim_internal_inds(&self) -> Self {
38 let mut result = self.clone();
40
41 let edges: Vec<EdgeIndex> = result.graph.graph().edge_indices().collect();
43
44 for edge in edges {
45 let old_bond_idx = match result.bond_index(edge) {
47 Some(idx) => idx.clone(),
48 None => continue,
49 };
50
51 let new_bond_idx = old_bond_idx.sim();
53
54 let (node_a, node_b) = match result.graph.graph().edge_endpoints(edge) {
56 Some(endpoints) => endpoints,
57 None => continue,
58 };
59
60 if let Some(edge_weight) = result.graph.graph_mut().edge_weight_mut(edge) {
62 *edge_weight = new_bond_idx.clone();
63 }
64
65 if let Some(tensor_a) = result.graph.graph_mut().node_weight_mut(node_a) {
67 if let Ok(new_tensor) = tensor_a.replaceind(&old_bond_idx, &new_bond_idx) {
68 *tensor_a = new_tensor;
69 }
70 }
71
72 if let Some(tensor_b) = result.graph.graph_mut().node_weight_mut(node_b) {
74 if let Ok(new_tensor) = tensor_b.replaceind(&old_bond_idx, &new_bond_idx) {
75 *tensor_b = new_tensor;
76 }
77 }
78 }
79
80 result
81 }
82
83 pub fn contract_to_tensor(&self) -> Result<T>
129 where
130 V: Ord,
131 {
132 if self.node_count() == 0 {
133 return Err(anyhow::anyhow!("Cannot contract empty TreeTN"));
134 }
135
136 if self.node_count() == 1 {
137 let node = self
139 .graph
140 .graph()
141 .node_indices()
142 .next()
143 .ok_or_else(|| anyhow::anyhow!("No nodes found"))?;
144 return self
145 .tensor(node)
146 .cloned()
147 .ok_or_else(|| anyhow::anyhow!("Tensor not found"));
148 }
149
150 self.validate_tree()
152 .context("contract_to_tensor: graph must be a tree")?;
153
154 let root_name = self
156 .graph
157 .graph()
158 .node_indices()
159 .filter_map(|idx| self.graph.node_name(idx).cloned())
160 .min()
161 .ok_or_else(|| anyhow::anyhow!("No nodes found"))?;
162 let root = self
163 .graph
164 .node_index(&root_name)
165 .ok_or_else(|| anyhow::anyhow!("Root node not found"))?;
166
167 let edges = self.site_index_network.edges_to_canonicalize(None, root);
169
170 let mut tensors: HashMap<NodeIndex, T> = self
172 .graph
173 .graph()
174 .node_indices()
175 .filter_map(|n| self.tensor(n).cloned().map(|t| (n, t)))
176 .collect();
177
178 for (from, to) in edges {
180 let from_tensor = tensors
181 .remove(&from)
182 .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", from))?;
183 let to_tensor = tensors
184 .remove(&to)
185 .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", to))?;
186
187 let contracted = T::contract(&[&to_tensor, &from_tensor], AllowedPairs::All)
190 .context("Failed to contract along edge")?;
191 tensors.insert(to, contracted);
192 }
193
194 let result = tensors
196 .remove(&root)
197 .ok_or_else(|| anyhow::anyhow!("Contraction produced no result"))?;
198
199 let mut expected_indices: Vec<T::Index> = Vec::new();
202 let mut node_names: Vec<V> = self.node_names();
203 node_names.sort();
204 for node_name in &node_names {
205 if let Some(site_space) = self.site_space(node_name) {
206 expected_indices.extend(site_space.iter().cloned());
208 }
209 }
210
211 let current_indices = result.external_indices();
213
214 if current_indices.len() != expected_indices.len() {
216 return Ok(result);
218 }
219
220 let already_ordered = current_indices
222 .iter()
223 .zip(expected_indices.iter())
224 .all(|(c, e)| c == e);
225
226 if already_ordered {
227 return Ok(result);
228 }
229
230 result.permuteinds(&expected_indices)
233 }
234
235 pub fn contract_zipup(
259 &self,
260 other: &Self,
261 center: &V,
262 svd_policy: Option<SvdTruncationPolicy>,
263 max_rank: Option<usize>,
264 ) -> Result<Self>
265 where
266 V: Ord,
267 <T::Index as IndexLike>::Id:
268 Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
269 {
270 self.contract_zipup_with(other, center, CanonicalForm::Unitary, svd_policy, max_rank)
271 }
272
273 pub fn contract_zipup_with(
277 &self,
278 other: &Self,
279 center: &V,
280 form: CanonicalForm,
281 svd_policy: Option<SvdTruncationPolicy>,
282 max_rank: Option<usize>,
283 ) -> Result<Self>
284 where
285 V: Ord,
286 <T::Index as IndexLike>::Id:
287 Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
288 {
289 self.contract_zipup_tree_accumulated(other, center, form, svd_policy, max_rank)
290 }
291
292 pub fn contract_zipup_tree_accumulated(
314 &self,
315 other: &Self,
316 center: &V,
317 form: CanonicalForm,
318 svd_policy: Option<SvdTruncationPolicy>,
319 max_rank: Option<usize>,
320 ) -> Result<Self>
321 where
322 V: Ord,
323 <T::Index as IndexLike>::Id:
324 Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
325 {
326 if !self.same_topology(other) {
328 return Err(anyhow::anyhow!(
329 "contract_zipup_tree_accumulated: networks have incompatible topologies"
330 ));
331 }
332
333 let tn_a = self.sim_internal_inds();
335 let tn_b = other.sim_internal_inds();
336
337 let edges = tn_a.edges_to_canonicalize_by_names(center).ok_or_else(|| {
339 anyhow::anyhow!(
340 "contract_zipup_tree_accumulated: center node {:?} not found",
341 center
342 )
343 })?;
344
345 if edges.is_empty() && self.node_count() == 1 {
347 let node_idx = tn_a.graph.graph().node_indices().next().ok_or_else(|| {
348 anyhow::anyhow!("contract_zipup_tree_accumulated: no nodes found")
349 })?;
350 let t_a = tn_a.tensor(node_idx).ok_or_else(|| {
351 anyhow::anyhow!("contract_zipup_tree_accumulated: tensor not found in tn_a")
352 })?;
353 let t_b = tn_b
354 .tensor(tn_b.graph.graph().node_indices().next().ok_or_else(|| {
355 anyhow::anyhow!("contract_zipup_tree_accumulated: tensor not found in tn_b")
356 })?)
357 .ok_or_else(|| {
358 anyhow::anyhow!("contract_zipup_tree_accumulated: tensor not found in tn_b")
359 })?;
360
361 let contracted = T::contract(&[t_a, t_b], AllowedPairs::All)?;
362 let node_name = tn_a.graph.node_name(node_idx).ok_or_else(|| {
363 anyhow::anyhow!("contract_zipup_tree_accumulated: node name not found")
364 })?;
365
366 let mut result = TreeTN::new();
367 result.add_tensor(node_name.clone(), contracted)?;
368 result.set_canonical_region(std::iter::once(center.clone()))?;
369 return Ok(result);
370 }
371
372 let mut intermediate_tensors: HashMap<V, Vec<T>> = HashMap::new();
374
375 let mut result_tensors: HashMap<V, T> = HashMap::new();
377
378 let root_name = center.clone();
380
381 let get_bond_index = |tn: &TreeTN<T, V>, node_a: &V, node_b: &V| -> Result<T::Index> {
383 let edge = tn.edge_between(node_a, node_b).ok_or_else(|| {
384 anyhow::anyhow!("Edge not found between {:?} and {:?}", node_a, node_b)
385 })?;
386 tn.bond_index(edge)
387 .ok_or_else(|| anyhow::anyhow!("Bond index not found for edge"))
388 .cloned()
389 };
390
391 let alg = match form {
393 CanonicalForm::Unitary => FactorizeAlg::SVD,
394 CanonicalForm::LU => FactorizeAlg::LU,
395 CanonicalForm::CI => FactorizeAlg::CI,
396 };
397
398 let mut factorize_options = match alg {
399 FactorizeAlg::SVD => FactorizeOptions::svd(),
400 FactorizeAlg::QR => FactorizeOptions::qr(),
401 FactorizeAlg::LU => FactorizeOptions::lu(),
402 FactorizeAlg::CI => FactorizeOptions::ci(),
403 }
404 .with_canonical(Canonical::Left);
405
406 if let Some(max_rank) = max_rank {
407 factorize_options = factorize_options.with_max_rank(max_rank);
408 }
409 if let Some(policy) = svd_policy {
410 factorize_options = factorize_options.with_svd_policy(policy);
411 }
412 factorize_options
413 .validate()
414 .map_err(|err| anyhow::anyhow!("invalid zipup factorization options: {err}"))?;
415
416 for (source_name, destination_name) in &edges {
418 let node_a_idx = tn_a
420 .node_index(source_name)
421 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in tn_a", source_name))?;
422 let node_b_idx = tn_b
423 .node_index(source_name)
424 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in tn_b", source_name))?;
425
426 let tensor_a = tn_a
427 .tensor(node_a_idx)
428 .ok_or_else(|| {
429 anyhow::anyhow!("Tensor not found for node {:?} in tn_a", source_name)
430 })?
431 .clone();
432 let tensor_b = tn_b
433 .tensor(node_b_idx)
434 .ok_or_else(|| {
435 anyhow::anyhow!("Tensor not found for node {:?} in tn_b", source_name)
436 })?
437 .clone();
438
439 let is_leaf = !intermediate_tensors.contains_key(source_name)
441 || intermediate_tensors
442 .get(source_name)
443 .map(|v| v.is_empty())
444 .unwrap_or(true);
445
446 let c_temp = if is_leaf {
447 T::contract(&[&tensor_a, &tensor_b], AllowedPairs::All)
449 .context("Failed to contract leaf tensors")?
450 } else {
451 let mut tensor_list = Vec::new();
453 if let Some(r_list) = intermediate_tensors.remove(source_name) {
454 tensor_list.extend(r_list);
455 }
456 tensor_list.push(tensor_a);
457 tensor_list.push(tensor_b);
458 let tensor_refs: Vec<&T> = tensor_list.iter().collect();
459 T::contract(&tensor_refs, AllowedPairs::All)
460 .context("Failed to contract internal node tensors")?
461 };
462
463 let bond_to_dest_a = get_bond_index(&tn_a, source_name, destination_name)
465 .context("Failed to get bond index to destination in tn_a")?;
466 let bond_to_dest_b = get_bond_index(&tn_b, source_name, destination_name)
467 .context("Failed to get bond index to destination in tn_b")?;
468
469 let left_inds: Vec<_> = c_temp
471 .external_indices()
472 .into_iter()
473 .filter(|idx| {
474 *idx.id() != *bond_to_dest_a.id() && *idx.id() != *bond_to_dest_b.id()
475 })
476 .collect();
477
478 if left_inds.is_empty() {
479 intermediate_tensors
481 .entry(destination_name.clone())
482 .or_default()
483 .push(c_temp);
484 continue;
485 }
486
487 let factorize_result = c_temp
488 .factorize(&left_inds, &factorize_options)
489 .context("Failed to factorize")?;
490
491 result_tensors.insert(source_name.clone(), factorize_result.left);
493
494 intermediate_tensors
496 .entry(destination_name.clone())
497 .or_default()
498 .push(factorize_result.right);
499
500 }
502
503 if let Some(r_list) = intermediate_tensors.remove(&root_name) {
505 let root_a_idx = tn_a
507 .node_index(&root_name)
508 .ok_or_else(|| anyhow::anyhow!("Root node {:?} not found in tn_a", root_name))?;
509 let root_b_idx = tn_b
510 .node_index(&root_name)
511 .ok_or_else(|| anyhow::anyhow!("Root node {:?} not found in tn_b", root_name))?;
512
513 let root_tensor_a = tn_a
514 .tensor(root_a_idx)
515 .ok_or_else(|| anyhow::anyhow!("Root tensor not found in tn_a"))?
516 .clone();
517 let root_tensor_b = tn_b
518 .tensor(root_b_idx)
519 .ok_or_else(|| anyhow::anyhow!("Root tensor not found in tn_b"))?
520 .clone();
521
522 let mut tensor_list = r_list;
524 tensor_list.push(root_tensor_a);
525 tensor_list.push(root_tensor_b);
526 let tensor_refs: Vec<&T> = tensor_list.iter().collect();
527 let root_result = T::contract(&tensor_refs, AllowedPairs::All)
528 .context("Failed to contract root node tensors")?;
529
530 result_tensors.insert(root_name.clone(), root_result);
532 } else {
533 if !result_tensors.contains_key(&root_name) {
536 let root_a_idx = tn_a.node_index(&root_name).ok_or_else(|| {
537 anyhow::anyhow!("Root node {:?} not found in tn_a", root_name)
538 })?;
539 let root_b_idx = tn_b.node_index(&root_name).ok_or_else(|| {
540 anyhow::anyhow!("Root node {:?} not found in tn_b", root_name)
541 })?;
542
543 let root_tensor_a = tn_a
544 .tensor(root_a_idx)
545 .ok_or_else(|| anyhow::anyhow!("Root tensor not found in tn_a"))?;
546 let root_tensor_b = tn_b
547 .tensor(root_b_idx)
548 .ok_or_else(|| anyhow::anyhow!("Root tensor not found in tn_b"))?;
549
550 let root_result = T::contract(&[root_tensor_a, root_tensor_b], AllowedPairs::All)
551 .context("Failed to contract root node tensors")?;
552
553 result_tensors.insert(root_name.clone(), root_result);
554 }
555 }
556
557 let mut result = TreeTN::new();
559
560 for (node_name, tensor) in result_tensors {
562 result.add_tensor(node_name, tensor)?;
563 }
564
565 for (source_name, destination_name) in &edges {
567 if let (Some(node_a_idx), Some(node_b_idx)) = (
568 result.node_index(source_name),
569 result.node_index(destination_name),
570 ) {
571 let tensor_a = result.tensor(node_a_idx).unwrap();
572 let tensor_b = result.tensor(node_b_idx).unwrap();
573
574 use tensor4all_core::index_ops::common_inds;
576 let indices_a = tensor_a.external_indices();
577 let indices_b = tensor_b.external_indices();
578 let common = common_inds::<T::Index>(&indices_a, &indices_b);
579 if let Some(bond_idx) = common.first() {
580 result.connect_internal(node_a_idx, bond_idx, node_b_idx, bond_idx)?;
581 }
582 }
583 }
584
585 if result.node_index(center).is_some() {
587 result.set_canonical_region(std::iter::once(center.clone()))?;
588 }
589
590 Ok(result)
591 }
592
593 pub fn contract_naive(&self, other: &Self) -> Result<T>
614 where
615 V: Ord,
616 <T::Index as IndexLike>::Id:
617 Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
618 {
619 if !self.same_topology(other) {
621 return Err(anyhow::anyhow!(
622 "contract_naive: networks have incompatible topologies"
623 ));
624 }
625
626 let tn1 = self.sim_internal_inds();
628 let tn2 = other.sim_internal_inds();
629
630 let tensor1 = tn1
632 .contract_to_tensor()
633 .map_err(|e| anyhow::anyhow!("contract_naive: failed to contract tn1: {}", e))?;
634 let tensor2 = tn2
635 .contract_to_tensor()
636 .map_err(|e| anyhow::anyhow!("contract_naive: failed to contract tn2: {}", e))?;
637
638 T::contract(&[&tensor1, &tensor2], AllowedPairs::All)
641 }
642
643 pub fn validate_ortho_consistency(&self) -> Result<()> {
652 if self.canonical_region.is_empty() {
654 if !self.ortho_towards.is_empty() {
655 return Err(anyhow::anyhow!(
656 "Found {} ortho_towards entries but canonical_region is empty",
657 self.ortho_towards.len()
658 ))
659 .context(
660 "validate_ortho_consistency: canonical_region empty implies no ortho_towards",
661 );
662 }
663 return Ok(());
664 }
665
666 let mut center_indices = HashSet::new();
668 for c in &self.canonical_region {
669 let idx = self
670 .graph
671 .node_index(c)
672 .ok_or_else(|| anyhow::anyhow!("canonical_region node {:?} does not exist", c))?;
673 center_indices.insert(idx);
674 }
675
676 if !self.site_index_network.is_connected_subset(¢er_indices) {
678 return Err(anyhow::anyhow!("canonical_region is not connected")).context(
679 "validate_ortho_consistency: canonical_region must form a connected subtree",
680 );
681 }
682
683 let expected_edges = self
685 .site_index_network
686 .edges_to_canonicalize_to_region(¢er_indices);
687
688 let mut expected_directions: HashMap<T::Index, V> = HashMap::new();
690 for (src, dst) in expected_edges.iter() {
691 let edge = self
693 .graph
694 .graph()
695 .find_edge(*src, *dst)
696 .or_else(|| self.graph.graph().find_edge(*dst, *src))
697 .ok_or_else(|| anyhow::anyhow!("Edge not found between {:?} and {:?}", src, dst))?;
698
699 let bond = self
700 .bond_index(edge)
701 .ok_or_else(|| anyhow::anyhow!("Bond index not found for edge"))?
702 .clone();
703
704 let dst_name = self
706 .graph
707 .node_name(*dst)
708 .ok_or_else(|| anyhow::anyhow!("Node name not found for {:?}", dst))?
709 .clone();
710
711 expected_directions.insert(bond, dst_name);
712 }
713
714 for (bond, expected_dir) in &expected_directions {
716 match self.ortho_towards.get(bond) {
717 Some(actual_dir) => {
718 if actual_dir != expected_dir {
719 return Err(anyhow::anyhow!(
720 "ortho_towards for bond {:?} points to {:?} but expected {:?}",
721 bond,
722 actual_dir,
723 expected_dir
724 ))
725 .context("validate_ortho_consistency: wrong direction");
726 }
727 }
728 None => {
729 return Err(anyhow::anyhow!(
730 "ortho_towards for bond {:?} is missing, expected to point to {:?}",
731 bond,
732 expected_dir
733 ))
734 .context("validate_ortho_consistency: missing ortho_towards");
735 }
736 }
737 }
738
739 let bond_indices: HashSet<T::Index> = self
742 .graph
743 .graph()
744 .edge_indices()
745 .filter_map(|e| self.bond_index(e))
746 .cloned()
747 .collect();
748
749 for idx in self.ortho_towards.keys() {
750 if bond_indices.contains(idx) && !expected_directions.contains_key(idx) {
751 return Err(anyhow::anyhow!(
753 "Unexpected ortho_towards for bond {:?} (inside canonical_region)",
754 idx
755 ))
756 .context(
757 "validate_ortho_consistency: bonds inside center should not have ortho_towards",
758 );
759 }
760 }
761
762 Ok(())
763 }
764}
765
766fn find_common_indices<T: TensorLike>(a: &T, b: &T) -> Vec<T::Index>
772where
773 <T::Index as IndexLike>::Id: Eq + std::hash::Hash,
774{
775 let a_ids: HashSet<_> = a
776 .external_indices()
777 .iter()
778 .map(|i| i.id().clone())
779 .collect();
780 b.external_indices()
781 .into_iter()
782 .filter(|i| a_ids.contains(i.id()))
783 .collect()
784}
785
786use super::fit::FitContractionOptions;
791
792#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
794pub enum ContractionMethod {
795 #[default]
797 Zipup,
798 Fit,
800 Naive,
803}
804
805#[derive(Debug, Clone)]
807pub struct ContractionOptions {
808 pub method: ContractionMethod,
810 pub max_rank: Option<usize>,
812 pub svd_policy: Option<SvdTruncationPolicy>,
814 pub qr_rtol: Option<f64>,
816 pub nfullsweeps: usize,
820 pub convergence_tol: Option<f64>,
822 pub factorize_alg: FactorizeAlg,
824}
825
826impl Default for ContractionOptions {
827 fn default() -> Self {
828 Self {
829 method: ContractionMethod::default(),
830 max_rank: None,
831 svd_policy: None,
832 qr_rtol: None,
833 nfullsweeps: 1,
834 convergence_tol: None,
835 factorize_alg: FactorizeAlg::default(),
836 }
837 }
838}
839
840impl ContractionOptions {
841 pub fn new(method: ContractionMethod) -> Self {
843 Self {
844 method,
845 ..Default::default()
846 }
847 }
848
849 pub fn zipup() -> Self {
851 Self::new(ContractionMethod::Zipup)
852 }
853
854 pub fn fit() -> Self {
856 Self::new(ContractionMethod::Fit)
857 }
858
859 pub fn with_max_rank(mut self, max_rank: usize) -> Self {
861 self.max_rank = Some(max_rank);
862 self
863 }
864
865 pub fn with_svd_policy(mut self, policy: SvdTruncationPolicy) -> Self {
867 self.svd_policy = Some(policy);
868 self
869 }
870
871 pub fn with_qr_rtol(mut self, rtol: f64) -> Self {
873 self.qr_rtol = Some(rtol);
874 self
875 }
876
877 pub fn with_nfullsweeps(mut self, nfullsweeps: usize) -> Self {
879 self.nfullsweeps = nfullsweeps;
880 self
881 }
882
883 pub fn with_convergence_tol(mut self, tol: f64) -> Self {
885 self.convergence_tol = Some(tol);
886 self
887 }
888
889 pub fn with_factorize_alg(mut self, alg: FactorizeAlg) -> Self {
891 self.factorize_alg = alg;
892 self
893 }
894}
895
896pub fn contract<T, V>(
901 tn_a: &TreeTN<T, V>,
902 tn_b: &TreeTN<T, V>,
903 center: &V,
904 options: ContractionOptions,
905) -> Result<TreeTN<T, V>>
906where
907 T: TensorLike,
908 <T::Index as IndexLike>::Id: Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
909 V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
910{
911 match options.method {
912 ContractionMethod::Zipup => {
913 tn_a.contract_zipup(tn_b, center, options.svd_policy, options.max_rank)
914 }
915 ContractionMethod::Fit => {
916 let fit_options = FitContractionOptions::new(options.nfullsweeps)
917 .with_factorize_alg(options.factorize_alg);
918 let fit_options = if let Some(max_rank) = options.max_rank {
919 fit_options.with_max_rank(max_rank)
920 } else {
921 fit_options
922 };
923 let fit_options = if let Some(policy) = options.svd_policy {
924 fit_options.with_svd_policy(policy)
925 } else {
926 fit_options
927 };
928 let fit_options = if let Some(qr_rtol) = options.qr_rtol {
929 fit_options.with_qr_rtol(qr_rtol)
930 } else {
931 fit_options
932 };
933 let fit_options = if let Some(tol) = options.convergence_tol {
934 fit_options.with_convergence_tol(tol)
935 } else {
936 fit_options
937 };
938 super::fit::contract_fit(tn_a, tn_b, center, fit_options)
939 }
940 ContractionMethod::Naive => contract_naive_to_treetn(
941 tn_a,
942 tn_b,
943 center,
944 options.max_rank,
945 options.svd_policy,
946 options.qr_rtol,
947 ),
948 }
949}
950
951#[allow(clippy::too_many_arguments)]
960pub fn contract_naive_to_treetn<T, V>(
961 tn_a: &TreeTN<T, V>,
962 tn_b: &TreeTN<T, V>,
963 center: &V,
964 _max_rank: Option<usize>,
965 _svd_policy: Option<SvdTruncationPolicy>,
966 _qr_rtol: Option<f64>,
967) -> Result<TreeTN<T, V>>
968where
969 T: TensorLike,
970 <T::Index as IndexLike>::Id: Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
971 V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
972{
973 let contracted_tensor = tn_a.contract_naive(tn_b)?;
975
976 if contracted_tensor.external_indices().is_empty() {
978 let mut tn = TreeTN::<T, V>::new();
979 tn.add_tensor(center.clone(), contracted_tensor)?;
980 tn.set_canonical_region([center.clone()])?;
981 return Ok(tn);
982 }
983
984 use super::decompose::factorize_tensor_to_treetn_with;
986
987 let mut nodes: HashMap<V, Vec<<T::Index as IndexLike>::Id>> = HashMap::new();
991 let contracted_indices = contracted_tensor.external_indices();
992 let contracted_ids: HashSet<_> = contracted_indices
993 .iter()
994 .map(|ci| ci.id().clone())
995 .collect();
996
997 let mut node_names: Vec<_> = tn_a.node_names();
999 node_names.sort();
1000
1001 for node_name in &node_names {
1002 let mut ids: Vec<<T::Index as IndexLike>::Id> = Vec::new();
1003
1004 if let Some(site_space_a) = tn_a.site_index_network.site_space(node_name) {
1006 for site_idx in site_space_a {
1007 if contracted_ids.contains(site_idx.id()) {
1008 ids.push(site_idx.id().clone());
1009 }
1010 }
1011 }
1012
1013 if let Some(site_space_b) = tn_b.site_index_network.site_space(node_name) {
1015 for site_idx in site_space_b {
1016 if contracted_ids.contains(site_idx.id()) && !ids.contains(site_idx.id()) {
1017 ids.push(site_idx.id().clone());
1018 }
1019 }
1020 }
1021
1022 nodes.insert(node_name.clone(), ids);
1023 }
1024
1025 let edges: Vec<(V, V)> = tn_a
1027 .graph
1028 .graph()
1029 .edge_indices()
1030 .filter_map(|e| {
1031 let (src, dst) = tn_a.graph.graph().edge_endpoints(e)?;
1032 let src_name = tn_a.graph.node_name(src)?;
1033 let dst_name = tn_a.graph.node_name(dst)?;
1034 Some((src_name.clone(), dst_name.clone()))
1035 })
1036 .collect();
1037
1038 let topology = super::decompose::TreeTopology::new(nodes, edges);
1039
1040 let result = factorize_tensor_to_treetn_with(
1042 &contracted_tensor,
1043 &topology,
1044 FactorizeOptions::svd(),
1045 center,
1046 )?;
1047
1048 Ok(result)
1049}
1050
1051#[cfg(test)]
1052mod tests;