einsum_ad

Function einsum_ad 

Source
pub fn einsum_ad<'a, T>(
    subscripts: &'a str,
    operands: &'a [&'a AdTensor<T>],
) -> EinsumAdBuilder<'a, T>
where T: Scalar + HasAlgebra<Algebra = Standard<T>>, CpuBackend: TensorPrims<Standard<T>, Context = CpuContext>,
Expand description

Creates a builder for AD einsum.

ยงExamples

use ad_tensors_rs::{einsum_ad, set_default_runtime, AdTensor, RuntimeContext};
use tenferro_prims::CpuContext;
use tenferro_tensor::{MemoryOrder, Tensor};

let _guard = set_default_runtime(RuntimeContext::Cpu(CpuContext::new(1)));
let a = Tensor::<f64>::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2], MemoryOrder::ColumnMajor).unwrap();
let b = Tensor::<f64>::from_slice(&[5.0, 6.0, 7.0, 8.0], &[2, 2], MemoryOrder::ColumnMajor).unwrap();
let ad_a = AdTensor::new_primal(a);
let ad_b = AdTensor::new_primal(b);
let out = einsum_ad("ij,jk->ik", &[&ad_a, &ad_b]).run().unwrap();
assert_eq!(out.dims(), &[2, 2]);