1use std::hash::{Hash, Hasher};
2use std::mem::{size_of, size_of_val};
3use std::sync::Arc;
4
5use lru::LruCache;
6use tenferro_ops::ext_op::ExtensionOp;
7use tenferro_tensor::CacheStats;
8
9use crate::exec::{ExecInstruction, ExecOp, ExecOutputExtents, ExecOutputShapes, ExecProgram};
10
11#[allow(dead_code)]
15pub const DEFAULT_GRAPH_COMPILE_CACHE_CAPACITY: usize = 256;
16
17pub(crate) const DEFAULT_COMPILE_CACHE_CAPACITY: usize = DEFAULT_GRAPH_COMPILE_CACHE_CAPACITY;
19
20#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
36pub struct GraphCompilerCacheStats {
37 pub compile: CacheStats,
39 pub extensions: CacheStats,
41}
42
43#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
59pub struct GraphExecutorCacheStats {
60 pub extensions: CacheStats,
62 pub backend: CacheStats,
64}
65
66#[derive(Clone, Debug)]
68pub(crate) struct CacheKey {
69 fingerprint: ExecProgramKey,
70 extensions: Vec<Arc<dyn ExtensionOp>>,
71}
72
73impl PartialEq for CacheKey {
74 fn eq(&self, other: &Self) -> bool {
75 self.fingerprint == other.fingerprint
76 && self.extensions.len() == other.extensions.len()
77 && self
78 .extensions
79 .iter()
80 .zip(&other.extensions)
81 .all(|(lhs, rhs)| {
82 lhs.family_id() == rhs.family_id() && lhs.payload_eq(rhs.as_ref())
83 })
84 }
85}
86
87impl Eq for CacheKey {}
88
89impl Hash for CacheKey {
90 fn hash<H: Hasher>(&self, state: &mut H) {
91 self.fingerprint.hash(state);
92 }
93}
94
95pub(crate) fn compute_cache_key(exec: &ExecProgram) -> CacheKey {
96 let mut extensions = Vec::new();
97 let fingerprint = exec_program_key(exec, &mut extensions);
98 CacheKey {
99 fingerprint,
100 extensions,
101 }
102}
103
104fn cache_key_retained_bytes(key: &CacheKey) -> usize {
105 saturating_sum([
106 size_of::<CacheKey>(),
107 exec_program_key_retained_bytes(&key.fingerprint),
108 key.extensions
109 .capacity()
110 .saturating_mul(size_of::<Arc<dyn ExtensionOp>>()),
111 ])
112}
113
114#[derive(Clone, Debug, Hash, PartialEq, Eq)]
115struct ExecProgramKey {
116 instructions: Vec<ExecInstructionKey>,
117 input_slots: Vec<usize>,
118 output_slots: Vec<usize>,
119 n_slots: usize,
120}
121
122#[derive(Clone, Debug, Hash, PartialEq, Eq)]
123struct ExecInstructionKey {
124 op: ExecOpKey,
125 input_slots: Vec<usize>,
126 output_slots: Vec<usize>,
127 dtype: tenferro_tensor::DType,
128 output_shapes: ExecOutputShapes,
129 output_extents: ExecOutputExtents,
130 last_use: Vec<bool>,
131}
132
133#[derive(Clone, Debug, Hash, PartialEq, Eq)]
134enum ExecOpKey {
135 Transpose {
136 perm: Vec<usize>,
137 },
138 Reshape {
139 shape: Vec<tenferro_ops::dim_expr::DimExpr>,
140 },
141 BroadcastInDim {
142 shape: Vec<tenferro_ops::dim_expr::DimExpr>,
143 dims: Vec<usize>,
144 },
145 Convert {
146 to: tenferro_tensor::DType,
147 },
148 Constant {
149 dtype: tenferro_tensor::DType,
150 bytes: Vec<u8>,
151 },
152 DotGeneral(tenferro_tensor::DotGeneralConfig),
153 DotGeneralWithConj {
154 config: tenferro_tensor::DotGeneralConfig,
155 lhs_conj: bool,
156 rhs_conj: bool,
157 },
158 ReduceSum {
159 axes: Vec<usize>,
160 },
161 ExtractDiag {
162 axis_a: usize,
163 axis_b: usize,
164 },
165 EmbedDiag {
166 axis_a: usize,
167 axis_b: usize,
168 },
169 Tril {
170 k: i64,
171 },
172 Triu {
173 k: i64,
174 },
175 Add,
176 Multiply,
177 Negate,
178 Conj,
179 Divide,
180 Abs,
181 Sign,
182 Maximum,
183 Minimum,
184 Compare(tenferro_tensor::CompareDir),
185 Select,
186 Clamp,
187 Exp,
188 Log,
189 Sin,
190 Cos,
191 Tanh,
192 Sqrt,
193 Rsqrt,
194 Pow,
195 Expm1,
196 Log1p,
197 Gather(tenferro_tensor::GatherConfig),
198 GatherDynamicSliceSizes {
199 offset_dims: Vec<usize>,
200 collapsed_slice_dims: Vec<usize>,
201 start_index_map: Vec<usize>,
202 index_vector_dim: usize,
203 slice_sizes: Vec<tenferro_ops::dim_expr::DimExpr>,
204 },
205 Scatter(tenferro_tensor::ScatterConfig),
206 Slice(tenferro_tensor::SliceConfig),
207 DynamicSlice {
208 slice_sizes: Vec<usize>,
209 },
210 DynamicUpdateSlice,
211 Pad(tenferro_tensor::PadConfig),
212 Concatenate {
213 axis: usize,
214 },
215 Reverse {
216 axes: Vec<usize>,
217 },
218 ShapeOf {
219 axis: usize,
220 },
221 DynamicTruncate {
222 axis: usize,
223 },
224 PadToMatch {
225 axis: usize,
226 },
227 ReduceProd {
228 axes: Vec<usize>,
229 },
230 ReduceMax {
231 axes: Vec<usize>,
232 },
233 ReduceMin {
234 axes: Vec<usize>,
235 },
236 Extension {
237 family_id: &'static str,
238 payload_hash: u64,
239 },
240}
241
242fn exec_program_key(
243 exec: &ExecProgram,
244 extensions: &mut Vec<Arc<dyn ExtensionOp>>,
245) -> ExecProgramKey {
246 ExecProgramKey {
247 instructions: exec
248 .instructions
249 .iter()
250 .map(|inst| exec_instruction_key(inst, extensions))
251 .collect(),
252 input_slots: exec.input_slots.clone(),
253 output_slots: exec.output_slots.clone(),
254 n_slots: exec.n_slots,
255 }
256}
257
258fn exec_instruction_key(
259 inst: &ExecInstruction,
260 extensions: &mut Vec<Arc<dyn ExtensionOp>>,
261) -> ExecInstructionKey {
262 ExecInstructionKey {
263 op: exec_op_key(&inst.op, extensions),
264 input_slots: inst.input_slots.clone(),
265 output_slots: inst.output_slots.clone(),
266 dtype: inst.dtype,
267 output_shapes: inst.output_shapes.clone(),
268 output_extents: inst.output_extents.clone(),
269 last_use: inst.last_use.clone(),
270 }
271}
272
273fn exec_op_key(op: &ExecOp, extensions: &mut Vec<Arc<dyn ExtensionOp>>) -> ExecOpKey {
274 match op {
275 ExecOp::Transpose { perm } => ExecOpKey::Transpose { perm: perm.clone() },
276 ExecOp::Reshape { shape } => ExecOpKey::Reshape {
277 shape: shape.clone(),
278 },
279 ExecOp::BroadcastInDim { shape, dims } => ExecOpKey::BroadcastInDim {
280 shape: shape.clone(),
281 dims: dims.clone(),
282 },
283 ExecOp::Convert { to } => ExecOpKey::Convert { to: *to },
284 ExecOp::Constant { dtype, bytes } => ExecOpKey::Constant {
285 dtype: *dtype,
286 bytes: bytes.clone(),
287 },
288 ExecOp::DotGeneral(config) => ExecOpKey::DotGeneral(config.clone()),
289 ExecOp::DotGeneralWithConj {
290 config,
291 lhs_conj,
292 rhs_conj,
293 } => ExecOpKey::DotGeneralWithConj {
294 config: config.clone(),
295 lhs_conj: *lhs_conj,
296 rhs_conj: *rhs_conj,
297 },
298 ExecOp::ReduceSum { axes } => ExecOpKey::ReduceSum { axes: axes.clone() },
299 ExecOp::ExtractDiag { axis_a, axis_b } => ExecOpKey::ExtractDiag {
300 axis_a: *axis_a,
301 axis_b: *axis_b,
302 },
303 ExecOp::EmbedDiag { axis_a, axis_b } => ExecOpKey::EmbedDiag {
304 axis_a: *axis_a,
305 axis_b: *axis_b,
306 },
307 ExecOp::Tril { k } => ExecOpKey::Tril { k: *k },
308 ExecOp::Triu { k } => ExecOpKey::Triu { k: *k },
309 ExecOp::Add => ExecOpKey::Add,
310 ExecOp::Multiply => ExecOpKey::Multiply,
311 ExecOp::Negate => ExecOpKey::Negate,
312 ExecOp::Conj => ExecOpKey::Conj,
313 ExecOp::Divide => ExecOpKey::Divide,
314 ExecOp::Abs => ExecOpKey::Abs,
315 ExecOp::Sign => ExecOpKey::Sign,
316 ExecOp::Maximum => ExecOpKey::Maximum,
317 ExecOp::Minimum => ExecOpKey::Minimum,
318 ExecOp::Compare(dir) => ExecOpKey::Compare(dir.clone()),
319 ExecOp::Select => ExecOpKey::Select,
320 ExecOp::Clamp => ExecOpKey::Clamp,
321 ExecOp::Exp => ExecOpKey::Exp,
322 ExecOp::Log => ExecOpKey::Log,
323 ExecOp::Sin => ExecOpKey::Sin,
324 ExecOp::Cos => ExecOpKey::Cos,
325 ExecOp::Tanh => ExecOpKey::Tanh,
326 ExecOp::Sqrt => ExecOpKey::Sqrt,
327 ExecOp::Rsqrt => ExecOpKey::Rsqrt,
328 ExecOp::Pow => ExecOpKey::Pow,
329 ExecOp::Expm1 => ExecOpKey::Expm1,
330 ExecOp::Log1p => ExecOpKey::Log1p,
331 ExecOp::Gather(config) => ExecOpKey::Gather(config.clone()),
332 ExecOp::GatherDynamicSliceSizes {
333 offset_dims,
334 collapsed_slice_dims,
335 start_index_map,
336 index_vector_dim,
337 slice_sizes,
338 } => ExecOpKey::GatherDynamicSliceSizes {
339 offset_dims: offset_dims.clone(),
340 collapsed_slice_dims: collapsed_slice_dims.clone(),
341 start_index_map: start_index_map.clone(),
342 index_vector_dim: *index_vector_dim,
343 slice_sizes: slice_sizes.clone(),
344 },
345 ExecOp::Scatter(config) => ExecOpKey::Scatter(config.clone()),
346 ExecOp::Slice(config) => ExecOpKey::Slice(config.clone()),
347 ExecOp::DynamicSlice { slice_sizes } => ExecOpKey::DynamicSlice {
348 slice_sizes: slice_sizes.clone(),
349 },
350 ExecOp::DynamicUpdateSlice => ExecOpKey::DynamicUpdateSlice,
351 ExecOp::Pad(config) => ExecOpKey::Pad(config.clone()),
352 ExecOp::Concatenate { axis } => ExecOpKey::Concatenate { axis: *axis },
353 ExecOp::Reverse { axes } => ExecOpKey::Reverse { axes: axes.clone() },
354 ExecOp::ShapeOf { axis } => ExecOpKey::ShapeOf { axis: *axis },
355 ExecOp::DynamicTruncate { axis } => ExecOpKey::DynamicTruncate { axis: *axis },
356 ExecOp::PadToMatch { axis } => ExecOpKey::PadToMatch { axis: *axis },
357 ExecOp::ReduceProd { axes } => ExecOpKey::ReduceProd { axes: axes.clone() },
358 ExecOp::ReduceMax { axes } => ExecOpKey::ReduceMax { axes: axes.clone() },
359 ExecOp::ReduceMin { axes } => ExecOpKey::ReduceMin { axes: axes.clone() },
360 ExecOp::Extension(extension) => {
361 let key = ExecOpKey::Extension {
362 family_id: extension.family_id(),
363 payload_hash: extension_payload_hash(extension.as_ref()),
364 };
365 extensions.push(Arc::clone(extension));
366 key
367 }
368 }
369}
370
371fn extension_payload_hash(extension: &dyn ExtensionOp) -> u64 {
372 let mut hasher = std::collections::hash_map::DefaultHasher::new();
373 extension.payload_hash(&mut DynHasherProxy::new(&mut hasher));
374 hasher.finish()
375}
376
377struct DynHasherProxy<'a, H: Hasher + ?Sized> {
378 inner: &'a mut H,
379}
380
381impl<'a, H: Hasher + ?Sized> DynHasherProxy<'a, H> {
382 fn new(inner: &'a mut H) -> Self {
383 Self { inner }
384 }
385}
386
387impl<H: Hasher + ?Sized> Hasher for DynHasherProxy<'_, H> {
388 fn finish(&self) -> u64 {
389 self.inner.finish()
390 }
391
392 fn write(&mut self, bytes: &[u8]) {
393 self.inner.write(bytes);
394 }
395}
396
397fn vec_retained_bytes<T>(values: &Vec<T>) -> usize {
398 values.capacity().saturating_mul(size_of::<T>())
399}
400
401fn vec_of_vec_retained_bytes<T>(values: &[Vec<T>]) -> usize {
402 saturating_sum(values.iter().map(vec_retained_bytes))
403}
404
405fn exec_program_key_retained_bytes(key: &ExecProgramKey) -> usize {
406 saturating_sum([
407 size_of::<ExecProgramKey>(),
408 vec_retained_bytes(&key.instructions),
409 saturating_sum(
410 key.instructions
411 .iter()
412 .map(exec_instruction_key_retained_bytes),
413 ),
414 vec_retained_bytes(&key.input_slots),
415 vec_retained_bytes(&key.output_slots),
416 ])
417}
418
419fn exec_instruction_key_retained_bytes(key: &ExecInstructionKey) -> usize {
420 saturating_sum([
421 size_of::<ExecInstructionKey>(),
422 exec_op_key_retained_bytes(&key.op),
423 vec_retained_bytes(&key.input_slots),
424 vec_retained_bytes(&key.output_slots),
425 vec_of_vec_retained_bytes(&key.output_shapes),
426 vec_of_vec_retained_bytes(&key.output_extents),
427 vec_retained_bytes(&key.last_use),
428 ])
429}
430
431fn exec_op_key_retained_bytes(key: &ExecOpKey) -> usize {
432 saturating_sum([
433 size_of::<ExecOpKey>(),
434 match key {
435 ExecOpKey::Transpose { perm } => vec_retained_bytes(perm),
436 ExecOpKey::Reshape { shape } => vec_retained_bytes(shape),
437 ExecOpKey::BroadcastInDim { shape, dims } => {
438 saturating_sum([vec_retained_bytes(shape), vec_retained_bytes(dims)])
439 }
440 ExecOpKey::Constant { bytes, .. } => vec_retained_bytes(bytes),
441 ExecOpKey::DotGeneral(config) => dot_general_config_retained_bytes(config),
442 ExecOpKey::DotGeneralWithConj { config, .. } => {
443 dot_general_config_retained_bytes(config)
444 }
445 ExecOpKey::ReduceSum { axes }
446 | ExecOpKey::Reverse { axes }
447 | ExecOpKey::ReduceProd { axes }
448 | ExecOpKey::ReduceMax { axes }
449 | ExecOpKey::ReduceMin { axes } => vec_retained_bytes(axes),
450 ExecOpKey::Gather(config) => gather_config_retained_bytes(config),
451 ExecOpKey::GatherDynamicSliceSizes {
452 offset_dims,
453 collapsed_slice_dims,
454 start_index_map,
455 slice_sizes,
456 ..
457 } => saturating_sum([
458 vec_retained_bytes(offset_dims),
459 vec_retained_bytes(collapsed_slice_dims),
460 vec_retained_bytes(start_index_map),
461 vec_retained_bytes(slice_sizes),
462 ]),
463 ExecOpKey::Scatter(config) => scatter_config_retained_bytes(config),
464 ExecOpKey::Slice(config) => slice_config_retained_bytes(config),
465 ExecOpKey::DynamicSlice { slice_sizes } => vec_retained_bytes(slice_sizes),
466 ExecOpKey::Pad(config) => pad_config_retained_bytes(config),
467 ExecOpKey::Convert { .. }
468 | ExecOpKey::ExtractDiag { .. }
469 | ExecOpKey::EmbedDiag { .. }
470 | ExecOpKey::Tril { .. }
471 | ExecOpKey::Triu { .. }
472 | ExecOpKey::Add
473 | ExecOpKey::Multiply
474 | ExecOpKey::Negate
475 | ExecOpKey::Conj
476 | ExecOpKey::Divide
477 | ExecOpKey::Abs
478 | ExecOpKey::Sign
479 | ExecOpKey::Maximum
480 | ExecOpKey::Minimum
481 | ExecOpKey::Compare(_)
482 | ExecOpKey::Select
483 | ExecOpKey::Clamp
484 | ExecOpKey::Exp
485 | ExecOpKey::Log
486 | ExecOpKey::Sin
487 | ExecOpKey::Cos
488 | ExecOpKey::Tanh
489 | ExecOpKey::Sqrt
490 | ExecOpKey::Rsqrt
491 | ExecOpKey::Pow
492 | ExecOpKey::Expm1
493 | ExecOpKey::Log1p
494 | ExecOpKey::DynamicUpdateSlice
495 | ExecOpKey::Concatenate { .. }
496 | ExecOpKey::ShapeOf { .. }
497 | ExecOpKey::DynamicTruncate { .. }
498 | ExecOpKey::PadToMatch { .. }
499 | ExecOpKey::Extension { .. } => 0,
500 },
501 ])
502}
503
504fn dot_general_config_retained_bytes(config: &tenferro_tensor::DotGeneralConfig) -> usize {
505 saturating_sum([
506 vec_retained_bytes(&config.lhs_contracting_dims),
507 vec_retained_bytes(&config.rhs_contracting_dims),
508 vec_retained_bytes(&config.lhs_batch_dims),
509 vec_retained_bytes(&config.rhs_batch_dims),
510 ])
511}
512
513fn gather_config_retained_bytes(config: &tenferro_tensor::GatherConfig) -> usize {
514 saturating_sum([
515 vec_retained_bytes(&config.offset_dims),
516 vec_retained_bytes(&config.collapsed_slice_dims),
517 vec_retained_bytes(&config.start_index_map),
518 vec_retained_bytes(&config.slice_sizes),
519 ])
520}
521
522fn scatter_config_retained_bytes(config: &tenferro_tensor::ScatterConfig) -> usize {
523 saturating_sum([
524 vec_retained_bytes(&config.update_window_dims),
525 vec_retained_bytes(&config.inserted_window_dims),
526 vec_retained_bytes(&config.scatter_dims_to_operand_dims),
527 ])
528}
529
530fn slice_config_retained_bytes(config: &tenferro_tensor::SliceConfig) -> usize {
531 saturating_sum([
532 vec_retained_bytes(&config.starts),
533 vec_retained_bytes(&config.limits),
534 vec_retained_bytes(&config.strides),
535 ])
536}
537
538fn pad_config_retained_bytes(config: &tenferro_tensor::PadConfig) -> usize {
539 saturating_sum([
540 vec_retained_bytes(&config.edge_padding_low),
541 vec_retained_bytes(&config.edge_padding_high),
542 vec_retained_bytes(&config.interior_padding),
543 ])
544}
545
546fn exec_op_retained_bytes(op: &ExecOp) -> usize {
547 match op {
548 ExecOp::Constant { bytes, .. } => vec_retained_bytes(bytes),
549 ExecOp::Extension(extension) => size_of_val(extension),
550 _ => 0,
551 }
552}
553
554fn exec_instruction_retained_bytes(inst: &ExecInstruction) -> usize {
555 saturating_sum([
556 size_of::<ExecInstruction>(),
557 exec_op_retained_bytes(&inst.op),
558 vec_retained_bytes(&inst.input_slots),
559 vec_retained_bytes(&inst.output_slots),
560 vec_of_vec_retained_bytes(&inst.output_shapes),
561 vec_of_vec_retained_bytes(&inst.output_extents),
562 vec_retained_bytes(&inst.last_use),
563 ])
564}
565
566fn exec_program_retained_bytes(program: &ExecProgram) -> usize {
567 saturating_sum([
568 size_of::<ExecProgram>(),
569 vec_retained_bytes(&program.instructions),
570 saturating_sum(
571 program
572 .instructions
573 .iter()
574 .map(exec_instruction_retained_bytes),
575 ),
576 vec_retained_bytes(&program.input_slots),
577 vec_retained_bytes(&program.output_slots),
578 ])
579}
580
581pub(crate) fn compile_cache_stats(cache: &LruCache<CacheKey, ExecProgram>) -> CacheStats {
582 CacheStats {
583 entries: cache.len(),
584 retained_bytes: cache
585 .iter()
586 .map(|(key, program)| {
587 saturating_sum([
588 cache_key_retained_bytes(key),
589 exec_program_retained_bytes(program),
590 ])
591 })
592 .fold(0usize, usize::saturating_add),
593 }
594}
595
596fn saturating_sum(values: impl IntoIterator<Item = usize>) -> usize {
597 values.into_iter().fold(0usize, usize::saturating_add)
598}
599
600#[cfg(test)]
601mod tests;