Skip to main content

tenferro_runtime/
extension.rs

1//! Public surface for out-of-tree extension primitives.
2//!
3//! This module exposes the Stage 6 `ExtensionOp` mechanism through the
4//! runtime crate. External crates implement
5//! [`tenferro_ops::ext_op::ExtensionOp`] and build traced graphs containing
6//! the extension via [`apply`].
7//!
8//! See `docs/spec/extension-op.md` for the normative contract.
9//!
10//! # Examples
11//!
12//! ```rust
13//! use tenferro_runtime::extension::{apply, ExtensionOp};
14//!
15//! // Construct an `Arc<dyn ExtensionOp>` and call `apply(op, &[input])`
16//! // to lower it into a `TracedTensor`.
17//! ```
18
19use std::collections::HashMap;
20use std::sync::Arc;
21
22use computegraph::graph::GraphBuilder;
23use computegraph::types::{OperationRole, ValueRef};
24use tenferro_ops::std_tensor_op::StdTensorOp;
25use tenferro_ops::SymDim;
26use tenferro_tensor::{Tensor, TensorBackend};
27
28use crate::checkpoint::CheckpointNode;
29use crate::error::{Error, Result};
30use crate::metadata::{push_metadata_scope, register_scoped_graph_metadata};
31use crate::traced::{next_traced_id, TracedTensor};
32
33pub use crate::compiler::CompilerOptions;
34#[doc(hidden)]
35pub use crate::compiler::{compile_std_to_exec, compile_std_to_exec_with_options};
36#[doc(hidden)]
37pub use crate::exec::{ExecInstruction, ExecOp, ExecOutputExtents, ExecOutputShapes, ExecProgram};
38#[doc(hidden)]
39pub use crate::shape_infer::{
40    infer_output_dtype, infer_output_extents, infer_output_shapes, promote_dtype,
41    promote_dtype_div_like, promote_dtype_for_binary_op, promote_dtypes,
42};
43pub use tenferro_ops::ext_op::ExtensionOp;
44pub use tenferro_ops::ExtensionFamilyId;
45
46pub use crate::extension_cache::{
47    ExtensionCacheKey, ExtensionCacheLimits, ExtensionCacheSelector, ExtensionCacheStore,
48};
49pub use crate::extension_runtime::{
50    ExtensionExecutionContext, ExtensionExecutor, ExtensionRegistry, ExtensionRuntime,
51    ExtensionRuntimeRegistryError,
52};
53
54/// Execute a lowered core program with caller-owned backend runtime cache state.
55///
56/// This owner-scoped hook is for operation-family runtimes that expand an
57/// extension into core tensor operations and need to run that lowered program
58/// while preserving the runtime cache owned by the outer graph executor.
59#[doc(hidden)]
60pub fn execute_lowered_program_with_backend_cache<B: TensorBackend + 'static>(
61    backend: &mut B,
62    program: &ExecProgram,
63    inputs: Vec<Tensor>,
64    backend_cache: &mut B::RuntimeCache,
65) -> Result<Vec<Tensor>> {
66    crate::exec::ensure_core_exec_program(
67        program,
68        "extension::execute_lowered_program_with_backend_cache",
69    )?;
70    crate::exec::eval_exec_ir_with_backend_cache(backend, program, inputs, backend_cache)
71}
72
73/// Apply an extension op in the traced graph.
74///
75/// The `op` value is cloned into a `StdTensorOp::Extension(Arc<dyn ExtensionOp>)`
76/// carrier. The returned vector contains one [`TracedTensor`] per declared
77/// output slot of the extension. Output shapes are inferred via
78/// [`ExtensionOp::infer_output_meta`] using the input shape hints.
79///
80/// `inputs.len()` must equal `op.input_count()`, and each input's
81/// `shape_hint` must be present (i.e. the extension must be used on
82/// tensors whose rank is known at graph-build time). For symbolic-shape
83/// composition, bind the placeholder tensors via
84/// [`crate::GraphExecutor::run_with_inputs`] at evaluation time.
85///
86/// # Examples
87///
88/// ```rust
89/// # use std::any::Any;
90/// use std::sync::Arc;
91/// use tenferro_runtime::extension::{apply, ExtensionOp};
92/// use tenferro_runtime::{DType, SymDim, Tensor, TracedTensor};
93///
94/// # #[derive(Clone, Debug)]
95/// # struct IdentityExt;
96/// # impl ExtensionOp for IdentityExt {
97/// #     fn family_id(&self) -> &'static str { "example.identity.v1" }
98/// #     fn payload_hash(&self, _hasher: &mut dyn std::hash::Hasher) {}
99/// #     fn payload_eq(&self, other: &dyn ExtensionOp) -> bool {
100/// #         other.as_any().downcast_ref::<IdentityExt>().is_some()
101/// #     }
102/// #     fn clone_arc(&self) -> Arc<dyn ExtensionOp> { Arc::new(self.clone()) }
103/// #     fn as_any(&self) -> &dyn Any { self }
104/// #     fn input_count(&self) -> usize { 1 }
105/// #     fn output_count(&self) -> usize { 1 }
106/// #     fn infer_output_meta(
107/// #         &self,
108/// #         dtypes: &[DType],
109/// #         shapes: &[&[SymDim]],
110/// #     ) -> Vec<(DType, Vec<SymDim>)> {
111/// #         vec![(dtypes[0], shapes[0].to_vec())]
112/// #     }
113/// #     fn eager_execute(&self, inputs: &[&Tensor]) -> tenferro_tensor::Result<Vec<Tensor>> {
114/// #         Ok(vec![inputs[0].clone()])
115/// #     }
116/// # }
117/// let op: Arc<dyn ExtensionOp> = Arc::new(IdentityExt);
118/// let a = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
119/// let outputs = apply(op, &[&a])?;
120/// assert_eq!(outputs.len(), 1);
121/// # Ok::<(), tenferro_runtime::Error>(())
122/// ```
123///
124/// # Errors
125///
126/// Returns [`Error::InvalidGraphBuild`] when the extension receives the wrong
127/// number of inputs or when [`ExtensionOp::infer_output_meta`] returns metadata
128/// whose count does not match [`ExtensionOp::output_count`].
129pub fn apply(op: Arc<dyn ExtensionOp>, inputs: &[&TracedTensor]) -> Result<Vec<TracedTensor>> {
130    if inputs.len() != op.input_count() {
131        return Err(Error::InvalidGraphBuild {
132            op: "extension::apply",
133            message: format!(
134                "op family {:?} expects {} inputs, got {}",
135                op.family_id(),
136                op.input_count(),
137                inputs.len()
138            ),
139        });
140    }
141
142    // Build the per-input dtype / shape slices the extension's
143    // `infer_output_meta` wants. Symbolic-shape inputs (shape_hint =
144    // None) use per-axis TensorAxis symbolic dims keyed by the input
145    // TracedTensor's id so downstream composition still resolves
146    // correctly via tenferro-internal-ops's SymDim API.
147    let input_dtypes: Vec<_> = inputs.iter().map(|t| t.dtype).collect();
148    let input_shape_storage: Vec<Vec<SymDim>> = inputs
149        .iter()
150        .map(|t| {
151            if let Some(hint) = t.shape_hint.clone() {
152                hint
153            } else {
154                (0..t.rank)
155                    .map(|axis| SymDim::tensor_axis(t.id, axis))
156                    .collect()
157            }
158        })
159        .collect();
160    let input_shape_refs: Vec<&[SymDim]> = input_shape_storage.iter().map(Vec::as_slice).collect();
161
162    let output_metas = op.infer_output_meta(&input_dtypes, &input_shape_refs);
163    if output_metas.len() != op.output_count() {
164        return Err(Error::InvalidGraphBuild {
165            op: "extension::apply",
166            message: format!(
167                "op family {:?}: infer_output_meta produced {} output metadata entries; op declared {} outputs",
168                op.family_id(),
169                output_metas.len(),
170                op.output_count()
171            ),
172        });
173    }
174
175    // Build the graph that carries the Extension op.
176    let mut builder = GraphBuilder::<StdTensorOp>::new();
177    for input in inputs {
178        builder.add_parent(input.graph.clone());
179    }
180    let op_inputs: Vec<ValueRef<StdTensorOp>> = inputs
181        .iter()
182        .map(|t| ValueRef::External(t.graph.values()[t.val].key.clone()))
183        .collect();
184    let carrier = StdTensorOp::Extension(op.clone());
185    let outputs = builder.add_operation(carrier, op_inputs, OperationRole::Primary);
186    builder.set_outputs(outputs.clone());
187    let graph = Arc::new(builder.build());
188    traced_outputs_from_graph(inputs, graph, &outputs, output_metas)
189}
190
191/// Apply an extension-provided lowering as ordinary traced graph operations.
192///
193/// This is for extension crates whose operation can be expanded at graph-build
194/// time. It preserves the same parent graph and metadata merging behavior as
195/// [`apply`], but does not insert a `StdTensorOp::Extension` carrier.
196pub fn apply_expanded_graph(
197    inputs: &[&TracedTensor],
198    output_metas: Vec<(tenferro_tensor::DType, Vec<SymDim>)>,
199    build: impl FnOnce(&mut GraphBuilder<StdTensorOp>, &[ValueRef<StdTensorOp>]) -> Result<Vec<usize>>,
200) -> Result<Vec<TracedTensor>> {
201    let mut builder = GraphBuilder::<StdTensorOp>::new();
202    for input in inputs {
203        builder.add_parent(input.graph.clone());
204    }
205    let op_inputs: Vec<ValueRef<StdTensorOp>> = inputs
206        .iter()
207        .map(|t| ValueRef::External(t.graph.values()[t.val].key.clone()))
208        .collect();
209    let outputs = build(&mut builder, &op_inputs)?;
210    if outputs.len() != output_metas.len() {
211        return Err(Error::InvalidGraphBuild {
212            op: "extension::apply_expanded_graph",
213            message: format!(
214                "extension expanded graph returned {} outputs for {} output metadata entries",
215                outputs.len(),
216                output_metas.len()
217            ),
218        });
219    }
220    builder.set_outputs(outputs.clone());
221    let graph = Arc::new(builder.build());
222    traced_outputs_from_graph(inputs, graph, &outputs, output_metas)
223}
224
225fn traced_outputs_from_graph(
226    inputs: &[&TracedTensor],
227    graph: Arc<computegraph::graph::Graph<StdTensorOp>>,
228    outputs: &[usize],
229    output_metas: Vec<(tenferro_tensor::DType, Vec<SymDim>)>,
230) -> Result<Vec<TracedTensor>> {
231    let metadata_scope = Arc::new(register_scoped_graph_metadata(
232        graph.as_ref(),
233        std::iter::empty(),
234    )?);
235
236    let mut merged_map = HashMap::new();
237    let mut extra_roots = Vec::new();
238    let mut checkpoint_chain = None;
239    let mut metadata_scopes = vec![Arc::clone(&metadata_scope)];
240    for input in inputs {
241        merged_map.extend(input.inputs_map.iter().map(|(k, v)| (k.clone(), v.clone())));
242        extra_roots.extend(input.extra_roots.iter().cloned());
243        checkpoint_chain =
244            CheckpointNode::merge_chains(checkpoint_chain, input.checkpoint_chain.clone());
245        for scope in &input.metadata_scopes {
246            push_metadata_scope(&mut metadata_scopes, Arc::clone(scope));
247        }
248    }
249    let merged_map = Arc::new(merged_map);
250
251    let all_inputs_concrete = inputs.iter().all(|t| t.shape_hint.is_some());
252    Ok(outputs
253        .iter()
254        .zip(output_metas)
255        .map(|(&val, (dtype, shape))| {
256            let shape_hint = if all_inputs_concrete {
257                Some(shape.clone())
258            } else {
259                None
260            };
261            TracedTensor {
262                id: next_traced_id(),
263                rank: shape.len(),
264                dtype,
265                graph: graph.clone(),
266                val,
267                data: None,
268                shape_hint,
269                inputs_map: merged_map.clone(),
270                extra_roots: extra_roots.clone(),
271                checkpoint_chain: checkpoint_chain.clone(),
272                metadata_scopes: metadata_scopes.clone(),
273            }
274        })
275        .collect())
276}
277
278#[cfg(test)]
279mod tests;