Skip to main content

tensor4all_core/
any_scalar.rs

1use std::cmp::Ordering;
2use std::fmt;
3use std::ops::{Add, Div, Mul, Neg, Sub};
4
5use anyhow::{anyhow, Result};
6use num_complex::{Complex32, Complex64};
7use num_traits::{One, Zero};
8use tensor4all_tensorbackend::AnyScalar as BackendScalar;
9
10use crate::defaults::tensordynlen::TensorDynLen;
11use crate::TensorElement;
12use tensor4all_tensorbackend::{Storage, SumFromStorage};
13
14#[derive(Clone, Copy, Debug, PartialEq)]
15enum ScalarValue {
16    F32(f32),
17    F64(f64),
18    C32(Complex32),
19    C64(Complex64),
20}
21
22impl ScalarValue {
23    fn real(self) -> f64 {
24        match self {
25            Self::F32(value) => value as f64,
26            Self::F64(value) => value,
27            Self::C32(value) => value.re as f64,
28            Self::C64(value) => value.re,
29        }
30    }
31
32    fn imag(self) -> f64 {
33        match self {
34            Self::F32(_) | Self::F64(_) => 0.0,
35            Self::C32(value) => value.im as f64,
36            Self::C64(value) => value.im,
37        }
38    }
39
40    fn abs(self) -> f64 {
41        match self {
42            Self::F32(value) => value.abs() as f64,
43            Self::F64(value) => value.abs(),
44            Self::C32(value) => value.norm() as f64,
45            Self::C64(value) => value.norm(),
46        }
47    }
48
49    fn is_complex(self) -> bool {
50        matches!(self, Self::C32(_) | Self::C64(_))
51    }
52
53    fn is_zero(self) -> bool {
54        match self {
55            Self::F32(value) => value == 0.0,
56            Self::F64(value) => value == 0.0,
57            Self::C32(value) => value == Complex32::new(0.0, 0.0),
58            Self::C64(value) => value == Complex64::new(0.0, 0.0),
59        }
60    }
61
62    fn into_complex(self) -> Complex64 {
63        match self {
64            Self::F32(value) => Complex64::new(value as f64, 0.0),
65            Self::F64(value) => Complex64::new(value, 0.0),
66            Self::C32(value) => Complex64::new(value.re as f64, value.im as f64),
67            Self::C64(value) => value,
68        }
69    }
70}
71
72trait ScalarTensorElement: TensorElement {
73    fn scalar_value(value: Self) -> ScalarValue;
74}
75
76impl ScalarTensorElement for f32 {
77    fn scalar_value(value: Self) -> ScalarValue {
78        ScalarValue::F32(value)
79    }
80}
81
82impl ScalarTensorElement for f64 {
83    fn scalar_value(value: Self) -> ScalarValue {
84        ScalarValue::F64(value)
85    }
86}
87
88impl ScalarTensorElement for Complex32 {
89    fn scalar_value(value: Self) -> ScalarValue {
90        ScalarValue::C32(value)
91    }
92}
93
94impl ScalarTensorElement for Complex64 {
95    fn scalar_value(value: Self) -> ScalarValue {
96        ScalarValue::C64(value)
97    }
98}
99
100/// Dynamic scalar compatibility wrapper for tensor4all-core.
101///
102/// This owns a rank-0 [`TensorDynLen`] so that scalar values can participate in
103/// the same eager autodiff graph as tensors while preserving the existing
104/// dynamic scalar API shape.
105#[derive(Clone)]
106pub struct AnyScalar {
107    tensor: Option<TensorDynLen>,
108    value: ScalarValue,
109}
110
111impl AnyScalar {
112    fn wrap_tensor(tensor: TensorDynLen) -> Result<Self> {
113        let dims = tensor.dims();
114        anyhow::ensure!(
115            dims.is_empty(),
116            "AnyScalar requires a rank-0 tensor, got dims {:?}",
117            dims
118        );
119        let value = Self::scalar_value_from_tensor(&tensor)?;
120        Ok(Self {
121            tensor: Some(tensor),
122            value,
123        })
124    }
125
126    fn from_tensor_result(tensor: Result<TensorDynLen>, op: &'static str) -> Result<Self> {
127        Self::wrap_tensor(
128            tensor.map_err(|e| anyhow!("AnyScalar::{op} returned invalid scalar tensor: {e}"))?,
129        )
130        .map_err(|e| anyhow!("AnyScalar::{op} returned non-scalar tensor: {e}"))
131    }
132
133    fn from_eager_binary<E>(
134        lhs: &Self,
135        rhs: &Self,
136        op: &'static str,
137        f: impl FnOnce(
138            &tenferro_ad::EagerTensor,
139            &tenferro_ad::EagerTensor,
140        ) -> std::result::Result<tenferro_ad::EagerTensor, E>,
141    ) -> Result<Self>
142    where
143        E: fmt::Display,
144    {
145        let result = f(lhs.as_tensor()?.as_inner()?, rhs.as_tensor()?.as_inner()?)
146            .map_err(|e| anyhow!("AnyScalar::{op} failed: {e}"))?;
147        Self::from_tensor_result(TensorDynLen::from_inner(vec![], result), op)
148    }
149
150    fn from_eager_unary<E>(
151        input: &Self,
152        op: &'static str,
153        f: impl FnOnce(&tenferro_ad::EagerTensor) -> std::result::Result<tenferro_ad::EagerTensor, E>,
154    ) -> Result<Self>
155    where
156        E: fmt::Display,
157    {
158        let result = f(input.as_tensor()?.as_inner()?)
159            .map_err(|e| anyhow!("AnyScalar::{op} failed: {e}"))?;
160        Self::from_tensor_result(TensorDynLen::from_inner(vec![], result), op)
161    }
162
163    fn scalar_value_from_tensor(tensor: &TensorDynLen) -> Result<ScalarValue> {
164        let storage = tensor.storage();
165        if storage.is_c64() {
166            let values = storage
167                .payload_c64_col_major_vec()
168                .map_err(|e| anyhow!("failed to read c64 scalar storage: {e}"))?;
169            values
170                .first()
171                .copied()
172                .map(ScalarValue::C64)
173                .ok_or_else(|| anyhow!("rank-0 c64 scalar storage is empty"))
174        } else {
175            let values = storage
176                .payload_f64_col_major_vec()
177                .map_err(|e| anyhow!("failed to read f64 scalar storage: {e}"))?;
178            values
179                .first()
180                .copied()
181                .map(ScalarValue::F64)
182                .ok_or_else(|| anyhow!("rank-0 f64 scalar storage is empty"))
183        }
184    }
185
186    fn value(&self) -> ScalarValue {
187        self.value
188    }
189
190    fn from_backend_scalar(value: BackendScalar) -> Self {
191        if let Some(value) = value.as_c64() {
192            Self::from_value(value)
193        } else {
194            Self::from_value(value.real())
195        }
196    }
197
198    pub(crate) fn from_tensor(tensor: TensorDynLen) -> Result<Self> {
199        Self::wrap_tensor(tensor)
200    }
201
202    pub(crate) fn as_tensor(&self) -> Result<&TensorDynLen> {
203        self.tensor
204            .as_ref()
205            .ok_or_else(|| anyhow!("AnyScalar has no backend tensor representation"))
206    }
207
208    /// Creates an `AnyScalar` from a tensor element.
209    ///
210    /// Use this when you already have a scalar value that implements
211    /// [`TensorElement`] and want to lift it into the dynamic scalar wrapper.
212    ///
213    /// # Arguments
214    ///
215    /// * `value` - The scalar value to wrap.
216    ///
217    /// # Returns
218    ///
219    /// A rank-0 `AnyScalar` containing `value`.
220    ///
221    /// # Examples
222    ///
223    /// ```
224    /// use tensor4all_core::AnyScalar;
225    ///
226    /// let scalar = AnyScalar::from_value(3.0f64);
227    /// assert_eq!(scalar.real(), 3.0);
228    /// assert!(scalar.is_real());
229    /// ```
230    #[allow(private_bounds)]
231    pub fn from_value<T: ScalarTensorElement>(value: T) -> Self {
232        Self {
233            tensor: TensorDynLen::scalar(value).ok(),
234            value: T::scalar_value(value),
235        }
236    }
237
238    /// Creates a real-valued `AnyScalar`.
239    ///
240    /// This is a convenience wrapper around [`AnyScalar::from_value`].
241    ///
242    /// # Arguments
243    ///
244    /// * `x` - The real scalar value to wrap.
245    ///
246    /// # Returns
247    ///
248    /// A rank-0 `AnyScalar` with real dtype.
249    ///
250    /// # Examples
251    ///
252    /// ```
253    /// use tensor4all_core::AnyScalar;
254    ///
255    /// let scalar = AnyScalar::from_real(1.25);
256    /// assert_eq!(scalar.as_f64(), Some(1.25));
257    /// assert!(scalar.is_real());
258    /// ```
259    pub fn from_real(x: f64) -> Self {
260        Self::from_value(x)
261    }
262
263    /// Creates a complex-valued `AnyScalar`.
264    ///
265    /// This is a convenience wrapper around [`AnyScalar::from_value`].
266    ///
267    /// # Arguments
268    ///
269    /// * `re` - The real part of the complex value.
270    /// * `im` - The imaginary part of the complex value.
271    ///
272    /// # Returns
273    ///
274    /// A rank-0 `AnyScalar` containing the requested complex number.
275    ///
276    /// # Examples
277    ///
278    /// ```
279    /// use tensor4all_core::AnyScalar;
280    ///
281    /// let scalar = AnyScalar::from_complex(1.0, -2.0);
282    /// assert_eq!(scalar.as_c64().map(|z| (z.re, z.im)), Some((1.0, -2.0)));
283    /// assert!(scalar.is_complex());
284    /// ```
285    pub fn from_complex(re: f64, im: f64) -> Self {
286        Self::from_value(Complex64::new(re, im))
287    }
288
289    /// Creates a real-valued `AnyScalar`.
290    ///
291    /// This is an alias for [`AnyScalar::from_real`].
292    ///
293    /// # Arguments
294    ///
295    /// * `x` - The real scalar value to wrap.
296    ///
297    /// # Returns
298    ///
299    /// A rank-0 `AnyScalar` with real dtype.
300    ///
301    /// # Examples
302    ///
303    /// ```
304    /// use tensor4all_core::AnyScalar;
305    ///
306    /// let scalar = AnyScalar::new_real(2.5);
307    /// assert_eq!(scalar.real(), 2.5);
308    /// assert!(scalar.is_real());
309    /// ```
310    pub fn new_real(x: f64) -> Self {
311        Self::from_real(x)
312    }
313
314    /// Creates a complex-valued `AnyScalar`.
315    ///
316    /// This is an alias for [`AnyScalar::from_complex`].
317    ///
318    /// # Arguments
319    ///
320    /// * `re` - The real part of the complex value.
321    /// * `im` - The imaginary part of the complex value.
322    ///
323    /// # Returns
324    ///
325    /// A rank-0 `AnyScalar` containing the requested complex number.
326    ///
327    /// # Examples
328    ///
329    /// ```
330    /// use tensor4all_core::AnyScalar;
331    ///
332    /// let scalar = AnyScalar::new_complex(2.0, 3.0);
333    /// assert_eq!(scalar.as_c64().map(|z| (z.re, z.im)), Some((2.0, 3.0)));
334    /// assert!(scalar.is_complex());
335    /// ```
336    pub fn new_complex(re: f64, im: f64) -> Self {
337        Self::from_complex(re, im)
338    }
339
340    /// Returns the detached primal value of this scalar.
341    ///
342    /// This is an alias for [`AnyScalar::detach`].
343    ///
344    /// # Returns
345    ///
346    /// A scalar with the same value and no gradient tracking.
347    ///
348    /// # Examples
349    ///
350    /// ```
351    /// use tensor4all_core::AnyScalar;
352    ///
353    /// let primal = AnyScalar::new_real(5.0).enable_grad().unwrap().primal().unwrap();
354    /// assert_eq!(primal.real(), 5.0);
355    /// assert!(!primal.tracks_grad());
356    /// ```
357    pub fn primal(&self) -> Result<Self> {
358        self.detach()
359    }
360
361    /// Enables gradient tracking for this scalar.
362    ///
363    /// # Returns
364    ///
365    /// A new scalar that shares the same value but participates in autodiff.
366    ///
367    /// # Examples
368    ///
369    /// ```
370    /// use tensor4all_core::AnyScalar;
371    ///
372    /// let scalar = AnyScalar::new_real(2.0).enable_grad().unwrap();
373    /// assert!(scalar.tracks_grad());
374    /// ```
375    pub fn enable_grad(self) -> Result<Self> {
376        let tensor = self
377            .tensor
378            .ok_or_else(|| anyhow!("AnyScalar has no backend tensor representation"))?;
379        Self::from_tensor(tensor.enable_grad()?)
380    }
381
382    /// Returns whether this scalar tracks gradients.
383    ///
384    /// # Returns
385    ///
386    /// `true` when the scalar participates in autodiff and can accumulate a
387    /// gradient, otherwise `false`.
388    ///
389    /// # Examples
390    ///
391    /// ```
392    /// use tensor4all_core::AnyScalar;
393    ///
394    /// let scalar = AnyScalar::new_real(1.0);
395    /// assert!(!scalar.tracks_grad());
396    /// ```
397    pub fn tracks_grad(&self) -> bool {
398        self.tensor.as_ref().is_some_and(TensorDynLen::tracks_grad)
399    }
400
401    /// Returns the stored gradient, if any.
402    ///
403    /// # Returns
404    ///
405    /// `Ok(Some(_))` when a gradient is available, `Ok(None)` when no gradient
406    /// has been recorded, or an error if the backend cannot read it.
407    ///
408    /// # Errors
409    ///
410    /// Propagates autodiff or tensor access failures from the underlying
411    /// tensor runtime.
412    ///
413    /// # Examples
414    ///
415    /// ```
416    /// use tensor4all_core::AnyScalar;
417    ///
418    /// let x = AnyScalar::new_real(2.0).enable_grad().unwrap();
419    /// let y = &x * &x;
420    /// y.backward().unwrap();
421    ///
422    /// let grad = x.grad().unwrap().unwrap();
423    /// assert_eq!(grad.real(), 4.0);
424    /// ```
425    pub fn grad(&self) -> Result<Option<Self>> {
426        self.as_tensor()?
427            .grad()
428            .and_then(|maybe_grad| maybe_grad.map(Self::from_tensor).transpose())
429    }
430
431    /// Clears the stored gradient for this scalar.
432    ///
433    /// # Returns
434    ///
435    /// `Ok(())` when the gradient buffer was cleared successfully.
436    ///
437    /// # Errors
438    ///
439    /// Propagates tensor runtime failures from the underlying autodiff state.
440    ///
441    /// # Examples
442    ///
443    /// ```
444    /// use tensor4all_core::AnyScalar;
445    ///
446    /// let x = AnyScalar::new_real(2.0).enable_grad().unwrap();
447    /// let y = &x * &x;
448    /// y.backward().unwrap();
449    /// assert!(x.grad().unwrap().is_some());
450    ///
451    /// x.clear_grad().unwrap();
452    /// assert!(x.grad().unwrap().is_none());
453    /// ```
454    pub fn clear_grad(&self) -> Result<()> {
455        self.as_tensor()?.clear_grad()
456    }
457
458    /// Runs reverse-mode autodiff starting from this scalar.
459    ///
460    /// # Returns
461    ///
462    /// `Ok(())` when gradients were accumulated successfully.
463    ///
464    /// # Errors
465    ///
466    /// Propagates failures from the underlying tensor autodiff engine.
467    ///
468    /// # Examples
469    ///
470    /// ```
471    /// use tensor4all_core::AnyScalar;
472    ///
473    /// let x = AnyScalar::new_real(2.0).enable_grad().unwrap();
474    /// let y = &x * &x;
475    /// y.backward().unwrap();
476    ///
477    /// let grad = x.grad().unwrap().unwrap();
478    /// assert_eq!(grad.real(), 4.0);
479    /// ```
480    pub fn backward(&self) -> Result<()> {
481        self.as_tensor()?.backward()
482    }
483
484    /// Returns a detached copy of this scalar.
485    ///
486    /// # Returns
487    ///
488    /// A scalar with the same value but without gradient tracking.
489    ///
490    /// # Examples
491    ///
492    /// ```
493    /// use tensor4all_core::AnyScalar;
494    ///
495    /// let detached = AnyScalar::new_real(7.0)
496    ///     .enable_grad()
497    ///     .unwrap()
498    ///     .detach()
499    ///     .unwrap();
500    /// assert_eq!(detached.real(), 7.0);
501    /// assert!(!detached.tracks_grad());
502    /// ```
503    pub fn detach(&self) -> Result<Self> {
504        Self::from_tensor(self.as_tensor()?.detach()?)
505    }
506
507    /// Returns the real part of this scalar.
508    ///
509    /// # Returns
510    ///
511    /// The real component as an `f64`, regardless of the underlying storage
512    /// type.
513    ///
514    /// # Examples
515    ///
516    /// ```
517    /// use tensor4all_core::AnyScalar;
518    ///
519    /// let scalar = AnyScalar::new_complex(3.0, -4.0);
520    /// assert_eq!(scalar.real(), 3.0);
521    /// ```
522    pub fn real(&self) -> f64 {
523        self.value().real()
524    }
525
526    /// Returns the imaginary part of this scalar.
527    ///
528    /// # Returns
529    ///
530    /// The imaginary component as an `f64`. Real-valued scalars return `0.0`.
531    ///
532    /// # Examples
533    ///
534    /// ```
535    /// use tensor4all_core::AnyScalar;
536    ///
537    /// let scalar = AnyScalar::new_complex(3.0, -4.0);
538    /// assert_eq!(scalar.imag(), -4.0);
539    /// ```
540    pub fn imag(&self) -> f64 {
541        self.value().imag()
542    }
543
544    /// Returns the magnitude of this scalar.
545    ///
546    /// # Returns
547    ///
548    /// The absolute value for real scalars or the complex norm for complex
549    /// scalars.
550    ///
551    /// # Examples
552    ///
553    /// ```
554    /// use tensor4all_core::AnyScalar;
555    ///
556    /// let scalar = AnyScalar::new_complex(3.0, -4.0);
557    /// assert_eq!(scalar.abs(), 5.0);
558    /// ```
559    pub fn abs(&self) -> f64 {
560        self.value().abs()
561    }
562
563    /// Returns whether this scalar is complex-valued.
564    ///
565    /// # Returns
566    ///
567    /// `true` for complex dtypes and `false` for real or integer dtypes.
568    ///
569    /// # Examples
570    ///
571    /// ```
572    /// use tensor4all_core::AnyScalar;
573    ///
574    /// assert!(AnyScalar::new_complex(1.0, 2.0).is_complex());
575    /// assert!(!AnyScalar::new_real(1.0).is_complex());
576    /// ```
577    pub fn is_complex(&self) -> bool {
578        self.value().is_complex()
579    }
580
581    /// Returns whether this scalar is real-valued.
582    ///
583    /// # Returns
584    ///
585    /// `true` when the scalar is not complex-valued.
586    ///
587    /// # Examples
588    ///
589    /// ```
590    /// use tensor4all_core::AnyScalar;
591    ///
592    /// assert!(AnyScalar::new_real(1.0).is_real());
593    /// assert!(!AnyScalar::new_complex(1.0, 2.0).is_real());
594    /// ```
595    pub fn is_real(&self) -> bool {
596        !self.is_complex()
597    }
598
599    /// Returns whether this scalar is exactly zero.
600    ///
601    /// # Returns
602    ///
603    /// `true` for exact zeros and `false` for any nonzero value.
604    ///
605    /// # Examples
606    ///
607    /// ```
608    /// use tensor4all_core::AnyScalar;
609    ///
610    /// assert!(AnyScalar::new_real(0.0).is_zero());
611    /// assert!(!AnyScalar::new_complex(0.0, 1.0).is_zero());
612    /// ```
613    pub fn is_zero(&self) -> bool {
614        self.value().is_zero()
615    }
616
617    /// Returns this scalar as an `f64` when it is real-valued.
618    ///
619    /// # Returns
620    ///
621    /// `Some(value)` for real and integer scalars, or `None` for complex
622    /// scalars.
623    ///
624    /// # Examples
625    ///
626    /// ```
627    /// use tensor4all_core::AnyScalar;
628    ///
629    /// assert_eq!(AnyScalar::new_real(2.5).as_f64(), Some(2.5));
630    /// assert_eq!(AnyScalar::new_complex(2.5, 1.0).as_f64(), None);
631    /// ```
632    pub fn as_f64(&self) -> Option<f64> {
633        match self.value() {
634            ScalarValue::F32(value) => Some(value as f64),
635            ScalarValue::F64(value) => Some(value),
636            ScalarValue::C32(_) | ScalarValue::C64(_) => None,
637        }
638    }
639
640    /// Returns this scalar as a `Complex64` when it is complex-valued.
641    ///
642    /// # Returns
643    ///
644    /// `Some(value)` for complex scalars or `None` for real and integer
645    /// scalars.
646    ///
647    /// # Examples
648    ///
649    /// ```
650    /// use tensor4all_core::AnyScalar;
651    ///
652    /// let scalar = AnyScalar::new_complex(2.5, 1.0);
653    /// assert_eq!(scalar.as_c64().map(|z| (z.re, z.im)), Some((2.5, 1.0)));
654    /// assert_eq!(AnyScalar::new_real(2.5).as_c64(), None);
655    /// ```
656    pub fn as_c64(&self) -> Option<Complex64> {
657        match self.value() {
658            ScalarValue::F32(_) | ScalarValue::F64(_) => None,
659            ScalarValue::C32(value) => Some(Complex64::new(value.re as f64, value.im as f64)),
660            ScalarValue::C64(value) => Some(value),
661        }
662    }
663
664    /// Returns the complex conjugate of this scalar.
665    ///
666    /// # Returns
667    ///
668    /// The conjugated scalar. Real-valued inputs are returned unchanged.
669    ///
670    /// # Examples
671    ///
672    /// ```
673    /// use tensor4all_core::AnyScalar;
674    ///
675    /// let scalar = AnyScalar::new_complex(3.0, -4.0).conj();
676    /// assert_eq!(scalar.as_c64().map(|z| (z.re, z.im)), Some((3.0, 4.0)));
677    /// ```
678    pub fn try_conj(&self) -> Result<Self> {
679        if !self.tracks_grad() {
680            return Ok(Self::from_backend_scalar(self.to_backend_scalar().conj()));
681        }
682        Self::from_eager_unary(self, "conj", |tensor| tensor.conj())
683    }
684
685    /// Returns the complex conjugate of this scalar.
686    pub fn conj(&self) -> Self {
687        self.try_conj()
688            .unwrap_or_else(|_| Self::from_backend_scalar(self.to_backend_scalar().conj()))
689    }
690
691    /// Returns the real part as a real-valued scalar.
692    ///
693    /// # Returns
694    ///
695    /// A real-valued scalar containing the real component of `self`.
696    ///
697    /// # Examples
698    ///
699    /// ```
700    /// use tensor4all_core::AnyScalar;
701    ///
702    /// let scalar = AnyScalar::new_complex(3.0, -4.0).real_part();
703    /// assert_eq!(scalar.real(), 3.0);
704    /// assert!(scalar.is_real());
705    /// ```
706    pub fn real_part(&self) -> Self {
707        Self::from_real(self.real())
708    }
709
710    /// Returns the imaginary part as a real-valued scalar.
711    ///
712    /// # Returns
713    ///
714    /// A real-valued scalar containing the imaginary component of `self`.
715    ///
716    /// # Examples
717    ///
718    /// ```
719    /// use tensor4all_core::AnyScalar;
720    ///
721    /// let scalar = AnyScalar::new_complex(3.0, -4.0).imag_part();
722    /// assert_eq!(scalar.real(), -4.0);
723    /// assert!(scalar.is_real());
724    /// ```
725    pub fn imag_part(&self) -> Self {
726        Self::from_real(self.imag())
727    }
728
729    /// Combines two real-valued scalars into a complex scalar.
730    ///
731    /// # Arguments
732    ///
733    /// * `real` - The real component.
734    /// * `imag` - The imaginary component.
735    ///
736    /// # Returns
737    ///
738    /// A complex `AnyScalar` whose real and imaginary parts come from the
739    /// inputs.
740    ///
741    /// # Errors
742    ///
743    /// Returns an error if either input is not real-valued.
744    ///
745    /// # Examples
746    ///
747    /// ```
748    /// use tensor4all_core::AnyScalar;
749    ///
750    /// let scalar = AnyScalar::compose_complex(
751    ///     AnyScalar::new_real(3.0),
752    ///     AnyScalar::new_real(-4.0),
753    /// )
754    /// .unwrap();
755    /// assert_eq!(scalar.as_c64().map(|z| (z.re, z.im)), Some((3.0, -4.0)));
756    /// ```
757    pub fn compose_complex(real: Self, imag: Self) -> Result<Self> {
758        if !real.is_real() || !imag.is_real() {
759            return Err(anyhow!("compose_complex requires real-valued inputs"));
760        }
761        let imag_term = imag.try_mul(&Self::new_complex(0.0, 1.0))?;
762        real.try_add(&imag_term)
763    }
764
765    /// Returns the square root of this scalar.
766    ///
767    /// # Returns
768    ///
769    /// The principal square root. Negative real inputs and complex inputs use
770    /// complex arithmetic.
771    ///
772    /// # Examples
773    ///
774    /// ```
775    /// use tensor4all_core::AnyScalar;
776    ///
777    /// let scalar = AnyScalar::new_real(9.0).sqrt();
778    /// assert_eq!(scalar.real(), 3.0);
779    /// assert!(scalar.is_real());
780    /// ```
781    pub fn sqrt(&self) -> Self {
782        if !self.tracks_grad() || self.is_complex() || self.real() < 0.0 {
783            Self::from_backend_scalar(self.to_backend_scalar().sqrt())
784        } else {
785            Self::from_eager_unary(self, "sqrt", |tensor| tensor.sqrt())
786                .unwrap_or_else(|_| Self::from_backend_scalar(self.to_backend_scalar().sqrt()))
787        }
788    }
789
790    /// Raises this scalar to a floating-point power.
791    ///
792    /// # Arguments
793    ///
794    /// * `exponent` - The exponent to apply.
795    ///
796    /// # Returns
797    ///
798    /// The value of `self^exponent`.
799    ///
800    /// # Examples
801    ///
802    /// ```
803    /// use tensor4all_core::AnyScalar;
804    ///
805    /// let scalar = AnyScalar::new_real(2.0).powf(3.0);
806    /// assert_eq!(scalar.real(), 8.0);
807    /// ```
808    pub fn powf(&self, exponent: f64) -> Self {
809        Self::from_backend_scalar(self.to_backend_scalar().powf(exponent))
810    }
811
812    /// Raises this scalar to an integer power.
813    ///
814    /// # Arguments
815    ///
816    /// * `exponent` - The integer exponent to apply. Negative exponents return
817    ///   the reciprocal power.
818    ///
819    /// # Returns
820    ///
821    /// The value of `self^exponent`. Zero exponents return `1`.
822    ///
823    /// # Examples
824    ///
825    /// ```
826    /// use tensor4all_core::AnyScalar;
827    ///
828    /// assert_eq!(AnyScalar::new_real(2.0).powi(3).real(), 8.0);
829    /// assert_eq!(AnyScalar::new_real(2.0).powi(-1).real(), 0.5);
830    /// ```
831    pub fn powi(&self, exponent: i32) -> Self {
832        if exponent == 0 {
833            return Self::one();
834        }
835
836        let mut base = self.clone();
837        let mut power = exponent.unsigned_abs();
838        let mut acc = Self::one();
839
840        while power > 0 {
841            if power % 2 == 1 {
842                acc = acc.try_mul(&base).unwrap_or_else(|_| {
843                    Self::from_backend_scalar(acc.to_backend_scalar() * base.to_backend_scalar())
844                });
845            }
846            power /= 2;
847            if power > 0 {
848                base = base.try_mul(&base).unwrap_or_else(|_| {
849                    Self::from_backend_scalar(base.to_backend_scalar() * base.to_backend_scalar())
850                });
851            }
852        }
853
854        if exponent < 0 {
855            Self::one().try_div(&acc).unwrap_or_else(|_| {
856                Self::from_backend_scalar(Self::one().to_backend_scalar() / acc.to_backend_scalar())
857            })
858        } else {
859            acc
860        }
861    }
862
863    pub(crate) fn to_backend_scalar(&self) -> BackendScalar {
864        match self.value() {
865            ScalarValue::F32(value) => BackendScalar::from_value(value),
866            ScalarValue::F64(value) => BackendScalar::from_value(value),
867            ScalarValue::C32(value) => BackendScalar::from_value(value),
868            ScalarValue::C64(value) => BackendScalar::from_value(value),
869        }
870    }
871
872    pub(crate) fn try_add(&self, rhs: &Self) -> Result<Self> {
873        if !self.tracks_grad() && !rhs.tracks_grad() {
874            return Ok(Self::from_backend_scalar(
875                self.to_backend_scalar() + rhs.to_backend_scalar(),
876            ));
877        }
878        Self::from_eager_binary(self, rhs, "add", |lhs, rhs| lhs.add(rhs))
879    }
880
881    pub(crate) fn try_mul(&self, rhs: &Self) -> Result<Self> {
882        if !self.tracks_grad() && !rhs.tracks_grad() {
883            return Ok(Self::from_backend_scalar(
884                self.to_backend_scalar() * rhs.to_backend_scalar(),
885            ));
886        }
887        Self::from_eager_binary(self, rhs, "mul", |lhs, rhs| lhs.mul(rhs))
888    }
889
890    pub(crate) fn try_div(&self, rhs: &Self) -> Result<Self> {
891        if !self.tracks_grad() && !rhs.tracks_grad() {
892            return Ok(Self::from_backend_scalar(
893                self.to_backend_scalar() / rhs.to_backend_scalar(),
894            ));
895        }
896        if self.as_tensor()?.as_native()?.dtype() == rhs.as_tensor()?.as_native()?.dtype() {
897            Self::from_eager_binary(self, rhs, "div", |lhs, rhs| lhs.div(rhs))
898        } else {
899            Ok(Self::from_backend_scalar(
900                self.to_backend_scalar() / rhs.to_backend_scalar(),
901            ))
902        }
903    }
904
905    pub(crate) fn try_neg(&self) -> Result<Self> {
906        if !self.tracks_grad() {
907            return Ok(Self::from_backend_scalar(-self.to_backend_scalar()));
908        }
909        Self::from_eager_unary(self, "neg", |tensor| tensor.neg())
910    }
911}
912
913impl SumFromStorage for AnyScalar {
914    fn sum_from_storage(storage: &Storage) -> Self {
915        Self::from_backend_scalar(BackendScalar::sum_from_storage(storage))
916    }
917}
918
919impl From<f32> for AnyScalar {
920    fn from(value: f32) -> Self {
921        Self::from_value(value)
922    }
923}
924
925impl From<f64> for AnyScalar {
926    fn from(value: f64) -> Self {
927        Self::from_value(value)
928    }
929}
930
931impl From<Complex32> for AnyScalar {
932    fn from(value: Complex32) -> Self {
933        Self::from_value(value)
934    }
935}
936
937impl From<Complex64> for AnyScalar {
938    fn from(value: Complex64) -> Self {
939        Self::from_value(value)
940    }
941}
942
943impl TryFrom<AnyScalar> for f64 {
944    type Error = &'static str;
945
946    fn try_from(value: AnyScalar) -> std::result::Result<Self, Self::Error> {
947        value.as_f64().ok_or("cannot convert complex scalar to f64")
948    }
949}
950
951impl From<AnyScalar> for Complex64 {
952    fn from(value: AnyScalar) -> Self {
953        value.value().into_complex()
954    }
955}
956
957impl Add<&AnyScalar> for &AnyScalar {
958    type Output = AnyScalar;
959
960    fn add(self, rhs: &AnyScalar) -> Self::Output {
961        self.try_add(rhs).unwrap_or_else(|_| {
962            AnyScalar::from_backend_scalar(self.to_backend_scalar() + rhs.to_backend_scalar())
963        })
964    }
965}
966
967impl Add<AnyScalar> for AnyScalar {
968    type Output = AnyScalar;
969
970    fn add(self, rhs: AnyScalar) -> Self::Output {
971        Add::add(&self, &rhs)
972    }
973}
974
975impl Add<AnyScalar> for &AnyScalar {
976    type Output = AnyScalar;
977
978    fn add(self, rhs: AnyScalar) -> Self::Output {
979        Add::add(self, &rhs)
980    }
981}
982
983impl Add<&AnyScalar> for AnyScalar {
984    type Output = AnyScalar;
985
986    fn add(self, rhs: &AnyScalar) -> Self::Output {
987        Add::add(&self, rhs)
988    }
989}
990
991impl Sub<&AnyScalar> for &AnyScalar {
992    type Output = AnyScalar;
993
994    fn sub(self, rhs: &AnyScalar) -> Self::Output {
995        Add::add(self, &Neg::neg(rhs))
996    }
997}
998
999impl Sub<AnyScalar> for AnyScalar {
1000    type Output = AnyScalar;
1001
1002    fn sub(self, rhs: AnyScalar) -> Self::Output {
1003        Sub::sub(&self, &rhs)
1004    }
1005}
1006
1007impl Sub<AnyScalar> for &AnyScalar {
1008    type Output = AnyScalar;
1009
1010    fn sub(self, rhs: AnyScalar) -> Self::Output {
1011        Sub::sub(self, &rhs)
1012    }
1013}
1014
1015impl Sub<&AnyScalar> for AnyScalar {
1016    type Output = AnyScalar;
1017
1018    fn sub(self, rhs: &AnyScalar) -> Self::Output {
1019        Sub::sub(&self, rhs)
1020    }
1021}
1022
1023impl Mul<&AnyScalar> for &AnyScalar {
1024    type Output = AnyScalar;
1025
1026    fn mul(self, rhs: &AnyScalar) -> Self::Output {
1027        self.try_mul(rhs).unwrap_or_else(|_| {
1028            AnyScalar::from_backend_scalar(self.to_backend_scalar() * rhs.to_backend_scalar())
1029        })
1030    }
1031}
1032
1033impl Mul<AnyScalar> for AnyScalar {
1034    type Output = AnyScalar;
1035
1036    fn mul(self, rhs: AnyScalar) -> Self::Output {
1037        Mul::mul(&self, &rhs)
1038    }
1039}
1040
1041impl Mul<AnyScalar> for &AnyScalar {
1042    type Output = AnyScalar;
1043
1044    fn mul(self, rhs: AnyScalar) -> Self::Output {
1045        Mul::mul(self, &rhs)
1046    }
1047}
1048
1049impl Mul<&AnyScalar> for AnyScalar {
1050    type Output = AnyScalar;
1051
1052    fn mul(self, rhs: &AnyScalar) -> Self::Output {
1053        Mul::mul(&self, rhs)
1054    }
1055}
1056
1057impl Div<&AnyScalar> for &AnyScalar {
1058    type Output = AnyScalar;
1059
1060    fn div(self, rhs: &AnyScalar) -> Self::Output {
1061        self.try_div(rhs).unwrap_or_else(|_| {
1062            AnyScalar::from_backend_scalar(self.to_backend_scalar() / rhs.to_backend_scalar())
1063        })
1064    }
1065}
1066
1067impl Div<AnyScalar> for AnyScalar {
1068    type Output = AnyScalar;
1069
1070    fn div(self, rhs: AnyScalar) -> Self::Output {
1071        Div::div(&self, &rhs)
1072    }
1073}
1074
1075impl Div<AnyScalar> for &AnyScalar {
1076    type Output = AnyScalar;
1077
1078    fn div(self, rhs: AnyScalar) -> Self::Output {
1079        Div::div(self, &rhs)
1080    }
1081}
1082
1083impl Div<&AnyScalar> for AnyScalar {
1084    type Output = AnyScalar;
1085
1086    fn div(self, rhs: &AnyScalar) -> Self::Output {
1087        Div::div(&self, rhs)
1088    }
1089}
1090
1091impl Neg for &AnyScalar {
1092    type Output = AnyScalar;
1093
1094    fn neg(self) -> Self::Output {
1095        self.try_neg()
1096            .unwrap_or_else(|_| AnyScalar::from_backend_scalar(-self.to_backend_scalar()))
1097    }
1098}
1099
1100impl Neg for AnyScalar {
1101    type Output = AnyScalar;
1102
1103    fn neg(self) -> Self::Output {
1104        Neg::neg(&self)
1105    }
1106}
1107
1108impl Mul<AnyScalar> for f64 {
1109    type Output = AnyScalar;
1110
1111    fn mul(self, rhs: AnyScalar) -> Self::Output {
1112        AnyScalar::from_real(self) * rhs
1113    }
1114}
1115
1116impl Mul<AnyScalar> for Complex64 {
1117    type Output = AnyScalar;
1118
1119    fn mul(self, rhs: AnyScalar) -> Self::Output {
1120        AnyScalar::from(self) * rhs
1121    }
1122}
1123
1124impl Div<AnyScalar> for Complex64 {
1125    type Output = AnyScalar;
1126
1127    fn div(self, rhs: AnyScalar) -> Self::Output {
1128        AnyScalar::from(self) / rhs
1129    }
1130}
1131
1132impl Default for AnyScalar {
1133    fn default() -> Self {
1134        Self::zero()
1135    }
1136}
1137
1138impl Zero for AnyScalar {
1139    fn zero() -> Self {
1140        Self::from_real(0.0)
1141    }
1142
1143    fn is_zero(&self) -> bool {
1144        AnyScalar::is_zero(self)
1145    }
1146}
1147
1148impl One for AnyScalar {
1149    fn one() -> Self {
1150        Self::from_real(1.0)
1151    }
1152}
1153
1154impl PartialEq for AnyScalar {
1155    fn eq(&self, other: &Self) -> bool {
1156        self.value() == other.value()
1157    }
1158}
1159
1160impl PartialOrd for AnyScalar {
1161    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1162        match (self.value(), other.value()) {
1163            (ScalarValue::F32(lhs), ScalarValue::F32(rhs)) => lhs.partial_cmp(&rhs),
1164            (ScalarValue::F32(lhs), ScalarValue::F64(rhs)) => (lhs as f64).partial_cmp(&rhs),
1165            (ScalarValue::F64(lhs), ScalarValue::F32(rhs)) => lhs.partial_cmp(&(rhs as f64)),
1166            (ScalarValue::F64(lhs), ScalarValue::F64(rhs)) => lhs.partial_cmp(&rhs),
1167            _ => None,
1168        }
1169    }
1170}
1171
1172impl fmt::Display for AnyScalar {
1173    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1174        match self.value() {
1175            ScalarValue::F32(value) => value.fmt(f),
1176            ScalarValue::F64(value) => value.fmt(f),
1177            ScalarValue::C32(value) => value.fmt(f),
1178            ScalarValue::C64(value) => value.fmt(f),
1179        }
1180    }
1181}
1182
1183impl fmt::Debug for AnyScalar {
1184    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1185        let dtype = match self.value {
1186            ScalarValue::F32(_) => "f32",
1187            ScalarValue::F64(_) => "f64",
1188            ScalarValue::C32(_) => "c32",
1189            ScalarValue::C64(_) => "c64",
1190        };
1191        f.debug_struct("AnyScalar")
1192            .field("dtype", &dtype)
1193            .field("value", &self.value())
1194            .field("tracks_grad", &self.tracks_grad())
1195            .finish()
1196    }
1197}
1198
1199#[cfg(test)]
1200mod tests {
1201    use super::*;
1202
1203    #[test]
1204    fn non_grad_scalar_arithmetic_uses_plain_values() {
1205        let a = AnyScalar::new_real(3.0);
1206        let b = AnyScalar::new_real(4.0);
1207
1208        let value = ((a.clone() + b.clone()) * b.clone() - AnyScalar::new_real(8.0))
1209            / AnyScalar::new_real(2.0);
1210
1211        assert_eq!(value.as_f64(), Some(10.0));
1212        assert!(!value.tracks_grad());
1213        assert!(value.as_tensor().is_ok());
1214    }
1215
1216    #[test]
1217    fn tracked_scalar_arithmetic_preserves_autodiff() {
1218        let x = AnyScalar::new_real(2.0).enable_grad().unwrap();
1219        let y = &x * &x;
1220
1221        assert!(y.tracks_grad());
1222        y.backward().unwrap();
1223
1224        let grad = x.grad().unwrap().unwrap();
1225        assert_eq!(grad.as_f64(), Some(4.0));
1226    }
1227}