1use crate::defaults::DynIndex;
36use crate::{unfold_split, TensorDynLen};
37use num_complex::{Complex64, ComplexFloat};
38use tensor4all_tcicore::{rrlu, AbstractMatrixCI, MatrixLUCI, RrLUOptions, Scalar as MatrixScalar};
39use tensor4all_tensorbackend::TensorElement;
40
41use crate::defaults::svd::svd_for_factorize;
42use crate::qr::{qr_with, QrOptions};
43use crate::svd::SvdOptions;
44
45pub use crate::tensor_like::{
47 Canonical, FactorizeAlg, FactorizeError, FactorizeOptions, FactorizeResult,
48};
49
50pub fn factorize(
77 t: &TensorDynLen,
78 left_inds: &[DynIndex],
79 options: &FactorizeOptions,
80) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
81 options.validate()?;
82
83 if t.is_diag() {
84 return Err(FactorizeError::UnsupportedStorage(
85 "Diagonal storage not supported for factorize",
86 ));
87 }
88
89 if t.is_f64() {
90 factorize_impl_f64(t, left_inds, options)
91 } else if t.is_complex() {
92 factorize_impl_c64(t, left_inds, options)
93 } else {
94 Err(FactorizeError::UnsupportedStorage(
95 "factorize currently supports only f64 and Complex64 tensors",
96 ))
97 }
98}
99
100pub fn factorize_full_rank(
146 t: &TensorDynLen,
147 left_inds: &[DynIndex],
148 alg: FactorizeAlg,
149 canonical: Canonical,
150) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
151 if t.is_diag() {
152 return Err(FactorizeError::UnsupportedStorage(
153 "Diagonal storage not supported for factorize",
154 ));
155 }
156
157 if t.is_f64() {
158 factorize_impl_f64_full_rank(t, left_inds, alg, canonical)
159 } else if t.is_complex() {
160 factorize_impl_c64_full_rank(t, left_inds, alg, canonical)
161 } else {
162 Err(FactorizeError::UnsupportedStorage(
163 "factorize currently supports only f64 and Complex64 tensors",
164 ))
165 }
166}
167
168fn factorize_impl_f64(
169 t: &TensorDynLen,
170 left_inds: &[DynIndex],
171 options: &FactorizeOptions,
172) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
173 match options.alg {
174 FactorizeAlg::SVD => factorize_svd(t, left_inds, options),
175 FactorizeAlg::QR => factorize_qr(t, left_inds, options),
176 FactorizeAlg::LU => factorize_lu::<f64>(t, left_inds, options),
177 FactorizeAlg::CI => factorize_ci::<f64>(t, left_inds, options),
178 }
179}
180
181fn factorize_impl_f64_full_rank(
182 t: &TensorDynLen,
183 left_inds: &[DynIndex],
184 alg: FactorizeAlg,
185 canonical: Canonical,
186) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
187 match alg {
188 FactorizeAlg::SVD => factorize_svd_full_rank(t, left_inds, canonical),
189 FactorizeAlg::QR => factorize_qr_full_rank(t, left_inds, canonical),
190 FactorizeAlg::LU => factorize_lu_full_rank::<f64>(t, left_inds, canonical),
191 FactorizeAlg::CI => factorize_ci_full_rank::<f64>(t, left_inds, canonical),
192 }
193}
194
195fn factorize_impl_c64(
196 t: &TensorDynLen,
197 left_inds: &[DynIndex],
198 options: &FactorizeOptions,
199) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
200 match options.alg {
201 FactorizeAlg::SVD => factorize_svd(t, left_inds, options),
202 FactorizeAlg::QR => factorize_qr(t, left_inds, options),
203 FactorizeAlg::LU => factorize_lu::<Complex64>(t, left_inds, options),
204 FactorizeAlg::CI => factorize_ci::<Complex64>(t, left_inds, options),
205 }
206}
207
208fn factorize_impl_c64_full_rank(
209 t: &TensorDynLen,
210 left_inds: &[DynIndex],
211 alg: FactorizeAlg,
212 canonical: Canonical,
213) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
214 match alg {
215 FactorizeAlg::SVD => factorize_svd_full_rank(t, left_inds, canonical),
216 FactorizeAlg::QR => factorize_qr_full_rank(t, left_inds, canonical),
217 FactorizeAlg::LU => factorize_lu_full_rank::<Complex64>(t, left_inds, canonical),
218 FactorizeAlg::CI => factorize_ci_full_rank::<Complex64>(t, left_inds, canonical),
219 }
220}
221
222fn factorize_svd(
224 t: &TensorDynLen,
225 left_inds: &[DynIndex],
226 options: &FactorizeOptions,
227) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
228 let mut svd_options = SvdOptions::new();
229 if let Some(policy) = options.svd_policy {
230 svd_options = svd_options.with_policy(policy);
231 }
232 if let Some(max_rank) = options.max_rank {
233 svd_options = svd_options.with_max_rank(max_rank);
234 }
235
236 factorize_svd_with_options(t, left_inds, options.canonical, &svd_options)
237}
238
239fn factorize_svd_full_rank(
240 t: &TensorDynLen,
241 left_inds: &[DynIndex],
242 canonical: Canonical,
243) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
244 let svd_options = SvdOptions::full_rank();
245 factorize_svd_with_options(t, left_inds, canonical, &svd_options)
246}
247
248fn factorize_svd_with_options(
249 t: &TensorDynLen,
250 left_inds: &[DynIndex],
251 canonical: Canonical,
252 svd_options: &SvdOptions,
253) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
254 let result = svd_for_factorize(t, left_inds, svd_options)?;
255 let u = result.u;
256 let s = result.s;
257 let vh = result.vh;
258 let bond_index = result.bond_index;
259 let singular_values = result.singular_values;
260 let rank = result.rank;
261 let sim_bond_index = s.indices[1].clone();
262
263 match canonical {
264 Canonical::Left => {
265 let right_contracted = s.contract(&vh);
267 let right = right_contracted.replaceind(&sim_bond_index, &bond_index);
268 Ok(FactorizeResult {
269 left: u,
270 right,
271 bond_index,
272 singular_values: Some(singular_values),
273 rank,
274 })
275 }
276 Canonical::Right => {
277 let left_contracted = u.contract(&s);
279 let left = left_contracted.replaceind(&sim_bond_index, &bond_index);
280 Ok(FactorizeResult {
281 left,
282 right: vh,
283 bond_index,
284 singular_values: Some(singular_values),
285 rank,
286 })
287 }
288 }
289}
290
291fn factorize_qr(
293 t: &TensorDynLen,
294 left_inds: &[DynIndex],
295 options: &FactorizeOptions,
296) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
297 if options.canonical == Canonical::Right {
298 return Err(FactorizeError::UnsupportedCanonical(
299 "QR only supports Canonical::Left (would need LQ for right)",
300 ));
301 }
302
303 let qr_options = if let Some(rtol) = options.qr_rtol {
304 QrOptions::new().with_rtol(rtol)
305 } else {
306 QrOptions::new()
307 };
308
309 factorize_qr_with_options(t, left_inds, &qr_options)
310}
311
312fn factorize_qr_full_rank(
313 t: &TensorDynLen,
314 left_inds: &[DynIndex],
315 canonical: Canonical,
316) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
317 if canonical == Canonical::Right {
318 return Err(FactorizeError::UnsupportedCanonical(
319 "QR only supports Canonical::Left (would need LQ for right)",
320 ));
321 }
322
323 factorize_qr_with_options(t, left_inds, &QrOptions::full_rank())
324}
325
326fn factorize_qr_with_options(
327 t: &TensorDynLen,
328 left_inds: &[DynIndex],
329 qr_options: &QrOptions,
330) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
331 let (q, r) = qr_with::<f64>(t, left_inds, qr_options)?;
332
333 let bond_index = q.indices.last().unwrap().clone();
335 let q_dims = q.dims();
337 let rank = *q_dims.last().unwrap();
338
339 Ok(FactorizeResult {
340 left: q,
341 right: r,
342 bond_index,
343 singular_values: None,
344 rank,
345 })
346}
347
348fn factorize_lu<T>(
350 t: &TensorDynLen,
351 left_inds: &[DynIndex],
352 options: &FactorizeOptions,
353) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
354where
355 T: TensorElement
356 + ComplexFloat
357 + Default
358 + From<<T as ComplexFloat>::Real>
359 + MatrixScalar
360 + tensor4all_tcicore::MatrixLuciScalar
361 + 'static,
362 <T as ComplexFloat>::Real: Into<f64> + 'static,
363 tensor4all_tcicore::DenseFaerLuKernel: tensor4all_tcicore::PivotKernel<T>,
364{
365 factorize_lu_with_options::<T>(
366 t,
367 left_inds,
368 options.canonical,
369 options.max_rank.unwrap_or(usize::MAX),
370 1e-14,
371 )
372}
373
374fn factorize_lu_full_rank<T>(
375 t: &TensorDynLen,
376 left_inds: &[DynIndex],
377 canonical: Canonical,
378) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
379where
380 T: TensorElement
381 + ComplexFloat
382 + Default
383 + From<<T as ComplexFloat>::Real>
384 + MatrixScalar
385 + tensor4all_tcicore::MatrixLuciScalar
386 + 'static,
387 <T as ComplexFloat>::Real: Into<f64> + 'static,
388 tensor4all_tcicore::DenseFaerLuKernel: tensor4all_tcicore::PivotKernel<T>,
389{
390 factorize_lu_with_options::<T>(t, left_inds, canonical, usize::MAX, 0.0)
391}
392
393fn factorize_lu_with_options<T>(
394 t: &TensorDynLen,
395 left_inds: &[DynIndex],
396 canonical: Canonical,
397 max_rank: usize,
398 rel_tol: f64,
399) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
400where
401 T: TensorElement
402 + ComplexFloat
403 + Default
404 + From<<T as ComplexFloat>::Real>
405 + MatrixScalar
406 + tensor4all_tcicore::MatrixLuciScalar
407 + 'static,
408 <T as ComplexFloat>::Real: Into<f64> + 'static,
409 tensor4all_tcicore::DenseFaerLuKernel: tensor4all_tcicore::PivotKernel<T>,
410{
411 let (a_tensor, _, m, n, left_indices, right_indices) = unfold_split(t, left_inds)
413 .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))?;
414
415 let a_matrix = native_tensor_to_matrix::<T>(&a_tensor, m, n)?;
417
418 let left_orthogonal = canonical == Canonical::Left;
420 let lu_options = RrLUOptions {
421 max_rank,
422 rel_tol,
423 abs_tol: 0.0,
424 left_orthogonal,
425 };
426
427 let lu = rrlu(&a_matrix, Some(lu_options))?;
429 let rank = lu.npivots();
430
431 let l_matrix = lu.left(true);
433 let u_matrix = lu.right(true);
434
435 let bond_index = DynIndex::new_bond(rank)
437 .map_err(|e| anyhow::anyhow!("Failed to create bond index: {:?}", e))?;
438
439 let l_vec = matrix_to_vec(&l_matrix);
441 let mut l_indices = left_indices.clone();
442 l_indices.push(bond_index.clone());
443 let left =
444 TensorDynLen::from_dense(l_indices, l_vec).map_err(FactorizeError::ComputationError)?;
445
446 let u_vec = matrix_to_vec(&u_matrix);
448 let mut r_indices = vec![bond_index.clone()];
449 r_indices.extend_from_slice(&right_indices);
450 let right =
451 TensorDynLen::from_dense(r_indices, u_vec).map_err(FactorizeError::ComputationError)?;
452
453 Ok(FactorizeResult {
454 left,
455 right,
456 bond_index,
457 singular_values: None,
458 rank,
459 })
460}
461
462fn factorize_ci<T>(
464 t: &TensorDynLen,
465 left_inds: &[DynIndex],
466 options: &FactorizeOptions,
467) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
468where
469 T: TensorElement
470 + ComplexFloat
471 + Default
472 + From<<T as ComplexFloat>::Real>
473 + MatrixScalar
474 + tensor4all_tcicore::MatrixLuciScalar
475 + 'static,
476 <T as ComplexFloat>::Real: Into<f64> + 'static,
477 tensor4all_tcicore::DenseFaerLuKernel: tensor4all_tcicore::PivotKernel<T>,
478{
479 factorize_ci_with_options::<T>(
480 t,
481 left_inds,
482 options.canonical,
483 options.max_rank.unwrap_or(usize::MAX),
484 1e-14,
485 )
486}
487
488fn factorize_ci_full_rank<T>(
489 t: &TensorDynLen,
490 left_inds: &[DynIndex],
491 canonical: Canonical,
492) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
493where
494 T: TensorElement
495 + ComplexFloat
496 + Default
497 + From<<T as ComplexFloat>::Real>
498 + MatrixScalar
499 + tensor4all_tcicore::MatrixLuciScalar
500 + 'static,
501 <T as ComplexFloat>::Real: Into<f64> + 'static,
502 tensor4all_tcicore::DenseFaerLuKernel: tensor4all_tcicore::PivotKernel<T>,
503{
504 factorize_ci_with_options::<T>(t, left_inds, canonical, usize::MAX, 0.0)
505}
506
507fn factorize_ci_with_options<T>(
508 t: &TensorDynLen,
509 left_inds: &[DynIndex],
510 canonical: Canonical,
511 max_rank: usize,
512 rel_tol: f64,
513) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
514where
515 T: TensorElement
516 + ComplexFloat
517 + Default
518 + From<<T as ComplexFloat>::Real>
519 + MatrixScalar
520 + tensor4all_tcicore::MatrixLuciScalar
521 + 'static,
522 <T as ComplexFloat>::Real: Into<f64> + 'static,
523 tensor4all_tcicore::DenseFaerLuKernel: tensor4all_tcicore::PivotKernel<T>,
524{
525 let (a_tensor, _, m, n, left_indices, right_indices) = unfold_split(t, left_inds)
527 .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))?;
528
529 let a_matrix = native_tensor_to_matrix::<T>(&a_tensor, m, n)?;
531
532 let left_orthogonal = canonical == Canonical::Left;
534 let lu_options = RrLUOptions {
535 max_rank,
536 rel_tol,
537 abs_tol: 0.0,
538 left_orthogonal,
539 };
540
541 let ci = MatrixLUCI::from_matrix(&a_matrix, Some(lu_options))?;
543 let rank = ci.rank();
544
545 let l_matrix = ci.left();
547 let r_matrix = ci.right();
548
549 let bond_index = DynIndex::new_bond(rank)
551 .map_err(|e| anyhow::anyhow!("Failed to create bond index: {:?}", e))?;
552
553 let l_vec = matrix_to_vec(&l_matrix);
555 let mut l_indices = left_indices.clone();
556 l_indices.push(bond_index.clone());
557 let left =
558 TensorDynLen::from_dense(l_indices, l_vec).map_err(FactorizeError::ComputationError)?;
559
560 let r_vec = matrix_to_vec(&r_matrix);
562 let mut r_indices = vec![bond_index.clone()];
563 r_indices.extend_from_slice(&right_indices);
564 let right =
565 TensorDynLen::from_dense(r_indices, r_vec).map_err(FactorizeError::ComputationError)?;
566
567 Ok(FactorizeResult {
568 left,
569 right,
570 bond_index,
571 singular_values: None,
572 rank,
573 })
574}
575
576fn native_tensor_to_matrix<T>(
578 tensor: &tenferro::Tensor,
579 m: usize,
580 n: usize,
581) -> Result<tensor4all_tcicore::Matrix<T>, FactorizeError>
582where
583 T: TensorElement + MatrixScalar + Copy,
584{
585 let data = T::dense_values_from_native_col_major(tensor).map_err(|e| {
586 FactorizeError::ComputationError(anyhow::anyhow!(
587 "failed to extract dense matrix entries from native tensor: {e}"
588 ))
589 })?;
590 if data.len() != m * n {
591 return Err(FactorizeError::ComputationError(anyhow::anyhow!(
592 "native matrix materialization produced {} entries for shape ({m}, {n})",
593 data.len()
594 )));
595 }
596
597 let mut matrix = tensor4all_tcicore::matrix::zeros(m, n);
598 for i in 0..m {
599 for j in 0..n {
600 matrix[[i, j]] = data[j * m + i];
601 }
602 }
603 Ok(matrix)
604}
605
606fn matrix_to_vec<T>(matrix: &tensor4all_tcicore::Matrix<T>) -> Vec<T>
608where
609 T: Clone,
610{
611 let m = tensor4all_tcicore::matrix::nrows(matrix);
612 let n = tensor4all_tcicore::matrix::ncols(matrix);
613 let mut vec = Vec::with_capacity(m * n);
614 for j in 0..n {
615 for i in 0..m {
616 vec.push(matrix[[i, j]].clone());
617 }
618 }
619 vec
620}
621
622#[cfg(test)]
623mod tests;