Skip to main content

tensor4all_tcicore/cached_function/
cache_key.rs

1//! Cache key trait for flat-index computation.
2//!
3//! Built-in implementations: `u64`, `u128`, `U256`, `U512`, `U1024`.
4//!
5//! # Custom key types
6//!
7//! To support index spaces larger than 1024 bits, implement `CacheKey` for a
8//! wider integer type and pass it via `CachedFunction::with_key_type`:
9//!
10//! ```
11//! use bnum::types::U2048;
12//! use tensor4all_tcicore::{CacheKey, CachedFunction};
13//!
14//! #[derive(Clone, Hash, PartialEq, Eq)]
15//! struct U2048Key(U2048);
16//!
17//! impl CacheKey for U2048Key {
18//!     const BITS_COUNT: u32 = 2048;
19//!     const ZERO: Self = Self(U2048::ZERO);
20//!     const ONE: Self = Self(U2048::ONE);
21//!     fn from_usize(v: usize) -> Self { Self(U2048::from(v as u64)) }
22//!     fn checked_mul(self, rhs: Self) -> Option<Self> {
23//!         self.0.checked_mul(rhs.0).map(Self)
24//!     }
25//!     fn wrapping_add(self, rhs: Self) -> Self {
26//!         Self(self.0.wrapping_add(rhs.0))
27//!     }
28//! }
29//!
30//! let local_dims = vec![2usize; 1025];
31//! let cf = CachedFunction::with_key_type::<U2048Key>(
32//!     |idx: &[usize]| idx.iter().sum::<usize>(),
33//!     &local_dims,
34//! ).unwrap();
35//! let zeros = vec![0usize; 1025];
36//!
37//! assert_eq!(cf.eval(&zeros), 0);
38//! assert_eq!(cf.key_type(), "custom");
39//! ```
40
41use std::hash::Hash;
42
43use bnum::types::{U1024, U256, U512};
44
45/// Trait for cache key types used in flat-index computation.
46///
47/// Built-in implementations are provided for `u64`, `u128`, `U256`, `U512`,
48/// and `U1024`. For index spaces larger than 1024 bits, implement this trait
49/// for a wider integer type and use
50/// [`CachedFunction::with_key_type`](crate::CachedFunction::with_key_type).
51///
52/// See module documentation for a complete custom key example.
53pub trait CacheKey: Hash + Eq + Clone + Send + Sync + 'static {
54    /// Number of bits this key type can represent.
55    const BITS_COUNT: u32;
56    /// The zero value.
57    const ZERO: Self;
58    /// The one value.
59    const ONE: Self;
60
61    /// Convert a `usize` to this key type.
62    fn from_usize(v: usize) -> Self;
63
64    /// Checked multiplication. Returns `None` on overflow.
65    fn checked_mul(self, rhs: Self) -> Option<Self>;
66
67    /// Wrapping addition (overflow wraps around).
68    fn wrapping_add(self, rhs: Self) -> Self;
69}
70
71impl CacheKey for u64 {
72    const BITS_COUNT: u32 = 64;
73    const ZERO: Self = 0;
74    const ONE: Self = 1;
75
76    fn from_usize(v: usize) -> Self {
77        v as u64
78    }
79
80    fn checked_mul(self, rhs: Self) -> Option<Self> {
81        self.checked_mul(rhs)
82    }
83
84    fn wrapping_add(self, rhs: Self) -> Self {
85        self.wrapping_add(rhs)
86    }
87}
88
89impl CacheKey for u128 {
90    const BITS_COUNT: u32 = 128;
91    const ZERO: Self = 0;
92    const ONE: Self = 1;
93
94    fn from_usize(v: usize) -> Self {
95        v as u128
96    }
97
98    fn checked_mul(self, rhs: Self) -> Option<Self> {
99        self.checked_mul(rhs)
100    }
101
102    fn wrapping_add(self, rhs: Self) -> Self {
103        self.wrapping_add(rhs)
104    }
105}
106
107macro_rules! impl_cache_key_bnum {
108    ($ty:ty, $bits:expr) => {
109        impl CacheKey for $ty {
110            const BITS_COUNT: u32 = $bits;
111            const ZERO: Self = <$ty>::ZERO;
112            const ONE: Self = <$ty>::ONE;
113
114            fn from_usize(v: usize) -> Self {
115                <$ty>::from(v as u64)
116            }
117
118            fn checked_mul(self, rhs: Self) -> Option<Self> {
119                self.checked_mul(rhs)
120            }
121
122            fn wrapping_add(self, rhs: Self) -> Self {
123                self.wrapping_add(rhs)
124            }
125        }
126    };
127}
128
129impl_cache_key_bnum!(U256, 256);
130impl_cache_key_bnum!(U512, 512);
131impl_cache_key_bnum!(U1024, 1024);