Skip to main content

tensor4all_tensorbackend/
any_scalar.rs

1use std::cmp::Ordering;
2use std::fmt;
3use std::ops::{Add, Div, Mul, Neg, Sub};
4
5use anyhow::{anyhow, ensure, Result};
6use num_complex::{Complex32, Complex64};
7use num_traits::{One, Zero};
8use tenferro::{DType, Tensor as NativeTensor};
9
10use crate::storage::{Storage, SumFromStorage};
11use crate::tensor_element::TensorElement;
12
13#[derive(Clone, Copy, Debug, PartialEq)]
14enum ScalarValue {
15    F32(f32),
16    F64(f64),
17    I64(i64),
18    C32(Complex32),
19    C64(Complex64),
20}
21
22impl ScalarValue {
23    fn real(self) -> f64 {
24        match self {
25            Self::F32(value) => value as f64,
26            Self::F64(value) => value,
27            Self::I64(value) => value as f64,
28            Self::C32(value) => value.re as f64,
29            Self::C64(value) => value.re,
30        }
31    }
32
33    fn imag(self) -> f64 {
34        match self {
35            Self::F32(_) | Self::F64(_) | Self::I64(_) => 0.0,
36            Self::C32(value) => value.im as f64,
37            Self::C64(value) => value.im,
38        }
39    }
40
41    fn abs(self) -> f64 {
42        match self {
43            Self::F32(value) => value.abs() as f64,
44            Self::F64(value) => value.abs(),
45            Self::I64(value) => value.abs() as f64,
46            Self::C32(value) => value.norm() as f64,
47            Self::C64(value) => value.norm(),
48        }
49    }
50
51    fn is_complex(self) -> bool {
52        matches!(self, Self::C32(_) | Self::C64(_))
53    }
54
55    fn is_zero(self) -> bool {
56        match self {
57            Self::F32(value) => value == 0.0,
58            Self::F64(value) => value == 0.0,
59            Self::I64(value) => value == 0,
60            Self::C32(value) => value == Complex32::new(0.0, 0.0),
61            Self::C64(value) => value == Complex64::new(0.0, 0.0),
62        }
63    }
64
65    fn into_complex(self) -> Complex64 {
66        match self {
67            Self::F32(value) => Complex64::new(value as f64, 0.0),
68            Self::F64(value) => Complex64::new(value, 0.0),
69            Self::I64(value) => Complex64::new(value as f64, 0.0),
70            Self::C32(value) => Complex64::new(value.re as f64, value.im as f64),
71            Self::C64(value) => value,
72        }
73    }
74}
75
76fn scalar_value_from_storage(storage: &Storage) -> ScalarValue {
77    if storage.is_f64() {
78        ScalarValue::F64(f64::sum_from_storage(storage))
79    } else {
80        ScalarValue::C64(Complex64::sum_from_storage(storage))
81    }
82}
83
84fn scalar_value_from_native(native: &NativeTensor) -> Result<ScalarValue> {
85    ensure!(
86        native.shape().is_empty(),
87        "expected rank-0 scalar tensor, got shape {:?}",
88        native.shape()
89    );
90
91    match native.dtype() {
92        DType::F32 => native
93            .as_slice::<f32>()
94            .and_then(|values| values.first().copied())
95            .map(ScalarValue::F32)
96            .ok_or_else(|| anyhow!("failed to read f32 scalar tensor value")),
97        DType::F64 => native
98            .as_slice::<f64>()
99            .and_then(|values| values.first().copied())
100            .map(ScalarValue::F64)
101            .ok_or_else(|| anyhow!("failed to read f64 scalar tensor value")),
102        DType::I64 => native
103            .as_slice::<i64>()
104            .and_then(|values| values.first().copied())
105            .map(ScalarValue::I64)
106            .ok_or_else(|| anyhow!("failed to read i64 scalar tensor value")),
107        DType::C32 => native
108            .as_slice::<Complex32>()
109            .and_then(|values| values.first().copied())
110            .map(ScalarValue::C32)
111            .ok_or_else(|| anyhow!("failed to read c32 scalar tensor value")),
112        DType::C64 => native
113            .as_slice::<Complex64>()
114            .and_then(|values| values.first().copied())
115            .map(ScalarValue::C64)
116            .ok_or_else(|| anyhow!("failed to read c64 scalar tensor value")),
117    }
118}
119
120trait ScalarTensorElement: TensorElement {
121    fn scalar_value(value: Self) -> ScalarValue;
122}
123
124impl ScalarTensorElement for f32 {
125    fn scalar_value(value: Self) -> ScalarValue {
126        ScalarValue::F32(value)
127    }
128}
129
130impl ScalarTensorElement for f64 {
131    fn scalar_value(value: Self) -> ScalarValue {
132        ScalarValue::F64(value)
133    }
134}
135
136impl ScalarTensorElement for Complex32 {
137    fn scalar_value(value: Self) -> ScalarValue {
138        ScalarValue::C32(value)
139    }
140}
141
142impl ScalarTensorElement for Complex64 {
143    fn scalar_value(value: Self) -> ScalarValue {
144        ScalarValue::C64(value)
145    }
146}
147
148pub(crate) fn promote_scalar_native(native: &NativeTensor, target: DType) -> Result<NativeTensor> {
149    let promoted = match (scalar_value_from_native(native)?, target) {
150        (ScalarValue::F32(value), DType::F32) => Scalar::from_value(value),
151        (ScalarValue::F32(value), DType::F64) => Scalar::from_value(value as f64),
152        (ScalarValue::F32(value), DType::C32) => Scalar::from_value(Complex32::new(value, 0.0)),
153        (ScalarValue::F32(value), DType::C64) => {
154            Scalar::from_value(Complex64::new(value as f64, 0.0))
155        }
156        (ScalarValue::F32(_), DType::I64) => {
157            return Err(anyhow!(
158                "cannot promote f32 scalar to i64 without truncation"
159            ));
160        }
161        (ScalarValue::F64(value), DType::F32) => Scalar::from_value(value as f32),
162        (ScalarValue::F64(value), DType::F64) => Scalar::from_value(value),
163        (ScalarValue::F64(_), DType::I64) => {
164            return Err(anyhow!(
165                "cannot promote f64 scalar to i64 without truncation"
166            ));
167        }
168        (ScalarValue::F64(value), DType::C32) => {
169            Scalar::from_value(Complex32::new(value as f32, 0.0))
170        }
171        (ScalarValue::F64(value), DType::C64) => Scalar::from_value(Complex64::new(value, 0.0)),
172        (ScalarValue::I64(value), DType::F32) => Scalar::from_value(value as f32),
173        (ScalarValue::I64(value), DType::F64) => Scalar::from_value(value as f64),
174        (ScalarValue::I64(value), DType::I64) => Scalar::from_i64(value),
175        (ScalarValue::I64(value), DType::C32) => {
176            Scalar::from_value(Complex32::new(value as f32, 0.0))
177        }
178        (ScalarValue::I64(value), DType::C64) => {
179            Scalar::from_value(Complex64::new(value as f64, 0.0))
180        }
181        (ScalarValue::C32(value), DType::F32) => Scalar::from_value(value.re),
182        (ScalarValue::C32(value), DType::F64) => Scalar::from_value(value.re as f64),
183        (ScalarValue::C32(_), DType::I64) => {
184            return Err(anyhow!("cannot promote c32 scalar to i64"));
185        }
186        (ScalarValue::C32(value), DType::C32) => Scalar::from_value(value),
187        (ScalarValue::C32(value), DType::C64) => {
188            Scalar::from_value(Complex64::new(value.re as f64, value.im as f64))
189        }
190        (ScalarValue::C64(value), DType::F32) => Scalar::from_value(value.re as f32),
191        (ScalarValue::C64(value), DType::F64) => Scalar::from_value(value.re),
192        (ScalarValue::C64(_), DType::I64) => {
193            return Err(anyhow!("cannot promote c64 scalar to i64"));
194        }
195        (ScalarValue::C64(value), DType::C32) => {
196            Scalar::from_value(Complex32::new(value.re as f32, value.im as f32))
197        }
198        (ScalarValue::C64(value), DType::C64) => Scalar::from_value(value),
199    };
200    Ok(promoted.native)
201}
202
203/// Dynamic scalar used across tensor4all backends.
204///
205/// This is a tensor4all-owned rank-0 wrapper over tenferro's dynamic tensor.
206///
207/// # Examples
208///
209/// ```
210/// use tensor4all_tensorbackend::AnyScalar;
211///
212/// // Real scalar
213/// let a = AnyScalar::new_real(3.14);
214/// assert!((a.real() - 3.14).abs() < 1e-10);
215/// assert_eq!(a.imag(), 0.0);
216/// assert!(a.is_real());
217/// assert!(!a.is_complex());
218///
219/// // Complex scalar
220/// let b = AnyScalar::new_complex(1.0, 2.0);
221/// assert!((b.real() - 1.0).abs() < 1e-10);
222/// assert!((b.imag() - 2.0).abs() < 1e-10);
223/// assert!(b.is_complex());
224///
225/// // Arithmetic
226/// let c = AnyScalar::new_real(2.0);
227/// let d = a + c;
228/// assert!((d.real() - 5.14).abs() < 1e-10);
229/// ```
230pub struct Scalar {
231    native: NativeTensor,
232    value: ScalarValue,
233}
234
235/// Backward-compatible scalar type name used across tensor4all APIs.
236pub type AnyScalar = Scalar;
237
238impl Scalar {
239    fn wrap_native(native: NativeTensor) -> Result<Self> {
240        if native.shape().is_empty() {
241            let value = scalar_value_from_native(&native)?;
242            Ok(Self { native, value })
243        } else {
244            Err(anyhow!(
245                "Scalar requires a rank-0 tensor, got shape {:?}",
246                native.shape()
247            ))
248        }
249    }
250
251    fn value(&self) -> ScalarValue {
252        self.value
253    }
254
255    fn from_i64(value: i64) -> Self {
256        Self {
257            native: NativeTensor::from_vec(vec![], vec![value]),
258            value: ScalarValue::I64(value),
259        }
260    }
261
262    pub(crate) fn from_native(value: NativeTensor) -> Result<Self> {
263        Self::wrap_native(value)
264    }
265
266    pub(crate) fn as_native(&self) -> &NativeTensor {
267        &self.native
268    }
269
270    /// Creates a scalar from any supported public tensor element type.
271    ///
272    /// # Examples
273    ///
274    /// ```
275    /// use tensor4all_tensorbackend::AnyScalar;
276    ///
277    /// let s = AnyScalar::from_value(3.14_f64);
278    /// assert!((s.real() - 3.14).abs() < 1e-10);
279    ///
280    /// use num_complex::Complex64;
281    /// let z = AnyScalar::from_value(Complex64::new(1.0, 2.0));
282    /// assert!(z.is_complex());
283    /// assert!((z.real() - 1.0).abs() < 1e-10);
284    /// assert!((z.imag() - 2.0).abs() < 1e-10);
285    /// ```
286    #[allow(private_bounds)]
287    pub fn from_value<T: ScalarTensorElement>(value: T) -> Self {
288        let native = NativeTensor::from_vec(vec![], vec![value]);
289        Self {
290            native,
291            value: T::scalar_value(value),
292        }
293    }
294
295    /// Creates a real scalar from an `f64` value.
296    ///
297    /// # Examples
298    ///
299    /// ```
300    /// use tensor4all_tensorbackend::AnyScalar;
301    ///
302    /// let s = AnyScalar::from_real(2.5);
303    /// assert!((s.real() - 2.5).abs() < 1e-10);
304    /// assert!(s.is_real());
305    /// ```
306    pub fn from_real(x: f64) -> Self {
307        Self::from_value(x)
308    }
309
310    /// Creates a complex scalar from real and imaginary parts.
311    ///
312    /// # Examples
313    ///
314    /// ```
315    /// use tensor4all_tensorbackend::AnyScalar;
316    ///
317    /// let s = AnyScalar::from_complex(1.0, -1.0);
318    /// assert!((s.real() - 1.0).abs() < 1e-10);
319    /// assert!((s.imag() - (-1.0)).abs() < 1e-10);
320    /// assert!(s.is_complex());
321    /// ```
322    pub fn from_complex(re: f64, im: f64) -> Self {
323        Self::from_value(Complex64::new(re, im))
324    }
325
326    /// Backward-compatible constructor for a real scalar.
327    ///
328    /// # Examples
329    ///
330    /// ```
331    /// use tensor4all_tensorbackend::AnyScalar;
332    ///
333    /// let s = AnyScalar::new_real(42.0);
334    /// assert!((s.real() - 42.0).abs() < 1e-10);
335    /// ```
336    pub fn new_real(x: f64) -> Self {
337        Self::from_real(x)
338    }
339
340    /// Backward-compatible constructor for a complex scalar.
341    ///
342    /// # Examples
343    ///
344    /// ```
345    /// use tensor4all_tensorbackend::AnyScalar;
346    ///
347    /// let s = AnyScalar::new_complex(3.0, 4.0);
348    /// assert!((s.abs() - 5.0).abs() < 1e-10); // |3 + 4i| = 5
349    /// ```
350    pub fn new_complex(re: f64, im: f64) -> Self {
351        Self::from_complex(re, im)
352    }
353
354    /// Returns the scalar's plain rank-0 tensor value.
355    ///
356    /// # Examples
357    ///
358    /// ```
359    /// use tensor4all_tensorbackend::AnyScalar;
360    ///
361    /// let s = AnyScalar::new_real(5.0);
362    /// let p = s.primal();
363    /// assert!((p.real() - 5.0).abs() < 1e-10);
364    /// ```
365    pub fn primal(&self) -> Self {
366        self.clone()
367    }
368
369    /// Returns the real part while intentionally dropping AD metadata.
370    ///
371    /// # Examples
372    ///
373    /// ```
374    /// use tensor4all_tensorbackend::AnyScalar;
375    ///
376    /// let s = AnyScalar::new_complex(3.0, 4.0);
377    /// assert!((s.real() - 3.0).abs() < 1e-10);
378    /// ```
379    pub fn real(&self) -> f64 {
380        self.value().real()
381    }
382
383    /// Returns the imaginary part while intentionally dropping AD metadata.
384    ///
385    /// Returns `0.0` for real scalars.
386    ///
387    /// # Examples
388    ///
389    /// ```
390    /// use tensor4all_tensorbackend::AnyScalar;
391    ///
392    /// let s = AnyScalar::new_complex(3.0, 4.0);
393    /// assert!((s.imag() - 4.0).abs() < 1e-10);
394    ///
395    /// let r = AnyScalar::new_real(5.0);
396    /// assert_eq!(r.imag(), 0.0);
397    /// ```
398    pub fn imag(&self) -> f64 {
399        self.value().imag()
400    }
401
402    /// Returns the absolute value while intentionally dropping AD metadata.
403    ///
404    /// For complex scalars, returns the complex modulus (`sqrt(re^2 + im^2)`).
405    ///
406    /// # Examples
407    ///
408    /// ```
409    /// use tensor4all_tensorbackend::AnyScalar;
410    ///
411    /// let s = AnyScalar::new_complex(3.0, 4.0);
412    /// assert!((s.abs() - 5.0).abs() < 1e-10);
413    ///
414    /// let r = AnyScalar::new_real(-7.0);
415    /// assert!((r.abs() - 7.0).abs() < 1e-10);
416    /// ```
417    pub fn abs(&self) -> f64 {
418        self.value().abs()
419    }
420
421    /// Returns true when the scalar is complex.
422    ///
423    /// # Examples
424    ///
425    /// ```
426    /// use tensor4all_tensorbackend::AnyScalar;
427    ///
428    /// assert!(AnyScalar::new_complex(1.0, 0.0).is_complex());
429    /// assert!(!AnyScalar::new_real(1.0).is_complex());
430    /// ```
431    pub fn is_complex(&self) -> bool {
432        self.value().is_complex()
433    }
434
435    /// Returns true when the scalar is real.
436    ///
437    /// # Examples
438    ///
439    /// ```
440    /// use tensor4all_tensorbackend::AnyScalar;
441    ///
442    /// assert!(AnyScalar::new_real(1.0).is_real());
443    /// assert!(!AnyScalar::new_complex(1.0, 2.0).is_real());
444    /// ```
445    pub fn is_real(&self) -> bool {
446        !self.is_complex()
447    }
448
449    /// Returns true when the scalar is zero.
450    ///
451    /// # Examples
452    ///
453    /// ```
454    /// use tensor4all_tensorbackend::AnyScalar;
455    ///
456    /// assert!(AnyScalar::new_real(0.0).is_zero());
457    /// assert!(!AnyScalar::new_real(1.0).is_zero());
458    /// ```
459    pub fn is_zero(&self) -> bool {
460        self.value().is_zero()
461    }
462
463    /// Returns the underlying value as `f64` when real.
464    ///
465    /// Returns `None` for complex scalars.
466    ///
467    /// # Examples
468    ///
469    /// ```
470    /// use tensor4all_tensorbackend::AnyScalar;
471    ///
472    /// let r = AnyScalar::new_real(2.5);
473    /// assert_eq!(r.as_f64(), Some(2.5));
474    ///
475    /// let c = AnyScalar::new_complex(1.0, 1.0);
476    /// assert_eq!(c.as_f64(), None);
477    /// ```
478    pub fn as_f64(&self) -> Option<f64> {
479        match self.value() {
480            ScalarValue::F32(value) => Some(value as f64),
481            ScalarValue::F64(value) => Some(value),
482            ScalarValue::I64(value) => Some(value as f64),
483            ScalarValue::C32(_) | ScalarValue::C64(_) => None,
484        }
485    }
486
487    /// Returns the underlying value as `Complex64` when complex.
488    ///
489    /// Returns `None` for real scalars.
490    ///
491    /// # Examples
492    ///
493    /// ```
494    /// use tensor4all_tensorbackend::AnyScalar;
495    /// use num_complex::Complex64;
496    ///
497    /// let c = AnyScalar::new_complex(1.0, 2.0);
498    /// assert_eq!(c.as_c64(), Some(Complex64::new(1.0, 2.0)));
499    ///
500    /// let r = AnyScalar::new_real(1.0);
501    /// assert_eq!(r.as_c64(), None);
502    /// ```
503    pub fn as_c64(&self) -> Option<Complex64> {
504        match self.value() {
505            ScalarValue::F32(_) | ScalarValue::F64(_) | ScalarValue::I64(_) => None,
506            ScalarValue::C32(value) => Some(Complex64::new(value.re as f64, value.im as f64)),
507            ScalarValue::C64(value) => Some(value),
508        }
509    }
510
511    /// Returns the complex conjugate.
512    ///
513    /// For real scalars, returns a copy (conjugate of a real number is itself).
514    ///
515    /// # Examples
516    ///
517    /// ```
518    /// use tensor4all_tensorbackend::AnyScalar;
519    ///
520    /// let c = AnyScalar::new_complex(3.0, 4.0);
521    /// let cc = c.conj();
522    /// assert!((cc.real() - 3.0).abs() < 1e-10);
523    /// assert!((cc.imag() - (-4.0)).abs() < 1e-10);
524    ///
525    /// let r = AnyScalar::new_real(5.0);
526    /// assert!((r.conj().real() - 5.0).abs() < 1e-10);
527    /// ```
528    pub fn conj(&self) -> Self {
529        match self.value() {
530            ScalarValue::F32(value) => Self::from_value(value),
531            ScalarValue::F64(value) => Self::from_value(value),
532            ScalarValue::I64(value) => Self::from_i64(value),
533            ScalarValue::C32(value) => Self::from_value(value.conj()),
534            ScalarValue::C64(value) => Self::from_value(value.conj()),
535        }
536    }
537
538    /// Returns the real part as a scalar, preserving scalar semantics.
539    ///
540    /// Unlike [`real`](Self::real), this returns an `AnyScalar` rather than raw `f64`.
541    ///
542    /// # Examples
543    ///
544    /// ```
545    /// use tensor4all_tensorbackend::AnyScalar;
546    ///
547    /// let c = AnyScalar::new_complex(3.0, 4.0);
548    /// let re = c.real_part();
549    /// assert!(re.is_real());
550    /// assert!((re.real() - 3.0).abs() < 1e-10);
551    /// ```
552    pub fn real_part(&self) -> Self {
553        Self::from_real(self.real())
554    }
555
556    /// Returns the imaginary part as a scalar, preserving scalar semantics.
557    ///
558    /// Unlike [`imag`](Self::imag), this returns an `AnyScalar` rather than raw `f64`.
559    /// The result is always a real scalar.
560    ///
561    /// # Examples
562    ///
563    /// ```
564    /// use tensor4all_tensorbackend::AnyScalar;
565    ///
566    /// let c = AnyScalar::new_complex(3.0, 4.0);
567    /// let im = c.imag_part();
568    /// assert!(im.is_real());
569    /// assert!((im.real() - 4.0).abs() < 1e-10);
570    /// ```
571    pub fn imag_part(&self) -> Self {
572        Self::from_real(self.imag())
573    }
574
575    /// Compose a complex scalar from real-valued parts.
576    ///
577    /// # Errors
578    ///
579    /// Returns an error if either input is not a real scalar.
580    ///
581    /// # Examples
582    ///
583    /// ```
584    /// use tensor4all_tensorbackend::AnyScalar;
585    ///
586    /// let re = AnyScalar::new_real(3.0);
587    /// let im = AnyScalar::new_real(4.0);
588    /// let c = AnyScalar::compose_complex(re, im).unwrap();
589    /// assert!(c.is_complex());
590    /// assert!((c.real() - 3.0).abs() < 1e-10);
591    /// assert!((c.imag() - 4.0).abs() < 1e-10);
592    /// ```
593    pub fn compose_complex(real: Self, imag: Self) -> Result<Self> {
594        if !real.is_real() || !imag.is_real() {
595            return Err(anyhow!(
596                "compose_complex requires real-valued inputs, got real={:?}, imag={:?}",
597                real.native.dtype(),
598                imag.native.dtype()
599            ));
600        }
601        Ok(Self::from_complex(real.real(), imag.real()))
602    }
603
604    /// Square root, preserving AD metadata.
605    ///
606    /// Automatically promotes to complex if the value is negative.
607    ///
608    /// # Examples
609    ///
610    /// ```
611    /// use tensor4all_tensorbackend::AnyScalar;
612    ///
613    /// let s = AnyScalar::new_real(9.0);
614    /// assert!((s.sqrt().real() - 3.0).abs() < 1e-10);
615    /// ```
616    pub fn sqrt(&self) -> Self {
617        if self.is_complex() || self.real() < 0.0 {
618            let value = self.value().into_complex().sqrt();
619            if value.im == 0.0 {
620                Self::from_real(value.re)
621            } else {
622                Self::from_value(value)
623            }
624        } else {
625            Self::from_real(self.real().sqrt())
626        }
627    }
628
629    /// Real exponent power, preserving AD metadata.
630    ///
631    /// Automatically promotes to complex when the base is negative and the
632    /// exponent is non-integer.
633    ///
634    /// # Examples
635    ///
636    /// ```
637    /// use tensor4all_tensorbackend::AnyScalar;
638    ///
639    /// let s = AnyScalar::new_real(2.0);
640    /// assert!((s.powf(3.0).real() - 8.0).abs() < 1e-10);
641    /// ```
642    pub fn powf(&self, exponent: f64) -> Self {
643        let needs_complex_promotion =
644            self.is_complex() || (self.real() < 0.0 && exponent.fract() != 0.0);
645        if needs_complex_promotion {
646            let value = self.value().into_complex().powf(exponent);
647            if value.im == 0.0 {
648                Self::from_real(value.re)
649            } else {
650                Self::from_value(value)
651            }
652        } else {
653            Self::from_real(self.real().powf(exponent))
654        }
655    }
656
657    /// Integer exponent power, preserving AD metadata.
658    ///
659    /// # Examples
660    ///
661    /// ```
662    /// use tensor4all_tensorbackend::AnyScalar;
663    ///
664    /// let s = AnyScalar::new_real(3.0);
665    /// assert!((s.powi(2).real() - 9.0).abs() < 1e-10);
666    /// ```
667    pub fn powi(&self, exponent: i32) -> Self {
668        self.powf(exponent as f64)
669    }
670}
671
672impl SumFromStorage for Scalar {
673    fn sum_from_storage(storage: &Storage) -> Self {
674        match scalar_value_from_storage(storage) {
675            ScalarValue::F32(value) => Self::from_value(value),
676            ScalarValue::F64(value) => Self::from_value(value),
677            ScalarValue::I64(value) => Self::from_i64(value),
678            ScalarValue::C32(value) => Self::from_value(value),
679            ScalarValue::C64(value) => Self::from_value(value),
680        }
681    }
682}
683
684impl From<f32> for Scalar {
685    fn from(value: f32) -> Self {
686        Self::from_value(value)
687    }
688}
689
690impl From<f64> for Scalar {
691    fn from(value: f64) -> Self {
692        Self::from_value(value)
693    }
694}
695
696impl From<Complex32> for Scalar {
697    fn from(value: Complex32) -> Self {
698        Self::from_value(value)
699    }
700}
701
702impl From<Complex64> for Scalar {
703    fn from(value: Complex64) -> Self {
704        Self::from_value(value)
705    }
706}
707
708impl TryFrom<Scalar> for f64 {
709    type Error = &'static str;
710
711    fn try_from(value: Scalar) -> std::result::Result<Self, Self::Error> {
712        match value.value() {
713            ScalarValue::F32(real) => Ok(real as f64),
714            ScalarValue::F64(real) => Ok(real),
715            ScalarValue::I64(real) => Ok(real as f64),
716            ScalarValue::C32(_) | ScalarValue::C64(_) => {
717                Err("cannot convert complex scalar to f64")
718            }
719        }
720    }
721}
722
723impl From<Scalar> for Complex64 {
724    fn from(value: Scalar) -> Self {
725        value.value().into_complex()
726    }
727}
728
729impl Add for Scalar {
730    type Output = Self;
731
732    fn add(self, rhs: Self) -> Self::Output {
733        match (self.value(), rhs.value()) {
734            (ScalarValue::F32(lhs), ScalarValue::F32(rhs)) => Self::from_value(lhs + rhs),
735            (lhs, rhs) if lhs.is_complex() || rhs.is_complex() => {
736                Self::from_value(lhs.into_complex() + rhs.into_complex())
737            }
738            (lhs, rhs) => Self::from_real(lhs.real() + rhs.real()),
739        }
740    }
741}
742
743impl Sub for Scalar {
744    type Output = Self;
745
746    fn sub(self, rhs: Self) -> Self::Output {
747        self + (-rhs)
748    }
749}
750
751impl Mul for Scalar {
752    type Output = Self;
753
754    fn mul(self, rhs: Self) -> Self::Output {
755        match (self.value(), rhs.value()) {
756            (ScalarValue::F32(lhs), ScalarValue::F32(rhs)) => Self::from_value(lhs * rhs),
757            (lhs, rhs) if lhs.is_complex() || rhs.is_complex() => {
758                Self::from_value(lhs.into_complex() * rhs.into_complex())
759            }
760            (lhs, rhs) => Self::from_real(lhs.real() * rhs.real()),
761        }
762    }
763}
764
765impl Div for Scalar {
766    type Output = Self;
767
768    fn div(self, rhs: Self) -> Self::Output {
769        match (self.value(), rhs.value()) {
770            (ScalarValue::F32(lhs), ScalarValue::F32(rhs)) => Self::from_value(lhs / rhs),
771            (lhs, rhs) if lhs.is_complex() || rhs.is_complex() => {
772                Self::from_value(lhs.into_complex() / rhs.into_complex())
773            }
774            (lhs, rhs) => Self::from_real(lhs.real() / rhs.real()),
775        }
776    }
777}
778
779impl Neg for Scalar {
780    type Output = Self;
781
782    fn neg(self) -> Self::Output {
783        match self.value() {
784            ScalarValue::F32(value) => Self::from_value(-value),
785            ScalarValue::F64(value) => Self::from_value(-value),
786            ScalarValue::I64(value) => Self::from_i64(-value),
787            ScalarValue::C32(value) => Self::from_value(-value),
788            ScalarValue::C64(value) => Self::from_value(-value),
789        }
790    }
791}
792
793impl Mul<Scalar> for f64 {
794    type Output = Scalar;
795
796    fn mul(self, rhs: Scalar) -> Self::Output {
797        Scalar::from_real(self) * rhs
798    }
799}
800
801impl Mul<Scalar> for Complex64 {
802    type Output = Scalar;
803
804    fn mul(self, rhs: Scalar) -> Self::Output {
805        Scalar::from(self) * rhs
806    }
807}
808
809impl Div<Scalar> for Complex64 {
810    type Output = Scalar;
811
812    fn div(self, rhs: Scalar) -> Self::Output {
813        Scalar::from(self) / rhs
814    }
815}
816
817impl Default for Scalar {
818    fn default() -> Self {
819        Self::zero()
820    }
821}
822
823impl Zero for Scalar {
824    fn zero() -> Self {
825        Self::from_real(0.0)
826    }
827
828    fn is_zero(&self) -> bool {
829        Scalar::is_zero(self)
830    }
831}
832
833impl One for Scalar {
834    fn one() -> Self {
835        Self::from_real(1.0)
836    }
837}
838
839impl PartialEq for Scalar {
840    fn eq(&self, other: &Self) -> bool {
841        self.native.dtype() == other.native.dtype() && self.value() == other.value()
842    }
843}
844
845impl PartialOrd for Scalar {
846    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
847        match (self.value(), other.value()) {
848            (ScalarValue::F32(lhs), ScalarValue::F32(rhs)) => lhs.partial_cmp(&rhs),
849            (ScalarValue::F32(lhs), ScalarValue::F64(rhs)) => (lhs as f64).partial_cmp(&rhs),
850            (ScalarValue::F32(lhs), ScalarValue::I64(rhs)) => {
851                (lhs as f64).partial_cmp(&(rhs as f64))
852            }
853            (ScalarValue::F64(lhs), ScalarValue::F32(rhs)) => lhs.partial_cmp(&(rhs as f64)),
854            (ScalarValue::F64(lhs), ScalarValue::F64(rhs)) => lhs.partial_cmp(&rhs),
855            (ScalarValue::F64(lhs), ScalarValue::I64(rhs)) => lhs.partial_cmp(&(rhs as f64)),
856            (ScalarValue::I64(lhs), ScalarValue::F32(rhs)) => {
857                (lhs as f64).partial_cmp(&(rhs as f64))
858            }
859            (ScalarValue::I64(lhs), ScalarValue::F64(rhs)) => (lhs as f64).partial_cmp(&rhs),
860            (ScalarValue::I64(lhs), ScalarValue::I64(rhs)) => lhs.partial_cmp(&rhs),
861            _ => None,
862        }
863    }
864}
865
866impl fmt::Display for Scalar {
867    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
868        match self.value() {
869            ScalarValue::F32(value) => value.fmt(f),
870            ScalarValue::F64(value) => value.fmt(f),
871            ScalarValue::I64(value) => value.fmt(f),
872            ScalarValue::C32(value) => value.fmt(f),
873            ScalarValue::C64(value) => value.fmt(f),
874        }
875    }
876}
877
878impl fmt::Debug for Scalar {
879    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
880        f.debug_struct("Scalar")
881            .field("dtype", &self.native.dtype())
882            .field("value", &self.value())
883            .finish()
884    }
885}
886
887impl Clone for Scalar {
888    fn clone(&self) -> Self {
889        Self {
890            native: self.native.clone(),
891            value: self.value,
892        }
893    }
894}
895
896#[cfg(test)]
897mod tests;