Skip to main content

tenferro_einsum/
cache.rs

1use std::mem::size_of;
2
3use crate::EinsumSubscripts;
4
5/// Stable family identifier for the standard tenferro einsum extension.
6pub const EINSUM_EXTENSION_FAMILY_ID: &str = "tenferro.einsum.v1";
7
8/// Compiler-side static contraction-plan cache name.
9pub(crate) const EINSUM_STATIC_PLANS_CACHE: &str = "static_plans";
10/// Compiler-side subscript parse cache name.
11pub(crate) const EINSUM_PARSE_CACHE: &str = "parse";
12/// Executor-side runtime contraction-plan cache name.
13pub(crate) const EINSUM_RUNTIME_PLANS_CACHE: &str = "runtime_plans";
14/// Executor-side compiled inner execution program cache name.
15pub(crate) const EINSUM_RUNTIME_EXEC_PROGRAMS_CACHE: &str = "runtime_exec_programs";
16/// EagerTensor expanded standard-op program cache name.
17#[cfg(feature = "autodiff")]
18pub(crate) const EINSUM_EAGER_EXPANDED_PROGRAMS_CACHE: &str = "eager_expanded_programs";
19
20/// Parsed einsum notation retained by parse caches.
21pub(crate) struct ParsedEinsum {
22    /// Canonical parsed subscripts.
23    pub(crate) subscripts: EinsumSubscripts,
24}
25
26/// Return the retained-byte estimate for canonical subscripts.
27#[must_use]
28pub(crate) fn einsum_subscripts_retained_bytes(subscripts: &EinsumSubscripts) -> usize {
29    saturating_sum([
30        vec_of_vec_retained_bytes(&subscripts.inputs),
31        vec_retained_bytes(&subscripts.output),
32    ])
33}
34
35pub(crate) fn vec_retained_bytes<T>(values: &Vec<T>) -> usize {
36    values.capacity().saturating_mul(size_of::<T>())
37}
38
39pub(crate) fn vec_of_vec_retained_bytes<T>(values: &[Vec<T>]) -> usize {
40    saturating_sum(values.iter().map(vec_retained_bytes))
41}
42
43pub(crate) fn saturating_sum(values: impl IntoIterator<Item = usize>) -> usize {
44    values.into_iter().fold(0usize, usize::saturating_add)
45}
46
47#[cfg(test)]
48mod tests {
49    use super::saturating_sum;
50
51    #[test]
52    fn retained_byte_sums_saturate() {
53        assert_eq!(saturating_sum([usize::MAX, 1]), usize::MAX);
54        assert_eq!(saturating_sum([usize::MAX - 4, 2, 8]), usize::MAX);
55    }
56}