tensor4all_tensorci/globalsearch.rs
1//! Post-hoc error estimation for tensor train approximations.
2//!
3//! After running [`crossinterpolate2`](crate::crossinterpolate2),
4//! [`estimate_true_error`] can verify the approximation quality by
5//! searching for multi-indices with large interpolation error via
6//! [`floating_zone`] optimization.
7
8use rand::Rng;
9use tensor4all_simplett::{AbstractTensorTrain, TTScalar, Tensor3Ops, TensorTrain};
10use tensor4all_tcicore::{MultiIndex, Scalar};
11
12/// Estimate the true interpolation error by searching for worst-case indices.
13///
14/// Launches [`floating_zone`] from `nsearch` random starting points (or
15/// from explicit `initial_points`), returning all found (pivot, error)
16/// pairs sorted by descending error.
17///
18/// This is useful as a post-hoc check: if the largest returned error is
19/// below your tolerance, you can be more confident in the approximation.
20///
21/// # Examples
22///
23/// ```
24/// use tensor4all_tensorci::estimate_true_error;
25/// use tensor4all_simplett::TensorTrain;
26///
27/// // Build a constant TT (value = 1.0) on a 4×4 grid
28/// let tt = TensorTrain::<f64>::constant(&[4, 4], 1.0);
29///
30/// // Exact function differs from the constant
31/// let f = |idx: &Vec<usize>| (idx[0] * idx[1]) as f64;
32/// let mut rng = rand::rng();
33///
34/// let errors = estimate_true_error(&tt, &f, 10, None, &mut rng);
35///
36/// // Results are sorted by descending error
37/// for w in errors.windows(2) {
38/// assert!(w[0].1 >= w[1].1, "must be sorted descending");
39/// }
40///
41/// // The worst-case error for |i*j - 1| on [0..4]x[0..4] is at (3,3): |9-1|=8
42/// let (best_pivot, max_err) = &errors[0];
43/// assert_eq!(*best_pivot, vec![3, 3]);
44/// assert!((max_err - 8.0).abs() < 1e-10);
45/// ```
46///
47/// # Arguments
48///
49/// * `tt` -- the tensor train approximation
50/// * `f` -- the exact function
51/// * `nsearch` -- number of random starting points (ignored when
52/// `initial_points` is `Some`)
53/// * `initial_points` -- explicit starting points for the search
54/// * `rng` -- random number generator
55///
56/// # Returns
57///
58/// `Vec<(MultiIndex, f64)>` sorted by descending error, with duplicate
59/// pivots removed.
60pub fn estimate_true_error<T, F>(
61 tt: &TensorTrain<T>,
62 f: &F,
63 nsearch: usize,
64 initial_points: Option<Vec<MultiIndex>>,
65 rng: &mut impl Rng,
66) -> Vec<(MultiIndex, f64)>
67where
68 T: Scalar + TTScalar,
69 F: Fn(&MultiIndex) -> T,
70{
71 let site_dims: Vec<usize> = (0..tt.len())
72 .map(|i| tt.site_tensor(i).site_dim())
73 .collect();
74
75 let points = if let Some(pts) = initial_points {
76 pts
77 } else {
78 (0..nsearch)
79 .map(|_| {
80 site_dims
81 .iter()
82 .map(|&d| rng.random_range(0..d))
83 .collect::<MultiIndex>()
84 })
85 .collect()
86 };
87
88 let mut pivot_errors: Vec<(MultiIndex, f64)> = points
89 .into_iter()
90 .map(|init_p| floating_zone(tt, f, &site_dims, Some(&init_p), f64::MAX))
91 .collect();
92
93 // Sort by descending error
94 pivot_errors.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
95
96 // Remove duplicates (same pivot)
97 pivot_errors.dedup_by(|a, b| a.0 == b.0);
98
99 pivot_errors
100}
101
102/// Local search for the multi-index with the largest interpolation error.
103///
104/// Starting from `init_p`, sweeps through each site position, evaluating
105/// all local indices while fixing the others, and picks the index with
106/// the maximum error `|f(idx) - tt(idx)|`. Repeats until the error
107/// stops increasing or `early_stop_tol` is exceeded.
108///
109/// # Examples
110///
111/// ```
112/// use tensor4all_tensorci::floating_zone;
113/// use tensor4all_simplett::TensorTrain;
114///
115/// // Constant TT (value = 0.0) on a 4×4 grid
116/// let tt = TensorTrain::<f64>::constant(&[4, 4], 0.0);
117///
118/// // f(i,j) = i * j, so TT error = |i*j|
119/// let f = |idx: &Vec<usize>| (idx[0] * idx[1]) as f64;
120/// let local_dims = vec![4, 4];
121///
122/// // Search from (2, 2) without early stopping
123/// let (pivot, error) = floating_zone(&tt, &f, &local_dims, Some(&vec![2, 2]), f64::MAX);
124///
125/// // Should find maximum error at (3, 3): |3*3 - 0| = 9
126/// assert_eq!(pivot, vec![3, 3]);
127/// assert!((error - 9.0).abs() < 1e-10);
128/// ```
129///
130/// # Arguments
131///
132/// * `tt` -- the tensor train approximation
133/// * `f` -- the exact function
134/// * `local_dims` -- number of values each index can take
135/// * `init_p` -- starting point (`None` defaults to the all-zeros index)
136/// * `early_stop_tol` -- stop early once the error exceeds this value
137/// (use `f64::MAX` to search exhaustively)
138///
139/// # Returns
140///
141/// `(pivot, max_error)` -- the best multi-index found and its error.
142pub fn floating_zone<T, F>(
143 tt: &TensorTrain<T>,
144 f: &F,
145 local_dims: &[usize],
146 init_p: Option<&MultiIndex>,
147 early_stop_tol: f64,
148) -> (MultiIndex, f64)
149where
150 T: Scalar + TTScalar,
151 F: Fn(&MultiIndex) -> T,
152{
153 let n = local_dims.len();
154
155 let mut pivot = if let Some(p) = init_p {
156 p.clone()
157 } else {
158 vec![0; n]
159 };
160
161 let f_val = f(&pivot);
162 let tt_val = tt.evaluate(&pivot).unwrap_or(T::zero());
163 let diff = f_val - tt_val;
164 let mut max_error = f64::sqrt(Scalar::abs_sq(diff));
165
166 let max_sweeps = n * 10; // Reasonable upper bound
167 for _ in 0..max_sweeps {
168 let prev_max_error = max_error;
169
170 for ipos in 0..n {
171 // Evaluate all local indices at this position
172 let mut best_local_error = 0.0f64;
173 let mut best_local_idx = pivot[ipos];
174
175 for v in 0..local_dims[ipos] {
176 pivot[ipos] = v;
177 let f_val = f(&pivot);
178 let tt_val = tt.evaluate(&pivot).unwrap_or(T::zero());
179 let diff = f_val - tt_val;
180 let error = f64::sqrt(Scalar::abs_sq(diff));
181 if error > best_local_error {
182 best_local_error = error;
183 best_local_idx = v;
184 }
185 }
186
187 pivot[ipos] = best_local_idx;
188 // Keep max_error monotonically non-decreasing
189 max_error = max_error.max(best_local_error);
190 }
191
192 if max_error == prev_max_error || max_error > early_stop_tol {
193 break;
194 }
195 }
196
197 (pivot, max_error)
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn test_floating_zone_finds_error() {
206 use tensor4all_simplett::Tensor3Ops;
207
208 // Build a rank-1 TT approximation of constant 1.0
209 let mut t0 = tensor4all_simplett::tensor3_zeros(1, 4, 1);
210 let mut t1 = tensor4all_simplett::tensor3_zeros(1, 4, 1);
211 for s in 0..4 {
212 t0.set3(0, s, 0, 1.0);
213 t1.set3(0, s, 0, 1.0);
214 }
215 let tt = TensorTrain::new(vec![t0, t1]).unwrap();
216
217 // Verify TT evaluates to 1.0 everywhere
218 assert!((tt.evaluate(&[0, 0]).unwrap() - 1.0).abs() < 1e-14);
219 assert!((tt.evaluate(&[3, 3]).unwrap() - 1.0).abs() < 1e-14);
220
221 // Exact function has non-constant behavior
222 let f = |idx: &MultiIndex| (idx[0] * idx[1]) as f64;
223 let local_dims = vec![4, 4];
224
225 // Start from (1, 1) so initial error is not zero: |1*1 - 1| = 0
226 // Actually start from (2, 2) so initial error is |4 - 1| = 3
227 let (pivot, error) = floating_zone(&tt, &f, &local_dims, Some(&vec![2, 2]), f64::MAX);
228
229 // Error should be > 0 since tt=1 but f(i,j)=i*j varies
230 // The maximum error should be at (3,3): |9-1|=8
231 assert!(error > 0.0, "Error should be positive, got {}", error);
232 assert_eq!(pivot, vec![3, 3], "Should find max error at (3,3)");
233 assert!(
234 (error - 8.0).abs() < 1e-10,
235 "Error should be 8.0, got {}",
236 error
237 );
238 }
239
240 #[test]
241 fn test_estimate_true_error_sorted() {
242 // Build a constant TT (all 0)
243 let t0 = tensor4all_simplett::tensor3_zeros(1, 4, 1);
244 let t1 = tensor4all_simplett::tensor3_zeros(1, 4, 1);
245 let tt = TensorTrain::new(vec![t0, t1]).unwrap();
246
247 let f = |idx: &MultiIndex| (idx[0] + idx[1]) as f64;
248 let mut rng = rand::rng();
249
250 let errors = estimate_true_error(&tt, &f, 10, None, &mut rng);
251
252 // Verify sorted in descending order
253 for i in 0..errors.len().saturating_sub(1) {
254 assert!(
255 errors[i].1 >= errors[i + 1].1,
256 "Errors should be sorted descending: {} < {}",
257 errors[i].1,
258 errors[i + 1].1
259 );
260 }
261 }
262}