Skip to main content

tenferro_runtime/graph/
cache.rs

1use std::hash::{Hash, Hasher};
2use std::mem::{size_of, size_of_val};
3use std::sync::Arc;
4
5use lru::LruCache;
6use tenferro_ops::ext_op::ExtensionOp;
7use tenferro_tensor::CacheStats;
8
9use crate::exec::{ExecInstruction, ExecOp, ExecOutputExtents, ExecOutputShapes, ExecProgram};
10
11/// Default capacity for compiled graph programs retained by a [`GraphCompiler`](super::GraphCompiler).
12// Public constant kept as the documented default; the crate-local alias below
13// is what current implementation paths consume.
14#[allow(dead_code)]
15pub const DEFAULT_GRAPH_COMPILE_CACHE_CAPACITY: usize = 256;
16
17/// Internal alias matching the existing engine cache helper name.
18pub(crate) const DEFAULT_COMPILE_CACHE_CAPACITY: usize = DEFAULT_GRAPH_COMPILE_CACHE_CAPACITY;
19
20/// Stats for caches owned by a [`GraphCompiler`](super::GraphCompiler).
21///
22/// `retained_bytes` fields are logical payload estimates, not process RSS.
23///
24/// # Examples
25///
26/// ```
27/// use tenferro_runtime::{CacheStats, GraphCompilerCacheStats};
28///
29/// let stats = GraphCompilerCacheStats {
30///     compile: CacheStats::empty(),
31///     extensions: CacheStats::empty(),
32/// };
33/// assert_eq!(stats.compile.entries, 0);
34/// ```
35#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
36pub struct GraphCompilerCacheStats {
37    /// Compiled execution-program cache.
38    pub compile: CacheStats,
39    /// Generic extension compile-time caches.
40    pub extensions: CacheStats,
41}
42
43/// Stats for runtime caches owned by a graph executor.
44///
45/// `retained_bytes` fields are logical payload estimates, not process RSS.
46///
47/// # Examples
48///
49/// ```
50/// use tenferro_runtime::{CacheStats, GraphExecutorCacheStats};
51///
52/// let stats = GraphExecutorCacheStats {
53///     extensions: CacheStats::empty(),
54///     backend: CacheStats::empty(),
55/// };
56/// assert_eq!(stats.extensions.entries, 0);
57/// ```
58#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
59pub struct GraphExecutorCacheStats {
60    /// Generic extension runtime caches.
61    pub extensions: CacheStats,
62    /// Backend-specific runtime analysis cache.
63    pub backend: CacheStats,
64}
65
66/// Cache key derived from compiled graph topology and execution metadata.
67#[derive(Clone, Debug)]
68pub(crate) struct CacheKey {
69    fingerprint: ExecProgramKey,
70    extensions: Vec<Arc<dyn ExtensionOp>>,
71}
72
73impl PartialEq for CacheKey {
74    fn eq(&self, other: &Self) -> bool {
75        self.fingerprint == other.fingerprint
76            && self.extensions.len() == other.extensions.len()
77            && self
78                .extensions
79                .iter()
80                .zip(&other.extensions)
81                .all(|(lhs, rhs)| {
82                    lhs.family_id() == rhs.family_id() && lhs.payload_eq(rhs.as_ref())
83                })
84    }
85}
86
87impl Eq for CacheKey {}
88
89impl Hash for CacheKey {
90    fn hash<H: Hasher>(&self, state: &mut H) {
91        self.fingerprint.hash(state);
92    }
93}
94
95pub(crate) fn compute_cache_key(exec: &ExecProgram) -> CacheKey {
96    let mut extensions = Vec::new();
97    let fingerprint = exec_program_key(exec, &mut extensions);
98    CacheKey {
99        fingerprint,
100        extensions,
101    }
102}
103
104fn cache_key_retained_bytes(key: &CacheKey) -> usize {
105    saturating_sum([
106        size_of::<CacheKey>(),
107        exec_program_key_retained_bytes(&key.fingerprint),
108        key.extensions
109            .capacity()
110            .saturating_mul(size_of::<Arc<dyn ExtensionOp>>()),
111    ])
112}
113
114#[derive(Clone, Debug, Hash, PartialEq, Eq)]
115struct ExecProgramKey {
116    instructions: Vec<ExecInstructionKey>,
117    input_slots: Vec<usize>,
118    output_slots: Vec<usize>,
119    n_slots: usize,
120}
121
122#[derive(Clone, Debug, Hash, PartialEq, Eq)]
123struct ExecInstructionKey {
124    op: ExecOpKey,
125    input_slots: Vec<usize>,
126    output_slots: Vec<usize>,
127    dtype: tenferro_tensor::DType,
128    output_shapes: ExecOutputShapes,
129    output_extents: ExecOutputExtents,
130    last_use: Vec<bool>,
131}
132
133#[derive(Clone, Debug, Hash, PartialEq, Eq)]
134enum ExecOpKey {
135    Transpose {
136        perm: Vec<usize>,
137    },
138    Reshape {
139        shape: Vec<tenferro_ops::dim_expr::DimExpr>,
140    },
141    BroadcastInDim {
142        shape: Vec<tenferro_ops::dim_expr::DimExpr>,
143        dims: Vec<usize>,
144    },
145    Convert {
146        to: tenferro_tensor::DType,
147    },
148    Constant {
149        dtype: tenferro_tensor::DType,
150        bytes: Vec<u8>,
151    },
152    DotGeneral(tenferro_tensor::DotGeneralConfig),
153    DotGeneralWithConj {
154        config: tenferro_tensor::DotGeneralConfig,
155        lhs_conj: bool,
156        rhs_conj: bool,
157    },
158    ReduceSum {
159        axes: Vec<usize>,
160    },
161    ExtractDiag {
162        axis_a: usize,
163        axis_b: usize,
164    },
165    EmbedDiag {
166        axis_a: usize,
167        axis_b: usize,
168    },
169    Tril {
170        k: i64,
171    },
172    Triu {
173        k: i64,
174    },
175    Add,
176    Multiply,
177    Negate,
178    Conj,
179    Divide,
180    Abs,
181    Sign,
182    Maximum,
183    Minimum,
184    Compare(tenferro_tensor::CompareDir),
185    Select,
186    Clamp,
187    Exp,
188    Log,
189    Sin,
190    Cos,
191    Tanh,
192    Sqrt,
193    Rsqrt,
194    Pow,
195    Expm1,
196    Log1p,
197    Gather(tenferro_tensor::GatherConfig),
198    GatherDynamicSliceSizes {
199        offset_dims: Vec<usize>,
200        collapsed_slice_dims: Vec<usize>,
201        start_index_map: Vec<usize>,
202        index_vector_dim: usize,
203        slice_sizes: Vec<tenferro_ops::dim_expr::DimExpr>,
204    },
205    Scatter(tenferro_tensor::ScatterConfig),
206    Slice(tenferro_tensor::SliceConfig),
207    DynamicSlice {
208        slice_sizes: Vec<usize>,
209    },
210    DynamicUpdateSlice,
211    Pad(tenferro_tensor::PadConfig),
212    Concatenate {
213        axis: usize,
214    },
215    Reverse {
216        axes: Vec<usize>,
217    },
218    ShapeOf {
219        axis: usize,
220    },
221    DynamicTruncate {
222        axis: usize,
223    },
224    PadToMatch {
225        axis: usize,
226    },
227    ReduceProd {
228        axes: Vec<usize>,
229    },
230    ReduceMax {
231        axes: Vec<usize>,
232    },
233    ReduceMin {
234        axes: Vec<usize>,
235    },
236    Extension {
237        family_id: &'static str,
238        payload_hash: u64,
239    },
240}
241
242fn exec_program_key(
243    exec: &ExecProgram,
244    extensions: &mut Vec<Arc<dyn ExtensionOp>>,
245) -> ExecProgramKey {
246    ExecProgramKey {
247        instructions: exec
248            .instructions
249            .iter()
250            .map(|inst| exec_instruction_key(inst, extensions))
251            .collect(),
252        input_slots: exec.input_slots.clone(),
253        output_slots: exec.output_slots.clone(),
254        n_slots: exec.n_slots,
255    }
256}
257
258fn exec_instruction_key(
259    inst: &ExecInstruction,
260    extensions: &mut Vec<Arc<dyn ExtensionOp>>,
261) -> ExecInstructionKey {
262    ExecInstructionKey {
263        op: exec_op_key(&inst.op, extensions),
264        input_slots: inst.input_slots.clone(),
265        output_slots: inst.output_slots.clone(),
266        dtype: inst.dtype,
267        output_shapes: inst.output_shapes.clone(),
268        output_extents: inst.output_extents.clone(),
269        last_use: inst.last_use.clone(),
270    }
271}
272
273fn exec_op_key(op: &ExecOp, extensions: &mut Vec<Arc<dyn ExtensionOp>>) -> ExecOpKey {
274    match op {
275        ExecOp::Transpose { perm } => ExecOpKey::Transpose { perm: perm.clone() },
276        ExecOp::Reshape { shape } => ExecOpKey::Reshape {
277            shape: shape.clone(),
278        },
279        ExecOp::BroadcastInDim { shape, dims } => ExecOpKey::BroadcastInDim {
280            shape: shape.clone(),
281            dims: dims.clone(),
282        },
283        ExecOp::Convert { to } => ExecOpKey::Convert { to: *to },
284        ExecOp::Constant { dtype, bytes } => ExecOpKey::Constant {
285            dtype: *dtype,
286            bytes: bytes.clone(),
287        },
288        ExecOp::DotGeneral(config) => ExecOpKey::DotGeneral(config.clone()),
289        ExecOp::DotGeneralWithConj {
290            config,
291            lhs_conj,
292            rhs_conj,
293        } => ExecOpKey::DotGeneralWithConj {
294            config: config.clone(),
295            lhs_conj: *lhs_conj,
296            rhs_conj: *rhs_conj,
297        },
298        ExecOp::ReduceSum { axes } => ExecOpKey::ReduceSum { axes: axes.clone() },
299        ExecOp::ExtractDiag { axis_a, axis_b } => ExecOpKey::ExtractDiag {
300            axis_a: *axis_a,
301            axis_b: *axis_b,
302        },
303        ExecOp::EmbedDiag { axis_a, axis_b } => ExecOpKey::EmbedDiag {
304            axis_a: *axis_a,
305            axis_b: *axis_b,
306        },
307        ExecOp::Tril { k } => ExecOpKey::Tril { k: *k },
308        ExecOp::Triu { k } => ExecOpKey::Triu { k: *k },
309        ExecOp::Add => ExecOpKey::Add,
310        ExecOp::Multiply => ExecOpKey::Multiply,
311        ExecOp::Negate => ExecOpKey::Negate,
312        ExecOp::Conj => ExecOpKey::Conj,
313        ExecOp::Divide => ExecOpKey::Divide,
314        ExecOp::Abs => ExecOpKey::Abs,
315        ExecOp::Sign => ExecOpKey::Sign,
316        ExecOp::Maximum => ExecOpKey::Maximum,
317        ExecOp::Minimum => ExecOpKey::Minimum,
318        ExecOp::Compare(dir) => ExecOpKey::Compare(dir.clone()),
319        ExecOp::Select => ExecOpKey::Select,
320        ExecOp::Clamp => ExecOpKey::Clamp,
321        ExecOp::Exp => ExecOpKey::Exp,
322        ExecOp::Log => ExecOpKey::Log,
323        ExecOp::Sin => ExecOpKey::Sin,
324        ExecOp::Cos => ExecOpKey::Cos,
325        ExecOp::Tanh => ExecOpKey::Tanh,
326        ExecOp::Sqrt => ExecOpKey::Sqrt,
327        ExecOp::Rsqrt => ExecOpKey::Rsqrt,
328        ExecOp::Pow => ExecOpKey::Pow,
329        ExecOp::Expm1 => ExecOpKey::Expm1,
330        ExecOp::Log1p => ExecOpKey::Log1p,
331        ExecOp::Gather(config) => ExecOpKey::Gather(config.clone()),
332        ExecOp::GatherDynamicSliceSizes {
333            offset_dims,
334            collapsed_slice_dims,
335            start_index_map,
336            index_vector_dim,
337            slice_sizes,
338        } => ExecOpKey::GatherDynamicSliceSizes {
339            offset_dims: offset_dims.clone(),
340            collapsed_slice_dims: collapsed_slice_dims.clone(),
341            start_index_map: start_index_map.clone(),
342            index_vector_dim: *index_vector_dim,
343            slice_sizes: slice_sizes.clone(),
344        },
345        ExecOp::Scatter(config) => ExecOpKey::Scatter(config.clone()),
346        ExecOp::Slice(config) => ExecOpKey::Slice(config.clone()),
347        ExecOp::DynamicSlice { slice_sizes } => ExecOpKey::DynamicSlice {
348            slice_sizes: slice_sizes.clone(),
349        },
350        ExecOp::DynamicUpdateSlice => ExecOpKey::DynamicUpdateSlice,
351        ExecOp::Pad(config) => ExecOpKey::Pad(config.clone()),
352        ExecOp::Concatenate { axis } => ExecOpKey::Concatenate { axis: *axis },
353        ExecOp::Reverse { axes } => ExecOpKey::Reverse { axes: axes.clone() },
354        ExecOp::ShapeOf { axis } => ExecOpKey::ShapeOf { axis: *axis },
355        ExecOp::DynamicTruncate { axis } => ExecOpKey::DynamicTruncate { axis: *axis },
356        ExecOp::PadToMatch { axis } => ExecOpKey::PadToMatch { axis: *axis },
357        ExecOp::ReduceProd { axes } => ExecOpKey::ReduceProd { axes: axes.clone() },
358        ExecOp::ReduceMax { axes } => ExecOpKey::ReduceMax { axes: axes.clone() },
359        ExecOp::ReduceMin { axes } => ExecOpKey::ReduceMin { axes: axes.clone() },
360        ExecOp::Extension(extension) => {
361            let key = ExecOpKey::Extension {
362                family_id: extension.family_id(),
363                payload_hash: extension_payload_hash(extension.as_ref()),
364            };
365            extensions.push(Arc::clone(extension));
366            key
367        }
368    }
369}
370
371fn extension_payload_hash(extension: &dyn ExtensionOp) -> u64 {
372    let mut hasher = std::collections::hash_map::DefaultHasher::new();
373    extension.payload_hash(&mut DynHasherProxy::new(&mut hasher));
374    hasher.finish()
375}
376
377struct DynHasherProxy<'a, H: Hasher + ?Sized> {
378    inner: &'a mut H,
379}
380
381impl<'a, H: Hasher + ?Sized> DynHasherProxy<'a, H> {
382    fn new(inner: &'a mut H) -> Self {
383        Self { inner }
384    }
385}
386
387impl<H: Hasher + ?Sized> Hasher for DynHasherProxy<'_, H> {
388    fn finish(&self) -> u64 {
389        self.inner.finish()
390    }
391
392    fn write(&mut self, bytes: &[u8]) {
393        self.inner.write(bytes);
394    }
395}
396
397fn vec_retained_bytes<T>(values: &Vec<T>) -> usize {
398    values.capacity().saturating_mul(size_of::<T>())
399}
400
401fn vec_of_vec_retained_bytes<T>(values: &[Vec<T>]) -> usize {
402    saturating_sum(values.iter().map(vec_retained_bytes))
403}
404
405fn exec_program_key_retained_bytes(key: &ExecProgramKey) -> usize {
406    saturating_sum([
407        size_of::<ExecProgramKey>(),
408        vec_retained_bytes(&key.instructions),
409        saturating_sum(
410            key.instructions
411                .iter()
412                .map(exec_instruction_key_retained_bytes),
413        ),
414        vec_retained_bytes(&key.input_slots),
415        vec_retained_bytes(&key.output_slots),
416    ])
417}
418
419fn exec_instruction_key_retained_bytes(key: &ExecInstructionKey) -> usize {
420    saturating_sum([
421        size_of::<ExecInstructionKey>(),
422        exec_op_key_retained_bytes(&key.op),
423        vec_retained_bytes(&key.input_slots),
424        vec_retained_bytes(&key.output_slots),
425        vec_of_vec_retained_bytes(&key.output_shapes),
426        vec_of_vec_retained_bytes(&key.output_extents),
427        vec_retained_bytes(&key.last_use),
428    ])
429}
430
431fn exec_op_key_retained_bytes(key: &ExecOpKey) -> usize {
432    saturating_sum([
433        size_of::<ExecOpKey>(),
434        match key {
435            ExecOpKey::Transpose { perm } => vec_retained_bytes(perm),
436            ExecOpKey::Reshape { shape } => vec_retained_bytes(shape),
437            ExecOpKey::BroadcastInDim { shape, dims } => {
438                saturating_sum([vec_retained_bytes(shape), vec_retained_bytes(dims)])
439            }
440            ExecOpKey::Constant { bytes, .. } => vec_retained_bytes(bytes),
441            ExecOpKey::DotGeneral(config) => dot_general_config_retained_bytes(config),
442            ExecOpKey::DotGeneralWithConj { config, .. } => {
443                dot_general_config_retained_bytes(config)
444            }
445            ExecOpKey::ReduceSum { axes }
446            | ExecOpKey::Reverse { axes }
447            | ExecOpKey::ReduceProd { axes }
448            | ExecOpKey::ReduceMax { axes }
449            | ExecOpKey::ReduceMin { axes } => vec_retained_bytes(axes),
450            ExecOpKey::Gather(config) => gather_config_retained_bytes(config),
451            ExecOpKey::GatherDynamicSliceSizes {
452                offset_dims,
453                collapsed_slice_dims,
454                start_index_map,
455                slice_sizes,
456                ..
457            } => saturating_sum([
458                vec_retained_bytes(offset_dims),
459                vec_retained_bytes(collapsed_slice_dims),
460                vec_retained_bytes(start_index_map),
461                vec_retained_bytes(slice_sizes),
462            ]),
463            ExecOpKey::Scatter(config) => scatter_config_retained_bytes(config),
464            ExecOpKey::Slice(config) => slice_config_retained_bytes(config),
465            ExecOpKey::DynamicSlice { slice_sizes } => vec_retained_bytes(slice_sizes),
466            ExecOpKey::Pad(config) => pad_config_retained_bytes(config),
467            ExecOpKey::Convert { .. }
468            | ExecOpKey::ExtractDiag { .. }
469            | ExecOpKey::EmbedDiag { .. }
470            | ExecOpKey::Tril { .. }
471            | ExecOpKey::Triu { .. }
472            | ExecOpKey::Add
473            | ExecOpKey::Multiply
474            | ExecOpKey::Negate
475            | ExecOpKey::Conj
476            | ExecOpKey::Divide
477            | ExecOpKey::Abs
478            | ExecOpKey::Sign
479            | ExecOpKey::Maximum
480            | ExecOpKey::Minimum
481            | ExecOpKey::Compare(_)
482            | ExecOpKey::Select
483            | ExecOpKey::Clamp
484            | ExecOpKey::Exp
485            | ExecOpKey::Log
486            | ExecOpKey::Sin
487            | ExecOpKey::Cos
488            | ExecOpKey::Tanh
489            | ExecOpKey::Sqrt
490            | ExecOpKey::Rsqrt
491            | ExecOpKey::Pow
492            | ExecOpKey::Expm1
493            | ExecOpKey::Log1p
494            | ExecOpKey::DynamicUpdateSlice
495            | ExecOpKey::Concatenate { .. }
496            | ExecOpKey::ShapeOf { .. }
497            | ExecOpKey::DynamicTruncate { .. }
498            | ExecOpKey::PadToMatch { .. }
499            | ExecOpKey::Extension { .. } => 0,
500        },
501    ])
502}
503
504fn dot_general_config_retained_bytes(config: &tenferro_tensor::DotGeneralConfig) -> usize {
505    saturating_sum([
506        vec_retained_bytes(&config.lhs_contracting_dims),
507        vec_retained_bytes(&config.rhs_contracting_dims),
508        vec_retained_bytes(&config.lhs_batch_dims),
509        vec_retained_bytes(&config.rhs_batch_dims),
510    ])
511}
512
513fn gather_config_retained_bytes(config: &tenferro_tensor::GatherConfig) -> usize {
514    saturating_sum([
515        vec_retained_bytes(&config.offset_dims),
516        vec_retained_bytes(&config.collapsed_slice_dims),
517        vec_retained_bytes(&config.start_index_map),
518        vec_retained_bytes(&config.slice_sizes),
519    ])
520}
521
522fn scatter_config_retained_bytes(config: &tenferro_tensor::ScatterConfig) -> usize {
523    saturating_sum([
524        vec_retained_bytes(&config.update_window_dims),
525        vec_retained_bytes(&config.inserted_window_dims),
526        vec_retained_bytes(&config.scatter_dims_to_operand_dims),
527    ])
528}
529
530fn slice_config_retained_bytes(config: &tenferro_tensor::SliceConfig) -> usize {
531    saturating_sum([
532        vec_retained_bytes(&config.starts),
533        vec_retained_bytes(&config.limits),
534        vec_retained_bytes(&config.strides),
535    ])
536}
537
538fn pad_config_retained_bytes(config: &tenferro_tensor::PadConfig) -> usize {
539    saturating_sum([
540        vec_retained_bytes(&config.edge_padding_low),
541        vec_retained_bytes(&config.edge_padding_high),
542        vec_retained_bytes(&config.interior_padding),
543    ])
544}
545
546fn exec_op_retained_bytes(op: &ExecOp) -> usize {
547    match op {
548        ExecOp::Constant { bytes, .. } => vec_retained_bytes(bytes),
549        ExecOp::Extension(extension) => size_of_val(extension),
550        _ => 0,
551    }
552}
553
554fn exec_instruction_retained_bytes(inst: &ExecInstruction) -> usize {
555    saturating_sum([
556        size_of::<ExecInstruction>(),
557        exec_op_retained_bytes(&inst.op),
558        vec_retained_bytes(&inst.input_slots),
559        vec_retained_bytes(&inst.output_slots),
560        vec_of_vec_retained_bytes(&inst.output_shapes),
561        vec_of_vec_retained_bytes(&inst.output_extents),
562        vec_retained_bytes(&inst.last_use),
563    ])
564}
565
566fn exec_program_retained_bytes(program: &ExecProgram) -> usize {
567    saturating_sum([
568        size_of::<ExecProgram>(),
569        vec_retained_bytes(&program.instructions),
570        saturating_sum(
571            program
572                .instructions
573                .iter()
574                .map(exec_instruction_retained_bytes),
575        ),
576        vec_retained_bytes(&program.input_slots),
577        vec_retained_bytes(&program.output_slots),
578    ])
579}
580
581pub(crate) fn compile_cache_stats(cache: &LruCache<CacheKey, ExecProgram>) -> CacheStats {
582    CacheStats {
583        entries: cache.len(),
584        retained_bytes: cache
585            .iter()
586            .map(|(key, program)| {
587                saturating_sum([
588                    cache_key_retained_bytes(key),
589                    exec_program_retained_bytes(program),
590                ])
591            })
592            .fold(0usize, usize::saturating_add),
593    }
594}
595
596fn saturating_sum(values: impl IntoIterator<Item = usize>) -> usize {
597    values.into_iter().fold(0usize, usize::saturating_add)
598}
599
600#[cfg(test)]
601mod tests;