tenferro/engine.rs
1use std::collections::HashMap;
2use std::num::NonZeroUsize;
3use std::sync::Arc;
4
5use lru::LruCache;
6
7use super::exec::ExecProgram;
8use tenferro_einsum::ContractionTree;
9use tenferro_tensor::{cpu::CpuBackend, TensorBackend};
10
11/// Key used for the N-ary einsum cache: `(subscripts, shapes)`.
12pub(crate) type EinsumCacheKey = (String, Vec<Vec<usize>>);
13
14/// LRU cache of optimized contraction trees keyed by einsum subscripts + input shapes.
15pub(crate) type NaryEinsumCache = LruCache<EinsumCacheKey, Arc<ContractionTree>>;
16
17/// Default capacity for `Engine::einsum_cache`.
18///
19/// Each `ContractionTree` is typically a few KB; 256 entries ≈ under 1 MB.
20pub const DEFAULT_EINSUM_CACHE_CAPACITY: usize = 256;
21
22/// Cache key derived from the compiled graph topology.
23///
24/// Uses the number and order of instructions, their op variants, and
25/// slot counts as a cheap proxy for structural identity.
26#[derive(Clone, Debug, PartialEq, Eq, Hash)]
27pub(crate) struct CacheKey {
28 /// Number of instructions, input slots, output slots, and total slots.
29 shape: (usize, usize, usize, usize),
30 /// A hash of the instruction ops (using Debug representation for simplicity).
31 op_hash: u64,
32}
33
34fn compute_cache_key(exec: &ExecProgram) -> CacheKey {
35 use std::hash::{Hash, Hasher};
36 let mut hasher = std::collections::hash_map::DefaultHasher::new();
37 for inst in &exec.instructions {
38 format!("{:?}", inst.op).hash(&mut hasher);
39 inst.input_slots.hash(&mut hasher);
40 inst.output_slots.hash(&mut hasher);
41 }
42 exec.input_slots.hash(&mut hasher);
43 exec.output_slots.hash(&mut hasher);
44
45 CacheKey {
46 shape: (
47 exec.instructions.len(),
48 exec.input_slots.len(),
49 exec.output_slots.len(),
50 exec.n_slots,
51 ),
52 op_hash: hasher.finish(),
53 }
54}
55
56/// Execution engine holding the backend and compile caches.
57///
58/// # Examples
59///
60/// ```ignore
61/// use tenferro_tensor::cpu::CpuBackend;
62/// use tenferro::engine::Engine;
63///
64/// let mut engine = Engine::new(CpuBackend::new());
65/// ```
66pub struct Engine<B: TensorBackend> {
67 pub(crate) backend: B,
68 pub(crate) compile_cache: HashMap<CacheKey, ExecProgram>,
69 pub(crate) einsum_cache: NaryEinsumCache,
70}
71
72impl<B: TensorBackend> Engine<B> {
73 /// Create a new engine with the given backend.
74 ///
75 /// # Examples
76 ///
77 /// ```ignore
78 /// use tenferro_tensor::cpu::CpuBackend;
79 /// use tenferro::engine::Engine;
80 ///
81 /// let engine = Engine::new(CpuBackend::new());
82 /// ```
83 pub fn new(backend: B) -> Self {
84 Self {
85 backend,
86 compile_cache: HashMap::new(),
87 einsum_cache: LruCache::new(
88 NonZeroUsize::new(DEFAULT_EINSUM_CACHE_CAPACITY)
89 .expect("DEFAULT_EINSUM_CACHE_CAPACITY must be non-zero"),
90 ),
91 }
92 }
93
94 /// Borrow the backend used by this engine.
95 ///
96 /// # Examples
97 ///
98 /// ```
99 /// use tenferro::{CpuBackend, Engine};
100 ///
101 /// let engine = Engine::new(CpuBackend::new());
102 /// let _backend = engine.backend();
103 /// ```
104 pub fn backend(&self) -> &B {
105 &self.backend
106 }
107
108 /// Number of cached einsum contraction trees currently retained by the engine.
109 ///
110 /// # Examples
111 ///
112 /// ```ignore
113 /// use tenferro::{CpuBackend, Engine};
114 ///
115 /// let engine = Engine::new(CpuBackend::new());
116 /// assert_eq!(engine.einsum_cache_len(), 0);
117 /// ```
118 pub fn einsum_cache_len(&self) -> usize {
119 self.einsum_cache.len()
120 }
121
122 /// Construct a new engine with an explicit `einsum_cache` capacity.
123 ///
124 /// # Examples
125 ///
126 /// ```ignore
127 /// use std::num::NonZeroUsize;
128 /// use tenferro::{CpuBackend, Engine};
129 ///
130 /// let engine = Engine::with_einsum_cache_capacity(
131 /// CpuBackend::new(),
132 /// NonZeroUsize::new(64).unwrap(),
133 /// );
134 /// ```
135 pub fn with_einsum_cache_capacity(backend: B, capacity: NonZeroUsize) -> Self {
136 Self {
137 backend,
138 compile_cache: HashMap::new(),
139 einsum_cache: LruCache::new(capacity),
140 }
141 }
142
143 /// Current capacity of the einsum contraction-tree cache.
144 ///
145 /// # Examples
146 ///
147 /// ```ignore
148 /// use tenferro::{CpuBackend, Engine};
149 ///
150 /// let engine = Engine::new(CpuBackend::new());
151 /// assert_eq!(engine.einsum_cache_capacity().get(), tenferro::engine::DEFAULT_EINSUM_CACHE_CAPACITY);
152 /// ```
153 pub fn einsum_cache_capacity(&self) -> NonZeroUsize {
154 self.einsum_cache.cap()
155 }
156
157 /// Resize the einsum contraction-tree cache.
158 ///
159 /// Shrinking below the current length evicts least-recently-used entries.
160 ///
161 /// # Examples
162 ///
163 /// ```ignore
164 /// use std::num::NonZeroUsize;
165 /// use tenferro::{CpuBackend, Engine};
166 ///
167 /// let mut engine = Engine::new(CpuBackend::new());
168 /// engine.set_einsum_cache_capacity(NonZeroUsize::new(32).unwrap());
169 /// ```
170 pub fn set_einsum_cache_capacity(&mut self, capacity: NonZeroUsize) {
171 self.einsum_cache.resize(capacity);
172 }
173
174 /// Returns `true` if the einsum cache contains a tree for `key`.
175 ///
176 /// Does not modify LRU recency.
177 ///
178 /// # Examples
179 ///
180 /// ```ignore
181 /// use tenferro::{CpuBackend, Engine};
182 ///
183 /// let engine = Engine::new(CpuBackend::new());
184 /// let key = ("ij,jk->ik".to_string(), vec![vec![2, 3], vec![3, 4]]);
185 /// assert!(!engine.einsum_cache_contains(&key));
186 /// ```
187 pub fn einsum_cache_contains(&self, key: &(String, Vec<Vec<usize>>)) -> bool {
188 self.einsum_cache.contains(key)
189 }
190
191 /// Look up a cached ExecProgram, or cache and return the given one.
192 ///
193 /// Returns a clone of the cached program to avoid borrow conflicts
194 /// with `self.backend`.
195 pub(crate) fn get_or_compile(&mut self, exec: ExecProgram) -> ExecProgram {
196 let key = compute_cache_key(&exec);
197 self.compile_cache.entry(key).or_insert(exec).clone()
198 }
199
200 /// Evaluate an `ExecProgram` through this engine, reusing the persistent
201 /// `einsum_cache` for any `NaryEinsum` ops encountered in the program.
202 ///
203 /// # Examples
204 ///
205 /// ```ignore
206 /// use tenferro::{CpuBackend, Engine};
207 /// use tenferro::exec::ExecProgram;
208 ///
209 /// let mut engine = Engine::new(CpuBackend::new());
210 /// // let outputs = engine.eval_exec_ir(&program, inputs)?;
211 /// ```
212 pub fn eval_exec_ir(
213 &mut self,
214 program: &ExecProgram,
215 inputs: Vec<tenferro_tensor::Tensor>,
216 ) -> crate::error::Result<Vec<tenferro_tensor::Tensor>> {
217 crate::segment::eval_exec_segmented_with_cache(
218 &mut self.backend,
219 program,
220 inputs,
221 &mut self.einsum_cache,
222 )
223 }
224}
225
226impl Engine<CpuBackend> {
227 /// Number of reusable typed host buffers currently retained by the CPU backend.
228 ///
229 /// # Examples
230 ///
231 /// ```ignore
232 /// use tenferro::{CpuBackend, Engine};
233 ///
234 /// let engine = Engine::new(CpuBackend::new());
235 /// assert_eq!(engine.buffer_pool_len(), 0);
236 /// ```
237 pub fn buffer_pool_len(&self) -> usize {
238 self.backend.buffer_pool_len()
239 }
240}