Skip to main content

tenferro_ad/
extension.rs

1//! Eager AD support for out-of-tree extension primitives.
2
3use std::sync::Arc;
4
5use computegraph::GraphOperation;
6use tenferro_ops::std_tensor_op::StdTensorOp;
7use tenferro_runtime::ad_support::push_metadata_scope;
8use tenferro_runtime::{Error, Result};
9use tenferro_tensor::{Tensor, TensorValue};
10
11use crate::eager::{record_eager_outputs, EagerRuntime, EagerTensor};
12
13pub use tenferro_ops::ext_op::{ExtensionAdRule, ExtensionRegistryError, ExtensionRuleSet};
14pub use tenferro_runtime::extension::{
15    apply, ExtensionCacheKey, ExtensionCacheLimits, ExtensionCacheSelector, ExtensionCacheStore,
16    ExtensionExecutionContext, ExtensionExecutor, ExtensionFamilyId, ExtensionOp,
17    ExtensionRegistry, ExtensionRuntime, ExtensionRuntimeRegistryError,
18};
19
20/// Adopt an untracked eager tensor value produced by this runtime's backend.
21///
22/// This is a low-level extension contract for eager composite operations that
23/// execute through [`EagerRuntime::with_backend_mut`] and receive a lazy
24/// [`TensorValue`] from the backend. The value must have been produced for the
25/// same eager runtime; this helper intentionally does not register gradient
26/// metadata and must not be used for tracked outputs.
27///
28/// # Examples
29///
30/// ```rust
31/// use tenferro_ad::extension::adopt_untracked_eager_value;
32/// use tenferro_ad::EagerRuntime;
33/// use tenferro_cpu::CpuBackend;
34/// use tenferro_tensor::{Tensor, TensorValue};
35///
36/// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
37/// let value = TensorValue::from_tensor(
38///     Tensor::from_vec_col_major(vec![1], vec![1.0_f64]).unwrap(),
39/// );
40/// let eager = adopt_untracked_eager_value(ctx, value);
41/// assert_eq!(eager.shape(), &[1]);
42/// assert!(!eager.tracks_grad());
43/// ```
44#[must_use]
45pub fn adopt_untracked_eager_value(ctx: Arc<EagerRuntime>, value: TensorValue) -> EagerTensor {
46    EagerTensor::new_untracked_value_result(ctx, value)
47}
48
49/// Apply an extension op to eager AD tensors.
50///
51/// # Examples
52///
53/// ```rust
54/// use tenferro_ad::extension::apply_eager;
55/// use tenferro_ad::{EagerRuntime, EagerTensor};
56/// use tenferro_cpu::CpuBackend;
57/// use tenferro_tensor::Tensor;
58///
59/// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
60/// let x = EagerTensor::from_tensor_in(
61///     Tensor::from_vec_col_major(vec![1], vec![1.0_f64]).unwrap(),
62///     ctx,
63/// ).unwrap();
64/// let _ = &x;
65/// let _apply = apply_eager;
66/// ```
67pub fn apply_eager(op: Arc<dyn ExtensionOp>, inputs: &[&EagerTensor]) -> Result<Vec<EagerTensor>> {
68    let Some(first) = inputs.first() else {
69        return Err(Error::Internal(
70            "extension::apply_eager requires at least one input tensor".to_string(),
71        ));
72    };
73    if inputs.len() != op.input_count() {
74        return Err(Error::Internal(format!(
75            "extension::apply_eager: op family {:?} expects {} inputs, got {}",
76            op.family_id(),
77            op.input_count(),
78            inputs.len()
79        )));
80    }
81
82    let ctx = Arc::clone(&first.ctx);
83    for tensor in inputs.iter().skip(1) {
84        if !first.same_context(tensor) {
85            return Err(Error::ContextMismatch {
86                lhs: first.ctx_id(),
87                rhs: tensor.ctx_id(),
88            });
89        }
90    }
91
92    let op = StdTensorOp::Extension(op);
93    let input_reads: Vec<_> = inputs.iter().map(|tensor| tensor.tensor_read()).collect();
94    let outputs = ctx.exec_outputs_read(&op, &input_reads)?;
95    if outputs.len() != op.output_count() {
96        return Err(Error::Internal(format!(
97            "expected {} eager outputs for {:?}, got {}",
98            op.output_count(),
99            op,
100            outputs.len()
101        )));
102    }
103
104    if !inputs.iter().any(|input| input.requires_grad) {
105        return outputs
106            .into_iter()
107            .map(|output| EagerTensor::new_untracked_result(Arc::clone(&ctx), output))
108            .collect();
109    }
110
111    let outputs: Vec<Arc<Tensor>> = outputs.into_iter().map(Arc::new).collect();
112    let recorded = record_eager_outputs(&op, &outputs, inputs)?;
113    if recorded.traces.len() != outputs.len() {
114        return Err(Error::Internal(format!(
115            "expected {} eager traces for {:?}, got {}",
116            outputs.len(),
117            op,
118            recorded.traces.len()
119        )));
120    }
121    let mut metadata_scopes = vec![Arc::clone(&recorded.metadata_scope)];
122    for input in inputs {
123        for scope in &input.metadata_scopes {
124            push_metadata_scope(&mut metadata_scopes, Arc::clone(scope));
125        }
126    }
127
128    recorded
129        .traces
130        .into_iter()
131        .zip(outputs)
132        .map(|(trace, output)| {
133            EagerTensor::new_result(
134                Arc::clone(&ctx),
135                trace.key,
136                output.as_ref().clone(),
137                trace.requires_grad,
138                trace.trace,
139                metadata_scopes.clone(),
140            )
141        })
142        .collect()
143}
144
145/// Apply one standard tensor op eagerly and record it for AD when needed.
146///
147/// Extension crates use this when an extension-level eager operation expands
148/// into ordinary `StdTensorOp` nodes instead of a custom extension primitive.
149pub fn apply_standard_op(op: StdTensorOp, inputs: &[&EagerTensor]) -> Result<EagerTensor> {
150    if matches!(op, StdTensorOp::Extension(_)) {
151        return Err(Error::Internal(
152            "extension::apply_standard_op does not accept Extension ops".into(),
153        ));
154    }
155    EagerTensor::nary_op(inputs, op)
156}