Skip to main content

tensor4all_tcicore/cached_function/
mod.rs

1//! Cached function wrapper for expensive function evaluations.
2//!
3//! [`CachedFunction`] wraps a user-supplied function `Fn(&[I]) -> V` with
4//! thread-safe memoization. On first call for a given multi-index, the
5//! function is evaluated and the result is cached; subsequent calls return
6//! the cached value.
7//!
8//! The internal cache key type is automatically selected based on the total
9//! index space size (up to 1024 bits by default). For larger index spaces,
10//! use [`CachedFunction::with_key_type`] with a custom [`CacheKey`]
11//! implementation.
12//!
13//! # Examples
14//!
15//! ```
16//! use tensor4all_tcicore::CachedFunction;
17//!
18//! let cf = CachedFunction::new(
19//!     |idx: &[usize]| idx[0] + idx[1],
20//!     &[3, 4],
21//! ).unwrap();
22//!
23//! assert_eq!(cf.eval(&[1, 2]), 3);
24//! assert_eq!(cf.num_evals(), 1);
25//! assert_eq!(cf.eval(&[1, 2]), 3);
26//! assert_eq!(cf.num_cache_hits(), 1);
27//! ```
28
29pub mod cache_key;
30pub mod error;
31pub mod index_int;
32
33use std::collections::HashMap;
34use std::sync::atomic::{AtomicUsize, Ordering};
35use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
36
37use bnum::types::{U1024, U256, U512};
38
39use cache_key::CacheKey;
40use index_int::IndexInt;
41
42/// Compute total bits needed to represent the index space.
43pub(crate) fn total_bits(local_dims: &[usize]) -> u32 {
44    local_dims
45        .iter()
46        .map(|&d| {
47            if d <= 1 {
48                0
49            } else {
50                ((d - 1) as u64).ilog2() + 1
51            }
52        })
53        .sum()
54}
55
56/// Compute mixed-radix coefficients for flat index computation.
57///
58/// Returns `Err(CacheKeyError::Overflow)` if the index space overflows key type `K`.
59pub(crate) fn compute_coeffs<K: CacheKey>(
60    local_dims: &[usize],
61) -> Result<Vec<K>, error::CacheKeyError> {
62    let bits = total_bits(local_dims);
63    if bits > K::BITS_COUNT {
64        return Err(error::CacheKeyError::Overflow {
65            total_bits: bits,
66            max_bits: K::BITS_COUNT,
67            key_type: std::any::type_name::<K>(),
68        });
69    }
70
71    let mut coeffs = Vec::with_capacity(local_dims.len());
72    let mut prod = K::ONE;
73    for &d in local_dims {
74        coeffs.push(prod.clone());
75        let dim = K::from_usize(d);
76        prod = prod
77            .checked_mul(dim)
78            .ok_or_else(|| error::CacheKeyError::Overflow {
79                total_bits: bits,
80                max_bits: K::BITS_COUNT,
81                key_type: std::any::type_name::<K>(),
82            })?;
83    }
84
85    Ok(coeffs)
86}
87
88fn read_lock<T>(lock: &RwLock<T>) -> RwLockReadGuard<'_, T> {
89    lock.read().unwrap_or_else(|poisoned| poisoned.into_inner())
90}
91
92fn write_lock<T>(lock: &RwLock<T>) -> RwLockWriteGuard<'_, T> {
93    lock.write()
94        .unwrap_or_else(|poisoned| poisoned.into_inner())
95}
96
97/// Compute flat index from multi-index and coefficients.
98fn flat_index<K: CacheKey, I: IndexInt>(idx: &[I], coeffs: &[K]) -> K {
99    idx.iter().zip(coeffs).fold(K::ZERO, |acc, (&i, c)| {
100        acc.wrapping_add(
101            c.clone()
102                .checked_mul(K::from_usize(i.to_usize()))
103                .unwrap_or(K::ZERO),
104        )
105    })
106}
107
108/// Internal cache with automatically selected key type.
109enum InnerCache<V> {
110    U64 {
111        cache: RwLock<HashMap<u64, V>>,
112        coeffs: Vec<u64>,
113    },
114    U128 {
115        cache: RwLock<HashMap<u128, V>>,
116        coeffs: Vec<u128>,
117    },
118    U256 {
119        cache: RwLock<HashMap<U256, V>>,
120        coeffs: Vec<U256>,
121    },
122    U512 {
123        cache: RwLock<HashMap<U512, V>>,
124        coeffs: Vec<U512>,
125    },
126    U1024 {
127        cache: RwLock<HashMap<U1024, V>>,
128        coeffs: Vec<U1024>,
129    },
130}
131
132impl<V: Clone + Send + Sync> InnerCache<V> {
133    /// Create a new cache, automatically selecting the key type.
134    fn new(local_dims: &[usize]) -> Result<Self, error::CacheKeyError> {
135        let bits = total_bits(local_dims);
136        if bits <= 64 {
137            Ok(Self::U64 {
138                cache: RwLock::new(HashMap::new()),
139                coeffs: compute_coeffs::<u64>(local_dims)?,
140            })
141        } else if bits <= 128 {
142            Ok(Self::U128 {
143                cache: RwLock::new(HashMap::new()),
144                coeffs: compute_coeffs::<u128>(local_dims)?,
145            })
146        } else if bits <= 256 {
147            Ok(Self::U256 {
148                cache: RwLock::new(HashMap::new()),
149                coeffs: compute_coeffs::<U256>(local_dims)?,
150            })
151        } else if bits <= 512 {
152            Ok(Self::U512 {
153                cache: RwLock::new(HashMap::new()),
154                coeffs: compute_coeffs::<U512>(local_dims)?,
155            })
156        } else if bits <= 1024 {
157            Ok(Self::U1024 {
158                cache: RwLock::new(HashMap::new()),
159                coeffs: compute_coeffs::<U1024>(local_dims)?,
160            })
161        } else {
162            Err(error::CacheKeyError::Overflow {
163                total_bits: bits,
164                max_bits: 1024,
165                key_type: "auto",
166            })
167        }
168    }
169
170    fn get<I: IndexInt>(&self, idx: &[I]) -> Option<V> {
171        match self {
172            Self::U64 { cache, coeffs } => read_lock(cache).get(&flat_index(idx, coeffs)).cloned(),
173            Self::U128 { cache, coeffs } => read_lock(cache).get(&flat_index(idx, coeffs)).cloned(),
174            Self::U256 { cache, coeffs } => read_lock(cache).get(&flat_index(idx, coeffs)).cloned(),
175            Self::U512 { cache, coeffs } => read_lock(cache).get(&flat_index(idx, coeffs)).cloned(),
176            Self::U1024 { cache, coeffs } => {
177                read_lock(cache).get(&flat_index(idx, coeffs)).cloned()
178            }
179        }
180    }
181
182    fn insert<I: IndexInt>(&self, idx: &[I], value: V) {
183        match self {
184            Self::U64 { cache, coeffs } => {
185                write_lock(cache).insert(flat_index(idx, coeffs), value);
186            }
187            Self::U128 { cache, coeffs } => {
188                write_lock(cache).insert(flat_index(idx, coeffs), value);
189            }
190            Self::U256 { cache, coeffs } => {
191                write_lock(cache).insert(flat_index(idx, coeffs), value);
192            }
193            Self::U512 { cache, coeffs } => {
194                write_lock(cache).insert(flat_index(idx, coeffs), value);
195            }
196            Self::U1024 { cache, coeffs } => {
197                write_lock(cache).insert(flat_index(idx, coeffs), value);
198            }
199        }
200    }
201
202    fn contains<I: IndexInt>(&self, idx: &[I]) -> bool {
203        match self {
204            Self::U64 { cache, coeffs } => read_lock(cache).contains_key(&flat_index(idx, coeffs)),
205            Self::U128 { cache, coeffs } => read_lock(cache).contains_key(&flat_index(idx, coeffs)),
206            Self::U256 { cache, coeffs } => read_lock(cache).contains_key(&flat_index(idx, coeffs)),
207            Self::U512 { cache, coeffs } => read_lock(cache).contains_key(&flat_index(idx, coeffs)),
208            Self::U1024 { cache, coeffs } => {
209                read_lock(cache).contains_key(&flat_index(idx, coeffs))
210            }
211        }
212    }
213
214    fn len(&self) -> usize {
215        match self {
216            Self::U64 { cache, .. } => read_lock(cache).len(),
217            Self::U128 { cache, .. } => read_lock(cache).len(),
218            Self::U256 { cache, .. } => read_lock(cache).len(),
219            Self::U512 { cache, .. } => read_lock(cache).len(),
220            Self::U1024 { cache, .. } => read_lock(cache).len(),
221        }
222    }
223
224    fn clear(&self) {
225        match self {
226            Self::U64 { cache, .. } => write_lock(cache).clear(),
227            Self::U128 { cache, .. } => write_lock(cache).clear(),
228            Self::U256 { cache, .. } => write_lock(cache).clear(),
229            Self::U512 { cache, .. } => write_lock(cache).clear(),
230            Self::U1024 { cache, .. } => write_lock(cache).clear(),
231        }
232    }
233
234    fn key_type_name(&self) -> &'static str {
235        match self {
236            Self::U64 { .. } => "u64",
237            Self::U128 { .. } => "u128",
238            Self::U256 { .. } => "U256",
239            Self::U512 { .. } => "U512",
240            Self::U1024 { .. } => "U1024",
241        }
242    }
243}
244
245/// Type-erased cache interface for custom key types.
246trait DynCache<V>: Send + Sync {
247    fn get(&self, idx: &[usize]) -> Option<V>;
248    fn insert(&self, idx: &[usize], value: V);
249    fn contains(&self, idx: &[usize]) -> bool;
250    fn len(&self) -> usize;
251    fn clear(&self);
252}
253
254/// Generic cache for user-specified key types.
255struct GenericCache<K: CacheKey, V> {
256    cache: RwLock<HashMap<K, V>>,
257    coeffs: Vec<K>,
258}
259
260impl<K: CacheKey, V: Clone + Send + Sync> GenericCache<K, V> {
261    fn new(local_dims: &[usize]) -> Result<Self, error::CacheKeyError> {
262        Ok(Self {
263            cache: RwLock::new(HashMap::new()),
264            coeffs: compute_coeffs::<K>(local_dims)?,
265        })
266    }
267}
268
269impl<K: CacheKey, V: Clone + Send + Sync> DynCache<V> for GenericCache<K, V> {
270    fn get(&self, idx: &[usize]) -> Option<V> {
271        let key = flat_index::<K, usize>(idx, &self.coeffs);
272        read_lock(&self.cache).get(&key).cloned()
273    }
274
275    fn insert(&self, idx: &[usize], value: V) {
276        let key = flat_index::<K, usize>(idx, &self.coeffs);
277        write_lock(&self.cache).insert(key, value);
278    }
279
280    fn contains(&self, idx: &[usize]) -> bool {
281        let key = flat_index::<K, usize>(idx, &self.coeffs);
282        read_lock(&self.cache).contains_key(&key)
283    }
284
285    fn len(&self) -> usize {
286        read_lock(&self.cache).len()
287    }
288
289    fn clear(&self) {
290        write_lock(&self.cache).clear();
291    }
292}
293
294/// Internal backend: auto-selected enum or custom type-erased cache.
295enum CacheBackend<V: Clone + Send + Sync + 'static> {
296    Auto(InnerCache<V>),
297    Custom(Box<dyn DynCache<V>>),
298}
299
300impl<V: Clone + Send + Sync + 'static> CacheBackend<V> {
301    fn get<I: IndexInt>(&self, idx: &[I]) -> Option<V> {
302        match self {
303            Self::Auto(inner) => inner.get(idx),
304            Self::Custom(cache) => {
305                let usize_idx: Vec<usize> = idx.iter().map(|&i| i.to_usize()).collect();
306                cache.get(&usize_idx)
307            }
308        }
309    }
310
311    fn insert<I: IndexInt>(&self, idx: &[I], value: V) {
312        match self {
313            Self::Auto(inner) => inner.insert(idx, value),
314            Self::Custom(cache) => {
315                let usize_idx: Vec<usize> = idx.iter().map(|&i| i.to_usize()).collect();
316                cache.insert(&usize_idx, value);
317            }
318        }
319    }
320
321    fn contains<I: IndexInt>(&self, idx: &[I]) -> bool {
322        match self {
323            Self::Auto(inner) => inner.contains(idx),
324            Self::Custom(cache) => {
325                let usize_idx: Vec<usize> = idx.iter().map(|&i| i.to_usize()).collect();
326                cache.contains(&usize_idx)
327            }
328        }
329    }
330
331    fn len(&self) -> usize {
332        match self {
333            Self::Auto(inner) => inner.len(),
334            Self::Custom(cache) => cache.len(),
335        }
336    }
337
338    fn clear(&self) {
339        match self {
340            Self::Auto(inner) => inner.clear(),
341            Self::Custom(cache) => cache.clear(),
342        }
343    }
344
345    fn key_type_name(&self) -> &'static str {
346        match self {
347            Self::Auto(inner) => inner.key_type_name(),
348            Self::Custom(_) => "custom",
349        }
350    }
351}
352
353type BatchFunc<I, V> = dyn Fn(&[Vec<I>]) -> Vec<V> + Send + Sync;
354
355/// A wrapper that caches function evaluations for multi-index inputs.
356///
357/// Thread-safe: all methods take `&self`. Multiple threads can call `eval`
358/// concurrently.
359///
360/// # Type parameters
361///
362/// - `V` - cached value type
363/// - `F` - single-evaluation function `Fn(&[I]) -> V`
364/// - `I` - index element type (default `usize`); use `u8` for quantics
365///
366/// # Examples
367///
368/// ```
369/// use tensor4all_tcicore::CachedFunction;
370///
371/// // Cache a 2-site function with local dimensions [3, 4]
372/// let cf = CachedFunction::new(
373///     |idx: &[usize]| (idx[0] * 4 + idx[1]) as f64,
374///     &[3, 4],
375/// ).unwrap();
376///
377/// // First call evaluates and caches
378/// let v00 = cf.eval(&[0, 0]);
379/// assert_eq!(v00, 0.0);
380/// assert_eq!(cf.num_evals(), 1);
381/// assert_eq!(cf.num_cache_hits(), 0);
382///
383/// // Second call uses cache
384/// let v00_again = cf.eval(&[0, 0]);
385/// assert_eq!(v00_again, 0.0);
386/// assert_eq!(cf.num_cache_hits(), 1);
387///
388/// let v12 = cf.eval(&[1, 2]);
389/// assert_eq!(v12, 6.0); // 1*4 + 2
390/// ```
391pub struct CachedFunction<V, F, I = usize>
392where
393    I: IndexInt,
394    V: Clone + Send + Sync + 'static,
395    F: Fn(&[I]) -> V + Send + Sync,
396{
397    func: F,
398    batch_func: Option<Box<BatchFunc<I, V>>>,
399    cache: CacheBackend<V>,
400    local_dims: Vec<usize>,
401    num_evals: AtomicUsize,
402    num_cache_hits: AtomicUsize,
403    _phantom: std::marker::PhantomData<I>,
404}
405
406impl<V, F, I> CachedFunction<V, F, I>
407where
408    I: IndexInt,
409    V: Clone + Send + Sync + 'static,
410    F: Fn(&[I]) -> V + Send + Sync,
411{
412    /// Create a new cached function with automatic key selection (up to 1024 bits).
413    ///
414    /// # Examples
415    ///
416    /// ```
417    /// use tensor4all_tcicore::CachedFunction;
418    ///
419    /// let cf = CachedFunction::new(|idx: &[usize]| idx[0] + idx[1], &[2, 3]).unwrap();
420    /// assert_eq!(cf.eval(&[1, 2]), 3);
421    /// assert_eq!(cf.num_sites(), 2);
422    /// assert_eq!(cf.local_dims(), &[2, 3]);
423    /// ```
424    pub fn new(func: F, local_dims: &[usize]) -> Result<Self, error::CacheKeyError> {
425        Ok(Self {
426            func,
427            batch_func: None,
428            cache: CacheBackend::Auto(InnerCache::new(local_dims)?),
429            local_dims: local_dims.to_vec(),
430            num_evals: AtomicUsize::new(0),
431            num_cache_hits: AtomicUsize::new(0),
432            _phantom: std::marker::PhantomData,
433        })
434    }
435
436    /// Create with a batch function for efficient multi-point evaluation.
437    ///
438    /// The batch function is used for cache misses during [`eval_batch`](Self::eval_batch)
439    /// calls, enabling amortized cost when evaluating many indices at once
440    /// (e.g., batch FFI calls or vectorized computations).
441    ///
442    /// # Examples
443    ///
444    /// ```
445    /// use tensor4all_tcicore::CachedFunction;
446    ///
447    /// let cf = CachedFunction::with_batch(
448    ///     |idx: &[usize]| idx[0] * 10 + idx[1],
449    ///     |indices: &[Vec<usize>]| indices.iter().map(|idx| idx[0] * 10 + idx[1]).collect(),
450    ///     &[3, 4],
451    /// ).unwrap();
452    ///
453    /// let results = cf.eval_batch(&[vec![0, 1], vec![2, 3]]);
454    /// assert_eq!(results, vec![1, 23]);
455    /// assert_eq!(cf.num_evals(), 2);
456    /// ```
457    pub fn with_batch<B>(
458        func: F,
459        batch_func: B,
460        local_dims: &[usize],
461    ) -> Result<Self, error::CacheKeyError>
462    where
463        B: Fn(&[Vec<I>]) -> Vec<V> + Send + Sync + 'static,
464    {
465        Ok(Self {
466            func,
467            batch_func: Some(Box::new(batch_func)),
468            cache: CacheBackend::Auto(InnerCache::new(local_dims)?),
469            local_dims: local_dims.to_vec(),
470            num_evals: AtomicUsize::new(0),
471            num_cache_hits: AtomicUsize::new(0),
472            _phantom: std::marker::PhantomData,
473        })
474    }
475
476    /// Create with an explicit key type for index spaces larger than 1024 bits.
477    ///
478    /// # Example
479    ///
480    /// ```
481    /// use bnum::types::U2048;
482    /// use tensor4all_tcicore::{CacheKey, CachedFunction};
483    ///
484    /// #[derive(Clone, Hash, PartialEq, Eq)]
485    /// struct U2048Key(U2048);
486    ///
487    /// impl CacheKey for U2048Key {
488    ///     const BITS_COUNT: u32 = 2048;
489    ///     const ZERO: Self = Self(U2048::ZERO);
490    ///     const ONE: Self = Self(U2048::ONE);
491    ///
492    ///     fn from_usize(v: usize) -> Self {
493    ///         Self(U2048::from(v as u64))
494    ///     }
495    ///
496    ///     fn checked_mul(self, rhs: Self) -> Option<Self> {
497    ///         self.0.checked_mul(rhs.0).map(Self)
498    ///     }
499    ///
500    ///     fn wrapping_add(self, rhs: Self) -> Self {
501    ///         Self(self.0.wrapping_add(rhs.0))
502    ///     }
503    /// }
504    ///
505    /// let local_dims = vec![2usize; 1025];
506    /// let cf = CachedFunction::with_key_type::<U2048Key>(
507    ///     |idx: &[usize]| idx.iter().sum::<usize>(),
508    ///     &local_dims,
509    /// ).unwrap();
510    /// let zeros = vec![0usize; 1025];
511    ///
512    /// assert_eq!(cf.eval(&zeros), 0);
513    /// assert_eq!(cf.key_type(), "custom");
514    /// ```
515    pub fn with_key_type<K: CacheKey>(
516        func: F,
517        local_dims: &[usize],
518    ) -> Result<Self, error::CacheKeyError> {
519        Ok(Self {
520            func,
521            batch_func: None,
522            cache: CacheBackend::Custom(Box::new(GenericCache::<K, V>::new(local_dims)?)),
523            local_dims: local_dims.to_vec(),
524            num_evals: AtomicUsize::new(0),
525            num_cache_hits: AtomicUsize::new(0),
526            _phantom: std::marker::PhantomData,
527        })
528    }
529
530    /// Create with explicit key type and batch function.
531    ///
532    /// Combines [`with_key_type`](Self::with_key_type) and
533    /// [`with_batch`](Self::with_batch) for index spaces larger than 1024
534    /// bits that also benefit from batch evaluation.
535    ///
536    /// # Examples
537    ///
538    /// ```
539    /// use tensor4all_tcicore::CachedFunction;
540    ///
541    /// // Use u128 key type with batch support
542    /// let cf = CachedFunction::with_key_type_and_batch::<u128, _>(
543    ///     |idx: &[usize]| idx.iter().sum::<usize>(),
544    ///     |indices: &[Vec<usize>]| indices.iter().map(|idx| idx.iter().sum()).collect(),
545    ///     &[2, 3, 4],
546    /// ).unwrap();
547    ///
548    /// let results = cf.eval_batch(&[vec![0, 0, 0], vec![1, 2, 3]]);
549    /// assert_eq!(results, vec![0, 6]);
550    /// ```
551    pub fn with_key_type_and_batch<K: CacheKey, B>(
552        func: F,
553        batch_func: B,
554        local_dims: &[usize],
555    ) -> Result<Self, error::CacheKeyError>
556    where
557        B: Fn(&[Vec<I>]) -> Vec<V> + Send + Sync + 'static,
558    {
559        Ok(Self {
560            func,
561            batch_func: Some(Box::new(batch_func)),
562            cache: CacheBackend::Custom(Box::new(GenericCache::<K, V>::new(local_dims)?)),
563            local_dims: local_dims.to_vec(),
564            num_evals: AtomicUsize::new(0),
565            num_cache_hits: AtomicUsize::new(0),
566            _phantom: std::marker::PhantomData,
567        })
568    }
569
570    /// Evaluate at a given index, using cache if available.
571    ///
572    /// On the first call for a given index, the wrapped function is invoked and
573    /// the result is cached. Subsequent calls with the same index return the
574    /// cached value. This method is thread-safe.
575    ///
576    /// # Examples
577    ///
578    /// ```
579    /// use tensor4all_tcicore::CachedFunction;
580    ///
581    /// let cf = CachedFunction::new(|idx: &[usize]| idx[0] * idx[1], &[5, 5]).unwrap();
582    /// assert_eq!(cf.eval(&[3, 4]), 12);
583    /// assert_eq!(cf.num_evals(), 1);
584    ///
585    /// // Cache hit
586    /// assert_eq!(cf.eval(&[3, 4]), 12);
587    /// assert_eq!(cf.num_evals(), 1);
588    /// assert_eq!(cf.num_cache_hits(), 1);
589    /// ```
590    pub fn eval(&self, idx: &[I]) -> V {
591        if let Some(value) = self.cache.get(idx) {
592            self.num_cache_hits.fetch_add(1, Ordering::Relaxed);
593            return value;
594        }
595
596        self.num_evals.fetch_add(1, Ordering::Relaxed);
597        let value = (self.func)(idx);
598        self.cache.insert(idx, value.clone());
599        value
600    }
601
602    /// Evaluate bypassing the cache.
603    ///
604    /// The result is neither read from nor stored in the cache, and
605    /// evaluation counters are not updated. Useful for verification or
606    /// when the caller intentionally wants a fresh evaluation.
607    ///
608    /// # Examples
609    ///
610    /// ```
611    /// use tensor4all_tcicore::CachedFunction;
612    ///
613    /// let cf = CachedFunction::new(|idx: &[usize]| idx[0] + 1, &[4]).unwrap();
614    /// assert_eq!(cf.eval_no_cache(&[2]), 3);
615    /// assert_eq!(cf.cache_size(), 0);
616    /// assert_eq!(cf.num_evals(), 0);
617    /// ```
618    pub fn eval_no_cache(&self, idx: &[I]) -> V {
619        (self.func)(idx)
620    }
621
622    /// Evaluate at multiple indices. Uses batch function for cache misses if available.
623    ///
624    /// Returns results in the same order as the input indices.
625    ///
626    /// # Examples
627    ///
628    /// ```
629    /// use tensor4all_tcicore::CachedFunction;
630    ///
631    /// let cf = CachedFunction::new(|idx: &[usize]| idx[0] * 2 + idx[1], &[2, 2]).unwrap();
632    /// let results = cf.eval_batch(&[vec![0, 0], vec![0, 1], vec![1, 0]]);
633    /// assert_eq!(results, vec![0, 1, 2]);
634    /// ```
635    pub fn eval_batch(&self, indices: &[Vec<I>]) -> Vec<V> {
636        if indices.is_empty() {
637            return Vec::new();
638        }
639
640        let mut results: Vec<Option<V>> = Vec::with_capacity(indices.len());
641        let mut miss_positions: Vec<usize> = Vec::new();
642        let mut miss_indices: Vec<Vec<I>> = Vec::new();
643
644        for (pos, idx) in indices.iter().enumerate() {
645            if let Some(value) = self.cache.get(idx) {
646                self.num_cache_hits.fetch_add(1, Ordering::Relaxed);
647                results.push(Some(value));
648            } else {
649                results.push(None);
650                miss_positions.push(pos);
651                miss_indices.push(idx.clone());
652            }
653        }
654
655        if miss_indices.is_empty() {
656            return results.into_iter().flatten().collect();
657        }
658
659        self.num_evals
660            .fetch_add(miss_indices.len(), Ordering::Relaxed);
661        let miss_values = if let Some(batch_func) = self.batch_func.as_ref() {
662            batch_func(&miss_indices)
663        } else {
664            miss_indices.iter().map(|idx| (self.func)(idx)).collect()
665        };
666
667        for (i, pos) in miss_positions.iter().enumerate() {
668            self.cache.insert(&miss_indices[i], miss_values[i].clone());
669            results[*pos] = Some(miss_values[i].clone());
670        }
671
672        results.into_iter().flatten().collect()
673    }
674
675    /// Get the local dimensions.
676    ///
677    /// # Examples
678    ///
679    /// ```
680    /// use tensor4all_tcicore::CachedFunction;
681    ///
682    /// let cf = CachedFunction::new(|idx: &[usize]| 0, &[3, 4, 5]).unwrap();
683    /// assert_eq!(cf.local_dims(), &[3, 4, 5]);
684    /// ```
685    pub fn local_dims(&self) -> &[usize] {
686        &self.local_dims
687    }
688
689    /// Get the number of sites (length of the multi-index).
690    ///
691    /// # Examples
692    ///
693    /// ```
694    /// use tensor4all_tcicore::CachedFunction;
695    ///
696    /// let cf = CachedFunction::new(|idx: &[usize]| 0, &[2, 3]).unwrap();
697    /// assert_eq!(cf.num_sites(), 2);
698    /// ```
699    pub fn num_sites(&self) -> usize {
700        self.local_dims.len()
701    }
702
703    /// Get the number of function evaluations (cache misses).
704    ///
705    /// # Examples
706    ///
707    /// ```
708    /// use tensor4all_tcicore::CachedFunction;
709    ///
710    /// let cf = CachedFunction::new(|idx: &[usize]| idx[0], &[4]).unwrap();
711    /// cf.eval(&[0]);
712    /// cf.eval(&[1]);
713    /// cf.eval(&[0]); // cache hit, not a new eval
714    /// assert_eq!(cf.num_evals(), 2);
715    /// ```
716    pub fn num_evals(&self) -> usize {
717        self.num_evals.load(Ordering::Relaxed)
718    }
719
720    /// Get the number of cache hits.
721    ///
722    /// # Examples
723    ///
724    /// ```
725    /// use tensor4all_tcicore::CachedFunction;
726    ///
727    /// let cf = CachedFunction::new(|idx: &[usize]| idx[0], &[4]).unwrap();
728    /// cf.eval(&[0]);
729    /// assert_eq!(cf.num_cache_hits(), 0);
730    /// cf.eval(&[0]);
731    /// assert_eq!(cf.num_cache_hits(), 1);
732    /// ```
733    pub fn num_cache_hits(&self) -> usize {
734        self.num_cache_hits.load(Ordering::Relaxed)
735    }
736
737    /// Get total calls (evaluations + cache hits).
738    ///
739    /// # Examples
740    ///
741    /// ```
742    /// use tensor4all_tcicore::CachedFunction;
743    ///
744    /// let cf = CachedFunction::new(|idx: &[usize]| idx[0], &[4]).unwrap();
745    /// cf.eval(&[0]);
746    /// cf.eval(&[1]);
747    /// cf.eval(&[0]); // cache hit
748    /// assert_eq!(cf.total_calls(), 3);
749    /// assert_eq!(cf.total_calls(), cf.num_evals() + cf.num_cache_hits());
750    /// ```
751    pub fn total_calls(&self) -> usize {
752        self.num_evals() + self.num_cache_hits()
753    }
754
755    /// Get cache hit ratio (0.0 when no calls have been made).
756    ///
757    /// Returns `num_cache_hits() / total_calls()` as a value in `[0.0, 1.0]`.
758    ///
759    /// # Examples
760    ///
761    /// ```
762    /// use tensor4all_tcicore::CachedFunction;
763    ///
764    /// let cf = CachedFunction::new(|idx: &[usize]| idx[0], &[4]).unwrap();
765    /// assert_eq!(cf.cache_hit_ratio(), 0.0); // no calls yet
766    ///
767    /// cf.eval(&[0]);
768    /// cf.eval(&[0]); // cache hit
769    /// assert!((cf.cache_hit_ratio() - 0.5).abs() < 1e-10);
770    /// ```
771    pub fn cache_hit_ratio(&self) -> f64 {
772        let total = self.total_calls();
773        if total == 0 {
774            0.0
775        } else {
776            self.num_cache_hits() as f64 / total as f64
777        }
778    }
779
780    /// Clear the cache.
781    ///
782    /// # Examples
783    ///
784    /// ```
785    /// use tensor4all_tcicore::CachedFunction;
786    ///
787    /// let cf = CachedFunction::new(|idx: &[usize]| idx[0], &[4]).unwrap();
788    /// cf.eval(&[2]);
789    /// assert_eq!(cf.cache_size(), 1);
790    /// cf.clear_cache();
791    /// assert_eq!(cf.cache_size(), 0);
792    /// ```
793    pub fn clear_cache(&self) {
794        self.cache.clear();
795    }
796
797    /// Number of cached entries.
798    ///
799    /// # Examples
800    ///
801    /// ```
802    /// use tensor4all_tcicore::CachedFunction;
803    ///
804    /// let cf = CachedFunction::new(|idx: &[usize]| idx[0], &[4]).unwrap();
805    /// assert_eq!(cf.cache_size(), 0);
806    /// cf.eval(&[0]);
807    /// cf.eval(&[1]);
808    /// assert_eq!(cf.cache_size(), 2);
809    /// cf.eval(&[0]); // cache hit, no new entry
810    /// assert_eq!(cf.cache_size(), 2);
811    /// ```
812    pub fn cache_size(&self) -> usize {
813        self.cache.len()
814    }
815
816    /// Check if an index is cached.
817    ///
818    /// # Examples
819    ///
820    /// ```
821    /// use tensor4all_tcicore::CachedFunction;
822    ///
823    /// let cf = CachedFunction::new(|idx: &[usize]| idx[0], &[4]).unwrap();
824    /// assert!(!cf.is_cached(&[1]));
825    /// cf.eval(&[1]);
826    /// assert!(cf.is_cached(&[1]));
827    /// ```
828    pub fn is_cached(&self, idx: &[I]) -> bool {
829        self.cache.contains(idx)
830    }
831
832    /// Internal key type name (for debugging).
833    ///
834    /// Returns `"u64"`, `"u128"`, `"U256"`, `"U512"`, `"U1024"` for
835    /// automatically selected types, or `"custom"` when constructed with
836    /// [`with_key_type`](Self::with_key_type).
837    ///
838    /// # Examples
839    ///
840    /// ```
841    /// use tensor4all_tcicore::CachedFunction;
842    ///
843    /// // Small index space uses u64
844    /// let cf = CachedFunction::new(|idx: &[usize]| 0, &[2, 3]).unwrap();
845    /// assert_eq!(cf.key_type(), "u64");
846    /// ```
847    pub fn key_type(&self) -> &'static str {
848        self.cache.key_type_name()
849    }
850}
851
852#[cfg(test)]
853mod tests;