Skip to main content

tensor4all_simplett/
contraction.rs

1//! Contraction operations for tensor trains
2//!
3//! This module provides various ways to combine tensor trains:
4//! - `dot`: Inner product (returns scalar)
5
6use crate::compression::CompressionMethod;
7use crate::einsum_helper::{
8    einsum_tensors, tensor_to_row_major_vec, typed_tensor_from_row_major_slice, EinsumScalar,
9};
10use crate::error::{Result, TensorTrainError};
11use crate::tensortrain::TensorTrain;
12use crate::traits::{AbstractTensorTrain, TTScalar};
13use crate::types::Tensor3Ops;
14use tensor4all_tcicore::matrix::Matrix;
15use tensor4all_tcicore::Scalar;
16
17/// Options for MPO-MPO contraction with on-the-fly compression.
18///
19/// # Examples
20///
21/// ```
22/// use tensor4all_simplett::ContractionOptions;
23///
24/// let opts = ContractionOptions::default();
25/// assert!((opts.tolerance - 1e-12).abs() < 1e-15);
26/// assert_eq!(opts.max_bond_dim, usize::MAX);
27/// ```
28#[derive(Debug, Clone)]
29pub struct ContractionOptions {
30    /// Relative truncation tolerance during contraction.
31    pub tolerance: f64,
32    /// Hard upper bound on bond dimension.
33    pub max_bond_dim: usize,
34    /// Decomposition method for intermediate compression.
35    pub method: CompressionMethod,
36}
37
38impl Default for ContractionOptions {
39    fn default() -> Self {
40        Self {
41            tolerance: 1e-12,
42            max_bond_dim: usize::MAX,
43            method: CompressionMethod::LU,
44        }
45    }
46}
47
48impl<T: TTScalar + Scalar + Default + EinsumScalar> TensorTrain<T> {
49    /// Inner product (dot product) of two tensor trains.
50    ///
51    /// Computes `sum_i self[i] * other[i]` by contracting the site tensors
52    /// from left to right. Both tensor trains must have the same length and
53    /// matching site dimensions.
54    ///
55    /// # Errors
56    ///
57    /// Returns an error if lengths or site dimensions do not match.
58    ///
59    /// # Examples
60    ///
61    /// ```
62    /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
63    ///
64    /// let a = TensorTrain::<f64>::constant(&[2, 3], 1.0);
65    /// let b = TensorTrain::<f64>::constant(&[2, 3], 2.0);
66    ///
67    /// // dot = sum_ij a[i,j]*b[i,j] = 1*2 * 2*3 = 12
68    /// let d = a.dot(&b).unwrap();
69    /// assert!((d - 12.0).abs() < 1e-10);
70    /// ```
71    pub fn dot(&self, other: &Self) -> Result<T> {
72        if self.len() != other.len() {
73            return Err(TensorTrainError::InvalidOperation {
74                message: format!(
75                    "Cannot compute dot product of tensor trains with different lengths: {} vs {}",
76                    self.len(),
77                    other.len()
78                ),
79            });
80        }
81
82        if self.is_empty() {
83            return Ok(T::zero());
84        }
85
86        let n = self.len();
87
88        // Start with contraction of first site
89        // result[ra, rb] = sum_s a[0, s, ra] * b[0, s, rb]
90        let a0 = self.site_tensor(0);
91        let b0 = other.site_tensor(0);
92
93        if a0.site_dim() != b0.site_dim() {
94            return Err(TensorTrainError::InvalidOperation {
95                message: format!(
96                    "Site dimensions mismatch at site 0: {} vs {}",
97                    a0.site_dim(),
98                    b0.site_dim()
99                ),
100            });
101        }
102
103        let mut result = Matrix::from_raw_vec(
104            a0.right_dim(),
105            b0.right_dim(),
106            tensor_to_row_major_vec(&einsum_tensors(
107                "asr,ast->rt",
108                &[a0.as_inner(), b0.as_inner()],
109            )),
110        );
111
112        // Contract through remaining sites
113        for i in 1..n {
114            let a = self.site_tensor(i);
115            let b = other.site_tensor(i);
116
117            if a.site_dim() != b.site_dim() {
118                return Err(TensorTrainError::InvalidOperation {
119                    message: format!(
120                        "Site dimensions mismatch at site {}: {} vs {}",
121                        i,
122                        a.site_dim(),
123                        b.site_dim()
124                    ),
125                });
126            }
127
128            let result_tf = typed_tensor_from_row_major_slice(
129                result.as_slice(),
130                &[result.nrows(), result.ncols()],
131            );
132
133            result = Matrix::from_raw_vec(
134                a.right_dim(),
135                b.right_dim(),
136                tensor_to_row_major_vec(&einsum_tensors(
137                    "ij,isk,jsl->kl",
138                    &[&result_tf, a.as_inner(), b.as_inner()],
139                )),
140            );
141        }
142
143        // Final result should be 1x1
144        Ok(result[[0, 0]])
145    }
146}
147
148/// Free-function wrapper for [`TensorTrain::dot`].
149///
150/// # Examples
151///
152/// ```
153/// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain, contraction::dot};
154///
155/// let a = TensorTrain::<f64>::constant(&[2, 3], 3.0);
156/// let b = TensorTrain::<f64>::constant(&[2, 3], 4.0);
157/// let d = dot(&a, &b).unwrap();
158/// // 3*4 * 2*3 = 72
159/// assert!((d - 72.0).abs() < 1e-10);
160/// ```
161pub fn dot<T: TTScalar + Scalar + Default + EinsumScalar>(
162    a: &TensorTrain<T>,
163    b: &TensorTrain<T>,
164) -> Result<T> {
165    a.dot(b)
166}
167
168#[cfg(test)]
169mod tests;