tidu/
value.rs

1use std::sync::{Arc, Mutex};
2
3use crate::reverse_graph::{
4    backward_from, grad_wrt, leaf_grad, leaf_handle, shares_graph, zero_leaf_grad, LeafHandle,
5    ReverseEdge, ReverseInput,
6};
7use crate::{AdResult, AutodiffError, Differentiable};
8
9enum ReverseHandle<V: Differentiable> {
10    None,
11    Leaf(LeafHandle<V>),
12    Edge(ReverseEdge<V>),
13}
14
15struct ReverseState<V: Differentiable> {
16    requires_grad: bool,
17    handle: ReverseHandle<V>,
18}
19
20/// Public value handle for reverse-mode AD.
21///
22/// `Value` exposes a torch-like surface while keeping graph ownership hidden.
23/// Internally it carries either a leaf gradient sink or an edge into a reverse
24/// graph.
25pub struct Value<V: Differentiable> {
26    primal: Arc<V>,
27    reverse: Mutex<ReverseState<V>>,
28}
29
30impl<V: Differentiable + 'static> Value<V> {
31    /// Create a detached value.
32    pub fn new(primal: V) -> Self {
33        Self {
34            primal: Arc::new(primal),
35            reverse: Mutex::new(ReverseState {
36                requires_grad: false,
37                handle: ReverseHandle::None,
38            }),
39        }
40    }
41
42    pub(crate) fn from_reverse_edge(primal: V, edge: ReverseEdge<V>) -> Self {
43        Self {
44            primal: Arc::new(primal),
45            reverse: Mutex::new(ReverseState {
46                requires_grad: true,
47                handle: ReverseHandle::Edge(edge),
48            }),
49        }
50    }
51
52    /// Borrow the primal value.
53    pub fn primal(&self) -> &V {
54        self.primal.as_ref()
55    }
56
57    pub(crate) fn shared_primal(&self) -> Arc<V> {
58        self.primal.clone()
59    }
60
61    /// Return whether this value participates in reverse-mode AD.
62    pub fn requires_grad(&self) -> bool {
63        self.reverse
64            .lock()
65            .expect("reverse state poisoned")
66            .requires_grad
67    }
68
69    /// Return a new value handle with gradient tracking enabled or disabled.
70    pub fn with_requires_grad(self, enabled: bool) -> Self {
71        {
72            let mut reverse = self.reverse.lock().expect("reverse state poisoned");
73            reverse.requires_grad = enabled;
74            reverse.handle = if enabled {
75                match &reverse.handle {
76                    ReverseHandle::Leaf(existing) => ReverseHandle::Leaf(existing.clone()),
77                    ReverseHandle::Edge(existing) => ReverseHandle::Edge(existing.clone()),
78                    ReverseHandle::None => ReverseHandle::Leaf(leaf_handle()),
79                }
80            } else {
81                ReverseHandle::None
82            };
83        }
84        self
85    }
86
87    pub(crate) fn reverse_input(&self) -> Option<ReverseInput<V>> {
88        let mut reverse = self.reverse.lock().expect("reverse state poisoned");
89        if !reverse.requires_grad {
90            return None;
91        }
92        let input = match &reverse.handle {
93            ReverseHandle::Leaf(handle) => ReverseInput::Leaf(handle.clone()),
94            ReverseHandle::Edge(edge) => ReverseInput::Edge(edge.clone()),
95            ReverseHandle::None => {
96                let handle = leaf_handle();
97                reverse.handle = ReverseHandle::Leaf(handle.clone());
98                ReverseInput::Leaf(handle)
99            }
100        };
101        Some(input)
102    }
103
104    /// Read the accumulated leaf gradient, if available.
105    pub fn grad(&self) -> AdResult<Option<V::Tangent>>
106    where
107        V::Tangent: Clone,
108    {
109        let reverse = self.reverse.lock().expect("reverse state poisoned");
110        if !reverse.requires_grad {
111            return Ok(None);
112        }
113        match &reverse.handle {
114            ReverseHandle::Leaf(handle) => Ok(leaf_grad::<V>(handle)),
115            ReverseHandle::Edge(_) | ReverseHandle::None => Ok(None),
116        }
117    }
118
119    /// Clear the accumulated leaf gradient.
120    pub fn zero_grad(&self) -> AdResult<()> {
121        let reverse = self.reverse.lock().expect("reverse state poisoned");
122        if !reverse.requires_grad {
123            return Ok(());
124        }
125        if let ReverseHandle::Leaf(handle) = &reverse.handle {
126            zero_leaf_grad::<V>(handle);
127        }
128        Ok(())
129    }
130
131    /// Run reverse-mode backward with the default scalar cotangent seed.
132    pub fn backward(&self) -> AdResult<()>
133    where
134        V::Tangent: Clone,
135    {
136        let n = self.primal.as_ref().num_elements();
137        if n != 1 {
138            return Err(AutodiffError::NonScalarLoss { num_elements: n });
139        }
140        self.backward_with_seed(self.primal.as_ref().seed_cotangent())
141    }
142
143    /// Run reverse-mode backward with an explicit cotangent seed.
144    pub fn backward_with_seed(&self, seed: V::Tangent) -> AdResult<()>
145    where
146        V::Tangent: Clone,
147    {
148        let input = self.reverse_input().ok_or(AutodiffError::MissingNode)?;
149        backward_from(input, seed)
150    }
151
152    /// Compute functional gradients with respect to the requested inputs.
153    pub fn grad_wrt_with_seed(
154        &self,
155        seed: V::Tangent,
156        wrt: &[&Self],
157    ) -> AdResult<Vec<Option<V::Tangent>>>
158    where
159        V::Tangent: Clone,
160    {
161        let input = self.reverse_input().ok_or(AutodiffError::MissingNode)?;
162        let wrt_inputs = wrt
163            .iter()
164            .map(|value| value.reverse_input())
165            .collect::<Vec<_>>();
166        grad_wrt(input, seed, &wrt_inputs)
167    }
168
169    /// Return whether `self` and `other` share any reachable reverse graph.
170    pub fn shares_reverse_graph(&self, other: &Self) -> bool {
171        match (self.reverse_input(), other.reverse_input()) {
172            (Some(lhs), Some(rhs)) => shares_graph(&lhs, &rhs),
173            (None, None) => true,
174            _ => false,
175        }
176    }
177}