1use std::collections::HashMap;
20use std::sync::Arc;
21
22use computegraph::graph::GraphBuilder;
23use computegraph::types::{OperationRole, ValueRef};
24use tenferro_ops::std_tensor_op::StdTensorOp;
25use tenferro_ops::SymDim;
26use tenferro_tensor::{Tensor, TensorBackend};
27
28use crate::checkpoint::CheckpointNode;
29use crate::error::{Error, Result};
30use crate::metadata::{push_metadata_scope, register_scoped_graph_metadata};
31use crate::traced::{next_traced_id, TracedTensor};
32
33pub use crate::compiler::CompilerOptions;
34#[doc(hidden)]
35pub use crate::compiler::{compile_std_to_exec, compile_std_to_exec_with_options};
36#[doc(hidden)]
37pub use crate::exec::{ExecInstruction, ExecOp, ExecOutputExtents, ExecOutputShapes, ExecProgram};
38#[doc(hidden)]
39pub use crate::shape_infer::{
40 infer_output_dtype, infer_output_extents, infer_output_shapes, promote_dtype,
41 promote_dtype_div_like, promote_dtype_for_binary_op, promote_dtypes,
42};
43pub use tenferro_ops::ext_op::ExtensionOp;
44pub use tenferro_ops::ExtensionFamilyId;
45
46pub use crate::extension_cache::{
47 ExtensionCacheKey, ExtensionCacheLimits, ExtensionCacheSelector, ExtensionCacheStore,
48};
49pub use crate::extension_runtime::{
50 ExtensionExecutionContext, ExtensionExecutor, ExtensionRegistry, ExtensionRuntime,
51 ExtensionRuntimeRegistryError,
52};
53
54#[doc(hidden)]
60pub fn execute_lowered_program_with_backend_cache<B: TensorBackend + 'static>(
61 backend: &mut B,
62 program: &ExecProgram,
63 inputs: Vec<Tensor>,
64 backend_cache: &mut B::RuntimeCache,
65) -> Result<Vec<Tensor>> {
66 crate::exec::ensure_core_exec_program(
67 program,
68 "extension::execute_lowered_program_with_backend_cache",
69 )?;
70 crate::exec::eval_exec_ir_with_backend_cache(backend, program, inputs, backend_cache)
71}
72
73pub fn apply(op: Arc<dyn ExtensionOp>, inputs: &[&TracedTensor]) -> Result<Vec<TracedTensor>> {
130 if inputs.len() != op.input_count() {
131 return Err(Error::InvalidGraphBuild {
132 op: "extension::apply",
133 message: format!(
134 "op family {:?} expects {} inputs, got {}",
135 op.family_id(),
136 op.input_count(),
137 inputs.len()
138 ),
139 });
140 }
141
142 let input_dtypes: Vec<_> = inputs.iter().map(|t| t.dtype).collect();
148 let input_shape_storage: Vec<Vec<SymDim>> = inputs
149 .iter()
150 .map(|t| {
151 if let Some(hint) = t.shape_hint.clone() {
152 hint
153 } else {
154 (0..t.rank)
155 .map(|axis| SymDim::tensor_axis(t.id, axis))
156 .collect()
157 }
158 })
159 .collect();
160 let input_shape_refs: Vec<&[SymDim]> = input_shape_storage.iter().map(Vec::as_slice).collect();
161
162 let output_metas = op.infer_output_meta(&input_dtypes, &input_shape_refs);
163 if output_metas.len() != op.output_count() {
164 return Err(Error::InvalidGraphBuild {
165 op: "extension::apply",
166 message: format!(
167 "op family {:?}: infer_output_meta produced {} output metadata entries; op declared {} outputs",
168 op.family_id(),
169 output_metas.len(),
170 op.output_count()
171 ),
172 });
173 }
174
175 let mut builder = GraphBuilder::<StdTensorOp>::new();
177 for input in inputs {
178 builder.add_parent(input.graph.clone());
179 }
180 let op_inputs: Vec<ValueRef<StdTensorOp>> = inputs
181 .iter()
182 .map(|t| ValueRef::External(t.graph.values()[t.val].key.clone()))
183 .collect();
184 let carrier = StdTensorOp::Extension(op.clone());
185 let outputs = builder.add_operation(carrier, op_inputs, OperationRole::Primary);
186 builder.set_outputs(outputs.clone());
187 let graph = Arc::new(builder.build());
188 traced_outputs_from_graph(inputs, graph, &outputs, output_metas)
189}
190
191pub fn apply_expanded_graph(
197 inputs: &[&TracedTensor],
198 output_metas: Vec<(tenferro_tensor::DType, Vec<SymDim>)>,
199 build: impl FnOnce(&mut GraphBuilder<StdTensorOp>, &[ValueRef<StdTensorOp>]) -> Result<Vec<usize>>,
200) -> Result<Vec<TracedTensor>> {
201 let mut builder = GraphBuilder::<StdTensorOp>::new();
202 for input in inputs {
203 builder.add_parent(input.graph.clone());
204 }
205 let op_inputs: Vec<ValueRef<StdTensorOp>> = inputs
206 .iter()
207 .map(|t| ValueRef::External(t.graph.values()[t.val].key.clone()))
208 .collect();
209 let outputs = build(&mut builder, &op_inputs)?;
210 if outputs.len() != output_metas.len() {
211 return Err(Error::InvalidGraphBuild {
212 op: "extension::apply_expanded_graph",
213 message: format!(
214 "extension expanded graph returned {} outputs for {} output metadata entries",
215 outputs.len(),
216 output_metas.len()
217 ),
218 });
219 }
220 builder.set_outputs(outputs.clone());
221 let graph = Arc::new(builder.build());
222 traced_outputs_from_graph(inputs, graph, &outputs, output_metas)
223}
224
225fn traced_outputs_from_graph(
226 inputs: &[&TracedTensor],
227 graph: Arc<computegraph::graph::Graph<StdTensorOp>>,
228 outputs: &[usize],
229 output_metas: Vec<(tenferro_tensor::DType, Vec<SymDim>)>,
230) -> Result<Vec<TracedTensor>> {
231 let metadata_scope = Arc::new(register_scoped_graph_metadata(
232 graph.as_ref(),
233 std::iter::empty(),
234 )?);
235
236 let mut merged_map = HashMap::new();
237 let mut extra_roots = Vec::new();
238 let mut checkpoint_chain = None;
239 let mut metadata_scopes = vec![Arc::clone(&metadata_scope)];
240 for input in inputs {
241 merged_map.extend(input.inputs_map.iter().map(|(k, v)| (k.clone(), v.clone())));
242 extra_roots.extend(input.extra_roots.iter().cloned());
243 checkpoint_chain =
244 CheckpointNode::merge_chains(checkpoint_chain, input.checkpoint_chain.clone());
245 for scope in &input.metadata_scopes {
246 push_metadata_scope(&mut metadata_scopes, Arc::clone(scope));
247 }
248 }
249 let merged_map = Arc::new(merged_map);
250
251 let all_inputs_concrete = inputs.iter().all(|t| t.shape_hint.is_some());
252 Ok(outputs
253 .iter()
254 .zip(output_metas)
255 .map(|(&val, (dtype, shape))| {
256 let shape_hint = if all_inputs_concrete {
257 Some(shape.clone())
258 } else {
259 None
260 };
261 TracedTensor {
262 id: next_traced_id(),
263 rank: shape.len(),
264 dtype,
265 graph: graph.clone(),
266 val,
267 data: None,
268 shape_hint,
269 inputs_map: merged_map.clone(),
270 extra_roots: extra_roots.clone(),
271 checkpoint_chain: checkpoint_chain.clone(),
272 metadata_scopes: metadata_scopes.clone(),
273 }
274 })
275 .collect())
276}
277
278#[cfg(test)]
279mod tests;