Skip to main content

tenferro_einsum/
lib.rs

1//! High-level einsum with N-ary contraction tree optimization.
2//!
3//! This crate provides:
4//!
5//! - **String notation**: `"ij,jk->ik"` (NumPy/PyTorch compatible)
6//! - **Parenthesized notation**: `"ij,(jk,kl)->il"` respects user-specified
7//!   contraction order via [`NestedEinsum`]
8//! - **Integer label notation**: using `u32` labels
9//! - **N-ary contraction**: Automatic or manual optimization of pairwise
10//!   contraction order via [`ContractionTree`]
11//! - **v2 builder**: [`build_einsum_fragment`] lowers einsum into a compute
12//!   graph fragment using `DotGeneral`, `ReduceSum`, `Transpose`, etc.
13//!
14//! # Examples
15//!
16//! ```ignore
17//! use computegraph::fragment::FragmentBuilder;
18//! use computegraph::types::ValRef;
19//! use tenferro_ops::input_key::TensorInputKey;
20//! use tenferro_ops::std_tensor_op::StdTensorOp;
21//! use tenferro_einsum::{ContractionTree, Subscripts, build_einsum_fragment};
22//!
23//! let subs = Subscripts::parse("ij,jk->ik").unwrap();
24//! let tree = ContractionTree::optimize(&subs, &[&[2, 3], &[3, 4]]).unwrap();
25//!
26//! let mut builder = FragmentBuilder::<StdTensorOp>::new();
27//! let a = builder.add_input(TensorInputKey::User { id: 0 });
28//! let b = builder.add_input(TensorInputKey::User { id: 1 });
29//!
30//! let result = build_einsum_fragment(
31//!     &mut builder,
32//!     &tree,
33//!     &[ValRef::Local(a), ValRef::Local(b)],
34//!     &[vec![2, 3], vec![3, 4]],
35//! );
36//! ```
37
38pub mod builder;
39mod eager;
40pub mod planning;
41pub mod syntax;
42mod typed_eager;
43pub(crate) mod util;
44
45// Re-exports for convenience
46pub use builder::build_einsum_fragment;
47pub use eager::eager_einsum;
48pub use planning::tree::{ContractionOptimizerOptions, ContractionTree};
49pub use syntax::nested::NestedEinsum;
50pub use syntax::subscripts::Subscripts;
51pub use typed_eager::typed_eager_einsum;
52pub use util::{build_size_dict, compute_output_shape};
53
54#[cfg(test)]
55mod tests;