Skip to main content

tensor4all_tensorci/
globalpivot.rs

1//! Global pivot finder for the TCI2 algorithm.
2//!
3//! After each two-site sweep, the TCI2 algorithm calls a
4//! [`GlobalPivotFinder`] to locate regions of high interpolation error
5//! that the local sweeps may have missed. The default implementation
6//! ([`DefaultGlobalPivotFinder`]) uses random starting points with local
7//! optimization.
8
9use rand::Rng;
10use tensor4all_simplett::{AbstractTensorTrain, TTScalar, TensorTrain};
11use tensor4all_tcicore::{MultiIndex, Scalar};
12
13/// Snapshot of the current TCI state, passed to [`GlobalPivotFinder`].
14pub struct GlobalPivotSearchInput<T: Scalar + TTScalar> {
15    /// Local dimensions of each tensor index.
16    pub local_dims: Vec<usize>,
17    /// Current tensor train approximation.
18    pub current_tt: TensorTrain<T>,
19    /// Maximum absolute function value encountered so far.
20    pub max_sample_value: f64,
21    /// Left index sets (I) for each site.
22    pub i_set: Vec<Vec<MultiIndex>>,
23    /// Right index sets (J) for each site.
24    pub j_set: Vec<Vec<MultiIndex>>,
25}
26
27/// Trait for global pivot finders.
28///
29/// Implementors search for multi-indices where the interpolation error
30/// `|f(idx) - tt(idx)|` is large. Found pivots are added to the TCI state
31/// to improve the approximation in the next sweep.
32///
33/// The default implementation is [`DefaultGlobalPivotFinder`]. Implement
34/// this trait to supply domain-specific search strategies.
35///
36/// # Examples
37///
38/// Using the default implementation via [`DefaultGlobalPivotFinder`]:
39///
40/// ```
41/// use tensor4all_tensorci::{DefaultGlobalPivotFinder, GlobalPivotFinder,
42///     GlobalPivotSearchInput};
43/// use tensor4all_simplett::TensorTrain;
44///
45/// // Constant-zero TT on a 4×4 grid
46/// let tt = TensorTrain::<f64>::constant(&[4, 4], 0.0);
47///
48/// let input = GlobalPivotSearchInput {
49///     local_dims: vec![4, 4],
50///     current_tt: tt,
51///     max_sample_value: 9.0,
52///     i_set: vec![vec![vec![]], vec![vec![0]]],
53///     j_set: vec![vec![vec![0]], vec![vec![]]],
54/// };
55///
56/// let finder = DefaultGlobalPivotFinder::default();
57/// let mut rng = rand::rng();
58///
59/// // f(i,j) = i*j has nonzero values, so pivots should be found
60/// let pivots = finder.find_global_pivots(
61///     &input,
62///     &|idx: &Vec<usize>| (idx[0] * idx[1]) as f64,
63///     0.1,
64///     &mut rng,
65/// );
66///
67/// // All returned pivots must have valid indices
68/// for p in &pivots {
69///     assert_eq!(p.len(), 2);
70///     assert!(p[0] < 4 && p[1] < 4);
71/// }
72/// ```
73pub trait GlobalPivotFinder {
74    /// Find multi-indices with high interpolation error.
75    ///
76    /// # Arguments
77    ///
78    /// * `input` -- current TCI state (tensor train, index sets, etc.)
79    /// * `f` -- the function being interpolated
80    /// * `abs_tol` -- absolute tolerance; pivots with error above this
81    ///   threshold (times `tol_margin`) are interesting
82    /// * `rng` -- random number generator for stochastic search
83    ///
84    /// # Returns
85    ///
86    /// Multi-indices where the interpolation error is large, up to the
87    /// implementation's maximum count.
88    fn find_global_pivots<T, F>(
89        &self,
90        input: &GlobalPivotSearchInput<T>,
91        f: &F,
92        abs_tol: f64,
93        rng: &mut impl Rng,
94    ) -> Vec<MultiIndex>
95    where
96        T: Scalar + TTScalar,
97        F: Fn(&MultiIndex) -> T;
98}
99
100/// Default global pivot finder using random search with local optimization.
101///
102/// Algorithm:
103///
104/// 1. Generate `nsearch` random initial points.
105/// 2. For each point, sweep all dimensions and pick the index with the
106///    largest interpolation error at each position.
107/// 3. Keep points where the error exceeds `abs_tol * tol_margin`.
108/// 4. Return at most `max_nglobal_pivot` results.
109///
110/// # Examples
111///
112/// ```
113/// use tensor4all_tensorci::DefaultGlobalPivotFinder;
114///
115/// // Default configuration
116/// let finder = DefaultGlobalPivotFinder::default();
117/// assert_eq!(finder.nsearch, 5);
118/// assert_eq!(finder.max_nglobal_pivot, 5);
119/// assert!((finder.tol_margin - 10.0).abs() < 1e-15);
120///
121/// // Custom configuration
122/// let custom = DefaultGlobalPivotFinder::new(20, 10, 5.0);
123/// assert_eq!(custom.nsearch, 20);
124/// assert_eq!(custom.max_nglobal_pivot, 10);
125/// assert!((custom.tol_margin - 5.0).abs() < 1e-15);
126/// ```
127#[derive(Debug, Clone)]
128pub struct DefaultGlobalPivotFinder {
129    /// Number of random initial points to search from
130    pub nsearch: usize,
131    /// Maximum number of pivots to add per iteration
132    pub max_nglobal_pivot: usize,
133    /// Search for pivots with error > abs_tol × tol_margin
134    pub tol_margin: f64,
135}
136
137impl Default for DefaultGlobalPivotFinder {
138    fn default() -> Self {
139        Self {
140            nsearch: 5,
141            max_nglobal_pivot: 5,
142            tol_margin: 10.0,
143        }
144    }
145}
146
147impl DefaultGlobalPivotFinder {
148    /// Create a new DefaultGlobalPivotFinder with the given parameters.
149    pub fn new(nsearch: usize, max_nglobal_pivot: usize, tol_margin: f64) -> Self {
150        Self {
151            nsearch,
152            max_nglobal_pivot,
153            tol_margin,
154        }
155    }
156}
157
158impl GlobalPivotFinder for DefaultGlobalPivotFinder {
159    fn find_global_pivots<T, F>(
160        &self,
161        input: &GlobalPivotSearchInput<T>,
162        f: &F,
163        abs_tol: f64,
164        rng: &mut impl Rng,
165    ) -> Vec<MultiIndex>
166    where
167        T: Scalar + TTScalar,
168        F: Fn(&MultiIndex) -> T,
169    {
170        let n = input.local_dims.len();
171
172        // Generate random initial points
173        let initial_points: Vec<MultiIndex> = (0..self.nsearch)
174            .map(|_| {
175                (0..n)
176                    .map(|p| rng.random_range(0..input.local_dims[p]))
177                    .collect()
178            })
179            .collect();
180
181        let mut found_pivots: Vec<MultiIndex> = Vec::new();
182
183        for point in &initial_points {
184            let mut current_point = point.clone();
185            let mut best_error = 0.0f64;
186            let mut best_point = point.clone();
187
188            // Local search: sweep all dimensions
189            for p in 0..n {
190                let original = current_point[p];
191                for v in 0..input.local_dims[p] {
192                    current_point[p] = v;
193                    let f_val = f(&current_point);
194                    let tt_val = input
195                        .current_tt
196                        .evaluate(&current_point)
197                        .unwrap_or(T::zero());
198                    let diff = f_val - tt_val;
199                    let error = f64::sqrt(Scalar::abs_sq(diff));
200                    if error > best_error {
201                        best_error = error;
202                        best_point = current_point.clone();
203                    }
204                }
205                current_point[p] = original; // Reset to original for next dimension
206            }
207
208            // Add point if error exceeds threshold
209            if best_error > abs_tol * self.tol_margin {
210                found_pivots.push(best_point);
211            }
212        }
213
214        // Limit number of pivots
215        found_pivots.truncate(self.max_nglobal_pivot);
216
217        found_pivots
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn test_default_global_pivot_finder() {
227        // Simple function: f(i, j) = i * j
228        let f = |idx: &MultiIndex| (idx[0] * idx[1]) as f64;
229        let local_dims = vec![4, 4];
230
231        // Build a deliberately bad TT (constant 0) to ensure large errors
232        let tensors = vec![
233            tensor4all_simplett::tensor3_zeros(1, 4, 1),
234            tensor4all_simplett::tensor3_zeros(1, 4, 1),
235        ];
236        let tt = TensorTrain::new(tensors).unwrap();
237
238        let input = GlobalPivotSearchInput {
239            local_dims: local_dims.clone(),
240            current_tt: tt,
241            max_sample_value: 9.0,
242            i_set: vec![vec![vec![]], vec![vec![0]]],
243            j_set: vec![vec![vec![0]], vec![vec![]]],
244        };
245
246        let finder = DefaultGlobalPivotFinder::new(10, 3, 1.0);
247        let mut rng = rand::rng();
248
249        let pivots = finder.find_global_pivots(&input, &f, 0.1, &mut rng);
250
251        // Should find some pivots since the TT is zero but f is not
252        // (except at i=0 or j=0)
253        assert!(
254            pivots.len() <= 3,
255            "Should limit to max_nglobal_pivot=3, got {}",
256            pivots.len()
257        );
258    }
259
260    #[test]
261    fn test_custom_global_pivot_finder() {
262        struct FixedPivotFinder;
263
264        impl GlobalPivotFinder for FixedPivotFinder {
265            fn find_global_pivots<T, F>(
266                &self,
267                _input: &GlobalPivotSearchInput<T>,
268                _f: &F,
269                _abs_tol: f64,
270                _rng: &mut impl Rng,
271            ) -> Vec<MultiIndex>
272            where
273                T: Scalar + TTScalar,
274                F: Fn(&MultiIndex) -> T,
275            {
276                // Always return a fixed pivot
277                vec![vec![1, 2]]
278            }
279        }
280
281        let finder = FixedPivotFinder;
282        let f = |_: &MultiIndex| 1.0f64;
283        let tensors = vec![
284            tensor4all_simplett::tensor3_zeros(1, 3, 1),
285            tensor4all_simplett::tensor3_zeros(1, 3, 1),
286        ];
287        let tt = TensorTrain::new(tensors).unwrap();
288
289        let input = GlobalPivotSearchInput {
290            local_dims: vec![3, 3],
291            current_tt: tt,
292            max_sample_value: 1.0,
293            i_set: vec![],
294            j_set: vec![],
295        };
296
297        let mut rng = rand::rng();
298        let pivots = finder.find_global_pivots(&input, &f, 0.0, &mut rng);
299        assert_eq!(pivots, vec![vec![1, 2]]);
300    }
301}