Skip to main content

tenferro_core_ops/
catalog.rs

1/// High-level category for a core primitive operation.
2///
3/// # Examples
4///
5/// ```rust
6/// use tenferro_core_ops::{descriptor, OpCategory, PrimitiveOpKind};
7///
8/// assert_eq!(
9///     descriptor(PrimitiveOpKind::ShapeOf).category,
10///     OpCategory::Host
11/// );
12/// ```
13#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
14pub enum OpCategory {
15    Elementwise,
16    Analytic,
17    Structural,
18    Reduction,
19    Contraction,
20    Indexing,
21    Dynamic,
22    Host,
23}
24
25/// Dtype compatibility policy for a core primitive operation.
26///
27/// # Examples
28///
29/// ```rust
30/// use tenferro_core_ops::{descriptor, DTypePolicy, PrimitiveOpKind};
31///
32/// assert_eq!(
33///     descriptor(PrimitiveOpKind::Compare).dtype_policy,
34///     DTypePolicy::CompareToBool
35/// );
36/// ```
37#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
38pub enum DTypePolicy {
39    SameAny,
40    SameNumeric,
41    SameFloat,
42    /// Preserve real float dtype and map complex magnitude to the matching real dtype.
43    AbsToReal,
44    SameFloatOrComplex,
45    CompareToBool,
46    BoolSelect,
47    Convert,
48    Shape,
49    Constant,
50}
51
52/// Static metadata for one core primitive operation.
53///
54/// # Examples
55///
56/// ```rust
57/// use tenferro_core_ops::{descriptor, PrimitiveOpKind};
58///
59/// let add = descriptor(PrimitiveOpKind::Add);
60/// assert_eq!(add.min_inputs, 2);
61/// assert_eq!(add.max_inputs, 2);
62/// ```
63#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
64pub struct PrimitiveOpDescriptor {
65    /// Catalog key for this operation.
66    pub kind: PrimitiveOpKind,
67    /// Stable snake-case operation name for diagnostics and descriptors.
68    pub name: &'static str,
69    /// Broad execution category.
70    pub category: OpCategory,
71    /// Dtype compatibility policy.
72    pub dtype_policy: DTypePolicy,
73    /// Minimum number of inputs accepted by the op.
74    pub min_inputs: u8,
75    /// Maximum number of inputs accepted by the op.
76    pub max_inputs: u8,
77    /// Whether this op is executed by host/runtime logic rather than a tensor backend.
78    pub host_only: bool,
79}
80
81macro_rules! primitive_ops {
82    ($macro:ident) => {
83        $macro! {
84            Add, "add", Elementwise, SameNumeric, 2, 2, false;
85            Mul, "mul", Elementwise, SameNumeric, 2, 2, false;
86            Neg, "neg", Elementwise, SameNumeric, 1, 1, false;
87            Conj, "conj", Elementwise, SameFloatOrComplex, 1, 1, false;
88            Div, "div", Elementwise, SameFloatOrComplex, 2, 2, false;
89            Abs, "abs", Elementwise, AbsToReal, 1, 1, false;
90            Sign, "sign", Elementwise, SameFloat, 1, 1, false;
91            Maximum, "maximum", Elementwise, SameFloat, 2, 2, false;
92            Minimum, "minimum", Elementwise, SameFloat, 2, 2, false;
93            Compare, "compare", Elementwise, CompareToBool, 2, 2, false;
94            Select, "select", Elementwise, BoolSelect, 3, 3, false;
95            Clamp, "clamp", Elementwise, SameFloat, 3, 3, false;
96            Exp, "exp", Analytic, SameFloatOrComplex, 1, 1, false;
97            Log, "log", Analytic, SameFloatOrComplex, 1, 1, false;
98            Sin, "sin", Analytic, SameFloatOrComplex, 1, 1, false;
99            Cos, "cos", Analytic, SameFloatOrComplex, 1, 1, false;
100            Tanh, "tanh", Analytic, SameFloatOrComplex, 1, 1, false;
101            Sqrt, "sqrt", Analytic, SameFloatOrComplex, 1, 1, false;
102            Rsqrt, "rsqrt", Analytic, SameFloatOrComplex, 1, 1, false;
103            Pow, "pow", Analytic, SameFloatOrComplex, 2, 2, false;
104            Expm1, "expm1", Analytic, SameFloatOrComplex, 1, 1, false;
105            Log1p, "log1p", Analytic, SameFloatOrComplex, 1, 1, false;
106            DotGeneral, "dot_general", Contraction, SameFloatOrComplex, 2, 2, false;
107            ReduceSum, "reduce_sum", Reduction, SameNumeric, 1, 1, false;
108            ReduceProd, "reduce_prod", Reduction, SameNumeric, 1, 1, false;
109            ReduceMax, "reduce_max", Reduction, SameFloat, 1, 1, false;
110            ReduceMin, "reduce_min", Reduction, SameFloat, 1, 1, false;
111            Transpose, "transpose", Structural, SameAny, 1, 1, false;
112            Reshape, "reshape", Structural, SameAny, 1, 1, false;
113            BroadcastInDim, "broadcast_in_dim", Structural, SameAny, 1, 1, false;
114            Convert, "convert", Structural, Convert, 1, 1, false;
115            ExtractDiag, "extract_diag", Structural, SameAny, 1, 1, false;
116            EmbedDiag, "embed_diag", Structural, SameAny, 1, 1, false;
117            Tril, "tril", Structural, SameAny, 1, 1, false;
118            Triu, "triu", Structural, SameAny, 1, 1, false;
119            Gather, "gather", Indexing, SameAny, 2, 2, false;
120            GatherDynamicSliceSizes, "gather_dynamic_slice_sizes", Indexing, SameAny, 2, 2, false;
121            Scatter, "scatter", Indexing, SameAny, 3, 3, false;
122            Slice, "slice", Indexing, SameAny, 1, 1, false;
123            DynamicSlice, "dynamic_slice", Indexing, SameAny, 2, 2, false;
124            DynamicUpdateSlice, "dynamic_update_slice", Indexing, SameAny, 3, 3, false;
125            Pad, "pad", Indexing, SameAny, 1, 1, false;
126            Concatenate, "concatenate", Indexing, SameAny, 1, u8::MAX, false;
127            Reverse, "reverse", Indexing, SameAny, 1, 1, false;
128            ShapeOf, "shape_of", Host, Shape, 1, 1, true;
129            DynamicTruncate, "dynamic_truncate", Dynamic, SameAny, 2, 2, true;
130            PadToMatch, "pad_to_match", Dynamic, SameAny, 2, 2, true;
131            Constant, "constant", Host, Constant, 0, 0, true;
132        }
133    };
134}
135
136macro_rules! define_kind {
137    ($( $variant:ident, $name:literal, $category:ident, $policy:ident, $min:expr, $max:expr, $host:expr; )*) => {
138        /// Catalog key for a core primitive operation.
139        ///
140        /// # Examples
141        ///
142        /// ```rust
143        /// use tenferro_core_ops::{descriptor, PrimitiveOpKind};
144        ///
145        /// assert_eq!(descriptor(PrimitiveOpKind::Add).name, "add");
146        /// ```
147        #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
148        pub enum PrimitiveOpKind {
149            $( $variant, )*
150        }
151
152        impl PrimitiveOpKind {
153            /// Number of primitive operation kinds in the catalog.
154            ///
155            /// # Examples
156            ///
157            /// ```rust
158            /// use tenferro_core_ops::PrimitiveOpKind;
159            ///
160            /// assert!(PrimitiveOpKind::COUNT > 0);
161            /// ```
162            pub const COUNT: usize = [$(PrimitiveOpKind::$variant),*].len();
163
164            /// Return this kind's dense catalog index.
165            ///
166            /// # Examples
167            ///
168            /// ```rust
169            /// use tenferro_core_ops::PrimitiveOpKind;
170            ///
171            /// assert_eq!(PrimitiveOpKind::Add.as_index(), 0);
172            /// ```
173            pub const fn as_index(self) -> usize {
174                self as usize
175            }
176        }
177    };
178}
179
180primitive_ops!(define_kind);
181
182macro_rules! define_descriptors {
183    ($( $variant:ident, $name:literal, $category:ident, $policy:ident, $min:expr, $max:expr, $host:expr; )*) => {
184        const DESCRIPTORS: &[PrimitiveOpDescriptor] = &[
185            $(
186                PrimitiveOpDescriptor {
187                    kind: PrimitiveOpKind::$variant,
188                    name: $name,
189                    category: OpCategory::$category,
190                    dtype_policy: DTypePolicy::$policy,
191                    min_inputs: $min,
192                    max_inputs: $max,
193                    host_only: $host,
194                },
195            )*
196        ];
197
198        /// Return the descriptor for a primitive operation kind.
199        ///
200        /// # Examples
201        ///
202        /// ```rust
203        /// use tenferro_core_ops::{descriptor, PrimitiveOpKind};
204        ///
205        /// assert_eq!(descriptor(PrimitiveOpKind::Add).name, "add");
206        /// ```
207        pub fn descriptor(kind: PrimitiveOpKind) -> &'static PrimitiveOpDescriptor {
208            match kind {
209                $(
210                    PrimitiveOpKind::$variant => &DESCRIPTORS[PrimitiveOpKind::$variant as usize],
211                )*
212            }
213        }
214    };
215}
216
217primitive_ops!(define_descriptors);
218
219/// Return all core primitive operation descriptors in catalog order.
220///
221/// # Examples
222///
223/// ```rust
224/// use tenferro_core_ops::all_primitive_descriptors;
225///
226/// assert!(all_primitive_descriptors()
227///     .iter()
228///     .any(|descriptor| descriptor.name == "add"));
229/// ```
230pub fn all_primitive_descriptors() -> &'static [PrimitiveOpDescriptor] {
231    DESCRIPTORS
232}
233
234#[doc(hidden)]
235#[macro_export]
236macro_rules! define_std_tensor_op {
237    () => {
238        #[derive(Clone, Debug)]
239        pub enum StdTensorOp {
240            // Semiring arithmetic core
241            Add,
242            Mul,
243            Neg,
244            Conj,
245            DotGeneral {
246                config: DotGeneralConfig,
247            },
248            Transpose {
249                perm: Vec<usize>,
250            },
251            Reshape {
252                to_shape: Vec<DimExpr>,
253            },
254            BroadcastInDim {
255                shape: Vec<DimExpr>,
256                dims: Vec<usize>,
257            },
258            Convert {
259                from: DType,
260                to: DType,
261            },
262            Constant {
263                dtype: DType,
264                bytes: Vec<u8>,
265            },
266            ReduceSum {
267                axes: Vec<usize>,
268            },
269
270            // Elementwise (non-semiring)
271            Div,
272            Abs,
273            Sign,
274            Maximum,
275            Minimum,
276            Compare(CompareDir),
277            Select,
278            Clamp,
279
280            // Analytic
281            Exp,
282            Log,
283            Sin,
284            Cos,
285            Tanh,
286            Sqrt,
287            Rsqrt,
288            Pow,
289            Expm1,
290            Log1p,
291
292            // Diagonal extraction / embedding (AD-closed pair)
293            ExtractDiag {
294                axis_a: usize,
295                axis_b: usize,
296            },
297            EmbedDiag {
298                axis_a: usize,
299                axis_b: usize,
300            },
301            Tril {
302                k: i64,
303            },
304            Triu {
305                k: i64,
306            },
307
308            // Indexing
309            Gather(GatherConfig),
310            GatherDynamicSliceSizes {
311                offset_dims: Vec<usize>,
312                collapsed_slice_dims: Vec<usize>,
313                start_index_map: Vec<usize>,
314                index_vector_dim: usize,
315                slice_sizes: Vec<DimExpr>,
316            },
317            Scatter(ScatterConfig),
318            Slice(SliceConfig),
319            DynamicSlice {
320                slice_sizes: Vec<usize>,
321            },
322            DynamicUpdateSlice,
323            Pad(PadConfig),
324            Concatenate {
325                axis: usize,
326                input_count: usize,
327            },
328            Reverse {
329                axes: Vec<usize>,
330            },
331            ShapeOf {
332                axis: usize,
333            },
334            DynamicTruncate {
335                axis: usize,
336            },
337            PadToMatch {
338                axis: usize,
339            },
340
341            // Reductions
342            ReduceProd {
343                axes: Vec<usize>,
344            },
345            ReduceMax {
346                axes: Vec<usize>,
347            },
348            ReduceMin {
349                axes: Vec<usize>,
350            },
351
352            /// Out-of-tree extension carrier.
353            ///
354            /// See [`crate::ext_op`] and `docs/spec/extension-op.md`. Identity,
355            /// hashing, equality, arity, shape inference, and AD rules are delegated
356            /// to the inner [`ExtensionOp`] trait object.
357            Extension(Arc<dyn ExtensionOp>),
358        }
359
360        impl StdTensorOp {
361            /// Return the core primitive catalog kind for this graph operation.
362            ///
363            /// Extension operations do not claim a core primitive kind; they are
364            /// dispatched through their extension family id instead.
365            ///
366            /// # Examples
367            ///
368            /// ```rust
369            /// use tenferro_core_ops::PrimitiveOpKind;
370            /// use tenferro_ops::std_tensor_op::StdTensorOp;
371            ///
372            /// assert_eq!(StdTensorOp::Add.primitive_kind(), Some(PrimitiveOpKind::Add));
373            /// ```
374            pub fn primitive_kind(&self) -> Option<$crate::PrimitiveOpKind> {
375                let kind = match self {
376                    Self::Add => $crate::PrimitiveOpKind::Add,
377                    Self::Mul => $crate::PrimitiveOpKind::Mul,
378                    Self::Neg => $crate::PrimitiveOpKind::Neg,
379                    Self::Conj => $crate::PrimitiveOpKind::Conj,
380                    Self::DotGeneral { .. } => $crate::PrimitiveOpKind::DotGeneral,
381                    Self::Transpose { .. } => $crate::PrimitiveOpKind::Transpose,
382                    Self::Reshape { .. } => $crate::PrimitiveOpKind::Reshape,
383                    Self::BroadcastInDim { .. } => $crate::PrimitiveOpKind::BroadcastInDim,
384                    Self::Convert { .. } => $crate::PrimitiveOpKind::Convert,
385                    Self::Constant { .. } => $crate::PrimitiveOpKind::Constant,
386                    Self::ReduceSum { .. } => $crate::PrimitiveOpKind::ReduceSum,
387                    Self::Div => $crate::PrimitiveOpKind::Div,
388                    Self::Abs => $crate::PrimitiveOpKind::Abs,
389                    Self::Sign => $crate::PrimitiveOpKind::Sign,
390                    Self::Maximum => $crate::PrimitiveOpKind::Maximum,
391                    Self::Minimum => $crate::PrimitiveOpKind::Minimum,
392                    Self::Compare(_) => $crate::PrimitiveOpKind::Compare,
393                    Self::Select => $crate::PrimitiveOpKind::Select,
394                    Self::Clamp => $crate::PrimitiveOpKind::Clamp,
395                    Self::Exp => $crate::PrimitiveOpKind::Exp,
396                    Self::Log => $crate::PrimitiveOpKind::Log,
397                    Self::Sin => $crate::PrimitiveOpKind::Sin,
398                    Self::Cos => $crate::PrimitiveOpKind::Cos,
399                    Self::Tanh => $crate::PrimitiveOpKind::Tanh,
400                    Self::Sqrt => $crate::PrimitiveOpKind::Sqrt,
401                    Self::Rsqrt => $crate::PrimitiveOpKind::Rsqrt,
402                    Self::Pow => $crate::PrimitiveOpKind::Pow,
403                    Self::Expm1 => $crate::PrimitiveOpKind::Expm1,
404                    Self::Log1p => $crate::PrimitiveOpKind::Log1p,
405                    Self::ExtractDiag { .. } => $crate::PrimitiveOpKind::ExtractDiag,
406                    Self::EmbedDiag { .. } => $crate::PrimitiveOpKind::EmbedDiag,
407                    Self::Tril { .. } => $crate::PrimitiveOpKind::Tril,
408                    Self::Triu { .. } => $crate::PrimitiveOpKind::Triu,
409                    Self::Gather(_) => $crate::PrimitiveOpKind::Gather,
410                    Self::GatherDynamicSliceSizes { .. } => {
411                        $crate::PrimitiveOpKind::GatherDynamicSliceSizes
412                    }
413                    Self::Scatter(_) => $crate::PrimitiveOpKind::Scatter,
414                    Self::Slice(_) => $crate::PrimitiveOpKind::Slice,
415                    Self::DynamicSlice { .. } => $crate::PrimitiveOpKind::DynamicSlice,
416                    Self::DynamicUpdateSlice => $crate::PrimitiveOpKind::DynamicUpdateSlice,
417                    Self::Pad(_) => $crate::PrimitiveOpKind::Pad,
418                    Self::Concatenate { .. } => $crate::PrimitiveOpKind::Concatenate,
419                    Self::Reverse { .. } => $crate::PrimitiveOpKind::Reverse,
420                    Self::ShapeOf { .. } => $crate::PrimitiveOpKind::ShapeOf,
421                    Self::DynamicTruncate { .. } => $crate::PrimitiveOpKind::DynamicTruncate,
422                    Self::PadToMatch { .. } => $crate::PrimitiveOpKind::PadToMatch,
423                    Self::ReduceProd { .. } => $crate::PrimitiveOpKind::ReduceProd,
424                    Self::ReduceMax { .. } => $crate::PrimitiveOpKind::ReduceMax,
425                    Self::ReduceMin { .. } => $crate::PrimitiveOpKind::ReduceMin,
426                    Self::Extension(_) => return None,
427                };
428                Some(kind)
429            }
430
431            #[cfg(test)]
432            pub(crate) fn sample_from_kind(kind: $crate::PrimitiveOpKind) -> Self {
433                match kind {
434                    $crate::PrimitiveOpKind::Add => Self::Add,
435                    $crate::PrimitiveOpKind::Mul => Self::Mul,
436                    $crate::PrimitiveOpKind::Neg => Self::Neg,
437                    $crate::PrimitiveOpKind::Conj => Self::Conj,
438                    $crate::PrimitiveOpKind::DotGeneral => Self::DotGeneral {
439                        config: DotGeneralConfig {
440                            lhs_contracting_dims: vec![0],
441                            rhs_contracting_dims: vec![0],
442                            lhs_batch_dims: vec![],
443                            rhs_batch_dims: vec![],
444                        },
445                    },
446                    $crate::PrimitiveOpKind::Transpose => Self::Transpose { perm: vec![0] },
447                    $crate::PrimitiveOpKind::Reshape => Self::Reshape {
448                        to_shape: vec![DimExpr::Const(1)],
449                    },
450                    $crate::PrimitiveOpKind::BroadcastInDim => Self::BroadcastInDim {
451                        shape: vec![DimExpr::Const(1)],
452                        dims: vec![0],
453                    },
454                    $crate::PrimitiveOpKind::Convert => Self::Convert {
455                        from: DType::F32,
456                        to: DType::F64,
457                    },
458                    $crate::PrimitiveOpKind::Constant => Self::Constant {
459                        dtype: DType::F64,
460                        bytes: 0.0_f64.to_le_bytes().to_vec(),
461                    },
462                    $crate::PrimitiveOpKind::ReduceSum => Self::ReduceSum { axes: vec![0] },
463                    $crate::PrimitiveOpKind::Div => Self::Div,
464                    $crate::PrimitiveOpKind::Abs => Self::Abs,
465                    $crate::PrimitiveOpKind::Sign => Self::Sign,
466                    $crate::PrimitiveOpKind::Maximum => Self::Maximum,
467                    $crate::PrimitiveOpKind::Minimum => Self::Minimum,
468                    $crate::PrimitiveOpKind::Compare => Self::Compare(CompareDir::Eq),
469                    $crate::PrimitiveOpKind::Select => Self::Select,
470                    $crate::PrimitiveOpKind::Clamp => Self::Clamp,
471                    $crate::PrimitiveOpKind::Exp => Self::Exp,
472                    $crate::PrimitiveOpKind::Log => Self::Log,
473                    $crate::PrimitiveOpKind::Sin => Self::Sin,
474                    $crate::PrimitiveOpKind::Cos => Self::Cos,
475                    $crate::PrimitiveOpKind::Tanh => Self::Tanh,
476                    $crate::PrimitiveOpKind::Sqrt => Self::Sqrt,
477                    $crate::PrimitiveOpKind::Rsqrt => Self::Rsqrt,
478                    $crate::PrimitiveOpKind::Pow => Self::Pow,
479                    $crate::PrimitiveOpKind::Expm1 => Self::Expm1,
480                    $crate::PrimitiveOpKind::Log1p => Self::Log1p,
481                    $crate::PrimitiveOpKind::ExtractDiag => Self::ExtractDiag {
482                        axis_a: 0,
483                        axis_b: 1,
484                    },
485                    $crate::PrimitiveOpKind::EmbedDiag => Self::EmbedDiag {
486                        axis_a: 0,
487                        axis_b: 1,
488                    },
489                    $crate::PrimitiveOpKind::Tril => Self::Tril { k: 0 },
490                    $crate::PrimitiveOpKind::Triu => Self::Triu { k: 0 },
491                    $crate::PrimitiveOpKind::Gather => Self::Gather(GatherConfig {
492                        offset_dims: vec![],
493                        collapsed_slice_dims: vec![0],
494                        start_index_map: vec![0],
495                        index_vector_dim: 1,
496                        slice_sizes: vec![1],
497                    }),
498                    $crate::PrimitiveOpKind::GatherDynamicSliceSizes => {
499                        Self::GatherDynamicSliceSizes {
500                            offset_dims: vec![],
501                            collapsed_slice_dims: vec![0],
502                            start_index_map: vec![0],
503                            index_vector_dim: 1,
504                            slice_sizes: vec![DimExpr::Const(1)],
505                        }
506                    }
507                    $crate::PrimitiveOpKind::Scatter => Self::Scatter(ScatterConfig {
508                        update_window_dims: vec![],
509                        inserted_window_dims: vec![0],
510                        scatter_dims_to_operand_dims: vec![0],
511                        index_vector_dim: 1,
512                    }),
513                    $crate::PrimitiveOpKind::Slice => Self::Slice(SliceConfig {
514                        starts: vec![0],
515                        limits: vec![1],
516                        strides: vec![1],
517                    }),
518                    $crate::PrimitiveOpKind::DynamicSlice => Self::DynamicSlice {
519                        slice_sizes: vec![1],
520                    },
521                    $crate::PrimitiveOpKind::DynamicUpdateSlice => Self::DynamicUpdateSlice,
522                    $crate::PrimitiveOpKind::Pad => Self::Pad(PadConfig {
523                        edge_padding_low: vec![0],
524                        edge_padding_high: vec![0],
525                        interior_padding: vec![0],
526                    }),
527                    $crate::PrimitiveOpKind::Concatenate => Self::Concatenate {
528                        axis: 0,
529                        input_count: 1,
530                    },
531                    $crate::PrimitiveOpKind::Reverse => Self::Reverse { axes: vec![0] },
532                    $crate::PrimitiveOpKind::ShapeOf => Self::ShapeOf { axis: 0 },
533                    $crate::PrimitiveOpKind::DynamicTruncate => Self::DynamicTruncate { axis: 0 },
534                    $crate::PrimitiveOpKind::PadToMatch => Self::PadToMatch { axis: 0 },
535                    $crate::PrimitiveOpKind::ReduceProd => Self::ReduceProd { axes: vec![0] },
536                    $crate::PrimitiveOpKind::ReduceMax => Self::ReduceMax { axes: vec![0] },
537                    $crate::PrimitiveOpKind::ReduceMin => Self::ReduceMin { axes: vec![0] },
538                }
539            }
540        }
541    };
542}
543
544#[doc(hidden)]
545#[macro_export]
546macro_rules! define_elementwise_fusion_op {
547    () => {
548        /// Elementwise op kinds supported by backend fusion implementations.
549        #[doc(hidden)]
550        #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
551        pub enum ElementwiseFusionOp {
552            Add,
553            Multiply,
554            Negate,
555            Conj,
556            Divide,
557            Abs,
558            Maximum,
559            Minimum,
560            Clamp,
561            Exp,
562            Log,
563            Sin,
564            Cos,
565            Tanh,
566            Sqrt,
567            Rsqrt,
568            Pow,
569            Expm1,
570            Log1p,
571        }
572
573        #[cfg(test)]
574        impl ElementwiseFusionOp {
575            pub(crate) fn iter() -> impl Iterator<Item = Self> {
576                [
577                    Self::Add,
578                    Self::Multiply,
579                    Self::Negate,
580                    Self::Conj,
581                    Self::Divide,
582                    Self::Abs,
583                    Self::Maximum,
584                    Self::Minimum,
585                    Self::Clamp,
586                    Self::Exp,
587                    Self::Log,
588                    Self::Sin,
589                    Self::Cos,
590                    Self::Tanh,
591                    Self::Sqrt,
592                    Self::Rsqrt,
593                    Self::Pow,
594                    Self::Expm1,
595                    Self::Log1p,
596                ]
597                .into_iter()
598            }
599
600            pub(crate) fn from_primitive_kind(kind: $crate::PrimitiveOpKind) -> Option<Self> {
601                match kind {
602                    $crate::PrimitiveOpKind::Add => Some(Self::Add),
603                    $crate::PrimitiveOpKind::Mul => Some(Self::Multiply),
604                    $crate::PrimitiveOpKind::Neg => Some(Self::Negate),
605                    $crate::PrimitiveOpKind::Conj => Some(Self::Conj),
606                    $crate::PrimitiveOpKind::Div => Some(Self::Divide),
607                    $crate::PrimitiveOpKind::Abs => Some(Self::Abs),
608                    $crate::PrimitiveOpKind::Maximum => Some(Self::Maximum),
609                    $crate::PrimitiveOpKind::Minimum => Some(Self::Minimum),
610                    $crate::PrimitiveOpKind::Clamp => Some(Self::Clamp),
611                    $crate::PrimitiveOpKind::Exp => Some(Self::Exp),
612                    $crate::PrimitiveOpKind::Log => Some(Self::Log),
613                    $crate::PrimitiveOpKind::Sin => Some(Self::Sin),
614                    $crate::PrimitiveOpKind::Cos => Some(Self::Cos),
615                    $crate::PrimitiveOpKind::Tanh => Some(Self::Tanh),
616                    $crate::PrimitiveOpKind::Sqrt => Some(Self::Sqrt),
617                    $crate::PrimitiveOpKind::Rsqrt => Some(Self::Rsqrt),
618                    $crate::PrimitiveOpKind::Pow => Some(Self::Pow),
619                    $crate::PrimitiveOpKind::Expm1 => Some(Self::Expm1),
620                    $crate::PrimitiveOpKind::Log1p => Some(Self::Log1p),
621                    _ => None,
622                }
623            }
624
625            pub(crate) fn primitive_kind(self) -> $crate::PrimitiveOpKind {
626                match self {
627                    Self::Add => $crate::PrimitiveOpKind::Add,
628                    Self::Multiply => $crate::PrimitiveOpKind::Mul,
629                    Self::Negate => $crate::PrimitiveOpKind::Neg,
630                    Self::Conj => $crate::PrimitiveOpKind::Conj,
631                    Self::Divide => $crate::PrimitiveOpKind::Div,
632                    Self::Abs => $crate::PrimitiveOpKind::Abs,
633                    Self::Maximum => $crate::PrimitiveOpKind::Maximum,
634                    Self::Minimum => $crate::PrimitiveOpKind::Minimum,
635                    Self::Clamp => $crate::PrimitiveOpKind::Clamp,
636                    Self::Exp => $crate::PrimitiveOpKind::Exp,
637                    Self::Log => $crate::PrimitiveOpKind::Log,
638                    Self::Sin => $crate::PrimitiveOpKind::Sin,
639                    Self::Cos => $crate::PrimitiveOpKind::Cos,
640                    Self::Tanh => $crate::PrimitiveOpKind::Tanh,
641                    Self::Sqrt => $crate::PrimitiveOpKind::Sqrt,
642                    Self::Rsqrt => $crate::PrimitiveOpKind::Rsqrt,
643                    Self::Pow => $crate::PrimitiveOpKind::Pow,
644                    Self::Expm1 => $crate::PrimitiveOpKind::Expm1,
645                    Self::Log1p => $crate::PrimitiveOpKind::Log1p,
646                }
647            }
648        }
649    };
650}
651
652#[doc(hidden)]
653#[macro_export]
654macro_rules! define_exec_op {
655    () => {
656        #[derive(Clone, Debug)]
657        pub enum ExecOp {
658            Transpose {
659                perm: Vec<usize>,
660            },
661            Reshape {
662                shape: Vec<DimExpr>,
663            },
664            BroadcastInDim {
665                shape: Vec<DimExpr>,
666                dims: Vec<usize>,
667            },
668            Convert {
669                to: DType,
670            },
671            Constant {
672                dtype: DType,
673                bytes: Vec<u8>,
674            },
675            DotGeneral(DotGeneralConfig),
676            DotGeneralWithConj {
677                config: DotGeneralConfig,
678                lhs_conj: bool,
679                rhs_conj: bool,
680            },
681            ReduceSum {
682                axes: Vec<usize>,
683            },
684            ExtractDiag {
685                axis_a: usize,
686                axis_b: usize,
687            },
688            EmbedDiag {
689                axis_a: usize,
690                axis_b: usize,
691            },
692            Tril {
693                k: i64,
694            },
695            Triu {
696                k: i64,
697            },
698            Add,
699            Multiply,
700            Negate,
701            Conj,
702            Divide,
703            Abs,
704            Sign,
705            Maximum,
706            Minimum,
707            Compare(CompareDir),
708            Select,
709            Clamp,
710            Exp,
711            Log,
712            Sin,
713            Cos,
714            Tanh,
715            Sqrt,
716            Rsqrt,
717            Pow,
718            Expm1,
719            Log1p,
720            Gather(GatherConfig),
721            GatherDynamicSliceSizes {
722                offset_dims: Vec<usize>,
723                collapsed_slice_dims: Vec<usize>,
724                start_index_map: Vec<usize>,
725                index_vector_dim: usize,
726                slice_sizes: Vec<DimExpr>,
727            },
728            Scatter(ScatterConfig),
729            Slice(SliceConfig),
730            DynamicSlice {
731                slice_sizes: Vec<usize>,
732            },
733            DynamicUpdateSlice,
734            Pad(PadConfig),
735            Concatenate {
736                axis: usize,
737            },
738            Reverse {
739                axes: Vec<usize>,
740            },
741            ShapeOf {
742                axis: usize,
743            },
744            DynamicTruncate {
745                axis: usize,
746            },
747            PadToMatch {
748                axis: usize,
749            },
750            ReduceProd {
751                axes: Vec<usize>,
752            },
753            ReduceMax {
754                axes: Vec<usize>,
755            },
756            ReduceMin {
757                axes: Vec<usize>,
758            },
759            /// Out-of-tree extension carrier in the execution IR.
760            ///
761            /// Payload and dispatch are defined by the inner [`ExtensionOp`]. The
762            /// execution pipeline treats extensions as single-instruction FFI
763            /// boundaries (spec Section 8): no elementwise fusion, and dispatch is
764            /// routed through the executor's registered extension runtime.
765            Extension(Arc<dyn ExtensionOp>),
766        }
767
768        impl ExecOp {
769            pub(crate) fn primitive_kind(&self) -> Option<$crate::PrimitiveOpKind> {
770                let kind = match self {
771                    Self::Transpose { .. } => $crate::PrimitiveOpKind::Transpose,
772                    Self::Reshape { .. } => $crate::PrimitiveOpKind::Reshape,
773                    Self::BroadcastInDim { .. } => $crate::PrimitiveOpKind::BroadcastInDim,
774                    Self::Convert { .. } => $crate::PrimitiveOpKind::Convert,
775                    Self::Constant { .. } => $crate::PrimitiveOpKind::Constant,
776                    Self::DotGeneral(_) | Self::DotGeneralWithConj { .. } => {
777                        $crate::PrimitiveOpKind::DotGeneral
778                    }
779                    Self::ReduceSum { .. } => $crate::PrimitiveOpKind::ReduceSum,
780                    Self::ExtractDiag { .. } => $crate::PrimitiveOpKind::ExtractDiag,
781                    Self::EmbedDiag { .. } => $crate::PrimitiveOpKind::EmbedDiag,
782                    Self::Tril { .. } => $crate::PrimitiveOpKind::Tril,
783                    Self::Triu { .. } => $crate::PrimitiveOpKind::Triu,
784                    Self::Add => $crate::PrimitiveOpKind::Add,
785                    Self::Multiply => $crate::PrimitiveOpKind::Mul,
786                    Self::Negate => $crate::PrimitiveOpKind::Neg,
787                    Self::Conj => $crate::PrimitiveOpKind::Conj,
788                    Self::Divide => $crate::PrimitiveOpKind::Div,
789                    Self::Abs => $crate::PrimitiveOpKind::Abs,
790                    Self::Sign => $crate::PrimitiveOpKind::Sign,
791                    Self::Maximum => $crate::PrimitiveOpKind::Maximum,
792                    Self::Minimum => $crate::PrimitiveOpKind::Minimum,
793                    Self::Compare(_) => $crate::PrimitiveOpKind::Compare,
794                    Self::Select => $crate::PrimitiveOpKind::Select,
795                    Self::Clamp => $crate::PrimitiveOpKind::Clamp,
796                    Self::Exp => $crate::PrimitiveOpKind::Exp,
797                    Self::Log => $crate::PrimitiveOpKind::Log,
798                    Self::Sin => $crate::PrimitiveOpKind::Sin,
799                    Self::Cos => $crate::PrimitiveOpKind::Cos,
800                    Self::Tanh => $crate::PrimitiveOpKind::Tanh,
801                    Self::Sqrt => $crate::PrimitiveOpKind::Sqrt,
802                    Self::Rsqrt => $crate::PrimitiveOpKind::Rsqrt,
803                    Self::Pow => $crate::PrimitiveOpKind::Pow,
804                    Self::Expm1 => $crate::PrimitiveOpKind::Expm1,
805                    Self::Log1p => $crate::PrimitiveOpKind::Log1p,
806                    Self::Gather(_) => $crate::PrimitiveOpKind::Gather,
807                    Self::GatherDynamicSliceSizes { .. } => {
808                        $crate::PrimitiveOpKind::GatherDynamicSliceSizes
809                    }
810                    Self::Scatter(_) => $crate::PrimitiveOpKind::Scatter,
811                    Self::Slice(_) => $crate::PrimitiveOpKind::Slice,
812                    Self::DynamicSlice { .. } => $crate::PrimitiveOpKind::DynamicSlice,
813                    Self::DynamicUpdateSlice => $crate::PrimitiveOpKind::DynamicUpdateSlice,
814                    Self::Pad(_) => $crate::PrimitiveOpKind::Pad,
815                    Self::Concatenate { .. } => $crate::PrimitiveOpKind::Concatenate,
816                    Self::Reverse { .. } => $crate::PrimitiveOpKind::Reverse,
817                    Self::ShapeOf { .. } => $crate::PrimitiveOpKind::ShapeOf,
818                    Self::DynamicTruncate { .. } => $crate::PrimitiveOpKind::DynamicTruncate,
819                    Self::PadToMatch { .. } => $crate::PrimitiveOpKind::PadToMatch,
820                    Self::ReduceProd { .. } => $crate::PrimitiveOpKind::ReduceProd,
821                    Self::ReduceMax { .. } => $crate::PrimitiveOpKind::ReduceMax,
822                    Self::ReduceMin { .. } => $crate::PrimitiveOpKind::ReduceMin,
823                    Self::Extension(_) => return None,
824                };
825                Some(kind)
826            }
827
828            pub(crate) fn from_std_tensor_op(
829                op: &tenferro_ops::std_tensor_op::StdTensorOp,
830            ) -> Self {
831                match op {
832                    tenferro_ops::std_tensor_op::StdTensorOp::Add => Self::Add,
833                    tenferro_ops::std_tensor_op::StdTensorOp::Mul => Self::Multiply,
834                    tenferro_ops::std_tensor_op::StdTensorOp::Neg => Self::Negate,
835                    tenferro_ops::std_tensor_op::StdTensorOp::Conj => Self::Conj,
836                    tenferro_ops::std_tensor_op::StdTensorOp::Div => Self::Divide,
837                    tenferro_ops::std_tensor_op::StdTensorOp::Abs => Self::Abs,
838                    tenferro_ops::std_tensor_op::StdTensorOp::Sign => Self::Sign,
839                    tenferro_ops::std_tensor_op::StdTensorOp::Maximum => Self::Maximum,
840                    tenferro_ops::std_tensor_op::StdTensorOp::Minimum => Self::Minimum,
841                    tenferro_ops::std_tensor_op::StdTensorOp::Compare(dir) => {
842                        Self::Compare(dir.clone())
843                    }
844                    tenferro_ops::std_tensor_op::StdTensorOp::Select => Self::Select,
845                    tenferro_ops::std_tensor_op::StdTensorOp::Clamp => Self::Clamp,
846                    tenferro_ops::std_tensor_op::StdTensorOp::Exp => Self::Exp,
847                    tenferro_ops::std_tensor_op::StdTensorOp::Log => Self::Log,
848                    tenferro_ops::std_tensor_op::StdTensorOp::Sin => Self::Sin,
849                    tenferro_ops::std_tensor_op::StdTensorOp::Cos => Self::Cos,
850                    tenferro_ops::std_tensor_op::StdTensorOp::Tanh => Self::Tanh,
851                    tenferro_ops::std_tensor_op::StdTensorOp::Sqrt => Self::Sqrt,
852                    tenferro_ops::std_tensor_op::StdTensorOp::Rsqrt => Self::Rsqrt,
853                    tenferro_ops::std_tensor_op::StdTensorOp::Pow => Self::Pow,
854                    tenferro_ops::std_tensor_op::StdTensorOp::Expm1 => Self::Expm1,
855                    tenferro_ops::std_tensor_op::StdTensorOp::Log1p => Self::Log1p,
856                    tenferro_ops::std_tensor_op::StdTensorOp::Transpose { perm } => {
857                        Self::Transpose { perm: perm.clone() }
858                    }
859                    tenferro_ops::std_tensor_op::StdTensorOp::Reshape { to_shape } => {
860                        Self::Reshape {
861                            shape: to_shape.clone(),
862                        }
863                    }
864                    tenferro_ops::std_tensor_op::StdTensorOp::BroadcastInDim { shape, dims } => {
865                        Self::BroadcastInDim {
866                            shape: shape.clone(),
867                            dims: dims.clone(),
868                        }
869                    }
870                    tenferro_ops::std_tensor_op::StdTensorOp::Convert { to, .. } => {
871                        Self::Convert { to: *to }
872                    }
873                    tenferro_ops::std_tensor_op::StdTensorOp::Constant { dtype, bytes } => {
874                        Self::Constant {
875                            dtype: *dtype,
876                            bytes: bytes.clone(),
877                        }
878                    }
879                    tenferro_ops::std_tensor_op::StdTensorOp::DotGeneral { config } => {
880                        Self::DotGeneral(config.clone())
881                    }
882                    tenferro_ops::std_tensor_op::StdTensorOp::ReduceSum { axes } => {
883                        Self::ReduceSum { axes: axes.clone() }
884                    }
885                    tenferro_ops::std_tensor_op::StdTensorOp::ReduceProd { axes } => {
886                        Self::ReduceProd { axes: axes.clone() }
887                    }
888                    tenferro_ops::std_tensor_op::StdTensorOp::ReduceMax { axes } => {
889                        Self::ReduceMax { axes: axes.clone() }
890                    }
891                    tenferro_ops::std_tensor_op::StdTensorOp::ReduceMin { axes } => {
892                        Self::ReduceMin { axes: axes.clone() }
893                    }
894                    tenferro_ops::std_tensor_op::StdTensorOp::ExtractDiag { axis_a, axis_b } => {
895                        Self::ExtractDiag {
896                            axis_a: *axis_a,
897                            axis_b: *axis_b,
898                        }
899                    }
900                    tenferro_ops::std_tensor_op::StdTensorOp::EmbedDiag { axis_a, axis_b } => {
901                        Self::EmbedDiag {
902                            axis_a: *axis_a,
903                            axis_b: *axis_b,
904                        }
905                    }
906                    tenferro_ops::std_tensor_op::StdTensorOp::Tril { k } => Self::Tril { k: *k },
907                    tenferro_ops::std_tensor_op::StdTensorOp::Triu { k } => Self::Triu { k: *k },
908                    tenferro_ops::std_tensor_op::StdTensorOp::Gather(config) => {
909                        Self::Gather(config.clone())
910                    }
911                    tenferro_ops::std_tensor_op::StdTensorOp::GatherDynamicSliceSizes {
912                        offset_dims,
913                        collapsed_slice_dims,
914                        start_index_map,
915                        index_vector_dim,
916                        slice_sizes,
917                    } => Self::GatherDynamicSliceSizes {
918                        offset_dims: offset_dims.clone(),
919                        collapsed_slice_dims: collapsed_slice_dims.clone(),
920                        start_index_map: start_index_map.clone(),
921                        index_vector_dim: *index_vector_dim,
922                        slice_sizes: slice_sizes.clone(),
923                    },
924                    tenferro_ops::std_tensor_op::StdTensorOp::Scatter(config) => {
925                        Self::Scatter(config.clone())
926                    }
927                    tenferro_ops::std_tensor_op::StdTensorOp::Slice(config) => {
928                        Self::Slice(config.clone())
929                    }
930                    tenferro_ops::std_tensor_op::StdTensorOp::DynamicSlice { slice_sizes } => {
931                        Self::DynamicSlice {
932                            slice_sizes: slice_sizes.clone(),
933                        }
934                    }
935                    tenferro_ops::std_tensor_op::StdTensorOp::DynamicUpdateSlice => {
936                        Self::DynamicUpdateSlice
937                    }
938                    tenferro_ops::std_tensor_op::StdTensorOp::Pad(config) => {
939                        Self::Pad(config.clone())
940                    }
941                    tenferro_ops::std_tensor_op::StdTensorOp::Concatenate { axis, .. } => {
942                        Self::Concatenate { axis: *axis }
943                    }
944                    tenferro_ops::std_tensor_op::StdTensorOp::Reverse { axes } => {
945                        Self::Reverse { axes: axes.clone() }
946                    }
947                    tenferro_ops::std_tensor_op::StdTensorOp::ShapeOf { axis } => {
948                        Self::ShapeOf { axis: *axis }
949                    }
950                    tenferro_ops::std_tensor_op::StdTensorOp::DynamicTruncate { axis } => {
951                        Self::DynamicTruncate { axis: *axis }
952                    }
953                    tenferro_ops::std_tensor_op::StdTensorOp::PadToMatch { axis } => {
954                        Self::PadToMatch { axis: *axis }
955                    }
956                    tenferro_ops::std_tensor_op::StdTensorOp::Extension(op) => {
957                        Self::Extension(op.clone())
958                    }
959                }
960            }
961
962            pub(crate) fn elementwise_fusion_op(&self) -> Option<ElementwiseFusionOp> {
963                match self {
964                    Self::Add => Some(ElementwiseFusionOp::Add),
965                    Self::Multiply => Some(ElementwiseFusionOp::Multiply),
966                    Self::Negate => Some(ElementwiseFusionOp::Negate),
967                    Self::Conj => Some(ElementwiseFusionOp::Conj),
968                    Self::Divide => Some(ElementwiseFusionOp::Divide),
969                    Self::Abs => Some(ElementwiseFusionOp::Abs),
970                    Self::Maximum => Some(ElementwiseFusionOp::Maximum),
971                    Self::Minimum => Some(ElementwiseFusionOp::Minimum),
972                    Self::Clamp => Some(ElementwiseFusionOp::Clamp),
973                    Self::Exp => Some(ElementwiseFusionOp::Exp),
974                    Self::Log => Some(ElementwiseFusionOp::Log),
975                    Self::Sin => Some(ElementwiseFusionOp::Sin),
976                    Self::Cos => Some(ElementwiseFusionOp::Cos),
977                    Self::Tanh => Some(ElementwiseFusionOp::Tanh),
978                    Self::Sqrt => Some(ElementwiseFusionOp::Sqrt),
979                    Self::Rsqrt => Some(ElementwiseFusionOp::Rsqrt),
980                    Self::Pow => Some(ElementwiseFusionOp::Pow),
981                    Self::Expm1 => Some(ElementwiseFusionOp::Expm1),
982                    Self::Log1p => Some(ElementwiseFusionOp::Log1p),
983                    _ => None,
984                }
985            }
986
987            #[cfg(test)]
988            pub(crate) fn input_arity_bounds(&self) -> Option<(u8, u8)> {
989                self.primitive_kind().map(|kind| {
990                    let descriptor = $crate::descriptor(kind);
991                    (descriptor.min_inputs, descriptor.max_inputs)
992                })
993            }
994
995            #[cfg(test)]
996            pub(crate) fn sample_from_kind(kind: $crate::PrimitiveOpKind) -> Self {
997                match kind {
998                    $crate::PrimitiveOpKind::Transpose => Self::Transpose { perm: vec![0] },
999                    $crate::PrimitiveOpKind::Reshape => Self::Reshape {
1000                        shape: vec![DimExpr::Const(1)],
1001                    },
1002                    $crate::PrimitiveOpKind::BroadcastInDim => Self::BroadcastInDim {
1003                        shape: vec![DimExpr::Const(1)],
1004                        dims: vec![0],
1005                    },
1006                    $crate::PrimitiveOpKind::Convert => Self::Convert { to: DType::F64 },
1007                    $crate::PrimitiveOpKind::Constant => Self::Constant {
1008                        dtype: DType::F64,
1009                        bytes: 0.0_f64.to_le_bytes().to_vec(),
1010                    },
1011                    $crate::PrimitiveOpKind::DotGeneral => Self::DotGeneral(DotGeneralConfig {
1012                        lhs_contracting_dims: vec![0],
1013                        rhs_contracting_dims: vec![0],
1014                        lhs_batch_dims: vec![],
1015                        rhs_batch_dims: vec![],
1016                    }),
1017                    $crate::PrimitiveOpKind::ReduceSum => Self::ReduceSum { axes: vec![0] },
1018                    $crate::PrimitiveOpKind::ExtractDiag => Self::ExtractDiag {
1019                        axis_a: 0,
1020                        axis_b: 1,
1021                    },
1022                    $crate::PrimitiveOpKind::EmbedDiag => Self::EmbedDiag {
1023                        axis_a: 0,
1024                        axis_b: 1,
1025                    },
1026                    $crate::PrimitiveOpKind::Tril => Self::Tril { k: 0 },
1027                    $crate::PrimitiveOpKind::Triu => Self::Triu { k: 0 },
1028                    $crate::PrimitiveOpKind::Add => Self::Add,
1029                    $crate::PrimitiveOpKind::Mul => Self::Multiply,
1030                    $crate::PrimitiveOpKind::Neg => Self::Negate,
1031                    $crate::PrimitiveOpKind::Conj => Self::Conj,
1032                    $crate::PrimitiveOpKind::Div => Self::Divide,
1033                    $crate::PrimitiveOpKind::Abs => Self::Abs,
1034                    $crate::PrimitiveOpKind::Sign => Self::Sign,
1035                    $crate::PrimitiveOpKind::Maximum => Self::Maximum,
1036                    $crate::PrimitiveOpKind::Minimum => Self::Minimum,
1037                    $crate::PrimitiveOpKind::Compare => Self::Compare(CompareDir::Eq),
1038                    $crate::PrimitiveOpKind::Select => Self::Select,
1039                    $crate::PrimitiveOpKind::Clamp => Self::Clamp,
1040                    $crate::PrimitiveOpKind::Exp => Self::Exp,
1041                    $crate::PrimitiveOpKind::Log => Self::Log,
1042                    $crate::PrimitiveOpKind::Sin => Self::Sin,
1043                    $crate::PrimitiveOpKind::Cos => Self::Cos,
1044                    $crate::PrimitiveOpKind::Tanh => Self::Tanh,
1045                    $crate::PrimitiveOpKind::Sqrt => Self::Sqrt,
1046                    $crate::PrimitiveOpKind::Rsqrt => Self::Rsqrt,
1047                    $crate::PrimitiveOpKind::Pow => Self::Pow,
1048                    $crate::PrimitiveOpKind::Expm1 => Self::Expm1,
1049                    $crate::PrimitiveOpKind::Log1p => Self::Log1p,
1050                    $crate::PrimitiveOpKind::Gather => Self::Gather(GatherConfig {
1051                        offset_dims: vec![],
1052                        collapsed_slice_dims: vec![0],
1053                        start_index_map: vec![0],
1054                        index_vector_dim: 1,
1055                        slice_sizes: vec![1],
1056                    }),
1057                    $crate::PrimitiveOpKind::GatherDynamicSliceSizes => {
1058                        Self::GatherDynamicSliceSizes {
1059                            offset_dims: vec![],
1060                            collapsed_slice_dims: vec![0],
1061                            start_index_map: vec![0],
1062                            index_vector_dim: 1,
1063                            slice_sizes: vec![DimExpr::Const(1)],
1064                        }
1065                    }
1066                    $crate::PrimitiveOpKind::Scatter => Self::Scatter(ScatterConfig {
1067                        update_window_dims: vec![],
1068                        inserted_window_dims: vec![0],
1069                        scatter_dims_to_operand_dims: vec![0],
1070                        index_vector_dim: 1,
1071                    }),
1072                    $crate::PrimitiveOpKind::Slice => Self::Slice(SliceConfig {
1073                        starts: vec![0],
1074                        limits: vec![1],
1075                        strides: vec![1],
1076                    }),
1077                    $crate::PrimitiveOpKind::DynamicSlice => Self::DynamicSlice {
1078                        slice_sizes: vec![1],
1079                    },
1080                    $crate::PrimitiveOpKind::DynamicUpdateSlice => Self::DynamicUpdateSlice,
1081                    $crate::PrimitiveOpKind::Pad => Self::Pad(PadConfig {
1082                        edge_padding_low: vec![0],
1083                        edge_padding_high: vec![0],
1084                        interior_padding: vec![0],
1085                    }),
1086                    $crate::PrimitiveOpKind::Concatenate => Self::Concatenate { axis: 0 },
1087                    $crate::PrimitiveOpKind::Reverse => Self::Reverse { axes: vec![0] },
1088                    $crate::PrimitiveOpKind::ShapeOf => Self::ShapeOf { axis: 0 },
1089                    $crate::PrimitiveOpKind::DynamicTruncate => Self::DynamicTruncate { axis: 0 },
1090                    $crate::PrimitiveOpKind::PadToMatch => Self::PadToMatch { axis: 0 },
1091                    $crate::PrimitiveOpKind::ReduceProd => Self::ReduceProd { axes: vec![0] },
1092                    $crate::PrimitiveOpKind::ReduceMax => Self::ReduceMax { axes: vec![0] },
1093                    $crate::PrimitiveOpKind::ReduceMin => Self::ReduceMin { axes: vec![0] },
1094                }
1095            }
1096        }
1097    };
1098}