1use std::hash::{Hash, Hasher};
2
3use chainrules_core::PrimitiveOp;
4use computegraph::fragment::FragmentBuilder;
5use computegraph::types::{GlobalValKey, LocalValId, OpMode, ValRef};
6use computegraph::{GraphOp, OpEmitter};
7use num_complex::{Complex32, Complex64};
8
9use crate::dim_expr::DimExpr;
10use crate::input_key::TensorInputKey;
11use crate::semiring_ops::SemiringOps;
12use tenferro_tensor::{
13 CompareDir, DType, DotGeneralConfig, GatherConfig, PadConfig, ScatterConfig, SliceConfig,
14};
15
16#[derive(Clone, Debug, PartialEq)]
17pub enum StdTensorOp {
18 Add,
20 Mul,
21 Neg,
22 Conj,
23 DotGeneral(DotGeneralConfig),
24 Transpose {
25 perm: Vec<usize>,
26 },
27 Reshape {
28 from_shape: Vec<DimExpr>,
29 to_shape: Vec<DimExpr>,
30 },
31 BroadcastInDim {
32 shape: Vec<DimExpr>,
33 dims: Vec<usize>,
34 },
35 Convert {
36 from: DType,
37 to: DType,
38 },
39 Constant {
40 dtype: DType,
41 bytes: Vec<u8>,
42 },
43 ReduceSum {
44 axes: Vec<usize>,
45 input_shape: Vec<DimExpr>,
46 },
47
48 Div,
50 Abs,
51 Sign,
52 Maximum,
53 Minimum,
54 Compare(CompareDir),
55 Select,
56 Clamp,
57
58 Exp,
60 Log,
61 Sin,
62 Cos,
63 Tanh,
64 Sqrt,
65 Rsqrt,
66 Pow,
67 Expm1,
68 Log1p,
69
70 ExtractDiag {
72 axis_a: usize,
73 axis_b: usize,
74 },
75 EmbedDiag {
76 axis_a: usize,
77 axis_b: usize,
78 },
79 Tril {
80 k: i64,
81 },
82 Triu {
83 k: i64,
84 },
85
86 Gather(GatherConfig),
88 Scatter(ScatterConfig),
89 Slice(SliceConfig),
90 DynamicSlice {
91 slice_sizes: Vec<usize>,
92 },
93 Pad(PadConfig),
94 NaryEinsum {
97 subscripts: String,
98 n_inputs: usize,
99 },
100 Concatenate {
101 axis: usize,
102 },
103 Reverse {
104 axes: Vec<usize>,
105 },
106 ShapeOf {
107 axis: usize,
108 },
109 DynamicTruncate {
110 axis: usize,
111 },
112 PadToMatch {
113 axis: usize,
114 },
115
116 ReduceProd {
118 axes: Vec<usize>,
119 input_shape: Vec<DimExpr>,
120 },
121 ReduceMax {
122 axes: Vec<usize>,
123 input_shape: Vec<DimExpr>,
124 },
125 ReduceMin {
126 axes: Vec<usize>,
127 input_shape: Vec<DimExpr>,
128 },
129
130 Cholesky {
132 input_shape: Vec<DimExpr>,
133 },
134 Lu {
135 input_shape: Vec<DimExpr>,
136 },
137 Svd {
138 eps: f64,
139 input_shape: Vec<DimExpr>,
140 },
141 Qr {
142 input_shape: Vec<DimExpr>,
143 },
144 Eigh {
145 eps: f64,
146 input_shape: Vec<DimExpr>,
147 },
148 Eig {
149 input_dtype: DType,
150 input_shape: Vec<DimExpr>,
151 },
152 TriangularSolve {
153 left_side: bool,
154 lower: bool,
155 transpose_a: bool,
156 unit_diagonal: bool,
157 lhs_shape: Vec<DimExpr>,
158 rhs_shape: Vec<DimExpr>,
159 },
160 ValidateNonsingular {
161 input_shape: Vec<DimExpr>,
162 },
163}
164
165impl StdTensorOp {
166 pub fn constant_f64(value: f64) -> Self {
176 Self::Constant {
177 dtype: DType::F64,
178 bytes: value.to_le_bytes().to_vec(),
179 }
180 }
181
182 pub fn constant_f32(value: f32) -> Self {
192 Self::Constant {
193 dtype: DType::F32,
194 bytes: value.to_le_bytes().to_vec(),
195 }
196 }
197
198 pub fn constant_c64(value: Complex64) -> Self {
209 let mut bytes = Vec::with_capacity(16);
210 bytes.extend_from_slice(&value.re.to_le_bytes());
211 bytes.extend_from_slice(&value.im.to_le_bytes());
212 Self::Constant {
213 dtype: DType::C64,
214 bytes,
215 }
216 }
217
218 pub fn constant_c32(value: Complex32) -> Self {
229 let mut bytes = Vec::with_capacity(8);
230 bytes.extend_from_slice(&value.re.to_le_bytes());
231 bytes.extend_from_slice(&value.im.to_le_bytes());
232 Self::Constant {
233 dtype: DType::C32,
234 bytes,
235 }
236 }
237}
238
239impl Eq for StdTensorOp {}
240
241impl Hash for StdTensorOp {
242 fn hash<H: Hasher>(&self, state: &mut H) {
243 std::mem::discriminant(self).hash(state);
244 match self {
245 Self::Add
246 | Self::Mul
247 | Self::Neg
248 | Self::Conj
249 | Self::Div
250 | Self::Abs
251 | Self::Sign
252 | Self::Maximum
253 | Self::Minimum
254 | Self::Select
255 | Self::Clamp
256 | Self::Exp
257 | Self::Log
258 | Self::Sin
259 | Self::Cos
260 | Self::Tanh
261 | Self::Sqrt
262 | Self::Rsqrt
263 | Self::Pow
264 | Self::Expm1
265 | Self::Log1p => {}
266 Self::Svd { eps, input_shape } => {
267 hash_f64(*eps, state);
268 input_shape.hash(state);
269 }
270 Self::Qr { input_shape }
271 | Self::Cholesky { input_shape }
272 | Self::Lu { input_shape } => {
273 input_shape.hash(state);
274 }
275 Self::Eig {
276 input_dtype,
277 input_shape,
278 } => {
279 input_dtype.hash(state);
280 input_shape.hash(state);
281 }
282 Self::Eigh { eps, input_shape } => {
283 hash_f64(*eps, state);
284 input_shape.hash(state);
285 }
286 Self::DotGeneral(config) => config.hash(state),
287 Self::Transpose { perm } => perm.hash(state),
288 Self::Reshape {
289 from_shape,
290 to_shape,
291 } => {
292 from_shape.hash(state);
293 to_shape.hash(state);
294 }
295 Self::BroadcastInDim { shape, dims } => {
296 shape.hash(state);
297 dims.hash(state);
298 }
299 Self::Convert { from, to } => {
300 from.hash(state);
301 to.hash(state);
302 }
303 Self::Constant { dtype, bytes } => {
304 dtype.hash(state);
305 bytes.hash(state);
306 }
307 Self::ReduceSum { axes, input_shape } => {
308 axes.hash(state);
309 input_shape.hash(state);
310 }
311 Self::Compare(dir) => dir.hash(state),
312 Self::ExtractDiag { axis_a, axis_b } | Self::EmbedDiag { axis_a, axis_b } => {
313 axis_a.hash(state);
314 axis_b.hash(state);
315 }
316 Self::Tril { k } | Self::Triu { k } => k.hash(state),
317 Self::Gather(config) => config.hash(state),
318 Self::Scatter(config) => config.hash(state),
319 Self::Slice(config) => config.hash(state),
320 Self::DynamicSlice { slice_sizes } => slice_sizes.hash(state),
321 Self::Pad(config) => config.hash(state),
322 Self::NaryEinsum {
323 subscripts,
324 n_inputs,
325 } => {
326 subscripts.hash(state);
327 n_inputs.hash(state);
328 }
329 Self::Concatenate { axis } => axis.hash(state),
330 Self::Reverse { axes } => axes.hash(state),
331 Self::ShapeOf { axis } | Self::DynamicTruncate { axis } | Self::PadToMatch { axis } => {
332 axis.hash(state)
333 }
334 Self::ReduceProd { axes, input_shape }
335 | Self::ReduceMax { axes, input_shape }
336 | Self::ReduceMin { axes, input_shape } => {
337 axes.hash(state);
338 input_shape.hash(state);
339 }
340 Self::TriangularSolve {
341 left_side,
342 lower,
343 transpose_a,
344 unit_diagonal,
345 lhs_shape,
346 rhs_shape,
347 } => {
348 left_side.hash(state);
349 lower.hash(state);
350 transpose_a.hash(state);
351 unit_diagonal.hash(state);
352 lhs_shape.hash(state);
353 rhs_shape.hash(state);
354 }
355 Self::ValidateNonsingular { input_shape } => {
356 input_shape.hash(state);
357 }
358 }
359 }
360}
361
362fn hash_f64<H: Hasher>(value: f64, state: &mut H) {
363 let bits = if value == 0.0 { 0 } else { value.to_bits() };
364 bits.hash(state);
365}
366
367fn n_inputs_from_dim_exprs(min_inputs: usize, exprs: &[&[DimExpr]]) -> usize {
368 let max_idx = exprs
369 .iter()
370 .flat_map(|exprs| exprs.iter())
371 .filter_map(DimExpr::max_input_idx)
372 .max()
373 .map_or(0, |max_idx| max_idx + 1);
374 max_idx.max(min_inputs)
375}
376
377impl GraphOp for StdTensorOp {
378 type Operand = tenferro_tensor::Tensor;
379 type Context = ();
380 type InputKey = TensorInputKey;
381
382 fn n_inputs(&self) -> usize {
383 match self {
384 Self::Add | Self::Mul | Self::DotGeneral(_) | Self::Gather(_) => 2,
385 Self::Neg
386 | Self::Conj
387 | Self::Transpose { .. }
388 | Self::Convert { .. }
389 | Self::ExtractDiag { .. }
390 | Self::EmbedDiag { .. }
391 | Self::Tril { .. }
392 | Self::Triu { .. }
393 | Self::Slice(_)
394 | Self::Pad(_)
395 | Self::Reverse { .. }
396 | Self::ShapeOf { .. } => 1,
397 Self::DynamicTruncate { .. } | Self::PadToMatch { .. } => 2,
398 Self::Reshape {
399 from_shape,
400 to_shape,
401 } => n_inputs_from_dim_exprs(1, &[from_shape, to_shape]),
402 Self::BroadcastInDim { shape, .. } => n_inputs_from_dim_exprs(1, &[shape]),
403 Self::ReduceSum { input_shape, .. }
404 | Self::ReduceProd { input_shape, .. }
405 | Self::ReduceMax { input_shape, .. }
406 | Self::ReduceMin { input_shape, .. } => n_inputs_from_dim_exprs(1, &[input_shape]),
407 Self::Div | Self::Maximum | Self::Minimum | Self::Pow | Self::DynamicSlice { .. } => 2,
408 Self::Constant { .. } => 0,
409 Self::Scatter(_) => 3,
410 Self::NaryEinsum { n_inputs, .. } => *n_inputs,
411 Self::Concatenate { .. } => {
412 todo!(
413 "n_inputs not yet implemented for variable-arity op {:?}",
414 self
415 )
416 }
417 Self::Abs
418 | Self::Sign
419 | Self::Exp
420 | Self::Log
421 | Self::Sin
422 | Self::Cos
423 | Self::Tanh
424 | Self::Sqrt
425 | Self::Rsqrt
426 | Self::Expm1
427 | Self::Log1p => 1,
428 Self::Select | Self::Clamp => 3,
429 Self::Compare(_) => 2,
430 Self::Cholesky { input_shape }
431 | Self::Lu { input_shape }
432 | Self::Svd { input_shape, .. }
433 | Self::Qr { input_shape }
434 | Self::Eigh { input_shape, .. }
435 | Self::Eig { input_shape, .. } => n_inputs_from_dim_exprs(1, &[input_shape]),
436 Self::TriangularSolve {
437 lhs_shape,
438 rhs_shape,
439 ..
440 } => n_inputs_from_dim_exprs(2, &[lhs_shape, rhs_shape]),
441 Self::ValidateNonsingular { input_shape } => n_inputs_from_dim_exprs(1, &[input_shape]),
442 }
443 }
444
445 fn n_outputs(&self) -> usize {
446 match self {
447 Self::Add
448 | Self::Mul
449 | Self::Neg
450 | Self::Conj
451 | Self::DotGeneral(_)
452 | Self::Transpose { .. }
453 | Self::Reshape { .. }
454 | Self::BroadcastInDim { .. }
455 | Self::Convert { .. }
456 | Self::ReduceSum { .. }
457 | Self::Div
458 | Self::Abs
459 | Self::Sign
460 | Self::Maximum
461 | Self::Minimum
462 | Self::Compare(_)
463 | Self::Select
464 | Self::Clamp
465 | Self::Constant { .. }
466 | Self::Exp
467 | Self::Log
468 | Self::Sin
469 | Self::Cos
470 | Self::Tanh
471 | Self::Sqrt
472 | Self::Rsqrt
473 | Self::Pow
474 | Self::Expm1
475 | Self::Log1p
476 | Self::ExtractDiag { .. }
477 | Self::EmbedDiag { .. }
478 | Self::Tril { .. }
479 | Self::Triu { .. }
480 | Self::Gather(_)
481 | Self::Scatter(_)
482 | Self::Slice(_)
483 | Self::DynamicSlice { .. }
484 | Self::Pad(_)
485 | Self::NaryEinsum { .. }
486 | Self::Reverse { .. }
487 | Self::ShapeOf { .. }
488 | Self::DynamicTruncate { .. }
489 | Self::PadToMatch { .. }
490 | Self::ReduceProd { .. }
491 | Self::ReduceMax { .. }
492 | Self::ReduceMin { .. } => 1,
493 Self::Cholesky { .. }
494 | Self::TriangularSolve { .. }
495 | Self::ValidateNonsingular { .. } => 1,
496 Self::Lu { .. } => 4,
497 Self::Svd { .. } => 3, Self::Qr { .. } => 2, Self::Eigh { .. } => 2, Self::Eig { .. } => 2, Self::Concatenate { .. } => todo!(
502 "n_outputs not yet implemented for variable-arity op {:?}",
503 self
504 ),
505 }
506 }
507}
508
509impl PrimitiveOp for StdTensorOp {
510 type ADContext = crate::ad::context::ShapeGuardContext;
511
512 fn add() -> Self {
513 StdTensorOp::Add
514 }
515
516 fn linearize(
517 &self,
518 builder: &mut FragmentBuilder<Self>,
519 primal_in: &[GlobalValKey<Self>],
520 primal_out: &[GlobalValKey<Self>],
521 tangent_in: &[Option<LocalValId>],
522 ctx: &mut Self::ADContext,
523 ) -> Vec<Option<LocalValId>> {
524 crate::ad::linearize(self, builder, primal_in, primal_out, tangent_in, ctx)
525 }
526
527 fn transpose_rule(
528 &self,
529 emitter: &mut impl OpEmitter<Self>,
530 cotangent_out: &[Option<LocalValId>],
531 inputs: &[ValRef<Self>],
532 mode: &OpMode,
533 ctx: &mut Self::ADContext,
534 ) -> Vec<Option<LocalValId>> {
535 crate::ad::transpose_rule(self, emitter, cotangent_out, inputs, mode, ctx)
536 }
537}
538
539impl SemiringOps for StdTensorOp {
540 fn add_op() -> Self {
541 StdTensorOp::Add
542 }
543
544 fn mul_op() -> Self {
545 StdTensorOp::Mul
546 }
547
548 fn dot_general(config: DotGeneralConfig) -> Self {
549 StdTensorOp::DotGeneral(config)
550 }
551
552 fn reduce_sum(axes: Vec<usize>, input_shape: Vec<DimExpr>) -> Self {
553 StdTensorOp::ReduceSum { axes, input_shape }
554 }
555
556 fn transpose_op(perm: Vec<usize>) -> Self {
557 StdTensorOp::Transpose { perm }
558 }
559
560 fn reshape(from_shape: Vec<DimExpr>, to_shape: Vec<DimExpr>) -> Self {
561 StdTensorOp::Reshape {
562 from_shape,
563 to_shape,
564 }
565 }
566
567 fn broadcast_in_dim(shape: Vec<DimExpr>, dims: Vec<usize>) -> Self {
568 StdTensorOp::BroadcastInDim { shape, dims }
569 }
570
571 fn extract_diag(axis_a: usize, axis_b: usize) -> Self {
572 StdTensorOp::ExtractDiag { axis_a, axis_b }
573 }
574
575 fn embed_diag(axis_a: usize, axis_b: usize) -> Self {
576 StdTensorOp::EmbedDiag { axis_a, axis_b }
577 }
578}