tenferro_ext_tropical/ad/
scalar.rs

1use tenferro_algebra::{HasAlgebra, Scalar, Standard};
2
3/// Trait for extracting the inner float type from a tropical scalar wrapper.
4///
5/// This enables generic code that operates on the inner values for backward
6/// pass computations.
7///
8/// # Examples
9///
10/// ```ignore
11/// use tenferro_ext_tropical::MaxPlus;
12/// use tenferro_ext_tropical::ad::TropicalScalar;
13///
14/// let x = MaxPlus(3.0_f64);
15/// assert_eq!(x.inner(), 3.0);
16/// ```
17pub trait TropicalScalar: Scalar {
18    /// The inner floating-point type.
19    type Inner: Scalar
20        + num_traits::Float
21        + std::ops::AddAssign
22        + HasAlgebra<Algebra = Standard<Self::Inner>>;
23
24    /// Extract the inner value.
25    fn inner(&self) -> Self::Inner;
26
27    /// Wrap an inner value into the tropical type.
28    fn from_inner(v: Self::Inner) -> Self;
29
30    /// Backward contribution for tropical multiplication w.r.t. the first operand.
31    fn mul_backward_a(a_inner: Self::Inner, b_inner: Self::Inner, dout: Self::Inner)
32        -> Self::Inner;
33
34    /// Backward contribution for tropical multiplication w.r.t. the second operand.
35    fn mul_backward_b(a_inner: Self::Inner, b_inner: Self::Inner, dout: Self::Inner)
36        -> Self::Inner;
37}
38
39macro_rules! impl_tropical_scalar_additive {
40    ($wrapper:ident, $float:ty) => {
41        impl TropicalScalar for crate::$wrapper<$float> {
42            type Inner = $float;
43
44            fn inner(&self) -> $float {
45                self.0
46            }
47
48            fn from_inner(v: $float) -> Self {
49                crate::$wrapper(v)
50            }
51
52            fn mul_backward_a(_a: $float, _b: $float, dout: $float) -> $float {
53                dout
54            }
55
56            fn mul_backward_b(_a: $float, _b: $float, dout: $float) -> $float {
57                dout
58            }
59        }
60    };
61}
62
63macro_rules! impl_tropical_scalar_multiplicative {
64    ($wrapper:ident, $float:ty) => {
65        impl TropicalScalar for crate::$wrapper<$float> {
66            type Inner = $float;
67
68            fn inner(&self) -> $float {
69                self.0
70            }
71
72            fn from_inner(v: $float) -> Self {
73                crate::$wrapper(v)
74            }
75
76            fn mul_backward_a(_a: $float, b: $float, dout: $float) -> $float {
77                dout * b
78            }
79
80            fn mul_backward_b(a: $float, _b: $float, dout: $float) -> $float {
81                dout * a
82            }
83        }
84    };
85}
86
87impl_tropical_scalar_additive!(MaxPlus, f32);
88impl_tropical_scalar_additive!(MaxPlus, f64);
89impl_tropical_scalar_additive!(MinPlus, f32);
90impl_tropical_scalar_additive!(MinPlus, f64);
91impl_tropical_scalar_multiplicative!(MaxMul, f32);
92impl_tropical_scalar_multiplicative!(MaxMul, f64);