Skip to main content

tenferro_ops/
semiring_op_kind.rs

1use tenferro_tensor::DotGeneralConfig;
2
3#[derive(Clone, Debug, Hash, PartialEq, Eq)]
4pub enum SemiringOpKind {
5    Add,
6    Mul,
7    DotGeneral(DotGeneralConfig),
8    ReduceSum { axes: Vec<usize> },
9    Transpose { perm: Vec<usize> },
10    Reshape { shape: Vec<usize> },
11    BroadcastInDim { shape: Vec<usize>, dims: Vec<usize> },
12    ExtractDiag { axis_a: usize, axis_b: usize },
13    EmbedDiag { axis_a: usize, axis_b: usize },
14}
15
16impl SemiringOpKind {
17    pub fn n_inputs(&self) -> usize {
18        match self {
19            Self::Add | Self::Mul | Self::DotGeneral(_) => 2,
20            Self::ReduceSum { .. }
21            | Self::Transpose { .. }
22            | Self::Reshape { .. }
23            | Self::BroadcastInDim { .. }
24            | Self::ExtractDiag { .. }
25            | Self::EmbedDiag { .. } => 1,
26        }
27    }
28}