Skip to main content

tenferro_ops/
std_tensor_op.rs

1use std::hash::{Hash, Hasher};
2
3use chainrules_core::PrimitiveOp;
4use computegraph::fragment::FragmentBuilder;
5use computegraph::types::{GlobalValKey, LocalValId, OpMode, ValRef};
6use computegraph::{GraphOp, OpEmitter};
7use num_complex::{Complex32, Complex64};
8
9use crate::dim_expr::DimExpr;
10use crate::input_key::TensorInputKey;
11use crate::semiring_ops::SemiringOps;
12use tenferro_tensor::{
13    CompareDir, DType, DotGeneralConfig, GatherConfig, PadConfig, ScatterConfig, SliceConfig,
14};
15
16#[derive(Clone, Debug, PartialEq)]
17pub enum StdTensorOp {
18    // Tier 1: semiring
19    Add,
20    Mul,
21    Neg,
22    Conj,
23    DotGeneral(DotGeneralConfig),
24    Transpose {
25        perm: Vec<usize>,
26    },
27    Reshape {
28        from_shape: Vec<DimExpr>,
29        to_shape: Vec<DimExpr>,
30    },
31    BroadcastInDim {
32        shape: Vec<DimExpr>,
33        dims: Vec<usize>,
34    },
35    Convert {
36        from: DType,
37        to: DType,
38    },
39    Constant {
40        dtype: DType,
41        bytes: Vec<u8>,
42    },
43    ReduceSum {
44        axes: Vec<usize>,
45        input_shape: Vec<DimExpr>,
46    },
47
48    // Tier 2: elementwise
49    Div,
50    Abs,
51    Sign,
52    Maximum,
53    Minimum,
54    Compare(CompareDir),
55    Select,
56    Clamp,
57
58    // Tier 2: analytic
59    Exp,
60    Log,
61    Sin,
62    Cos,
63    Tanh,
64    Sqrt,
65    Rsqrt,
66    Pow,
67    Expm1,
68    Log1p,
69
70    // Tier 1: diagonal extraction / embedding (AD-closed pair)
71    ExtractDiag {
72        axis_a: usize,
73        axis_b: usize,
74    },
75    EmbedDiag {
76        axis_a: usize,
77        axis_b: usize,
78    },
79    Tril {
80        k: i64,
81    },
82    Triu {
83        k: i64,
84    },
85
86    // Tier 2: indexing
87    Gather(GatherConfig),
88    Scatter(ScatterConfig),
89    Slice(SliceConfig),
90    DynamicSlice {
91        slice_sizes: Vec<usize>,
92    },
93    Pad(PadConfig),
94    /// N-ary einsum kept as a single graph node.
95    /// Contraction path is optimized at execution time from actual input shapes.
96    NaryEinsum {
97        subscripts: String,
98        n_inputs: usize,
99    },
100    Concatenate {
101        axis: usize,
102    },
103    Reverse {
104        axes: Vec<usize>,
105    },
106    ShapeOf {
107        axis: usize,
108    },
109    DynamicTruncate {
110        axis: usize,
111    },
112    PadToMatch {
113        axis: usize,
114    },
115
116    // Tier 2: reductions
117    ReduceProd {
118        axes: Vec<usize>,
119        input_shape: Vec<DimExpr>,
120    },
121    ReduceMax {
122        axes: Vec<usize>,
123        input_shape: Vec<DimExpr>,
124    },
125    ReduceMin {
126        axes: Vec<usize>,
127        input_shape: Vec<DimExpr>,
128    },
129
130    // Linalg
131    Cholesky {
132        input_shape: Vec<DimExpr>,
133    },
134    Lu {
135        input_shape: Vec<DimExpr>,
136    },
137    Svd {
138        eps: f64,
139        input_shape: Vec<DimExpr>,
140    },
141    Qr {
142        input_shape: Vec<DimExpr>,
143    },
144    Eigh {
145        eps: f64,
146        input_shape: Vec<DimExpr>,
147    },
148    Eig {
149        input_dtype: DType,
150        input_shape: Vec<DimExpr>,
151    },
152    TriangularSolve {
153        left_side: bool,
154        lower: bool,
155        transpose_a: bool,
156        unit_diagonal: bool,
157        lhs_shape: Vec<DimExpr>,
158        rhs_shape: Vec<DimExpr>,
159    },
160    ValidateNonsingular {
161        input_shape: Vec<DimExpr>,
162    },
163}
164
165impl StdTensorOp {
166    /// Create an `f64` scalar constant op.
167    ///
168    /// # Examples
169    ///
170    /// ```ignore
171    /// use tenferro_ops::std_tensor_op::StdTensorOp;
172    ///
173    /// let op = StdTensorOp::constant_f64(1.5);
174    /// ```
175    pub fn constant_f64(value: f64) -> Self {
176        Self::Constant {
177            dtype: DType::F64,
178            bytes: value.to_le_bytes().to_vec(),
179        }
180    }
181
182    /// Create an `f32` scalar constant op.
183    ///
184    /// # Examples
185    ///
186    /// ```ignore
187    /// use tenferro_ops::std_tensor_op::StdTensorOp;
188    ///
189    /// let op = StdTensorOp::constant_f32(1.5_f32);
190    /// ```
191    pub fn constant_f32(value: f32) -> Self {
192        Self::Constant {
193            dtype: DType::F32,
194            bytes: value.to_le_bytes().to_vec(),
195        }
196    }
197
198    /// Create a `Complex64` scalar constant op.
199    ///
200    /// # Examples
201    ///
202    /// ```ignore
203    /// use num_complex::Complex64;
204    /// use tenferro_ops::std_tensor_op::StdTensorOp;
205    ///
206    /// let op = StdTensorOp::constant_c64(Complex64::new(1.0, -2.0));
207    /// ```
208    pub fn constant_c64(value: Complex64) -> Self {
209        let mut bytes = Vec::with_capacity(16);
210        bytes.extend_from_slice(&value.re.to_le_bytes());
211        bytes.extend_from_slice(&value.im.to_le_bytes());
212        Self::Constant {
213            dtype: DType::C64,
214            bytes,
215        }
216    }
217
218    /// Create a `Complex32` scalar constant op.
219    ///
220    /// # Examples
221    ///
222    /// ```ignore
223    /// use num_complex::Complex32;
224    /// use tenferro_ops::std_tensor_op::StdTensorOp;
225    ///
226    /// let op = StdTensorOp::constant_c32(Complex32::new(1.0, -2.0));
227    /// ```
228    pub fn constant_c32(value: Complex32) -> Self {
229        let mut bytes = Vec::with_capacity(8);
230        bytes.extend_from_slice(&value.re.to_le_bytes());
231        bytes.extend_from_slice(&value.im.to_le_bytes());
232        Self::Constant {
233            dtype: DType::C32,
234            bytes,
235        }
236    }
237}
238
239impl Eq for StdTensorOp {}
240
241impl Hash for StdTensorOp {
242    fn hash<H: Hasher>(&self, state: &mut H) {
243        std::mem::discriminant(self).hash(state);
244        match self {
245            Self::Add
246            | Self::Mul
247            | Self::Neg
248            | Self::Conj
249            | Self::Div
250            | Self::Abs
251            | Self::Sign
252            | Self::Maximum
253            | Self::Minimum
254            | Self::Select
255            | Self::Clamp
256            | Self::Exp
257            | Self::Log
258            | Self::Sin
259            | Self::Cos
260            | Self::Tanh
261            | Self::Sqrt
262            | Self::Rsqrt
263            | Self::Pow
264            | Self::Expm1
265            | Self::Log1p => {}
266            Self::Svd { eps, input_shape } => {
267                hash_f64(*eps, state);
268                input_shape.hash(state);
269            }
270            Self::Qr { input_shape }
271            | Self::Cholesky { input_shape }
272            | Self::Lu { input_shape } => {
273                input_shape.hash(state);
274            }
275            Self::Eig {
276                input_dtype,
277                input_shape,
278            } => {
279                input_dtype.hash(state);
280                input_shape.hash(state);
281            }
282            Self::Eigh { eps, input_shape } => {
283                hash_f64(*eps, state);
284                input_shape.hash(state);
285            }
286            Self::DotGeneral(config) => config.hash(state),
287            Self::Transpose { perm } => perm.hash(state),
288            Self::Reshape {
289                from_shape,
290                to_shape,
291            } => {
292                from_shape.hash(state);
293                to_shape.hash(state);
294            }
295            Self::BroadcastInDim { shape, dims } => {
296                shape.hash(state);
297                dims.hash(state);
298            }
299            Self::Convert { from, to } => {
300                from.hash(state);
301                to.hash(state);
302            }
303            Self::Constant { dtype, bytes } => {
304                dtype.hash(state);
305                bytes.hash(state);
306            }
307            Self::ReduceSum { axes, input_shape } => {
308                axes.hash(state);
309                input_shape.hash(state);
310            }
311            Self::Compare(dir) => dir.hash(state),
312            Self::ExtractDiag { axis_a, axis_b } | Self::EmbedDiag { axis_a, axis_b } => {
313                axis_a.hash(state);
314                axis_b.hash(state);
315            }
316            Self::Tril { k } | Self::Triu { k } => k.hash(state),
317            Self::Gather(config) => config.hash(state),
318            Self::Scatter(config) => config.hash(state),
319            Self::Slice(config) => config.hash(state),
320            Self::DynamicSlice { slice_sizes } => slice_sizes.hash(state),
321            Self::Pad(config) => config.hash(state),
322            Self::NaryEinsum {
323                subscripts,
324                n_inputs,
325            } => {
326                subscripts.hash(state);
327                n_inputs.hash(state);
328            }
329            Self::Concatenate { axis } => axis.hash(state),
330            Self::Reverse { axes } => axes.hash(state),
331            Self::ShapeOf { axis } | Self::DynamicTruncate { axis } | Self::PadToMatch { axis } => {
332                axis.hash(state)
333            }
334            Self::ReduceProd { axes, input_shape }
335            | Self::ReduceMax { axes, input_shape }
336            | Self::ReduceMin { axes, input_shape } => {
337                axes.hash(state);
338                input_shape.hash(state);
339            }
340            Self::TriangularSolve {
341                left_side,
342                lower,
343                transpose_a,
344                unit_diagonal,
345                lhs_shape,
346                rhs_shape,
347            } => {
348                left_side.hash(state);
349                lower.hash(state);
350                transpose_a.hash(state);
351                unit_diagonal.hash(state);
352                lhs_shape.hash(state);
353                rhs_shape.hash(state);
354            }
355            Self::ValidateNonsingular { input_shape } => {
356                input_shape.hash(state);
357            }
358        }
359    }
360}
361
362fn hash_f64<H: Hasher>(value: f64, state: &mut H) {
363    let bits = if value == 0.0 { 0 } else { value.to_bits() };
364    bits.hash(state);
365}
366
367fn n_inputs_from_dim_exprs(min_inputs: usize, exprs: &[&[DimExpr]]) -> usize {
368    let max_idx = exprs
369        .iter()
370        .flat_map(|exprs| exprs.iter())
371        .filter_map(DimExpr::max_input_idx)
372        .max()
373        .map_or(0, |max_idx| max_idx + 1);
374    max_idx.max(min_inputs)
375}
376
377impl GraphOp for StdTensorOp {
378    type Operand = tenferro_tensor::Tensor;
379    type Context = ();
380    type InputKey = TensorInputKey;
381
382    fn n_inputs(&self) -> usize {
383        match self {
384            Self::Add | Self::Mul | Self::DotGeneral(_) | Self::Gather(_) => 2,
385            Self::Neg
386            | Self::Conj
387            | Self::Transpose { .. }
388            | Self::Convert { .. }
389            | Self::ExtractDiag { .. }
390            | Self::EmbedDiag { .. }
391            | Self::Tril { .. }
392            | Self::Triu { .. }
393            | Self::Slice(_)
394            | Self::Pad(_)
395            | Self::Reverse { .. }
396            | Self::ShapeOf { .. } => 1,
397            Self::DynamicTruncate { .. } | Self::PadToMatch { .. } => 2,
398            Self::Reshape {
399                from_shape,
400                to_shape,
401            } => n_inputs_from_dim_exprs(1, &[from_shape, to_shape]),
402            Self::BroadcastInDim { shape, .. } => n_inputs_from_dim_exprs(1, &[shape]),
403            Self::ReduceSum { input_shape, .. }
404            | Self::ReduceProd { input_shape, .. }
405            | Self::ReduceMax { input_shape, .. }
406            | Self::ReduceMin { input_shape, .. } => n_inputs_from_dim_exprs(1, &[input_shape]),
407            Self::Div | Self::Maximum | Self::Minimum | Self::Pow | Self::DynamicSlice { .. } => 2,
408            Self::Constant { .. } => 0,
409            Self::Scatter(_) => 3,
410            Self::NaryEinsum { n_inputs, .. } => *n_inputs,
411            Self::Concatenate { .. } => {
412                todo!(
413                    "n_inputs not yet implemented for variable-arity op {:?}",
414                    self
415                )
416            }
417            Self::Abs
418            | Self::Sign
419            | Self::Exp
420            | Self::Log
421            | Self::Sin
422            | Self::Cos
423            | Self::Tanh
424            | Self::Sqrt
425            | Self::Rsqrt
426            | Self::Expm1
427            | Self::Log1p => 1,
428            Self::Select | Self::Clamp => 3,
429            Self::Compare(_) => 2,
430            Self::Cholesky { input_shape }
431            | Self::Lu { input_shape }
432            | Self::Svd { input_shape, .. }
433            | Self::Qr { input_shape }
434            | Self::Eigh { input_shape, .. }
435            | Self::Eig { input_shape, .. } => n_inputs_from_dim_exprs(1, &[input_shape]),
436            Self::TriangularSolve {
437                lhs_shape,
438                rhs_shape,
439                ..
440            } => n_inputs_from_dim_exprs(2, &[lhs_shape, rhs_shape]),
441            Self::ValidateNonsingular { input_shape } => n_inputs_from_dim_exprs(1, &[input_shape]),
442        }
443    }
444
445    fn n_outputs(&self) -> usize {
446        match self {
447            Self::Add
448            | Self::Mul
449            | Self::Neg
450            | Self::Conj
451            | Self::DotGeneral(_)
452            | Self::Transpose { .. }
453            | Self::Reshape { .. }
454            | Self::BroadcastInDim { .. }
455            | Self::Convert { .. }
456            | Self::ReduceSum { .. }
457            | Self::Div
458            | Self::Abs
459            | Self::Sign
460            | Self::Maximum
461            | Self::Minimum
462            | Self::Compare(_)
463            | Self::Select
464            | Self::Clamp
465            | Self::Constant { .. }
466            | Self::Exp
467            | Self::Log
468            | Self::Sin
469            | Self::Cos
470            | Self::Tanh
471            | Self::Sqrt
472            | Self::Rsqrt
473            | Self::Pow
474            | Self::Expm1
475            | Self::Log1p
476            | Self::ExtractDiag { .. }
477            | Self::EmbedDiag { .. }
478            | Self::Tril { .. }
479            | Self::Triu { .. }
480            | Self::Gather(_)
481            | Self::Scatter(_)
482            | Self::Slice(_)
483            | Self::DynamicSlice { .. }
484            | Self::Pad(_)
485            | Self::NaryEinsum { .. }
486            | Self::Reverse { .. }
487            | Self::ShapeOf { .. }
488            | Self::DynamicTruncate { .. }
489            | Self::PadToMatch { .. }
490            | Self::ReduceProd { .. }
491            | Self::ReduceMax { .. }
492            | Self::ReduceMin { .. } => 1,
493            Self::Cholesky { .. }
494            | Self::TriangularSolve { .. }
495            | Self::ValidateNonsingular { .. } => 1,
496            Self::Lu { .. } => 4,
497            Self::Svd { .. } => 3,  // U, S, Vt
498            Self::Qr { .. } => 2,   // Q, R
499            Self::Eigh { .. } => 2, // eigenvalues, eigenvectors
500            Self::Eig { .. } => 2,  // eigenvalues, eigenvectors
501            Self::Concatenate { .. } => todo!(
502                "n_outputs not yet implemented for variable-arity op {:?}",
503                self
504            ),
505        }
506    }
507}
508
509impl PrimitiveOp for StdTensorOp {
510    type ADContext = crate::ad::context::ShapeGuardContext;
511
512    fn add() -> Self {
513        StdTensorOp::Add
514    }
515
516    fn linearize(
517        &self,
518        builder: &mut FragmentBuilder<Self>,
519        primal_in: &[GlobalValKey<Self>],
520        primal_out: &[GlobalValKey<Self>],
521        tangent_in: &[Option<LocalValId>],
522        ctx: &mut Self::ADContext,
523    ) -> Vec<Option<LocalValId>> {
524        crate::ad::linearize(self, builder, primal_in, primal_out, tangent_in, ctx)
525    }
526
527    fn transpose_rule(
528        &self,
529        emitter: &mut impl OpEmitter<Self>,
530        cotangent_out: &[Option<LocalValId>],
531        inputs: &[ValRef<Self>],
532        mode: &OpMode,
533        ctx: &mut Self::ADContext,
534    ) -> Vec<Option<LocalValId>> {
535        crate::ad::transpose_rule(self, emitter, cotangent_out, inputs, mode, ctx)
536    }
537}
538
539impl SemiringOps for StdTensorOp {
540    fn add_op() -> Self {
541        StdTensorOp::Add
542    }
543
544    fn mul_op() -> Self {
545        StdTensorOp::Mul
546    }
547
548    fn dot_general(config: DotGeneralConfig) -> Self {
549        StdTensorOp::DotGeneral(config)
550    }
551
552    fn reduce_sum(axes: Vec<usize>, input_shape: Vec<DimExpr>) -> Self {
553        StdTensorOp::ReduceSum { axes, input_shape }
554    }
555
556    fn transpose_op(perm: Vec<usize>) -> Self {
557        StdTensorOp::Transpose { perm }
558    }
559
560    fn reshape(from_shape: Vec<DimExpr>, to_shape: Vec<DimExpr>) -> Self {
561        StdTensorOp::Reshape {
562            from_shape,
563            to_shape,
564        }
565    }
566
567    fn broadcast_in_dim(shape: Vec<DimExpr>, dims: Vec<usize>) -> Self {
568        StdTensorOp::BroadcastInDim { shape, dims }
569    }
570
571    fn extract_diag(axis_a: usize, axis_b: usize) -> Self {
572        StdTensorOp::ExtractDiag { axis_a, axis_b }
573    }
574
575    fn embed_diag(axis_a: usize, axis_b: usize) -> Self {
576        StdTensorOp::EmbedDiag { axis_a, axis_b }
577    }
578}