1use std::mem::size_of;
2
3use crate::EinsumSubscripts;
4
5pub const EINSUM_EXTENSION_FAMILY_ID: &str = "tenferro.einsum.v1";
7
8pub(crate) const EINSUM_STATIC_PLANS_CACHE: &str = "static_plans";
10pub(crate) const EINSUM_PARSE_CACHE: &str = "parse";
12pub(crate) const EINSUM_RUNTIME_PLANS_CACHE: &str = "runtime_plans";
14pub(crate) const EINSUM_RUNTIME_EXEC_PROGRAMS_CACHE: &str = "runtime_exec_programs";
16#[cfg(feature = "autodiff")]
18pub(crate) const EINSUM_EAGER_EXPANDED_PROGRAMS_CACHE: &str = "eager_expanded_programs";
19
20pub(crate) struct ParsedEinsum {
22 pub(crate) subscripts: EinsumSubscripts,
24}
25
26#[must_use]
28pub(crate) fn einsum_subscripts_retained_bytes(subscripts: &EinsumSubscripts) -> usize {
29 saturating_sum([
30 vec_of_vec_retained_bytes(&subscripts.inputs),
31 vec_retained_bytes(&subscripts.output),
32 ])
33}
34
35pub(crate) fn vec_retained_bytes<T>(values: &Vec<T>) -> usize {
36 values.capacity().saturating_mul(size_of::<T>())
37}
38
39pub(crate) fn vec_of_vec_retained_bytes<T>(values: &[Vec<T>]) -> usize {
40 saturating_sum(values.iter().map(vec_retained_bytes))
41}
42
43pub(crate) fn saturating_sum(values: impl IntoIterator<Item = usize>) -> usize {
44 values.into_iter().fold(0usize, usize::saturating_add)
45}
46
47#[cfg(test)]
48mod tests {
49 use super::saturating_sum;
50
51 #[test]
52 fn retained_byte_sums_saturate() {
53 assert_eq!(saturating_sum([usize::MAX, 1]), usize::MAX);
54 assert_eq!(saturating_sum([usize::MAX - 4, 2, 8]), usize::MAX);
55 }
56}