tenferro_runtime/graph/executor.rs
1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::sync::Arc;
4
5use tenferro_ops::input_key::TensorInputKey;
6use tenferro_tensor::{
7 DType, RuntimeCacheControl, Tensor, TensorBackend, TensorRead, TensorValue, TypedTensor,
8};
9
10use super::cache::GraphExecutorCacheStats;
11use super::program::{GraphProgram, GraphProgramInput};
12use crate::error::{Error, Result};
13use crate::exec::{ExecProgram, ExecSlot};
14use crate::extension_runtime::{ExtensionExecutor, ExtensionRuntimeRegistryError};
15use crate::traced::TracedTensor;
16
17/// Executes compiled graph programs on a concrete tensor backend.
18///
19/// A graph executor owns backend execution state only: backend runtime caches,
20/// extension runtime state, and reusable execution workspace. Compilation
21/// state lives in [`GraphCompiler`](super::GraphCompiler).
22///
23/// # Examples
24///
25/// ```
26/// use tenferro_cpu::CpuBackend;
27/// use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
28///
29/// let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
30/// let y = (&x + &x).unwrap();
31/// let mut compiler = GraphCompiler::new();
32/// let program = compiler.compile(&y).unwrap();
33///
34/// let mut executor = GraphExecutor::new(CpuBackend::new());
35/// let out = executor.run(&program).unwrap();
36/// assert_eq!(out.as_slice::<f64>().unwrap(), &[2.0, 4.0]);
37/// ```
38pub struct GraphExecutor<B: TensorBackend + 'static> {
39 backend: B,
40 backend_cache: B::RuntimeCache,
41 extension_executor: ExtensionExecutor<B>,
42 slot_workspace: Vec<Option<ExecSlot<'static>>>,
43 borrowed_slot_workspace_capacity: usize,
44}
45
46impl<B: TensorBackend + 'static> fmt::Debug for GraphExecutor<B> {
47 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48 f.debug_struct("GraphExecutor")
49 .field("backend_type", &std::any::type_name::<B>())
50 .field("cache_stats", &self.cache_stats())
51 .field("slot_workspace_len", &self.slot_workspace.len())
52 .field(
53 "borrowed_slot_workspace_capacity",
54 &self.borrowed_slot_workspace_capacity,
55 )
56 .finish_non_exhaustive()
57 }
58}
59
60impl<B: TensorBackend + 'static> GraphExecutor<B> {
61 /// Create an executor with the given backend and bounded default caches.
62 ///
63 /// # Examples
64 ///
65 /// ```
66 /// use tenferro_cpu::CpuBackend;
67 /// use tenferro_runtime::{GraphExecutor};
68 ///
69 /// let executor = GraphExecutor::new(CpuBackend::new());
70 /// assert_eq!(executor.cache_stats().extensions.entries, 0);
71 /// ```
72 pub fn new(backend: B) -> Self {
73 Self {
74 backend,
75 backend_cache: B::RuntimeCache::default(),
76 extension_executor: ExtensionExecutor::new(),
77 slot_workspace: Vec::new(),
78 borrowed_slot_workspace_capacity: 0,
79 }
80 }
81
82 /// Borrow the backend used by this executor.
83 ///
84 /// # Examples
85 ///
86 /// ```
87 /// use tenferro_cpu::CpuBackend;
88 /// use tenferro_runtime::{GraphExecutor};
89 ///
90 /// let executor = GraphExecutor::new(CpuBackend::new());
91 /// let _backend = executor.backend();
92 /// ```
93 pub fn backend(&self) -> &B {
94 &self.backend
95 }
96
97 /// Return output tensors to the executor backend's reusable buffer pool.
98 ///
99 /// This is useful for tight benchmark or serving loops that consume an
100 /// output before the next run and want backend-level output allocation
101 /// behavior to match caching allocators.
102 ///
103 /// # Examples
104 ///
105 /// ```
106 /// use tenferro_cpu::CpuBackend;
107 /// use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
108 ///
109 /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
110 /// let mut compiler = GraphCompiler::new();
111 /// let program = compiler.compile(&x.neg()).unwrap();
112 /// let mut executor = GraphExecutor::new(CpuBackend::new());
113 ///
114 /// let out = executor.run(&program).unwrap();
115 /// assert_eq!(out.as_slice::<f64>().unwrap(), &[-2.0]);
116 /// executor.reclaim_outputs(vec![out]);
117 /// ```
118 pub fn reclaim_outputs(&mut self, outputs: Vec<Tensor>) {
119 for tensor in outputs {
120 self.backend.reclaim_buffer(tensor);
121 }
122 }
123
124 /// Return compact value outputs to the backend pool when ownership is unique.
125 ///
126 /// Lazy owned views are intentionally ignored because their base storage may
127 /// be aliased by view metadata.
128 ///
129 /// # Examples
130 ///
131 /// ```
132 /// use tenferro_cpu::CpuBackend;
133 /// use tenferro_runtime::{GraphExecutor, Tensor, TensorValue};
134 ///
135 /// let tensor = Tensor::from_vec_col_major(vec![1], vec![3.0_f64]).unwrap();
136 /// let mut executor = GraphExecutor::new(CpuBackend::new());
137 /// executor.reclaim_value_outputs(vec![TensorValue::from_tensor(tensor)]);
138 /// ```
139 pub fn reclaim_value_outputs(&mut self, outputs: Vec<TensorValue>) {
140 for value in outputs {
141 if let TensorValue::Tensor(tensor) = value {
142 if let Ok(tensor) = Arc::try_unwrap(tensor) {
143 self.backend.reclaim_buffer(tensor);
144 }
145 }
146 }
147 }
148
149 /// Borrow the extension runtime executor owned by this graph executor.
150 ///
151 /// # Examples
152 ///
153 /// ```
154 /// use tenferro_cpu::CpuBackend;
155 /// use tenferro_runtime::{GraphExecutor};
156 ///
157 /// let executor = GraphExecutor::new(CpuBackend::new());
158 /// assert_eq!(executor.extension_executor().cache_stats().entries, 0);
159 /// ```
160 pub fn extension_executor(&self) -> &ExtensionExecutor<B> {
161 &self.extension_executor
162 }
163
164 /// Mutably borrow the extension runtime executor owned by this graph executor.
165 ///
166 /// # Examples
167 ///
168 /// ```
169 /// use tenferro_cpu::CpuBackend;
170 /// use tenferro_runtime::{GraphExecutor};
171 ///
172 /// let mut executor = GraphExecutor::new(CpuBackend::new());
173 /// executor.extension_executor_mut().clear_caches();
174 /// ```
175 pub fn extension_executor_mut(&mut self) -> &mut ExtensionExecutor<B> {
176 &mut self.extension_executor
177 }
178
179 /// Register one extension runtime on this executor.
180 ///
181 /// # Examples
182 ///
183 /// ```
184 /// use tenferro_cpu::CpuBackend;
185 /// use tenferro_runtime::GraphExecutor;
186 ///
187 /// let mut executor = GraphExecutor::new(CpuBackend::new());
188 /// executor.register_extension(|_| Ok(())).unwrap();
189 /// ```
190 pub fn register_extension(
191 &mut self,
192 register: impl FnOnce(
193 &mut ExtensionExecutor<B>,
194 ) -> std::result::Result<(), ExtensionRuntimeRegistryError>,
195 ) -> std::result::Result<(), ExtensionRuntimeRegistryError> {
196 register(&mut self.extension_executor)
197 }
198
199 /// Run a one-output program using the program's default input tensors.
200 ///
201 /// # Examples
202 ///
203 /// ```
204 /// use tenferro_cpu::CpuBackend;
205 /// use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
206 ///
207 /// let x = TracedTensor::from_vec_col_major(vec![1], vec![3.0_f64]).unwrap();
208 /// let mut compiler = GraphCompiler::new();
209 /// let program = compiler.compile(&x.neg()).unwrap();
210 /// let mut executor = GraphExecutor::new(CpuBackend::new());
211 /// let out = executor.run(&program).unwrap();
212 /// assert_eq!(out.as_slice::<f64>().unwrap(), &[-3.0]);
213 /// ```
214 pub fn run(&mut self, program: &GraphProgram) -> Result<Tensor> {
215 let mut outputs = self.run_many(program)?;
216 expect_single_output(&mut outputs)
217 }
218
219 /// Run a one-output program and preserve lazy owned output views.
220 ///
221 /// # Examples
222 ///
223 /// ```
224 /// use tenferro_cpu::CpuBackend;
225 /// use tenferro_runtime::{GraphCompiler, GraphExecutor, TensorValue, TracedTensor};
226 ///
227 /// let x = TracedTensor::from_vec_col_major(
228 /// vec![2, 3],
229 /// vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0],
230 /// )
231 /// .unwrap();
232 /// let y = x.transpose(&[1, 0]).unwrap();
233 /// let mut compiler = GraphCompiler::new();
234 /// let program = compiler.compile(&y).unwrap();
235 ///
236 /// let mut executor = GraphExecutor::new(CpuBackend::new());
237 /// let value = executor.run_value(&program).unwrap();
238 /// assert!(matches!(&value, TensorValue::View(_)));
239 /// assert_eq!(value.shape(), &[3, 2]);
240 /// assert_eq!(
241 /// value.to_tensor().unwrap().as_slice::<f64>().unwrap(),
242 /// &[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]
243 /// );
244 /// ```
245 pub fn run_value(&mut self, program: &GraphProgram) -> Result<TensorValue> {
246 let mut outputs = self.run_many_values(program)?;
247 expect_single_value(&mut outputs)
248 }
249
250 /// Run a program using the program's default input tensors.
251 ///
252 /// # Examples
253 ///
254 /// ```
255 /// use tenferro_cpu::CpuBackend;
256 /// use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
257 ///
258 /// let x = TracedTensor::from_vec_col_major(vec![1], vec![3.0_f64]).unwrap();
259 /// let y = x.neg();
260 /// let mut compiler = GraphCompiler::new();
261 /// let program = compiler.compile_many(&[&x, &y]).unwrap();
262 /// let mut executor = GraphExecutor::new(CpuBackend::new());
263 /// let outputs = executor.run_many(&program).unwrap();
264 /// assert_eq!(outputs.len(), 2);
265 /// ```
266 pub fn run_many(&mut self, program: &GraphProgram) -> Result<Vec<Tensor>> {
267 self.run_many_with_inputs(program, &[])
268 }
269
270 /// Run a program and preserve lazy owned output views.
271 ///
272 /// # Examples
273 ///
274 /// ```
275 /// use tenferro_cpu::CpuBackend;
276 /// use tenferro_runtime::{GraphCompiler, GraphExecutor, TensorValue, TracedTensor};
277 ///
278 /// let x = TracedTensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap();
279 /// let y = x.transpose(&[1, 0]).unwrap();
280 /// let mut compiler = GraphCompiler::new();
281 /// let program = compiler.compile_many(&[&y]).unwrap();
282 ///
283 /// let mut executor = GraphExecutor::new(CpuBackend::new());
284 /// let outputs = executor.run_many_values(&program).unwrap();
285 /// assert_eq!(outputs.len(), 1);
286 /// assert!(matches!(&outputs[0], TensorValue::View(_)));
287 /// assert_eq!(outputs[0].shape(), &[2, 2]);
288 /// ```
289 pub fn run_many_values(&mut self, program: &GraphProgram) -> Result<Vec<TensorValue>> {
290 self.run_many_values_with_inputs(program, &[])
291 }
292
293 /// Run a one-output program with explicit runtime placeholder bindings.
294 ///
295 /// Explicit bindings override program defaults and are validated against
296 /// the ordered input specs captured in the compiled program.
297 ///
298 /// # Examples
299 ///
300 /// ```
301 /// use tenferro_cpu::CpuBackend;
302 /// use tenferro_runtime::{DType, GraphCompiler, GraphExecutor, Tensor, TracedTensor};
303 ///
304 /// let x = TracedTensor::input_symbolic_shape(DType::F64, 1).unwrap();
305 /// let y = (&x + &x).unwrap();
306 /// let mut compiler = GraphCompiler::new();
307 /// let program = compiler
308 /// .compile_with_input_specs(&y, &[(&x, DType::F64, &[2])])
309 /// .unwrap();
310 /// let bound = Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
311 /// let mut executor = GraphExecutor::new(CpuBackend::new());
312 /// let out = executor.run_with_inputs(&program, &[(&x, &bound)]).unwrap();
313 /// assert_eq!(out.as_slice::<f64>().unwrap(), &[2.0, 4.0]);
314 /// ```
315 pub fn run_with_inputs(
316 &mut self,
317 program: &GraphProgram,
318 bindings: &[(&TracedTensor, &Tensor)],
319 ) -> Result<Tensor> {
320 let mut outputs = self.run_many_with_inputs(program, bindings)?;
321 expect_single_output(&mut outputs)
322 }
323
324 /// Run a one-output program with explicit bindings and preserve lazy output views.
325 ///
326 /// # Examples
327 ///
328 /// ```
329 /// use tenferro_cpu::CpuBackend;
330 /// use tenferro_runtime::{DType, GraphCompiler, GraphExecutor, Tensor, TensorValue, TracedTensor};
331 ///
332 /// let x = TracedTensor::input_symbolic_shape(DType::F64, 2).unwrap();
333 /// let y = x.transpose(&[1, 0]).unwrap();
334 /// let mut compiler = GraphCompiler::new();
335 /// let program = compiler
336 /// .compile_with_input_specs(&y, &[(&x, DType::F64, &[2, 2])])
337 /// .unwrap();
338 /// let bound = Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap();
339 /// let mut executor = GraphExecutor::new(CpuBackend::new());
340 ///
341 /// let value = executor.run_value_with_inputs(&program, &[(&x, &bound)]).unwrap();
342 /// assert!(matches!(&value, TensorValue::View(_)));
343 /// assert_eq!(value.to_tensor().unwrap().as_slice::<f64>().unwrap(), &[1.0, 3.0, 2.0, 4.0]);
344 /// ```
345 pub fn run_value_with_inputs(
346 &mut self,
347 program: &GraphProgram,
348 bindings: &[(&TracedTensor, &Tensor)],
349 ) -> Result<TensorValue> {
350 let mut outputs = self.run_many_values_with_inputs(program, bindings)?;
351 expect_single_value(&mut outputs)
352 }
353
354 /// Run a one-output program with explicit borrowed runtime placeholder bindings.
355 ///
356 /// Unlike [`run_with_inputs`](Self::run_with_inputs), caller-owned input
357 /// tensors are read through [`TensorRead`] and are not cloned into executor
358 /// slots.
359 ///
360 /// # Examples
361 ///
362 /// ```
363 /// use tenferro_cpu::CpuBackend;
364 /// use tenferro_runtime::{
365 /// DType, GraphCompiler, GraphExecutor, TensorRead, TensorView, TracedTensor, TypedTensorView,
366 /// };
367 ///
368 /// let x = TracedTensor::input_symbolic_shape(DType::F64, 1).unwrap();
369 /// let y = (&x + &x).unwrap();
370 /// let mut compiler = GraphCompiler::new();
371 /// let program = compiler
372 /// .compile_with_input_specs(&y, &[(&x, DType::F64, &[2])])
373 /// .unwrap();
374 /// let data = [1.0_f64, 99.0, 2.0];
375 /// let view = TypedTensorView::from_slice([2], [2], 0, &data).unwrap();
376 /// let read = TensorRead::from_view(TensorView::F64(view));
377 /// let mut executor = GraphExecutor::new(CpuBackend::new());
378 ///
379 /// let out = executor.run_with_input_reads(&program, &[(&x, read)]).unwrap();
380 /// assert_eq!(out.as_slice::<f64>().unwrap(), &[2.0, 4.0]);
381 /// ```
382 pub fn run_with_input_reads<'a>(
383 &mut self,
384 program: &'a GraphProgram,
385 bindings: &[(&TracedTensor, TensorRead<'a>)],
386 ) -> Result<Tensor> {
387 let mut outputs = self.run_many_with_input_reads(program, bindings)?;
388 expect_single_output(&mut outputs)
389 }
390
391 /// Run a one-output program with borrowed bindings and preserve lazy output views.
392 ///
393 /// # Examples
394 ///
395 /// ```
396 /// use tenferro_cpu::CpuBackend;
397 /// use tenferro_runtime::{
398 /// DType, GraphCompiler, GraphExecutor, TensorRead, TensorValue, TensorView, TracedTensor,
399 /// TypedTensorView,
400 /// };
401 ///
402 /// let x = TracedTensor::input_symbolic_shape(DType::F64, 2).unwrap();
403 /// let y = x.transpose(&[1, 0]).unwrap();
404 /// let mut compiler = GraphCompiler::new();
405 /// let program = compiler
406 /// .compile_with_input_specs(&y, &[(&x, DType::F64, &[2, 2])])
407 /// .unwrap();
408 /// let data = [1.0_f64, 2.0, 3.0, 4.0];
409 /// let view = TypedTensorView::from_slice([2, 2], [1, 2], 0, &data).unwrap();
410 /// let read = TensorRead::from_view(TensorView::F64(view));
411 /// let mut executor = GraphExecutor::new(CpuBackend::new());
412 ///
413 /// let value = executor
414 /// .run_value_with_input_reads(&program, &[(&x, read)])
415 /// .unwrap();
416 /// assert!(matches!(&value, TensorValue::View(_)));
417 /// assert_eq!(value.shape(), &[2, 2]);
418 /// ```
419 pub fn run_value_with_input_reads<'a>(
420 &mut self,
421 program: &'a GraphProgram,
422 bindings: &[(&TracedTensor, TensorRead<'a>)],
423 ) -> Result<TensorValue> {
424 let mut outputs = self.run_many_values_with_input_reads(program, bindings)?;
425 expect_single_value(&mut outputs)
426 }
427
428 /// Run a program with explicit runtime placeholder bindings.
429 ///
430 /// # Examples
431 ///
432 /// ```
433 /// use tenferro_cpu::CpuBackend;
434 /// use tenferro_runtime::{DType, GraphCompiler, GraphExecutor, Tensor, TracedTensor};
435 ///
436 /// let x = TracedTensor::input_symbolic_shape(DType::F64, 1).unwrap();
437 /// let sum = (&x + &x).unwrap();
438 /// let mut compiler = GraphCompiler::new();
439 /// let program = compiler
440 /// .compile_with_input_specs(&sum, &[(&x, DType::F64, &[2])])
441 /// .unwrap();
442 /// let bound = Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
443 /// let mut executor = GraphExecutor::new(CpuBackend::new());
444 /// let outputs = executor.run_many_with_inputs(&program, &[(&x, &bound)]).unwrap();
445 /// assert_eq!(outputs.len(), 1);
446 /// assert_eq!(outputs[0].as_slice::<f64>().unwrap(), &[2.0, 4.0]);
447 /// ```
448 pub fn run_many_with_inputs(
449 &mut self,
450 program: &GraphProgram,
451 bindings: &[(&TracedTensor, &Tensor)],
452 ) -> Result<Vec<Tensor>> {
453 let input_tensors = resolve_inputs(program, bindings, &mut self.backend)?;
454 self.eval_exec_ir(&program.exec, input_tensors)
455 }
456
457 /// Run a program with explicit bindings and preserve lazy output views.
458 ///
459 /// # Examples
460 ///
461 /// ```
462 /// use tenferro_cpu::CpuBackend;
463 /// use tenferro_runtime::{DType, GraphCompiler, GraphExecutor, Tensor, TensorValue, TracedTensor};
464 ///
465 /// let x = TracedTensor::input_symbolic_shape(DType::F64, 2).unwrap();
466 /// let y = x.transpose(&[1, 0]).unwrap();
467 /// let mut compiler = GraphCompiler::new();
468 /// let program = compiler
469 /// .compile_with_input_specs(&y, &[(&x, DType::F64, &[2, 2])])
470 /// .unwrap();
471 /// let bound = Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap();
472 /// let mut executor = GraphExecutor::new(CpuBackend::new());
473 ///
474 /// let outputs = executor
475 /// .run_many_values_with_inputs(&program, &[(&x, &bound)])
476 /// .unwrap();
477 /// assert_eq!(outputs.len(), 1);
478 /// assert!(matches!(&outputs[0], TensorValue::View(_)));
479 /// ```
480 pub fn run_many_values_with_inputs(
481 &mut self,
482 program: &GraphProgram,
483 bindings: &[(&TracedTensor, &Tensor)],
484 ) -> Result<Vec<TensorValue>> {
485 let input_tensors = resolve_inputs(program, bindings, &mut self.backend)?;
486 self.eval_exec_ir_values(&program.exec, input_tensors)
487 }
488
489 /// Run a program with explicit borrowed runtime placeholder bindings.
490 ///
491 /// Bindings override program defaults and are validated against the input
492 /// specs captured in the compiled program. Bound tensors are borrowed by
493 /// the executor for this call instead of cloned into input slots.
494 ///
495 /// # Examples
496 ///
497 /// ```
498 /// use tenferro_cpu::CpuBackend;
499 /// use tenferro_runtime::{
500 /// DType, GraphCompiler, GraphExecutor, TensorRead, TensorView, TracedTensor, TypedTensorView,
501 /// };
502 ///
503 /// let x = TracedTensor::input_symbolic_shape(DType::F64, 1).unwrap();
504 /// let y = (&x + &x).unwrap();
505 /// let mut compiler = GraphCompiler::new();
506 /// let program = compiler
507 /// .compile_with_input_specs(&y, &[(&x, DType::F64, &[2])])
508 /// .unwrap();
509 /// let data = [1.0_f64, 99.0, 2.0];
510 /// let view = TypedTensorView::from_slice([2], [2], 0, &data).unwrap();
511 /// let read = TensorRead::from_view(TensorView::F64(view));
512 /// let mut executor = GraphExecutor::new(CpuBackend::new());
513 ///
514 /// let outputs = executor.run_many_with_input_reads(&program, &[(&x, read)]).unwrap();
515 /// assert_eq!(outputs[0].as_slice::<f64>().unwrap(), &[2.0, 4.0]);
516 /// ```
517 pub fn run_many_with_input_reads<'a>(
518 &mut self,
519 program: &'a GraphProgram,
520 bindings: &[(&TracedTensor, TensorRead<'a>)],
521 ) -> Result<Vec<Tensor>> {
522 let inputs = resolve_input_reads(program, bindings, &mut self.backend)?;
523 self.eval_exec_ir_slots(&program.exec, inputs)
524 }
525
526 /// Run a program with borrowed bindings and preserve lazy output views.
527 ///
528 /// # Examples
529 ///
530 /// ```
531 /// use tenferro_cpu::CpuBackend;
532 /// use tenferro_runtime::{
533 /// DType, GraphCompiler, GraphExecutor, TensorRead, TensorValue, TensorView, TracedTensor,
534 /// TypedTensorView,
535 /// };
536 ///
537 /// let x = TracedTensor::input_symbolic_shape(DType::F64, 2).unwrap();
538 /// let y = x.transpose(&[1, 0]).unwrap();
539 /// let mut compiler = GraphCompiler::new();
540 /// let program = compiler
541 /// .compile_with_input_specs(&y, &[(&x, DType::F64, &[2, 2])])
542 /// .unwrap();
543 /// let data = [1.0_f64, 2.0, 3.0, 4.0];
544 /// let view = TypedTensorView::from_slice([2, 2], [1, 2], 0, &data).unwrap();
545 /// let read = TensorRead::from_view(TensorView::F64(view));
546 /// let mut executor = GraphExecutor::new(CpuBackend::new());
547 ///
548 /// let outputs = executor
549 /// .run_many_values_with_input_reads(&program, &[(&x, read)])
550 /// .unwrap();
551 /// assert_eq!(outputs.len(), 1);
552 /// assert!(matches!(&outputs[0], TensorValue::View(_)));
553 /// ```
554 pub fn run_many_values_with_input_reads<'a>(
555 &mut self,
556 program: &'a GraphProgram,
557 bindings: &[(&TracedTensor, TensorRead<'a>)],
558 ) -> Result<Vec<TensorValue>> {
559 let inputs = resolve_input_reads(program, bindings, &mut self.backend)?;
560 self.eval_exec_ir_slot_values(&program.exec, inputs)
561 }
562
563 /// Evaluate an execution program through this executor's backend state.
564 ///
565 /// This lower-level entry point is intended for code that already owns an
566 /// execution program and concrete ordered input tensors.
567 ///
568 /// # Examples
569 ///
570 /// ```
571 /// use tenferro_cpu::CpuBackend;
572 /// use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
573 ///
574 /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
575 /// let mut compiler = GraphCompiler::new();
576 /// let program = compiler.compile(&x.neg()).unwrap();
577 /// let mut executor = GraphExecutor::new(CpuBackend::new());
578 /// let out = executor.run(&program).unwrap();
579 /// assert_eq!(out.as_slice::<f64>().unwrap(), &[-2.0]);
580 /// ```
581 pub fn eval_exec_ir(
582 &mut self,
583 program: &ExecProgram,
584 inputs: Vec<Tensor>,
585 ) -> Result<Vec<Tensor>> {
586 validate_exec_input_count(program, inputs.len())?;
587 crate::segment::eval_exec_segmented_with_cache_and_workspace(
588 &mut self.backend,
589 program,
590 inputs,
591 &mut self.slot_workspace,
592 &mut self.backend_cache,
593 Some(&mut self.extension_executor),
594 )
595 }
596
597 /// Evaluate an execution program and preserve lazy owned output views.
598 ///
599 /// # Examples
600 ///
601 /// ```
602 /// use tenferro_cpu::CpuBackend;
603 /// use tenferro_runtime::{GraphCompiler, GraphExecutor, TensorValue, TracedTensor};
604 ///
605 /// let x = TracedTensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap();
606 /// let y = x.transpose(&[1, 0]).unwrap();
607 /// let mut compiler = GraphCompiler::new();
608 /// let program = compiler.compile(&y).unwrap();
609 /// let mut executor = GraphExecutor::new(CpuBackend::new());
610 ///
611 /// let value = executor.run_value(&program).unwrap();
612 /// assert!(matches!(&value, TensorValue::View(_)));
613 /// assert_eq!(value.shape(), &[2, 2]);
614 /// ```
615 pub fn eval_exec_ir_values(
616 &mut self,
617 program: &ExecProgram,
618 inputs: Vec<Tensor>,
619 ) -> Result<Vec<TensorValue>> {
620 validate_exec_input_count(program, inputs.len())?;
621 let inputs = inputs.into_iter().map(ExecSlot::Owned).collect();
622 crate::segment::eval_exec_segmented_slot_values_with_cache_and_workspace(
623 &mut self.backend,
624 program,
625 inputs,
626 &mut self.slot_workspace,
627 &mut self.backend_cache,
628 Some(&mut self.extension_executor),
629 )
630 }
631
632 /// Evaluate an execution program without consuming caller-owned inputs.
633 ///
634 /// # Examples
635 ///
636 /// ```
637 /// use tenferro_cpu::CpuBackend;
638 /// use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
639 ///
640 /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
641 /// let mut compiler = GraphCompiler::new();
642 /// let program = compiler.compile(&x.neg()).unwrap();
643 /// let mut executor = GraphExecutor::new(CpuBackend::new());
644 /// let out = executor.run(&program).unwrap();
645 /// assert_eq!(out.shape(), &[1]);
646 /// ```
647 pub fn eval_exec_ir_non_consuming(
648 &mut self,
649 program: &ExecProgram,
650 inputs: &[Tensor],
651 ) -> Result<Vec<Tensor>> {
652 let inputs = inputs
653 .iter()
654 .map(|tensor| ExecSlot::Read(TensorRead::from_tensor(tensor)))
655 .collect();
656 self.eval_exec_ir_slots(program, inputs)
657 }
658
659 /// Evaluate an execution program without consuming inputs and preserve lazy outputs.
660 ///
661 /// # Examples
662 ///
663 /// ```
664 /// use tenferro_cpu::CpuBackend;
665 /// use tenferro_runtime::{GraphCompiler, GraphExecutor, TensorValue, TracedTensor};
666 ///
667 /// let x = TracedTensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap();
668 /// let y = x.transpose(&[1, 0]).unwrap();
669 /// let mut compiler = GraphCompiler::new();
670 /// let program = compiler.compile(&y).unwrap();
671 /// let mut executor = GraphExecutor::new(CpuBackend::new());
672 ///
673 /// let value = executor.run_value(&program).unwrap();
674 /// assert!(matches!(&value, TensorValue::View(_)));
675 /// assert_eq!(value.to_tensor().unwrap().shape(), &[2, 2]);
676 /// ```
677 pub fn eval_exec_ir_non_consuming_values(
678 &mut self,
679 program: &ExecProgram,
680 inputs: &[Tensor],
681 ) -> Result<Vec<TensorValue>> {
682 let inputs = inputs
683 .iter()
684 .map(|tensor| ExecSlot::Read(TensorRead::from_tensor(tensor)))
685 .collect();
686 self.eval_exec_ir_slot_values(program, inputs)
687 }
688
689 fn eval_exec_ir_slots<'a>(
690 &mut self,
691 program: &ExecProgram,
692 inputs: Vec<ExecSlot<'a>>,
693 ) -> Result<Vec<Tensor>> {
694 validate_exec_input_count(program, inputs.len())?;
695 let mut slot_workspace = Vec::with_capacity(self.borrowed_slot_workspace_capacity);
696 let result = crate::segment::eval_exec_segmented_slots_with_cache_and_workspace(
697 &mut self.backend,
698 program,
699 inputs,
700 &mut slot_workspace,
701 &mut self.backend_cache,
702 Some(&mut self.extension_executor),
703 );
704 self.borrowed_slot_workspace_capacity = slot_workspace.capacity();
705 result
706 }
707
708 fn eval_exec_ir_slot_values<'a>(
709 &mut self,
710 program: &ExecProgram,
711 inputs: Vec<ExecSlot<'a>>,
712 ) -> Result<Vec<TensorValue>> {
713 validate_exec_input_count(program, inputs.len())?;
714 let mut slot_workspace = Vec::with_capacity(self.borrowed_slot_workspace_capacity);
715 let result = crate::segment::eval_exec_segmented_slot_values_with_cache_and_workspace(
716 &mut self.backend,
717 program,
718 inputs,
719 &mut slot_workspace,
720 &mut self.backend_cache,
721 Some(&mut self.extension_executor),
722 );
723 self.borrowed_slot_workspace_capacity = slot_workspace.capacity();
724 result
725 }
726
727 /// Clear backend-specific runtime analysis cache entries.
728 ///
729 /// # Examples
730 ///
731 /// ```
732 /// use tenferro_cpu::CpuBackend;
733 /// use tenferro_runtime::{GraphExecutor};
734 ///
735 /// let mut executor = GraphExecutor::new(CpuBackend::new());
736 /// executor.clear_backend_cache();
737 /// assert_eq!(executor.cache_stats().backend.entries, 0);
738 /// ```
739 pub fn clear_backend_cache(&mut self) {
740 self.backend_cache.clear();
741 }
742
743 /// Clear generic extension runtime cache entries.
744 ///
745 /// # Examples
746 ///
747 /// ```
748 /// use tenferro_cpu::CpuBackend;
749 /// use tenferro_runtime::{GraphExecutor};
750 ///
751 /// let mut executor = GraphExecutor::new(CpuBackend::new());
752 /// executor.clear_extension_caches();
753 /// assert_eq!(executor.cache_stats().extensions.entries, 0);
754 /// ```
755 pub fn clear_extension_caches(&mut self) {
756 self.extension_executor.clear_caches();
757 }
758
759 /// Clear every executor-owned runtime cache.
760 ///
761 /// # Examples
762 ///
763 /// ```
764 /// use tenferro_cpu::CpuBackend;
765 /// use tenferro_runtime::{GraphExecutor};
766 ///
767 /// let mut executor = GraphExecutor::new(CpuBackend::new());
768 /// executor.clear_caches();
769 /// assert_eq!(executor.cache_stats().backend.entries, 0);
770 /// ```
771 pub fn clear_caches(&mut self) {
772 self.clear_extension_caches();
773 self.clear_backend_cache();
774 }
775
776 /// Return executor runtime cache-entry and retained-byte stats.
777 ///
778 /// # Examples
779 ///
780 /// ```
781 /// use tenferro_cpu::CpuBackend;
782 /// use tenferro_runtime::{GraphExecutor};
783 ///
784 /// let executor = GraphExecutor::new(CpuBackend::new());
785 /// let stats = executor.cache_stats();
786 /// assert_eq!(stats.extensions.entries, 0);
787 /// ```
788 pub fn cache_stats(&self) -> GraphExecutorCacheStats {
789 GraphExecutorCacheStats {
790 extensions: self.extension_executor.cache_stats(),
791 backend: self.backend_cache.stats(),
792 }
793 }
794}
795
796impl<B: TensorBackend + 'static> Default for GraphExecutor<B>
797where
798 B: Default,
799{
800 fn default() -> Self {
801 Self::new(B::default())
802 }
803}
804
805fn validate_exec_input_count(program: &ExecProgram, actual: usize) -> Result<()> {
806 let expected = program.input_slots.len();
807 if actual != expected {
808 return Err(Error::Internal(format!(
809 "expected {expected} inputs for execution program, got {actual}"
810 )));
811 }
812 Ok(())
813}
814
815fn expect_single_output(outputs: &mut Vec<Tensor>) -> Result<Tensor> {
816 if outputs.len() != 1 {
817 return Err(Error::Internal(format!(
818 "expected 1 output, got {}",
819 outputs.len()
820 )));
821 }
822 outputs
823 .pop()
824 .ok_or_else(|| Error::Internal("missing graph output".to_string()))
825}
826
827fn expect_single_value(outputs: &mut Vec<TensorValue>) -> Result<TensorValue> {
828 if outputs.len() != 1 {
829 return Err(Error::Internal(format!(
830 "expected 1 output, got {}",
831 outputs.len()
832 )));
833 }
834 outputs
835 .pop()
836 .ok_or_else(|| Error::Internal("missing graph output".to_string()))
837}
838
839fn resolve_inputs(
840 program: &GraphProgram,
841 bindings: &[(&TracedTensor, &Tensor)],
842 backend: &mut impl TensorBackend,
843) -> Result<Vec<Tensor>> {
844 let program_keys: HashSet<_> = program
845 .inputs
846 .iter()
847 .map(|input| input.key.clone())
848 .collect();
849 let tangent_root_specs = tangent_root_specs(&program.inputs);
850 let default_map: HashMap<_, _> = program
851 .inputs
852 .iter()
853 .filter_map(|input| {
854 input
855 .default_tensor
856 .as_ref()
857 .map(|tensor| (input.key.clone(), tensor.as_ref()))
858 })
859 .collect();
860 let mut binding_map = HashMap::new();
861 for (index, (placeholder, tensor)) in bindings.iter().enumerate() {
862 if placeholder.data.is_some() {
863 return Err(Error::UnexpectedBinding {
864 binding_index: index,
865 });
866 }
867 let key = placeholder.input_key().ok_or(Error::UnexpectedBinding {
868 binding_index: index,
869 })?;
870 validate_binding_placeholder(index, placeholder, tensor)?;
871 let is_program_input = program_keys.contains(&key);
872 if !is_program_input && !tangent_root_specs.contains_key(&key) {
873 return Err(Error::UnexpectedBinding {
874 binding_index: index,
875 });
876 }
877 if binding_map.insert(key.clone(), *tensor).is_some() {
878 return Err(Error::DuplicateBinding {
879 input_key: format!("{:?}", key),
880 });
881 }
882 }
883
884 program
885 .inputs
886 .iter()
887 .map(|input| resolve_input(input, &binding_map, &default_map, backend))
888 .collect()
889}
890
891fn resolve_input_reads<'a>(
892 program: &'a GraphProgram,
893 bindings: &[(&TracedTensor, TensorRead<'a>)],
894 backend: &mut impl TensorBackend,
895) -> Result<Vec<ExecSlot<'a>>> {
896 let program_keys: HashSet<_> = program
897 .inputs
898 .iter()
899 .map(|input| input.key.clone())
900 .collect();
901 let tangent_root_specs = tangent_root_specs(&program.inputs);
902 let default_map: HashMap<_, _> = program
903 .inputs
904 .iter()
905 .filter_map(|input| {
906 input
907 .default_tensor
908 .as_ref()
909 .map(|tensor| (input.key.clone(), tensor.as_ref()))
910 })
911 .collect();
912 let mut binding_map = HashMap::new();
913 for (index, (placeholder, read)) in bindings.iter().enumerate() {
914 if placeholder.data.is_some() {
915 return Err(Error::UnexpectedBinding {
916 binding_index: index,
917 });
918 }
919 let key = placeholder.input_key().ok_or(Error::UnexpectedBinding {
920 binding_index: index,
921 })?;
922 validate_binding_placeholder_read(index, placeholder, read)?;
923 let is_program_input = program_keys.contains(&key);
924 if !is_program_input && !tangent_root_specs.contains_key(&key) {
925 return Err(Error::UnexpectedBinding {
926 binding_index: index,
927 });
928 }
929 if binding_map.insert(key.clone(), read.clone()).is_some() {
930 return Err(Error::DuplicateBinding {
931 input_key: format!("{:?}", key),
932 });
933 }
934 }
935
936 program
937 .inputs
938 .iter()
939 .map(|input| resolve_input_read(input, &binding_map, &default_map, backend))
940 .collect()
941}
942
943fn tangent_root_specs(inputs: &[GraphProgramInput]) -> HashMap<TensorInputKey, &GraphProgramInput> {
944 let mut specs = HashMap::new();
945 for input in inputs {
946 if !matches!(input.key, TensorInputKey::User { .. }) {
947 specs
948 .entry(tangent_primal_root(&input.key).clone())
949 .or_insert(input);
950 }
951 }
952 specs
953}
954
955fn resolve_input(
956 input: &GraphProgramInput,
957 bindings: &HashMap<TensorInputKey, &Tensor>,
958 defaults: &HashMap<TensorInputKey, &Tensor>,
959 backend: &mut impl TensorBackend,
960) -> Result<Tensor> {
961 let tensor = if let Some(bound) = bindings.get(&input.key) {
962 (*bound).clone()
963 } else if let Some(default) = &input.default_tensor {
964 resolve_default_tensor(default.as_ref(), backend)?
965 } else if let Some(zero) = deferred_zero_for_tangent_key(&input.key, bindings, defaults)? {
966 zero
967 } else {
968 return Err(Error::UnboundPlaceholder {
969 input_key: format!("{:?}", input.key),
970 });
971 };
972 validate_input_tensor(input, &tensor)?;
973 Ok(tensor)
974}
975
976fn resolve_input_read<'a>(
977 input: &GraphProgramInput,
978 bindings: &HashMap<TensorInputKey, TensorRead<'a>>,
979 defaults: &HashMap<TensorInputKey, &'a Tensor>,
980 backend: &mut impl TensorBackend,
981) -> Result<ExecSlot<'a>> {
982 let slot = if let Some(bound) = bindings.get(&input.key) {
983 ExecSlot::Read(bound.clone())
984 } else if let Some(default) = defaults.get(&input.key) {
985 if should_upload_default_tensor(default) {
986 ExecSlot::Owned(backend.upload_host_tensor(default)?)
987 } else {
988 ExecSlot::Read(TensorRead::from_tensor(default))
989 }
990 } else if let Some(zero) = deferred_zero_for_tangent_key_read(&input.key, bindings, defaults)? {
991 ExecSlot::Owned(zero)
992 } else {
993 return Err(Error::UnboundPlaceholder {
994 input_key: format!("{:?}", input.key),
995 });
996 };
997 validate_input_slot(input, &slot)?;
998 Ok(slot)
999}
1000
1001fn resolve_default_tensor(default: &Tensor, backend: &mut impl TensorBackend) -> Result<Tensor> {
1002 if should_upload_default_tensor(default) {
1003 Ok(backend.upload_host_tensor(default)?)
1004 } else {
1005 Ok(default.clone())
1006 }
1007}
1008
1009fn should_upload_default_tensor(default: &Tensor) -> bool {
1010 default.shape().is_empty() && tensor_has_host_buffer(default)
1011}
1012
1013fn tensor_has_host_buffer(tensor: &Tensor) -> bool {
1014 !tensor.is_backend_buffer()
1015}
1016
1017fn validate_binding_placeholder(
1018 index: usize,
1019 placeholder: &TracedTensor,
1020 tensor: &Tensor,
1021) -> Result<()> {
1022 if placeholder.data.is_some() {
1023 return Err(Error::UnexpectedBinding {
1024 binding_index: index,
1025 });
1026 }
1027 if placeholder.dtype != tensor.dtype() {
1028 return Err(Error::PlaceholderDtypeMismatch {
1029 expected: placeholder.dtype,
1030 actual: tensor.dtype(),
1031 });
1032 }
1033 match placeholder.try_concrete_shape() {
1034 Some(expected_shape) => {
1035 if expected_shape.as_slice() != tensor.shape() {
1036 return Err(Error::PlaceholderShapeMismatch {
1037 expected: expected_shape,
1038 actual: tensor.shape().to_vec(),
1039 });
1040 }
1041 }
1042 None => {
1043 if placeholder.rank != tensor.shape().len() {
1044 return Err(Error::PlaceholderRankMismatch {
1045 expected: placeholder.rank,
1046 actual: tensor.shape().len(),
1047 });
1048 }
1049 }
1050 }
1051 Ok(())
1052}
1053
1054fn validate_binding_placeholder_read(
1055 index: usize,
1056 placeholder: &TracedTensor,
1057 read: &TensorRead<'_>,
1058) -> Result<()> {
1059 if placeholder.data.is_some() {
1060 return Err(Error::UnexpectedBinding {
1061 binding_index: index,
1062 });
1063 }
1064 if placeholder.dtype != read.dtype() {
1065 return Err(Error::PlaceholderDtypeMismatch {
1066 expected: placeholder.dtype,
1067 actual: read.dtype(),
1068 });
1069 }
1070 match placeholder.try_concrete_shape() {
1071 Some(expected_shape) => {
1072 if expected_shape.as_slice() != read.shape() {
1073 return Err(Error::PlaceholderShapeMismatch {
1074 expected: expected_shape,
1075 actual: read.shape().to_vec(),
1076 });
1077 }
1078 }
1079 None => {
1080 if placeholder.rank != read.shape().len() {
1081 return Err(Error::PlaceholderRankMismatch {
1082 expected: placeholder.rank,
1083 actual: read.shape().len(),
1084 });
1085 }
1086 }
1087 }
1088 Ok(())
1089}
1090
1091fn validate_input_tensor(input: &GraphProgramInput, tensor: &Tensor) -> Result<()> {
1092 if input.dtype != tensor.dtype() {
1093 return Err(Error::PlaceholderDtypeMismatch {
1094 expected: input.dtype,
1095 actual: tensor.dtype(),
1096 });
1097 }
1098 if input.shape.as_slice() != tensor.shape() {
1099 return Err(Error::PlaceholderShapeMismatch {
1100 expected: input.shape.clone(),
1101 actual: tensor.shape().to_vec(),
1102 });
1103 }
1104 Ok(())
1105}
1106
1107fn validate_input_slot(input: &GraphProgramInput, slot: &ExecSlot<'_>) -> Result<()> {
1108 if input.dtype != slot.dtype() {
1109 return Err(Error::PlaceholderDtypeMismatch {
1110 expected: input.dtype,
1111 actual: slot.dtype(),
1112 });
1113 }
1114 if input.shape.as_slice() != slot.shape() {
1115 return Err(Error::PlaceholderShapeMismatch {
1116 expected: input.shape.clone(),
1117 actual: slot.shape().to_vec(),
1118 });
1119 }
1120 Ok(())
1121}
1122
1123fn deferred_zero_for_tangent_key(
1124 key: &TensorInputKey,
1125 bindings: &HashMap<TensorInputKey, &Tensor>,
1126 defaults: &HashMap<TensorInputKey, &Tensor>,
1127) -> Result<Option<Tensor>> {
1128 if !key.is_tangent() {
1129 return Ok(None);
1130 }
1131 let root = tangent_primal_root(key);
1132 let Some(primal) = bindings.get(root).or_else(|| defaults.get(root)) else {
1133 return Ok(None);
1134 };
1135 zeros_tensor(primal.dtype(), primal.shape().to_vec()).map(Some)
1136}
1137
1138fn deferred_zero_for_tangent_key_read<'a>(
1139 key: &TensorInputKey,
1140 bindings: &HashMap<TensorInputKey, TensorRead<'a>>,
1141 defaults: &HashMap<TensorInputKey, &'a Tensor>,
1142) -> Result<Option<Tensor>> {
1143 if !key.is_tangent() {
1144 return Ok(None);
1145 }
1146 let root = tangent_primal_root(key);
1147 if let Some(primal) = bindings.get(root) {
1148 return zeros_tensor(primal.dtype(), primal.shape().to_vec()).map(Some);
1149 }
1150 let Some(primal) = defaults.get(root) else {
1151 return Ok(None);
1152 };
1153 zeros_tensor(primal.dtype(), primal.shape().to_vec()).map(Some)
1154}
1155
1156fn tangent_primal_root(key: &TensorInputKey) -> &TensorInputKey {
1157 key.primal_root()
1158}
1159
1160fn zeros_tensor(dtype: DType, shape: Vec<usize>) -> Result<Tensor> {
1161 match dtype {
1162 DType::F32 => Ok(Tensor::F32(TypedTensor::zeros(shape)?)),
1163 DType::F64 => Ok(Tensor::F64(TypedTensor::zeros(shape)?)),
1164 DType::I32 => Ok(Tensor::I32(TypedTensor::zeros(shape)?)),
1165 DType::I64 => Ok(Tensor::I64(TypedTensor::zeros(shape)?)),
1166 DType::Bool => {
1167 let len = checked_default_element_count(&shape)?;
1168 Ok(Tensor::Bool(TypedTensor::from_vec_col_major(
1169 shape,
1170 vec![false; len],
1171 )?))
1172 }
1173 DType::C32 => Ok(Tensor::C32(TypedTensor::zeros(shape)?)),
1174 DType::C64 => Ok(Tensor::C64(TypedTensor::zeros(shape)?)),
1175 }
1176}
1177
1178fn checked_default_element_count(shape: &[usize]) -> Result<usize> {
1179 shape.iter().try_fold(1usize, |acc, &dim| {
1180 acc.checked_mul(dim)
1181 .ok_or_else(|| Error::InvalidCompiledGraph {
1182 message: format!("deferred zero shape product overflows usize for shape {shape:?}"),
1183 })
1184 })
1185}
1186
1187#[cfg(test)]
1188mod tests;