tenferro_prims/infra/
plan_cache.rs

1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::hash::Hash;
4
5use crate::{
6    AnalyticPrimsDescriptor, ScalarPrimsDescriptor, SemiringCoreDescriptor,
7    SemiringFastPathDescriptor,
8};
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
11pub(crate) enum PlanCacheDescriptor {
12    SemiringCore(SemiringCoreDescriptor),
13    SemiringFastPath(SemiringFastPathDescriptor),
14    Scalar(ScalarPrimsDescriptor),
15    Analytic(AnalyticPrimsDescriptor),
16}
17
18pub(crate) trait CacheDescriptor: Clone + Eq + Hash + 'static {
19    fn into_cache_descriptor(self) -> PlanCacheDescriptor;
20}
21
22impl CacheDescriptor for SemiringCoreDescriptor {
23    fn into_cache_descriptor(self) -> PlanCacheDescriptor {
24        PlanCacheDescriptor::SemiringCore(self)
25    }
26}
27
28impl CacheDescriptor for SemiringFastPathDescriptor {
29    fn into_cache_descriptor(self) -> PlanCacheDescriptor {
30        PlanCacheDescriptor::SemiringFastPath(self)
31    }
32}
33
34impl CacheDescriptor for ScalarPrimsDescriptor {
35    fn into_cache_descriptor(self) -> PlanCacheDescriptor {
36        PlanCacheDescriptor::Scalar(self)
37    }
38}
39
40impl CacheDescriptor for AnalyticPrimsDescriptor {
41    fn into_cache_descriptor(self) -> PlanCacheDescriptor {
42        PlanCacheDescriptor::Analytic(self)
43    }
44}
45
46#[derive(Debug, Clone, PartialEq, Eq, Hash)]
47struct PlanCacheKey {
48    plan_type_id: TypeId,
49    descriptor_type_id: TypeId,
50    descriptor: PlanCacheDescriptor,
51    shapes: Vec<Vec<usize>>,
52}
53
54impl PlanCacheKey {
55    fn new<P: 'static, D: CacheDescriptor>(desc: &D, shapes: &[&[usize]]) -> Self {
56        Self {
57            plan_type_id: TypeId::of::<P>(),
58            descriptor_type_id: TypeId::of::<D>(),
59            descriptor: desc.clone().into_cache_descriptor(),
60            shapes: shapes.iter().map(|shape| shape.to_vec()).collect(),
61        }
62    }
63}
64
65/// Cache for pre-computed execution plans.
66///
67/// The cache is keyed by plan type, descriptor family, descriptor value, and
68/// concrete tensor shapes. It is intentionally family-aware so the public
69/// primitive protocol can stay split into focused traits without reintroducing
70/// a monolithic descriptor surface.
71///
72/// # Examples
73///
74/// ```
75/// use tenferro_prims::PlanCache;
76///
77/// let cache = PlanCache::new();
78/// assert!(cache.is_empty());
79/// ```
80pub struct PlanCache {
81    entries: HashMap<PlanCacheKey, Box<dyn Any + Send + Sync>>,
82}
83
84impl PlanCache {
85    pub fn new() -> Self {
86        Self {
87            entries: HashMap::new(),
88        }
89    }
90
91    pub fn len(&self) -> usize {
92        self.entries.len()
93    }
94
95    pub fn is_empty(&self) -> bool {
96        self.entries.is_empty()
97    }
98
99    pub(crate) fn get<P, D>(&self, desc: &D, shapes: &[&[usize]]) -> Option<P>
100    where
101        P: Clone + Send + Sync + 'static,
102        D: CacheDescriptor,
103    {
104        let key = PlanCacheKey::new::<P, D>(desc, shapes);
105        self.entries
106            .get(&key)
107            .and_then(|boxed| boxed.downcast_ref::<P>())
108            .cloned()
109    }
110
111    pub(crate) fn insert<P, D>(&mut self, desc: &D, shapes: &[&[usize]], plan: P)
112    where
113        P: Clone + Send + Sync + 'static,
114        D: CacheDescriptor,
115    {
116        let key = PlanCacheKey::new::<P, D>(desc, shapes);
117        self.entries.insert(key, Box::new(plan));
118    }
119
120    pub fn clear(&mut self) {
121        self.entries.clear();
122    }
123}
124
125impl Default for PlanCache {
126    fn default() -> Self {
127        Self::new()
128    }
129}