tensor4all_itensorlike/
contract.rs1use tensor4all_treetn::treetn::contraction::{
7 contract as treetn_contract, ContractionMethod, ContractionOptions as TreeTNContractionOptions,
8};
9use tensor4all_treetn::CanonicalForm;
10
11use crate::error::{Result, TensorTrainError};
12use crate::options::{validate_svd_truncation_options, ContractMethod, ContractOptions};
13use crate::tensortrain::TensorTrain;
14
15pub fn contract(
34 a: &TensorTrain,
35 b: &TensorTrain,
36 options: &ContractOptions,
37) -> Result<TensorTrain> {
38 if a.is_empty() || b.is_empty() {
39 return Err(TensorTrainError::InvalidStructure {
40 message: "Cannot contract empty tensor trains".to_string(),
41 });
42 }
43
44 if a.len() != b.len() {
45 return Err(TensorTrainError::InvalidStructure {
46 message: format!(
47 "Tensor trains must have the same length for contraction: {} vs {}",
48 a.len(),
49 b.len()
50 ),
51 });
52 }
53
54 validate_svd_truncation_options(options.max_rank(), options.svd_policy())?;
55
56 if matches!(options.method(), ContractMethod::Fit) && !options.nhalfsweeps().is_multiple_of(2) {
57 return Err(TensorTrainError::OperationError {
58 message: format!(
59 "nhalfsweeps must be a multiple of 2 for Fit method, got {}",
60 options.nhalfsweeps()
61 ),
62 });
63 }
64
65 let treetn_method = match options.method() {
67 ContractMethod::Zipup => ContractionMethod::Zipup,
68 ContractMethod::Fit => ContractionMethod::Fit,
69 ContractMethod::Naive => ContractionMethod::Naive,
70 };
71
72 let nfullsweeps = options.nhalfsweeps() / 2;
74 let treetn_options = TreeTNContractionOptions::new(treetn_method).with_nfullsweeps(nfullsweeps);
75
76 let treetn_options = if let Some(max_rank) = options.max_rank() {
77 treetn_options.with_max_rank(max_rank)
78 } else {
79 treetn_options
80 };
81
82 let treetn_options = if let Some(policy) = options.svd_policy() {
83 treetn_options.with_svd_policy(policy)
84 } else {
85 treetn_options
86 };
87
88 let center = a.len() - 1;
90
91 let result_inner = if matches!(options.method(), ContractMethod::Zipup) {
92 a.as_treetn()
93 .contract_zipup_tree_accumulated(
94 b.as_treetn(),
95 ¢er,
96 CanonicalForm::Unitary,
97 options.svd_policy(),
98 options.max_rank(),
99 )
100 .map_err(|e| TensorTrainError::InvalidStructure {
101 message: format!("Zip-up contraction failed: {}", e),
102 })?
103 } else {
104 treetn_contract(a.as_treetn(), b.as_treetn(), ¢er, treetn_options).map_err(|e| {
105 TensorTrainError::InvalidStructure {
106 message: format!("TreeTN contraction failed: {}", e),
107 }
108 })?
109 };
110
111 TensorTrain::from_inner(result_inner, Some(CanonicalForm::Unitary))
112}
113
114impl TensorTrain {
115 pub fn contract(&self, other: &Self, options: &ContractOptions) -> Result<Self> {
133 contract(self, other, options)
134 }
135}
136
137#[cfg(test)]
138mod tests;