tenferro_ops/
semiring_op_kind.rs1use tenferro_tensor::DotGeneralConfig;
2
3#[derive(Clone, Debug, Hash, PartialEq, Eq)]
4pub enum SemiringOpKind {
5 Add,
6 Mul,
7 DotGeneral(DotGeneralConfig),
8 ReduceSum { axes: Vec<usize> },
9 Transpose { perm: Vec<usize> },
10 Reshape { shape: Vec<usize> },
11 BroadcastInDim { shape: Vec<usize>, dims: Vec<usize> },
12 ExtractDiag { axis_a: usize, axis_b: usize },
13 EmbedDiag { axis_a: usize, axis_b: usize },
14}
15
16impl SemiringOpKind {
17 pub fn n_inputs(&self) -> usize {
18 match self {
19 Self::Add | Self::Mul | Self::DotGeneral(_) => 2,
20 Self::ReduceSum { .. }
21 | Self::Transpose { .. }
22 | Self::Reshape { .. }
23 | Self::BroadcastInDim { .. }
24 | Self::ExtractDiag { .. }
25 | Self::EmbedDiag { .. } => 1,
26 }
27 }
28}