1use std::collections::HashMap;
2use std::fmt;
3use std::num::NonZeroUsize;
4use std::sync::Arc;
5
6use computegraph::compile::compile;
7use computegraph::materialize::materialize_merge;
8use computegraph::resolve::resolve;
9use computegraph::types::ValueKey;
10use lru::LruCache;
11use num_complex::{Complex32, Complex64};
12use tenferro_ops::dim_expr::DimExpr;
13use tenferro_ops::input_key::TensorInputKey;
14use tenferro_tensor::{DType, Tensor, TensorScalar};
15
16use super::cache::{
17 compile_cache_stats, compute_cache_key, CacheKey, GraphCompilerCacheStats,
18 DEFAULT_COMPILE_CACHE_CAPACITY,
19};
20use super::program::{GraphProgram, GraphProgramInput};
21use crate::compiler::{compile_std_to_exec_with_options, CompilerOptions};
22use crate::error::{Error, Result};
23use crate::exec::ExecProgram;
24use crate::extension_cache::{ExtensionCacheSelector, ExtensionCacheStore};
25use crate::traced::{try_concrete_shape, TracedTensor};
26
27#[derive(Clone)]
28struct InputDescriptor {
29 key: TensorInputKey,
30 dtype: DType,
31 shape: Vec<usize>,
32 default_tensor: Option<Arc<Tensor>>,
33}
34
35pub struct GraphCompiler {
52 compile_cache: LruCache<CacheKey, ExecProgram>,
53 extension_cache: ExtensionCacheStore,
54 compiler_options: CompilerOptions,
55}
56
57impl fmt::Debug for GraphCompiler {
58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 f.debug_struct("GraphCompiler")
60 .field("cache_stats", &self.cache_stats())
61 .field("compile_cache_capacity", &self.compile_cache_capacity())
62 .field("compiler_options", &self.compiler_options)
63 .field("extension_cache", &self.extension_cache)
64 .finish_non_exhaustive()
65 }
66}
67
68impl GraphCompiler {
69 pub fn new() -> Self {
80 Self {
81 compile_cache: LruCache::new(
82 NonZeroUsize::new(DEFAULT_COMPILE_CACHE_CAPACITY).unwrap_or(NonZeroUsize::MIN),
83 ),
84 extension_cache: ExtensionCacheStore::new(),
85 compiler_options: CompilerOptions::default(),
86 }
87 }
88
89 pub fn with_compiler_options(compiler_options: CompilerOptions) -> Self {
106 Self {
107 compile_cache: LruCache::new(
108 NonZeroUsize::new(DEFAULT_COMPILE_CACHE_CAPACITY).unwrap_or(NonZeroUsize::MIN),
109 ),
110 extension_cache: ExtensionCacheStore::new(),
111 compiler_options,
112 }
113 }
114
115 pub fn compile(&mut self, output: &TracedTensor) -> Result<GraphProgram> {
128 self.compile_many(&[output])
129 }
130
131 pub fn compile_many(&mut self, outputs: &[&TracedTensor]) -> Result<GraphProgram> {
145 let mut all_inputs = HashMap::new();
146 for output in outputs {
147 for (key, tensor) in output.inputs_map.iter() {
148 if let Some(existing) = all_inputs.get(key) {
149 if !default_tensors_equivalent(existing, tensor) {
150 return Err(Error::DuplicateBinding {
151 input_key: format!("{:?}", key),
152 });
153 }
154 continue;
155 }
156 all_inputs.insert(key.clone(), tensor.clone());
157 }
158 }
159 self.compile_many_with_descriptors(outputs, &HashMap::new(), &all_inputs)
160 }
161
162 pub fn compile_with_input_specs(
177 &mut self,
178 output: &TracedTensor,
179 bindings: &[(&TracedTensor, DType, &[usize])],
180 ) -> Result<GraphProgram> {
181 let mut binding_specs = HashMap::new();
182 for (index, (placeholder, dtype, shape)) in bindings.iter().enumerate() {
183 validate_placeholder_spec(index, placeholder, *dtype, shape)?;
184 let key = placeholder.input_key().ok_or(Error::UnexpectedBinding {
185 binding_index: index,
186 })?;
187 if binding_specs
188 .insert(
189 key.clone(),
190 InputDescriptor {
191 key: key.clone(),
192 dtype: *dtype,
193 shape: (*shape).to_vec(),
194 default_tensor: None,
195 },
196 )
197 .is_some()
198 {
199 return Err(Error::DuplicateBinding {
200 input_key: format!("{:?}", key),
201 });
202 }
203 }
204
205 self.compile_many_with_descriptors(&[output], &binding_specs, output.inputs_map.as_ref())
206 }
207
208 pub fn compile_cache_len(&self) -> usize {
219 self.compile_cache.len()
220 }
221
222 pub fn compile_cache_capacity(&self) -> NonZeroUsize {
233 self.compile_cache.cap()
234 }
235
236 pub fn set_compile_cache_capacity(&mut self, capacity: NonZeroUsize) {
249 self.compile_cache.resize(capacity);
250 }
251
252 pub fn compiler_options(&self) -> CompilerOptions {
264 self.compiler_options
265 }
266
267 pub fn set_compiler_options(&mut self, compiler_options: CompilerOptions) {
287 if self.compiler_options == compiler_options {
288 return;
289 }
290 self.compiler_options = compiler_options;
291 self.clear_compile_cache();
292 }
293
294 pub fn clear_compile_cache(&mut self) {
306 self.compile_cache.clear();
307 }
308
309 pub fn clear_extension_caches(&mut self) {
321 self.extension_cache.clear();
322 }
323
324 pub fn clear_caches(&mut self) {
336 self.clear_compile_cache();
337 self.clear_extension_caches();
338 }
339
340 pub fn cache_stats(&self) -> GraphCompilerCacheStats {
352 GraphCompilerCacheStats {
353 compile: compile_cache_stats(&self.compile_cache),
354 extensions: self.extension_cache.stats(ExtensionCacheSelector::All),
355 }
356 }
357
358 pub fn extension_caches(&self) -> &ExtensionCacheStore {
369 &self.extension_cache
370 }
371
372 pub fn extension_caches_mut(&mut self) -> &mut ExtensionCacheStore {
383 &mut self.extension_cache
384 }
385
386 fn compile_many_with_descriptors(
387 &mut self,
388 outputs: &[&TracedTensor],
389 binding_specs: &HashMap<TensorInputKey, InputDescriptor>,
390 default_inputs: &HashMap<TensorInputKey, Arc<Tensor>>,
391 ) -> Result<GraphProgram> {
392 let mut roots = Vec::new();
393 let mut output_keys = Vec::with_capacity(outputs.len());
394 for output in outputs {
395 roots.extend(output.resolve_roots());
396 output_keys.push(output.graph.values()[output.val].key.clone());
397 }
398
399 let view = resolve(roots);
400 let graph = materialize_merge(&view, &output_keys);
401 let compiled = compile(&graph);
402
403 let mut descriptors = Vec::with_capacity(graph.inputs.len());
404 let mut input_dtypes = Vec::with_capacity(graph.inputs.len());
405 let mut input_shapes = Vec::with_capacity(graph.inputs.len());
406 for key in &graph.inputs {
407 let ValueKey::Input(input_key) = key else {
408 return Err(Error::Internal(
409 "expected Input key in graph inputs".to_string(),
410 ));
411 };
412 let descriptor = descriptor_for_input(input_key, binding_specs, default_inputs)?;
413 input_dtypes.push(descriptor.dtype);
414 input_shapes.push(DimExpr::from_concrete(&descriptor.shape));
415 descriptors.push(GraphProgramInput::new(
416 descriptor.key,
417 descriptor.dtype,
418 descriptor.shape.clone(),
419 DimExpr::from_concrete(&descriptor.shape),
420 descriptor.default_tensor,
421 ));
422 }
423
424 let exec = compile_std_to_exec_with_options(
425 &compiled,
426 &input_dtypes,
427 &input_shapes,
428 self.compiler_options,
429 )?;
430 let exec = self.get_or_compile(exec);
431 Ok(GraphProgram::new(exec, descriptors))
432 }
433
434 fn get_or_compile(&mut self, exec: ExecProgram) -> ExecProgram {
435 let key = compute_cache_key(&exec);
436 if let Some(cached) = self.compile_cache.get(&key) {
437 return cached.clone();
438 }
439 self.compile_cache.put(key, exec.clone());
440 exec
441 }
442}
443
444impl Default for GraphCompiler {
445 fn default() -> Self {
446 Self::new()
447 }
448}
449
450fn validate_placeholder_spec(
451 index: usize,
452 placeholder: &TracedTensor,
453 dtype: DType,
454 shape: &[usize],
455) -> Result<()> {
456 if placeholder.data.is_some() {
457 return Err(Error::UnexpectedBinding {
458 binding_index: index,
459 });
460 }
461 placeholder.input_key().ok_or(Error::UnexpectedBinding {
462 binding_index: index,
463 })?;
464
465 if placeholder.dtype != dtype {
466 return Err(Error::PlaceholderDtypeMismatch {
467 expected: placeholder.dtype,
468 actual: dtype,
469 });
470 }
471 validate_placeholder_shape(placeholder, shape)
472}
473
474fn validate_placeholder_shape(placeholder: &TracedTensor, shape: &[usize]) -> Result<()> {
475 match try_concrete_shape(placeholder) {
476 Some(expected_shape) => {
477 if expected_shape.as_slice() != shape {
478 return Err(Error::PlaceholderShapeMismatch {
479 expected: expected_shape,
480 actual: shape.to_vec(),
481 });
482 }
483 }
484 None => {
485 if placeholder.rank != shape.len() {
486 return Err(Error::PlaceholderRankMismatch {
487 expected: placeholder.rank,
488 actual: shape.len(),
489 });
490 }
491 }
492 }
493 Ok(())
494}
495
496fn descriptor_for_input(
497 key: &TensorInputKey,
498 binding_specs: &HashMap<TensorInputKey, InputDescriptor>,
499 default_inputs: &HashMap<TensorInputKey, Arc<Tensor>>,
500) -> Result<InputDescriptor> {
501 if let Some(tensor) = default_inputs.get(key) {
502 return Ok(InputDescriptor {
503 key: key.clone(),
504 dtype: tensor.dtype(),
505 shape: tensor.shape().to_vec(),
506 default_tensor: Some(tensor.clone()),
507 });
508 }
509 if let Some(spec) = binding_specs.get(key) {
510 return Ok(spec.clone());
511 }
512 if !matches!(key, TensorInputKey::User { .. }) {
513 let root = tangent_primal_root(key);
514 if let Some(tensor) = default_inputs.get(root) {
515 return Ok(InputDescriptor {
516 key: key.clone(),
517 dtype: tensor.dtype(),
518 shape: tensor.shape().to_vec(),
519 default_tensor: Some(Arc::new(zeros_tensor(
520 tensor.dtype(),
521 tensor.shape().to_vec(),
522 )?)),
523 });
524 }
525 if let Some(spec) = binding_specs.get(root) {
526 return Ok(InputDescriptor {
527 key: key.clone(),
528 dtype: spec.dtype,
529 shape: spec.shape.clone(),
530 default_tensor: spec
531 .default_tensor
532 .as_ref()
533 .map(|tensor| {
534 zeros_tensor(tensor.dtype(), tensor.shape().to_vec()).map(Arc::new)
535 })
536 .transpose()?,
537 });
538 }
539 }
540 Err(Error::UnboundPlaceholder {
541 input_key: format!("{:?}", key),
542 })
543}
544
545fn default_tensors_equivalent(lhs: &Arc<Tensor>, rhs: &Arc<Tensor>) -> bool {
546 if Arc::ptr_eq(lhs, rhs) {
547 return true;
548 }
549 if lhs.dtype() != rhs.dtype() || lhs.shape() != rhs.shape() {
550 return false;
551 }
552 match lhs.dtype() {
553 DType::F32 => default_slices_equivalent::<f32>(lhs, rhs),
554 DType::F64 => default_slices_equivalent::<f64>(lhs, rhs),
555 DType::I32 => default_slices_equivalent::<i32>(lhs, rhs),
556 DType::I64 => default_slices_equivalent::<i64>(lhs, rhs),
557 DType::Bool => default_slices_equivalent::<bool>(lhs, rhs),
558 DType::C32 => default_slices_equivalent::<Complex32>(lhs, rhs),
559 DType::C64 => default_slices_equivalent::<Complex64>(lhs, rhs),
560 }
561}
562
563fn default_slices_equivalent<T: TensorScalar + PartialEq>(lhs: &Tensor, rhs: &Tensor) -> bool {
564 match (lhs.as_slice::<T>(), rhs.as_slice::<T>()) {
565 (Ok(lhs), Ok(rhs)) => lhs == rhs,
566 _ => false,
569 }
570}
571
572fn tangent_primal_root(key: &TensorInputKey) -> &TensorInputKey {
573 key.primal_root()
574}
575
576fn zeros_tensor(dtype: DType, shape: Vec<usize>) -> Result<Tensor> {
577 match dtype {
578 DType::F32 => Ok(Tensor::F32(tenferro_tensor::TypedTensor::zeros(shape)?)),
579 DType::F64 => Ok(Tensor::F64(tenferro_tensor::TypedTensor::zeros(shape)?)),
580 DType::I32 => Ok(Tensor::I32(tenferro_tensor::TypedTensor::zeros(shape)?)),
581 DType::I64 => Ok(Tensor::I64(tenferro_tensor::TypedTensor::zeros(shape)?)),
582 DType::Bool => {
583 let len = checked_default_element_count(&shape)?;
584 Ok(Tensor::Bool(
585 tenferro_tensor::TypedTensor::from_vec_col_major(shape, vec![false; len])?,
586 ))
587 }
588 DType::C32 => Ok(Tensor::C32(tenferro_tensor::TypedTensor::zeros(shape)?)),
589 DType::C64 => Ok(Tensor::C64(tenferro_tensor::TypedTensor::zeros(shape)?)),
590 }
591}
592
593fn checked_default_element_count(shape: &[usize]) -> Result<usize> {
594 shape.iter().try_fold(1usize, |acc, &dim| {
595 acc.checked_mul(dim)
596 .ok_or_else(|| Error::InvalidCompiledGraph {
597 message: format!(
598 "default tensor shape product overflows usize for shape {shape:?}"
599 ),
600 })
601 })
602}
603
604#[cfg(test)]
605mod tests {
606 use super::*;
607 use std::sync::Arc;
608 use tenferro_tensor::{
609 Buffer, BufferHandle, DeviceId, DeviceKind, GpuBackendKind, MemoryKind, Placement,
610 TypedTensor,
611 };
612
613 #[test]
614 fn compile_many_rejects_conflicting_default_inputs_for_same_key() {
615 let x = TracedTensor::from_vec_col_major(vec![1], vec![1.0_f64]).unwrap();
616 let y1 = x.neg();
617 let mut y2 = x.neg();
618 let key = x.input_key().expect("concrete traced tensor has input key");
619 let replacement = Arc::new(Tensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap());
620 let mut inputs = (*y2.inputs_map).clone();
621 inputs.insert(key.clone(), replacement);
622 y2.inputs_map = Arc::new(inputs);
623
624 let err = GraphCompiler::new().compile_many(&[&y1, &y2]).unwrap_err();
625
626 assert!(matches!(
627 err,
628 Error::DuplicateBinding { ref input_key } if input_key.contains(&format!("{key:?}"))
629 ));
630 }
631
632 #[test]
633 fn default_tensors_equivalent_rejects_distinct_backend_buffers() {
634 let placement = Placement {
635 memory_kind: MemoryKind::Device,
636 device: Some(DeviceId {
637 kind: DeviceKind::Gpu(GpuBackendKind::Cuda),
638 ordinal: 0,
639 }),
640 };
641 let lhs = Arc::new(Tensor::F64(
642 TypedTensor::from_buffer_col_major(
643 vec![2],
644 Buffer::Backend(Arc::new(BufferHandle::<f64>::new_with_len(1, 2))),
645 placement.clone(),
646 )
647 .unwrap(),
648 ));
649 let rhs = Arc::new(Tensor::F64(
650 TypedTensor::from_buffer_col_major(
651 vec![2],
652 Buffer::Backend(Arc::new(BufferHandle::<f64>::new_with_len(2, 2))),
653 placement,
654 )
655 .unwrap(),
656 ));
657
658 assert!(
659 !default_tensors_equivalent(&lhs, &rhs),
660 "distinct backend-resident default tensors must not compare equal just because both are unreadable on host"
661 );
662 assert!(default_tensors_equivalent(&lhs, &lhs));
663 }
664}