tenferro_runtime/graph/program.rs
1use std::sync::Arc;
2
3use tenferro_ops::dim_expr::DimExpr;
4use tenferro_ops::input_key::TensorInputKey;
5use tenferro_tensor::{DType, Tensor};
6
7use crate::exec::ExecProgram;
8use crate::graph::lowering_view::GraphProgramLoweringView;
9
10/// A compiled traced graph, independent of any execution backend.
11///
12/// # Examples
13///
14/// ```
15/// use tenferro_runtime::{GraphCompiler, TracedTensor};
16///
17/// let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
18/// let y = (&x + &x).unwrap();
19/// let mut compiler = GraphCompiler::new();
20/// let program = compiler.compile(&y).unwrap();
21/// assert_eq!(program.input_count(), 1);
22/// ```
23#[derive(Clone, Debug)]
24pub struct GraphProgram {
25 pub(crate) exec: ExecProgram,
26 pub(crate) inputs: Vec<GraphProgramInput>,
27}
28
29impl GraphProgram {
30 pub(crate) fn new(exec: ExecProgram, inputs: Vec<GraphProgramInput>) -> Self {
31 Self { exec, inputs }
32 }
33
34 /// Return the number of graph inputs expected by this program.
35 ///
36 /// # Examples
37 ///
38 /// ```
39 /// use tenferro_runtime::{GraphCompiler, TracedTensor};
40 ///
41 /// let x = TracedTensor::from_vec_col_major(vec![1], vec![3.0_f64]).unwrap();
42 /// let mut compiler = GraphCompiler::new();
43 /// let program = compiler.compile(&x.neg()).unwrap();
44 /// assert_eq!(program.input_count(), 1);
45 /// ```
46 #[inline(never)]
47 pub fn input_count(&self) -> usize {
48 self.inputs.len()
49 }
50
51 /// Return the number of graph outputs produced by this program.
52 ///
53 /// # Examples
54 ///
55 /// ```
56 /// use tenferro_runtime::{GraphCompiler, TracedTensor};
57 ///
58 /// let x = TracedTensor::from_vec_col_major(vec![1], vec![3.0_f64]).unwrap();
59 /// let mut compiler = GraphCompiler::new();
60 /// let program = compiler.compile(&x.neg()).unwrap();
61 /// assert_eq!(program.output_count(), 1);
62 /// ```
63 #[inline(never)]
64 pub fn output_count(&self) -> usize {
65 self.exec.output_slots.len()
66 }
67
68 /// Return the ordered input specs expected by this program.
69 ///
70 /// # Examples
71 ///
72 /// ```
73 /// use tenferro_runtime::{DType, GraphCompiler, TracedTensor};
74 ///
75 /// let x = TracedTensor::input_symbolic_shape(DType::F64, 1).unwrap();
76 /// let mut compiler = GraphCompiler::new();
77 /// let program = compiler
78 /// .compile_with_input_specs(&x.neg(), &[(&x, DType::F64, &[4])])
79 /// .unwrap();
80 /// assert_eq!(program.input_specs()[0].shape(), &[4]);
81 /// ```
82 #[inline(never)]
83 pub fn input_specs(&self) -> &[GraphProgramInput] {
84 &self.inputs
85 }
86
87 /// Return a read-only lowering view for peer executor integrations.
88 ///
89 /// The view exposes only immutable, lowering-oriented program metadata.
90 /// Native execution and mutation remain owned by [`GraphExecutor`](super::GraphExecutor).
91 ///
92 /// # Examples
93 ///
94 /// ```
95 /// use tenferro_runtime::{GraphCompiler, TracedTensor};
96 ///
97 /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
98 /// let mut compiler = GraphCompiler::new();
99 /// let program = compiler.compile(&x.neg()).unwrap();
100 /// assert_eq!(program.lowering_view().output_slots().len(), 1);
101 /// ```
102 #[inline(never)]
103 pub fn lowering_view(&self) -> GraphProgramLoweringView<'_> {
104 GraphProgramLoweringView::new(&self.exec)
105 }
106}
107
108/// A single ordered input required by a [`GraphProgram`].
109///
110/// # Examples
111///
112/// ```
113/// use tenferro_runtime::{GraphCompiler, TracedTensor};
114///
115/// let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
116/// let mut compiler = GraphCompiler::new();
117/// let program = compiler.compile(&x.neg()).unwrap();
118/// let input = &program.input_specs()[0];
119/// assert_eq!(input.shape(), &[2]);
120/// ```
121#[derive(Clone, Debug)]
122pub struct GraphProgramInput {
123 pub(crate) key: TensorInputKey,
124 pub(crate) dtype: DType,
125 pub(crate) shape: Vec<usize>,
126 // Preserved for symbolic-shape diagnostics and future graph-input metadata
127 // without exposing `DimExpr` through the stable input-spec accessor.
128 #[allow(dead_code)]
129 pub(crate) dim_expr_shape: Vec<DimExpr>,
130 pub(crate) default_tensor: Option<Arc<Tensor>>,
131}
132
133impl GraphProgramInput {
134 pub(crate) fn new(
135 key: TensorInputKey,
136 dtype: DType,
137 shape: Vec<usize>,
138 dim_expr_shape: Vec<DimExpr>,
139 default_tensor: Option<Arc<Tensor>>,
140 ) -> Self {
141 Self {
142 key,
143 dtype,
144 shape,
145 dim_expr_shape,
146 default_tensor,
147 }
148 }
149
150 /// Return the dtype expected for this input.
151 ///
152 /// # Examples
153 ///
154 /// ```
155 /// use tenferro_runtime::{DType, GraphCompiler, TracedTensor};
156 ///
157 /// let x = TracedTensor::input_symbolic_shape(DType::F64, 1).unwrap();
158 /// let mut compiler = GraphCompiler::new();
159 /// let program = compiler
160 /// .compile_with_input_specs(&x, &[(&x, DType::F64, &[2])])
161 /// .unwrap();
162 /// assert_eq!(program.input_specs()[0].dtype(), DType::F64);
163 /// ```
164 #[inline(never)]
165 pub fn dtype(&self) -> DType {
166 self.dtype
167 }
168
169 /// Return the concrete shape expected for this input.
170 ///
171 /// # Examples
172 ///
173 /// ```
174 /// use tenferro_runtime::{GraphCompiler, TracedTensor};
175 ///
176 /// let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
177 /// let mut compiler = GraphCompiler::new();
178 /// let program = compiler.compile(&x).unwrap();
179 /// assert_eq!(program.input_specs()[0].shape(), &[2]);
180 /// ```
181 #[inline(never)]
182 pub fn shape(&self) -> &[usize] {
183 &self.shape
184 }
185}