Skip to main content

tensor4all_tensorci/
globalsearch.rs

1//! Post-hoc error estimation for tensor train approximations.
2//!
3//! After running [`crossinterpolate2`](crate::crossinterpolate2),
4//! [`estimate_true_error`] can verify the approximation quality by
5//! searching for multi-indices with large interpolation error via
6//! [`floating_zone`] optimization.
7
8use rand::Rng;
9use tensor4all_simplett::{AbstractTensorTrain, TTScalar, Tensor3Ops, TensorTrain};
10use tensor4all_tcicore::{MultiIndex, Scalar};
11
12/// Estimate the true interpolation error by searching for worst-case indices.
13///
14/// Launches [`floating_zone`] from `nsearch` random starting points (or
15/// from explicit `initial_points`), returning all found (pivot, error)
16/// pairs sorted by descending error.
17///
18/// This is useful as a post-hoc check: if the largest returned error is
19/// below your tolerance, you can be more confident in the approximation.
20///
21/// # Examples
22///
23/// ```
24/// use tensor4all_tensorci::estimate_true_error;
25/// use tensor4all_simplett::TensorTrain;
26///
27/// // Build a constant TT (value = 1.0) on a 4×4 grid
28/// let tt = TensorTrain::<f64>::constant(&[4, 4], 1.0);
29///
30/// // Exact function differs from the constant
31/// let f = |idx: &Vec<usize>| (idx[0] * idx[1]) as f64;
32/// let mut rng = rand::rng();
33///
34/// let errors = estimate_true_error(&tt, &f, 10, None, &mut rng);
35///
36/// // Results are sorted by descending error
37/// for w in errors.windows(2) {
38///     assert!(w[0].1 >= w[1].1, "must be sorted descending");
39/// }
40///
41/// // The worst-case error for |i*j - 1| on [0..4]x[0..4] is at (3,3): |9-1|=8
42/// let (best_pivot, max_err) = &errors[0];
43/// assert_eq!(*best_pivot, vec![3, 3]);
44/// assert!((max_err - 8.0).abs() < 1e-10);
45/// ```
46///
47/// # Arguments
48///
49/// * `tt` -- the tensor train approximation
50/// * `f` -- the exact function
51/// * `nsearch` -- number of random starting points (ignored when
52///   `initial_points` is `Some`)
53/// * `initial_points` -- explicit starting points for the search
54/// * `rng` -- random number generator
55///
56/// # Returns
57///
58/// `Vec<(MultiIndex, f64)>` sorted by descending error, with duplicate
59/// pivots removed.
60pub fn estimate_true_error<T, F>(
61    tt: &TensorTrain<T>,
62    f: &F,
63    nsearch: usize,
64    initial_points: Option<Vec<MultiIndex>>,
65    rng: &mut impl Rng,
66) -> Vec<(MultiIndex, f64)>
67where
68    T: Scalar + TTScalar,
69    F: Fn(&MultiIndex) -> T,
70{
71    let site_dims: Vec<usize> = (0..tt.len())
72        .map(|i| tt.site_tensor(i).site_dim())
73        .collect();
74
75    let points = if let Some(pts) = initial_points {
76        pts
77    } else {
78        (0..nsearch)
79            .map(|_| {
80                site_dims
81                    .iter()
82                    .map(|&d| rng.random_range(0..d))
83                    .collect::<MultiIndex>()
84            })
85            .collect()
86    };
87
88    let mut pivot_errors: Vec<(MultiIndex, f64)> = points
89        .into_iter()
90        .map(|init_p| floating_zone(tt, f, &site_dims, Some(&init_p), f64::MAX))
91        .collect();
92
93    // Sort by descending error
94    pivot_errors.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
95
96    // Remove duplicates (same pivot)
97    pivot_errors.dedup_by(|a, b| a.0 == b.0);
98
99    pivot_errors
100}
101
102/// Local search for the multi-index with the largest interpolation error.
103///
104/// Starting from `init_p`, sweeps through each site position, evaluating
105/// all local indices while fixing the others, and picks the index with
106/// the maximum error `|f(idx) - tt(idx)|`. Repeats until the error
107/// stops increasing or `early_stop_tol` is exceeded.
108///
109/// # Examples
110///
111/// ```
112/// use tensor4all_tensorci::floating_zone;
113/// use tensor4all_simplett::TensorTrain;
114///
115/// // Constant TT (value = 0.0) on a 4×4 grid
116/// let tt = TensorTrain::<f64>::constant(&[4, 4], 0.0);
117///
118/// // f(i,j) = i * j, so TT error = |i*j|
119/// let f = |idx: &Vec<usize>| (idx[0] * idx[1]) as f64;
120/// let local_dims = vec![4, 4];
121///
122/// // Search from (2, 2) without early stopping
123/// let (pivot, error) = floating_zone(&tt, &f, &local_dims, Some(&vec![2, 2]), f64::MAX);
124///
125/// // Should find maximum error at (3, 3): |3*3 - 0| = 9
126/// assert_eq!(pivot, vec![3, 3]);
127/// assert!((error - 9.0).abs() < 1e-10);
128/// ```
129///
130/// # Arguments
131///
132/// * `tt` -- the tensor train approximation
133/// * `f` -- the exact function
134/// * `local_dims` -- number of values each index can take
135/// * `init_p` -- starting point (`None` defaults to the all-zeros index)
136/// * `early_stop_tol` -- stop early once the error exceeds this value
137///   (use `f64::MAX` to search exhaustively)
138///
139/// # Returns
140///
141/// `(pivot, max_error)` -- the best multi-index found and its error.
142pub fn floating_zone<T, F>(
143    tt: &TensorTrain<T>,
144    f: &F,
145    local_dims: &[usize],
146    init_p: Option<&MultiIndex>,
147    early_stop_tol: f64,
148) -> (MultiIndex, f64)
149where
150    T: Scalar + TTScalar,
151    F: Fn(&MultiIndex) -> T,
152{
153    let n = local_dims.len();
154
155    let mut pivot = if let Some(p) = init_p {
156        p.clone()
157    } else {
158        vec![0; n]
159    };
160
161    let f_val = f(&pivot);
162    let tt_val = tt.evaluate(&pivot).unwrap_or(T::zero());
163    let diff = f_val - tt_val;
164    let mut max_error = f64::sqrt(Scalar::abs_sq(diff));
165
166    let max_sweeps = n * 10; // Reasonable upper bound
167    for _ in 0..max_sweeps {
168        let prev_max_error = max_error;
169
170        for ipos in 0..n {
171            // Evaluate all local indices at this position
172            let mut best_local_error = 0.0f64;
173            let mut best_local_idx = pivot[ipos];
174
175            for v in 0..local_dims[ipos] {
176                pivot[ipos] = v;
177                let f_val = f(&pivot);
178                let tt_val = tt.evaluate(&pivot).unwrap_or(T::zero());
179                let diff = f_val - tt_val;
180                let error = f64::sqrt(Scalar::abs_sq(diff));
181                if error > best_local_error {
182                    best_local_error = error;
183                    best_local_idx = v;
184                }
185            }
186
187            pivot[ipos] = best_local_idx;
188            // Keep max_error monotonically non-decreasing
189            max_error = max_error.max(best_local_error);
190        }
191
192        if max_error == prev_max_error || max_error > early_stop_tol {
193            break;
194        }
195    }
196
197    (pivot, max_error)
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn test_floating_zone_finds_error() {
206        use tensor4all_simplett::Tensor3Ops;
207
208        // Build a rank-1 TT approximation of constant 1.0
209        let mut t0 = tensor4all_simplett::tensor3_zeros(1, 4, 1);
210        let mut t1 = tensor4all_simplett::tensor3_zeros(1, 4, 1);
211        for s in 0..4 {
212            t0.set3(0, s, 0, 1.0);
213            t1.set3(0, s, 0, 1.0);
214        }
215        let tt = TensorTrain::new(vec![t0, t1]).unwrap();
216
217        // Verify TT evaluates to 1.0 everywhere
218        assert!((tt.evaluate(&[0, 0]).unwrap() - 1.0).abs() < 1e-14);
219        assert!((tt.evaluate(&[3, 3]).unwrap() - 1.0).abs() < 1e-14);
220
221        // Exact function has non-constant behavior
222        let f = |idx: &MultiIndex| (idx[0] * idx[1]) as f64;
223        let local_dims = vec![4, 4];
224
225        // Start from (1, 1) so initial error is not zero: |1*1 - 1| = 0
226        // Actually start from (2, 2) so initial error is |4 - 1| = 3
227        let (pivot, error) = floating_zone(&tt, &f, &local_dims, Some(&vec![2, 2]), f64::MAX);
228
229        // Error should be > 0 since tt=1 but f(i,j)=i*j varies
230        // The maximum error should be at (3,3): |9-1|=8
231        assert!(error > 0.0, "Error should be positive, got {}", error);
232        assert_eq!(pivot, vec![3, 3], "Should find max error at (3,3)");
233        assert!(
234            (error - 8.0).abs() < 1e-10,
235            "Error should be 8.0, got {}",
236            error
237        );
238    }
239
240    #[test]
241    fn test_estimate_true_error_sorted() {
242        // Build a constant TT (all 0)
243        let t0 = tensor4all_simplett::tensor3_zeros(1, 4, 1);
244        let t1 = tensor4all_simplett::tensor3_zeros(1, 4, 1);
245        let tt = TensorTrain::new(vec![t0, t1]).unwrap();
246
247        let f = |idx: &MultiIndex| (idx[0] + idx[1]) as f64;
248        let mut rng = rand::rng();
249
250        let errors = estimate_true_error(&tt, &f, 10, None, &mut rng);
251
252        // Verify sorted in descending order
253        for i in 0..errors.len().saturating_sub(1) {
254            assert!(
255                errors[i].1 >= errors[i + 1].1,
256                "Errors should be sorted descending: {} < {}",
257                errors[i].1,
258                errors[i + 1].1
259            );
260        }
261    }
262}