pub fn tropical_einsum_rrule<T, Alg, Backend>(
_ctx: &mut Backend::Context,
subscripts: &str,
operands: &[&Tensor<T>],
cotangent: &Tensor<T::Inner>,
) -> Result<Vec<Tensor<T::Inner>>>where
Alg: Semiring<Scalar = T>,
T: TropicalScalar + HasAlgebra<Algebra = Alg>,
Backend: TensorSemiringCore<Alg>,Expand description
Reverse-mode rule (rrule) for tropical einsum.
Given a tropical einsum operation and a cotangent (in standard reals), computes gradients for each input operand (also in standard reals).
§Examples
ⓘ
use tenferro_ext_tropical::ad::tropical_einsum_rrule;
use tenferro_ext_tropical::{MaxPlus, MaxPlusAlgebra};
use tenferro_prims::{CpuBackend, CpuContext};
use tenferro_tensor::{MemoryOrder, Tensor};
let mut ctx = CpuContext::new(1);
let a = Tensor::<MaxPlus<f64>>::from_slice(
&[MaxPlus(1.0), MaxPlus(2.0), MaxPlus(3.0), MaxPlus(4.0)],
&[2, 2], MemoryOrder::ColumnMajor,
).unwrap();
let b = Tensor::<MaxPlus<f64>>::from_slice(
&[MaxPlus(5.0), MaxPlus(6.0), MaxPlus(7.0), MaxPlus(8.0)],
&[2, 2], MemoryOrder::ColumnMajor,
).unwrap();
let grad_c = Tensor::<f64>::from_slice(
&[1.0, 1.0, 1.0, 1.0], &[2, 2], MemoryOrder::ColumnMajor,
).unwrap();
let grads = tropical_einsum_rrule::<MaxPlus<f64>, MaxPlusAlgebra<f64>, CpuBackend>(
&mut ctx, "ij,jk->ik", &[&a, &b], &grad_c,
).unwrap();
assert_eq!(grads.len(), 2);