tensor4all_tcicore/cached_function/cache_key.rs
1//! Cache key trait for flat-index computation.
2//!
3//! Built-in implementations: `u64`, `u128`, `U256`, `U512`, `U1024`.
4//!
5//! # Custom key types
6//!
7//! To support index spaces larger than 1024 bits, implement `CacheKey` for a
8//! wider integer type and pass it via `CachedFunction::with_key_type`:
9//!
10//! ```
11//! use bnum::types::U2048;
12//! use tensor4all_tcicore::{CacheKey, CachedFunction};
13//!
14//! #[derive(Clone, Hash, PartialEq, Eq)]
15//! struct U2048Key(U2048);
16//!
17//! impl CacheKey for U2048Key {
18//! const BITS_COUNT: u32 = 2048;
19//! const ZERO: Self = Self(U2048::ZERO);
20//! const ONE: Self = Self(U2048::ONE);
21//! fn from_usize(v: usize) -> Self { Self(U2048::from(v as u64)) }
22//! fn checked_mul(self, rhs: Self) -> Option<Self> {
23//! self.0.checked_mul(rhs.0).map(Self)
24//! }
25//! fn wrapping_add(self, rhs: Self) -> Self {
26//! Self(self.0.wrapping_add(rhs.0))
27//! }
28//! }
29//!
30//! let local_dims = vec![2usize; 1025];
31//! let cf = CachedFunction::with_key_type::<U2048Key>(
32//! |idx: &[usize]| idx.iter().sum::<usize>(),
33//! &local_dims,
34//! ).unwrap();
35//! let zeros = vec![0usize; 1025];
36//!
37//! assert_eq!(cf.eval(&zeros), 0);
38//! assert_eq!(cf.key_type(), "custom");
39//! ```
40
41use std::hash::Hash;
42
43use bnum::types::{U1024, U256, U512};
44
45/// Trait for cache key types used in flat-index computation.
46///
47/// Built-in implementations are provided for `u64`, `u128`, `U256`, `U512`,
48/// and `U1024`. For index spaces larger than 1024 bits, implement this trait
49/// for a wider integer type and use
50/// [`CachedFunction::with_key_type`](crate::CachedFunction::with_key_type).
51///
52/// See module documentation for a complete custom key example.
53pub trait CacheKey: Hash + Eq + Clone + Send + Sync + 'static {
54 /// Number of bits this key type can represent.
55 const BITS_COUNT: u32;
56 /// The zero value.
57 const ZERO: Self;
58 /// The one value.
59 const ONE: Self;
60
61 /// Convert a `usize` to this key type.
62 fn from_usize(v: usize) -> Self;
63
64 /// Checked multiplication. Returns `None` on overflow.
65 fn checked_mul(self, rhs: Self) -> Option<Self>;
66
67 /// Wrapping addition (overflow wraps around).
68 fn wrapping_add(self, rhs: Self) -> Self;
69}
70
71impl CacheKey for u64 {
72 const BITS_COUNT: u32 = 64;
73 const ZERO: Self = 0;
74 const ONE: Self = 1;
75
76 fn from_usize(v: usize) -> Self {
77 v as u64
78 }
79
80 fn checked_mul(self, rhs: Self) -> Option<Self> {
81 self.checked_mul(rhs)
82 }
83
84 fn wrapping_add(self, rhs: Self) -> Self {
85 self.wrapping_add(rhs)
86 }
87}
88
89impl CacheKey for u128 {
90 const BITS_COUNT: u32 = 128;
91 const ZERO: Self = 0;
92 const ONE: Self = 1;
93
94 fn from_usize(v: usize) -> Self {
95 v as u128
96 }
97
98 fn checked_mul(self, rhs: Self) -> Option<Self> {
99 self.checked_mul(rhs)
100 }
101
102 fn wrapping_add(self, rhs: Self) -> Self {
103 self.wrapping_add(rhs)
104 }
105}
106
107macro_rules! impl_cache_key_bnum {
108 ($ty:ty, $bits:expr) => {
109 impl CacheKey for $ty {
110 const BITS_COUNT: u32 = $bits;
111 const ZERO: Self = <$ty>::ZERO;
112 const ONE: Self = <$ty>::ONE;
113
114 fn from_usize(v: usize) -> Self {
115 <$ty>::from(v as u64)
116 }
117
118 fn checked_mul(self, rhs: Self) -> Option<Self> {
119 self.checked_mul(rhs)
120 }
121
122 fn wrapping_add(self, rhs: Self) -> Self {
123 self.wrapping_add(rhs)
124 }
125 }
126 };
127}
128
129impl_cache_key_bnum!(U256, 256);
130impl_cache_key_bnum!(U512, 512);
131impl_cache_key_bnum!(U1024, 1024);