1use std::cmp::Ordering;
2use std::fmt;
3use std::ops::{Add, Div, Mul, Neg, Sub};
4
5use anyhow::{anyhow, ensure, Result};
6use num_complex::{Complex32, Complex64};
7use num_traits::{One, Zero};
8use tenferro::{DType, Tensor as NativeTensor};
9
10use crate::storage::{Storage, SumFromStorage};
11use crate::tensor_element::TensorElement;
12
13#[derive(Clone, Copy, Debug, PartialEq)]
14enum ScalarValue {
15 F32(f32),
16 F64(f64),
17 I64(i64),
18 C32(Complex32),
19 C64(Complex64),
20}
21
22impl ScalarValue {
23 fn real(self) -> f64 {
24 match self {
25 Self::F32(value) => value as f64,
26 Self::F64(value) => value,
27 Self::I64(value) => value as f64,
28 Self::C32(value) => value.re as f64,
29 Self::C64(value) => value.re,
30 }
31 }
32
33 fn imag(self) -> f64 {
34 match self {
35 Self::F32(_) | Self::F64(_) | Self::I64(_) => 0.0,
36 Self::C32(value) => value.im as f64,
37 Self::C64(value) => value.im,
38 }
39 }
40
41 fn abs(self) -> f64 {
42 match self {
43 Self::F32(value) => value.abs() as f64,
44 Self::F64(value) => value.abs(),
45 Self::I64(value) => value.abs() as f64,
46 Self::C32(value) => value.norm() as f64,
47 Self::C64(value) => value.norm(),
48 }
49 }
50
51 fn is_complex(self) -> bool {
52 matches!(self, Self::C32(_) | Self::C64(_))
53 }
54
55 fn is_zero(self) -> bool {
56 match self {
57 Self::F32(value) => value == 0.0,
58 Self::F64(value) => value == 0.0,
59 Self::I64(value) => value == 0,
60 Self::C32(value) => value == Complex32::new(0.0, 0.0),
61 Self::C64(value) => value == Complex64::new(0.0, 0.0),
62 }
63 }
64
65 fn into_complex(self) -> Complex64 {
66 match self {
67 Self::F32(value) => Complex64::new(value as f64, 0.0),
68 Self::F64(value) => Complex64::new(value, 0.0),
69 Self::I64(value) => Complex64::new(value as f64, 0.0),
70 Self::C32(value) => Complex64::new(value.re as f64, value.im as f64),
71 Self::C64(value) => value,
72 }
73 }
74}
75
76fn scalar_value_from_storage(storage: &Storage) -> ScalarValue {
77 if storage.is_f64() {
78 ScalarValue::F64(f64::sum_from_storage(storage))
79 } else {
80 ScalarValue::C64(Complex64::sum_from_storage(storage))
81 }
82}
83
84fn scalar_value_from_native(native: &NativeTensor) -> Result<ScalarValue> {
85 ensure!(
86 native.shape().is_empty(),
87 "expected rank-0 scalar tensor, got shape {:?}",
88 native.shape()
89 );
90
91 match native.dtype() {
92 DType::F32 => native
93 .as_slice::<f32>()
94 .and_then(|values| values.first().copied())
95 .map(ScalarValue::F32)
96 .ok_or_else(|| anyhow!("failed to read f32 scalar tensor value")),
97 DType::F64 => native
98 .as_slice::<f64>()
99 .and_then(|values| values.first().copied())
100 .map(ScalarValue::F64)
101 .ok_or_else(|| anyhow!("failed to read f64 scalar tensor value")),
102 DType::I64 => native
103 .as_slice::<i64>()
104 .and_then(|values| values.first().copied())
105 .map(ScalarValue::I64)
106 .ok_or_else(|| anyhow!("failed to read i64 scalar tensor value")),
107 DType::C32 => native
108 .as_slice::<Complex32>()
109 .and_then(|values| values.first().copied())
110 .map(ScalarValue::C32)
111 .ok_or_else(|| anyhow!("failed to read c32 scalar tensor value")),
112 DType::C64 => native
113 .as_slice::<Complex64>()
114 .and_then(|values| values.first().copied())
115 .map(ScalarValue::C64)
116 .ok_or_else(|| anyhow!("failed to read c64 scalar tensor value")),
117 }
118}
119
120trait ScalarTensorElement: TensorElement {
121 fn scalar_value(value: Self) -> ScalarValue;
122}
123
124impl ScalarTensorElement for f32 {
125 fn scalar_value(value: Self) -> ScalarValue {
126 ScalarValue::F32(value)
127 }
128}
129
130impl ScalarTensorElement for f64 {
131 fn scalar_value(value: Self) -> ScalarValue {
132 ScalarValue::F64(value)
133 }
134}
135
136impl ScalarTensorElement for Complex32 {
137 fn scalar_value(value: Self) -> ScalarValue {
138 ScalarValue::C32(value)
139 }
140}
141
142impl ScalarTensorElement for Complex64 {
143 fn scalar_value(value: Self) -> ScalarValue {
144 ScalarValue::C64(value)
145 }
146}
147
148pub(crate) fn promote_scalar_native(native: &NativeTensor, target: DType) -> Result<NativeTensor> {
149 let promoted = match (scalar_value_from_native(native)?, target) {
150 (ScalarValue::F32(value), DType::F32) => Scalar::from_value(value),
151 (ScalarValue::F32(value), DType::F64) => Scalar::from_value(value as f64),
152 (ScalarValue::F32(value), DType::C32) => Scalar::from_value(Complex32::new(value, 0.0)),
153 (ScalarValue::F32(value), DType::C64) => {
154 Scalar::from_value(Complex64::new(value as f64, 0.0))
155 }
156 (ScalarValue::F32(_), DType::I64) => {
157 return Err(anyhow!(
158 "cannot promote f32 scalar to i64 without truncation"
159 ));
160 }
161 (ScalarValue::F64(value), DType::F32) => Scalar::from_value(value as f32),
162 (ScalarValue::F64(value), DType::F64) => Scalar::from_value(value),
163 (ScalarValue::F64(_), DType::I64) => {
164 return Err(anyhow!(
165 "cannot promote f64 scalar to i64 without truncation"
166 ));
167 }
168 (ScalarValue::F64(value), DType::C32) => {
169 Scalar::from_value(Complex32::new(value as f32, 0.0))
170 }
171 (ScalarValue::F64(value), DType::C64) => Scalar::from_value(Complex64::new(value, 0.0)),
172 (ScalarValue::I64(value), DType::F32) => Scalar::from_value(value as f32),
173 (ScalarValue::I64(value), DType::F64) => Scalar::from_value(value as f64),
174 (ScalarValue::I64(value), DType::I64) => Scalar::from_i64(value),
175 (ScalarValue::I64(value), DType::C32) => {
176 Scalar::from_value(Complex32::new(value as f32, 0.0))
177 }
178 (ScalarValue::I64(value), DType::C64) => {
179 Scalar::from_value(Complex64::new(value as f64, 0.0))
180 }
181 (ScalarValue::C32(value), DType::F32) => Scalar::from_value(value.re),
182 (ScalarValue::C32(value), DType::F64) => Scalar::from_value(value.re as f64),
183 (ScalarValue::C32(_), DType::I64) => {
184 return Err(anyhow!("cannot promote c32 scalar to i64"));
185 }
186 (ScalarValue::C32(value), DType::C32) => Scalar::from_value(value),
187 (ScalarValue::C32(value), DType::C64) => {
188 Scalar::from_value(Complex64::new(value.re as f64, value.im as f64))
189 }
190 (ScalarValue::C64(value), DType::F32) => Scalar::from_value(value.re as f32),
191 (ScalarValue::C64(value), DType::F64) => Scalar::from_value(value.re),
192 (ScalarValue::C64(_), DType::I64) => {
193 return Err(anyhow!("cannot promote c64 scalar to i64"));
194 }
195 (ScalarValue::C64(value), DType::C32) => {
196 Scalar::from_value(Complex32::new(value.re as f32, value.im as f32))
197 }
198 (ScalarValue::C64(value), DType::C64) => Scalar::from_value(value),
199 };
200 Ok(promoted.native)
201}
202
203pub struct Scalar {
231 native: NativeTensor,
232 value: ScalarValue,
233}
234
235pub type AnyScalar = Scalar;
237
238impl Scalar {
239 fn wrap_native(native: NativeTensor) -> Result<Self> {
240 if native.shape().is_empty() {
241 let value = scalar_value_from_native(&native)?;
242 Ok(Self { native, value })
243 } else {
244 Err(anyhow!(
245 "Scalar requires a rank-0 tensor, got shape {:?}",
246 native.shape()
247 ))
248 }
249 }
250
251 fn value(&self) -> ScalarValue {
252 self.value
253 }
254
255 fn from_i64(value: i64) -> Self {
256 Self {
257 native: NativeTensor::from_vec(vec![], vec![value]),
258 value: ScalarValue::I64(value),
259 }
260 }
261
262 pub(crate) fn from_native(value: NativeTensor) -> Result<Self> {
263 Self::wrap_native(value)
264 }
265
266 pub(crate) fn as_native(&self) -> &NativeTensor {
267 &self.native
268 }
269
270 #[allow(private_bounds)]
287 pub fn from_value<T: ScalarTensorElement>(value: T) -> Self {
288 let native = NativeTensor::from_vec(vec![], vec![value]);
289 Self {
290 native,
291 value: T::scalar_value(value),
292 }
293 }
294
295 pub fn from_real(x: f64) -> Self {
307 Self::from_value(x)
308 }
309
310 pub fn from_complex(re: f64, im: f64) -> Self {
323 Self::from_value(Complex64::new(re, im))
324 }
325
326 pub fn new_real(x: f64) -> Self {
337 Self::from_real(x)
338 }
339
340 pub fn new_complex(re: f64, im: f64) -> Self {
351 Self::from_complex(re, im)
352 }
353
354 pub fn primal(&self) -> Self {
366 self.clone()
367 }
368
369 pub fn real(&self) -> f64 {
380 self.value().real()
381 }
382
383 pub fn imag(&self) -> f64 {
399 self.value().imag()
400 }
401
402 pub fn abs(&self) -> f64 {
418 self.value().abs()
419 }
420
421 pub fn is_complex(&self) -> bool {
432 self.value().is_complex()
433 }
434
435 pub fn is_real(&self) -> bool {
446 !self.is_complex()
447 }
448
449 pub fn is_zero(&self) -> bool {
460 self.value().is_zero()
461 }
462
463 pub fn as_f64(&self) -> Option<f64> {
479 match self.value() {
480 ScalarValue::F32(value) => Some(value as f64),
481 ScalarValue::F64(value) => Some(value),
482 ScalarValue::I64(value) => Some(value as f64),
483 ScalarValue::C32(_) | ScalarValue::C64(_) => None,
484 }
485 }
486
487 pub fn as_c64(&self) -> Option<Complex64> {
504 match self.value() {
505 ScalarValue::F32(_) | ScalarValue::F64(_) | ScalarValue::I64(_) => None,
506 ScalarValue::C32(value) => Some(Complex64::new(value.re as f64, value.im as f64)),
507 ScalarValue::C64(value) => Some(value),
508 }
509 }
510
511 pub fn conj(&self) -> Self {
529 match self.value() {
530 ScalarValue::F32(value) => Self::from_value(value),
531 ScalarValue::F64(value) => Self::from_value(value),
532 ScalarValue::I64(value) => Self::from_i64(value),
533 ScalarValue::C32(value) => Self::from_value(value.conj()),
534 ScalarValue::C64(value) => Self::from_value(value.conj()),
535 }
536 }
537
538 pub fn real_part(&self) -> Self {
553 Self::from_real(self.real())
554 }
555
556 pub fn imag_part(&self) -> Self {
572 Self::from_real(self.imag())
573 }
574
575 pub fn compose_complex(real: Self, imag: Self) -> Result<Self> {
594 if !real.is_real() || !imag.is_real() {
595 return Err(anyhow!(
596 "compose_complex requires real-valued inputs, got real={:?}, imag={:?}",
597 real.native.dtype(),
598 imag.native.dtype()
599 ));
600 }
601 Ok(Self::from_complex(real.real(), imag.real()))
602 }
603
604 pub fn sqrt(&self) -> Self {
617 if self.is_complex() || self.real() < 0.0 {
618 let value = self.value().into_complex().sqrt();
619 if value.im == 0.0 {
620 Self::from_real(value.re)
621 } else {
622 Self::from_value(value)
623 }
624 } else {
625 Self::from_real(self.real().sqrt())
626 }
627 }
628
629 pub fn powf(&self, exponent: f64) -> Self {
643 let needs_complex_promotion =
644 self.is_complex() || (self.real() < 0.0 && exponent.fract() != 0.0);
645 if needs_complex_promotion {
646 let value = self.value().into_complex().powf(exponent);
647 if value.im == 0.0 {
648 Self::from_real(value.re)
649 } else {
650 Self::from_value(value)
651 }
652 } else {
653 Self::from_real(self.real().powf(exponent))
654 }
655 }
656
657 pub fn powi(&self, exponent: i32) -> Self {
668 self.powf(exponent as f64)
669 }
670}
671
672impl SumFromStorage for Scalar {
673 fn sum_from_storage(storage: &Storage) -> Self {
674 match scalar_value_from_storage(storage) {
675 ScalarValue::F32(value) => Self::from_value(value),
676 ScalarValue::F64(value) => Self::from_value(value),
677 ScalarValue::I64(value) => Self::from_i64(value),
678 ScalarValue::C32(value) => Self::from_value(value),
679 ScalarValue::C64(value) => Self::from_value(value),
680 }
681 }
682}
683
684impl From<f32> for Scalar {
685 fn from(value: f32) -> Self {
686 Self::from_value(value)
687 }
688}
689
690impl From<f64> for Scalar {
691 fn from(value: f64) -> Self {
692 Self::from_value(value)
693 }
694}
695
696impl From<Complex32> for Scalar {
697 fn from(value: Complex32) -> Self {
698 Self::from_value(value)
699 }
700}
701
702impl From<Complex64> for Scalar {
703 fn from(value: Complex64) -> Self {
704 Self::from_value(value)
705 }
706}
707
708impl TryFrom<Scalar> for f64 {
709 type Error = &'static str;
710
711 fn try_from(value: Scalar) -> std::result::Result<Self, Self::Error> {
712 match value.value() {
713 ScalarValue::F32(real) => Ok(real as f64),
714 ScalarValue::F64(real) => Ok(real),
715 ScalarValue::I64(real) => Ok(real as f64),
716 ScalarValue::C32(_) | ScalarValue::C64(_) => {
717 Err("cannot convert complex scalar to f64")
718 }
719 }
720 }
721}
722
723impl From<Scalar> for Complex64 {
724 fn from(value: Scalar) -> Self {
725 value.value().into_complex()
726 }
727}
728
729impl Add for Scalar {
730 type Output = Self;
731
732 fn add(self, rhs: Self) -> Self::Output {
733 match (self.value(), rhs.value()) {
734 (ScalarValue::F32(lhs), ScalarValue::F32(rhs)) => Self::from_value(lhs + rhs),
735 (lhs, rhs) if lhs.is_complex() || rhs.is_complex() => {
736 Self::from_value(lhs.into_complex() + rhs.into_complex())
737 }
738 (lhs, rhs) => Self::from_real(lhs.real() + rhs.real()),
739 }
740 }
741}
742
743impl Sub for Scalar {
744 type Output = Self;
745
746 fn sub(self, rhs: Self) -> Self::Output {
747 self + (-rhs)
748 }
749}
750
751impl Mul for Scalar {
752 type Output = Self;
753
754 fn mul(self, rhs: Self) -> Self::Output {
755 match (self.value(), rhs.value()) {
756 (ScalarValue::F32(lhs), ScalarValue::F32(rhs)) => Self::from_value(lhs * rhs),
757 (lhs, rhs) if lhs.is_complex() || rhs.is_complex() => {
758 Self::from_value(lhs.into_complex() * rhs.into_complex())
759 }
760 (lhs, rhs) => Self::from_real(lhs.real() * rhs.real()),
761 }
762 }
763}
764
765impl Div for Scalar {
766 type Output = Self;
767
768 fn div(self, rhs: Self) -> Self::Output {
769 match (self.value(), rhs.value()) {
770 (ScalarValue::F32(lhs), ScalarValue::F32(rhs)) => Self::from_value(lhs / rhs),
771 (lhs, rhs) if lhs.is_complex() || rhs.is_complex() => {
772 Self::from_value(lhs.into_complex() / rhs.into_complex())
773 }
774 (lhs, rhs) => Self::from_real(lhs.real() / rhs.real()),
775 }
776 }
777}
778
779impl Neg for Scalar {
780 type Output = Self;
781
782 fn neg(self) -> Self::Output {
783 match self.value() {
784 ScalarValue::F32(value) => Self::from_value(-value),
785 ScalarValue::F64(value) => Self::from_value(-value),
786 ScalarValue::I64(value) => Self::from_i64(-value),
787 ScalarValue::C32(value) => Self::from_value(-value),
788 ScalarValue::C64(value) => Self::from_value(-value),
789 }
790 }
791}
792
793impl Mul<Scalar> for f64 {
794 type Output = Scalar;
795
796 fn mul(self, rhs: Scalar) -> Self::Output {
797 Scalar::from_real(self) * rhs
798 }
799}
800
801impl Mul<Scalar> for Complex64 {
802 type Output = Scalar;
803
804 fn mul(self, rhs: Scalar) -> Self::Output {
805 Scalar::from(self) * rhs
806 }
807}
808
809impl Div<Scalar> for Complex64 {
810 type Output = Scalar;
811
812 fn div(self, rhs: Scalar) -> Self::Output {
813 Scalar::from(self) / rhs
814 }
815}
816
817impl Default for Scalar {
818 fn default() -> Self {
819 Self::zero()
820 }
821}
822
823impl Zero for Scalar {
824 fn zero() -> Self {
825 Self::from_real(0.0)
826 }
827
828 fn is_zero(&self) -> bool {
829 Scalar::is_zero(self)
830 }
831}
832
833impl One for Scalar {
834 fn one() -> Self {
835 Self::from_real(1.0)
836 }
837}
838
839impl PartialEq for Scalar {
840 fn eq(&self, other: &Self) -> bool {
841 self.native.dtype() == other.native.dtype() && self.value() == other.value()
842 }
843}
844
845impl PartialOrd for Scalar {
846 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
847 match (self.value(), other.value()) {
848 (ScalarValue::F32(lhs), ScalarValue::F32(rhs)) => lhs.partial_cmp(&rhs),
849 (ScalarValue::F32(lhs), ScalarValue::F64(rhs)) => (lhs as f64).partial_cmp(&rhs),
850 (ScalarValue::F32(lhs), ScalarValue::I64(rhs)) => {
851 (lhs as f64).partial_cmp(&(rhs as f64))
852 }
853 (ScalarValue::F64(lhs), ScalarValue::F32(rhs)) => lhs.partial_cmp(&(rhs as f64)),
854 (ScalarValue::F64(lhs), ScalarValue::F64(rhs)) => lhs.partial_cmp(&rhs),
855 (ScalarValue::F64(lhs), ScalarValue::I64(rhs)) => lhs.partial_cmp(&(rhs as f64)),
856 (ScalarValue::I64(lhs), ScalarValue::F32(rhs)) => {
857 (lhs as f64).partial_cmp(&(rhs as f64))
858 }
859 (ScalarValue::I64(lhs), ScalarValue::F64(rhs)) => (lhs as f64).partial_cmp(&rhs),
860 (ScalarValue::I64(lhs), ScalarValue::I64(rhs)) => lhs.partial_cmp(&rhs),
861 _ => None,
862 }
863 }
864}
865
866impl fmt::Display for Scalar {
867 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
868 match self.value() {
869 ScalarValue::F32(value) => value.fmt(f),
870 ScalarValue::F64(value) => value.fmt(f),
871 ScalarValue::I64(value) => value.fmt(f),
872 ScalarValue::C32(value) => value.fmt(f),
873 ScalarValue::C64(value) => value.fmt(f),
874 }
875 }
876}
877
878impl fmt::Debug for Scalar {
879 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
880 f.debug_struct("Scalar")
881 .field("dtype", &self.native.dtype())
882 .field("value", &self.value())
883 .finish()
884 }
885}
886
887impl Clone for Scalar {
888 fn clone(&self) -> Self {
889 Self {
890 native: self.native.clone(),
891 value: self.value,
892 }
893 }
894}
895
896#[cfg(test)]
897mod tests;