Skip to main content

tenferro/
engine.rs

1use std::collections::HashMap;
2use std::num::NonZeroUsize;
3use std::sync::Arc;
4
5use lru::LruCache;
6
7use super::exec::ExecProgram;
8use tenferro_einsum::ContractionTree;
9use tenferro_tensor::{cpu::CpuBackend, TensorBackend};
10
11/// Key used for the N-ary einsum cache: `(subscripts, shapes)`.
12pub(crate) type EinsumCacheKey = (String, Vec<Vec<usize>>);
13
14/// LRU cache of optimized contraction trees keyed by einsum subscripts + input shapes.
15pub(crate) type NaryEinsumCache = LruCache<EinsumCacheKey, Arc<ContractionTree>>;
16
17/// Default capacity for `Engine::einsum_cache`.
18///
19/// Each `ContractionTree` is typically a few KB; 256 entries ≈ under 1 MB.
20pub const DEFAULT_EINSUM_CACHE_CAPACITY: usize = 256;
21
22/// Cache key derived from the compiled graph topology.
23///
24/// Uses the number and order of instructions, their op variants, and
25/// slot counts as a cheap proxy for structural identity.
26#[derive(Clone, Debug, PartialEq, Eq, Hash)]
27pub(crate) struct CacheKey {
28    /// Number of instructions, input slots, output slots, and total slots.
29    shape: (usize, usize, usize, usize),
30    /// A hash of the instruction ops (using Debug representation for simplicity).
31    op_hash: u64,
32}
33
34fn compute_cache_key(exec: &ExecProgram) -> CacheKey {
35    use std::hash::{Hash, Hasher};
36    let mut hasher = std::collections::hash_map::DefaultHasher::new();
37    for inst in &exec.instructions {
38        format!("{:?}", inst.op).hash(&mut hasher);
39        inst.input_slots.hash(&mut hasher);
40        inst.output_slots.hash(&mut hasher);
41    }
42    exec.input_slots.hash(&mut hasher);
43    exec.output_slots.hash(&mut hasher);
44
45    CacheKey {
46        shape: (
47            exec.instructions.len(),
48            exec.input_slots.len(),
49            exec.output_slots.len(),
50            exec.n_slots,
51        ),
52        op_hash: hasher.finish(),
53    }
54}
55
56/// Execution engine holding the backend and compile caches.
57///
58/// # Examples
59///
60/// ```ignore
61/// use tenferro_tensor::cpu::CpuBackend;
62/// use tenferro::engine::Engine;
63///
64/// let mut engine = Engine::new(CpuBackend::new());
65/// ```
66pub struct Engine<B: TensorBackend> {
67    pub(crate) backend: B,
68    pub(crate) compile_cache: HashMap<CacheKey, ExecProgram>,
69    pub(crate) einsum_cache: NaryEinsumCache,
70}
71
72impl<B: TensorBackend> Engine<B> {
73    /// Create a new engine with the given backend.
74    ///
75    /// # Examples
76    ///
77    /// ```ignore
78    /// use tenferro_tensor::cpu::CpuBackend;
79    /// use tenferro::engine::Engine;
80    ///
81    /// let engine = Engine::new(CpuBackend::new());
82    /// ```
83    pub fn new(backend: B) -> Self {
84        Self {
85            backend,
86            compile_cache: HashMap::new(),
87            einsum_cache: LruCache::new(
88                NonZeroUsize::new(DEFAULT_EINSUM_CACHE_CAPACITY)
89                    .expect("DEFAULT_EINSUM_CACHE_CAPACITY must be non-zero"),
90            ),
91        }
92    }
93
94    /// Borrow the backend used by this engine.
95    ///
96    /// # Examples
97    ///
98    /// ```
99    /// use tenferro::{CpuBackend, Engine};
100    ///
101    /// let engine = Engine::new(CpuBackend::new());
102    /// let _backend = engine.backend();
103    /// ```
104    pub fn backend(&self) -> &B {
105        &self.backend
106    }
107
108    /// Number of cached einsum contraction trees currently retained by the engine.
109    ///
110    /// # Examples
111    ///
112    /// ```ignore
113    /// use tenferro::{CpuBackend, Engine};
114    ///
115    /// let engine = Engine::new(CpuBackend::new());
116    /// assert_eq!(engine.einsum_cache_len(), 0);
117    /// ```
118    pub fn einsum_cache_len(&self) -> usize {
119        self.einsum_cache.len()
120    }
121
122    /// Construct a new engine with an explicit `einsum_cache` capacity.
123    ///
124    /// # Examples
125    ///
126    /// ```ignore
127    /// use std::num::NonZeroUsize;
128    /// use tenferro::{CpuBackend, Engine};
129    ///
130    /// let engine = Engine::with_einsum_cache_capacity(
131    ///     CpuBackend::new(),
132    ///     NonZeroUsize::new(64).unwrap(),
133    /// );
134    /// ```
135    pub fn with_einsum_cache_capacity(backend: B, capacity: NonZeroUsize) -> Self {
136        Self {
137            backend,
138            compile_cache: HashMap::new(),
139            einsum_cache: LruCache::new(capacity),
140        }
141    }
142
143    /// Current capacity of the einsum contraction-tree cache.
144    ///
145    /// # Examples
146    ///
147    /// ```ignore
148    /// use tenferro::{CpuBackend, Engine};
149    ///
150    /// let engine = Engine::new(CpuBackend::new());
151    /// assert_eq!(engine.einsum_cache_capacity().get(), tenferro::engine::DEFAULT_EINSUM_CACHE_CAPACITY);
152    /// ```
153    pub fn einsum_cache_capacity(&self) -> NonZeroUsize {
154        self.einsum_cache.cap()
155    }
156
157    /// Resize the einsum contraction-tree cache.
158    ///
159    /// Shrinking below the current length evicts least-recently-used entries.
160    ///
161    /// # Examples
162    ///
163    /// ```ignore
164    /// use std::num::NonZeroUsize;
165    /// use tenferro::{CpuBackend, Engine};
166    ///
167    /// let mut engine = Engine::new(CpuBackend::new());
168    /// engine.set_einsum_cache_capacity(NonZeroUsize::new(32).unwrap());
169    /// ```
170    pub fn set_einsum_cache_capacity(&mut self, capacity: NonZeroUsize) {
171        self.einsum_cache.resize(capacity);
172    }
173
174    /// Returns `true` if the einsum cache contains a tree for `key`.
175    ///
176    /// Does not modify LRU recency.
177    ///
178    /// # Examples
179    ///
180    /// ```ignore
181    /// use tenferro::{CpuBackend, Engine};
182    ///
183    /// let engine = Engine::new(CpuBackend::new());
184    /// let key = ("ij,jk->ik".to_string(), vec![vec![2, 3], vec![3, 4]]);
185    /// assert!(!engine.einsum_cache_contains(&key));
186    /// ```
187    pub fn einsum_cache_contains(&self, key: &(String, Vec<Vec<usize>>)) -> bool {
188        self.einsum_cache.contains(key)
189    }
190
191    /// Look up a cached ExecProgram, or cache and return the given one.
192    ///
193    /// Returns a clone of the cached program to avoid borrow conflicts
194    /// with `self.backend`.
195    pub(crate) fn get_or_compile(&mut self, exec: ExecProgram) -> ExecProgram {
196        let key = compute_cache_key(&exec);
197        self.compile_cache.entry(key).or_insert(exec).clone()
198    }
199
200    /// Evaluate an `ExecProgram` through this engine, reusing the persistent
201    /// `einsum_cache` for any `NaryEinsum` ops encountered in the program.
202    ///
203    /// # Examples
204    ///
205    /// ```ignore
206    /// use tenferro::{CpuBackend, Engine};
207    /// use tenferro::exec::ExecProgram;
208    ///
209    /// let mut engine = Engine::new(CpuBackend::new());
210    /// // let outputs = engine.eval_exec_ir(&program, inputs)?;
211    /// ```
212    pub fn eval_exec_ir(
213        &mut self,
214        program: &ExecProgram,
215        inputs: Vec<tenferro_tensor::Tensor>,
216    ) -> crate::error::Result<Vec<tenferro_tensor::Tensor>> {
217        crate::segment::eval_exec_segmented_with_cache(
218            &mut self.backend,
219            program,
220            inputs,
221            &mut self.einsum_cache,
222        )
223    }
224}
225
226impl Engine<CpuBackend> {
227    /// Number of reusable typed host buffers currently retained by the CPU backend.
228    ///
229    /// # Examples
230    ///
231    /// ```ignore
232    /// use tenferro::{CpuBackend, Engine};
233    ///
234    /// let engine = Engine::new(CpuBackend::new());
235    /// assert_eq!(engine.buffer_pool_len(), 0);
236    /// ```
237    pub fn buffer_pool_len(&self) -> usize {
238        self.backend.buffer_pool_len()
239    }
240}