strided_opteinsum/lib.rs
1//! N-ary Einstein summation with nested contraction notation.
2//!
3//! This crate provides an einsum frontend that parses nested string
4//! notation (e.g. `"(ij,jk),kl->il"`), supports mixed `f64` / `Complex64`
5//! operands, and delegates pairwise contractions to [`strided_einsum2`].
6//! For three or more tensors in a single contraction node the
7//! [`omeco`] greedy optimizer is used to find an efficient pairwise order.
8//!
9//! # Quick start
10//!
11//! ```ignore
12//! use strided_opteinsum::{einsum, EinsumOperand};
13//!
14//! let result = einsum("(ij,jk),kl->il", vec![a.into(), b.into(), c.into()], None)?;
15//! ```
16
17use std::collections::HashMap;
18
19/// Error types for einsum operations.
20pub mod error;
21/// Recursive contraction-tree evaluation.
22pub mod expr;
23/// Type-erased einsum operands (`f64` / `Complex64`, owned / borrowed).
24pub mod operand;
25/// Nested einsum string parser.
26pub mod parse;
27/// Single-tensor operations (permute, trace, diagonal extraction).
28pub mod single_tensor;
29/// Runtime type dispatch over `f64` and `Complex64` tensors.
30pub mod typed_tensor;
31
32pub use error::{EinsumError, Result};
33pub use operand::{EinsumOperand, EinsumScalar, StridedData};
34pub use parse::{parse_einsum, EinsumCode, EinsumNode};
35pub use typed_tensor::{needs_c64_promotion, TypedTensor};
36
37/// Parse and evaluate an einsum expression in one call.
38///
39/// Pass `size_dict` to specify sizes for output indices not present in any
40/// input (generative outputs like `"->ii"` or `"i->ij"`).
41///
42/// # Example
43/// ```ignore
44/// let result = einsum("(ij,jk),kl->il", vec![a.into(), b.into(), c.into()], None)?;
45/// ```
46pub fn einsum<'a>(
47 notation: &str,
48 operands: Vec<EinsumOperand<'a>>,
49 size_dict: Option<&HashMap<char, usize>>,
50) -> Result<EinsumOperand<'a>> {
51 let code = parse_einsum(notation)?;
52 code.evaluate(operands, size_dict)
53}
54
55/// Parse and evaluate an einsum expression, writing the result into a
56/// pre-allocated output buffer with alpha/beta scaling.
57///
58/// `output = alpha * einsum(operands) + beta * output`
59///
60/// Pass `size_dict` to specify sizes for output indices not present in any
61/// input (generative outputs like `"->ii"` or `"i->ij"`).
62///
63/// # Example
64/// ```ignore
65/// use strided_opteinsum::{einsum_into, EinsumOperand};
66///
67/// let mut c = StridedArray::<f64>::col_major(&[2, 2]);
68/// einsum_into("ij,jk->ik", vec![a.into(), b.into()], c.view_mut(), 1.0, 0.0, None)?;
69/// ```
70pub fn einsum_into<T: EinsumScalar>(
71 notation: &str,
72 operands: Vec<EinsumOperand<'_>>,
73 output: strided_view::StridedViewMut<T>,
74 alpha: T,
75 beta: T,
76 size_dict: Option<&HashMap<char, usize>>,
77) -> Result<()> {
78 let code = parse_einsum(notation)?;
79 code.evaluate_into(operands, output, alpha, beta, size_dict)
80}