tidu/
linearized.rs

1use std::sync::Arc;
2
3use crate::checkpoint::{current_ad_policy, storage_decision, CheckpointHint, StorageDecision};
4use crate::reverse_graph::{ReverseEdge, ReverseNode, StoredNodeLinearization};
5use crate::{AdResult, AutodiffError, Differentiable, Value};
6
7/// AD-role metadata for one input or output slot.
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub struct SlotSchema {
10    pub differentiable: bool,
11    pub auxiliary: bool,
12}
13
14impl SlotSchema {
15    fn validate(self, kind: &str, index: usize) -> AdResult<Self> {
16        if self.auxiliary && self.differentiable {
17            return Err(AutodiffError::InvalidArgument(format!(
18                "{kind} schema slot {index} cannot be auxiliary and differentiable at the same time"
19            )));
20        }
21        Ok(self)
22    }
23}
24
25/// Runtime schema for op inputs or outputs.
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct Schema {
28    pub slots: Vec<SlotSchema>,
29}
30
31impl Schema {
32    pub(crate) fn validate_len(&self, kind: &str, expected_len: usize) -> AdResult<()> {
33        if self.slots.len() != expected_len {
34            return Err(AutodiffError::InvalidArgument(format!(
35                "{kind} schema returned {} slots for {expected_len} values",
36                self.slots.len()
37            )));
38        }
39        for (index, slot) in self.slots.iter().copied().enumerate() {
40            slot.validate(kind, index)?;
41        }
42        Ok(())
43    }
44}
45
46pub trait LinearizableOp<V: Differentiable + Send + Sync + 'static>: Send + Sync + 'static {
47    type Linearized: LinearizedOp<V> + Send + Sync + 'static;
48
49    fn primal(&self, inputs: &[&V]) -> AdResult<Vec<V>>;
50    fn input_schema(&self, inputs: &[&V]) -> AdResult<Schema>;
51    fn output_schema(&self, inputs: &[&V], outputs: &[V]) -> AdResult<Schema>;
52    fn linearize(&self, inputs: &[&V], outputs: &[V]) -> AdResult<Self::Linearized>;
53
54    /// Return a retain-vs-replay hint for the runtime checkpoint policy.
55    fn checkpoint_hint(&self) -> CheckpointHint {
56        CheckpointHint::CheapReplay
57    }
58
59    fn apply(&self, inputs: &[&Value<V>]) -> AdResult<Vec<Value<V>>>
60    where
61        Self: Sized + Clone,
62    {
63        let primals: Vec<&V> = inputs.iter().map(|input| input.primal()).collect();
64        let input_schema = self.input_schema(&primals)?;
65        input_schema.validate_len("input", inputs.len())?;
66
67        let outputs = self.primal(&primals)?;
68        let output_schema = self.output_schema(&primals, &outputs)?;
69        output_schema.validate_len("output", outputs.len())?;
70
71        let input_grad_mask: Vec<bool> = inputs
72            .iter()
73            .zip(&input_schema.slots)
74            .map(|(input, slot)| input.requires_grad() && slot.differentiable)
75            .collect();
76
77        let differentiable_output_slots: Vec<usize> = output_schema
78            .slots
79            .iter()
80            .enumerate()
81            .filter_map(|(index, slot)| slot.differentiable.then_some(index))
82            .collect();
83
84        if !input_grad_mask.iter().any(|needed| *needed) || differentiable_output_slots.is_empty() {
85            return Ok(outputs.into_iter().map(Value::new).collect());
86        }
87
88        let input_nodes = inputs
89            .iter()
90            .zip(&input_schema.slots)
91            .map(|(input, slot)| {
92                if slot.differentiable {
93                    Ok(input.reverse_input())
94                } else {
95                    Ok(None)
96                }
97            })
98            .collect::<AdResult<Vec<_>>>()?;
99
100        let stored_linearization =
101            match storage_decision(current_ad_policy(), self.checkpoint_hint()) {
102                StorageDecision::Retain => {
103                    StoredNodeLinearization::retained(self.linearize(&primals, &outputs)?)
104                }
105                StorageDecision::Replay => StoredNodeLinearization::replay(
106                    self.clone(),
107                    inputs.iter().map(|input| input.shared_primal()).collect(),
108                ),
109            };
110
111        let output_count = output_schema.slots.len();
112        let node = Arc::new(ReverseNode::new(
113            input_nodes,
114            output_count,
115            input_grad_mask,
116            stored_linearization,
117        ));
118
119        Ok(outputs
120            .into_iter()
121            .enumerate()
122            .map(|(index, output)| {
123                if output_schema.slots[index].differentiable {
124                    Value::from_reverse_edge(
125                        output,
126                        ReverseEdge {
127                            node: node.clone(),
128                            output_slot: index,
129                        },
130                    )
131                } else {
132                    Value::new(output)
133                }
134            })
135            .collect())
136    }
137
138    fn apply_one(&self, inputs: &[&Value<V>]) -> AdResult<Value<V>>
139    where
140        Self: Sized + Clone,
141    {
142        let mut outputs = self.apply(inputs)?;
143        if outputs.len() != 1 {
144            return Err(AutodiffError::InvalidArgument(format!(
145                "LinearizableOp::apply_one expected exactly 1 output, got {}",
146                outputs.len()
147            )));
148        }
149        Ok(outputs.remove(0))
150    }
151}
152
153pub trait LinearizedOp<V: Differentiable + Send + Sync + 'static>: Send + Sync + 'static {
154    fn jvp(&self, input_tangents: &[Option<V::Tangent>]) -> AdResult<Vec<Option<V::Tangent>>>;
155
156    fn vjp(
157        &self,
158        output_cotangents: &[Option<V::Tangent>],
159        input_grad_mask: &[bool],
160    ) -> AdResult<Vec<Option<V::Tangent>>>;
161}