tenferro_ext_tropical/ad/
rules.rs

1use tenferro_algebra::{HasAlgebra, Semiring};
2use tenferro_device::{Error, Result};
3use tenferro_einsum::Subscripts;
4use tenferro_prims::TensorSemiringCore;
5use tenferro_tensor::Tensor;
6
7use super::backward::tropical_backward;
8use super::common::contracted_modes;
9use super::forward::{tropical_forward_tangent, tropical_forward_with_argmax};
10use super::TropicalScalar;
11
12/// Reverse-mode rule (rrule) for tropical einsum.
13///
14/// Given a tropical einsum operation and a cotangent (in standard reals),
15/// computes gradients for each input operand (also in standard reals).
16///
17/// # Examples
18///
19/// ```ignore
20/// use tenferro_ext_tropical::ad::tropical_einsum_rrule;
21/// use tenferro_ext_tropical::{MaxPlus, MaxPlusAlgebra};
22/// use tenferro_prims::{CpuBackend, CpuContext};
23/// use tenferro_tensor::{MemoryOrder, Tensor};
24///
25/// let mut ctx = CpuContext::new(1);
26/// let a = Tensor::<MaxPlus<f64>>::from_slice(
27///     &[MaxPlus(1.0), MaxPlus(2.0), MaxPlus(3.0), MaxPlus(4.0)],
28///     &[2, 2], MemoryOrder::ColumnMajor,
29/// ).unwrap();
30/// let b = Tensor::<MaxPlus<f64>>::from_slice(
31///     &[MaxPlus(5.0), MaxPlus(6.0), MaxPlus(7.0), MaxPlus(8.0)],
32///     &[2, 2], MemoryOrder::ColumnMajor,
33/// ).unwrap();
34/// let grad_c = Tensor::<f64>::from_slice(
35///     &[1.0, 1.0, 1.0, 1.0], &[2, 2], MemoryOrder::ColumnMajor,
36/// ).unwrap();
37///
38/// let grads = tropical_einsum_rrule::<MaxPlus<f64>, MaxPlusAlgebra<f64>, CpuBackend>(
39///     &mut ctx, "ij,jk->ik", &[&a, &b], &grad_c,
40/// ).unwrap();
41/// assert_eq!(grads.len(), 2);
42/// ```
43pub fn tropical_einsum_rrule<T, Alg, Backend>(
44    _ctx: &mut Backend::Context,
45    subscripts: &str,
46    operands: &[&Tensor<T>],
47    cotangent: &Tensor<T::Inner>,
48) -> Result<Vec<Tensor<T::Inner>>>
49where
50    Alg: Semiring<Scalar = T>,
51    T: TropicalScalar + HasAlgebra<Algebra = Alg>,
52    Backend: TensorSemiringCore<Alg>,
53{
54    let subs = Subscripts::parse(subscripts)?;
55    validate_operand_count(operands.len(), "tropical_einsum_rrule")?;
56    let contracted = contracted_modes(&subs);
57    let (output, tracker) = tropical_forward_with_argmax(operands, &subs, &contracted)?;
58
59    if cotangent.dims() != output.dims() {
60        return Err(Error::InvalidArgument(format!(
61            "cotangent shape mismatch: expected {:?}, got {:?}",
62            output.dims(),
63            cotangent.dims()
64        )));
65    }
66
67    tropical_backward(operands, cotangent, &tracker, &subs, &contracted)
68}
69
70/// Forward-mode rule (frule) for tropical einsum.
71///
72/// Given tropical primals and optional standard-real tangents, compute the
73/// output tangent by routing each tangent contribution through the winner
74/// selected during the tropical forward pass.
75///
76/// # Examples
77///
78/// ```ignore
79/// use tenferro_ext_tropical::ad::tropical_einsum_frule;
80/// use tenferro_ext_tropical::{MaxPlus, MaxPlusAlgebra};
81/// use tenferro_prims::{CpuBackend, CpuContext};
82/// use tenferro_tensor::{MemoryOrder, Tensor};
83///
84/// let mut ctx = CpuContext::new(1);
85/// let a = Tensor::<MaxPlus<f64>>::from_slice(
86///     &[MaxPlus(1.0), MaxPlus(2.0), MaxPlus(3.0), MaxPlus(4.0)],
87///     &[2, 2], MemoryOrder::ColumnMajor,
88/// ).unwrap();
89/// let b = Tensor::<MaxPlus<f64>>::from_slice(
90///     &[MaxPlus(5.0), MaxPlus(6.0), MaxPlus(7.0), MaxPlus(8.0)],
91///     &[2, 2], MemoryOrder::ColumnMajor,
92/// ).unwrap();
93/// let da = Tensor::<f64>::from_slice(
94///     &[1.0, 0.0, 0.0, 0.0], &[2, 2], MemoryOrder::ColumnMajor,
95/// ).unwrap();
96///
97/// let dc = tropical_einsum_frule::<MaxPlus<f64>, MaxPlusAlgebra<f64>, CpuBackend>(
98///     &mut ctx, "ij,jk->ik", &[&a, &b], &[Some(&da), None],
99/// ).unwrap();
100/// assert_eq!(dc.dims(), &[2, 2]);
101/// ```
102pub fn tropical_einsum_frule<T, Alg, Backend>(
103    _ctx: &mut Backend::Context,
104    subscripts: &str,
105    primals: &[&Tensor<T>],
106    tangents: &[Option<&Tensor<T::Inner>>],
107) -> Result<Tensor<T::Inner>>
108where
109    Alg: Semiring<Scalar = T>,
110    T: TropicalScalar + HasAlgebra<Algebra = Alg>,
111    Backend: TensorSemiringCore<Alg>,
112{
113    let subs = Subscripts::parse(subscripts)?;
114    validate_operand_count(primals.len(), "tropical_einsum_frule")?;
115    if tangents.len() != primals.len() {
116        return Err(Error::InvalidArgument(format!(
117            "tropical_einsum_frule expects {} tangents, got {}",
118            primals.len(),
119            tangents.len()
120        )));
121    }
122    for (idx, (primal, tangent)) in primals.iter().zip(tangents.iter()).enumerate() {
123        if let Some(tangent) = tangent {
124            if tangent.dims() != primal.dims() {
125                return Err(Error::InvalidArgument(format!(
126                    "tangent shape mismatch for operand {idx}: expected {:?}, got {:?}",
127                    primal.dims(),
128                    tangent.dims()
129                )));
130            }
131        }
132    }
133
134    let contracted = contracted_modes(&subs);
135    let (output, tracker) = tropical_forward_with_argmax(primals, &subs, &contracted)?;
136    tropical_forward_tangent(
137        primals,
138        tangents,
139        &tracker,
140        &subs,
141        &contracted,
142        output.dims(),
143    )
144}
145fn validate_operand_count(count: usize, api_name: &str) -> Result<()> {
146    if count == 0 || count > 2 {
147        return Err(Error::InvalidArgument(format!(
148            "{api_name} supports 1 or 2 operands"
149        )));
150    }
151    Ok(())
152}