Skip to main content

tenferro_runtime/graph/
lowering_view.rs

1use std::fmt;
2
3use tenferro_ops::ext_op::ExtensionOp;
4use tenferro_ops::ShapeExtent;
5use tenferro_tensor::{DType, DotGeneralConfig};
6
7use crate::exec::{ExecInstruction, ExecOp, ExecProgram};
8
9/// Read-only lowering view over a compiled graph program.
10///
11/// This view is for peer executor crates that need to translate a
12/// [`GraphProgram`](super::GraphProgram) without mutating the runtime-owned
13/// execution program.
14///
15/// # Examples
16///
17/// ```
18/// use tenferro_runtime::{GraphCompiler, TracedTensor};
19///
20/// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
21/// let mut compiler = GraphCompiler::new();
22/// let program = compiler.compile(&x.neg()).unwrap();
23/// let view = program.lowering_view();
24/// assert_eq!(view.output_slots().len(), 1);
25/// ```
26#[derive(Clone, Copy)]
27pub struct GraphProgramLoweringView<'a> {
28    exec: &'a ExecProgram,
29}
30
31impl fmt::Debug for GraphProgramLoweringView<'_> {
32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33        f.debug_struct("GraphProgramLoweringView")
34            .field("slot_count", &self.slot_count())
35            .field("input_count", &self.input_slots().len())
36            .field("output_count", &self.output_slots().len())
37            .field("instruction_count", &self.exec.instructions.len())
38            .finish()
39    }
40}
41
42impl<'a> GraphProgramLoweringView<'a> {
43    pub(crate) fn new(exec: &'a ExecProgram) -> Self {
44        Self { exec }
45    }
46
47    /// Return the number of execution slots used by the program.
48    ///
49    /// # Examples
50    ///
51    /// ```
52    /// use tenferro_runtime::{GraphCompiler, TracedTensor};
53    ///
54    /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
55    /// let mut compiler = GraphCompiler::new();
56    /// let program = compiler.compile(&x.neg()).unwrap();
57    /// assert!(program.lowering_view().slot_count() >= 1);
58    /// ```
59    pub fn slot_count(&self) -> usize {
60        self.exec.n_slots
61    }
62
63    /// Return the execution slots populated by graph inputs.
64    ///
65    /// # Examples
66    ///
67    /// ```
68    /// use tenferro_runtime::{GraphCompiler, TracedTensor};
69    ///
70    /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
71    /// let mut compiler = GraphCompiler::new();
72    /// let program = compiler.compile(&x.neg()).unwrap();
73    /// assert_eq!(program.lowering_view().input_slots().len(), 1);
74    /// ```
75    pub fn input_slots(&self) -> &'a [usize] {
76        &self.exec.input_slots
77    }
78
79    /// Return the execution slots used as program outputs.
80    ///
81    /// # Examples
82    ///
83    /// ```
84    /// use tenferro_runtime::{GraphCompiler, TracedTensor};
85    ///
86    /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
87    /// let mut compiler = GraphCompiler::new();
88    /// let program = compiler.compile(&x.neg()).unwrap();
89    /// assert_eq!(program.lowering_view().output_slots().len(), 1);
90    /// ```
91    pub fn output_slots(&self) -> &'a [usize] {
92        &self.exec.output_slots
93    }
94
95    /// Iterate over read-only instruction views in execution order.
96    ///
97    /// # Examples
98    ///
99    /// ```
100    /// use tenferro_runtime::{GraphCompiler, TracedTensor};
101    ///
102    /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
103    /// let mut compiler = GraphCompiler::new();
104    /// let program = compiler.compile(&x.neg()).unwrap();
105    /// assert!(program.lowering_view().instructions().count() >= 1);
106    /// ```
107    pub fn instructions(&self) -> impl ExactSizeIterator<Item = GraphInstructionView<'a>> + '_ {
108        self.exec.instructions.iter().map(GraphInstructionView::new)
109    }
110}
111
112/// Read-only lowering view over one execution instruction.
113///
114/// # Examples
115///
116/// ```
117/// use tenferro_runtime::{GraphCompiler, TracedTensor};
118///
119/// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
120/// let mut compiler = GraphCompiler::new();
121/// let program = compiler.compile(&x.neg()).unwrap();
122/// let inst = program.lowering_view().instructions().next().unwrap();
123/// assert_eq!(inst.output_slots().len(), 1);
124/// ```
125#[derive(Clone, Copy)]
126pub struct GraphInstructionView<'a> {
127    inst: &'a ExecInstruction,
128}
129
130impl fmt::Debug for GraphInstructionView<'_> {
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        f.debug_struct("GraphInstructionView")
133            .field("op", &self.op_name())
134            .field("input_count", &self.input_slots().len())
135            .field("output_count", &self.output_slots().len())
136            .field("dtype", &self.dtype())
137            .finish()
138    }
139}
140
141impl<'a> GraphInstructionView<'a> {
142    fn new(inst: &'a ExecInstruction) -> Self {
143        Self { inst }
144    }
145
146    /// Return the operation view for this instruction.
147    ///
148    /// # Examples
149    ///
150    /// ```
151    /// use tenferro_runtime::{GraphCompiler, GraphOpView, TracedTensor};
152    ///
153    /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
154    /// let mut compiler = GraphCompiler::new();
155    /// let program = compiler.compile(&x.neg()).unwrap();
156    /// let inst = program.lowering_view().instructions().next().unwrap();
157    /// assert!(matches!(inst.op(), GraphOpView::Negate));
158    /// ```
159    pub fn op(&self) -> GraphOpView<'a> {
160        match &self.inst.op {
161            ExecOp::Constant { dtype, bytes } => GraphOpView::Constant {
162                dtype: *dtype,
163                bytes,
164            },
165            ExecOp::Add => GraphOpView::Add,
166            ExecOp::Multiply => GraphOpView::Multiply,
167            ExecOp::Negate => GraphOpView::Negate,
168            ExecOp::Divide => GraphOpView::Divide,
169            ExecOp::Abs => GraphOpView::Abs,
170            ExecOp::Exp => GraphOpView::Exp,
171            ExecOp::Log => GraphOpView::Log,
172            ExecOp::Sin => GraphOpView::Sin,
173            ExecOp::Cos => GraphOpView::Cos,
174            ExecOp::Tanh => GraphOpView::Tanh,
175            ExecOp::Sqrt => GraphOpView::Sqrt,
176            ExecOp::Rsqrt => GraphOpView::Rsqrt,
177            ExecOp::Pow => GraphOpView::Pow,
178            ExecOp::Expm1 => GraphOpView::Expm1,
179            ExecOp::Log1p => GraphOpView::Log1p,
180            ExecOp::Convert { to } => GraphOpView::Convert { to: *to },
181            ExecOp::Reshape { .. } => GraphOpView::Reshape,
182            ExecOp::BroadcastInDim { dims, .. } => GraphOpView::BroadcastInDim { dims },
183            ExecOp::Transpose { perm } => GraphOpView::Transpose { perm },
184            ExecOp::ReduceSum { axes } => GraphOpView::ReduceSum { axes },
185            ExecOp::DotGeneral(config) => GraphOpView::DotGeneral { config },
186            ExecOp::Extension(op) => GraphOpView::Extension { op: op.as_ref() },
187            other => GraphOpView::Unsupported {
188                name: exec_op_name(other),
189            },
190        }
191    }
192
193    /// Return a stable operation name for diagnostics.
194    ///
195    /// # Examples
196    ///
197    /// ```
198    /// use tenferro_runtime::{GraphCompiler, TracedTensor};
199    ///
200    /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
201    /// let mut compiler = GraphCompiler::new();
202    /// let program = compiler.compile(&x.neg()).unwrap();
203    /// let inst = program.lowering_view().instructions().next().unwrap();
204    /// assert_eq!(inst.op_name(), "Negate");
205    /// ```
206    pub fn op_name(&self) -> &'static str {
207        self.op().name()
208    }
209
210    /// Return the input slots consumed by this instruction.
211    ///
212    /// # Examples
213    ///
214    /// ```
215    /// use tenferro_runtime::{GraphCompiler, TracedTensor};
216    ///
217    /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
218    /// let y = (&x + &x).unwrap();
219    /// let mut compiler = GraphCompiler::new();
220    /// let program = compiler.compile(&y).unwrap();
221    /// let inst = program.lowering_view().instructions().next().unwrap();
222    /// assert_eq!(inst.input_slots().len(), 2);
223    /// ```
224    pub fn input_slots(&self) -> &'a [usize] {
225        &self.inst.input_slots
226    }
227
228    /// Return the output slots written by this instruction.
229    ///
230    /// # Examples
231    ///
232    /// ```
233    /// use tenferro_runtime::{GraphCompiler, TracedTensor};
234    ///
235    /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
236    /// let mut compiler = GraphCompiler::new();
237    /// let program = compiler.compile(&x.neg()).unwrap();
238    /// let inst = program.lowering_view().instructions().next().unwrap();
239    /// assert_eq!(inst.output_slots().len(), 1);
240    /// ```
241    pub fn output_slots(&self) -> &'a [usize] {
242        &self.inst.output_slots
243    }
244
245    /// Return the dtype of this instruction's output.
246    ///
247    /// # Examples
248    ///
249    /// ```
250    /// use tenferro_runtime::{DType, GraphCompiler, TracedTensor};
251    ///
252    /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
253    /// let mut compiler = GraphCompiler::new();
254    /// let program = compiler.compile(&x.neg()).unwrap();
255    /// let inst = program.lowering_view().instructions().next().unwrap();
256    /// assert_eq!(inst.dtype(), DType::F64);
257    /// ```
258    pub fn dtype(&self) -> DType {
259        self.inst.dtype
260    }
261
262    /// Resolve an exact static output shape for this instruction.
263    ///
264    /// `input_shapes` must be ordered the same way as [`Self::input_slots`].
265    ///
266    /// # Examples
267    ///
268    /// ```
269    /// use tenferro_runtime::{GraphCompiler, TracedTensor};
270    ///
271    /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
272    /// let mut compiler = GraphCompiler::new();
273    /// let program = compiler.compile(&x.neg()).unwrap();
274    /// let inst = program.lowering_view().instructions().next().unwrap();
275    /// assert_eq!(inst.static_output_shape(0, &[&[1]]).unwrap(), vec![1]);
276    /// ```
277    pub fn static_output_shape(
278        &self,
279        output_index: usize,
280        input_shapes: &[&[usize]],
281    ) -> std::result::Result<Vec<usize>, GraphProgramLoweringShapeError> {
282        let extents = self.inst.output_extents.get(output_index).ok_or(
283            GraphProgramLoweringShapeError::MissingOutput {
284                op: self.op_name(),
285                output_index,
286            },
287        )?;
288        let mut shape = Vec::with_capacity(extents.len());
289        for (axis, extent) in extents.iter().enumerate() {
290            match extent {
291                ShapeExtent::Exact(dim) => shape.push(dim.eval(input_shapes).map_err(|err| {
292                    GraphProgramLoweringShapeError::InvalidDimExpr {
293                        op: self.op_name(),
294                        output_index,
295                        axis,
296                        source: err,
297                    }
298                })?),
299                ShapeExtent::UpperBound(_) => {
300                    return Err(GraphProgramLoweringShapeError::NonStatic {
301                        op: self.op_name(),
302                        output_index,
303                        axis,
304                        kind: "an upper bound",
305                    });
306                }
307                ShapeExtent::Unknown => {
308                    return Err(GraphProgramLoweringShapeError::NonStatic {
309                        op: self.op_name(),
310                        output_index,
311                        axis,
312                        kind: "unknown",
313                    });
314                }
315            }
316        }
317        Ok(shape)
318    }
319}
320
321/// Read-only operation view for graph lowering integrations.
322///
323/// Unsupported operation families are represented as [`GraphOpView::Unsupported`]
324/// so peer executors can emit precise diagnostics without depending on the raw
325/// execution IR.
326///
327/// # Examples
328///
329/// ```
330/// use tenferro_runtime::{GraphCompiler, GraphOpView, TracedTensor};
331///
332/// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
333/// let mut compiler = GraphCompiler::new();
334/// let program = compiler.compile(&x.neg()).unwrap();
335/// let op = program.lowering_view().instructions().next().unwrap().op();
336/// assert!(matches!(op, GraphOpView::Negate));
337/// ```
338#[derive(Clone, Copy)]
339pub enum GraphOpView<'a> {
340    /// Scalar constant payload.
341    Constant { dtype: DType, bytes: &'a [u8] },
342    /// Elementwise addition.
343    Add,
344    /// Elementwise multiplication.
345    Multiply,
346    /// Elementwise negation.
347    Negate,
348    /// Elementwise division.
349    Divide,
350    /// Elementwise absolute value.
351    Abs,
352    /// Elementwise exponential.
353    Exp,
354    /// Elementwise natural logarithm.
355    Log,
356    /// Elementwise sine.
357    Sin,
358    /// Elementwise cosine.
359    Cos,
360    /// Elementwise hyperbolic tangent.
361    Tanh,
362    /// Elementwise square root.
363    Sqrt,
364    /// Elementwise reciprocal square root.
365    Rsqrt,
366    /// Elementwise power.
367    Pow,
368    /// Elementwise exponential minus one.
369    Expm1,
370    /// Elementwise natural logarithm of one plus input.
371    Log1p,
372    /// Dtype conversion.
373    Convert { to: DType },
374    /// Shape-only reshape.
375    Reshape,
376    /// Broadcast with output-to-input dimension mapping.
377    BroadcastInDim { dims: &'a [usize] },
378    /// Transpose with output dimension permutation.
379    Transpose { perm: &'a [usize] },
380    /// Sum reduction.
381    ReduceSum { axes: &'a [usize] },
382    /// General dot/contraction.
383    DotGeneral { config: &'a DotGeneralConfig },
384    /// Extension operation with an owner-provided optional standard-op lowering.
385    Extension { op: &'a dyn ExtensionOp },
386    /// Operation outside the stable public lowering view.
387    Unsupported { name: &'static str },
388}
389
390impl fmt::Debug for GraphOpView<'_> {
391    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
392        match self {
393            Self::Constant { dtype, bytes } => f
394                .debug_struct("Constant")
395                .field("dtype", dtype)
396                .field("byte_len", &bytes.len())
397                .finish(),
398            Self::Add => f.write_str("Add"),
399            Self::Multiply => f.write_str("Multiply"),
400            Self::Negate => f.write_str("Negate"),
401            Self::Divide => f.write_str("Divide"),
402            Self::Abs => f.write_str("Abs"),
403            Self::Exp => f.write_str("Exp"),
404            Self::Log => f.write_str("Log"),
405            Self::Sin => f.write_str("Sin"),
406            Self::Cos => f.write_str("Cos"),
407            Self::Tanh => f.write_str("Tanh"),
408            Self::Sqrt => f.write_str("Sqrt"),
409            Self::Rsqrt => f.write_str("Rsqrt"),
410            Self::Pow => f.write_str("Pow"),
411            Self::Expm1 => f.write_str("Expm1"),
412            Self::Log1p => f.write_str("Log1p"),
413            Self::Convert { to } => f.debug_struct("Convert").field("to", to).finish(),
414            Self::Reshape => f.write_str("Reshape"),
415            Self::BroadcastInDim { dims } => f
416                .debug_struct("BroadcastInDim")
417                .field("dims", dims)
418                .finish(),
419            Self::Transpose { perm } => f.debug_struct("Transpose").field("perm", perm).finish(),
420            Self::ReduceSum { axes } => f.debug_struct("ReduceSum").field("axes", axes).finish(),
421            Self::DotGeneral { config } => f
422                .debug_struct("DotGeneral")
423                .field("config", config)
424                .finish(),
425            Self::Extension { op } => f
426                .debug_struct("Extension")
427                .field("family_id", &op.family_id())
428                .finish(),
429            Self::Unsupported { name } => {
430                f.debug_struct("Unsupported").field("name", name).finish()
431            }
432        }
433    }
434}
435
436impl GraphOpView<'_> {
437    /// Return the stable operation name used in diagnostics.
438    ///
439    /// # Examples
440    ///
441    /// ```
442    /// use tenferro_runtime::{GraphCompiler, TracedTensor};
443    ///
444    /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
445    /// let mut compiler = GraphCompiler::new();
446    /// let program = compiler.compile(&x.neg()).unwrap();
447    /// let op = program.lowering_view().instructions().next().unwrap().op();
448    /// assert_eq!(op.name(), "Negate");
449    /// ```
450    pub fn name(&self) -> &'static str {
451        match self {
452            Self::Constant { .. } => "Constant",
453            Self::Add => "Add",
454            Self::Multiply => "Multiply",
455            Self::Negate => "Negate",
456            Self::Divide => "Divide",
457            Self::Abs => "Abs",
458            Self::Exp => "Exp",
459            Self::Log => "Log",
460            Self::Sin => "Sin",
461            Self::Cos => "Cos",
462            Self::Tanh => "Tanh",
463            Self::Sqrt => "Sqrt",
464            Self::Rsqrt => "Rsqrt",
465            Self::Pow => "Pow",
466            Self::Expm1 => "Expm1",
467            Self::Log1p => "Log1p",
468            Self::Convert { .. } => "Convert",
469            Self::Reshape => "Reshape",
470            Self::BroadcastInDim { .. } => "BroadcastInDim",
471            Self::Transpose { .. } => "Transpose",
472            Self::ReduceSum { .. } => "ReduceSum",
473            Self::DotGeneral { .. } => "DotGeneral",
474            Self::Extension { .. } => "Extension",
475            Self::Unsupported { name } => name,
476        }
477    }
478}
479
480/// Error returned when a lowering view cannot resolve an exact output shape.
481///
482/// # Examples
483///
484/// ```
485/// use tenferro_runtime::GraphProgramLoweringShapeError;
486///
487/// let err = GraphProgramLoweringShapeError::MissingOutput {
488///     op: "Example",
489///     output_index: 0,
490/// };
491/// assert!(err.to_string().contains("Example"));
492/// ```
493#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
494pub enum GraphProgramLoweringShapeError {
495    /// The instruction has no metadata for the requested output.
496    #[error("ExecOp::{op} missing output_extents for output {output_index}")]
497    MissingOutput {
498        op: &'static str,
499        output_index: usize,
500    },
501    /// The instruction output has dynamic or unknown extent metadata.
502    #[error("ExecOp::{op} output {output_index} axis {axis} has non-static extent: {kind}")]
503    NonStatic {
504        op: &'static str,
505        output_index: usize,
506        axis: usize,
507        kind: &'static str,
508    },
509    /// Static shape evaluation failed for an exact dimension expression.
510    #[error(
511        "ExecOp::{op} output {output_index} axis {axis} has invalid dimension expression: {source}"
512    )]
513    InvalidDimExpr {
514        op: &'static str,
515        output_index: usize,
516        axis: usize,
517        source: tenferro_ops::dim_expr::DimExprEvalError,
518    },
519}
520
521fn exec_op_name(op: &ExecOp) -> &'static str {
522    match op {
523        ExecOp::Transpose { .. } => "Transpose",
524        ExecOp::Reshape { .. } => "Reshape",
525        ExecOp::BroadcastInDim { .. } => "BroadcastInDim",
526        ExecOp::Convert { .. } => "Convert",
527        ExecOp::Constant { .. } => "Constant",
528        ExecOp::DotGeneral(_) => "DotGeneral",
529        ExecOp::DotGeneralWithConj { .. } => "DotGeneralWithConj",
530        ExecOp::ReduceSum { .. } => "ReduceSum",
531        ExecOp::ExtractDiag { .. } => "ExtractDiag",
532        ExecOp::EmbedDiag { .. } => "EmbedDiag",
533        ExecOp::Tril { .. } => "Tril",
534        ExecOp::Triu { .. } => "Triu",
535        ExecOp::Add => "Add",
536        ExecOp::Multiply => "Multiply",
537        ExecOp::Negate => "Negate",
538        ExecOp::Conj => "Conj",
539        ExecOp::Divide => "Divide",
540        ExecOp::Abs => "Abs",
541        ExecOp::Sign => "Sign",
542        ExecOp::Maximum => "Maximum",
543        ExecOp::Minimum => "Minimum",
544        ExecOp::Compare(_) => "Compare",
545        ExecOp::Select => "Select",
546        ExecOp::Clamp => "Clamp",
547        ExecOp::Exp => "Exp",
548        ExecOp::Log => "Log",
549        ExecOp::Sin => "Sin",
550        ExecOp::Cos => "Cos",
551        ExecOp::Tanh => "Tanh",
552        ExecOp::Sqrt => "Sqrt",
553        ExecOp::Rsqrt => "Rsqrt",
554        ExecOp::Pow => "Pow",
555        ExecOp::Expm1 => "Expm1",
556        ExecOp::Log1p => "Log1p",
557        ExecOp::Gather(_) => "Gather",
558        ExecOp::GatherDynamicSliceSizes { .. } => "GatherDynamicSliceSizes",
559        ExecOp::Scatter(_) => "Scatter",
560        ExecOp::Slice(_) => "Slice",
561        ExecOp::DynamicSlice { .. } => "DynamicSlice",
562        ExecOp::DynamicUpdateSlice => "DynamicUpdateSlice",
563        ExecOp::Pad(_) => "Pad",
564        ExecOp::Concatenate { .. } => "Concatenate",
565        ExecOp::Reverse { .. } => "Reverse",
566        ExecOp::ShapeOf { .. } => "ShapeOf",
567        ExecOp::DynamicTruncate { .. } => "DynamicTruncate",
568        ExecOp::PadToMatch { .. } => "PadToMatch",
569        ExecOp::ReduceProd { .. } => "ReduceProd",
570        ExecOp::ReduceMax { .. } => "ReduceMax",
571        ExecOp::ReduceMin { .. } => "ReduceMin",
572        ExecOp::Extension(_) => "Extension",
573    }
574}
575
576#[cfg(test)]
577mod tests;