strided_opteinsum/
typed_tensor.rs1use num_complex::Complex64;
2use strided_view::StridedArray;
3
4pub enum TypedTensor {
6 F64(StridedArray<f64>),
7 C64(StridedArray<Complex64>),
8}
9
10const DEMOTE_THRESHOLD: f64 = 1e-15;
11
12impl TypedTensor {
13 pub fn is_f64(&self) -> bool {
15 matches!(self, TypedTensor::F64(_))
16 }
17 pub fn is_c64(&self) -> bool {
19 matches!(self, TypedTensor::C64(_))
20 }
21
22 pub fn dims(&self) -> &[usize] {
24 match self {
25 TypedTensor::F64(a) => a.dims(),
26 TypedTensor::C64(a) => a.dims(),
27 }
28 }
29
30 pub fn to_c64(self) -> TypedTensor {
32 match self {
33 TypedTensor::C64(_) => self,
34 TypedTensor::F64(a) => {
35 let c64_data: Vec<Complex64> =
36 a.data().iter().map(|&x| Complex64::new(x, 0.0)).collect();
37 let arr = StridedArray::from_parts(c64_data, a.dims(), a.strides(), 0).unwrap();
38 TypedTensor::C64(arr)
39 }
40 }
41 }
42
43 pub fn try_demote_to_f64(arr: StridedArray<Complex64>) -> TypedTensor {
45 let all_real = arr.data().iter().all(|c| c.im.abs() < DEMOTE_THRESHOLD);
46 if all_real {
47 let f64_data: Vec<f64> = arr.data().iter().map(|c| c.re).collect();
48 let f_arr = StridedArray::from_parts(f64_data, arr.dims(), arr.strides(), 0).unwrap();
49 TypedTensor::F64(f_arr)
50 } else {
51 TypedTensor::C64(arr)
52 }
53 }
54}
55
56pub fn needs_c64_promotion(inputs: &[&TypedTensor]) -> bool {
58 inputs.iter().any(|t| t.is_c64())
59}
60
61#[cfg(test)]
62mod tests {
63 use super::*;
64 use approx::assert_abs_diff_eq;
65
66 #[test]
67 fn test_typed_f64() {
68 let arr = StridedArray::<f64>::col_major(&[2, 3]);
69 let t = TypedTensor::F64(arr);
70 assert!(t.is_f64());
71 assert!(!t.is_c64());
72 assert_eq!(t.dims(), &[2, 3]);
73 }
74
75 #[test]
76 fn test_typed_c64() {
77 let arr = StridedArray::<Complex64>::col_major(&[3, 4]);
78 let t = TypedTensor::C64(arr);
79 assert!(t.is_c64());
80 }
81
82 #[test]
83 fn test_promote_to_c64() {
84 let mut arr = StridedArray::<f64>::col_major(&[2]);
85 arr.data_mut()[0] = 3.0;
86 arr.data_mut()[1] = 7.0;
87 let t = TypedTensor::F64(arr);
88 let promoted = t.to_c64();
89 match promoted {
90 TypedTensor::C64(a) => {
91 assert_abs_diff_eq!(a.data()[0].re, 3.0);
92 assert_abs_diff_eq!(a.data()[1].re, 7.0);
93 assert_abs_diff_eq!(a.data()[0].im, 0.0);
94 }
95 _ => panic!("expected C64"),
96 }
97 }
98
99 #[test]
100 fn test_demote_to_f64() {
101 let mut arr = StridedArray::<Complex64>::col_major(&[2]);
102 arr.data_mut()[0] = Complex64::new(1.0, 0.0);
103 arr.data_mut()[1] = Complex64::new(2.0, 1e-16);
104 let t = TypedTensor::try_demote_to_f64(arr);
105 assert!(t.is_f64());
106 }
107
108 #[test]
109 fn test_no_demote_if_complex() {
110 let mut arr = StridedArray::<Complex64>::col_major(&[2]);
111 arr.data_mut()[0] = Complex64::new(1.0, 0.5);
112 arr.data_mut()[1] = Complex64::new(2.0, 0.0);
113 let t = TypedTensor::try_demote_to_f64(arr);
114 assert!(t.is_c64());
115 }
116
117 #[test]
118 fn test_needs_c64_promotion() {
119 let f = TypedTensor::F64(StridedArray::<f64>::col_major(&[1]));
120 let c = TypedTensor::C64(StridedArray::<Complex64>::col_major(&[1]));
121 assert!(!needs_c64_promotion(&[&f]));
122 assert!(needs_c64_promotion(&[&f, &c]));
123 assert!(needs_c64_promotion(&[&c]));
124 }
125}