1use crate::defaults::tensordynlen::unfold_split_inner;
36use crate::defaults::DynIndex;
37use crate::{contract_pair, unfold_split, TensorDynLen};
38use num_complex::{Complex64, ComplexFloat};
39use tensor4all_tcicore::{rrlu, AbstractMatrixCI, MatrixLUCI, RrLUOptions, Scalar as MatrixScalar};
40use tensor4all_tensorbackend::{Matrix, TensorElement};
41
42use crate::defaults::svd::svd_for_factorize;
43use crate::qr::{qr_with, QrOptions};
44use crate::svd::SvdOptions;
45
46pub use crate::tensor_like::{
48 Canonical, FactorizeAlg, FactorizeError, FactorizeOptions, FactorizeResult,
49};
50
51pub fn factorize(
78 t: &TensorDynLen,
79 left_inds: &[DynIndex],
80 options: &FactorizeOptions,
81) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
82 options.validate()?;
83
84 if t.is_diag() {
85 return Err(FactorizeError::UnsupportedStorage(
86 "Diagonal storage not supported for factorize",
87 ));
88 }
89
90 if t.is_f64() {
91 factorize_impl_f64(t, left_inds, options)
92 } else if t.is_complex() {
93 factorize_impl_c64(t, left_inds, options)
94 } else {
95 Err(FactorizeError::UnsupportedStorage(
96 "factorize currently supports only f64 and Complex64 tensors",
97 ))
98 }
99}
100
101pub fn factorize_full_rank(
147 t: &TensorDynLen,
148 left_inds: &[DynIndex],
149 alg: FactorizeAlg,
150 canonical: Canonical,
151) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
152 if t.is_diag() {
153 return Err(FactorizeError::UnsupportedStorage(
154 "Diagonal storage not supported for factorize",
155 ));
156 }
157
158 if t.is_f64() {
159 factorize_impl_f64_full_rank(t, left_inds, alg, canonical)
160 } else if t.is_complex() {
161 factorize_impl_c64_full_rank(t, left_inds, alg, canonical)
162 } else {
163 Err(FactorizeError::UnsupportedStorage(
164 "factorize currently supports only f64 and Complex64 tensors",
165 ))
166 }
167}
168
169fn factorize_impl_f64(
170 t: &TensorDynLen,
171 left_inds: &[DynIndex],
172 options: &FactorizeOptions,
173) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
174 match options.alg {
175 FactorizeAlg::SVD => factorize_svd(t, left_inds, options),
176 FactorizeAlg::QR => factorize_qr(t, left_inds, options),
177 FactorizeAlg::LU => factorize_lu::<f64>(t, left_inds, options),
178 FactorizeAlg::CI => factorize_ci::<f64>(t, left_inds, options),
179 }
180}
181
182fn factorize_impl_f64_full_rank(
183 t: &TensorDynLen,
184 left_inds: &[DynIndex],
185 alg: FactorizeAlg,
186 canonical: Canonical,
187) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
188 match alg {
189 FactorizeAlg::SVD => factorize_svd_full_rank(t, left_inds, canonical),
190 FactorizeAlg::QR => factorize_qr_full_rank(t, left_inds, canonical),
191 FactorizeAlg::LU => factorize_lu_full_rank::<f64>(t, left_inds, canonical),
192 FactorizeAlg::CI => factorize_ci_full_rank::<f64>(t, left_inds, canonical),
193 }
194}
195
196fn factorize_impl_c64(
197 t: &TensorDynLen,
198 left_inds: &[DynIndex],
199 options: &FactorizeOptions,
200) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
201 match options.alg {
202 FactorizeAlg::SVD => factorize_svd(t, left_inds, options),
203 FactorizeAlg::QR => factorize_qr(t, left_inds, options),
204 FactorizeAlg::LU => factorize_lu::<Complex64>(t, left_inds, options),
205 FactorizeAlg::CI => factorize_ci::<Complex64>(t, left_inds, options),
206 }
207}
208
209fn factorize_impl_c64_full_rank(
210 t: &TensorDynLen,
211 left_inds: &[DynIndex],
212 alg: FactorizeAlg,
213 canonical: Canonical,
214) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
215 match alg {
216 FactorizeAlg::SVD => factorize_svd_full_rank(t, left_inds, canonical),
217 FactorizeAlg::QR => factorize_qr_full_rank(t, left_inds, canonical),
218 FactorizeAlg::LU => factorize_lu_full_rank::<Complex64>(t, left_inds, canonical),
219 FactorizeAlg::CI => factorize_ci_full_rank::<Complex64>(t, left_inds, canonical),
220 }
221}
222
223fn factorize_svd(
225 t: &TensorDynLen,
226 left_inds: &[DynIndex],
227 options: &FactorizeOptions,
228) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
229 let mut svd_options = SvdOptions::new();
230 if let Some(policy) = options.svd_policy {
231 svd_options = svd_options.with_policy(policy);
232 }
233 if let Some(max_rank) = options.max_rank {
234 svd_options = svd_options.with_max_rank(max_rank);
235 }
236
237 factorize_svd_with_options(t, left_inds, options.canonical, &svd_options)
238}
239
240fn factorize_svd_full_rank(
241 t: &TensorDynLen,
242 left_inds: &[DynIndex],
243 canonical: Canonical,
244) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
245 let svd_options = SvdOptions::full_rank();
246 factorize_svd_with_options(t, left_inds, canonical, &svd_options)
247}
248
249fn factorize_svd_with_options(
250 t: &TensorDynLen,
251 left_inds: &[DynIndex],
252 canonical: Canonical,
253 svd_options: &SvdOptions,
254) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
255 let result = svd_for_factorize(t, left_inds, svd_options)?;
256 let u = result.u;
257 let s = result.s;
258 let vh = result.vh;
259 let bond_index = result.bond_index;
260 let singular_values = result.singular_values;
261 let rank = result.rank;
262 let sim_bond_index = s.indices[1].clone();
263
264 match canonical {
265 Canonical::Left => {
266 let right_contracted = contract_pair(&s, &vh)?;
268 let right = right_contracted.replaceind(&sim_bond_index, &bond_index)?;
269 Ok(FactorizeResult {
270 left: u,
271 right,
272 bond_index,
273 singular_values: Some(singular_values),
274 rank,
275 })
276 }
277 Canonical::Right => {
278 let left_contracted = contract_pair(&u, &s)?;
280 let left = left_contracted.replaceind(&sim_bond_index, &bond_index)?;
281 Ok(FactorizeResult {
282 left,
283 right: vh,
284 bond_index,
285 singular_values: Some(singular_values),
286 rank,
287 })
288 }
289 }
290}
291
292fn factorize_qr(
294 t: &TensorDynLen,
295 left_inds: &[DynIndex],
296 options: &FactorizeOptions,
297) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
298 if options.canonical == Canonical::Right {
299 return Err(FactorizeError::UnsupportedCanonical(
300 "QR only supports Canonical::Left (would need LQ for right)",
301 ));
302 }
303
304 let qr_options = if let Some(rtol) = options.qr_rtol {
305 QrOptions::new().with_rtol(rtol)
306 } else {
307 QrOptions::new()
308 };
309
310 factorize_qr_with_options(t, left_inds, &qr_options)
311}
312
313fn factorize_qr_full_rank(
314 t: &TensorDynLen,
315 left_inds: &[DynIndex],
316 canonical: Canonical,
317) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
318 if canonical == Canonical::Right {
319 return Err(FactorizeError::UnsupportedCanonical(
320 "QR only supports Canonical::Left (would need LQ for right)",
321 ));
322 }
323
324 factorize_qr_with_options(t, left_inds, &QrOptions::full_rank())
325}
326
327fn factorize_qr_with_options(
328 t: &TensorDynLen,
329 left_inds: &[DynIndex],
330 qr_options: &QrOptions,
331) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
332 let (q, r) = qr_with::<f64>(t, left_inds, qr_options)?;
333
334 let bond_index = q
336 .indices
337 .last()
338 .cloned()
339 .ok_or_else(|| anyhow::anyhow!("QR factorization returned a rank-0 Q tensor"))?;
340 let q_dims = q.dims();
342 let rank = *q_dims
343 .last()
344 .ok_or_else(|| anyhow::anyhow!("QR factorization returned Q with no dimensions"))?;
345
346 Ok(FactorizeResult {
347 left: q,
348 right: r,
349 bond_index,
350 singular_values: None,
351 rank,
352 })
353}
354
355fn factorize_lu<T>(
357 t: &TensorDynLen,
358 left_inds: &[DynIndex],
359 options: &FactorizeOptions,
360) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
361where
362 T: TensorElement
363 + ComplexFloat
364 + Default
365 + From<<T as ComplexFloat>::Real>
366 + MatrixScalar
367 + tensor4all_tcicore::MatrixLuciScalar
368 + 'static,
369 <T as ComplexFloat>::Real: Into<f64> + 'static,
370{
371 factorize_lu_with_options::<T>(
372 t,
373 left_inds,
374 options.canonical,
375 options.max_rank.unwrap_or(usize::MAX),
376 1e-14,
377 )
378}
379
380fn factorize_lu_full_rank<T>(
381 t: &TensorDynLen,
382 left_inds: &[DynIndex],
383 canonical: Canonical,
384) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
385where
386 T: TensorElement
387 + ComplexFloat
388 + Default
389 + From<<T as ComplexFloat>::Real>
390 + MatrixScalar
391 + tensor4all_tcicore::MatrixLuciScalar
392 + 'static,
393 <T as ComplexFloat>::Real: Into<f64> + 'static,
394{
395 factorize_lu_with_options::<T>(t, left_inds, canonical, usize::MAX, 0.0)
396}
397
398fn factorize_lu_with_options<T>(
399 t: &TensorDynLen,
400 left_inds: &[DynIndex],
401 canonical: Canonical,
402 max_rank: usize,
403 rel_tol: f64,
404) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
405where
406 T: TensorElement
407 + ComplexFloat
408 + Default
409 + From<<T as ComplexFloat>::Real>
410 + MatrixScalar
411 + tensor4all_tcicore::MatrixLuciScalar
412 + 'static,
413 <T as ComplexFloat>::Real: Into<f64> + 'static,
414{
415 let (a_tensor, _, m, n, left_indices, right_indices) = unfold_split(t, left_inds)
417 .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))?;
418
419 let a_matrix = native_tensor_to_matrix::<T>(&a_tensor, m, n)?;
421
422 let left_orthogonal = canonical == Canonical::Left;
424 let lu_options = RrLUOptions {
425 max_rank,
426 rel_tol,
427 abs_tol: 0.0,
428 left_orthogonal,
429 };
430
431 let lu = rrlu(&a_matrix, Some(lu_options))?;
433 let rank = lu.npivots();
434
435 let l_matrix = lu.left(true);
437 let u_matrix = lu.right(true);
438
439 let bond_index = DynIndex::new_bond(rank)
441 .map_err(|e| anyhow::anyhow!("Failed to create bond index: {:?}", e))?;
442
443 let l_vec = matrix_to_vec(&l_matrix);
445 let mut l_indices = left_indices.clone();
446 l_indices.push(bond_index.clone());
447 let left =
448 TensorDynLen::from_dense(l_indices, l_vec).map_err(FactorizeError::ComputationError)?;
449
450 let u_vec = matrix_to_vec(&u_matrix);
452 let mut r_indices = vec![bond_index.clone()];
453 r_indices.extend_from_slice(&right_indices);
454 let right =
455 TensorDynLen::from_dense(r_indices, u_vec).map_err(FactorizeError::ComputationError)?;
456
457 Ok(FactorizeResult {
458 left,
459 right,
460 bond_index,
461 singular_values: None,
462 rank,
463 })
464}
465
466fn factorize_ci<T>(
468 t: &TensorDynLen,
469 left_inds: &[DynIndex],
470 options: &FactorizeOptions,
471) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
472where
473 T: TensorElement
474 + ComplexFloat
475 + Default
476 + From<<T as ComplexFloat>::Real>
477 + MatrixScalar
478 + tensor4all_tcicore::MatrixLuciScalar
479 + 'static,
480 <T as ComplexFloat>::Real: Into<f64> + 'static,
481{
482 factorize_ci_with_options::<T>(
483 t,
484 left_inds,
485 options.canonical,
486 options.max_rank.unwrap_or(usize::MAX),
487 1e-14,
488 )
489}
490
491fn factorize_ci_full_rank<T>(
492 t: &TensorDynLen,
493 left_inds: &[DynIndex],
494 canonical: Canonical,
495) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
496where
497 T: TensorElement
498 + ComplexFloat
499 + Default
500 + From<<T as ComplexFloat>::Real>
501 + MatrixScalar
502 + tensor4all_tcicore::MatrixLuciScalar
503 + 'static,
504 <T as ComplexFloat>::Real: Into<f64> + 'static,
505{
506 factorize_ci_with_options::<T>(t, left_inds, canonical, usize::MAX, 0.0)
507}
508
509fn factorize_ci_with_options<T>(
510 t: &TensorDynLen,
511 left_inds: &[DynIndex],
512 canonical: Canonical,
513 max_rank: usize,
514 rel_tol: f64,
515) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
516where
517 T: TensorElement
518 + ComplexFloat
519 + Default
520 + From<<T as ComplexFloat>::Real>
521 + MatrixScalar
522 + tensor4all_tcicore::MatrixLuciScalar
523 + 'static,
524 <T as ComplexFloat>::Real: Into<f64> + 'static,
525{
526 let (matrix_inner, _, m, n, left_indices, right_indices) = unfold_split_inner(t, left_inds)
529 .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))?;
530
531 let a_matrix = native_tensor_to_matrix::<T>(matrix_inner.data(), m, n)?;
533
534 let left_orthogonal = canonical == Canonical::Left;
536 let lu_options = RrLUOptions {
537 max_rank,
538 rel_tol,
539 abs_tol: 0.0,
540 left_orthogonal,
541 };
542
543 let ci = MatrixLUCI::from_matrix(&a_matrix, Some(lu_options))?;
545 let rank = ci.rank();
546
547 let row_indices = ci.row_indices().to_vec();
548 let col_indices = ci.col_indices().to_vec();
549 let cols_inner = matrix_inner.take_cols(&col_indices).map_err(|e| {
550 FactorizeError::ComputationError(anyhow::anyhow!(
551 "fixed-pivot CI column gather failed: {e}"
552 ))
553 })?;
554 let rows_inner = matrix_inner.take_rows(&row_indices).map_err(|e| {
555 FactorizeError::ComputationError(anyhow::anyhow!("fixed-pivot CI row gather failed: {e}"))
556 })?;
557 let pivot_inner = matrix_inner
558 .take_block(&row_indices, &col_indices)
559 .map_err(|e| {
560 FactorizeError::ComputationError(anyhow::anyhow!(
561 "fixed-pivot CI pivot block gather failed: {e}"
562 ))
563 })?;
564 let (left_inner, right_inner) = match canonical {
565 Canonical::Left => {
566 let left = pivot_inner.right_solve(&cols_inner).map_err(|e| {
567 FactorizeError::ComputationError(anyhow::anyhow!(
568 "fixed-pivot CI right solve failed: {e}"
569 ))
570 })?;
571 (left, rows_inner)
572 }
573 Canonical::Right => {
574 let right = pivot_inner.solve(&rows_inner).map_err(|e| {
575 FactorizeError::ComputationError(anyhow::anyhow!(
576 "fixed-pivot CI solve failed: {e}"
577 ))
578 })?;
579 (cols_inner, right)
580 }
581 };
582
583 let bond_index = DynIndex::new_bond(rank)
585 .map_err(|e| anyhow::anyhow!("Failed to create bond index: {:?}", e))?;
586
587 let mut l_indices = left_indices.clone();
588 l_indices.push(bond_index.clone());
589 let l_dims: Vec<usize> = l_indices.iter().map(|idx| idx.dim).collect();
590 let left_inner = left_inner.reshape(&l_dims).map_err(|e| {
591 FactorizeError::ComputationError(anyhow::anyhow!("fixed-pivot CI left reshape failed: {e}"))
592 })?;
593 let left = TensorDynLen::from_inner(l_indices, left_inner)
594 .map_err(FactorizeError::ComputationError)?;
595
596 let mut r_indices = vec![bond_index.clone()];
597 r_indices.extend_from_slice(&right_indices);
598 let r_dims: Vec<usize> = r_indices.iter().map(|idx| idx.dim).collect();
599 let right_inner = right_inner.reshape(&r_dims).map_err(|e| {
600 FactorizeError::ComputationError(anyhow::anyhow!(
601 "fixed-pivot CI right reshape failed: {e}"
602 ))
603 })?;
604 let right = TensorDynLen::from_inner(r_indices, right_inner)
605 .map_err(FactorizeError::ComputationError)?;
606
607 Ok(FactorizeResult {
608 left,
609 right,
610 bond_index,
611 singular_values: None,
612 rank,
613 })
614}
615
616fn native_tensor_to_matrix<T>(
618 tensor: &tenferro::Tensor,
619 m: usize,
620 n: usize,
621) -> Result<Matrix<T>, FactorizeError>
622where
623 T: TensorElement + MatrixScalar + Copy,
624{
625 let data = T::dense_values_from_native_col_major(tensor).map_err(|e| {
626 FactorizeError::ComputationError(anyhow::anyhow!(
627 "failed to extract dense matrix entries from native tensor: {e}"
628 ))
629 })?;
630 if data.len() != m * n {
631 return Err(FactorizeError::ComputationError(anyhow::anyhow!(
632 "native matrix materialization produced {} entries for shape ({m}, {n})",
633 data.len()
634 )));
635 }
636
637 let mut matrix = Matrix::zeros(m, n);
638 for i in 0..m {
639 for j in 0..n {
640 matrix[[i, j]] = data[j * m + i];
641 }
642 }
643 Ok(matrix)
644}
645
646fn matrix_to_vec<T>(matrix: &Matrix<T>) -> Vec<T>
648where
649 T: Clone,
650{
651 let m = matrix.nrows();
652 let n = matrix.ncols();
653 let mut vec = Vec::with_capacity(m * n);
654 for j in 0..n {
655 for i in 0..m {
656 vec.push(matrix[[i, j]].clone());
657 }
658 }
659 vec
660}
661
662#[cfg(test)]
663mod tests;