tenferro_algebra/
lib.rs

1//! Algebra traits for the tenferro workspace.
2//!
3//! This crate provides the minimal algebra foundation:
4//!
5//! - [`Scalar`]: Minimum requirements for tensor element types
6//!   (equivalent to [`ScalarBase`](strided_traits::ScalarBase) from strided-traits).
7//! - [`Conjugate`]: Complex conjugation (identity for real types).
8//! - [`HasAlgebra`]: Maps a scalar type `T` to its default algebra `Alg`.
9//!   Enables automatic inference: `Tensor<f64>` → `Standard<f64>`,
10//!   `Tensor<MaxPlus<f64>>` → `MaxPlusAlgebra<f64>` (in external crate).
11//!   This is UX sugar — the core model is `Alg::Scalar`-centric.
12//! - [`Algebra`]: Associates an algebra marker with its scalar type (`Alg::Scalar`).
13//! - [`Semiring`]: Extends `Algebra` with zero, one, add, mul for algebra-generic operations.
14//! - [`Standard<T>`](Standard): Typed standard arithmetic algebra (add = `+`, mul = `*`).
15//!
16//! # Extensibility
17//!
18//! External crates define new algebras by implementing `HasAlgebra` for their
19//! scalar types and then implementing the primitive family traits they need
20//! (for example `TensorSemiringCore<MyAlgebra> for CpuBackend`) under the
21//! orphan rule. For example, `tenferro-ext-tropical` defines `MaxPlus<T>`.
22//!
23//! # Examples
24//!
25//! ```
26//! use tenferro_algebra::{HasAlgebra, Scalar, Standard};
27//!
28//! // f64 maps to Standard<f64> algebra automatically
29//! fn check_algebra<T: HasAlgebra<Algebra = Standard<T>>>() {}
30//! check_algebra::<f64>();
31//! check_algebra::<f32>();
32//!
33//! // Scalar is automatically implemented for numeric types
34//! fn needs_scalar<T: Scalar>() {}
35//! needs_scalar::<f64>();
36//! needs_scalar::<f32>();
37//! ```
38
39use std::marker::PhantomData;
40
41use num_complex::{Complex32, Complex64};
42
43/// Scalar element type for tensors.
44///
45/// Minimum requirements for a type to be stored in a `Tensor<T>`.
46/// All standard numeric types (`f32`, `f64`, `Complex32`, `Complex64`)
47/// satisfy this trait automatically via the blanket implementation.
48///
49/// `Scalar` is a supertrait of [`ScalarBase`](strided_traits::ScalarBase),
50/// ensuring that any `Scalar` type can be used with strided-rs operations.
51///
52/// # Examples
53///
54/// ```
55/// use tenferro_algebra::Scalar;
56///
57/// fn needs_scalar<T: Scalar>() {}
58/// needs_scalar::<f64>();
59/// needs_scalar::<f32>();
60/// ```
61pub trait Scalar: strided_traits::ScalarBase {}
62
63impl<T: strided_traits::ScalarBase> Scalar for T {}
64
65/// Complex conjugation for tensor element types.
66///
67/// Default implementation returns `self` unchanged, which is correct
68/// for real-valued types. Complex types override with actual conjugation.
69///
70/// # Examples
71///
72/// ```
73/// use tenferro_algebra::Conjugate;
74///
75/// // Real types: conj is identity
76/// assert_eq!(3.14_f64.conj(), 3.14_f64);
77///
78/// // Complex types: conj negates imaginary part
79/// use num_complex::Complex64;
80/// let z = Complex64::new(1.0, 2.0);
81/// assert_eq!(z.conj(), Complex64::new(1.0, -2.0));
82/// ```
83pub trait Conjugate: Copy {
84    /// Return the complex conjugate of this value.
85    fn conj(self) -> Self {
86        self
87    }
88}
89
90impl Conjugate for f32 {}
91impl Conjugate for f64 {}
92
93impl Conjugate for Complex32 {
94    fn conj(self) -> Self {
95        Complex32::conj(&self)
96    }
97}
98
99impl Conjugate for Complex64 {
100    fn conj(self) -> Self {
101        Complex64::conj(&self)
102    }
103}
104
105/// Maps a scalar type `T` to its default algebra `Alg`.
106///
107/// Enables automatic algebra inference: `Tensor<f64>` → `Standard<f64>`,
108/// `Tensor<MaxPlus<f64>>` → `MaxPlusAlgebra<f64>` (in external crate).
109///
110/// This trait is **UX sugar** for default algebra inference. The core
111/// algebra model is `Alg::Scalar`-centric (see [`Semiring`]).
112///
113/// # Implementing for custom types
114///
115/// ```ignore
116/// struct MyScalar(f64);
117/// struct MyAlgebra;
118///
119/// impl HasAlgebra for MyScalar {
120///     type Algebra = MyAlgebra;
121/// }
122/// ```
123pub trait HasAlgebra {
124    /// The algebra associated with this scalar type.
125    type Algebra;
126}
127
128/// Typed standard arithmetic algebra (add = `+`, mul = `*`).
129///
130/// The type parameter `T` carries the scalar type, making the algebra
131/// `Alg::Scalar`-centric. This is the canonical core model — `HasAlgebra`
132/// provides UX sugar for automatic inference (e.g., `f64` → `Standard<f64>`).
133///
134/// This is the default algebra for built-in numeric types (`f32`, `f64`,
135/// `Complex32`, `Complex64`).
136///
137/// # Examples
138///
139/// ```
140/// use tenferro_algebra::{HasAlgebra, Standard};
141///
142/// // f64 maps to Standard<f64> algebra automatically
143/// fn check_algebra<T: HasAlgebra<Algebra = Standard<T>>>() {}
144/// check_algebra::<f64>();
145/// check_algebra::<f32>();
146/// ```
147pub struct Standard<T>(PhantomData<T>);
148
149impl HasAlgebra for f32 {
150    type Algebra = Standard<f32>;
151}
152
153impl HasAlgebra for f64 {
154    type Algebra = Standard<f64>;
155}
156
157impl HasAlgebra for Complex32 {
158    type Algebra = Standard<Complex32>;
159}
160
161impl HasAlgebra for Complex64 {
162    type Algebra = Standard<Complex64>;
163}
164
165/// Associates an algebra marker with its scalar type.
166///
167/// This is the minimal algebra marker trait required by the primitive family
168/// traits. It provides the scalar type without requiring semiring operations.
169///
170/// [`Semiring`] extends `Algebra` with zero/one/add/mul.
171///
172/// # Examples
173///
174/// ```
175/// use tenferro_algebra::{Algebra, Standard};
176///
177/// fn needs_algebra<A: Algebra>() {}
178/// needs_algebra::<Standard<f64>>();
179/// needs_algebra::<Standard<f32>>();
180/// ```
181pub trait Algebra {
182    /// The scalar element type for tensors under this algebra.
183    type Scalar: Scalar;
184}
185
186/// Semiring trait for algebra-generic operations.
187///
188/// The algebra type `Alg` carries its scalar type via `Alg::Scalar`. This
189/// is the **core algebra model** — primitive-family trait bounds and
190/// einsum/linalg contracts are centered on `Alg::Scalar`.
191///
192/// Defines the four fundamental operations needed for tensor contractions
193/// under a given algebra:
194///
195/// - `zero()`: Additive identity
196/// - `one()`: Multiplicative identity
197/// - `add(a, b)`: Semiring addition (e.g., `+` for Standard, `max` for MaxPlus)
198/// - `mul(a, b)`: Semiring multiplication (e.g., `*` for Standard, `+` for MaxPlus)
199///
200/// # Examples
201///
202/// Standard arithmetic (`Standard<f64>`):
203/// - `zero() = 0`, `one() = 1`, `add = +`, `mul = *`
204///
205/// Tropical (MaxPlus) semiring (in external crate):
206/// - `zero() = -∞`, `one() = 0`, `add = max`, `mul = +`
207pub trait Semiring: Algebra {
208    /// Additive identity element.
209    fn zero() -> Self::Scalar;
210
211    /// Multiplicative identity element.
212    fn one() -> Self::Scalar;
213
214    /// Semiring addition.
215    fn add(a: Self::Scalar, b: Self::Scalar) -> Self::Scalar;
216
217    /// Semiring multiplication.
218    fn mul(a: Self::Scalar, b: Self::Scalar) -> Self::Scalar;
219}
220
221impl<T: Scalar> Algebra for Standard<T> {
222    type Scalar = T;
223}
224
225/// Standard arithmetic implements `Semiring` with `+` and `*`.
226///
227/// # Examples
228///
229/// ```
230/// use tenferro_algebra::{Semiring, Standard};
231///
232/// let z = <Standard<f64> as Semiring>::zero();
233/// let o = <Standard<f64> as Semiring>::one();
234/// assert_eq!(z, 0.0);
235/// assert_eq!(o, 1.0);
236/// ```
237impl<T: Scalar + num_traits::Zero + num_traits::One> Semiring for Standard<T> {
238    fn zero() -> T {
239        T::zero()
240    }
241
242    fn one() -> T {
243        T::one()
244    }
245
246    fn add(a: T, b: T) -> T {
247        a + b
248    }
249
250    fn mul(a: T, b: T) -> T {
251        a * b
252    }
253}