mdarray_opteinsum/
lib.rs

1//! N-ary Einstein summation for `mdarray` arrays.
2//!
3//! This crate provides a thin wrapper over [`strided_opteinsum`] that accepts
4//! `mdarray` `Array<T, DynRank>` and `View<T, DynRank>` types directly.
5//! Row-major / column-major layout conversion is handled transparently.
6//!
7//! # Example
8//!
9//! ```ignore
10//! use mdarray::{Array, DynRank};
11//! use mdarray_opteinsum::einsum;
12//!
13//! let a: Array<f64, DynRank> = /* 3x4 matrix */;
14//! let b: Array<f64, DynRank> = /* 4x5 matrix */;
15//! let c: Array<f64, DynRank> = einsum("ij,jk->ik", vec![(&a).into(), (&b).into()])?;
16//! ```
17
18pub mod convert;
19
20use mdarray::{Array, DynRank, View, ViewMut};
21use num_complex::Complex64;
22use strided_opteinsum::{EinsumOperand, EinsumScalar};
23
24use crate::convert::{
25    array_to_strided_view, reverse_notation, strided_array_to_mdarray,
26    view_mut_to_strided_view_mut, view_to_strided_view,
27};
28
29/// Error type for mdarray-opteinsum operations.
30#[derive(Debug, thiserror::Error)]
31pub enum Error {
32    #[error(transparent)]
33    Einsum(#[from] strided_opteinsum::EinsumError),
34}
35
36pub type Result<T> = std::result::Result<T, Error>;
37
38/// A type-erased einsum operand wrapping mdarray types.
39///
40/// Construct via `From` impls on owned arrays or views.
41pub enum MdOperand<'a> {
42    F64Array(&'a Array<f64, DynRank>),
43    C64Array(&'a Array<Complex64, DynRank>),
44    F64View(View<'a, f64, DynRank>),
45    C64View(View<'a, Complex64, DynRank>),
46}
47
48impl<'a> From<&'a Array<f64, DynRank>> for MdOperand<'a> {
49    fn from(arr: &'a Array<f64, DynRank>) -> Self {
50        MdOperand::F64Array(arr)
51    }
52}
53
54impl<'a> From<&'a Array<Complex64, DynRank>> for MdOperand<'a> {
55    fn from(arr: &'a Array<Complex64, DynRank>) -> Self {
56        MdOperand::C64Array(arr)
57    }
58}
59
60impl<'a> From<View<'a, f64, DynRank>> for MdOperand<'a> {
61    fn from(view: View<'a, f64, DynRank>) -> Self {
62        MdOperand::F64View(view)
63    }
64}
65
66impl<'a> From<View<'a, Complex64, DynRank>> for MdOperand<'a> {
67    fn from(view: View<'a, Complex64, DynRank>) -> Self {
68        MdOperand::C64View(view)
69    }
70}
71
72/// Convert an `MdOperand` into a `strided_opteinsum::EinsumOperand` with reversed dims.
73fn to_einsum_operand<'a>(op: &'a MdOperand<'a>) -> EinsumOperand<'a> {
74    match op {
75        MdOperand::F64Array(arr) => {
76            let sv = array_to_strided_view(*arr);
77            EinsumOperand::from_view(&sv)
78        }
79        MdOperand::C64Array(arr) => {
80            let sv = array_to_strided_view(*arr);
81            EinsumOperand::from_view(&sv)
82        }
83        MdOperand::F64View(view) => {
84            let sv = view_to_strided_view(view);
85            EinsumOperand::from_view(&sv)
86        }
87        MdOperand::C64View(view) => {
88            let sv = view_to_strided_view(view);
89            EinsumOperand::from_view(&sv)
90        }
91    }
92}
93
94/// Parse and evaluate an einsum expression on mdarray operands.
95///
96/// Returns the result as an owned `Array<T, DynRank>`.
97///
98/// # Example
99///
100/// ```ignore
101/// let c: Array<f64, DynRank> = einsum("ij,jk->ik", vec![(&a).into(), (&b).into()])?;
102/// ```
103pub fn einsum<T: EinsumScalar>(
104    notation: &str,
105    operands: Vec<MdOperand<'_>>,
106) -> Result<Array<T, DynRank>> {
107    let reversed = reverse_notation(notation);
108    let einsum_ops: Vec<EinsumOperand<'_>> =
109        operands.iter().map(|op| to_einsum_operand(op)).collect();
110    let result = strided_opteinsum::einsum(&reversed, einsum_ops, None)?;
111    let data = T::extract_data(result)?;
112    let strided_arr = data.into_array();
113    Ok(strided_array_to_mdarray(strided_arr))
114}
115
116/// Parse and evaluate an einsum expression, writing the result into a
117/// pre-allocated mdarray output with alpha/beta scaling.
118///
119/// `output = alpha * einsum(operands) + beta * output`
120pub fn einsum_into<'a, T: EinsumScalar>(
121    notation: &str,
122    operands: Vec<MdOperand<'_>>,
123    output: &'a mut ViewMut<'a, T, DynRank>,
124    alpha: T,
125    beta: T,
126) -> Result<()> {
127    let reversed = reverse_notation(notation);
128    let einsum_ops: Vec<EinsumOperand<'_>> =
129        operands.iter().map(|op| to_einsum_operand(op)).collect();
130    let strided_out = view_mut_to_strided_view_mut(output);
131    strided_opteinsum::einsum_into(&reversed, einsum_ops, strided_out, alpha, beta, None)?;
132    Ok(())
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use approx::assert_abs_diff_eq;
139    use std::mem::ManuallyDrop;
140
141    /// Helper: create a row-major DynRank array from a Vec and shape.
142    fn make_array(data: Vec<f64>, dims: &[usize]) -> Array<f64, DynRank> {
143        let shape: DynRank = mdarray::Shape::from_dims(dims);
144        let mut data = ManuallyDrop::new(data);
145        let capacity = data.capacity();
146        let ptr = data.as_mut_ptr();
147        unsafe { Array::from_raw_parts(ptr, shape, capacity) }
148    }
149
150    /// Helper: read element at row-major indices from a DynRank array.
151    fn get_elem(arr: &Array<f64, DynRank>, indices: &[usize]) -> f64 {
152        let dims = arr.dims();
153        let mut offset = 0usize;
154        for (i, &idx) in indices.iter().enumerate() {
155            let stride: usize = dims[i + 1..].iter().product();
156            offset += idx * stride;
157        }
158        unsafe { *arr.as_ptr().add(offset) }
159    }
160
161    #[test]
162    fn test_matmul_2x3_times_3x2() {
163        // A = [[1,2,3],[4,5,6]]  (2x3, row-major)
164        let a = make_array(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
165        // B = [[7,8],[9,10],[11,12]]  (3x2, row-major)
166        let b = make_array(vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0], &[3, 2]);
167
168        let c: Array<f64, DynRank> = einsum("ij,jk->ik", vec![(&a).into(), (&b).into()]).unwrap();
169
170        assert_eq!(c.dims(), &[2, 2]);
171        // C[0,0] = 1*7 + 2*9 + 3*11 = 58
172        assert_abs_diff_eq!(get_elem(&c, &[0, 0]), 58.0, epsilon = 1e-10);
173        // C[0,1] = 1*8 + 2*10 + 3*12 = 64
174        assert_abs_diff_eq!(get_elem(&c, &[0, 1]), 64.0, epsilon = 1e-10);
175        // C[1,0] = 4*7 + 5*9 + 6*11 = 139
176        assert_abs_diff_eq!(get_elem(&c, &[1, 0]), 139.0, epsilon = 1e-10);
177        // C[1,1] = 4*8 + 5*10 + 6*12 = 154
178        assert_abs_diff_eq!(get_elem(&c, &[1, 1]), 154.0, epsilon = 1e-10);
179    }
180
181    #[test]
182    fn test_trace() {
183        // A = [[1,2],[3,4]]  (2x2)
184        let a = make_array(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);
185
186        let c: Array<f64, DynRank> = einsum("ii->", vec![(&a).into()]).unwrap();
187
188        // Scalar result: trace = 1 + 4 = 5
189        assert_eq!(c.dims(), &[] as &[usize]);
190        assert_abs_diff_eq!(unsafe { *c.as_ptr() }, 5.0, epsilon = 1e-10);
191    }
192
193    #[test]
194    fn test_transpose() {
195        // A = [[1,2,3],[4,5,6]]  (2x3)
196        let a = make_array(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
197
198        let c: Array<f64, DynRank> = einsum("ij->ji", vec![(&a).into()]).unwrap();
199
200        assert_eq!(c.dims(), &[3, 2]);
201        // C[0,0] = A[0,0] = 1
202        assert_abs_diff_eq!(get_elem(&c, &[0, 0]), 1.0, epsilon = 1e-10);
203        // C[1,0] = A[0,1] = 2
204        assert_abs_diff_eq!(get_elem(&c, &[1, 0]), 2.0, epsilon = 1e-10);
205        // C[0,1] = A[1,0] = 4
206        assert_abs_diff_eq!(get_elem(&c, &[0, 1]), 4.0, epsilon = 1e-10);
207        // C[2,1] = A[1,2] = 6
208        assert_abs_diff_eq!(get_elem(&c, &[2, 1]), 6.0, epsilon = 1e-10);
209    }
210
211    #[test]
212    fn test_dot_product() {
213        // a = [1,2,3], b = [4,5,6]
214        let a = make_array(vec![1.0, 2.0, 3.0], &[3]);
215        let b = make_array(vec![4.0, 5.0, 6.0], &[3]);
216
217        let c: Array<f64, DynRank> = einsum("i,i->", vec![(&a).into(), (&b).into()]).unwrap();
218
219        // dot = 1*4 + 2*5 + 3*6 = 32
220        assert_eq!(c.dims(), &[] as &[usize]);
221        assert_abs_diff_eq!(unsafe { *c.as_ptr() }, 32.0, epsilon = 1e-10);
222    }
223
224    #[test]
225    fn test_outer_product() {
226        // a = [1,2], b = [3,4,5]
227        let a = make_array(vec![1.0, 2.0], &[2]);
228        let b = make_array(vec![3.0, 4.0, 5.0], &[3]);
229
230        let c: Array<f64, DynRank> = einsum("i,j->ij", vec![(&a).into(), (&b).into()]).unwrap();
231
232        assert_eq!(c.dims(), &[2, 3]);
233        // C[0,0] = 1*3 = 3
234        assert_abs_diff_eq!(get_elem(&c, &[0, 0]), 3.0, epsilon = 1e-10);
235        // C[0,2] = 1*5 = 5
236        assert_abs_diff_eq!(get_elem(&c, &[0, 2]), 5.0, epsilon = 1e-10);
237        // C[1,1] = 2*4 = 8
238        assert_abs_diff_eq!(get_elem(&c, &[1, 1]), 8.0, epsilon = 1e-10);
239    }
240
241    #[test]
242    fn test_three_operand_chain() {
243        // A[2,3] * B[3,4] * C[4,2] -> result[2,2]
244        let a = make_array((1..=6).map(|x| x as f64).collect(), &[2, 3]);
245        let b = make_array((1..=12).map(|x| x as f64).collect(), &[3, 4]);
246        let c = make_array((1..=8).map(|x| x as f64).collect(), &[4, 2]);
247
248        let result: Array<f64, DynRank> =
249            einsum("ij,jk,kl->il", vec![(&a).into(), (&b).into(), (&c).into()]).unwrap();
250
251        assert_eq!(result.dims(), &[2, 2]);
252
253        // Verify against manual computation:
254        // AB = A * B (2x4), then ABC = AB * C (2x2)
255        // AB[0,0] = 1*1 + 2*5 + 3*9 = 38
256        // AB[0,1] = 1*2 + 2*6 + 3*10 = 44
257        // AB[0,2] = 1*3 + 2*7 + 3*11 = 50
258        // AB[0,3] = 1*4 + 2*8 + 3*12 = 56
259        // ABC[0,0] = 38*1 + 44*3 + 50*5 + 56*7 = 38+132+250+392 = 812
260        // ABC[0,1] = 38*2 + 44*4 + 50*6 + 56*8 = 76+176+300+448 = 1000
261        assert_abs_diff_eq!(get_elem(&result, &[0, 0]), 812.0, epsilon = 1e-10);
262        assert_abs_diff_eq!(get_elem(&result, &[0, 1]), 1000.0, epsilon = 1e-10);
263    }
264
265    #[test]
266    fn test_view_input() {
267        // Test that View inputs work
268        let a = make_array(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
269        let b = make_array(vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0], &[3, 2]);
270
271        let va = a.expr();
272        let vb = b.expr();
273
274        let c: Array<f64, DynRank> = einsum("ij,jk->ik", vec![va.into(), vb.into()]).unwrap();
275
276        assert_eq!(c.dims(), &[2, 2]);
277        assert_abs_diff_eq!(get_elem(&c, &[0, 0]), 58.0, epsilon = 1e-10);
278        assert_abs_diff_eq!(get_elem(&c, &[1, 1]), 154.0, epsilon = 1e-10);
279    }
280
281    #[test]
282    fn test_einsum_into_matmul() {
283        let a = make_array(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
284        let b = make_array(vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0], &[3, 2]);
285        let mut c: Array<f64, DynRank> = Array::zeros(&[2usize, 2]);
286
287        {
288            let mut view = c.expr_mut();
289            einsum_into(
290                "ij,jk->ik",
291                vec![(&a).into(), (&b).into()],
292                &mut view,
293                1.0,
294                0.0,
295            )
296            .unwrap();
297        }
298
299        assert_abs_diff_eq!(get_elem(&c, &[0, 0]), 58.0, epsilon = 1e-10);
300        assert_abs_diff_eq!(get_elem(&c, &[1, 1]), 154.0, epsilon = 1e-10);
301    }
302}