strided_opteinsum/lib.rs
1//! N-ary Einstein summation with nested contraction notation.
2//!
3//! This crate provides an einsum frontend that parses nested string
4//! notation (e.g. `"(ij,jk),kl->il"`), supports mixed `f64` / `Complex64`
5//! operands, and delegates pairwise contractions to [`strided_einsum2`].
6//! For three or more tensors in a single contraction node the
7//! [`omeco`] greedy optimizer is used to find an efficient pairwise order.
8//!
9//! # Quick start
10//!
11//! ```ignore
12//! use strided_opteinsum::{einsum, EinsumOperand};
13//!
14//! let result = einsum("(ij,jk),kl->il", vec![a.into(), b.into(), c.into()], None)?;
15//! ```
16
17use std::collections::HashMap;
18
19/// Error types for einsum operations.
20pub mod error;
21/// Recursive contraction-tree evaluation.
22pub mod expr;
23/// Type-erased einsum operands (`f64` / `Complex64`, owned / borrowed).
24pub mod operand;
25/// Nested einsum string parser.
26pub mod parse;
27/// Single-tensor operations (permute, trace, diagonal extraction).
28pub mod single_tensor;
29/// Runtime type dispatch over `f64` and `Complex64` tensors.
30pub mod typed_tensor;
31
32pub use error::{EinsumError, Result};
33pub use expr::BufferPool;
34pub use operand::{EinsumOperand, EinsumScalar, StridedData};
35pub use parse::{parse_einsum, EinsumCode, EinsumNode};
36pub use typed_tensor::{needs_c64_promotion, TypedTensor};
37
38/// Parse and evaluate an einsum expression in one call.
39///
40/// Pass `size_dict` to specify sizes for output indices not present in any
41/// input (generative outputs like `"->ii"` or `"i->ij"`).
42///
43/// # Example
44/// ```ignore
45/// let result = einsum("(ij,jk),kl->il", vec![a.into(), b.into(), c.into()], None)?;
46/// ```
47pub fn einsum<'a>(
48 notation: &str,
49 operands: Vec<EinsumOperand<'a>>,
50 size_dict: Option<&HashMap<char, usize>>,
51) -> Result<EinsumOperand<'a>> {
52 let code = parse_einsum(notation)?;
53 code.evaluate(operands, size_dict)
54}
55
56/// Parse and evaluate an einsum expression with optional buffer pool reuse.
57///
58/// Pass `Some(&mut pool)` to reuse intermediate buffers across calls.
59/// Pass `None` for independent allocation (equivalent to [`einsum`]).
60pub fn einsum_with_pool<'a>(
61 notation: &str,
62 operands: Vec<EinsumOperand<'a>>,
63 size_dict: Option<&HashMap<char, usize>>,
64 pool: Option<&mut BufferPool>,
65) -> Result<EinsumOperand<'a>> {
66 let code = parse_einsum(notation)?;
67 code.evaluate_with_pool(operands, size_dict, pool)
68}
69
70/// Parse and evaluate an einsum expression, writing the result into a
71/// pre-allocated output buffer with alpha/beta scaling.
72///
73/// `output = alpha * einsum(operands) + beta * output`
74///
75/// Pass `size_dict` to specify sizes for output indices not present in any
76/// input (generative outputs like `"->ii"` or `"i->ij"`).
77///
78/// # Example
79/// ```ignore
80/// use strided_opteinsum::{einsum_into, EinsumOperand};
81///
82/// let mut c = StridedArray::<f64>::col_major(&[2, 2]);
83/// einsum_into("ij,jk->ik", vec![a.into(), b.into()], c.view_mut(), 1.0, 0.0, None)?;
84/// ```
85pub fn einsum_into<T: EinsumScalar>(
86 notation: &str,
87 operands: Vec<EinsumOperand<'_>>,
88 output: strided_view::StridedViewMut<T>,
89 alpha: T,
90 beta: T,
91 size_dict: Option<&HashMap<char, usize>>,
92) -> Result<()> {
93 let code = parse_einsum(notation)?;
94 code.evaluate_into(operands, output, alpha, beta, size_dict)
95}
96
97/// Parse and evaluate an einsum expression into an output buffer with
98/// optional buffer pool reuse.
99///
100/// `output = alpha * einsum(operands) + beta * output`
101///
102/// Pass `Some(&mut pool)` to reuse intermediate buffers across calls.
103/// Pass `None` for independent allocation (equivalent to [`einsum_into`]).
104pub fn einsum_into_with_pool<T: EinsumScalar>(
105 notation: &str,
106 operands: Vec<EinsumOperand<'_>>,
107 output: strided_view::StridedViewMut<T>,
108 alpha: T,
109 beta: T,
110 size_dict: Option<&HashMap<char, usize>>,
111 pool: Option<&mut BufferPool>,
112) -> Result<()> {
113 let code = parse_einsum(notation)?;
114 code.evaluate_into_with_pool(operands, output, alpha, beta, size_dict, pool)
115}