einsum

Function einsum 

Source
pub fn einsum<Alg, Backend>(
    ctx: &mut BackendContext<Alg, Backend>,
    subscripts: &str,
    operands: &[&Tensor<Alg::Scalar>],
    size_dict: Option<&HashMap<u32, usize>>,
) -> Result<Tensor<Alg::Scalar>>
where Alg: Semiring, Alg::Scalar: Scalar + Conjugate + HasAlgebra<Algebra = Alg>, Backend: EinsumBackend<Alg>, BackendContext<Alg, Backend>: TensorTempPoolContext,
Expand description

Execute einsum using string notation.

Parses the subscript string, optimizes the contraction order, and executes the contraction. The backend B and its context ctx are passed explicitly.

Parentheses in the subscript string specify contraction order explicitly (e.g., "ij,(jk,kl)->il" contracts B and C first). Without parentheses, the contraction order is optimized automatically.

Ellipsis notation (...) is supported for batch dimensions, following NumPy/PyTorch/JAX conventions. The ellipsis expands to the appropriate number of batch dimensions based on tensor shapes.

§Arguments

  • ctx — Mutable backend context (thread pool, plan cache)
  • subscripts — Einstein summation notation (e.g., "ij,jk->ik")
  • operands — Input tensors
  • size_dict — Optional dimension sizes for output labels not in inputs

§Examples

use tenferro_algebra::Standard;
use tenferro_prims::{CpuBackend, CpuContext};
let mut ctx = CpuContext::new(4);

// Matrix multiplication
let c = einsum::<Standard<f64>, CpuBackend>(&mut ctx, "ij,jk->ik", &[&a, &b], None).unwrap();

// Trace
let tr = einsum::<Standard<f64>, CpuBackend>(&mut ctx, "ii->", &[&a], None).unwrap();

// Batch matrix multiplication
let c =
    einsum::<Standard<f64>, CpuBackend>(&mut ctx, "bij,bjk->bik", &[&a, &b], None).unwrap();

// Batch matrix multiplication with ellipsis notation
let c =
    einsum::<Standard<f64>, CpuBackend>(&mut ctx, "...ij,...jk->...ik", &[&a, &b], None).unwrap();

// Explicit contraction order: contract B*C first, then A
let d = einsum::<Standard<f64>, CpuBackend>(&mut ctx, "ij,(jk,kl)->il", &[&a, &b, &c], None)
    .unwrap();

§Errors

Returns an error if the notation is invalid or tensor shapes are incompatible with the subscripts.