tensor4all_tensorci/optfirstpivot.rs
1//! Utility for optimizing the first pivot for tensor cross interpolation.
2//!
3//! A good initial pivot (where `|f|` is large) improves TCI convergence.
4//! [`opt_first_pivot`] performs a greedy local search starting from a
5//! user-supplied guess.
6
7use tensor4all_simplett::TTScalar;
8use tensor4all_tcicore::{MultiIndex, Scalar};
9
10/// Optimize the initial pivot by greedy coordinate-wise search.
11///
12/// Starting from `first_pivot`, sweeps through each dimension and replaces
13/// the current index with whichever value maximizes `|f(pivot)|`. Repeats
14/// until no improvement is found or `max_sweep` sweeps have been performed.
15///
16/// # Arguments
17///
18/// * `f` -- function to interpolate
19/// * `local_dims` -- number of values each index can take
20/// * `first_pivot` -- starting point for the search
21/// * `max_sweep` -- maximum number of full sweeps (1000 is a safe default)
22///
23/// # Returns
24///
25/// The optimized pivot (multi-index with the largest `|f|` found).
26///
27/// # Examples
28///
29/// ```
30/// use tensor4all_tensorci::opt_first_pivot;
31///
32/// let f = |idx: &Vec<usize>| (idx[0] as f64 + idx[1] as f64 + 1.0).powi(2);
33/// let local_dims = vec![4, 4];
34/// let start = vec![0, 0]; // f(0,0) = 1.0
35///
36/// let pivot = opt_first_pivot::<f64, _>(&f, &local_dims, &start, 1000);
37/// // Should find the maximum: f(3,3) = 49.0
38/// assert_eq!(pivot, vec![3, 3]);
39/// ```
40pub fn opt_first_pivot<T, F>(
41 f: &F,
42 local_dims: &[usize],
43 first_pivot: &MultiIndex,
44 max_sweep: usize,
45) -> MultiIndex
46where
47 T: Scalar + TTScalar,
48 F: Fn(&MultiIndex) -> T,
49{
50 let n = local_dims.len();
51 let mut pivot = first_pivot.to_vec();
52 let mut val_f = f64::sqrt(Scalar::abs_sq(f(&pivot)));
53
54 for _ in 0..max_sweep {
55 let prev_val = val_f;
56 for i in 0..n {
57 for d in 0..local_dims[i] {
58 let bak = pivot[i];
59 pivot[i] = d;
60 let new_val = f64::sqrt(Scalar::abs_sq(f(&pivot)));
61 if new_val > val_f {
62 val_f = new_val;
63 } else {
64 pivot[i] = bak;
65 }
66 }
67 }
68 if prev_val == val_f {
69 break;
70 }
71 }
72
73 pivot
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79
80 #[test]
81 fn test_opt_first_pivot_finds_nonzero() {
82 // Function that is zero at origin but nonzero elsewhere
83 let f = |idx: &MultiIndex| (idx[0] as f64 + idx[1] as f64 + 1.0).powi(2);
84 let local_dims = vec![4, 4];
85 let first_pivot = vec![0, 0]; // f(0,0) = 1.0
86
87 let pivot = opt_first_pivot::<f64, _>(&f, &local_dims, &first_pivot, 1000);
88 // Should find the maximum: f(3,3) = 49.0
89 assert_eq!(pivot, vec![3, 3]);
90 }
91
92 #[test]
93 fn test_opt_first_pivot_already_optimal() {
94 let f = |idx: &MultiIndex| idx[0] as f64 * idx[1] as f64;
95 let local_dims = vec![4, 4];
96 let first_pivot = vec![3, 3]; // Already at maximum
97
98 let pivot = opt_first_pivot::<f64, _>(&f, &local_dims, &first_pivot, 1000);
99 assert_eq!(pivot, vec![3, 3]);
100 }
101
102 #[test]
103 fn test_opt_first_pivot_complex() {
104 use num_complex::Complex64;
105 let f = |idx: &MultiIndex| Complex64::new(idx[0] as f64 * 2.0, idx[1] as f64 * 3.0);
106 let local_dims = vec![4, 4];
107 let first_pivot = vec![0, 0];
108
109 let pivot = opt_first_pivot::<Complex64, _>(&f, &local_dims, &first_pivot, 1000);
110 // Maximum |f| at (3, 3) = |6 + 9i| = sqrt(117)
111 assert_eq!(pivot, vec![3, 3]);
112 }
113}