1use std::cmp::Ordering;
2use std::fmt;
3use std::ops::{Add, Div, Mul, Neg, Sub};
4
5use anyhow::{anyhow, Result};
6use num_complex::{Complex32, Complex64};
7use num_traits::{One, Zero};
8use tensor4all_tensorbackend::AnyScalar as BackendScalar;
9
10use crate::defaults::tensordynlen::TensorDynLen;
11use crate::TensorElement;
12use tensor4all_tensorbackend::{Storage, SumFromStorage};
13
14#[derive(Clone, Copy, Debug, PartialEq)]
15enum ScalarValue {
16 F32(f32),
17 F64(f64),
18 C32(Complex32),
19 C64(Complex64),
20}
21
22impl ScalarValue {
23 fn real(self) -> f64 {
24 match self {
25 Self::F32(value) => value as f64,
26 Self::F64(value) => value,
27 Self::C32(value) => value.re as f64,
28 Self::C64(value) => value.re,
29 }
30 }
31
32 fn imag(self) -> f64 {
33 match self {
34 Self::F32(_) | Self::F64(_) => 0.0,
35 Self::C32(value) => value.im as f64,
36 Self::C64(value) => value.im,
37 }
38 }
39
40 fn abs(self) -> f64 {
41 match self {
42 Self::F32(value) => value.abs() as f64,
43 Self::F64(value) => value.abs(),
44 Self::C32(value) => value.norm() as f64,
45 Self::C64(value) => value.norm(),
46 }
47 }
48
49 fn is_complex(self) -> bool {
50 matches!(self, Self::C32(_) | Self::C64(_))
51 }
52
53 fn is_zero(self) -> bool {
54 match self {
55 Self::F32(value) => value == 0.0,
56 Self::F64(value) => value == 0.0,
57 Self::C32(value) => value == Complex32::new(0.0, 0.0),
58 Self::C64(value) => value == Complex64::new(0.0, 0.0),
59 }
60 }
61
62 fn into_complex(self) -> Complex64 {
63 match self {
64 Self::F32(value) => Complex64::new(value as f64, 0.0),
65 Self::F64(value) => Complex64::new(value, 0.0),
66 Self::C32(value) => Complex64::new(value.re as f64, value.im as f64),
67 Self::C64(value) => value,
68 }
69 }
70}
71
72trait ScalarTensorElement: TensorElement {
73 fn scalar_value(value: Self) -> ScalarValue;
74}
75
76impl ScalarTensorElement for f32 {
77 fn scalar_value(value: Self) -> ScalarValue {
78 ScalarValue::F32(value)
79 }
80}
81
82impl ScalarTensorElement for f64 {
83 fn scalar_value(value: Self) -> ScalarValue {
84 ScalarValue::F64(value)
85 }
86}
87
88impl ScalarTensorElement for Complex32 {
89 fn scalar_value(value: Self) -> ScalarValue {
90 ScalarValue::C32(value)
91 }
92}
93
94impl ScalarTensorElement for Complex64 {
95 fn scalar_value(value: Self) -> ScalarValue {
96 ScalarValue::C64(value)
97 }
98}
99
100#[derive(Clone)]
106pub struct AnyScalar {
107 tensor: Option<TensorDynLen>,
108 value: ScalarValue,
109}
110
111impl AnyScalar {
112 fn wrap_tensor(tensor: TensorDynLen) -> Result<Self> {
113 let dims = tensor.dims();
114 anyhow::ensure!(
115 dims.is_empty(),
116 "AnyScalar requires a rank-0 tensor, got dims {:?}",
117 dims
118 );
119 let value = Self::scalar_value_from_tensor(&tensor)?;
120 Ok(Self {
121 tensor: Some(tensor),
122 value,
123 })
124 }
125
126 fn from_tensor_result(tensor: Result<TensorDynLen>, op: &'static str) -> Result<Self> {
127 Self::wrap_tensor(
128 tensor.map_err(|e| anyhow!("AnyScalar::{op} returned invalid scalar tensor: {e}"))?,
129 )
130 .map_err(|e| anyhow!("AnyScalar::{op} returned non-scalar tensor: {e}"))
131 }
132
133 fn from_eager_binary<E>(
134 lhs: &Self,
135 rhs: &Self,
136 op: &'static str,
137 f: impl FnOnce(
138 &tenferro_ad::EagerTensor,
139 &tenferro_ad::EagerTensor,
140 ) -> std::result::Result<tenferro_ad::EagerTensor, E>,
141 ) -> Result<Self>
142 where
143 E: fmt::Display,
144 {
145 let result = f(lhs.as_tensor()?.as_inner()?, rhs.as_tensor()?.as_inner()?)
146 .map_err(|e| anyhow!("AnyScalar::{op} failed: {e}"))?;
147 Self::from_tensor_result(TensorDynLen::from_inner(vec![], result), op)
148 }
149
150 fn from_eager_unary<E>(
151 input: &Self,
152 op: &'static str,
153 f: impl FnOnce(&tenferro_ad::EagerTensor) -> std::result::Result<tenferro_ad::EagerTensor, E>,
154 ) -> Result<Self>
155 where
156 E: fmt::Display,
157 {
158 let result = f(input.as_tensor()?.as_inner()?)
159 .map_err(|e| anyhow!("AnyScalar::{op} failed: {e}"))?;
160 Self::from_tensor_result(TensorDynLen::from_inner(vec![], result), op)
161 }
162
163 fn scalar_value_from_tensor(tensor: &TensorDynLen) -> Result<ScalarValue> {
164 let storage = tensor.storage();
165 if storage.is_c64() {
166 let values = storage
167 .payload_c64_col_major_vec()
168 .map_err(|e| anyhow!("failed to read c64 scalar storage: {e}"))?;
169 values
170 .first()
171 .copied()
172 .map(ScalarValue::C64)
173 .ok_or_else(|| anyhow!("rank-0 c64 scalar storage is empty"))
174 } else {
175 let values = storage
176 .payload_f64_col_major_vec()
177 .map_err(|e| anyhow!("failed to read f64 scalar storage: {e}"))?;
178 values
179 .first()
180 .copied()
181 .map(ScalarValue::F64)
182 .ok_or_else(|| anyhow!("rank-0 f64 scalar storage is empty"))
183 }
184 }
185
186 fn value(&self) -> ScalarValue {
187 self.value
188 }
189
190 fn from_backend_scalar(value: BackendScalar) -> Self {
191 if let Some(value) = value.as_c64() {
192 Self::from_value(value)
193 } else {
194 Self::from_value(value.real())
195 }
196 }
197
198 pub(crate) fn from_tensor(tensor: TensorDynLen) -> Result<Self> {
199 Self::wrap_tensor(tensor)
200 }
201
202 pub(crate) fn as_tensor(&self) -> Result<&TensorDynLen> {
203 self.tensor
204 .as_ref()
205 .ok_or_else(|| anyhow!("AnyScalar has no backend tensor representation"))
206 }
207
208 #[allow(private_bounds)]
231 pub fn from_value<T: ScalarTensorElement>(value: T) -> Self {
232 Self {
233 tensor: TensorDynLen::scalar(value).ok(),
234 value: T::scalar_value(value),
235 }
236 }
237
238 pub fn from_real(x: f64) -> Self {
260 Self::from_value(x)
261 }
262
263 pub fn from_complex(re: f64, im: f64) -> Self {
286 Self::from_value(Complex64::new(re, im))
287 }
288
289 pub fn new_real(x: f64) -> Self {
311 Self::from_real(x)
312 }
313
314 pub fn new_complex(re: f64, im: f64) -> Self {
337 Self::from_complex(re, im)
338 }
339
340 pub fn primal(&self) -> Result<Self> {
358 self.detach()
359 }
360
361 pub fn enable_grad(self) -> Result<Self> {
376 let tensor = self
377 .tensor
378 .ok_or_else(|| anyhow!("AnyScalar has no backend tensor representation"))?;
379 Self::from_tensor(tensor.enable_grad()?)
380 }
381
382 pub fn tracks_grad(&self) -> bool {
398 self.tensor.as_ref().is_some_and(TensorDynLen::tracks_grad)
399 }
400
401 pub fn grad(&self) -> Result<Option<Self>> {
426 self.as_tensor()?
427 .grad()
428 .and_then(|maybe_grad| maybe_grad.map(Self::from_tensor).transpose())
429 }
430
431 pub fn clear_grad(&self) -> Result<()> {
455 self.as_tensor()?.clear_grad()
456 }
457
458 pub fn backward(&self) -> Result<()> {
481 self.as_tensor()?.backward()
482 }
483
484 pub fn detach(&self) -> Result<Self> {
504 Self::from_tensor(self.as_tensor()?.detach()?)
505 }
506
507 pub fn real(&self) -> f64 {
523 self.value().real()
524 }
525
526 pub fn imag(&self) -> f64 {
541 self.value().imag()
542 }
543
544 pub fn abs(&self) -> f64 {
560 self.value().abs()
561 }
562
563 pub fn is_complex(&self) -> bool {
578 self.value().is_complex()
579 }
580
581 pub fn is_real(&self) -> bool {
596 !self.is_complex()
597 }
598
599 pub fn is_zero(&self) -> bool {
614 self.value().is_zero()
615 }
616
617 pub fn as_f64(&self) -> Option<f64> {
633 match self.value() {
634 ScalarValue::F32(value) => Some(value as f64),
635 ScalarValue::F64(value) => Some(value),
636 ScalarValue::C32(_) | ScalarValue::C64(_) => None,
637 }
638 }
639
640 pub fn as_c64(&self) -> Option<Complex64> {
657 match self.value() {
658 ScalarValue::F32(_) | ScalarValue::F64(_) => None,
659 ScalarValue::C32(value) => Some(Complex64::new(value.re as f64, value.im as f64)),
660 ScalarValue::C64(value) => Some(value),
661 }
662 }
663
664 pub fn try_conj(&self) -> Result<Self> {
679 if !self.tracks_grad() {
680 return Ok(Self::from_backend_scalar(self.to_backend_scalar().conj()));
681 }
682 Self::from_eager_unary(self, "conj", |tensor| tensor.conj())
683 }
684
685 pub fn conj(&self) -> Self {
687 self.try_conj()
688 .unwrap_or_else(|_| Self::from_backend_scalar(self.to_backend_scalar().conj()))
689 }
690
691 pub fn real_part(&self) -> Self {
707 Self::from_real(self.real())
708 }
709
710 pub fn imag_part(&self) -> Self {
726 Self::from_real(self.imag())
727 }
728
729 pub fn compose_complex(real: Self, imag: Self) -> Result<Self> {
758 if !real.is_real() || !imag.is_real() {
759 return Err(anyhow!("compose_complex requires real-valued inputs"));
760 }
761 let imag_term = imag.try_mul(&Self::new_complex(0.0, 1.0))?;
762 real.try_add(&imag_term)
763 }
764
765 pub fn sqrt(&self) -> Self {
782 if !self.tracks_grad() || self.is_complex() || self.real() < 0.0 {
783 Self::from_backend_scalar(self.to_backend_scalar().sqrt())
784 } else {
785 Self::from_eager_unary(self, "sqrt", |tensor| tensor.sqrt())
786 .unwrap_or_else(|_| Self::from_backend_scalar(self.to_backend_scalar().sqrt()))
787 }
788 }
789
790 pub fn powf(&self, exponent: f64) -> Self {
809 Self::from_backend_scalar(self.to_backend_scalar().powf(exponent))
810 }
811
812 pub fn powi(&self, exponent: i32) -> Self {
832 if exponent == 0 {
833 return Self::one();
834 }
835
836 let mut base = self.clone();
837 let mut power = exponent.unsigned_abs();
838 let mut acc = Self::one();
839
840 while power > 0 {
841 if power % 2 == 1 {
842 acc = acc.try_mul(&base).unwrap_or_else(|_| {
843 Self::from_backend_scalar(acc.to_backend_scalar() * base.to_backend_scalar())
844 });
845 }
846 power /= 2;
847 if power > 0 {
848 base = base.try_mul(&base).unwrap_or_else(|_| {
849 Self::from_backend_scalar(base.to_backend_scalar() * base.to_backend_scalar())
850 });
851 }
852 }
853
854 if exponent < 0 {
855 Self::one().try_div(&acc).unwrap_or_else(|_| {
856 Self::from_backend_scalar(Self::one().to_backend_scalar() / acc.to_backend_scalar())
857 })
858 } else {
859 acc
860 }
861 }
862
863 pub(crate) fn to_backend_scalar(&self) -> BackendScalar {
864 match self.value() {
865 ScalarValue::F32(value) => BackendScalar::from_value(value),
866 ScalarValue::F64(value) => BackendScalar::from_value(value),
867 ScalarValue::C32(value) => BackendScalar::from_value(value),
868 ScalarValue::C64(value) => BackendScalar::from_value(value),
869 }
870 }
871
872 pub(crate) fn try_add(&self, rhs: &Self) -> Result<Self> {
873 if !self.tracks_grad() && !rhs.tracks_grad() {
874 return Ok(Self::from_backend_scalar(
875 self.to_backend_scalar() + rhs.to_backend_scalar(),
876 ));
877 }
878 Self::from_eager_binary(self, rhs, "add", |lhs, rhs| lhs.add(rhs))
879 }
880
881 pub(crate) fn try_mul(&self, rhs: &Self) -> Result<Self> {
882 if !self.tracks_grad() && !rhs.tracks_grad() {
883 return Ok(Self::from_backend_scalar(
884 self.to_backend_scalar() * rhs.to_backend_scalar(),
885 ));
886 }
887 Self::from_eager_binary(self, rhs, "mul", |lhs, rhs| lhs.mul(rhs))
888 }
889
890 pub(crate) fn try_div(&self, rhs: &Self) -> Result<Self> {
891 if !self.tracks_grad() && !rhs.tracks_grad() {
892 return Ok(Self::from_backend_scalar(
893 self.to_backend_scalar() / rhs.to_backend_scalar(),
894 ));
895 }
896 if self.as_tensor()?.as_native()?.dtype() == rhs.as_tensor()?.as_native()?.dtype() {
897 Self::from_eager_binary(self, rhs, "div", |lhs, rhs| lhs.div(rhs))
898 } else {
899 Ok(Self::from_backend_scalar(
900 self.to_backend_scalar() / rhs.to_backend_scalar(),
901 ))
902 }
903 }
904
905 pub(crate) fn try_neg(&self) -> Result<Self> {
906 if !self.tracks_grad() {
907 return Ok(Self::from_backend_scalar(-self.to_backend_scalar()));
908 }
909 Self::from_eager_unary(self, "neg", |tensor| tensor.neg())
910 }
911}
912
913impl SumFromStorage for AnyScalar {
914 fn sum_from_storage(storage: &Storage) -> Self {
915 Self::from_backend_scalar(BackendScalar::sum_from_storage(storage))
916 }
917}
918
919impl From<f32> for AnyScalar {
920 fn from(value: f32) -> Self {
921 Self::from_value(value)
922 }
923}
924
925impl From<f64> for AnyScalar {
926 fn from(value: f64) -> Self {
927 Self::from_value(value)
928 }
929}
930
931impl From<Complex32> for AnyScalar {
932 fn from(value: Complex32) -> Self {
933 Self::from_value(value)
934 }
935}
936
937impl From<Complex64> for AnyScalar {
938 fn from(value: Complex64) -> Self {
939 Self::from_value(value)
940 }
941}
942
943impl TryFrom<AnyScalar> for f64 {
944 type Error = &'static str;
945
946 fn try_from(value: AnyScalar) -> std::result::Result<Self, Self::Error> {
947 value.as_f64().ok_or("cannot convert complex scalar to f64")
948 }
949}
950
951impl From<AnyScalar> for Complex64 {
952 fn from(value: AnyScalar) -> Self {
953 value.value().into_complex()
954 }
955}
956
957impl Add<&AnyScalar> for &AnyScalar {
958 type Output = AnyScalar;
959
960 fn add(self, rhs: &AnyScalar) -> Self::Output {
961 self.try_add(rhs).unwrap_or_else(|_| {
962 AnyScalar::from_backend_scalar(self.to_backend_scalar() + rhs.to_backend_scalar())
963 })
964 }
965}
966
967impl Add<AnyScalar> for AnyScalar {
968 type Output = AnyScalar;
969
970 fn add(self, rhs: AnyScalar) -> Self::Output {
971 Add::add(&self, &rhs)
972 }
973}
974
975impl Add<AnyScalar> for &AnyScalar {
976 type Output = AnyScalar;
977
978 fn add(self, rhs: AnyScalar) -> Self::Output {
979 Add::add(self, &rhs)
980 }
981}
982
983impl Add<&AnyScalar> for AnyScalar {
984 type Output = AnyScalar;
985
986 fn add(self, rhs: &AnyScalar) -> Self::Output {
987 Add::add(&self, rhs)
988 }
989}
990
991impl Sub<&AnyScalar> for &AnyScalar {
992 type Output = AnyScalar;
993
994 fn sub(self, rhs: &AnyScalar) -> Self::Output {
995 Add::add(self, &Neg::neg(rhs))
996 }
997}
998
999impl Sub<AnyScalar> for AnyScalar {
1000 type Output = AnyScalar;
1001
1002 fn sub(self, rhs: AnyScalar) -> Self::Output {
1003 Sub::sub(&self, &rhs)
1004 }
1005}
1006
1007impl Sub<AnyScalar> for &AnyScalar {
1008 type Output = AnyScalar;
1009
1010 fn sub(self, rhs: AnyScalar) -> Self::Output {
1011 Sub::sub(self, &rhs)
1012 }
1013}
1014
1015impl Sub<&AnyScalar> for AnyScalar {
1016 type Output = AnyScalar;
1017
1018 fn sub(self, rhs: &AnyScalar) -> Self::Output {
1019 Sub::sub(&self, rhs)
1020 }
1021}
1022
1023impl Mul<&AnyScalar> for &AnyScalar {
1024 type Output = AnyScalar;
1025
1026 fn mul(self, rhs: &AnyScalar) -> Self::Output {
1027 self.try_mul(rhs).unwrap_or_else(|_| {
1028 AnyScalar::from_backend_scalar(self.to_backend_scalar() * rhs.to_backend_scalar())
1029 })
1030 }
1031}
1032
1033impl Mul<AnyScalar> for AnyScalar {
1034 type Output = AnyScalar;
1035
1036 fn mul(self, rhs: AnyScalar) -> Self::Output {
1037 Mul::mul(&self, &rhs)
1038 }
1039}
1040
1041impl Mul<AnyScalar> for &AnyScalar {
1042 type Output = AnyScalar;
1043
1044 fn mul(self, rhs: AnyScalar) -> Self::Output {
1045 Mul::mul(self, &rhs)
1046 }
1047}
1048
1049impl Mul<&AnyScalar> for AnyScalar {
1050 type Output = AnyScalar;
1051
1052 fn mul(self, rhs: &AnyScalar) -> Self::Output {
1053 Mul::mul(&self, rhs)
1054 }
1055}
1056
1057impl Div<&AnyScalar> for &AnyScalar {
1058 type Output = AnyScalar;
1059
1060 fn div(self, rhs: &AnyScalar) -> Self::Output {
1061 self.try_div(rhs).unwrap_or_else(|_| {
1062 AnyScalar::from_backend_scalar(self.to_backend_scalar() / rhs.to_backend_scalar())
1063 })
1064 }
1065}
1066
1067impl Div<AnyScalar> for AnyScalar {
1068 type Output = AnyScalar;
1069
1070 fn div(self, rhs: AnyScalar) -> Self::Output {
1071 Div::div(&self, &rhs)
1072 }
1073}
1074
1075impl Div<AnyScalar> for &AnyScalar {
1076 type Output = AnyScalar;
1077
1078 fn div(self, rhs: AnyScalar) -> Self::Output {
1079 Div::div(self, &rhs)
1080 }
1081}
1082
1083impl Div<&AnyScalar> for AnyScalar {
1084 type Output = AnyScalar;
1085
1086 fn div(self, rhs: &AnyScalar) -> Self::Output {
1087 Div::div(&self, rhs)
1088 }
1089}
1090
1091impl Neg for &AnyScalar {
1092 type Output = AnyScalar;
1093
1094 fn neg(self) -> Self::Output {
1095 self.try_neg()
1096 .unwrap_or_else(|_| AnyScalar::from_backend_scalar(-self.to_backend_scalar()))
1097 }
1098}
1099
1100impl Neg for AnyScalar {
1101 type Output = AnyScalar;
1102
1103 fn neg(self) -> Self::Output {
1104 Neg::neg(&self)
1105 }
1106}
1107
1108impl Mul<AnyScalar> for f64 {
1109 type Output = AnyScalar;
1110
1111 fn mul(self, rhs: AnyScalar) -> Self::Output {
1112 AnyScalar::from_real(self) * rhs
1113 }
1114}
1115
1116impl Mul<AnyScalar> for Complex64 {
1117 type Output = AnyScalar;
1118
1119 fn mul(self, rhs: AnyScalar) -> Self::Output {
1120 AnyScalar::from(self) * rhs
1121 }
1122}
1123
1124impl Div<AnyScalar> for Complex64 {
1125 type Output = AnyScalar;
1126
1127 fn div(self, rhs: AnyScalar) -> Self::Output {
1128 AnyScalar::from(self) / rhs
1129 }
1130}
1131
1132impl Default for AnyScalar {
1133 fn default() -> Self {
1134 Self::zero()
1135 }
1136}
1137
1138impl Zero for AnyScalar {
1139 fn zero() -> Self {
1140 Self::from_real(0.0)
1141 }
1142
1143 fn is_zero(&self) -> bool {
1144 AnyScalar::is_zero(self)
1145 }
1146}
1147
1148impl One for AnyScalar {
1149 fn one() -> Self {
1150 Self::from_real(1.0)
1151 }
1152}
1153
1154impl PartialEq for AnyScalar {
1155 fn eq(&self, other: &Self) -> bool {
1156 self.value() == other.value()
1157 }
1158}
1159
1160impl PartialOrd for AnyScalar {
1161 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1162 match (self.value(), other.value()) {
1163 (ScalarValue::F32(lhs), ScalarValue::F32(rhs)) => lhs.partial_cmp(&rhs),
1164 (ScalarValue::F32(lhs), ScalarValue::F64(rhs)) => (lhs as f64).partial_cmp(&rhs),
1165 (ScalarValue::F64(lhs), ScalarValue::F32(rhs)) => lhs.partial_cmp(&(rhs as f64)),
1166 (ScalarValue::F64(lhs), ScalarValue::F64(rhs)) => lhs.partial_cmp(&rhs),
1167 _ => None,
1168 }
1169 }
1170}
1171
1172impl fmt::Display for AnyScalar {
1173 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1174 match self.value() {
1175 ScalarValue::F32(value) => value.fmt(f),
1176 ScalarValue::F64(value) => value.fmt(f),
1177 ScalarValue::C32(value) => value.fmt(f),
1178 ScalarValue::C64(value) => value.fmt(f),
1179 }
1180 }
1181}
1182
1183impl fmt::Debug for AnyScalar {
1184 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1185 let dtype = match self.value {
1186 ScalarValue::F32(_) => "f32",
1187 ScalarValue::F64(_) => "f64",
1188 ScalarValue::C32(_) => "c32",
1189 ScalarValue::C64(_) => "c64",
1190 };
1191 f.debug_struct("AnyScalar")
1192 .field("dtype", &dtype)
1193 .field("value", &self.value())
1194 .field("tracks_grad", &self.tracks_grad())
1195 .finish()
1196 }
1197}
1198
1199#[cfg(test)]
1200mod tests {
1201 use super::*;
1202
1203 #[test]
1204 fn non_grad_scalar_arithmetic_uses_plain_values() {
1205 let a = AnyScalar::new_real(3.0);
1206 let b = AnyScalar::new_real(4.0);
1207
1208 let value = ((a.clone() + b.clone()) * b.clone() - AnyScalar::new_real(8.0))
1209 / AnyScalar::new_real(2.0);
1210
1211 assert_eq!(value.as_f64(), Some(10.0));
1212 assert!(!value.tracks_grad());
1213 assert!(value.as_tensor().is_ok());
1214 }
1215
1216 #[test]
1217 fn tracked_scalar_arithmetic_preserves_autodiff() {
1218 let x = AnyScalar::new_real(2.0).enable_grad().unwrap();
1219 let y = &x * &x;
1220
1221 assert!(y.tracks_grad());
1222 y.backward().unwrap();
1223
1224 let grad = x.grad().unwrap().unwrap();
1225 assert_eq!(grad.as_f64(), Some(4.0));
1226 }
1227}