1pub mod convert;
19
20use ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
21use num_complex::Complex64;
22use strided_opteinsum::{EinsumOperand, EinsumScalar};
23
24use crate::convert::{
25 array_to_strided_view, strided_array_to_ndarray, view_mut_to_strided_view_mut,
26 view_to_strided_view,
27};
28
29#[derive(Debug, thiserror::Error)]
31pub enum Error {
32 #[error(transparent)]
33 Einsum(#[from] strided_opteinsum::EinsumError),
34}
35
36pub type Result<T> = std::result::Result<T, Error>;
37
38pub enum NdOperand<'a> {
42 F64Array(&'a ArrayD<f64>),
43 C64Array(&'a ArrayD<Complex64>),
44 F64View(ArrayViewD<'a, f64>),
45 C64View(ArrayViewD<'a, Complex64>),
46}
47
48impl<'a> From<&'a ArrayD<f64>> for NdOperand<'a> {
49 fn from(arr: &'a ArrayD<f64>) -> Self {
50 NdOperand::F64Array(arr)
51 }
52}
53
54impl<'a> From<&'a ArrayD<Complex64>> for NdOperand<'a> {
55 fn from(arr: &'a ArrayD<Complex64>) -> Self {
56 NdOperand::C64Array(arr)
57 }
58}
59
60impl<'a> From<ArrayViewD<'a, f64>> for NdOperand<'a> {
61 fn from(view: ArrayViewD<'a, f64>) -> Self {
62 NdOperand::F64View(view)
63 }
64}
65
66impl<'a> From<ArrayViewD<'a, Complex64>> for NdOperand<'a> {
67 fn from(view: ArrayViewD<'a, Complex64>) -> Self {
68 NdOperand::C64View(view)
69 }
70}
71
72fn to_einsum_operand<'a>(op: &NdOperand<'a>) -> EinsumOperand<'a> {
74 match op {
75 NdOperand::F64Array(arr) => {
76 let sv = array_to_strided_view(arr);
77 EinsumOperand::from_view(&sv)
78 }
79 NdOperand::C64Array(arr) => {
80 let sv = array_to_strided_view(arr);
81 EinsumOperand::from_view(&sv)
82 }
83 NdOperand::F64View(view) => {
84 let sv = view_to_strided_view(view);
85 EinsumOperand::from_view(&sv)
86 }
87 NdOperand::C64View(view) => {
88 let sv = view_to_strided_view(view);
89 EinsumOperand::from_view(&sv)
90 }
91 }
92}
93
94pub fn einsum<T: EinsumScalar>(notation: &str, operands: Vec<NdOperand<'_>>) -> Result<ArrayD<T>> {
104 let einsum_ops: Vec<EinsumOperand<'_>> =
105 operands.iter().map(|op| to_einsum_operand(op)).collect();
106 let result = strided_opteinsum::einsum(notation, einsum_ops, None)?;
107 let data = T::extract_data(result)?;
108 let strided_arr = data.into_array();
109 Ok(strided_array_to_ndarray(strided_arr))
110}
111
112pub fn einsum_into<'a, T: EinsumScalar>(
117 notation: &str,
118 operands: Vec<NdOperand<'_>>,
119 output: &'a mut ArrayViewMutD<'a, T>,
120 alpha: T,
121 beta: T,
122) -> Result<()> {
123 let einsum_ops: Vec<EinsumOperand<'_>> =
124 operands.iter().map(|op| to_einsum_operand(op)).collect();
125 let strided_out = view_mut_to_strided_view_mut(output);
126 strided_opteinsum::einsum_into(notation, einsum_ops, strided_out, alpha, beta, None)?;
127 Ok(())
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use approx::assert_abs_diff_eq;
134 use ndarray::ArrayD;
135
136 fn make_array(data: Vec<f64>, shape: Vec<usize>) -> ArrayD<f64> {
137 ArrayD::from_shape_vec(shape, data).unwrap()
138 }
139
140 #[test]
141 fn test_matmul_2x3_times_3x2() {
142 let a = make_array(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
143 let b = make_array(vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0], vec![3, 2]);
144
145 let c: ArrayD<f64> = einsum("ij,jk->ik", vec![(&a).into(), (&b).into()]).unwrap();
146
147 assert_eq!(c.shape(), &[2, 2]);
148 assert_abs_diff_eq!(c[[0, 0]], 58.0, epsilon = 1e-10);
149 assert_abs_diff_eq!(c[[0, 1]], 64.0, epsilon = 1e-10);
150 assert_abs_diff_eq!(c[[1, 0]], 139.0, epsilon = 1e-10);
151 assert_abs_diff_eq!(c[[1, 1]], 154.0, epsilon = 1e-10);
152 }
153
154 #[test]
155 fn test_trace() {
156 let a = make_array(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
157
158 let c: ArrayD<f64> = einsum("ii->", vec![(&a).into()]).unwrap();
159
160 assert_eq!(c.shape(), &[] as &[usize]);
161 assert_abs_diff_eq!(c[[]], 5.0, epsilon = 1e-10);
162 }
163
164 #[test]
165 fn test_transpose() {
166 let a = make_array(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
167
168 let c: ArrayD<f64> = einsum("ij->ji", vec![(&a).into()]).unwrap();
169
170 assert_eq!(c.shape(), &[3, 2]);
171 assert_abs_diff_eq!(c[[0, 0]], 1.0, epsilon = 1e-10);
172 assert_abs_diff_eq!(c[[1, 0]], 2.0, epsilon = 1e-10);
173 assert_abs_diff_eq!(c[[0, 1]], 4.0, epsilon = 1e-10);
174 assert_abs_diff_eq!(c[[2, 1]], 6.0, epsilon = 1e-10);
175 }
176
177 #[test]
178 fn test_dot_product() {
179 let a = make_array(vec![1.0, 2.0, 3.0], vec![3]);
180 let b = make_array(vec![4.0, 5.0, 6.0], vec![3]);
181
182 let c: ArrayD<f64> = einsum("i,i->", vec![(&a).into(), (&b).into()]).unwrap();
183
184 assert_eq!(c.shape(), &[] as &[usize]);
185 assert_abs_diff_eq!(c[[]], 32.0, epsilon = 1e-10);
186 }
187
188 #[test]
189 fn test_outer_product() {
190 let a = make_array(vec![1.0, 2.0], vec![2]);
191 let b = make_array(vec![3.0, 4.0, 5.0], vec![3]);
192
193 let c: ArrayD<f64> = einsum("i,j->ij", vec![(&a).into(), (&b).into()]).unwrap();
194
195 assert_eq!(c.shape(), &[2, 3]);
196 assert_abs_diff_eq!(c[[0, 0]], 3.0, epsilon = 1e-10);
197 assert_abs_diff_eq!(c[[0, 2]], 5.0, epsilon = 1e-10);
198 assert_abs_diff_eq!(c[[1, 1]], 8.0, epsilon = 1e-10);
199 }
200
201 #[test]
202 fn test_three_operand_chain() {
203 let a = make_array((1..=6).map(|x| x as f64).collect(), vec![2, 3]);
204 let b = make_array((1..=12).map(|x| x as f64).collect(), vec![3, 4]);
205 let c = make_array((1..=8).map(|x| x as f64).collect(), vec![4, 2]);
206
207 let result: ArrayD<f64> =
208 einsum("ij,jk,kl->il", vec![(&a).into(), (&b).into(), (&c).into()]).unwrap();
209
210 assert_eq!(result.shape(), &[2, 2]);
211 assert_abs_diff_eq!(result[[0, 0]], 812.0, epsilon = 1e-10);
212 assert_abs_diff_eq!(result[[0, 1]], 1000.0, epsilon = 1e-10);
213 }
214
215 #[test]
216 fn test_view_input() {
217 let a = make_array(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
218 let b = make_array(vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0], vec![3, 2]);
219
220 let va = a.view();
221 let vb = b.view();
222
223 let c: ArrayD<f64> = einsum(
224 "ij,jk->ik",
225 vec![
226 va.into_dimensionality().unwrap().into(),
227 vb.into_dimensionality().unwrap().into(),
228 ],
229 )
230 .unwrap();
231
232 assert_eq!(c.shape(), &[2, 2]);
233 assert_abs_diff_eq!(c[[0, 0]], 58.0, epsilon = 1e-10);
234 assert_abs_diff_eq!(c[[1, 1]], 154.0, epsilon = 1e-10);
235 }
236
237 #[test]
238 fn test_einsum_into_matmul() {
239 let a = make_array(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
240 let b = make_array(vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0], vec![3, 2]);
241 let mut c = ArrayD::<f64>::zeros(vec![2, 2]);
242
243 {
244 let mut view = c.view_mut();
245 einsum_into(
246 "ij,jk->ik",
247 vec![(&a).into(), (&b).into()],
248 &mut view,
249 1.0,
250 0.0,
251 )
252 .unwrap();
253 }
254
255 assert_abs_diff_eq!(c[[0, 0]], 58.0, epsilon = 1e-10);
256 assert_abs_diff_eq!(c[[1, 1]], 154.0, epsilon = 1e-10);
257 }
258}