tenferro_ops/
semiring_op.rs1use std::fmt;
2use std::hash::{Hash, Hasher};
3use std::marker::PhantomData;
4
5use computegraph::GraphOp;
6use tenferro_algebra::Algebra;
7use tenferro_tensor::{DotGeneralConfig, TypedTensor};
8
9use crate::dim_expr::DimExpr;
10use crate::semiring_op_kind::SemiringOpKind;
11use crate::semiring_ops::SemiringOps;
12
13#[derive(Clone, Debug, Hash, PartialEq, Eq)]
14pub struct SemiringInputKey {
15 pub id: u64,
16}
17
18pub struct SemiringOp<Alg: Algebra> {
19 pub kind: SemiringOpKind,
20 _marker: PhantomData<Alg>,
21}
22
23impl<Alg: Algebra> SemiringOp<Alg> {
24 pub fn new(kind: SemiringOpKind) -> Self {
25 Self {
26 kind,
27 _marker: PhantomData,
28 }
29 }
30}
31
32impl<Alg: Algebra> Clone for SemiringOp<Alg> {
33 fn clone(&self) -> Self {
34 Self {
35 kind: self.kind.clone(),
36 _marker: PhantomData,
37 }
38 }
39}
40
41impl<Alg: Algebra> fmt::Debug for SemiringOp<Alg> {
42 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43 f.debug_struct("SemiringOp")
44 .field("kind", &self.kind)
45 .finish()
46 }
47}
48
49impl<Alg: Algebra> Hash for SemiringOp<Alg> {
50 fn hash<H: Hasher>(&self, state: &mut H) {
51 self.kind.hash(state);
52 }
53}
54
55impl<Alg: Algebra> PartialEq for SemiringOp<Alg> {
56 fn eq(&self, other: &Self) -> bool {
57 self.kind == other.kind
58 }
59}
60
61impl<Alg: Algebra> Eq for SemiringOp<Alg> {}
62
63impl<Alg> GraphOp for SemiringOp<Alg>
64where
65 Alg: Algebra + Send + Sync + 'static,
66{
67 type Operand = TypedTensor<Alg::Scalar>;
68 type Context = ();
69 type InputKey = SemiringInputKey;
70
71 fn n_inputs(&self) -> usize {
72 self.kind.n_inputs()
73 }
74
75 fn n_outputs(&self) -> usize {
76 1
77 }
78}
79
80impl<Alg> SemiringOps for SemiringOp<Alg>
81where
82 Alg: Algebra + Send + Sync + 'static,
83{
84 fn add_op() -> Self {
85 Self::new(SemiringOpKind::Add)
86 }
87
88 fn mul_op() -> Self {
89 Self::new(SemiringOpKind::Mul)
90 }
91
92 fn dot_general(config: DotGeneralConfig) -> Self {
93 Self::new(SemiringOpKind::DotGeneral(config))
94 }
95
96 fn reduce_sum(axes: Vec<usize>, _input_shape: Vec<DimExpr>) -> Self {
97 Self::new(SemiringOpKind::ReduceSum { axes })
98 }
99
100 fn transpose_op(perm: Vec<usize>) -> Self {
101 Self::new(SemiringOpKind::Transpose { perm })
102 }
103
104 fn reshape(_from_shape: Vec<DimExpr>, to_shape: Vec<DimExpr>) -> Self {
105 Self::new(SemiringOpKind::Reshape {
106 shape: concrete_shape(to_shape),
107 })
108 }
109
110 fn broadcast_in_dim(shape: Vec<DimExpr>, dims: Vec<usize>) -> Self {
111 Self::new(SemiringOpKind::BroadcastInDim {
112 shape: concrete_shape(shape),
113 dims,
114 })
115 }
116
117 fn extract_diag(axis_a: usize, axis_b: usize) -> Self {
118 Self::new(SemiringOpKind::ExtractDiag { axis_a, axis_b })
119 }
120
121 fn embed_diag(axis_a: usize, axis_b: usize) -> Self {
122 Self::new(SemiringOpKind::EmbedDiag { axis_a, axis_b })
123 }
124}
125
126fn concrete_shape(shape: Vec<DimExpr>) -> Vec<usize> {
127 shape
128 .into_iter()
129 .map(|dim| match dim {
130 DimExpr::Const(value) => value,
131 other => panic!("SemiringOp only supports concrete shape expressions, got {other:?}"),
132 })
133 .collect()
134}