Skip to main content

tenferro_runtime/graph/
executor.rs

1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::sync::Arc;
4
5use tenferro_ops::input_key::TensorInputKey;
6use tenferro_tensor::{
7    DType, RuntimeCacheControl, Tensor, TensorBackend, TensorRead, TensorValue, TypedTensor,
8};
9
10use super::cache::GraphExecutorCacheStats;
11use super::program::{GraphProgram, GraphProgramInput};
12use crate::error::{Error, Result};
13use crate::exec::{ExecProgram, ExecSlot};
14use crate::extension_runtime::{ExtensionExecutor, ExtensionRuntimeRegistryError};
15use crate::traced::TracedTensor;
16
17/// Executes compiled graph programs on a concrete tensor backend.
18///
19/// A graph executor owns backend execution state only: backend runtime caches,
20/// extension runtime state, and reusable execution workspace. Compilation
21/// state lives in [`GraphCompiler`](super::GraphCompiler).
22///
23/// # Examples
24///
25/// ```
26/// use tenferro_cpu::CpuBackend;
27/// use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
28///
29/// let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
30/// let y = (&x + &x).unwrap();
31/// let mut compiler = GraphCompiler::new();
32/// let program = compiler.compile(&y).unwrap();
33///
34/// let mut executor = GraphExecutor::new(CpuBackend::new());
35/// let out = executor.run(&program).unwrap();
36/// assert_eq!(out.as_slice::<f64>().unwrap(), &[2.0, 4.0]);
37/// ```
38pub struct GraphExecutor<B: TensorBackend + 'static> {
39    backend: B,
40    backend_cache: B::RuntimeCache,
41    extension_executor: ExtensionExecutor<B>,
42    slot_workspace: Vec<Option<ExecSlot<'static>>>,
43    borrowed_slot_workspace_capacity: usize,
44}
45
46impl<B: TensorBackend + 'static> fmt::Debug for GraphExecutor<B> {
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        f.debug_struct("GraphExecutor")
49            .field("backend_type", &std::any::type_name::<B>())
50            .field("cache_stats", &self.cache_stats())
51            .field("slot_workspace_len", &self.slot_workspace.len())
52            .field(
53                "borrowed_slot_workspace_capacity",
54                &self.borrowed_slot_workspace_capacity,
55            )
56            .finish_non_exhaustive()
57    }
58}
59
60impl<B: TensorBackend + 'static> GraphExecutor<B> {
61    /// Create an executor with the given backend and bounded default caches.
62    ///
63    /// # Examples
64    ///
65    /// ```
66    /// use tenferro_cpu::CpuBackend;
67    /// use tenferro_runtime::{GraphExecutor};
68    ///
69    /// let executor = GraphExecutor::new(CpuBackend::new());
70    /// assert_eq!(executor.cache_stats().extensions.entries, 0);
71    /// ```
72    pub fn new(backend: B) -> Self {
73        Self {
74            backend,
75            backend_cache: B::RuntimeCache::default(),
76            extension_executor: ExtensionExecutor::new(),
77            slot_workspace: Vec::new(),
78            borrowed_slot_workspace_capacity: 0,
79        }
80    }
81
82    /// Borrow the backend used by this executor.
83    ///
84    /// # Examples
85    ///
86    /// ```
87    /// use tenferro_cpu::CpuBackend;
88    /// use tenferro_runtime::{GraphExecutor};
89    ///
90    /// let executor = GraphExecutor::new(CpuBackend::new());
91    /// let _backend = executor.backend();
92    /// ```
93    pub fn backend(&self) -> &B {
94        &self.backend
95    }
96
97    /// Return output tensors to the executor backend's reusable buffer pool.
98    ///
99    /// This is useful for tight benchmark or serving loops that consume an
100    /// output before the next run and want backend-level output allocation
101    /// behavior to match caching allocators.
102    ///
103    /// # Examples
104    ///
105    /// ```
106    /// use tenferro_cpu::CpuBackend;
107    /// use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
108    ///
109    /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
110    /// let mut compiler = GraphCompiler::new();
111    /// let program = compiler.compile(&x.neg()).unwrap();
112    /// let mut executor = GraphExecutor::new(CpuBackend::new());
113    ///
114    /// let out = executor.run(&program).unwrap();
115    /// assert_eq!(out.as_slice::<f64>().unwrap(), &[-2.0]);
116    /// executor.reclaim_outputs(vec![out]);
117    /// ```
118    pub fn reclaim_outputs(&mut self, outputs: Vec<Tensor>) {
119        for tensor in outputs {
120            self.backend.reclaim_buffer(tensor);
121        }
122    }
123
124    /// Return compact value outputs to the backend pool when ownership is unique.
125    ///
126    /// Lazy owned views are intentionally ignored because their base storage may
127    /// be aliased by view metadata.
128    ///
129    /// # Examples
130    ///
131    /// ```
132    /// use tenferro_cpu::CpuBackend;
133    /// use tenferro_runtime::{GraphExecutor, Tensor, TensorValue};
134    ///
135    /// let tensor = Tensor::from_vec_col_major(vec![1], vec![3.0_f64]).unwrap();
136    /// let mut executor = GraphExecutor::new(CpuBackend::new());
137    /// executor.reclaim_value_outputs(vec![TensorValue::from_tensor(tensor)]);
138    /// ```
139    pub fn reclaim_value_outputs(&mut self, outputs: Vec<TensorValue>) {
140        for value in outputs {
141            if let TensorValue::Tensor(tensor) = value {
142                if let Ok(tensor) = Arc::try_unwrap(tensor) {
143                    self.backend.reclaim_buffer(tensor);
144                }
145            }
146        }
147    }
148
149    /// Borrow the extension runtime executor owned by this graph executor.
150    ///
151    /// # Examples
152    ///
153    /// ```
154    /// use tenferro_cpu::CpuBackend;
155    /// use tenferro_runtime::{GraphExecutor};
156    ///
157    /// let executor = GraphExecutor::new(CpuBackend::new());
158    /// assert_eq!(executor.extension_executor().cache_stats().entries, 0);
159    /// ```
160    pub fn extension_executor(&self) -> &ExtensionExecutor<B> {
161        &self.extension_executor
162    }
163
164    /// Mutably borrow the extension runtime executor owned by this graph executor.
165    ///
166    /// # Examples
167    ///
168    /// ```
169    /// use tenferro_cpu::CpuBackend;
170    /// use tenferro_runtime::{GraphExecutor};
171    ///
172    /// let mut executor = GraphExecutor::new(CpuBackend::new());
173    /// executor.extension_executor_mut().clear_caches();
174    /// ```
175    pub fn extension_executor_mut(&mut self) -> &mut ExtensionExecutor<B> {
176        &mut self.extension_executor
177    }
178
179    /// Register one extension runtime on this executor.
180    ///
181    /// # Examples
182    ///
183    /// ```
184    /// use tenferro_cpu::CpuBackend;
185    /// use tenferro_runtime::GraphExecutor;
186    ///
187    /// let mut executor = GraphExecutor::new(CpuBackend::new());
188    /// executor.register_extension(|_| Ok(())).unwrap();
189    /// ```
190    pub fn register_extension(
191        &mut self,
192        register: impl FnOnce(
193            &mut ExtensionExecutor<B>,
194        ) -> std::result::Result<(), ExtensionRuntimeRegistryError>,
195    ) -> std::result::Result<(), ExtensionRuntimeRegistryError> {
196        register(&mut self.extension_executor)
197    }
198
199    /// Run a one-output program using the program's default input tensors.
200    ///
201    /// # Examples
202    ///
203    /// ```
204    /// use tenferro_cpu::CpuBackend;
205    /// use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
206    ///
207    /// let x = TracedTensor::from_vec_col_major(vec![1], vec![3.0_f64]).unwrap();
208    /// let mut compiler = GraphCompiler::new();
209    /// let program = compiler.compile(&x.neg()).unwrap();
210    /// let mut executor = GraphExecutor::new(CpuBackend::new());
211    /// let out = executor.run(&program).unwrap();
212    /// assert_eq!(out.as_slice::<f64>().unwrap(), &[-3.0]);
213    /// ```
214    pub fn run(&mut self, program: &GraphProgram) -> Result<Tensor> {
215        let mut outputs = self.run_many(program)?;
216        expect_single_output(&mut outputs)
217    }
218
219    /// Run a one-output program and preserve lazy owned output views.
220    ///
221    /// # Examples
222    ///
223    /// ```
224    /// use tenferro_cpu::CpuBackend;
225    /// use tenferro_runtime::{GraphCompiler, GraphExecutor, TensorValue, TracedTensor};
226    ///
227    /// let x = TracedTensor::from_vec_col_major(
228    ///     vec![2, 3],
229    ///     vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0],
230    /// )
231    /// .unwrap();
232    /// let y = x.transpose(&[1, 0]).unwrap();
233    /// let mut compiler = GraphCompiler::new();
234    /// let program = compiler.compile(&y).unwrap();
235    ///
236    /// let mut executor = GraphExecutor::new(CpuBackend::new());
237    /// let value = executor.run_value(&program).unwrap();
238    /// assert!(matches!(&value, TensorValue::View(_)));
239    /// assert_eq!(value.shape(), &[3, 2]);
240    /// assert_eq!(
241    ///     value.to_tensor().unwrap().as_slice::<f64>().unwrap(),
242    ///     &[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]
243    /// );
244    /// ```
245    pub fn run_value(&mut self, program: &GraphProgram) -> Result<TensorValue> {
246        let mut outputs = self.run_many_values(program)?;
247        expect_single_value(&mut outputs)
248    }
249
250    /// Run a program using the program's default input tensors.
251    ///
252    /// # Examples
253    ///
254    /// ```
255    /// use tenferro_cpu::CpuBackend;
256    /// use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
257    ///
258    /// let x = TracedTensor::from_vec_col_major(vec![1], vec![3.0_f64]).unwrap();
259    /// let y = x.neg();
260    /// let mut compiler = GraphCompiler::new();
261    /// let program = compiler.compile_many(&[&x, &y]).unwrap();
262    /// let mut executor = GraphExecutor::new(CpuBackend::new());
263    /// let outputs = executor.run_many(&program).unwrap();
264    /// assert_eq!(outputs.len(), 2);
265    /// ```
266    pub fn run_many(&mut self, program: &GraphProgram) -> Result<Vec<Tensor>> {
267        self.run_many_with_inputs(program, &[])
268    }
269
270    /// Run a program and preserve lazy owned output views.
271    ///
272    /// # Examples
273    ///
274    /// ```
275    /// use tenferro_cpu::CpuBackend;
276    /// use tenferro_runtime::{GraphCompiler, GraphExecutor, TensorValue, TracedTensor};
277    ///
278    /// let x = TracedTensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap();
279    /// let y = x.transpose(&[1, 0]).unwrap();
280    /// let mut compiler = GraphCompiler::new();
281    /// let program = compiler.compile_many(&[&y]).unwrap();
282    ///
283    /// let mut executor = GraphExecutor::new(CpuBackend::new());
284    /// let outputs = executor.run_many_values(&program).unwrap();
285    /// assert_eq!(outputs.len(), 1);
286    /// assert!(matches!(&outputs[0], TensorValue::View(_)));
287    /// assert_eq!(outputs[0].shape(), &[2, 2]);
288    /// ```
289    pub fn run_many_values(&mut self, program: &GraphProgram) -> Result<Vec<TensorValue>> {
290        self.run_many_values_with_inputs(program, &[])
291    }
292
293    /// Run a one-output program with explicit runtime placeholder bindings.
294    ///
295    /// Explicit bindings override program defaults and are validated against
296    /// the ordered input specs captured in the compiled program.
297    ///
298    /// # Examples
299    ///
300    /// ```
301    /// use tenferro_cpu::CpuBackend;
302    /// use tenferro_runtime::{DType, GraphCompiler, GraphExecutor, Tensor, TracedTensor};
303    ///
304    /// let x = TracedTensor::input_symbolic_shape(DType::F64, 1).unwrap();
305    /// let y = (&x + &x).unwrap();
306    /// let mut compiler = GraphCompiler::new();
307    /// let program = compiler
308    ///     .compile_with_input_specs(&y, &[(&x, DType::F64, &[2])])
309    ///     .unwrap();
310    /// let bound = Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
311    /// let mut executor = GraphExecutor::new(CpuBackend::new());
312    /// let out = executor.run_with_inputs(&program, &[(&x, &bound)]).unwrap();
313    /// assert_eq!(out.as_slice::<f64>().unwrap(), &[2.0, 4.0]);
314    /// ```
315    pub fn run_with_inputs(
316        &mut self,
317        program: &GraphProgram,
318        bindings: &[(&TracedTensor, &Tensor)],
319    ) -> Result<Tensor> {
320        let mut outputs = self.run_many_with_inputs(program, bindings)?;
321        expect_single_output(&mut outputs)
322    }
323
324    /// Run a one-output program with explicit bindings and preserve lazy output views.
325    ///
326    /// # Examples
327    ///
328    /// ```
329    /// use tenferro_cpu::CpuBackend;
330    /// use tenferro_runtime::{DType, GraphCompiler, GraphExecutor, Tensor, TensorValue, TracedTensor};
331    ///
332    /// let x = TracedTensor::input_symbolic_shape(DType::F64, 2).unwrap();
333    /// let y = x.transpose(&[1, 0]).unwrap();
334    /// let mut compiler = GraphCompiler::new();
335    /// let program = compiler
336    ///     .compile_with_input_specs(&y, &[(&x, DType::F64, &[2, 2])])
337    ///     .unwrap();
338    /// let bound = Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap();
339    /// let mut executor = GraphExecutor::new(CpuBackend::new());
340    ///
341    /// let value = executor.run_value_with_inputs(&program, &[(&x, &bound)]).unwrap();
342    /// assert!(matches!(&value, TensorValue::View(_)));
343    /// assert_eq!(value.to_tensor().unwrap().as_slice::<f64>().unwrap(), &[1.0, 3.0, 2.0, 4.0]);
344    /// ```
345    pub fn run_value_with_inputs(
346        &mut self,
347        program: &GraphProgram,
348        bindings: &[(&TracedTensor, &Tensor)],
349    ) -> Result<TensorValue> {
350        let mut outputs = self.run_many_values_with_inputs(program, bindings)?;
351        expect_single_value(&mut outputs)
352    }
353
354    /// Run a one-output program with explicit borrowed runtime placeholder bindings.
355    ///
356    /// Unlike [`run_with_inputs`](Self::run_with_inputs), caller-owned input
357    /// tensors are read through [`TensorRead`] and are not cloned into executor
358    /// slots.
359    ///
360    /// # Examples
361    ///
362    /// ```
363    /// use tenferro_cpu::CpuBackend;
364    /// use tenferro_runtime::{
365    ///     DType, GraphCompiler, GraphExecutor, TensorRead, TensorView, TracedTensor, TypedTensorView,
366    /// };
367    ///
368    /// let x = TracedTensor::input_symbolic_shape(DType::F64, 1).unwrap();
369    /// let y = (&x + &x).unwrap();
370    /// let mut compiler = GraphCompiler::new();
371    /// let program = compiler
372    ///     .compile_with_input_specs(&y, &[(&x, DType::F64, &[2])])
373    ///     .unwrap();
374    /// let data = [1.0_f64, 99.0, 2.0];
375    /// let view = TypedTensorView::from_slice([2], [2], 0, &data).unwrap();
376    /// let read = TensorRead::from_view(TensorView::F64(view));
377    /// let mut executor = GraphExecutor::new(CpuBackend::new());
378    ///
379    /// let out = executor.run_with_input_reads(&program, &[(&x, read)]).unwrap();
380    /// assert_eq!(out.as_slice::<f64>().unwrap(), &[2.0, 4.0]);
381    /// ```
382    pub fn run_with_input_reads<'a>(
383        &mut self,
384        program: &'a GraphProgram,
385        bindings: &[(&TracedTensor, TensorRead<'a>)],
386    ) -> Result<Tensor> {
387        let mut outputs = self.run_many_with_input_reads(program, bindings)?;
388        expect_single_output(&mut outputs)
389    }
390
391    /// Run a one-output program with borrowed bindings and preserve lazy output views.
392    ///
393    /// # Examples
394    ///
395    /// ```
396    /// use tenferro_cpu::CpuBackend;
397    /// use tenferro_runtime::{
398    ///     DType, GraphCompiler, GraphExecutor, TensorRead, TensorValue, TensorView, TracedTensor,
399    ///     TypedTensorView,
400    /// };
401    ///
402    /// let x = TracedTensor::input_symbolic_shape(DType::F64, 2).unwrap();
403    /// let y = x.transpose(&[1, 0]).unwrap();
404    /// let mut compiler = GraphCompiler::new();
405    /// let program = compiler
406    ///     .compile_with_input_specs(&y, &[(&x, DType::F64, &[2, 2])])
407    ///     .unwrap();
408    /// let data = [1.0_f64, 2.0, 3.0, 4.0];
409    /// let view = TypedTensorView::from_slice([2, 2], [1, 2], 0, &data).unwrap();
410    /// let read = TensorRead::from_view(TensorView::F64(view));
411    /// let mut executor = GraphExecutor::new(CpuBackend::new());
412    ///
413    /// let value = executor
414    ///     .run_value_with_input_reads(&program, &[(&x, read)])
415    ///     .unwrap();
416    /// assert!(matches!(&value, TensorValue::View(_)));
417    /// assert_eq!(value.shape(), &[2, 2]);
418    /// ```
419    pub fn run_value_with_input_reads<'a>(
420        &mut self,
421        program: &'a GraphProgram,
422        bindings: &[(&TracedTensor, TensorRead<'a>)],
423    ) -> Result<TensorValue> {
424        let mut outputs = self.run_many_values_with_input_reads(program, bindings)?;
425        expect_single_value(&mut outputs)
426    }
427
428    /// Run a program with explicit runtime placeholder bindings.
429    ///
430    /// # Examples
431    ///
432    /// ```
433    /// use tenferro_cpu::CpuBackend;
434    /// use tenferro_runtime::{DType, GraphCompiler, GraphExecutor, Tensor, TracedTensor};
435    ///
436    /// let x = TracedTensor::input_symbolic_shape(DType::F64, 1).unwrap();
437    /// let sum = (&x + &x).unwrap();
438    /// let mut compiler = GraphCompiler::new();
439    /// let program = compiler
440    ///     .compile_with_input_specs(&sum, &[(&x, DType::F64, &[2])])
441    ///     .unwrap();
442    /// let bound = Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
443    /// let mut executor = GraphExecutor::new(CpuBackend::new());
444    /// let outputs = executor.run_many_with_inputs(&program, &[(&x, &bound)]).unwrap();
445    /// assert_eq!(outputs.len(), 1);
446    /// assert_eq!(outputs[0].as_slice::<f64>().unwrap(), &[2.0, 4.0]);
447    /// ```
448    pub fn run_many_with_inputs(
449        &mut self,
450        program: &GraphProgram,
451        bindings: &[(&TracedTensor, &Tensor)],
452    ) -> Result<Vec<Tensor>> {
453        let input_tensors = resolve_inputs(program, bindings, &mut self.backend)?;
454        self.eval_exec_ir(&program.exec, input_tensors)
455    }
456
457    /// Run a program with explicit bindings and preserve lazy output views.
458    ///
459    /// # Examples
460    ///
461    /// ```
462    /// use tenferro_cpu::CpuBackend;
463    /// use tenferro_runtime::{DType, GraphCompiler, GraphExecutor, Tensor, TensorValue, TracedTensor};
464    ///
465    /// let x = TracedTensor::input_symbolic_shape(DType::F64, 2).unwrap();
466    /// let y = x.transpose(&[1, 0]).unwrap();
467    /// let mut compiler = GraphCompiler::new();
468    /// let program = compiler
469    ///     .compile_with_input_specs(&y, &[(&x, DType::F64, &[2, 2])])
470    ///     .unwrap();
471    /// let bound = Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap();
472    /// let mut executor = GraphExecutor::new(CpuBackend::new());
473    ///
474    /// let outputs = executor
475    ///     .run_many_values_with_inputs(&program, &[(&x, &bound)])
476    ///     .unwrap();
477    /// assert_eq!(outputs.len(), 1);
478    /// assert!(matches!(&outputs[0], TensorValue::View(_)));
479    /// ```
480    pub fn run_many_values_with_inputs(
481        &mut self,
482        program: &GraphProgram,
483        bindings: &[(&TracedTensor, &Tensor)],
484    ) -> Result<Vec<TensorValue>> {
485        let input_tensors = resolve_inputs(program, bindings, &mut self.backend)?;
486        self.eval_exec_ir_values(&program.exec, input_tensors)
487    }
488
489    /// Run a program with explicit borrowed runtime placeholder bindings.
490    ///
491    /// Bindings override program defaults and are validated against the input
492    /// specs captured in the compiled program. Bound tensors are borrowed by
493    /// the executor for this call instead of cloned into input slots.
494    ///
495    /// # Examples
496    ///
497    /// ```
498    /// use tenferro_cpu::CpuBackend;
499    /// use tenferro_runtime::{
500    ///     DType, GraphCompiler, GraphExecutor, TensorRead, TensorView, TracedTensor, TypedTensorView,
501    /// };
502    ///
503    /// let x = TracedTensor::input_symbolic_shape(DType::F64, 1).unwrap();
504    /// let y = (&x + &x).unwrap();
505    /// let mut compiler = GraphCompiler::new();
506    /// let program = compiler
507    ///     .compile_with_input_specs(&y, &[(&x, DType::F64, &[2])])
508    ///     .unwrap();
509    /// let data = [1.0_f64, 99.0, 2.0];
510    /// let view = TypedTensorView::from_slice([2], [2], 0, &data).unwrap();
511    /// let read = TensorRead::from_view(TensorView::F64(view));
512    /// let mut executor = GraphExecutor::new(CpuBackend::new());
513    ///
514    /// let outputs = executor.run_many_with_input_reads(&program, &[(&x, read)]).unwrap();
515    /// assert_eq!(outputs[0].as_slice::<f64>().unwrap(), &[2.0, 4.0]);
516    /// ```
517    pub fn run_many_with_input_reads<'a>(
518        &mut self,
519        program: &'a GraphProgram,
520        bindings: &[(&TracedTensor, TensorRead<'a>)],
521    ) -> Result<Vec<Tensor>> {
522        let inputs = resolve_input_reads(program, bindings, &mut self.backend)?;
523        self.eval_exec_ir_slots(&program.exec, inputs)
524    }
525
526    /// Run a program with borrowed bindings and preserve lazy output views.
527    ///
528    /// # Examples
529    ///
530    /// ```
531    /// use tenferro_cpu::CpuBackend;
532    /// use tenferro_runtime::{
533    ///     DType, GraphCompiler, GraphExecutor, TensorRead, TensorValue, TensorView, TracedTensor,
534    ///     TypedTensorView,
535    /// };
536    ///
537    /// let x = TracedTensor::input_symbolic_shape(DType::F64, 2).unwrap();
538    /// let y = x.transpose(&[1, 0]).unwrap();
539    /// let mut compiler = GraphCompiler::new();
540    /// let program = compiler
541    ///     .compile_with_input_specs(&y, &[(&x, DType::F64, &[2, 2])])
542    ///     .unwrap();
543    /// let data = [1.0_f64, 2.0, 3.0, 4.0];
544    /// let view = TypedTensorView::from_slice([2, 2], [1, 2], 0, &data).unwrap();
545    /// let read = TensorRead::from_view(TensorView::F64(view));
546    /// let mut executor = GraphExecutor::new(CpuBackend::new());
547    ///
548    /// let outputs = executor
549    ///     .run_many_values_with_input_reads(&program, &[(&x, read)])
550    ///     .unwrap();
551    /// assert_eq!(outputs.len(), 1);
552    /// assert!(matches!(&outputs[0], TensorValue::View(_)));
553    /// ```
554    pub fn run_many_values_with_input_reads<'a>(
555        &mut self,
556        program: &'a GraphProgram,
557        bindings: &[(&TracedTensor, TensorRead<'a>)],
558    ) -> Result<Vec<TensorValue>> {
559        let inputs = resolve_input_reads(program, bindings, &mut self.backend)?;
560        self.eval_exec_ir_slot_values(&program.exec, inputs)
561    }
562
563    /// Evaluate an execution program through this executor's backend state.
564    ///
565    /// This lower-level entry point is intended for code that already owns an
566    /// execution program and concrete ordered input tensors.
567    ///
568    /// # Examples
569    ///
570    /// ```
571    /// use tenferro_cpu::CpuBackend;
572    /// use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
573    ///
574    /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
575    /// let mut compiler = GraphCompiler::new();
576    /// let program = compiler.compile(&x.neg()).unwrap();
577    /// let mut executor = GraphExecutor::new(CpuBackend::new());
578    /// let out = executor.run(&program).unwrap();
579    /// assert_eq!(out.as_slice::<f64>().unwrap(), &[-2.0]);
580    /// ```
581    pub fn eval_exec_ir(
582        &mut self,
583        program: &ExecProgram,
584        inputs: Vec<Tensor>,
585    ) -> Result<Vec<Tensor>> {
586        validate_exec_input_count(program, inputs.len())?;
587        crate::segment::eval_exec_segmented_with_cache_and_workspace(
588            &mut self.backend,
589            program,
590            inputs,
591            &mut self.slot_workspace,
592            &mut self.backend_cache,
593            Some(&mut self.extension_executor),
594        )
595    }
596
597    /// Evaluate an execution program and preserve lazy owned output views.
598    ///
599    /// # Examples
600    ///
601    /// ```
602    /// use tenferro_cpu::CpuBackend;
603    /// use tenferro_runtime::{GraphCompiler, GraphExecutor, TensorValue, TracedTensor};
604    ///
605    /// let x = TracedTensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap();
606    /// let y = x.transpose(&[1, 0]).unwrap();
607    /// let mut compiler = GraphCompiler::new();
608    /// let program = compiler.compile(&y).unwrap();
609    /// let mut executor = GraphExecutor::new(CpuBackend::new());
610    ///
611    /// let value = executor.run_value(&program).unwrap();
612    /// assert!(matches!(&value, TensorValue::View(_)));
613    /// assert_eq!(value.shape(), &[2, 2]);
614    /// ```
615    pub fn eval_exec_ir_values(
616        &mut self,
617        program: &ExecProgram,
618        inputs: Vec<Tensor>,
619    ) -> Result<Vec<TensorValue>> {
620        validate_exec_input_count(program, inputs.len())?;
621        let inputs = inputs.into_iter().map(ExecSlot::Owned).collect();
622        crate::segment::eval_exec_segmented_slot_values_with_cache_and_workspace(
623            &mut self.backend,
624            program,
625            inputs,
626            &mut self.slot_workspace,
627            &mut self.backend_cache,
628            Some(&mut self.extension_executor),
629        )
630    }
631
632    /// Evaluate an execution program without consuming caller-owned inputs.
633    ///
634    /// # Examples
635    ///
636    /// ```
637    /// use tenferro_cpu::CpuBackend;
638    /// use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
639    ///
640    /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
641    /// let mut compiler = GraphCompiler::new();
642    /// let program = compiler.compile(&x.neg()).unwrap();
643    /// let mut executor = GraphExecutor::new(CpuBackend::new());
644    /// let out = executor.run(&program).unwrap();
645    /// assert_eq!(out.shape(), &[1]);
646    /// ```
647    pub fn eval_exec_ir_non_consuming(
648        &mut self,
649        program: &ExecProgram,
650        inputs: &[Tensor],
651    ) -> Result<Vec<Tensor>> {
652        let inputs = inputs
653            .iter()
654            .map(|tensor| ExecSlot::Read(TensorRead::from_tensor(tensor)))
655            .collect();
656        self.eval_exec_ir_slots(program, inputs)
657    }
658
659    /// Evaluate an execution program without consuming inputs and preserve lazy outputs.
660    ///
661    /// # Examples
662    ///
663    /// ```
664    /// use tenferro_cpu::CpuBackend;
665    /// use tenferro_runtime::{GraphCompiler, GraphExecutor, TensorValue, TracedTensor};
666    ///
667    /// let x = TracedTensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap();
668    /// let y = x.transpose(&[1, 0]).unwrap();
669    /// let mut compiler = GraphCompiler::new();
670    /// let program = compiler.compile(&y).unwrap();
671    /// let mut executor = GraphExecutor::new(CpuBackend::new());
672    ///
673    /// let value = executor.run_value(&program).unwrap();
674    /// assert!(matches!(&value, TensorValue::View(_)));
675    /// assert_eq!(value.to_tensor().unwrap().shape(), &[2, 2]);
676    /// ```
677    pub fn eval_exec_ir_non_consuming_values(
678        &mut self,
679        program: &ExecProgram,
680        inputs: &[Tensor],
681    ) -> Result<Vec<TensorValue>> {
682        let inputs = inputs
683            .iter()
684            .map(|tensor| ExecSlot::Read(TensorRead::from_tensor(tensor)))
685            .collect();
686        self.eval_exec_ir_slot_values(program, inputs)
687    }
688
689    fn eval_exec_ir_slots<'a>(
690        &mut self,
691        program: &ExecProgram,
692        inputs: Vec<ExecSlot<'a>>,
693    ) -> Result<Vec<Tensor>> {
694        validate_exec_input_count(program, inputs.len())?;
695        let mut slot_workspace = Vec::with_capacity(self.borrowed_slot_workspace_capacity);
696        let result = crate::segment::eval_exec_segmented_slots_with_cache_and_workspace(
697            &mut self.backend,
698            program,
699            inputs,
700            &mut slot_workspace,
701            &mut self.backend_cache,
702            Some(&mut self.extension_executor),
703        );
704        self.borrowed_slot_workspace_capacity = slot_workspace.capacity();
705        result
706    }
707
708    fn eval_exec_ir_slot_values<'a>(
709        &mut self,
710        program: &ExecProgram,
711        inputs: Vec<ExecSlot<'a>>,
712    ) -> Result<Vec<TensorValue>> {
713        validate_exec_input_count(program, inputs.len())?;
714        let mut slot_workspace = Vec::with_capacity(self.borrowed_slot_workspace_capacity);
715        let result = crate::segment::eval_exec_segmented_slot_values_with_cache_and_workspace(
716            &mut self.backend,
717            program,
718            inputs,
719            &mut slot_workspace,
720            &mut self.backend_cache,
721            Some(&mut self.extension_executor),
722        );
723        self.borrowed_slot_workspace_capacity = slot_workspace.capacity();
724        result
725    }
726
727    /// Clear backend-specific runtime analysis cache entries.
728    ///
729    /// # Examples
730    ///
731    /// ```
732    /// use tenferro_cpu::CpuBackend;
733    /// use tenferro_runtime::{GraphExecutor};
734    ///
735    /// let mut executor = GraphExecutor::new(CpuBackend::new());
736    /// executor.clear_backend_cache();
737    /// assert_eq!(executor.cache_stats().backend.entries, 0);
738    /// ```
739    pub fn clear_backend_cache(&mut self) {
740        self.backend_cache.clear();
741    }
742
743    /// Clear generic extension runtime cache entries.
744    ///
745    /// # Examples
746    ///
747    /// ```
748    /// use tenferro_cpu::CpuBackend;
749    /// use tenferro_runtime::{GraphExecutor};
750    ///
751    /// let mut executor = GraphExecutor::new(CpuBackend::new());
752    /// executor.clear_extension_caches();
753    /// assert_eq!(executor.cache_stats().extensions.entries, 0);
754    /// ```
755    pub fn clear_extension_caches(&mut self) {
756        self.extension_executor.clear_caches();
757    }
758
759    /// Clear every executor-owned runtime cache.
760    ///
761    /// # Examples
762    ///
763    /// ```
764    /// use tenferro_cpu::CpuBackend;
765    /// use tenferro_runtime::{GraphExecutor};
766    ///
767    /// let mut executor = GraphExecutor::new(CpuBackend::new());
768    /// executor.clear_caches();
769    /// assert_eq!(executor.cache_stats().backend.entries, 0);
770    /// ```
771    pub fn clear_caches(&mut self) {
772        self.clear_extension_caches();
773        self.clear_backend_cache();
774    }
775
776    /// Return executor runtime cache-entry and retained-byte stats.
777    ///
778    /// # Examples
779    ///
780    /// ```
781    /// use tenferro_cpu::CpuBackend;
782    /// use tenferro_runtime::{GraphExecutor};
783    ///
784    /// let executor = GraphExecutor::new(CpuBackend::new());
785    /// let stats = executor.cache_stats();
786    /// assert_eq!(stats.extensions.entries, 0);
787    /// ```
788    pub fn cache_stats(&self) -> GraphExecutorCacheStats {
789        GraphExecutorCacheStats {
790            extensions: self.extension_executor.cache_stats(),
791            backend: self.backend_cache.stats(),
792        }
793    }
794}
795
796impl<B: TensorBackend + 'static> Default for GraphExecutor<B>
797where
798    B: Default,
799{
800    fn default() -> Self {
801        Self::new(B::default())
802    }
803}
804
805fn validate_exec_input_count(program: &ExecProgram, actual: usize) -> Result<()> {
806    let expected = program.input_slots.len();
807    if actual != expected {
808        return Err(Error::Internal(format!(
809            "expected {expected} inputs for execution program, got {actual}"
810        )));
811    }
812    Ok(())
813}
814
815fn expect_single_output(outputs: &mut Vec<Tensor>) -> Result<Tensor> {
816    if outputs.len() != 1 {
817        return Err(Error::Internal(format!(
818            "expected 1 output, got {}",
819            outputs.len()
820        )));
821    }
822    outputs
823        .pop()
824        .ok_or_else(|| Error::Internal("missing graph output".to_string()))
825}
826
827fn expect_single_value(outputs: &mut Vec<TensorValue>) -> Result<TensorValue> {
828    if outputs.len() != 1 {
829        return Err(Error::Internal(format!(
830            "expected 1 output, got {}",
831            outputs.len()
832        )));
833    }
834    outputs
835        .pop()
836        .ok_or_else(|| Error::Internal("missing graph output".to_string()))
837}
838
839fn resolve_inputs(
840    program: &GraphProgram,
841    bindings: &[(&TracedTensor, &Tensor)],
842    backend: &mut impl TensorBackend,
843) -> Result<Vec<Tensor>> {
844    let program_keys: HashSet<_> = program
845        .inputs
846        .iter()
847        .map(|input| input.key.clone())
848        .collect();
849    let tangent_root_specs = tangent_root_specs(&program.inputs);
850    let default_map: HashMap<_, _> = program
851        .inputs
852        .iter()
853        .filter_map(|input| {
854            input
855                .default_tensor
856                .as_ref()
857                .map(|tensor| (input.key.clone(), tensor.as_ref()))
858        })
859        .collect();
860    let mut binding_map = HashMap::new();
861    for (index, (placeholder, tensor)) in bindings.iter().enumerate() {
862        if placeholder.data.is_some() {
863            return Err(Error::UnexpectedBinding {
864                binding_index: index,
865            });
866        }
867        let key = placeholder.input_key().ok_or(Error::UnexpectedBinding {
868            binding_index: index,
869        })?;
870        validate_binding_placeholder(index, placeholder, tensor)?;
871        let is_program_input = program_keys.contains(&key);
872        if !is_program_input && !tangent_root_specs.contains_key(&key) {
873            return Err(Error::UnexpectedBinding {
874                binding_index: index,
875            });
876        }
877        if binding_map.insert(key.clone(), *tensor).is_some() {
878            return Err(Error::DuplicateBinding {
879                input_key: format!("{:?}", key),
880            });
881        }
882    }
883
884    program
885        .inputs
886        .iter()
887        .map(|input| resolve_input(input, &binding_map, &default_map, backend))
888        .collect()
889}
890
891fn resolve_input_reads<'a>(
892    program: &'a GraphProgram,
893    bindings: &[(&TracedTensor, TensorRead<'a>)],
894    backend: &mut impl TensorBackend,
895) -> Result<Vec<ExecSlot<'a>>> {
896    let program_keys: HashSet<_> = program
897        .inputs
898        .iter()
899        .map(|input| input.key.clone())
900        .collect();
901    let tangent_root_specs = tangent_root_specs(&program.inputs);
902    let default_map: HashMap<_, _> = program
903        .inputs
904        .iter()
905        .filter_map(|input| {
906            input
907                .default_tensor
908                .as_ref()
909                .map(|tensor| (input.key.clone(), tensor.as_ref()))
910        })
911        .collect();
912    let mut binding_map = HashMap::new();
913    for (index, (placeholder, read)) in bindings.iter().enumerate() {
914        if placeholder.data.is_some() {
915            return Err(Error::UnexpectedBinding {
916                binding_index: index,
917            });
918        }
919        let key = placeholder.input_key().ok_or(Error::UnexpectedBinding {
920            binding_index: index,
921        })?;
922        validate_binding_placeholder_read(index, placeholder, read)?;
923        let is_program_input = program_keys.contains(&key);
924        if !is_program_input && !tangent_root_specs.contains_key(&key) {
925            return Err(Error::UnexpectedBinding {
926                binding_index: index,
927            });
928        }
929        if binding_map.insert(key.clone(), read.clone()).is_some() {
930            return Err(Error::DuplicateBinding {
931                input_key: format!("{:?}", key),
932            });
933        }
934    }
935
936    program
937        .inputs
938        .iter()
939        .map(|input| resolve_input_read(input, &binding_map, &default_map, backend))
940        .collect()
941}
942
943fn tangent_root_specs(inputs: &[GraphProgramInput]) -> HashMap<TensorInputKey, &GraphProgramInput> {
944    let mut specs = HashMap::new();
945    for input in inputs {
946        if !matches!(input.key, TensorInputKey::User { .. }) {
947            specs
948                .entry(tangent_primal_root(&input.key).clone())
949                .or_insert(input);
950        }
951    }
952    specs
953}
954
955fn resolve_input(
956    input: &GraphProgramInput,
957    bindings: &HashMap<TensorInputKey, &Tensor>,
958    defaults: &HashMap<TensorInputKey, &Tensor>,
959    backend: &mut impl TensorBackend,
960) -> Result<Tensor> {
961    let tensor = if let Some(bound) = bindings.get(&input.key) {
962        (*bound).clone()
963    } else if let Some(default) = &input.default_tensor {
964        resolve_default_tensor(default.as_ref(), backend)?
965    } else if let Some(zero) = deferred_zero_for_tangent_key(&input.key, bindings, defaults)? {
966        zero
967    } else {
968        return Err(Error::UnboundPlaceholder {
969            input_key: format!("{:?}", input.key),
970        });
971    };
972    validate_input_tensor(input, &tensor)?;
973    Ok(tensor)
974}
975
976fn resolve_input_read<'a>(
977    input: &GraphProgramInput,
978    bindings: &HashMap<TensorInputKey, TensorRead<'a>>,
979    defaults: &HashMap<TensorInputKey, &'a Tensor>,
980    backend: &mut impl TensorBackend,
981) -> Result<ExecSlot<'a>> {
982    let slot = if let Some(bound) = bindings.get(&input.key) {
983        ExecSlot::Read(bound.clone())
984    } else if let Some(default) = defaults.get(&input.key) {
985        if should_upload_default_tensor(default) {
986            ExecSlot::Owned(backend.upload_host_tensor(default)?)
987        } else {
988            ExecSlot::Read(TensorRead::from_tensor(default))
989        }
990    } else if let Some(zero) = deferred_zero_for_tangent_key_read(&input.key, bindings, defaults)? {
991        ExecSlot::Owned(zero)
992    } else {
993        return Err(Error::UnboundPlaceholder {
994            input_key: format!("{:?}", input.key),
995        });
996    };
997    validate_input_slot(input, &slot)?;
998    Ok(slot)
999}
1000
1001fn resolve_default_tensor(default: &Tensor, backend: &mut impl TensorBackend) -> Result<Tensor> {
1002    if should_upload_default_tensor(default) {
1003        Ok(backend.upload_host_tensor(default)?)
1004    } else {
1005        Ok(default.clone())
1006    }
1007}
1008
1009fn should_upload_default_tensor(default: &Tensor) -> bool {
1010    default.shape().is_empty() && tensor_has_host_buffer(default)
1011}
1012
1013fn tensor_has_host_buffer(tensor: &Tensor) -> bool {
1014    !tensor.is_backend_buffer()
1015}
1016
1017fn validate_binding_placeholder(
1018    index: usize,
1019    placeholder: &TracedTensor,
1020    tensor: &Tensor,
1021) -> Result<()> {
1022    if placeholder.data.is_some() {
1023        return Err(Error::UnexpectedBinding {
1024            binding_index: index,
1025        });
1026    }
1027    if placeholder.dtype != tensor.dtype() {
1028        return Err(Error::PlaceholderDtypeMismatch {
1029            expected: placeholder.dtype,
1030            actual: tensor.dtype(),
1031        });
1032    }
1033    match placeholder.try_concrete_shape() {
1034        Some(expected_shape) => {
1035            if expected_shape.as_slice() != tensor.shape() {
1036                return Err(Error::PlaceholderShapeMismatch {
1037                    expected: expected_shape,
1038                    actual: tensor.shape().to_vec(),
1039                });
1040            }
1041        }
1042        None => {
1043            if placeholder.rank != tensor.shape().len() {
1044                return Err(Error::PlaceholderRankMismatch {
1045                    expected: placeholder.rank,
1046                    actual: tensor.shape().len(),
1047                });
1048            }
1049        }
1050    }
1051    Ok(())
1052}
1053
1054fn validate_binding_placeholder_read(
1055    index: usize,
1056    placeholder: &TracedTensor,
1057    read: &TensorRead<'_>,
1058) -> Result<()> {
1059    if placeholder.data.is_some() {
1060        return Err(Error::UnexpectedBinding {
1061            binding_index: index,
1062        });
1063    }
1064    if placeholder.dtype != read.dtype() {
1065        return Err(Error::PlaceholderDtypeMismatch {
1066            expected: placeholder.dtype,
1067            actual: read.dtype(),
1068        });
1069    }
1070    match placeholder.try_concrete_shape() {
1071        Some(expected_shape) => {
1072            if expected_shape.as_slice() != read.shape() {
1073                return Err(Error::PlaceholderShapeMismatch {
1074                    expected: expected_shape,
1075                    actual: read.shape().to_vec(),
1076                });
1077            }
1078        }
1079        None => {
1080            if placeholder.rank != read.shape().len() {
1081                return Err(Error::PlaceholderRankMismatch {
1082                    expected: placeholder.rank,
1083                    actual: read.shape().len(),
1084                });
1085            }
1086        }
1087    }
1088    Ok(())
1089}
1090
1091fn validate_input_tensor(input: &GraphProgramInput, tensor: &Tensor) -> Result<()> {
1092    if input.dtype != tensor.dtype() {
1093        return Err(Error::PlaceholderDtypeMismatch {
1094            expected: input.dtype,
1095            actual: tensor.dtype(),
1096        });
1097    }
1098    if input.shape.as_slice() != tensor.shape() {
1099        return Err(Error::PlaceholderShapeMismatch {
1100            expected: input.shape.clone(),
1101            actual: tensor.shape().to_vec(),
1102        });
1103    }
1104    Ok(())
1105}
1106
1107fn validate_input_slot(input: &GraphProgramInput, slot: &ExecSlot<'_>) -> Result<()> {
1108    if input.dtype != slot.dtype() {
1109        return Err(Error::PlaceholderDtypeMismatch {
1110            expected: input.dtype,
1111            actual: slot.dtype(),
1112        });
1113    }
1114    if input.shape.as_slice() != slot.shape() {
1115        return Err(Error::PlaceholderShapeMismatch {
1116            expected: input.shape.clone(),
1117            actual: slot.shape().to_vec(),
1118        });
1119    }
1120    Ok(())
1121}
1122
1123fn deferred_zero_for_tangent_key(
1124    key: &TensorInputKey,
1125    bindings: &HashMap<TensorInputKey, &Tensor>,
1126    defaults: &HashMap<TensorInputKey, &Tensor>,
1127) -> Result<Option<Tensor>> {
1128    if !key.is_tangent() {
1129        return Ok(None);
1130    }
1131    let root = tangent_primal_root(key);
1132    let Some(primal) = bindings.get(root).or_else(|| defaults.get(root)) else {
1133        return Ok(None);
1134    };
1135    zeros_tensor(primal.dtype(), primal.shape().to_vec()).map(Some)
1136}
1137
1138fn deferred_zero_for_tangent_key_read<'a>(
1139    key: &TensorInputKey,
1140    bindings: &HashMap<TensorInputKey, TensorRead<'a>>,
1141    defaults: &HashMap<TensorInputKey, &'a Tensor>,
1142) -> Result<Option<Tensor>> {
1143    if !key.is_tangent() {
1144        return Ok(None);
1145    }
1146    let root = tangent_primal_root(key);
1147    if let Some(primal) = bindings.get(root) {
1148        return zeros_tensor(primal.dtype(), primal.shape().to_vec()).map(Some);
1149    }
1150    let Some(primal) = defaults.get(root) else {
1151        return Ok(None);
1152    };
1153    zeros_tensor(primal.dtype(), primal.shape().to_vec()).map(Some)
1154}
1155
1156fn tangent_primal_root(key: &TensorInputKey) -> &TensorInputKey {
1157    key.primal_root()
1158}
1159
1160fn zeros_tensor(dtype: DType, shape: Vec<usize>) -> Result<Tensor> {
1161    match dtype {
1162        DType::F32 => Ok(Tensor::F32(TypedTensor::zeros(shape)?)),
1163        DType::F64 => Ok(Tensor::F64(TypedTensor::zeros(shape)?)),
1164        DType::I32 => Ok(Tensor::I32(TypedTensor::zeros(shape)?)),
1165        DType::I64 => Ok(Tensor::I64(TypedTensor::zeros(shape)?)),
1166        DType::Bool => {
1167            let len = checked_default_element_count(&shape)?;
1168            Ok(Tensor::Bool(TypedTensor::from_vec_col_major(
1169                shape,
1170                vec![false; len],
1171            )?))
1172        }
1173        DType::C32 => Ok(Tensor::C32(TypedTensor::zeros(shape)?)),
1174        DType::C64 => Ok(Tensor::C64(TypedTensor::zeros(shape)?)),
1175    }
1176}
1177
1178fn checked_default_element_count(shape: &[usize]) -> Result<usize> {
1179    shape.iter().try_fold(1usize, |acc, &dim| {
1180        acc.checked_mul(dim)
1181            .ok_or_else(|| Error::InvalidCompiledGraph {
1182                message: format!("deferred zero shape product overflows usize for shape {shape:?}"),
1183            })
1184    })
1185}
1186
1187#[cfg(test)]
1188mod tests;