tensor4all_treetn/algorithm.rs
1//! Algorithm selection types for tensor train operations.
2//!
3//! This module provides enums for selecting algorithms in various tensor train operations.
4//! These types are designed to be FFI-friendly (representable as C integers).
5//!
6//! # Design Philosophy
7//!
8//! Following ITensors.jl's pattern, algorithms are represented as static types.
9//! Unlike ITensors.jl which uses symbol-based dispatch (`Algorithm"svd"`),
10//! we use Rust enums for:
11//! - Compile-time exhaustiveness checking
12//! - Easy FFI mapping to C integers
13//! - Clear documentation of available algorithms
14//!
15//! # Example
16//!
17//! ```
18//! use tensor4all_treetn::algorithm::{ContractionAlgorithm, CanonicalForm};
19//!
20//! // Select contraction algorithm
21//! let alg = ContractionAlgorithm::ZipUp;
22//!
23//! // Select canonical form
24//! let form = CanonicalForm::Unitary;
25//! ```
26
27/// Algorithm for tensor train contraction (TT-TT or MPO-MPO).
28///
29/// These algorithms contract two tensor trains and produce a new tensor train,
30/// optionally with compression/truncation.
31///
32/// # C API Representation
33/// - `T4A_CONTRACT_NAIVE` = 0
34/// - `T4A_CONTRACT_ZIPUP` = 1
35/// - `T4A_CONTRACT_FIT` = 2
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
37#[repr(i32)]
38pub enum ContractionAlgorithm {
39 /// Naive contraction followed by compression.
40 ///
41 /// Contracts all site tensors first (producing large intermediate tensors),
42 /// then compresses the result. Exact up to compression tolerance.
43 ///
44 /// Memory: O(D^4) where D is the bond dimension.
45 #[default]
46 Naive = 0,
47
48 /// Zip-up contraction with on-the-fly compression.
49 ///
50 /// Contracts and compresses site-by-site, keeping bond dimensions small.
51 /// More memory efficient than naive, but may introduce additional error.
52 ///
53 /// Memory: O(D^2)
54 ZipUp = 1,
55
56 /// Variational fitting algorithm.
57 ///
58 /// Optimizes the result tensor train to minimize the distance to the
59 /// exact contraction. Uses sweeping optimization.
60 ///
61 /// Best for cases where target bond dimension is much smaller than
62 /// the exact result.
63 Fit = 2,
64}
65
66impl ContractionAlgorithm {
67 /// Create from C API integer representation.
68 ///
69 /// Returns `None` for invalid values.
70 pub fn from_i32(value: i32) -> Option<Self> {
71 match value {
72 0 => Some(Self::Naive),
73 1 => Some(Self::ZipUp),
74 2 => Some(Self::Fit),
75 _ => None,
76 }
77 }
78
79 /// Convert to C API integer representation.
80 pub fn to_i32(self) -> i32 {
81 self as i32
82 }
83
84 /// Get algorithm name as string.
85 pub fn name(&self) -> &'static str {
86 match self {
87 Self::Naive => "naive",
88 Self::ZipUp => "zipup",
89 Self::Fit => "fit",
90 }
91 }
92}
93
94/// Canonical form for tensor train / tree tensor network.
95///
96/// Specifies the mathematical form of the canonical representation.
97/// Each form uses a specific factorization algorithm internally.
98///
99/// # C API Representation
100/// - `T4A_CANONICAL_UNITARY` = 0
101/// - `T4A_CANONICAL_LU` = 1
102/// - `T4A_CANONICAL_CI` = 2
103#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
104#[repr(i32)]
105pub enum CanonicalForm {
106 /// Unitary (orthogonal/isometric) canonical form.
107 ///
108 /// Each tensor is isometric towards the orthogonality center.
109 /// Uses QR decomposition internally.
110 /// Properties:
111 /// - Numerically stable
112 /// - Easy norm computation
113 /// - Standard canonical form for DMRG
114 #[default]
115 Unitary = 0,
116
117 /// LU-based canonical form.
118 ///
119 /// Uses rank-revealing LU decomposition.
120 /// Properties:
121 /// - Faster than QR
122 /// - One factor has unit diagonal
123 LU = 1,
124
125 /// Cross Interpolation (CI) canonical form.
126 ///
127 /// Uses CI/skeleton decomposition.
128 /// Properties:
129 /// - Adaptive rank selection
130 /// - Efficient for low-rank structures
131 CI = 2,
132}
133
134impl CanonicalForm {
135 /// Create from C API integer representation.
136 ///
137 /// Returns `None` for invalid values.
138 pub fn from_i32(value: i32) -> Option<Self> {
139 match value {
140 0 => Some(Self::Unitary),
141 1 => Some(Self::LU),
142 2 => Some(Self::CI),
143 _ => None,
144 }
145 }
146
147 /// Convert to C API integer representation.
148 pub fn to_i32(self) -> i32 {
149 self as i32
150 }
151
152 /// Get form name as string.
153 pub fn name(&self) -> &'static str {
154 match self {
155 Self::Unitary => "unitary",
156 Self::LU => "lu",
157 Self::CI => "ci",
158 }
159 }
160}
161
162/// Algorithm for tensor train compression.
163///
164/// These algorithms compress a tensor train to reduce bond dimensions.
165///
166/// # C API Representation
167/// - `T4A_COMPRESS_SVD` = 0
168/// - `T4A_COMPRESS_LU` = 1
169/// - `T4A_COMPRESS_CI` = 2
170/// - `T4A_COMPRESS_VARIATIONAL` = 3
171#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
172#[repr(i32)]
173pub enum CompressionAlgorithm {
174 /// SVD-based compression (orthogonalization + truncation).
175 ///
176 /// Sweeps through the tensor train, applying SVD at each bond.
177 /// Optimal truncation for given tolerance.
178 #[default]
179 SVD = 0,
180
181 /// LU-based compression.
182 ///
183 /// Uses LU decomposition instead of SVD. Faster but may not give
184 /// optimal truncation.
185 LU = 1,
186
187 /// Cross Interpolation based compression.
188 ///
189 /// Uses CI/skeleton decomposition for compression.
190 CI = 2,
191
192 /// Variational compression.
193 ///
194 /// Optimizes the compressed tensor train using sweeping.
195 /// Useful when target bond dimension is known.
196 Variational = 3,
197}
198
199impl CompressionAlgorithm {
200 /// Create from C API integer representation.
201 ///
202 /// Returns `None` for invalid values.
203 pub fn from_i32(value: i32) -> Option<Self> {
204 match value {
205 0 => Some(Self::SVD),
206 1 => Some(Self::LU),
207 2 => Some(Self::CI),
208 3 => Some(Self::Variational),
209 _ => None,
210 }
211 }
212
213 /// Convert to C API integer representation.
214 pub fn to_i32(self) -> i32 {
215 self as i32
216 }
217
218 /// Get algorithm name as string.
219 pub fn name(&self) -> &'static str {
220 match self {
221 Self::SVD => "svd",
222 Self::LU => "lu",
223 Self::CI => "ci",
224 Self::Variational => "variational",
225 }
226 }
227}
228
229#[cfg(test)]
230mod tests;