1use std::collections::{BTreeMap, HashMap};
2
3use num_complex::Complex64;
4#[cfg(test)]
5use num_traits::Zero;
6use strided_einsum2::{einsum2_into, einsum2_into_owned};
7use strided_kernel::copy_scale;
8use strided_view::{StridedArray, StridedViewMut};
9
10use crate::operand::{EinsumOperand, EinsumScalar, StridedData};
11use crate::parse::{EinsumCode, EinsumNode};
12use crate::single_tensor::single_tensor_einsum;
13
14pub struct BufferPool {
40 f64_pool: BTreeMap<usize, Vec<Vec<f64>>>,
41 c64_pool: BTreeMap<usize, Vec<Vec<Complex64>>>,
42}
43
44impl BufferPool {
45 pub fn new() -> Self {
47 Self {
48 f64_pool: BTreeMap::new(),
49 c64_pool: BTreeMap::new(),
50 }
51 }
52}
53
54trait PoolOps: EinsumScalar {
62 fn pool_acquire(pool: &mut BufferPool, dims: &[usize]) -> StridedArray<Self>;
68
69 fn pool_release(pool: &mut BufferPool, data: StridedData<'_, Self>);
72}
73
74fn take_best_fit<T>(pool: &mut BTreeMap<usize, Vec<Vec<T>>>, total: usize) -> Option<Vec<T>> {
76 let key = *pool.range(total..).next()?.0;
78 let vecs = pool.get_mut(&key)?;
79 let buf = vecs.pop();
80 if vecs.is_empty() {
81 pool.remove(&key);
82 }
83 buf
84}
85
86impl PoolOps for f64 {
87 fn pool_acquire(pool: &mut BufferPool, dims: &[usize]) -> StridedArray<f64> {
88 let total: usize = dims.iter().product();
89 match take_best_fit(&mut pool.f64_pool, total) {
91 Some(buf) => unsafe { StridedArray::col_major_from_buffer_uninit(buf, dims) },
92 None => unsafe { StridedArray::col_major_uninit(dims) },
93 }
94 }
95
96 fn pool_release(pool: &mut BufferPool, data: StridedData<'_, f64>) {
97 if let StridedData::Owned(arr) = data {
98 let buf = arr.into_data();
99 pool.f64_pool.entry(buf.len()).or_default().push(buf);
100 }
101 }
102}
103
104impl PoolOps for Complex64 {
105 fn pool_acquire(pool: &mut BufferPool, dims: &[usize]) -> StridedArray<Complex64> {
106 let total: usize = dims.iter().product();
107 match take_best_fit(&mut pool.c64_pool, total) {
109 Some(buf) => unsafe { StridedArray::col_major_from_buffer_uninit(buf, dims) },
110 None => unsafe { StridedArray::col_major_uninit(dims) },
111 }
112 }
113
114 fn pool_release(pool: &mut BufferPool, data: StridedData<'_, Complex64>) {
115 if let StridedData::Owned(arr) = data {
116 let buf = arr.into_data();
117 pool.c64_pool.entry(buf.len()).or_default().push(buf);
118 }
119 }
120}
121
122fn collect_all_ids(node: &EinsumNode) -> Vec<char> {
127 let mut result = Vec::new();
128 collect_all_ids_inner(node, &mut result);
129 result
130}
131
132fn collect_all_ids_inner(node: &EinsumNode, result: &mut Vec<char>) {
133 match node {
134 EinsumNode::Leaf { ids, .. } => {
135 for &id in ids {
136 if !result.contains(&id) {
137 result.push(id);
138 }
139 }
140 }
141 EinsumNode::Contract { args } => {
142 for arg in args {
143 collect_all_ids_inner(arg, result);
144 }
145 }
146 }
147}
148
149fn compute_contract_output_ids(args: &[EinsumNode], needed_ids: &[char]) -> Vec<char> {
159 if args.len() == 2 {
160 let left_ids = collect_all_ids(&args[0]);
161 let right_ids = collect_all_ids(&args[1]);
162 return compute_binary_output_ids(&left_ids, &right_ids, needed_ids);
163 }
164
165 let mut all_ids_ordered = Vec::new();
167 for arg in args {
168 for id in collect_all_ids(arg) {
169 if !all_ids_ordered.contains(&id) {
170 all_ids_ordered.push(id);
171 }
172 }
173 }
174
175 all_ids_ordered
177 .into_iter()
178 .filter(|id| needed_ids.contains(id))
179 .collect()
180}
181
182fn compute_child_needed_ids(
192 output_ids: &[char],
193 child_idx: usize,
194 args: &[EinsumNode],
195) -> Vec<char> {
196 let mut needed: Vec<char> = output_ids.to_vec();
197
198 let child_ids = collect_all_ids(&args[child_idx]);
200 for (j, arg) in args.iter().enumerate() {
201 if j == child_idx {
202 continue;
203 }
204 let sibling_ids = collect_all_ids(arg);
205 for &id in &child_ids {
206 if sibling_ids.contains(&id) && !needed.contains(&id) {
207 needed.push(id);
208 }
209 }
210 }
211
212 needed
213}
214
215fn out_dims_from_map(
220 dim_map: &HashMap<char, usize>,
221 output_ids: &[char],
222 size_dict: &HashMap<char, usize>,
223) -> crate::Result<Vec<usize>> {
224 let mut out_dims = Vec::with_capacity(output_ids.len());
225 for &id in output_ids {
226 if let Some(&dim) = dim_map.get(&id) {
227 out_dims.push(dim);
228 } else if let Some(&dim) = size_dict.get(&id) {
229 out_dims.push(dim);
230 } else {
231 return Err(crate::EinsumError::OrphanOutputAxis(id.to_string()));
232 }
233 }
234 Ok(out_dims)
235}
236
237fn out_dims_from_ids(
239 left_ids: &[char],
240 left_dims: &[usize],
241 right_ids: &[char],
242 right_dims: &[usize],
243 output_ids: &[char],
244 size_dict: &HashMap<char, usize>,
245) -> crate::Result<Vec<usize>> {
246 let mut out_dims = Vec::with_capacity(output_ids.len());
247 for &id in output_ids {
248 if let Some(pos) = left_ids.iter().position(|&c| c == id) {
249 out_dims.push(left_dims[pos]);
250 } else if let Some(pos) = right_ids.iter().position(|&c| c == id) {
251 out_dims.push(right_dims[pos]);
252 } else if let Some(&dim) = size_dict.get(&id) {
253 out_dims.push(dim);
254 } else {
255 return Err(crate::EinsumError::OrphanOutputAxis(id.to_string()));
256 }
257 }
258 Ok(out_dims)
259}
260
261fn compute_binary_output_ids(
268 left_ids: &[char],
269 right_ids: &[char],
270 needed_ids: &[char],
271) -> Vec<char> {
272 let mut out = Vec::new();
273 for &id in left_ids {
274 if needed_ids.contains(&id) && !right_ids.contains(&id) && !out.contains(&id) {
275 out.push(id);
276 }
277 }
278 for &id in right_ids {
279 if needed_ids.contains(&id) && !left_ids.contains(&id) && !out.contains(&id) {
280 out.push(id);
281 }
282 }
283 for &id in left_ids {
284 if needed_ids.contains(&id) && right_ids.contains(&id) && !out.contains(&id) {
285 out.push(id);
286 }
287 }
288 out
289}
290
291fn eval_pair_alloc<T: PoolOps>(
296 ld: StridedData<'_, T>,
297 left_ids: &[char],
298 rd: StridedData<'_, T>,
299 right_ids: &[char],
300 output_ids: &[char],
301 pool: &mut BufferPool,
302 size_dict: &HashMap<char, usize>,
303) -> crate::Result<EinsumOperand<'static>> {
304 let out_dims = out_dims_from_ids(
305 left_ids,
306 ld.dims(),
307 right_ids,
308 rd.dims(),
309 output_ids,
310 size_dict,
311 )?;
312 let mut c_arr = T::pool_acquire(pool, &out_dims);
313 match (ld, rd) {
314 (StridedData::Owned(a), StridedData::Owned(b)) => {
317 einsum2_into_owned(
318 c_arr.view_mut(),
319 a,
320 b,
321 output_ids,
322 left_ids,
323 right_ids,
324 T::one(),
325 T::zero(),
326 false,
327 false,
328 )?;
329 }
330 (ld, rd) => {
331 let a_view = ld.as_view();
332 let b_view = rd.as_view();
333 einsum2_into(
334 c_arr.view_mut(),
335 &a_view,
336 &b_view,
337 output_ids,
338 left_ids,
339 right_ids,
340 T::one(),
341 T::zero(),
342 )?;
343 T::pool_release(pool, ld);
344 T::pool_release(pool, rd);
345 }
346 }
347 Ok(T::wrap_array(c_arr))
348}
349
350fn eval_pair(
355 left: EinsumOperand<'_>,
356 left_ids: &[char],
357 right: EinsumOperand<'_>,
358 right_ids: &[char],
359 output_ids: &[char],
360 pool: &mut BufferPool,
361 size_dict: &HashMap<char, usize>,
362) -> crate::Result<EinsumOperand<'static>> {
363 match (left, right) {
364 (EinsumOperand::F64(ld), EinsumOperand::F64(rd)) => {
365 eval_pair_alloc(ld, left_ids, rd, right_ids, output_ids, pool, size_dict)
366 }
367 (EinsumOperand::C64(ld), EinsumOperand::C64(rd)) => {
368 eval_pair_alloc(ld, left_ids, rd, right_ids, output_ids, pool, size_dict)
369 }
370 (left, right) => {
371 let left_c64 = left.to_c64_owned();
373 let right_c64 = right.to_c64_owned();
374 eval_pair(
375 left_c64, left_ids, right_c64, right_ids, output_ids, pool, size_dict,
376 )
377 }
378 }
379}
380
381fn eval_pair_into<T: EinsumScalar>(
391 left: EinsumOperand<'_>,
392 left_ids: &[char],
393 right: EinsumOperand<'_>,
394 right_ids: &[char],
395 output: StridedViewMut<T>,
396 output_ids: &[char],
397 alpha: T,
398 beta: T,
399) -> crate::Result<()> {
400 let left_data = T::extract_data(left)?;
401 let right_data = T::extract_data(right)?;
402
403 match (left_data, right_data) {
404 (StridedData::Owned(a), StridedData::Owned(b)) => {
405 einsum2_into_owned(
406 output, a, b, output_ids, left_ids, right_ids, alpha, beta, false, false,
407 )?;
408 }
409 (StridedData::Owned(a), StridedData::View(b)) => {
410 einsum2_into(
411 output,
412 &a.view(),
413 &b,
414 output_ids,
415 left_ids,
416 right_ids,
417 alpha,
418 beta,
419 )?;
420 }
421 (StridedData::View(a), StridedData::Owned(b)) => {
422 einsum2_into(
423 output,
424 &a,
425 &b.view(),
426 output_ids,
427 left_ids,
428 right_ids,
429 alpha,
430 beta,
431 )?;
432 }
433 (StridedData::View(a), StridedData::View(b)) => {
434 einsum2_into(output, &a, &b, output_ids, left_ids, right_ids, alpha, beta)?;
435 }
436 }
437 Ok(())
438}
439
440fn accumulate_into<T: EinsumScalar>(
448 output: &mut StridedViewMut<T>,
449 result: &StridedArray<T>,
450 alpha: T,
451 beta: T,
452) -> crate::Result<()> {
453 let result_view = result.view();
454 if beta == T::zero() {
455 if alpha == T::one() {
456 strided_kernel::copy_into(output, &result_view)?;
457 } else {
458 copy_scale(output, &result_view, alpha)?;
459 }
460 } else {
461 let dims = output.dims().to_vec();
465 let mut temp = StridedArray::<T>::col_major(&dims);
466 strided_kernel::copy_into(&mut temp.view_mut(), &result_view)?;
467 let mut output_copy = StridedArray::<T>::col_major(&dims);
479 strided_kernel::copy_into(&mut output_copy.view_mut(), &output.as_view())?;
480 strided_kernel::zip_map2_into(output, &temp.view(), &output_copy.view(), |r, o| {
481 alpha * r + beta * o
482 })?;
483 }
484 Ok(())
485}
486
487fn eval_single_typed<T: EinsumScalar>(
493 data: &StridedData<'_, T>,
494 input_ids: &[char],
495 output_ids: &[char],
496 size_dict: &HashMap<char, usize>,
497) -> crate::Result<EinsumOperand<'static>> {
498 let view = data.as_view();
499 let result = single_tensor_einsum(&view, input_ids, output_ids, Some(size_dict))?;
500 Ok(T::wrap_array(result))
501}
502
503fn eval_single(
504 operand: &EinsumOperand<'_>,
505 input_ids: &[char],
506 output_ids: &[char],
507 size_dict: &HashMap<char, usize>,
508) -> crate::Result<EinsumOperand<'static>> {
509 match operand {
510 EinsumOperand::F64(data) => eval_single_typed(data, input_ids, output_ids, size_dict),
511 EinsumOperand::C64(data) => eval_single_typed(data, input_ids, output_ids, size_dict),
512 }
513}
514
515fn is_permutation_only(input_ids: &[char], output_ids: &[char]) -> bool {
522 if input_ids.len() != output_ids.len() {
523 return false;
524 }
525 for (i, &id) in input_ids.iter().enumerate() {
527 if input_ids[..i].contains(&id) {
528 return false; }
530 }
531 for &id in output_ids {
533 if !input_ids.contains(&id) {
534 return false;
535 }
536 }
537 true
538}
539
540fn compute_permutation(input_ids: &[char], output_ids: &[char]) -> Vec<usize> {
542 output_ids
543 .iter()
544 .map(|oid| input_ids.iter().position(|iid| iid == oid).unwrap())
545 .collect()
546}
547
548fn execute_nested<'a>(
559 nested: &omeco::NestedEinsum<char>,
560 children: &mut Vec<Option<(EinsumOperand<'a>, Vec<char>)>>,
561 pool: &mut BufferPool,
562 size_dict: &HashMap<char, usize>,
563) -> crate::Result<(EinsumOperand<'a>, Vec<char>)> {
564 match nested {
565 omeco::NestedEinsum::Leaf { tensor_index } => {
566 let slot = children.get_mut(*tensor_index).ok_or_else(|| {
567 crate::EinsumError::Internal(format!(
568 "optimizer referenced child index {} out of bounds",
569 tensor_index
570 ))
571 })?;
572 let (op, ids) = slot.take().ok_or_else(|| {
573 crate::EinsumError::Internal(format!(
574 "child operand {} was already consumed",
575 tensor_index
576 ))
577 })?;
578 Ok((op, ids))
579 }
580 omeco::NestedEinsum::Node { args, eins } => {
581 if args.len() != 2 {
582 return Err(crate::EinsumError::Internal(format!(
583 "optimizer produced non-binary node with {} children",
584 args.len()
585 )));
586 }
587 let (left, left_ids) = execute_nested(&args[0], children, pool, size_dict)?;
588 let (right, right_ids) = execute_nested(&args[1], children, pool, size_dict)?;
589 let output_ids: Vec<char> = eins.iy.clone();
590 let result = eval_pair(
591 left,
592 &left_ids,
593 right,
594 &right_ids,
595 &output_ids,
596 pool,
597 size_dict,
598 )?;
599 Ok((result, output_ids))
600 }
601 }
602}
603
604fn execute_nested_into<'a, T: EinsumScalar>(
610 nested: &omeco::NestedEinsum<char>,
611 children: &mut Vec<Option<(EinsumOperand<'a>, Vec<char>)>>,
612 output: StridedViewMut<T>,
613 output_ids: &[char],
614 alpha: T,
615 beta: T,
616 pool: &mut BufferPool,
617 size_dict: &HashMap<char, usize>,
618) -> crate::Result<()> {
619 match nested {
620 omeco::NestedEinsum::Node { args, eins: _ } => {
621 if args.len() != 2 {
622 return Err(crate::EinsumError::Internal(format!(
623 "optimizer produced non-binary node with {} children",
624 args.len()
625 )));
626 }
627 let (left, left_ids) = execute_nested(&args[0], children, pool, size_dict)?;
629 let (right, right_ids) = execute_nested(&args[1], children, pool, size_dict)?;
630 eval_pair_into(
632 left, &left_ids, right, &right_ids, output, output_ids, alpha, beta,
633 )
634 }
635 omeco::NestedEinsum::Leaf { tensor_index } => {
636 let slot = children.get_mut(*tensor_index).ok_or_else(|| {
638 crate::EinsumError::Internal(format!(
639 "optimizer referenced child index {} out of bounds",
640 tensor_index
641 ))
642 })?;
643 let (op, op_ids) = slot.take().ok_or_else(|| {
644 crate::EinsumError::Internal(format!(
645 "child operand {} was already consumed",
646 tensor_index
647 ))
648 })?;
649 let data = T::extract_data(op)?;
650 let arr = data.into_array();
651 if op_ids != output_ids {
653 let perm = compute_permutation(&op_ids, output_ids);
654 let permuted = arr.permuted(&perm)?;
655 accumulate_into(&mut { output }, &permuted, alpha, beta)?;
656 } else {
657 accumulate_into(&mut { output }, &arr, alpha, beta)?;
658 }
659 Ok(())
660 }
661 }
662}
663
664fn eval_node<'a>(
677 node: &EinsumNode,
678 operands: &mut Vec<Option<EinsumOperand<'a>>>,
679 needed_ids: &[char],
680 pool: &mut BufferPool,
681 size_dict: &HashMap<char, usize>,
682) -> crate::Result<(EinsumOperand<'a>, Vec<char>)> {
683 match node {
684 EinsumNode::Leaf { ids, tensor_index } => {
685 let found = operands.len();
686 let slot = operands.get_mut(*tensor_index).ok_or_else(|| {
687 crate::EinsumError::OperandCountMismatch {
688 expected: tensor_index + 1,
689 found,
690 }
691 })?;
692 let op = slot.take().ok_or_else(|| {
693 crate::EinsumError::Internal(format!(
694 "operand {} was already consumed",
695 tensor_index
696 ))
697 })?;
698 Ok((op, ids.clone()))
700 }
701 EinsumNode::Contract { args } => {
702 let node_output_ids = compute_contract_output_ids(args, needed_ids);
704
705 match args.len() {
706 0 => unreachable!("empty Contract node"),
707 1 => {
708 let child_needed = compute_child_needed_ids(&node_output_ids, 0, args);
710 let (child_op, child_ids) =
711 eval_node(&args[0], operands, &child_needed, pool, size_dict)?;
712
713 if child_ids == node_output_ids {
715 return Ok((child_op, node_output_ids));
716 }
717
718 if is_permutation_only(&child_ids, &node_output_ids) {
720 let perm = compute_permutation(&child_ids, &node_output_ids);
721 return Ok((child_op.permuted(&perm)?, node_output_ids));
722 }
723
724 let result = eval_single(&child_op, &child_ids, &node_output_ids, size_dict)?;
726 Ok((result, node_output_ids))
727 }
728 2 => {
729 let left_needed = compute_child_needed_ids(&node_output_ids, 0, args);
731 let right_needed = compute_child_needed_ids(&node_output_ids, 1, args);
732 let (left, left_ids) =
733 eval_node(&args[0], operands, &left_needed, pool, size_dict)?;
734 let (right, right_ids) =
735 eval_node(&args[1], operands, &right_needed, pool, size_dict)?;
736 let result = eval_pair(
737 left,
738 &left_ids,
739 right,
740 &right_ids,
741 &node_output_ids,
742 pool,
743 size_dict,
744 )?;
745 Ok((result, node_output_ids))
746 }
747 _ => {
748 let mut children: Vec<Option<(EinsumOperand<'a>, Vec<char>)>> = Vec::new();
753 for (i, arg) in args.iter().enumerate() {
754 let child_needed = compute_child_needed_ids(&node_output_ids, i, args);
755 let (op, ids) = eval_node(arg, operands, &child_needed, pool, size_dict)?;
756 children.push(Some((op, ids)));
757 }
758
759 let mut dim_sizes: HashMap<char, usize> = HashMap::new();
761 for child_opt in &children {
762 if let Some((op, ids)) = child_opt {
763 for (j, &id) in ids.iter().enumerate() {
764 dim_sizes.insert(id, op.dims()[j]);
765 }
766 }
767 }
768
769 let input_ids: Vec<Vec<char>> = children
771 .iter()
772 .map(|c| c.as_ref().unwrap().1.clone())
773 .collect();
774 let code = omeco::EinCode::new(input_ids, node_output_ids.clone());
775
776 let optimizer = omeco::GreedyMethod::default();
778 let nested = omeco::CodeOptimizer::optimize(&optimizer, &code, &dim_sizes)
779 .ok_or_else(|| {
780 crate::EinsumError::Internal(
781 "optimizer failed to produce a plan".into(),
782 )
783 })?;
784
785 let (result, result_ids) =
787 execute_nested(&nested, &mut children, pool, size_dict)?;
788 Ok((result, result_ids))
789 }
790 }
791 }
792 }
793}
794
795impl EinsumCode {
800 pub fn evaluate<'a>(
812 &self,
813 operands: Vec<EinsumOperand<'a>>,
814 size_dict: Option<&HashMap<char, usize>>,
815 ) -> crate::Result<EinsumOperand<'a>> {
816 self.evaluate_with_pool(operands, size_dict, None)
817 }
818
819 pub fn evaluate_with_pool<'a>(
824 &self,
825 operands: Vec<EinsumOperand<'a>>,
826 size_dict: Option<&HashMap<char, usize>>,
827 pool: Option<&mut BufferPool>,
828 ) -> crate::Result<EinsumOperand<'a>> {
829 let expected = leaf_count(&self.root);
830 if operands.len() != expected {
831 return Err(crate::EinsumError::OperandCountMismatch {
832 expected,
833 found: operands.len(),
834 });
835 }
836
837 let mut ops: Vec<Option<EinsumOperand<'a>>> = operands.into_iter().map(Some).collect();
838 let mut temp_pool;
839 let pool = match pool {
840 Some(p) => p,
841 None => {
842 temp_pool = BufferPool::new();
843 &mut temp_pool
844 }
845 };
846
847 let mut unified = build_dim_map(&self.root, &ops);
849 if let Some(sd) = size_dict {
850 merge_size_dict(&mut unified, sd)?;
851 }
852
853 let (result, result_ids) =
854 eval_node(&self.root, &mut ops, &self.output_ids, pool, &unified)?;
855
856 if result_ids == self.output_ids {
858 return Ok(result);
859 }
860
861 if is_permutation_only(&result_ids, &self.output_ids) {
863 let perm = compute_permutation(&result_ids, &self.output_ids);
864 return Ok(result.permuted(&perm)?);
865 }
866
867 let adjusted = eval_single(&result, &result_ids, &self.output_ids, &unified)?;
869 Ok(adjusted)
870 }
871}
872
873fn leaf_count(node: &EinsumNode) -> usize {
874 match node {
875 EinsumNode::Leaf { .. } => 1,
876 EinsumNode::Contract { args } => args.iter().map(leaf_count).sum(),
877 }
878}
879
880fn build_dim_map(
885 node: &EinsumNode,
886 operands: &[Option<EinsumOperand<'_>>],
887) -> HashMap<char, usize> {
888 let mut dim_map = HashMap::new();
889 build_dim_map_inner(node, operands, &mut dim_map);
890 dim_map
891}
892
893fn merge_size_dict(
897 unified: &mut HashMap<char, usize>,
898 user: &HashMap<char, usize>,
899) -> crate::Result<()> {
900 for (&label, &size) in user {
901 if let Some(&existing) = unified.get(&label) {
902 if existing != size {
903 return Err(crate::EinsumError::DimensionMismatch {
904 axis: label.to_string(),
905 dim_a: existing,
906 dim_b: size,
907 });
908 }
909 } else {
910 unified.insert(label, size);
911 }
912 }
913 Ok(())
914}
915
916fn build_dim_map_inner(
917 node: &EinsumNode,
918 operands: &[Option<EinsumOperand<'_>>],
919 dim_map: &mut HashMap<char, usize>,
920) {
921 match node {
922 EinsumNode::Leaf { ids, tensor_index } => {
923 if let Some(Some(op)) = operands.get(*tensor_index) {
924 for (i, &id) in ids.iter().enumerate() {
925 dim_map.insert(id, op.dims()[i]);
926 }
927 }
928 }
929 EinsumNode::Contract { args } => {
930 for arg in args {
931 build_dim_map_inner(arg, operands, dim_map);
932 }
933 }
934 }
935}
936
937impl EinsumCode {
938 pub fn evaluate_into<T: EinsumScalar>(
951 &self,
952 operands: Vec<EinsumOperand<'_>>,
953 output: StridedViewMut<T>,
954 alpha: T,
955 beta: T,
956 size_dict: Option<&HashMap<char, usize>>,
957 ) -> crate::Result<()> {
958 self.evaluate_into_with_pool(operands, output, alpha, beta, size_dict, None)
959 }
960
961 pub fn evaluate_into_with_pool<T: EinsumScalar>(
970 &self,
971 operands: Vec<EinsumOperand<'_>>,
972 mut output: StridedViewMut<T>,
973 alpha: T,
974 beta: T,
975 size_dict: Option<&HashMap<char, usize>>,
976 pool: Option<&mut BufferPool>,
977 ) -> crate::Result<()> {
978 let expected = leaf_count(&self.root);
979 if operands.len() != expected {
980 return Err(crate::EinsumError::OperandCountMismatch {
981 expected,
982 found: operands.len(),
983 });
984 }
985
986 let mut ops: Vec<Option<EinsumOperand<'_>>> = operands.into_iter().map(Some).collect();
988 T::validate_operands(&ops)?;
989
990 let mut unified = build_dim_map(&self.root, &ops);
992 if let Some(sd) = size_dict {
993 merge_size_dict(&mut unified, sd)?;
994 }
995
996 let expected_dims = out_dims_from_map(&unified, &self.output_ids, &unified)?;
998 if output.dims() != expected_dims.as_slice() {
999 return Err(crate::EinsumError::OutputShapeMismatch {
1000 expected: expected_dims,
1001 got: output.dims().to_vec(),
1002 });
1003 }
1004
1005 let mut temp_pool;
1006 let pool = match pool {
1007 Some(p) => p,
1008 None => {
1009 temp_pool = BufferPool::new();
1010 &mut temp_pool
1011 }
1012 };
1013
1014 match &self.root {
1015 EinsumNode::Leaf { ids, tensor_index } => {
1016 let op = ops[*tensor_index].take().ok_or_else(|| {
1018 crate::EinsumError::Internal("operand already consumed".into())
1019 })?;
1020 let single_result = eval_single(&op, ids, &self.output_ids, &unified)?;
1021 let data = T::extract_data(single_result)?;
1022 accumulate_into(&mut output, &data.into_array(), alpha, beta)?;
1023 }
1024 EinsumNode::Contract { args } => match args.len() {
1025 0 => unreachable!("empty Contract node"),
1026 1 => {
1027 let child_needed = compute_child_needed_ids(&self.output_ids, 0, args);
1029 let (child_op, child_ids) =
1030 eval_node(&args[0], &mut ops, &child_needed, pool, &unified)?;
1031
1032 if child_ids == self.output_ids {
1033 let data = T::extract_data(child_op)?;
1035 accumulate_into(&mut output, &data.into_array(), alpha, beta)?;
1036 } else if is_permutation_only(&child_ids, &self.output_ids) {
1037 let perm = compute_permutation(&child_ids, &self.output_ids);
1039 let data = T::extract_data(child_op)?;
1040 let arr = data.into_array();
1041 let permuted = arr.permuted(&perm)?;
1042 accumulate_into(&mut output, &permuted, alpha, beta)?;
1043 } else {
1044 let result =
1046 eval_single(&child_op, &child_ids, &self.output_ids, &unified)?;
1047 let data = T::extract_data(result)?;
1048 accumulate_into(&mut output, &data.into_array(), alpha, beta)?;
1049 }
1050 }
1051 2 => {
1052 let left_needed = compute_child_needed_ids(&self.output_ids, 0, args);
1054 let right_needed = compute_child_needed_ids(&self.output_ids, 1, args);
1055 let (left, left_ids) =
1056 eval_node(&args[0], &mut ops, &left_needed, pool, &unified)?;
1057 let (right, right_ids) =
1058 eval_node(&args[1], &mut ops, &right_needed, pool, &unified)?;
1059 eval_pair_into(
1060 left,
1061 &left_ids,
1062 right,
1063 &right_ids,
1064 output,
1065 &self.output_ids,
1066 alpha,
1067 beta,
1068 )?;
1069 }
1070 _ => {
1071 let node_output_ids = compute_contract_output_ids(args, &self.output_ids);
1073
1074 let mut children: Vec<Option<(EinsumOperand<'_>, Vec<char>)>> = Vec::new();
1075 for (i, arg) in args.iter().enumerate() {
1076 let child_needed = compute_child_needed_ids(&node_output_ids, i, args);
1077 let (op, ids) = eval_node(arg, &mut ops, &child_needed, pool, &unified)?;
1078 children.push(Some((op, ids)));
1079 }
1080
1081 let mut dim_sizes: HashMap<char, usize> = HashMap::new();
1082 for child_opt in &children {
1083 if let Some((op, ids)) = child_opt {
1084 for (j, &id) in ids.iter().enumerate() {
1085 dim_sizes.insert(id, op.dims()[j]);
1086 }
1087 }
1088 }
1089
1090 let input_ids: Vec<Vec<char>> = children
1091 .iter()
1092 .map(|c| c.as_ref().unwrap().1.clone())
1093 .collect();
1094 let code = omeco::EinCode::new(input_ids, self.output_ids.clone());
1095
1096 let optimizer = omeco::GreedyMethod::default();
1097 let nested = omeco::CodeOptimizer::optimize(&optimizer, &code, &dim_sizes)
1098 .ok_or_else(|| {
1099 crate::EinsumError::Internal(
1100 "optimizer failed to produce a plan".into(),
1101 )
1102 })?;
1103
1104 execute_nested_into(
1105 &nested,
1106 &mut children,
1107 output,
1108 &self.output_ids,
1109 alpha,
1110 beta,
1111 pool,
1112 &unified,
1113 )?;
1114 }
1115 },
1116 }
1117
1118 Ok(())
1119 }
1120}
1121
1122#[cfg(test)]
1127mod tests {
1128 use super::*;
1129 use crate::parse::parse_einsum;
1130 use approx::assert_abs_diff_eq;
1131 use strided_view::{row_major_strides, StridedArray};
1132
1133 fn make_f64(dims: &[usize], data: Vec<f64>) -> EinsumOperand<'static> {
1134 let strides = row_major_strides(dims);
1135 StridedArray::from_parts(data, dims, &strides, 0)
1136 .unwrap()
1137 .into()
1138 }
1139
1140 #[test]
1141 fn test_binary_output_ids_canonical_lo_ro_batch_order() {
1142 let out = compute_binary_output_ids(&['b', 'a', 'x'], &['x', 'c', 'a'], &['b', 'c', 'a']);
1143 assert_eq!(out, vec!['b', 'c', 'a']);
1144 }
1145
1146 #[test]
1147 fn test_matmul() {
1148 let code = parse_einsum("ij,jk->ik").unwrap();
1149 let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1150 let b = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1151 let result = code.evaluate(vec![a, b], None).unwrap();
1152 match result {
1153 EinsumOperand::F64(data) => {
1154 let arr = data.as_array();
1155 assert_eq!(arr.dims(), &[2, 2]);
1156 assert_abs_diff_eq!(arr.get(&[0, 0]), 19.0);
1157 assert_abs_diff_eq!(arr.get(&[0, 1]), 22.0);
1158 assert_abs_diff_eq!(arr.get(&[1, 0]), 43.0);
1159 assert_abs_diff_eq!(arr.get(&[1, 1]), 50.0);
1160 }
1161 _ => panic!("expected F64"),
1162 }
1163 }
1164
1165 #[test]
1166 fn test_nested_three_tensor() {
1167 let code = parse_einsum("(ij,jk),kl->il").unwrap();
1168 let a = make_f64(&[2, 2], vec![1.0, 0.0, 0.0, 1.0]);
1170 let b = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1171 let c = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1172 let result = code.evaluate(vec![a, b, c], None).unwrap();
1174 match result {
1175 EinsumOperand::F64(data) => {
1176 let arr = data.as_array();
1177 assert_eq!(arr.dims(), &[2, 2]);
1178 assert_abs_diff_eq!(arr.get(&[0, 0]), 19.0);
1179 assert_abs_diff_eq!(arr.get(&[0, 1]), 22.0);
1180 assert_abs_diff_eq!(arr.get(&[1, 0]), 43.0);
1181 assert_abs_diff_eq!(arr.get(&[1, 1]), 50.0);
1182 }
1183 _ => panic!("expected F64"),
1184 }
1185 }
1186
1187 #[test]
1188 fn test_outer_product() {
1189 let code = parse_einsum("i,j->ij").unwrap();
1190 let a = make_f64(&[3], vec![1.0, 2.0, 3.0]);
1191 let b = make_f64(&[2], vec![10.0, 20.0]);
1192 let result = code.evaluate(vec![a, b], None).unwrap();
1193 match result {
1194 EinsumOperand::F64(data) => {
1195 let arr = data.as_array();
1196 assert_eq!(arr.dims(), &[3, 2]);
1197 assert_abs_diff_eq!(arr.get(&[0, 0]), 10.0);
1198 assert_abs_diff_eq!(arr.get(&[2, 1]), 60.0);
1199 }
1200 _ => panic!("expected F64"),
1201 }
1202 }
1203
1204 #[test]
1205 fn test_dot_product() {
1206 let code = parse_einsum("i,i->").unwrap();
1207 let a = make_f64(&[3], vec![1.0, 2.0, 3.0]);
1208 let b = make_f64(&[3], vec![4.0, 5.0, 6.0]);
1209 let result = code.evaluate(vec![a, b], None).unwrap();
1210 match result {
1211 EinsumOperand::F64(data) => {
1212 assert_abs_diff_eq!(data.as_array().data()[0], 32.0);
1214 }
1215 _ => panic!("expected F64"),
1216 }
1217 }
1218
1219 #[test]
1220 fn test_single_tensor_permute() {
1221 let code = parse_einsum("ij->ji").unwrap();
1222 let a = make_f64(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1223 let result = code.evaluate(vec![a], None).unwrap();
1224 match result {
1225 EinsumOperand::F64(data) => {
1226 let arr = data.as_array();
1227 assert_eq!(arr.dims(), &[3, 2]);
1228 assert_abs_diff_eq!(arr.get(&[0, 0]), 1.0);
1229 assert_abs_diff_eq!(arr.get(&[0, 1]), 4.0);
1230 }
1231 _ => panic!("expected F64"),
1232 }
1233 }
1234
1235 #[test]
1236 fn test_single_tensor_trace() {
1237 let code = parse_einsum("ii->").unwrap();
1238 let a = make_f64(&[3, 3], vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]);
1239 let result = code.evaluate(vec![a], None).unwrap();
1240 match result {
1241 EinsumOperand::F64(data) => {
1242 assert_abs_diff_eq!(data.as_array().data()[0], 6.0);
1243 }
1244 _ => panic!("expected F64"),
1245 }
1246 }
1247
1248 #[test]
1249 fn test_three_tensor_flat_omeco() {
1250 let code = parse_einsum("ij,jk,kl->il").unwrap();
1252 let a = make_f64(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1253 let b = make_f64(&[3, 2], vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0]);
1254 let c = make_f64(&[2, 2], vec![1.0, 0.0, 0.0, 1.0]);
1256 let result = code.evaluate(vec![a, b, c], None).unwrap();
1257 match result {
1258 EinsumOperand::F64(data) => {
1259 let arr = data.as_array();
1260 assert_eq!(arr.dims(), &[2, 2]);
1261 assert_abs_diff_eq!(arr.get(&[0, 0]), 4.0, epsilon = 1e-10);
1262 assert_abs_diff_eq!(arr.get(&[0, 1]), 2.0, epsilon = 1e-10);
1263 assert_abs_diff_eq!(arr.get(&[1, 0]), 10.0, epsilon = 1e-10);
1264 assert_abs_diff_eq!(arr.get(&[1, 1]), 5.0, epsilon = 1e-10);
1265 }
1266 _ => panic!("expected F64"),
1267 }
1268 }
1269
1270 #[test]
1271 fn test_four_tensor_flat_omeco() {
1272 let code = parse_einsum("ij,jk,kl,lm->im").unwrap();
1274 let a = make_f64(&[2, 2], vec![1.0, 0.0, 0.0, 1.0]); let b = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1276 let c = make_f64(&[2, 2], vec![1.0, 0.0, 0.0, 1.0]); let d = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1278 let result = code.evaluate(vec![a, b, c, d], None).unwrap();
1280 match result {
1281 EinsumOperand::F64(data) => {
1282 let arr = data.as_array();
1283 assert_eq!(arr.dims(), &[2, 2]);
1284 assert_abs_diff_eq!(arr.get(&[0, 0]), 19.0, epsilon = 1e-10);
1285 assert_abs_diff_eq!(arr.get(&[1, 1]), 50.0, epsilon = 1e-10);
1286 }
1287 _ => panic!("expected F64"),
1288 }
1289 }
1290
1291 #[test]
1292 fn test_orphan_output_axis_returns_error() {
1293 let code = parse_einsum("ij,jk->iz").unwrap();
1294 let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1295 let b = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1296 let err = code.evaluate(vec![a, b], None).unwrap_err();
1297 assert!(matches!(err, crate::EinsumError::OrphanOutputAxis(ref s) if s == "z"));
1298 }
1299
1300 #[test]
1301 fn test_operand_count_mismatch_too_few() {
1302 let code = parse_einsum("ij,jk->ik").unwrap();
1303 let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1304 let err = code.evaluate(vec![a], None).unwrap_err();
1305 assert!(matches!(
1306 err,
1307 crate::EinsumError::OperandCountMismatch {
1308 expected: 2,
1309 found: 1
1310 }
1311 ));
1312 }
1313
1314 #[test]
1315 fn test_operand_count_mismatch_too_many() {
1316 let code = parse_einsum("ij->ji").unwrap();
1317 let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1318 let b = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1319 let err = code.evaluate(vec![a, b], None).unwrap_err();
1320 assert!(matches!(
1321 err,
1322 crate::EinsumError::OperandCountMismatch {
1323 expected: 1,
1324 found: 2
1325 }
1326 ));
1327 }
1328
1329 #[test]
1334 fn test_into_matmul() {
1335 let code = parse_einsum("ij,jk->ik").unwrap();
1336 let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1337 let b = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1338 let mut c = StridedArray::<f64>::col_major(&[2, 2]);
1339 code.evaluate_into(vec![a, b], c.view_mut(), 1.0, 0.0, None)
1340 .unwrap();
1341 assert_abs_diff_eq!(c.get(&[0, 0]), 19.0);
1342 assert_abs_diff_eq!(c.get(&[0, 1]), 22.0);
1343 assert_abs_diff_eq!(c.get(&[1, 0]), 43.0);
1344 assert_abs_diff_eq!(c.get(&[1, 1]), 50.0);
1345 }
1346
1347 #[test]
1348 fn test_into_matmul_alpha_beta() {
1349 let code = parse_einsum("ij,jk->ik").unwrap();
1351 let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1352 let b = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1353 let mut c = StridedArray::<f64>::col_major(&[2, 2]);
1357 for v in c.data_mut().iter_mut() {
1358 *v = 1.0;
1359 }
1360 code.evaluate_into(vec![a, b], c.view_mut(), 2.0, 3.0, None)
1361 .unwrap();
1362 assert_abs_diff_eq!(c.get(&[0, 0]), 41.0, epsilon = 1e-10);
1363 assert_abs_diff_eq!(c.get(&[0, 1]), 47.0, epsilon = 1e-10);
1364 assert_abs_diff_eq!(c.get(&[1, 0]), 89.0, epsilon = 1e-10);
1365 assert_abs_diff_eq!(c.get(&[1, 1]), 103.0, epsilon = 1e-10);
1366 }
1367
1368 #[test]
1369 fn test_into_single_tensor_permute() {
1370 let code = parse_einsum("ij->ji").unwrap();
1371 let a = make_f64(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1372 let mut c = StridedArray::<f64>::col_major(&[3, 2]);
1373 code.evaluate_into(vec![a], c.view_mut(), 1.0, 0.0, None)
1374 .unwrap();
1375 assert_eq!(c.dims(), &[3, 2]);
1376 assert_abs_diff_eq!(c.get(&[0, 0]), 1.0);
1377 assert_abs_diff_eq!(c.get(&[0, 1]), 4.0);
1378 assert_abs_diff_eq!(c.get(&[1, 0]), 2.0);
1379 assert_abs_diff_eq!(c.get(&[2, 1]), 6.0);
1380 }
1381
1382 #[test]
1383 fn test_into_single_tensor_trace() {
1384 let code = parse_einsum("ii->").unwrap();
1385 let a = make_f64(&[3, 3], vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]);
1386 let mut c = StridedArray::<f64>::col_major(&[]);
1387 code.evaluate_into(vec![a], c.view_mut(), 1.0, 0.0, None)
1388 .unwrap();
1389 assert_abs_diff_eq!(c.data()[0], 6.0);
1390 }
1391
1392 #[test]
1393 fn test_into_three_tensor_omeco() {
1394 let code = parse_einsum("ij,jk,kl->il").unwrap();
1395 let a = make_f64(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1396 let b = make_f64(&[3, 2], vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0]);
1397 let c_op = make_f64(&[2, 2], vec![1.0, 0.0, 0.0, 1.0]);
1398 let mut out = StridedArray::<f64>::col_major(&[2, 2]);
1399 code.evaluate_into(vec![a, b, c_op], out.view_mut(), 1.0, 0.0, None)
1400 .unwrap();
1401 assert_abs_diff_eq!(out.get(&[0, 0]), 4.0, epsilon = 1e-10);
1402 assert_abs_diff_eq!(out.get(&[0, 1]), 2.0, epsilon = 1e-10);
1403 assert_abs_diff_eq!(out.get(&[1, 0]), 10.0, epsilon = 1e-10);
1404 assert_abs_diff_eq!(out.get(&[1, 1]), 5.0, epsilon = 1e-10);
1405 }
1406
1407 #[test]
1408 fn test_into_nested() {
1409 let code = parse_einsum("(ij,jk),kl->il").unwrap();
1410 let a = make_f64(&[2, 2], vec![1.0, 0.0, 0.0, 1.0]);
1411 let b = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1412 let c_op = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1413 let mut out = StridedArray::<f64>::col_major(&[2, 2]);
1414 code.evaluate_into(vec![a, b, c_op], out.view_mut(), 1.0, 0.0, None)
1415 .unwrap();
1416 assert_abs_diff_eq!(out.get(&[0, 0]), 19.0, epsilon = 1e-10);
1417 assert_abs_diff_eq!(out.get(&[0, 1]), 22.0, epsilon = 1e-10);
1418 assert_abs_diff_eq!(out.get(&[1, 0]), 43.0, epsilon = 1e-10);
1419 assert_abs_diff_eq!(out.get(&[1, 1]), 50.0, epsilon = 1e-10);
1420 }
1421
1422 #[test]
1423 fn test_into_dot_product() {
1424 let code = parse_einsum("i,i->").unwrap();
1425 let a = make_f64(&[3], vec![1.0, 2.0, 3.0]);
1426 let b = make_f64(&[3], vec![4.0, 5.0, 6.0]);
1427 let mut c = StridedArray::<f64>::col_major(&[]);
1428 code.evaluate_into(vec![a, b], c.view_mut(), 1.0, 0.0, None)
1429 .unwrap();
1430 assert_abs_diff_eq!(c.data()[0], 32.0);
1431 }
1432
1433 #[test]
1434 fn test_into_type_mismatch_f64_output_c64_input() {
1435 let code = parse_einsum("ij->ji").unwrap();
1436 let c64_data = vec![
1437 Complex64::new(1.0, 0.0),
1438 Complex64::new(2.0, 0.0),
1439 Complex64::new(3.0, 0.0),
1440 Complex64::new(4.0, 0.0),
1441 ];
1442 let strides = row_major_strides(&[2, 2]);
1443 let arr = StridedArray::from_parts(c64_data, &[2, 2], &strides, 0).unwrap();
1444 let op = EinsumOperand::C64(StridedData::Owned(arr));
1445 let mut out = StridedArray::<f64>::col_major(&[2, 2]);
1446 let err = code
1447 .evaluate_into(vec![op], out.view_mut(), 1.0, 0.0, None)
1448 .unwrap_err();
1449 assert!(matches!(err, crate::EinsumError::TypeMismatch { .. }));
1450 }
1451
1452 #[test]
1453 fn test_into_shape_mismatch() {
1454 let code = parse_einsum("ij,jk->ik").unwrap();
1455 let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1456 let b = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1457 let mut out = StridedArray::<f64>::col_major(&[3, 3]); let err = code
1459 .evaluate_into(vec![a, b], out.view_mut(), 1.0, 0.0, None)
1460 .unwrap_err();
1461 assert!(matches!(
1462 err,
1463 crate::EinsumError::OutputShapeMismatch { .. }
1464 ));
1465 }
1466
1467 #[test]
1468 fn test_into_c64_output() {
1469 let code = parse_einsum("ij,jk->ik").unwrap();
1470 let c64 = |r| Complex64::new(r, 0.0);
1471 let a_data = vec![c64(1.0), c64(2.0), c64(3.0), c64(4.0)];
1472 let b_data = vec![c64(5.0), c64(6.0), c64(7.0), c64(8.0)];
1473 let strides = row_major_strides(&[2, 2]);
1474 let a = EinsumOperand::C64(StridedData::Owned(
1475 StridedArray::from_parts(a_data, &[2, 2], &strides, 0).unwrap(),
1476 ));
1477 let b = EinsumOperand::C64(StridedData::Owned(
1478 StridedArray::from_parts(b_data, &[2, 2], &strides, 0).unwrap(),
1479 ));
1480 let mut out = StridedArray::<Complex64>::col_major(&[2, 2]);
1481 code.evaluate_into(
1482 vec![a, b],
1483 out.view_mut(),
1484 c64(1.0),
1485 Complex64::zero(),
1486 None,
1487 )
1488 .unwrap();
1489 assert_abs_diff_eq!(out.get(&[0, 0]).re, 19.0);
1490 assert_abs_diff_eq!(out.get(&[1, 1]).re, 50.0);
1491 }
1492
1493 #[test]
1494 fn test_into_mixed_types_c64_output() {
1495 let code = parse_einsum("ij,jk->ik").unwrap();
1497 let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1498 let c64 = |r| Complex64::new(r, 0.0);
1499 let b_data = vec![c64(5.0), c64(6.0), c64(7.0), c64(8.0)];
1500 let strides = row_major_strides(&[2, 2]);
1501 let b = EinsumOperand::C64(StridedData::Owned(
1502 StridedArray::from_parts(b_data, &[2, 2], &strides, 0).unwrap(),
1503 ));
1504 let mut out = StridedArray::<Complex64>::col_major(&[2, 2]);
1505 code.evaluate_into(
1506 vec![a, b],
1507 out.view_mut(),
1508 c64(1.0),
1509 Complex64::zero(),
1510 None,
1511 )
1512 .unwrap();
1513 assert_abs_diff_eq!(out.get(&[0, 0]).re, 19.0);
1514 assert_abs_diff_eq!(out.get(&[1, 1]).re, 50.0);
1515 }
1516}