Skip to main content

Crate tenferro_einsum

Crate tenferro_einsum 

Source
Expand description

High-level einsum with N-ary contraction tree optimization.

This crate provides:

  • String notation: "ij,jk->ik" (NumPy/PyTorch compatible)
  • Parenthesized notation: "ij,(jk,kl)->il" respects user-specified contraction order via NestedEinsum
  • Integer label notation: using u32 labels
  • N-ary contraction: Automatic or manual optimization of pairwise contraction order via ContractionTree
  • v2 builder: build_einsum_fragment lowers einsum into a compute graph fragment using DotGeneral, ReduceSum, Transpose, etc.

§Examples

use computegraph::fragment::FragmentBuilder;
use computegraph::types::ValRef;
use tenferro_ops::input_key::TensorInputKey;
use tenferro_ops::std_tensor_op::StdTensorOp;
use tenferro_einsum::{ContractionTree, Subscripts, build_einsum_fragment};

let subs = Subscripts::parse("ij,jk->ik").unwrap();
let tree = ContractionTree::optimize(&subs, &[&[2, 3], &[3, 4]]).unwrap();

let mut builder = FragmentBuilder::<StdTensorOp>::new();
let a = builder.add_input(TensorInputKey::User { id: 0 });
let b = builder.add_input(TensorInputKey::User { id: 1 });

let result = build_einsum_fragment(
    &mut builder,
    &tree,
    &[ValRef::Local(a), ValRef::Local(b)],
    &[vec![2, 3], vec![3, 4]],
);

Re-exports§

pub use builder::build_einsum_fragment;

Modules§

builder
planning
syntax

Structs§

ContractionOptimizerOptions
Public options for automatic contraction-path optimization.
ContractionTree
Contraction tree determining pairwise contraction order for N-ary einsum.
Subscripts
Einsum subscripts using integer labels (omeinsum-rs compatible).

Enums§

NestedEinsum
Recursive einsum tree that preserves parenthesized grouping.

Functions§

build_size_dict
Build a label -> size mapping from subscripts and input shapes.
compute_output_shape
Compute output shape from output subscripts and size dictionary.
eager_einsum
Eager N-ary einsum on concrete Tensor values.
typed_eager_einsum
Execute eager einsum over typed tensors and return a typed result.