1pub mod context;
2
3mod analytic;
4mod contraction;
5mod diagonal;
6mod dynamic;
7mod elementwise_tier2;
8mod linalg;
9mod semiring;
10mod structural;
11
12use computegraph::fragment::FragmentBuilder;
13use computegraph::types::{GlobalValKey, LocalValId, OpMode, ValRef};
14use computegraph::OpEmitter;
15
16use crate::std_tensor_op::StdTensorOp;
17
18fn linearize_non_semiring(
19 op: &StdTensorOp,
20 builder: &mut FragmentBuilder<StdTensorOp>,
21 primal_in: &[GlobalValKey<StdTensorOp>],
22 primal_out: &[GlobalValKey<StdTensorOp>],
23 tangent_in: &[Option<LocalValId>],
24 ctx: &mut context::ShapeGuardContext,
25) -> Option<Vec<Option<LocalValId>>> {
26 Some(match op {
27 StdTensorOp::Div => {
28 elementwise_tier2::linearize_div(builder, primal_in, primal_out, tangent_in)
29 }
30 StdTensorOp::Abs => elementwise_tier2::linearize_abs(builder, primal_in, tangent_in),
31 StdTensorOp::Sign => elementwise_tier2::linearize_sign(builder, tangent_in),
32 StdTensorOp::Constant { .. } => vec![None],
33 StdTensorOp::Compare(_) => vec![None],
34 StdTensorOp::Exp => analytic::linearize_exp(builder, primal_out, tangent_in),
35 StdTensorOp::Log => analytic::linearize_log(builder, primal_in, tangent_in),
36 StdTensorOp::Sin => analytic::linearize_sin(builder, primal_in, tangent_in),
37 StdTensorOp::Cos => analytic::linearize_cos(builder, primal_in, tangent_in),
38 StdTensorOp::Tanh => analytic::linearize_tanh(builder, primal_out, tangent_in),
39 StdTensorOp::Sqrt => analytic::linearize_sqrt(builder, primal_out, tangent_in),
40 StdTensorOp::Rsqrt => analytic::linearize_rsqrt(builder, primal_in, primal_out, tangent_in),
41 StdTensorOp::Pow => analytic::linearize_pow(builder, primal_in, primal_out, tangent_in),
42 StdTensorOp::Expm1 => analytic::linearize_expm1(builder, primal_out, tangent_in),
43 StdTensorOp::Log1p => analytic::linearize_log1p(builder, primal_in, tangent_in),
44 StdTensorOp::DotGeneral(config) => {
45 contraction::linearize_dot_general(builder, primal_in, tangent_in, config)
46 }
47 StdTensorOp::NaryEinsum { subscripts, .. } => {
48 contraction::linearize_nary_einsum(builder, primal_in, tangent_in, subscripts)
49 }
50 StdTensorOp::ReduceSum { axes, .. } => {
51 contraction::linearize_reduce_sum(builder, tangent_in, op, axes)
52 }
53 StdTensorOp::ReduceProd { axes, input_shape } => contraction::linearize_reduce_prod(
54 builder,
55 primal_in,
56 primal_out,
57 tangent_in,
58 axes,
59 input_shape,
60 ),
61 StdTensorOp::ReduceMax { axes, input_shape }
62 | StdTensorOp::ReduceMin { axes, input_shape } => contraction::linearize_reduce_chooser(
63 builder,
64 primal_in,
65 primal_out,
66 tangent_in,
67 axes,
68 input_shape,
69 ),
70 StdTensorOp::Transpose { perm } => {
71 structural::linearize_transpose(builder, tangent_in, perm)
72 }
73 StdTensorOp::Reshape { .. } => {
74 structural::linearize_reshape(builder, primal_in, tangent_in, op)
75 }
76 StdTensorOp::BroadcastInDim { shape, dims } => {
77 structural::linearize_broadcast_in_dim(builder, primal_in, tangent_in, shape, dims)
78 }
79 StdTensorOp::Convert { from, to } => {
80 structural::linearize_convert(builder, tangent_in, *from, *to)
81 }
82 StdTensorOp::ExtractDiag { axis_a, axis_b } => {
83 diagonal::linearize_extract_diag(builder, tangent_in, *axis_a, *axis_b)
84 }
85 StdTensorOp::EmbedDiag { axis_a, axis_b } => {
86 diagonal::linearize_embed_diag(builder, tangent_in, *axis_a, *axis_b)
87 }
88 StdTensorOp::Tril { k } => structural::linearize_tril(builder, tangent_in, *k),
89 StdTensorOp::Triu { k } => structural::linearize_triu(builder, tangent_in, *k),
90 StdTensorOp::Pad(config) => structural::linearize_pad(builder, tangent_in, config),
91 StdTensorOp::DynamicTruncate { axis } => {
92 dynamic::linearize_dynamic_truncate(builder, primal_in, tangent_in, *axis)
93 }
94 StdTensorOp::PadToMatch { axis } => {
95 dynamic::linearize_pad_to_match(builder, primal_in, tangent_in, *axis)
96 }
97 StdTensorOp::ShapeOf { .. } => vec![None],
98 StdTensorOp::Lu { input_shape } => {
99 linalg::linearize_lu(builder, primal_out, tangent_in, input_shape, ctx)
100 }
101 StdTensorOp::TriangularSolve {
102 left_side,
103 lower,
104 transpose_a,
105 unit_diagonal,
106 lhs_shape,
107 rhs_shape,
108 } => linalg::linearize_triangular_solve(
109 builder,
110 primal_in,
111 primal_out,
112 tangent_in,
113 *left_side,
114 *lower,
115 *transpose_a,
116 *unit_diagonal,
117 lhs_shape,
118 rhs_shape,
119 ),
120 StdTensorOp::Cholesky { input_shape } => {
121 linalg::linearize_cholesky(builder, primal_out, tangent_in, input_shape)
122 }
123 StdTensorOp::Svd { eps, input_shape } => {
124 linalg::linearize_svd(builder, primal_out, tangent_in, *eps, input_shape, ctx)
125 }
126 StdTensorOp::Qr { input_shape } => {
127 linalg::linearize_qr(builder, primal_out, tangent_in, input_shape, ctx)
128 }
129 StdTensorOp::Eigh { eps, input_shape } => {
130 linalg::linearize_eigh(builder, primal_out, tangent_in, *eps, input_shape)
131 }
132 StdTensorOp::Eig {
133 input_dtype,
134 input_shape,
135 } => linalg::linearize_eig(builder, primal_out, tangent_in, *input_dtype, input_shape),
136 StdTensorOp::ValidateNonsingular { .. } => vec![tangent_in[0]],
137 _ => return None,
138 })
139}
140
141fn linearize_semiring(
142 op: &StdTensorOp,
143 builder: &mut FragmentBuilder<StdTensorOp>,
144 primal_in: &[GlobalValKey<StdTensorOp>],
145 tangent_in: &[Option<LocalValId>],
146 _ctx: &mut context::ShapeGuardContext,
147) -> Option<Vec<Option<LocalValId>>> {
148 Some(match op {
149 StdTensorOp::Add => semiring::linearize_add(builder, tangent_in),
150 StdTensorOp::Mul => semiring::linearize_mul(builder, primal_in, tangent_in),
151 StdTensorOp::Neg => semiring::linearize_neg(builder, tangent_in),
152 StdTensorOp::Conj => semiring::linearize_conj(builder, tangent_in),
153 _ => return None,
154 })
155}
156
157fn transpose_non_semiring(
158 op: &StdTensorOp,
159 emitter: &mut impl OpEmitter<StdTensorOp>,
160 cotangent_out: &[Option<LocalValId>],
161 inputs: &[ValRef<StdTensorOp>],
162 mode: &OpMode,
163 ctx: &mut context::ShapeGuardContext,
164) -> Option<Vec<Option<LocalValId>>> {
165 let _ = ctx;
166 Some(match op {
167 StdTensorOp::Div => elementwise_tier2::transpose_div(emitter, cotangent_out, inputs, mode),
168 StdTensorOp::Abs => elementwise_tier2::transpose_abs(emitter, cotangent_out, inputs, mode),
169 StdTensorOp::Sign => elementwise_tier2::transpose_sign(emitter, cotangent_out, mode),
170 StdTensorOp::Constant { .. } => vec![],
171 StdTensorOp::Compare(_) => vec![None, None],
172 StdTensorOp::Exp => analytic::transpose_exp(emitter, cotangent_out, inputs, mode),
173 StdTensorOp::Log => analytic::transpose_log(emitter, cotangent_out, inputs, mode),
174 StdTensorOp::Sin => analytic::transpose_sin(emitter, cotangent_out, inputs, mode),
175 StdTensorOp::Cos => analytic::transpose_cos(emitter, cotangent_out, inputs, mode),
176 StdTensorOp::Tanh => analytic::transpose_tanh(emitter, cotangent_out, inputs, mode),
177 StdTensorOp::Sqrt => analytic::transpose_sqrt(emitter, cotangent_out, inputs, mode),
178 StdTensorOp::Rsqrt => analytic::transpose_rsqrt(emitter, cotangent_out, inputs, mode),
179 StdTensorOp::Pow => analytic::transpose_pow(emitter, cotangent_out, inputs, mode),
180 StdTensorOp::Expm1 => analytic::transpose_expm1(emitter, cotangent_out, inputs, mode),
181 StdTensorOp::Log1p => analytic::transpose_log1p(emitter, cotangent_out, inputs, mode),
182 StdTensorOp::DotGeneral(config) => {
183 contraction::transpose_dot_general(emitter, cotangent_out, inputs, mode, config)
184 }
185 StdTensorOp::NaryEinsum {
186 subscripts,
187 n_inputs,
188 } => contraction::transpose_nary_einsum(
189 emitter,
190 cotangent_out,
191 inputs,
192 mode,
193 subscripts,
194 *n_inputs,
195 ),
196 StdTensorOp::ReduceSum { .. } => {
197 contraction::transpose_reduce_sum(emitter, cotangent_out, op, inputs)
198 }
199 StdTensorOp::ReduceProd { .. } => {
200 contraction::transpose_reduce_prod(emitter, cotangent_out, inputs, op)
201 }
202 StdTensorOp::ReduceMax { .. } | StdTensorOp::ReduceMin { .. } => {
203 contraction::transpose_reduce_chooser(emitter, cotangent_out, inputs, op)
204 }
205 StdTensorOp::Transpose { perm } => {
206 structural::transpose_transpose(emitter, cotangent_out, perm)
207 }
208 StdTensorOp::Reshape { .. } => {
209 structural::transpose_reshape(emitter, cotangent_out, op, inputs)
210 }
211 StdTensorOp::BroadcastInDim { shape, dims } => {
212 structural::transpose_broadcast_in_dim(emitter, cotangent_out, shape, dims)
213 }
214 StdTensorOp::Convert { from, to } => {
215 structural::transpose_convert(emitter, cotangent_out, mode, *from, *to)
216 }
217 StdTensorOp::ExtractDiag { axis_a, axis_b } => {
218 diagonal::transpose_extract_diag(emitter, cotangent_out, *axis_a, *axis_b)
219 }
220 StdTensorOp::EmbedDiag { axis_a, axis_b } => {
221 diagonal::transpose_embed_diag(emitter, cotangent_out, *axis_a, *axis_b)
222 }
223 StdTensorOp::Tril { k } => structural::transpose_tril(emitter, cotangent_out, *k),
224 StdTensorOp::Triu { k } => structural::transpose_triu(emitter, cotangent_out, *k),
225 StdTensorOp::DynamicTruncate { axis } => {
226 dynamic::transpose_dynamic_truncate(emitter, cotangent_out, inputs, *axis)
227 }
228 StdTensorOp::PadToMatch { axis } => {
229 dynamic::transpose_pad_to_match(emitter, cotangent_out, inputs, *axis)
230 }
231 StdTensorOp::ShapeOf { .. } => vec![None],
232 StdTensorOp::TriangularSolve {
233 left_side,
234 lower,
235 transpose_a,
236 unit_diagonal,
237 lhs_shape,
238 rhs_shape,
239 } => linalg::transpose_triangular_solve(
240 emitter,
241 cotangent_out,
242 inputs,
243 mode,
244 *left_side,
245 *lower,
246 *transpose_a,
247 *unit_diagonal,
248 lhs_shape,
249 rhs_shape,
250 ),
251 StdTensorOp::ValidateNonsingular { .. } => vec![cotangent_out[0]],
252 _ => return None,
253 })
254}
255
256fn transpose_semiring(
257 op: &StdTensorOp,
258 emitter: &mut impl OpEmitter<StdTensorOp>,
259 cotangent_out: &[Option<LocalValId>],
260 inputs: &[ValRef<StdTensorOp>],
261 mode: &OpMode,
262 _ctx: &mut context::ShapeGuardContext,
263) -> Option<Vec<Option<LocalValId>>> {
264 Some(match op {
265 StdTensorOp::Add => semiring::transpose_add(cotangent_out),
266 StdTensorOp::Mul => semiring::transpose_mul(emitter, cotangent_out, inputs, mode),
267 StdTensorOp::Neg => semiring::transpose_neg(emitter, cotangent_out),
268 StdTensorOp::Conj => semiring::transpose_conj(emitter, cotangent_out),
269 _ => return None,
270 })
271}
272
273fn todo_linearize(op: &StdTensorOp) -> ! {
274 todo!("linearize not implemented for {:?}", op)
275}
276
277fn todo_transpose_rule(op: &StdTensorOp) -> ! {
278 todo!("transpose_rule not implemented for {:?}", op)
279}
280
281pub fn linearize(
282 op: &StdTensorOp,
283 builder: &mut FragmentBuilder<StdTensorOp>,
284 primal_in: &[GlobalValKey<StdTensorOp>],
285 primal_out: &[GlobalValKey<StdTensorOp>],
286 tangent_in: &[Option<LocalValId>],
287 ctx: &mut context::ShapeGuardContext,
288) -> Vec<Option<LocalValId>> {
289 if let Some(result) =
290 linearize_non_semiring(op, builder, primal_in, primal_out, tangent_in, ctx)
291 {
292 return result;
293 }
294 if let Some(result) = linearize_semiring(op, builder, primal_in, tangent_in, ctx) {
295 return result;
296 }
297 todo_linearize(op)
298}
299
300pub fn transpose_rule(
301 op: &StdTensorOp,
302 emitter: &mut impl OpEmitter<StdTensorOp>,
303 cotangent_out: &[Option<LocalValId>],
304 inputs: &[ValRef<StdTensorOp>],
305 mode: &OpMode,
306 ctx: &mut context::ShapeGuardContext,
307) -> Vec<Option<LocalValId>> {
308 if let Some(result) = transpose_non_semiring(op, emitter, cotangent_out, inputs, mode, ctx) {
309 return result;
310 }
311 if let Some(result) = transpose_semiring(op, emitter, cotangent_out, inputs, mode, ctx) {
312 return result;
313 }
314 todo_transpose_rule(op)
315}