Skip to main content

tenferro_ops/ad/
mod.rs

1//! Automatic differentiation rules for [`StdTensorOp`].
2//!
3//! `linearize` and `transpose_rule` are separate graph-level contracts.
4//! Core ops keep their rules here; extension ops own their own AD support
5//! through the extension trait.
6
7pub mod context;
8
9#[cfg(feature = "autodiff")]
10mod analytic;
11#[cfg(feature = "autodiff")]
12mod contraction;
13#[cfg(feature = "autodiff")]
14mod diagonal;
15#[cfg(feature = "autodiff")]
16mod dynamic;
17#[cfg(feature = "autodiff")]
18mod elementwise;
19#[cfg(feature = "autodiff")]
20mod indexing;
21#[cfg(feature = "autodiff")]
22pub(crate) mod registry;
23#[cfg(feature = "autodiff")]
24mod semiring;
25#[cfg(feature = "autodiff")]
26mod structural;
27#[cfg(feature = "autodiff")]
28#[doc(hidden)]
29pub mod support;
30#[cfg(feature = "autodiff")]
31mod zeros;
32
33#[cfg(feature = "autodiff")]
34use computegraph::graph::GraphBuilder;
35#[cfg(feature = "autodiff")]
36use computegraph::types::{LocalValueId, OperationRole, ValueKey, ValueRef};
37#[cfg(feature = "autodiff")]
38use tidu::{ADRuleError, ADRuleKind, ADRuleResult, PrimitiveBuilder, PrimitiveValue};
39
40#[cfg(feature = "autodiff")]
41use crate::ext_op::{linearize_extension_rule, transpose_extension_rule};
42#[cfg(feature = "autodiff")]
43use crate::std_tensor_op::StdTensorOp;
44
45#[cfg(feature = "autodiff")]
46fn missing_primitive_kind(op: &StdTensorOp, rule: ADRuleKind) -> ADRuleError {
47    ADRuleError::invalid_input(
48        "tenferro-internal-ops primitive AD dispatch",
49        rule,
50        format!("non-extension operation has no primitive kind: {op:?}"),
51    )
52}
53
54/// Builder interface used by tenferro AD rules.
55///
56/// # Examples
57///
58/// ```
59/// use computegraph::graph::GraphBuilder;
60/// use computegraph::{OperationRole, ValueRef};
61/// use tenferro_ops::ad::PrimitiveRuleBuilder;
62/// use tenferro_ops::input_key::TensorInputKey;
63/// use tenferro_ops::std_tensor_op::StdTensorOp;
64///
65/// let mut builder = GraphBuilder::<StdTensorOp>::new();
66/// let x = builder.add_input(TensorInputKey::User { id: 1 });
67/// let out = PrimitiveRuleBuilder::add_operation(
68///     &mut builder,
69///     StdTensorOp::Neg,
70///     vec![ValueRef::Local(x)],
71///     OperationRole::Primary,
72/// );
73/// assert_eq!(out.len(), 1);
74/// ```
75#[cfg(feature = "autodiff")]
76pub trait PrimitiveRuleBuilder {
77    /// Add one primitive graph operation and return local ids for its outputs.
78    fn add_operation(
79        &mut self,
80        operation: StdTensorOp,
81        inputs: Vec<ValueRef<StdTensorOp>>,
82        role: OperationRole,
83    ) -> Vec<LocalValueId>;
84}
85
86#[cfg(feature = "autodiff")]
87impl<B> PrimitiveRuleBuilder for B
88where
89    B: PrimitiveBuilder<StdTensorOp> + ?Sized,
90{
91    fn add_operation(
92        &mut self,
93        operation: StdTensorOp,
94        inputs: Vec<ValueRef<StdTensorOp>>,
95        role: OperationRole,
96    ) -> Vec<LocalValueId> {
97        let inputs = inputs.into_iter().map(PrimitiveValue::from).collect();
98        PrimitiveBuilder::add_primitive(self, operation, inputs, role)
99    }
100}
101
102#[cfg(feature = "autodiff")]
103impl PrimitiveRuleBuilder for GraphBuilder<StdTensorOp> {
104    fn add_operation(
105        &mut self,
106        operation: StdTensorOp,
107        inputs: Vec<ValueRef<StdTensorOp>>,
108        role: OperationRole,
109    ) -> Vec<LocalValueId> {
110        GraphBuilder::add_operation(self, operation, inputs, role)
111    }
112}
113
114/// Forward-mode AD (JVP) for `StdTensorOp`: given the primal op and its
115/// tangent inputs, emit the linearized graph into `builder` and return
116/// the output tangents.
117///
118/// Rules per op live in the category submodules (`semiring`, `analytic`,
119/// `elementwise`, `structural`, `contraction`, `indexing`, `diagonal`,
120/// `dynamic`). `StdTensorOp::Extension(_)` delegates to the trait.
121#[cfg(feature = "autodiff")]
122pub fn linearize(
123    op: &StdTensorOp,
124    builder: &mut dyn PrimitiveRuleBuilder,
125    primal_in: &[ValueKey<StdTensorOp>],
126    primal_out: &[ValueKey<StdTensorOp>],
127    tangent_in: &[Option<LocalValueId>],
128    ctx: &mut context::ShapeGuardContext,
129) -> ADRuleResult<Vec<Option<LocalValueId>>> {
130    if let StdTensorOp::Extension(ext) = op {
131        return linearize_extension_rule(
132            ext.as_ref(),
133            builder,
134            primal_in,
135            primal_out,
136            tangent_in,
137            ctx,
138        );
139    }
140
141    let kind = op
142        .primitive_kind()
143        .ok_or_else(|| missing_primitive_kind(op, ADRuleKind::Jvp))?;
144    let rule = registry::primitive_ad_rule(kind)
145        .ok_or_else(|| registry::missing_rule(kind, ADRuleKind::Jvp))?;
146    rule.linearize(op, builder, primal_in, primal_out, tangent_in, ctx)
147}
148
149/// Reverse-mode AD (VJP) for `StdTensorOp`: given the primal op, its
150/// inputs, and the output cotangent, emit the transposed graph and
151/// return the input cotangents.
152///
153/// See [`linearize`] for the category split; the same categories appear
154/// here.
155#[cfg(feature = "autodiff")]
156pub fn transpose_rule(
157    op: &StdTensorOp,
158    builder: &mut impl PrimitiveRuleBuilder,
159    cotangent_out: &[Option<LocalValueId>],
160    inputs: &[ValueRef<StdTensorOp>],
161    mode: &OperationRole,
162    ctx: &mut context::ShapeGuardContext,
163) -> ADRuleResult<Vec<Option<LocalValueId>>> {
164    if let StdTensorOp::Extension(ext) = op {
165        let builder_dyn: &mut dyn PrimitiveRuleBuilder = builder;
166        return transpose_extension_rule(
167            ext.as_ref(),
168            builder_dyn,
169            cotangent_out,
170            inputs,
171            mode,
172            ctx,
173        );
174    }
175
176    let kind = op
177        .primitive_kind()
178        .ok_or_else(|| missing_primitive_kind(op, ADRuleKind::Transpose))?;
179    let rule = registry::primitive_ad_rule(kind)
180        .ok_or_else(|| registry::missing_rule(kind, ADRuleKind::Transpose))?;
181    let builder_dyn: &mut dyn PrimitiveRuleBuilder = builder;
182    rule.transpose_rule(op, builder_dyn, cotangent_out, inputs, mode, ctx)
183}
184
185#[cfg(test)]
186mod tests;