Skip to main content

tenferro_ops/
std_tensor_op.rs

1use std::hash::{Hash, Hasher};
2use std::sync::Arc;
3
4#[cfg(feature = "autodiff")]
5use computegraph::types::{LocalValueId, OperationRole, ValueKey};
6use computegraph::GraphOperation;
7use num_complex::{Complex32, Complex64};
8#[cfg(feature = "autodiff")]
9use tidu::{ADRuleResult, Primitive, PrimitiveBuilder, PrimitiveValue};
10
11use crate::dim_expr::DimExpr;
12use crate::ext_op::{ext_op_eq, hash_extension, ExtensionOp};
13use crate::input_key::TensorInputKey;
14use tenferro_tensor::{
15    CompareDir, DType, DotGeneralConfig, GatherConfig, PadConfig, ScatterConfig, SliceConfig,
16    TensorScalar,
17};
18
19/// Scalar values that can be encoded as tensor constant operations.
20///
21/// # Examples
22///
23/// ```rust
24/// use tenferro_ops::std_tensor_op::ConstantScalar;
25///
26/// assert_eq!(1.0_f64.constant_bytes(), 1.0_f64.to_le_bytes().to_vec());
27/// ```
28pub trait ConstantScalar: TensorScalar + private::Sealed {
29    /// Encode the scalar value as little-endian constant bytes.
30    ///
31    /// # Examples
32    ///
33    /// ```rust
34    /// use tenferro_ops::std_tensor_op::ConstantScalar;
35    ///
36    /// assert_eq!(true.constant_bytes(), vec![1]);
37    /// ```
38    fn constant_bytes(self) -> Vec<u8>;
39}
40
41mod private {
42    pub trait Sealed {}
43
44    impl Sealed for f64 {}
45    impl Sealed for f32 {}
46    impl Sealed for i64 {}
47    impl Sealed for i32 {}
48    impl Sealed for bool {}
49    impl Sealed for num_complex::Complex64 {}
50    impl Sealed for num_complex::Complex32 {}
51}
52
53impl ConstantScalar for f64 {
54    fn constant_bytes(self) -> Vec<u8> {
55        self.to_le_bytes().to_vec()
56    }
57}
58
59impl ConstantScalar for f32 {
60    fn constant_bytes(self) -> Vec<u8> {
61        self.to_le_bytes().to_vec()
62    }
63}
64
65impl ConstantScalar for i64 {
66    fn constant_bytes(self) -> Vec<u8> {
67        self.to_le_bytes().to_vec()
68    }
69}
70
71impl ConstantScalar for i32 {
72    fn constant_bytes(self) -> Vec<u8> {
73        self.to_le_bytes().to_vec()
74    }
75}
76
77impl ConstantScalar for bool {
78    fn constant_bytes(self) -> Vec<u8> {
79        vec![u8::from(self)]
80    }
81}
82
83impl ConstantScalar for Complex64 {
84    fn constant_bytes(self) -> Vec<u8> {
85        let mut bytes = Vec::with_capacity(16);
86        bytes.extend_from_slice(&self.re.to_le_bytes());
87        bytes.extend_from_slice(&self.im.to_le_bytes());
88        bytes
89    }
90}
91
92impl ConstantScalar for Complex32 {
93    fn constant_bytes(self) -> Vec<u8> {
94        let mut bytes = Vec::with_capacity(8);
95        bytes.extend_from_slice(&self.re.to_le_bytes());
96        bytes.extend_from_slice(&self.im.to_le_bytes());
97        bytes
98    }
99}
100
101tenferro_core_ops::define_std_tensor_op!();
102
103impl StdTensorOp {
104    /// Create a scalar constant op from any supported tensor scalar.
105    ///
106    /// # Examples
107    ///
108    /// ```rust
109    /// use num_complex::Complex64;
110    /// use tenferro_ops::std_tensor_op::StdTensorOp;
111    /// use tenferro_tensor::DType;
112    ///
113    /// let real = StdTensorOp::constant(1.5_f64);
114    /// let complex = StdTensorOp::constant(Complex64::new(1.0, -2.0));
115    ///
116    /// assert!(matches!(real, StdTensorOp::Constant { dtype: DType::F64, .. }));
117    /// assert!(matches!(complex, StdTensorOp::Constant { dtype: DType::C64, .. }));
118    /// ```
119    pub fn constant<T: ConstantScalar>(value: T) -> Self {
120        Self::Constant {
121            dtype: T::dtype(),
122            bytes: value.constant_bytes(),
123        }
124    }
125}
126
127impl PartialEq for StdTensorOp {
128    fn eq(&self, other: &Self) -> bool {
129        if std::mem::discriminant(self) != std::mem::discriminant(other) {
130            return false;
131        }
132        match (self, other) {
133            (Self::Add, Self::Add)
134            | (Self::Mul, Self::Mul)
135            | (Self::Neg, Self::Neg)
136            | (Self::Conj, Self::Conj)
137            | (Self::Div, Self::Div)
138            | (Self::Abs, Self::Abs)
139            | (Self::Sign, Self::Sign)
140            | (Self::Maximum, Self::Maximum)
141            | (Self::Minimum, Self::Minimum)
142            | (Self::Select, Self::Select)
143            | (Self::Clamp, Self::Clamp)
144            | (Self::Exp, Self::Exp)
145            | (Self::Log, Self::Log)
146            | (Self::Sin, Self::Sin)
147            | (Self::Cos, Self::Cos)
148            | (Self::Tanh, Self::Tanh)
149            | (Self::Sqrt, Self::Sqrt)
150            | (Self::Rsqrt, Self::Rsqrt)
151            | (Self::Pow, Self::Pow)
152            | (Self::Expm1, Self::Expm1)
153            | (Self::Log1p, Self::Log1p)
154            | (Self::DynamicUpdateSlice, Self::DynamicUpdateSlice) => true,
155            (Self::DotGeneral { config: a }, Self::DotGeneral { config: b }) => a == b,
156            (Self::Transpose { perm: a }, Self::Transpose { perm: b }) => a == b,
157            (Self::Reshape { to_shape: a }, Self::Reshape { to_shape: b }) => a == b,
158            (
159                Self::BroadcastInDim {
160                    shape: sa,
161                    dims: da,
162                },
163                Self::BroadcastInDim {
164                    shape: sb,
165                    dims: db,
166                },
167            ) => sa == sb && da == db,
168            (Self::Convert { from: fa, to: ta }, Self::Convert { from: fb, to: tb }) => {
169                fa == fb && ta == tb
170            }
171            (
172                Self::Constant {
173                    dtype: da,
174                    bytes: ba,
175                },
176                Self::Constant {
177                    dtype: db,
178                    bytes: bb,
179                },
180            ) => da == db && ba == bb,
181            (Self::ReduceSum { axes: a }, Self::ReduceSum { axes: b })
182            | (Self::ReduceProd { axes: a }, Self::ReduceProd { axes: b })
183            | (Self::ReduceMax { axes: a }, Self::ReduceMax { axes: b })
184            | (Self::ReduceMin { axes: a }, Self::ReduceMin { axes: b })
185            | (Self::Reverse { axes: a }, Self::Reverse { axes: b }) => a == b,
186            (Self::Compare(a), Self::Compare(b)) => a == b,
187            (
188                Self::ExtractDiag {
189                    axis_a: aa,
190                    axis_b: ba,
191                },
192                Self::ExtractDiag {
193                    axis_a: ab,
194                    axis_b: bb,
195                },
196            )
197            | (
198                Self::EmbedDiag {
199                    axis_a: aa,
200                    axis_b: ba,
201                },
202                Self::EmbedDiag {
203                    axis_a: ab,
204                    axis_b: bb,
205                },
206            ) => aa == ab && ba == bb,
207            (Self::Tril { k: a }, Self::Tril { k: b })
208            | (Self::Triu { k: a }, Self::Triu { k: b }) => a == b,
209            (Self::Gather(a), Self::Gather(b)) => a == b,
210            (
211                Self::GatherDynamicSliceSizes {
212                    offset_dims: oa,
213                    collapsed_slice_dims: ca,
214                    start_index_map: sa,
215                    index_vector_dim: ia,
216                    slice_sizes: za,
217                },
218                Self::GatherDynamicSliceSizes {
219                    offset_dims: ob,
220                    collapsed_slice_dims: cb,
221                    start_index_map: sb,
222                    index_vector_dim: ib,
223                    slice_sizes: zb,
224                },
225            ) => oa == ob && ca == cb && sa == sb && ia == ib && za == zb,
226            (Self::Scatter(a), Self::Scatter(b)) => a == b,
227            (Self::Slice(a), Self::Slice(b)) => a == b,
228            (Self::DynamicSlice { slice_sizes: a }, Self::DynamicSlice { slice_sizes: b }) => {
229                a == b
230            }
231            (Self::Pad(a), Self::Pad(b)) => a == b,
232            (
233                Self::Concatenate {
234                    axis: a,
235                    input_count: na,
236                },
237                Self::Concatenate {
238                    axis: b,
239                    input_count: nb,
240                },
241            ) => a == b && na == nb,
242            (Self::ShapeOf { axis: a }, Self::ShapeOf { axis: b })
243            | (Self::DynamicTruncate { axis: a }, Self::DynamicTruncate { axis: b })
244            | (Self::PadToMatch { axis: a }, Self::PadToMatch { axis: b }) => a == b,
245            (Self::Extension(a), Self::Extension(b)) => ext_op_eq(a.as_ref(), b.as_ref()),
246            _ => false,
247        }
248    }
249}
250
251impl Eq for StdTensorOp {}
252
253impl Hash for StdTensorOp {
254    fn hash<H: Hasher>(&self, state: &mut H) {
255        std::mem::discriminant(self).hash(state);
256        match self {
257            Self::Add
258            | Self::Mul
259            | Self::Neg
260            | Self::Conj
261            | Self::Div
262            | Self::Abs
263            | Self::Sign
264            | Self::Maximum
265            | Self::Minimum
266            | Self::Select
267            | Self::Clamp
268            | Self::Exp
269            | Self::Log
270            | Self::Sin
271            | Self::Cos
272            | Self::Tanh
273            | Self::Sqrt
274            | Self::Rsqrt
275            | Self::Pow
276            | Self::Expm1
277            | Self::Log1p => {}
278            Self::DotGeneral { config } => {
279                config.hash(state);
280            }
281            Self::Transpose { perm } => perm.hash(state),
282            Self::Reshape { to_shape } => {
283                to_shape.hash(state);
284            }
285            Self::BroadcastInDim { shape, dims } => {
286                shape.hash(state);
287                dims.hash(state);
288            }
289            Self::Convert { from, to } => {
290                from.hash(state);
291                to.hash(state);
292            }
293            Self::Constant { dtype, bytes } => {
294                dtype.hash(state);
295                bytes.hash(state);
296            }
297            Self::ReduceSum { axes } => {
298                axes.hash(state);
299            }
300            Self::Compare(dir) => dir.hash(state),
301            Self::ExtractDiag { axis_a, axis_b } | Self::EmbedDiag { axis_a, axis_b } => {
302                axis_a.hash(state);
303                axis_b.hash(state);
304            }
305            Self::Tril { k } | Self::Triu { k } => k.hash(state),
306            Self::Gather(config) => config.hash(state),
307            Self::GatherDynamicSliceSizes {
308                offset_dims,
309                collapsed_slice_dims,
310                start_index_map,
311                index_vector_dim,
312                slice_sizes,
313            } => {
314                offset_dims.hash(state);
315                collapsed_slice_dims.hash(state);
316                start_index_map.hash(state);
317                index_vector_dim.hash(state);
318                slice_sizes.hash(state);
319            }
320            Self::Scatter(config) => config.hash(state),
321            Self::Slice(config) => config.hash(state),
322            Self::DynamicSlice { slice_sizes } => slice_sizes.hash(state),
323            Self::DynamicUpdateSlice => {}
324            Self::Pad(config) => config.hash(state),
325            Self::Concatenate { axis, input_count } => {
326                axis.hash(state);
327                input_count.hash(state);
328            }
329            Self::Reverse { axes } => axes.hash(state),
330            Self::ShapeOf { axis } | Self::DynamicTruncate { axis } | Self::PadToMatch { axis } => {
331                axis.hash(state)
332            }
333            Self::ReduceProd { axes } | Self::ReduceMax { axes } | Self::ReduceMin { axes } => {
334                axes.hash(state);
335            }
336            Self::Extension(op) => hash_extension(op.as_ref(), state),
337        }
338    }
339}
340
341fn n_inputs_from_dim_exprs(min_inputs: usize, exprs: &[&[DimExpr]]) -> usize {
342    let max_idx = exprs
343        .iter()
344        .flat_map(|exprs| exprs.iter())
345        .filter_map(DimExpr::max_input_idx)
346        .max()
347        .map_or(0, |max_idx| max_idx + 1);
348    max_idx.max(min_inputs)
349}
350
351impl GraphOperation for StdTensorOp {
352    type Operand = tenferro_tensor::Tensor;
353    type Context = ();
354    type InputKey = TensorInputKey;
355
356    fn input_count(&self) -> usize {
357        match self {
358            Self::Add | Self::Mul | Self::DotGeneral { .. } | Self::Gather(_) => 2,
359            Self::GatherDynamicSliceSizes { slice_sizes, .. } => {
360                n_inputs_from_dim_exprs(2, &[slice_sizes])
361            }
362            Self::Neg
363            | Self::Conj
364            | Self::Transpose { .. }
365            | Self::Convert { .. }
366            | Self::ExtractDiag { .. }
367            | Self::EmbedDiag { .. }
368            | Self::Tril { .. }
369            | Self::Triu { .. }
370            | Self::Slice(_)
371            | Self::Pad(_)
372            | Self::Reverse { .. }
373            | Self::ShapeOf { .. } => 1,
374            Self::DynamicTruncate { .. } | Self::PadToMatch { .. } => 2,
375            Self::Reshape { to_shape } => n_inputs_from_dim_exprs(1, &[to_shape]),
376            Self::BroadcastInDim { shape, .. } => n_inputs_from_dim_exprs(1, &[shape]),
377            Self::ReduceSum { .. }
378            | Self::ReduceProd { .. }
379            | Self::ReduceMax { .. }
380            | Self::ReduceMin { .. } => 1,
381            Self::Div | Self::Maximum | Self::Minimum | Self::Pow | Self::DynamicSlice { .. } => 2,
382            Self::Constant { .. } => 0,
383            Self::Scatter(_) | Self::DynamicUpdateSlice => 3,
384            Self::Concatenate { input_count, .. } => *input_count,
385            Self::Abs
386            | Self::Sign
387            | Self::Exp
388            | Self::Log
389            | Self::Sin
390            | Self::Cos
391            | Self::Tanh
392            | Self::Sqrt
393            | Self::Rsqrt
394            | Self::Expm1
395            | Self::Log1p => 1,
396            Self::Select | Self::Clamp => 3,
397            Self::Compare(_) => 2,
398            Self::Extension(op) => ExtensionOp::input_count(op.as_ref()),
399        }
400    }
401
402    fn output_count(&self) -> usize {
403        match self {
404            Self::Add
405            | Self::Mul
406            | Self::Neg
407            | Self::Conj
408            | Self::DotGeneral { .. }
409            | Self::Transpose { .. }
410            | Self::Reshape { .. }
411            | Self::BroadcastInDim { .. }
412            | Self::Convert { .. }
413            | Self::ReduceSum { .. }
414            | Self::Div
415            | Self::Abs
416            | Self::Sign
417            | Self::Maximum
418            | Self::Minimum
419            | Self::Compare(_)
420            | Self::Select
421            | Self::Clamp
422            | Self::Constant { .. }
423            | Self::Exp
424            | Self::Log
425            | Self::Sin
426            | Self::Cos
427            | Self::Tanh
428            | Self::Sqrt
429            | Self::Rsqrt
430            | Self::Pow
431            | Self::Expm1
432            | Self::Log1p
433            | Self::ExtractDiag { .. }
434            | Self::EmbedDiag { .. }
435            | Self::Tril { .. }
436            | Self::Triu { .. }
437            | Self::Gather(_)
438            | Self::GatherDynamicSliceSizes { .. }
439            | Self::Scatter(_)
440            | Self::Slice(_)
441            | Self::DynamicSlice { .. }
442            | Self::DynamicUpdateSlice
443            | Self::Pad(_)
444            | Self::Reverse { .. }
445            | Self::ShapeOf { .. }
446            | Self::DynamicTruncate { .. }
447            | Self::PadToMatch { .. }
448            | Self::ReduceProd { .. }
449            | Self::ReduceMax { .. }
450            | Self::ReduceMin { .. } => 1,
451            Self::Concatenate { .. } => 1,
452            Self::Extension(op) => ExtensionOp::output_count(op.as_ref()),
453        }
454    }
455}
456
457#[cfg(feature = "autodiff")]
458impl Primitive for StdTensorOp {
459    type ADContext = crate::ad::context::ShapeGuardContext;
460
461    fn add() -> Self {
462        StdTensorOp::Add
463    }
464
465    fn jvp_rule(
466        &self,
467        builder: &mut impl PrimitiveBuilder<Self>,
468        primal_in: &[ValueKey<Self>],
469        primal_out: &[ValueKey<Self>],
470        tangent_in: &[Option<LocalValueId>],
471        ctx: &mut Self::ADContext,
472    ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
473        crate::ad::linearize(self, builder, primal_in, primal_out, tangent_in, ctx)
474    }
475
476    fn transpose_rule(
477        &self,
478        builder: &mut impl PrimitiveBuilder<Self>,
479        cotangent_out: &[Option<LocalValueId>],
480        inputs: &[PrimitiveValue<Self>],
481        mode: &OperationRole,
482        ctx: &mut Self::ADContext,
483    ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
484        let inputs = inputs.iter().cloned().map(Into::into).collect::<Vec<_>>();
485        crate::ad::transpose_rule(self, builder, cotangent_out, &inputs, mode, ctx)
486    }
487}
488
489#[cfg(all(test, feature = "autodiff"))]
490impl StdTensorOp {
491    pub(crate) fn jvp_rule(
492        &self,
493        builder: &mut computegraph::graph::GraphBuilder<Self>,
494        primal_in: &[ValueKey<Self>],
495        primal_out: &[ValueKey<Self>],
496        tangent_in: &[Option<LocalValueId>],
497        ctx: &mut crate::ad::context::ShapeGuardContext,
498    ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
499        crate::ad::linearize(self, builder, primal_in, primal_out, tangent_in, ctx)
500    }
501
502    pub(crate) fn transpose_rule(
503        &self,
504        builder: &mut impl crate::ad::PrimitiveRuleBuilder,
505        cotangent_out: &[Option<LocalValueId>],
506        inputs: &[computegraph::ValueRef<Self>],
507        mode: &OperationRole,
508        ctx: &mut crate::ad::context::ShapeGuardContext,
509    ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
510        crate::ad::transpose_rule(self, builder, cotangent_out, inputs, mode, ctx)
511    }
512}