tenferro_ops/input_key.rs
1#[cfg(feature = "autodiff")]
2use tidu::{ADKey, DiffPassId};
3
4#[derive(Clone, Debug, Hash, PartialEq, Eq)]
5pub enum TensorInputKey {
6 User {
7 id: u64,
8 },
9 #[cfg(feature = "autodiff")]
10 Tangent {
11 of: Box<TensorInputKey>,
12 pass: DiffPassId,
13 },
14}
15
16impl TensorInputKey {
17 /// Returns `true` when this key names an AD tangent input.
18 ///
19 /// # Examples
20 ///
21 /// ```rust
22 /// use tenferro_ops::input_key::TensorInputKey;
23 ///
24 /// let key = TensorInputKey::User { id: 0 };
25 /// assert!(!key.is_tangent());
26 /// ```
27 pub fn is_tangent(&self) -> bool {
28 match self {
29 TensorInputKey::User { .. } => false,
30 #[cfg(feature = "autodiff")]
31 TensorInputKey::Tangent { .. } => true,
32 }
33 }
34
35 /// Returns the user input key that owns this input's concrete primal data.
36 ///
37 /// For non-AD keys this returns `self`; for tangent keys it recursively
38 /// follows the `of` chain to the original user input.
39 ///
40 /// # Examples
41 ///
42 /// ```rust
43 /// use tenferro_ops::input_key::TensorInputKey;
44 ///
45 /// let key = TensorInputKey::User { id: 0 };
46 /// assert_eq!(key.primal_root(), &key);
47 /// ```
48 pub fn primal_root(&self) -> &Self {
49 match self {
50 TensorInputKey::User { .. } => self,
51 #[cfg(feature = "autodiff")]
52 TensorInputKey::Tangent { of, .. } => of.primal_root(),
53 }
54 }
55}
56
57#[cfg(feature = "autodiff")]
58impl ADKey for TensorInputKey {
59 fn tangent_of(&self, pass: DiffPassId) -> Self {
60 TensorInputKey::Tangent {
61 of: Box::new(self.clone()),
62 pass,
63 }
64 }
65}