Skip to main content

tensor4all_tensorci/
tensorci2.rs

1//! TensorCI2 - Two-site Tensor Cross Interpolation algorithm
2//!
3//! This implements the TCI2 algorithm which uses two-site updates for
4//! more efficient convergence. Unlike TCI1, it supports batch evaluation
5//! of function values through an explicit batch function parameter.
6
7use crate::error::{Result, TCIError};
8use crate::globalpivot::{DefaultGlobalPivotFinder, GlobalPivotFinder, GlobalPivotSearchInput};
9use rand::SeedableRng;
10use std::cell::{Cell, RefCell};
11use std::collections::HashMap;
12use tensor4all_simplett::{tensor3_zeros, TTScalar, Tensor3, Tensor3Ops, TensorTrain};
13use tensor4all_tcicore::matrix::zeros;
14use tensor4all_tcicore::MultiIndex;
15use tensor4all_tcicore::Scalar;
16use tensor4all_tcicore::{
17    rrlu, AbstractMatrixCI, CrossFactors, DenseFaerLuKernel, DenseMatrixSource,
18    LazyBlockRookKernel, LazyMatrixSource, MatrixLUCI, PivotKernel, PivotKernelOptions,
19    RrLUOptions,
20};
21
22/// Configuration for the TCI2 algorithm ([`crossinterpolate2`]).
23///
24/// Controls convergence criteria, bond dimension limits, pivot search
25/// strategy, and global pivot search parameters.
26///
27/// # Recommended starting point
28///
29/// For most problems the defaults work well. The fields you will most
30/// commonly adjust are:
31///
32/// | Field | Typical range | Purpose |
33/// |---|---|---|
34/// | `tolerance` | `1e-6` -- `1e-12` | Relative convergence threshold |
35/// | `max_bond_dim` | `50` -- `500` | Hard cap on bond dimension |
36/// | `max_iter` | `20` -- `100` | Maximum number of half-sweeps |
37/// | `seed` | `Some(42)` | Fix seed for reproducibility |
38///
39/// # Convergence criterion
40///
41/// The algorithm stops when **all** of the following hold for
42/// `ncheck_history` consecutive iterations:
43///
44/// 1. The normalized bond error is below `tolerance`.
45/// 2. No global pivots were added.
46/// 3. The rank is stable.
47///
48/// Alternatively, it stops when the rank reaches `max_bond_dim`.
49///
50/// # Examples
51///
52/// ```
53/// use tensor4all_tensorci::TCI2Options;
54///
55/// // Default options
56/// let opts = TCI2Options::default();
57/// assert!((opts.tolerance - 1e-8).abs() < 1e-15);
58/// assert_eq!(opts.max_iter, 20);
59/// assert_eq!(opts.max_bond_dim, usize::MAX);
60/// assert_eq!(opts.verbosity, 0);
61///
62/// // Custom options via struct update syntax
63/// let custom = TCI2Options {
64///     tolerance: 1e-12,
65///     max_bond_dim: 100,
66///     seed: Some(42),
67///     ..TCI2Options::default()
68/// };
69/// assert!((custom.tolerance - 1e-12).abs() < 1e-20);
70/// assert_eq!(custom.max_bond_dim, 100);
71/// assert_eq!(custom.seed, Some(42));
72/// ```
73#[derive(Debug, Clone)]
74pub struct TCI2Options {
75    /// Convergence tolerance (default: `1e-8`).
76    ///
77    /// When `normalize_error` is enabled (the default), this is the
78    /// *relative* threshold: the bond error is divided by the maximum
79    /// absolute function value seen so far. Typical choices are `1e-8` for
80    /// moderate accuracy and `1e-12` for high accuracy.
81    pub tolerance: f64,
82    /// Maximum number of half-sweep iterations (default: `20`).
83    ///
84    /// Each iteration performs one forward or backward two-site sweep
85    /// followed by a global pivot search. Increase if the function is
86    /// difficult to converge.
87    pub max_iter: usize,
88    /// Hard upper bound on bond dimension (default: `usize::MAX`, i.e. no limit).
89    ///
90    /// The algorithm stops early once the rank reaches this value. For
91    /// expensive functions, setting this to `50`--`500` avoids runaway
92    /// computation.
93    pub max_bond_dim: usize,
94    /// Pivot search strategy (default: [`PivotSearchStrategy::Full`]).
95    ///
96    /// `Full` materializes the entire candidate matrix and finds the best
97    /// pivot exactly. `Rook` uses lazy block-rook search and is faster for
98    /// very large local dimensions but may miss some pivots.
99    pub pivot_search: PivotSearchStrategy,
100    /// Whether to normalize the bond error by the maximum observed sample
101    /// value (default: `true`).
102    ///
103    /// When enabled, `tolerance` acts as a *relative* threshold. Disable
104    /// to use `tolerance` as an *absolute* threshold.
105    pub normalize_error: bool,
106    /// Verbosity level (default: `0` = silent).
107    ///
108    /// `1` prints per-iteration summaries, `2` adds per-bond details.
109    pub verbosity: usize,
110    /// Maximum number of global pivots to add per iteration (default: `5`).
111    pub max_nglobal_pivot: usize,
112    /// Number of random starting points for global pivot search (default: `5`).
113    ///
114    /// Each starting point undergoes local optimization to find regions of
115    /// high interpolation error.
116    pub nsearch: usize,
117    /// Sweep strategy for 2-site sweeps (default: [`Sweep2Strategy::BackAndForth`]).
118    pub sweep_strategy: Sweep2Strategy,
119    /// Number of recent iterations checked for convergence (default: `3`).
120    ///
121    /// The algorithm requires `ncheck_history` consecutive converged
122    /// iterations before declaring convergence.
123    pub ncheck_history: usize,
124    /// Whether to use strictly nested index sets (default: `false`).
125    ///
126    /// When `false`, the algorithm keeps a history of previous index sets
127    /// and merges them during sweeps, which generally improves convergence.
128    pub strictly_nested: bool,
129    /// Tolerance margin for global pivot search (default: `10.0`).
130    ///
131    /// Global pivots are accepted when their error exceeds
132    /// `abs_tolerance * tol_margin_global_search`.
133    pub tol_margin_global_search: f64,
134    /// Random seed for reproducibility (default: `None` = OS entropy).
135    ///
136    /// Set to `Some(seed)` for deterministic results.
137    pub seed: Option<u64>,
138}
139
140impl Default for TCI2Options {
141    fn default() -> Self {
142        Self {
143            tolerance: 1e-8,
144            max_iter: 20,
145            max_bond_dim: usize::MAX,
146            pivot_search: PivotSearchStrategy::Full,
147            normalize_error: true,
148            verbosity: 0,
149            max_nglobal_pivot: 5,
150            nsearch: 5,
151            sweep_strategy: Sweep2Strategy::BackAndForth,
152            ncheck_history: 3,
153            strictly_nested: false,
154            tol_margin_global_search: 10.0,
155            seed: None,
156        }
157    }
158}
159
160#[derive(Clone, Copy)]
161struct Sweep1SiteBondConfig {
162    rel_tol: f64,
163    abs_tol: f64,
164    max_bond_dim: usize,
165    update_tensors: bool,
166}
167
168struct PivotUpdateContext<'a, B> {
169    batched_f: &'a Option<B>,
170    left_orthogonal: bool,
171    options: &'a TCI2Options,
172    extra_i_set: &'a [MultiIndex],
173    extra_j_set: &'a [MultiIndex],
174}
175
176/// Strategy for finding new pivots during a two-site update.
177///
178/// Controls how much of the candidate matrix is evaluated when searching
179/// for the next pivot at each bond.
180///
181/// # Examples
182///
183/// ```
184/// use tensor4all_tensorci::PivotSearchStrategy;
185///
186/// // Full is the default
187/// let strategy = PivotSearchStrategy::default();
188/// assert_eq!(strategy, PivotSearchStrategy::Full);
189///
190/// // Both variants can be compared
191/// assert_ne!(PivotSearchStrategy::Full, PivotSearchStrategy::Rook);
192/// ```
193#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
194pub enum PivotSearchStrategy {
195    /// Evaluate the entire candidate matrix and pick the best pivot.
196    ///
197    /// Exact but costs O(D_left * d * d * D_right) evaluations per bond,
198    /// where d is the local dimension. This is the default and recommended
199    /// for most problems.
200    #[default]
201    Full,
202    /// Use lazy block-rook pivoting over partial matrix blocks.
203    ///
204    /// Avoids materializing the full candidate matrix, making it faster
205    /// for very large local dimensions. Error normalization uses the
206    /// maximum sample value observed through the lazy requests rather
207    /// than a full-grid scan.
208    Rook,
209}
210
211/// Direction of two-site sweeps.
212///
213/// # Examples
214///
215/// ```
216/// use tensor4all_tensorci::Sweep2Strategy;
217///
218/// // BackAndForth is the default
219/// let strategy = Sweep2Strategy::default();
220/// assert_eq!(strategy, Sweep2Strategy::BackAndForth);
221///
222/// // All three variants
223/// let fwd = Sweep2Strategy::Forward;
224/// let bwd = Sweep2Strategy::Backward;
225/// assert_ne!(fwd, bwd);
226/// ```
227#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
228pub enum Sweep2Strategy {
229    /// Sweep left-to-right only.
230    Forward,
231    /// Sweep right-to-left only.
232    Backward,
233    /// Alternate between forward and backward sweeps (default).
234    ///
235    /// This is the most robust choice: forward sweeps improve
236    /// right-canonical form while backward sweeps improve left-canonical
237    /// form, and alternating achieves better overall convergence.
238    #[default]
239    BackAndForth,
240}
241
242/// State object for the two-site Tensor Cross Interpolation algorithm.
243///
244/// Holds the current index sets (I, J), site tensors, and error statistics
245/// produced by [`crossinterpolate2`]. After interpolation, call
246/// [`to_tensor_train`](Self::to_tensor_train) to extract the resulting
247/// [`TensorTrain`].
248///
249/// # Key methods
250///
251/// | Method | Purpose |
252/// |---|---|
253/// | [`rank`](Self::rank) | Maximum bond dimension |
254/// | [`link_dims`](Self::link_dims) | Bond dimensions at each bond |
255/// | [`to_tensor_train`](Self::to_tensor_train) | Extract the tensor train |
256/// | [`max_bond_error`](Self::max_bond_error) | Largest bond error from last sweep |
257/// | [`pivot_errors`](Self::pivot_errors) | Per-bond pivot errors from back-truncation |
258///
259/// You normally do not construct `TensorCI2` directly; use
260/// [`crossinterpolate2`] instead.
261#[derive(Debug, Clone)]
262pub struct TensorCI2<T: Scalar + TTScalar> {
263    /// Index sets I for each site
264    i_set: Vec<Vec<MultiIndex>>,
265    /// Index sets J for each site
266    j_set: Vec<Vec<MultiIndex>>,
267    /// Local dimensions
268    local_dims: Vec<usize>,
269    /// Site tensors (3-leg tensors)
270    site_tensors: Vec<Tensor3<T>>,
271    /// Pivot errors during back-truncation
272    pivot_errors: Vec<f64>,
273    /// Bond errors from 2-site sweep
274    bond_errors: Vec<f64>,
275    /// Maximum observed sample value found during function evaluation
276    max_sample_value: f64,
277    /// History of I-sets for non-strictly-nested mode
278    i_set_history: Vec<Vec<Vec<MultiIndex>>>,
279    /// History of J-sets for non-strictly-nested mode
280    j_set_history: Vec<Vec<Vec<MultiIndex>>>,
281}
282
283impl<T> TensorCI2<T>
284where
285    T: Scalar + TTScalar + Default + tensor4all_tcicore::MatrixLuciScalar,
286    DenseFaerLuKernel: PivotKernel<T>,
287    LazyBlockRookKernel: PivotKernel<T>,
288{
289    /// Create a new empty TensorCI2
290    pub fn new(local_dims: Vec<usize>) -> Result<Self> {
291        if local_dims.len() < 2 {
292            return Err(TCIError::DimensionMismatch {
293                message: "local_dims should have at least 2 elements".to_string(),
294            });
295        }
296
297        let n = local_dims.len();
298        Ok(Self {
299            i_set: (0..n).map(|_| Vec::new()).collect(),
300            j_set: (0..n).map(|_| Vec::new()).collect(),
301            local_dims: local_dims.clone(),
302            site_tensors: local_dims.iter().map(|&d| tensor3_zeros(0, d, 0)).collect(),
303            pivot_errors: Vec::new(),
304            bond_errors: vec![0.0; n.saturating_sub(1)],
305            max_sample_value: 0.0,
306            i_set_history: Vec::new(),
307            j_set_history: Vec::new(),
308        })
309    }
310
311    /// Number of sites
312    pub fn len(&self) -> usize {
313        self.local_dims.len()
314    }
315
316    /// Check if empty
317    pub fn is_empty(&self) -> bool {
318        self.local_dims.is_empty()
319    }
320
321    /// Get local dimensions
322    pub fn local_dims(&self) -> &[usize] {
323        &self.local_dims
324    }
325
326    /// Get current rank (maximum bond dimension)
327    pub fn rank(&self) -> usize {
328        if self.len() <= 1 {
329            return if self.i_set.is_empty() || self.i_set[0].is_empty() {
330                0
331            } else {
332                1
333            };
334        }
335        self.i_set
336            .iter()
337            .skip(1)
338            .map(|s| s.len())
339            .max()
340            .unwrap_or(0)
341    }
342
343    /// Get bond dimensions
344    pub fn link_dims(&self) -> Vec<usize> {
345        if self.len() <= 1 {
346            return Vec::new();
347        }
348        self.i_set.iter().skip(1).map(|s| s.len()).collect()
349    }
350
351    /// Get the maximum observed sample value seen so far
352    pub fn max_sample_value(&self) -> f64 {
353        self.max_sample_value
354    }
355
356    /// Get maximum bond error
357    pub fn max_bond_error(&self) -> f64 {
358        self.bond_errors.iter().cloned().fold(0.0, f64::max)
359    }
360
361    /// Get pivot errors from back-truncation
362    pub fn pivot_errors(&self) -> &[f64] {
363        &self.pivot_errors
364    }
365
366    /// Check if site tensors are available
367    pub fn is_site_tensors_available(&self) -> bool {
368        self.site_tensors
369            .iter()
370            .all(|t| t.left_dim() > 0 || t.right_dim() > 0)
371    }
372
373    /// Get site tensor at position p
374    pub fn site_tensor(&self, p: usize) -> &Tensor3<T> {
375        &self.site_tensors[p]
376    }
377
378    /// Convert to TensorTrain
379    pub fn to_tensor_train(&self) -> Result<TensorTrain<T>> {
380        let tensors = self.site_tensors.clone();
381        TensorTrain::new(tensors).map_err(TCIError::TensorTrainError)
382    }
383
384    /// Add global pivots to the TCI
385    pub fn add_global_pivots(&mut self, pivots: &[MultiIndex]) -> Result<()> {
386        for pivot in pivots {
387            if pivot.len() != self.len() {
388                return Err(TCIError::DimensionMismatch {
389                    message: format!(
390                        "Pivot length ({}) must match number of sites ({})",
391                        pivot.len(),
392                        self.len()
393                    ),
394                });
395            }
396
397            // Add to I and J sets
398            for p in 0..self.len() {
399                let i_indices: MultiIndex = pivot[0..p].to_vec();
400                let j_indices: MultiIndex = pivot[p + 1..].to_vec();
401
402                if !self.i_set[p].contains(&i_indices) {
403                    self.i_set[p].push(i_indices);
404                }
405                if !self.j_set[p].contains(&j_indices) {
406                    self.j_set[p].push(j_indices);
407                }
408            }
409        }
410
411        // Invalidate site tensors after adding pivots
412        self.invalidate_site_tensors();
413
414        Ok(())
415    }
416
417    /// Get I-set at site p
418    pub fn i_set(&self, p: usize) -> &[MultiIndex] {
419        &self.i_set[p]
420    }
421
422    /// Get J-set at site p
423    pub fn j_set(&self, p: usize) -> &[MultiIndex] {
424        &self.j_set[p]
425    }
426
427    /// Invalidate all site tensors
428    pub fn invalidate_site_tensors(&mut self) {
429        for p in 0..self.len() {
430            self.site_tensors[p] = tensor3_zeros(0, self.local_dims[p], 0);
431        }
432    }
433
434    /// Flush pivot errors (reset to empty)
435    pub fn flush_pivot_errors(&mut self) {
436        self.pivot_errors.clear();
437    }
438
439    /// Perform one 2-site sweep.
440    ///
441    /// This is a public wrapper around the internal `update_pivots` logic,
442    /// suitable for calling from C-API.
443    pub fn sweep2site<F, B>(
444        &mut self,
445        f: &F,
446        batched_f: &Option<B>,
447        forward: bool,
448        options: &TCI2Options,
449    ) -> Result<()>
450    where
451        F: Fn(&MultiIndex) -> T,
452        B: Fn(&[MultiIndex]) -> Vec<T>,
453    {
454        let n = self.len();
455        self.invalidate_site_tensors();
456        self.flush_pivot_errors();
457
458        let empty: Vec<MultiIndex> = Vec::new();
459        if forward {
460            for b in 0..n - 1 {
461                update_pivots(
462                    self,
463                    b,
464                    f,
465                    PivotUpdateContext {
466                        batched_f,
467                        left_orthogonal: true,
468                        options,
469                        extra_i_set: &empty,
470                        extra_j_set: &empty,
471                    },
472                )?;
473            }
474        } else {
475            for b in (0..n - 1).rev() {
476                update_pivots(
477                    self,
478                    b,
479                    f,
480                    PivotUpdateContext {
481                        batched_f,
482                        left_orthogonal: false,
483                        options,
484                        extra_i_set: &empty,
485                        extra_j_set: &empty,
486                    },
487                )?;
488            }
489        }
490
491        // Fill site tensors after sweep
492        self.fill_site_tensors(f)?;
493        Ok(())
494    }
495
496    /// Update pivot errors with element-wise max
497    fn update_pivot_errors(&mut self, errors: &[f64]) {
498        if self.pivot_errors.len() < errors.len() {
499            self.pivot_errors.resize(errors.len(), 0.0);
500        }
501        for (i, &e) in errors.iter().enumerate() {
502            self.pivot_errors[i] = self.pivot_errors[i].max(e);
503        }
504    }
505
506    /// Evaluate function at all combinations of left indices × local index × right indices.
507    ///
508    /// Returns a 3D array of shape (len(i_indices), local_dim, len(j_indices)).
509    fn fill_tensor<F>(
510        &self,
511        f: &F,
512        i_indices: &[MultiIndex],
513        j_indices: &[MultiIndex],
514        local_dim: usize,
515        site: usize,
516    ) -> Tensor3<T>
517    where
518        F: Fn(&MultiIndex) -> T,
519    {
520        let ni = i_indices.len();
521        let nj = j_indices.len();
522        let mut tensor = tensor3_zeros(ni, local_dim, nj);
523        for (ii, i_multi) in i_indices.iter().enumerate() {
524            for s in 0..local_dim {
525                for (jj, j_multi) in j_indices.iter().enumerate() {
526                    let mut full_idx = i_multi.clone();
527                    full_idx.push(s);
528                    full_idx.extend(j_multi.iter().cloned());
529                    debug_assert_eq!(
530                        full_idx.len(),
531                        self.local_dims.len(),
532                        "fill_tensor: full_idx length {} != n_sites {} at site {}",
533                        full_idx.len(),
534                        self.local_dims.len(),
535                        site
536                    );
537                    let val = f(&full_idx);
538                    tensor.set3(ii, s, jj, val);
539                }
540            }
541        }
542        tensor
543    }
544
545    /// Perform a 1-site sweep, updating I/J sets and optionally site tensors.
546    ///
547    /// This is used for cleanup after adding global pivots, and for computing
548    /// canonical site tensors.
549    ///
550    /// Port of Julia's `sweep1site!` from `tensorci2.jl`.
551    pub fn sweep1site<F>(
552        &mut self,
553        f: &F,
554        forward: bool,
555        rel_tol: f64,
556        abs_tol: f64,
557        max_bond_dim: usize,
558        update_tensors: bool,
559    ) -> Result<()>
560    where
561        F: Fn(&MultiIndex) -> T,
562    {
563        self.flush_pivot_errors();
564        self.invalidate_site_tensors();
565
566        let n = self.len();
567        let bond_config = Sweep1SiteBondConfig {
568            rel_tol,
569            abs_tol,
570            max_bond_dim,
571            update_tensors,
572        };
573
574        if forward {
575            for b in 0..n - 1 {
576                self.sweep1site_at_bond(f, b, true, bond_config)?;
577            }
578        } else {
579            for b in (1..n).rev() {
580                self.sweep1site_at_bond(f, b, false, bond_config)?;
581            }
582        }
583
584        // Update last tensor according to last index set
585        if update_tensors {
586            let last_idx = if forward { n - 1 } else { 0 };
587            let tensor = self.fill_tensor(
588                f,
589                &self.i_set[last_idx].clone(),
590                &self.j_set[last_idx].clone(),
591                self.local_dims[last_idx],
592                last_idx,
593            );
594            self.site_tensors[last_idx] = tensor;
595        }
596
597        Ok(())
598    }
599
600    /// Process one bond during 1-site sweep.
601    fn sweep1site_at_bond<F>(
602        &mut self,
603        f: &F,
604        b: usize,
605        forward: bool,
606        config: Sweep1SiteBondConfig,
607    ) -> Result<()>
608    where
609        F: Fn(&MultiIndex) -> T,
610    {
611        // Build combined indices: for forward, Kronecker(I_b, local_b) × J_b
612        //                         for backward, I_b × Kronecker(local_b, J_b)
613        let (is, js) = if forward {
614            (self.kronecker_i(b), self.j_set[b].clone())
615        } else {
616            (self.i_set[b].clone(), self.kronecker_j(b))
617        };
618
619        if is.is_empty() || js.is_empty() {
620            return Ok(());
621        }
622
623        // Build Pi matrix by evaluating function at all (I, J) combinations
624        let ni = is.len();
625        let nj = js.len();
626        let mut pi = zeros(ni, nj);
627        for (i, i_multi) in is.iter().enumerate() {
628            for (j, j_multi) in js.iter().enumerate() {
629                let mut full_idx = i_multi.clone();
630                full_idx.extend(j_multi.iter().cloned());
631                let val = f(&full_idx);
632                pi[[i, j]] = val;
633                let abs_val = f64::sqrt(Scalar::abs_sq(val));
634                if abs_val > self.max_sample_value {
635                    self.max_sample_value = abs_val;
636                }
637            }
638        }
639
640        // LU-based cross interpolation
641        let lu_options = RrLUOptions {
642            max_rank: config.max_bond_dim,
643            rel_tol: config.rel_tol,
644            abs_tol: config.abs_tol,
645            left_orthogonal: forward,
646        };
647        let luci = MatrixLUCI::from_matrix(&pi, Some(lu_options))?;
648
649        let row_indices = luci.row_indices();
650        let col_indices = luci.col_indices();
651
652        // Update I/J sets
653        if forward {
654            self.i_set[b + 1] = row_indices.iter().map(|&i| is[i].clone()).collect();
655            self.j_set[b] = col_indices.iter().map(|&j| js[j].clone()).collect();
656        } else {
657            self.i_set[b] = row_indices.iter().map(|&i| is[i].clone()).collect();
658            self.j_set[b - 1] = col_indices.iter().map(|&j| js[j].clone()).collect();
659        }
660
661        // Update site tensor
662        if config.update_tensors {
663            let mat = if forward { luci.left() } else { luci.right() };
664            let local_dim = self.local_dims[b];
665            if forward {
666                let left_dim = if b == 0 { 1 } else { self.i_set[b].len() };
667                let right_dim = luci.rank().max(1);
668                let mut tensor = tensor3_zeros(left_dim, local_dim, right_dim);
669                for l in 0..left_dim {
670                    for s in 0..local_dim {
671                        for r in 0..right_dim {
672                            let row = l * local_dim + s;
673                            if row < mat.nrows() && r < mat.ncols() {
674                                tensor.set3(l, s, r, mat[[row, r]]);
675                            }
676                        }
677                    }
678                }
679                self.site_tensors[b] = tensor;
680            } else {
681                let left_dim = luci.rank().max(1);
682                let right_dim = if b == self.len() - 1 {
683                    1
684                } else {
685                    self.j_set[b].len()
686                };
687                let mut tensor = tensor3_zeros(left_dim, local_dim, right_dim);
688                for l in 0..left_dim {
689                    for s in 0..local_dim {
690                        for r in 0..right_dim {
691                            let col = s * right_dim + r;
692                            if l < mat.nrows() && col < mat.ncols() {
693                                tensor.set3(l, s, r, mat[[l, col]]);
694                            }
695                        }
696                    }
697                }
698                self.site_tensors[b] = tensor;
699            }
700        }
701
702        // Update errors
703        let errors = luci.pivot_errors();
704        if !errors.is_empty() {
705            let bond_idx = if forward { b } else { b - 1 };
706            self.bond_errors[bond_idx] = *errors.last().unwrap_or(&0.0);
707        }
708        self.update_pivot_errors(&errors);
709
710        Ok(())
711    }
712
713    /// Fill all site tensors using 1-site LU decomposition at each bond.
714    ///
715    /// For each site b (except the last), computes the Pi matrix
716    /// (Kronecker(I_b, d_b) × J_b) and the pivot matrix P (I_{b+1} × J_b),
717    /// then solves P^T \ Pi^T to get the site tensor T_b = Pi * P^{-1}.
718    /// The last site tensor is set by direct evaluation.
719    ///
720    /// Port of Julia's `fillsitetensors!` / `setsitetensor!`.
721    pub fn fill_site_tensors<F>(&mut self, f: &F) -> Result<()>
722    where
723        F: Fn(&MultiIndex) -> T,
724    {
725        let n = self.len();
726        for b in 0..n {
727            let i_kron = self.kronecker_i(b);
728            let j_set_b = self.j_set[b].clone();
729
730            if i_kron.is_empty() || j_set_b.is_empty() {
731                continue;
732            }
733
734            // Pi1: evaluate f at (Kronecker(I_b, d_b), J_b)
735            let ni = i_kron.len();
736            let nj = j_set_b.len();
737            let mut pi1 = zeros(ni, nj);
738            for (i, i_multi) in i_kron.iter().enumerate() {
739                for (j, j_multi) in j_set_b.iter().enumerate() {
740                    let mut full_idx = i_multi.clone();
741                    full_idx.extend(j_multi.iter().cloned());
742                    pi1[[i, j]] = f(&full_idx);
743                }
744            }
745
746            if b == n - 1 {
747                // Last site: store Pi1 directly
748                let left_dim = if b == 0 { 1 } else { self.i_set[b].len() };
749                let site_dim = self.local_dims[b];
750                let right_dim = 1; // last site
751                let mut tensor = tensor3_zeros(left_dim, site_dim, right_dim);
752                for l in 0..left_dim {
753                    for s in 0..site_dim {
754                        let row = l * site_dim + s;
755                        if row < ni {
756                            tensor.set3(l, s, 0, pi1[[row, 0]]);
757                        }
758                    }
759                }
760                self.site_tensors[b] = tensor;
761            } else {
762                // Non-last site: solve P^T \ Pi1^T to get Tmat = Pi1 * P^{-1}
763                // P = pivot matrix (I_{b+1} × J_b)
764                let i_set_bp1 = self.i_set[b + 1].clone();
765                let np = i_set_bp1.len();
766
767                let mut p_mat = zeros(np, nj);
768                for (i, i_multi) in i_set_bp1.iter().enumerate() {
769                    for (j, j_multi) in j_set_b.iter().enumerate() {
770                        let mut full_idx = i_multi.clone();
771                        full_idx.extend(j_multi.iter().cloned());
772                        p_mat[[i, j]] = f(&full_idx);
773                    }
774                }
775
776                // Solve P * X^T = Pi1^T via LU factorization
777                // First transpose P and Pi1
778                let mut p_t = zeros(nj, np);
779                for i in 0..np {
780                    for j in 0..nj {
781                        p_t[[j, i]] = p_mat[[i, j]];
782                    }
783                }
784                let mut pi1_t = zeros(nj, ni);
785                for i in 0..ni {
786                    for j in 0..nj {
787                        pi1_t[[j, i]] = pi1[[i, j]];
788                    }
789                }
790
791                // LU factorize P^T with full pivoting
792                let lu = rrlu(&p_t, None)?;
793                let l_mat = lu.left(true);
794                let u_mat = lu.right(true);
795
796                // Solve L*U * X_t = Pi1^T
797                let x_t = tensor4all_tcicore::matrixlu::solve_lu(&l_mat, &u_mat, &pi1_t)?;
798
799                // X = X_t^T → shape (ni, np) = (|I_b|*d_b, |I_{b+1}|)
800                let left_dim = if b == 0 { 1 } else { self.i_set[b].len() };
801                let site_dim = self.local_dims[b];
802                let right_dim = np; // = |I_{b+1}|
803                let mut tensor = tensor3_zeros(left_dim, site_dim, right_dim);
804                for l in 0..left_dim {
805                    for s in 0..site_dim {
806                        for r in 0..right_dim {
807                            let row = l * site_dim + s;
808                            tensor.set3(l, s, r, x_t[[r, row]]);
809                        }
810                    }
811                }
812                self.site_tensors[b] = tensor;
813            }
814        }
815        Ok(())
816    }
817
818    /// Make the TCI canonical by performing 3 one-site sweeps.
819    ///
820    /// 1. Forward sweep (exact, no truncation)
821    /// 2. Backward sweep (with truncation)
822    /// 3. Forward sweep (with truncation + update tensors)
823    ///
824    /// Port of Julia's `makecanonical!`.
825    pub fn make_canonical<F>(
826        &mut self,
827        f: &F,
828        rel_tol: f64,
829        abs_tol: f64,
830        max_bond_dim: usize,
831    ) -> Result<()>
832    where
833        F: Fn(&MultiIndex) -> T,
834    {
835        // First half-sweep: exact, no truncation
836        self.sweep1site(f, true, 0.0, 0.0, usize::MAX, false)?;
837        // Second half-sweep: backward with truncation
838        self.sweep1site(f, false, rel_tol, abs_tol, max_bond_dim, false)?;
839        // Third half-sweep: forward with truncation and tensor updates
840        self.sweep1site(f, true, rel_tol, abs_tol, max_bond_dim, true)?;
841        Ok(())
842    }
843
844    /// Expand indices by Kronecker product with local dimension
845    fn kronecker_i(&self, p: usize) -> Vec<MultiIndex> {
846        let mut result = Vec::new();
847        for i_multi in &self.i_set[p] {
848            for local_idx in 0..self.local_dims[p] {
849                let mut new_idx = i_multi.clone();
850                new_idx.push(local_idx);
851                result.push(new_idx);
852            }
853        }
854        result
855    }
856
857    fn kronecker_j(&self, p: usize) -> Vec<MultiIndex> {
858        let mut result = Vec::new();
859        for local_idx in 0..self.local_dims[p] {
860            for j_multi in &self.j_set[p] {
861                let mut new_idx = vec![local_idx];
862                new_idx.extend(j_multi.iter().cloned());
863                result.push(new_idx);
864            }
865        }
866        result
867    }
868}
869
870/// Check convergence based on history of ranks, errors, and global pivots.
871///
872/// Port of Julia's `convergencecriterion`.
873fn convergence_criterion(
874    ranks: &[usize],
875    errors: &[f64],
876    nglobal_pivots: &[usize],
877    tolerance: f64,
878    max_bond_dim: usize,
879    ncheck_history: usize,
880) -> bool {
881    if errors.len() < ncheck_history {
882        return false;
883    }
884
885    let n = errors.len();
886    let last_errors = &errors[n - ncheck_history..];
887    let last_ranks = &ranks[n - ncheck_history..];
888    let last_ngp = &nglobal_pivots[n - ncheck_history..];
889
890    let errors_converged = last_errors.iter().all(|&e| e < tolerance);
891    let no_global_pivots = last_ngp.iter().all(|&n| n == 0);
892    let rank_stable =
893        last_ranks.iter().min().copied().unwrap_or(0) == last_ranks.last().copied().unwrap_or(0);
894    let at_max_bond = last_ranks.iter().all(|&r| r >= max_bond_dim);
895
896    (errors_converged && no_global_pivots && rank_stable) || at_max_bond
897}
898
899/// Approximate a function as a tensor train using the TCI2 algorithm.
900///
901/// This is the main entry point for tensor cross interpolation. It
902/// performs alternating two-site sweeps with global pivot search until
903/// convergence, then cleans up with a final one-site sweep.
904///
905/// # Arguments
906///
907/// * `f` -- Function to interpolate. Takes a multi-index `&Vec<usize>` where
908///   each element is in `0..local_dims[i]` (0-indexed) and returns a scalar.
909/// * `batched_f` -- Optional batch evaluation function for efficiency.
910///   Takes `&[Vec<usize>]` and returns `Vec<T>`. Pass `None` to use
911///   element-wise evaluation only.
912/// * `local_dims` -- Number of values each index can take. Must have at
913///   least 2 elements (TCI requires at least 2 sites).
914/// * `initial_pivots` -- Starting multi-indices for the algorithm. At least
915///   one pivot must have a non-zero function value. Choose pivots where
916///   `|f|` is large for best convergence. If empty, defaults to the
917///   all-zeros index.
918/// * `options` -- Algorithm configuration; see [`TCI2Options`].
919///
920/// # Returns
921///
922/// A tuple `(tci, ranks, errors)`:
923///
924/// * `tci: TensorCI2<T>` -- The interpolation state. Call
925///   [`to_tensor_train`](TensorCI2::to_tensor_train) to get a
926///   [`TensorTrain`].
927/// * `ranks: Vec<usize>` -- Bond dimension after each half-sweep.
928/// * `errors: Vec<f64>` -- Normalized error estimate after each half-sweep.
929///   The last value should be below `tolerance` if the algorithm converged.
930///
931/// # Errors
932///
933/// Returns [`TCIError::DimensionMismatch`] if `local_dims` has fewer than
934/// 2 elements, or [`TCIError::InvalidPivot`] if all initial pivots
935/// evaluate to zero.
936///
937/// # Examples
938///
939/// Basic usage: interpolate `f(i, j) = i + j + 1` on a 4x4 grid.
940///
941/// ```
942/// use tensor4all_tensorci::{crossinterpolate2, TCI2Options};
943/// use tensor4all_simplett::AbstractTensorTrain;
944///
945/// let f = |idx: &Vec<usize>| (idx[0] + idx[1] + 1) as f64;
946/// let local_dims = vec![4, 4];
947/// let initial_pivots = vec![vec![3, 3]]; // pick where |f| is large
948///
949/// let (tci, _ranks, errors) =
950///     crossinterpolate2::<f64, _, fn(&[Vec<usize>]) -> Vec<f64>>(
951///         f,
952///         None,
953///         local_dims,
954///         initial_pivots,
955///         TCI2Options {
956///             tolerance: 1e-10,
957///             seed: Some(42),
958///             ..TCI2Options::default()
959///         },
960///     )
961///     .unwrap();
962///
963/// // Verify convergence
964/// assert!(*errors.last().unwrap() < 1e-10);
965///
966/// // Evaluate through the tensor train
967/// let tt = tci.to_tensor_train().unwrap();
968/// let val = tt.evaluate(&[2, 3]).unwrap();
969/// assert!((val - 6.0).abs() < 1e-10); // f(2,3) = 2+3+1 = 6
970///
971/// // Bond dimensions are available
972/// assert!(!tci.link_dims().is_empty());
973/// ```
974pub fn crossinterpolate2<T, F, B>(
975    f: F,
976    batched_f: Option<B>,
977    local_dims: Vec<usize>,
978    initial_pivots: Vec<MultiIndex>,
979    options: TCI2Options,
980) -> Result<(TensorCI2<T>, Vec<usize>, Vec<f64>)>
981where
982    T: Scalar + TTScalar + Default + tensor4all_tcicore::MatrixLuciScalar,
983    DenseFaerLuKernel: PivotKernel<T>,
984    LazyBlockRookKernel: PivotKernel<T>,
985    F: Fn(&MultiIndex) -> T,
986    B: Fn(&[MultiIndex]) -> Vec<T>,
987{
988    if local_dims.len() < 2 {
989        return Err(TCIError::DimensionMismatch {
990            message: "local_dims should have at least 2 elements".to_string(),
991        });
992    }
993
994    let pivots = if initial_pivots.is_empty() {
995        vec![vec![0; local_dims.len()]]
996    } else {
997        initial_pivots
998    };
999
1000    let mut tci = TensorCI2::new(local_dims)?;
1001    tci.add_global_pivots(&pivots)?;
1002
1003    // Initialize max_sample_value
1004    for pivot in &pivots {
1005        let value = f(pivot);
1006        let abs_val = f64::sqrt(Scalar::abs_sq(value));
1007        if abs_val > tci.max_sample_value {
1008            tci.max_sample_value = abs_val;
1009        }
1010    }
1011
1012    if tci.max_sample_value < 1e-30 {
1013        return Err(TCIError::InvalidPivot {
1014            message: "Initial pivots have zero function values".to_string(),
1015        });
1016    }
1017
1018    let n = tci.len();
1019    let mut errors = Vec::new();
1020    let mut ranks = Vec::new();
1021    let mut nglobal_pivots_history: Vec<usize> = Vec::new();
1022
1023    // Create RNG
1024    let mut rng = if let Some(seed) = options.seed {
1025        rand::rngs::StdRng::seed_from_u64(seed)
1026    } else {
1027        rand::rngs::StdRng::from_os_rng()
1028    };
1029
1030    // Create global pivot finder
1031    let finder = DefaultGlobalPivotFinder::new(
1032        options.nsearch,
1033        options.max_nglobal_pivot,
1034        options.tol_margin_global_search,
1035    );
1036
1037    // Main optimization loop
1038    for iter in 0..options.max_iter {
1039        let error_normalization = if options.normalize_error && tci.max_sample_value > 0.0 {
1040            tci.max_sample_value
1041        } else {
1042            1.0
1043        };
1044        let abs_tol = options.tolerance * error_normalization;
1045
1046        // Determine sweep direction
1047        let is_forward = match options.sweep_strategy {
1048            Sweep2Strategy::Forward => true,
1049            Sweep2Strategy::Backward => false,
1050            Sweep2Strategy::BackAndForth => iter % 2 == 0,
1051        };
1052
1053        // Get extra index sets from history for non-strictly-nested mode
1054        let (extra_i_set, extra_j_set) =
1055            if !options.strictly_nested && !tci.i_set_history.is_empty() {
1056                let last = tci.i_set_history.len() - 1;
1057                (
1058                    tci.i_set_history[last].clone(),
1059                    tci.j_set_history[last].clone(),
1060                )
1061            } else {
1062                let empty: Vec<Vec<MultiIndex>> = (0..n).map(|_| Vec::new()).collect();
1063                (empty.clone(), empty)
1064            };
1065
1066        // Save current sets to history
1067        tci.i_set_history.push(tci.i_set.clone());
1068        tci.j_set_history.push(tci.j_set.clone());
1069
1070        // 2-site sweep
1071        tci.invalidate_site_tensors();
1072        tci.flush_pivot_errors();
1073
1074        if is_forward {
1075            for b in 0..n - 1 {
1076                update_pivots(
1077                    &mut tci,
1078                    b,
1079                    &f,
1080                    PivotUpdateContext {
1081                        batched_f: &batched_f,
1082                        left_orthogonal: true,
1083                        options: &options,
1084                        extra_i_set: &extra_i_set[b + 1],
1085                        extra_j_set: &extra_j_set[b],
1086                    },
1087                )?;
1088            }
1089        } else {
1090            for b in (0..n - 1).rev() {
1091                update_pivots(
1092                    &mut tci,
1093                    b,
1094                    &f,
1095                    PivotUpdateContext {
1096                        batched_f: &batched_f,
1097                        left_orthogonal: false,
1098                        options: &options,
1099                        extra_i_set: &extra_i_set[b + 1],
1100                        extra_j_set: &extra_j_set[b],
1101                    },
1102                )?;
1103            }
1104        }
1105
1106        // Fill site tensors after sweep
1107        tci.fill_site_tensors(&f)?;
1108
1109        // Record error
1110        let error = tci.max_bond_error();
1111        let error_normalized = error / error_normalization;
1112        errors.push(error_normalized);
1113
1114        // Global pivot search
1115        let tt = tci.to_tensor_train()?;
1116        let input = GlobalPivotSearchInput {
1117            local_dims: tci.local_dims.clone(),
1118            current_tt: tt,
1119            max_sample_value: tci.max_sample_value,
1120            i_set: tci.i_set.clone(),
1121            j_set: tci.j_set.clone(),
1122        };
1123
1124        let global_pivots = finder.find_global_pivots(&input, &f, abs_tol, &mut rng);
1125        let n_global = global_pivots.len();
1126        tci.add_global_pivots(&global_pivots)?;
1127        nglobal_pivots_history.push(n_global);
1128
1129        ranks.push(tci.rank());
1130
1131        if options.verbosity > 0 {
1132            println!(
1133                "iteration = {}, rank = {}, error = {:.2e}, maxsamplevalue = {:.2e}, nglobalpivot = {}",
1134                iter + 1,
1135                tci.rank(),
1136                error_normalized,
1137                tci.max_sample_value,
1138                n_global
1139            );
1140        }
1141
1142        // Check convergence
1143        if convergence_criterion(
1144            &ranks,
1145            &errors,
1146            &nglobal_pivots_history,
1147            abs_tol,
1148            options.max_bond_dim,
1149            options.ncheck_history,
1150        ) {
1151            break;
1152        }
1153    }
1154
1155    // Final 1-site sweep to:
1156    // 1. Remove unnecessary pivots added by global pivots
1157    // 2. Compute site tensors
1158    let error_normalization = if options.normalize_error && tci.max_sample_value > 0.0 {
1159        tci.max_sample_value
1160    } else {
1161        1.0
1162    };
1163    let abs_tol = options.tolerance * error_normalization;
1164    tci.sweep1site(&f, true, 1e-14, abs_tol, options.max_bond_dim, true)?;
1165
1166    // Normalize errors for return
1167    let normalized_errors = errors.to_vec();
1168
1169    Ok((tci, ranks, normalized_errors))
1170}
1171
1172/// Update pivots at bond b using LU-based cross interpolation
1173fn update_pivots<T, F, B>(
1174    tci: &mut TensorCI2<T>,
1175    b: usize,
1176    f: &F,
1177    context: PivotUpdateContext<'_, B>,
1178) -> Result<()>
1179where
1180    T: Scalar + TTScalar + Default + tensor4all_tcicore::MatrixLuciScalar,
1181    DenseFaerLuKernel: PivotKernel<T>,
1182    LazyBlockRookKernel: PivotKernel<T>,
1183    F: Fn(&MultiIndex) -> T,
1184    B: Fn(&[MultiIndex]) -> Vec<T>,
1185{
1186    // Build combined index sets, including extra sets from history
1187    let mut i_combined = tci.kronecker_i(b);
1188    let mut j_combined = tci.kronecker_j(b + 1);
1189
1190    // Union with extra sets (for non-strictly-nested mode)
1191    for extra in context.extra_i_set {
1192        if !i_combined.contains(extra) {
1193            i_combined.push(extra.clone());
1194        }
1195    }
1196    for extra in context.extra_j_set {
1197        if !j_combined.contains(extra) {
1198            j_combined.push(extra.clone());
1199        }
1200    }
1201
1202    if i_combined.is_empty() || j_combined.is_empty() {
1203        return Ok(());
1204    }
1205
1206    // Apply LU-based cross interpolation
1207    let lu_options = PivotKernelOptions {
1208        max_rank: context.options.max_bond_dim,
1209        rel_tol: context.options.tolerance,
1210        abs_tol: 0.0,
1211        left_orthogonal: context.left_orthogonal,
1212    };
1213
1214    let selection;
1215    let factors;
1216    if context.options.pivot_search == PivotSearchStrategy::Full {
1217        let mut pi = zeros(i_combined.len(), j_combined.len());
1218
1219        if let Some(ref batch_fn) = context.batched_f {
1220            let mut all_indices: Vec<MultiIndex> =
1221                Vec::with_capacity(i_combined.len() * j_combined.len());
1222            for i_multi in &i_combined {
1223                for j_multi in &j_combined {
1224                    let mut full_idx = i_multi.clone();
1225                    full_idx.extend(j_multi.iter().cloned());
1226                    all_indices.push(full_idx);
1227                }
1228            }
1229
1230            let values = batch_fn(&all_indices);
1231            if values.len() != all_indices.len() {
1232                return Err(callback_length_mismatch(values.len(), all_indices.len()));
1233            }
1234            let mut idx = 0;
1235            for i in 0..i_combined.len() {
1236                for j in 0..j_combined.len() {
1237                    pi[[i, j]] = values[idx];
1238                    update_max_sample_value(tci, values[idx]);
1239                    idx += 1;
1240                }
1241            }
1242        } else {
1243            for (i, i_multi) in i_combined.iter().enumerate() {
1244                for (j, j_multi) in j_combined.iter().enumerate() {
1245                    let mut full_idx = i_multi.clone();
1246                    full_idx.extend(j_multi.iter().cloned());
1247                    let value = f(&full_idx);
1248                    pi[[i, j]] = value;
1249                    update_max_sample_value(tci, value);
1250                }
1251            }
1252        }
1253
1254        let mut data = Vec::with_capacity(pi.nrows() * pi.ncols());
1255        for col in 0..pi.ncols() {
1256            for row in 0..pi.nrows() {
1257                data.push(pi[[row, col]]);
1258            }
1259        }
1260        let source = DenseMatrixSource::from_column_major(&data, pi.nrows(), pi.ncols());
1261        selection = DenseFaerLuKernel.factorize(&source, &lu_options)?;
1262        factors = CrossFactors::from_source(&source, &selection)?;
1263    } else {
1264        let evaluator = LazyPiEvaluator::new(
1265            &i_combined,
1266            &j_combined,
1267            f,
1268            context.batched_f,
1269            tci.max_sample_value,
1270        );
1271        let source = LazyMatrixSource::new(
1272            i_combined.len(),
1273            j_combined.len(),
1274            |rows, cols, out: &mut [T]| {
1275                evaluator.fill_block(rows, cols, out);
1276            },
1277        );
1278        let selection_result = LazyBlockRookKernel.factorize(&source, &lu_options);
1279        if let Some(err) = evaluator.take_error() {
1280            return Err(err);
1281        }
1282        selection = selection_result?;
1283
1284        let factors_result = CrossFactors::from_source(&source, &selection);
1285        if let Some(err) = evaluator.take_error() {
1286            return Err(err);
1287        }
1288        factors = factors_result?;
1289        tci.max_sample_value = evaluator.sampled_max();
1290    }
1291
1292    // Update I and J sets
1293    let row_indices = &selection.row_indices;
1294    let col_indices = &selection.col_indices;
1295
1296    tci.i_set[b + 1] = row_indices.iter().map(|&i| i_combined[i].clone()).collect();
1297    tci.j_set[b] = col_indices.iter().map(|&j| j_combined[j].clone()).collect();
1298
1299    // Skip site tensor update if extra sets were used (tensors will be
1300    // filled separately by fill_site_tensors after the sweep).
1301    if !context.extra_i_set.is_empty() || !context.extra_j_set.is_empty() {
1302        // Update bond error only
1303        let errors = &selection.pivot_errors;
1304        if !errors.is_empty() {
1305            tci.bond_errors[b] = *errors.last().unwrap_or(&0.0);
1306        }
1307        return Ok(());
1308    }
1309
1310    // Update site tensors
1311    let left = if context.left_orthogonal {
1312        factors.cols_times_pivot_inv()?
1313    } else {
1314        factors.pivot_cols.clone()
1315    };
1316    let right = if context.left_orthogonal {
1317        factors.pivot_rows.clone()
1318    } else {
1319        factors.pivot_inv_times_rows()?
1320    };
1321
1322    // Convert left matrix to tensor at site b
1323    let left_dim = if b == 0 { 1 } else { tci.i_set[b].len() };
1324    let site_dim_b = tci.local_dims[b];
1325    let new_bond_dim = selection.rank.max(1);
1326
1327    let mut tensor_b = tensor3_zeros(left_dim, site_dim_b, new_bond_dim);
1328    for l in 0..left_dim {
1329        for s in 0..site_dim_b {
1330            for r in 0..new_bond_dim {
1331                let row = l * site_dim_b + s;
1332                if row < left.nrows() && r < left.ncols() {
1333                    tensor_b.set3(l, s, r, left[[row, r]]);
1334                }
1335            }
1336        }
1337    }
1338    tci.site_tensors[b] = tensor_b;
1339
1340    // Convert right matrix to tensor at site b+1
1341    let site_dim_bp1 = tci.local_dims[b + 1];
1342    let right_dim = if b + 1 == tci.len() - 1 {
1343        1
1344    } else {
1345        tci.j_set[b + 1].len()
1346    };
1347
1348    let mut tensor_bp1 = tensor3_zeros(new_bond_dim, site_dim_bp1, right_dim);
1349    for l in 0..new_bond_dim {
1350        for s in 0..site_dim_bp1 {
1351            for r in 0..right_dim {
1352                let col = s * right_dim + r;
1353                if l < right.nrows() && col < right.ncols() {
1354                    tensor_bp1.set3(l, s, r, right[[l, col]]);
1355                }
1356            }
1357        }
1358    }
1359    tci.site_tensors[b + 1] = tensor_bp1;
1360
1361    // Update bond error
1362    if !selection.pivot_errors.is_empty() {
1363        tci.bond_errors[b] = *selection.pivot_errors.last().unwrap_or(&0.0);
1364    }
1365
1366    Ok(())
1367}
1368
1369fn update_max_sample_value<T: Scalar + TTScalar>(tci: &mut TensorCI2<T>, value: T) {
1370    let abs_val = f64::sqrt(Scalar::abs_sq(value));
1371    if abs_val > tci.max_sample_value {
1372        tci.max_sample_value = abs_val;
1373    }
1374}
1375
1376fn build_full_index(
1377    i_combined: &[MultiIndex],
1378    j_combined: &[MultiIndex],
1379    row: usize,
1380    col: usize,
1381) -> MultiIndex {
1382    let mut full_idx = i_combined[row].clone();
1383    full_idx.extend(j_combined[col].iter().cloned());
1384    full_idx
1385}
1386
1387fn callback_length_mismatch(actual: usize, expected: usize) -> TCIError {
1388    TCIError::InvalidOperation {
1389        message: format!(
1390            "batch callback returned {actual} values for {expected} requested entries"
1391        ),
1392    }
1393}
1394
1395struct LazyPiEvaluator<'a, T, F, B>
1396where
1397    T: Scalar + TTScalar + Default + tensor4all_tcicore::MatrixLuciScalar,
1398    F: Fn(&MultiIndex) -> T,
1399    B: Fn(&[MultiIndex]) -> Vec<T>,
1400{
1401    i_combined: &'a [MultiIndex],
1402    j_combined: &'a [MultiIndex],
1403    f: &'a F,
1404    batched_f: &'a Option<B>,
1405    cache: RefCell<HashMap<(usize, usize), T>>,
1406    pending_error: RefCell<Option<TCIError>>,
1407    sampled_max: Cell<f64>,
1408}
1409
1410impl<'a, T, F, B> LazyPiEvaluator<'a, T, F, B>
1411where
1412    T: Scalar + TTScalar + Default + tensor4all_tcicore::MatrixLuciScalar,
1413    F: Fn(&MultiIndex) -> T,
1414    B: Fn(&[MultiIndex]) -> Vec<T>,
1415{
1416    fn new(
1417        i_combined: &'a [MultiIndex],
1418        j_combined: &'a [MultiIndex],
1419        f: &'a F,
1420        batched_f: &'a Option<B>,
1421        initial_max: f64,
1422    ) -> Self {
1423        Self {
1424            i_combined,
1425            j_combined,
1426            f,
1427            batched_f,
1428            cache: RefCell::new(HashMap::new()),
1429            pending_error: RefCell::new(None),
1430            sampled_max: Cell::new(initial_max),
1431        }
1432    }
1433
1434    fn fill_block(&self, rows: &[usize], cols: &[usize], out: &mut [T]) {
1435        if self.pending_error.borrow().is_some() {
1436            out.fill(T::zero());
1437            return;
1438        }
1439
1440        let mut missing_entries = Vec::new();
1441        let mut missing_indices = Vec::new();
1442
1443        {
1444            let cache_ref = self.cache.borrow();
1445            for (j_pos, &col) in cols.iter().enumerate() {
1446                for (i_pos, &row) in rows.iter().enumerate() {
1447                    let out_idx = i_pos + rows.len() * j_pos;
1448                    if let Some(&value) = cache_ref.get(&(row, col)) {
1449                        out[out_idx] = value;
1450                    } else {
1451                        missing_entries.push((out_idx, row, col));
1452                        missing_indices.push(build_full_index(
1453                            self.i_combined,
1454                            self.j_combined,
1455                            row,
1456                            col,
1457                        ));
1458                    }
1459                }
1460            }
1461        }
1462
1463        if missing_entries.is_empty() {
1464            return;
1465        }
1466
1467        let values = if let Some(batch_fn) = self.batched_f {
1468            batch_fn(&missing_indices)
1469        } else {
1470            missing_indices.iter().map(self.f).collect()
1471        };
1472        if values.len() != missing_entries.len() {
1473            *self.pending_error.borrow_mut() = Some(callback_length_mismatch(
1474                values.len(),
1475                missing_entries.len(),
1476            ));
1477            for (out_idx, _, _) in missing_entries {
1478                out[out_idx] = T::zero();
1479            }
1480            return;
1481        }
1482
1483        let mut cache_ref = self.cache.borrow_mut();
1484        for ((out_idx, row, col), value) in missing_entries.into_iter().zip(values) {
1485            out[out_idx] = value;
1486            cache_ref.insert((row, col), value);
1487
1488            let abs_val = f64::sqrt(Scalar::abs_sq(value));
1489            if abs_val > self.sampled_max.get() {
1490                self.sampled_max.set(abs_val);
1491            }
1492        }
1493    }
1494
1495    fn sampled_max(&self) -> f64 {
1496        self.sampled_max.get()
1497    }
1498
1499    fn take_error(&self) -> Option<TCIError> {
1500        self.pending_error.borrow_mut().take()
1501    }
1502}
1503
1504#[cfg(test)]
1505mod tests;