Skip to main content

tensor4all_treetn/
algorithm.rs

1//! Algorithm selection types for tensor train operations.
2//!
3//! This module provides enums for selecting algorithms in various tensor train operations.
4//! These types are designed to be FFI-friendly (representable as C integers).
5//!
6//! # Design Philosophy
7//!
8//! Following ITensors.jl's pattern, algorithms are represented as static types.
9//! Unlike ITensors.jl which uses symbol-based dispatch (`Algorithm"svd"`),
10//! we use Rust enums for:
11//! - Compile-time exhaustiveness checking
12//! - Easy FFI mapping to C integers
13//! - Clear documentation of available algorithms
14//!
15//! # Example
16//!
17//! ```
18//! use tensor4all_treetn::algorithm::{ContractionAlgorithm, CanonicalForm};
19//!
20//! // Select contraction algorithm
21//! let alg = ContractionAlgorithm::ZipUp;
22//!
23//! // Select canonical form
24//! let form = CanonicalForm::Unitary;
25//! ```
26
27/// Algorithm for tensor train contraction (TT-TT or MPO-MPO).
28///
29/// These algorithms contract two tensor trains and produce a new tensor train,
30/// optionally with compression/truncation.
31///
32/// # C API Representation
33/// - `T4A_CONTRACT_NAIVE` = 0
34/// - `T4A_CONTRACT_ZIPUP` = 1
35/// - `T4A_CONTRACT_FIT` = 2
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
37#[repr(i32)]
38pub enum ContractionAlgorithm {
39    /// Naive contraction followed by compression.
40    ///
41    /// Contracts all site tensors first (producing large intermediate tensors),
42    /// then compresses the result. Exact up to compression tolerance.
43    ///
44    /// Memory: O(D^4) where D is the bond dimension.
45    #[default]
46    Naive = 0,
47
48    /// Zip-up contraction with on-the-fly compression.
49    ///
50    /// Contracts and compresses site-by-site, keeping bond dimensions small.
51    /// More memory efficient than naive, but may introduce additional error.
52    ///
53    /// Memory: O(D^2)
54    ZipUp = 1,
55
56    /// Variational fitting algorithm.
57    ///
58    /// Optimizes the result tensor train to minimize the distance to the
59    /// exact contraction. Uses sweeping optimization.
60    ///
61    /// Best for cases where target bond dimension is much smaller than
62    /// the exact result.
63    Fit = 2,
64}
65
66impl ContractionAlgorithm {
67    /// Create from C API integer representation.
68    ///
69    /// Returns `None` for invalid values.
70    pub fn from_i32(value: i32) -> Option<Self> {
71        match value {
72            0 => Some(Self::Naive),
73            1 => Some(Self::ZipUp),
74            2 => Some(Self::Fit),
75            _ => None,
76        }
77    }
78
79    /// Convert to C API integer representation.
80    pub fn to_i32(self) -> i32 {
81        self as i32
82    }
83
84    /// Get algorithm name as string.
85    pub fn name(&self) -> &'static str {
86        match self {
87            Self::Naive => "naive",
88            Self::ZipUp => "zipup",
89            Self::Fit => "fit",
90        }
91    }
92}
93
94/// Canonical form for tensor train / tree tensor network.
95///
96/// Specifies the mathematical form of the canonical representation.
97/// Each form uses a specific factorization algorithm internally.
98///
99/// # C API Representation
100/// - `T4A_CANONICAL_UNITARY` = 0
101/// - `T4A_CANONICAL_LU` = 1
102/// - `T4A_CANONICAL_CI` = 2
103#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
104#[repr(i32)]
105pub enum CanonicalForm {
106    /// Unitary (orthogonal/isometric) canonical form.
107    ///
108    /// Each tensor is isometric towards the orthogonality center.
109    /// Uses QR decomposition internally.
110    /// Properties:
111    /// - Numerically stable
112    /// - Easy norm computation
113    /// - Standard canonical form for DMRG
114    #[default]
115    Unitary = 0,
116
117    /// LU-based canonical form.
118    ///
119    /// Uses rank-revealing LU decomposition.
120    /// Properties:
121    /// - Faster than QR
122    /// - One factor has unit diagonal
123    LU = 1,
124
125    /// Cross Interpolation (CI) canonical form.
126    ///
127    /// Uses CI/skeleton decomposition.
128    /// Properties:
129    /// - Adaptive rank selection
130    /// - Efficient for low-rank structures
131    CI = 2,
132}
133
134impl CanonicalForm {
135    /// Create from C API integer representation.
136    ///
137    /// Returns `None` for invalid values.
138    pub fn from_i32(value: i32) -> Option<Self> {
139        match value {
140            0 => Some(Self::Unitary),
141            1 => Some(Self::LU),
142            2 => Some(Self::CI),
143            _ => None,
144        }
145    }
146
147    /// Convert to C API integer representation.
148    pub fn to_i32(self) -> i32 {
149        self as i32
150    }
151
152    /// Get form name as string.
153    pub fn name(&self) -> &'static str {
154        match self {
155            Self::Unitary => "unitary",
156            Self::LU => "lu",
157            Self::CI => "ci",
158        }
159    }
160}
161
162/// Algorithm for tensor train compression.
163///
164/// These algorithms compress a tensor train to reduce bond dimensions.
165///
166/// # C API Representation
167/// - `T4A_COMPRESS_SVD` = 0
168/// - `T4A_COMPRESS_LU` = 1
169/// - `T4A_COMPRESS_CI` = 2
170/// - `T4A_COMPRESS_VARIATIONAL` = 3
171#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
172#[repr(i32)]
173pub enum CompressionAlgorithm {
174    /// SVD-based compression (orthogonalization + truncation).
175    ///
176    /// Sweeps through the tensor train, applying SVD at each bond.
177    /// Optimal truncation for given tolerance.
178    #[default]
179    SVD = 0,
180
181    /// LU-based compression.
182    ///
183    /// Uses LU decomposition instead of SVD. Faster but may not give
184    /// optimal truncation.
185    LU = 1,
186
187    /// Cross Interpolation based compression.
188    ///
189    /// Uses CI/skeleton decomposition for compression.
190    CI = 2,
191
192    /// Variational compression.
193    ///
194    /// Optimizes the compressed tensor train using sweeping.
195    /// Useful when target bond dimension is known.
196    Variational = 3,
197}
198
199impl CompressionAlgorithm {
200    /// Create from C API integer representation.
201    ///
202    /// Returns `None` for invalid values.
203    pub fn from_i32(value: i32) -> Option<Self> {
204        match value {
205            0 => Some(Self::SVD),
206            1 => Some(Self::LU),
207            2 => Some(Self::CI),
208            3 => Some(Self::Variational),
209            _ => None,
210        }
211    }
212
213    /// Convert to C API integer representation.
214    pub fn to_i32(self) -> i32 {
215        self as i32
216    }
217
218    /// Get algorithm name as string.
219    pub fn name(&self) -> &'static str {
220        match self {
221            Self::SVD => "svd",
222            Self::LU => "lu",
223            Self::CI => "ci",
224            Self::Variational => "variational",
225        }
226    }
227}
228
229#[cfg(test)]
230mod tests;