Skip to main content

tidu/rules/
ad_key.rs

1use std::hash::Hash;
2
3/// Unique identifier for a [`crate::linearize`] call.
4pub type DiffPassId = u64;
5
6/// Constraint on `GraphOperation::InputKey` for AD use.
7///
8/// `tidu` uses this trait to generate tangent input keys during
9/// [`crate::linearize`].
10///
11/// # Examples
12///
13/// ```
14/// use tidu::{ADKey, DiffPassId};
15///
16/// #[derive(Clone, Debug, PartialEq, Eq, Hash)]
17/// enum MyKey {
18///     User(String),
19///     Tangent { of: Box<MyKey>, pass: DiffPassId },
20/// }
21///
22/// impl ADKey for MyKey {
23///     fn tangent_of(&self, pass: DiffPassId) -> Self {
24///         MyKey::Tangent {
25///             of: Box::new(self.clone()),
26///             pass,
27///         }
28///     }
29/// }
30/// ```
31pub trait ADKey: Clone + std::fmt::Debug + Hash + Eq + Send + Sync + 'static {
32    /// Create a tangent input key derived from this key.
33    /// `pass` is a unique identifier for the `linearize` call.
34    fn tangent_of(&self, pass: DiffPassId) -> Self;
35}