1pub mod context;
8
9#[cfg(feature = "autodiff")]
10mod analytic;
11#[cfg(feature = "autodiff")]
12mod contraction;
13#[cfg(feature = "autodiff")]
14mod diagonal;
15#[cfg(feature = "autodiff")]
16mod dynamic;
17#[cfg(feature = "autodiff")]
18mod elementwise;
19#[cfg(feature = "autodiff")]
20mod indexing;
21#[cfg(feature = "autodiff")]
22pub(crate) mod registry;
23#[cfg(feature = "autodiff")]
24mod semiring;
25#[cfg(feature = "autodiff")]
26mod structural;
27#[cfg(feature = "autodiff")]
28#[doc(hidden)]
29pub mod support;
30#[cfg(feature = "autodiff")]
31mod zeros;
32
33#[cfg(feature = "autodiff")]
34use computegraph::graph::GraphBuilder;
35#[cfg(feature = "autodiff")]
36use computegraph::types::{LocalValueId, OperationRole, ValueKey, ValueRef};
37#[cfg(feature = "autodiff")]
38use tidu::{ADRuleError, ADRuleKind, ADRuleResult, PrimitiveBuilder, PrimitiveValue};
39
40#[cfg(feature = "autodiff")]
41use crate::ext_op::{linearize_extension_rule, transpose_extension_rule};
42#[cfg(feature = "autodiff")]
43use crate::std_tensor_op::StdTensorOp;
44
45#[cfg(feature = "autodiff")]
46fn missing_primitive_kind(op: &StdTensorOp, rule: ADRuleKind) -> ADRuleError {
47 ADRuleError::invalid_input(
48 "tenferro-internal-ops primitive AD dispatch",
49 rule,
50 format!("non-extension operation has no primitive kind: {op:?}"),
51 )
52}
53
54#[cfg(feature = "autodiff")]
76pub trait PrimitiveRuleBuilder {
77 fn add_operation(
79 &mut self,
80 operation: StdTensorOp,
81 inputs: Vec<ValueRef<StdTensorOp>>,
82 role: OperationRole,
83 ) -> Vec<LocalValueId>;
84}
85
86#[cfg(feature = "autodiff")]
87impl<B> PrimitiveRuleBuilder for B
88where
89 B: PrimitiveBuilder<StdTensorOp> + ?Sized,
90{
91 fn add_operation(
92 &mut self,
93 operation: StdTensorOp,
94 inputs: Vec<ValueRef<StdTensorOp>>,
95 role: OperationRole,
96 ) -> Vec<LocalValueId> {
97 let inputs = inputs.into_iter().map(PrimitiveValue::from).collect();
98 PrimitiveBuilder::add_primitive(self, operation, inputs, role)
99 }
100}
101
102#[cfg(feature = "autodiff")]
103impl PrimitiveRuleBuilder for GraphBuilder<StdTensorOp> {
104 fn add_operation(
105 &mut self,
106 operation: StdTensorOp,
107 inputs: Vec<ValueRef<StdTensorOp>>,
108 role: OperationRole,
109 ) -> Vec<LocalValueId> {
110 GraphBuilder::add_operation(self, operation, inputs, role)
111 }
112}
113
114#[cfg(feature = "autodiff")]
122pub fn linearize(
123 op: &StdTensorOp,
124 builder: &mut dyn PrimitiveRuleBuilder,
125 primal_in: &[ValueKey<StdTensorOp>],
126 primal_out: &[ValueKey<StdTensorOp>],
127 tangent_in: &[Option<LocalValueId>],
128 ctx: &mut context::ShapeGuardContext,
129) -> ADRuleResult<Vec<Option<LocalValueId>>> {
130 if let StdTensorOp::Extension(ext) = op {
131 return linearize_extension_rule(
132 ext.as_ref(),
133 builder,
134 primal_in,
135 primal_out,
136 tangent_in,
137 ctx,
138 );
139 }
140
141 let kind = op
142 .primitive_kind()
143 .ok_or_else(|| missing_primitive_kind(op, ADRuleKind::Jvp))?;
144 let rule = registry::primitive_ad_rule(kind)
145 .ok_or_else(|| registry::missing_rule(kind, ADRuleKind::Jvp))?;
146 rule.linearize(op, builder, primal_in, primal_out, tangent_in, ctx)
147}
148
149#[cfg(feature = "autodiff")]
156pub fn transpose_rule(
157 op: &StdTensorOp,
158 builder: &mut impl PrimitiveRuleBuilder,
159 cotangent_out: &[Option<LocalValueId>],
160 inputs: &[ValueRef<StdTensorOp>],
161 mode: &OperationRole,
162 ctx: &mut context::ShapeGuardContext,
163) -> ADRuleResult<Vec<Option<LocalValueId>>> {
164 if let StdTensorOp::Extension(ext) = op {
165 let builder_dyn: &mut dyn PrimitiveRuleBuilder = builder;
166 return transpose_extension_rule(
167 ext.as_ref(),
168 builder_dyn,
169 cotangent_out,
170 inputs,
171 mode,
172 ctx,
173 );
174 }
175
176 let kind = op
177 .primitive_kind()
178 .ok_or_else(|| missing_primitive_kind(op, ADRuleKind::Transpose))?;
179 let rule = registry::primitive_ad_rule(kind)
180 .ok_or_else(|| registry::missing_rule(kind, ADRuleKind::Transpose))?;
181 let builder_dyn: &mut dyn PrimitiveRuleBuilder = builder;
182 rule.transpose_rule(op, builder_dyn, cotangent_out, inputs, mode, ctx)
183}
184
185#[cfg(test)]
186mod tests;