1pub mod convert;
19
20use mdarray::{Array, DynRank, View, ViewMut};
21use num_complex::Complex64;
22use strided_opteinsum::{EinsumOperand, EinsumScalar};
23
24use crate::convert::{
25 array_to_strided_view, reverse_notation, strided_array_to_mdarray,
26 view_mut_to_strided_view_mut, view_to_strided_view,
27};
28
29#[derive(Debug, thiserror::Error)]
31pub enum Error {
32 #[error(transparent)]
33 Einsum(#[from] strided_opteinsum::EinsumError),
34}
35
36pub type Result<T> = std::result::Result<T, Error>;
37
38pub enum MdOperand<'a> {
42 F64Array(&'a Array<f64, DynRank>),
43 C64Array(&'a Array<Complex64, DynRank>),
44 F64View(View<'a, f64, DynRank>),
45 C64View(View<'a, Complex64, DynRank>),
46}
47
48impl<'a> From<&'a Array<f64, DynRank>> for MdOperand<'a> {
49 fn from(arr: &'a Array<f64, DynRank>) -> Self {
50 MdOperand::F64Array(arr)
51 }
52}
53
54impl<'a> From<&'a Array<Complex64, DynRank>> for MdOperand<'a> {
55 fn from(arr: &'a Array<Complex64, DynRank>) -> Self {
56 MdOperand::C64Array(arr)
57 }
58}
59
60impl<'a> From<View<'a, f64, DynRank>> for MdOperand<'a> {
61 fn from(view: View<'a, f64, DynRank>) -> Self {
62 MdOperand::F64View(view)
63 }
64}
65
66impl<'a> From<View<'a, Complex64, DynRank>> for MdOperand<'a> {
67 fn from(view: View<'a, Complex64, DynRank>) -> Self {
68 MdOperand::C64View(view)
69 }
70}
71
72fn to_einsum_operand<'a>(op: &'a MdOperand<'a>) -> EinsumOperand<'a> {
74 match op {
75 MdOperand::F64Array(arr) => {
76 let sv = array_to_strided_view(*arr);
77 EinsumOperand::from_view(&sv)
78 }
79 MdOperand::C64Array(arr) => {
80 let sv = array_to_strided_view(*arr);
81 EinsumOperand::from_view(&sv)
82 }
83 MdOperand::F64View(view) => {
84 let sv = view_to_strided_view(view);
85 EinsumOperand::from_view(&sv)
86 }
87 MdOperand::C64View(view) => {
88 let sv = view_to_strided_view(view);
89 EinsumOperand::from_view(&sv)
90 }
91 }
92}
93
94pub fn einsum<T: EinsumScalar>(
104 notation: &str,
105 operands: Vec<MdOperand<'_>>,
106) -> Result<Array<T, DynRank>> {
107 let reversed = reverse_notation(notation);
108 let einsum_ops: Vec<EinsumOperand<'_>> =
109 operands.iter().map(|op| to_einsum_operand(op)).collect();
110 let result = strided_opteinsum::einsum(&reversed, einsum_ops, None)?;
111 let data = T::extract_data(result)?;
112 let strided_arr = data.into_array();
113 Ok(strided_array_to_mdarray(strided_arr))
114}
115
116pub fn einsum_into<'a, T: EinsumScalar>(
121 notation: &str,
122 operands: Vec<MdOperand<'_>>,
123 output: &'a mut ViewMut<'a, T, DynRank>,
124 alpha: T,
125 beta: T,
126) -> Result<()> {
127 let reversed = reverse_notation(notation);
128 let einsum_ops: Vec<EinsumOperand<'_>> =
129 operands.iter().map(|op| to_einsum_operand(op)).collect();
130 let strided_out = view_mut_to_strided_view_mut(output);
131 strided_opteinsum::einsum_into(&reversed, einsum_ops, strided_out, alpha, beta, None)?;
132 Ok(())
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use approx::assert_abs_diff_eq;
139 use std::mem::ManuallyDrop;
140
141 fn make_array(data: Vec<f64>, dims: &[usize]) -> Array<f64, DynRank> {
143 let shape: DynRank = mdarray::Shape::from_dims(dims);
144 let mut data = ManuallyDrop::new(data);
145 let capacity = data.capacity();
146 let ptr = data.as_mut_ptr();
147 unsafe { Array::from_raw_parts(ptr, shape, capacity) }
148 }
149
150 fn get_elem(arr: &Array<f64, DynRank>, indices: &[usize]) -> f64 {
152 let dims = arr.dims();
153 let mut offset = 0usize;
154 for (i, &idx) in indices.iter().enumerate() {
155 let stride: usize = dims[i + 1..].iter().product();
156 offset += idx * stride;
157 }
158 unsafe { *arr.as_ptr().add(offset) }
159 }
160
161 #[test]
162 fn test_matmul_2x3_times_3x2() {
163 let a = make_array(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
165 let b = make_array(vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0], &[3, 2]);
167
168 let c: Array<f64, DynRank> = einsum("ij,jk->ik", vec![(&a).into(), (&b).into()]).unwrap();
169
170 assert_eq!(c.dims(), &[2, 2]);
171 assert_abs_diff_eq!(get_elem(&c, &[0, 0]), 58.0, epsilon = 1e-10);
173 assert_abs_diff_eq!(get_elem(&c, &[0, 1]), 64.0, epsilon = 1e-10);
175 assert_abs_diff_eq!(get_elem(&c, &[1, 0]), 139.0, epsilon = 1e-10);
177 assert_abs_diff_eq!(get_elem(&c, &[1, 1]), 154.0, epsilon = 1e-10);
179 }
180
181 #[test]
182 fn test_trace() {
183 let a = make_array(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);
185
186 let c: Array<f64, DynRank> = einsum("ii->", vec![(&a).into()]).unwrap();
187
188 assert_eq!(c.dims(), &[] as &[usize]);
190 assert_abs_diff_eq!(unsafe { *c.as_ptr() }, 5.0, epsilon = 1e-10);
191 }
192
193 #[test]
194 fn test_transpose() {
195 let a = make_array(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
197
198 let c: Array<f64, DynRank> = einsum("ij->ji", vec![(&a).into()]).unwrap();
199
200 assert_eq!(c.dims(), &[3, 2]);
201 assert_abs_diff_eq!(get_elem(&c, &[0, 0]), 1.0, epsilon = 1e-10);
203 assert_abs_diff_eq!(get_elem(&c, &[1, 0]), 2.0, epsilon = 1e-10);
205 assert_abs_diff_eq!(get_elem(&c, &[0, 1]), 4.0, epsilon = 1e-10);
207 assert_abs_diff_eq!(get_elem(&c, &[2, 1]), 6.0, epsilon = 1e-10);
209 }
210
211 #[test]
212 fn test_dot_product() {
213 let a = make_array(vec![1.0, 2.0, 3.0], &[3]);
215 let b = make_array(vec![4.0, 5.0, 6.0], &[3]);
216
217 let c: Array<f64, DynRank> = einsum("i,i->", vec![(&a).into(), (&b).into()]).unwrap();
218
219 assert_eq!(c.dims(), &[] as &[usize]);
221 assert_abs_diff_eq!(unsafe { *c.as_ptr() }, 32.0, epsilon = 1e-10);
222 }
223
224 #[test]
225 fn test_outer_product() {
226 let a = make_array(vec![1.0, 2.0], &[2]);
228 let b = make_array(vec![3.0, 4.0, 5.0], &[3]);
229
230 let c: Array<f64, DynRank> = einsum("i,j->ij", vec![(&a).into(), (&b).into()]).unwrap();
231
232 assert_eq!(c.dims(), &[2, 3]);
233 assert_abs_diff_eq!(get_elem(&c, &[0, 0]), 3.0, epsilon = 1e-10);
235 assert_abs_diff_eq!(get_elem(&c, &[0, 2]), 5.0, epsilon = 1e-10);
237 assert_abs_diff_eq!(get_elem(&c, &[1, 1]), 8.0, epsilon = 1e-10);
239 }
240
241 #[test]
242 fn test_three_operand_chain() {
243 let a = make_array((1..=6).map(|x| x as f64).collect(), &[2, 3]);
245 let b = make_array((1..=12).map(|x| x as f64).collect(), &[3, 4]);
246 let c = make_array((1..=8).map(|x| x as f64).collect(), &[4, 2]);
247
248 let result: Array<f64, DynRank> =
249 einsum("ij,jk,kl->il", vec![(&a).into(), (&b).into(), (&c).into()]).unwrap();
250
251 assert_eq!(result.dims(), &[2, 2]);
252
253 assert_abs_diff_eq!(get_elem(&result, &[0, 0]), 812.0, epsilon = 1e-10);
262 assert_abs_diff_eq!(get_elem(&result, &[0, 1]), 1000.0, epsilon = 1e-10);
263 }
264
265 #[test]
266 fn test_view_input() {
267 let a = make_array(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
269 let b = make_array(vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0], &[3, 2]);
270
271 let va = a.expr();
272 let vb = b.expr();
273
274 let c: Array<f64, DynRank> = einsum("ij,jk->ik", vec![va.into(), vb.into()]).unwrap();
275
276 assert_eq!(c.dims(), &[2, 2]);
277 assert_abs_diff_eq!(get_elem(&c, &[0, 0]), 58.0, epsilon = 1e-10);
278 assert_abs_diff_eq!(get_elem(&c, &[1, 1]), 154.0, epsilon = 1e-10);
279 }
280
281 #[test]
282 fn test_einsum_into_matmul() {
283 let a = make_array(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
284 let b = make_array(vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0], &[3, 2]);
285 let mut c: Array<f64, DynRank> = Array::zeros(&[2usize, 2]);
286
287 {
288 let mut view = c.expr_mut();
289 einsum_into(
290 "ij,jk->ik",
291 vec![(&a).into(), (&b).into()],
292 &mut view,
293 1.0,
294 0.0,
295 )
296 .unwrap();
297 }
298
299 assert_abs_diff_eq!(get_elem(&c, &[0, 0]), 58.0, epsilon = 1e-10);
300 assert_abs_diff_eq!(get_elem(&c, &[1, 1]), 154.0, epsilon = 1e-10);
301 }
302}