pub fn tropical_einsum_frule<T, Alg, Backend>(
_ctx: &mut Backend::Context,
subscripts: &str,
primals: &[&Tensor<T>],
tangents: &[Option<&Tensor<T::Inner>>],
) -> Result<Tensor<T::Inner>>where
Alg: Semiring<Scalar = T>,
T: TropicalScalar + HasAlgebra<Algebra = Alg>,
Backend: TensorSemiringCore<Alg>,Expand description
Forward-mode rule (frule) for tropical einsum.
Given tropical primals and optional standard-real tangents, compute the output tangent by routing each tangent contribution through the winner selected during the tropical forward pass.
§Examples
ⓘ
use tenferro_ext_tropical::ad::tropical_einsum_frule;
use tenferro_ext_tropical::{MaxPlus, MaxPlusAlgebra};
use tenferro_prims::{CpuBackend, CpuContext};
use tenferro_tensor::{MemoryOrder, Tensor};
let mut ctx = CpuContext::new(1);
let a = Tensor::<MaxPlus<f64>>::from_slice(
&[MaxPlus(1.0), MaxPlus(2.0), MaxPlus(3.0), MaxPlus(4.0)],
&[2, 2], MemoryOrder::ColumnMajor,
).unwrap();
let b = Tensor::<MaxPlus<f64>>::from_slice(
&[MaxPlus(5.0), MaxPlus(6.0), MaxPlus(7.0), MaxPlus(8.0)],
&[2, 2], MemoryOrder::ColumnMajor,
).unwrap();
let da = Tensor::<f64>::from_slice(
&[1.0, 0.0, 0.0, 0.0], &[2, 2], MemoryOrder::ColumnMajor,
).unwrap();
let dc = tropical_einsum_frule::<MaxPlus<f64>, MaxPlusAlgebra<f64>, CpuBackend>(
&mut ctx, "ij,jk->ik", &[&a, &b], &[Some(&da), None],
).unwrap();
assert_eq!(dc.dims(), &[2, 2]);