1use std::sync::Arc;
2
3use crate::checkpoint::{current_ad_policy, storage_decision, CheckpointHint, StorageDecision};
4use crate::reverse_graph::{ReverseEdge, ReverseNode, StoredNodeLinearization};
5use crate::{AdResult, AutodiffError, Differentiable, Value};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub struct SlotSchema {
10 pub differentiable: bool,
11 pub auxiliary: bool,
12}
13
14impl SlotSchema {
15 fn validate(self, kind: &str, index: usize) -> AdResult<Self> {
16 if self.auxiliary && self.differentiable {
17 return Err(AutodiffError::InvalidArgument(format!(
18 "{kind} schema slot {index} cannot be auxiliary and differentiable at the same time"
19 )));
20 }
21 Ok(self)
22 }
23}
24
25#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct Schema {
28 pub slots: Vec<SlotSchema>,
29}
30
31impl Schema {
32 pub(crate) fn validate_len(&self, kind: &str, expected_len: usize) -> AdResult<()> {
33 if self.slots.len() != expected_len {
34 return Err(AutodiffError::InvalidArgument(format!(
35 "{kind} schema returned {} slots for {expected_len} values",
36 self.slots.len()
37 )));
38 }
39 for (index, slot) in self.slots.iter().copied().enumerate() {
40 slot.validate(kind, index)?;
41 }
42 Ok(())
43 }
44}
45
46pub trait LinearizableOp<V: Differentiable + Send + Sync + 'static>: Send + Sync + 'static {
47 type Linearized: LinearizedOp<V> + Send + Sync + 'static;
48
49 fn primal(&self, inputs: &[&V]) -> AdResult<Vec<V>>;
50 fn input_schema(&self, inputs: &[&V]) -> AdResult<Schema>;
51 fn output_schema(&self, inputs: &[&V], outputs: &[V]) -> AdResult<Schema>;
52 fn linearize(&self, inputs: &[&V], outputs: &[V]) -> AdResult<Self::Linearized>;
53
54 fn checkpoint_hint(&self) -> CheckpointHint {
56 CheckpointHint::CheapReplay
57 }
58
59 fn apply(&self, inputs: &[&Value<V>]) -> AdResult<Vec<Value<V>>>
60 where
61 Self: Sized + Clone,
62 {
63 let primals: Vec<&V> = inputs.iter().map(|input| input.primal()).collect();
64 let input_schema = self.input_schema(&primals)?;
65 input_schema.validate_len("input", inputs.len())?;
66
67 let outputs = self.primal(&primals)?;
68 let output_schema = self.output_schema(&primals, &outputs)?;
69 output_schema.validate_len("output", outputs.len())?;
70
71 let input_grad_mask: Vec<bool> = inputs
72 .iter()
73 .zip(&input_schema.slots)
74 .map(|(input, slot)| input.requires_grad() && slot.differentiable)
75 .collect();
76
77 let differentiable_output_slots: Vec<usize> = output_schema
78 .slots
79 .iter()
80 .enumerate()
81 .filter_map(|(index, slot)| slot.differentiable.then_some(index))
82 .collect();
83
84 if !input_grad_mask.iter().any(|needed| *needed) || differentiable_output_slots.is_empty() {
85 return Ok(outputs.into_iter().map(Value::new).collect());
86 }
87
88 let input_nodes = inputs
89 .iter()
90 .zip(&input_schema.slots)
91 .map(|(input, slot)| {
92 if slot.differentiable {
93 Ok(input.reverse_input())
94 } else {
95 Ok(None)
96 }
97 })
98 .collect::<AdResult<Vec<_>>>()?;
99
100 let stored_linearization =
101 match storage_decision(current_ad_policy(), self.checkpoint_hint()) {
102 StorageDecision::Retain => {
103 StoredNodeLinearization::retained(self.linearize(&primals, &outputs)?)
104 }
105 StorageDecision::Replay => StoredNodeLinearization::replay(
106 self.clone(),
107 inputs.iter().map(|input| input.shared_primal()).collect(),
108 ),
109 };
110
111 let output_count = output_schema.slots.len();
112 let node = Arc::new(ReverseNode::new(
113 input_nodes,
114 output_count,
115 input_grad_mask,
116 stored_linearization,
117 ));
118
119 Ok(outputs
120 .into_iter()
121 .enumerate()
122 .map(|(index, output)| {
123 if output_schema.slots[index].differentiable {
124 Value::from_reverse_edge(
125 output,
126 ReverseEdge {
127 node: node.clone(),
128 output_slot: index,
129 },
130 )
131 } else {
132 Value::new(output)
133 }
134 })
135 .collect())
136 }
137
138 fn apply_one(&self, inputs: &[&Value<V>]) -> AdResult<Value<V>>
139 where
140 Self: Sized + Clone,
141 {
142 let mut outputs = self.apply(inputs)?;
143 if outputs.len() != 1 {
144 return Err(AutodiffError::InvalidArgument(format!(
145 "LinearizableOp::apply_one expected exactly 1 output, got {}",
146 outputs.len()
147 )));
148 }
149 Ok(outputs.remove(0))
150 }
151}
152
153pub trait LinearizedOp<V: Differentiable + Send + Sync + 'static>: Send + Sync + 'static {
154 fn jvp(&self, input_tangents: &[Option<V::Tangent>]) -> AdResult<Vec<Option<V::Tangent>>>;
155
156 fn vjp(
157 &self,
158 output_cotangents: &[Option<V::Tangent>],
159 input_grad_mask: &[bool],
160 ) -> AdResult<Vec<Option<V::Tangent>>>;
161}