tidu/rules/primitive_op.rs
1use super::{ADKey, ADRuleResult, PrimitiveBuilder, PrimitiveValue};
2use computegraph::{GraphOperation, LocalValueId, OperationRole, ValueKey};
3
4/// Extends `GraphOperation` with primitive JVP and transpose rules for AD.
5///
6/// - `try_jvp_rule` is called by [`crate::try_linearize`]
7/// - `try_linear_transpose_rule` is called by [`crate::try_linear_transpose`]
8///
9/// Both methods add new primitive applications through a [`PrimitiveBuilder`]. The downstream
10/// implementor is responsible for ensuring closure: every op emitted must also
11/// implement `Primitive`.
12///
13/// # Examples
14///
15/// ```
16/// use computegraph::{ValueKey, GraphOperation, LocalValueId, OperationRole};
17/// use tidu::{ADKey, DiffPassId, Primitive, PrimitiveBuilder, PrimitiveValue};
18///
19/// #[derive(Clone, Debug, PartialEq, Eq, Hash)]
20/// enum Key { Base(String), Tan(Box<Key>, DiffPassId) }
21///
22/// impl ADKey for Key {
23/// fn tangent_of(&self, p: DiffPassId) -> Self { Key::Tan(Box::new(self.clone()), p) }
24/// }
25///
26/// #[derive(Clone, Debug, PartialEq, Eq, Hash)]
27/// struct AddOp;
28///
29/// impl GraphOperation for AddOp {
30/// type Operand = f64;
31/// type Context = ();
32/// type InputKey = Key;
33/// fn input_count(&self) -> usize { 2 }
34/// fn output_count(&self) -> usize { 1 }
35/// }
36///
37/// impl Primitive for AddOp {
38/// type ADContext = ();
39///
40/// fn add() -> Self { AddOp }
41/// fn jvp_rule(
42/// &self, _b: &mut impl PrimitiveBuilder<Self>,
43/// _pi: &[ValueKey<Self>], _po: &[ValueKey<Self>],
44/// t: &[Option<LocalValueId>],
45/// _ctx: &mut (),
46/// ) -> Vec<Option<LocalValueId>> {
47/// vec![t[0].or(t[1])]
48/// }
49/// fn transpose_rule(
50/// &self, _builder: &mut impl PrimitiveBuilder<Self>,
51/// ct: &[Option<LocalValueId>], _i: &[PrimitiveValue<Self>], _m: &OperationRole,
52/// _ctx: &mut (),
53/// ) -> Vec<Option<LocalValueId>> {
54/// vec![ct[0], ct[0]]
55/// }
56/// }
57/// ```
58pub trait Primitive: GraphOperation
59where
60 Self::InputKey: ADKey,
61{
62 /// Runtime AD context threaded through linearization and transposition.
63 ///
64 /// This can carry information such as concrete shapes or guard decisions
65 /// that influence how AD rules emit graph structure.
66 type ADContext: Default;
67
68 /// Returns the addition operation used for cotangent accumulation
69 /// in [`crate::linear_transpose`]. When multiple cotangents flow to the same
70 /// `ValueKey`, `linear_transpose` emits `Op::add()` nodes to sum them.
71 fn add() -> Self
72 where
73 Self: Sized;
74
75 /// Emit the JVP rule for this primitive.
76 ///
77 /// Must be linear in tangent inputs. May reference primal inputs/outputs
78 /// through `External(ValueKey)`. Must emit ops in `OperationRole::Linearized`.
79 fn jvp_rule(
80 &self,
81 builder: &mut impl PrimitiveBuilder<Self>,
82 primal_inputs: &[ValueKey<Self>],
83 primal_outputs: &[ValueKey<Self>],
84 tangent_inputs: &[Option<LocalValueId>],
85 ctx: &mut Self::ADContext,
86 ) -> Vec<Option<LocalValueId>>
87 where
88 Self: Sized;
89
90 /// Fallible variant of [`Primitive::jvp_rule`].
91 ///
92 /// Implementors that can encounter missing extension rules should override
93 /// this method and return [`super::ADRuleError`] instead of panicking. The
94 /// default implementation preserves the infallible contract for existing
95 /// primitive sets.
96 fn try_jvp_rule(
97 &self,
98 builder: &mut impl PrimitiveBuilder<Self>,
99 primal_inputs: &[ValueKey<Self>],
100 primal_outputs: &[ValueKey<Self>],
101 tangent_inputs: &[Option<LocalValueId>],
102 ctx: &mut Self::ADContext,
103 ) -> ADRuleResult<Vec<Option<LocalValueId>>>
104 where
105 Self: Sized,
106 {
107 Ok(self.jvp_rule(builder, primal_inputs, primal_outputs, tangent_inputs, ctx))
108 }
109
110 /// Emit the transpose rule for this linear primitive.
111 ///
112 /// Receives cotangent outputs and produces cotangent inputs.
113 /// Must only emit ops that themselves implement `Primitive`.
114 fn transpose_rule(
115 &self,
116 builder: &mut impl PrimitiveBuilder<Self>,
117 cotangent_outputs: &[Option<LocalValueId>],
118 inputs: &[PrimitiveValue<Self>],
119 role: &OperationRole,
120 ctx: &mut Self::ADContext,
121 ) -> Vec<Option<LocalValueId>>
122 where
123 Self: Sized;
124
125 /// Fallible variant of [`Primitive::transpose_rule`].
126 ///
127 /// Implementors that can encounter missing extension rules should override
128 /// this method and return [`super::ADRuleError`] instead of panicking. The
129 /// default implementation preserves the infallible contract for existing
130 /// primitive sets.
131 fn try_linear_transpose_rule(
132 &self,
133 builder: &mut impl PrimitiveBuilder<Self>,
134 cotangent_outputs: &[Option<LocalValueId>],
135 inputs: &[PrimitiveValue<Self>],
136 role: &OperationRole,
137 ctx: &mut Self::ADContext,
138 ) -> ADRuleResult<Vec<Option<LocalValueId>>>
139 where
140 Self: Sized,
141 {
142 Ok(self.transpose_rule(builder, cotangent_outputs, inputs, role, ctx))
143 }
144}