tenferro_algebra/lib.rs
1//! Algebra traits for the tenferro workspace.
2//!
3//! This crate provides the minimal algebra foundation:
4//!
5//! - [`Scalar`]: Minimum requirements for tensor element types
6//! (`Copy + Send + Sync + Add + Mul + Zero + One + PartialEq`).
7//! - [`Conjugate`]: Complex conjugation (identity for real types).
8//! - [`HasAlgebra`]: Maps a scalar type `T` to its default algebra `A`.
9//! Enables automatic inference: `Tensor<f64>` → `Standard`,
10//! `Tensor<MaxPlus<f64>>` → `MaxPlus` (in external crate).
11//! - [`Semiring`]: Defines zero, one, add, mul for algebra-generic operations.
12//! - [`Standard`]: Standard arithmetic algebra (add = `+`, mul = `*`).
13//!
14//! # Extensibility
15//!
16//! External crates define new algebras by implementing `HasAlgebra` for their
17//! scalar types and `TensorPrims<MyAlgebra>` for `CpuBackend` (orphan rule
18//! compatible). For example, `tenferro-tropical` defines `MaxPlus<T>`.
19//!
20//! # Examples
21//!
22//! ```
23//! use tenferro_algebra::{HasAlgebra, Scalar, Standard};
24//!
25//! // f64 maps to Standard algebra automatically
26//! fn check_algebra<T: HasAlgebra<Algebra = Standard>>() {}
27//! check_algebra::<f64>();
28//! check_algebra::<f32>();
29//!
30//! // Scalar is automatically implemented for numeric types
31//! fn needs_scalar<T: Scalar>() {}
32//! needs_scalar::<f64>();
33//! needs_scalar::<f32>();
34//! ```
35
36use num_complex::{Complex32, Complex64};
37
38/// Scalar element type for tensors.
39///
40/// Minimum requirements for a type to be stored in a `Tensor<T>`.
41/// All standard numeric types (`f32`, `f64`, `Complex32`, `Complex64`)
42/// satisfy this trait automatically via the blanket implementation.
43///
44/// # Examples
45///
46/// ```
47/// use tenferro_algebra::Scalar;
48///
49/// fn needs_scalar<T: Scalar>() {}
50/// needs_scalar::<f64>();
51/// needs_scalar::<f32>();
52/// ```
53pub trait Scalar:
54 Copy
55 + Send
56 + Sync
57 + std::ops::Add<Output = Self>
58 + std::ops::Mul<Output = Self>
59 + num_traits::Zero
60 + num_traits::One
61 + PartialEq
62{
63}
64
65impl<T> Scalar for T where
66 T: Copy
67 + Send
68 + Sync
69 + std::ops::Add<Output = Self>
70 + std::ops::Mul<Output = Self>
71 + num_traits::Zero
72 + num_traits::One
73 + PartialEq
74{
75}
76
77/// Complex conjugation for tensor element types.
78///
79/// Default implementation returns `self` unchanged, which is correct
80/// for real-valued types. Complex types override with actual conjugation.
81///
82/// # Examples
83///
84/// ```
85/// use tenferro_algebra::Conjugate;
86///
87/// // Real types: conj is identity
88/// assert_eq!(3.14_f64.conj(), 3.14_f64);
89///
90/// // Complex types: conj negates imaginary part
91/// use num_complex::Complex64;
92/// let z = Complex64::new(1.0, 2.0);
93/// assert_eq!(z.conj(), Complex64::new(1.0, -2.0));
94/// ```
95pub trait Conjugate: Copy {
96 /// Return the complex conjugate of this value.
97 fn conj(self) -> Self {
98 self
99 }
100}
101
102impl Conjugate for f32 {}
103impl Conjugate for f64 {}
104
105impl Conjugate for Complex32 {
106 fn conj(self) -> Self {
107 Complex32::conj(&self)
108 }
109}
110
111impl Conjugate for Complex64 {
112 fn conj(self) -> Self {
113 Complex64::conj(&self)
114 }
115}
116
117/// Maps a scalar type `T` to its default algebra `A`.
118///
119/// Enables automatic algebra inference: `Tensor<f64>` → `Standard`,
120/// `Tensor<MaxPlus<f64>>` → `MaxPlus` (in external crate).
121///
122/// # Implementing for custom types
123///
124/// ```ignore
125/// struct MyScalar(f64);
126/// struct MyAlgebra;
127///
128/// impl HasAlgebra for MyScalar {
129/// type Algebra = MyAlgebra;
130/// }
131/// ```
132pub trait HasAlgebra {
133 /// The algebra associated with this scalar type.
134 type Algebra;
135}
136
137/// Standard arithmetic algebra (add = `+`, mul = `*`).
138///
139/// This is the default algebra for built-in numeric types (`f32`, `f64`,
140/// `Complex32`, `Complex64`).
141pub struct Standard;
142
143impl HasAlgebra for f32 {
144 type Algebra = Standard;
145}
146
147impl HasAlgebra for f64 {
148 type Algebra = Standard;
149}
150
151impl HasAlgebra for Complex32 {
152 type Algebra = Standard;
153}
154
155impl HasAlgebra for Complex64 {
156 type Algebra = Standard;
157}
158
159/// Semiring trait for algebra-generic operations.
160///
161/// Defines the four fundamental operations needed for tensor contractions
162/// under a given algebra:
163///
164/// - `zero()`: Additive identity
165/// - `one()`: Multiplicative identity
166/// - `add(a, b)`: Semiring addition (e.g., `+` for Standard, `max` for MaxPlus)
167/// - `mul(a, b)`: Semiring multiplication (e.g., `*` for Standard, `+` for MaxPlus)
168///
169/// # Examples
170///
171/// Standard arithmetic:
172/// - `zero() = 0`, `one() = 1`, `add = +`, `mul = *`
173///
174/// Tropical (MaxPlus) semiring (in external crate):
175/// - `zero() = -∞`, `one() = 0`, `add = max`, `mul = +`
176pub trait Semiring {
177 /// The scalar type for this semiring.
178 type Scalar: Scalar;
179
180 /// Additive identity element.
181 fn zero() -> Self::Scalar;
182
183 /// Multiplicative identity element.
184 fn one() -> Self::Scalar;
185
186 /// Semiring addition.
187 fn add(a: Self::Scalar, b: Self::Scalar) -> Self::Scalar;
188
189 /// Semiring multiplication.
190 fn mul(a: Self::Scalar, b: Self::Scalar) -> Self::Scalar;
191}