strided_opteinsum/
typed_tensor.rs

1use num_complex::Complex64;
2use strided_view::StridedArray;
3
4/// A type-erased tensor that dispatches over f64 and Complex64 at runtime.
5pub enum TypedTensor {
6    F64(StridedArray<f64>),
7    C64(StridedArray<Complex64>),
8}
9
10const DEMOTE_THRESHOLD: f64 = 1e-15;
11
12impl TypedTensor {
13    /// Returns `true` if this tensor holds `f64` data.
14    pub fn is_f64(&self) -> bool {
15        matches!(self, TypedTensor::F64(_))
16    }
17    /// Returns `true` if this tensor holds `Complex64` data.
18    pub fn is_c64(&self) -> bool {
19        matches!(self, TypedTensor::C64(_))
20    }
21
22    /// Returns the dimensions of the underlying array.
23    pub fn dims(&self) -> &[usize] {
24        match self {
25            TypedTensor::F64(a) => a.dims(),
26            TypedTensor::C64(a) => a.dims(),
27        }
28    }
29
30    /// Promote to Complex64. If already C64, returns self unchanged.
31    pub fn to_c64(self) -> TypedTensor {
32        match self {
33            TypedTensor::C64(_) => self,
34            TypedTensor::F64(a) => {
35                let c64_data: Vec<Complex64> =
36                    a.data().iter().map(|&x| Complex64::new(x, 0.0)).collect();
37                let arr = StridedArray::from_parts(c64_data, a.dims(), a.strides(), 0).unwrap();
38                TypedTensor::C64(arr)
39            }
40        }
41    }
42
43    /// Try to demote a Complex64 array to f64 if all imaginary parts are negligible.
44    pub fn try_demote_to_f64(arr: StridedArray<Complex64>) -> TypedTensor {
45        let all_real = arr.data().iter().all(|c| c.im.abs() < DEMOTE_THRESHOLD);
46        if all_real {
47            let f64_data: Vec<f64> = arr.data().iter().map(|c| c.re).collect();
48            let f_arr = StridedArray::from_parts(f64_data, arr.dims(), arr.strides(), 0).unwrap();
49            TypedTensor::F64(f_arr)
50        } else {
51            TypedTensor::C64(arr)
52        }
53    }
54}
55
56/// Returns true if any input is Complex64 (triggering promotion for all).
57pub fn needs_c64_promotion(inputs: &[&TypedTensor]) -> bool {
58    inputs.iter().any(|t| t.is_c64())
59}
60
61#[cfg(test)]
62mod tests {
63    use super::*;
64    use approx::assert_abs_diff_eq;
65
66    #[test]
67    fn test_typed_f64() {
68        let arr = StridedArray::<f64>::col_major(&[2, 3]);
69        let t = TypedTensor::F64(arr);
70        assert!(t.is_f64());
71        assert!(!t.is_c64());
72        assert_eq!(t.dims(), &[2, 3]);
73    }
74
75    #[test]
76    fn test_typed_c64() {
77        let arr = StridedArray::<Complex64>::col_major(&[3, 4]);
78        let t = TypedTensor::C64(arr);
79        assert!(t.is_c64());
80    }
81
82    #[test]
83    fn test_promote_to_c64() {
84        let mut arr = StridedArray::<f64>::col_major(&[2]);
85        arr.data_mut()[0] = 3.0;
86        arr.data_mut()[1] = 7.0;
87        let t = TypedTensor::F64(arr);
88        let promoted = t.to_c64();
89        match promoted {
90            TypedTensor::C64(a) => {
91                assert_abs_diff_eq!(a.data()[0].re, 3.0);
92                assert_abs_diff_eq!(a.data()[1].re, 7.0);
93                assert_abs_diff_eq!(a.data()[0].im, 0.0);
94            }
95            _ => panic!("expected C64"),
96        }
97    }
98
99    #[test]
100    fn test_demote_to_f64() {
101        let mut arr = StridedArray::<Complex64>::col_major(&[2]);
102        arr.data_mut()[0] = Complex64::new(1.0, 0.0);
103        arr.data_mut()[1] = Complex64::new(2.0, 1e-16);
104        let t = TypedTensor::try_demote_to_f64(arr);
105        assert!(t.is_f64());
106    }
107
108    #[test]
109    fn test_no_demote_if_complex() {
110        let mut arr = StridedArray::<Complex64>::col_major(&[2]);
111        arr.data_mut()[0] = Complex64::new(1.0, 0.5);
112        arr.data_mut()[1] = Complex64::new(2.0, 0.0);
113        let t = TypedTensor::try_demote_to_f64(arr);
114        assert!(t.is_c64());
115    }
116
117    #[test]
118    fn test_needs_c64_promotion() {
119        let f = TypedTensor::F64(StridedArray::<f64>::col_major(&[1]));
120        let c = TypedTensor::C64(StridedArray::<Complex64>::col_major(&[1]));
121        assert!(!needs_c64_promotion(&[&f]));
122        assert!(needs_c64_promotion(&[&f, &c]));
123        assert!(needs_c64_promotion(&[&c]));
124    }
125}