try_einsum

Function try_einsum 

Source
pub fn try_einsum<B: TensorNetworkOps, const D: usize>(
    subscripts: &str,
    inputs: Vec<Tensor<B, D>>,
) -> Result<Tensor<B, D>>
Expand description

Fallible high-level einsum on Burn tensors, dispatching to the backend’s TensorNetworkOps::tn_einsum implementation.

The const rank D is shared by the input and output Burn tensors, so this wrapper is only suitable for contractions whose output rank stays equal to the input rank. Use TensorNetworkOps::tn_einsum directly for rank-changing contractions.

§Examples

use burn::backend::NdArray;
use burn::tensor::Tensor;
use tenferro_ext_burn::try_einsum;

let a: Tensor<NdArray<f64>, 2> = Tensor::ones([3, 4], &Default::default());
let b: Tensor<NdArray<f64>, 2> = Tensor::ones([4, 5], &Default::default());
let c: Tensor<NdArray<f64>, 2> = try_einsum("ij,jk->ik", vec![a, b]).unwrap();