tropical_einsum_frule

Function tropical_einsum_frule 

Source
pub fn tropical_einsum_frule<T, Alg, Backend>(
    _ctx: &mut Backend::Context,
    subscripts: &str,
    primals: &[&Tensor<T>],
    tangents: &[Option<&Tensor<T::Inner>>],
) -> Result<Tensor<T::Inner>>
where Alg: Semiring<Scalar = T>, T: TropicalScalar + HasAlgebra<Algebra = Alg>, Backend: TensorSemiringCore<Alg>,
Expand description

Forward-mode rule (frule) for tropical einsum.

Given tropical primals and optional standard-real tangents, compute the output tangent by routing each tangent contribution through the winner selected during the tropical forward pass.

§Examples

use tenferro_ext_tropical::ad::tropical_einsum_frule;
use tenferro_ext_tropical::{MaxPlus, MaxPlusAlgebra};
use tenferro_prims::{CpuBackend, CpuContext};
use tenferro_tensor::{MemoryOrder, Tensor};

let mut ctx = CpuContext::new(1);
let a = Tensor::<MaxPlus<f64>>::from_slice(
    &[MaxPlus(1.0), MaxPlus(2.0), MaxPlus(3.0), MaxPlus(4.0)],
    &[2, 2], MemoryOrder::ColumnMajor,
).unwrap();
let b = Tensor::<MaxPlus<f64>>::from_slice(
    &[MaxPlus(5.0), MaxPlus(6.0), MaxPlus(7.0), MaxPlus(8.0)],
    &[2, 2], MemoryOrder::ColumnMajor,
).unwrap();
let da = Tensor::<f64>::from_slice(
    &[1.0, 0.0, 0.0, 0.0], &[2, 2], MemoryOrder::ColumnMajor,
).unwrap();

let dc = tropical_einsum_frule::<MaxPlus<f64>, MaxPlusAlgebra<f64>, CpuBackend>(
    &mut ctx, "ij,jk->ik", &[&a, &b], &[Some(&da), None],
).unwrap();
assert_eq!(dc.dims(), &[2, 2]);