tenferro_ext_tropical/ad/
scalar.rs1use tenferro_algebra::{HasAlgebra, Scalar, Standard};
2
3pub trait TropicalScalar: Scalar {
18 type Inner: Scalar
20 + num_traits::Float
21 + std::ops::AddAssign
22 + HasAlgebra<Algebra = Standard<Self::Inner>>;
23
24 fn inner(&self) -> Self::Inner;
26
27 fn from_inner(v: Self::Inner) -> Self;
29
30 fn mul_backward_a(a_inner: Self::Inner, b_inner: Self::Inner, dout: Self::Inner)
32 -> Self::Inner;
33
34 fn mul_backward_b(a_inner: Self::Inner, b_inner: Self::Inner, dout: Self::Inner)
36 -> Self::Inner;
37}
38
39macro_rules! impl_tropical_scalar_additive {
40 ($wrapper:ident, $float:ty) => {
41 impl TropicalScalar for crate::$wrapper<$float> {
42 type Inner = $float;
43
44 fn inner(&self) -> $float {
45 self.0
46 }
47
48 fn from_inner(v: $float) -> Self {
49 crate::$wrapper(v)
50 }
51
52 fn mul_backward_a(_a: $float, _b: $float, dout: $float) -> $float {
53 dout
54 }
55
56 fn mul_backward_b(_a: $float, _b: $float, dout: $float) -> $float {
57 dout
58 }
59 }
60 };
61}
62
63macro_rules! impl_tropical_scalar_multiplicative {
64 ($wrapper:ident, $float:ty) => {
65 impl TropicalScalar for crate::$wrapper<$float> {
66 type Inner = $float;
67
68 fn inner(&self) -> $float {
69 self.0
70 }
71
72 fn from_inner(v: $float) -> Self {
73 crate::$wrapper(v)
74 }
75
76 fn mul_backward_a(_a: $float, b: $float, dout: $float) -> $float {
77 dout * b
78 }
79
80 fn mul_backward_b(a: $float, _b: $float, dout: $float) -> $float {
81 dout * a
82 }
83 }
84 };
85}
86
87impl_tropical_scalar_additive!(MaxPlus, f32);
88impl_tropical_scalar_additive!(MaxPlus, f64);
89impl_tropical_scalar_additive!(MinPlus, f32);
90impl_tropical_scalar_additive!(MinPlus, f64);
91impl_tropical_scalar_multiplicative!(MaxMul, f32);
92impl_tropical_scalar_multiplicative!(MaxMul, f64);