1use std::cmp::Ordering;
2use std::fmt;
3use std::ops::{Add, Div, Mul, Neg, Sub};
4
5use anyhow::{anyhow, Result};
6use num_complex::{Complex32, Complex64};
7use num_traits::{One, Zero};
8use tenferro::DType;
9use tensor4all_tensorbackend::AnyScalar as BackendScalar;
10
11use crate::defaults::tensordynlen::TensorDynLen;
12use crate::storage::{Storage, SumFromStorage};
13use crate::TensorElement;
14
15#[derive(Clone, Copy, Debug, PartialEq)]
16enum ScalarValue {
17 F32(f32),
18 F64(f64),
19 C32(Complex32),
20 C64(Complex64),
21}
22
23impl ScalarValue {
24 fn real(self) -> f64 {
25 match self {
26 Self::F32(value) => value as f64,
27 Self::F64(value) => value,
28 Self::C32(value) => value.re as f64,
29 Self::C64(value) => value.re,
30 }
31 }
32
33 fn imag(self) -> f64 {
34 match self {
35 Self::F32(_) | Self::F64(_) => 0.0,
36 Self::C32(value) => value.im as f64,
37 Self::C64(value) => value.im,
38 }
39 }
40
41 fn abs(self) -> f64 {
42 match self {
43 Self::F32(value) => value.abs() as f64,
44 Self::F64(value) => value.abs(),
45 Self::C32(value) => value.norm() as f64,
46 Self::C64(value) => value.norm(),
47 }
48 }
49
50 fn is_complex(self) -> bool {
51 matches!(self, Self::C32(_) | Self::C64(_))
52 }
53
54 fn is_zero(self) -> bool {
55 match self {
56 Self::F32(value) => value == 0.0,
57 Self::F64(value) => value == 0.0,
58 Self::C32(value) => value == Complex32::new(0.0, 0.0),
59 Self::C64(value) => value == Complex64::new(0.0, 0.0),
60 }
61 }
62
63 fn into_complex(self) -> Complex64 {
64 match self {
65 Self::F32(value) => Complex64::new(value as f64, 0.0),
66 Self::F64(value) => Complex64::new(value, 0.0),
67 Self::C32(value) => Complex64::new(value.re as f64, value.im as f64),
68 Self::C64(value) => value,
69 }
70 }
71}
72
73#[derive(Clone)]
79pub struct AnyScalar {
80 tensor: TensorDynLen,
81}
82
83#[allow(missing_docs)]
84impl AnyScalar {
85 fn wrap_tensor(tensor: TensorDynLen) -> Result<Self> {
86 let dims = tensor.dims();
87 anyhow::ensure!(
88 dims.is_empty(),
89 "AnyScalar requires a rank-0 tensor, got dims {:?}",
90 dims
91 );
92 Ok(Self { tensor })
93 }
94
95 fn from_tensor_result(tensor: Result<TensorDynLen>, op: &'static str) -> Self {
96 Self::wrap_tensor(
97 tensor
98 .unwrap_or_else(|e| panic!("AnyScalar::{op} returned invalid scalar tensor: {e}")),
99 )
100 .unwrap_or_else(|e| panic!("AnyScalar::{op} returned non-scalar tensor: {e}"))
101 }
102
103 fn from_eager_binary<E>(
104 lhs: &Self,
105 rhs: &Self,
106 op: &'static str,
107 f: impl FnOnce(
108 &tenferro::EagerTensor<tenferro::CpuBackend>,
109 &tenferro::EagerTensor<tenferro::CpuBackend>,
110 ) -> std::result::Result<tenferro::EagerTensor<tenferro::CpuBackend>, E>,
111 ) -> Self
112 where
113 E: fmt::Display,
114 {
115 let result = f(lhs.tensor.inner.as_ref(), rhs.tensor.inner.as_ref())
116 .unwrap_or_else(|e| panic!("AnyScalar::{op} failed: {e}"));
117 Self::from_tensor_result(TensorDynLen::from_inner(vec![], result), op)
118 }
119
120 fn from_eager_unary<E>(
121 input: &Self,
122 op: &'static str,
123 f: impl FnOnce(
124 &tenferro::EagerTensor<tenferro::CpuBackend>,
125 ) -> std::result::Result<tenferro::EagerTensor<tenferro::CpuBackend>, E>,
126 ) -> Self
127 where
128 E: fmt::Display,
129 {
130 let result = f(input.tensor.inner.as_ref())
131 .unwrap_or_else(|e| panic!("AnyScalar::{op} failed: {e}"));
132 Self::from_tensor_result(TensorDynLen::from_inner(vec![], result), op)
133 }
134
135 fn value(&self) -> ScalarValue {
136 match self.tensor.as_native().dtype() {
137 DType::F32 => ScalarValue::F32(
138 self.tensor
139 .to_vec::<f32>()
140 .unwrap_or_else(|e| panic!("failed to read f32 scalar value: {e}"))[0],
141 ),
142 DType::F64 => ScalarValue::F64(
143 self.tensor
144 .to_vec::<f64>()
145 .unwrap_or_else(|e| panic!("failed to read f64 scalar value: {e}"))[0],
146 ),
147 DType::C32 => ScalarValue::C32(
148 self.tensor
149 .to_vec::<Complex32>()
150 .unwrap_or_else(|e| panic!("failed to read c32 scalar value: {e}"))[0],
151 ),
152 DType::C64 => ScalarValue::C64(
153 self.tensor
154 .to_vec::<Complex64>()
155 .unwrap_or_else(|e| panic!("failed to read c64 scalar value: {e}"))[0],
156 ),
157 }
158 }
159
160 fn from_backend_scalar(value: BackendScalar) -> Self {
161 if let Some(value) = value.as_c64() {
162 Self::from_value(value)
163 } else {
164 Self::from_value(value.real())
165 }
166 }
167
168 pub(crate) fn from_tensor_unchecked(tensor: TensorDynLen) -> Self {
169 Self::wrap_tensor(tensor).unwrap_or_else(|e| panic!("AnyScalar tensor wrapper failed: {e}"))
170 }
171
172 pub(crate) fn as_tensor(&self) -> &TensorDynLen {
173 &self.tensor
174 }
175
176 pub fn from_value<T: TensorElement>(value: T) -> Self {
177 Self::from_tensor_result(TensorDynLen::scalar(value), "from_value")
178 }
179
180 pub fn from_real(x: f64) -> Self {
181 Self::from_value(x)
182 }
183
184 pub fn from_complex(re: f64, im: f64) -> Self {
185 Self::from_value(Complex64::new(re, im))
186 }
187
188 pub fn new_real(x: f64) -> Self {
189 Self::from_real(x)
190 }
191
192 pub fn new_complex(re: f64, im: f64) -> Self {
193 Self::from_complex(re, im)
194 }
195
196 pub fn primal(&self) -> Self {
197 self.detach()
198 }
199
200 pub fn enable_grad(self) -> Self {
201 Self::from_tensor_unchecked(self.tensor.enable_grad())
202 }
203
204 pub fn tracks_grad(&self) -> bool {
205 self.tensor.tracks_grad()
206 }
207
208 pub fn grad(&self) -> Result<Option<Self>> {
209 self.tensor
210 .grad()
211 .map(|maybe_grad| maybe_grad.map(Self::from_tensor_unchecked))
212 }
213
214 pub fn clear_grad(&self) -> Result<()> {
215 self.tensor.clear_grad()
216 }
217
218 pub fn backward(&self) -> Result<()> {
219 self.tensor.backward()
220 }
221
222 pub fn detach(&self) -> Self {
223 Self::from_tensor_unchecked(self.tensor.detach())
224 }
225
226 pub fn real(&self) -> f64 {
227 self.value().real()
228 }
229
230 pub fn imag(&self) -> f64 {
231 self.value().imag()
232 }
233
234 pub fn abs(&self) -> f64 {
235 self.value().abs()
236 }
237
238 pub fn is_complex(&self) -> bool {
239 self.value().is_complex()
240 }
241
242 pub fn is_real(&self) -> bool {
243 !self.is_complex()
244 }
245
246 pub fn is_zero(&self) -> bool {
247 self.value().is_zero()
248 }
249
250 pub fn as_f64(&self) -> Option<f64> {
251 match self.value() {
252 ScalarValue::F32(value) => Some(value as f64),
253 ScalarValue::F64(value) => Some(value),
254 ScalarValue::C32(_) | ScalarValue::C64(_) => None,
255 }
256 }
257
258 pub fn as_c64(&self) -> Option<Complex64> {
259 match self.value() {
260 ScalarValue::F32(_) | ScalarValue::F64(_) => None,
261 ScalarValue::C32(value) => Some(Complex64::new(value.re as f64, value.im as f64)),
262 ScalarValue::C64(value) => Some(value),
263 }
264 }
265
266 pub fn conj(&self) -> Self {
267 Self::from_eager_unary(self, "conj", |tensor| tensor.conj())
268 }
269
270 pub fn real_part(&self) -> Self {
271 Self::from_real(self.real())
272 }
273
274 pub fn imag_part(&self) -> Self {
275 Self::from_real(self.imag())
276 }
277
278 pub fn compose_complex(real: Self, imag: Self) -> Result<Self> {
279 if !real.is_real() || !imag.is_real() {
280 return Err(anyhow!(
281 "compose_complex requires real-valued inputs, got real={:?}, imag={:?}",
282 real.tensor.as_native().dtype(),
283 imag.tensor.as_native().dtype()
284 ));
285 }
286 let imag_term = &imag * &Self::new_complex(0.0, 1.0);
287 Ok(&real + &imag_term)
288 }
289
290 pub fn sqrt(&self) -> Self {
291 if self.is_complex() || self.real() < 0.0 {
292 Self::from_backend_scalar(self.to_backend_scalar().sqrt())
293 } else {
294 Self::from_eager_unary(self, "sqrt", |tensor| tensor.sqrt())
295 }
296 }
297
298 pub fn powf(&self, exponent: f64) -> Self {
299 Self::from_backend_scalar(self.to_backend_scalar().powf(exponent))
300 }
301
302 pub fn powi(&self, exponent: i32) -> Self {
303 if exponent == 0 {
304 return Self::one();
305 }
306
307 let mut base = self.clone();
308 let mut power = exponent.unsigned_abs();
309 let mut acc = Self::one();
310
311 while power > 0 {
312 if power % 2 == 1 {
313 acc = &acc * &base;
314 }
315 power /= 2;
316 if power > 0 {
317 base = &base * &base;
318 }
319 }
320
321 if exponent < 0 {
322 Self::one() / acc
323 } else {
324 acc
325 }
326 }
327
328 pub(crate) fn to_backend_scalar(&self) -> BackendScalar {
329 match self.value() {
330 ScalarValue::F32(value) => BackendScalar::from_value(value),
331 ScalarValue::F64(value) => BackendScalar::from_value(value),
332 ScalarValue::C32(value) => BackendScalar::from_value(value),
333 ScalarValue::C64(value) => BackendScalar::from_value(value),
334 }
335 }
336}
337
338impl SumFromStorage for AnyScalar {
339 fn sum_from_storage(storage: &Storage) -> Self {
340 Self::from_backend_scalar(BackendScalar::sum_from_storage(storage))
341 }
342}
343
344impl From<f32> for AnyScalar {
345 fn from(value: f32) -> Self {
346 Self::from_value(value)
347 }
348}
349
350impl From<f64> for AnyScalar {
351 fn from(value: f64) -> Self {
352 Self::from_value(value)
353 }
354}
355
356impl From<Complex32> for AnyScalar {
357 fn from(value: Complex32) -> Self {
358 Self::from_value(value)
359 }
360}
361
362impl From<Complex64> for AnyScalar {
363 fn from(value: Complex64) -> Self {
364 Self::from_value(value)
365 }
366}
367
368impl TryFrom<AnyScalar> for f64 {
369 type Error = &'static str;
370
371 fn try_from(value: AnyScalar) -> std::result::Result<Self, Self::Error> {
372 value.as_f64().ok_or("cannot convert complex scalar to f64")
373 }
374}
375
376impl From<AnyScalar> for Complex64 {
377 fn from(value: AnyScalar) -> Self {
378 value.value().into_complex()
379 }
380}
381
382impl Add<&AnyScalar> for &AnyScalar {
383 type Output = AnyScalar;
384
385 fn add(self, rhs: &AnyScalar) -> Self::Output {
386 AnyScalar::from_eager_binary(self, rhs, "add", |lhs, rhs| lhs.add(rhs))
387 }
388}
389
390impl Add<AnyScalar> for AnyScalar {
391 type Output = AnyScalar;
392
393 fn add(self, rhs: AnyScalar) -> Self::Output {
394 Add::add(&self, &rhs)
395 }
396}
397
398impl Add<AnyScalar> for &AnyScalar {
399 type Output = AnyScalar;
400
401 fn add(self, rhs: AnyScalar) -> Self::Output {
402 Add::add(self, &rhs)
403 }
404}
405
406impl Add<&AnyScalar> for AnyScalar {
407 type Output = AnyScalar;
408
409 fn add(self, rhs: &AnyScalar) -> Self::Output {
410 Add::add(&self, rhs)
411 }
412}
413
414impl Sub<&AnyScalar> for &AnyScalar {
415 type Output = AnyScalar;
416
417 fn sub(self, rhs: &AnyScalar) -> Self::Output {
418 Add::add(self, &Neg::neg(rhs))
419 }
420}
421
422impl Sub<AnyScalar> for AnyScalar {
423 type Output = AnyScalar;
424
425 fn sub(self, rhs: AnyScalar) -> Self::Output {
426 Sub::sub(&self, &rhs)
427 }
428}
429
430impl Sub<AnyScalar> for &AnyScalar {
431 type Output = AnyScalar;
432
433 fn sub(self, rhs: AnyScalar) -> Self::Output {
434 Sub::sub(self, &rhs)
435 }
436}
437
438impl Sub<&AnyScalar> for AnyScalar {
439 type Output = AnyScalar;
440
441 fn sub(self, rhs: &AnyScalar) -> Self::Output {
442 Sub::sub(&self, rhs)
443 }
444}
445
446impl Mul<&AnyScalar> for &AnyScalar {
447 type Output = AnyScalar;
448
449 fn mul(self, rhs: &AnyScalar) -> Self::Output {
450 AnyScalar::from_eager_binary(self, rhs, "mul", |lhs, rhs| lhs.mul(rhs))
451 }
452}
453
454impl Mul<AnyScalar> for AnyScalar {
455 type Output = AnyScalar;
456
457 fn mul(self, rhs: AnyScalar) -> Self::Output {
458 Mul::mul(&self, &rhs)
459 }
460}
461
462impl Mul<AnyScalar> for &AnyScalar {
463 type Output = AnyScalar;
464
465 fn mul(self, rhs: AnyScalar) -> Self::Output {
466 Mul::mul(self, &rhs)
467 }
468}
469
470impl Mul<&AnyScalar> for AnyScalar {
471 type Output = AnyScalar;
472
473 fn mul(self, rhs: &AnyScalar) -> Self::Output {
474 Mul::mul(&self, rhs)
475 }
476}
477
478impl Div<&AnyScalar> for &AnyScalar {
479 type Output = AnyScalar;
480
481 fn div(self, rhs: &AnyScalar) -> Self::Output {
482 if self.tensor.as_native().dtype() == rhs.tensor.as_native().dtype() {
483 AnyScalar::from_eager_binary(self, rhs, "div", |lhs, rhs| lhs.div(rhs))
484 } else {
485 AnyScalar::from_backend_scalar(self.to_backend_scalar() / rhs.to_backend_scalar())
486 }
487 }
488}
489
490impl Div<AnyScalar> for AnyScalar {
491 type Output = AnyScalar;
492
493 fn div(self, rhs: AnyScalar) -> Self::Output {
494 Div::div(&self, &rhs)
495 }
496}
497
498impl Div<AnyScalar> for &AnyScalar {
499 type Output = AnyScalar;
500
501 fn div(self, rhs: AnyScalar) -> Self::Output {
502 Div::div(self, &rhs)
503 }
504}
505
506impl Div<&AnyScalar> for AnyScalar {
507 type Output = AnyScalar;
508
509 fn div(self, rhs: &AnyScalar) -> Self::Output {
510 Div::div(&self, rhs)
511 }
512}
513
514impl Neg for &AnyScalar {
515 type Output = AnyScalar;
516
517 fn neg(self) -> Self::Output {
518 AnyScalar::from_eager_unary(self, "neg", |tensor| tensor.neg())
519 }
520}
521
522impl Neg for AnyScalar {
523 type Output = AnyScalar;
524
525 fn neg(self) -> Self::Output {
526 Neg::neg(&self)
527 }
528}
529
530impl Mul<AnyScalar> for f64 {
531 type Output = AnyScalar;
532
533 fn mul(self, rhs: AnyScalar) -> Self::Output {
534 AnyScalar::from_real(self) * rhs
535 }
536}
537
538impl Mul<AnyScalar> for Complex64 {
539 type Output = AnyScalar;
540
541 fn mul(self, rhs: AnyScalar) -> Self::Output {
542 AnyScalar::from(self) * rhs
543 }
544}
545
546impl Div<AnyScalar> for Complex64 {
547 type Output = AnyScalar;
548
549 fn div(self, rhs: AnyScalar) -> Self::Output {
550 AnyScalar::from(self) / rhs
551 }
552}
553
554impl Default for AnyScalar {
555 fn default() -> Self {
556 Self::zero()
557 }
558}
559
560impl Zero for AnyScalar {
561 fn zero() -> Self {
562 Self::from_real(0.0)
563 }
564
565 fn is_zero(&self) -> bool {
566 AnyScalar::is_zero(self)
567 }
568}
569
570impl One for AnyScalar {
571 fn one() -> Self {
572 Self::from_real(1.0)
573 }
574}
575
576impl PartialEq for AnyScalar {
577 fn eq(&self, other: &Self) -> bool {
578 self.tensor.as_native().dtype() == other.tensor.as_native().dtype()
579 && self.value() == other.value()
580 }
581}
582
583impl PartialOrd for AnyScalar {
584 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
585 match (self.value(), other.value()) {
586 (ScalarValue::F32(lhs), ScalarValue::F32(rhs)) => lhs.partial_cmp(&rhs),
587 (ScalarValue::F32(lhs), ScalarValue::F64(rhs)) => (lhs as f64).partial_cmp(&rhs),
588 (ScalarValue::F64(lhs), ScalarValue::F32(rhs)) => lhs.partial_cmp(&(rhs as f64)),
589 (ScalarValue::F64(lhs), ScalarValue::F64(rhs)) => lhs.partial_cmp(&rhs),
590 _ => None,
591 }
592 }
593}
594
595impl fmt::Display for AnyScalar {
596 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
597 match self.value() {
598 ScalarValue::F32(value) => value.fmt(f),
599 ScalarValue::F64(value) => value.fmt(f),
600 ScalarValue::C32(value) => value.fmt(f),
601 ScalarValue::C64(value) => value.fmt(f),
602 }
603 }
604}
605
606impl fmt::Debug for AnyScalar {
607 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
608 f.debug_struct("AnyScalar")
609 .field("dtype", &self.tensor.as_native().dtype())
610 .field("value", &self.value())
611 .field("tracks_grad", &self.tracks_grad())
612 .finish()
613 }
614}