1use std::sync::Arc;
4
5use computegraph::GraphOperation;
6use tenferro_ops::std_tensor_op::StdTensorOp;
7use tenferro_runtime::ad_support::push_metadata_scope;
8use tenferro_runtime::{Error, Result};
9use tenferro_tensor::{Tensor, TensorValue};
10
11use crate::eager::{record_eager_outputs, EagerRuntime, EagerTensor};
12
13pub use tenferro_ops::ext_op::{ExtensionAdRule, ExtensionRegistryError, ExtensionRuleSet};
14pub use tenferro_runtime::extension::{
15 apply, ExtensionCacheKey, ExtensionCacheLimits, ExtensionCacheSelector, ExtensionCacheStore,
16 ExtensionExecutionContext, ExtensionExecutor, ExtensionFamilyId, ExtensionOp,
17 ExtensionRegistry, ExtensionRuntime, ExtensionRuntimeRegistryError,
18};
19
20#[must_use]
45pub fn adopt_untracked_eager_value(ctx: Arc<EagerRuntime>, value: TensorValue) -> EagerTensor {
46 EagerTensor::new_untracked_value_result(ctx, value)
47}
48
49pub fn apply_eager(op: Arc<dyn ExtensionOp>, inputs: &[&EagerTensor]) -> Result<Vec<EagerTensor>> {
68 let Some(first) = inputs.first() else {
69 return Err(Error::Internal(
70 "extension::apply_eager requires at least one input tensor".to_string(),
71 ));
72 };
73 if inputs.len() != op.input_count() {
74 return Err(Error::Internal(format!(
75 "extension::apply_eager: op family {:?} expects {} inputs, got {}",
76 op.family_id(),
77 op.input_count(),
78 inputs.len()
79 )));
80 }
81
82 let ctx = Arc::clone(&first.ctx);
83 for tensor in inputs.iter().skip(1) {
84 if !first.same_context(tensor) {
85 return Err(Error::ContextMismatch {
86 lhs: first.ctx_id(),
87 rhs: tensor.ctx_id(),
88 });
89 }
90 }
91
92 let op = StdTensorOp::Extension(op);
93 let input_reads: Vec<_> = inputs.iter().map(|tensor| tensor.tensor_read()).collect();
94 let outputs = ctx.exec_outputs_read(&op, &input_reads)?;
95 if outputs.len() != op.output_count() {
96 return Err(Error::Internal(format!(
97 "expected {} eager outputs for {:?}, got {}",
98 op.output_count(),
99 op,
100 outputs.len()
101 )));
102 }
103
104 if !inputs.iter().any(|input| input.requires_grad) {
105 return outputs
106 .into_iter()
107 .map(|output| EagerTensor::new_untracked_result(Arc::clone(&ctx), output))
108 .collect();
109 }
110
111 let outputs: Vec<Arc<Tensor>> = outputs.into_iter().map(Arc::new).collect();
112 let recorded = record_eager_outputs(&op, &outputs, inputs)?;
113 if recorded.traces.len() != outputs.len() {
114 return Err(Error::Internal(format!(
115 "expected {} eager traces for {:?}, got {}",
116 outputs.len(),
117 op,
118 recorded.traces.len()
119 )));
120 }
121 let mut metadata_scopes = vec![Arc::clone(&recorded.metadata_scope)];
122 for input in inputs {
123 for scope in &input.metadata_scopes {
124 push_metadata_scope(&mut metadata_scopes, Arc::clone(scope));
125 }
126 }
127
128 recorded
129 .traces
130 .into_iter()
131 .zip(outputs)
132 .map(|(trace, output)| {
133 EagerTensor::new_result(
134 Arc::clone(&ctx),
135 trace.key,
136 output.as_ref().clone(),
137 trace.requires_grad,
138 trace.trace,
139 metadata_scopes.clone(),
140 )
141 })
142 .collect()
143}
144
145pub fn apply_standard_op(op: StdTensorOp, inputs: &[&EagerTensor]) -> Result<EagerTensor> {
150 if matches!(op, StdTensorOp::Extension(_)) {
151 return Err(Error::Internal(
152 "extension::apply_standard_op does not accept Extension ops".into(),
153 ));
154 }
155 EagerTensor::nary_op(inputs, op)
156}