Skip to main content

tenferro_runtime/graph/
program.rs

1use std::sync::Arc;
2
3use tenferro_ops::dim_expr::DimExpr;
4use tenferro_ops::input_key::TensorInputKey;
5use tenferro_tensor::{DType, Tensor};
6
7use crate::exec::ExecProgram;
8use crate::graph::lowering_view::GraphProgramLoweringView;
9
10/// A compiled traced graph, independent of any execution backend.
11///
12/// # Examples
13///
14/// ```
15/// use tenferro_runtime::{GraphCompiler, TracedTensor};
16///
17/// let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
18/// let y = (&x + &x).unwrap();
19/// let mut compiler = GraphCompiler::new();
20/// let program = compiler.compile(&y).unwrap();
21/// assert_eq!(program.input_count(), 1);
22/// ```
23#[derive(Clone, Debug)]
24pub struct GraphProgram {
25    pub(crate) exec: ExecProgram,
26    pub(crate) inputs: Vec<GraphProgramInput>,
27}
28
29impl GraphProgram {
30    pub(crate) fn new(exec: ExecProgram, inputs: Vec<GraphProgramInput>) -> Self {
31        Self { exec, inputs }
32    }
33
34    /// Return the number of graph inputs expected by this program.
35    ///
36    /// # Examples
37    ///
38    /// ```
39    /// use tenferro_runtime::{GraphCompiler, TracedTensor};
40    ///
41    /// let x = TracedTensor::from_vec_col_major(vec![1], vec![3.0_f64]).unwrap();
42    /// let mut compiler = GraphCompiler::new();
43    /// let program = compiler.compile(&x.neg()).unwrap();
44    /// assert_eq!(program.input_count(), 1);
45    /// ```
46    #[inline(never)]
47    pub fn input_count(&self) -> usize {
48        self.inputs.len()
49    }
50
51    /// Return the number of graph outputs produced by this program.
52    ///
53    /// # Examples
54    ///
55    /// ```
56    /// use tenferro_runtime::{GraphCompiler, TracedTensor};
57    ///
58    /// let x = TracedTensor::from_vec_col_major(vec![1], vec![3.0_f64]).unwrap();
59    /// let mut compiler = GraphCompiler::new();
60    /// let program = compiler.compile(&x.neg()).unwrap();
61    /// assert_eq!(program.output_count(), 1);
62    /// ```
63    #[inline(never)]
64    pub fn output_count(&self) -> usize {
65        self.exec.output_slots.len()
66    }
67
68    /// Return the ordered input specs expected by this program.
69    ///
70    /// # Examples
71    ///
72    /// ```
73    /// use tenferro_runtime::{DType, GraphCompiler, TracedTensor};
74    ///
75    /// let x = TracedTensor::input_symbolic_shape(DType::F64, 1).unwrap();
76    /// let mut compiler = GraphCompiler::new();
77    /// let program = compiler
78    ///     .compile_with_input_specs(&x.neg(), &[(&x, DType::F64, &[4])])
79    ///     .unwrap();
80    /// assert_eq!(program.input_specs()[0].shape(), &[4]);
81    /// ```
82    #[inline(never)]
83    pub fn input_specs(&self) -> &[GraphProgramInput] {
84        &self.inputs
85    }
86
87    /// Return a read-only lowering view for peer executor integrations.
88    ///
89    /// The view exposes only immutable, lowering-oriented program metadata.
90    /// Native execution and mutation remain owned by [`GraphExecutor`](super::GraphExecutor).
91    ///
92    /// # Examples
93    ///
94    /// ```
95    /// use tenferro_runtime::{GraphCompiler, TracedTensor};
96    ///
97    /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
98    /// let mut compiler = GraphCompiler::new();
99    /// let program = compiler.compile(&x.neg()).unwrap();
100    /// assert_eq!(program.lowering_view().output_slots().len(), 1);
101    /// ```
102    #[inline(never)]
103    pub fn lowering_view(&self) -> GraphProgramLoweringView<'_> {
104        GraphProgramLoweringView::new(&self.exec)
105    }
106}
107
108/// A single ordered input required by a [`GraphProgram`].
109///
110/// # Examples
111///
112/// ```
113/// use tenferro_runtime::{GraphCompiler, TracedTensor};
114///
115/// let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
116/// let mut compiler = GraphCompiler::new();
117/// let program = compiler.compile(&x.neg()).unwrap();
118/// let input = &program.input_specs()[0];
119/// assert_eq!(input.shape(), &[2]);
120/// ```
121#[derive(Clone, Debug)]
122pub struct GraphProgramInput {
123    pub(crate) key: TensorInputKey,
124    pub(crate) dtype: DType,
125    pub(crate) shape: Vec<usize>,
126    // Preserved for symbolic-shape diagnostics and future graph-input metadata
127    // without exposing `DimExpr` through the stable input-spec accessor.
128    #[allow(dead_code)]
129    pub(crate) dim_expr_shape: Vec<DimExpr>,
130    pub(crate) default_tensor: Option<Arc<Tensor>>,
131}
132
133impl GraphProgramInput {
134    pub(crate) fn new(
135        key: TensorInputKey,
136        dtype: DType,
137        shape: Vec<usize>,
138        dim_expr_shape: Vec<DimExpr>,
139        default_tensor: Option<Arc<Tensor>>,
140    ) -> Self {
141        Self {
142            key,
143            dtype,
144            shape,
145            dim_expr_shape,
146            default_tensor,
147        }
148    }
149
150    /// Return the dtype expected for this input.
151    ///
152    /// # Examples
153    ///
154    /// ```
155    /// use tenferro_runtime::{DType, GraphCompiler, TracedTensor};
156    ///
157    /// let x = TracedTensor::input_symbolic_shape(DType::F64, 1).unwrap();
158    /// let mut compiler = GraphCompiler::new();
159    /// let program = compiler
160    ///     .compile_with_input_specs(&x, &[(&x, DType::F64, &[2])])
161    ///     .unwrap();
162    /// assert_eq!(program.input_specs()[0].dtype(), DType::F64);
163    /// ```
164    #[inline(never)]
165    pub fn dtype(&self) -> DType {
166        self.dtype
167    }
168
169    /// Return the concrete shape expected for this input.
170    ///
171    /// # Examples
172    ///
173    /// ```
174    /// use tenferro_runtime::{GraphCompiler, TracedTensor};
175    ///
176    /// let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
177    /// let mut compiler = GraphCompiler::new();
178    /// let program = compiler.compile(&x).unwrap();
179    /// assert_eq!(program.input_specs()[0].shape(), &[2]);
180    /// ```
181    #[inline(never)]
182    pub fn shape(&self) -> &[usize] {
183        &self.shape
184    }
185}