Skip to main content

tensor4all_itensorlike/
contract.rs

1//! Contraction operations for tensor trains.
2//!
3//! This module provides both a free function [`contract`] and an impl method
4//! [`TensorTrain::contract`] for contracting two tensor trains.
5
6use tensor4all_treetn::treetn::contraction::{
7    contract as treetn_contract, ContractionMethod, ContractionOptions as TreeTNContractionOptions,
8};
9use tensor4all_treetn::CanonicalForm;
10
11use crate::error::{Result, TensorTrainError};
12use crate::options::{validate_svd_truncation_options, ContractMethod, ContractOptions};
13use crate::tensortrain::TensorTrain;
14
15/// Contract two tensor trains, returning a new tensor train.
16///
17/// This performs element-wise contraction of corresponding sites,
18/// similar to MPO-MPO contraction in ITensor.
19///
20/// # Arguments
21/// * `a` - The first tensor train
22/// * `b` - The second tensor train
23/// * `options` - Contraction options (method, max_rank, rtol, nhalfsweeps)
24///
25/// # Returns
26/// A new tensor train resulting from the contraction.
27///
28/// # Errors
29/// Returns an error if:
30/// - Either tensor train is empty
31/// - The tensor trains have different lengths
32/// - The contraction algorithm fails
33pub fn contract(
34    a: &TensorTrain,
35    b: &TensorTrain,
36    options: &ContractOptions,
37) -> Result<TensorTrain> {
38    if a.is_empty() || b.is_empty() {
39        return Err(TensorTrainError::InvalidStructure {
40            message: "Cannot contract empty tensor trains".to_string(),
41        });
42    }
43
44    if a.len() != b.len() {
45        return Err(TensorTrainError::InvalidStructure {
46            message: format!(
47                "Tensor trains must have the same length for contraction: {} vs {}",
48                a.len(),
49                b.len()
50            ),
51        });
52    }
53
54    validate_svd_truncation_options(options.max_rank(), options.svd_policy())?;
55
56    if matches!(options.method(), ContractMethod::Fit) && !options.nhalfsweeps().is_multiple_of(2) {
57        return Err(TensorTrainError::OperationError {
58            message: format!(
59                "nhalfsweeps must be a multiple of 2 for Fit method, got {}",
60                options.nhalfsweeps()
61            ),
62        });
63    }
64
65    // Convert ContractOptions to TreeTN ContractionOptions
66    let treetn_method = match options.method() {
67        ContractMethod::Zipup => ContractionMethod::Zipup,
68        ContractMethod::Fit => ContractionMethod::Fit,
69        ContractMethod::Naive => ContractionMethod::Naive,
70    };
71
72    // Convert nhalfsweeps to nfullsweeps (nhalfsweeps / 2)
73    let nfullsweeps = options.nhalfsweeps() / 2;
74    let treetn_options = TreeTNContractionOptions::new(treetn_method).with_nfullsweeps(nfullsweeps);
75
76    let treetn_options = if let Some(max_rank) = options.max_rank() {
77        treetn_options.with_max_rank(max_rank)
78    } else {
79        treetn_options
80    };
81
82    let treetn_options = if let Some(policy) = options.svd_policy() {
83        treetn_options.with_svd_policy(policy)
84    } else {
85        treetn_options
86    };
87
88    // Use the last site as the canonical center (consistent with existing behavior)
89    let center = a.len() - 1;
90
91    let result_inner = if matches!(options.method(), ContractMethod::Zipup) {
92        a.as_treetn()
93            .contract_zipup_tree_accumulated(
94                b.as_treetn(),
95                &center,
96                CanonicalForm::Unitary,
97                options.svd_policy(),
98                options.max_rank(),
99            )
100            .map_err(|e| TensorTrainError::InvalidStructure {
101                message: format!("Zip-up contraction failed: {}", e),
102            })?
103    } else {
104        treetn_contract(a.as_treetn(), b.as_treetn(), &center, treetn_options).map_err(|e| {
105            TensorTrainError::InvalidStructure {
106                message: format!("TreeTN contraction failed: {}", e),
107            }
108        })?
109    };
110
111    TensorTrain::from_inner(result_inner, Some(CanonicalForm::Unitary))
112}
113
114impl TensorTrain {
115    /// Contract two tensor trains, returning a new tensor train.
116    ///
117    /// This performs element-wise contraction of corresponding sites,
118    /// similar to MPO-MPO contraction in ITensor.
119    ///
120    /// # Arguments
121    /// * `other` - The other tensor train to contract with
122    /// * `options` - Contraction options (method, max_rank, rtol, nhalfsweeps)
123    ///
124    /// # Returns
125    /// A new tensor train resulting from the contraction.
126    ///
127    /// # Errors
128    /// Returns an error if:
129    /// - Either tensor train is empty
130    /// - The tensor trains have different lengths
131    /// - The contraction algorithm fails
132    pub fn contract(&self, other: &Self, options: &ContractOptions) -> Result<Self> {
133        contract(self, other, options)
134    }
135}
136
137#[cfg(test)]
138mod tests;