strided_opteinsum/
operand.rs

1use num_complex::Complex64;
2use num_traits::Zero;
3use strided_einsum2::Scalar;
4use strided_kernel::copy_into;
5use strided_view::{col_major_strides, ElementOpApply, StridedArray, StridedView};
6
7use crate::typed_tensor::TypedTensor;
8
9/// Type-erased strided data that can be either owned or borrowed.
10#[derive(Debug)]
11pub enum StridedData<'a, T> {
12    Owned(StridedArray<T>),
13    View(StridedView<'a, T>),
14}
15
16impl<'a, T> StridedData<'a, T> {
17    /// Return the dimensions of the underlying data.
18    pub fn dims(&self) -> &[usize] {
19        match self {
20            StridedData::Owned(arr) => arr.dims(),
21            StridedData::View(view) => view.dims(),
22        }
23    }
24
25    /// Return an immutable strided view over the data.
26    pub fn as_view(&self) -> StridedView<'_, T> {
27        match self {
28            StridedData::Owned(arr) => arr.view(),
29            StridedData::View(view) => view.clone(),
30        }
31    }
32
33    /// Return a reference to the owned array.
34    ///
35    /// # Panics
36    /// Panics if this is a `View` variant.
37    pub fn as_array(&self) -> &StridedArray<T> {
38        match self {
39            StridedData::Owned(arr) => arr,
40            StridedData::View(_) => panic!("StridedData::as_array called on a View variant"),
41        }
42    }
43}
44
45impl<'a, T> StridedData<'a, T> {
46    /// Permute dimensions (metadata-only reorder, no data copy).
47    pub fn permuted(self, perm: &[usize]) -> strided_view::Result<Self> {
48        match self {
49            StridedData::Owned(arr) => Ok(StridedData::Owned(arr.permuted(perm)?)),
50            StridedData::View(view) => Ok(StridedData::View(view.permute(perm)?)),
51        }
52    }
53}
54
55impl<'a, T> StridedData<'a, T>
56where
57    T: Copy + ElementOpApply + Send + Sync + Zero + Default,
58{
59    /// Convert into an owned `StridedArray`.
60    ///
61    /// If already owned, returns the inner array directly.
62    /// If a view, copies the data into a new column-major array.
63    pub fn into_array(self) -> StridedArray<T> {
64        match self {
65            StridedData::Owned(arr) => arr,
66            StridedData::View(view) => {
67                let dims = view.dims().to_vec();
68                let mut dest = StridedArray::<T>::col_major(&dims);
69                let mut dest_view = dest.view_mut();
70                copy_into(&mut dest_view, &view).expect("copy_into failed in into_array");
71                dest
72            }
73        }
74    }
75}
76
77/// A type-erased einsum operand holding either f64 or Complex64 strided data.
78#[derive(Debug)]
79pub enum EinsumOperand<'a> {
80    F64(StridedData<'a, f64>),
81    C64(StridedData<'a, Complex64>),
82}
83
84impl<'a> EinsumOperand<'a> {
85    /// Returns `true` if this operand holds f64 data.
86    pub fn is_f64(&self) -> bool {
87        matches!(self, EinsumOperand::F64(_))
88    }
89
90    /// Returns `true` if this operand holds Complex64 data.
91    pub fn is_c64(&self) -> bool {
92        matches!(self, EinsumOperand::C64(_))
93    }
94
95    /// Return the dimensions of the underlying data.
96    pub fn dims(&self) -> &[usize] {
97        match self {
98            EinsumOperand::F64(data) => data.dims(),
99            EinsumOperand::C64(data) => data.dims(),
100        }
101    }
102
103    /// Permute dimensions (metadata-only reorder, no data copy).
104    pub fn permuted(self, perm: &[usize]) -> crate::Result<Self> {
105        match self {
106            EinsumOperand::F64(data) => Ok(EinsumOperand::F64(data.permuted(perm)?)),
107            EinsumOperand::C64(data) => Ok(EinsumOperand::C64(data.permuted(perm)?)),
108        }
109    }
110
111    /// Create an `EinsumOperand` from a borrowed strided view.
112    ///
113    /// Type inference selects the correct variant (`F64` or `C64`) from the view's element type.
114    pub fn from_view<T: EinsumScalar>(view: &StridedView<'a, T>) -> Self {
115        T::wrap_data(StridedData::View(view.clone()))
116    }
117
118    /// Promote to an owned Complex64 operand by borrowing the data.
119    ///
120    /// Unlike `to_c64_owned`, this works on `&self` and always copies.
121    pub fn to_c64_owned_ref(&self) -> EinsumOperand<'static> {
122        match self {
123            EinsumOperand::C64(data) => {
124                let view = data.as_view();
125                let dims = view.dims().to_vec();
126                let mut dest = StridedArray::<Complex64>::col_major(&dims);
127                copy_into(&mut dest.view_mut(), &view).expect("copy_into failed");
128                EinsumOperand::C64(StridedData::Owned(dest))
129            }
130            EinsumOperand::F64(data) => {
131                let view = data.as_view();
132                let dims = view.dims().to_vec();
133                let strides = col_major_strides(&dims);
134                let mut f64_dest = StridedArray::<f64>::col_major(&dims);
135                copy_into(&mut f64_dest.view_mut(), &view).expect("copy_into failed");
136                let c64_data: Vec<Complex64> = f64_dest
137                    .data()
138                    .iter()
139                    .map(|&x| Complex64::new(x, 0.0))
140                    .collect();
141                let c64_array = StridedArray::from_parts(c64_data, &dims, &strides, 0)
142                    .expect("from_parts failed");
143                EinsumOperand::C64(StridedData::Owned(c64_array))
144            }
145        }
146    }
147
148    /// Promote to an owned Complex64 operand.
149    ///
150    /// - If already C64 and owned, returns as-is.
151    /// - If C64 view, copies into an owned array.
152    /// - If F64, converts each element to `Complex64` and returns an owned array.
153    pub fn to_c64_owned(self) -> EinsumOperand<'static> {
154        match self {
155            EinsumOperand::C64(data) => EinsumOperand::C64(StridedData::Owned(data.into_array())),
156            EinsumOperand::F64(data) => {
157                let view = data.as_view();
158                let dims = view.dims().to_vec();
159                let strides = col_major_strides(&dims);
160                // Build a new col-major array by copying data through the view
161                // First materialize f64 into a col-major owned array, then convert
162                let f64_array = match data {
163                    StridedData::Owned(arr) => arr,
164                    StridedData::View(v) => {
165                        let mut dest = StridedArray::<f64>::col_major(v.dims());
166                        let mut dest_view = dest.view_mut();
167                        copy_into(&mut dest_view, &v).expect("copy_into failed in to_c64_owned");
168                        dest
169                    }
170                };
171                let c64_data: Vec<Complex64> = f64_array
172                    .data()
173                    .iter()
174                    .map(|&x| Complex64::new(x, 0.0))
175                    .collect();
176                let c64_array = StridedArray::from_parts(c64_data, &dims, &strides, 0)
177                    .expect("from_parts failed in to_c64_owned");
178                EinsumOperand::C64(StridedData::Owned(c64_array))
179            }
180        }
181    }
182}
183
184impl From<StridedArray<f64>> for EinsumOperand<'static> {
185    fn from(arr: StridedArray<f64>) -> Self {
186        EinsumOperand::F64(StridedData::Owned(arr))
187    }
188}
189
190impl From<StridedArray<Complex64>> for EinsumOperand<'static> {
191    fn from(arr: StridedArray<Complex64>) -> Self {
192        EinsumOperand::C64(StridedData::Owned(arr))
193    }
194}
195
196// ---------------------------------------------------------------------------
197// EinsumScalar trait — sealed, implemented for f64 and Complex64
198// ---------------------------------------------------------------------------
199
200mod private {
201    pub trait Sealed {}
202    impl Sealed for f64 {}
203    impl Sealed for num_complex::Complex64 {}
204}
205
206/// Scalar types that can be used as the output element type for `einsum_into`.
207///
208/// Sealed trait: only implemented for `f64` and `Complex64`.
209pub trait EinsumScalar: private::Sealed + Scalar + Default + 'static {
210    /// Human-readable type name for error messages.
211    fn type_name() -> &'static str;
212
213    /// Wrap typed `StridedData` into a type-erased `EinsumOperand`.
214    fn wrap_data(data: StridedData<'_, Self>) -> EinsumOperand<'_>;
215
216    /// Wrap an owned `StridedArray` into a type-erased `EinsumOperand`.
217    fn wrap_array(arr: StridedArray<Self>) -> EinsumOperand<'static>;
218
219    /// Wrap an owned `StridedArray` into a `TypedTensor`.
220    fn wrap_typed_tensor(arr: StridedArray<Self>) -> TypedTensor;
221
222    /// Extract typed data from a type-erased `EinsumOperand`, promoting if needed.
223    ///
224    /// For `f64`: returns error if operand is `C64`.
225    /// For `Complex64`: promotes `F64` operands to `C64`.
226    fn extract_data<'a>(op: EinsumOperand<'a>) -> crate::Result<StridedData<'a, Self>>;
227
228    /// Check whether any operand requires this type or is incompatible.
229    /// Returns error early if `T = f64` but any operand is `C64`.
230    fn validate_operands(ops: &[Option<EinsumOperand<'_>>]) -> crate::Result<()>;
231}
232
233impl EinsumScalar for f64 {
234    fn type_name() -> &'static str {
235        "f64"
236    }
237
238    fn wrap_data(data: StridedData<'_, Self>) -> EinsumOperand<'_> {
239        EinsumOperand::F64(data)
240    }
241
242    fn wrap_array(arr: StridedArray<Self>) -> EinsumOperand<'static> {
243        EinsumOperand::F64(StridedData::Owned(arr))
244    }
245
246    fn wrap_typed_tensor(arr: StridedArray<Self>) -> TypedTensor {
247        TypedTensor::F64(arr)
248    }
249
250    fn extract_data<'a>(op: EinsumOperand<'a>) -> crate::Result<StridedData<'a, f64>> {
251        match op {
252            EinsumOperand::F64(data) => Ok(data),
253            EinsumOperand::C64(_) => Err(crate::EinsumError::TypeMismatch {
254                output_type: "f64",
255                computed_type: "Complex64",
256            }),
257        }
258    }
259
260    fn validate_operands(ops: &[Option<EinsumOperand<'_>>]) -> crate::Result<()> {
261        for op in ops.iter().flatten() {
262            if op.is_c64() {
263                return Err(crate::EinsumError::TypeMismatch {
264                    output_type: "f64",
265                    computed_type: "Complex64",
266                });
267            }
268        }
269        Ok(())
270    }
271}
272
273impl EinsumScalar for Complex64 {
274    fn type_name() -> &'static str {
275        "Complex64"
276    }
277
278    fn wrap_data(data: StridedData<'_, Self>) -> EinsumOperand<'_> {
279        EinsumOperand::C64(data)
280    }
281
282    fn wrap_array(arr: StridedArray<Self>) -> EinsumOperand<'static> {
283        EinsumOperand::C64(StridedData::Owned(arr))
284    }
285
286    fn wrap_typed_tensor(arr: StridedArray<Self>) -> TypedTensor {
287        TypedTensor::C64(arr)
288    }
289
290    fn extract_data<'a>(op: EinsumOperand<'a>) -> crate::Result<StridedData<'a, Complex64>> {
291        match op {
292            EinsumOperand::C64(data) => Ok(data),
293            EinsumOperand::F64(data) => {
294                // Promote f64 → Complex64.
295                // First materialize to col-major f64, then convert elements.
296                let view = data.as_view();
297                let dims = view.dims().to_vec();
298                let strides = col_major_strides(&dims);
299                let mut f64_col = StridedArray::<f64>::col_major(&dims);
300                copy_into(&mut f64_col.view_mut(), &view)
301                    .expect("copy_into failed in extract_data");
302                let c64_data: Vec<Complex64> = f64_col
303                    .data()
304                    .iter()
305                    .map(|&x| Complex64::new(x, 0.0))
306                    .collect();
307                let c64_array = StridedArray::from_parts(c64_data, &dims, &strides, 0)
308                    .expect("from_parts failed in extract_data");
309                Ok(StridedData::Owned(c64_array))
310            }
311        }
312    }
313
314    fn validate_operands(_ops: &[Option<EinsumOperand<'_>>]) -> crate::Result<()> {
315        // Complex64 output accepts both f64 and c64 operands
316        Ok(())
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use num_complex::Complex64;
324    use strided_view::StridedArray;
325
326    #[test]
327    fn test_f64_owned() {
328        let arr = StridedArray::<f64>::col_major(&[2, 3]);
329        let op = EinsumOperand::from(arr);
330        assert!(op.is_f64());
331        assert!(!op.is_c64());
332        assert_eq!(op.dims(), &[2, 3]);
333    }
334
335    #[test]
336    fn test_c64_owned() {
337        let arr = StridedArray::<Complex64>::col_major(&[4, 5]);
338        let op = EinsumOperand::from(arr);
339        assert!(op.is_c64());
340        assert_eq!(op.dims(), &[4, 5]);
341    }
342
343    #[test]
344    fn test_f64_view() {
345        let arr = StridedArray::<f64>::col_major(&[2, 3]);
346        let view = arr.view();
347        let op = EinsumOperand::from_view(&view);
348        assert!(op.is_f64());
349        assert_eq!(op.dims(), &[2, 3]);
350    }
351
352    #[test]
353    fn test_promote_f64_to_c64() {
354        let mut arr = StridedArray::<f64>::col_major(&[2, 2]);
355        arr.data_mut()[0] = 1.0;
356        arr.data_mut()[1] = 2.0;
357        arr.data_mut()[2] = 3.0;
358        arr.data_mut()[3] = 4.0;
359        let op = EinsumOperand::from(arr);
360        let promoted = op.to_c64_owned();
361        assert!(promoted.is_c64());
362        match &promoted {
363            EinsumOperand::C64(StridedData::Owned(arr)) => {
364                assert_eq!(arr.data()[0], Complex64::new(1.0, 0.0));
365                assert_eq!(arr.data()[1], Complex64::new(2.0, 0.0));
366            }
367            _ => panic!("expected C64 Owned"),
368        }
369    }
370
371    // -----------------------------------------------------------------------
372    // to_c64_owned_ref tests
373    // -----------------------------------------------------------------------
374
375    #[test]
376    fn test_to_c64_owned_ref_from_f64_owned() {
377        let mut arr = StridedArray::<f64>::col_major(&[2, 2]);
378        arr.data_mut()[0] = 1.0;
379        arr.data_mut()[1] = 2.0;
380        arr.data_mut()[2] = 3.0;
381        arr.data_mut()[3] = 4.0;
382        let op = EinsumOperand::from(arr);
383        let promoted = op.to_c64_owned_ref();
384        assert!(promoted.is_c64());
385        match &promoted {
386            EinsumOperand::C64(StridedData::Owned(arr)) => {
387                assert_eq!(arr.dims(), &[2, 2]);
388                assert_eq!(arr.data()[0], Complex64::new(1.0, 0.0));
389                assert_eq!(arr.data()[1], Complex64::new(2.0, 0.0));
390                assert_eq!(arr.data()[2], Complex64::new(3.0, 0.0));
391                assert_eq!(arr.data()[3], Complex64::new(4.0, 0.0));
392            }
393            _ => panic!("expected C64 Owned"),
394        }
395    }
396
397    #[test]
398    fn test_to_c64_owned_ref_from_f64_view() {
399        let mut arr = StridedArray::<f64>::col_major(&[2, 2]);
400        arr.data_mut()[0] = 5.0;
401        arr.data_mut()[1] = 6.0;
402        arr.data_mut()[2] = 7.0;
403        arr.data_mut()[3] = 8.0;
404        let view = arr.view();
405        let op = EinsumOperand::from_view(&view);
406        let promoted = op.to_c64_owned_ref();
407        assert!(promoted.is_c64());
408        match &promoted {
409            EinsumOperand::C64(StridedData::Owned(c_arr)) => {
410                assert_eq!(c_arr.dims(), &[2, 2]);
411                assert_eq!(c_arr.data()[0], Complex64::new(5.0, 0.0));
412                assert_eq!(c_arr.data()[1], Complex64::new(6.0, 0.0));
413                assert_eq!(c_arr.data()[2], Complex64::new(7.0, 0.0));
414                assert_eq!(c_arr.data()[3], Complex64::new(8.0, 0.0));
415            }
416            _ => panic!("expected C64 Owned"),
417        }
418    }
419
420    #[test]
421    fn test_to_c64_owned_ref_from_c64_owned() {
422        let mut arr = StridedArray::<Complex64>::col_major(&[2, 2]);
423        arr.data_mut()[0] = Complex64::new(1.0, 2.0);
424        arr.data_mut()[1] = Complex64::new(3.0, 4.0);
425        arr.data_mut()[2] = Complex64::new(5.0, 6.0);
426        arr.data_mut()[3] = Complex64::new(7.0, 8.0);
427        let op = EinsumOperand::from(arr);
428        let copied = op.to_c64_owned_ref();
429        assert!(copied.is_c64());
430        match &copied {
431            EinsumOperand::C64(StridedData::Owned(c_arr)) => {
432                assert_eq!(c_arr.dims(), &[2, 2]);
433                assert_eq!(c_arr.data()[0], Complex64::new(1.0, 2.0));
434                assert_eq!(c_arr.data()[1], Complex64::new(3.0, 4.0));
435                assert_eq!(c_arr.data()[2], Complex64::new(5.0, 6.0));
436                assert_eq!(c_arr.data()[3], Complex64::new(7.0, 8.0));
437            }
438            _ => panic!("expected C64 Owned"),
439        }
440    }
441
442    #[test]
443    fn test_to_c64_owned_ref_from_c64_view() {
444        let mut arr = StridedArray::<Complex64>::col_major(&[3]);
445        arr.data_mut()[0] = Complex64::new(1.0, -1.0);
446        arr.data_mut()[1] = Complex64::new(2.0, -2.0);
447        arr.data_mut()[2] = Complex64::new(3.0, -3.0);
448        let view = arr.view();
449        let op = EinsumOperand::from_view(&view);
450        let copied = op.to_c64_owned_ref();
451        assert!(copied.is_c64());
452        match &copied {
453            EinsumOperand::C64(StridedData::Owned(c_arr)) => {
454                assert_eq!(c_arr.dims(), &[3]);
455                assert_eq!(c_arr.data()[0], Complex64::new(1.0, -1.0));
456                assert_eq!(c_arr.data()[1], Complex64::new(2.0, -2.0));
457                assert_eq!(c_arr.data()[2], Complex64::new(3.0, -3.0));
458            }
459            _ => panic!("expected C64 Owned"),
460        }
461    }
462
463    // -----------------------------------------------------------------------
464    // StridedData::into_array tests (View variant)
465    // -----------------------------------------------------------------------
466
467    #[test]
468    fn test_strided_data_into_array_from_owned() {
469        let mut arr = StridedArray::<f64>::col_major(&[2, 3]);
470        for (i, v) in arr.data_mut().iter_mut().enumerate() {
471            *v = i as f64;
472        }
473        let data = StridedData::Owned(arr);
474        let result = data.into_array();
475        assert_eq!(result.dims(), &[2, 3]);
476        assert_eq!(result.data()[0], 0.0);
477        assert_eq!(result.data()[5], 5.0);
478    }
479
480    #[test]
481    fn test_strided_data_into_array_from_view() {
482        let mut arr = StridedArray::<f64>::col_major(&[2, 3]);
483        for (i, v) in arr.data_mut().iter_mut().enumerate() {
484            *v = (i as f64) * 10.0;
485        }
486        let view = arr.view();
487        let data = StridedData::<f64>::View(view);
488        let result = data.into_array();
489        assert_eq!(result.dims(), &[2, 3]);
490        // Values should be copied correctly
491        assert_eq!(result.get(&[0, 0]), 0.0);
492        assert_eq!(result.get(&[1, 0]), 10.0);
493    }
494
495    #[test]
496    fn test_strided_data_into_array_from_view_c64() {
497        let mut arr = StridedArray::<Complex64>::col_major(&[2, 2]);
498        arr.data_mut()[0] = Complex64::new(1.0, 2.0);
499        arr.data_mut()[1] = Complex64::new(3.0, 4.0);
500        arr.data_mut()[2] = Complex64::new(5.0, 6.0);
501        arr.data_mut()[3] = Complex64::new(7.0, 8.0);
502        let view = arr.view();
503        let data = StridedData::<Complex64>::View(view);
504        let result = data.into_array();
505        assert_eq!(result.dims(), &[2, 2]);
506        assert_eq!(result.get(&[0, 0]), Complex64::new(1.0, 2.0));
507        assert_eq!(result.get(&[1, 1]), Complex64::new(7.0, 8.0));
508    }
509
510    // -----------------------------------------------------------------------
511    // StridedData::as_array tests
512    // -----------------------------------------------------------------------
513
514    #[test]
515    fn test_strided_data_as_array_owned() {
516        let arr = StridedArray::<f64>::col_major(&[3, 2]);
517        let data = StridedData::Owned(arr);
518        let array_ref = data.as_array();
519        assert_eq!(array_ref.dims(), &[3, 2]);
520    }
521
522    #[test]
523    #[should_panic(expected = "StridedData::as_array called on a View variant")]
524    fn test_strided_data_as_array_view_panics() {
525        let arr = StridedArray::<f64>::col_major(&[3, 2]);
526        let view = arr.view();
527        let data = StridedData::<f64>::View(view);
528        let _ = data.as_array(); // should panic
529    }
530
531    // -----------------------------------------------------------------------
532    // EinsumScalar::validate_operands tests
533    // -----------------------------------------------------------------------
534
535    #[test]
536    fn test_validate_operands_f64_all_f64() {
537        let arr1 = StridedArray::<f64>::col_major(&[2, 2]);
538        let arr2 = StridedArray::<f64>::col_major(&[2, 2]);
539        let ops: Vec<Option<EinsumOperand>> = vec![
540            Some(EinsumOperand::from(arr1)),
541            Some(EinsumOperand::from(arr2)),
542        ];
543        assert!(f64::validate_operands(&ops).is_ok());
544    }
545
546    #[test]
547    fn test_validate_operands_f64_with_none() {
548        // None entries should be skipped without error
549        let arr = StridedArray::<f64>::col_major(&[2, 2]);
550        let ops: Vec<Option<EinsumOperand>> = vec![Some(EinsumOperand::from(arr)), None];
551        assert!(f64::validate_operands(&ops).is_ok());
552    }
553
554    #[test]
555    fn test_validate_operands_f64_with_c64_returns_error() {
556        let f64_arr = StridedArray::<f64>::col_major(&[2, 2]);
557        let c64_arr = StridedArray::<Complex64>::col_major(&[2, 2]);
558        let ops: Vec<Option<EinsumOperand>> = vec![
559            Some(EinsumOperand::from(f64_arr)),
560            Some(EinsumOperand::from(c64_arr)),
561        ];
562        let err = f64::validate_operands(&ops).unwrap_err();
563        assert!(matches!(
564            err,
565            crate::EinsumError::TypeMismatch {
566                output_type: "f64",
567                computed_type: "Complex64",
568            }
569        ));
570    }
571
572    #[test]
573    fn test_validate_operands_c64_accepts_anything() {
574        let f64_arr = StridedArray::<f64>::col_major(&[2, 2]);
575        let c64_arr = StridedArray::<Complex64>::col_major(&[2, 2]);
576        let ops: Vec<Option<EinsumOperand>> = vec![
577            Some(EinsumOperand::from(f64_arr)),
578            Some(EinsumOperand::from(c64_arr)),
579        ];
580        assert!(Complex64::validate_operands(&ops).is_ok());
581    }
582
583    #[test]
584    fn test_validate_operands_c64_all_f64() {
585        let arr1 = StridedArray::<f64>::col_major(&[2, 2]);
586        let arr2 = StridedArray::<f64>::col_major(&[2, 2]);
587        let ops: Vec<Option<EinsumOperand>> = vec![
588            Some(EinsumOperand::from(arr1)),
589            Some(EinsumOperand::from(arr2)),
590        ];
591        assert!(Complex64::validate_operands(&ops).is_ok());
592    }
593
594    // -----------------------------------------------------------------------
595    // EinsumScalar::extract_data tests
596    // -----------------------------------------------------------------------
597
598    #[test]
599    fn test_extract_data_f64_from_f64() {
600        let mut arr = StridedArray::<f64>::col_major(&[2, 2]);
601        arr.data_mut()[0] = 42.0;
602        let op = EinsumOperand::from(arr);
603        let data = f64::extract_data(op).unwrap();
604        assert_eq!(data.as_view().get(&[0, 0]), 42.0);
605    }
606
607    #[test]
608    fn test_extract_data_f64_from_c64_returns_error() {
609        let arr = StridedArray::<Complex64>::col_major(&[2, 2]);
610        let op = EinsumOperand::from(arr);
611        let err = f64::extract_data(op).unwrap_err();
612        assert!(matches!(
613            err,
614            crate::EinsumError::TypeMismatch {
615                output_type: "f64",
616                computed_type: "Complex64",
617            }
618        ));
619    }
620
621    #[test]
622    fn test_extract_data_c64_from_c64() {
623        let mut arr = StridedArray::<Complex64>::col_major(&[2, 2]);
624        arr.data_mut()[0] = Complex64::new(1.0, 2.0);
625        let op = EinsumOperand::from(arr);
626        let data = Complex64::extract_data(op).unwrap();
627        assert_eq!(data.as_view().get(&[0, 0]), Complex64::new(1.0, 2.0));
628    }
629
630    #[test]
631    fn test_extract_data_c64_from_f64_promotes() {
632        let mut arr = StridedArray::<f64>::col_major(&[2, 2]);
633        arr.data_mut()[0] = 5.0;
634        arr.data_mut()[1] = 6.0;
635        arr.data_mut()[2] = 7.0;
636        arr.data_mut()[3] = 8.0;
637        let op = EinsumOperand::from(arr);
638        let data = Complex64::extract_data(op).unwrap();
639        // Data should be promoted from f64 to Complex64
640        match &data {
641            StridedData::Owned(c_arr) => {
642                assert_eq!(c_arr.dims(), &[2, 2]);
643                assert_eq!(c_arr.data()[0], Complex64::new(5.0, 0.0));
644                assert_eq!(c_arr.data()[1], Complex64::new(6.0, 0.0));
645                assert_eq!(c_arr.data()[2], Complex64::new(7.0, 0.0));
646                assert_eq!(c_arr.data()[3], Complex64::new(8.0, 0.0));
647            }
648            StridedData::View(_) => panic!("expected Owned after promotion"),
649        }
650    }
651
652    // -----------------------------------------------------------------------
653    // EinsumScalar::type_name tests
654    // -----------------------------------------------------------------------
655
656    #[test]
657    fn test_type_name() {
658        assert_eq!(f64::type_name(), "f64");
659        assert_eq!(Complex64::type_name(), "Complex64");
660    }
661
662    // -----------------------------------------------------------------------
663    // StridedData::dims and as_view tests
664    // -----------------------------------------------------------------------
665
666    #[test]
667    fn test_strided_data_dims_and_as_view() {
668        let mut arr = StridedArray::<f64>::col_major(&[3, 4]);
669        for (i, v) in arr.data_mut().iter_mut().enumerate() {
670            *v = i as f64;
671        }
672        // Test Owned variant
673        let owned = StridedData::Owned(arr.clone());
674        assert_eq!(owned.dims(), &[3, 4]);
675        let owned_view = owned.as_view();
676        assert_eq!(owned_view.dims(), &[3, 4]);
677
678        // Test View variant
679        let view = arr.view();
680        let data_view = StridedData::<f64>::View(view);
681        assert_eq!(data_view.dims(), &[3, 4]);
682        let view_again = data_view.as_view();
683        assert_eq!(view_again.dims(), &[3, 4]);
684    }
685
686    // -----------------------------------------------------------------------
687    // StridedData::permuted tests
688    // -----------------------------------------------------------------------
689
690    #[test]
691    fn test_strided_data_permuted_owned() {
692        let mut arr = StridedArray::<f64>::col_major(&[2, 3]);
693        for (i, v) in arr.data_mut().iter_mut().enumerate() {
694            *v = i as f64;
695        }
696        let data = StridedData::Owned(arr);
697        let permuted = data.permuted(&[1, 0]).unwrap();
698        assert_eq!(permuted.dims(), &[3, 2]);
699    }
700
701    #[test]
702    fn test_strided_data_permuted_view() {
703        let mut arr = StridedArray::<f64>::col_major(&[2, 3]);
704        for (i, v) in arr.data_mut().iter_mut().enumerate() {
705            *v = i as f64;
706        }
707        let view = arr.view();
708        let data = StridedData::<f64>::View(view);
709        let permuted = data.permuted(&[1, 0]).unwrap();
710        assert_eq!(permuted.dims(), &[3, 2]);
711    }
712
713    // -----------------------------------------------------------------------
714    // EinsumOperand::permuted tests
715    // -----------------------------------------------------------------------
716
717    #[test]
718    fn test_einsum_operand_permuted_f64() {
719        let arr = StridedArray::<f64>::col_major(&[2, 3]);
720        let op = EinsumOperand::from(arr);
721        let permuted = op.permuted(&[1, 0]).unwrap();
722        assert!(permuted.is_f64());
723        assert_eq!(permuted.dims(), &[3, 2]);
724    }
725
726    #[test]
727    fn test_einsum_operand_permuted_c64() {
728        let arr = StridedArray::<Complex64>::col_major(&[4, 5]);
729        let op = EinsumOperand::from(arr);
730        let permuted = op.permuted(&[1, 0]).unwrap();
731        assert!(permuted.is_c64());
732        assert_eq!(permuted.dims(), &[5, 4]);
733    }
734
735    // -----------------------------------------------------------------------
736    // to_c64_owned edge cases
737    // -----------------------------------------------------------------------
738
739    #[test]
740    fn test_to_c64_owned_c64_view() {
741        // C64 View variant should be materialized into Owned
742        let mut arr = StridedArray::<Complex64>::col_major(&[2, 2]);
743        arr.data_mut()[0] = Complex64::new(1.0, -1.0);
744        arr.data_mut()[1] = Complex64::new(2.0, -2.0);
745        arr.data_mut()[2] = Complex64::new(3.0, -3.0);
746        arr.data_mut()[3] = Complex64::new(4.0, -4.0);
747        let view = arr.view();
748        let op = EinsumOperand::from_view(&view);
749        let owned = op.to_c64_owned();
750        assert!(owned.is_c64());
751        match &owned {
752            EinsumOperand::C64(StridedData::Owned(c_arr)) => {
753                assert_eq!(c_arr.dims(), &[2, 2]);
754                assert_eq!(c_arr.data()[0], Complex64::new(1.0, -1.0));
755                assert_eq!(c_arr.data()[3], Complex64::new(4.0, -4.0));
756            }
757            _ => panic!("expected C64 Owned"),
758        }
759    }
760
761    #[test]
762    fn test_to_c64_owned_c64_already_owned() {
763        // C64 Owned should pass through without reallocation
764        let mut arr = StridedArray::<Complex64>::col_major(&[2]);
765        arr.data_mut()[0] = Complex64::new(10.0, 20.0);
766        arr.data_mut()[1] = Complex64::new(30.0, 40.0);
767        let op = EinsumOperand::from(arr);
768        let owned = op.to_c64_owned();
769        assert!(owned.is_c64());
770        match &owned {
771            EinsumOperand::C64(StridedData::Owned(c_arr)) => {
772                assert_eq!(c_arr.data()[0], Complex64::new(10.0, 20.0));
773                assert_eq!(c_arr.data()[1], Complex64::new(30.0, 40.0));
774            }
775            _ => panic!("expected C64 Owned"),
776        }
777    }
778
779    #[test]
780    fn test_to_c64_owned_f64_view() {
781        // F64 View should be materialized and promoted
782        let mut arr = StridedArray::<f64>::col_major(&[3]);
783        arr.data_mut()[0] = 10.0;
784        arr.data_mut()[1] = 20.0;
785        arr.data_mut()[2] = 30.0;
786        let view = arr.view();
787        let op = EinsumOperand::from_view(&view);
788        let promoted = op.to_c64_owned();
789        assert!(promoted.is_c64());
790        match &promoted {
791            EinsumOperand::C64(StridedData::Owned(c_arr)) => {
792                assert_eq!(c_arr.dims(), &[3]);
793                assert_eq!(c_arr.data()[0], Complex64::new(10.0, 0.0));
794                assert_eq!(c_arr.data()[1], Complex64::new(20.0, 0.0));
795                assert_eq!(c_arr.data()[2], Complex64::new(30.0, 0.0));
796            }
797            _ => panic!("expected C64 Owned"),
798        }
799    }
800}