ndarray_opteinsum/
lib.rs

1//! N-ary Einstein summation for `ndarray` arrays.
2//!
3//! This crate provides a thin wrapper over [`strided_opteinsum`] that accepts
4//! `ndarray` `ArrayD<T>` and `ArrayViewD<T>` types directly.
5//! Dims and strides are passed through without reversal (ndarray has explicit strides).
6//!
7//! # Example
8//!
9//! ```ignore
10//! use ndarray::ArrayD;
11//! use ndarray_opteinsum::einsum;
12//!
13//! let a = ArrayD::from_shape_vec(vec![2, 3], (1..=6).map(|x| x as f64).collect()).unwrap();
14//! let b = ArrayD::from_shape_vec(vec![3, 2], (7..=12).map(|x| x as f64).collect()).unwrap();
15//! let c: ArrayD<f64> = einsum("ij,jk->ik", vec![(&a).into(), (&b).into()]).unwrap();
16//! ```
17
18pub mod convert;
19
20use ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
21use num_complex::Complex64;
22use strided_opteinsum::{EinsumOperand, EinsumScalar};
23
24use crate::convert::{
25    array_to_strided_view, strided_array_to_ndarray, view_mut_to_strided_view_mut,
26    view_to_strided_view,
27};
28
29/// Error type for ndarray-opteinsum operations.
30#[derive(Debug, thiserror::Error)]
31pub enum Error {
32    #[error(transparent)]
33    Einsum(#[from] strided_opteinsum::EinsumError),
34}
35
36pub type Result<T> = std::result::Result<T, Error>;
37
38/// A type-erased einsum operand wrapping ndarray types.
39///
40/// Construct via `From` impls on owned arrays or views.
41pub enum NdOperand<'a> {
42    F64Array(&'a ArrayD<f64>),
43    C64Array(&'a ArrayD<Complex64>),
44    F64View(ArrayViewD<'a, f64>),
45    C64View(ArrayViewD<'a, Complex64>),
46}
47
48impl<'a> From<&'a ArrayD<f64>> for NdOperand<'a> {
49    fn from(arr: &'a ArrayD<f64>) -> Self {
50        NdOperand::F64Array(arr)
51    }
52}
53
54impl<'a> From<&'a ArrayD<Complex64>> for NdOperand<'a> {
55    fn from(arr: &'a ArrayD<Complex64>) -> Self {
56        NdOperand::C64Array(arr)
57    }
58}
59
60impl<'a> From<ArrayViewD<'a, f64>> for NdOperand<'a> {
61    fn from(view: ArrayViewD<'a, f64>) -> Self {
62        NdOperand::F64View(view)
63    }
64}
65
66impl<'a> From<ArrayViewD<'a, Complex64>> for NdOperand<'a> {
67    fn from(view: ArrayViewD<'a, Complex64>) -> Self {
68        NdOperand::C64View(view)
69    }
70}
71
72/// Convert an `NdOperand` into a `strided_opteinsum::EinsumOperand`.
73fn to_einsum_operand<'a>(op: &NdOperand<'a>) -> EinsumOperand<'a> {
74    match op {
75        NdOperand::F64Array(arr) => {
76            let sv = array_to_strided_view(arr);
77            EinsumOperand::from_view(&sv)
78        }
79        NdOperand::C64Array(arr) => {
80            let sv = array_to_strided_view(arr);
81            EinsumOperand::from_view(&sv)
82        }
83        NdOperand::F64View(view) => {
84            let sv = view_to_strided_view(view);
85            EinsumOperand::from_view(&sv)
86        }
87        NdOperand::C64View(view) => {
88            let sv = view_to_strided_view(view);
89            EinsumOperand::from_view(&sv)
90        }
91    }
92}
93
94/// Parse and evaluate an einsum expression on ndarray operands.
95///
96/// Returns the result as an owned `ArrayD<T>`.
97///
98/// # Example
99///
100/// ```ignore
101/// let c: ArrayD<f64> = einsum("ij,jk->ik", vec![(&a).into(), (&b).into()])?;
102/// ```
103pub fn einsum<T: EinsumScalar>(notation: &str, operands: Vec<NdOperand<'_>>) -> Result<ArrayD<T>> {
104    let einsum_ops: Vec<EinsumOperand<'_>> =
105        operands.iter().map(|op| to_einsum_operand(op)).collect();
106    let result = strided_opteinsum::einsum(notation, einsum_ops, None)?;
107    let data = T::extract_data(result)?;
108    let strided_arr = data.into_array();
109    Ok(strided_array_to_ndarray(strided_arr))
110}
111
112/// Parse and evaluate an einsum expression, writing the result into a
113/// pre-allocated ndarray output with alpha/beta scaling.
114///
115/// `output = alpha * einsum(operands) + beta * output`
116pub fn einsum_into<'a, T: EinsumScalar>(
117    notation: &str,
118    operands: Vec<NdOperand<'_>>,
119    output: &'a mut ArrayViewMutD<'a, T>,
120    alpha: T,
121    beta: T,
122) -> Result<()> {
123    let einsum_ops: Vec<EinsumOperand<'_>> =
124        operands.iter().map(|op| to_einsum_operand(op)).collect();
125    let strided_out = view_mut_to_strided_view_mut(output);
126    strided_opteinsum::einsum_into(notation, einsum_ops, strided_out, alpha, beta, None)?;
127    Ok(())
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use approx::assert_abs_diff_eq;
134    use ndarray::ArrayD;
135
136    fn make_array(data: Vec<f64>, shape: Vec<usize>) -> ArrayD<f64> {
137        ArrayD::from_shape_vec(shape, data).unwrap()
138    }
139
140    #[test]
141    fn test_matmul_2x3_times_3x2() {
142        let a = make_array(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
143        let b = make_array(vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0], vec![3, 2]);
144
145        let c: ArrayD<f64> = einsum("ij,jk->ik", vec![(&a).into(), (&b).into()]).unwrap();
146
147        assert_eq!(c.shape(), &[2, 2]);
148        assert_abs_diff_eq!(c[[0, 0]], 58.0, epsilon = 1e-10);
149        assert_abs_diff_eq!(c[[0, 1]], 64.0, epsilon = 1e-10);
150        assert_abs_diff_eq!(c[[1, 0]], 139.0, epsilon = 1e-10);
151        assert_abs_diff_eq!(c[[1, 1]], 154.0, epsilon = 1e-10);
152    }
153
154    #[test]
155    fn test_trace() {
156        let a = make_array(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
157
158        let c: ArrayD<f64> = einsum("ii->", vec![(&a).into()]).unwrap();
159
160        assert_eq!(c.shape(), &[] as &[usize]);
161        assert_abs_diff_eq!(c[[]], 5.0, epsilon = 1e-10);
162    }
163
164    #[test]
165    fn test_transpose() {
166        let a = make_array(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
167
168        let c: ArrayD<f64> = einsum("ij->ji", vec![(&a).into()]).unwrap();
169
170        assert_eq!(c.shape(), &[3, 2]);
171        assert_abs_diff_eq!(c[[0, 0]], 1.0, epsilon = 1e-10);
172        assert_abs_diff_eq!(c[[1, 0]], 2.0, epsilon = 1e-10);
173        assert_abs_diff_eq!(c[[0, 1]], 4.0, epsilon = 1e-10);
174        assert_abs_diff_eq!(c[[2, 1]], 6.0, epsilon = 1e-10);
175    }
176
177    #[test]
178    fn test_dot_product() {
179        let a = make_array(vec![1.0, 2.0, 3.0], vec![3]);
180        let b = make_array(vec![4.0, 5.0, 6.0], vec![3]);
181
182        let c: ArrayD<f64> = einsum("i,i->", vec![(&a).into(), (&b).into()]).unwrap();
183
184        assert_eq!(c.shape(), &[] as &[usize]);
185        assert_abs_diff_eq!(c[[]], 32.0, epsilon = 1e-10);
186    }
187
188    #[test]
189    fn test_outer_product() {
190        let a = make_array(vec![1.0, 2.0], vec![2]);
191        let b = make_array(vec![3.0, 4.0, 5.0], vec![3]);
192
193        let c: ArrayD<f64> = einsum("i,j->ij", vec![(&a).into(), (&b).into()]).unwrap();
194
195        assert_eq!(c.shape(), &[2, 3]);
196        assert_abs_diff_eq!(c[[0, 0]], 3.0, epsilon = 1e-10);
197        assert_abs_diff_eq!(c[[0, 2]], 5.0, epsilon = 1e-10);
198        assert_abs_diff_eq!(c[[1, 1]], 8.0, epsilon = 1e-10);
199    }
200
201    #[test]
202    fn test_three_operand_chain() {
203        let a = make_array((1..=6).map(|x| x as f64).collect(), vec![2, 3]);
204        let b = make_array((1..=12).map(|x| x as f64).collect(), vec![3, 4]);
205        let c = make_array((1..=8).map(|x| x as f64).collect(), vec![4, 2]);
206
207        let result: ArrayD<f64> =
208            einsum("ij,jk,kl->il", vec![(&a).into(), (&b).into(), (&c).into()]).unwrap();
209
210        assert_eq!(result.shape(), &[2, 2]);
211        assert_abs_diff_eq!(result[[0, 0]], 812.0, epsilon = 1e-10);
212        assert_abs_diff_eq!(result[[0, 1]], 1000.0, epsilon = 1e-10);
213    }
214
215    #[test]
216    fn test_view_input() {
217        let a = make_array(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
218        let b = make_array(vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0], vec![3, 2]);
219
220        let va = a.view();
221        let vb = b.view();
222
223        let c: ArrayD<f64> = einsum(
224            "ij,jk->ik",
225            vec![
226                va.into_dimensionality().unwrap().into(),
227                vb.into_dimensionality().unwrap().into(),
228            ],
229        )
230        .unwrap();
231
232        assert_eq!(c.shape(), &[2, 2]);
233        assert_abs_diff_eq!(c[[0, 0]], 58.0, epsilon = 1e-10);
234        assert_abs_diff_eq!(c[[1, 1]], 154.0, epsilon = 1e-10);
235    }
236
237    #[test]
238    fn test_einsum_into_matmul() {
239        let a = make_array(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
240        let b = make_array(vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0], vec![3, 2]);
241        let mut c = ArrayD::<f64>::zeros(vec![2, 2]);
242
243        {
244            let mut view = c.view_mut();
245            einsum_into(
246                "ij,jk->ik",
247                vec![(&a).into(), (&b).into()],
248                &mut view,
249                1.0,
250                0.0,
251            )
252            .unwrap();
253        }
254
255        assert_abs_diff_eq!(c[[0, 0]], 58.0, epsilon = 1e-10);
256        assert_abs_diff_eq!(c[[1, 1]], 154.0, epsilon = 1e-10);
257    }
258}