Skip to main content

tensor4all_core/
any_scalar.rs

1use std::cmp::Ordering;
2use std::fmt;
3use std::ops::{Add, Div, Mul, Neg, Sub};
4
5use anyhow::{anyhow, Result};
6use num_complex::{Complex32, Complex64};
7use num_traits::{One, Zero};
8use tenferro::DType;
9use tensor4all_tensorbackend::AnyScalar as BackendScalar;
10
11use crate::defaults::tensordynlen::TensorDynLen;
12use crate::storage::{Storage, SumFromStorage};
13use crate::TensorElement;
14
15#[derive(Clone, Copy, Debug, PartialEq)]
16enum ScalarValue {
17    F32(f32),
18    F64(f64),
19    C32(Complex32),
20    C64(Complex64),
21}
22
23impl ScalarValue {
24    fn real(self) -> f64 {
25        match self {
26            Self::F32(value) => value as f64,
27            Self::F64(value) => value,
28            Self::C32(value) => value.re as f64,
29            Self::C64(value) => value.re,
30        }
31    }
32
33    fn imag(self) -> f64 {
34        match self {
35            Self::F32(_) | Self::F64(_) => 0.0,
36            Self::C32(value) => value.im as f64,
37            Self::C64(value) => value.im,
38        }
39    }
40
41    fn abs(self) -> f64 {
42        match self {
43            Self::F32(value) => value.abs() as f64,
44            Self::F64(value) => value.abs(),
45            Self::C32(value) => value.norm() as f64,
46            Self::C64(value) => value.norm(),
47        }
48    }
49
50    fn is_complex(self) -> bool {
51        matches!(self, Self::C32(_) | Self::C64(_))
52    }
53
54    fn is_zero(self) -> bool {
55        match self {
56            Self::F32(value) => value == 0.0,
57            Self::F64(value) => value == 0.0,
58            Self::C32(value) => value == Complex32::new(0.0, 0.0),
59            Self::C64(value) => value == Complex64::new(0.0, 0.0),
60        }
61    }
62
63    fn into_complex(self) -> Complex64 {
64        match self {
65            Self::F32(value) => Complex64::new(value as f64, 0.0),
66            Self::F64(value) => Complex64::new(value, 0.0),
67            Self::C32(value) => Complex64::new(value.re as f64, value.im as f64),
68            Self::C64(value) => value,
69        }
70    }
71}
72
73/// Dynamic scalar compatibility wrapper for tensor4all-core.
74///
75/// This owns a rank-0 [`TensorDynLen`] so that scalar values can participate in
76/// the same eager autodiff graph as tensors while preserving the existing
77/// dynamic scalar API shape.
78#[derive(Clone)]
79pub struct AnyScalar {
80    tensor: TensorDynLen,
81}
82
83#[allow(missing_docs)]
84impl AnyScalar {
85    fn wrap_tensor(tensor: TensorDynLen) -> Result<Self> {
86        let dims = tensor.dims();
87        anyhow::ensure!(
88            dims.is_empty(),
89            "AnyScalar requires a rank-0 tensor, got dims {:?}",
90            dims
91        );
92        Ok(Self { tensor })
93    }
94
95    fn from_tensor_result(tensor: Result<TensorDynLen>, op: &'static str) -> Self {
96        Self::wrap_tensor(
97            tensor
98                .unwrap_or_else(|e| panic!("AnyScalar::{op} returned invalid scalar tensor: {e}")),
99        )
100        .unwrap_or_else(|e| panic!("AnyScalar::{op} returned non-scalar tensor: {e}"))
101    }
102
103    fn from_eager_binary<E>(
104        lhs: &Self,
105        rhs: &Self,
106        op: &'static str,
107        f: impl FnOnce(
108            &tenferro::EagerTensor<tenferro::CpuBackend>,
109            &tenferro::EagerTensor<tenferro::CpuBackend>,
110        ) -> std::result::Result<tenferro::EagerTensor<tenferro::CpuBackend>, E>,
111    ) -> Self
112    where
113        E: fmt::Display,
114    {
115        let result = f(lhs.tensor.inner.as_ref(), rhs.tensor.inner.as_ref())
116            .unwrap_or_else(|e| panic!("AnyScalar::{op} failed: {e}"));
117        Self::from_tensor_result(TensorDynLen::from_inner(vec![], result), op)
118    }
119
120    fn from_eager_unary<E>(
121        input: &Self,
122        op: &'static str,
123        f: impl FnOnce(
124            &tenferro::EagerTensor<tenferro::CpuBackend>,
125        ) -> std::result::Result<tenferro::EagerTensor<tenferro::CpuBackend>, E>,
126    ) -> Self
127    where
128        E: fmt::Display,
129    {
130        let result = f(input.tensor.inner.as_ref())
131            .unwrap_or_else(|e| panic!("AnyScalar::{op} failed: {e}"));
132        Self::from_tensor_result(TensorDynLen::from_inner(vec![], result), op)
133    }
134
135    fn value(&self) -> ScalarValue {
136        match self.tensor.as_native().dtype() {
137            DType::F32 => ScalarValue::F32(
138                self.tensor
139                    .to_vec::<f32>()
140                    .unwrap_or_else(|e| panic!("failed to read f32 scalar value: {e}"))[0],
141            ),
142            DType::F64 => ScalarValue::F64(
143                self.tensor
144                    .to_vec::<f64>()
145                    .unwrap_or_else(|e| panic!("failed to read f64 scalar value: {e}"))[0],
146            ),
147            DType::C32 => ScalarValue::C32(
148                self.tensor
149                    .to_vec::<Complex32>()
150                    .unwrap_or_else(|e| panic!("failed to read c32 scalar value: {e}"))[0],
151            ),
152            DType::C64 => ScalarValue::C64(
153                self.tensor
154                    .to_vec::<Complex64>()
155                    .unwrap_or_else(|e| panic!("failed to read c64 scalar value: {e}"))[0],
156            ),
157        }
158    }
159
160    fn from_backend_scalar(value: BackendScalar) -> Self {
161        if let Some(value) = value.as_c64() {
162            Self::from_value(value)
163        } else {
164            Self::from_value(value.real())
165        }
166    }
167
168    pub(crate) fn from_tensor_unchecked(tensor: TensorDynLen) -> Self {
169        Self::wrap_tensor(tensor).unwrap_or_else(|e| panic!("AnyScalar tensor wrapper failed: {e}"))
170    }
171
172    pub(crate) fn as_tensor(&self) -> &TensorDynLen {
173        &self.tensor
174    }
175
176    pub fn from_value<T: TensorElement>(value: T) -> Self {
177        Self::from_tensor_result(TensorDynLen::scalar(value), "from_value")
178    }
179
180    pub fn from_real(x: f64) -> Self {
181        Self::from_value(x)
182    }
183
184    pub fn from_complex(re: f64, im: f64) -> Self {
185        Self::from_value(Complex64::new(re, im))
186    }
187
188    pub fn new_real(x: f64) -> Self {
189        Self::from_real(x)
190    }
191
192    pub fn new_complex(re: f64, im: f64) -> Self {
193        Self::from_complex(re, im)
194    }
195
196    pub fn primal(&self) -> Self {
197        self.detach()
198    }
199
200    pub fn enable_grad(self) -> Self {
201        Self::from_tensor_unchecked(self.tensor.enable_grad())
202    }
203
204    pub fn tracks_grad(&self) -> bool {
205        self.tensor.tracks_grad()
206    }
207
208    pub fn grad(&self) -> Result<Option<Self>> {
209        self.tensor
210            .grad()
211            .map(|maybe_grad| maybe_grad.map(Self::from_tensor_unchecked))
212    }
213
214    pub fn clear_grad(&self) -> Result<()> {
215        self.tensor.clear_grad()
216    }
217
218    pub fn backward(&self) -> Result<()> {
219        self.tensor.backward()
220    }
221
222    pub fn detach(&self) -> Self {
223        Self::from_tensor_unchecked(self.tensor.detach())
224    }
225
226    pub fn real(&self) -> f64 {
227        self.value().real()
228    }
229
230    pub fn imag(&self) -> f64 {
231        self.value().imag()
232    }
233
234    pub fn abs(&self) -> f64 {
235        self.value().abs()
236    }
237
238    pub fn is_complex(&self) -> bool {
239        self.value().is_complex()
240    }
241
242    pub fn is_real(&self) -> bool {
243        !self.is_complex()
244    }
245
246    pub fn is_zero(&self) -> bool {
247        self.value().is_zero()
248    }
249
250    pub fn as_f64(&self) -> Option<f64> {
251        match self.value() {
252            ScalarValue::F32(value) => Some(value as f64),
253            ScalarValue::F64(value) => Some(value),
254            ScalarValue::C32(_) | ScalarValue::C64(_) => None,
255        }
256    }
257
258    pub fn as_c64(&self) -> Option<Complex64> {
259        match self.value() {
260            ScalarValue::F32(_) | ScalarValue::F64(_) => None,
261            ScalarValue::C32(value) => Some(Complex64::new(value.re as f64, value.im as f64)),
262            ScalarValue::C64(value) => Some(value),
263        }
264    }
265
266    pub fn conj(&self) -> Self {
267        Self::from_eager_unary(self, "conj", |tensor| tensor.conj())
268    }
269
270    pub fn real_part(&self) -> Self {
271        Self::from_real(self.real())
272    }
273
274    pub fn imag_part(&self) -> Self {
275        Self::from_real(self.imag())
276    }
277
278    pub fn compose_complex(real: Self, imag: Self) -> Result<Self> {
279        if !real.is_real() || !imag.is_real() {
280            return Err(anyhow!(
281                "compose_complex requires real-valued inputs, got real={:?}, imag={:?}",
282                real.tensor.as_native().dtype(),
283                imag.tensor.as_native().dtype()
284            ));
285        }
286        let imag_term = &imag * &Self::new_complex(0.0, 1.0);
287        Ok(&real + &imag_term)
288    }
289
290    pub fn sqrt(&self) -> Self {
291        if self.is_complex() || self.real() < 0.0 {
292            Self::from_backend_scalar(self.to_backend_scalar().sqrt())
293        } else {
294            Self::from_eager_unary(self, "sqrt", |tensor| tensor.sqrt())
295        }
296    }
297
298    pub fn powf(&self, exponent: f64) -> Self {
299        Self::from_backend_scalar(self.to_backend_scalar().powf(exponent))
300    }
301
302    pub fn powi(&self, exponent: i32) -> Self {
303        if exponent == 0 {
304            return Self::one();
305        }
306
307        let mut base = self.clone();
308        let mut power = exponent.unsigned_abs();
309        let mut acc = Self::one();
310
311        while power > 0 {
312            if power % 2 == 1 {
313                acc = &acc * &base;
314            }
315            power /= 2;
316            if power > 0 {
317                base = &base * &base;
318            }
319        }
320
321        if exponent < 0 {
322            Self::one() / acc
323        } else {
324            acc
325        }
326    }
327
328    pub(crate) fn to_backend_scalar(&self) -> BackendScalar {
329        match self.value() {
330            ScalarValue::F32(value) => BackendScalar::from_value(value),
331            ScalarValue::F64(value) => BackendScalar::from_value(value),
332            ScalarValue::C32(value) => BackendScalar::from_value(value),
333            ScalarValue::C64(value) => BackendScalar::from_value(value),
334        }
335    }
336}
337
338impl SumFromStorage for AnyScalar {
339    fn sum_from_storage(storage: &Storage) -> Self {
340        Self::from_backend_scalar(BackendScalar::sum_from_storage(storage))
341    }
342}
343
344impl From<f32> for AnyScalar {
345    fn from(value: f32) -> Self {
346        Self::from_value(value)
347    }
348}
349
350impl From<f64> for AnyScalar {
351    fn from(value: f64) -> Self {
352        Self::from_value(value)
353    }
354}
355
356impl From<Complex32> for AnyScalar {
357    fn from(value: Complex32) -> Self {
358        Self::from_value(value)
359    }
360}
361
362impl From<Complex64> for AnyScalar {
363    fn from(value: Complex64) -> Self {
364        Self::from_value(value)
365    }
366}
367
368impl TryFrom<AnyScalar> for f64 {
369    type Error = &'static str;
370
371    fn try_from(value: AnyScalar) -> std::result::Result<Self, Self::Error> {
372        value.as_f64().ok_or("cannot convert complex scalar to f64")
373    }
374}
375
376impl From<AnyScalar> for Complex64 {
377    fn from(value: AnyScalar) -> Self {
378        value.value().into_complex()
379    }
380}
381
382impl Add<&AnyScalar> for &AnyScalar {
383    type Output = AnyScalar;
384
385    fn add(self, rhs: &AnyScalar) -> Self::Output {
386        AnyScalar::from_eager_binary(self, rhs, "add", |lhs, rhs| lhs.add(rhs))
387    }
388}
389
390impl Add<AnyScalar> for AnyScalar {
391    type Output = AnyScalar;
392
393    fn add(self, rhs: AnyScalar) -> Self::Output {
394        Add::add(&self, &rhs)
395    }
396}
397
398impl Add<AnyScalar> for &AnyScalar {
399    type Output = AnyScalar;
400
401    fn add(self, rhs: AnyScalar) -> Self::Output {
402        Add::add(self, &rhs)
403    }
404}
405
406impl Add<&AnyScalar> for AnyScalar {
407    type Output = AnyScalar;
408
409    fn add(self, rhs: &AnyScalar) -> Self::Output {
410        Add::add(&self, rhs)
411    }
412}
413
414impl Sub<&AnyScalar> for &AnyScalar {
415    type Output = AnyScalar;
416
417    fn sub(self, rhs: &AnyScalar) -> Self::Output {
418        Add::add(self, &Neg::neg(rhs))
419    }
420}
421
422impl Sub<AnyScalar> for AnyScalar {
423    type Output = AnyScalar;
424
425    fn sub(self, rhs: AnyScalar) -> Self::Output {
426        Sub::sub(&self, &rhs)
427    }
428}
429
430impl Sub<AnyScalar> for &AnyScalar {
431    type Output = AnyScalar;
432
433    fn sub(self, rhs: AnyScalar) -> Self::Output {
434        Sub::sub(self, &rhs)
435    }
436}
437
438impl Sub<&AnyScalar> for AnyScalar {
439    type Output = AnyScalar;
440
441    fn sub(self, rhs: &AnyScalar) -> Self::Output {
442        Sub::sub(&self, rhs)
443    }
444}
445
446impl Mul<&AnyScalar> for &AnyScalar {
447    type Output = AnyScalar;
448
449    fn mul(self, rhs: &AnyScalar) -> Self::Output {
450        AnyScalar::from_eager_binary(self, rhs, "mul", |lhs, rhs| lhs.mul(rhs))
451    }
452}
453
454impl Mul<AnyScalar> for AnyScalar {
455    type Output = AnyScalar;
456
457    fn mul(self, rhs: AnyScalar) -> Self::Output {
458        Mul::mul(&self, &rhs)
459    }
460}
461
462impl Mul<AnyScalar> for &AnyScalar {
463    type Output = AnyScalar;
464
465    fn mul(self, rhs: AnyScalar) -> Self::Output {
466        Mul::mul(self, &rhs)
467    }
468}
469
470impl Mul<&AnyScalar> for AnyScalar {
471    type Output = AnyScalar;
472
473    fn mul(self, rhs: &AnyScalar) -> Self::Output {
474        Mul::mul(&self, rhs)
475    }
476}
477
478impl Div<&AnyScalar> for &AnyScalar {
479    type Output = AnyScalar;
480
481    fn div(self, rhs: &AnyScalar) -> Self::Output {
482        if self.tensor.as_native().dtype() == rhs.tensor.as_native().dtype() {
483            AnyScalar::from_eager_binary(self, rhs, "div", |lhs, rhs| lhs.div(rhs))
484        } else {
485            AnyScalar::from_backend_scalar(self.to_backend_scalar() / rhs.to_backend_scalar())
486        }
487    }
488}
489
490impl Div<AnyScalar> for AnyScalar {
491    type Output = AnyScalar;
492
493    fn div(self, rhs: AnyScalar) -> Self::Output {
494        Div::div(&self, &rhs)
495    }
496}
497
498impl Div<AnyScalar> for &AnyScalar {
499    type Output = AnyScalar;
500
501    fn div(self, rhs: AnyScalar) -> Self::Output {
502        Div::div(self, &rhs)
503    }
504}
505
506impl Div<&AnyScalar> for AnyScalar {
507    type Output = AnyScalar;
508
509    fn div(self, rhs: &AnyScalar) -> Self::Output {
510        Div::div(&self, rhs)
511    }
512}
513
514impl Neg for &AnyScalar {
515    type Output = AnyScalar;
516
517    fn neg(self) -> Self::Output {
518        AnyScalar::from_eager_unary(self, "neg", |tensor| tensor.neg())
519    }
520}
521
522impl Neg for AnyScalar {
523    type Output = AnyScalar;
524
525    fn neg(self) -> Self::Output {
526        Neg::neg(&self)
527    }
528}
529
530impl Mul<AnyScalar> for f64 {
531    type Output = AnyScalar;
532
533    fn mul(self, rhs: AnyScalar) -> Self::Output {
534        AnyScalar::from_real(self) * rhs
535    }
536}
537
538impl Mul<AnyScalar> for Complex64 {
539    type Output = AnyScalar;
540
541    fn mul(self, rhs: AnyScalar) -> Self::Output {
542        AnyScalar::from(self) * rhs
543    }
544}
545
546impl Div<AnyScalar> for Complex64 {
547    type Output = AnyScalar;
548
549    fn div(self, rhs: AnyScalar) -> Self::Output {
550        AnyScalar::from(self) / rhs
551    }
552}
553
554impl Default for AnyScalar {
555    fn default() -> Self {
556        Self::zero()
557    }
558}
559
560impl Zero for AnyScalar {
561    fn zero() -> Self {
562        Self::from_real(0.0)
563    }
564
565    fn is_zero(&self) -> bool {
566        AnyScalar::is_zero(self)
567    }
568}
569
570impl One for AnyScalar {
571    fn one() -> Self {
572        Self::from_real(1.0)
573    }
574}
575
576impl PartialEq for AnyScalar {
577    fn eq(&self, other: &Self) -> bool {
578        self.tensor.as_native().dtype() == other.tensor.as_native().dtype()
579            && self.value() == other.value()
580    }
581}
582
583impl PartialOrd for AnyScalar {
584    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
585        match (self.value(), other.value()) {
586            (ScalarValue::F32(lhs), ScalarValue::F32(rhs)) => lhs.partial_cmp(&rhs),
587            (ScalarValue::F32(lhs), ScalarValue::F64(rhs)) => (lhs as f64).partial_cmp(&rhs),
588            (ScalarValue::F64(lhs), ScalarValue::F32(rhs)) => lhs.partial_cmp(&(rhs as f64)),
589            (ScalarValue::F64(lhs), ScalarValue::F64(rhs)) => lhs.partial_cmp(&rhs),
590            _ => None,
591        }
592    }
593}
594
595impl fmt::Display for AnyScalar {
596    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
597        match self.value() {
598            ScalarValue::F32(value) => value.fmt(f),
599            ScalarValue::F64(value) => value.fmt(f),
600            ScalarValue::C32(value) => value.fmt(f),
601            ScalarValue::C64(value) => value.fmt(f),
602        }
603    }
604}
605
606impl fmt::Debug for AnyScalar {
607    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
608        f.debug_struct("AnyScalar")
609            .field("dtype", &self.tensor.as_native().dtype())
610            .field("value", &self.value())
611            .field("tracks_grad", &self.tracks_grad())
612            .finish()
613    }
614}