tensor4all_tensorci/globalpivot.rs
1//! Global pivot finder for the TCI2 algorithm.
2//!
3//! After each two-site sweep, the TCI2 algorithm calls a
4//! [`GlobalPivotFinder`] to locate regions of high interpolation error
5//! that the local sweeps may have missed. The default implementation
6//! ([`DefaultGlobalPivotFinder`]) uses random starting points with local
7//! optimization.
8
9use rand::Rng;
10use tensor4all_simplett::{AbstractTensorTrain, TTScalar, TensorTrain};
11use tensor4all_tcicore::{MultiIndex, Scalar};
12
13/// Snapshot of the current TCI state, passed to [`GlobalPivotFinder`].
14pub struct GlobalPivotSearchInput<T: Scalar + TTScalar> {
15 /// Local dimensions of each tensor index.
16 pub local_dims: Vec<usize>,
17 /// Current tensor train approximation.
18 pub current_tt: TensorTrain<T>,
19 /// Maximum absolute function value encountered so far.
20 pub max_sample_value: f64,
21 /// Left index sets (I) for each site.
22 pub i_set: Vec<Vec<MultiIndex>>,
23 /// Right index sets (J) for each site.
24 pub j_set: Vec<Vec<MultiIndex>>,
25}
26
27/// Trait for global pivot finders.
28///
29/// Implementors search for multi-indices where the interpolation error
30/// `|f(idx) - tt(idx)|` is large. Found pivots are added to the TCI state
31/// to improve the approximation in the next sweep.
32///
33/// The default implementation is [`DefaultGlobalPivotFinder`]. Implement
34/// this trait to supply domain-specific search strategies.
35///
36/// # Examples
37///
38/// Using the default implementation via [`DefaultGlobalPivotFinder`]:
39///
40/// ```
41/// use tensor4all_tensorci::{DefaultGlobalPivotFinder, GlobalPivotFinder,
42/// GlobalPivotSearchInput};
43/// use tensor4all_simplett::TensorTrain;
44///
45/// // Constant-zero TT on a 4×4 grid
46/// let tt = TensorTrain::<f64>::constant(&[4, 4], 0.0);
47///
48/// let input = GlobalPivotSearchInput {
49/// local_dims: vec![4, 4],
50/// current_tt: tt,
51/// max_sample_value: 9.0,
52/// i_set: vec![vec![vec![]], vec![vec![0]]],
53/// j_set: vec![vec![vec![0]], vec![vec![]]],
54/// };
55///
56/// let finder = DefaultGlobalPivotFinder::default();
57/// let mut rng = rand::rng();
58///
59/// // f(i,j) = i*j has nonzero values, so pivots should be found
60/// let pivots = finder.find_global_pivots(
61/// &input,
62/// &|idx: &Vec<usize>| (idx[0] * idx[1]) as f64,
63/// 0.1,
64/// &mut rng,
65/// );
66///
67/// // All returned pivots must have valid indices
68/// for p in &pivots {
69/// assert_eq!(p.len(), 2);
70/// assert!(p[0] < 4 && p[1] < 4);
71/// }
72/// ```
73pub trait GlobalPivotFinder {
74 /// Find multi-indices with high interpolation error.
75 ///
76 /// # Arguments
77 ///
78 /// * `input` -- current TCI state (tensor train, index sets, etc.)
79 /// * `f` -- the function being interpolated
80 /// * `abs_tol` -- absolute tolerance; pivots with error above this
81 /// threshold (times `tol_margin`) are interesting
82 /// * `rng` -- random number generator for stochastic search
83 ///
84 /// # Returns
85 ///
86 /// Multi-indices where the interpolation error is large, up to the
87 /// implementation's maximum count.
88 fn find_global_pivots<T, F>(
89 &self,
90 input: &GlobalPivotSearchInput<T>,
91 f: &F,
92 abs_tol: f64,
93 rng: &mut impl Rng,
94 ) -> Vec<MultiIndex>
95 where
96 T: Scalar + TTScalar,
97 F: Fn(&MultiIndex) -> T;
98}
99
100/// Default global pivot finder using random search with local optimization.
101///
102/// Algorithm:
103///
104/// 1. Generate `nsearch` random initial points.
105/// 2. For each point, sweep all dimensions and pick the index with the
106/// largest interpolation error at each position.
107/// 3. Keep points where the error exceeds `abs_tol * tol_margin`.
108/// 4. Return at most `max_nglobal_pivot` results.
109///
110/// # Examples
111///
112/// ```
113/// use tensor4all_tensorci::DefaultGlobalPivotFinder;
114///
115/// // Default configuration
116/// let finder = DefaultGlobalPivotFinder::default();
117/// assert_eq!(finder.nsearch, 5);
118/// assert_eq!(finder.max_nglobal_pivot, 5);
119/// assert!((finder.tol_margin - 10.0).abs() < 1e-15);
120///
121/// // Custom configuration
122/// let custom = DefaultGlobalPivotFinder::new(20, 10, 5.0);
123/// assert_eq!(custom.nsearch, 20);
124/// assert_eq!(custom.max_nglobal_pivot, 10);
125/// assert!((custom.tol_margin - 5.0).abs() < 1e-15);
126/// ```
127#[derive(Debug, Clone)]
128pub struct DefaultGlobalPivotFinder {
129 /// Number of random initial points to search from
130 pub nsearch: usize,
131 /// Maximum number of pivots to add per iteration
132 pub max_nglobal_pivot: usize,
133 /// Search for pivots with error > abs_tol × tol_margin
134 pub tol_margin: f64,
135}
136
137impl Default for DefaultGlobalPivotFinder {
138 fn default() -> Self {
139 Self {
140 nsearch: 5,
141 max_nglobal_pivot: 5,
142 tol_margin: 10.0,
143 }
144 }
145}
146
147impl DefaultGlobalPivotFinder {
148 /// Create a new DefaultGlobalPivotFinder with the given parameters.
149 pub fn new(nsearch: usize, max_nglobal_pivot: usize, tol_margin: f64) -> Self {
150 Self {
151 nsearch,
152 max_nglobal_pivot,
153 tol_margin,
154 }
155 }
156}
157
158impl GlobalPivotFinder for DefaultGlobalPivotFinder {
159 fn find_global_pivots<T, F>(
160 &self,
161 input: &GlobalPivotSearchInput<T>,
162 f: &F,
163 abs_tol: f64,
164 rng: &mut impl Rng,
165 ) -> Vec<MultiIndex>
166 where
167 T: Scalar + TTScalar,
168 F: Fn(&MultiIndex) -> T,
169 {
170 let n = input.local_dims.len();
171
172 // Generate random initial points
173 let initial_points: Vec<MultiIndex> = (0..self.nsearch)
174 .map(|_| {
175 (0..n)
176 .map(|p| rng.random_range(0..input.local_dims[p]))
177 .collect()
178 })
179 .collect();
180
181 let mut found_pivots: Vec<MultiIndex> = Vec::new();
182
183 for point in &initial_points {
184 let mut current_point = point.clone();
185 let mut best_error = 0.0f64;
186 let mut best_point = point.clone();
187
188 // Local search: sweep all dimensions
189 for p in 0..n {
190 let original = current_point[p];
191 for v in 0..input.local_dims[p] {
192 current_point[p] = v;
193 let f_val = f(¤t_point);
194 let tt_val = input
195 .current_tt
196 .evaluate(¤t_point)
197 .unwrap_or(T::zero());
198 let diff = f_val - tt_val;
199 let error = f64::sqrt(Scalar::abs_sq(diff));
200 if error > best_error {
201 best_error = error;
202 best_point = current_point.clone();
203 }
204 }
205 current_point[p] = original; // Reset to original for next dimension
206 }
207
208 // Add point if error exceeds threshold
209 if best_error > abs_tol * self.tol_margin {
210 found_pivots.push(best_point);
211 }
212 }
213
214 // Limit number of pivots
215 found_pivots.truncate(self.max_nglobal_pivot);
216
217 found_pivots
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn test_default_global_pivot_finder() {
227 // Simple function: f(i, j) = i * j
228 let f = |idx: &MultiIndex| (idx[0] * idx[1]) as f64;
229 let local_dims = vec![4, 4];
230
231 // Build a deliberately bad TT (constant 0) to ensure large errors
232 let tensors = vec![
233 tensor4all_simplett::tensor3_zeros(1, 4, 1),
234 tensor4all_simplett::tensor3_zeros(1, 4, 1),
235 ];
236 let tt = TensorTrain::new(tensors).unwrap();
237
238 let input = GlobalPivotSearchInput {
239 local_dims: local_dims.clone(),
240 current_tt: tt,
241 max_sample_value: 9.0,
242 i_set: vec![vec![vec![]], vec![vec![0]]],
243 j_set: vec![vec![vec![0]], vec![vec![]]],
244 };
245
246 let finder = DefaultGlobalPivotFinder::new(10, 3, 1.0);
247 let mut rng = rand::rng();
248
249 let pivots = finder.find_global_pivots(&input, &f, 0.1, &mut rng);
250
251 // Should find some pivots since the TT is zero but f is not
252 // (except at i=0 or j=0)
253 assert!(
254 pivots.len() <= 3,
255 "Should limit to max_nglobal_pivot=3, got {}",
256 pivots.len()
257 );
258 }
259
260 #[test]
261 fn test_custom_global_pivot_finder() {
262 struct FixedPivotFinder;
263
264 impl GlobalPivotFinder for FixedPivotFinder {
265 fn find_global_pivots<T, F>(
266 &self,
267 _input: &GlobalPivotSearchInput<T>,
268 _f: &F,
269 _abs_tol: f64,
270 _rng: &mut impl Rng,
271 ) -> Vec<MultiIndex>
272 where
273 T: Scalar + TTScalar,
274 F: Fn(&MultiIndex) -> T,
275 {
276 // Always return a fixed pivot
277 vec![vec![1, 2]]
278 }
279 }
280
281 let finder = FixedPivotFinder;
282 let f = |_: &MultiIndex| 1.0f64;
283 let tensors = vec![
284 tensor4all_simplett::tensor3_zeros(1, 3, 1),
285 tensor4all_simplett::tensor3_zeros(1, 3, 1),
286 ];
287 let tt = TensorTrain::new(tensors).unwrap();
288
289 let input = GlobalPivotSearchInput {
290 local_dims: vec![3, 3],
291 current_tt: tt,
292 max_sample_value: 1.0,
293 i_set: vec![],
294 j_set: vec![],
295 };
296
297 let mut rng = rand::rng();
298 let pivots = finder.find_global_pivots(&input, &f, 0.0, &mut rng);
299 assert_eq!(pivots, vec![vec![1, 2]]);
300 }
301}