Skip to main content

tenferro_ops/ad/
mod.rs

1pub mod context;
2
3mod analytic;
4mod contraction;
5mod diagonal;
6mod dynamic;
7mod elementwise_tier2;
8mod linalg;
9mod semiring;
10mod structural;
11
12use computegraph::fragment::FragmentBuilder;
13use computegraph::types::{GlobalValKey, LocalValId, OpMode, ValRef};
14use computegraph::OpEmitter;
15
16use crate::std_tensor_op::StdTensorOp;
17
18fn linearize_non_semiring(
19    op: &StdTensorOp,
20    builder: &mut FragmentBuilder<StdTensorOp>,
21    primal_in: &[GlobalValKey<StdTensorOp>],
22    primal_out: &[GlobalValKey<StdTensorOp>],
23    tangent_in: &[Option<LocalValId>],
24    ctx: &mut context::ShapeGuardContext,
25) -> Option<Vec<Option<LocalValId>>> {
26    Some(match op {
27        StdTensorOp::Div => {
28            elementwise_tier2::linearize_div(builder, primal_in, primal_out, tangent_in)
29        }
30        StdTensorOp::Abs => elementwise_tier2::linearize_abs(builder, primal_in, tangent_in),
31        StdTensorOp::Sign => elementwise_tier2::linearize_sign(builder, tangent_in),
32        StdTensorOp::Constant { .. } => vec![None],
33        StdTensorOp::Compare(_) => vec![None],
34        StdTensorOp::Exp => analytic::linearize_exp(builder, primal_out, tangent_in),
35        StdTensorOp::Log => analytic::linearize_log(builder, primal_in, tangent_in),
36        StdTensorOp::Sin => analytic::linearize_sin(builder, primal_in, tangent_in),
37        StdTensorOp::Cos => analytic::linearize_cos(builder, primal_in, tangent_in),
38        StdTensorOp::Tanh => analytic::linearize_tanh(builder, primal_out, tangent_in),
39        StdTensorOp::Sqrt => analytic::linearize_sqrt(builder, primal_out, tangent_in),
40        StdTensorOp::Rsqrt => analytic::linearize_rsqrt(builder, primal_in, primal_out, tangent_in),
41        StdTensorOp::Pow => analytic::linearize_pow(builder, primal_in, primal_out, tangent_in),
42        StdTensorOp::Expm1 => analytic::linearize_expm1(builder, primal_out, tangent_in),
43        StdTensorOp::Log1p => analytic::linearize_log1p(builder, primal_in, tangent_in),
44        StdTensorOp::DotGeneral(config) => {
45            contraction::linearize_dot_general(builder, primal_in, tangent_in, config)
46        }
47        StdTensorOp::NaryEinsum { subscripts, .. } => {
48            contraction::linearize_nary_einsum(builder, primal_in, tangent_in, subscripts)
49        }
50        StdTensorOp::ReduceSum { axes, .. } => {
51            contraction::linearize_reduce_sum(builder, tangent_in, op, axes)
52        }
53        StdTensorOp::ReduceProd { axes, input_shape } => contraction::linearize_reduce_prod(
54            builder,
55            primal_in,
56            primal_out,
57            tangent_in,
58            axes,
59            input_shape,
60        ),
61        StdTensorOp::ReduceMax { axes, input_shape }
62        | StdTensorOp::ReduceMin { axes, input_shape } => contraction::linearize_reduce_chooser(
63            builder,
64            primal_in,
65            primal_out,
66            tangent_in,
67            axes,
68            input_shape,
69        ),
70        StdTensorOp::Transpose { perm } => {
71            structural::linearize_transpose(builder, tangent_in, perm)
72        }
73        StdTensorOp::Reshape { .. } => {
74            structural::linearize_reshape(builder, primal_in, tangent_in, op)
75        }
76        StdTensorOp::BroadcastInDim { shape, dims } => {
77            structural::linearize_broadcast_in_dim(builder, primal_in, tangent_in, shape, dims)
78        }
79        StdTensorOp::Convert { from, to } => {
80            structural::linearize_convert(builder, tangent_in, *from, *to)
81        }
82        StdTensorOp::ExtractDiag { axis_a, axis_b } => {
83            diagonal::linearize_extract_diag(builder, tangent_in, *axis_a, *axis_b)
84        }
85        StdTensorOp::EmbedDiag { axis_a, axis_b } => {
86            diagonal::linearize_embed_diag(builder, tangent_in, *axis_a, *axis_b)
87        }
88        StdTensorOp::Tril { k } => structural::linearize_tril(builder, tangent_in, *k),
89        StdTensorOp::Triu { k } => structural::linearize_triu(builder, tangent_in, *k),
90        StdTensorOp::Pad(config) => structural::linearize_pad(builder, tangent_in, config),
91        StdTensorOp::DynamicTruncate { axis } => {
92            dynamic::linearize_dynamic_truncate(builder, primal_in, tangent_in, *axis)
93        }
94        StdTensorOp::PadToMatch { axis } => {
95            dynamic::linearize_pad_to_match(builder, primal_in, tangent_in, *axis)
96        }
97        StdTensorOp::ShapeOf { .. } => vec![None],
98        StdTensorOp::Lu { input_shape } => {
99            linalg::linearize_lu(builder, primal_out, tangent_in, input_shape, ctx)
100        }
101        StdTensorOp::TriangularSolve {
102            left_side,
103            lower,
104            transpose_a,
105            unit_diagonal,
106            lhs_shape,
107            rhs_shape,
108        } => linalg::linearize_triangular_solve(
109            builder,
110            primal_in,
111            primal_out,
112            tangent_in,
113            *left_side,
114            *lower,
115            *transpose_a,
116            *unit_diagonal,
117            lhs_shape,
118            rhs_shape,
119        ),
120        StdTensorOp::Cholesky { input_shape } => {
121            linalg::linearize_cholesky(builder, primal_out, tangent_in, input_shape)
122        }
123        StdTensorOp::Svd { eps, input_shape } => {
124            linalg::linearize_svd(builder, primal_out, tangent_in, *eps, input_shape, ctx)
125        }
126        StdTensorOp::Qr { input_shape } => {
127            linalg::linearize_qr(builder, primal_out, tangent_in, input_shape, ctx)
128        }
129        StdTensorOp::Eigh { eps, input_shape } => {
130            linalg::linearize_eigh(builder, primal_out, tangent_in, *eps, input_shape)
131        }
132        StdTensorOp::Eig {
133            input_dtype,
134            input_shape,
135        } => linalg::linearize_eig(builder, primal_out, tangent_in, *input_dtype, input_shape),
136        StdTensorOp::ValidateNonsingular { .. } => vec![tangent_in[0]],
137        _ => return None,
138    })
139}
140
141fn linearize_semiring(
142    op: &StdTensorOp,
143    builder: &mut FragmentBuilder<StdTensorOp>,
144    primal_in: &[GlobalValKey<StdTensorOp>],
145    tangent_in: &[Option<LocalValId>],
146    _ctx: &mut context::ShapeGuardContext,
147) -> Option<Vec<Option<LocalValId>>> {
148    Some(match op {
149        StdTensorOp::Add => semiring::linearize_add(builder, tangent_in),
150        StdTensorOp::Mul => semiring::linearize_mul(builder, primal_in, tangent_in),
151        StdTensorOp::Neg => semiring::linearize_neg(builder, tangent_in),
152        StdTensorOp::Conj => semiring::linearize_conj(builder, tangent_in),
153        _ => return None,
154    })
155}
156
157fn transpose_non_semiring(
158    op: &StdTensorOp,
159    emitter: &mut impl OpEmitter<StdTensorOp>,
160    cotangent_out: &[Option<LocalValId>],
161    inputs: &[ValRef<StdTensorOp>],
162    mode: &OpMode,
163    ctx: &mut context::ShapeGuardContext,
164) -> Option<Vec<Option<LocalValId>>> {
165    let _ = ctx;
166    Some(match op {
167        StdTensorOp::Div => elementwise_tier2::transpose_div(emitter, cotangent_out, inputs, mode),
168        StdTensorOp::Abs => elementwise_tier2::transpose_abs(emitter, cotangent_out, inputs, mode),
169        StdTensorOp::Sign => elementwise_tier2::transpose_sign(emitter, cotangent_out, mode),
170        StdTensorOp::Constant { .. } => vec![],
171        StdTensorOp::Compare(_) => vec![None, None],
172        StdTensorOp::Exp => analytic::transpose_exp(emitter, cotangent_out, inputs, mode),
173        StdTensorOp::Log => analytic::transpose_log(emitter, cotangent_out, inputs, mode),
174        StdTensorOp::Sin => analytic::transpose_sin(emitter, cotangent_out, inputs, mode),
175        StdTensorOp::Cos => analytic::transpose_cos(emitter, cotangent_out, inputs, mode),
176        StdTensorOp::Tanh => analytic::transpose_tanh(emitter, cotangent_out, inputs, mode),
177        StdTensorOp::Sqrt => analytic::transpose_sqrt(emitter, cotangent_out, inputs, mode),
178        StdTensorOp::Rsqrt => analytic::transpose_rsqrt(emitter, cotangent_out, inputs, mode),
179        StdTensorOp::Pow => analytic::transpose_pow(emitter, cotangent_out, inputs, mode),
180        StdTensorOp::Expm1 => analytic::transpose_expm1(emitter, cotangent_out, inputs, mode),
181        StdTensorOp::Log1p => analytic::transpose_log1p(emitter, cotangent_out, inputs, mode),
182        StdTensorOp::DotGeneral(config) => {
183            contraction::transpose_dot_general(emitter, cotangent_out, inputs, mode, config)
184        }
185        StdTensorOp::NaryEinsum {
186            subscripts,
187            n_inputs,
188        } => contraction::transpose_nary_einsum(
189            emitter,
190            cotangent_out,
191            inputs,
192            mode,
193            subscripts,
194            *n_inputs,
195        ),
196        StdTensorOp::ReduceSum { .. } => {
197            contraction::transpose_reduce_sum(emitter, cotangent_out, op, inputs)
198        }
199        StdTensorOp::ReduceProd { .. } => {
200            contraction::transpose_reduce_prod(emitter, cotangent_out, inputs, op)
201        }
202        StdTensorOp::ReduceMax { .. } | StdTensorOp::ReduceMin { .. } => {
203            contraction::transpose_reduce_chooser(emitter, cotangent_out, inputs, op)
204        }
205        StdTensorOp::Transpose { perm } => {
206            structural::transpose_transpose(emitter, cotangent_out, perm)
207        }
208        StdTensorOp::Reshape { .. } => {
209            structural::transpose_reshape(emitter, cotangent_out, op, inputs)
210        }
211        StdTensorOp::BroadcastInDim { shape, dims } => {
212            structural::transpose_broadcast_in_dim(emitter, cotangent_out, shape, dims)
213        }
214        StdTensorOp::Convert { from, to } => {
215            structural::transpose_convert(emitter, cotangent_out, mode, *from, *to)
216        }
217        StdTensorOp::ExtractDiag { axis_a, axis_b } => {
218            diagonal::transpose_extract_diag(emitter, cotangent_out, *axis_a, *axis_b)
219        }
220        StdTensorOp::EmbedDiag { axis_a, axis_b } => {
221            diagonal::transpose_embed_diag(emitter, cotangent_out, *axis_a, *axis_b)
222        }
223        StdTensorOp::Tril { k } => structural::transpose_tril(emitter, cotangent_out, *k),
224        StdTensorOp::Triu { k } => structural::transpose_triu(emitter, cotangent_out, *k),
225        StdTensorOp::DynamicTruncate { axis } => {
226            dynamic::transpose_dynamic_truncate(emitter, cotangent_out, inputs, *axis)
227        }
228        StdTensorOp::PadToMatch { axis } => {
229            dynamic::transpose_pad_to_match(emitter, cotangent_out, inputs, *axis)
230        }
231        StdTensorOp::ShapeOf { .. } => vec![None],
232        StdTensorOp::TriangularSolve {
233            left_side,
234            lower,
235            transpose_a,
236            unit_diagonal,
237            lhs_shape,
238            rhs_shape,
239        } => linalg::transpose_triangular_solve(
240            emitter,
241            cotangent_out,
242            inputs,
243            mode,
244            *left_side,
245            *lower,
246            *transpose_a,
247            *unit_diagonal,
248            lhs_shape,
249            rhs_shape,
250        ),
251        StdTensorOp::ValidateNonsingular { .. } => vec![cotangent_out[0]],
252        _ => return None,
253    })
254}
255
256fn transpose_semiring(
257    op: &StdTensorOp,
258    emitter: &mut impl OpEmitter<StdTensorOp>,
259    cotangent_out: &[Option<LocalValId>],
260    inputs: &[ValRef<StdTensorOp>],
261    mode: &OpMode,
262    _ctx: &mut context::ShapeGuardContext,
263) -> Option<Vec<Option<LocalValId>>> {
264    Some(match op {
265        StdTensorOp::Add => semiring::transpose_add(cotangent_out),
266        StdTensorOp::Mul => semiring::transpose_mul(emitter, cotangent_out, inputs, mode),
267        StdTensorOp::Neg => semiring::transpose_neg(emitter, cotangent_out),
268        StdTensorOp::Conj => semiring::transpose_conj(emitter, cotangent_out),
269        _ => return None,
270    })
271}
272
273fn todo_linearize(op: &StdTensorOp) -> ! {
274    todo!("linearize not implemented for {:?}", op)
275}
276
277fn todo_transpose_rule(op: &StdTensorOp) -> ! {
278    todo!("transpose_rule not implemented for {:?}", op)
279}
280
281pub fn linearize(
282    op: &StdTensorOp,
283    builder: &mut FragmentBuilder<StdTensorOp>,
284    primal_in: &[GlobalValKey<StdTensorOp>],
285    primal_out: &[GlobalValKey<StdTensorOp>],
286    tangent_in: &[Option<LocalValId>],
287    ctx: &mut context::ShapeGuardContext,
288) -> Vec<Option<LocalValId>> {
289    if let Some(result) =
290        linearize_non_semiring(op, builder, primal_in, primal_out, tangent_in, ctx)
291    {
292        return result;
293    }
294    if let Some(result) = linearize_semiring(op, builder, primal_in, tangent_in, ctx) {
295        return result;
296    }
297    todo_linearize(op)
298}
299
300pub fn transpose_rule(
301    op: &StdTensorOp,
302    emitter: &mut impl OpEmitter<StdTensorOp>,
303    cotangent_out: &[Option<LocalValId>],
304    inputs: &[ValRef<StdTensorOp>],
305    mode: &OpMode,
306    ctx: &mut context::ShapeGuardContext,
307) -> Vec<Option<LocalValId>> {
308    if let Some(result) = transpose_non_semiring(op, emitter, cotangent_out, inputs, mode, ctx) {
309        return result;
310    }
311    if let Some(result) = transpose_semiring(op, emitter, cotangent_out, inputs, mode, ctx) {
312        return result;
313    }
314    todo_transpose_rule(op)
315}