tidu/rules/ad_key.rs
1use std::hash::Hash;
2
3/// Unique identifier for a [`crate::linearize`] call.
4pub type DiffPassId = u64;
5
6/// Constraint on `GraphOperation::InputKey` for AD use.
7///
8/// `tidu` uses this trait to generate tangent input keys during
9/// [`crate::linearize`].
10///
11/// # Examples
12///
13/// ```
14/// use tidu::{ADKey, DiffPassId};
15///
16/// #[derive(Clone, Debug, PartialEq, Eq, Hash)]
17/// enum MyKey {
18/// User(String),
19/// Tangent { of: Box<MyKey>, pass: DiffPassId },
20/// }
21///
22/// impl ADKey for MyKey {
23/// fn tangent_of(&self, pass: DiffPassId) -> Self {
24/// MyKey::Tangent {
25/// of: Box::new(self.clone()),
26/// pass,
27/// }
28/// }
29/// }
30/// ```
31pub trait ADKey: Clone + std::fmt::Debug + Hash + Eq + Send + Sync + 'static {
32 /// Create a tangent input key derived from this key.
33 /// `pass` is a unique identifier for the `linearize` call.
34 fn tangent_of(&self, pass: DiffPassId) -> Self;
35}