Skip to main content

tensor4all_core/
any_scalar.rs

1use std::cmp::Ordering;
2use std::fmt;
3use std::ops::{Add, Div, Mul, Neg, Sub};
4
5use anyhow::{anyhow, Result};
6use num_complex::{Complex32, Complex64};
7use num_traits::{One, Zero};
8use tensor4all_tensorbackend::AnyScalar as BackendScalar;
9
10use crate::defaults::tensordynlen::TensorDynLen;
11use crate::TensorElement;
12use tensor4all_tensorbackend::{Storage, SumFromStorage};
13
14#[derive(Clone, Copy, Debug, PartialEq)]
15enum ScalarValue {
16    F32(f32),
17    F64(f64),
18    C32(Complex32),
19    C64(Complex64),
20}
21
22impl ScalarValue {
23    fn real(self) -> f64 {
24        match self {
25            Self::F32(value) => value as f64,
26            Self::F64(value) => value,
27            Self::C32(value) => value.re as f64,
28            Self::C64(value) => value.re,
29        }
30    }
31
32    fn imag(self) -> f64 {
33        match self {
34            Self::F32(_) | Self::F64(_) => 0.0,
35            Self::C32(value) => value.im as f64,
36            Self::C64(value) => value.im,
37        }
38    }
39
40    fn abs(self) -> f64 {
41        match self {
42            Self::F32(value) => value.abs() as f64,
43            Self::F64(value) => value.abs(),
44            Self::C32(value) => value.norm() as f64,
45            Self::C64(value) => value.norm(),
46        }
47    }
48
49    fn is_complex(self) -> bool {
50        matches!(self, Self::C32(_) | Self::C64(_))
51    }
52
53    fn is_zero(self) -> bool {
54        match self {
55            Self::F32(value) => value == 0.0,
56            Self::F64(value) => value == 0.0,
57            Self::C32(value) => value == Complex32::new(0.0, 0.0),
58            Self::C64(value) => value == Complex64::new(0.0, 0.0),
59        }
60    }
61
62    fn into_complex(self) -> Complex64 {
63        match self {
64            Self::F32(value) => Complex64::new(value as f64, 0.0),
65            Self::F64(value) => Complex64::new(value, 0.0),
66            Self::C32(value) => Complex64::new(value.re as f64, value.im as f64),
67            Self::C64(value) => value,
68        }
69    }
70}
71
72trait ScalarTensorElement: TensorElement {
73    fn scalar_value(value: Self) -> ScalarValue;
74}
75
76impl ScalarTensorElement for f32 {
77    fn scalar_value(value: Self) -> ScalarValue {
78        ScalarValue::F32(value)
79    }
80}
81
82impl ScalarTensorElement for f64 {
83    fn scalar_value(value: Self) -> ScalarValue {
84        ScalarValue::F64(value)
85    }
86}
87
88impl ScalarTensorElement for Complex32 {
89    fn scalar_value(value: Self) -> ScalarValue {
90        ScalarValue::C32(value)
91    }
92}
93
94impl ScalarTensorElement for Complex64 {
95    fn scalar_value(value: Self) -> ScalarValue {
96        ScalarValue::C64(value)
97    }
98}
99
100/// Dynamic scalar compatibility wrapper for tensor4all-core.
101///
102/// This owns a rank-0 [`TensorDynLen`] so that scalar values can participate in
103/// the same eager autodiff graph as tensors while preserving the existing
104/// dynamic scalar API shape.
105#[derive(Clone)]
106pub struct AnyScalar {
107    tensor: Option<TensorDynLen>,
108    value: ScalarValue,
109}
110
111impl AnyScalar {
112    fn wrap_tensor(tensor: TensorDynLen) -> Result<Self> {
113        let dims = tensor.dims();
114        anyhow::ensure!(
115            dims.is_empty(),
116            "AnyScalar requires a rank-0 tensor, got dims {:?}",
117            dims
118        );
119        let value = Self::scalar_value_from_tensor(&tensor)?;
120        Ok(Self {
121            tensor: Some(tensor),
122            value,
123        })
124    }
125
126    fn from_tensor_result(tensor: Result<TensorDynLen>, op: &'static str) -> Result<Self> {
127        Self::wrap_tensor(
128            tensor.map_err(|e| anyhow!("AnyScalar::{op} returned invalid scalar tensor: {e}"))?,
129        )
130        .map_err(|e| anyhow!("AnyScalar::{op} returned non-scalar tensor: {e}"))
131    }
132
133    fn from_eager_binary<E>(
134        lhs: &Self,
135        rhs: &Self,
136        op: &'static str,
137        f: impl FnOnce(
138            &tenferro::EagerTensor<tenferro::CpuBackend>,
139            &tenferro::EagerTensor<tenferro::CpuBackend>,
140        ) -> std::result::Result<tenferro::EagerTensor<tenferro::CpuBackend>, E>,
141    ) -> Result<Self>
142    where
143        E: fmt::Display,
144    {
145        let result = f(lhs.as_tensor()?.as_inner()?, rhs.as_tensor()?.as_inner()?)
146            .map_err(|e| anyhow!("AnyScalar::{op} failed: {e}"))?;
147        Self::from_tensor_result(TensorDynLen::from_inner(vec![], result), op)
148    }
149
150    fn from_eager_unary<E>(
151        input: &Self,
152        op: &'static str,
153        f: impl FnOnce(
154            &tenferro::EagerTensor<tenferro::CpuBackend>,
155        ) -> std::result::Result<tenferro::EagerTensor<tenferro::CpuBackend>, E>,
156    ) -> Result<Self>
157    where
158        E: fmt::Display,
159    {
160        let result = f(input.as_tensor()?.as_inner()?)
161            .map_err(|e| anyhow!("AnyScalar::{op} failed: {e}"))?;
162        Self::from_tensor_result(TensorDynLen::from_inner(vec![], result), op)
163    }
164
165    fn scalar_value_from_tensor(tensor: &TensorDynLen) -> Result<ScalarValue> {
166        let storage = tensor.storage();
167        if storage.is_c64() {
168            let values = storage
169                .payload_c64_col_major_vec()
170                .map_err(|e| anyhow!("failed to read c64 scalar storage: {e}"))?;
171            values
172                .first()
173                .copied()
174                .map(ScalarValue::C64)
175                .ok_or_else(|| anyhow!("rank-0 c64 scalar storage is empty"))
176        } else {
177            let values = storage
178                .payload_f64_col_major_vec()
179                .map_err(|e| anyhow!("failed to read f64 scalar storage: {e}"))?;
180            values
181                .first()
182                .copied()
183                .map(ScalarValue::F64)
184                .ok_or_else(|| anyhow!("rank-0 f64 scalar storage is empty"))
185        }
186    }
187
188    fn value(&self) -> ScalarValue {
189        self.value
190    }
191
192    fn from_backend_scalar(value: BackendScalar) -> Self {
193        if let Some(value) = value.as_c64() {
194            Self::from_value(value)
195        } else {
196            Self::from_value(value.real())
197        }
198    }
199
200    pub(crate) fn from_tensor(tensor: TensorDynLen) -> Result<Self> {
201        Self::wrap_tensor(tensor)
202    }
203
204    pub(crate) fn as_tensor(&self) -> Result<&TensorDynLen> {
205        self.tensor
206            .as_ref()
207            .ok_or_else(|| anyhow!("AnyScalar has no backend tensor representation"))
208    }
209
210    /// Creates an `AnyScalar` from a tensor element.
211    ///
212    /// Use this when you already have a scalar value that implements
213    /// [`TensorElement`] and want to lift it into the dynamic scalar wrapper.
214    ///
215    /// # Arguments
216    ///
217    /// * `value` - The scalar value to wrap.
218    ///
219    /// # Returns
220    ///
221    /// A rank-0 `AnyScalar` containing `value`.
222    ///
223    /// # Examples
224    ///
225    /// ```
226    /// use tensor4all_core::AnyScalar;
227    ///
228    /// let scalar = AnyScalar::from_value(3.0f64);
229    /// assert_eq!(scalar.real(), 3.0);
230    /// assert!(scalar.is_real());
231    /// ```
232    #[allow(private_bounds)]
233    pub fn from_value<T: ScalarTensorElement>(value: T) -> Self {
234        Self {
235            tensor: TensorDynLen::scalar(value).ok(),
236            value: T::scalar_value(value),
237        }
238    }
239
240    /// Creates a real-valued `AnyScalar`.
241    ///
242    /// This is a convenience wrapper around [`AnyScalar::from_value`].
243    ///
244    /// # Arguments
245    ///
246    /// * `x` - The real scalar value to wrap.
247    ///
248    /// # Returns
249    ///
250    /// A rank-0 `AnyScalar` with real dtype.
251    ///
252    /// # Examples
253    ///
254    /// ```
255    /// use tensor4all_core::AnyScalar;
256    ///
257    /// let scalar = AnyScalar::from_real(1.25);
258    /// assert_eq!(scalar.as_f64(), Some(1.25));
259    /// assert!(scalar.is_real());
260    /// ```
261    pub fn from_real(x: f64) -> Self {
262        Self::from_value(x)
263    }
264
265    /// Creates a complex-valued `AnyScalar`.
266    ///
267    /// This is a convenience wrapper around [`AnyScalar::from_value`].
268    ///
269    /// # Arguments
270    ///
271    /// * `re` - The real part of the complex value.
272    /// * `im` - The imaginary part of the complex value.
273    ///
274    /// # Returns
275    ///
276    /// A rank-0 `AnyScalar` containing the requested complex number.
277    ///
278    /// # Examples
279    ///
280    /// ```
281    /// use tensor4all_core::AnyScalar;
282    ///
283    /// let scalar = AnyScalar::from_complex(1.0, -2.0);
284    /// assert_eq!(scalar.as_c64().map(|z| (z.re, z.im)), Some((1.0, -2.0)));
285    /// assert!(scalar.is_complex());
286    /// ```
287    pub fn from_complex(re: f64, im: f64) -> Self {
288        Self::from_value(Complex64::new(re, im))
289    }
290
291    /// Creates a real-valued `AnyScalar`.
292    ///
293    /// This is an alias for [`AnyScalar::from_real`].
294    ///
295    /// # Arguments
296    ///
297    /// * `x` - The real scalar value to wrap.
298    ///
299    /// # Returns
300    ///
301    /// A rank-0 `AnyScalar` with real dtype.
302    ///
303    /// # Examples
304    ///
305    /// ```
306    /// use tensor4all_core::AnyScalar;
307    ///
308    /// let scalar = AnyScalar::new_real(2.5);
309    /// assert_eq!(scalar.real(), 2.5);
310    /// assert!(scalar.is_real());
311    /// ```
312    pub fn new_real(x: f64) -> Self {
313        Self::from_real(x)
314    }
315
316    /// Creates a complex-valued `AnyScalar`.
317    ///
318    /// This is an alias for [`AnyScalar::from_complex`].
319    ///
320    /// # Arguments
321    ///
322    /// * `re` - The real part of the complex value.
323    /// * `im` - The imaginary part of the complex value.
324    ///
325    /// # Returns
326    ///
327    /// A rank-0 `AnyScalar` containing the requested complex number.
328    ///
329    /// # Examples
330    ///
331    /// ```
332    /// use tensor4all_core::AnyScalar;
333    ///
334    /// let scalar = AnyScalar::new_complex(2.0, 3.0);
335    /// assert_eq!(scalar.as_c64().map(|z| (z.re, z.im)), Some((2.0, 3.0)));
336    /// assert!(scalar.is_complex());
337    /// ```
338    pub fn new_complex(re: f64, im: f64) -> Self {
339        Self::from_complex(re, im)
340    }
341
342    /// Returns the detached primal value of this scalar.
343    ///
344    /// This is an alias for [`AnyScalar::detach`].
345    ///
346    /// # Returns
347    ///
348    /// A scalar with the same value and no gradient tracking.
349    ///
350    /// # Examples
351    ///
352    /// ```
353    /// use tensor4all_core::AnyScalar;
354    ///
355    /// let primal = AnyScalar::new_real(5.0).enable_grad().unwrap().primal().unwrap();
356    /// assert_eq!(primal.real(), 5.0);
357    /// assert!(!primal.tracks_grad());
358    /// ```
359    pub fn primal(&self) -> Result<Self> {
360        self.detach()
361    }
362
363    /// Enables gradient tracking for this scalar.
364    ///
365    /// # Returns
366    ///
367    /// A new scalar that shares the same value but participates in autodiff.
368    ///
369    /// # Examples
370    ///
371    /// ```
372    /// use tensor4all_core::AnyScalar;
373    ///
374    /// let scalar = AnyScalar::new_real(2.0).enable_grad().unwrap();
375    /// assert!(scalar.tracks_grad());
376    /// ```
377    pub fn enable_grad(self) -> Result<Self> {
378        let tensor = self
379            .tensor
380            .ok_or_else(|| anyhow!("AnyScalar has no backend tensor representation"))?;
381        Self::from_tensor(tensor.enable_grad()?)
382    }
383
384    /// Returns whether this scalar tracks gradients.
385    ///
386    /// # Returns
387    ///
388    /// `true` when the scalar participates in autodiff and can accumulate a
389    /// gradient, otherwise `false`.
390    ///
391    /// # Examples
392    ///
393    /// ```
394    /// use tensor4all_core::AnyScalar;
395    ///
396    /// let scalar = AnyScalar::new_real(1.0);
397    /// assert!(!scalar.tracks_grad());
398    /// ```
399    pub fn tracks_grad(&self) -> bool {
400        self.tensor.as_ref().is_some_and(TensorDynLen::tracks_grad)
401    }
402
403    /// Returns the stored gradient, if any.
404    ///
405    /// # Returns
406    ///
407    /// `Ok(Some(_))` when a gradient is available, `Ok(None)` when no gradient
408    /// has been recorded, or an error if the backend cannot read it.
409    ///
410    /// # Errors
411    ///
412    /// Propagates autodiff or tensor access failures from the underlying
413    /// tensor runtime.
414    ///
415    /// # Examples
416    ///
417    /// ```
418    /// use tensor4all_core::AnyScalar;
419    ///
420    /// let x = AnyScalar::new_real(2.0).enable_grad().unwrap();
421    /// let y = &x * &x;
422    /// y.backward().unwrap();
423    ///
424    /// let grad = x.grad().unwrap().unwrap();
425    /// assert_eq!(grad.real(), 4.0);
426    /// ```
427    pub fn grad(&self) -> Result<Option<Self>> {
428        self.as_tensor()?
429            .grad()
430            .and_then(|maybe_grad| maybe_grad.map(Self::from_tensor).transpose())
431    }
432
433    /// Clears the stored gradient for this scalar.
434    ///
435    /// # Returns
436    ///
437    /// `Ok(())` when the gradient buffer was cleared successfully.
438    ///
439    /// # Errors
440    ///
441    /// Propagates tensor runtime failures from the underlying autodiff state.
442    ///
443    /// # Examples
444    ///
445    /// ```
446    /// use tensor4all_core::AnyScalar;
447    ///
448    /// let x = AnyScalar::new_real(2.0).enable_grad().unwrap();
449    /// let y = &x * &x;
450    /// y.backward().unwrap();
451    /// assert!(x.grad().unwrap().is_some());
452    ///
453    /// x.clear_grad().unwrap();
454    /// assert!(x.grad().unwrap().is_none());
455    /// ```
456    pub fn clear_grad(&self) -> Result<()> {
457        self.as_tensor()?.clear_grad()
458    }
459
460    /// Runs reverse-mode autodiff starting from this scalar.
461    ///
462    /// # Returns
463    ///
464    /// `Ok(())` when gradients were accumulated successfully.
465    ///
466    /// # Errors
467    ///
468    /// Propagates failures from the underlying tensor autodiff engine.
469    ///
470    /// # Examples
471    ///
472    /// ```
473    /// use tensor4all_core::AnyScalar;
474    ///
475    /// let x = AnyScalar::new_real(2.0).enable_grad().unwrap();
476    /// let y = &x * &x;
477    /// y.backward().unwrap();
478    ///
479    /// let grad = x.grad().unwrap().unwrap();
480    /// assert_eq!(grad.real(), 4.0);
481    /// ```
482    pub fn backward(&self) -> Result<()> {
483        self.as_tensor()?.backward()
484    }
485
486    /// Returns a detached copy of this scalar.
487    ///
488    /// # Returns
489    ///
490    /// A scalar with the same value but without gradient tracking.
491    ///
492    /// # Examples
493    ///
494    /// ```
495    /// use tensor4all_core::AnyScalar;
496    ///
497    /// let detached = AnyScalar::new_real(7.0)
498    ///     .enable_grad()
499    ///     .unwrap()
500    ///     .detach()
501    ///     .unwrap();
502    /// assert_eq!(detached.real(), 7.0);
503    /// assert!(!detached.tracks_grad());
504    /// ```
505    pub fn detach(&self) -> Result<Self> {
506        Self::from_tensor(self.as_tensor()?.detach()?)
507    }
508
509    /// Returns the real part of this scalar.
510    ///
511    /// # Returns
512    ///
513    /// The real component as an `f64`, regardless of the underlying storage
514    /// type.
515    ///
516    /// # Examples
517    ///
518    /// ```
519    /// use tensor4all_core::AnyScalar;
520    ///
521    /// let scalar = AnyScalar::new_complex(3.0, -4.0);
522    /// assert_eq!(scalar.real(), 3.0);
523    /// ```
524    pub fn real(&self) -> f64 {
525        self.value().real()
526    }
527
528    /// Returns the imaginary part of this scalar.
529    ///
530    /// # Returns
531    ///
532    /// The imaginary component as an `f64`. Real-valued scalars return `0.0`.
533    ///
534    /// # Examples
535    ///
536    /// ```
537    /// use tensor4all_core::AnyScalar;
538    ///
539    /// let scalar = AnyScalar::new_complex(3.0, -4.0);
540    /// assert_eq!(scalar.imag(), -4.0);
541    /// ```
542    pub fn imag(&self) -> f64 {
543        self.value().imag()
544    }
545
546    /// Returns the magnitude of this scalar.
547    ///
548    /// # Returns
549    ///
550    /// The absolute value for real scalars or the complex norm for complex
551    /// scalars.
552    ///
553    /// # Examples
554    ///
555    /// ```
556    /// use tensor4all_core::AnyScalar;
557    ///
558    /// let scalar = AnyScalar::new_complex(3.0, -4.0);
559    /// assert_eq!(scalar.abs(), 5.0);
560    /// ```
561    pub fn abs(&self) -> f64 {
562        self.value().abs()
563    }
564
565    /// Returns whether this scalar is complex-valued.
566    ///
567    /// # Returns
568    ///
569    /// `true` for complex dtypes and `false` for real or integer dtypes.
570    ///
571    /// # Examples
572    ///
573    /// ```
574    /// use tensor4all_core::AnyScalar;
575    ///
576    /// assert!(AnyScalar::new_complex(1.0, 2.0).is_complex());
577    /// assert!(!AnyScalar::new_real(1.0).is_complex());
578    /// ```
579    pub fn is_complex(&self) -> bool {
580        self.value().is_complex()
581    }
582
583    /// Returns whether this scalar is real-valued.
584    ///
585    /// # Returns
586    ///
587    /// `true` when the scalar is not complex-valued.
588    ///
589    /// # Examples
590    ///
591    /// ```
592    /// use tensor4all_core::AnyScalar;
593    ///
594    /// assert!(AnyScalar::new_real(1.0).is_real());
595    /// assert!(!AnyScalar::new_complex(1.0, 2.0).is_real());
596    /// ```
597    pub fn is_real(&self) -> bool {
598        !self.is_complex()
599    }
600
601    /// Returns whether this scalar is exactly zero.
602    ///
603    /// # Returns
604    ///
605    /// `true` for exact zeros and `false` for any nonzero value.
606    ///
607    /// # Examples
608    ///
609    /// ```
610    /// use tensor4all_core::AnyScalar;
611    ///
612    /// assert!(AnyScalar::new_real(0.0).is_zero());
613    /// assert!(!AnyScalar::new_complex(0.0, 1.0).is_zero());
614    /// ```
615    pub fn is_zero(&self) -> bool {
616        self.value().is_zero()
617    }
618
619    /// Returns this scalar as an `f64` when it is real-valued.
620    ///
621    /// # Returns
622    ///
623    /// `Some(value)` for real and integer scalars, or `None` for complex
624    /// scalars.
625    ///
626    /// # Examples
627    ///
628    /// ```
629    /// use tensor4all_core::AnyScalar;
630    ///
631    /// assert_eq!(AnyScalar::new_real(2.5).as_f64(), Some(2.5));
632    /// assert_eq!(AnyScalar::new_complex(2.5, 1.0).as_f64(), None);
633    /// ```
634    pub fn as_f64(&self) -> Option<f64> {
635        match self.value() {
636            ScalarValue::F32(value) => Some(value as f64),
637            ScalarValue::F64(value) => Some(value),
638            ScalarValue::C32(_) | ScalarValue::C64(_) => None,
639        }
640    }
641
642    /// Returns this scalar as a `Complex64` when it is complex-valued.
643    ///
644    /// # Returns
645    ///
646    /// `Some(value)` for complex scalars or `None` for real and integer
647    /// scalars.
648    ///
649    /// # Examples
650    ///
651    /// ```
652    /// use tensor4all_core::AnyScalar;
653    ///
654    /// let scalar = AnyScalar::new_complex(2.5, 1.0);
655    /// assert_eq!(scalar.as_c64().map(|z| (z.re, z.im)), Some((2.5, 1.0)));
656    /// assert_eq!(AnyScalar::new_real(2.5).as_c64(), None);
657    /// ```
658    pub fn as_c64(&self) -> Option<Complex64> {
659        match self.value() {
660            ScalarValue::F32(_) | ScalarValue::F64(_) => None,
661            ScalarValue::C32(value) => Some(Complex64::new(value.re as f64, value.im as f64)),
662            ScalarValue::C64(value) => Some(value),
663        }
664    }
665
666    /// Returns the complex conjugate of this scalar.
667    ///
668    /// # Returns
669    ///
670    /// The conjugated scalar. Real-valued inputs are returned unchanged.
671    ///
672    /// # Examples
673    ///
674    /// ```
675    /// use tensor4all_core::AnyScalar;
676    ///
677    /// let scalar = AnyScalar::new_complex(3.0, -4.0).conj();
678    /// assert_eq!(scalar.as_c64().map(|z| (z.re, z.im)), Some((3.0, 4.0)));
679    /// ```
680    pub fn try_conj(&self) -> Result<Self> {
681        if !self.tracks_grad() {
682            return Ok(Self::from_backend_scalar(self.to_backend_scalar().conj()));
683        }
684        Self::from_eager_unary(self, "conj", |tensor| tensor.conj())
685    }
686
687    /// Returns the complex conjugate of this scalar.
688    pub fn conj(&self) -> Self {
689        self.try_conj()
690            .unwrap_or_else(|_| Self::from_backend_scalar(self.to_backend_scalar().conj()))
691    }
692
693    /// Returns the real part as a real-valued scalar.
694    ///
695    /// # Returns
696    ///
697    /// A real-valued scalar containing the real component of `self`.
698    ///
699    /// # Examples
700    ///
701    /// ```
702    /// use tensor4all_core::AnyScalar;
703    ///
704    /// let scalar = AnyScalar::new_complex(3.0, -4.0).real_part();
705    /// assert_eq!(scalar.real(), 3.0);
706    /// assert!(scalar.is_real());
707    /// ```
708    pub fn real_part(&self) -> Self {
709        Self::from_real(self.real())
710    }
711
712    /// Returns the imaginary part as a real-valued scalar.
713    ///
714    /// # Returns
715    ///
716    /// A real-valued scalar containing the imaginary component of `self`.
717    ///
718    /// # Examples
719    ///
720    /// ```
721    /// use tensor4all_core::AnyScalar;
722    ///
723    /// let scalar = AnyScalar::new_complex(3.0, -4.0).imag_part();
724    /// assert_eq!(scalar.real(), -4.0);
725    /// assert!(scalar.is_real());
726    /// ```
727    pub fn imag_part(&self) -> Self {
728        Self::from_real(self.imag())
729    }
730
731    /// Combines two real-valued scalars into a complex scalar.
732    ///
733    /// # Arguments
734    ///
735    /// * `real` - The real component.
736    /// * `imag` - The imaginary component.
737    ///
738    /// # Returns
739    ///
740    /// A complex `AnyScalar` whose real and imaginary parts come from the
741    /// inputs.
742    ///
743    /// # Errors
744    ///
745    /// Returns an error if either input is not real-valued.
746    ///
747    /// # Examples
748    ///
749    /// ```
750    /// use tensor4all_core::AnyScalar;
751    ///
752    /// let scalar = AnyScalar::compose_complex(
753    ///     AnyScalar::new_real(3.0),
754    ///     AnyScalar::new_real(-4.0),
755    /// )
756    /// .unwrap();
757    /// assert_eq!(scalar.as_c64().map(|z| (z.re, z.im)), Some((3.0, -4.0)));
758    /// ```
759    pub fn compose_complex(real: Self, imag: Self) -> Result<Self> {
760        if !real.is_real() || !imag.is_real() {
761            return Err(anyhow!("compose_complex requires real-valued inputs"));
762        }
763        let imag_term = imag.try_mul(&Self::new_complex(0.0, 1.0))?;
764        real.try_add(&imag_term)
765    }
766
767    /// Returns the square root of this scalar.
768    ///
769    /// # Returns
770    ///
771    /// The principal square root. Negative real inputs and complex inputs use
772    /// complex arithmetic.
773    ///
774    /// # Examples
775    ///
776    /// ```
777    /// use tensor4all_core::AnyScalar;
778    ///
779    /// let scalar = AnyScalar::new_real(9.0).sqrt();
780    /// assert_eq!(scalar.real(), 3.0);
781    /// assert!(scalar.is_real());
782    /// ```
783    pub fn sqrt(&self) -> Self {
784        if !self.tracks_grad() || self.is_complex() || self.real() < 0.0 {
785            Self::from_backend_scalar(self.to_backend_scalar().sqrt())
786        } else {
787            Self::from_eager_unary(self, "sqrt", |tensor| tensor.sqrt())
788                .unwrap_or_else(|_| Self::from_backend_scalar(self.to_backend_scalar().sqrt()))
789        }
790    }
791
792    /// Raises this scalar to a floating-point power.
793    ///
794    /// # Arguments
795    ///
796    /// * `exponent` - The exponent to apply.
797    ///
798    /// # Returns
799    ///
800    /// The value of `self^exponent`.
801    ///
802    /// # Examples
803    ///
804    /// ```
805    /// use tensor4all_core::AnyScalar;
806    ///
807    /// let scalar = AnyScalar::new_real(2.0).powf(3.0);
808    /// assert_eq!(scalar.real(), 8.0);
809    /// ```
810    pub fn powf(&self, exponent: f64) -> Self {
811        Self::from_backend_scalar(self.to_backend_scalar().powf(exponent))
812    }
813
814    /// Raises this scalar to an integer power.
815    ///
816    /// # Arguments
817    ///
818    /// * `exponent` - The integer exponent to apply. Negative exponents return
819    ///   the reciprocal power.
820    ///
821    /// # Returns
822    ///
823    /// The value of `self^exponent`. Zero exponents return `1`.
824    ///
825    /// # Examples
826    ///
827    /// ```
828    /// use tensor4all_core::AnyScalar;
829    ///
830    /// assert_eq!(AnyScalar::new_real(2.0).powi(3).real(), 8.0);
831    /// assert_eq!(AnyScalar::new_real(2.0).powi(-1).real(), 0.5);
832    /// ```
833    pub fn powi(&self, exponent: i32) -> Self {
834        if exponent == 0 {
835            return Self::one();
836        }
837
838        let mut base = self.clone();
839        let mut power = exponent.unsigned_abs();
840        let mut acc = Self::one();
841
842        while power > 0 {
843            if power % 2 == 1 {
844                acc = acc.try_mul(&base).unwrap_or_else(|_| {
845                    Self::from_backend_scalar(acc.to_backend_scalar() * base.to_backend_scalar())
846                });
847            }
848            power /= 2;
849            if power > 0 {
850                base = base.try_mul(&base).unwrap_or_else(|_| {
851                    Self::from_backend_scalar(base.to_backend_scalar() * base.to_backend_scalar())
852                });
853            }
854        }
855
856        if exponent < 0 {
857            Self::one().try_div(&acc).unwrap_or_else(|_| {
858                Self::from_backend_scalar(Self::one().to_backend_scalar() / acc.to_backend_scalar())
859            })
860        } else {
861            acc
862        }
863    }
864
865    pub(crate) fn to_backend_scalar(&self) -> BackendScalar {
866        match self.value() {
867            ScalarValue::F32(value) => BackendScalar::from_value(value),
868            ScalarValue::F64(value) => BackendScalar::from_value(value),
869            ScalarValue::C32(value) => BackendScalar::from_value(value),
870            ScalarValue::C64(value) => BackendScalar::from_value(value),
871        }
872    }
873
874    pub(crate) fn try_add(&self, rhs: &Self) -> Result<Self> {
875        if !self.tracks_grad() && !rhs.tracks_grad() {
876            return Ok(Self::from_backend_scalar(
877                self.to_backend_scalar() + rhs.to_backend_scalar(),
878            ));
879        }
880        Self::from_eager_binary(self, rhs, "add", |lhs, rhs| lhs.add(rhs))
881    }
882
883    pub(crate) fn try_mul(&self, rhs: &Self) -> Result<Self> {
884        if !self.tracks_grad() && !rhs.tracks_grad() {
885            return Ok(Self::from_backend_scalar(
886                self.to_backend_scalar() * rhs.to_backend_scalar(),
887            ));
888        }
889        Self::from_eager_binary(self, rhs, "mul", |lhs, rhs| lhs.mul(rhs))
890    }
891
892    pub(crate) fn try_div(&self, rhs: &Self) -> Result<Self> {
893        if !self.tracks_grad() && !rhs.tracks_grad() {
894            return Ok(Self::from_backend_scalar(
895                self.to_backend_scalar() / rhs.to_backend_scalar(),
896            ));
897        }
898        if self.as_tensor()?.as_native()?.dtype() == rhs.as_tensor()?.as_native()?.dtype() {
899            Self::from_eager_binary(self, rhs, "div", |lhs, rhs| lhs.div(rhs))
900        } else {
901            Ok(Self::from_backend_scalar(
902                self.to_backend_scalar() / rhs.to_backend_scalar(),
903            ))
904        }
905    }
906
907    pub(crate) fn try_neg(&self) -> Result<Self> {
908        if !self.tracks_grad() {
909            return Ok(Self::from_backend_scalar(-self.to_backend_scalar()));
910        }
911        Self::from_eager_unary(self, "neg", |tensor| tensor.neg())
912    }
913}
914
915impl SumFromStorage for AnyScalar {
916    fn sum_from_storage(storage: &Storage) -> Self {
917        Self::from_backend_scalar(BackendScalar::sum_from_storage(storage))
918    }
919}
920
921impl From<f32> for AnyScalar {
922    fn from(value: f32) -> Self {
923        Self::from_value(value)
924    }
925}
926
927impl From<f64> for AnyScalar {
928    fn from(value: f64) -> Self {
929        Self::from_value(value)
930    }
931}
932
933impl From<Complex32> for AnyScalar {
934    fn from(value: Complex32) -> Self {
935        Self::from_value(value)
936    }
937}
938
939impl From<Complex64> for AnyScalar {
940    fn from(value: Complex64) -> Self {
941        Self::from_value(value)
942    }
943}
944
945impl TryFrom<AnyScalar> for f64 {
946    type Error = &'static str;
947
948    fn try_from(value: AnyScalar) -> std::result::Result<Self, Self::Error> {
949        value.as_f64().ok_or("cannot convert complex scalar to f64")
950    }
951}
952
953impl From<AnyScalar> for Complex64 {
954    fn from(value: AnyScalar) -> Self {
955        value.value().into_complex()
956    }
957}
958
959impl Add<&AnyScalar> for &AnyScalar {
960    type Output = AnyScalar;
961
962    fn add(self, rhs: &AnyScalar) -> Self::Output {
963        self.try_add(rhs).unwrap_or_else(|_| {
964            AnyScalar::from_backend_scalar(self.to_backend_scalar() + rhs.to_backend_scalar())
965        })
966    }
967}
968
969impl Add<AnyScalar> for AnyScalar {
970    type Output = AnyScalar;
971
972    fn add(self, rhs: AnyScalar) -> Self::Output {
973        Add::add(&self, &rhs)
974    }
975}
976
977impl Add<AnyScalar> for &AnyScalar {
978    type Output = AnyScalar;
979
980    fn add(self, rhs: AnyScalar) -> Self::Output {
981        Add::add(self, &rhs)
982    }
983}
984
985impl Add<&AnyScalar> for AnyScalar {
986    type Output = AnyScalar;
987
988    fn add(self, rhs: &AnyScalar) -> Self::Output {
989        Add::add(&self, rhs)
990    }
991}
992
993impl Sub<&AnyScalar> for &AnyScalar {
994    type Output = AnyScalar;
995
996    fn sub(self, rhs: &AnyScalar) -> Self::Output {
997        Add::add(self, &Neg::neg(rhs))
998    }
999}
1000
1001impl Sub<AnyScalar> for AnyScalar {
1002    type Output = AnyScalar;
1003
1004    fn sub(self, rhs: AnyScalar) -> Self::Output {
1005        Sub::sub(&self, &rhs)
1006    }
1007}
1008
1009impl Sub<AnyScalar> for &AnyScalar {
1010    type Output = AnyScalar;
1011
1012    fn sub(self, rhs: AnyScalar) -> Self::Output {
1013        Sub::sub(self, &rhs)
1014    }
1015}
1016
1017impl Sub<&AnyScalar> for AnyScalar {
1018    type Output = AnyScalar;
1019
1020    fn sub(self, rhs: &AnyScalar) -> Self::Output {
1021        Sub::sub(&self, rhs)
1022    }
1023}
1024
1025impl Mul<&AnyScalar> for &AnyScalar {
1026    type Output = AnyScalar;
1027
1028    fn mul(self, rhs: &AnyScalar) -> Self::Output {
1029        self.try_mul(rhs).unwrap_or_else(|_| {
1030            AnyScalar::from_backend_scalar(self.to_backend_scalar() * rhs.to_backend_scalar())
1031        })
1032    }
1033}
1034
1035impl Mul<AnyScalar> for AnyScalar {
1036    type Output = AnyScalar;
1037
1038    fn mul(self, rhs: AnyScalar) -> Self::Output {
1039        Mul::mul(&self, &rhs)
1040    }
1041}
1042
1043impl Mul<AnyScalar> for &AnyScalar {
1044    type Output = AnyScalar;
1045
1046    fn mul(self, rhs: AnyScalar) -> Self::Output {
1047        Mul::mul(self, &rhs)
1048    }
1049}
1050
1051impl Mul<&AnyScalar> for AnyScalar {
1052    type Output = AnyScalar;
1053
1054    fn mul(self, rhs: &AnyScalar) -> Self::Output {
1055        Mul::mul(&self, rhs)
1056    }
1057}
1058
1059impl Div<&AnyScalar> for &AnyScalar {
1060    type Output = AnyScalar;
1061
1062    fn div(self, rhs: &AnyScalar) -> Self::Output {
1063        self.try_div(rhs).unwrap_or_else(|_| {
1064            AnyScalar::from_backend_scalar(self.to_backend_scalar() / rhs.to_backend_scalar())
1065        })
1066    }
1067}
1068
1069impl Div<AnyScalar> for AnyScalar {
1070    type Output = AnyScalar;
1071
1072    fn div(self, rhs: AnyScalar) -> Self::Output {
1073        Div::div(&self, &rhs)
1074    }
1075}
1076
1077impl Div<AnyScalar> for &AnyScalar {
1078    type Output = AnyScalar;
1079
1080    fn div(self, rhs: AnyScalar) -> Self::Output {
1081        Div::div(self, &rhs)
1082    }
1083}
1084
1085impl Div<&AnyScalar> for AnyScalar {
1086    type Output = AnyScalar;
1087
1088    fn div(self, rhs: &AnyScalar) -> Self::Output {
1089        Div::div(&self, rhs)
1090    }
1091}
1092
1093impl Neg for &AnyScalar {
1094    type Output = AnyScalar;
1095
1096    fn neg(self) -> Self::Output {
1097        self.try_neg()
1098            .unwrap_or_else(|_| AnyScalar::from_backend_scalar(-self.to_backend_scalar()))
1099    }
1100}
1101
1102impl Neg for AnyScalar {
1103    type Output = AnyScalar;
1104
1105    fn neg(self) -> Self::Output {
1106        Neg::neg(&self)
1107    }
1108}
1109
1110impl Mul<AnyScalar> for f64 {
1111    type Output = AnyScalar;
1112
1113    fn mul(self, rhs: AnyScalar) -> Self::Output {
1114        AnyScalar::from_real(self) * rhs
1115    }
1116}
1117
1118impl Mul<AnyScalar> for Complex64 {
1119    type Output = AnyScalar;
1120
1121    fn mul(self, rhs: AnyScalar) -> Self::Output {
1122        AnyScalar::from(self) * rhs
1123    }
1124}
1125
1126impl Div<AnyScalar> for Complex64 {
1127    type Output = AnyScalar;
1128
1129    fn div(self, rhs: AnyScalar) -> Self::Output {
1130        AnyScalar::from(self) / rhs
1131    }
1132}
1133
1134impl Default for AnyScalar {
1135    fn default() -> Self {
1136        Self::zero()
1137    }
1138}
1139
1140impl Zero for AnyScalar {
1141    fn zero() -> Self {
1142        Self::from_real(0.0)
1143    }
1144
1145    fn is_zero(&self) -> bool {
1146        AnyScalar::is_zero(self)
1147    }
1148}
1149
1150impl One for AnyScalar {
1151    fn one() -> Self {
1152        Self::from_real(1.0)
1153    }
1154}
1155
1156impl PartialEq for AnyScalar {
1157    fn eq(&self, other: &Self) -> bool {
1158        self.value() == other.value()
1159    }
1160}
1161
1162impl PartialOrd for AnyScalar {
1163    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1164        match (self.value(), other.value()) {
1165            (ScalarValue::F32(lhs), ScalarValue::F32(rhs)) => lhs.partial_cmp(&rhs),
1166            (ScalarValue::F32(lhs), ScalarValue::F64(rhs)) => (lhs as f64).partial_cmp(&rhs),
1167            (ScalarValue::F64(lhs), ScalarValue::F32(rhs)) => lhs.partial_cmp(&(rhs as f64)),
1168            (ScalarValue::F64(lhs), ScalarValue::F64(rhs)) => lhs.partial_cmp(&rhs),
1169            _ => None,
1170        }
1171    }
1172}
1173
1174impl fmt::Display for AnyScalar {
1175    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1176        match self.value() {
1177            ScalarValue::F32(value) => value.fmt(f),
1178            ScalarValue::F64(value) => value.fmt(f),
1179            ScalarValue::C32(value) => value.fmt(f),
1180            ScalarValue::C64(value) => value.fmt(f),
1181        }
1182    }
1183}
1184
1185impl fmt::Debug for AnyScalar {
1186    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1187        let dtype = match self.value {
1188            ScalarValue::F32(_) => "f32",
1189            ScalarValue::F64(_) => "f64",
1190            ScalarValue::C32(_) => "c32",
1191            ScalarValue::C64(_) => "c64",
1192        };
1193        f.debug_struct("AnyScalar")
1194            .field("dtype", &dtype)
1195            .field("value", &self.value())
1196            .field("tracks_grad", &self.tracks_grad())
1197            .finish()
1198    }
1199}
1200
1201#[cfg(test)]
1202mod tests {
1203    use super::*;
1204
1205    #[test]
1206    fn non_grad_scalar_arithmetic_uses_plain_values() {
1207        let a = AnyScalar::new_real(3.0);
1208        let b = AnyScalar::new_real(4.0);
1209
1210        let value = ((a.clone() + b.clone()) * b.clone() - AnyScalar::new_real(8.0))
1211            / AnyScalar::new_real(2.0);
1212
1213        assert_eq!(value.as_f64(), Some(10.0));
1214        assert!(!value.tracks_grad());
1215        assert!(value.as_tensor().is_ok());
1216    }
1217
1218    #[test]
1219    fn tracked_scalar_arithmetic_preserves_autodiff() {
1220        let x = AnyScalar::new_real(2.0).enable_grad().unwrap();
1221        let y = &x * &x;
1222
1223        assert!(y.tracks_grad());
1224        y.backward().unwrap();
1225
1226        let grad = x.grad().unwrap().unwrap();
1227        assert_eq!(grad.as_f64(), Some(4.0));
1228    }
1229}