1use crate::einsum_helper::{tensor_to_row_major_vec, typed_tensor_from_row_major_slice};
4use crate::error::Result;
5use crate::tensortrain::TensorTrain;
6use crate::traits::{AbstractTensorTrain, TTScalar};
7use crate::types::{tensor3_zeros, Tensor3, Tensor3Ops};
8use tenferro_tensor::{TensorScalar, TypedTensor};
9use tensor4all_tcicore::matrix::{mat_mul, ncols, nrows, zeros, Matrix};
10use tensor4all_tcicore::Scalar;
11use tensor4all_tcicore::{rrlu, AbstractMatrixCI, MatrixLUCI, RrLUOptions};
12use tensor4all_tensorbackend::BackendLinalgScalar;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
28pub enum CompressionMethod {
29 #[default]
33 LU,
34 CI,
39 SVD,
44}
45
46#[derive(Debug, Clone)]
89pub struct CompressionOptions {
90 pub method: CompressionMethod,
92 pub tolerance: f64,
98 pub max_bond_dim: usize,
103 pub normalize_error: bool,
105}
106
107impl Default for CompressionOptions {
108 fn default() -> Self {
109 Self {
110 method: CompressionMethod::LU,
111 tolerance: 1e-12,
112 max_bond_dim: usize::MAX,
113 normalize_error: true,
114 }
115 }
116}
117
118fn tensor3_to_left_matrix<T: TTScalar + Scalar + Default>(tensor: &Tensor3<T>) -> Matrix<T> {
120 let left_dim = tensor.left_dim();
121 let site_dim = tensor.site_dim();
122 let right_dim = tensor.right_dim();
123 let rows = left_dim * site_dim;
124 let cols = right_dim;
125
126 let mut mat = zeros(rows, cols);
127 for l in 0..left_dim {
128 for s in 0..site_dim {
129 for r in 0..right_dim {
130 mat[[l * site_dim + s, r]] = *tensor.get3(l, s, r);
131 }
132 }
133 }
134 mat
135}
136
137fn tensor3_to_right_matrix<T: TTScalar + Scalar + Default>(tensor: &Tensor3<T>) -> Matrix<T> {
139 let left_dim = tensor.left_dim();
140 let site_dim = tensor.site_dim();
141 let right_dim = tensor.right_dim();
142 let rows = left_dim;
143 let cols = site_dim * right_dim;
144
145 let mut mat = zeros(rows, cols);
146 for l in 0..left_dim {
147 for s in 0..site_dim {
148 for r in 0..right_dim {
149 mat[[l, s * right_dim + r]] = *tensor.get3(l, s, r);
150 }
151 }
152 }
153 mat
154}
155
156fn factorize<T>(
158 matrix: &Matrix<T>,
159 method: CompressionMethod,
160 tolerance: f64,
161 max_bond_dim: usize,
162 left_orthogonal: bool,
163) -> crate::error::Result<(Matrix<T>, Matrix<T>, usize)>
164where
165 T: TTScalar + Scalar + tensor4all_tcicore::MatrixLuciScalar,
166 tensor4all_tcicore::DenseFaerLuKernel: tensor4all_tcicore::PivotKernel<T>,
167{
168 let reltol = if tolerance > 0.0 { tolerance } else { 1e-14 };
169 let abstol = 0.0;
170
171 let options = RrLUOptions {
172 max_rank: max_bond_dim,
173 rel_tol: reltol,
174 abs_tol: abstol,
175 left_orthogonal,
176 };
177
178 match method {
179 CompressionMethod::LU => {
180 let lu = rrlu(matrix, Some(options))?;
181 let left = lu.left(true); let right = lu.right(true); let npivots = lu.npivots();
184 Ok((left, right, npivots))
185 }
186 CompressionMethod::CI => {
187 let luci = MatrixLUCI::from_matrix(matrix, Some(options))?;
188 let left = luci.left();
189 let right = luci.right();
190 let npivots = luci.rank();
191 Ok((left, right, npivots))
192 }
193 CompressionMethod::SVD => {
194 svd_dispatch(matrix, tolerance, max_bond_dim, left_orthogonal)
197 }
198 }
199}
200
201trait SVDCompressScalar:
203 TTScalar + Scalar + Default + Copy + BackendLinalgScalar + num_complex::ComplexFloat + 'static
204{
205 fn sv_to_f64(real: <Self as TensorScalar>::Real) -> f64;
206 fn from_sv(real: <Self as TensorScalar>::Real) -> Self;
207}
208
209impl SVDCompressScalar for f64 {
210 fn sv_to_f64(real: <Self as TensorScalar>::Real) -> f64 {
211 real
212 }
213 fn from_sv(real: <Self as TensorScalar>::Real) -> Self {
214 real
215 }
216}
217
218impl SVDCompressScalar for num_complex::Complex64 {
219 fn sv_to_f64(real: <Self as TensorScalar>::Real) -> f64 {
220 real
221 }
222 fn from_sv(real: <Self as TensorScalar>::Real) -> Self {
223 num_complex::Complex64::new(real, 0.0)
224 }
225}
226
227fn svd_dispatch<T: TTScalar + Scalar>(
232 matrix: &Matrix<T>,
233 tolerance: f64,
234 max_bond_dim: usize,
235 left_orthogonal: bool,
236) -> crate::error::Result<(Matrix<T>, Matrix<T>, usize)> {
237 use std::any::Any;
238
239 let m = nrows(matrix);
240 let n = ncols(matrix);
241
242 if let Some(mat_f64) = (matrix as &dyn Any).downcast_ref::<Matrix<f64>>() {
244 let (l, r, rank) = factorize_svd(mat_f64, tolerance, max_bond_dim, left_orthogonal)?;
245 let left = unsafe { std::mem::transmute::<Matrix<f64>, Matrix<T>>(l) };
247 let right = unsafe { std::mem::transmute::<Matrix<f64>, Matrix<T>>(r) };
248 return Ok((left, right, rank));
249 }
250
251 if let Some(mat_c64) = (matrix as &dyn Any).downcast_ref::<Matrix<num_complex::Complex64>>() {
253 let (l, r, rank) = factorize_svd(mat_c64, tolerance, max_bond_dim, left_orthogonal)?;
254 let left = unsafe { std::mem::transmute::<Matrix<num_complex::Complex64>, Matrix<T>>(l) };
255 let right = unsafe { std::mem::transmute::<Matrix<num_complex::Complex64>, Matrix<T>>(r) };
256 return Ok((left, right, rank));
257 }
258
259 Err(crate::error::TensorTrainError::InvalidOperation {
260 message: format!(
261 "SVD compression not supported for this scalar type (matrix {}x{})",
262 m, n
263 ),
264 })
265}
266
267fn typed_tensor_row_major<T: TensorScalar>(
269 tensor: &TypedTensor<T>,
270) -> crate::error::Result<Vec<T>> {
271 Ok(tensor_to_row_major_vec(tensor))
272}
273
274fn factorize_svd<T: SVDCompressScalar>(
276 matrix: &Matrix<T>,
277 tolerance: f64,
278 max_bond_dim: usize,
279 left_orthogonal: bool,
280) -> crate::error::Result<(Matrix<T>, Matrix<T>, usize)> {
281 let m = nrows(matrix);
282 let n = ncols(matrix);
283
284 if m == 0 || n == 0 {
285 return Err(crate::error::TensorTrainError::InvalidOperation {
286 message: "Cannot factorize empty matrix".to_string(),
287 });
288 }
289
290 let mut data = vec![T::zero(); m * n];
292 for i in 0..m {
293 for j in 0..n {
294 data[i * n + j] = matrix[[i, j]];
295 }
296 }
297 let a_tensor = typed_tensor_from_row_major_slice(&data, &[m, n]);
298
299 let svd_result = tensor4all_tensorbackend::svd_backend(&a_tensor).map_err(|e| {
300 crate::error::TensorTrainError::InvalidOperation {
301 message: format!("SVD computation failed: {e:?}"),
302 }
303 })?;
304
305 let u_data = typed_tensor_row_major(&svd_result.u)?;
307 let u_cols = svd_result.u.shape[1];
308 let s_data: Vec<<T as TensorScalar>::Real> = typed_tensor_row_major(&svd_result.s)?;
309 let vt_data = typed_tensor_row_major(&svd_result.vt)?;
310 let vt_cols = svd_result.vt.shape[1];
311
312 let min_dim = m.min(n);
314 let s_max = if !s_data.is_empty() {
315 T::sv_to_f64(s_data[0])
316 } else {
317 0.0
318 };
319
320 let mut rank = 0;
321 for &singular_value in s_data.iter().take(min_dim) {
322 if rank >= max_bond_dim {
323 break;
324 }
325 let sv = T::sv_to_f64(singular_value);
326 if sv < tolerance * s_max {
327 break;
328 }
329 rank += 1;
330 }
331 rank = rank.max(1);
332
333 let mut left = zeros(m, rank);
335 let mut right = zeros(rank, n);
336
337 if left_orthogonal {
338 for i in 0..m {
340 for j in 0..rank {
341 left[[i, j]] = u_data[i * u_cols + j];
342 }
343 }
344 for i in 0..rank {
345 let sv = T::from_sv(s_data[i]);
346 for j in 0..n {
347 right[[i, j]] = sv * vt_data[i * vt_cols + j];
348 }
349 }
350 } else {
351 for i in 0..m {
353 for j in 0..rank {
354 let sv = T::from_sv(s_data[j]);
355 left[[i, j]] = u_data[i * u_cols + j] * sv;
356 }
357 }
358 for i in 0..rank {
359 for j in 0..n {
360 right[[i, j]] = vt_data[i * vt_cols + j];
361 }
362 }
363 }
364
365 Ok((left, right, rank))
366}
367
368impl<T: TTScalar + Scalar + Default> TensorTrain<T> {
369 pub fn compress(&mut self, options: &CompressionOptions) -> Result<()>
400 where
401 T: tensor4all_tcicore::MatrixLuciScalar,
402 tensor4all_tcicore::DenseFaerLuKernel: tensor4all_tcicore::PivotKernel<T>,
403 {
404 let n = self.len();
405 if n <= 1 {
406 return Ok(());
407 }
408
409 let tensors = self.site_tensors_mut();
410
411 for ell in 0..n - 1 {
413 let left_dim = tensors[ell].left_dim();
414 let site_dim = tensors[ell].site_dim();
415
416 let mat = tensor3_to_left_matrix(&tensors[ell]);
418
419 let (left_factor, right_factor, new_bond_dim) = factorize(
421 &mat,
422 options.method,
423 0.0, usize::MAX, true, )?;
427
428 let mut new_tensor = tensor3_zeros(left_dim, site_dim, new_bond_dim);
430 for l in 0..left_dim {
431 for s in 0..site_dim {
432 for r in 0..new_bond_dim {
433 let row = l * site_dim + s;
434 if row < nrows(&left_factor) && r < ncols(&left_factor) {
435 new_tensor.set3(l, s, r, left_factor[[row, r]]);
436 }
437 }
438 }
439 }
440 tensors[ell] = new_tensor;
441
442 let next_site_dim = tensors[ell + 1].site_dim();
444 let next_right_dim = tensors[ell + 1].right_dim();
445
446 let next_mat = tensor3_to_right_matrix(&tensors[ell + 1]);
448
449 let contracted = mat_mul(&right_factor, &next_mat);
451
452 let mut new_next_tensor = tensor3_zeros(new_bond_dim, next_site_dim, next_right_dim);
454 for l in 0..new_bond_dim {
455 for s in 0..next_site_dim {
456 for r in 0..next_right_dim {
457 new_next_tensor.set3(l, s, r, contracted[[l, s * next_right_dim + r]]);
458 }
459 }
460 }
461 tensors[ell + 1] = new_next_tensor;
462 }
463
464 for ell in (1..n).rev() {
466 let site_dim = tensors[ell].site_dim();
467 let right_dim = tensors[ell].right_dim();
468
469 let mat = tensor3_to_right_matrix(&tensors[ell]);
471
472 let (left_factor, right_factor, new_bond_dim) = factorize(
474 &mat,
475 options.method,
476 options.tolerance,
477 options.max_bond_dim,
478 false, )?;
480
481 let mut new_tensor = tensor3_zeros(new_bond_dim, site_dim, right_dim);
483 for l in 0..new_bond_dim {
484 for s in 0..site_dim {
485 for r in 0..right_dim {
486 new_tensor.set3(l, s, r, right_factor[[l, s * right_dim + r]]);
487 }
488 }
489 }
490 tensors[ell] = new_tensor;
491
492 let prev_left_dim = tensors[ell - 1].left_dim();
494 let prev_site_dim = tensors[ell - 1].site_dim();
495
496 let prev_mat = tensor3_to_left_matrix(&tensors[ell - 1]);
498
499 let contracted = mat_mul(&prev_mat, &left_factor);
501
502 let mut new_prev_tensor = tensor3_zeros(prev_left_dim, prev_site_dim, new_bond_dim);
504 for l in 0..prev_left_dim {
505 for s in 0..prev_site_dim {
506 for r in 0..new_bond_dim {
507 new_prev_tensor.set3(l, s, r, contracted[[l * prev_site_dim + s, r]]);
508 }
509 }
510 }
511 tensors[ell - 1] = new_prev_tensor;
512 }
513
514 Ok(())
515 }
516
517 pub fn compressed(&self, options: &CompressionOptions) -> Result<Self>
541 where
542 T: tensor4all_tcicore::MatrixLuciScalar,
543 tensor4all_tcicore::DenseFaerLuKernel: tensor4all_tcicore::PivotKernel<T>,
544 {
545 let mut result = self.clone();
546 result.compress(options)?;
547 Ok(result)
548 }
549}
550
551#[cfg(test)]
552mod tests;