tensor4all_tcicore/cached_function/mod.rs
1//! Cached function wrapper for expensive function evaluations.
2//!
3//! [`CachedFunction`] wraps a user-supplied function `Fn(&[I]) -> V` with
4//! thread-safe memoization. On first call for a given multi-index, the
5//! function is evaluated and the result is cached; subsequent calls return
6//! the cached value.
7//!
8//! The internal cache key type is automatically selected based on the total
9//! index space size (up to 1024 bits by default). For larger index spaces,
10//! use [`CachedFunction::with_key_type`] with a custom [`CacheKey`]
11//! implementation.
12//!
13//! # Examples
14//!
15//! ```
16//! use tensor4all_tcicore::CachedFunction;
17//!
18//! let cf = CachedFunction::new(
19//! |idx: &[usize]| idx[0] + idx[1],
20//! &[3, 4],
21//! ).unwrap();
22//!
23//! assert_eq!(cf.eval(&[1, 2]), 3);
24//! assert_eq!(cf.num_evals(), 1);
25//! assert_eq!(cf.eval(&[1, 2]), 3);
26//! assert_eq!(cf.num_cache_hits(), 1);
27//! ```
28
29pub mod cache_key;
30pub mod error;
31pub mod index_int;
32
33use std::collections::HashMap;
34use std::sync::atomic::{AtomicUsize, Ordering};
35use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
36
37use bnum::types::{U1024, U256, U512};
38
39use cache_key::CacheKey;
40use index_int::IndexInt;
41
42/// Compute total bits needed to represent the index space.
43pub(crate) fn total_bits(local_dims: &[usize]) -> u32 {
44 local_dims
45 .iter()
46 .map(|&d| {
47 if d <= 1 {
48 0
49 } else {
50 ((d - 1) as u64).ilog2() + 1
51 }
52 })
53 .sum()
54}
55
56/// Compute mixed-radix coefficients for flat index computation.
57///
58/// Returns `Err(CacheKeyError::Overflow)` if the index space overflows key type `K`.
59pub(crate) fn compute_coeffs<K: CacheKey>(
60 local_dims: &[usize],
61) -> Result<Vec<K>, error::CacheKeyError> {
62 let bits = total_bits(local_dims);
63 if bits > K::BITS_COUNT {
64 return Err(error::CacheKeyError::Overflow {
65 total_bits: bits,
66 max_bits: K::BITS_COUNT,
67 key_type: std::any::type_name::<K>(),
68 });
69 }
70
71 let mut coeffs = Vec::with_capacity(local_dims.len());
72 let mut prod = K::ONE;
73 for &d in local_dims {
74 coeffs.push(prod.clone());
75 let dim = K::from_usize(d);
76 prod = prod
77 .checked_mul(dim)
78 .ok_or_else(|| error::CacheKeyError::Overflow {
79 total_bits: bits,
80 max_bits: K::BITS_COUNT,
81 key_type: std::any::type_name::<K>(),
82 })?;
83 }
84
85 Ok(coeffs)
86}
87
88fn read_lock<T>(lock: &RwLock<T>) -> RwLockReadGuard<'_, T> {
89 lock.read().unwrap_or_else(|poisoned| poisoned.into_inner())
90}
91
92fn write_lock<T>(lock: &RwLock<T>) -> RwLockWriteGuard<'_, T> {
93 lock.write()
94 .unwrap_or_else(|poisoned| poisoned.into_inner())
95}
96
97/// Compute flat index from multi-index and coefficients.
98fn flat_index<K: CacheKey, I: IndexInt>(idx: &[I], coeffs: &[K]) -> K {
99 idx.iter().zip(coeffs).fold(K::ZERO, |acc, (&i, c)| {
100 acc.wrapping_add(
101 c.clone()
102 .checked_mul(K::from_usize(i.to_usize()))
103 .unwrap_or(K::ZERO),
104 )
105 })
106}
107
108/// Internal cache with automatically selected key type.
109enum InnerCache<V> {
110 U64 {
111 cache: RwLock<HashMap<u64, V>>,
112 coeffs: Vec<u64>,
113 },
114 U128 {
115 cache: RwLock<HashMap<u128, V>>,
116 coeffs: Vec<u128>,
117 },
118 U256 {
119 cache: RwLock<HashMap<U256, V>>,
120 coeffs: Vec<U256>,
121 },
122 U512 {
123 cache: RwLock<HashMap<U512, V>>,
124 coeffs: Vec<U512>,
125 },
126 U1024 {
127 cache: RwLock<HashMap<U1024, V>>,
128 coeffs: Vec<U1024>,
129 },
130}
131
132impl<V: Clone + Send + Sync> InnerCache<V> {
133 /// Create a new cache, automatically selecting the key type.
134 fn new(local_dims: &[usize]) -> Result<Self, error::CacheKeyError> {
135 let bits = total_bits(local_dims);
136 if bits <= 64 {
137 Ok(Self::U64 {
138 cache: RwLock::new(HashMap::new()),
139 coeffs: compute_coeffs::<u64>(local_dims)?,
140 })
141 } else if bits <= 128 {
142 Ok(Self::U128 {
143 cache: RwLock::new(HashMap::new()),
144 coeffs: compute_coeffs::<u128>(local_dims)?,
145 })
146 } else if bits <= 256 {
147 Ok(Self::U256 {
148 cache: RwLock::new(HashMap::new()),
149 coeffs: compute_coeffs::<U256>(local_dims)?,
150 })
151 } else if bits <= 512 {
152 Ok(Self::U512 {
153 cache: RwLock::new(HashMap::new()),
154 coeffs: compute_coeffs::<U512>(local_dims)?,
155 })
156 } else if bits <= 1024 {
157 Ok(Self::U1024 {
158 cache: RwLock::new(HashMap::new()),
159 coeffs: compute_coeffs::<U1024>(local_dims)?,
160 })
161 } else {
162 Err(error::CacheKeyError::Overflow {
163 total_bits: bits,
164 max_bits: 1024,
165 key_type: "auto",
166 })
167 }
168 }
169
170 fn get<I: IndexInt>(&self, idx: &[I]) -> Option<V> {
171 match self {
172 Self::U64 { cache, coeffs } => read_lock(cache).get(&flat_index(idx, coeffs)).cloned(),
173 Self::U128 { cache, coeffs } => read_lock(cache).get(&flat_index(idx, coeffs)).cloned(),
174 Self::U256 { cache, coeffs } => read_lock(cache).get(&flat_index(idx, coeffs)).cloned(),
175 Self::U512 { cache, coeffs } => read_lock(cache).get(&flat_index(idx, coeffs)).cloned(),
176 Self::U1024 { cache, coeffs } => {
177 read_lock(cache).get(&flat_index(idx, coeffs)).cloned()
178 }
179 }
180 }
181
182 fn insert<I: IndexInt>(&self, idx: &[I], value: V) {
183 match self {
184 Self::U64 { cache, coeffs } => {
185 write_lock(cache).insert(flat_index(idx, coeffs), value);
186 }
187 Self::U128 { cache, coeffs } => {
188 write_lock(cache).insert(flat_index(idx, coeffs), value);
189 }
190 Self::U256 { cache, coeffs } => {
191 write_lock(cache).insert(flat_index(idx, coeffs), value);
192 }
193 Self::U512 { cache, coeffs } => {
194 write_lock(cache).insert(flat_index(idx, coeffs), value);
195 }
196 Self::U1024 { cache, coeffs } => {
197 write_lock(cache).insert(flat_index(idx, coeffs), value);
198 }
199 }
200 }
201
202 fn contains<I: IndexInt>(&self, idx: &[I]) -> bool {
203 match self {
204 Self::U64 { cache, coeffs } => read_lock(cache).contains_key(&flat_index(idx, coeffs)),
205 Self::U128 { cache, coeffs } => read_lock(cache).contains_key(&flat_index(idx, coeffs)),
206 Self::U256 { cache, coeffs } => read_lock(cache).contains_key(&flat_index(idx, coeffs)),
207 Self::U512 { cache, coeffs } => read_lock(cache).contains_key(&flat_index(idx, coeffs)),
208 Self::U1024 { cache, coeffs } => {
209 read_lock(cache).contains_key(&flat_index(idx, coeffs))
210 }
211 }
212 }
213
214 fn len(&self) -> usize {
215 match self {
216 Self::U64 { cache, .. } => read_lock(cache).len(),
217 Self::U128 { cache, .. } => read_lock(cache).len(),
218 Self::U256 { cache, .. } => read_lock(cache).len(),
219 Self::U512 { cache, .. } => read_lock(cache).len(),
220 Self::U1024 { cache, .. } => read_lock(cache).len(),
221 }
222 }
223
224 fn clear(&self) {
225 match self {
226 Self::U64 { cache, .. } => write_lock(cache).clear(),
227 Self::U128 { cache, .. } => write_lock(cache).clear(),
228 Self::U256 { cache, .. } => write_lock(cache).clear(),
229 Self::U512 { cache, .. } => write_lock(cache).clear(),
230 Self::U1024 { cache, .. } => write_lock(cache).clear(),
231 }
232 }
233
234 fn key_type_name(&self) -> &'static str {
235 match self {
236 Self::U64 { .. } => "u64",
237 Self::U128 { .. } => "u128",
238 Self::U256 { .. } => "U256",
239 Self::U512 { .. } => "U512",
240 Self::U1024 { .. } => "U1024",
241 }
242 }
243}
244
245/// Type-erased cache interface for custom key types.
246trait DynCache<V>: Send + Sync {
247 fn get(&self, idx: &[usize]) -> Option<V>;
248 fn insert(&self, idx: &[usize], value: V);
249 fn contains(&self, idx: &[usize]) -> bool;
250 fn len(&self) -> usize;
251 fn clear(&self);
252}
253
254/// Generic cache for user-specified key types.
255struct GenericCache<K: CacheKey, V> {
256 cache: RwLock<HashMap<K, V>>,
257 coeffs: Vec<K>,
258}
259
260impl<K: CacheKey, V: Clone + Send + Sync> GenericCache<K, V> {
261 fn new(local_dims: &[usize]) -> Result<Self, error::CacheKeyError> {
262 Ok(Self {
263 cache: RwLock::new(HashMap::new()),
264 coeffs: compute_coeffs::<K>(local_dims)?,
265 })
266 }
267}
268
269impl<K: CacheKey, V: Clone + Send + Sync> DynCache<V> for GenericCache<K, V> {
270 fn get(&self, idx: &[usize]) -> Option<V> {
271 let key = flat_index::<K, usize>(idx, &self.coeffs);
272 read_lock(&self.cache).get(&key).cloned()
273 }
274
275 fn insert(&self, idx: &[usize], value: V) {
276 let key = flat_index::<K, usize>(idx, &self.coeffs);
277 write_lock(&self.cache).insert(key, value);
278 }
279
280 fn contains(&self, idx: &[usize]) -> bool {
281 let key = flat_index::<K, usize>(idx, &self.coeffs);
282 read_lock(&self.cache).contains_key(&key)
283 }
284
285 fn len(&self) -> usize {
286 read_lock(&self.cache).len()
287 }
288
289 fn clear(&self) {
290 write_lock(&self.cache).clear();
291 }
292}
293
294/// Internal backend: auto-selected enum or custom type-erased cache.
295enum CacheBackend<V: Clone + Send + Sync + 'static> {
296 Auto(InnerCache<V>),
297 Custom(Box<dyn DynCache<V>>),
298}
299
300impl<V: Clone + Send + Sync + 'static> CacheBackend<V> {
301 fn get<I: IndexInt>(&self, idx: &[I]) -> Option<V> {
302 match self {
303 Self::Auto(inner) => inner.get(idx),
304 Self::Custom(cache) => {
305 let usize_idx: Vec<usize> = idx.iter().map(|&i| i.to_usize()).collect();
306 cache.get(&usize_idx)
307 }
308 }
309 }
310
311 fn insert<I: IndexInt>(&self, idx: &[I], value: V) {
312 match self {
313 Self::Auto(inner) => inner.insert(idx, value),
314 Self::Custom(cache) => {
315 let usize_idx: Vec<usize> = idx.iter().map(|&i| i.to_usize()).collect();
316 cache.insert(&usize_idx, value);
317 }
318 }
319 }
320
321 fn contains<I: IndexInt>(&self, idx: &[I]) -> bool {
322 match self {
323 Self::Auto(inner) => inner.contains(idx),
324 Self::Custom(cache) => {
325 let usize_idx: Vec<usize> = idx.iter().map(|&i| i.to_usize()).collect();
326 cache.contains(&usize_idx)
327 }
328 }
329 }
330
331 fn len(&self) -> usize {
332 match self {
333 Self::Auto(inner) => inner.len(),
334 Self::Custom(cache) => cache.len(),
335 }
336 }
337
338 fn clear(&self) {
339 match self {
340 Self::Auto(inner) => inner.clear(),
341 Self::Custom(cache) => cache.clear(),
342 }
343 }
344
345 fn key_type_name(&self) -> &'static str {
346 match self {
347 Self::Auto(inner) => inner.key_type_name(),
348 Self::Custom(_) => "custom",
349 }
350 }
351}
352
353type BatchFunc<I, V> = dyn Fn(&[Vec<I>]) -> Vec<V> + Send + Sync;
354
355/// A wrapper that caches function evaluations for multi-index inputs.
356///
357/// Thread-safe: all methods take `&self`. Multiple threads can call `eval`
358/// concurrently.
359///
360/// # Type parameters
361///
362/// - `V` - cached value type
363/// - `F` - single-evaluation function `Fn(&[I]) -> V`
364/// - `I` - index element type (default `usize`); use `u8` for quantics
365///
366/// # Examples
367///
368/// ```
369/// use tensor4all_tcicore::CachedFunction;
370///
371/// // Cache a 2-site function with local dimensions [3, 4]
372/// let cf = CachedFunction::new(
373/// |idx: &[usize]| (idx[0] * 4 + idx[1]) as f64,
374/// &[3, 4],
375/// ).unwrap();
376///
377/// // First call evaluates and caches
378/// let v00 = cf.eval(&[0, 0]);
379/// assert_eq!(v00, 0.0);
380/// assert_eq!(cf.num_evals(), 1);
381/// assert_eq!(cf.num_cache_hits(), 0);
382///
383/// // Second call uses cache
384/// let v00_again = cf.eval(&[0, 0]);
385/// assert_eq!(v00_again, 0.0);
386/// assert_eq!(cf.num_cache_hits(), 1);
387///
388/// let v12 = cf.eval(&[1, 2]);
389/// assert_eq!(v12, 6.0); // 1*4 + 2
390/// ```
391pub struct CachedFunction<V, F, I = usize>
392where
393 I: IndexInt,
394 V: Clone + Send + Sync + 'static,
395 F: Fn(&[I]) -> V + Send + Sync,
396{
397 func: F,
398 batch_func: Option<Box<BatchFunc<I, V>>>,
399 cache: CacheBackend<V>,
400 local_dims: Vec<usize>,
401 num_evals: AtomicUsize,
402 num_cache_hits: AtomicUsize,
403 _phantom: std::marker::PhantomData<I>,
404}
405
406impl<V, F, I> CachedFunction<V, F, I>
407where
408 I: IndexInt,
409 V: Clone + Send + Sync + 'static,
410 F: Fn(&[I]) -> V + Send + Sync,
411{
412 /// Create a new cached function with automatic key selection (up to 1024 bits).
413 ///
414 /// # Examples
415 ///
416 /// ```
417 /// use tensor4all_tcicore::CachedFunction;
418 ///
419 /// let cf = CachedFunction::new(|idx: &[usize]| idx[0] + idx[1], &[2, 3]).unwrap();
420 /// assert_eq!(cf.eval(&[1, 2]), 3);
421 /// assert_eq!(cf.num_sites(), 2);
422 /// assert_eq!(cf.local_dims(), &[2, 3]);
423 /// ```
424 pub fn new(func: F, local_dims: &[usize]) -> Result<Self, error::CacheKeyError> {
425 Ok(Self {
426 func,
427 batch_func: None,
428 cache: CacheBackend::Auto(InnerCache::new(local_dims)?),
429 local_dims: local_dims.to_vec(),
430 num_evals: AtomicUsize::new(0),
431 num_cache_hits: AtomicUsize::new(0),
432 _phantom: std::marker::PhantomData,
433 })
434 }
435
436 /// Create with a batch function for efficient multi-point evaluation.
437 ///
438 /// The batch function is used for cache misses during [`eval_batch`](Self::eval_batch)
439 /// calls, enabling amortized cost when evaluating many indices at once
440 /// (e.g., batch FFI calls or vectorized computations).
441 ///
442 /// # Examples
443 ///
444 /// ```
445 /// use tensor4all_tcicore::CachedFunction;
446 ///
447 /// let cf = CachedFunction::with_batch(
448 /// |idx: &[usize]| idx[0] * 10 + idx[1],
449 /// |indices: &[Vec<usize>]| indices.iter().map(|idx| idx[0] * 10 + idx[1]).collect(),
450 /// &[3, 4],
451 /// ).unwrap();
452 ///
453 /// let results = cf.eval_batch(&[vec![0, 1], vec![2, 3]]);
454 /// assert_eq!(results, vec![1, 23]);
455 /// assert_eq!(cf.num_evals(), 2);
456 /// ```
457 pub fn with_batch<B>(
458 func: F,
459 batch_func: B,
460 local_dims: &[usize],
461 ) -> Result<Self, error::CacheKeyError>
462 where
463 B: Fn(&[Vec<I>]) -> Vec<V> + Send + Sync + 'static,
464 {
465 Ok(Self {
466 func,
467 batch_func: Some(Box::new(batch_func)),
468 cache: CacheBackend::Auto(InnerCache::new(local_dims)?),
469 local_dims: local_dims.to_vec(),
470 num_evals: AtomicUsize::new(0),
471 num_cache_hits: AtomicUsize::new(0),
472 _phantom: std::marker::PhantomData,
473 })
474 }
475
476 /// Create with an explicit key type for index spaces larger than 1024 bits.
477 ///
478 /// # Example
479 ///
480 /// ```
481 /// use bnum::types::U2048;
482 /// use tensor4all_tcicore::{CacheKey, CachedFunction};
483 ///
484 /// #[derive(Clone, Hash, PartialEq, Eq)]
485 /// struct U2048Key(U2048);
486 ///
487 /// impl CacheKey for U2048Key {
488 /// const BITS_COUNT: u32 = 2048;
489 /// const ZERO: Self = Self(U2048::ZERO);
490 /// const ONE: Self = Self(U2048::ONE);
491 ///
492 /// fn from_usize(v: usize) -> Self {
493 /// Self(U2048::from(v as u64))
494 /// }
495 ///
496 /// fn checked_mul(self, rhs: Self) -> Option<Self> {
497 /// self.0.checked_mul(rhs.0).map(Self)
498 /// }
499 ///
500 /// fn wrapping_add(self, rhs: Self) -> Self {
501 /// Self(self.0.wrapping_add(rhs.0))
502 /// }
503 /// }
504 ///
505 /// let local_dims = vec![2usize; 1025];
506 /// let cf = CachedFunction::with_key_type::<U2048Key>(
507 /// |idx: &[usize]| idx.iter().sum::<usize>(),
508 /// &local_dims,
509 /// ).unwrap();
510 /// let zeros = vec![0usize; 1025];
511 ///
512 /// assert_eq!(cf.eval(&zeros), 0);
513 /// assert_eq!(cf.key_type(), "custom");
514 /// ```
515 pub fn with_key_type<K: CacheKey>(
516 func: F,
517 local_dims: &[usize],
518 ) -> Result<Self, error::CacheKeyError> {
519 Ok(Self {
520 func,
521 batch_func: None,
522 cache: CacheBackend::Custom(Box::new(GenericCache::<K, V>::new(local_dims)?)),
523 local_dims: local_dims.to_vec(),
524 num_evals: AtomicUsize::new(0),
525 num_cache_hits: AtomicUsize::new(0),
526 _phantom: std::marker::PhantomData,
527 })
528 }
529
530 /// Create with explicit key type and batch function.
531 ///
532 /// Combines [`with_key_type`](Self::with_key_type) and
533 /// [`with_batch`](Self::with_batch) for index spaces larger than 1024
534 /// bits that also benefit from batch evaluation.
535 ///
536 /// # Examples
537 ///
538 /// ```
539 /// use tensor4all_tcicore::CachedFunction;
540 ///
541 /// // Use u128 key type with batch support
542 /// let cf = CachedFunction::with_key_type_and_batch::<u128, _>(
543 /// |idx: &[usize]| idx.iter().sum::<usize>(),
544 /// |indices: &[Vec<usize>]| indices.iter().map(|idx| idx.iter().sum()).collect(),
545 /// &[2, 3, 4],
546 /// ).unwrap();
547 ///
548 /// let results = cf.eval_batch(&[vec![0, 0, 0], vec![1, 2, 3]]);
549 /// assert_eq!(results, vec![0, 6]);
550 /// ```
551 pub fn with_key_type_and_batch<K: CacheKey, B>(
552 func: F,
553 batch_func: B,
554 local_dims: &[usize],
555 ) -> Result<Self, error::CacheKeyError>
556 where
557 B: Fn(&[Vec<I>]) -> Vec<V> + Send + Sync + 'static,
558 {
559 Ok(Self {
560 func,
561 batch_func: Some(Box::new(batch_func)),
562 cache: CacheBackend::Custom(Box::new(GenericCache::<K, V>::new(local_dims)?)),
563 local_dims: local_dims.to_vec(),
564 num_evals: AtomicUsize::new(0),
565 num_cache_hits: AtomicUsize::new(0),
566 _phantom: std::marker::PhantomData,
567 })
568 }
569
570 /// Evaluate at a given index, using cache if available.
571 ///
572 /// On the first call for a given index, the wrapped function is invoked and
573 /// the result is cached. Subsequent calls with the same index return the
574 /// cached value. This method is thread-safe.
575 ///
576 /// # Examples
577 ///
578 /// ```
579 /// use tensor4all_tcicore::CachedFunction;
580 ///
581 /// let cf = CachedFunction::new(|idx: &[usize]| idx[0] * idx[1], &[5, 5]).unwrap();
582 /// assert_eq!(cf.eval(&[3, 4]), 12);
583 /// assert_eq!(cf.num_evals(), 1);
584 ///
585 /// // Cache hit
586 /// assert_eq!(cf.eval(&[3, 4]), 12);
587 /// assert_eq!(cf.num_evals(), 1);
588 /// assert_eq!(cf.num_cache_hits(), 1);
589 /// ```
590 pub fn eval(&self, idx: &[I]) -> V {
591 if let Some(value) = self.cache.get(idx) {
592 self.num_cache_hits.fetch_add(1, Ordering::Relaxed);
593 return value;
594 }
595
596 self.num_evals.fetch_add(1, Ordering::Relaxed);
597 let value = (self.func)(idx);
598 self.cache.insert(idx, value.clone());
599 value
600 }
601
602 /// Evaluate bypassing the cache.
603 ///
604 /// The result is neither read from nor stored in the cache, and
605 /// evaluation counters are not updated. Useful for verification or
606 /// when the caller intentionally wants a fresh evaluation.
607 ///
608 /// # Examples
609 ///
610 /// ```
611 /// use tensor4all_tcicore::CachedFunction;
612 ///
613 /// let cf = CachedFunction::new(|idx: &[usize]| idx[0] + 1, &[4]).unwrap();
614 /// assert_eq!(cf.eval_no_cache(&[2]), 3);
615 /// assert_eq!(cf.cache_size(), 0);
616 /// assert_eq!(cf.num_evals(), 0);
617 /// ```
618 pub fn eval_no_cache(&self, idx: &[I]) -> V {
619 (self.func)(idx)
620 }
621
622 /// Evaluate at multiple indices. Uses batch function for cache misses if available.
623 ///
624 /// Returns results in the same order as the input indices.
625 ///
626 /// # Examples
627 ///
628 /// ```
629 /// use tensor4all_tcicore::CachedFunction;
630 ///
631 /// let cf = CachedFunction::new(|idx: &[usize]| idx[0] * 2 + idx[1], &[2, 2]).unwrap();
632 /// let results = cf.eval_batch(&[vec![0, 0], vec![0, 1], vec![1, 0]]);
633 /// assert_eq!(results, vec![0, 1, 2]);
634 /// ```
635 pub fn eval_batch(&self, indices: &[Vec<I>]) -> Vec<V> {
636 if indices.is_empty() {
637 return Vec::new();
638 }
639
640 let mut results: Vec<Option<V>> = Vec::with_capacity(indices.len());
641 let mut miss_positions: Vec<usize> = Vec::new();
642 let mut miss_indices: Vec<Vec<I>> = Vec::new();
643
644 for (pos, idx) in indices.iter().enumerate() {
645 if let Some(value) = self.cache.get(idx) {
646 self.num_cache_hits.fetch_add(1, Ordering::Relaxed);
647 results.push(Some(value));
648 } else {
649 results.push(None);
650 miss_positions.push(pos);
651 miss_indices.push(idx.clone());
652 }
653 }
654
655 if miss_indices.is_empty() {
656 return results.into_iter().flatten().collect();
657 }
658
659 self.num_evals
660 .fetch_add(miss_indices.len(), Ordering::Relaxed);
661 let miss_values = if let Some(batch_func) = self.batch_func.as_ref() {
662 batch_func(&miss_indices)
663 } else {
664 miss_indices.iter().map(|idx| (self.func)(idx)).collect()
665 };
666
667 for (i, pos) in miss_positions.iter().enumerate() {
668 self.cache.insert(&miss_indices[i], miss_values[i].clone());
669 results[*pos] = Some(miss_values[i].clone());
670 }
671
672 results.into_iter().flatten().collect()
673 }
674
675 /// Get the local dimensions.
676 ///
677 /// # Examples
678 ///
679 /// ```
680 /// use tensor4all_tcicore::CachedFunction;
681 ///
682 /// let cf = CachedFunction::new(|idx: &[usize]| 0, &[3, 4, 5]).unwrap();
683 /// assert_eq!(cf.local_dims(), &[3, 4, 5]);
684 /// ```
685 pub fn local_dims(&self) -> &[usize] {
686 &self.local_dims
687 }
688
689 /// Get the number of sites (length of the multi-index).
690 ///
691 /// # Examples
692 ///
693 /// ```
694 /// use tensor4all_tcicore::CachedFunction;
695 ///
696 /// let cf = CachedFunction::new(|idx: &[usize]| 0, &[2, 3]).unwrap();
697 /// assert_eq!(cf.num_sites(), 2);
698 /// ```
699 pub fn num_sites(&self) -> usize {
700 self.local_dims.len()
701 }
702
703 /// Get the number of function evaluations (cache misses).
704 ///
705 /// # Examples
706 ///
707 /// ```
708 /// use tensor4all_tcicore::CachedFunction;
709 ///
710 /// let cf = CachedFunction::new(|idx: &[usize]| idx[0], &[4]).unwrap();
711 /// cf.eval(&[0]);
712 /// cf.eval(&[1]);
713 /// cf.eval(&[0]); // cache hit, not a new eval
714 /// assert_eq!(cf.num_evals(), 2);
715 /// ```
716 pub fn num_evals(&self) -> usize {
717 self.num_evals.load(Ordering::Relaxed)
718 }
719
720 /// Get the number of cache hits.
721 ///
722 /// # Examples
723 ///
724 /// ```
725 /// use tensor4all_tcicore::CachedFunction;
726 ///
727 /// let cf = CachedFunction::new(|idx: &[usize]| idx[0], &[4]).unwrap();
728 /// cf.eval(&[0]);
729 /// assert_eq!(cf.num_cache_hits(), 0);
730 /// cf.eval(&[0]);
731 /// assert_eq!(cf.num_cache_hits(), 1);
732 /// ```
733 pub fn num_cache_hits(&self) -> usize {
734 self.num_cache_hits.load(Ordering::Relaxed)
735 }
736
737 /// Get total calls (evaluations + cache hits).
738 ///
739 /// # Examples
740 ///
741 /// ```
742 /// use tensor4all_tcicore::CachedFunction;
743 ///
744 /// let cf = CachedFunction::new(|idx: &[usize]| idx[0], &[4]).unwrap();
745 /// cf.eval(&[0]);
746 /// cf.eval(&[1]);
747 /// cf.eval(&[0]); // cache hit
748 /// assert_eq!(cf.total_calls(), 3);
749 /// assert_eq!(cf.total_calls(), cf.num_evals() + cf.num_cache_hits());
750 /// ```
751 pub fn total_calls(&self) -> usize {
752 self.num_evals() + self.num_cache_hits()
753 }
754
755 /// Get cache hit ratio (0.0 when no calls have been made).
756 ///
757 /// Returns `num_cache_hits() / total_calls()` as a value in `[0.0, 1.0]`.
758 ///
759 /// # Examples
760 ///
761 /// ```
762 /// use tensor4all_tcicore::CachedFunction;
763 ///
764 /// let cf = CachedFunction::new(|idx: &[usize]| idx[0], &[4]).unwrap();
765 /// assert_eq!(cf.cache_hit_ratio(), 0.0); // no calls yet
766 ///
767 /// cf.eval(&[0]);
768 /// cf.eval(&[0]); // cache hit
769 /// assert!((cf.cache_hit_ratio() - 0.5).abs() < 1e-10);
770 /// ```
771 pub fn cache_hit_ratio(&self) -> f64 {
772 let total = self.total_calls();
773 if total == 0 {
774 0.0
775 } else {
776 self.num_cache_hits() as f64 / total as f64
777 }
778 }
779
780 /// Clear the cache.
781 ///
782 /// # Examples
783 ///
784 /// ```
785 /// use tensor4all_tcicore::CachedFunction;
786 ///
787 /// let cf = CachedFunction::new(|idx: &[usize]| idx[0], &[4]).unwrap();
788 /// cf.eval(&[2]);
789 /// assert_eq!(cf.cache_size(), 1);
790 /// cf.clear_cache();
791 /// assert_eq!(cf.cache_size(), 0);
792 /// ```
793 pub fn clear_cache(&self) {
794 self.cache.clear();
795 }
796
797 /// Number of cached entries.
798 ///
799 /// # Examples
800 ///
801 /// ```
802 /// use tensor4all_tcicore::CachedFunction;
803 ///
804 /// let cf = CachedFunction::new(|idx: &[usize]| idx[0], &[4]).unwrap();
805 /// assert_eq!(cf.cache_size(), 0);
806 /// cf.eval(&[0]);
807 /// cf.eval(&[1]);
808 /// assert_eq!(cf.cache_size(), 2);
809 /// cf.eval(&[0]); // cache hit, no new entry
810 /// assert_eq!(cf.cache_size(), 2);
811 /// ```
812 pub fn cache_size(&self) -> usize {
813 self.cache.len()
814 }
815
816 /// Check if an index is cached.
817 ///
818 /// # Examples
819 ///
820 /// ```
821 /// use tensor4all_tcicore::CachedFunction;
822 ///
823 /// let cf = CachedFunction::new(|idx: &[usize]| idx[0], &[4]).unwrap();
824 /// assert!(!cf.is_cached(&[1]));
825 /// cf.eval(&[1]);
826 /// assert!(cf.is_cached(&[1]));
827 /// ```
828 pub fn is_cached(&self, idx: &[I]) -> bool {
829 self.cache.contains(idx)
830 }
831
832 /// Internal key type name (for debugging).
833 ///
834 /// Returns `"u64"`, `"u128"`, `"U256"`, `"U512"`, `"U1024"` for
835 /// automatically selected types, or `"custom"` when constructed with
836 /// [`with_key_type`](Self::with_key_type).
837 ///
838 /// # Examples
839 ///
840 /// ```
841 /// use tensor4all_tcicore::CachedFunction;
842 ///
843 /// // Small index space uses u64
844 /// let cf = CachedFunction::new(|idx: &[usize]| 0, &[2, 3]).unwrap();
845 /// assert_eq!(cf.key_type(), "u64");
846 /// ```
847 pub fn key_type(&self) -> &'static str {
848 self.cache.key_type_name()
849 }
850}
851
852#[cfg(test)]
853mod tests;