Skip to main content

tenferro_runtime/graph/
compiler.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::num::NonZeroUsize;
4use std::sync::Arc;
5
6use computegraph::compile::compile;
7use computegraph::materialize::materialize_merge;
8use computegraph::resolve::resolve;
9use computegraph::types::ValueKey;
10use lru::LruCache;
11use num_complex::{Complex32, Complex64};
12use tenferro_ops::dim_expr::DimExpr;
13use tenferro_ops::input_key::TensorInputKey;
14use tenferro_tensor::{DType, Tensor, TensorScalar};
15
16use super::cache::{
17    compile_cache_stats, compute_cache_key, CacheKey, GraphCompilerCacheStats,
18    DEFAULT_COMPILE_CACHE_CAPACITY,
19};
20use super::program::{GraphProgram, GraphProgramInput};
21use crate::compiler::{compile_std_to_exec_with_options, CompilerOptions};
22use crate::error::{Error, Result};
23use crate::exec::ExecProgram;
24use crate::extension_cache::{ExtensionCacheSelector, ExtensionCacheStore};
25use crate::traced::{try_concrete_shape, TracedTensor};
26
27#[derive(Clone)]
28struct InputDescriptor {
29    key: TensorInputKey,
30    dtype: DType,
31    shape: Vec<usize>,
32    default_tensor: Option<Arc<Tensor>>,
33}
34
35/// Compiler for traced tensor graphs.
36///
37/// A graph compiler lowers one or more [`TracedTensor`] outputs to a reusable
38/// [`GraphProgram`] without requiring a backend.
39///
40/// # Examples
41///
42/// ```
43/// use tenferro_runtime::{GraphCompiler, TracedTensor};
44///
45/// let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
46/// let y = (&x + &x).unwrap();
47/// let mut compiler = GraphCompiler::new();
48/// let program = compiler.compile(&y).unwrap();
49/// assert_eq!(program.output_count(), 1);
50/// ```
51pub struct GraphCompiler {
52    compile_cache: LruCache<CacheKey, ExecProgram>,
53    extension_cache: ExtensionCacheStore,
54    compiler_options: CompilerOptions,
55}
56
57impl fmt::Debug for GraphCompiler {
58    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59        f.debug_struct("GraphCompiler")
60            .field("cache_stats", &self.cache_stats())
61            .field("compile_cache_capacity", &self.compile_cache_capacity())
62            .field("compiler_options", &self.compiler_options)
63            .field("extension_cache", &self.extension_cache)
64            .finish_non_exhaustive()
65    }
66}
67
68impl GraphCompiler {
69    /// Create a compiler with bounded default caches.
70    ///
71    /// # Examples
72    ///
73    /// ```
74    /// use tenferro_runtime::GraphCompiler;
75    ///
76    /// let compiler = GraphCompiler::new();
77    /// assert_eq!(compiler.compile_cache_len(), 0);
78    /// ```
79    pub fn new() -> Self {
80        Self {
81            compile_cache: LruCache::new(
82                NonZeroUsize::new(DEFAULT_COMPILE_CACHE_CAPACITY).unwrap_or(NonZeroUsize::MIN),
83            ),
84            extension_cache: ExtensionCacheStore::new(),
85            compiler_options: CompilerOptions::default(),
86        }
87    }
88
89    /// Create a compiler with explicit lowering and optimizer options.
90    ///
91    /// # Examples
92    ///
93    /// ```
94    /// use tenferro_runtime::{CompilerOptions, OptimizerConfig};
95    /// use tenferro_runtime::GraphCompiler;
96    ///
97    /// let compiler = GraphCompiler::with_compiler_options(CompilerOptions {
98    ///     optimizer: OptimizerConfig {
99    ///         dot_decomposer: true,
100    ///         ..OptimizerConfig::default()
101    ///     },
102    /// });
103    /// assert!(compiler.compiler_options().optimizer.dot_decomposer);
104    /// ```
105    pub fn with_compiler_options(compiler_options: CompilerOptions) -> Self {
106        Self {
107            compile_cache: LruCache::new(
108                NonZeroUsize::new(DEFAULT_COMPILE_CACHE_CAPACITY).unwrap_or(NonZeroUsize::MIN),
109            ),
110            extension_cache: ExtensionCacheStore::new(),
111            compiler_options,
112        }
113    }
114
115    /// Compile one traced output into a graph program.
116    ///
117    /// # Examples
118    ///
119    /// ```
120    /// use tenferro_runtime::{GraphCompiler, TracedTensor};
121    ///
122    /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
123    /// let mut compiler = GraphCompiler::new();
124    /// let program = compiler.compile(&x.neg()).unwrap();
125    /// assert_eq!(program.input_count(), 1);
126    /// ```
127    pub fn compile(&mut self, output: &TracedTensor) -> Result<GraphProgram> {
128        self.compile_many(&[output])
129    }
130
131    /// Compile multiple traced outputs into one graph program.
132    ///
133    /// # Examples
134    ///
135    /// ```
136    /// use tenferro_runtime::{GraphCompiler, TracedTensor};
137    ///
138    /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
139    /// let y = x.neg();
140    /// let mut compiler = GraphCompiler::new();
141    /// let program = compiler.compile_many(&[&x, &y]).unwrap();
142    /// assert_eq!(program.output_count(), 2);
143    /// ```
144    pub fn compile_many(&mut self, outputs: &[&TracedTensor]) -> Result<GraphProgram> {
145        let mut all_inputs = HashMap::new();
146        for output in outputs {
147            for (key, tensor) in output.inputs_map.iter() {
148                if let Some(existing) = all_inputs.get(key) {
149                    if !default_tensors_equivalent(existing, tensor) {
150                        return Err(Error::DuplicateBinding {
151                            input_key: format!("{:?}", key),
152                        });
153                    }
154                    continue;
155                }
156                all_inputs.insert(key.clone(), tensor.clone());
157            }
158        }
159        self.compile_many_with_descriptors(outputs, &HashMap::new(), &all_inputs)
160    }
161
162    /// Compile one traced output with concrete placeholder specs.
163    ///
164    /// # Examples
165    ///
166    /// ```
167    /// use tenferro_runtime::{DType, GraphCompiler, TracedTensor};
168    ///
169    /// let x = TracedTensor::input_symbolic_shape(DType::F64, 1).unwrap();
170    /// let mut compiler = GraphCompiler::new();
171    /// let program = compiler
172    ///     .compile_with_input_specs(&x.neg(), &[(&x, DType::F64, &[3])])
173    ///     .unwrap();
174    /// assert_eq!(program.input_specs()[0].shape(), &[3]);
175    /// ```
176    pub fn compile_with_input_specs(
177        &mut self,
178        output: &TracedTensor,
179        bindings: &[(&TracedTensor, DType, &[usize])],
180    ) -> Result<GraphProgram> {
181        let mut binding_specs = HashMap::new();
182        for (index, (placeholder, dtype, shape)) in bindings.iter().enumerate() {
183            validate_placeholder_spec(index, placeholder, *dtype, shape)?;
184            let key = placeholder.input_key().ok_or(Error::UnexpectedBinding {
185                binding_index: index,
186            })?;
187            if binding_specs
188                .insert(
189                    key.clone(),
190                    InputDescriptor {
191                        key: key.clone(),
192                        dtype: *dtype,
193                        shape: (*shape).to_vec(),
194                        default_tensor: None,
195                    },
196                )
197                .is_some()
198            {
199                return Err(Error::DuplicateBinding {
200                    input_key: format!("{:?}", key),
201                });
202            }
203        }
204
205        self.compile_many_with_descriptors(&[output], &binding_specs, output.inputs_map.as_ref())
206    }
207
208    /// Number of compiled programs currently retained.
209    ///
210    /// # Examples
211    ///
212    /// ```
213    /// use tenferro_runtime::GraphCompiler;
214    ///
215    /// let compiler = GraphCompiler::new();
216    /// assert_eq!(compiler.compile_cache_len(), 0);
217    /// ```
218    pub fn compile_cache_len(&self) -> usize {
219        self.compile_cache.len()
220    }
221
222    /// Current compiled-program cache capacity.
223    ///
224    /// # Examples
225    ///
226    /// ```
227    /// use tenferro_runtime::GraphCompiler;
228    ///
229    /// let compiler = GraphCompiler::new();
230    /// assert!(compiler.compile_cache_capacity().get() > 0);
231    /// ```
232    pub fn compile_cache_capacity(&self) -> NonZeroUsize {
233        self.compile_cache.cap()
234    }
235
236    /// Resize the compiled-program cache.
237    ///
238    /// # Examples
239    ///
240    /// ```
241    /// use std::num::NonZeroUsize;
242    /// use tenferro_runtime::GraphCompiler;
243    ///
244    /// let mut compiler = GraphCompiler::new();
245    /// compiler.set_compile_cache_capacity(NonZeroUsize::new(2).unwrap());
246    /// assert_eq!(compiler.compile_cache_capacity().get(), 2);
247    /// ```
248    pub fn set_compile_cache_capacity(&mut self, capacity: NonZeroUsize) {
249        self.compile_cache.resize(capacity);
250    }
251
252    /// Return the compiler options used for future graph lowerings.
253    ///
254    /// # Examples
255    ///
256    /// ```
257    /// use tenferro_runtime::CompilerOptions;
258    /// use tenferro_runtime::GraphCompiler;
259    ///
260    /// let compiler = GraphCompiler::new();
261    /// assert_eq!(compiler.compiler_options(), CompilerOptions::default());
262    /// ```
263    pub fn compiler_options(&self) -> CompilerOptions {
264        self.compiler_options
265    }
266
267    /// Replace compiler options and clear compiled graph cache entries.
268    ///
269    /// # Examples
270    ///
271    /// ```
272    /// use tenferro_runtime::{CompilerOptions, OptimizerConfig};
273    /// use tenferro_runtime::GraphCompiler;
274    ///
275    /// let mut compiler = GraphCompiler::new();
276    /// let options = CompilerOptions {
277    ///     optimizer: OptimizerConfig {
278    ///         dot_decomposer: true,
279    ///         ..OptimizerConfig::default()
280    ///     },
281    /// };
282    /// compiler.set_compiler_options(options);
283    /// assert_eq!(compiler.compiler_options(), options);
284    /// assert_eq!(compiler.compile_cache_len(), 0);
285    /// ```
286    pub fn set_compiler_options(&mut self, compiler_options: CompilerOptions) {
287        if self.compiler_options == compiler_options {
288            return;
289        }
290        self.compiler_options = compiler_options;
291        self.clear_compile_cache();
292    }
293
294    /// Clear the compiled-program cache.
295    ///
296    /// # Examples
297    ///
298    /// ```
299    /// use tenferro_runtime::GraphCompiler;
300    ///
301    /// let mut compiler = GraphCompiler::new();
302    /// compiler.clear_compile_cache();
303    /// assert_eq!(compiler.compile_cache_len(), 0);
304    /// ```
305    pub fn clear_compile_cache(&mut self) {
306        self.compile_cache.clear();
307    }
308
309    /// Clear generic extension compile-time cache entries.
310    ///
311    /// # Examples
312    ///
313    /// ```
314    /// use tenferro_runtime::GraphCompiler;
315    ///
316    /// let mut compiler = GraphCompiler::new();
317    /// compiler.clear_extension_caches();
318    /// assert_eq!(compiler.cache_stats().extensions.entries, 0);
319    /// ```
320    pub fn clear_extension_caches(&mut self) {
321        self.extension_cache.clear();
322    }
323
324    /// Clear every cache owned by the compiler.
325    ///
326    /// # Examples
327    ///
328    /// ```
329    /// use tenferro_runtime::GraphCompiler;
330    ///
331    /// let mut compiler = GraphCompiler::new();
332    /// compiler.clear_caches();
333    /// assert_eq!(compiler.cache_stats().compile.entries, 0);
334    /// ```
335    pub fn clear_caches(&mut self) {
336        self.clear_compile_cache();
337        self.clear_extension_caches();
338    }
339
340    /// Return cache-entry and retained-byte stats.
341    ///
342    /// # Examples
343    ///
344    /// ```
345    /// use tenferro_runtime::GraphCompiler;
346    ///
347    /// let compiler = GraphCompiler::new();
348    /// let stats = compiler.cache_stats();
349    /// assert_eq!(stats.compile.entries, 0);
350    /// ```
351    pub fn cache_stats(&self) -> GraphCompilerCacheStats {
352        GraphCompilerCacheStats {
353            compile: compile_cache_stats(&self.compile_cache),
354            extensions: self.extension_cache.stats(ExtensionCacheSelector::All),
355        }
356    }
357
358    /// Borrow generic extension compile-time cache storage.
359    ///
360    /// # Examples
361    ///
362    /// ```
363    /// use tenferro_runtime::GraphCompiler;
364    ///
365    /// let compiler = GraphCompiler::new();
366    /// assert!(compiler.extension_caches().is_empty());
367    /// ```
368    pub fn extension_caches(&self) -> &ExtensionCacheStore {
369        &self.extension_cache
370    }
371
372    /// Mutably borrow generic extension compile-time cache storage.
373    ///
374    /// # Examples
375    ///
376    /// ```
377    /// use tenferro_runtime::GraphCompiler;
378    ///
379    /// let mut compiler = GraphCompiler::new();
380    /// compiler.extension_caches_mut().clear();
381    /// ```
382    pub fn extension_caches_mut(&mut self) -> &mut ExtensionCacheStore {
383        &mut self.extension_cache
384    }
385
386    fn compile_many_with_descriptors(
387        &mut self,
388        outputs: &[&TracedTensor],
389        binding_specs: &HashMap<TensorInputKey, InputDescriptor>,
390        default_inputs: &HashMap<TensorInputKey, Arc<Tensor>>,
391    ) -> Result<GraphProgram> {
392        let mut roots = Vec::new();
393        let mut output_keys = Vec::with_capacity(outputs.len());
394        for output in outputs {
395            roots.extend(output.resolve_roots());
396            output_keys.push(output.graph.values()[output.val].key.clone());
397        }
398
399        let view = resolve(roots);
400        let graph = materialize_merge(&view, &output_keys);
401        let compiled = compile(&graph);
402
403        let mut descriptors = Vec::with_capacity(graph.inputs.len());
404        let mut input_dtypes = Vec::with_capacity(graph.inputs.len());
405        let mut input_shapes = Vec::with_capacity(graph.inputs.len());
406        for key in &graph.inputs {
407            let ValueKey::Input(input_key) = key else {
408                return Err(Error::Internal(
409                    "expected Input key in graph inputs".to_string(),
410                ));
411            };
412            let descriptor = descriptor_for_input(input_key, binding_specs, default_inputs)?;
413            input_dtypes.push(descriptor.dtype);
414            input_shapes.push(DimExpr::from_concrete(&descriptor.shape));
415            descriptors.push(GraphProgramInput::new(
416                descriptor.key,
417                descriptor.dtype,
418                descriptor.shape.clone(),
419                DimExpr::from_concrete(&descriptor.shape),
420                descriptor.default_tensor,
421            ));
422        }
423
424        let exec = compile_std_to_exec_with_options(
425            &compiled,
426            &input_dtypes,
427            &input_shapes,
428            self.compiler_options,
429        )?;
430        let exec = self.get_or_compile(exec);
431        Ok(GraphProgram::new(exec, descriptors))
432    }
433
434    fn get_or_compile(&mut self, exec: ExecProgram) -> ExecProgram {
435        let key = compute_cache_key(&exec);
436        if let Some(cached) = self.compile_cache.get(&key) {
437            return cached.clone();
438        }
439        self.compile_cache.put(key, exec.clone());
440        exec
441    }
442}
443
444impl Default for GraphCompiler {
445    fn default() -> Self {
446        Self::new()
447    }
448}
449
450fn validate_placeholder_spec(
451    index: usize,
452    placeholder: &TracedTensor,
453    dtype: DType,
454    shape: &[usize],
455) -> Result<()> {
456    if placeholder.data.is_some() {
457        return Err(Error::UnexpectedBinding {
458            binding_index: index,
459        });
460    }
461    placeholder.input_key().ok_or(Error::UnexpectedBinding {
462        binding_index: index,
463    })?;
464
465    if placeholder.dtype != dtype {
466        return Err(Error::PlaceholderDtypeMismatch {
467            expected: placeholder.dtype,
468            actual: dtype,
469        });
470    }
471    validate_placeholder_shape(placeholder, shape)
472}
473
474fn validate_placeholder_shape(placeholder: &TracedTensor, shape: &[usize]) -> Result<()> {
475    match try_concrete_shape(placeholder) {
476        Some(expected_shape) => {
477            if expected_shape.as_slice() != shape {
478                return Err(Error::PlaceholderShapeMismatch {
479                    expected: expected_shape,
480                    actual: shape.to_vec(),
481                });
482            }
483        }
484        None => {
485            if placeholder.rank != shape.len() {
486                return Err(Error::PlaceholderRankMismatch {
487                    expected: placeholder.rank,
488                    actual: shape.len(),
489                });
490            }
491        }
492    }
493    Ok(())
494}
495
496fn descriptor_for_input(
497    key: &TensorInputKey,
498    binding_specs: &HashMap<TensorInputKey, InputDescriptor>,
499    default_inputs: &HashMap<TensorInputKey, Arc<Tensor>>,
500) -> Result<InputDescriptor> {
501    if let Some(tensor) = default_inputs.get(key) {
502        return Ok(InputDescriptor {
503            key: key.clone(),
504            dtype: tensor.dtype(),
505            shape: tensor.shape().to_vec(),
506            default_tensor: Some(tensor.clone()),
507        });
508    }
509    if let Some(spec) = binding_specs.get(key) {
510        return Ok(spec.clone());
511    }
512    if !matches!(key, TensorInputKey::User { .. }) {
513        let root = tangent_primal_root(key);
514        if let Some(tensor) = default_inputs.get(root) {
515            return Ok(InputDescriptor {
516                key: key.clone(),
517                dtype: tensor.dtype(),
518                shape: tensor.shape().to_vec(),
519                default_tensor: Some(Arc::new(zeros_tensor(
520                    tensor.dtype(),
521                    tensor.shape().to_vec(),
522                )?)),
523            });
524        }
525        if let Some(spec) = binding_specs.get(root) {
526            return Ok(InputDescriptor {
527                key: key.clone(),
528                dtype: spec.dtype,
529                shape: spec.shape.clone(),
530                default_tensor: spec
531                    .default_tensor
532                    .as_ref()
533                    .map(|tensor| {
534                        zeros_tensor(tensor.dtype(), tensor.shape().to_vec()).map(Arc::new)
535                    })
536                    .transpose()?,
537            });
538        }
539    }
540    Err(Error::UnboundPlaceholder {
541        input_key: format!("{:?}", key),
542    })
543}
544
545fn default_tensors_equivalent(lhs: &Arc<Tensor>, rhs: &Arc<Tensor>) -> bool {
546    if Arc::ptr_eq(lhs, rhs) {
547        return true;
548    }
549    if lhs.dtype() != rhs.dtype() || lhs.shape() != rhs.shape() {
550        return false;
551    }
552    match lhs.dtype() {
553        DType::F32 => default_slices_equivalent::<f32>(lhs, rhs),
554        DType::F64 => default_slices_equivalent::<f64>(lhs, rhs),
555        DType::I32 => default_slices_equivalent::<i32>(lhs, rhs),
556        DType::I64 => default_slices_equivalent::<i64>(lhs, rhs),
557        DType::Bool => default_slices_equivalent::<bool>(lhs, rhs),
558        DType::C32 => default_slices_equivalent::<Complex32>(lhs, rhs),
559        DType::C64 => default_slices_equivalent::<Complex64>(lhs, rhs),
560    }
561}
562
563fn default_slices_equivalent<T: TensorScalar + PartialEq>(lhs: &Tensor, rhs: &Tensor) -> bool {
564    match (lhs.as_slice::<T>(), rhs.as_slice::<T>()) {
565        (Ok(lhs), Ok(rhs)) => lhs == rhs,
566        // Backend-resident defaults cannot be inspected here; only the same
567        // Arc<Tensor> is considered equivalent by `default_tensors_equivalent`.
568        _ => false,
569    }
570}
571
572fn tangent_primal_root(key: &TensorInputKey) -> &TensorInputKey {
573    key.primal_root()
574}
575
576fn zeros_tensor(dtype: DType, shape: Vec<usize>) -> Result<Tensor> {
577    match dtype {
578        DType::F32 => Ok(Tensor::F32(tenferro_tensor::TypedTensor::zeros(shape)?)),
579        DType::F64 => Ok(Tensor::F64(tenferro_tensor::TypedTensor::zeros(shape)?)),
580        DType::I32 => Ok(Tensor::I32(tenferro_tensor::TypedTensor::zeros(shape)?)),
581        DType::I64 => Ok(Tensor::I64(tenferro_tensor::TypedTensor::zeros(shape)?)),
582        DType::Bool => {
583            let len = checked_default_element_count(&shape)?;
584            Ok(Tensor::Bool(
585                tenferro_tensor::TypedTensor::from_vec_col_major(shape, vec![false; len])?,
586            ))
587        }
588        DType::C32 => Ok(Tensor::C32(tenferro_tensor::TypedTensor::zeros(shape)?)),
589        DType::C64 => Ok(Tensor::C64(tenferro_tensor::TypedTensor::zeros(shape)?)),
590    }
591}
592
593fn checked_default_element_count(shape: &[usize]) -> Result<usize> {
594    shape.iter().try_fold(1usize, |acc, &dim| {
595        acc.checked_mul(dim)
596            .ok_or_else(|| Error::InvalidCompiledGraph {
597                message: format!(
598                    "default tensor shape product overflows usize for shape {shape:?}"
599                ),
600            })
601    })
602}
603
604#[cfg(test)]
605mod tests {
606    use super::*;
607    use std::sync::Arc;
608    use tenferro_tensor::{
609        Buffer, BufferHandle, DeviceId, DeviceKind, GpuBackendKind, MemoryKind, Placement,
610        TypedTensor,
611    };
612
613    #[test]
614    fn compile_many_rejects_conflicting_default_inputs_for_same_key() {
615        let x = TracedTensor::from_vec_col_major(vec![1], vec![1.0_f64]).unwrap();
616        let y1 = x.neg();
617        let mut y2 = x.neg();
618        let key = x.input_key().expect("concrete traced tensor has input key");
619        let replacement = Arc::new(Tensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap());
620        let mut inputs = (*y2.inputs_map).clone();
621        inputs.insert(key.clone(), replacement);
622        y2.inputs_map = Arc::new(inputs);
623
624        let err = GraphCompiler::new().compile_many(&[&y1, &y2]).unwrap_err();
625
626        assert!(matches!(
627            err,
628            Error::DuplicateBinding { ref input_key } if input_key.contains(&format!("{key:?}"))
629        ));
630    }
631
632    #[test]
633    fn default_tensors_equivalent_rejects_distinct_backend_buffers() {
634        let placement = Placement {
635            memory_kind: MemoryKind::Device,
636            device: Some(DeviceId {
637                kind: DeviceKind::Gpu(GpuBackendKind::Cuda),
638                ordinal: 0,
639            }),
640        };
641        let lhs = Arc::new(Tensor::F64(
642            TypedTensor::from_buffer_col_major(
643                vec![2],
644                Buffer::Backend(Arc::new(BufferHandle::<f64>::new_with_len(1, 2))),
645                placement.clone(),
646            )
647            .unwrap(),
648        ));
649        let rhs = Arc::new(Tensor::F64(
650            TypedTensor::from_buffer_col_major(
651                vec![2],
652                Buffer::Backend(Arc::new(BufferHandle::<f64>::new_with_len(2, 2))),
653                placement,
654            )
655            .unwrap(),
656        ));
657
658        assert!(
659            !default_tensors_equivalent(&lhs, &rhs),
660            "distinct backend-resident default tensors must not compare equal just because both are unreadable on host"
661        );
662        assert!(default_tensors_equivalent(&lhs, &lhs));
663    }
664}