Skip to main content

tenferro_ops/
semiring_op.rs

1use std::fmt;
2use std::hash::{Hash, Hasher};
3use std::marker::PhantomData;
4
5use computegraph::GraphOp;
6use tenferro_algebra::Algebra;
7use tenferro_tensor::{DotGeneralConfig, TypedTensor};
8
9use crate::dim_expr::DimExpr;
10use crate::semiring_op_kind::SemiringOpKind;
11use crate::semiring_ops::SemiringOps;
12
13#[derive(Clone, Debug, Hash, PartialEq, Eq)]
14pub struct SemiringInputKey {
15    pub id: u64,
16}
17
18pub struct SemiringOp<Alg: Algebra> {
19    pub kind: SemiringOpKind,
20    _marker: PhantomData<Alg>,
21}
22
23impl<Alg: Algebra> SemiringOp<Alg> {
24    pub fn new(kind: SemiringOpKind) -> Self {
25        Self {
26            kind,
27            _marker: PhantomData,
28        }
29    }
30}
31
32impl<Alg: Algebra> Clone for SemiringOp<Alg> {
33    fn clone(&self) -> Self {
34        Self {
35            kind: self.kind.clone(),
36            _marker: PhantomData,
37        }
38    }
39}
40
41impl<Alg: Algebra> fmt::Debug for SemiringOp<Alg> {
42    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43        f.debug_struct("SemiringOp")
44            .field("kind", &self.kind)
45            .finish()
46    }
47}
48
49impl<Alg: Algebra> Hash for SemiringOp<Alg> {
50    fn hash<H: Hasher>(&self, state: &mut H) {
51        self.kind.hash(state);
52    }
53}
54
55impl<Alg: Algebra> PartialEq for SemiringOp<Alg> {
56    fn eq(&self, other: &Self) -> bool {
57        self.kind == other.kind
58    }
59}
60
61impl<Alg: Algebra> Eq for SemiringOp<Alg> {}
62
63impl<Alg> GraphOp for SemiringOp<Alg>
64where
65    Alg: Algebra + Send + Sync + 'static,
66{
67    type Operand = TypedTensor<Alg::Scalar>;
68    type Context = ();
69    type InputKey = SemiringInputKey;
70
71    fn n_inputs(&self) -> usize {
72        self.kind.n_inputs()
73    }
74
75    fn n_outputs(&self) -> usize {
76        1
77    }
78}
79
80impl<Alg> SemiringOps for SemiringOp<Alg>
81where
82    Alg: Algebra + Send + Sync + 'static,
83{
84    fn add_op() -> Self {
85        Self::new(SemiringOpKind::Add)
86    }
87
88    fn mul_op() -> Self {
89        Self::new(SemiringOpKind::Mul)
90    }
91
92    fn dot_general(config: DotGeneralConfig) -> Self {
93        Self::new(SemiringOpKind::DotGeneral(config))
94    }
95
96    fn reduce_sum(axes: Vec<usize>, _input_shape: Vec<DimExpr>) -> Self {
97        Self::new(SemiringOpKind::ReduceSum { axes })
98    }
99
100    fn transpose_op(perm: Vec<usize>) -> Self {
101        Self::new(SemiringOpKind::Transpose { perm })
102    }
103
104    fn reshape(_from_shape: Vec<DimExpr>, to_shape: Vec<DimExpr>) -> Self {
105        Self::new(SemiringOpKind::Reshape {
106            shape: concrete_shape(to_shape),
107        })
108    }
109
110    fn broadcast_in_dim(shape: Vec<DimExpr>, dims: Vec<usize>) -> Self {
111        Self::new(SemiringOpKind::BroadcastInDim {
112            shape: concrete_shape(shape),
113            dims,
114        })
115    }
116
117    fn extract_diag(axis_a: usize, axis_b: usize) -> Self {
118        Self::new(SemiringOpKind::ExtractDiag { axis_a, axis_b })
119    }
120
121    fn embed_diag(axis_a: usize, axis_b: usize) -> Self {
122        Self::new(SemiringOpKind::EmbedDiag { axis_a, axis_b })
123    }
124}
125
126fn concrete_shape(shape: Vec<DimExpr>) -> Vec<usize> {
127    shape
128        .into_iter()
129        .map(|dim| match dim {
130            DimExpr::Const(value) => value,
131            other => panic!("SemiringOp only supports concrete shape expressions, got {other:?}"),
132        })
133        .collect()
134}