1use std::sync::{Arc, Mutex};
2
3use crate::reverse_graph::{
4 backward_from, grad_wrt, leaf_grad, leaf_handle, shares_graph, zero_leaf_grad, LeafHandle,
5 ReverseEdge, ReverseInput,
6};
7use crate::{AdResult, AutodiffError, Differentiable};
8
9enum ReverseHandle<V: Differentiable> {
10 None,
11 Leaf(LeafHandle<V>),
12 Edge(ReverseEdge<V>),
13}
14
15struct ReverseState<V: Differentiable> {
16 requires_grad: bool,
17 handle: ReverseHandle<V>,
18}
19
20pub struct Value<V: Differentiable> {
26 primal: Arc<V>,
27 reverse: Mutex<ReverseState<V>>,
28}
29
30impl<V: Differentiable + 'static> Value<V> {
31 pub fn new(primal: V) -> Self {
33 Self {
34 primal: Arc::new(primal),
35 reverse: Mutex::new(ReverseState {
36 requires_grad: false,
37 handle: ReverseHandle::None,
38 }),
39 }
40 }
41
42 pub(crate) fn from_reverse_edge(primal: V, edge: ReverseEdge<V>) -> Self {
43 Self {
44 primal: Arc::new(primal),
45 reverse: Mutex::new(ReverseState {
46 requires_grad: true,
47 handle: ReverseHandle::Edge(edge),
48 }),
49 }
50 }
51
52 pub fn primal(&self) -> &V {
54 self.primal.as_ref()
55 }
56
57 pub(crate) fn shared_primal(&self) -> Arc<V> {
58 self.primal.clone()
59 }
60
61 pub fn requires_grad(&self) -> bool {
63 self.reverse
64 .lock()
65 .expect("reverse state poisoned")
66 .requires_grad
67 }
68
69 pub fn with_requires_grad(self, enabled: bool) -> Self {
71 {
72 let mut reverse = self.reverse.lock().expect("reverse state poisoned");
73 reverse.requires_grad = enabled;
74 reverse.handle = if enabled {
75 match &reverse.handle {
76 ReverseHandle::Leaf(existing) => ReverseHandle::Leaf(existing.clone()),
77 ReverseHandle::Edge(existing) => ReverseHandle::Edge(existing.clone()),
78 ReverseHandle::None => ReverseHandle::Leaf(leaf_handle()),
79 }
80 } else {
81 ReverseHandle::None
82 };
83 }
84 self
85 }
86
87 pub(crate) fn reverse_input(&self) -> Option<ReverseInput<V>> {
88 let mut reverse = self.reverse.lock().expect("reverse state poisoned");
89 if !reverse.requires_grad {
90 return None;
91 }
92 let input = match &reverse.handle {
93 ReverseHandle::Leaf(handle) => ReverseInput::Leaf(handle.clone()),
94 ReverseHandle::Edge(edge) => ReverseInput::Edge(edge.clone()),
95 ReverseHandle::None => {
96 let handle = leaf_handle();
97 reverse.handle = ReverseHandle::Leaf(handle.clone());
98 ReverseInput::Leaf(handle)
99 }
100 };
101 Some(input)
102 }
103
104 pub fn grad(&self) -> AdResult<Option<V::Tangent>>
106 where
107 V::Tangent: Clone,
108 {
109 let reverse = self.reverse.lock().expect("reverse state poisoned");
110 if !reverse.requires_grad {
111 return Ok(None);
112 }
113 match &reverse.handle {
114 ReverseHandle::Leaf(handle) => Ok(leaf_grad::<V>(handle)),
115 ReverseHandle::Edge(_) | ReverseHandle::None => Ok(None),
116 }
117 }
118
119 pub fn zero_grad(&self) -> AdResult<()> {
121 let reverse = self.reverse.lock().expect("reverse state poisoned");
122 if !reverse.requires_grad {
123 return Ok(());
124 }
125 if let ReverseHandle::Leaf(handle) = &reverse.handle {
126 zero_leaf_grad::<V>(handle);
127 }
128 Ok(())
129 }
130
131 pub fn backward(&self) -> AdResult<()>
133 where
134 V::Tangent: Clone,
135 {
136 let n = self.primal.as_ref().num_elements();
137 if n != 1 {
138 return Err(AutodiffError::NonScalarLoss { num_elements: n });
139 }
140 self.backward_with_seed(self.primal.as_ref().seed_cotangent())
141 }
142
143 pub fn backward_with_seed(&self, seed: V::Tangent) -> AdResult<()>
145 where
146 V::Tangent: Clone,
147 {
148 let input = self.reverse_input().ok_or(AutodiffError::MissingNode)?;
149 backward_from(input, seed)
150 }
151
152 pub fn grad_wrt_with_seed(
154 &self,
155 seed: V::Tangent,
156 wrt: &[&Self],
157 ) -> AdResult<Vec<Option<V::Tangent>>>
158 where
159 V::Tangent: Clone,
160 {
161 let input = self.reverse_input().ok_or(AutodiffError::MissingNode)?;
162 let wrt_inputs = wrt
163 .iter()
164 .map(|value| value.reverse_input())
165 .collect::<Vec<_>>();
166 grad_wrt(input, seed, &wrt_inputs)
167 }
168
169 pub fn shares_reverse_graph(&self, other: &Self) -> bool {
171 match (self.reverse_input(), other.reverse_input()) {
172 (Some(lhs), Some(rhs)) => shares_graph(&lhs, &rhs),
173 (None, None) => true,
174 _ => false,
175 }
176 }
177}