1#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
14pub enum OpCategory {
15 Elementwise,
16 Analytic,
17 Structural,
18 Reduction,
19 Contraction,
20 Indexing,
21 Dynamic,
22 Host,
23}
24
25#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
38pub enum DTypePolicy {
39 SameAny,
40 SameNumeric,
41 SameFloat,
42 AbsToReal,
44 SameFloatOrComplex,
45 CompareToBool,
46 BoolSelect,
47 Convert,
48 Shape,
49 Constant,
50}
51
52#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
64pub struct PrimitiveOpDescriptor {
65 pub kind: PrimitiveOpKind,
67 pub name: &'static str,
69 pub category: OpCategory,
71 pub dtype_policy: DTypePolicy,
73 pub min_inputs: u8,
75 pub max_inputs: u8,
77 pub host_only: bool,
79}
80
81macro_rules! primitive_ops {
82 ($macro:ident) => {
83 $macro! {
84 Add, "add", Elementwise, SameNumeric, 2, 2, false;
85 Mul, "mul", Elementwise, SameNumeric, 2, 2, false;
86 Neg, "neg", Elementwise, SameNumeric, 1, 1, false;
87 Conj, "conj", Elementwise, SameFloatOrComplex, 1, 1, false;
88 Div, "div", Elementwise, SameFloatOrComplex, 2, 2, false;
89 Abs, "abs", Elementwise, AbsToReal, 1, 1, false;
90 Sign, "sign", Elementwise, SameFloat, 1, 1, false;
91 Maximum, "maximum", Elementwise, SameFloat, 2, 2, false;
92 Minimum, "minimum", Elementwise, SameFloat, 2, 2, false;
93 Compare, "compare", Elementwise, CompareToBool, 2, 2, false;
94 Select, "select", Elementwise, BoolSelect, 3, 3, false;
95 Clamp, "clamp", Elementwise, SameFloat, 3, 3, false;
96 Exp, "exp", Analytic, SameFloatOrComplex, 1, 1, false;
97 Log, "log", Analytic, SameFloatOrComplex, 1, 1, false;
98 Sin, "sin", Analytic, SameFloatOrComplex, 1, 1, false;
99 Cos, "cos", Analytic, SameFloatOrComplex, 1, 1, false;
100 Tanh, "tanh", Analytic, SameFloatOrComplex, 1, 1, false;
101 Sqrt, "sqrt", Analytic, SameFloatOrComplex, 1, 1, false;
102 Rsqrt, "rsqrt", Analytic, SameFloatOrComplex, 1, 1, false;
103 Pow, "pow", Analytic, SameFloatOrComplex, 2, 2, false;
104 Expm1, "expm1", Analytic, SameFloatOrComplex, 1, 1, false;
105 Log1p, "log1p", Analytic, SameFloatOrComplex, 1, 1, false;
106 DotGeneral, "dot_general", Contraction, SameFloatOrComplex, 2, 2, false;
107 ReduceSum, "reduce_sum", Reduction, SameNumeric, 1, 1, false;
108 ReduceProd, "reduce_prod", Reduction, SameNumeric, 1, 1, false;
109 ReduceMax, "reduce_max", Reduction, SameFloat, 1, 1, false;
110 ReduceMin, "reduce_min", Reduction, SameFloat, 1, 1, false;
111 Transpose, "transpose", Structural, SameAny, 1, 1, false;
112 Reshape, "reshape", Structural, SameAny, 1, 1, false;
113 BroadcastInDim, "broadcast_in_dim", Structural, SameAny, 1, 1, false;
114 Convert, "convert", Structural, Convert, 1, 1, false;
115 ExtractDiag, "extract_diag", Structural, SameAny, 1, 1, false;
116 EmbedDiag, "embed_diag", Structural, SameAny, 1, 1, false;
117 Tril, "tril", Structural, SameAny, 1, 1, false;
118 Triu, "triu", Structural, SameAny, 1, 1, false;
119 Gather, "gather", Indexing, SameAny, 2, 2, false;
120 GatherDynamicSliceSizes, "gather_dynamic_slice_sizes", Indexing, SameAny, 2, 2, false;
121 Scatter, "scatter", Indexing, SameAny, 3, 3, false;
122 Slice, "slice", Indexing, SameAny, 1, 1, false;
123 DynamicSlice, "dynamic_slice", Indexing, SameAny, 2, 2, false;
124 DynamicUpdateSlice, "dynamic_update_slice", Indexing, SameAny, 3, 3, false;
125 Pad, "pad", Indexing, SameAny, 1, 1, false;
126 Concatenate, "concatenate", Indexing, SameAny, 1, u8::MAX, false;
127 Reverse, "reverse", Indexing, SameAny, 1, 1, false;
128 ShapeOf, "shape_of", Host, Shape, 1, 1, true;
129 DynamicTruncate, "dynamic_truncate", Dynamic, SameAny, 2, 2, true;
130 PadToMatch, "pad_to_match", Dynamic, SameAny, 2, 2, true;
131 Constant, "constant", Host, Constant, 0, 0, true;
132 }
133 };
134}
135
136macro_rules! define_kind {
137 ($( $variant:ident, $name:literal, $category:ident, $policy:ident, $min:expr, $max:expr, $host:expr; )*) => {
138 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
148 pub enum PrimitiveOpKind {
149 $( $variant, )*
150 }
151
152 impl PrimitiveOpKind {
153 pub const COUNT: usize = [$(PrimitiveOpKind::$variant),*].len();
163
164 pub const fn as_index(self) -> usize {
174 self as usize
175 }
176 }
177 };
178}
179
180primitive_ops!(define_kind);
181
182macro_rules! define_descriptors {
183 ($( $variant:ident, $name:literal, $category:ident, $policy:ident, $min:expr, $max:expr, $host:expr; )*) => {
184 const DESCRIPTORS: &[PrimitiveOpDescriptor] = &[
185 $(
186 PrimitiveOpDescriptor {
187 kind: PrimitiveOpKind::$variant,
188 name: $name,
189 category: OpCategory::$category,
190 dtype_policy: DTypePolicy::$policy,
191 min_inputs: $min,
192 max_inputs: $max,
193 host_only: $host,
194 },
195 )*
196 ];
197
198 pub fn descriptor(kind: PrimitiveOpKind) -> &'static PrimitiveOpDescriptor {
208 match kind {
209 $(
210 PrimitiveOpKind::$variant => &DESCRIPTORS[PrimitiveOpKind::$variant as usize],
211 )*
212 }
213 }
214 };
215}
216
217primitive_ops!(define_descriptors);
218
219pub fn all_primitive_descriptors() -> &'static [PrimitiveOpDescriptor] {
231 DESCRIPTORS
232}
233
234#[doc(hidden)]
235#[macro_export]
236macro_rules! define_std_tensor_op {
237 () => {
238 #[derive(Clone, Debug)]
239 pub enum StdTensorOp {
240 Add,
242 Mul,
243 Neg,
244 Conj,
245 DotGeneral {
246 config: DotGeneralConfig,
247 },
248 Transpose {
249 perm: Vec<usize>,
250 },
251 Reshape {
252 to_shape: Vec<DimExpr>,
253 },
254 BroadcastInDim {
255 shape: Vec<DimExpr>,
256 dims: Vec<usize>,
257 },
258 Convert {
259 from: DType,
260 to: DType,
261 },
262 Constant {
263 dtype: DType,
264 bytes: Vec<u8>,
265 },
266 ReduceSum {
267 axes: Vec<usize>,
268 },
269
270 Div,
272 Abs,
273 Sign,
274 Maximum,
275 Minimum,
276 Compare(CompareDir),
277 Select,
278 Clamp,
279
280 Exp,
282 Log,
283 Sin,
284 Cos,
285 Tanh,
286 Sqrt,
287 Rsqrt,
288 Pow,
289 Expm1,
290 Log1p,
291
292 ExtractDiag {
294 axis_a: usize,
295 axis_b: usize,
296 },
297 EmbedDiag {
298 axis_a: usize,
299 axis_b: usize,
300 },
301 Tril {
302 k: i64,
303 },
304 Triu {
305 k: i64,
306 },
307
308 Gather(GatherConfig),
310 GatherDynamicSliceSizes {
311 offset_dims: Vec<usize>,
312 collapsed_slice_dims: Vec<usize>,
313 start_index_map: Vec<usize>,
314 index_vector_dim: usize,
315 slice_sizes: Vec<DimExpr>,
316 },
317 Scatter(ScatterConfig),
318 Slice(SliceConfig),
319 DynamicSlice {
320 slice_sizes: Vec<usize>,
321 },
322 DynamicUpdateSlice,
323 Pad(PadConfig),
324 Concatenate {
325 axis: usize,
326 input_count: usize,
327 },
328 Reverse {
329 axes: Vec<usize>,
330 },
331 ShapeOf {
332 axis: usize,
333 },
334 DynamicTruncate {
335 axis: usize,
336 },
337 PadToMatch {
338 axis: usize,
339 },
340
341 ReduceProd {
343 axes: Vec<usize>,
344 },
345 ReduceMax {
346 axes: Vec<usize>,
347 },
348 ReduceMin {
349 axes: Vec<usize>,
350 },
351
352 Extension(Arc<dyn ExtensionOp>),
358 }
359
360 impl StdTensorOp {
361 pub fn primitive_kind(&self) -> Option<$crate::PrimitiveOpKind> {
375 let kind = match self {
376 Self::Add => $crate::PrimitiveOpKind::Add,
377 Self::Mul => $crate::PrimitiveOpKind::Mul,
378 Self::Neg => $crate::PrimitiveOpKind::Neg,
379 Self::Conj => $crate::PrimitiveOpKind::Conj,
380 Self::DotGeneral { .. } => $crate::PrimitiveOpKind::DotGeneral,
381 Self::Transpose { .. } => $crate::PrimitiveOpKind::Transpose,
382 Self::Reshape { .. } => $crate::PrimitiveOpKind::Reshape,
383 Self::BroadcastInDim { .. } => $crate::PrimitiveOpKind::BroadcastInDim,
384 Self::Convert { .. } => $crate::PrimitiveOpKind::Convert,
385 Self::Constant { .. } => $crate::PrimitiveOpKind::Constant,
386 Self::ReduceSum { .. } => $crate::PrimitiveOpKind::ReduceSum,
387 Self::Div => $crate::PrimitiveOpKind::Div,
388 Self::Abs => $crate::PrimitiveOpKind::Abs,
389 Self::Sign => $crate::PrimitiveOpKind::Sign,
390 Self::Maximum => $crate::PrimitiveOpKind::Maximum,
391 Self::Minimum => $crate::PrimitiveOpKind::Minimum,
392 Self::Compare(_) => $crate::PrimitiveOpKind::Compare,
393 Self::Select => $crate::PrimitiveOpKind::Select,
394 Self::Clamp => $crate::PrimitiveOpKind::Clamp,
395 Self::Exp => $crate::PrimitiveOpKind::Exp,
396 Self::Log => $crate::PrimitiveOpKind::Log,
397 Self::Sin => $crate::PrimitiveOpKind::Sin,
398 Self::Cos => $crate::PrimitiveOpKind::Cos,
399 Self::Tanh => $crate::PrimitiveOpKind::Tanh,
400 Self::Sqrt => $crate::PrimitiveOpKind::Sqrt,
401 Self::Rsqrt => $crate::PrimitiveOpKind::Rsqrt,
402 Self::Pow => $crate::PrimitiveOpKind::Pow,
403 Self::Expm1 => $crate::PrimitiveOpKind::Expm1,
404 Self::Log1p => $crate::PrimitiveOpKind::Log1p,
405 Self::ExtractDiag { .. } => $crate::PrimitiveOpKind::ExtractDiag,
406 Self::EmbedDiag { .. } => $crate::PrimitiveOpKind::EmbedDiag,
407 Self::Tril { .. } => $crate::PrimitiveOpKind::Tril,
408 Self::Triu { .. } => $crate::PrimitiveOpKind::Triu,
409 Self::Gather(_) => $crate::PrimitiveOpKind::Gather,
410 Self::GatherDynamicSliceSizes { .. } => {
411 $crate::PrimitiveOpKind::GatherDynamicSliceSizes
412 }
413 Self::Scatter(_) => $crate::PrimitiveOpKind::Scatter,
414 Self::Slice(_) => $crate::PrimitiveOpKind::Slice,
415 Self::DynamicSlice { .. } => $crate::PrimitiveOpKind::DynamicSlice,
416 Self::DynamicUpdateSlice => $crate::PrimitiveOpKind::DynamicUpdateSlice,
417 Self::Pad(_) => $crate::PrimitiveOpKind::Pad,
418 Self::Concatenate { .. } => $crate::PrimitiveOpKind::Concatenate,
419 Self::Reverse { .. } => $crate::PrimitiveOpKind::Reverse,
420 Self::ShapeOf { .. } => $crate::PrimitiveOpKind::ShapeOf,
421 Self::DynamicTruncate { .. } => $crate::PrimitiveOpKind::DynamicTruncate,
422 Self::PadToMatch { .. } => $crate::PrimitiveOpKind::PadToMatch,
423 Self::ReduceProd { .. } => $crate::PrimitiveOpKind::ReduceProd,
424 Self::ReduceMax { .. } => $crate::PrimitiveOpKind::ReduceMax,
425 Self::ReduceMin { .. } => $crate::PrimitiveOpKind::ReduceMin,
426 Self::Extension(_) => return None,
427 };
428 Some(kind)
429 }
430
431 #[cfg(test)]
432 pub(crate) fn sample_from_kind(kind: $crate::PrimitiveOpKind) -> Self {
433 match kind {
434 $crate::PrimitiveOpKind::Add => Self::Add,
435 $crate::PrimitiveOpKind::Mul => Self::Mul,
436 $crate::PrimitiveOpKind::Neg => Self::Neg,
437 $crate::PrimitiveOpKind::Conj => Self::Conj,
438 $crate::PrimitiveOpKind::DotGeneral => Self::DotGeneral {
439 config: DotGeneralConfig {
440 lhs_contracting_dims: vec![0],
441 rhs_contracting_dims: vec![0],
442 lhs_batch_dims: vec![],
443 rhs_batch_dims: vec![],
444 },
445 },
446 $crate::PrimitiveOpKind::Transpose => Self::Transpose { perm: vec![0] },
447 $crate::PrimitiveOpKind::Reshape => Self::Reshape {
448 to_shape: vec![DimExpr::Const(1)],
449 },
450 $crate::PrimitiveOpKind::BroadcastInDim => Self::BroadcastInDim {
451 shape: vec![DimExpr::Const(1)],
452 dims: vec![0],
453 },
454 $crate::PrimitiveOpKind::Convert => Self::Convert {
455 from: DType::F32,
456 to: DType::F64,
457 },
458 $crate::PrimitiveOpKind::Constant => Self::Constant {
459 dtype: DType::F64,
460 bytes: 0.0_f64.to_le_bytes().to_vec(),
461 },
462 $crate::PrimitiveOpKind::ReduceSum => Self::ReduceSum { axes: vec![0] },
463 $crate::PrimitiveOpKind::Div => Self::Div,
464 $crate::PrimitiveOpKind::Abs => Self::Abs,
465 $crate::PrimitiveOpKind::Sign => Self::Sign,
466 $crate::PrimitiveOpKind::Maximum => Self::Maximum,
467 $crate::PrimitiveOpKind::Minimum => Self::Minimum,
468 $crate::PrimitiveOpKind::Compare => Self::Compare(CompareDir::Eq),
469 $crate::PrimitiveOpKind::Select => Self::Select,
470 $crate::PrimitiveOpKind::Clamp => Self::Clamp,
471 $crate::PrimitiveOpKind::Exp => Self::Exp,
472 $crate::PrimitiveOpKind::Log => Self::Log,
473 $crate::PrimitiveOpKind::Sin => Self::Sin,
474 $crate::PrimitiveOpKind::Cos => Self::Cos,
475 $crate::PrimitiveOpKind::Tanh => Self::Tanh,
476 $crate::PrimitiveOpKind::Sqrt => Self::Sqrt,
477 $crate::PrimitiveOpKind::Rsqrt => Self::Rsqrt,
478 $crate::PrimitiveOpKind::Pow => Self::Pow,
479 $crate::PrimitiveOpKind::Expm1 => Self::Expm1,
480 $crate::PrimitiveOpKind::Log1p => Self::Log1p,
481 $crate::PrimitiveOpKind::ExtractDiag => Self::ExtractDiag {
482 axis_a: 0,
483 axis_b: 1,
484 },
485 $crate::PrimitiveOpKind::EmbedDiag => Self::EmbedDiag {
486 axis_a: 0,
487 axis_b: 1,
488 },
489 $crate::PrimitiveOpKind::Tril => Self::Tril { k: 0 },
490 $crate::PrimitiveOpKind::Triu => Self::Triu { k: 0 },
491 $crate::PrimitiveOpKind::Gather => Self::Gather(GatherConfig {
492 offset_dims: vec![],
493 collapsed_slice_dims: vec![0],
494 start_index_map: vec![0],
495 index_vector_dim: 1,
496 slice_sizes: vec![1],
497 }),
498 $crate::PrimitiveOpKind::GatherDynamicSliceSizes => {
499 Self::GatherDynamicSliceSizes {
500 offset_dims: vec![],
501 collapsed_slice_dims: vec![0],
502 start_index_map: vec![0],
503 index_vector_dim: 1,
504 slice_sizes: vec![DimExpr::Const(1)],
505 }
506 }
507 $crate::PrimitiveOpKind::Scatter => Self::Scatter(ScatterConfig {
508 update_window_dims: vec![],
509 inserted_window_dims: vec![0],
510 scatter_dims_to_operand_dims: vec![0],
511 index_vector_dim: 1,
512 }),
513 $crate::PrimitiveOpKind::Slice => Self::Slice(SliceConfig {
514 starts: vec![0],
515 limits: vec![1],
516 strides: vec![1],
517 }),
518 $crate::PrimitiveOpKind::DynamicSlice => Self::DynamicSlice {
519 slice_sizes: vec![1],
520 },
521 $crate::PrimitiveOpKind::DynamicUpdateSlice => Self::DynamicUpdateSlice,
522 $crate::PrimitiveOpKind::Pad => Self::Pad(PadConfig {
523 edge_padding_low: vec![0],
524 edge_padding_high: vec![0],
525 interior_padding: vec![0],
526 }),
527 $crate::PrimitiveOpKind::Concatenate => Self::Concatenate {
528 axis: 0,
529 input_count: 1,
530 },
531 $crate::PrimitiveOpKind::Reverse => Self::Reverse { axes: vec![0] },
532 $crate::PrimitiveOpKind::ShapeOf => Self::ShapeOf { axis: 0 },
533 $crate::PrimitiveOpKind::DynamicTruncate => Self::DynamicTruncate { axis: 0 },
534 $crate::PrimitiveOpKind::PadToMatch => Self::PadToMatch { axis: 0 },
535 $crate::PrimitiveOpKind::ReduceProd => Self::ReduceProd { axes: vec![0] },
536 $crate::PrimitiveOpKind::ReduceMax => Self::ReduceMax { axes: vec![0] },
537 $crate::PrimitiveOpKind::ReduceMin => Self::ReduceMin { axes: vec![0] },
538 }
539 }
540 }
541 };
542}
543
544#[doc(hidden)]
545#[macro_export]
546macro_rules! define_elementwise_fusion_op {
547 () => {
548 #[doc(hidden)]
550 #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
551 pub enum ElementwiseFusionOp {
552 Add,
553 Multiply,
554 Negate,
555 Conj,
556 Divide,
557 Abs,
558 Maximum,
559 Minimum,
560 Clamp,
561 Exp,
562 Log,
563 Sin,
564 Cos,
565 Tanh,
566 Sqrt,
567 Rsqrt,
568 Pow,
569 Expm1,
570 Log1p,
571 }
572
573 #[cfg(test)]
574 impl ElementwiseFusionOp {
575 pub(crate) fn iter() -> impl Iterator<Item = Self> {
576 [
577 Self::Add,
578 Self::Multiply,
579 Self::Negate,
580 Self::Conj,
581 Self::Divide,
582 Self::Abs,
583 Self::Maximum,
584 Self::Minimum,
585 Self::Clamp,
586 Self::Exp,
587 Self::Log,
588 Self::Sin,
589 Self::Cos,
590 Self::Tanh,
591 Self::Sqrt,
592 Self::Rsqrt,
593 Self::Pow,
594 Self::Expm1,
595 Self::Log1p,
596 ]
597 .into_iter()
598 }
599
600 pub(crate) fn from_primitive_kind(kind: $crate::PrimitiveOpKind) -> Option<Self> {
601 match kind {
602 $crate::PrimitiveOpKind::Add => Some(Self::Add),
603 $crate::PrimitiveOpKind::Mul => Some(Self::Multiply),
604 $crate::PrimitiveOpKind::Neg => Some(Self::Negate),
605 $crate::PrimitiveOpKind::Conj => Some(Self::Conj),
606 $crate::PrimitiveOpKind::Div => Some(Self::Divide),
607 $crate::PrimitiveOpKind::Abs => Some(Self::Abs),
608 $crate::PrimitiveOpKind::Maximum => Some(Self::Maximum),
609 $crate::PrimitiveOpKind::Minimum => Some(Self::Minimum),
610 $crate::PrimitiveOpKind::Clamp => Some(Self::Clamp),
611 $crate::PrimitiveOpKind::Exp => Some(Self::Exp),
612 $crate::PrimitiveOpKind::Log => Some(Self::Log),
613 $crate::PrimitiveOpKind::Sin => Some(Self::Sin),
614 $crate::PrimitiveOpKind::Cos => Some(Self::Cos),
615 $crate::PrimitiveOpKind::Tanh => Some(Self::Tanh),
616 $crate::PrimitiveOpKind::Sqrt => Some(Self::Sqrt),
617 $crate::PrimitiveOpKind::Rsqrt => Some(Self::Rsqrt),
618 $crate::PrimitiveOpKind::Pow => Some(Self::Pow),
619 $crate::PrimitiveOpKind::Expm1 => Some(Self::Expm1),
620 $crate::PrimitiveOpKind::Log1p => Some(Self::Log1p),
621 _ => None,
622 }
623 }
624
625 pub(crate) fn primitive_kind(self) -> $crate::PrimitiveOpKind {
626 match self {
627 Self::Add => $crate::PrimitiveOpKind::Add,
628 Self::Multiply => $crate::PrimitiveOpKind::Mul,
629 Self::Negate => $crate::PrimitiveOpKind::Neg,
630 Self::Conj => $crate::PrimitiveOpKind::Conj,
631 Self::Divide => $crate::PrimitiveOpKind::Div,
632 Self::Abs => $crate::PrimitiveOpKind::Abs,
633 Self::Maximum => $crate::PrimitiveOpKind::Maximum,
634 Self::Minimum => $crate::PrimitiveOpKind::Minimum,
635 Self::Clamp => $crate::PrimitiveOpKind::Clamp,
636 Self::Exp => $crate::PrimitiveOpKind::Exp,
637 Self::Log => $crate::PrimitiveOpKind::Log,
638 Self::Sin => $crate::PrimitiveOpKind::Sin,
639 Self::Cos => $crate::PrimitiveOpKind::Cos,
640 Self::Tanh => $crate::PrimitiveOpKind::Tanh,
641 Self::Sqrt => $crate::PrimitiveOpKind::Sqrt,
642 Self::Rsqrt => $crate::PrimitiveOpKind::Rsqrt,
643 Self::Pow => $crate::PrimitiveOpKind::Pow,
644 Self::Expm1 => $crate::PrimitiveOpKind::Expm1,
645 Self::Log1p => $crate::PrimitiveOpKind::Log1p,
646 }
647 }
648 }
649 };
650}
651
652#[doc(hidden)]
653#[macro_export]
654macro_rules! define_exec_op {
655 () => {
656 #[derive(Clone, Debug)]
657 pub enum ExecOp {
658 Transpose {
659 perm: Vec<usize>,
660 },
661 Reshape {
662 shape: Vec<DimExpr>,
663 },
664 BroadcastInDim {
665 shape: Vec<DimExpr>,
666 dims: Vec<usize>,
667 },
668 Convert {
669 to: DType,
670 },
671 Constant {
672 dtype: DType,
673 bytes: Vec<u8>,
674 },
675 DotGeneral(DotGeneralConfig),
676 DotGeneralWithConj {
677 config: DotGeneralConfig,
678 lhs_conj: bool,
679 rhs_conj: bool,
680 },
681 ReduceSum {
682 axes: Vec<usize>,
683 },
684 ExtractDiag {
685 axis_a: usize,
686 axis_b: usize,
687 },
688 EmbedDiag {
689 axis_a: usize,
690 axis_b: usize,
691 },
692 Tril {
693 k: i64,
694 },
695 Triu {
696 k: i64,
697 },
698 Add,
699 Multiply,
700 Negate,
701 Conj,
702 Divide,
703 Abs,
704 Sign,
705 Maximum,
706 Minimum,
707 Compare(CompareDir),
708 Select,
709 Clamp,
710 Exp,
711 Log,
712 Sin,
713 Cos,
714 Tanh,
715 Sqrt,
716 Rsqrt,
717 Pow,
718 Expm1,
719 Log1p,
720 Gather(GatherConfig),
721 GatherDynamicSliceSizes {
722 offset_dims: Vec<usize>,
723 collapsed_slice_dims: Vec<usize>,
724 start_index_map: Vec<usize>,
725 index_vector_dim: usize,
726 slice_sizes: Vec<DimExpr>,
727 },
728 Scatter(ScatterConfig),
729 Slice(SliceConfig),
730 DynamicSlice {
731 slice_sizes: Vec<usize>,
732 },
733 DynamicUpdateSlice,
734 Pad(PadConfig),
735 Concatenate {
736 axis: usize,
737 },
738 Reverse {
739 axes: Vec<usize>,
740 },
741 ShapeOf {
742 axis: usize,
743 },
744 DynamicTruncate {
745 axis: usize,
746 },
747 PadToMatch {
748 axis: usize,
749 },
750 ReduceProd {
751 axes: Vec<usize>,
752 },
753 ReduceMax {
754 axes: Vec<usize>,
755 },
756 ReduceMin {
757 axes: Vec<usize>,
758 },
759 Extension(Arc<dyn ExtensionOp>),
766 }
767
768 impl ExecOp {
769 pub(crate) fn primitive_kind(&self) -> Option<$crate::PrimitiveOpKind> {
770 let kind = match self {
771 Self::Transpose { .. } => $crate::PrimitiveOpKind::Transpose,
772 Self::Reshape { .. } => $crate::PrimitiveOpKind::Reshape,
773 Self::BroadcastInDim { .. } => $crate::PrimitiveOpKind::BroadcastInDim,
774 Self::Convert { .. } => $crate::PrimitiveOpKind::Convert,
775 Self::Constant { .. } => $crate::PrimitiveOpKind::Constant,
776 Self::DotGeneral(_) | Self::DotGeneralWithConj { .. } => {
777 $crate::PrimitiveOpKind::DotGeneral
778 }
779 Self::ReduceSum { .. } => $crate::PrimitiveOpKind::ReduceSum,
780 Self::ExtractDiag { .. } => $crate::PrimitiveOpKind::ExtractDiag,
781 Self::EmbedDiag { .. } => $crate::PrimitiveOpKind::EmbedDiag,
782 Self::Tril { .. } => $crate::PrimitiveOpKind::Tril,
783 Self::Triu { .. } => $crate::PrimitiveOpKind::Triu,
784 Self::Add => $crate::PrimitiveOpKind::Add,
785 Self::Multiply => $crate::PrimitiveOpKind::Mul,
786 Self::Negate => $crate::PrimitiveOpKind::Neg,
787 Self::Conj => $crate::PrimitiveOpKind::Conj,
788 Self::Divide => $crate::PrimitiveOpKind::Div,
789 Self::Abs => $crate::PrimitiveOpKind::Abs,
790 Self::Sign => $crate::PrimitiveOpKind::Sign,
791 Self::Maximum => $crate::PrimitiveOpKind::Maximum,
792 Self::Minimum => $crate::PrimitiveOpKind::Minimum,
793 Self::Compare(_) => $crate::PrimitiveOpKind::Compare,
794 Self::Select => $crate::PrimitiveOpKind::Select,
795 Self::Clamp => $crate::PrimitiveOpKind::Clamp,
796 Self::Exp => $crate::PrimitiveOpKind::Exp,
797 Self::Log => $crate::PrimitiveOpKind::Log,
798 Self::Sin => $crate::PrimitiveOpKind::Sin,
799 Self::Cos => $crate::PrimitiveOpKind::Cos,
800 Self::Tanh => $crate::PrimitiveOpKind::Tanh,
801 Self::Sqrt => $crate::PrimitiveOpKind::Sqrt,
802 Self::Rsqrt => $crate::PrimitiveOpKind::Rsqrt,
803 Self::Pow => $crate::PrimitiveOpKind::Pow,
804 Self::Expm1 => $crate::PrimitiveOpKind::Expm1,
805 Self::Log1p => $crate::PrimitiveOpKind::Log1p,
806 Self::Gather(_) => $crate::PrimitiveOpKind::Gather,
807 Self::GatherDynamicSliceSizes { .. } => {
808 $crate::PrimitiveOpKind::GatherDynamicSliceSizes
809 }
810 Self::Scatter(_) => $crate::PrimitiveOpKind::Scatter,
811 Self::Slice(_) => $crate::PrimitiveOpKind::Slice,
812 Self::DynamicSlice { .. } => $crate::PrimitiveOpKind::DynamicSlice,
813 Self::DynamicUpdateSlice => $crate::PrimitiveOpKind::DynamicUpdateSlice,
814 Self::Pad(_) => $crate::PrimitiveOpKind::Pad,
815 Self::Concatenate { .. } => $crate::PrimitiveOpKind::Concatenate,
816 Self::Reverse { .. } => $crate::PrimitiveOpKind::Reverse,
817 Self::ShapeOf { .. } => $crate::PrimitiveOpKind::ShapeOf,
818 Self::DynamicTruncate { .. } => $crate::PrimitiveOpKind::DynamicTruncate,
819 Self::PadToMatch { .. } => $crate::PrimitiveOpKind::PadToMatch,
820 Self::ReduceProd { .. } => $crate::PrimitiveOpKind::ReduceProd,
821 Self::ReduceMax { .. } => $crate::PrimitiveOpKind::ReduceMax,
822 Self::ReduceMin { .. } => $crate::PrimitiveOpKind::ReduceMin,
823 Self::Extension(_) => return None,
824 };
825 Some(kind)
826 }
827
828 pub(crate) fn from_std_tensor_op(
829 op: &tenferro_ops::std_tensor_op::StdTensorOp,
830 ) -> Self {
831 match op {
832 tenferro_ops::std_tensor_op::StdTensorOp::Add => Self::Add,
833 tenferro_ops::std_tensor_op::StdTensorOp::Mul => Self::Multiply,
834 tenferro_ops::std_tensor_op::StdTensorOp::Neg => Self::Negate,
835 tenferro_ops::std_tensor_op::StdTensorOp::Conj => Self::Conj,
836 tenferro_ops::std_tensor_op::StdTensorOp::Div => Self::Divide,
837 tenferro_ops::std_tensor_op::StdTensorOp::Abs => Self::Abs,
838 tenferro_ops::std_tensor_op::StdTensorOp::Sign => Self::Sign,
839 tenferro_ops::std_tensor_op::StdTensorOp::Maximum => Self::Maximum,
840 tenferro_ops::std_tensor_op::StdTensorOp::Minimum => Self::Minimum,
841 tenferro_ops::std_tensor_op::StdTensorOp::Compare(dir) => {
842 Self::Compare(dir.clone())
843 }
844 tenferro_ops::std_tensor_op::StdTensorOp::Select => Self::Select,
845 tenferro_ops::std_tensor_op::StdTensorOp::Clamp => Self::Clamp,
846 tenferro_ops::std_tensor_op::StdTensorOp::Exp => Self::Exp,
847 tenferro_ops::std_tensor_op::StdTensorOp::Log => Self::Log,
848 tenferro_ops::std_tensor_op::StdTensorOp::Sin => Self::Sin,
849 tenferro_ops::std_tensor_op::StdTensorOp::Cos => Self::Cos,
850 tenferro_ops::std_tensor_op::StdTensorOp::Tanh => Self::Tanh,
851 tenferro_ops::std_tensor_op::StdTensorOp::Sqrt => Self::Sqrt,
852 tenferro_ops::std_tensor_op::StdTensorOp::Rsqrt => Self::Rsqrt,
853 tenferro_ops::std_tensor_op::StdTensorOp::Pow => Self::Pow,
854 tenferro_ops::std_tensor_op::StdTensorOp::Expm1 => Self::Expm1,
855 tenferro_ops::std_tensor_op::StdTensorOp::Log1p => Self::Log1p,
856 tenferro_ops::std_tensor_op::StdTensorOp::Transpose { perm } => {
857 Self::Transpose { perm: perm.clone() }
858 }
859 tenferro_ops::std_tensor_op::StdTensorOp::Reshape { to_shape } => {
860 Self::Reshape {
861 shape: to_shape.clone(),
862 }
863 }
864 tenferro_ops::std_tensor_op::StdTensorOp::BroadcastInDim { shape, dims } => {
865 Self::BroadcastInDim {
866 shape: shape.clone(),
867 dims: dims.clone(),
868 }
869 }
870 tenferro_ops::std_tensor_op::StdTensorOp::Convert { to, .. } => {
871 Self::Convert { to: *to }
872 }
873 tenferro_ops::std_tensor_op::StdTensorOp::Constant { dtype, bytes } => {
874 Self::Constant {
875 dtype: *dtype,
876 bytes: bytes.clone(),
877 }
878 }
879 tenferro_ops::std_tensor_op::StdTensorOp::DotGeneral { config } => {
880 Self::DotGeneral(config.clone())
881 }
882 tenferro_ops::std_tensor_op::StdTensorOp::ReduceSum { axes } => {
883 Self::ReduceSum { axes: axes.clone() }
884 }
885 tenferro_ops::std_tensor_op::StdTensorOp::ReduceProd { axes } => {
886 Self::ReduceProd { axes: axes.clone() }
887 }
888 tenferro_ops::std_tensor_op::StdTensorOp::ReduceMax { axes } => {
889 Self::ReduceMax { axes: axes.clone() }
890 }
891 tenferro_ops::std_tensor_op::StdTensorOp::ReduceMin { axes } => {
892 Self::ReduceMin { axes: axes.clone() }
893 }
894 tenferro_ops::std_tensor_op::StdTensorOp::ExtractDiag { axis_a, axis_b } => {
895 Self::ExtractDiag {
896 axis_a: *axis_a,
897 axis_b: *axis_b,
898 }
899 }
900 tenferro_ops::std_tensor_op::StdTensorOp::EmbedDiag { axis_a, axis_b } => {
901 Self::EmbedDiag {
902 axis_a: *axis_a,
903 axis_b: *axis_b,
904 }
905 }
906 tenferro_ops::std_tensor_op::StdTensorOp::Tril { k } => Self::Tril { k: *k },
907 tenferro_ops::std_tensor_op::StdTensorOp::Triu { k } => Self::Triu { k: *k },
908 tenferro_ops::std_tensor_op::StdTensorOp::Gather(config) => {
909 Self::Gather(config.clone())
910 }
911 tenferro_ops::std_tensor_op::StdTensorOp::GatherDynamicSliceSizes {
912 offset_dims,
913 collapsed_slice_dims,
914 start_index_map,
915 index_vector_dim,
916 slice_sizes,
917 } => Self::GatherDynamicSliceSizes {
918 offset_dims: offset_dims.clone(),
919 collapsed_slice_dims: collapsed_slice_dims.clone(),
920 start_index_map: start_index_map.clone(),
921 index_vector_dim: *index_vector_dim,
922 slice_sizes: slice_sizes.clone(),
923 },
924 tenferro_ops::std_tensor_op::StdTensorOp::Scatter(config) => {
925 Self::Scatter(config.clone())
926 }
927 tenferro_ops::std_tensor_op::StdTensorOp::Slice(config) => {
928 Self::Slice(config.clone())
929 }
930 tenferro_ops::std_tensor_op::StdTensorOp::DynamicSlice { slice_sizes } => {
931 Self::DynamicSlice {
932 slice_sizes: slice_sizes.clone(),
933 }
934 }
935 tenferro_ops::std_tensor_op::StdTensorOp::DynamicUpdateSlice => {
936 Self::DynamicUpdateSlice
937 }
938 tenferro_ops::std_tensor_op::StdTensorOp::Pad(config) => {
939 Self::Pad(config.clone())
940 }
941 tenferro_ops::std_tensor_op::StdTensorOp::Concatenate { axis, .. } => {
942 Self::Concatenate { axis: *axis }
943 }
944 tenferro_ops::std_tensor_op::StdTensorOp::Reverse { axes } => {
945 Self::Reverse { axes: axes.clone() }
946 }
947 tenferro_ops::std_tensor_op::StdTensorOp::ShapeOf { axis } => {
948 Self::ShapeOf { axis: *axis }
949 }
950 tenferro_ops::std_tensor_op::StdTensorOp::DynamicTruncate { axis } => {
951 Self::DynamicTruncate { axis: *axis }
952 }
953 tenferro_ops::std_tensor_op::StdTensorOp::PadToMatch { axis } => {
954 Self::PadToMatch { axis: *axis }
955 }
956 tenferro_ops::std_tensor_op::StdTensorOp::Extension(op) => {
957 Self::Extension(op.clone())
958 }
959 }
960 }
961
962 pub(crate) fn elementwise_fusion_op(&self) -> Option<ElementwiseFusionOp> {
963 match self {
964 Self::Add => Some(ElementwiseFusionOp::Add),
965 Self::Multiply => Some(ElementwiseFusionOp::Multiply),
966 Self::Negate => Some(ElementwiseFusionOp::Negate),
967 Self::Conj => Some(ElementwiseFusionOp::Conj),
968 Self::Divide => Some(ElementwiseFusionOp::Divide),
969 Self::Abs => Some(ElementwiseFusionOp::Abs),
970 Self::Maximum => Some(ElementwiseFusionOp::Maximum),
971 Self::Minimum => Some(ElementwiseFusionOp::Minimum),
972 Self::Clamp => Some(ElementwiseFusionOp::Clamp),
973 Self::Exp => Some(ElementwiseFusionOp::Exp),
974 Self::Log => Some(ElementwiseFusionOp::Log),
975 Self::Sin => Some(ElementwiseFusionOp::Sin),
976 Self::Cos => Some(ElementwiseFusionOp::Cos),
977 Self::Tanh => Some(ElementwiseFusionOp::Tanh),
978 Self::Sqrt => Some(ElementwiseFusionOp::Sqrt),
979 Self::Rsqrt => Some(ElementwiseFusionOp::Rsqrt),
980 Self::Pow => Some(ElementwiseFusionOp::Pow),
981 Self::Expm1 => Some(ElementwiseFusionOp::Expm1),
982 Self::Log1p => Some(ElementwiseFusionOp::Log1p),
983 _ => None,
984 }
985 }
986
987 #[cfg(test)]
988 pub(crate) fn input_arity_bounds(&self) -> Option<(u8, u8)> {
989 self.primitive_kind().map(|kind| {
990 let descriptor = $crate::descriptor(kind);
991 (descriptor.min_inputs, descriptor.max_inputs)
992 })
993 }
994
995 #[cfg(test)]
996 pub(crate) fn sample_from_kind(kind: $crate::PrimitiveOpKind) -> Self {
997 match kind {
998 $crate::PrimitiveOpKind::Transpose => Self::Transpose { perm: vec![0] },
999 $crate::PrimitiveOpKind::Reshape => Self::Reshape {
1000 shape: vec![DimExpr::Const(1)],
1001 },
1002 $crate::PrimitiveOpKind::BroadcastInDim => Self::BroadcastInDim {
1003 shape: vec![DimExpr::Const(1)],
1004 dims: vec![0],
1005 },
1006 $crate::PrimitiveOpKind::Convert => Self::Convert { to: DType::F64 },
1007 $crate::PrimitiveOpKind::Constant => Self::Constant {
1008 dtype: DType::F64,
1009 bytes: 0.0_f64.to_le_bytes().to_vec(),
1010 },
1011 $crate::PrimitiveOpKind::DotGeneral => Self::DotGeneral(DotGeneralConfig {
1012 lhs_contracting_dims: vec![0],
1013 rhs_contracting_dims: vec![0],
1014 lhs_batch_dims: vec![],
1015 rhs_batch_dims: vec![],
1016 }),
1017 $crate::PrimitiveOpKind::ReduceSum => Self::ReduceSum { axes: vec![0] },
1018 $crate::PrimitiveOpKind::ExtractDiag => Self::ExtractDiag {
1019 axis_a: 0,
1020 axis_b: 1,
1021 },
1022 $crate::PrimitiveOpKind::EmbedDiag => Self::EmbedDiag {
1023 axis_a: 0,
1024 axis_b: 1,
1025 },
1026 $crate::PrimitiveOpKind::Tril => Self::Tril { k: 0 },
1027 $crate::PrimitiveOpKind::Triu => Self::Triu { k: 0 },
1028 $crate::PrimitiveOpKind::Add => Self::Add,
1029 $crate::PrimitiveOpKind::Mul => Self::Multiply,
1030 $crate::PrimitiveOpKind::Neg => Self::Negate,
1031 $crate::PrimitiveOpKind::Conj => Self::Conj,
1032 $crate::PrimitiveOpKind::Div => Self::Divide,
1033 $crate::PrimitiveOpKind::Abs => Self::Abs,
1034 $crate::PrimitiveOpKind::Sign => Self::Sign,
1035 $crate::PrimitiveOpKind::Maximum => Self::Maximum,
1036 $crate::PrimitiveOpKind::Minimum => Self::Minimum,
1037 $crate::PrimitiveOpKind::Compare => Self::Compare(CompareDir::Eq),
1038 $crate::PrimitiveOpKind::Select => Self::Select,
1039 $crate::PrimitiveOpKind::Clamp => Self::Clamp,
1040 $crate::PrimitiveOpKind::Exp => Self::Exp,
1041 $crate::PrimitiveOpKind::Log => Self::Log,
1042 $crate::PrimitiveOpKind::Sin => Self::Sin,
1043 $crate::PrimitiveOpKind::Cos => Self::Cos,
1044 $crate::PrimitiveOpKind::Tanh => Self::Tanh,
1045 $crate::PrimitiveOpKind::Sqrt => Self::Sqrt,
1046 $crate::PrimitiveOpKind::Rsqrt => Self::Rsqrt,
1047 $crate::PrimitiveOpKind::Pow => Self::Pow,
1048 $crate::PrimitiveOpKind::Expm1 => Self::Expm1,
1049 $crate::PrimitiveOpKind::Log1p => Self::Log1p,
1050 $crate::PrimitiveOpKind::Gather => Self::Gather(GatherConfig {
1051 offset_dims: vec![],
1052 collapsed_slice_dims: vec![0],
1053 start_index_map: vec![0],
1054 index_vector_dim: 1,
1055 slice_sizes: vec![1],
1056 }),
1057 $crate::PrimitiveOpKind::GatherDynamicSliceSizes => {
1058 Self::GatherDynamicSliceSizes {
1059 offset_dims: vec![],
1060 collapsed_slice_dims: vec![0],
1061 start_index_map: vec![0],
1062 index_vector_dim: 1,
1063 slice_sizes: vec![DimExpr::Const(1)],
1064 }
1065 }
1066 $crate::PrimitiveOpKind::Scatter => Self::Scatter(ScatterConfig {
1067 update_window_dims: vec![],
1068 inserted_window_dims: vec![0],
1069 scatter_dims_to_operand_dims: vec![0],
1070 index_vector_dim: 1,
1071 }),
1072 $crate::PrimitiveOpKind::Slice => Self::Slice(SliceConfig {
1073 starts: vec![0],
1074 limits: vec![1],
1075 strides: vec![1],
1076 }),
1077 $crate::PrimitiveOpKind::DynamicSlice => Self::DynamicSlice {
1078 slice_sizes: vec![1],
1079 },
1080 $crate::PrimitiveOpKind::DynamicUpdateSlice => Self::DynamicUpdateSlice,
1081 $crate::PrimitiveOpKind::Pad => Self::Pad(PadConfig {
1082 edge_padding_low: vec![0],
1083 edge_padding_high: vec![0],
1084 interior_padding: vec![0],
1085 }),
1086 $crate::PrimitiveOpKind::Concatenate => Self::Concatenate { axis: 0 },
1087 $crate::PrimitiveOpKind::Reverse => Self::Reverse { axes: vec![0] },
1088 $crate::PrimitiveOpKind::ShapeOf => Self::ShapeOf { axis: 0 },
1089 $crate::PrimitiveOpKind::DynamicTruncate => Self::DynamicTruncate { axis: 0 },
1090 $crate::PrimitiveOpKind::PadToMatch => Self::PadToMatch { axis: 0 },
1091 $crate::PrimitiveOpKind::ReduceProd => Self::ReduceProd { axes: vec![0] },
1092 $crate::PrimitiveOpKind::ReduceMax => Self::ReduceMax { axes: vec![0] },
1093 $crate::PrimitiveOpKind::ReduceMin => Self::ReduceMin { axes: vec![0] },
1094 }
1095 }
1096 }
1097 };
1098}