Einsum Design
Einsum is a standard extension crate, not part of a root facade. The public user-facing paths live under tenferro_einsum as crate-root extension traits: GraphCompilerEinsumExt for traced graph construction, EagerEinsumExt for immediate eager execution, and tensor extension traits for tensordot contraction sugar. tensordot is not a tenferro-linalg API.
The workspace intentionally has no root tenferro crate and no einsum facade paths. Programs that use traced einsum must explicitly register the extension runtime with their executor.
The implementation is split between:
crates/tenferro-einsum/src/traced.rsfor the user-facing traced API, contraction strategy selection, symbolic-shape handling, and graph cache integration,crates/tenferro-einsum/src/extension.rsfor runtime extension execution,crates/tenferro-einsum/src/syntax/for subscript and nested-order parsing,crates/tenferro-einsum/src/planning/for contraction tree planning and per-step lowering plans,crates/tenferro-einsum/src/builder.rsfor graph-fragment lowering,crates/tenferro-einsum/src/eager_ad.rsfor eager tensor execution.
Historical design notes that refer to direct CudaBackend/RocmBackend, tenferro-prims, or the old nine-function einsum API are not current.
Public Traced API
The extension crate exposes lazy traced einsum:
use tenferro_cpu::CpuBackend;
use tenferro_einsum::GraphCompilerEinsumExt;
use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
let a = TracedTensor::from_vec_col_major(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
let b = TracedTensor::from_vec_col_major(vec![3, 2], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
let mut compiler = GraphCompiler::new();
let c = compiler.einsum(&[&a, &b], "ij,jk->ik").unwrap();
let program = compiler.compile(&c).unwrap();
let mut executor = GraphExecutor::new(CpuBackend::new());
executor.register_extension(tenferro_einsum::register_runtime).unwrap();
let result = executor.run(&program).unwrap();
assert_eq!(result.shape(), &[2, 2]);einsum_with accepts an explicit EinsumOptimize strategy:
| Strategy | Meaning |
|---|---|
Auto(ContractionOptimizerOptions) |
TreeSA/omeco path optimization with configured score |
False |
left-to-right contraction |
Nested(NestedEinsum) |
explicit parenthesized contraction tree |
Path(Vec<(usize, usize)>) |
JAX-compatible shrinking-list path; shape-independent and valid for symbolic traced inputs |
Tree(ContractionTree) |
concrete/precomputed tree; accepted only when concrete shapes are available |
EinsumOptimize::default() is time-optimized automatic planning. The traced API stores the resolved planning policy as a shape-independent plan specification in the extension payload. Path pairs remain positional over the current shrinking operand list; Tree values are converted to fixed contraction pairs when accepted for concrete inputs.
Eager Tensor API
EagerEinsumExt exposes immediate execution over EagerTensor input slices/arrays:
use tenferro_ad::{EagerRuntime, Tensor};
use tenferro_einsum::EagerEinsumExt;
let ctx = EagerRuntime::new();
let a = Tensor::from_vec_col_major(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
let b = Tensor::from_vec_col_major(vec![3, 2], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
let a = ctx.constant_from(a).unwrap();
let b = ctx.constant_from(b).unwrap();
let c = [&a, &b].einsum("ij,jk->ik").unwrap();
assert_eq!(c.shape(), &[2, 2]);Runtime-owned concrete execution is internal to the extension runtime. It is not exposed through a facade crate.
Subscripts And Repeated Labels
Subscripts::parse accepts flat NumPy/PyTorch-style labels and rejects parenthesized contraction-order notation. Use NestedEinsum::parse when contraction order must be preserved.
Repeated-label semantics follow the usual einsum rules:
| Pattern | Meaning |
|---|---|
ii-> |
extract the diagonal, then reduce it to a scalar trace |
ii->i |
extract the diagonal |
iij->ij |
extract the diagonal across the first two axes and preserve j |
i->ii |
embed the vector on a diagonal matrix |
The implementation applies these rules before ordinary contraction:
diagonalize_repeatedrepeatedly appliesextract_diagonalto duplicate labels within one operand.- Labels absent from the output or from later live operands are reduced with
reduce_sum. embed_repeatedappliesembed_diagonalwhen the output repeats a label more often than the current value.transpose_to_labelsrestores requested output order.
Strict binary/GEMM lowering intentionally rejects repeated labels and returns None. Those cases stay on the general eager/builder path, which handles diagonalization explicitly.
Static And Symbolic Shapes
The traced extension API chooses the lowering mode from input shape availability:
| Inputs | Build-time behavior | Runtime behavior |
|---|---|---|
| All concrete shapes | optimize the contraction tree at graph build time and lower into ordinary graph ops where possible | execute the lowered graph |
| Any symbolic shape | emit one einsum extension op | optimize from actual input shapes at runtime |
tenferro_einsum caches concrete-shape contraction trees in the extension cache. Runtime contraction trees are keyed by subscripts, input shapes, and the resolved planning policy or explicit path so repeated symbolic-shape runs with the same concrete shapes and policy amortize planning cost without conflating different optimizer settings. The same plan specification participates in traced extension payload identity, so otherwise identical ops that use different planner options or paths remain distinct extension ops. Runtime extension execution also caches the compiled inner execution program keyed by subscripts, concrete input shapes, input dtypes, and the resolved planning policy, so repeated eager or traced extension runs do not rebuild and compile the lowered inner graph.
Planning
ContractionTree records the pairwise contraction sequence, live operand labels, size dictionary, and compiled step plans. Automatic planning first asks omeco/TreeSA for a path. If omeco does not return one, the local self-greedy fallback chooses the pair with the smallest intermediate output size.
Planner invariants are checked with normal Result propagation:
- input rank and shape labels are validated by
build_size_dict, - explicit paths must reference distinct live operands,
- the final explicit path must leave exactly one live value,
- contraction-cost labels must have known sizes.
Lowering And Execution
Each pairwise step classifies labels into:
- left-only free labels,
- right-only free labels,
- shared batch labels that survive,
- shared contraction labels that are reduced.
When a strict binary plan applies, the step caches the canonical matrix/GEMM layout metadata. When it does not apply, the builder and eager executor use the general path of diagonalization, reductions, broadcast/outer product, and DotGeneral.
Column-major ordering matters. For GEMM-like steps, compute dimensions stay on the left and batch dimensions stay on the right so each batch slice remains a contiguous block for the underlying tensor backend.
GPU Interaction
Einsum itself remains backend-agnostic at the graph level. GPU execution happens when a compiled program is evaluated with CudaBackend from crates/tenferro-gpu/src/cubecl/.
Current GPU status:
- CUDA uses CubeCL/CubeCL-CUDA under the public
cudafeature. - cuTENSOR/cuBLAS paths cover selected contractions and GEMM-like operations.
- ROCm is a stub and not a supported execution path.
- Complex CubeCL expansion is blocked on upstream CubeCL support and is not part of this batch.
- GPU benchmarking is outside this batch.
AD
Graph-level AD rules for einsum live in tenferro-einsum and are registered as extension AD rules. Primitive operations emitted by lowering still use the core AD rules from crates/tenferro-internal-ops/src/ad/.
VJP construction preserves the primal planning policy. For explicit EinsumOptimize::Path payloads, the AD rule remaps the positional path to the VJP operand list so the gradient contraction inherits the caller’s selected order where that order is still meaningful.
Tests
Primary local checks for this surface are:
cargo test -p tenferro-einsum
cargo test -p tenferro-einsum --docGPU-specific execution tests require CUDA and are ignored by default; see gpu-backend-design.md for the command.