1use crate::error::{Result, TCIError};
8use crate::globalpivot::{DefaultGlobalPivotFinder, GlobalPivotFinder, GlobalPivotSearchInput};
9use rand::SeedableRng;
10use std::cell::{Cell, RefCell};
11use std::collections::HashMap;
12use tensor4all_simplett::{tensor3_zeros, TTScalar, Tensor3, Tensor3Ops, TensorTrain};
13use tensor4all_tcicore::matrix::zeros;
14use tensor4all_tcicore::MultiIndex;
15use tensor4all_tcicore::Scalar;
16use tensor4all_tcicore::{
17 rrlu, AbstractMatrixCI, CrossFactors, DenseFaerLuKernel, DenseMatrixSource,
18 LazyBlockRookKernel, LazyMatrixSource, MatrixLUCI, PivotKernel, PivotKernelOptions,
19 RrLUOptions,
20};
21
22#[derive(Debug, Clone)]
74pub struct TCI2Options {
75 pub tolerance: f64,
82 pub max_iter: usize,
88 pub max_bond_dim: usize,
94 pub pivot_search: PivotSearchStrategy,
100 pub normalize_error: bool,
106 pub verbosity: usize,
110 pub max_nglobal_pivot: usize,
112 pub nsearch: usize,
117 pub sweep_strategy: Sweep2Strategy,
119 pub ncheck_history: usize,
124 pub strictly_nested: bool,
129 pub tol_margin_global_search: f64,
134 pub seed: Option<u64>,
138}
139
140impl Default for TCI2Options {
141 fn default() -> Self {
142 Self {
143 tolerance: 1e-8,
144 max_iter: 20,
145 max_bond_dim: usize::MAX,
146 pivot_search: PivotSearchStrategy::Full,
147 normalize_error: true,
148 verbosity: 0,
149 max_nglobal_pivot: 5,
150 nsearch: 5,
151 sweep_strategy: Sweep2Strategy::BackAndForth,
152 ncheck_history: 3,
153 strictly_nested: false,
154 tol_margin_global_search: 10.0,
155 seed: None,
156 }
157 }
158}
159
160#[derive(Clone, Copy)]
161struct Sweep1SiteBondConfig {
162 rel_tol: f64,
163 abs_tol: f64,
164 max_bond_dim: usize,
165 update_tensors: bool,
166}
167
168struct PivotUpdateContext<'a, B> {
169 batched_f: &'a Option<B>,
170 left_orthogonal: bool,
171 options: &'a TCI2Options,
172 extra_i_set: &'a [MultiIndex],
173 extra_j_set: &'a [MultiIndex],
174}
175
176#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
194pub enum PivotSearchStrategy {
195 #[default]
201 Full,
202 Rook,
209}
210
211#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
228pub enum Sweep2Strategy {
229 Forward,
231 Backward,
233 #[default]
239 BackAndForth,
240}
241
242#[derive(Debug, Clone)]
262pub struct TensorCI2<T: Scalar + TTScalar> {
263 i_set: Vec<Vec<MultiIndex>>,
265 j_set: Vec<Vec<MultiIndex>>,
267 local_dims: Vec<usize>,
269 site_tensors: Vec<Tensor3<T>>,
271 pivot_errors: Vec<f64>,
273 bond_errors: Vec<f64>,
275 max_sample_value: f64,
277 i_set_history: Vec<Vec<Vec<MultiIndex>>>,
279 j_set_history: Vec<Vec<Vec<MultiIndex>>>,
281}
282
283impl<T> TensorCI2<T>
284where
285 T: Scalar + TTScalar + Default + tensor4all_tcicore::MatrixLuciScalar,
286 DenseFaerLuKernel: PivotKernel<T>,
287 LazyBlockRookKernel: PivotKernel<T>,
288{
289 pub fn new(local_dims: Vec<usize>) -> Result<Self> {
291 if local_dims.len() < 2 {
292 return Err(TCIError::DimensionMismatch {
293 message: "local_dims should have at least 2 elements".to_string(),
294 });
295 }
296
297 let n = local_dims.len();
298 Ok(Self {
299 i_set: (0..n).map(|_| Vec::new()).collect(),
300 j_set: (0..n).map(|_| Vec::new()).collect(),
301 local_dims: local_dims.clone(),
302 site_tensors: local_dims.iter().map(|&d| tensor3_zeros(0, d, 0)).collect(),
303 pivot_errors: Vec::new(),
304 bond_errors: vec![0.0; n.saturating_sub(1)],
305 max_sample_value: 0.0,
306 i_set_history: Vec::new(),
307 j_set_history: Vec::new(),
308 })
309 }
310
311 pub fn len(&self) -> usize {
313 self.local_dims.len()
314 }
315
316 pub fn is_empty(&self) -> bool {
318 self.local_dims.is_empty()
319 }
320
321 pub fn local_dims(&self) -> &[usize] {
323 &self.local_dims
324 }
325
326 pub fn rank(&self) -> usize {
328 if self.len() <= 1 {
329 return if self.i_set.is_empty() || self.i_set[0].is_empty() {
330 0
331 } else {
332 1
333 };
334 }
335 self.i_set
336 .iter()
337 .skip(1)
338 .map(|s| s.len())
339 .max()
340 .unwrap_or(0)
341 }
342
343 pub fn link_dims(&self) -> Vec<usize> {
345 if self.len() <= 1 {
346 return Vec::new();
347 }
348 self.i_set.iter().skip(1).map(|s| s.len()).collect()
349 }
350
351 pub fn max_sample_value(&self) -> f64 {
353 self.max_sample_value
354 }
355
356 pub fn max_bond_error(&self) -> f64 {
358 self.bond_errors.iter().cloned().fold(0.0, f64::max)
359 }
360
361 pub fn pivot_errors(&self) -> &[f64] {
363 &self.pivot_errors
364 }
365
366 pub fn is_site_tensors_available(&self) -> bool {
368 self.site_tensors
369 .iter()
370 .all(|t| t.left_dim() > 0 || t.right_dim() > 0)
371 }
372
373 pub fn site_tensor(&self, p: usize) -> &Tensor3<T> {
375 &self.site_tensors[p]
376 }
377
378 pub fn to_tensor_train(&self) -> Result<TensorTrain<T>> {
380 let tensors = self.site_tensors.clone();
381 TensorTrain::new(tensors).map_err(TCIError::TensorTrainError)
382 }
383
384 pub fn add_global_pivots(&mut self, pivots: &[MultiIndex]) -> Result<()> {
386 for pivot in pivots {
387 if pivot.len() != self.len() {
388 return Err(TCIError::DimensionMismatch {
389 message: format!(
390 "Pivot length ({}) must match number of sites ({})",
391 pivot.len(),
392 self.len()
393 ),
394 });
395 }
396
397 for p in 0..self.len() {
399 let i_indices: MultiIndex = pivot[0..p].to_vec();
400 let j_indices: MultiIndex = pivot[p + 1..].to_vec();
401
402 if !self.i_set[p].contains(&i_indices) {
403 self.i_set[p].push(i_indices);
404 }
405 if !self.j_set[p].contains(&j_indices) {
406 self.j_set[p].push(j_indices);
407 }
408 }
409 }
410
411 self.invalidate_site_tensors();
413
414 Ok(())
415 }
416
417 pub fn i_set(&self, p: usize) -> &[MultiIndex] {
419 &self.i_set[p]
420 }
421
422 pub fn j_set(&self, p: usize) -> &[MultiIndex] {
424 &self.j_set[p]
425 }
426
427 pub fn invalidate_site_tensors(&mut self) {
429 for p in 0..self.len() {
430 self.site_tensors[p] = tensor3_zeros(0, self.local_dims[p], 0);
431 }
432 }
433
434 pub fn flush_pivot_errors(&mut self) {
436 self.pivot_errors.clear();
437 }
438
439 pub fn sweep2site<F, B>(
444 &mut self,
445 f: &F,
446 batched_f: &Option<B>,
447 forward: bool,
448 options: &TCI2Options,
449 ) -> Result<()>
450 where
451 F: Fn(&MultiIndex) -> T,
452 B: Fn(&[MultiIndex]) -> Vec<T>,
453 {
454 let n = self.len();
455 self.invalidate_site_tensors();
456 self.flush_pivot_errors();
457
458 let empty: Vec<MultiIndex> = Vec::new();
459 if forward {
460 for b in 0..n - 1 {
461 update_pivots(
462 self,
463 b,
464 f,
465 PivotUpdateContext {
466 batched_f,
467 left_orthogonal: true,
468 options,
469 extra_i_set: &empty,
470 extra_j_set: &empty,
471 },
472 )?;
473 }
474 } else {
475 for b in (0..n - 1).rev() {
476 update_pivots(
477 self,
478 b,
479 f,
480 PivotUpdateContext {
481 batched_f,
482 left_orthogonal: false,
483 options,
484 extra_i_set: &empty,
485 extra_j_set: &empty,
486 },
487 )?;
488 }
489 }
490
491 self.fill_site_tensors(f)?;
493 Ok(())
494 }
495
496 fn update_pivot_errors(&mut self, errors: &[f64]) {
498 if self.pivot_errors.len() < errors.len() {
499 self.pivot_errors.resize(errors.len(), 0.0);
500 }
501 for (i, &e) in errors.iter().enumerate() {
502 self.pivot_errors[i] = self.pivot_errors[i].max(e);
503 }
504 }
505
506 fn fill_tensor<F>(
510 &self,
511 f: &F,
512 i_indices: &[MultiIndex],
513 j_indices: &[MultiIndex],
514 local_dim: usize,
515 site: usize,
516 ) -> Tensor3<T>
517 where
518 F: Fn(&MultiIndex) -> T,
519 {
520 let ni = i_indices.len();
521 let nj = j_indices.len();
522 let mut tensor = tensor3_zeros(ni, local_dim, nj);
523 for (ii, i_multi) in i_indices.iter().enumerate() {
524 for s in 0..local_dim {
525 for (jj, j_multi) in j_indices.iter().enumerate() {
526 let mut full_idx = i_multi.clone();
527 full_idx.push(s);
528 full_idx.extend(j_multi.iter().cloned());
529 debug_assert_eq!(
530 full_idx.len(),
531 self.local_dims.len(),
532 "fill_tensor: full_idx length {} != n_sites {} at site {}",
533 full_idx.len(),
534 self.local_dims.len(),
535 site
536 );
537 let val = f(&full_idx);
538 tensor.set3(ii, s, jj, val);
539 }
540 }
541 }
542 tensor
543 }
544
545 pub fn sweep1site<F>(
552 &mut self,
553 f: &F,
554 forward: bool,
555 rel_tol: f64,
556 abs_tol: f64,
557 max_bond_dim: usize,
558 update_tensors: bool,
559 ) -> Result<()>
560 where
561 F: Fn(&MultiIndex) -> T,
562 {
563 self.flush_pivot_errors();
564 self.invalidate_site_tensors();
565
566 let n = self.len();
567 let bond_config = Sweep1SiteBondConfig {
568 rel_tol,
569 abs_tol,
570 max_bond_dim,
571 update_tensors,
572 };
573
574 if forward {
575 for b in 0..n - 1 {
576 self.sweep1site_at_bond(f, b, true, bond_config)?;
577 }
578 } else {
579 for b in (1..n).rev() {
580 self.sweep1site_at_bond(f, b, false, bond_config)?;
581 }
582 }
583
584 if update_tensors {
586 let last_idx = if forward { n - 1 } else { 0 };
587 let tensor = self.fill_tensor(
588 f,
589 &self.i_set[last_idx].clone(),
590 &self.j_set[last_idx].clone(),
591 self.local_dims[last_idx],
592 last_idx,
593 );
594 self.site_tensors[last_idx] = tensor;
595 }
596
597 Ok(())
598 }
599
600 fn sweep1site_at_bond<F>(
602 &mut self,
603 f: &F,
604 b: usize,
605 forward: bool,
606 config: Sweep1SiteBondConfig,
607 ) -> Result<()>
608 where
609 F: Fn(&MultiIndex) -> T,
610 {
611 let (is, js) = if forward {
614 (self.kronecker_i(b), self.j_set[b].clone())
615 } else {
616 (self.i_set[b].clone(), self.kronecker_j(b))
617 };
618
619 if is.is_empty() || js.is_empty() {
620 return Ok(());
621 }
622
623 let ni = is.len();
625 let nj = js.len();
626 let mut pi = zeros(ni, nj);
627 for (i, i_multi) in is.iter().enumerate() {
628 for (j, j_multi) in js.iter().enumerate() {
629 let mut full_idx = i_multi.clone();
630 full_idx.extend(j_multi.iter().cloned());
631 let val = f(&full_idx);
632 pi[[i, j]] = val;
633 let abs_val = f64::sqrt(Scalar::abs_sq(val));
634 if abs_val > self.max_sample_value {
635 self.max_sample_value = abs_val;
636 }
637 }
638 }
639
640 let lu_options = RrLUOptions {
642 max_rank: config.max_bond_dim,
643 rel_tol: config.rel_tol,
644 abs_tol: config.abs_tol,
645 left_orthogonal: forward,
646 };
647 let luci = MatrixLUCI::from_matrix(&pi, Some(lu_options))?;
648
649 let row_indices = luci.row_indices();
650 let col_indices = luci.col_indices();
651
652 if forward {
654 self.i_set[b + 1] = row_indices.iter().map(|&i| is[i].clone()).collect();
655 self.j_set[b] = col_indices.iter().map(|&j| js[j].clone()).collect();
656 } else {
657 self.i_set[b] = row_indices.iter().map(|&i| is[i].clone()).collect();
658 self.j_set[b - 1] = col_indices.iter().map(|&j| js[j].clone()).collect();
659 }
660
661 if config.update_tensors {
663 let mat = if forward { luci.left() } else { luci.right() };
664 let local_dim = self.local_dims[b];
665 if forward {
666 let left_dim = if b == 0 { 1 } else { self.i_set[b].len() };
667 let right_dim = luci.rank().max(1);
668 let mut tensor = tensor3_zeros(left_dim, local_dim, right_dim);
669 for l in 0..left_dim {
670 for s in 0..local_dim {
671 for r in 0..right_dim {
672 let row = l * local_dim + s;
673 if row < mat.nrows() && r < mat.ncols() {
674 tensor.set3(l, s, r, mat[[row, r]]);
675 }
676 }
677 }
678 }
679 self.site_tensors[b] = tensor;
680 } else {
681 let left_dim = luci.rank().max(1);
682 let right_dim = if b == self.len() - 1 {
683 1
684 } else {
685 self.j_set[b].len()
686 };
687 let mut tensor = tensor3_zeros(left_dim, local_dim, right_dim);
688 for l in 0..left_dim {
689 for s in 0..local_dim {
690 for r in 0..right_dim {
691 let col = s * right_dim + r;
692 if l < mat.nrows() && col < mat.ncols() {
693 tensor.set3(l, s, r, mat[[l, col]]);
694 }
695 }
696 }
697 }
698 self.site_tensors[b] = tensor;
699 }
700 }
701
702 let errors = luci.pivot_errors();
704 if !errors.is_empty() {
705 let bond_idx = if forward { b } else { b - 1 };
706 self.bond_errors[bond_idx] = *errors.last().unwrap_or(&0.0);
707 }
708 self.update_pivot_errors(&errors);
709
710 Ok(())
711 }
712
713 pub fn fill_site_tensors<F>(&mut self, f: &F) -> Result<()>
722 where
723 F: Fn(&MultiIndex) -> T,
724 {
725 let n = self.len();
726 for b in 0..n {
727 let i_kron = self.kronecker_i(b);
728 let j_set_b = self.j_set[b].clone();
729
730 if i_kron.is_empty() || j_set_b.is_empty() {
731 continue;
732 }
733
734 let ni = i_kron.len();
736 let nj = j_set_b.len();
737 let mut pi1 = zeros(ni, nj);
738 for (i, i_multi) in i_kron.iter().enumerate() {
739 for (j, j_multi) in j_set_b.iter().enumerate() {
740 let mut full_idx = i_multi.clone();
741 full_idx.extend(j_multi.iter().cloned());
742 pi1[[i, j]] = f(&full_idx);
743 }
744 }
745
746 if b == n - 1 {
747 let left_dim = if b == 0 { 1 } else { self.i_set[b].len() };
749 let site_dim = self.local_dims[b];
750 let right_dim = 1; let mut tensor = tensor3_zeros(left_dim, site_dim, right_dim);
752 for l in 0..left_dim {
753 for s in 0..site_dim {
754 let row = l * site_dim + s;
755 if row < ni {
756 tensor.set3(l, s, 0, pi1[[row, 0]]);
757 }
758 }
759 }
760 self.site_tensors[b] = tensor;
761 } else {
762 let i_set_bp1 = self.i_set[b + 1].clone();
765 let np = i_set_bp1.len();
766
767 let mut p_mat = zeros(np, nj);
768 for (i, i_multi) in i_set_bp1.iter().enumerate() {
769 for (j, j_multi) in j_set_b.iter().enumerate() {
770 let mut full_idx = i_multi.clone();
771 full_idx.extend(j_multi.iter().cloned());
772 p_mat[[i, j]] = f(&full_idx);
773 }
774 }
775
776 let mut p_t = zeros(nj, np);
779 for i in 0..np {
780 for j in 0..nj {
781 p_t[[j, i]] = p_mat[[i, j]];
782 }
783 }
784 let mut pi1_t = zeros(nj, ni);
785 for i in 0..ni {
786 for j in 0..nj {
787 pi1_t[[j, i]] = pi1[[i, j]];
788 }
789 }
790
791 let lu = rrlu(&p_t, None)?;
793 let l_mat = lu.left(true);
794 let u_mat = lu.right(true);
795
796 let x_t = tensor4all_tcicore::matrixlu::solve_lu(&l_mat, &u_mat, &pi1_t)?;
798
799 let left_dim = if b == 0 { 1 } else { self.i_set[b].len() };
801 let site_dim = self.local_dims[b];
802 let right_dim = np; let mut tensor = tensor3_zeros(left_dim, site_dim, right_dim);
804 for l in 0..left_dim {
805 for s in 0..site_dim {
806 for r in 0..right_dim {
807 let row = l * site_dim + s;
808 tensor.set3(l, s, r, x_t[[r, row]]);
809 }
810 }
811 }
812 self.site_tensors[b] = tensor;
813 }
814 }
815 Ok(())
816 }
817
818 pub fn make_canonical<F>(
826 &mut self,
827 f: &F,
828 rel_tol: f64,
829 abs_tol: f64,
830 max_bond_dim: usize,
831 ) -> Result<()>
832 where
833 F: Fn(&MultiIndex) -> T,
834 {
835 self.sweep1site(f, true, 0.0, 0.0, usize::MAX, false)?;
837 self.sweep1site(f, false, rel_tol, abs_tol, max_bond_dim, false)?;
839 self.sweep1site(f, true, rel_tol, abs_tol, max_bond_dim, true)?;
841 Ok(())
842 }
843
844 fn kronecker_i(&self, p: usize) -> Vec<MultiIndex> {
846 let mut result = Vec::new();
847 for i_multi in &self.i_set[p] {
848 for local_idx in 0..self.local_dims[p] {
849 let mut new_idx = i_multi.clone();
850 new_idx.push(local_idx);
851 result.push(new_idx);
852 }
853 }
854 result
855 }
856
857 fn kronecker_j(&self, p: usize) -> Vec<MultiIndex> {
858 let mut result = Vec::new();
859 for local_idx in 0..self.local_dims[p] {
860 for j_multi in &self.j_set[p] {
861 let mut new_idx = vec![local_idx];
862 new_idx.extend(j_multi.iter().cloned());
863 result.push(new_idx);
864 }
865 }
866 result
867 }
868}
869
870fn convergence_criterion(
874 ranks: &[usize],
875 errors: &[f64],
876 nglobal_pivots: &[usize],
877 tolerance: f64,
878 max_bond_dim: usize,
879 ncheck_history: usize,
880) -> bool {
881 if errors.len() < ncheck_history {
882 return false;
883 }
884
885 let n = errors.len();
886 let last_errors = &errors[n - ncheck_history..];
887 let last_ranks = &ranks[n - ncheck_history..];
888 let last_ngp = &nglobal_pivots[n - ncheck_history..];
889
890 let errors_converged = last_errors.iter().all(|&e| e < tolerance);
891 let no_global_pivots = last_ngp.iter().all(|&n| n == 0);
892 let rank_stable =
893 last_ranks.iter().min().copied().unwrap_or(0) == last_ranks.last().copied().unwrap_or(0);
894 let at_max_bond = last_ranks.iter().all(|&r| r >= max_bond_dim);
895
896 (errors_converged && no_global_pivots && rank_stable) || at_max_bond
897}
898
899pub fn crossinterpolate2<T, F, B>(
975 f: F,
976 batched_f: Option<B>,
977 local_dims: Vec<usize>,
978 initial_pivots: Vec<MultiIndex>,
979 options: TCI2Options,
980) -> Result<(TensorCI2<T>, Vec<usize>, Vec<f64>)>
981where
982 T: Scalar + TTScalar + Default + tensor4all_tcicore::MatrixLuciScalar,
983 DenseFaerLuKernel: PivotKernel<T>,
984 LazyBlockRookKernel: PivotKernel<T>,
985 F: Fn(&MultiIndex) -> T,
986 B: Fn(&[MultiIndex]) -> Vec<T>,
987{
988 if local_dims.len() < 2 {
989 return Err(TCIError::DimensionMismatch {
990 message: "local_dims should have at least 2 elements".to_string(),
991 });
992 }
993
994 let pivots = if initial_pivots.is_empty() {
995 vec![vec![0; local_dims.len()]]
996 } else {
997 initial_pivots
998 };
999
1000 let mut tci = TensorCI2::new(local_dims)?;
1001 tci.add_global_pivots(&pivots)?;
1002
1003 for pivot in &pivots {
1005 let value = f(pivot);
1006 let abs_val = f64::sqrt(Scalar::abs_sq(value));
1007 if abs_val > tci.max_sample_value {
1008 tci.max_sample_value = abs_val;
1009 }
1010 }
1011
1012 if tci.max_sample_value < 1e-30 {
1013 return Err(TCIError::InvalidPivot {
1014 message: "Initial pivots have zero function values".to_string(),
1015 });
1016 }
1017
1018 let n = tci.len();
1019 let mut errors = Vec::new();
1020 let mut ranks = Vec::new();
1021 let mut nglobal_pivots_history: Vec<usize> = Vec::new();
1022
1023 let mut rng = if let Some(seed) = options.seed {
1025 rand::rngs::StdRng::seed_from_u64(seed)
1026 } else {
1027 rand::rngs::StdRng::from_os_rng()
1028 };
1029
1030 let finder = DefaultGlobalPivotFinder::new(
1032 options.nsearch,
1033 options.max_nglobal_pivot,
1034 options.tol_margin_global_search,
1035 );
1036
1037 for iter in 0..options.max_iter {
1039 let error_normalization = if options.normalize_error && tci.max_sample_value > 0.0 {
1040 tci.max_sample_value
1041 } else {
1042 1.0
1043 };
1044 let abs_tol = options.tolerance * error_normalization;
1045
1046 let is_forward = match options.sweep_strategy {
1048 Sweep2Strategy::Forward => true,
1049 Sweep2Strategy::Backward => false,
1050 Sweep2Strategy::BackAndForth => iter % 2 == 0,
1051 };
1052
1053 let (extra_i_set, extra_j_set) =
1055 if !options.strictly_nested && !tci.i_set_history.is_empty() {
1056 let last = tci.i_set_history.len() - 1;
1057 (
1058 tci.i_set_history[last].clone(),
1059 tci.j_set_history[last].clone(),
1060 )
1061 } else {
1062 let empty: Vec<Vec<MultiIndex>> = (0..n).map(|_| Vec::new()).collect();
1063 (empty.clone(), empty)
1064 };
1065
1066 tci.i_set_history.push(tci.i_set.clone());
1068 tci.j_set_history.push(tci.j_set.clone());
1069
1070 tci.invalidate_site_tensors();
1072 tci.flush_pivot_errors();
1073
1074 if is_forward {
1075 for b in 0..n - 1 {
1076 update_pivots(
1077 &mut tci,
1078 b,
1079 &f,
1080 PivotUpdateContext {
1081 batched_f: &batched_f,
1082 left_orthogonal: true,
1083 options: &options,
1084 extra_i_set: &extra_i_set[b + 1],
1085 extra_j_set: &extra_j_set[b],
1086 },
1087 )?;
1088 }
1089 } else {
1090 for b in (0..n - 1).rev() {
1091 update_pivots(
1092 &mut tci,
1093 b,
1094 &f,
1095 PivotUpdateContext {
1096 batched_f: &batched_f,
1097 left_orthogonal: false,
1098 options: &options,
1099 extra_i_set: &extra_i_set[b + 1],
1100 extra_j_set: &extra_j_set[b],
1101 },
1102 )?;
1103 }
1104 }
1105
1106 tci.fill_site_tensors(&f)?;
1108
1109 let error = tci.max_bond_error();
1111 let error_normalized = error / error_normalization;
1112 errors.push(error_normalized);
1113
1114 let tt = tci.to_tensor_train()?;
1116 let input = GlobalPivotSearchInput {
1117 local_dims: tci.local_dims.clone(),
1118 current_tt: tt,
1119 max_sample_value: tci.max_sample_value,
1120 i_set: tci.i_set.clone(),
1121 j_set: tci.j_set.clone(),
1122 };
1123
1124 let global_pivots = finder.find_global_pivots(&input, &f, abs_tol, &mut rng);
1125 let n_global = global_pivots.len();
1126 tci.add_global_pivots(&global_pivots)?;
1127 nglobal_pivots_history.push(n_global);
1128
1129 ranks.push(tci.rank());
1130
1131 if options.verbosity > 0 {
1132 println!(
1133 "iteration = {}, rank = {}, error = {:.2e}, maxsamplevalue = {:.2e}, nglobalpivot = {}",
1134 iter + 1,
1135 tci.rank(),
1136 error_normalized,
1137 tci.max_sample_value,
1138 n_global
1139 );
1140 }
1141
1142 if convergence_criterion(
1144 &ranks,
1145 &errors,
1146 &nglobal_pivots_history,
1147 abs_tol,
1148 options.max_bond_dim,
1149 options.ncheck_history,
1150 ) {
1151 break;
1152 }
1153 }
1154
1155 let error_normalization = if options.normalize_error && tci.max_sample_value > 0.0 {
1159 tci.max_sample_value
1160 } else {
1161 1.0
1162 };
1163 let abs_tol = options.tolerance * error_normalization;
1164 tci.sweep1site(&f, true, 1e-14, abs_tol, options.max_bond_dim, true)?;
1165
1166 let normalized_errors = errors.to_vec();
1168
1169 Ok((tci, ranks, normalized_errors))
1170}
1171
1172fn update_pivots<T, F, B>(
1174 tci: &mut TensorCI2<T>,
1175 b: usize,
1176 f: &F,
1177 context: PivotUpdateContext<'_, B>,
1178) -> Result<()>
1179where
1180 T: Scalar + TTScalar + Default + tensor4all_tcicore::MatrixLuciScalar,
1181 DenseFaerLuKernel: PivotKernel<T>,
1182 LazyBlockRookKernel: PivotKernel<T>,
1183 F: Fn(&MultiIndex) -> T,
1184 B: Fn(&[MultiIndex]) -> Vec<T>,
1185{
1186 let mut i_combined = tci.kronecker_i(b);
1188 let mut j_combined = tci.kronecker_j(b + 1);
1189
1190 for extra in context.extra_i_set {
1192 if !i_combined.contains(extra) {
1193 i_combined.push(extra.clone());
1194 }
1195 }
1196 for extra in context.extra_j_set {
1197 if !j_combined.contains(extra) {
1198 j_combined.push(extra.clone());
1199 }
1200 }
1201
1202 if i_combined.is_empty() || j_combined.is_empty() {
1203 return Ok(());
1204 }
1205
1206 let lu_options = PivotKernelOptions {
1208 max_rank: context.options.max_bond_dim,
1209 rel_tol: context.options.tolerance,
1210 abs_tol: 0.0,
1211 left_orthogonal: context.left_orthogonal,
1212 };
1213
1214 let selection;
1215 let factors;
1216 if context.options.pivot_search == PivotSearchStrategy::Full {
1217 let mut pi = zeros(i_combined.len(), j_combined.len());
1218
1219 if let Some(ref batch_fn) = context.batched_f {
1220 let mut all_indices: Vec<MultiIndex> =
1221 Vec::with_capacity(i_combined.len() * j_combined.len());
1222 for i_multi in &i_combined {
1223 for j_multi in &j_combined {
1224 let mut full_idx = i_multi.clone();
1225 full_idx.extend(j_multi.iter().cloned());
1226 all_indices.push(full_idx);
1227 }
1228 }
1229
1230 let values = batch_fn(&all_indices);
1231 if values.len() != all_indices.len() {
1232 return Err(callback_length_mismatch(values.len(), all_indices.len()));
1233 }
1234 let mut idx = 0;
1235 for i in 0..i_combined.len() {
1236 for j in 0..j_combined.len() {
1237 pi[[i, j]] = values[idx];
1238 update_max_sample_value(tci, values[idx]);
1239 idx += 1;
1240 }
1241 }
1242 } else {
1243 for (i, i_multi) in i_combined.iter().enumerate() {
1244 for (j, j_multi) in j_combined.iter().enumerate() {
1245 let mut full_idx = i_multi.clone();
1246 full_idx.extend(j_multi.iter().cloned());
1247 let value = f(&full_idx);
1248 pi[[i, j]] = value;
1249 update_max_sample_value(tci, value);
1250 }
1251 }
1252 }
1253
1254 let mut data = Vec::with_capacity(pi.nrows() * pi.ncols());
1255 for col in 0..pi.ncols() {
1256 for row in 0..pi.nrows() {
1257 data.push(pi[[row, col]]);
1258 }
1259 }
1260 let source = DenseMatrixSource::from_column_major(&data, pi.nrows(), pi.ncols());
1261 selection = DenseFaerLuKernel.factorize(&source, &lu_options)?;
1262 factors = CrossFactors::from_source(&source, &selection)?;
1263 } else {
1264 let evaluator = LazyPiEvaluator::new(
1265 &i_combined,
1266 &j_combined,
1267 f,
1268 context.batched_f,
1269 tci.max_sample_value,
1270 );
1271 let source = LazyMatrixSource::new(
1272 i_combined.len(),
1273 j_combined.len(),
1274 |rows, cols, out: &mut [T]| {
1275 evaluator.fill_block(rows, cols, out);
1276 },
1277 );
1278 let selection_result = LazyBlockRookKernel.factorize(&source, &lu_options);
1279 if let Some(err) = evaluator.take_error() {
1280 return Err(err);
1281 }
1282 selection = selection_result?;
1283
1284 let factors_result = CrossFactors::from_source(&source, &selection);
1285 if let Some(err) = evaluator.take_error() {
1286 return Err(err);
1287 }
1288 factors = factors_result?;
1289 tci.max_sample_value = evaluator.sampled_max();
1290 }
1291
1292 let row_indices = &selection.row_indices;
1294 let col_indices = &selection.col_indices;
1295
1296 tci.i_set[b + 1] = row_indices.iter().map(|&i| i_combined[i].clone()).collect();
1297 tci.j_set[b] = col_indices.iter().map(|&j| j_combined[j].clone()).collect();
1298
1299 if !context.extra_i_set.is_empty() || !context.extra_j_set.is_empty() {
1302 let errors = &selection.pivot_errors;
1304 if !errors.is_empty() {
1305 tci.bond_errors[b] = *errors.last().unwrap_or(&0.0);
1306 }
1307 return Ok(());
1308 }
1309
1310 let left = if context.left_orthogonal {
1312 factors.cols_times_pivot_inv()?
1313 } else {
1314 factors.pivot_cols.clone()
1315 };
1316 let right = if context.left_orthogonal {
1317 factors.pivot_rows.clone()
1318 } else {
1319 factors.pivot_inv_times_rows()?
1320 };
1321
1322 let left_dim = if b == 0 { 1 } else { tci.i_set[b].len() };
1324 let site_dim_b = tci.local_dims[b];
1325 let new_bond_dim = selection.rank.max(1);
1326
1327 let mut tensor_b = tensor3_zeros(left_dim, site_dim_b, new_bond_dim);
1328 for l in 0..left_dim {
1329 for s in 0..site_dim_b {
1330 for r in 0..new_bond_dim {
1331 let row = l * site_dim_b + s;
1332 if row < left.nrows() && r < left.ncols() {
1333 tensor_b.set3(l, s, r, left[[row, r]]);
1334 }
1335 }
1336 }
1337 }
1338 tci.site_tensors[b] = tensor_b;
1339
1340 let site_dim_bp1 = tci.local_dims[b + 1];
1342 let right_dim = if b + 1 == tci.len() - 1 {
1343 1
1344 } else {
1345 tci.j_set[b + 1].len()
1346 };
1347
1348 let mut tensor_bp1 = tensor3_zeros(new_bond_dim, site_dim_bp1, right_dim);
1349 for l in 0..new_bond_dim {
1350 for s in 0..site_dim_bp1 {
1351 for r in 0..right_dim {
1352 let col = s * right_dim + r;
1353 if l < right.nrows() && col < right.ncols() {
1354 tensor_bp1.set3(l, s, r, right[[l, col]]);
1355 }
1356 }
1357 }
1358 }
1359 tci.site_tensors[b + 1] = tensor_bp1;
1360
1361 if !selection.pivot_errors.is_empty() {
1363 tci.bond_errors[b] = *selection.pivot_errors.last().unwrap_or(&0.0);
1364 }
1365
1366 Ok(())
1367}
1368
1369fn update_max_sample_value<T: Scalar + TTScalar>(tci: &mut TensorCI2<T>, value: T) {
1370 let abs_val = f64::sqrt(Scalar::abs_sq(value));
1371 if abs_val > tci.max_sample_value {
1372 tci.max_sample_value = abs_val;
1373 }
1374}
1375
1376fn build_full_index(
1377 i_combined: &[MultiIndex],
1378 j_combined: &[MultiIndex],
1379 row: usize,
1380 col: usize,
1381) -> MultiIndex {
1382 let mut full_idx = i_combined[row].clone();
1383 full_idx.extend(j_combined[col].iter().cloned());
1384 full_idx
1385}
1386
1387fn callback_length_mismatch(actual: usize, expected: usize) -> TCIError {
1388 TCIError::InvalidOperation {
1389 message: format!(
1390 "batch callback returned {actual} values for {expected} requested entries"
1391 ),
1392 }
1393}
1394
1395struct LazyPiEvaluator<'a, T, F, B>
1396where
1397 T: Scalar + TTScalar + Default + tensor4all_tcicore::MatrixLuciScalar,
1398 F: Fn(&MultiIndex) -> T,
1399 B: Fn(&[MultiIndex]) -> Vec<T>,
1400{
1401 i_combined: &'a [MultiIndex],
1402 j_combined: &'a [MultiIndex],
1403 f: &'a F,
1404 batched_f: &'a Option<B>,
1405 cache: RefCell<HashMap<(usize, usize), T>>,
1406 pending_error: RefCell<Option<TCIError>>,
1407 sampled_max: Cell<f64>,
1408}
1409
1410impl<'a, T, F, B> LazyPiEvaluator<'a, T, F, B>
1411where
1412 T: Scalar + TTScalar + Default + tensor4all_tcicore::MatrixLuciScalar,
1413 F: Fn(&MultiIndex) -> T,
1414 B: Fn(&[MultiIndex]) -> Vec<T>,
1415{
1416 fn new(
1417 i_combined: &'a [MultiIndex],
1418 j_combined: &'a [MultiIndex],
1419 f: &'a F,
1420 batched_f: &'a Option<B>,
1421 initial_max: f64,
1422 ) -> Self {
1423 Self {
1424 i_combined,
1425 j_combined,
1426 f,
1427 batched_f,
1428 cache: RefCell::new(HashMap::new()),
1429 pending_error: RefCell::new(None),
1430 sampled_max: Cell::new(initial_max),
1431 }
1432 }
1433
1434 fn fill_block(&self, rows: &[usize], cols: &[usize], out: &mut [T]) {
1435 if self.pending_error.borrow().is_some() {
1436 out.fill(T::zero());
1437 return;
1438 }
1439
1440 let mut missing_entries = Vec::new();
1441 let mut missing_indices = Vec::new();
1442
1443 {
1444 let cache_ref = self.cache.borrow();
1445 for (j_pos, &col) in cols.iter().enumerate() {
1446 for (i_pos, &row) in rows.iter().enumerate() {
1447 let out_idx = i_pos + rows.len() * j_pos;
1448 if let Some(&value) = cache_ref.get(&(row, col)) {
1449 out[out_idx] = value;
1450 } else {
1451 missing_entries.push((out_idx, row, col));
1452 missing_indices.push(build_full_index(
1453 self.i_combined,
1454 self.j_combined,
1455 row,
1456 col,
1457 ));
1458 }
1459 }
1460 }
1461 }
1462
1463 if missing_entries.is_empty() {
1464 return;
1465 }
1466
1467 let values = if let Some(batch_fn) = self.batched_f {
1468 batch_fn(&missing_indices)
1469 } else {
1470 missing_indices.iter().map(self.f).collect()
1471 };
1472 if values.len() != missing_entries.len() {
1473 *self.pending_error.borrow_mut() = Some(callback_length_mismatch(
1474 values.len(),
1475 missing_entries.len(),
1476 ));
1477 for (out_idx, _, _) in missing_entries {
1478 out[out_idx] = T::zero();
1479 }
1480 return;
1481 }
1482
1483 let mut cache_ref = self.cache.borrow_mut();
1484 for ((out_idx, row, col), value) in missing_entries.into_iter().zip(values) {
1485 out[out_idx] = value;
1486 cache_ref.insert((row, col), value);
1487
1488 let abs_val = f64::sqrt(Scalar::abs_sq(value));
1489 if abs_val > self.sampled_max.get() {
1490 self.sampled_max.set(abs_val);
1491 }
1492 }
1493 }
1494
1495 fn sampled_max(&self) -> f64 {
1496 self.sampled_max.get()
1497 }
1498
1499 fn take_error(&self) -> Option<TCIError> {
1500 self.pending_error.borrow_mut().take()
1501 }
1502}
1503
1504#[cfg(test)]
1505mod tests;