1use num_complex::Complex64;
2use num_traits::Zero;
3use strided_einsum2::Scalar;
4use strided_kernel::copy_into;
5use strided_view::{col_major_strides, ElementOpApply, StridedArray, StridedView};
6
7use crate::typed_tensor::TypedTensor;
8
9#[derive(Debug)]
11pub enum StridedData<'a, T> {
12 Owned(StridedArray<T>),
13 View(StridedView<'a, T>),
14}
15
16impl<'a, T> StridedData<'a, T> {
17 pub fn dims(&self) -> &[usize] {
19 match self {
20 StridedData::Owned(arr) => arr.dims(),
21 StridedData::View(view) => view.dims(),
22 }
23 }
24
25 pub fn as_view(&self) -> StridedView<'_, T> {
27 match self {
28 StridedData::Owned(arr) => arr.view(),
29 StridedData::View(view) => view.clone(),
30 }
31 }
32
33 pub fn as_array(&self) -> &StridedArray<T> {
38 match self {
39 StridedData::Owned(arr) => arr,
40 StridedData::View(_) => panic!("StridedData::as_array called on a View variant"),
41 }
42 }
43}
44
45impl<'a, T> StridedData<'a, T> {
46 pub fn permuted(self, perm: &[usize]) -> strided_view::Result<Self> {
48 match self {
49 StridedData::Owned(arr) => Ok(StridedData::Owned(arr.permuted(perm)?)),
50 StridedData::View(view) => Ok(StridedData::View(view.permute(perm)?)),
51 }
52 }
53}
54
55impl<'a, T> StridedData<'a, T>
56where
57 T: Copy + ElementOpApply + Send + Sync + Zero + Default,
58{
59 pub fn into_array(self) -> StridedArray<T> {
64 match self {
65 StridedData::Owned(arr) => arr,
66 StridedData::View(view) => {
67 let dims = view.dims().to_vec();
68 let mut dest = StridedArray::<T>::col_major(&dims);
69 let mut dest_view = dest.view_mut();
70 copy_into(&mut dest_view, &view).expect("copy_into failed in into_array");
71 dest
72 }
73 }
74 }
75}
76
77#[derive(Debug)]
79pub enum EinsumOperand<'a> {
80 F64(StridedData<'a, f64>),
81 C64(StridedData<'a, Complex64>),
82}
83
84impl<'a> EinsumOperand<'a> {
85 pub fn is_f64(&self) -> bool {
87 matches!(self, EinsumOperand::F64(_))
88 }
89
90 pub fn is_c64(&self) -> bool {
92 matches!(self, EinsumOperand::C64(_))
93 }
94
95 pub fn dims(&self) -> &[usize] {
97 match self {
98 EinsumOperand::F64(data) => data.dims(),
99 EinsumOperand::C64(data) => data.dims(),
100 }
101 }
102
103 pub fn permuted(self, perm: &[usize]) -> crate::Result<Self> {
105 match self {
106 EinsumOperand::F64(data) => Ok(EinsumOperand::F64(data.permuted(perm)?)),
107 EinsumOperand::C64(data) => Ok(EinsumOperand::C64(data.permuted(perm)?)),
108 }
109 }
110
111 pub fn from_view<T: EinsumScalar>(view: &StridedView<'a, T>) -> Self {
115 T::wrap_data(StridedData::View(view.clone()))
116 }
117
118 pub fn to_c64_owned_ref(&self) -> EinsumOperand<'static> {
122 match self {
123 EinsumOperand::C64(data) => {
124 let view = data.as_view();
125 let dims = view.dims().to_vec();
126 let mut dest = StridedArray::<Complex64>::col_major(&dims);
127 copy_into(&mut dest.view_mut(), &view).expect("copy_into failed");
128 EinsumOperand::C64(StridedData::Owned(dest))
129 }
130 EinsumOperand::F64(data) => {
131 let view = data.as_view();
132 let dims = view.dims().to_vec();
133 let strides = col_major_strides(&dims);
134 let mut f64_dest = StridedArray::<f64>::col_major(&dims);
135 copy_into(&mut f64_dest.view_mut(), &view).expect("copy_into failed");
136 let c64_data: Vec<Complex64> = f64_dest
137 .data()
138 .iter()
139 .map(|&x| Complex64::new(x, 0.0))
140 .collect();
141 let c64_array = StridedArray::from_parts(c64_data, &dims, &strides, 0)
142 .expect("from_parts failed");
143 EinsumOperand::C64(StridedData::Owned(c64_array))
144 }
145 }
146 }
147
148 pub fn to_c64_owned(self) -> EinsumOperand<'static> {
154 match self {
155 EinsumOperand::C64(data) => EinsumOperand::C64(StridedData::Owned(data.into_array())),
156 EinsumOperand::F64(data) => {
157 let view = data.as_view();
158 let dims = view.dims().to_vec();
159 let strides = col_major_strides(&dims);
160 let f64_array = match data {
163 StridedData::Owned(arr) => arr,
164 StridedData::View(v) => {
165 let mut dest = StridedArray::<f64>::col_major(v.dims());
166 let mut dest_view = dest.view_mut();
167 copy_into(&mut dest_view, &v).expect("copy_into failed in to_c64_owned");
168 dest
169 }
170 };
171 let c64_data: Vec<Complex64> = f64_array
172 .data()
173 .iter()
174 .map(|&x| Complex64::new(x, 0.0))
175 .collect();
176 let c64_array = StridedArray::from_parts(c64_data, &dims, &strides, 0)
177 .expect("from_parts failed in to_c64_owned");
178 EinsumOperand::C64(StridedData::Owned(c64_array))
179 }
180 }
181 }
182}
183
184impl From<StridedArray<f64>> for EinsumOperand<'static> {
185 fn from(arr: StridedArray<f64>) -> Self {
186 EinsumOperand::F64(StridedData::Owned(arr))
187 }
188}
189
190impl From<StridedArray<Complex64>> for EinsumOperand<'static> {
191 fn from(arr: StridedArray<Complex64>) -> Self {
192 EinsumOperand::C64(StridedData::Owned(arr))
193 }
194}
195
196mod private {
201 pub trait Sealed {}
202 impl Sealed for f64 {}
203 impl Sealed for num_complex::Complex64 {}
204}
205
206pub trait EinsumScalar: private::Sealed + Scalar + Default + 'static {
210 fn type_name() -> &'static str;
212
213 fn wrap_data(data: StridedData<'_, Self>) -> EinsumOperand<'_>;
215
216 fn wrap_array(arr: StridedArray<Self>) -> EinsumOperand<'static>;
218
219 fn wrap_typed_tensor(arr: StridedArray<Self>) -> TypedTensor;
221
222 fn extract_data<'a>(op: EinsumOperand<'a>) -> crate::Result<StridedData<'a, Self>>;
227
228 fn validate_operands(ops: &[Option<EinsumOperand<'_>>]) -> crate::Result<()>;
231}
232
233impl EinsumScalar for f64 {
234 fn type_name() -> &'static str {
235 "f64"
236 }
237
238 fn wrap_data(data: StridedData<'_, Self>) -> EinsumOperand<'_> {
239 EinsumOperand::F64(data)
240 }
241
242 fn wrap_array(arr: StridedArray<Self>) -> EinsumOperand<'static> {
243 EinsumOperand::F64(StridedData::Owned(arr))
244 }
245
246 fn wrap_typed_tensor(arr: StridedArray<Self>) -> TypedTensor {
247 TypedTensor::F64(arr)
248 }
249
250 fn extract_data<'a>(op: EinsumOperand<'a>) -> crate::Result<StridedData<'a, f64>> {
251 match op {
252 EinsumOperand::F64(data) => Ok(data),
253 EinsumOperand::C64(_) => Err(crate::EinsumError::TypeMismatch {
254 output_type: "f64",
255 computed_type: "Complex64",
256 }),
257 }
258 }
259
260 fn validate_operands(ops: &[Option<EinsumOperand<'_>>]) -> crate::Result<()> {
261 for op in ops.iter().flatten() {
262 if op.is_c64() {
263 return Err(crate::EinsumError::TypeMismatch {
264 output_type: "f64",
265 computed_type: "Complex64",
266 });
267 }
268 }
269 Ok(())
270 }
271}
272
273impl EinsumScalar for Complex64 {
274 fn type_name() -> &'static str {
275 "Complex64"
276 }
277
278 fn wrap_data(data: StridedData<'_, Self>) -> EinsumOperand<'_> {
279 EinsumOperand::C64(data)
280 }
281
282 fn wrap_array(arr: StridedArray<Self>) -> EinsumOperand<'static> {
283 EinsumOperand::C64(StridedData::Owned(arr))
284 }
285
286 fn wrap_typed_tensor(arr: StridedArray<Self>) -> TypedTensor {
287 TypedTensor::C64(arr)
288 }
289
290 fn extract_data<'a>(op: EinsumOperand<'a>) -> crate::Result<StridedData<'a, Complex64>> {
291 match op {
292 EinsumOperand::C64(data) => Ok(data),
293 EinsumOperand::F64(data) => {
294 let view = data.as_view();
297 let dims = view.dims().to_vec();
298 let strides = col_major_strides(&dims);
299 let mut f64_col = StridedArray::<f64>::col_major(&dims);
300 copy_into(&mut f64_col.view_mut(), &view)
301 .expect("copy_into failed in extract_data");
302 let c64_data: Vec<Complex64> = f64_col
303 .data()
304 .iter()
305 .map(|&x| Complex64::new(x, 0.0))
306 .collect();
307 let c64_array = StridedArray::from_parts(c64_data, &dims, &strides, 0)
308 .expect("from_parts failed in extract_data");
309 Ok(StridedData::Owned(c64_array))
310 }
311 }
312 }
313
314 fn validate_operands(_ops: &[Option<EinsumOperand<'_>>]) -> crate::Result<()> {
315 Ok(())
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use num_complex::Complex64;
324 use strided_view::StridedArray;
325
326 #[test]
327 fn test_f64_owned() {
328 let arr = StridedArray::<f64>::col_major(&[2, 3]);
329 let op = EinsumOperand::from(arr);
330 assert!(op.is_f64());
331 assert!(!op.is_c64());
332 assert_eq!(op.dims(), &[2, 3]);
333 }
334
335 #[test]
336 fn test_c64_owned() {
337 let arr = StridedArray::<Complex64>::col_major(&[4, 5]);
338 let op = EinsumOperand::from(arr);
339 assert!(op.is_c64());
340 assert_eq!(op.dims(), &[4, 5]);
341 }
342
343 #[test]
344 fn test_f64_view() {
345 let arr = StridedArray::<f64>::col_major(&[2, 3]);
346 let view = arr.view();
347 let op = EinsumOperand::from_view(&view);
348 assert!(op.is_f64());
349 assert_eq!(op.dims(), &[2, 3]);
350 }
351
352 #[test]
353 fn test_promote_f64_to_c64() {
354 let mut arr = StridedArray::<f64>::col_major(&[2, 2]);
355 arr.data_mut()[0] = 1.0;
356 arr.data_mut()[1] = 2.0;
357 arr.data_mut()[2] = 3.0;
358 arr.data_mut()[3] = 4.0;
359 let op = EinsumOperand::from(arr);
360 let promoted = op.to_c64_owned();
361 assert!(promoted.is_c64());
362 match &promoted {
363 EinsumOperand::C64(StridedData::Owned(arr)) => {
364 assert_eq!(arr.data()[0], Complex64::new(1.0, 0.0));
365 assert_eq!(arr.data()[1], Complex64::new(2.0, 0.0));
366 }
367 _ => panic!("expected C64 Owned"),
368 }
369 }
370
371 #[test]
376 fn test_to_c64_owned_ref_from_f64_owned() {
377 let mut arr = StridedArray::<f64>::col_major(&[2, 2]);
378 arr.data_mut()[0] = 1.0;
379 arr.data_mut()[1] = 2.0;
380 arr.data_mut()[2] = 3.0;
381 arr.data_mut()[3] = 4.0;
382 let op = EinsumOperand::from(arr);
383 let promoted = op.to_c64_owned_ref();
384 assert!(promoted.is_c64());
385 match &promoted {
386 EinsumOperand::C64(StridedData::Owned(arr)) => {
387 assert_eq!(arr.dims(), &[2, 2]);
388 assert_eq!(arr.data()[0], Complex64::new(1.0, 0.0));
389 assert_eq!(arr.data()[1], Complex64::new(2.0, 0.0));
390 assert_eq!(arr.data()[2], Complex64::new(3.0, 0.0));
391 assert_eq!(arr.data()[3], Complex64::new(4.0, 0.0));
392 }
393 _ => panic!("expected C64 Owned"),
394 }
395 }
396
397 #[test]
398 fn test_to_c64_owned_ref_from_f64_view() {
399 let mut arr = StridedArray::<f64>::col_major(&[2, 2]);
400 arr.data_mut()[0] = 5.0;
401 arr.data_mut()[1] = 6.0;
402 arr.data_mut()[2] = 7.0;
403 arr.data_mut()[3] = 8.0;
404 let view = arr.view();
405 let op = EinsumOperand::from_view(&view);
406 let promoted = op.to_c64_owned_ref();
407 assert!(promoted.is_c64());
408 match &promoted {
409 EinsumOperand::C64(StridedData::Owned(c_arr)) => {
410 assert_eq!(c_arr.dims(), &[2, 2]);
411 assert_eq!(c_arr.data()[0], Complex64::new(5.0, 0.0));
412 assert_eq!(c_arr.data()[1], Complex64::new(6.0, 0.0));
413 assert_eq!(c_arr.data()[2], Complex64::new(7.0, 0.0));
414 assert_eq!(c_arr.data()[3], Complex64::new(8.0, 0.0));
415 }
416 _ => panic!("expected C64 Owned"),
417 }
418 }
419
420 #[test]
421 fn test_to_c64_owned_ref_from_c64_owned() {
422 let mut arr = StridedArray::<Complex64>::col_major(&[2, 2]);
423 arr.data_mut()[0] = Complex64::new(1.0, 2.0);
424 arr.data_mut()[1] = Complex64::new(3.0, 4.0);
425 arr.data_mut()[2] = Complex64::new(5.0, 6.0);
426 arr.data_mut()[3] = Complex64::new(7.0, 8.0);
427 let op = EinsumOperand::from(arr);
428 let copied = op.to_c64_owned_ref();
429 assert!(copied.is_c64());
430 match &copied {
431 EinsumOperand::C64(StridedData::Owned(c_arr)) => {
432 assert_eq!(c_arr.dims(), &[2, 2]);
433 assert_eq!(c_arr.data()[0], Complex64::new(1.0, 2.0));
434 assert_eq!(c_arr.data()[1], Complex64::new(3.0, 4.0));
435 assert_eq!(c_arr.data()[2], Complex64::new(5.0, 6.0));
436 assert_eq!(c_arr.data()[3], Complex64::new(7.0, 8.0));
437 }
438 _ => panic!("expected C64 Owned"),
439 }
440 }
441
442 #[test]
443 fn test_to_c64_owned_ref_from_c64_view() {
444 let mut arr = StridedArray::<Complex64>::col_major(&[3]);
445 arr.data_mut()[0] = Complex64::new(1.0, -1.0);
446 arr.data_mut()[1] = Complex64::new(2.0, -2.0);
447 arr.data_mut()[2] = Complex64::new(3.0, -3.0);
448 let view = arr.view();
449 let op = EinsumOperand::from_view(&view);
450 let copied = op.to_c64_owned_ref();
451 assert!(copied.is_c64());
452 match &copied {
453 EinsumOperand::C64(StridedData::Owned(c_arr)) => {
454 assert_eq!(c_arr.dims(), &[3]);
455 assert_eq!(c_arr.data()[0], Complex64::new(1.0, -1.0));
456 assert_eq!(c_arr.data()[1], Complex64::new(2.0, -2.0));
457 assert_eq!(c_arr.data()[2], Complex64::new(3.0, -3.0));
458 }
459 _ => panic!("expected C64 Owned"),
460 }
461 }
462
463 #[test]
468 fn test_strided_data_into_array_from_owned() {
469 let mut arr = StridedArray::<f64>::col_major(&[2, 3]);
470 for (i, v) in arr.data_mut().iter_mut().enumerate() {
471 *v = i as f64;
472 }
473 let data = StridedData::Owned(arr);
474 let result = data.into_array();
475 assert_eq!(result.dims(), &[2, 3]);
476 assert_eq!(result.data()[0], 0.0);
477 assert_eq!(result.data()[5], 5.0);
478 }
479
480 #[test]
481 fn test_strided_data_into_array_from_view() {
482 let mut arr = StridedArray::<f64>::col_major(&[2, 3]);
483 for (i, v) in arr.data_mut().iter_mut().enumerate() {
484 *v = (i as f64) * 10.0;
485 }
486 let view = arr.view();
487 let data = StridedData::<f64>::View(view);
488 let result = data.into_array();
489 assert_eq!(result.dims(), &[2, 3]);
490 assert_eq!(result.get(&[0, 0]), 0.0);
492 assert_eq!(result.get(&[1, 0]), 10.0);
493 }
494
495 #[test]
496 fn test_strided_data_into_array_from_view_c64() {
497 let mut arr = StridedArray::<Complex64>::col_major(&[2, 2]);
498 arr.data_mut()[0] = Complex64::new(1.0, 2.0);
499 arr.data_mut()[1] = Complex64::new(3.0, 4.0);
500 arr.data_mut()[2] = Complex64::new(5.0, 6.0);
501 arr.data_mut()[3] = Complex64::new(7.0, 8.0);
502 let view = arr.view();
503 let data = StridedData::<Complex64>::View(view);
504 let result = data.into_array();
505 assert_eq!(result.dims(), &[2, 2]);
506 assert_eq!(result.get(&[0, 0]), Complex64::new(1.0, 2.0));
507 assert_eq!(result.get(&[1, 1]), Complex64::new(7.0, 8.0));
508 }
509
510 #[test]
515 fn test_strided_data_as_array_owned() {
516 let arr = StridedArray::<f64>::col_major(&[3, 2]);
517 let data = StridedData::Owned(arr);
518 let array_ref = data.as_array();
519 assert_eq!(array_ref.dims(), &[3, 2]);
520 }
521
522 #[test]
523 #[should_panic(expected = "StridedData::as_array called on a View variant")]
524 fn test_strided_data_as_array_view_panics() {
525 let arr = StridedArray::<f64>::col_major(&[3, 2]);
526 let view = arr.view();
527 let data = StridedData::<f64>::View(view);
528 let _ = data.as_array(); }
530
531 #[test]
536 fn test_validate_operands_f64_all_f64() {
537 let arr1 = StridedArray::<f64>::col_major(&[2, 2]);
538 let arr2 = StridedArray::<f64>::col_major(&[2, 2]);
539 let ops: Vec<Option<EinsumOperand>> = vec![
540 Some(EinsumOperand::from(arr1)),
541 Some(EinsumOperand::from(arr2)),
542 ];
543 assert!(f64::validate_operands(&ops).is_ok());
544 }
545
546 #[test]
547 fn test_validate_operands_f64_with_none() {
548 let arr = StridedArray::<f64>::col_major(&[2, 2]);
550 let ops: Vec<Option<EinsumOperand>> = vec![Some(EinsumOperand::from(arr)), None];
551 assert!(f64::validate_operands(&ops).is_ok());
552 }
553
554 #[test]
555 fn test_validate_operands_f64_with_c64_returns_error() {
556 let f64_arr = StridedArray::<f64>::col_major(&[2, 2]);
557 let c64_arr = StridedArray::<Complex64>::col_major(&[2, 2]);
558 let ops: Vec<Option<EinsumOperand>> = vec![
559 Some(EinsumOperand::from(f64_arr)),
560 Some(EinsumOperand::from(c64_arr)),
561 ];
562 let err = f64::validate_operands(&ops).unwrap_err();
563 assert!(matches!(
564 err,
565 crate::EinsumError::TypeMismatch {
566 output_type: "f64",
567 computed_type: "Complex64",
568 }
569 ));
570 }
571
572 #[test]
573 fn test_validate_operands_c64_accepts_anything() {
574 let f64_arr = StridedArray::<f64>::col_major(&[2, 2]);
575 let c64_arr = StridedArray::<Complex64>::col_major(&[2, 2]);
576 let ops: Vec<Option<EinsumOperand>> = vec![
577 Some(EinsumOperand::from(f64_arr)),
578 Some(EinsumOperand::from(c64_arr)),
579 ];
580 assert!(Complex64::validate_operands(&ops).is_ok());
581 }
582
583 #[test]
584 fn test_validate_operands_c64_all_f64() {
585 let arr1 = StridedArray::<f64>::col_major(&[2, 2]);
586 let arr2 = StridedArray::<f64>::col_major(&[2, 2]);
587 let ops: Vec<Option<EinsumOperand>> = vec![
588 Some(EinsumOperand::from(arr1)),
589 Some(EinsumOperand::from(arr2)),
590 ];
591 assert!(Complex64::validate_operands(&ops).is_ok());
592 }
593
594 #[test]
599 fn test_extract_data_f64_from_f64() {
600 let mut arr = StridedArray::<f64>::col_major(&[2, 2]);
601 arr.data_mut()[0] = 42.0;
602 let op = EinsumOperand::from(arr);
603 let data = f64::extract_data(op).unwrap();
604 assert_eq!(data.as_view().get(&[0, 0]), 42.0);
605 }
606
607 #[test]
608 fn test_extract_data_f64_from_c64_returns_error() {
609 let arr = StridedArray::<Complex64>::col_major(&[2, 2]);
610 let op = EinsumOperand::from(arr);
611 let err = f64::extract_data(op).unwrap_err();
612 assert!(matches!(
613 err,
614 crate::EinsumError::TypeMismatch {
615 output_type: "f64",
616 computed_type: "Complex64",
617 }
618 ));
619 }
620
621 #[test]
622 fn test_extract_data_c64_from_c64() {
623 let mut arr = StridedArray::<Complex64>::col_major(&[2, 2]);
624 arr.data_mut()[0] = Complex64::new(1.0, 2.0);
625 let op = EinsumOperand::from(arr);
626 let data = Complex64::extract_data(op).unwrap();
627 assert_eq!(data.as_view().get(&[0, 0]), Complex64::new(1.0, 2.0));
628 }
629
630 #[test]
631 fn test_extract_data_c64_from_f64_promotes() {
632 let mut arr = StridedArray::<f64>::col_major(&[2, 2]);
633 arr.data_mut()[0] = 5.0;
634 arr.data_mut()[1] = 6.0;
635 arr.data_mut()[2] = 7.0;
636 arr.data_mut()[3] = 8.0;
637 let op = EinsumOperand::from(arr);
638 let data = Complex64::extract_data(op).unwrap();
639 match &data {
641 StridedData::Owned(c_arr) => {
642 assert_eq!(c_arr.dims(), &[2, 2]);
643 assert_eq!(c_arr.data()[0], Complex64::new(5.0, 0.0));
644 assert_eq!(c_arr.data()[1], Complex64::new(6.0, 0.0));
645 assert_eq!(c_arr.data()[2], Complex64::new(7.0, 0.0));
646 assert_eq!(c_arr.data()[3], Complex64::new(8.0, 0.0));
647 }
648 StridedData::View(_) => panic!("expected Owned after promotion"),
649 }
650 }
651
652 #[test]
657 fn test_type_name() {
658 assert_eq!(f64::type_name(), "f64");
659 assert_eq!(Complex64::type_name(), "Complex64");
660 }
661
662 #[test]
667 fn test_strided_data_dims_and_as_view() {
668 let mut arr = StridedArray::<f64>::col_major(&[3, 4]);
669 for (i, v) in arr.data_mut().iter_mut().enumerate() {
670 *v = i as f64;
671 }
672 let owned = StridedData::Owned(arr.clone());
674 assert_eq!(owned.dims(), &[3, 4]);
675 let owned_view = owned.as_view();
676 assert_eq!(owned_view.dims(), &[3, 4]);
677
678 let view = arr.view();
680 let data_view = StridedData::<f64>::View(view);
681 assert_eq!(data_view.dims(), &[3, 4]);
682 let view_again = data_view.as_view();
683 assert_eq!(view_again.dims(), &[3, 4]);
684 }
685
686 #[test]
691 fn test_strided_data_permuted_owned() {
692 let mut arr = StridedArray::<f64>::col_major(&[2, 3]);
693 for (i, v) in arr.data_mut().iter_mut().enumerate() {
694 *v = i as f64;
695 }
696 let data = StridedData::Owned(arr);
697 let permuted = data.permuted(&[1, 0]).unwrap();
698 assert_eq!(permuted.dims(), &[3, 2]);
699 }
700
701 #[test]
702 fn test_strided_data_permuted_view() {
703 let mut arr = StridedArray::<f64>::col_major(&[2, 3]);
704 for (i, v) in arr.data_mut().iter_mut().enumerate() {
705 *v = i as f64;
706 }
707 let view = arr.view();
708 let data = StridedData::<f64>::View(view);
709 let permuted = data.permuted(&[1, 0]).unwrap();
710 assert_eq!(permuted.dims(), &[3, 2]);
711 }
712
713 #[test]
718 fn test_einsum_operand_permuted_f64() {
719 let arr = StridedArray::<f64>::col_major(&[2, 3]);
720 let op = EinsumOperand::from(arr);
721 let permuted = op.permuted(&[1, 0]).unwrap();
722 assert!(permuted.is_f64());
723 assert_eq!(permuted.dims(), &[3, 2]);
724 }
725
726 #[test]
727 fn test_einsum_operand_permuted_c64() {
728 let arr = StridedArray::<Complex64>::col_major(&[4, 5]);
729 let op = EinsumOperand::from(arr);
730 let permuted = op.permuted(&[1, 0]).unwrap();
731 assert!(permuted.is_c64());
732 assert_eq!(permuted.dims(), &[5, 4]);
733 }
734
735 #[test]
740 fn test_to_c64_owned_c64_view() {
741 let mut arr = StridedArray::<Complex64>::col_major(&[2, 2]);
743 arr.data_mut()[0] = Complex64::new(1.0, -1.0);
744 arr.data_mut()[1] = Complex64::new(2.0, -2.0);
745 arr.data_mut()[2] = Complex64::new(3.0, -3.0);
746 arr.data_mut()[3] = Complex64::new(4.0, -4.0);
747 let view = arr.view();
748 let op = EinsumOperand::from_view(&view);
749 let owned = op.to_c64_owned();
750 assert!(owned.is_c64());
751 match &owned {
752 EinsumOperand::C64(StridedData::Owned(c_arr)) => {
753 assert_eq!(c_arr.dims(), &[2, 2]);
754 assert_eq!(c_arr.data()[0], Complex64::new(1.0, -1.0));
755 assert_eq!(c_arr.data()[3], Complex64::new(4.0, -4.0));
756 }
757 _ => panic!("expected C64 Owned"),
758 }
759 }
760
761 #[test]
762 fn test_to_c64_owned_c64_already_owned() {
763 let mut arr = StridedArray::<Complex64>::col_major(&[2]);
765 arr.data_mut()[0] = Complex64::new(10.0, 20.0);
766 arr.data_mut()[1] = Complex64::new(30.0, 40.0);
767 let op = EinsumOperand::from(arr);
768 let owned = op.to_c64_owned();
769 assert!(owned.is_c64());
770 match &owned {
771 EinsumOperand::C64(StridedData::Owned(c_arr)) => {
772 assert_eq!(c_arr.data()[0], Complex64::new(10.0, 20.0));
773 assert_eq!(c_arr.data()[1], Complex64::new(30.0, 40.0));
774 }
775 _ => panic!("expected C64 Owned"),
776 }
777 }
778
779 #[test]
780 fn test_to_c64_owned_f64_view() {
781 let mut arr = StridedArray::<f64>::col_major(&[3]);
783 arr.data_mut()[0] = 10.0;
784 arr.data_mut()[1] = 20.0;
785 arr.data_mut()[2] = 30.0;
786 let view = arr.view();
787 let op = EinsumOperand::from_view(&view);
788 let promoted = op.to_c64_owned();
789 assert!(promoted.is_c64());
790 match &promoted {
791 EinsumOperand::C64(StridedData::Owned(c_arr)) => {
792 assert_eq!(c_arr.dims(), &[3]);
793 assert_eq!(c_arr.data()[0], Complex64::new(10.0, 0.0));
794 assert_eq!(c_arr.data()[1], Complex64::new(20.0, 0.0));
795 assert_eq!(c_arr.data()[2], Complex64::new(30.0, 0.0));
796 }
797 _ => panic!("expected C64 Owned"),
798 }
799 }
800}