tenferro_runtime/graph/lowering_view.rs
1use std::fmt;
2
3use tenferro_ops::ext_op::ExtensionOp;
4use tenferro_ops::ShapeExtent;
5use tenferro_tensor::{DType, DotGeneralConfig};
6
7use crate::exec::{ExecInstruction, ExecOp, ExecProgram};
8
9/// Read-only lowering view over a compiled graph program.
10///
11/// This view is for peer executor crates that need to translate a
12/// [`GraphProgram`](super::GraphProgram) without mutating the runtime-owned
13/// execution program.
14///
15/// # Examples
16///
17/// ```
18/// use tenferro_runtime::{GraphCompiler, TracedTensor};
19///
20/// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
21/// let mut compiler = GraphCompiler::new();
22/// let program = compiler.compile(&x.neg()).unwrap();
23/// let view = program.lowering_view();
24/// assert_eq!(view.output_slots().len(), 1);
25/// ```
26#[derive(Clone, Copy)]
27pub struct GraphProgramLoweringView<'a> {
28 exec: &'a ExecProgram,
29}
30
31impl fmt::Debug for GraphProgramLoweringView<'_> {
32 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33 f.debug_struct("GraphProgramLoweringView")
34 .field("slot_count", &self.slot_count())
35 .field("input_count", &self.input_slots().len())
36 .field("output_count", &self.output_slots().len())
37 .field("instruction_count", &self.exec.instructions.len())
38 .finish()
39 }
40}
41
42impl<'a> GraphProgramLoweringView<'a> {
43 pub(crate) fn new(exec: &'a ExecProgram) -> Self {
44 Self { exec }
45 }
46
47 /// Return the number of execution slots used by the program.
48 ///
49 /// # Examples
50 ///
51 /// ```
52 /// use tenferro_runtime::{GraphCompiler, TracedTensor};
53 ///
54 /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
55 /// let mut compiler = GraphCompiler::new();
56 /// let program = compiler.compile(&x.neg()).unwrap();
57 /// assert!(program.lowering_view().slot_count() >= 1);
58 /// ```
59 pub fn slot_count(&self) -> usize {
60 self.exec.n_slots
61 }
62
63 /// Return the execution slots populated by graph inputs.
64 ///
65 /// # Examples
66 ///
67 /// ```
68 /// use tenferro_runtime::{GraphCompiler, TracedTensor};
69 ///
70 /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
71 /// let mut compiler = GraphCompiler::new();
72 /// let program = compiler.compile(&x.neg()).unwrap();
73 /// assert_eq!(program.lowering_view().input_slots().len(), 1);
74 /// ```
75 pub fn input_slots(&self) -> &'a [usize] {
76 &self.exec.input_slots
77 }
78
79 /// Return the execution slots used as program outputs.
80 ///
81 /// # Examples
82 ///
83 /// ```
84 /// use tenferro_runtime::{GraphCompiler, TracedTensor};
85 ///
86 /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
87 /// let mut compiler = GraphCompiler::new();
88 /// let program = compiler.compile(&x.neg()).unwrap();
89 /// assert_eq!(program.lowering_view().output_slots().len(), 1);
90 /// ```
91 pub fn output_slots(&self) -> &'a [usize] {
92 &self.exec.output_slots
93 }
94
95 /// Iterate over read-only instruction views in execution order.
96 ///
97 /// # Examples
98 ///
99 /// ```
100 /// use tenferro_runtime::{GraphCompiler, TracedTensor};
101 ///
102 /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
103 /// let mut compiler = GraphCompiler::new();
104 /// let program = compiler.compile(&x.neg()).unwrap();
105 /// assert!(program.lowering_view().instructions().count() >= 1);
106 /// ```
107 pub fn instructions(&self) -> impl ExactSizeIterator<Item = GraphInstructionView<'a>> + '_ {
108 self.exec.instructions.iter().map(GraphInstructionView::new)
109 }
110}
111
112/// Read-only lowering view over one execution instruction.
113///
114/// # Examples
115///
116/// ```
117/// use tenferro_runtime::{GraphCompiler, TracedTensor};
118///
119/// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
120/// let mut compiler = GraphCompiler::new();
121/// let program = compiler.compile(&x.neg()).unwrap();
122/// let inst = program.lowering_view().instructions().next().unwrap();
123/// assert_eq!(inst.output_slots().len(), 1);
124/// ```
125#[derive(Clone, Copy)]
126pub struct GraphInstructionView<'a> {
127 inst: &'a ExecInstruction,
128}
129
130impl fmt::Debug for GraphInstructionView<'_> {
131 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132 f.debug_struct("GraphInstructionView")
133 .field("op", &self.op_name())
134 .field("input_count", &self.input_slots().len())
135 .field("output_count", &self.output_slots().len())
136 .field("dtype", &self.dtype())
137 .finish()
138 }
139}
140
141impl<'a> GraphInstructionView<'a> {
142 fn new(inst: &'a ExecInstruction) -> Self {
143 Self { inst }
144 }
145
146 /// Return the operation view for this instruction.
147 ///
148 /// # Examples
149 ///
150 /// ```
151 /// use tenferro_runtime::{GraphCompiler, GraphOpView, TracedTensor};
152 ///
153 /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
154 /// let mut compiler = GraphCompiler::new();
155 /// let program = compiler.compile(&x.neg()).unwrap();
156 /// let inst = program.lowering_view().instructions().next().unwrap();
157 /// assert!(matches!(inst.op(), GraphOpView::Negate));
158 /// ```
159 pub fn op(&self) -> GraphOpView<'a> {
160 match &self.inst.op {
161 ExecOp::Constant { dtype, bytes } => GraphOpView::Constant {
162 dtype: *dtype,
163 bytes,
164 },
165 ExecOp::Add => GraphOpView::Add,
166 ExecOp::Multiply => GraphOpView::Multiply,
167 ExecOp::Negate => GraphOpView::Negate,
168 ExecOp::Divide => GraphOpView::Divide,
169 ExecOp::Abs => GraphOpView::Abs,
170 ExecOp::Exp => GraphOpView::Exp,
171 ExecOp::Log => GraphOpView::Log,
172 ExecOp::Sin => GraphOpView::Sin,
173 ExecOp::Cos => GraphOpView::Cos,
174 ExecOp::Tanh => GraphOpView::Tanh,
175 ExecOp::Sqrt => GraphOpView::Sqrt,
176 ExecOp::Rsqrt => GraphOpView::Rsqrt,
177 ExecOp::Pow => GraphOpView::Pow,
178 ExecOp::Expm1 => GraphOpView::Expm1,
179 ExecOp::Log1p => GraphOpView::Log1p,
180 ExecOp::Convert { to } => GraphOpView::Convert { to: *to },
181 ExecOp::Reshape { .. } => GraphOpView::Reshape,
182 ExecOp::BroadcastInDim { dims, .. } => GraphOpView::BroadcastInDim { dims },
183 ExecOp::Transpose { perm } => GraphOpView::Transpose { perm },
184 ExecOp::ReduceSum { axes } => GraphOpView::ReduceSum { axes },
185 ExecOp::DotGeneral(config) => GraphOpView::DotGeneral { config },
186 ExecOp::Extension(op) => GraphOpView::Extension { op: op.as_ref() },
187 other => GraphOpView::Unsupported {
188 name: exec_op_name(other),
189 },
190 }
191 }
192
193 /// Return a stable operation name for diagnostics.
194 ///
195 /// # Examples
196 ///
197 /// ```
198 /// use tenferro_runtime::{GraphCompiler, TracedTensor};
199 ///
200 /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
201 /// let mut compiler = GraphCompiler::new();
202 /// let program = compiler.compile(&x.neg()).unwrap();
203 /// let inst = program.lowering_view().instructions().next().unwrap();
204 /// assert_eq!(inst.op_name(), "Negate");
205 /// ```
206 pub fn op_name(&self) -> &'static str {
207 self.op().name()
208 }
209
210 /// Return the input slots consumed by this instruction.
211 ///
212 /// # Examples
213 ///
214 /// ```
215 /// use tenferro_runtime::{GraphCompiler, TracedTensor};
216 ///
217 /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
218 /// let y = (&x + &x).unwrap();
219 /// let mut compiler = GraphCompiler::new();
220 /// let program = compiler.compile(&y).unwrap();
221 /// let inst = program.lowering_view().instructions().next().unwrap();
222 /// assert_eq!(inst.input_slots().len(), 2);
223 /// ```
224 pub fn input_slots(&self) -> &'a [usize] {
225 &self.inst.input_slots
226 }
227
228 /// Return the output slots written by this instruction.
229 ///
230 /// # Examples
231 ///
232 /// ```
233 /// use tenferro_runtime::{GraphCompiler, TracedTensor};
234 ///
235 /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
236 /// let mut compiler = GraphCompiler::new();
237 /// let program = compiler.compile(&x.neg()).unwrap();
238 /// let inst = program.lowering_view().instructions().next().unwrap();
239 /// assert_eq!(inst.output_slots().len(), 1);
240 /// ```
241 pub fn output_slots(&self) -> &'a [usize] {
242 &self.inst.output_slots
243 }
244
245 /// Return the dtype of this instruction's output.
246 ///
247 /// # Examples
248 ///
249 /// ```
250 /// use tenferro_runtime::{DType, GraphCompiler, TracedTensor};
251 ///
252 /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
253 /// let mut compiler = GraphCompiler::new();
254 /// let program = compiler.compile(&x.neg()).unwrap();
255 /// let inst = program.lowering_view().instructions().next().unwrap();
256 /// assert_eq!(inst.dtype(), DType::F64);
257 /// ```
258 pub fn dtype(&self) -> DType {
259 self.inst.dtype
260 }
261
262 /// Resolve an exact static output shape for this instruction.
263 ///
264 /// `input_shapes` must be ordered the same way as [`Self::input_slots`].
265 ///
266 /// # Examples
267 ///
268 /// ```
269 /// use tenferro_runtime::{GraphCompiler, TracedTensor};
270 ///
271 /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
272 /// let mut compiler = GraphCompiler::new();
273 /// let program = compiler.compile(&x.neg()).unwrap();
274 /// let inst = program.lowering_view().instructions().next().unwrap();
275 /// assert_eq!(inst.static_output_shape(0, &[&[1]]).unwrap(), vec![1]);
276 /// ```
277 pub fn static_output_shape(
278 &self,
279 output_index: usize,
280 input_shapes: &[&[usize]],
281 ) -> std::result::Result<Vec<usize>, GraphProgramLoweringShapeError> {
282 let extents = self.inst.output_extents.get(output_index).ok_or(
283 GraphProgramLoweringShapeError::MissingOutput {
284 op: self.op_name(),
285 output_index,
286 },
287 )?;
288 let mut shape = Vec::with_capacity(extents.len());
289 for (axis, extent) in extents.iter().enumerate() {
290 match extent {
291 ShapeExtent::Exact(dim) => shape.push(dim.eval(input_shapes).map_err(|err| {
292 GraphProgramLoweringShapeError::InvalidDimExpr {
293 op: self.op_name(),
294 output_index,
295 axis,
296 source: err,
297 }
298 })?),
299 ShapeExtent::UpperBound(_) => {
300 return Err(GraphProgramLoweringShapeError::NonStatic {
301 op: self.op_name(),
302 output_index,
303 axis,
304 kind: "an upper bound",
305 });
306 }
307 ShapeExtent::Unknown => {
308 return Err(GraphProgramLoweringShapeError::NonStatic {
309 op: self.op_name(),
310 output_index,
311 axis,
312 kind: "unknown",
313 });
314 }
315 }
316 }
317 Ok(shape)
318 }
319}
320
321/// Read-only operation view for graph lowering integrations.
322///
323/// Unsupported operation families are represented as [`GraphOpView::Unsupported`]
324/// so peer executors can emit precise diagnostics without depending on the raw
325/// execution IR.
326///
327/// # Examples
328///
329/// ```
330/// use tenferro_runtime::{GraphCompiler, GraphOpView, TracedTensor};
331///
332/// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
333/// let mut compiler = GraphCompiler::new();
334/// let program = compiler.compile(&x.neg()).unwrap();
335/// let op = program.lowering_view().instructions().next().unwrap().op();
336/// assert!(matches!(op, GraphOpView::Negate));
337/// ```
338#[derive(Clone, Copy)]
339pub enum GraphOpView<'a> {
340 /// Scalar constant payload.
341 Constant { dtype: DType, bytes: &'a [u8] },
342 /// Elementwise addition.
343 Add,
344 /// Elementwise multiplication.
345 Multiply,
346 /// Elementwise negation.
347 Negate,
348 /// Elementwise division.
349 Divide,
350 /// Elementwise absolute value.
351 Abs,
352 /// Elementwise exponential.
353 Exp,
354 /// Elementwise natural logarithm.
355 Log,
356 /// Elementwise sine.
357 Sin,
358 /// Elementwise cosine.
359 Cos,
360 /// Elementwise hyperbolic tangent.
361 Tanh,
362 /// Elementwise square root.
363 Sqrt,
364 /// Elementwise reciprocal square root.
365 Rsqrt,
366 /// Elementwise power.
367 Pow,
368 /// Elementwise exponential minus one.
369 Expm1,
370 /// Elementwise natural logarithm of one plus input.
371 Log1p,
372 /// Dtype conversion.
373 Convert { to: DType },
374 /// Shape-only reshape.
375 Reshape,
376 /// Broadcast with output-to-input dimension mapping.
377 BroadcastInDim { dims: &'a [usize] },
378 /// Transpose with output dimension permutation.
379 Transpose { perm: &'a [usize] },
380 /// Sum reduction.
381 ReduceSum { axes: &'a [usize] },
382 /// General dot/contraction.
383 DotGeneral { config: &'a DotGeneralConfig },
384 /// Extension operation with an owner-provided optional standard-op lowering.
385 Extension { op: &'a dyn ExtensionOp },
386 /// Operation outside the stable public lowering view.
387 Unsupported { name: &'static str },
388}
389
390impl fmt::Debug for GraphOpView<'_> {
391 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
392 match self {
393 Self::Constant { dtype, bytes } => f
394 .debug_struct("Constant")
395 .field("dtype", dtype)
396 .field("byte_len", &bytes.len())
397 .finish(),
398 Self::Add => f.write_str("Add"),
399 Self::Multiply => f.write_str("Multiply"),
400 Self::Negate => f.write_str("Negate"),
401 Self::Divide => f.write_str("Divide"),
402 Self::Abs => f.write_str("Abs"),
403 Self::Exp => f.write_str("Exp"),
404 Self::Log => f.write_str("Log"),
405 Self::Sin => f.write_str("Sin"),
406 Self::Cos => f.write_str("Cos"),
407 Self::Tanh => f.write_str("Tanh"),
408 Self::Sqrt => f.write_str("Sqrt"),
409 Self::Rsqrt => f.write_str("Rsqrt"),
410 Self::Pow => f.write_str("Pow"),
411 Self::Expm1 => f.write_str("Expm1"),
412 Self::Log1p => f.write_str("Log1p"),
413 Self::Convert { to } => f.debug_struct("Convert").field("to", to).finish(),
414 Self::Reshape => f.write_str("Reshape"),
415 Self::BroadcastInDim { dims } => f
416 .debug_struct("BroadcastInDim")
417 .field("dims", dims)
418 .finish(),
419 Self::Transpose { perm } => f.debug_struct("Transpose").field("perm", perm).finish(),
420 Self::ReduceSum { axes } => f.debug_struct("ReduceSum").field("axes", axes).finish(),
421 Self::DotGeneral { config } => f
422 .debug_struct("DotGeneral")
423 .field("config", config)
424 .finish(),
425 Self::Extension { op } => f
426 .debug_struct("Extension")
427 .field("family_id", &op.family_id())
428 .finish(),
429 Self::Unsupported { name } => {
430 f.debug_struct("Unsupported").field("name", name).finish()
431 }
432 }
433 }
434}
435
436impl GraphOpView<'_> {
437 /// Return the stable operation name used in diagnostics.
438 ///
439 /// # Examples
440 ///
441 /// ```
442 /// use tenferro_runtime::{GraphCompiler, TracedTensor};
443 ///
444 /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
445 /// let mut compiler = GraphCompiler::new();
446 /// let program = compiler.compile(&x.neg()).unwrap();
447 /// let op = program.lowering_view().instructions().next().unwrap().op();
448 /// assert_eq!(op.name(), "Negate");
449 /// ```
450 pub fn name(&self) -> &'static str {
451 match self {
452 Self::Constant { .. } => "Constant",
453 Self::Add => "Add",
454 Self::Multiply => "Multiply",
455 Self::Negate => "Negate",
456 Self::Divide => "Divide",
457 Self::Abs => "Abs",
458 Self::Exp => "Exp",
459 Self::Log => "Log",
460 Self::Sin => "Sin",
461 Self::Cos => "Cos",
462 Self::Tanh => "Tanh",
463 Self::Sqrt => "Sqrt",
464 Self::Rsqrt => "Rsqrt",
465 Self::Pow => "Pow",
466 Self::Expm1 => "Expm1",
467 Self::Log1p => "Log1p",
468 Self::Convert { .. } => "Convert",
469 Self::Reshape => "Reshape",
470 Self::BroadcastInDim { .. } => "BroadcastInDim",
471 Self::Transpose { .. } => "Transpose",
472 Self::ReduceSum { .. } => "ReduceSum",
473 Self::DotGeneral { .. } => "DotGeneral",
474 Self::Extension { .. } => "Extension",
475 Self::Unsupported { name } => name,
476 }
477 }
478}
479
480/// Error returned when a lowering view cannot resolve an exact output shape.
481///
482/// # Examples
483///
484/// ```
485/// use tenferro_runtime::GraphProgramLoweringShapeError;
486///
487/// let err = GraphProgramLoweringShapeError::MissingOutput {
488/// op: "Example",
489/// output_index: 0,
490/// };
491/// assert!(err.to_string().contains("Example"));
492/// ```
493#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
494pub enum GraphProgramLoweringShapeError {
495 /// The instruction has no metadata for the requested output.
496 #[error("ExecOp::{op} missing output_extents for output {output_index}")]
497 MissingOutput {
498 op: &'static str,
499 output_index: usize,
500 },
501 /// The instruction output has dynamic or unknown extent metadata.
502 #[error("ExecOp::{op} output {output_index} axis {axis} has non-static extent: {kind}")]
503 NonStatic {
504 op: &'static str,
505 output_index: usize,
506 axis: usize,
507 kind: &'static str,
508 },
509 /// Static shape evaluation failed for an exact dimension expression.
510 #[error(
511 "ExecOp::{op} output {output_index} axis {axis} has invalid dimension expression: {source}"
512 )]
513 InvalidDimExpr {
514 op: &'static str,
515 output_index: usize,
516 axis: usize,
517 source: tenferro_ops::dim_expr::DimExprEvalError,
518 },
519}
520
521fn exec_op_name(op: &ExecOp) -> &'static str {
522 match op {
523 ExecOp::Transpose { .. } => "Transpose",
524 ExecOp::Reshape { .. } => "Reshape",
525 ExecOp::BroadcastInDim { .. } => "BroadcastInDim",
526 ExecOp::Convert { .. } => "Convert",
527 ExecOp::Constant { .. } => "Constant",
528 ExecOp::DotGeneral(_) => "DotGeneral",
529 ExecOp::DotGeneralWithConj { .. } => "DotGeneralWithConj",
530 ExecOp::ReduceSum { .. } => "ReduceSum",
531 ExecOp::ExtractDiag { .. } => "ExtractDiag",
532 ExecOp::EmbedDiag { .. } => "EmbedDiag",
533 ExecOp::Tril { .. } => "Tril",
534 ExecOp::Triu { .. } => "Triu",
535 ExecOp::Add => "Add",
536 ExecOp::Multiply => "Multiply",
537 ExecOp::Negate => "Negate",
538 ExecOp::Conj => "Conj",
539 ExecOp::Divide => "Divide",
540 ExecOp::Abs => "Abs",
541 ExecOp::Sign => "Sign",
542 ExecOp::Maximum => "Maximum",
543 ExecOp::Minimum => "Minimum",
544 ExecOp::Compare(_) => "Compare",
545 ExecOp::Select => "Select",
546 ExecOp::Clamp => "Clamp",
547 ExecOp::Exp => "Exp",
548 ExecOp::Log => "Log",
549 ExecOp::Sin => "Sin",
550 ExecOp::Cos => "Cos",
551 ExecOp::Tanh => "Tanh",
552 ExecOp::Sqrt => "Sqrt",
553 ExecOp::Rsqrt => "Rsqrt",
554 ExecOp::Pow => "Pow",
555 ExecOp::Expm1 => "Expm1",
556 ExecOp::Log1p => "Log1p",
557 ExecOp::Gather(_) => "Gather",
558 ExecOp::GatherDynamicSliceSizes { .. } => "GatherDynamicSliceSizes",
559 ExecOp::Scatter(_) => "Scatter",
560 ExecOp::Slice(_) => "Slice",
561 ExecOp::DynamicSlice { .. } => "DynamicSlice",
562 ExecOp::DynamicUpdateSlice => "DynamicUpdateSlice",
563 ExecOp::Pad(_) => "Pad",
564 ExecOp::Concatenate { .. } => "Concatenate",
565 ExecOp::Reverse { .. } => "Reverse",
566 ExecOp::ShapeOf { .. } => "ShapeOf",
567 ExecOp::DynamicTruncate { .. } => "DynamicTruncate",
568 ExecOp::PadToMatch { .. } => "PadToMatch",
569 ExecOp::ReduceProd { .. } => "ReduceProd",
570 ExecOp::ReduceMax { .. } => "ReduceMax",
571 ExecOp::ReduceMin { .. } => "ReduceMin",
572 ExecOp::Extension(_) => "Extension",
573 }
574}
575
576#[cfg(test)]
577mod tests;