Skip to main content

tensor4all_tensorci/
optfirstpivot.rs

1//! Utility for optimizing the first pivot for tensor cross interpolation.
2//!
3//! A good initial pivot (where `|f|` is large) improves TCI convergence.
4//! [`opt_first_pivot`] performs a greedy local search starting from a
5//! user-supplied guess.
6
7use tensor4all_simplett::TTScalar;
8use tensor4all_tcicore::{MultiIndex, Scalar};
9
10/// Optimize the initial pivot by greedy coordinate-wise search.
11///
12/// Starting from `first_pivot`, sweeps through each dimension and replaces
13/// the current index with whichever value maximizes `|f(pivot)|`. Repeats
14/// until no improvement is found or `max_sweep` sweeps have been performed.
15///
16/// # Arguments
17///
18/// * `f` -- function to interpolate
19/// * `local_dims` -- number of values each index can take
20/// * `first_pivot` -- starting point for the search
21/// * `max_sweep` -- maximum number of full sweeps (1000 is a safe default)
22///
23/// # Returns
24///
25/// The optimized pivot (multi-index with the largest `|f|` found).
26///
27/// # Examples
28///
29/// ```
30/// use tensor4all_tensorci::opt_first_pivot;
31///
32/// let f = |idx: &Vec<usize>| (idx[0] as f64 + idx[1] as f64 + 1.0).powi(2);
33/// let local_dims = vec![4, 4];
34/// let start = vec![0, 0]; // f(0,0) = 1.0
35///
36/// let pivot = opt_first_pivot::<f64, _>(&f, &local_dims, &start, 1000);
37/// // Should find the maximum: f(3,3) = 49.0
38/// assert_eq!(pivot, vec![3, 3]);
39/// ```
40pub fn opt_first_pivot<T, F>(
41    f: &F,
42    local_dims: &[usize],
43    first_pivot: &MultiIndex,
44    max_sweep: usize,
45) -> MultiIndex
46where
47    T: Scalar + TTScalar,
48    F: Fn(&MultiIndex) -> T,
49{
50    let n = local_dims.len();
51    let mut pivot = first_pivot.to_vec();
52    let mut val_f = f64::sqrt(Scalar::abs_sq(f(&pivot)));
53
54    for _ in 0..max_sweep {
55        let prev_val = val_f;
56        for i in 0..n {
57            for d in 0..local_dims[i] {
58                let bak = pivot[i];
59                pivot[i] = d;
60                let new_val = f64::sqrt(Scalar::abs_sq(f(&pivot)));
61                if new_val > val_f {
62                    val_f = new_val;
63                } else {
64                    pivot[i] = bak;
65                }
66            }
67        }
68        if prev_val == val_f {
69            break;
70        }
71    }
72
73    pivot
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79
80    #[test]
81    fn test_opt_first_pivot_finds_nonzero() {
82        // Function that is zero at origin but nonzero elsewhere
83        let f = |idx: &MultiIndex| (idx[0] as f64 + idx[1] as f64 + 1.0).powi(2);
84        let local_dims = vec![4, 4];
85        let first_pivot = vec![0, 0]; // f(0,0) = 1.0
86
87        let pivot = opt_first_pivot::<f64, _>(&f, &local_dims, &first_pivot, 1000);
88        // Should find the maximum: f(3,3) = 49.0
89        assert_eq!(pivot, vec![3, 3]);
90    }
91
92    #[test]
93    fn test_opt_first_pivot_already_optimal() {
94        let f = |idx: &MultiIndex| idx[0] as f64 * idx[1] as f64;
95        let local_dims = vec![4, 4];
96        let first_pivot = vec![3, 3]; // Already at maximum
97
98        let pivot = opt_first_pivot::<f64, _>(&f, &local_dims, &first_pivot, 1000);
99        assert_eq!(pivot, vec![3, 3]);
100    }
101
102    #[test]
103    fn test_opt_first_pivot_complex() {
104        use num_complex::Complex64;
105        let f = |idx: &MultiIndex| Complex64::new(idx[0] as f64 * 2.0, idx[1] as f64 * 3.0);
106        let local_dims = vec![4, 4];
107        let first_pivot = vec![0, 0];
108
109        let pivot = opt_first_pivot::<Complex64, _>(&f, &local_dims, &first_pivot, 1000);
110        // Maximum |f| at (3, 3) = |6 + 9i| = sqrt(117)
111        assert_eq!(pivot, vec![3, 3]);
112    }
113}