ad_tensors_rs/traits.rs
1use std::hash::Hash;
2
3use tenferro_algebra::Scalar;
4use tenferro_tensor::Tensor;
5
6use crate::{AdScalar, AdTensor, AdValue, DiffPolicy, DynScalar, Result};
7
8/// AD rule method result type alias.
9///
10/// # Examples
11///
12/// ```rust
13/// use ad_tensors_rs::{AdResult, Error};
14///
15/// let ok: AdResult<()> = Ok(());
16/// let err: AdResult<()> = Err(Error::InvalidAdTensor {
17/// message: "demo".into(),
18/// });
19/// assert!(ok.is_ok());
20/// assert!(err.is_err());
21/// ```
22pub type AdResult<T> = Result<T>;
23
24/// Trait bound for index labels used in contraction/factorization APIs.
25///
26/// # Examples
27///
28/// ```rust
29/// use ad_tensors_rs::IndexLike;
30///
31/// fn accepts_index<I: IndexLike>(_index: I) {}
32/// accepts_index(3_u32);
33/// ```
34pub trait IndexLike: Clone + Eq + Hash {}
35
36impl<T> IndexLike for T where T: Clone + Eq + Hash {}
37
38/// A value that can be observed as an [`AdValue`].
39///
40/// # Examples
41///
42/// ```rust
43/// use ad_tensors_rs::{AdMode, AdValue, Differentiable};
44///
45/// fn mode_of<V: Differentiable>(value: &V) -> AdMode {
46/// value.ad_value().mode()
47/// }
48///
49/// let x = AdValue::primal(2.0_f64);
50/// assert_eq!(mode_of(&x), AdMode::Primal);
51/// ```
52pub trait Differentiable: Clone {
53 /// Underlying primal payload type.
54 type Primal: Clone;
55
56 /// Borrow as an [`AdValue`].
57 fn ad_value(&self) -> &AdValue<Self::Primal>;
58}
59
60impl<T: Clone> Differentiable for AdValue<T> {
61 type Primal = T;
62
63 fn ad_value(&self) -> &AdValue<T> {
64 self
65 }
66}
67
68impl<T: Clone> Differentiable for AdScalar<T> {
69 type Primal = T;
70
71 fn ad_value(&self) -> &AdValue<T> {
72 self.as_value()
73 }
74}
75
76impl<T: Clone + Scalar> Differentiable for AdTensor<T> {
77 type Primal = Tensor<T>;
78
79 fn ad_value(&self) -> &AdValue<Tensor<T>> {
80 self.as_value()
81 }
82}
83
84/// Allowed index-pair restrictions for contraction planning.
85///
86/// # Examples
87///
88/// ```rust
89/// use ad_tensors_rs::AllowedPairs;
90///
91/// let pairs = AllowedPairs { pairs: &[(0, 1), (2, 3)] };
92/// assert_eq!(pairs.pairs.len(), 2);
93/// ```
94#[derive(Debug, Clone, Copy)]
95pub struct AllowedPairs<'a> {
96 /// Allowed input-index pairs.
97 pub pairs: &'a [(usize, usize)],
98}
99
100/// Factorization options for generic tensor-kernel APIs.
101///
102/// # Examples
103///
104/// ```rust
105/// use ad_tensors_rs::{DiffPolicy, FactorizeOptions};
106///
107/// let opts = FactorizeOptions {
108/// max_rank: Some(32),
109/// diff_policy: DiffPolicy::StopGradient,
110/// };
111/// assert_eq!(opts.max_rank, Some(32));
112/// ```
113#[derive(Debug, Clone)]
114pub struct FactorizeOptions {
115 /// Optional truncation rank.
116 pub max_rank: Option<usize>,
117 /// Differentiation behavior at non-smooth points.
118 pub diff_policy: DiffPolicy,
119}
120
121impl Default for FactorizeOptions {
122 fn default() -> Self {
123 Self {
124 max_rank: None,
125 diff_policy: DiffPolicy::StopGradient,
126 }
127 }
128}
129
130/// Generic factorization output container.
131///
132/// # Examples
133///
134/// ```rust
135/// use ad_tensors_rs::FactorizeResult;
136///
137/// let result = FactorizeResult { left: 1_i32, right: 2_i32 };
138/// assert_eq!(result.left + result.right, 3);
139/// ```
140#[derive(Debug, Clone)]
141pub struct FactorizeResult<T> {
142 /// Left factor.
143 pub left: T,
144 /// Right factor.
145 pub right: T,
146}
147
148/// Numeric tensor-kernel boundary for contraction/factorization operations.
149///
150/// # Examples
151///
152/// ```rust
153/// use ad_tensors_rs::{AllowedPairs, DynScalar, FactorizeOptions, FactorizeResult, Result, TensorKernel};
154///
155/// #[derive(Clone)]
156/// struct DummyKernel;
157///
158/// impl TensorKernel for DummyKernel {
159/// type Index = u8;
160///
161/// fn contract(_tensors: &[&Self], _allowed: AllowedPairs<'_>) -> Result<Self> {
162/// Ok(Self)
163/// }
164///
165/// fn factorize(
166/// &self,
167/// _left_inds: &[Self::Index],
168/// _options: &FactorizeOptions,
169/// ) -> Result<FactorizeResult<Self>> {
170/// Ok(FactorizeResult { left: Self, right: Self })
171/// }
172///
173/// fn axpby(&self, _a: DynScalar, _other: &Self, _b: DynScalar) -> Result<Self> {
174/// Ok(Self)
175/// }
176///
177/// fn scale(&self, _a: DynScalar) -> Result<Self> {
178/// Ok(Self)
179/// }
180///
181/// fn inner_product(&self, _other: &Self) -> Result<DynScalar> {
182/// Ok(DynScalar::F64(0.0))
183/// }
184/// }
185///
186/// let x = DummyKernel;
187/// let _ = x.scale(DynScalar::F64(2.0)).unwrap();
188/// ```
189pub trait TensorKernel: Clone {
190 /// Index label type used by this kernel.
191 type Index: IndexLike;
192
193 /// Contract a set of tensors under index-pair constraints.
194 fn contract(tensors: &[&Self], allowed: AllowedPairs<'_>) -> Result<Self>;
195
196 /// Factorize a tensor into two factors.
197 fn factorize(
198 &self,
199 left_inds: &[Self::Index],
200 options: &FactorizeOptions,
201 ) -> Result<FactorizeResult<Self>>;
202
203 /// Affine combination `a * self + b * other`.
204 fn axpby(&self, a: DynScalar, other: &Self, b: DynScalar) -> Result<Self>;
205
206 /// Scalar multiply.
207 fn scale(&self, a: DynScalar) -> Result<Self>;
208
209 /// Inner product result.
210 fn inner_product(&self, other: &Self) -> Result<DynScalar>;
211}
212
213/// Operation-level AD rules (`rrule`, `frule`, `hvp`).
214///
215/// # Examples
216///
217/// ```rust
218/// use ad_tensors_rs::{AdResult, AdValue, Differentiable, OpRule, Result};
219///
220/// struct IdentityRule;
221///
222/// impl OpRule<AdValue<f64>> for IdentityRule {
223/// fn eval(&self, inputs: &[&AdValue<f64>]) -> Result<AdValue<f64>> {
224/// Ok((*inputs[0]).clone())
225/// }
226///
227/// fn rrule(
228/// &self,
229/// _inputs: &[&AdValue<f64>],
230/// _out: &AdValue<f64>,
231/// cotangent: &f64,
232/// ) -> AdResult<Vec<f64>> {
233/// Ok(vec![*cotangent])
234/// }
235///
236/// fn frule(
237/// &self,
238/// _inputs: &[&AdValue<f64>],
239/// tangents: &[Option<&f64>],
240/// ) -> AdResult<f64> {
241/// Ok(tangents[0].copied().unwrap_or(0.0))
242/// }
243///
244/// fn hvp(
245/// &self,
246/// _inputs: &[&AdValue<f64>],
247/// cotangent: &f64,
248/// cotangent_tangent: Option<&f64>,
249/// _input_tangents: &[Option<&f64>],
250/// ) -> AdResult<Vec<(f64, f64)>> {
251/// Ok(vec![(*cotangent, cotangent_tangent.copied().unwrap_or(0.0))])
252/// }
253/// }
254///
255/// let x = AdValue::primal(2.0_f64);
256/// let rule = IdentityRule;
257/// let y = rule.eval(&[&x]).unwrap();
258/// assert_eq!(y.primal_ref(), &2.0);
259/// ```
260pub trait OpRule<V: Differentiable> {
261 /// Compute primal output.
262 fn eval(&self, inputs: &[&V]) -> Result<V>;
263
264 /// Reverse-mode pullback.
265 fn rrule(&self, inputs: &[&V], out: &V, cotangent: &V::Primal) -> AdResult<Vec<V::Primal>>;
266
267 /// Forward-mode pushforward.
268 fn frule(&self, inputs: &[&V], tangents: &[Option<&V::Primal>]) -> AdResult<V::Primal>;
269
270 /// Hessian-vector product.
271 fn hvp(
272 &self,
273 inputs: &[&V],
274 cotangent: &V::Primal,
275 cotangent_tangent: Option<&V::Primal>,
276 input_tangents: &[Option<&V::Primal>],
277 ) -> AdResult<Vec<(V::Primal, V::Primal)>>;
278}