tenferro_ext_tropical/ad/rules.rs
1use tenferro_algebra::{HasAlgebra, Semiring};
2use tenferro_device::{Error, Result};
3use tenferro_einsum::Subscripts;
4use tenferro_prims::TensorSemiringCore;
5use tenferro_tensor::Tensor;
6
7use super::backward::tropical_backward;
8use super::common::contracted_modes;
9use super::forward::{tropical_forward_tangent, tropical_forward_with_argmax};
10use super::TropicalScalar;
11
12/// Reverse-mode rule (rrule) for tropical einsum.
13///
14/// Given a tropical einsum operation and a cotangent (in standard reals),
15/// computes gradients for each input operand (also in standard reals).
16///
17/// # Examples
18///
19/// ```ignore
20/// use tenferro_ext_tropical::ad::tropical_einsum_rrule;
21/// use tenferro_ext_tropical::{MaxPlus, MaxPlusAlgebra};
22/// use tenferro_prims::{CpuBackend, CpuContext};
23/// use tenferro_tensor::{MemoryOrder, Tensor};
24///
25/// let mut ctx = CpuContext::new(1);
26/// let a = Tensor::<MaxPlus<f64>>::from_slice(
27/// &[MaxPlus(1.0), MaxPlus(2.0), MaxPlus(3.0), MaxPlus(4.0)],
28/// &[2, 2], MemoryOrder::ColumnMajor,
29/// ).unwrap();
30/// let b = Tensor::<MaxPlus<f64>>::from_slice(
31/// &[MaxPlus(5.0), MaxPlus(6.0), MaxPlus(7.0), MaxPlus(8.0)],
32/// &[2, 2], MemoryOrder::ColumnMajor,
33/// ).unwrap();
34/// let grad_c = Tensor::<f64>::from_slice(
35/// &[1.0, 1.0, 1.0, 1.0], &[2, 2], MemoryOrder::ColumnMajor,
36/// ).unwrap();
37///
38/// let grads = tropical_einsum_rrule::<MaxPlus<f64>, MaxPlusAlgebra<f64>, CpuBackend>(
39/// &mut ctx, "ij,jk->ik", &[&a, &b], &grad_c,
40/// ).unwrap();
41/// assert_eq!(grads.len(), 2);
42/// ```
43pub fn tropical_einsum_rrule<T, Alg, Backend>(
44 _ctx: &mut Backend::Context,
45 subscripts: &str,
46 operands: &[&Tensor<T>],
47 cotangent: &Tensor<T::Inner>,
48) -> Result<Vec<Tensor<T::Inner>>>
49where
50 Alg: Semiring<Scalar = T>,
51 T: TropicalScalar + HasAlgebra<Algebra = Alg>,
52 Backend: TensorSemiringCore<Alg>,
53{
54 let subs = Subscripts::parse(subscripts)?;
55 validate_operand_count(operands.len(), "tropical_einsum_rrule")?;
56 let contracted = contracted_modes(&subs);
57 let (output, tracker) = tropical_forward_with_argmax(operands, &subs, &contracted)?;
58
59 if cotangent.dims() != output.dims() {
60 return Err(Error::InvalidArgument(format!(
61 "cotangent shape mismatch: expected {:?}, got {:?}",
62 output.dims(),
63 cotangent.dims()
64 )));
65 }
66
67 tropical_backward(operands, cotangent, &tracker, &subs, &contracted)
68}
69
70/// Forward-mode rule (frule) for tropical einsum.
71///
72/// Given tropical primals and optional standard-real tangents, compute the
73/// output tangent by routing each tangent contribution through the winner
74/// selected during the tropical forward pass.
75///
76/// # Examples
77///
78/// ```ignore
79/// use tenferro_ext_tropical::ad::tropical_einsum_frule;
80/// use tenferro_ext_tropical::{MaxPlus, MaxPlusAlgebra};
81/// use tenferro_prims::{CpuBackend, CpuContext};
82/// use tenferro_tensor::{MemoryOrder, Tensor};
83///
84/// let mut ctx = CpuContext::new(1);
85/// let a = Tensor::<MaxPlus<f64>>::from_slice(
86/// &[MaxPlus(1.0), MaxPlus(2.0), MaxPlus(3.0), MaxPlus(4.0)],
87/// &[2, 2], MemoryOrder::ColumnMajor,
88/// ).unwrap();
89/// let b = Tensor::<MaxPlus<f64>>::from_slice(
90/// &[MaxPlus(5.0), MaxPlus(6.0), MaxPlus(7.0), MaxPlus(8.0)],
91/// &[2, 2], MemoryOrder::ColumnMajor,
92/// ).unwrap();
93/// let da = Tensor::<f64>::from_slice(
94/// &[1.0, 0.0, 0.0, 0.0], &[2, 2], MemoryOrder::ColumnMajor,
95/// ).unwrap();
96///
97/// let dc = tropical_einsum_frule::<MaxPlus<f64>, MaxPlusAlgebra<f64>, CpuBackend>(
98/// &mut ctx, "ij,jk->ik", &[&a, &b], &[Some(&da), None],
99/// ).unwrap();
100/// assert_eq!(dc.dims(), &[2, 2]);
101/// ```
102pub fn tropical_einsum_frule<T, Alg, Backend>(
103 _ctx: &mut Backend::Context,
104 subscripts: &str,
105 primals: &[&Tensor<T>],
106 tangents: &[Option<&Tensor<T::Inner>>],
107) -> Result<Tensor<T::Inner>>
108where
109 Alg: Semiring<Scalar = T>,
110 T: TropicalScalar + HasAlgebra<Algebra = Alg>,
111 Backend: TensorSemiringCore<Alg>,
112{
113 let subs = Subscripts::parse(subscripts)?;
114 validate_operand_count(primals.len(), "tropical_einsum_frule")?;
115 if tangents.len() != primals.len() {
116 return Err(Error::InvalidArgument(format!(
117 "tropical_einsum_frule expects {} tangents, got {}",
118 primals.len(),
119 tangents.len()
120 )));
121 }
122 for (idx, (primal, tangent)) in primals.iter().zip(tangents.iter()).enumerate() {
123 if let Some(tangent) = tangent {
124 if tangent.dims() != primal.dims() {
125 return Err(Error::InvalidArgument(format!(
126 "tangent shape mismatch for operand {idx}: expected {:?}, got {:?}",
127 primal.dims(),
128 tangent.dims()
129 )));
130 }
131 }
132 }
133
134 let contracted = contracted_modes(&subs);
135 let (output, tracker) = tropical_forward_with_argmax(primals, &subs, &contracted)?;
136 tropical_forward_tangent(
137 primals,
138 tangents,
139 &tracker,
140 &subs,
141 &contracted,
142 output.dims(),
143 )
144}
145fn validate_operand_count(count: usize, api_name: &str) -> Result<()> {
146 if count == 0 || count > 2 {
147 return Err(Error::InvalidArgument(format!(
148 "{api_name} supports 1 or 2 operands"
149 )));
150 }
151 Ok(())
152}