ad_tensors_rs/
traits.rs

1use std::hash::Hash;
2
3use tenferro_algebra::Scalar;
4use tenferro_tensor::Tensor;
5
6use crate::{AdScalar, AdTensor, AdValue, DiffPolicy, DynScalar, Result};
7
8/// AD rule method result type alias.
9///
10/// # Examples
11///
12/// ```rust
13/// use ad_tensors_rs::{AdResult, Error};
14///
15/// let ok: AdResult<()> = Ok(());
16/// let err: AdResult<()> = Err(Error::InvalidAdTensor {
17///     message: "demo".into(),
18/// });
19/// assert!(ok.is_ok());
20/// assert!(err.is_err());
21/// ```
22pub type AdResult<T> = Result<T>;
23
24/// Trait bound for index labels used in contraction/factorization APIs.
25///
26/// # Examples
27///
28/// ```rust
29/// use ad_tensors_rs::IndexLike;
30///
31/// fn accepts_index<I: IndexLike>(_index: I) {}
32/// accepts_index(3_u32);
33/// ```
34pub trait IndexLike: Clone + Eq + Hash {}
35
36impl<T> IndexLike for T where T: Clone + Eq + Hash {}
37
38/// A value that can be observed as an [`AdValue`].
39///
40/// # Examples
41///
42/// ```rust
43/// use ad_tensors_rs::{AdMode, AdValue, Differentiable};
44///
45/// fn mode_of<V: Differentiable>(value: &V) -> AdMode {
46///     value.ad_value().mode()
47/// }
48///
49/// let x = AdValue::primal(2.0_f64);
50/// assert_eq!(mode_of(&x), AdMode::Primal);
51/// ```
52pub trait Differentiable: Clone {
53    /// Underlying primal payload type.
54    type Primal: Clone;
55
56    /// Borrow as an [`AdValue`].
57    fn ad_value(&self) -> &AdValue<Self::Primal>;
58}
59
60impl<T: Clone> Differentiable for AdValue<T> {
61    type Primal = T;
62
63    fn ad_value(&self) -> &AdValue<T> {
64        self
65    }
66}
67
68impl<T: Clone> Differentiable for AdScalar<T> {
69    type Primal = T;
70
71    fn ad_value(&self) -> &AdValue<T> {
72        self.as_value()
73    }
74}
75
76impl<T: Clone + Scalar> Differentiable for AdTensor<T> {
77    type Primal = Tensor<T>;
78
79    fn ad_value(&self) -> &AdValue<Tensor<T>> {
80        self.as_value()
81    }
82}
83
84/// Allowed index-pair restrictions for contraction planning.
85///
86/// # Examples
87///
88/// ```rust
89/// use ad_tensors_rs::AllowedPairs;
90///
91/// let pairs = AllowedPairs { pairs: &[(0, 1), (2, 3)] };
92/// assert_eq!(pairs.pairs.len(), 2);
93/// ```
94#[derive(Debug, Clone, Copy)]
95pub struct AllowedPairs<'a> {
96    /// Allowed input-index pairs.
97    pub pairs: &'a [(usize, usize)],
98}
99
100/// Factorization options for generic tensor-kernel APIs.
101///
102/// # Examples
103///
104/// ```rust
105/// use ad_tensors_rs::{DiffPolicy, FactorizeOptions};
106///
107/// let opts = FactorizeOptions {
108///     max_rank: Some(32),
109///     diff_policy: DiffPolicy::StopGradient,
110/// };
111/// assert_eq!(opts.max_rank, Some(32));
112/// ```
113#[derive(Debug, Clone)]
114pub struct FactorizeOptions {
115    /// Optional truncation rank.
116    pub max_rank: Option<usize>,
117    /// Differentiation behavior at non-smooth points.
118    pub diff_policy: DiffPolicy,
119}
120
121impl Default for FactorizeOptions {
122    fn default() -> Self {
123        Self {
124            max_rank: None,
125            diff_policy: DiffPolicy::StopGradient,
126        }
127    }
128}
129
130/// Generic factorization output container.
131///
132/// # Examples
133///
134/// ```rust
135/// use ad_tensors_rs::FactorizeResult;
136///
137/// let result = FactorizeResult { left: 1_i32, right: 2_i32 };
138/// assert_eq!(result.left + result.right, 3);
139/// ```
140#[derive(Debug, Clone)]
141pub struct FactorizeResult<T> {
142    /// Left factor.
143    pub left: T,
144    /// Right factor.
145    pub right: T,
146}
147
148/// Numeric tensor-kernel boundary for contraction/factorization operations.
149///
150/// # Examples
151///
152/// ```rust
153/// use ad_tensors_rs::{AllowedPairs, DynScalar, FactorizeOptions, FactorizeResult, Result, TensorKernel};
154///
155/// #[derive(Clone)]
156/// struct DummyKernel;
157///
158/// impl TensorKernel for DummyKernel {
159///     type Index = u8;
160///
161///     fn contract(_tensors: &[&Self], _allowed: AllowedPairs<'_>) -> Result<Self> {
162///         Ok(Self)
163///     }
164///
165///     fn factorize(
166///         &self,
167///         _left_inds: &[Self::Index],
168///         _options: &FactorizeOptions,
169///     ) -> Result<FactorizeResult<Self>> {
170///         Ok(FactorizeResult { left: Self, right: Self })
171///     }
172///
173///     fn axpby(&self, _a: DynScalar, _other: &Self, _b: DynScalar) -> Result<Self> {
174///         Ok(Self)
175///     }
176///
177///     fn scale(&self, _a: DynScalar) -> Result<Self> {
178///         Ok(Self)
179///     }
180///
181///     fn inner_product(&self, _other: &Self) -> Result<DynScalar> {
182///         Ok(DynScalar::F64(0.0))
183///     }
184/// }
185///
186/// let x = DummyKernel;
187/// let _ = x.scale(DynScalar::F64(2.0)).unwrap();
188/// ```
189pub trait TensorKernel: Clone {
190    /// Index label type used by this kernel.
191    type Index: IndexLike;
192
193    /// Contract a set of tensors under index-pair constraints.
194    fn contract(tensors: &[&Self], allowed: AllowedPairs<'_>) -> Result<Self>;
195
196    /// Factorize a tensor into two factors.
197    fn factorize(
198        &self,
199        left_inds: &[Self::Index],
200        options: &FactorizeOptions,
201    ) -> Result<FactorizeResult<Self>>;
202
203    /// Affine combination `a * self + b * other`.
204    fn axpby(&self, a: DynScalar, other: &Self, b: DynScalar) -> Result<Self>;
205
206    /// Scalar multiply.
207    fn scale(&self, a: DynScalar) -> Result<Self>;
208
209    /// Inner product result.
210    fn inner_product(&self, other: &Self) -> Result<DynScalar>;
211}
212
213/// Operation-level AD rules (`rrule`, `frule`, `hvp`).
214///
215/// # Examples
216///
217/// ```rust
218/// use ad_tensors_rs::{AdResult, AdValue, Differentiable, OpRule, Result};
219///
220/// struct IdentityRule;
221///
222/// impl OpRule<AdValue<f64>> for IdentityRule {
223///     fn eval(&self, inputs: &[&AdValue<f64>]) -> Result<AdValue<f64>> {
224///         Ok((*inputs[0]).clone())
225///     }
226///
227///     fn rrule(
228///         &self,
229///         _inputs: &[&AdValue<f64>],
230///         _out: &AdValue<f64>,
231///         cotangent: &f64,
232///     ) -> AdResult<Vec<f64>> {
233///         Ok(vec![*cotangent])
234///     }
235///
236///     fn frule(
237///         &self,
238///         _inputs: &[&AdValue<f64>],
239///         tangents: &[Option<&f64>],
240///     ) -> AdResult<f64> {
241///         Ok(tangents[0].copied().unwrap_or(0.0))
242///     }
243///
244///     fn hvp(
245///         &self,
246///         _inputs: &[&AdValue<f64>],
247///         cotangent: &f64,
248///         cotangent_tangent: Option<&f64>,
249///         _input_tangents: &[Option<&f64>],
250///     ) -> AdResult<Vec<(f64, f64)>> {
251///         Ok(vec![(*cotangent, cotangent_tangent.copied().unwrap_or(0.0))])
252///     }
253/// }
254///
255/// let x = AdValue::primal(2.0_f64);
256/// let rule = IdentityRule;
257/// let y = rule.eval(&[&x]).unwrap();
258/// assert_eq!(y.primal_ref(), &2.0);
259/// ```
260pub trait OpRule<V: Differentiable> {
261    /// Compute primal output.
262    fn eval(&self, inputs: &[&V]) -> Result<V>;
263
264    /// Reverse-mode pullback.
265    fn rrule(&self, inputs: &[&V], out: &V, cotangent: &V::Primal) -> AdResult<Vec<V::Primal>>;
266
267    /// Forward-mode pushforward.
268    fn frule(&self, inputs: &[&V], tangents: &[Option<&V::Primal>]) -> AdResult<V::Primal>;
269
270    /// Hessian-vector product.
271    fn hvp(
272        &self,
273        inputs: &[&V],
274        cotangent: &V::Primal,
275        cotangent_tangent: Option<&V::Primal>,
276        input_tangents: &[Option<&V::Primal>],
277    ) -> AdResult<Vec<(V::Primal, V::Primal)>>;
278}