Skip to main content

tensor4all_tensorci/
tensorci1.rs

1//! TensorCI1 -- One-site Tensor Cross Interpolation algorithm (legacy).
2//!
3//! This module is kept for backward compatibility. For new code, prefer
4//! [`crossinterpolate2`](crate::crossinterpolate2) which uses two-site
5//! updates and generally converges faster.
6
7use crate::error::{Result, TCIError};
8use tensor4all_simplett::{tensor3_zeros, TTScalar, Tensor3, Tensor3Ops, TensorTrain};
9use tensor4all_tcicore::matrix::{a_times_b_inv, mat_mul, zeros, Matrix};
10use tensor4all_tcicore::Scalar;
11use tensor4all_tcicore::{AbstractMatrixCI, MatrixACA};
12use tensor4all_tcicore::{IndexSet, MultiIndex};
13
14/// Sweep direction for TCI1 optimization.
15///
16/// # Examples
17///
18/// ```
19/// use tensor4all_tensorci::SweepStrategy;
20///
21/// // BackAndForth is the default
22/// let strategy = SweepStrategy::default();
23/// assert_eq!(strategy, SweepStrategy::BackAndForth);
24///
25/// // All three variants are distinct
26/// assert_ne!(SweepStrategy::Forward, SweepStrategy::Backward);
27/// assert_ne!(SweepStrategy::Forward, SweepStrategy::BackAndForth);
28/// ```
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
30pub enum SweepStrategy {
31    /// Sweep left-to-right only.
32    Forward,
33    /// Sweep right-to-left only.
34    Backward,
35    /// Alternate between forward and backward sweeps (default).
36    #[default]
37    BackAndForth,
38}
39
40/// Returns true if this iteration should be a forward sweep
41fn forward_sweep(strategy: SweepStrategy, iter: usize) -> bool {
42    match strategy {
43        SweepStrategy::Forward => true,
44        SweepStrategy::Backward => false,
45        SweepStrategy::BackAndForth => iter % 2 == 1,
46    }
47}
48
49/// Configuration for the TCI1 algorithm ([`crossinterpolate1`]).
50///
51/// See [`TCI2Options`](crate::TCI2Options) for the recommended TCI2
52/// counterpart.
53///
54/// # Examples
55///
56/// ```
57/// use tensor4all_tensorci::{TCI1Options, SweepStrategy};
58///
59/// // Default options
60/// let opts = TCI1Options::default();
61/// assert!((opts.tolerance - 1e-8).abs() < 1e-15);
62/// assert_eq!(opts.max_iter, 200);
63/// assert_eq!(opts.sweep_strategy, SweepStrategy::BackAndForth);
64/// assert!((opts.pivot_tolerance - 1e-12).abs() < 1e-20);
65/// assert!(opts.normalize_error);
66/// assert_eq!(opts.verbosity, 0);
67///
68/// // Custom options via struct update syntax
69/// let custom = TCI1Options {
70///     tolerance: 1e-10,
71///     max_iter: 50,
72///     sweep_strategy: SweepStrategy::Forward,
73///     ..TCI1Options::default()
74/// };
75/// assert!((custom.tolerance - 1e-10).abs() < 1e-20);
76/// assert_eq!(custom.max_iter, 50);
77/// assert_eq!(custom.sweep_strategy, SweepStrategy::Forward);
78/// ```
79#[derive(Debug, Clone)]
80pub struct TCI1Options {
81    /// Convergence tolerance (default: `1e-8`).
82    pub tolerance: f64,
83    /// Maximum number of iterations (default: `200`).
84    pub max_iter: usize,
85    /// Sweep strategy (default: [`SweepStrategy::BackAndForth`]).
86    pub sweep_strategy: SweepStrategy,
87    /// Minimum pivot error to add a new pivot (default: `1e-12`).
88    pub pivot_tolerance: f64,
89    /// Whether to normalize error by the maximum sample value (default: `true`).
90    pub normalize_error: bool,
91    /// Verbosity level (default: `0` = silent).
92    pub verbosity: usize,
93}
94
95impl Default for TCI1Options {
96    fn default() -> Self {
97        Self {
98            tolerance: 1e-8,
99            max_iter: 200,
100            sweep_strategy: SweepStrategy::BackAndForth,
101            pivot_tolerance: 1e-12,
102            normalize_error: true,
103            verbosity: 0,
104        }
105    }
106}
107
108/// State object for the one-site TCI algorithm (legacy).
109///
110/// Prefer [`TensorCI2`](crate::TensorCI2) for new code.
111/// `TensorCI1` uses [`MatrixACA`] for pivot
112/// selection and performs one-site updates (one index at a time).
113#[derive(Debug, Clone)]
114pub struct TensorCI1<T: Scalar + TTScalar> {
115    /// Index sets I for each site
116    i_set: Vec<IndexSet<MultiIndex>>,
117    /// Index sets J for each site
118    j_set: Vec<IndexSet<MultiIndex>>,
119    /// Local dimensions
120    local_dims: Vec<usize>,
121    /// T tensors (3-leg tensors)
122    t_tensors: Vec<Tensor3<T>>,
123    /// P matrices (pivot matrices)
124    p_matrices: Vec<Matrix<T>>,
125    /// ACA decompositions at each bond
126    aca: Vec<MatrixACA<T>>,
127    /// Pi matrices at each bond
128    pi: Vec<Matrix<T>>,
129    /// Pi I sets
130    pi_i_set: Vec<IndexSet<MultiIndex>>,
131    /// Pi J sets
132    pi_j_set: Vec<IndexSet<MultiIndex>>,
133    /// Pivot errors at each bond
134    pivot_errors: Vec<f64>,
135    /// Maximum sample value found
136    max_sample_value: f64,
137}
138
139impl<T: Scalar + TTScalar + Default> TensorCI1<T> {
140    /// Create a new empty TensorCI1
141    pub fn new(local_dims: Vec<usize>) -> Self {
142        let n = local_dims.len();
143        Self {
144            i_set: (0..n).map(|_| IndexSet::new()).collect(),
145            j_set: (0..n).map(|_| IndexSet::new()).collect(),
146            local_dims: local_dims.clone(),
147            t_tensors: local_dims.iter().map(|&d| tensor3_zeros(0, d, 0)).collect(),
148            p_matrices: (0..n).map(|_| zeros(0, 0)).collect(),
149            aca: (0..n).map(|_| MatrixACA::new(0, 0)).collect(),
150            pi: (0..n).map(|_| zeros(0, 0)).collect(),
151            pi_i_set: (0..n).map(|_| IndexSet::new()).collect(),
152            pi_j_set: (0..n).map(|_| IndexSet::new()).collect(),
153            pivot_errors: vec![f64::INFINITY; n.saturating_sub(1)],
154            max_sample_value: 0.0,
155        }
156    }
157
158    /// Number of sites
159    pub fn len(&self) -> usize {
160        self.t_tensors.len()
161    }
162
163    /// Check if empty
164    pub fn is_empty(&self) -> bool {
165        self.t_tensors.is_empty()
166    }
167
168    /// Get local dimensions
169    pub fn local_dims(&self) -> &[usize] {
170        &self.local_dims
171    }
172
173    /// Get current rank (maximum bond dimension)
174    pub fn rank(&self) -> usize {
175        if self.t_tensors.len() <= 1 {
176            return if self.i_set.is_empty() || self.i_set[0].is_empty() {
177                0
178            } else {
179                1
180            };
181        }
182        self.t_tensors
183            .iter()
184            .skip(1)
185            .map(|t| t.left_dim())
186            .max()
187            .unwrap_or(0)
188    }
189
190    /// Get bond dimensions
191    pub fn link_dims(&self) -> Vec<usize> {
192        if self.t_tensors.len() <= 1 {
193            return Vec::new();
194        }
195        self.t_tensors
196            .iter()
197            .skip(1)
198            .map(|t| t.left_dim())
199            .collect()
200    }
201
202    /// Get the maximum pivot error from the last sweep
203    pub fn last_sweep_pivot_error(&self) -> f64 {
204        self.pivot_errors.iter().cloned().fold(0.0, f64::max)
205    }
206
207    /// Get site tensor at position p (T * P^{-1})
208    pub fn site_tensor(&self, p: usize) -> Tensor3<T> {
209        if p >= self.len() {
210            return tensor3_zeros(1, 1, 1);
211        }
212
213        let t = &self.t_tensors[p];
214        let shape = (t.left_dim(), t.site_dim(), t.right_dim());
215
216        // If empty, return identity-like tensor
217        if shape.0 == 0 || shape.2 == 0 {
218            return t.clone();
219        }
220
221        // Compute T * P^{-1}
222        let t_mat = tensor3_to_matrix(t);
223        let p_mat = &self.p_matrices[p];
224
225        if p_mat.nrows() == 0 || p_mat.ncols() == 0 {
226            return self.t_tensors[p].clone();
227        }
228
229        let result = a_times_b_inv(&t_mat, p_mat);
230        matrix_to_tensor3(&result, shape.0, shape.1, shape.2)
231    }
232
233    /// Get all site tensors
234    pub fn site_tensors(&self) -> Vec<Tensor3<T>> {
235        (0..self.len()).map(|p| self.site_tensor(p)).collect()
236    }
237
238    /// Convert to TensorTrain
239    pub fn to_tensor_train(&self) -> Result<TensorTrain<T>> {
240        let tensors = self.site_tensors();
241        TensorTrain::new(tensors).map_err(TCIError::TensorTrainError)
242    }
243
244    /// Get maximum sample value
245    pub fn max_sample_value(&self) -> f64 {
246        self.max_sample_value
247    }
248
249    /// Update maximum sample value from a slice of values
250    fn update_max_sample(&mut self, values: &[T]) {
251        for v in values {
252            let abs_val = f64::sqrt(Scalar::abs_sq(*v));
253            if abs_val > self.max_sample_value {
254                self.max_sample_value = abs_val;
255            }
256        }
257    }
258
259    /// Update maximum sample value from a matrix
260    fn update_max_sample_matrix(&mut self, mat: &Matrix<T>) {
261        for i in 0..mat.nrows() {
262            for j in 0..mat.ncols() {
263                let abs_val = f64::sqrt(Scalar::abs_sq(mat[[i, j]]));
264                if abs_val > self.max_sample_value {
265                    self.max_sample_value = abs_val;
266                }
267            }
268        }
269    }
270
271    /// Evaluate the TCI at a specific set of indices
272    #[allow(clippy::needless_range_loop)]
273    pub fn evaluate(&self, indices: &[usize]) -> Result<T> {
274        if indices.len() != self.len() {
275            return Err(TCIError::DimensionMismatch {
276                message: format!(
277                    "Index length ({}) must match number of sites ({})",
278                    indices.len(),
279                    self.len()
280                ),
281            });
282        }
283
284        if self.is_empty() {
285            return Err(TCIError::Empty);
286        }
287
288        // Check rank
289        if self.rank() == 0 {
290            return Err(TCIError::Empty);
291        }
292
293        // Evaluate by contracting site tensors
294        // For each site p, compute T[p][:, indices[p], :] * P[p]^{-1}
295        let mut result = {
296            let t = &self.t_tensors[0];
297            let idx = indices[0];
298            if idx >= t.site_dim() {
299                return Err(TCIError::IndexOutOfBounds {
300                    message: format!(
301                        "Index {} out of bounds at site 0 (max {})",
302                        idx,
303                        t.site_dim()
304                    ),
305                });
306            }
307            let slice = t.slice_site(idx);
308
309            // Apply P^{-1} if needed
310            let p = &self.p_matrices[0];
311            if p.nrows() > 0 && p.ncols() > 0 {
312                let slice_mat = vec_to_row_matrix(&slice);
313                let result_mat = a_times_b_inv(&slice_mat, p);
314                row_matrix_to_vec(&result_mat)
315            } else {
316                slice
317            }
318        };
319
320        // Contract with remaining sites
321        for p in 1..self.len() {
322            let t = &self.t_tensors[p];
323            let idx = indices[p];
324            if idx >= t.site_dim() {
325                return Err(TCIError::IndexOutOfBounds {
326                    message: format!(
327                        "Index {} out of bounds at site {} (max {})",
328                        idx,
329                        p,
330                        t.site_dim()
331                    ),
332                });
333            }
334
335            let right_dim = t.right_dim();
336
337            let next = row_vec_times_matrix(&result, &t.slice_site(idx), right_dim);
338
339            // Apply P^{-1}
340            let p_mat = &self.p_matrices[p];
341            if p_mat.nrows() > 0 && p_mat.ncols() > 0 {
342                let next_mat = vec_to_row_matrix(&next);
343                let result_mat = a_times_b_inv(&next_mat, p_mat);
344                result = row_matrix_to_vec(&result_mat);
345            } else {
346                result = next;
347            }
348        }
349
350        if result.len() != 1 {
351            return Err(TCIError::InvalidOperation {
352                message: format!("Final result should have size 1, got {}", result.len()),
353            });
354        }
355
356        Ok(result[0])
357    }
358
359    /// Get I set for a site
360    pub fn i_set(&self, p: usize) -> &IndexSet<MultiIndex> {
361        &self.i_set[p]
362    }
363
364    /// Get J set for a site
365    pub fn j_set(&self, p: usize) -> &IndexSet<MultiIndex> {
366        &self.j_set[p]
367    }
368
369    /// Build the Pi I set for site p
370    /// PiIset[p] = { [i..., up] : i in Iset[p], up in 1..localdims[p] }
371    fn get_pi_i_set(&self, p: usize) -> IndexSet<MultiIndex> {
372        let mut result = Vec::new();
373        for i_multi in self.i_set[p].iter() {
374            for up in 0..self.local_dims[p] {
375                let mut new_idx = i_multi.clone();
376                new_idx.push(up);
377                result.push(new_idx);
378            }
379        }
380        IndexSet::from_vec(result)
381    }
382
383    /// Build the Pi J set for site p
384    /// PiJset[p] = { [up+1, j...] : up+1 in 1..localdims[p], j in Jset[p] }
385    fn get_pi_j_set(&self, p: usize) -> IndexSet<MultiIndex> {
386        let mut result = Vec::new();
387        for up1 in 0..self.local_dims[p] {
388            for j_multi in self.j_set[p].iter() {
389                let mut new_idx = vec![up1];
390                new_idx.extend(j_multi.iter().cloned());
391                result.push(new_idx);
392            }
393        }
394        IndexSet::from_vec(result)
395    }
396
397    /// Build the Pi matrix at bond p
398    /// Pi[p][i, j] = f([PiIset[p][i]..., PiJset[p+1][j]...])
399    fn get_pi<F>(&mut self, p: usize, f: &F) -> Matrix<T>
400    where
401        F: Fn(&MultiIndex) -> T,
402    {
403        let i_set = &self.pi_i_set[p];
404        let j_set = &self.pi_j_set[p + 1];
405
406        let mut pi = zeros(i_set.len(), j_set.len());
407        for (i, i_multi) in i_set.iter().enumerate() {
408            for (j, j_multi) in j_set.iter().enumerate() {
409                let mut full_idx = i_multi.clone();
410                full_idx.extend(j_multi.iter().cloned());
411                pi[[i, j]] = f(&full_idx);
412            }
413        }
414
415        self.update_max_sample_matrix(&pi);
416        pi
417    }
418
419    /// Update Pi rows at site p (after I set changed at p+1)
420    fn update_pi_rows<F>(&mut self, p: usize, f: &F)
421    where
422        F: Fn(&MultiIndex) -> T,
423    {
424        let new_i_set = self.get_pi_i_set(p);
425        // Clone the old set to avoid borrow issues
426        let old_i_set: Vec<MultiIndex> = self.pi_i_set[p].iter().cloned().collect();
427        let old_i_set_ref = IndexSet::from_vec(old_i_set.clone());
428
429        // Find new indices
430        let new_indices: Vec<MultiIndex> = new_i_set
431            .iter()
432            .filter(|i| old_i_set_ref.pos(i).is_none())
433            .cloned()
434            .collect();
435
436        // Create new Pi matrix
437        let mut new_pi = zeros(new_i_set.len(), self.pi[p].ncols());
438
439        // Copy old rows
440        for (old_i, i_multi) in old_i_set.iter().enumerate() {
441            if let Some(new_i) = new_i_set.pos(i_multi) {
442                for j in 0..new_pi.ncols() {
443                    new_pi[[new_i, j]] = self.pi[p][[old_i, j]];
444                }
445            }
446        }
447
448        // Compute new rows
449        for i_multi in &new_indices {
450            if let Some(new_i) = new_i_set.pos(i_multi) {
451                for (j, j_multi) in self.pi_j_set[p + 1].iter().enumerate() {
452                    let mut full_idx = i_multi.clone();
453                    full_idx.extend(j_multi.iter().cloned());
454                    new_pi[[new_i, j]] = f(&full_idx);
455                }
456                // Update max sample
457                let row: Vec<T> = (0..new_pi.ncols()).map(|j| new_pi[[new_i, j]]).collect();
458                self.update_max_sample(&row);
459            }
460        }
461
462        self.pi[p] = new_pi;
463        self.pi_i_set[p] = new_i_set;
464
465        // Update ACA rows
466        let t_shape = (
467            self.t_tensors[p].left_dim(),
468            self.t_tensors[p].site_dim(),
469            self.t_tensors[p].right_dim(),
470        );
471        let t_p = tensor3_to_matrix_cols(&self.t_tensors[p], t_shape.0 * t_shape.1, t_shape.2);
472        let permutation: Vec<usize> = old_i_set
473            .iter()
474            .filter_map(|i| self.pi_i_set[p].pos(i))
475            .collect();
476        self.aca[p].set_rows(&t_p, &permutation);
477    }
478
479    /// Update Pi cols at site p (after J set changed at p)
480    fn update_pi_cols<F>(&mut self, p: usize, f: &F)
481    where
482        F: Fn(&MultiIndex) -> T,
483    {
484        let new_j_set = self.get_pi_j_set(p + 1);
485        // Clone the old set to avoid borrow issues
486        let old_j_set: Vec<MultiIndex> = self.pi_j_set[p + 1].iter().cloned().collect();
487        let old_j_set_ref = IndexSet::from_vec(old_j_set.clone());
488
489        // Find new indices
490        let new_indices: Vec<MultiIndex> = new_j_set
491            .iter()
492            .filter(|j| old_j_set_ref.pos(j).is_none())
493            .cloned()
494            .collect();
495
496        // Create new Pi matrix
497        let mut new_pi = zeros(self.pi[p].nrows(), new_j_set.len());
498
499        // Copy old columns
500        for (old_j, j_multi) in old_j_set.iter().enumerate() {
501            if let Some(new_j) = new_j_set.pos(j_multi) {
502                for i in 0..new_pi.nrows() {
503                    new_pi[[i, new_j]] = self.pi[p][[i, old_j]];
504                }
505            }
506        }
507
508        // Compute new columns
509        for j_multi in &new_indices {
510            if let Some(new_j) = new_j_set.pos(j_multi) {
511                for (i, i_multi) in self.pi_i_set[p].iter().enumerate() {
512                    let mut full_idx = i_multi.clone();
513                    full_idx.extend(j_multi.iter().cloned());
514                    new_pi[[i, new_j]] = f(&full_idx);
515                }
516                // Update max sample
517                let col: Vec<T> = (0..new_pi.nrows()).map(|i| new_pi[[i, new_j]]).collect();
518                self.update_max_sample(&col);
519            }
520        }
521
522        self.pi[p] = new_pi;
523        self.pi_j_set[p + 1] = new_j_set;
524
525        // Update ACA cols
526        let t_p1 = &self.t_tensors[p + 1];
527        let t_shape = (t_p1.left_dim(), t_p1.site_dim(), t_p1.right_dim());
528        let t_mat = tensor3_to_matrix_rows(t_p1, t_shape.0, t_shape.1 * t_shape.2);
529        let permutation: Vec<usize> = old_j_set
530            .iter()
531            .filter_map(|j| self.pi_j_set[p + 1].pos(j))
532            .collect();
533        self.aca[p].set_cols(&t_mat, &permutation);
534    }
535
536    /// Add a pivot row at bond p
537    fn add_pivot_row<F>(&mut self, p: usize, new_i: usize, f: &F) -> Result<()>
538    where
539        F: Fn(&MultiIndex) -> T,
540    {
541        // Add to ACA
542        let _ = self.aca[p].add_pivot_row(&self.pi[p], new_i);
543
544        // Add to I set at p+1
545        let new_i_multi =
546            self.pi_i_set[p]
547                .get(new_i)
548                .cloned()
549                .ok_or_else(|| TCIError::IndexInconsistency {
550                    message: format!("Missing pivot row index: bond={}, row={}", p, new_i),
551                })?;
552        self.i_set[p + 1].push(new_i_multi);
553
554        // Update T[p+1] - get all pivot rows from Pi
555        // Each row in the I set corresponds to a pivot row in Pi
556        let i_set_len = self.i_set[p + 1].len();
557        let local_dim = self.local_dims[p + 1];
558        let j_set_len = self.j_set[p + 1].len();
559
560        // Build pivot rows matrix: shape (i_set_len, local_dim * j_set_len)
561        let mut pivot_rows = zeros(i_set_len, local_dim * j_set_len);
562        for (row_idx, i_multi) in self.i_set[p + 1].iter().enumerate() {
563            if let Some(pi_row) = self.pi_i_set[p].pos(i_multi) {
564                for j in 0..self.pi[p].ncols() {
565                    pivot_rows[[row_idx, j]] = self.pi[p][[pi_row, j]];
566                }
567            }
568        }
569
570        self.t_tensors[p + 1] = matrix_to_tensor3(&pivot_rows, i_set_len, local_dim, j_set_len);
571
572        // Update P matrix using pivot values
573        self.update_p_matrix(p);
574
575        // Update adjacent Pi matrix if exists
576        if p < self.len() - 2 {
577            self.update_pi_rows(p + 1, f);
578        }
579        Ok(())
580    }
581
582    /// Add a pivot col at bond p
583    fn add_pivot_col<F>(&mut self, p: usize, new_j: usize, f: &F) -> Result<()>
584    where
585        F: Fn(&MultiIndex) -> T,
586    {
587        // Add to ACA
588        let _ = self.aca[p].add_pivot_col(&self.pi[p], new_j);
589
590        // Add to J set at p
591        let new_j_multi = self.pi_j_set[p + 1].get(new_j).cloned().ok_or_else(|| {
592            TCIError::IndexInconsistency {
593                message: format!("Missing pivot col index: bond={}, col={}", p, new_j),
594            }
595        })?;
596        self.j_set[p].push(new_j_multi);
597
598        // Update T[p] - get all pivot columns from Pi
599        // Each column in the J set corresponds to a pivot column in Pi
600        let i_set_len = self.i_set[p].len();
601        let local_dim = self.local_dims[p];
602        let j_set_len = self.j_set[p].len();
603
604        // Build pivot cols matrix: shape (i_set_len * local_dim, j_set_len)
605        let mut pivot_cols = zeros(i_set_len * local_dim, j_set_len);
606        for (col_idx, j_multi) in self.j_set[p].iter().enumerate() {
607            if let Some(pi_col) = self.pi_j_set[p + 1].pos(j_multi) {
608                for i in 0..self.pi[p].nrows() {
609                    pivot_cols[[i, col_idx]] = self.pi[p][[i, pi_col]];
610                }
611            }
612        }
613
614        self.t_tensors[p] = matrix_to_tensor3(&pivot_cols, i_set_len, local_dim, j_set_len);
615
616        // Update P matrix using pivot values
617        self.update_p_matrix(p);
618
619        // Update adjacent Pi matrix if exists
620        if p > 0 {
621            self.update_pi_cols(p - 1, f);
622        }
623        Ok(())
624    }
625
626    /// Update P matrix at bond p from current I and J sets
627    fn update_p_matrix(&mut self, p: usize) {
628        let i_set_len = self.i_set[p + 1].len();
629        let j_set_len = self.j_set[p].len();
630
631        let mut p_mat = zeros(i_set_len, j_set_len);
632        for (i, i_multi) in self.i_set[p + 1].iter().enumerate() {
633            for (j, j_multi) in self.j_set[p].iter().enumerate() {
634                if let (Some(pi_i), Some(pi_j)) = (
635                    self.pi_i_set[p].pos(i_multi),
636                    self.pi_j_set[p + 1].pos(j_multi),
637                ) {
638                    p_mat[[i, j]] = self.pi[p][[pi_i, pi_j]];
639                }
640            }
641        }
642        self.p_matrices[p] = p_mat;
643    }
644
645    /// Add a pivot at bond p
646    fn add_pivot<F>(&mut self, p: usize, f: &F, tolerance: f64) -> Result<()>
647    where
648        F: Fn(&MultiIndex) -> T,
649    {
650        if p >= self.len() - 1 {
651            return Ok(());
652        }
653
654        // Check if we've reached full rank
655        let pi_rows = self.pi[p].nrows();
656        let pi_cols = self.pi[p].ncols();
657        if self.aca[p].rank() >= pi_rows.min(pi_cols) {
658            self.pivot_errors[p] = 0.0;
659            return Ok(());
660        }
661
662        // Find new pivot using ACA
663        let new_pivot = self.aca[p].find_new_pivot(&self.pi[p]);
664
665        match new_pivot {
666            Ok(((new_i, new_j), error)) => {
667                let error_val = f64::sqrt(Scalar::abs_sq(error));
668                self.pivot_errors[p] = error_val;
669
670                if error_val < tolerance {
671                    return Ok(());
672                }
673
674                // Add pivot column first, then row
675                self.add_pivot_col(p, new_j, f)?;
676                self.add_pivot_row(p, new_i, f)?;
677            }
678            Err(_) => {
679                self.pivot_errors[p] = 0.0;
680            }
681        }
682        Ok(())
683    }
684
685    /// Initialize from function with first pivot
686    fn initialize_from_pivot<F>(&mut self, f: &F, first_pivot: &MultiIndex) -> Result<()>
687    where
688        F: Fn(&MultiIndex) -> T,
689    {
690        let first_value = f(first_pivot);
691        if Scalar::abs_sq(first_value) < 1e-30 {
692            return Err(TCIError::InvalidPivot {
693                message: "First pivot must have non-zero function value".to_string(),
694            });
695        }
696
697        self.max_sample_value = f64::sqrt(Scalar::abs_sq(first_value));
698        let n = self.len();
699
700        // Initialize I and J sets from first pivot
701        for p in 0..n {
702            let i_indices: MultiIndex = first_pivot[0..p].to_vec();
703            let j_indices: MultiIndex = first_pivot[p + 1..].to_vec();
704            self.i_set[p] = IndexSet::from_vec(vec![i_indices]);
705            self.j_set[p] = IndexSet::from_vec(vec![j_indices]);
706        }
707
708        // Build Pi I and J sets
709        for p in 0..n {
710            self.pi_i_set[p] = self.get_pi_i_set(p);
711            self.pi_j_set[p] = self.get_pi_j_set(p);
712        }
713
714        // Build Pi matrices
715        for p in 0..n - 1 {
716            self.pi[p] = self.get_pi(p, f);
717        }
718
719        // Initialize ACA and T tensors for each bond
720        for p in 0..n - 1 {
721            // Find local pivot position in Pi
722            let local_pivot = (
723                self.pi_i_set[p]
724                    .pos(&self.i_set[p + 1].get(0).cloned().unwrap_or_default())
725                    .unwrap_or(0),
726                self.pi_j_set[p + 1]
727                    .pos(&self.j_set[p].get(0).cloned().unwrap_or_default())
728                    .unwrap_or(0),
729            );
730
731            // Initialize ACA from Pi with the pivot
732            self.aca[p] = MatrixACA::from_matrix_with_pivot(&self.pi[p], local_pivot)?;
733
734            // Update T tensors
735            if p == 0 {
736                // T[0] from pivot column of Pi[0]
737                let pivot_col: Vec<T> = (0..self.pi[p].nrows())
738                    .map(|i| self.pi[p][[i, local_pivot.1]])
739                    .collect();
740                let col_mat = vec_to_col_matrix(&pivot_col);
741                self.t_tensors[0] = matrix_to_tensor3(&col_mat, 1, self.local_dims[0], 1);
742            }
743
744            // T[p+1] from pivot row of Pi[p]
745            let pivot_row: Vec<T> = (0..self.pi[p].ncols())
746                .map(|j| self.pi[p][[local_pivot.0, j]])
747                .collect();
748            let row_mat = vec_to_row_matrix(&pivot_row);
749            self.t_tensors[p + 1] = matrix_to_tensor3(&row_mat, 1, self.local_dims[p + 1], 1);
750
751            // P[p] = Pi[p][local_pivot]
752            let mut p_mat = zeros(1, 1);
753            p_mat[[0, 0]] = self.pi[p][[local_pivot.0, local_pivot.1]];
754            self.p_matrices[p] = p_mat;
755        }
756
757        // P[n-1] = identity (1x1 of ones)
758        let mut p_last = zeros(1, 1);
759        p_last[[0, 0]] = T::one();
760        self.p_matrices[n - 1] = p_last;
761
762        Ok(())
763    }
764}
765
766// Helper functions
767
768/// Convert a vector to a row matrix
769fn vec_to_row_matrix<T: Scalar>(v: &[T]) -> Matrix<T> {
770    let mut mat = zeros(1, v.len());
771    for (j, &val) in v.iter().enumerate() {
772        mat[[0, j]] = val;
773    }
774    mat
775}
776
777fn row_vec_times_matrix<T: Scalar>(row: &[T], matrix: &[T], ncols: usize) -> Vec<T> {
778    let nrows = row.len();
779    let row_mat = vec_to_row_matrix(row);
780    let matrix_mat = Matrix::from_raw_vec(nrows, ncols, matrix.to_vec());
781    let result_mat = mat_mul(&row_mat, &matrix_mat);
782    row_matrix_to_vec(&result_mat)
783}
784
785/// Convert a vector to a column matrix
786fn vec_to_col_matrix<T: Scalar>(v: &[T]) -> Matrix<T> {
787    let mut mat = zeros(v.len(), 1);
788    for (i, &val) in v.iter().enumerate() {
789        mat[[i, 0]] = val;
790    }
791    mat
792}
793
794/// Convert a row matrix to a vector
795fn row_matrix_to_vec<T: Scalar>(mat: &Matrix<T>) -> Vec<T> {
796    (0..mat.ncols()).map(|j| mat[[0, j]]).collect()
797}
798
799/// Convert Tensor3 to Matrix (reshape for columns: (left*site, right))
800fn tensor3_to_matrix<T: TTScalar + Scalar + Default>(tensor: &Tensor3<T>) -> Matrix<T> {
801    let left_dim = tensor.left_dim();
802    let site_dim = tensor.site_dim();
803    let right_dim = tensor.right_dim();
804    let rows = left_dim * site_dim;
805    let cols = right_dim;
806
807    let mut mat = zeros(rows, cols);
808    for l in 0..left_dim {
809        for s in 0..site_dim {
810            for r in 0..right_dim {
811                mat[[l * site_dim + s, r]] = *tensor.get3(l, s, r);
812            }
813        }
814    }
815    mat
816}
817
818/// Convert Tensor3 to Matrix for columns (left*site, right)
819fn tensor3_to_matrix_cols<T: TTScalar + Scalar + Default>(
820    tensor: &Tensor3<T>,
821    rows: usize,
822    cols: usize,
823) -> Matrix<T> {
824    let left_dim = tensor.left_dim();
825    let site_dim = tensor.site_dim();
826    let right_dim = tensor.right_dim();
827
828    let mut mat = zeros(rows, cols);
829    for l in 0..left_dim {
830        for s in 0..site_dim {
831            for r in 0..right_dim {
832                if l * site_dim + s < rows && r < cols {
833                    mat[[l * site_dim + s, r]] = *tensor.get3(l, s, r);
834                }
835            }
836        }
837    }
838    mat
839}
840
841/// Convert Tensor3 to Matrix for rows (left, site*right)
842fn tensor3_to_matrix_rows<T: TTScalar + Scalar + Default>(
843    tensor: &Tensor3<T>,
844    rows: usize,
845    cols: usize,
846) -> Matrix<T> {
847    let left_dim = tensor.left_dim();
848    let site_dim = tensor.site_dim();
849    let right_dim = tensor.right_dim();
850
851    let mut mat = zeros(rows, cols);
852    for l in 0..left_dim {
853        for s in 0..site_dim {
854            for r in 0..right_dim {
855                if l < rows && s * right_dim + r < cols {
856                    mat[[l, s * right_dim + r]] = *tensor.get3(l, s, r);
857                }
858            }
859        }
860    }
861    mat
862}
863
864/// Convert Matrix to Tensor3
865fn matrix_to_tensor3<T: TTScalar + Scalar + Default>(
866    mat: &Matrix<T>,
867    left_dim: usize,
868    site_dim: usize,
869    right_dim: usize,
870) -> Tensor3<T> {
871    let mut tensor = tensor3_zeros(left_dim, site_dim, right_dim);
872
873    // Determine the layout based on matrix dimensions
874    if mat.nrows() == left_dim * site_dim && mat.ncols() == right_dim {
875        // Column layout: (left*site, right)
876        for l in 0..left_dim {
877            for s in 0..site_dim {
878                for r in 0..right_dim {
879                    tensor.set3(l, s, r, mat[[l * site_dim + s, r]]);
880                }
881            }
882        }
883    } else if mat.nrows() == left_dim && mat.ncols() == site_dim * right_dim {
884        // Row layout: (left, site*right)
885        for l in 0..left_dim {
886            for s in 0..site_dim {
887                for r in 0..right_dim {
888                    tensor.set3(l, s, r, mat[[l, s * right_dim + r]]);
889                }
890            }
891        }
892    } else if mat.nrows() == 1 && mat.ncols() == site_dim {
893        // Single row with site values
894        for s in 0..site_dim {
895            tensor.set3(0, s, 0, mat[[0, s]]);
896        }
897    } else if mat.nrows() == site_dim && mat.ncols() == 1 {
898        // Single column with site values
899        for s in 0..site_dim {
900            tensor.set3(0, s, 0, mat[[s, 0]]);
901        }
902    }
903
904    tensor
905}
906
907/// Approximate a function as a tensor train using the TCI1 algorithm (legacy).
908///
909/// Prefer [`crossinterpolate2`](crate::crossinterpolate2) for new code.
910///
911/// # Arguments
912///
913/// * `f` -- Function to interpolate. Takes `&MultiIndex` (0-indexed) and
914///   returns a scalar.
915/// * `local_dims` -- Number of values each index can take.
916/// * `first_pivot` -- Starting multi-index. Must have the same length as
917///   `local_dims` and the function value must be non-zero.
918/// * `options` -- Algorithm configuration; see [`TCI1Options`].
919///
920/// # Returns
921///
922/// A tuple `(tci, ranks, errors)`:
923///
924/// * `tci: TensorCI1<T>` -- The interpolation state.
925/// * `ranks: Vec<usize>` -- Bond dimension after each sweep.
926/// * `errors: Vec<f64>` -- Error estimate after each sweep.
927///
928/// # Examples
929///
930/// ```
931/// use tensor4all_tensorci::{crossinterpolate1, TCI1Options};
932/// use tensor4all_simplett::AbstractTensorTrain;
933///
934/// let f = |idx: &Vec<usize>| (idx[0] + idx[1] + 1) as f64;
935/// let local_dims = vec![4, 4];
936/// let first_pivot = vec![3, 3];
937///
938/// let (tci, _ranks, _errors) = crossinterpolate1::<f64, _>(
939///     f,
940///     local_dims,
941///     first_pivot,
942///     TCI1Options { tolerance: 1e-10, ..TCI1Options::default() },
943/// ).unwrap();
944///
945/// let tt = tci.to_tensor_train().unwrap();
946/// let val = tt.evaluate(&[2, 3]).unwrap();
947/// assert!((val - 6.0).abs() < 1e-6); // f(2,3) = 2+3+1 = 6
948/// ```
949pub fn crossinterpolate1<T, F>(
950    f: F,
951    local_dims: Vec<usize>,
952    first_pivot: MultiIndex,
953    options: TCI1Options,
954) -> Result<(TensorCI1<T>, Vec<usize>, Vec<f64>)>
955where
956    T: Scalar + TTScalar + Default,
957    F: Fn(&MultiIndex) -> T,
958{
959    if local_dims.len() != first_pivot.len() {
960        return Err(TCIError::DimensionMismatch {
961            message: format!(
962                "local_dims length ({}) must match first_pivot length ({})",
963                local_dims.len(),
964                first_pivot.len()
965            ),
966        });
967    }
968
969    let mut tci = TensorCI1::new(local_dims.clone());
970    tci.initialize_from_pivot(&f, &first_pivot)?;
971
972    let n = tci.len();
973    let mut errors = Vec::new();
974    let mut ranks = Vec::new();
975
976    // Main iteration loop
977    for iter in 1..=options.max_iter {
978        // Sweep
979        if forward_sweep(options.sweep_strategy, iter) {
980            for bond_index in 0..n - 1 {
981                tci.add_pivot(bond_index, &f, options.pivot_tolerance)?;
982            }
983        } else {
984            for bond_index in (0..n - 1).rev() {
985                tci.add_pivot(bond_index, &f, options.pivot_tolerance)?;
986            }
987        }
988
989        // Record error and rank
990        let error = tci.last_sweep_pivot_error();
991        let error_normalized = if options.normalize_error && tci.max_sample_value > 0.0 {
992            error / tci.max_sample_value
993        } else {
994            error
995        };
996
997        errors.push(error_normalized);
998        ranks.push(tci.rank());
999
1000        if options.verbosity > 0 && iter % 10 == 0 {
1001            println!(
1002                "iteration = {}, rank = {}, error = {:.2e}",
1003                iter,
1004                tci.rank(),
1005                error_normalized
1006            );
1007        }
1008
1009        // Check convergence
1010        if error_normalized < options.tolerance {
1011            break;
1012        }
1013    }
1014
1015    Ok((tci, ranks, errors))
1016}
1017
1018#[cfg(test)]
1019mod tests;