Skip to main content

tenferro_einsum/
lib.rs

1//! High-level einsum with N-ary contraction tree optimization.
2//!
3//! This crate provides:
4//!
5//! - **String notation**: `"ij,jk->ik"` (NumPy/PyTorch compatible)
6//! - **Parenthesized notation**: `"ij,(jk,kl)->il"` respects user-specified
7//!   contraction order via [`NestedEinsum`]
8//! - **Integer label notation**: using `u32` labels
9//! - **Repeated labels**: `"ii->i"` extracts diagonals, `"ii->"` traces, and
10//!   `"i->ii"` embeds a vector on a diagonal
11//! - **N-ary contraction**: Automatic or manual optimization of pairwise
12//!   contraction order via [`ContractionTree`]
13//! - **Tensordot sugar**: NumPy-style axis-pair contraction extension methods,
14//!   implemented as contraction sugar rather than as linear algebra APIs.
15//! - **Extension runtime**: traced einsum lowers to a registered tenferro
16//!   extension runtime, keeping core op definitions small.
17//! - **Tensor extension traits**: graph-building helpers are available as
18//!   methods on `GraphCompiler`, eager input slices, and tensor receivers.
19//!
20//! # Examples
21//!
22//! ```
23//! use tenferro_einsum::{ContractionTree, Subscripts};
24//!
25//! let subs = Subscripts::parse("ij,jk->ik").unwrap();
26//! let tree = ContractionTree::optimize(&subs, &[&[2, 3], &[3, 4]]).unwrap();
27//! assert_eq!(tree.step_count(), 1);
28//! ```
29//!
30//! ```
31//! use tenferro_einsum::Subscripts;
32//!
33//! let trace = Subscripts::parse("ii->").unwrap();
34//! let diagonal = Subscripts::parse("ii->i").unwrap();
35//! let embedded = Subscripts::parse("i->ii").unwrap();
36//! let higher_rank = Subscripts::parse("iij->ij").unwrap();
37//!
38//! assert!(trace.output.is_empty());
39//! assert_eq!(diagonal.output, vec![b'i' as u32]);
40//! assert_eq!(embedded.output, vec![b'i' as u32, b'i' as u32]);
41//! assert_eq!(higher_rank.inputs[0], vec![b'i' as u32, b'i' as u32, b'j' as u32]);
42//! ```
43
44mod binary_dot;
45mod builder;
46mod cache;
47mod eager;
48#[cfg(feature = "autodiff")]
49mod eager_ad;
50mod error;
51mod extension;
52pub mod lowering;
53mod optimize;
54mod planning;
55mod subscripts;
56mod syntax;
57mod tensordot;
58mod traced;
59#[cfg(test)]
60mod typed_eager;
61pub(crate) mod util;
62
63pub use cache::EINSUM_EXTENSION_FAMILY_ID;
64#[cfg(feature = "autodiff")]
65pub use eager_ad::{EagerEinsumExt, EagerTensorEinsumExt};
66pub use error::{Error, Result};
67#[cfg(feature = "autodiff")]
68pub use extension::ad_rules;
69pub use extension::register_runtime;
70pub use optimize::EinsumOptimize;
71pub use planning::tree::{ContractionOptimizerOptions, ContractionTree};
72pub use subscripts::{parse_einsum_subscripts, EinsumSubscripts};
73pub use syntax::nested::NestedEinsum;
74pub use syntax::subscripts::Subscripts;
75pub use tensordot::TensorDotAxes;
76pub use traced::{GraphCompilerEinsumExt, TracedTensorEinsumExt};
77
78#[cfg(test)]
79mod tests;
80#[cfg(test)]
81mod typed_eager_tests;