Skip to main content

tenferro_ops/
input_key.rs

1#[cfg(feature = "autodiff")]
2use tidu::{ADKey, DiffPassId};
3
4#[derive(Clone, Debug, Hash, PartialEq, Eq)]
5pub enum TensorInputKey {
6    User {
7        id: u64,
8    },
9    #[cfg(feature = "autodiff")]
10    Tangent {
11        of: Box<TensorInputKey>,
12        pass: DiffPassId,
13    },
14}
15
16impl TensorInputKey {
17    /// Returns `true` when this key names an AD tangent input.
18    ///
19    /// # Examples
20    ///
21    /// ```rust
22    /// use tenferro_ops::input_key::TensorInputKey;
23    ///
24    /// let key = TensorInputKey::User { id: 0 };
25    /// assert!(!key.is_tangent());
26    /// ```
27    pub fn is_tangent(&self) -> bool {
28        match self {
29            TensorInputKey::User { .. } => false,
30            #[cfg(feature = "autodiff")]
31            TensorInputKey::Tangent { .. } => true,
32        }
33    }
34
35    /// Returns the user input key that owns this input's concrete primal data.
36    ///
37    /// For non-AD keys this returns `self`; for tangent keys it recursively
38    /// follows the `of` chain to the original user input.
39    ///
40    /// # Examples
41    ///
42    /// ```rust
43    /// use tenferro_ops::input_key::TensorInputKey;
44    ///
45    /// let key = TensorInputKey::User { id: 0 };
46    /// assert_eq!(key.primal_root(), &key);
47    /// ```
48    pub fn primal_root(&self) -> &Self {
49        match self {
50            TensorInputKey::User { .. } => self,
51            #[cfg(feature = "autodiff")]
52            TensorInputKey::Tangent { of, .. } => of.primal_root(),
53        }
54    }
55}
56
57#[cfg(feature = "autodiff")]
58impl ADKey for TensorInputKey {
59    fn tangent_of(&self, pass: DiffPassId) -> Self {
60        TensorInputKey::Tangent {
61            of: Box::new(self.clone()),
62            pass,
63        }
64    }
65}