1use std::collections::HashMap;
2
3use num_complex::Complex64;
4#[cfg(test)]
5use num_traits::Zero;
6use strided_einsum2::{einsum2_into, einsum2_into_owned};
7use strided_kernel::copy_scale;
8use strided_view::{StridedArray, StridedViewMut};
9
10use crate::operand::{EinsumOperand, EinsumScalar, StridedData};
11use crate::parse::{EinsumCode, EinsumNode};
12use crate::single_tensor::single_tensor_einsum;
13
14struct BufferPool {
24 f64_pool: HashMap<usize, Vec<Vec<f64>>>,
25 c64_pool: HashMap<usize, Vec<Vec<Complex64>>>,
26}
27
28impl BufferPool {
29 fn new() -> Self {
30 Self {
31 f64_pool: HashMap::new(),
32 c64_pool: HashMap::new(),
33 }
34 }
35}
36
37trait PoolOps: EinsumScalar {
45 fn pool_acquire(pool: &mut BufferPool, dims: &[usize]) -> StridedArray<Self>;
51
52 fn pool_release(pool: &mut BufferPool, data: StridedData<'_, Self>);
55}
56
57impl PoolOps for f64 {
58 fn pool_acquire(pool: &mut BufferPool, dims: &[usize]) -> StridedArray<f64> {
59 let total: usize = dims.iter().product();
60 match pool.f64_pool.get_mut(&total).and_then(|v| v.pop()) {
62 Some(buf) => unsafe { StridedArray::col_major_from_buffer_uninit(buf, dims) },
63 None => unsafe { StridedArray::col_major_uninit(dims) },
64 }
65 }
66
67 fn pool_release(pool: &mut BufferPool, data: StridedData<'_, f64>) {
68 if let StridedData::Owned(arr) = data {
69 let buf = arr.into_data();
70 pool.f64_pool.entry(buf.len()).or_default().push(buf);
71 }
72 }
73}
74
75impl PoolOps for Complex64 {
76 fn pool_acquire(pool: &mut BufferPool, dims: &[usize]) -> StridedArray<Complex64> {
77 let total: usize = dims.iter().product();
78 match pool.c64_pool.get_mut(&total).and_then(|v| v.pop()) {
80 Some(buf) => unsafe { StridedArray::col_major_from_buffer_uninit(buf, dims) },
81 None => unsafe { StridedArray::col_major_uninit(dims) },
82 }
83 }
84
85 fn pool_release(pool: &mut BufferPool, data: StridedData<'_, Complex64>) {
86 if let StridedData::Owned(arr) = data {
87 let buf = arr.into_data();
88 pool.c64_pool.entry(buf.len()).or_default().push(buf);
89 }
90 }
91}
92
93fn collect_all_ids(node: &EinsumNode) -> Vec<char> {
98 let mut result = Vec::new();
99 collect_all_ids_inner(node, &mut result);
100 result
101}
102
103fn collect_all_ids_inner(node: &EinsumNode, result: &mut Vec<char>) {
104 match node {
105 EinsumNode::Leaf { ids, .. } => {
106 for &id in ids {
107 if !result.contains(&id) {
108 result.push(id);
109 }
110 }
111 }
112 EinsumNode::Contract { args } => {
113 for arg in args {
114 collect_all_ids_inner(arg, result);
115 }
116 }
117 }
118}
119
120fn compute_contract_output_ids(args: &[EinsumNode], needed_ids: &[char]) -> Vec<char> {
130 let mut all_ids_ordered = Vec::new();
132 for arg in args {
133 for id in collect_all_ids(arg) {
134 if !all_ids_ordered.contains(&id) {
135 all_ids_ordered.push(id);
136 }
137 }
138 }
139
140 all_ids_ordered
142 .into_iter()
143 .filter(|id| needed_ids.contains(id))
144 .collect()
145}
146
147fn compute_child_needed_ids(
157 output_ids: &[char],
158 child_idx: usize,
159 args: &[EinsumNode],
160) -> Vec<char> {
161 let mut needed: Vec<char> = output_ids.to_vec();
162
163 let child_ids = collect_all_ids(&args[child_idx]);
165 for (j, arg) in args.iter().enumerate() {
166 if j == child_idx {
167 continue;
168 }
169 let sibling_ids = collect_all_ids(arg);
170 for &id in &child_ids {
171 if sibling_ids.contains(&id) && !needed.contains(&id) {
172 needed.push(id);
173 }
174 }
175 }
176
177 needed
178}
179
180fn out_dims_from_map(
185 dim_map: &HashMap<char, usize>,
186 output_ids: &[char],
187 size_dict: &HashMap<char, usize>,
188) -> crate::Result<Vec<usize>> {
189 let mut out_dims = Vec::with_capacity(output_ids.len());
190 for &id in output_ids {
191 if let Some(&dim) = dim_map.get(&id) {
192 out_dims.push(dim);
193 } else if let Some(&dim) = size_dict.get(&id) {
194 out_dims.push(dim);
195 } else {
196 return Err(crate::EinsumError::OrphanOutputAxis(id.to_string()));
197 }
198 }
199 Ok(out_dims)
200}
201
202fn out_dims_from_ids(
204 left_ids: &[char],
205 left_dims: &[usize],
206 right_ids: &[char],
207 right_dims: &[usize],
208 output_ids: &[char],
209 size_dict: &HashMap<char, usize>,
210) -> crate::Result<Vec<usize>> {
211 let mut out_dims = Vec::with_capacity(output_ids.len());
212 for &id in output_ids {
213 if let Some(pos) = left_ids.iter().position(|&c| c == id) {
214 out_dims.push(left_dims[pos]);
215 } else if let Some(pos) = right_ids.iter().position(|&c| c == id) {
216 out_dims.push(right_dims[pos]);
217 } else if let Some(&dim) = size_dict.get(&id) {
218 out_dims.push(dim);
219 } else {
220 return Err(crate::EinsumError::OrphanOutputAxis(id.to_string()));
221 }
222 }
223 Ok(out_dims)
224}
225
226fn eval_pair_alloc<T: PoolOps>(
231 ld: StridedData<'_, T>,
232 left_ids: &[char],
233 rd: StridedData<'_, T>,
234 right_ids: &[char],
235 output_ids: &[char],
236 pool: &mut BufferPool,
237 size_dict: &HashMap<char, usize>,
238) -> crate::Result<EinsumOperand<'static>> {
239 let out_dims = out_dims_from_ids(
240 left_ids,
241 ld.dims(),
242 right_ids,
243 rd.dims(),
244 output_ids,
245 size_dict,
246 )?;
247 let mut c_arr = T::pool_acquire(pool, &out_dims);
248 {
249 let a_view = ld.as_view();
250 let b_view = rd.as_view();
251 einsum2_into(
252 c_arr.view_mut(),
253 &a_view,
254 &b_view,
255 output_ids,
256 left_ids,
257 right_ids,
258 T::one(),
259 T::zero(),
260 )?;
261 }
262 T::pool_release(pool, ld);
263 T::pool_release(pool, rd);
264 Ok(T::wrap_array(c_arr))
265}
266
267fn eval_pair(
272 left: EinsumOperand<'_>,
273 left_ids: &[char],
274 right: EinsumOperand<'_>,
275 right_ids: &[char],
276 output_ids: &[char],
277 pool: &mut BufferPool,
278 size_dict: &HashMap<char, usize>,
279) -> crate::Result<EinsumOperand<'static>> {
280 match (left, right) {
281 (EinsumOperand::F64(ld), EinsumOperand::F64(rd)) => {
282 eval_pair_alloc(ld, left_ids, rd, right_ids, output_ids, pool, size_dict)
283 }
284 (EinsumOperand::C64(ld), EinsumOperand::C64(rd)) => {
285 eval_pair_alloc(ld, left_ids, rd, right_ids, output_ids, pool, size_dict)
286 }
287 (left, right) => {
288 let left_c64 = left.to_c64_owned();
290 let right_c64 = right.to_c64_owned();
291 eval_pair(
292 left_c64, left_ids, right_c64, right_ids, output_ids, pool, size_dict,
293 )
294 }
295 }
296}
297
298fn eval_pair_into<T: EinsumScalar>(
308 left: EinsumOperand<'_>,
309 left_ids: &[char],
310 right: EinsumOperand<'_>,
311 right_ids: &[char],
312 output: StridedViewMut<T>,
313 output_ids: &[char],
314 alpha: T,
315 beta: T,
316) -> crate::Result<()> {
317 let left_data = T::extract_data(left)?;
318 let right_data = T::extract_data(right)?;
319
320 match (left_data, right_data) {
321 (StridedData::Owned(a), StridedData::Owned(b)) => {
322 einsum2_into_owned(
323 output, a, b, output_ids, left_ids, right_ids, alpha, beta, false, false,
324 )?;
325 }
326 (StridedData::Owned(a), StridedData::View(b)) => {
327 einsum2_into(
328 output,
329 &a.view(),
330 &b,
331 output_ids,
332 left_ids,
333 right_ids,
334 alpha,
335 beta,
336 )?;
337 }
338 (StridedData::View(a), StridedData::Owned(b)) => {
339 einsum2_into(
340 output,
341 &a,
342 &b.view(),
343 output_ids,
344 left_ids,
345 right_ids,
346 alpha,
347 beta,
348 )?;
349 }
350 (StridedData::View(a), StridedData::View(b)) => {
351 einsum2_into(output, &a, &b, output_ids, left_ids, right_ids, alpha, beta)?;
352 }
353 }
354 Ok(())
355}
356
357fn accumulate_into<T: EinsumScalar>(
365 output: &mut StridedViewMut<T>,
366 result: &StridedArray<T>,
367 alpha: T,
368 beta: T,
369) -> crate::Result<()> {
370 let result_view = result.view();
371 if beta == T::zero() {
372 if alpha == T::one() {
373 strided_kernel::copy_into(output, &result_view)?;
374 } else {
375 copy_scale(output, &result_view, alpha)?;
376 }
377 } else {
378 let dims = output.dims().to_vec();
382 let mut temp = StridedArray::<T>::col_major(&dims);
383 strided_kernel::copy_into(&mut temp.view_mut(), &result_view)?;
384 let mut output_copy = StridedArray::<T>::col_major(&dims);
396 strided_kernel::copy_into(&mut output_copy.view_mut(), &output.as_view())?;
397 strided_kernel::zip_map2_into(output, &temp.view(), &output_copy.view(), |r, o| {
398 alpha * r + beta * o
399 })?;
400 }
401 Ok(())
402}
403
404fn eval_single_typed<T: EinsumScalar>(
410 data: &StridedData<'_, T>,
411 input_ids: &[char],
412 output_ids: &[char],
413 size_dict: &HashMap<char, usize>,
414) -> crate::Result<EinsumOperand<'static>> {
415 let view = data.as_view();
416 let result = single_tensor_einsum(&view, input_ids, output_ids, Some(size_dict))?;
417 Ok(T::wrap_array(result))
418}
419
420fn eval_single(
421 operand: &EinsumOperand<'_>,
422 input_ids: &[char],
423 output_ids: &[char],
424 size_dict: &HashMap<char, usize>,
425) -> crate::Result<EinsumOperand<'static>> {
426 match operand {
427 EinsumOperand::F64(data) => eval_single_typed(data, input_ids, output_ids, size_dict),
428 EinsumOperand::C64(data) => eval_single_typed(data, input_ids, output_ids, size_dict),
429 }
430}
431
432fn is_permutation_only(input_ids: &[char], output_ids: &[char]) -> bool {
439 if input_ids.len() != output_ids.len() {
440 return false;
441 }
442 for (i, &id) in input_ids.iter().enumerate() {
444 if input_ids[..i].contains(&id) {
445 return false; }
447 }
448 for &id in output_ids {
450 if !input_ids.contains(&id) {
451 return false;
452 }
453 }
454 true
455}
456
457fn compute_permutation(input_ids: &[char], output_ids: &[char]) -> Vec<usize> {
459 output_ids
460 .iter()
461 .map(|oid| input_ids.iter().position(|iid| iid == oid).unwrap())
462 .collect()
463}
464
465fn execute_nested<'a>(
476 nested: &omeco::NestedEinsum<char>,
477 children: &mut Vec<Option<(EinsumOperand<'a>, Vec<char>)>>,
478 pool: &mut BufferPool,
479 size_dict: &HashMap<char, usize>,
480) -> crate::Result<(EinsumOperand<'a>, Vec<char>)> {
481 match nested {
482 omeco::NestedEinsum::Leaf { tensor_index } => {
483 let slot = children.get_mut(*tensor_index).ok_or_else(|| {
484 crate::EinsumError::Internal(format!(
485 "optimizer referenced child index {} out of bounds",
486 tensor_index
487 ))
488 })?;
489 let (op, ids) = slot.take().ok_or_else(|| {
490 crate::EinsumError::Internal(format!(
491 "child operand {} was already consumed",
492 tensor_index
493 ))
494 })?;
495 Ok((op, ids))
496 }
497 omeco::NestedEinsum::Node { args, eins } => {
498 if args.len() != 2 {
499 return Err(crate::EinsumError::Internal(format!(
500 "optimizer produced non-binary node with {} children",
501 args.len()
502 )));
503 }
504 let (left, left_ids) = execute_nested(&args[0], children, pool, size_dict)?;
505 let (right, right_ids) = execute_nested(&args[1], children, pool, size_dict)?;
506 let output_ids: Vec<char> = eins.iy.clone();
507 let result = eval_pair(
508 left,
509 &left_ids,
510 right,
511 &right_ids,
512 &output_ids,
513 pool,
514 size_dict,
515 )?;
516 Ok((result, output_ids))
517 }
518 }
519}
520
521fn execute_nested_into<'a, T: EinsumScalar>(
527 nested: &omeco::NestedEinsum<char>,
528 children: &mut Vec<Option<(EinsumOperand<'a>, Vec<char>)>>,
529 output: StridedViewMut<T>,
530 output_ids: &[char],
531 alpha: T,
532 beta: T,
533 pool: &mut BufferPool,
534 size_dict: &HashMap<char, usize>,
535) -> crate::Result<()> {
536 match nested {
537 omeco::NestedEinsum::Node { args, eins: _ } => {
538 if args.len() != 2 {
539 return Err(crate::EinsumError::Internal(format!(
540 "optimizer produced non-binary node with {} children",
541 args.len()
542 )));
543 }
544 let (left, left_ids) = execute_nested(&args[0], children, pool, size_dict)?;
546 let (right, right_ids) = execute_nested(&args[1], children, pool, size_dict)?;
547 eval_pair_into(
549 left, &left_ids, right, &right_ids, output, output_ids, alpha, beta,
550 )
551 }
552 omeco::NestedEinsum::Leaf { tensor_index } => {
553 let slot = children.get_mut(*tensor_index).ok_or_else(|| {
555 crate::EinsumError::Internal(format!(
556 "optimizer referenced child index {} out of bounds",
557 tensor_index
558 ))
559 })?;
560 let (op, op_ids) = slot.take().ok_or_else(|| {
561 crate::EinsumError::Internal(format!(
562 "child operand {} was already consumed",
563 tensor_index
564 ))
565 })?;
566 let data = T::extract_data(op)?;
567 let arr = data.into_array();
568 if op_ids != output_ids {
570 let perm = compute_permutation(&op_ids, output_ids);
571 let permuted = arr.permuted(&perm)?;
572 accumulate_into(&mut { output }, &permuted, alpha, beta)?;
573 } else {
574 accumulate_into(&mut { output }, &arr, alpha, beta)?;
575 }
576 Ok(())
577 }
578 }
579}
580
581fn eval_node<'a>(
594 node: &EinsumNode,
595 operands: &mut Vec<Option<EinsumOperand<'a>>>,
596 needed_ids: &[char],
597 pool: &mut BufferPool,
598 size_dict: &HashMap<char, usize>,
599) -> crate::Result<(EinsumOperand<'a>, Vec<char>)> {
600 match node {
601 EinsumNode::Leaf { ids, tensor_index } => {
602 let found = operands.len();
603 let slot = operands.get_mut(*tensor_index).ok_or_else(|| {
604 crate::EinsumError::OperandCountMismatch {
605 expected: tensor_index + 1,
606 found,
607 }
608 })?;
609 let op = slot.take().ok_or_else(|| {
610 crate::EinsumError::Internal(format!(
611 "operand {} was already consumed",
612 tensor_index
613 ))
614 })?;
615 Ok((op, ids.clone()))
617 }
618 EinsumNode::Contract { args } => {
619 let node_output_ids = compute_contract_output_ids(args, needed_ids);
621
622 match args.len() {
623 0 => unreachable!("empty Contract node"),
624 1 => {
625 let child_needed = compute_child_needed_ids(&node_output_ids, 0, args);
627 let (child_op, child_ids) =
628 eval_node(&args[0], operands, &child_needed, pool, size_dict)?;
629
630 if child_ids == node_output_ids {
632 return Ok((child_op, node_output_ids));
633 }
634
635 if is_permutation_only(&child_ids, &node_output_ids) {
637 let perm = compute_permutation(&child_ids, &node_output_ids);
638 return Ok((child_op.permuted(&perm)?, node_output_ids));
639 }
640
641 let result = eval_single(&child_op, &child_ids, &node_output_ids, size_dict)?;
643 Ok((result, node_output_ids))
644 }
645 2 => {
646 let left_needed = compute_child_needed_ids(&node_output_ids, 0, args);
648 let right_needed = compute_child_needed_ids(&node_output_ids, 1, args);
649 let (left, left_ids) =
650 eval_node(&args[0], operands, &left_needed, pool, size_dict)?;
651 let (right, right_ids) =
652 eval_node(&args[1], operands, &right_needed, pool, size_dict)?;
653 let result = eval_pair(
654 left,
655 &left_ids,
656 right,
657 &right_ids,
658 &node_output_ids,
659 pool,
660 size_dict,
661 )?;
662 Ok((result, node_output_ids))
663 }
664 _ => {
665 let mut children: Vec<Option<(EinsumOperand<'a>, Vec<char>)>> = Vec::new();
670 for (i, arg) in args.iter().enumerate() {
671 let child_needed = compute_child_needed_ids(&node_output_ids, i, args);
672 let (op, ids) = eval_node(arg, operands, &child_needed, pool, size_dict)?;
673 children.push(Some((op, ids)));
674 }
675
676 let mut dim_sizes: HashMap<char, usize> = HashMap::new();
678 for child_opt in &children {
679 if let Some((op, ids)) = child_opt {
680 for (j, &id) in ids.iter().enumerate() {
681 dim_sizes.insert(id, op.dims()[j]);
682 }
683 }
684 }
685
686 let input_ids: Vec<Vec<char>> = children
688 .iter()
689 .map(|c| c.as_ref().unwrap().1.clone())
690 .collect();
691 let code = omeco::EinCode::new(input_ids, node_output_ids.clone());
692
693 let optimizer = omeco::GreedyMethod::default();
695 let nested = omeco::CodeOptimizer::optimize(&optimizer, &code, &dim_sizes)
696 .ok_or_else(|| {
697 crate::EinsumError::Internal(
698 "optimizer failed to produce a plan".into(),
699 )
700 })?;
701
702 let (result, result_ids) =
704 execute_nested(&nested, &mut children, pool, size_dict)?;
705 Ok((result, result_ids))
706 }
707 }
708 }
709 }
710}
711
712impl EinsumCode {
717 pub fn evaluate<'a>(
726 &self,
727 operands: Vec<EinsumOperand<'a>>,
728 size_dict: Option<&HashMap<char, usize>>,
729 ) -> crate::Result<EinsumOperand<'a>> {
730 let expected = leaf_count(&self.root);
731 if operands.len() != expected {
732 return Err(crate::EinsumError::OperandCountMismatch {
733 expected,
734 found: operands.len(),
735 });
736 }
737
738 let mut ops: Vec<Option<EinsumOperand<'a>>> = operands.into_iter().map(Some).collect();
739 let mut pool = BufferPool::new();
740
741 let mut unified = build_dim_map(&self.root, &ops);
743 if let Some(sd) = size_dict {
744 merge_size_dict(&mut unified, sd)?;
745 }
746
747 let (result, result_ids) =
748 eval_node(&self.root, &mut ops, &self.output_ids, &mut pool, &unified)?;
749
750 if result_ids == self.output_ids {
752 return Ok(result);
753 }
754
755 if is_permutation_only(&result_ids, &self.output_ids) {
757 let perm = compute_permutation(&result_ids, &self.output_ids);
758 return Ok(result.permuted(&perm)?);
759 }
760
761 let adjusted = eval_single(&result, &result_ids, &self.output_ids, &unified)?;
763 Ok(adjusted)
764 }
765}
766
767fn leaf_count(node: &EinsumNode) -> usize {
768 match node {
769 EinsumNode::Leaf { .. } => 1,
770 EinsumNode::Contract { args } => args.iter().map(leaf_count).sum(),
771 }
772}
773
774fn build_dim_map(
779 node: &EinsumNode,
780 operands: &[Option<EinsumOperand<'_>>],
781) -> HashMap<char, usize> {
782 let mut dim_map = HashMap::new();
783 build_dim_map_inner(node, operands, &mut dim_map);
784 dim_map
785}
786
787fn merge_size_dict(
791 unified: &mut HashMap<char, usize>,
792 user: &HashMap<char, usize>,
793) -> crate::Result<()> {
794 for (&label, &size) in user {
795 if let Some(&existing) = unified.get(&label) {
796 if existing != size {
797 return Err(crate::EinsumError::DimensionMismatch {
798 axis: label.to_string(),
799 dim_a: existing,
800 dim_b: size,
801 });
802 }
803 } else {
804 unified.insert(label, size);
805 }
806 }
807 Ok(())
808}
809
810fn build_dim_map_inner(
811 node: &EinsumNode,
812 operands: &[Option<EinsumOperand<'_>>],
813 dim_map: &mut HashMap<char, usize>,
814) {
815 match node {
816 EinsumNode::Leaf { ids, tensor_index } => {
817 if let Some(Some(op)) = operands.get(*tensor_index) {
818 for (i, &id) in ids.iter().enumerate() {
819 dim_map.insert(id, op.dims()[i]);
820 }
821 }
822 }
823 EinsumNode::Contract { args } => {
824 for arg in args {
825 build_dim_map_inner(arg, operands, dim_map);
826 }
827 }
828 }
829}
830
831impl EinsumCode {
832 pub fn evaluate_into<T: EinsumScalar>(
845 &self,
846 operands: Vec<EinsumOperand<'_>>,
847 mut output: StridedViewMut<T>,
848 alpha: T,
849 beta: T,
850 size_dict: Option<&HashMap<char, usize>>,
851 ) -> crate::Result<()> {
852 let expected = leaf_count(&self.root);
853 if operands.len() != expected {
854 return Err(crate::EinsumError::OperandCountMismatch {
855 expected,
856 found: operands.len(),
857 });
858 }
859
860 let mut ops: Vec<Option<EinsumOperand<'_>>> = operands.into_iter().map(Some).collect();
862 T::validate_operands(&ops)?;
863
864 let mut unified = build_dim_map(&self.root, &ops);
866 if let Some(sd) = size_dict {
867 merge_size_dict(&mut unified, sd)?;
868 }
869
870 let expected_dims = out_dims_from_map(&unified, &self.output_ids, &unified)?;
872 if output.dims() != expected_dims.as_slice() {
873 return Err(crate::EinsumError::OutputShapeMismatch {
874 expected: expected_dims,
875 got: output.dims().to_vec(),
876 });
877 }
878
879 let mut pool = BufferPool::new();
880
881 match &self.root {
882 EinsumNode::Leaf { ids, tensor_index } => {
883 let op = ops[*tensor_index].take().ok_or_else(|| {
885 crate::EinsumError::Internal("operand already consumed".into())
886 })?;
887 let single_result = eval_single(&op, ids, &self.output_ids, &unified)?;
888 let data = T::extract_data(single_result)?;
889 accumulate_into(&mut output, &data.into_array(), alpha, beta)?;
890 }
891 EinsumNode::Contract { args } => match args.len() {
892 0 => unreachable!("empty Contract node"),
893 1 => {
894 let child_needed = compute_child_needed_ids(&self.output_ids, 0, args);
896 let (child_op, child_ids) =
897 eval_node(&args[0], &mut ops, &child_needed, &mut pool, &unified)?;
898
899 if child_ids == self.output_ids {
900 let data = T::extract_data(child_op)?;
902 accumulate_into(&mut output, &data.into_array(), alpha, beta)?;
903 } else if is_permutation_only(&child_ids, &self.output_ids) {
904 let perm = compute_permutation(&child_ids, &self.output_ids);
906 let data = T::extract_data(child_op)?;
907 let arr = data.into_array();
908 let permuted = arr.permuted(&perm)?;
909 accumulate_into(&mut output, &permuted, alpha, beta)?;
910 } else {
911 let result =
913 eval_single(&child_op, &child_ids, &self.output_ids, &unified)?;
914 let data = T::extract_data(result)?;
915 accumulate_into(&mut output, &data.into_array(), alpha, beta)?;
916 }
917 }
918 2 => {
919 let left_needed = compute_child_needed_ids(&self.output_ids, 0, args);
921 let right_needed = compute_child_needed_ids(&self.output_ids, 1, args);
922 let (left, left_ids) =
923 eval_node(&args[0], &mut ops, &left_needed, &mut pool, &unified)?;
924 let (right, right_ids) =
925 eval_node(&args[1], &mut ops, &right_needed, &mut pool, &unified)?;
926 eval_pair_into(
927 left,
928 &left_ids,
929 right,
930 &right_ids,
931 output,
932 &self.output_ids,
933 alpha,
934 beta,
935 )?;
936 }
937 _ => {
938 let node_output_ids = compute_contract_output_ids(args, &self.output_ids);
940
941 let mut children: Vec<Option<(EinsumOperand<'_>, Vec<char>)>> = Vec::new();
942 for (i, arg) in args.iter().enumerate() {
943 let child_needed = compute_child_needed_ids(&node_output_ids, i, args);
944 let (op, ids) =
945 eval_node(arg, &mut ops, &child_needed, &mut pool, &unified)?;
946 children.push(Some((op, ids)));
947 }
948
949 let mut dim_sizes: HashMap<char, usize> = HashMap::new();
950 for child_opt in &children {
951 if let Some((op, ids)) = child_opt {
952 for (j, &id) in ids.iter().enumerate() {
953 dim_sizes.insert(id, op.dims()[j]);
954 }
955 }
956 }
957
958 let input_ids: Vec<Vec<char>> = children
959 .iter()
960 .map(|c| c.as_ref().unwrap().1.clone())
961 .collect();
962 let code = omeco::EinCode::new(input_ids, self.output_ids.clone());
963
964 let optimizer = omeco::GreedyMethod::default();
965 let nested = omeco::CodeOptimizer::optimize(&optimizer, &code, &dim_sizes)
966 .ok_or_else(|| {
967 crate::EinsumError::Internal(
968 "optimizer failed to produce a plan".into(),
969 )
970 })?;
971
972 execute_nested_into(
973 &nested,
974 &mut children,
975 output,
976 &self.output_ids,
977 alpha,
978 beta,
979 &mut pool,
980 &unified,
981 )?;
982 }
983 },
984 }
985
986 Ok(())
987 }
988}
989
990#[cfg(test)]
995mod tests {
996 use super::*;
997 use crate::parse::parse_einsum;
998 use approx::assert_abs_diff_eq;
999 use strided_view::{row_major_strides, StridedArray};
1000
1001 fn make_f64(dims: &[usize], data: Vec<f64>) -> EinsumOperand<'static> {
1002 let strides = row_major_strides(dims);
1003 StridedArray::from_parts(data, dims, &strides, 0)
1004 .unwrap()
1005 .into()
1006 }
1007
1008 #[test]
1009 fn test_matmul() {
1010 let code = parse_einsum("ij,jk->ik").unwrap();
1011 let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1012 let b = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1013 let result = code.evaluate(vec![a, b], None).unwrap();
1014 match result {
1015 EinsumOperand::F64(data) => {
1016 let arr = data.as_array();
1017 assert_eq!(arr.dims(), &[2, 2]);
1018 assert_abs_diff_eq!(arr.get(&[0, 0]), 19.0);
1019 assert_abs_diff_eq!(arr.get(&[0, 1]), 22.0);
1020 assert_abs_diff_eq!(arr.get(&[1, 0]), 43.0);
1021 assert_abs_diff_eq!(arr.get(&[1, 1]), 50.0);
1022 }
1023 _ => panic!("expected F64"),
1024 }
1025 }
1026
1027 #[test]
1028 fn test_nested_three_tensor() {
1029 let code = parse_einsum("(ij,jk),kl->il").unwrap();
1030 let a = make_f64(&[2, 2], vec![1.0, 0.0, 0.0, 1.0]);
1032 let b = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1033 let c = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1034 let result = code.evaluate(vec![a, b, c], None).unwrap();
1036 match result {
1037 EinsumOperand::F64(data) => {
1038 let arr = data.as_array();
1039 assert_eq!(arr.dims(), &[2, 2]);
1040 assert_abs_diff_eq!(arr.get(&[0, 0]), 19.0);
1041 assert_abs_diff_eq!(arr.get(&[0, 1]), 22.0);
1042 assert_abs_diff_eq!(arr.get(&[1, 0]), 43.0);
1043 assert_abs_diff_eq!(arr.get(&[1, 1]), 50.0);
1044 }
1045 _ => panic!("expected F64"),
1046 }
1047 }
1048
1049 #[test]
1050 fn test_outer_product() {
1051 let code = parse_einsum("i,j->ij").unwrap();
1052 let a = make_f64(&[3], vec![1.0, 2.0, 3.0]);
1053 let b = make_f64(&[2], vec![10.0, 20.0]);
1054 let result = code.evaluate(vec![a, b], None).unwrap();
1055 match result {
1056 EinsumOperand::F64(data) => {
1057 let arr = data.as_array();
1058 assert_eq!(arr.dims(), &[3, 2]);
1059 assert_abs_diff_eq!(arr.get(&[0, 0]), 10.0);
1060 assert_abs_diff_eq!(arr.get(&[2, 1]), 60.0);
1061 }
1062 _ => panic!("expected F64"),
1063 }
1064 }
1065
1066 #[test]
1067 fn test_dot_product() {
1068 let code = parse_einsum("i,i->").unwrap();
1069 let a = make_f64(&[3], vec![1.0, 2.0, 3.0]);
1070 let b = make_f64(&[3], vec![4.0, 5.0, 6.0]);
1071 let result = code.evaluate(vec![a, b], None).unwrap();
1072 match result {
1073 EinsumOperand::F64(data) => {
1074 assert_abs_diff_eq!(data.as_array().data()[0], 32.0);
1076 }
1077 _ => panic!("expected F64"),
1078 }
1079 }
1080
1081 #[test]
1082 fn test_single_tensor_permute() {
1083 let code = parse_einsum("ij->ji").unwrap();
1084 let a = make_f64(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1085 let result = code.evaluate(vec![a], None).unwrap();
1086 match result {
1087 EinsumOperand::F64(data) => {
1088 let arr = data.as_array();
1089 assert_eq!(arr.dims(), &[3, 2]);
1090 assert_abs_diff_eq!(arr.get(&[0, 0]), 1.0);
1091 assert_abs_diff_eq!(arr.get(&[0, 1]), 4.0);
1092 }
1093 _ => panic!("expected F64"),
1094 }
1095 }
1096
1097 #[test]
1098 fn test_single_tensor_trace() {
1099 let code = parse_einsum("ii->").unwrap();
1100 let a = make_f64(&[3, 3], vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]);
1101 let result = code.evaluate(vec![a], None).unwrap();
1102 match result {
1103 EinsumOperand::F64(data) => {
1104 assert_abs_diff_eq!(data.as_array().data()[0], 6.0);
1105 }
1106 _ => panic!("expected F64"),
1107 }
1108 }
1109
1110 #[test]
1111 fn test_three_tensor_flat_omeco() {
1112 let code = parse_einsum("ij,jk,kl->il").unwrap();
1114 let a = make_f64(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1115 let b = make_f64(&[3, 2], vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0]);
1116 let c = make_f64(&[2, 2], vec![1.0, 0.0, 0.0, 1.0]);
1118 let result = code.evaluate(vec![a, b, c], None).unwrap();
1119 match result {
1120 EinsumOperand::F64(data) => {
1121 let arr = data.as_array();
1122 assert_eq!(arr.dims(), &[2, 2]);
1123 assert_abs_diff_eq!(arr.get(&[0, 0]), 4.0, epsilon = 1e-10);
1124 assert_abs_diff_eq!(arr.get(&[0, 1]), 2.0, epsilon = 1e-10);
1125 assert_abs_diff_eq!(arr.get(&[1, 0]), 10.0, epsilon = 1e-10);
1126 assert_abs_diff_eq!(arr.get(&[1, 1]), 5.0, epsilon = 1e-10);
1127 }
1128 _ => panic!("expected F64"),
1129 }
1130 }
1131
1132 #[test]
1133 fn test_four_tensor_flat_omeco() {
1134 let code = parse_einsum("ij,jk,kl,lm->im").unwrap();
1136 let a = make_f64(&[2, 2], vec![1.0, 0.0, 0.0, 1.0]); let b = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1138 let c = make_f64(&[2, 2], vec![1.0, 0.0, 0.0, 1.0]); let d = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1140 let result = code.evaluate(vec![a, b, c, d], None).unwrap();
1142 match result {
1143 EinsumOperand::F64(data) => {
1144 let arr = data.as_array();
1145 assert_eq!(arr.dims(), &[2, 2]);
1146 assert_abs_diff_eq!(arr.get(&[0, 0]), 19.0, epsilon = 1e-10);
1147 assert_abs_diff_eq!(arr.get(&[1, 1]), 50.0, epsilon = 1e-10);
1148 }
1149 _ => panic!("expected F64"),
1150 }
1151 }
1152
1153 #[test]
1154 fn test_orphan_output_axis_returns_error() {
1155 let code = parse_einsum("ij,jk->iz").unwrap();
1156 let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1157 let b = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1158 let err = code.evaluate(vec![a, b], None).unwrap_err();
1159 assert!(matches!(err, crate::EinsumError::OrphanOutputAxis(ref s) if s == "z"));
1160 }
1161
1162 #[test]
1163 fn test_operand_count_mismatch_too_few() {
1164 let code = parse_einsum("ij,jk->ik").unwrap();
1165 let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1166 let err = code.evaluate(vec![a], None).unwrap_err();
1167 assert!(matches!(
1168 err,
1169 crate::EinsumError::OperandCountMismatch {
1170 expected: 2,
1171 found: 1
1172 }
1173 ));
1174 }
1175
1176 #[test]
1177 fn test_operand_count_mismatch_too_many() {
1178 let code = parse_einsum("ij->ji").unwrap();
1179 let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1180 let b = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1181 let err = code.evaluate(vec![a, b], None).unwrap_err();
1182 assert!(matches!(
1183 err,
1184 crate::EinsumError::OperandCountMismatch {
1185 expected: 1,
1186 found: 2
1187 }
1188 ));
1189 }
1190
1191 #[test]
1196 fn test_into_matmul() {
1197 let code = parse_einsum("ij,jk->ik").unwrap();
1198 let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1199 let b = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1200 let mut c = StridedArray::<f64>::col_major(&[2, 2]);
1201 code.evaluate_into(vec![a, b], c.view_mut(), 1.0, 0.0, None)
1202 .unwrap();
1203 assert_abs_diff_eq!(c.get(&[0, 0]), 19.0);
1204 assert_abs_diff_eq!(c.get(&[0, 1]), 22.0);
1205 assert_abs_diff_eq!(c.get(&[1, 0]), 43.0);
1206 assert_abs_diff_eq!(c.get(&[1, 1]), 50.0);
1207 }
1208
1209 #[test]
1210 fn test_into_matmul_alpha_beta() {
1211 let code = parse_einsum("ij,jk->ik").unwrap();
1213 let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1214 let b = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1215 let mut c = StridedArray::<f64>::col_major(&[2, 2]);
1219 for v in c.data_mut().iter_mut() {
1220 *v = 1.0;
1221 }
1222 code.evaluate_into(vec![a, b], c.view_mut(), 2.0, 3.0, None)
1223 .unwrap();
1224 assert_abs_diff_eq!(c.get(&[0, 0]), 41.0, epsilon = 1e-10);
1225 assert_abs_diff_eq!(c.get(&[0, 1]), 47.0, epsilon = 1e-10);
1226 assert_abs_diff_eq!(c.get(&[1, 0]), 89.0, epsilon = 1e-10);
1227 assert_abs_diff_eq!(c.get(&[1, 1]), 103.0, epsilon = 1e-10);
1228 }
1229
1230 #[test]
1231 fn test_into_single_tensor_permute() {
1232 let code = parse_einsum("ij->ji").unwrap();
1233 let a = make_f64(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1234 let mut c = StridedArray::<f64>::col_major(&[3, 2]);
1235 code.evaluate_into(vec![a], c.view_mut(), 1.0, 0.0, None)
1236 .unwrap();
1237 assert_eq!(c.dims(), &[3, 2]);
1238 assert_abs_diff_eq!(c.get(&[0, 0]), 1.0);
1239 assert_abs_diff_eq!(c.get(&[0, 1]), 4.0);
1240 assert_abs_diff_eq!(c.get(&[1, 0]), 2.0);
1241 assert_abs_diff_eq!(c.get(&[2, 1]), 6.0);
1242 }
1243
1244 #[test]
1245 fn test_into_single_tensor_trace() {
1246 let code = parse_einsum("ii->").unwrap();
1247 let a = make_f64(&[3, 3], vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]);
1248 let mut c = StridedArray::<f64>::col_major(&[]);
1249 code.evaluate_into(vec![a], c.view_mut(), 1.0, 0.0, None)
1250 .unwrap();
1251 assert_abs_diff_eq!(c.data()[0], 6.0);
1252 }
1253
1254 #[test]
1255 fn test_into_three_tensor_omeco() {
1256 let code = parse_einsum("ij,jk,kl->il").unwrap();
1257 let a = make_f64(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1258 let b = make_f64(&[3, 2], vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0]);
1259 let c_op = make_f64(&[2, 2], vec![1.0, 0.0, 0.0, 1.0]);
1260 let mut out = StridedArray::<f64>::col_major(&[2, 2]);
1261 code.evaluate_into(vec![a, b, c_op], out.view_mut(), 1.0, 0.0, None)
1262 .unwrap();
1263 assert_abs_diff_eq!(out.get(&[0, 0]), 4.0, epsilon = 1e-10);
1264 assert_abs_diff_eq!(out.get(&[0, 1]), 2.0, epsilon = 1e-10);
1265 assert_abs_diff_eq!(out.get(&[1, 0]), 10.0, epsilon = 1e-10);
1266 assert_abs_diff_eq!(out.get(&[1, 1]), 5.0, epsilon = 1e-10);
1267 }
1268
1269 #[test]
1270 fn test_into_nested() {
1271 let code = parse_einsum("(ij,jk),kl->il").unwrap();
1272 let a = make_f64(&[2, 2], vec![1.0, 0.0, 0.0, 1.0]);
1273 let b = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1274 let c_op = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1275 let mut out = StridedArray::<f64>::col_major(&[2, 2]);
1276 code.evaluate_into(vec![a, b, c_op], out.view_mut(), 1.0, 0.0, None)
1277 .unwrap();
1278 assert_abs_diff_eq!(out.get(&[0, 0]), 19.0, epsilon = 1e-10);
1279 assert_abs_diff_eq!(out.get(&[0, 1]), 22.0, epsilon = 1e-10);
1280 assert_abs_diff_eq!(out.get(&[1, 0]), 43.0, epsilon = 1e-10);
1281 assert_abs_diff_eq!(out.get(&[1, 1]), 50.0, epsilon = 1e-10);
1282 }
1283
1284 #[test]
1285 fn test_into_dot_product() {
1286 let code = parse_einsum("i,i->").unwrap();
1287 let a = make_f64(&[3], vec![1.0, 2.0, 3.0]);
1288 let b = make_f64(&[3], vec![4.0, 5.0, 6.0]);
1289 let mut c = StridedArray::<f64>::col_major(&[]);
1290 code.evaluate_into(vec![a, b], c.view_mut(), 1.0, 0.0, None)
1291 .unwrap();
1292 assert_abs_diff_eq!(c.data()[0], 32.0);
1293 }
1294
1295 #[test]
1296 fn test_into_type_mismatch_f64_output_c64_input() {
1297 let code = parse_einsum("ij->ji").unwrap();
1298 let c64_data = vec![
1299 Complex64::new(1.0, 0.0),
1300 Complex64::new(2.0, 0.0),
1301 Complex64::new(3.0, 0.0),
1302 Complex64::new(4.0, 0.0),
1303 ];
1304 let strides = row_major_strides(&[2, 2]);
1305 let arr = StridedArray::from_parts(c64_data, &[2, 2], &strides, 0).unwrap();
1306 let op = EinsumOperand::C64(StridedData::Owned(arr));
1307 let mut out = StridedArray::<f64>::col_major(&[2, 2]);
1308 let err = code
1309 .evaluate_into(vec![op], out.view_mut(), 1.0, 0.0, None)
1310 .unwrap_err();
1311 assert!(matches!(err, crate::EinsumError::TypeMismatch { .. }));
1312 }
1313
1314 #[test]
1315 fn test_into_shape_mismatch() {
1316 let code = parse_einsum("ij,jk->ik").unwrap();
1317 let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1318 let b = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1319 let mut out = StridedArray::<f64>::col_major(&[3, 3]); let err = code
1321 .evaluate_into(vec![a, b], out.view_mut(), 1.0, 0.0, None)
1322 .unwrap_err();
1323 assert!(matches!(
1324 err,
1325 crate::EinsumError::OutputShapeMismatch { .. }
1326 ));
1327 }
1328
1329 #[test]
1330 fn test_into_c64_output() {
1331 let code = parse_einsum("ij,jk->ik").unwrap();
1332 let c64 = |r| Complex64::new(r, 0.0);
1333 let a_data = vec![c64(1.0), c64(2.0), c64(3.0), c64(4.0)];
1334 let b_data = vec![c64(5.0), c64(6.0), c64(7.0), c64(8.0)];
1335 let strides = row_major_strides(&[2, 2]);
1336 let a = EinsumOperand::C64(StridedData::Owned(
1337 StridedArray::from_parts(a_data, &[2, 2], &strides, 0).unwrap(),
1338 ));
1339 let b = EinsumOperand::C64(StridedData::Owned(
1340 StridedArray::from_parts(b_data, &[2, 2], &strides, 0).unwrap(),
1341 ));
1342 let mut out = StridedArray::<Complex64>::col_major(&[2, 2]);
1343 code.evaluate_into(
1344 vec![a, b],
1345 out.view_mut(),
1346 c64(1.0),
1347 Complex64::zero(),
1348 None,
1349 )
1350 .unwrap();
1351 assert_abs_diff_eq!(out.get(&[0, 0]).re, 19.0);
1352 assert_abs_diff_eq!(out.get(&[1, 1]).re, 50.0);
1353 }
1354
1355 #[test]
1356 fn test_into_mixed_types_c64_output() {
1357 let code = parse_einsum("ij,jk->ik").unwrap();
1359 let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1360 let c64 = |r| Complex64::new(r, 0.0);
1361 let b_data = vec![c64(5.0), c64(6.0), c64(7.0), c64(8.0)];
1362 let strides = row_major_strides(&[2, 2]);
1363 let b = EinsumOperand::C64(StridedData::Owned(
1364 StridedArray::from_parts(b_data, &[2, 2], &strides, 0).unwrap(),
1365 ));
1366 let mut out = StridedArray::<Complex64>::col_major(&[2, 2]);
1367 code.evaluate_into(
1368 vec![a, b],
1369 out.view_mut(),
1370 c64(1.0),
1371 Complex64::zero(),
1372 None,
1373 )
1374 .unwrap();
1375 assert_abs_diff_eq!(out.get(&[0, 0]).re, 19.0);
1376 assert_abs_diff_eq!(out.get(&[1, 1]).re, 50.0);
1377 }
1378}