Skip to main content

tensor4all_simplett/
cache.rs

1//! Cached tensor train evaluation
2//!
3//! This module provides `TTCache`, a wrapper around tensor trains that caches
4//! partial evaluations for efficient repeated evaluation.
5
6use std::collections::{HashMap, HashSet};
7
8use bnum::types::{U1024, U256, U512};
9
10use crate::einsum_helper::EinsumScalar;
11use crate::einsum_helper::{matrix_times_col_vector, row_vector_times_matrix};
12use crate::error::{Result, TensorTrainError};
13use crate::traits::{AbstractTensorTrain, TTScalar};
14use crate::types::{LocalIndex, MultiIndex, Tensor3, Tensor3Ops};
15
16/// Compute total bits needed for index space
17fn compute_total_bits(local_dims: &[usize]) -> u32 {
18    local_dims
19        .iter()
20        .map(|&d| if d <= 1 { 0 } else { (d as u64).ilog2() + 1 })
21        .sum()
22}
23
24/// Index key types for different bit widths
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26enum IndexKey {
27    U64(u64),
28    U128(u128),
29    U256(U256),
30    U512(U512),
31    U1024(U1024),
32}
33
34/// Flat indexer with automatic key type selection based on index space size
35enum FlatIndexer {
36    U64 { coeffs: Vec<u64> },
37    U128 { coeffs: Vec<u128> },
38    U256 { coeffs: Vec<U256> },
39    U512 { coeffs: Vec<U512> },
40    U1024 { coeffs: Vec<U1024> },
41}
42
43/// Macro for computing coefficients for primitive integer types (u64, u128)
44macro_rules! compute_coeffs_primitive {
45    ($local_dims:expr, $T:ty) => {{
46        let mut coeffs = Vec::with_capacity($local_dims.len());
47        let mut prod: $T = 1;
48        for &d in $local_dims {
49            coeffs.push(prod);
50            prod = prod.saturating_mul(d as $T);
51        }
52        coeffs
53    }};
54}
55
56/// Macro for computing coefficients for bnum types (U256, U512, U1024)
57macro_rules! compute_coeffs_bnum {
58    ($local_dims:expr, $T:ty) => {{
59        let mut coeffs = Vec::with_capacity($local_dims.len());
60        let mut prod = <$T>::ONE;
61        for &d in $local_dims {
62            coeffs.push(prod);
63            prod = prod.saturating_mul(<$T>::from(d as u64));
64        }
65        coeffs
66    }};
67}
68
69/// Macro for computing flat index for primitive types
70macro_rules! flat_index_primitive {
71    ($idx:expr, $coeffs:expr, $T:ty, $Key:ident) => {{
72        let key: $T = $idx.iter().zip($coeffs).map(|(&i, &c)| c * i as $T).sum();
73        IndexKey::$Key(key)
74    }};
75}
76
77/// Macro for computing flat index for bnum types
78macro_rules! flat_index_bnum {
79    ($idx:expr, $coeffs:expr, $T:ty, $Key:ident) => {{
80        let key = $idx
81            .iter()
82            .zip($coeffs)
83            .map(|(&i, &c)| c * <$T>::from(i as u64))
84            .fold(<$T>::ZERO, |a, b| a + b);
85        IndexKey::$Key(key)
86    }};
87}
88
89impl FlatIndexer {
90    /// Create a new indexer, automatically selecting the key type
91    fn new(local_dims: &[usize]) -> Self {
92        let total_bits = compute_total_bits(local_dims);
93
94        if total_bits <= 64 {
95            Self::U64 {
96                coeffs: compute_coeffs_primitive!(local_dims, u64),
97            }
98        } else if total_bits <= 128 {
99            Self::U128 {
100                coeffs: compute_coeffs_primitive!(local_dims, u128),
101            }
102        } else if total_bits <= 256 {
103            Self::U256 {
104                coeffs: compute_coeffs_bnum!(local_dims, U256),
105            }
106        } else if total_bits <= 512 {
107            Self::U512 {
108                coeffs: compute_coeffs_bnum!(local_dims, U512),
109            }
110        } else {
111            Self::U1024 {
112                coeffs: compute_coeffs_bnum!(local_dims, U1024),
113            }
114        }
115    }
116
117    /// Compute flat index key from multi-index
118    fn flat_index(&self, idx: &[usize]) -> IndexKey {
119        match self {
120            Self::U64 { coeffs } => flat_index_primitive!(idx, coeffs, u64, U64),
121            Self::U128 { coeffs } => flat_index_primitive!(idx, coeffs, u128, U128),
122            Self::U256 { coeffs } => flat_index_bnum!(idx, coeffs, U256, U256),
123            Self::U512 { coeffs } => flat_index_bnum!(idx, coeffs, U512, U512),
124            Self::U1024 { coeffs } => flat_index_bnum!(idx, coeffs, U1024, U1024),
125        }
126    }
127}
128
129/// Helper struct for building unique index mappings
130struct IndexMapper {
131    left_indexer: FlatIndexer,
132    right_indexer: FlatIndexer,
133    left_key_to_id: HashMap<IndexKey, usize>,
134    right_key_to_id: HashMap<IndexKey, usize>,
135    idx_to_left: Vec<usize>,
136    idx_to_right: Vec<usize>,
137    left_first_idx: Vec<usize>,
138    right_first_idx: Vec<usize>,
139}
140
141impl IndexMapper {
142    fn new(left_dims: &[usize], right_dims: &[usize], capacity: usize) -> Self {
143        Self {
144            left_indexer: FlatIndexer::new(left_dims),
145            right_indexer: FlatIndexer::new(right_dims),
146            left_key_to_id: HashMap::new(),
147            right_key_to_id: HashMap::new(),
148            idx_to_left: Vec::with_capacity(capacity),
149            idx_to_right: Vec::with_capacity(capacity),
150            left_first_idx: Vec::new(),
151            right_first_idx: Vec::new(),
152        }
153    }
154
155    fn add_index(&mut self, i: usize, left_part: &[usize], right_part: &[usize]) {
156        let left_key = self.left_indexer.flat_index(left_part);
157        let right_key = self.right_indexer.flat_index(right_part);
158
159        let left_id = match self.left_key_to_id.get(&left_key) {
160            Some(&id) => id,
161            None => {
162                let id = self.left_key_to_id.len();
163                self.left_key_to_id.insert(left_key, id);
164                self.left_first_idx.push(i);
165                id
166            }
167        };
168
169        let right_id = match self.right_key_to_id.get(&right_key) {
170            Some(&id) => id,
171            None => {
172                let id = self.right_key_to_id.len();
173                self.right_key_to_id.insert(right_key, id);
174                self.right_first_idx.push(i);
175                id
176            }
177        };
178
179        self.idx_to_left.push(left_id);
180        self.idx_to_right.push(right_id);
181    }
182}
183
184/// Helper for counting unique keys in split heuristic
185struct UniqueCounter {
186    indexer: FlatIndexer,
187    keys: HashSet<IndexKey>,
188}
189
190impl UniqueCounter {
191    fn new(local_dims: &[usize], capacity: usize) -> Self {
192        Self {
193            indexer: FlatIndexer::new(local_dims),
194            keys: HashSet::with_capacity(capacity),
195        }
196    }
197
198    fn insert(&mut self, idx: &[usize]) {
199        let key = self.indexer.flat_index(idx);
200        self.keys.insert(key);
201    }
202
203    fn len(&self) -> usize {
204        self.keys.len()
205    }
206}
207
208/// Cached tensor train evaluator.
209///
210/// Wraps a tensor train and caches left and right partial contractions so
211/// that repeated evaluations sharing common prefixes or suffixes reuse
212/// previously computed results. This is particularly effective for
213/// batch evaluation via [`evaluate_many`](Self::evaluate_many).
214///
215/// # Examples
216///
217/// ```
218/// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain, TTCache};
219///
220/// let tt = TensorTrain::<f64>::constant(&[2, 3, 4], 5.0);
221/// let mut cache = TTCache::new(&tt);
222///
223/// // Single evaluation (caches intermediate contractions)
224/// let val = cache.evaluate(&[1, 2, 3]).unwrap();
225/// assert!((val - 5.0).abs() < 1e-12);
226///
227/// // Batch evaluation reuses cached partial contractions
228/// let indices = vec![vec![0, 0, 0], vec![1, 2, 3], vec![0, 1, 2]];
229/// let vals = cache.evaluate_many(&indices, None).unwrap();
230/// assert!(vals.iter().all(|&v| (v - 5.0).abs() < 1e-12));
231/// ```
232#[derive(Debug, Clone)]
233pub struct TTCache<T: TTScalar> {
234    /// The site tensors (reshaped to 3D: left_bond x flat_site x right_bond)
235    tensors: Vec<Tensor3<T>>,
236    /// Cache for left partial evaluations: site -> (indices -> vector)
237    cache_left: Vec<HashMap<MultiIndex, Vec<T>>>,
238    /// Cache for right partial evaluations: site -> (indices -> vector)
239    cache_right: Vec<HashMap<MultiIndex, Vec<T>>>,
240    /// Site dimensions for each tensor (can be multi-dimensional per site)
241    site_dims: Vec<Vec<usize>>,
242}
243
244impl<T: TTScalar + EinsumScalar> TTCache<T> {
245    /// Create a new TTCache from a tensor train
246    pub fn new<TT: AbstractTensorTrain<T>>(tt: &TT) -> Self {
247        let n = tt.len();
248        let tensors: Vec<Tensor3<T>> = tt.site_tensors().to_vec();
249        let site_dims: Vec<Vec<usize>> = tensors.iter().map(|t| vec![t.site_dim()]).collect();
250
251        Self {
252            tensors,
253            cache_left: (0..n).map(|_| HashMap::new()).collect(),
254            cache_right: (0..n).map(|_| HashMap::new()).collect(),
255            site_dims,
256        }
257    }
258
259    /// Create a new TTCache with custom site dimensions
260    ///
261    /// This allows treating a single tensor site as multiple logical indices.
262    pub fn with_site_dims<TT: AbstractTensorTrain<T>>(
263        tt: &TT,
264        site_dims: Vec<Vec<usize>>,
265    ) -> Result<Self> {
266        let n = tt.len();
267        if site_dims.len() != n {
268            return Err(TensorTrainError::InvalidOperation {
269                message: format!(
270                    "site_dims length {} doesn't match tensor train length {}",
271                    site_dims.len(),
272                    n
273                ),
274            });
275        }
276
277        // Validate that site_dims products match tensor site dimensions
278        for (i, (tensor, dims)) in tt.site_tensors().iter().zip(site_dims.iter()).enumerate() {
279            let expected: usize = dims.iter().product();
280            if expected != tensor.site_dim() {
281                return Err(TensorTrainError::InvalidOperation {
282                    message: format!(
283                        "site_dims product {} doesn't match tensor site dim {} at site {}",
284                        expected,
285                        tensor.site_dim(),
286                        i
287                    ),
288                });
289            }
290        }
291
292        let tensors: Vec<Tensor3<T>> = tt.site_tensors().to_vec();
293
294        Ok(Self {
295            tensors,
296            cache_left: (0..n).map(|_| HashMap::new()).collect(),
297            cache_right: (0..n).map(|_| HashMap::new()).collect(),
298            site_dims,
299        })
300    }
301
302    /// Number of sites
303    pub fn len(&self) -> usize {
304        self.tensors.len()
305    }
306
307    /// Check if empty
308    pub fn is_empty(&self) -> bool {
309        self.tensors.is_empty()
310    }
311
312    /// Get site dimensions
313    pub fn site_dims(&self) -> &[Vec<usize>] {
314        &self.site_dims
315    }
316
317    /// Get link dimensions
318    pub fn link_dims(&self) -> Vec<usize> {
319        if self.len() <= 1 {
320            return Vec::new();
321        }
322        (1..self.len())
323            .map(|i| self.tensors[i].left_dim())
324            .collect()
325    }
326
327    /// Get link dimension at position i (between site i and i+1)
328    pub fn link_dim(&self, i: usize) -> usize {
329        self.tensors[i + 1].left_dim()
330    }
331
332    /// Clear all cached values
333    pub fn clear_cache(&mut self) {
334        for cache in &mut self.cache_left {
335            cache.clear();
336        }
337        for cache in &mut self.cache_right {
338            cache.clear();
339        }
340    }
341
342    /// Convert multi-index to flat index for a site
343    fn multi_to_flat(&self, site: usize, indices: &[LocalIndex]) -> LocalIndex {
344        let dims = &self.site_dims[site];
345        let mut flat = 0;
346        let mut stride = 1;
347        for (i, &idx) in indices.iter().rev().enumerate() {
348            flat += idx * stride;
349            stride *= dims[dims.len() - 1 - i];
350        }
351        flat
352    }
353
354    /// Evaluate from the left up to (but not including) site `end`
355    ///
356    /// Returns a vector of size `link_dim(end-1)` (or 1 if end == 0)
357    pub fn evaluate_left(&mut self, indices: &[LocalIndex]) -> Vec<T> {
358        let ell = indices.len();
359        if ell == 0 {
360            return vec![T::one()];
361        }
362
363        // Check cache
364        let key: MultiIndex = indices.to_vec();
365        if let Some(cached) = self.cache_left[ell - 1].get(&key) {
366            return cached.clone();
367        }
368
369        // Compute recursively
370        let result = if ell == 1 {
371            // First site: just extract the slice
372            let flat_idx = self.multi_to_flat(0, &indices[0..1]);
373            let tensor = &self.tensors[0];
374            tensor.slice_site(flat_idx)
375        } else {
376            // Recursive case: left[0..ell-1] * tensor[ell-1][:, idx, :]
377            let left = self.evaluate_left(&indices[0..ell - 1]);
378            let flat_idx = self.multi_to_flat(ell - 1, &indices[ell - 1..ell]);
379            let tensor = &self.tensors[ell - 1];
380            let slice = tensor.slice_site(flat_idx);
381            row_vector_times_matrix(&left, &slice, tensor.left_dim(), tensor.right_dim())
382        };
383
384        // Cache and return
385        self.cache_left[ell - 1].insert(key, result.clone());
386        result
387    }
388
389    /// Evaluate from the right starting at site `start`
390    ///
391    /// `indices` contains indices for sites `start` to `n-1`
392    /// Returns a vector of size `link_dim(start-1)` (or 1 if start == n)
393    pub fn evaluate_right(&mut self, indices: &[LocalIndex]) -> Vec<T> {
394        let n = self.len();
395        let ell = indices.len();
396        if ell == 0 {
397            return vec![T::one()];
398        }
399
400        let start = n - ell;
401
402        // Check cache
403        let key: MultiIndex = indices.to_vec();
404        if let Some(cached) = self.cache_right[start].get(&key) {
405            return cached.clone();
406        }
407
408        // Compute recursively
409        let result = if ell == 1 {
410            // Last site: just extract the slice
411            let flat_idx = self.multi_to_flat(n - 1, &indices[0..1]);
412            let tensor = &self.tensors[n - 1];
413            tensor.slice_site(flat_idx)
414        } else {
415            // Recursive case: tensor[start][:, idx, :] * right[1..]
416            let right = self.evaluate_right(&indices[1..]);
417            let flat_idx = self.multi_to_flat(start, &indices[0..1]);
418            let tensor = &self.tensors[start];
419            let slice = tensor.slice_site(flat_idx);
420            matrix_times_col_vector(&slice, tensor.left_dim(), tensor.right_dim(), &right)
421        };
422
423        // Cache and return
424        self.cache_right[start].insert(key, result.clone());
425        result
426    }
427
428    /// Evaluate the tensor train at a given index set using cache
429    pub fn evaluate(&mut self, indices: &[LocalIndex]) -> Result<T> {
430        let n = self.len();
431        if indices.len() != n {
432            return Err(TensorTrainError::IndexLengthMismatch {
433                expected: n,
434                got: indices.len(),
435            });
436        }
437
438        if n == 0 {
439            return Err(TensorTrainError::Empty);
440        }
441
442        // Split at midpoint for efficiency
443        let mid = n / 2;
444        let left = self.evaluate_left(&indices[0..mid]);
445        let right = self.evaluate_right(&indices[mid..]);
446
447        // Contract left and right
448        if left.len() != right.len() {
449            return Err(TensorTrainError::InvalidOperation {
450                message: format!(
451                    "Left/right dimension mismatch: {} vs {}",
452                    left.len(),
453                    right.len()
454                ),
455            });
456        }
457
458        let mut result = T::zero();
459        for i in 0..left.len() {
460            result = result + left[i] * right[i];
461        }
462
463        Ok(result)
464    }
465
466    /// Batch evaluate at multiple index sets
467    ///
468    /// This method efficiently evaluates the tensor train at multiple indices
469    /// by splitting at a given position and computing unique left/right environments.
470    ///
471    /// # Arguments
472    /// * `indices` - The indices to evaluate
473    /// * `split` - Optional split position. If `None`, uses a simple heuristic
474    ///   (checks 1/4, 1/2, 3/4 positions and picks the best).
475    ///   If you know the optimal split position (e.g., from TCI), pass `Some(split)`
476    ///   to avoid the search overhead.
477    pub fn evaluate_many(
478        &mut self,
479        indices: &[MultiIndex],
480        split: Option<usize>,
481    ) -> Result<Vec<T>> {
482        if indices.is_empty() {
483            return Ok(Vec::new());
484        }
485
486        let n = self.len();
487        if n == 0 {
488            return Err(TensorTrainError::Empty);
489        }
490
491        // Determine split position
492        let split = match split {
493            Some(s) => s,
494            None => self.find_split_heuristic(indices),
495        };
496
497        if split == 0 || split > n {
498            return Err(TensorTrainError::InvalidOperation {
499                message: format!("Invalid split position: {} (n_sites={})", split, n),
500            });
501        }
502
503        // Get local dimensions for flat index computation
504        let local_dims: Vec<usize> = self.site_dims.iter().map(|d| d.iter().product()).collect();
505
506        // Build index mapper with appropriate key type for each half
507        let mut mapper =
508            IndexMapper::new(&local_dims[..split], &local_dims[split..], indices.len());
509
510        for (i, idx) in indices.iter().enumerate() {
511            mapper.add_index(i, &idx[..split], &idx[split..]);
512        }
513
514        // Extract unique parts using first occurrence indices
515        let unique_left: Vec<MultiIndex> = mapper
516            .left_first_idx
517            .iter()
518            .map(|&i| indices[i][..split].to_vec())
519            .collect();
520
521        let unique_right: Vec<MultiIndex> = mapper
522            .right_first_idx
523            .iter()
524            .map(|&i| indices[i][split..].to_vec())
525            .collect();
526
527        // Compute left environments for all unique left parts
528        let left_envs: Vec<Vec<T>> = unique_left.iter().map(|l| self.evaluate_left(l)).collect();
529
530        // Compute right environments for all unique right parts
531        let right_envs: Vec<Vec<T>> = unique_right
532            .iter()
533            .map(|r| self.evaluate_right(r))
534            .collect();
535
536        // Compute results using position mappings
537        let results: Vec<T> = mapper
538            .idx_to_left
539            .iter()
540            .zip(&mapper.idx_to_right)
541            .map(|(&il, &ir)| {
542                let left_env = &left_envs[il];
543                let right_env = &right_envs[ir];
544                // Inner product
545                left_env
546                    .iter()
547                    .zip(right_env.iter())
548                    .fold(T::zero(), |acc, (&l, &r)| acc + l * r)
549            })
550            .collect();
551
552        Ok(results)
553    }
554
555    /// Find a good split position using 3-point sampling heuristic
556    ///
557    /// Samples at 1/4, 1/2, 3/4 positions and returns the one with
558    /// minimum total unique left + right parts.
559    fn find_split_heuristic(&self, indices: &[MultiIndex]) -> usize {
560        let n = self.len();
561        if n <= 1 {
562            return n.max(1);
563        }
564
565        let local_dims: Vec<usize> = self.site_dims.iter().map(|d| d.iter().product()).collect();
566
567        // Helper to compute cost at a split position
568        let compute_cost = |split: usize| -> usize {
569            if split == 0 || split >= n {
570                return usize::MAX;
571            }
572
573            let mut left_counter = UniqueCounter::new(&local_dims[..split], indices.len());
574            let mut right_counter = UniqueCounter::new(&local_dims[split..], indices.len());
575
576            for idx in indices {
577                left_counter.insert(&idx[..split]);
578                right_counter.insert(&idx[split..]);
579            }
580
581            left_counter.len() + right_counter.len()
582        };
583
584        // 3-point sampling: 1/4, 1/2, 3/4 positions
585        let candidates = [n / 4, n / 2, 3 * n / 4];
586        let costs: Vec<(usize, usize)> = candidates
587            .iter()
588            .filter(|&&p| p >= 1 && p < n)
589            .map(|&p| (p, compute_cost(p)))
590            .collect();
591
592        costs
593            .into_iter()
594            .min_by_key(|&(_, c)| c)
595            .map(|(p, _)| p)
596            .unwrap_or(n / 2)
597    }
598}
599
600#[cfg(test)]
601mod tests;