1use std::hash::{Hash, Hasher};
2use std::sync::Arc;
3
4#[cfg(feature = "autodiff")]
5use computegraph::types::{LocalValueId, OperationRole, ValueKey};
6use computegraph::GraphOperation;
7use num_complex::{Complex32, Complex64};
8#[cfg(feature = "autodiff")]
9use tidu::{ADRuleResult, Primitive, PrimitiveBuilder, PrimitiveValue};
10
11use crate::dim_expr::DimExpr;
12use crate::ext_op::{ext_op_eq, hash_extension, ExtensionOp};
13use crate::input_key::TensorInputKey;
14use tenferro_tensor::{
15 CompareDir, DType, DotGeneralConfig, GatherConfig, PadConfig, ScatterConfig, SliceConfig,
16 TensorScalar,
17};
18
19pub trait ConstantScalar: TensorScalar + private::Sealed {
29 fn constant_bytes(self) -> Vec<u8>;
39}
40
41mod private {
42 pub trait Sealed {}
43
44 impl Sealed for f64 {}
45 impl Sealed for f32 {}
46 impl Sealed for i64 {}
47 impl Sealed for i32 {}
48 impl Sealed for bool {}
49 impl Sealed for num_complex::Complex64 {}
50 impl Sealed for num_complex::Complex32 {}
51}
52
53impl ConstantScalar for f64 {
54 fn constant_bytes(self) -> Vec<u8> {
55 self.to_le_bytes().to_vec()
56 }
57}
58
59impl ConstantScalar for f32 {
60 fn constant_bytes(self) -> Vec<u8> {
61 self.to_le_bytes().to_vec()
62 }
63}
64
65impl ConstantScalar for i64 {
66 fn constant_bytes(self) -> Vec<u8> {
67 self.to_le_bytes().to_vec()
68 }
69}
70
71impl ConstantScalar for i32 {
72 fn constant_bytes(self) -> Vec<u8> {
73 self.to_le_bytes().to_vec()
74 }
75}
76
77impl ConstantScalar for bool {
78 fn constant_bytes(self) -> Vec<u8> {
79 vec![u8::from(self)]
80 }
81}
82
83impl ConstantScalar for Complex64 {
84 fn constant_bytes(self) -> Vec<u8> {
85 let mut bytes = Vec::with_capacity(16);
86 bytes.extend_from_slice(&self.re.to_le_bytes());
87 bytes.extend_from_slice(&self.im.to_le_bytes());
88 bytes
89 }
90}
91
92impl ConstantScalar for Complex32 {
93 fn constant_bytes(self) -> Vec<u8> {
94 let mut bytes = Vec::with_capacity(8);
95 bytes.extend_from_slice(&self.re.to_le_bytes());
96 bytes.extend_from_slice(&self.im.to_le_bytes());
97 bytes
98 }
99}
100
101tenferro_core_ops::define_std_tensor_op!();
102
103impl StdTensorOp {
104 pub fn constant<T: ConstantScalar>(value: T) -> Self {
120 Self::Constant {
121 dtype: T::dtype(),
122 bytes: value.constant_bytes(),
123 }
124 }
125}
126
127impl PartialEq for StdTensorOp {
128 fn eq(&self, other: &Self) -> bool {
129 if std::mem::discriminant(self) != std::mem::discriminant(other) {
130 return false;
131 }
132 match (self, other) {
133 (Self::Add, Self::Add)
134 | (Self::Mul, Self::Mul)
135 | (Self::Neg, Self::Neg)
136 | (Self::Conj, Self::Conj)
137 | (Self::Div, Self::Div)
138 | (Self::Abs, Self::Abs)
139 | (Self::Sign, Self::Sign)
140 | (Self::Maximum, Self::Maximum)
141 | (Self::Minimum, Self::Minimum)
142 | (Self::Select, Self::Select)
143 | (Self::Clamp, Self::Clamp)
144 | (Self::Exp, Self::Exp)
145 | (Self::Log, Self::Log)
146 | (Self::Sin, Self::Sin)
147 | (Self::Cos, Self::Cos)
148 | (Self::Tanh, Self::Tanh)
149 | (Self::Sqrt, Self::Sqrt)
150 | (Self::Rsqrt, Self::Rsqrt)
151 | (Self::Pow, Self::Pow)
152 | (Self::Expm1, Self::Expm1)
153 | (Self::Log1p, Self::Log1p)
154 | (Self::DynamicUpdateSlice, Self::DynamicUpdateSlice) => true,
155 (Self::DotGeneral { config: a }, Self::DotGeneral { config: b }) => a == b,
156 (Self::Transpose { perm: a }, Self::Transpose { perm: b }) => a == b,
157 (Self::Reshape { to_shape: a }, Self::Reshape { to_shape: b }) => a == b,
158 (
159 Self::BroadcastInDim {
160 shape: sa,
161 dims: da,
162 },
163 Self::BroadcastInDim {
164 shape: sb,
165 dims: db,
166 },
167 ) => sa == sb && da == db,
168 (Self::Convert { from: fa, to: ta }, Self::Convert { from: fb, to: tb }) => {
169 fa == fb && ta == tb
170 }
171 (
172 Self::Constant {
173 dtype: da,
174 bytes: ba,
175 },
176 Self::Constant {
177 dtype: db,
178 bytes: bb,
179 },
180 ) => da == db && ba == bb,
181 (Self::ReduceSum { axes: a }, Self::ReduceSum { axes: b })
182 | (Self::ReduceProd { axes: a }, Self::ReduceProd { axes: b })
183 | (Self::ReduceMax { axes: a }, Self::ReduceMax { axes: b })
184 | (Self::ReduceMin { axes: a }, Self::ReduceMin { axes: b })
185 | (Self::Reverse { axes: a }, Self::Reverse { axes: b }) => a == b,
186 (Self::Compare(a), Self::Compare(b)) => a == b,
187 (
188 Self::ExtractDiag {
189 axis_a: aa,
190 axis_b: ba,
191 },
192 Self::ExtractDiag {
193 axis_a: ab,
194 axis_b: bb,
195 },
196 )
197 | (
198 Self::EmbedDiag {
199 axis_a: aa,
200 axis_b: ba,
201 },
202 Self::EmbedDiag {
203 axis_a: ab,
204 axis_b: bb,
205 },
206 ) => aa == ab && ba == bb,
207 (Self::Tril { k: a }, Self::Tril { k: b })
208 | (Self::Triu { k: a }, Self::Triu { k: b }) => a == b,
209 (Self::Gather(a), Self::Gather(b)) => a == b,
210 (
211 Self::GatherDynamicSliceSizes {
212 offset_dims: oa,
213 collapsed_slice_dims: ca,
214 start_index_map: sa,
215 index_vector_dim: ia,
216 slice_sizes: za,
217 },
218 Self::GatherDynamicSliceSizes {
219 offset_dims: ob,
220 collapsed_slice_dims: cb,
221 start_index_map: sb,
222 index_vector_dim: ib,
223 slice_sizes: zb,
224 },
225 ) => oa == ob && ca == cb && sa == sb && ia == ib && za == zb,
226 (Self::Scatter(a), Self::Scatter(b)) => a == b,
227 (Self::Slice(a), Self::Slice(b)) => a == b,
228 (Self::DynamicSlice { slice_sizes: a }, Self::DynamicSlice { slice_sizes: b }) => {
229 a == b
230 }
231 (Self::Pad(a), Self::Pad(b)) => a == b,
232 (
233 Self::Concatenate {
234 axis: a,
235 input_count: na,
236 },
237 Self::Concatenate {
238 axis: b,
239 input_count: nb,
240 },
241 ) => a == b && na == nb,
242 (Self::ShapeOf { axis: a }, Self::ShapeOf { axis: b })
243 | (Self::DynamicTruncate { axis: a }, Self::DynamicTruncate { axis: b })
244 | (Self::PadToMatch { axis: a }, Self::PadToMatch { axis: b }) => a == b,
245 (Self::Extension(a), Self::Extension(b)) => ext_op_eq(a.as_ref(), b.as_ref()),
246 _ => false,
247 }
248 }
249}
250
251impl Eq for StdTensorOp {}
252
253impl Hash for StdTensorOp {
254 fn hash<H: Hasher>(&self, state: &mut H) {
255 std::mem::discriminant(self).hash(state);
256 match self {
257 Self::Add
258 | Self::Mul
259 | Self::Neg
260 | Self::Conj
261 | Self::Div
262 | Self::Abs
263 | Self::Sign
264 | Self::Maximum
265 | Self::Minimum
266 | Self::Select
267 | Self::Clamp
268 | Self::Exp
269 | Self::Log
270 | Self::Sin
271 | Self::Cos
272 | Self::Tanh
273 | Self::Sqrt
274 | Self::Rsqrt
275 | Self::Pow
276 | Self::Expm1
277 | Self::Log1p => {}
278 Self::DotGeneral { config } => {
279 config.hash(state);
280 }
281 Self::Transpose { perm } => perm.hash(state),
282 Self::Reshape { to_shape } => {
283 to_shape.hash(state);
284 }
285 Self::BroadcastInDim { shape, dims } => {
286 shape.hash(state);
287 dims.hash(state);
288 }
289 Self::Convert { from, to } => {
290 from.hash(state);
291 to.hash(state);
292 }
293 Self::Constant { dtype, bytes } => {
294 dtype.hash(state);
295 bytes.hash(state);
296 }
297 Self::ReduceSum { axes } => {
298 axes.hash(state);
299 }
300 Self::Compare(dir) => dir.hash(state),
301 Self::ExtractDiag { axis_a, axis_b } | Self::EmbedDiag { axis_a, axis_b } => {
302 axis_a.hash(state);
303 axis_b.hash(state);
304 }
305 Self::Tril { k } | Self::Triu { k } => k.hash(state),
306 Self::Gather(config) => config.hash(state),
307 Self::GatherDynamicSliceSizes {
308 offset_dims,
309 collapsed_slice_dims,
310 start_index_map,
311 index_vector_dim,
312 slice_sizes,
313 } => {
314 offset_dims.hash(state);
315 collapsed_slice_dims.hash(state);
316 start_index_map.hash(state);
317 index_vector_dim.hash(state);
318 slice_sizes.hash(state);
319 }
320 Self::Scatter(config) => config.hash(state),
321 Self::Slice(config) => config.hash(state),
322 Self::DynamicSlice { slice_sizes } => slice_sizes.hash(state),
323 Self::DynamicUpdateSlice => {}
324 Self::Pad(config) => config.hash(state),
325 Self::Concatenate { axis, input_count } => {
326 axis.hash(state);
327 input_count.hash(state);
328 }
329 Self::Reverse { axes } => axes.hash(state),
330 Self::ShapeOf { axis } | Self::DynamicTruncate { axis } | Self::PadToMatch { axis } => {
331 axis.hash(state)
332 }
333 Self::ReduceProd { axes } | Self::ReduceMax { axes } | Self::ReduceMin { axes } => {
334 axes.hash(state);
335 }
336 Self::Extension(op) => hash_extension(op.as_ref(), state),
337 }
338 }
339}
340
341fn n_inputs_from_dim_exprs(min_inputs: usize, exprs: &[&[DimExpr]]) -> usize {
342 let max_idx = exprs
343 .iter()
344 .flat_map(|exprs| exprs.iter())
345 .filter_map(DimExpr::max_input_idx)
346 .max()
347 .map_or(0, |max_idx| max_idx + 1);
348 max_idx.max(min_inputs)
349}
350
351impl GraphOperation for StdTensorOp {
352 type Operand = tenferro_tensor::Tensor;
353 type Context = ();
354 type InputKey = TensorInputKey;
355
356 fn input_count(&self) -> usize {
357 match self {
358 Self::Add | Self::Mul | Self::DotGeneral { .. } | Self::Gather(_) => 2,
359 Self::GatherDynamicSliceSizes { slice_sizes, .. } => {
360 n_inputs_from_dim_exprs(2, &[slice_sizes])
361 }
362 Self::Neg
363 | Self::Conj
364 | Self::Transpose { .. }
365 | Self::Convert { .. }
366 | Self::ExtractDiag { .. }
367 | Self::EmbedDiag { .. }
368 | Self::Tril { .. }
369 | Self::Triu { .. }
370 | Self::Slice(_)
371 | Self::Pad(_)
372 | Self::Reverse { .. }
373 | Self::ShapeOf { .. } => 1,
374 Self::DynamicTruncate { .. } | Self::PadToMatch { .. } => 2,
375 Self::Reshape { to_shape } => n_inputs_from_dim_exprs(1, &[to_shape]),
376 Self::BroadcastInDim { shape, .. } => n_inputs_from_dim_exprs(1, &[shape]),
377 Self::ReduceSum { .. }
378 | Self::ReduceProd { .. }
379 | Self::ReduceMax { .. }
380 | Self::ReduceMin { .. } => 1,
381 Self::Div | Self::Maximum | Self::Minimum | Self::Pow | Self::DynamicSlice { .. } => 2,
382 Self::Constant { .. } => 0,
383 Self::Scatter(_) | Self::DynamicUpdateSlice => 3,
384 Self::Concatenate { input_count, .. } => *input_count,
385 Self::Abs
386 | Self::Sign
387 | Self::Exp
388 | Self::Log
389 | Self::Sin
390 | Self::Cos
391 | Self::Tanh
392 | Self::Sqrt
393 | Self::Rsqrt
394 | Self::Expm1
395 | Self::Log1p => 1,
396 Self::Select | Self::Clamp => 3,
397 Self::Compare(_) => 2,
398 Self::Extension(op) => ExtensionOp::input_count(op.as_ref()),
399 }
400 }
401
402 fn output_count(&self) -> usize {
403 match self {
404 Self::Add
405 | Self::Mul
406 | Self::Neg
407 | Self::Conj
408 | Self::DotGeneral { .. }
409 | Self::Transpose { .. }
410 | Self::Reshape { .. }
411 | Self::BroadcastInDim { .. }
412 | Self::Convert { .. }
413 | Self::ReduceSum { .. }
414 | Self::Div
415 | Self::Abs
416 | Self::Sign
417 | Self::Maximum
418 | Self::Minimum
419 | Self::Compare(_)
420 | Self::Select
421 | Self::Clamp
422 | Self::Constant { .. }
423 | Self::Exp
424 | Self::Log
425 | Self::Sin
426 | Self::Cos
427 | Self::Tanh
428 | Self::Sqrt
429 | Self::Rsqrt
430 | Self::Pow
431 | Self::Expm1
432 | Self::Log1p
433 | Self::ExtractDiag { .. }
434 | Self::EmbedDiag { .. }
435 | Self::Tril { .. }
436 | Self::Triu { .. }
437 | Self::Gather(_)
438 | Self::GatherDynamicSliceSizes { .. }
439 | Self::Scatter(_)
440 | Self::Slice(_)
441 | Self::DynamicSlice { .. }
442 | Self::DynamicUpdateSlice
443 | Self::Pad(_)
444 | Self::Reverse { .. }
445 | Self::ShapeOf { .. }
446 | Self::DynamicTruncate { .. }
447 | Self::PadToMatch { .. }
448 | Self::ReduceProd { .. }
449 | Self::ReduceMax { .. }
450 | Self::ReduceMin { .. } => 1,
451 Self::Concatenate { .. } => 1,
452 Self::Extension(op) => ExtensionOp::output_count(op.as_ref()),
453 }
454 }
455}
456
457#[cfg(feature = "autodiff")]
458impl Primitive for StdTensorOp {
459 type ADContext = crate::ad::context::ShapeGuardContext;
460
461 fn add() -> Self {
462 StdTensorOp::Add
463 }
464
465 fn jvp_rule(
466 &self,
467 builder: &mut impl PrimitiveBuilder<Self>,
468 primal_in: &[ValueKey<Self>],
469 primal_out: &[ValueKey<Self>],
470 tangent_in: &[Option<LocalValueId>],
471 ctx: &mut Self::ADContext,
472 ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
473 crate::ad::linearize(self, builder, primal_in, primal_out, tangent_in, ctx)
474 }
475
476 fn transpose_rule(
477 &self,
478 builder: &mut impl PrimitiveBuilder<Self>,
479 cotangent_out: &[Option<LocalValueId>],
480 inputs: &[PrimitiveValue<Self>],
481 mode: &OperationRole,
482 ctx: &mut Self::ADContext,
483 ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
484 let inputs = inputs.iter().cloned().map(Into::into).collect::<Vec<_>>();
485 crate::ad::transpose_rule(self, builder, cotangent_out, &inputs, mode, ctx)
486 }
487}
488
489#[cfg(all(test, feature = "autodiff"))]
490impl StdTensorOp {
491 pub(crate) fn jvp_rule(
492 &self,
493 builder: &mut computegraph::graph::GraphBuilder<Self>,
494 primal_in: &[ValueKey<Self>],
495 primal_out: &[ValueKey<Self>],
496 tangent_in: &[Option<LocalValueId>],
497 ctx: &mut crate::ad::context::ShapeGuardContext,
498 ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
499 crate::ad::linearize(self, builder, primal_in, primal_out, tangent_in, ctx)
500 }
501
502 pub(crate) fn transpose_rule(
503 &self,
504 builder: &mut impl crate::ad::PrimitiveRuleBuilder,
505 cotangent_out: &[Option<LocalValueId>],
506 inputs: &[computegraph::ValueRef<Self>],
507 mode: &OperationRole,
508 ctx: &mut crate::ad::context::ShapeGuardContext,
509 ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
510 crate::ad::transpose_rule(self, builder, cotangent_out, inputs, mode, ctx)
511 }
512}