Skip to main content

tidu/rules/
primitive_op.rs

1use super::{ADKey, ADRuleResult, PrimitiveBuilder, PrimitiveValue};
2use computegraph::{GraphOperation, LocalValueId, OperationRole, ValueKey};
3
4/// Extends `GraphOperation` with primitive JVP and transpose rules for AD.
5///
6/// - `try_jvp_rule` is called by [`crate::try_linearize`]
7/// - `try_linear_transpose_rule` is called by [`crate::try_linear_transpose`]
8///
9/// Both methods add new primitive applications through a [`PrimitiveBuilder`]. The downstream
10/// implementor is responsible for ensuring closure: every op emitted must also
11/// implement `Primitive`.
12///
13/// # Examples
14///
15/// ```
16/// use computegraph::{ValueKey, GraphOperation, LocalValueId, OperationRole};
17/// use tidu::{ADKey, DiffPassId, Primitive, PrimitiveBuilder, PrimitiveValue};
18///
19/// #[derive(Clone, Debug, PartialEq, Eq, Hash)]
20/// enum Key { Base(String), Tan(Box<Key>, DiffPassId) }
21///
22/// impl ADKey for Key {
23///     fn tangent_of(&self, p: DiffPassId) -> Self { Key::Tan(Box::new(self.clone()), p) }
24/// }
25///
26/// #[derive(Clone, Debug, PartialEq, Eq, Hash)]
27/// struct AddOp;
28///
29/// impl GraphOperation for AddOp {
30///     type Operand = f64;
31///     type Context = ();
32///     type InputKey = Key;
33///     fn input_count(&self) -> usize { 2 }
34///     fn output_count(&self) -> usize { 1 }
35/// }
36///
37/// impl Primitive for AddOp {
38///     type ADContext = ();
39///
40///     fn add() -> Self { AddOp }
41///     fn jvp_rule(
42///         &self, _b: &mut impl PrimitiveBuilder<Self>,
43///         _pi: &[ValueKey<Self>], _po: &[ValueKey<Self>],
44///         t: &[Option<LocalValueId>],
45///         _ctx: &mut (),
46///     ) -> Vec<Option<LocalValueId>> {
47///         vec![t[0].or(t[1])]
48///     }
49///     fn transpose_rule(
50///         &self, _builder: &mut impl PrimitiveBuilder<Self>,
51///         ct: &[Option<LocalValueId>], _i: &[PrimitiveValue<Self>], _m: &OperationRole,
52///         _ctx: &mut (),
53///     ) -> Vec<Option<LocalValueId>> {
54///         vec![ct[0], ct[0]]
55///     }
56/// }
57/// ```
58pub trait Primitive: GraphOperation
59where
60    Self::InputKey: ADKey,
61{
62    /// Runtime AD context threaded through linearization and transposition.
63    ///
64    /// This can carry information such as concrete shapes or guard decisions
65    /// that influence how AD rules emit graph structure.
66    type ADContext: Default;
67
68    /// Returns the addition operation used for cotangent accumulation
69    /// in [`crate::linear_transpose`]. When multiple cotangents flow to the same
70    /// `ValueKey`, `linear_transpose` emits `Op::add()` nodes to sum them.
71    fn add() -> Self
72    where
73        Self: Sized;
74
75    /// Emit the JVP rule for this primitive.
76    ///
77    /// Must be linear in tangent inputs. May reference primal inputs/outputs
78    /// through `External(ValueKey)`. Must emit ops in `OperationRole::Linearized`.
79    fn jvp_rule(
80        &self,
81        builder: &mut impl PrimitiveBuilder<Self>,
82        primal_inputs: &[ValueKey<Self>],
83        primal_outputs: &[ValueKey<Self>],
84        tangent_inputs: &[Option<LocalValueId>],
85        ctx: &mut Self::ADContext,
86    ) -> Vec<Option<LocalValueId>>
87    where
88        Self: Sized;
89
90    /// Fallible variant of [`Primitive::jvp_rule`].
91    ///
92    /// Implementors that can encounter missing extension rules should override
93    /// this method and return [`super::ADRuleError`] instead of panicking. The
94    /// default implementation preserves the infallible contract for existing
95    /// primitive sets.
96    fn try_jvp_rule(
97        &self,
98        builder: &mut impl PrimitiveBuilder<Self>,
99        primal_inputs: &[ValueKey<Self>],
100        primal_outputs: &[ValueKey<Self>],
101        tangent_inputs: &[Option<LocalValueId>],
102        ctx: &mut Self::ADContext,
103    ) -> ADRuleResult<Vec<Option<LocalValueId>>>
104    where
105        Self: Sized,
106    {
107        Ok(self.jvp_rule(builder, primal_inputs, primal_outputs, tangent_inputs, ctx))
108    }
109
110    /// Emit the transpose rule for this linear primitive.
111    ///
112    /// Receives cotangent outputs and produces cotangent inputs.
113    /// Must only emit ops that themselves implement `Primitive`.
114    fn transpose_rule(
115        &self,
116        builder: &mut impl PrimitiveBuilder<Self>,
117        cotangent_outputs: &[Option<LocalValueId>],
118        inputs: &[PrimitiveValue<Self>],
119        role: &OperationRole,
120        ctx: &mut Self::ADContext,
121    ) -> Vec<Option<LocalValueId>>
122    where
123        Self: Sized;
124
125    /// Fallible variant of [`Primitive::transpose_rule`].
126    ///
127    /// Implementors that can encounter missing extension rules should override
128    /// this method and return [`super::ADRuleError`] instead of panicking. The
129    /// default implementation preserves the infallible contract for existing
130    /// primitive sets.
131    fn try_linear_transpose_rule(
132        &self,
133        builder: &mut impl PrimitiveBuilder<Self>,
134        cotangent_outputs: &[Option<LocalValueId>],
135        inputs: &[PrimitiveValue<Self>],
136        role: &OperationRole,
137        ctx: &mut Self::ADContext,
138    ) -> ADRuleResult<Vec<Option<LocalValueId>>>
139    where
140        Self: Sized,
141    {
142        Ok(self.transpose_rule(builder, cotangent_outputs, inputs, role, ctx))
143    }
144}