tropical_einsum_rrule

Function tropical_einsum_rrule 

Source
pub fn tropical_einsum_rrule<T, Alg, Backend>(
    _ctx: &mut Backend::Context,
    subscripts: &str,
    operands: &[&Tensor<T>],
    cotangent: &Tensor<T::Inner>,
) -> Result<Vec<Tensor<T::Inner>>>
where Alg: Semiring<Scalar = T>, T: TropicalScalar + HasAlgebra<Algebra = Alg>, Backend: TensorSemiringCore<Alg>,
Expand description

Reverse-mode rule (rrule) for tropical einsum.

Given a tropical einsum operation and a cotangent (in standard reals), computes gradients for each input operand (also in standard reals).

§Examples

use tenferro_ext_tropical::ad::tropical_einsum_rrule;
use tenferro_ext_tropical::{MaxPlus, MaxPlusAlgebra};
use tenferro_prims::{CpuBackend, CpuContext};
use tenferro_tensor::{MemoryOrder, Tensor};

let mut ctx = CpuContext::new(1);
let a = Tensor::<MaxPlus<f64>>::from_slice(
    &[MaxPlus(1.0), MaxPlus(2.0), MaxPlus(3.0), MaxPlus(4.0)],
    &[2, 2], MemoryOrder::ColumnMajor,
).unwrap();
let b = Tensor::<MaxPlus<f64>>::from_slice(
    &[MaxPlus(5.0), MaxPlus(6.0), MaxPlus(7.0), MaxPlus(8.0)],
    &[2, 2], MemoryOrder::ColumnMajor,
).unwrap();
let grad_c = Tensor::<f64>::from_slice(
    &[1.0, 1.0, 1.0, 1.0], &[2, 2], MemoryOrder::ColumnMajor,
).unwrap();

let grads = tropical_einsum_rrule::<MaxPlus<f64>, MaxPlusAlgebra<f64>, CpuBackend>(
    &mut ctx, "ij,jk->ik", &[&a, &b], &grad_c,
).unwrap();
assert_eq!(grads.len(), 2);