1use crate::defaults::tensordynlen::unfold_split_inner;
36use crate::defaults::DynIndex;
37use crate::{contract_pair, unfold_split, TensorDynLen};
38use anyhow::Result as AnyhowResult;
39use num_complex::{Complex64, ComplexFloat};
40use tenferro_ad::EagerTensor;
41use tenferro_linalg::eager_tensor::full_piv_lu_solve as eager_full_piv_lu_solve;
42use tensor4all_tcicore::{rrlu, AbstractMatrixCI, MatrixLUCI, RrLUOptions, Scalar as MatrixScalar};
43use tensor4all_tensorbackend::{Matrix, TensorElement};
44
45use crate::defaults::svd::svd_for_factorize;
46use crate::qr::{qr_with, QrOptions};
47use crate::svd::SvdOptions;
48
49pub use crate::tensor_like::{
51 Canonical, FactorizeAlg, FactorizeError, FactorizeOptions, FactorizeResult,
52};
53
54pub fn factorize(
81 t: &TensorDynLen,
82 left_inds: &[DynIndex],
83 options: &FactorizeOptions,
84) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
85 options.validate()?;
86
87 if t.is_diag() {
88 return Err(FactorizeError::UnsupportedStorage(
89 "Diagonal storage not supported for factorize",
90 ));
91 }
92
93 if t.is_f64() {
94 factorize_impl_f64(t, left_inds, options)
95 } else if t.is_complex() {
96 factorize_impl_c64(t, left_inds, options)
97 } else {
98 Err(FactorizeError::UnsupportedStorage(
99 "factorize currently supports only f64 and Complex64 tensors",
100 ))
101 }
102}
103
104pub fn factorize_full_rank(
150 t: &TensorDynLen,
151 left_inds: &[DynIndex],
152 alg: FactorizeAlg,
153 canonical: Canonical,
154) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
155 if t.is_diag() {
156 return Err(FactorizeError::UnsupportedStorage(
157 "Diagonal storage not supported for factorize",
158 ));
159 }
160
161 if t.is_f64() {
162 factorize_impl_f64_full_rank(t, left_inds, alg, canonical)
163 } else if t.is_complex() {
164 factorize_impl_c64_full_rank(t, left_inds, alg, canonical)
165 } else {
166 Err(FactorizeError::UnsupportedStorage(
167 "factorize currently supports only f64 and Complex64 tensors",
168 ))
169 }
170}
171
172fn factorize_impl_f64(
173 t: &TensorDynLen,
174 left_inds: &[DynIndex],
175 options: &FactorizeOptions,
176) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
177 match options.alg {
178 FactorizeAlg::SVD => factorize_svd(t, left_inds, options),
179 FactorizeAlg::QR => factorize_qr(t, left_inds, options),
180 FactorizeAlg::LU => factorize_lu::<f64>(t, left_inds, options),
181 FactorizeAlg::CI => factorize_ci::<f64>(t, left_inds, options),
182 }
183}
184
185fn factorize_impl_f64_full_rank(
186 t: &TensorDynLen,
187 left_inds: &[DynIndex],
188 alg: FactorizeAlg,
189 canonical: Canonical,
190) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
191 match alg {
192 FactorizeAlg::SVD => factorize_svd_full_rank(t, left_inds, canonical),
193 FactorizeAlg::QR => factorize_qr_full_rank(t, left_inds, canonical),
194 FactorizeAlg::LU => factorize_lu_full_rank::<f64>(t, left_inds, canonical),
195 FactorizeAlg::CI => factorize_ci_full_rank::<f64>(t, left_inds, canonical),
196 }
197}
198
199fn factorize_impl_c64(
200 t: &TensorDynLen,
201 left_inds: &[DynIndex],
202 options: &FactorizeOptions,
203) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
204 match options.alg {
205 FactorizeAlg::SVD => factorize_svd(t, left_inds, options),
206 FactorizeAlg::QR => factorize_qr(t, left_inds, options),
207 FactorizeAlg::LU => factorize_lu::<Complex64>(t, left_inds, options),
208 FactorizeAlg::CI => factorize_ci::<Complex64>(t, left_inds, options),
209 }
210}
211
212fn factorize_impl_c64_full_rank(
213 t: &TensorDynLen,
214 left_inds: &[DynIndex],
215 alg: FactorizeAlg,
216 canonical: Canonical,
217) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
218 match alg {
219 FactorizeAlg::SVD => factorize_svd_full_rank(t, left_inds, canonical),
220 FactorizeAlg::QR => factorize_qr_full_rank(t, left_inds, canonical),
221 FactorizeAlg::LU => factorize_lu_full_rank::<Complex64>(t, left_inds, canonical),
222 FactorizeAlg::CI => factorize_ci_full_rank::<Complex64>(t, left_inds, canonical),
223 }
224}
225
226fn factorize_svd(
228 t: &TensorDynLen,
229 left_inds: &[DynIndex],
230 options: &FactorizeOptions,
231) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
232 let mut svd_options = SvdOptions::new();
233 if let Some(policy) = options.svd_policy {
234 svd_options = svd_options.with_policy(policy);
235 }
236 if let Some(max_rank) = options.max_rank {
237 svd_options = svd_options.with_max_rank(max_rank);
238 }
239
240 factorize_svd_with_options(t, left_inds, options.canonical, &svd_options)
241}
242
243fn factorize_svd_full_rank(
244 t: &TensorDynLen,
245 left_inds: &[DynIndex],
246 canonical: Canonical,
247) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
248 let svd_options = SvdOptions::full_rank();
249 factorize_svd_with_options(t, left_inds, canonical, &svd_options)
250}
251
252fn factorize_svd_with_options(
253 t: &TensorDynLen,
254 left_inds: &[DynIndex],
255 canonical: Canonical,
256 svd_options: &SvdOptions,
257) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
258 let result = svd_for_factorize(t, left_inds, svd_options)?;
259 let u = result.u;
260 let s = result.s;
261 let vh = result.vh;
262 let bond_index = result.bond_index;
263 let singular_values = result.singular_values;
264 let rank = result.rank;
265 let sim_bond_index = s.indices[1].clone();
266
267 match canonical {
268 Canonical::Left => {
269 let right_contracted = contract_pair(&s, &vh)?;
271 let right = right_contracted.replaceind(&sim_bond_index, &bond_index)?;
272 Ok(FactorizeResult {
273 left: u,
274 right,
275 bond_index,
276 singular_values: Some(singular_values),
277 rank,
278 })
279 }
280 Canonical::Right => {
281 let left_contracted = contract_pair(&u, &s)?;
283 let left = left_contracted.replaceind(&sim_bond_index, &bond_index)?;
284 Ok(FactorizeResult {
285 left,
286 right: vh,
287 bond_index,
288 singular_values: Some(singular_values),
289 rank,
290 })
291 }
292 }
293}
294
295fn factorize_qr(
297 t: &TensorDynLen,
298 left_inds: &[DynIndex],
299 options: &FactorizeOptions,
300) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
301 if options.canonical == Canonical::Right {
302 return Err(FactorizeError::UnsupportedCanonical(
303 "QR only supports Canonical::Left (would need LQ for right)",
304 ));
305 }
306
307 let qr_options = if let Some(rtol) = options.qr_rtol {
308 QrOptions::new().with_rtol(rtol)
309 } else {
310 QrOptions::new()
311 };
312
313 factorize_qr_with_options(t, left_inds, &qr_options)
314}
315
316fn factorize_qr_full_rank(
317 t: &TensorDynLen,
318 left_inds: &[DynIndex],
319 canonical: Canonical,
320) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
321 if canonical == Canonical::Right {
322 return Err(FactorizeError::UnsupportedCanonical(
323 "QR only supports Canonical::Left (would need LQ for right)",
324 ));
325 }
326
327 factorize_qr_with_options(t, left_inds, &QrOptions::full_rank())
328}
329
330fn factorize_qr_with_options(
331 t: &TensorDynLen,
332 left_inds: &[DynIndex],
333 qr_options: &QrOptions,
334) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
335 let (q, r) = qr_with::<f64>(t, left_inds, qr_options)?;
336
337 let bond_index = q
339 .indices
340 .last()
341 .cloned()
342 .ok_or_else(|| anyhow::anyhow!("QR factorization returned a rank-0 Q tensor"))?;
343 let q_dims = q.dims();
345 let rank = *q_dims
346 .last()
347 .ok_or_else(|| anyhow::anyhow!("QR factorization returned Q with no dimensions"))?;
348
349 Ok(FactorizeResult {
350 left: q,
351 right: r,
352 bond_index,
353 singular_values: None,
354 rank,
355 })
356}
357
358fn factorize_lu<T>(
360 t: &TensorDynLen,
361 left_inds: &[DynIndex],
362 options: &FactorizeOptions,
363) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
364where
365 T: TensorElement
366 + ComplexFloat
367 + Default
368 + From<<T as ComplexFloat>::Real>
369 + MatrixScalar
370 + tensor4all_tcicore::MatrixLuciScalar
371 + 'static,
372 <T as ComplexFloat>::Real: Into<f64> + 'static,
373{
374 factorize_lu_with_options::<T>(
375 t,
376 left_inds,
377 options.canonical,
378 options.max_rank.unwrap_or(usize::MAX),
379 1e-14,
380 )
381}
382
383fn factorize_lu_full_rank<T>(
384 t: &TensorDynLen,
385 left_inds: &[DynIndex],
386 canonical: Canonical,
387) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
388where
389 T: TensorElement
390 + ComplexFloat
391 + Default
392 + From<<T as ComplexFloat>::Real>
393 + MatrixScalar
394 + tensor4all_tcicore::MatrixLuciScalar
395 + 'static,
396 <T as ComplexFloat>::Real: Into<f64> + 'static,
397{
398 factorize_lu_with_options::<T>(t, left_inds, canonical, usize::MAX, 0.0)
399}
400
401fn factorize_lu_with_options<T>(
402 t: &TensorDynLen,
403 left_inds: &[DynIndex],
404 canonical: Canonical,
405 max_rank: usize,
406 rel_tol: f64,
407) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
408where
409 T: TensorElement
410 + ComplexFloat
411 + Default
412 + From<<T as ComplexFloat>::Real>
413 + MatrixScalar
414 + tensor4all_tcicore::MatrixLuciScalar
415 + 'static,
416 <T as ComplexFloat>::Real: Into<f64> + 'static,
417{
418 let (a_tensor, _, m, n, left_indices, right_indices) = unfold_split(t, left_inds)
420 .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))?;
421
422 let a_matrix = native_tensor_to_matrix::<T>(&a_tensor, m, n)?;
424
425 let left_orthogonal = canonical == Canonical::Left;
427 let lu_options = RrLUOptions {
428 max_rank,
429 rel_tol,
430 abs_tol: 0.0,
431 left_orthogonal,
432 };
433
434 let lu = rrlu(&a_matrix, Some(lu_options))?;
436 let rank = lu.npivots();
437
438 let l_matrix = lu.left(true);
440 let u_matrix = lu.right(true);
441
442 let bond_index = DynIndex::new_bond(rank)
444 .map_err(|e| anyhow::anyhow!("Failed to create bond index: {:?}", e))?;
445
446 let l_vec = matrix_to_vec(&l_matrix);
448 let mut l_indices = left_indices.clone();
449 l_indices.push(bond_index.clone());
450 let left =
451 TensorDynLen::from_dense(l_indices, l_vec).map_err(FactorizeError::ComputationError)?;
452
453 let u_vec = matrix_to_vec(&u_matrix);
455 let mut r_indices = vec![bond_index.clone()];
456 r_indices.extend_from_slice(&right_indices);
457 let right =
458 TensorDynLen::from_dense(r_indices, u_vec).map_err(FactorizeError::ComputationError)?;
459
460 Ok(FactorizeResult {
461 left,
462 right,
463 bond_index,
464 singular_values: None,
465 rank,
466 })
467}
468
469fn factorize_ci<T>(
471 t: &TensorDynLen,
472 left_inds: &[DynIndex],
473 options: &FactorizeOptions,
474) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
475where
476 T: TensorElement
477 + ComplexFloat
478 + Default
479 + From<<T as ComplexFloat>::Real>
480 + MatrixScalar
481 + tensor4all_tcicore::MatrixLuciScalar
482 + 'static,
483 <T as ComplexFloat>::Real: Into<f64> + 'static,
484{
485 factorize_ci_with_options::<T>(
486 t,
487 left_inds,
488 options.canonical,
489 options.max_rank.unwrap_or(usize::MAX),
490 1e-14,
491 )
492}
493
494fn factorize_ci_full_rank<T>(
495 t: &TensorDynLen,
496 left_inds: &[DynIndex],
497 canonical: Canonical,
498) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
499where
500 T: TensorElement
501 + ComplexFloat
502 + Default
503 + From<<T as ComplexFloat>::Real>
504 + MatrixScalar
505 + tensor4all_tcicore::MatrixLuciScalar
506 + 'static,
507 <T as ComplexFloat>::Real: Into<f64> + 'static,
508{
509 factorize_ci_with_options::<T>(t, left_inds, canonical, usize::MAX, 0.0)
510}
511
512fn factorize_ci_with_options<T>(
513 t: &TensorDynLen,
514 left_inds: &[DynIndex],
515 canonical: Canonical,
516 max_rank: usize,
517 rel_tol: f64,
518) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
519where
520 T: TensorElement
521 + ComplexFloat
522 + Default
523 + From<<T as ComplexFloat>::Real>
524 + MatrixScalar
525 + tensor4all_tcicore::MatrixLuciScalar
526 + 'static,
527 <T as ComplexFloat>::Real: Into<f64> + 'static,
528{
529 let (matrix_inner, _, m, n, left_indices, right_indices) = unfold_split_inner(t, left_inds)
532 .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))?;
533
534 let a_matrix = native_tensor_to_matrix::<T>(matrix_inner.data(), m, n)?;
536
537 let left_orthogonal = canonical == Canonical::Left;
539 let lu_options = RrLUOptions {
540 max_rank,
541 rel_tol,
542 abs_tol: 0.0,
543 left_orthogonal,
544 };
545
546 let ci = MatrixLUCI::from_matrix(&a_matrix, Some(lu_options))?;
548 let rank = ci.rank();
549
550 let row_indices = ci.row_indices().to_vec();
551 let col_indices = ci.col_indices().to_vec();
552 let cols_inner = matrix_inner.take_cols(&col_indices).map_err(|e| {
553 FactorizeError::ComputationError(anyhow::anyhow!(
554 "fixed-pivot CI column gather failed: {e}"
555 ))
556 })?;
557 let rows_inner = matrix_inner.take_rows(&row_indices).map_err(|e| {
558 FactorizeError::ComputationError(anyhow::anyhow!("fixed-pivot CI row gather failed: {e}"))
559 })?;
560 let pivot_inner = matrix_inner
561 .take_block(&row_indices, &col_indices)
562 .map_err(|e| {
563 FactorizeError::ComputationError(anyhow::anyhow!(
564 "fixed-pivot CI pivot block gather failed: {e}"
565 ))
566 })?;
567 let (left_inner, right_inner) = match canonical {
568 Canonical::Left => {
569 let left = eager_right_solve(&pivot_inner, &cols_inner).map_err(|e| {
570 FactorizeError::ComputationError(anyhow::anyhow!(
571 "fixed-pivot CI right solve failed: {e}"
572 ))
573 })?;
574 (left, rows_inner)
575 }
576 Canonical::Right => {
577 let right = eager_full_piv_lu_solve(&pivot_inner, &rows_inner).map_err(|e| {
578 FactorizeError::ComputationError(anyhow::anyhow!(
579 "fixed-pivot CI solve failed: {e}"
580 ))
581 })?;
582 (cols_inner, right)
583 }
584 };
585
586 let bond_index = DynIndex::new_bond(rank)
588 .map_err(|e| anyhow::anyhow!("Failed to create bond index: {:?}", e))?;
589
590 let mut l_indices = left_indices.clone();
591 l_indices.push(bond_index.clone());
592 let l_dims: Vec<usize> = l_indices.iter().map(|idx| idx.dim).collect();
593 let left_inner = left_inner.reshape(&l_dims).map_err(|e| {
594 FactorizeError::ComputationError(anyhow::anyhow!("fixed-pivot CI left reshape failed: {e}"))
595 })?;
596 let left = TensorDynLen::from_inner(l_indices, left_inner)
597 .map_err(FactorizeError::ComputationError)?;
598
599 let mut r_indices = vec![bond_index.clone()];
600 r_indices.extend_from_slice(&right_indices);
601 let r_dims: Vec<usize> = r_indices.iter().map(|idx| idx.dim).collect();
602 let right_inner = right_inner.reshape(&r_dims).map_err(|e| {
603 FactorizeError::ComputationError(anyhow::anyhow!(
604 "fixed-pivot CI right reshape failed: {e}"
605 ))
606 })?;
607 let right = TensorDynLen::from_inner(r_indices, right_inner)
608 .map_err(FactorizeError::ComputationError)?;
609
610 Ok(FactorizeResult {
611 left,
612 right,
613 bond_index,
614 singular_values: None,
615 rank,
616 })
617}
618
619fn eager_right_solve(a: &EagerTensor, rhs: &EagerTensor) -> AnyhowResult<EagerTensor> {
620 let a_t = a.transpose(&[1, 0])?;
621 let rhs_t = rhs.transpose(&[1, 0])?;
622 Ok(eager_full_piv_lu_solve(&a_t, &rhs_t)?.transpose(&[1, 0])?)
623}
624
625fn native_tensor_to_matrix<T>(
627 tensor: &tenferro::Tensor,
628 m: usize,
629 n: usize,
630) -> Result<Matrix<T>, FactorizeError>
631where
632 T: TensorElement + MatrixScalar + Copy,
633{
634 let data = T::dense_values_from_native_col_major(tensor).map_err(|e| {
635 FactorizeError::ComputationError(anyhow::anyhow!(
636 "failed to extract dense matrix entries from native tensor: {e}"
637 ))
638 })?;
639 if data.len() != m * n {
640 return Err(FactorizeError::ComputationError(anyhow::anyhow!(
641 "native matrix materialization produced {} entries for shape ({m}, {n})",
642 data.len()
643 )));
644 }
645
646 let mut matrix = Matrix::zeros(m, n);
647 for i in 0..m {
648 for j in 0..n {
649 matrix[[i, j]] = data[j * m + i];
650 }
651 }
652 Ok(matrix)
653}
654
655fn matrix_to_vec<T>(matrix: &Matrix<T>) -> Vec<T>
657where
658 T: Clone,
659{
660 let m = matrix.nrows();
661 let n = matrix.ncols();
662 let mut vec = Vec::with_capacity(m * n);
663 for j in 0..n {
664 for i in 0..m {
665 vec.push(matrix[[i, j]].clone());
666 }
667 }
668 vec
669}
670
671#[cfg(test)]
672mod tests;