Skip to main content

tenferro_ad/
traced.rs

1use std::sync::atomic::{AtomicU64, Ordering};
2use std::sync::Arc;
3
4use computegraph::resolve::resolve;
5use computegraph::types::ValueKey;
6use tenferro_ops::input_key::TensorInputKey;
7use tenferro_ops::ExtensionRuleSet;
8use tenferro_ops::ShapeGuardContext;
9use tenferro_runtime::ad_support::{
10    checkpoint_chain as tensor_checkpoint_chain, checkpoint_tensor,
11    extra_roots as tensor_extra_roots, inputs_map as tensor_inputs_map, leaf_input_key,
12    linear_input_key, metadata_scopes as tensor_metadata_scopes, metadata_scopes_with_new,
13    ones_tensor, push_metadata_scope, register_scoped_graph_metadata, registered_meta,
14    resolve_roots as tensor_resolve_roots, shape_hint as tensor_shape_hint, tensor_from_parts,
15    tensor_meta_from_tensor, TracedTensorParts,
16};
17use tenferro_runtime::{Error, GraphCompiler, GraphExecutor, Result, TracedTensor};
18use tenferro_tensor::TensorBackend;
19use tidu::{linear_transpose, linearize, ADRuleError};
20
21static NEXT_DIFF_PASS_ID: AtomicU64 = AtomicU64::new(0);
22
23fn next_pass_id() -> u64 {
24    NEXT_DIFF_PASS_ID.fetch_add(1, Ordering::Relaxed)
25}
26
27pub(crate) fn next_input_key() -> TensorInputKey {
28    tenferro_runtime::ad_support::allocate_input_key()
29}
30
31fn error_shape_hint(tensor: &TracedTensor) -> Vec<usize> {
32    tensor
33        .try_concrete_shape()
34        .unwrap_or_else(|| vec![0; tensor.rank])
35}
36
37fn shape_guard_context(extension_rules: Option<&ExtensionRuleSet>) -> ShapeGuardContext {
38    let ctx = ShapeGuardContext::with_global_metadata();
39    match extension_rules {
40        Some(rules) => ctx.with_extension_rules(rules.clone()),
41        None => ctx,
42    }
43}
44
45fn ad_rule_error(transform: &'static str, err: ADRuleError) -> Error {
46    match err {
47        ADRuleError::Unsupported { op, .. } => {
48            Error::Internal(format!("unsupported {transform} AD rule for {op}"))
49        }
50        ADRuleError::InvalidInput { op, message, .. } => Error::InvalidGraphBuild {
51            op: transform,
52            message: format!("{op}: {message}"),
53        },
54    }
55}
56
57pub(crate) fn grad_with_rules(
58    output: &TracedTensor,
59    wrt: &TracedTensor,
60    extension_rules: &ExtensionRuleSet,
61) -> Result<TracedTensor> {
62    grad_with_optional_rules(output, wrt, Some(extension_rules))
63}
64
65pub(crate) fn jvp_with_rules(
66    output: &TracedTensor,
67    wrt: &TracedTensor,
68    tangent: &TracedTensor,
69    extension_rules: &ExtensionRuleSet,
70) -> Result<TracedTensor> {
71    let wrt_input_key = leaf_input_key(wrt)?;
72    jvp_optional_impl(output, wrt, tangent, Some(extension_rules))?
73        .ok_or_else(|| Error::Internal(format!("jvp output is inactive for {:?}", wrt_input_key)))
74}
75
76pub(crate) fn grad_optional_with_rules(
77    output: &TracedTensor,
78    wrt: &TracedTensor,
79    extension_rules: &ExtensionRuleSet,
80) -> Result<Option<TracedTensor>> {
81    if output.rank != 0 {
82        return Err(Error::NonScalarGrad {
83            shape: error_shape_hint(output),
84        });
85    }
86
87    let ones = ones_tensor(output.dtype, vec![])?;
88    let seed = TracedTensor::from_tensor_concrete_shape(ones)?;
89    vjp_optional_impl(output, wrt, &seed, Some(extension_rules))
90}
91
92pub(crate) fn jvp_optional_with_rules(
93    output: &TracedTensor,
94    wrt: &TracedTensor,
95    tangent: &TracedTensor,
96    extension_rules: &ExtensionRuleSet,
97) -> Result<Option<TracedTensor>> {
98    jvp_optional_impl(output, wrt, tangent, Some(extension_rules))
99}
100
101pub(crate) fn vjp_with_rules(
102    output: &TracedTensor,
103    wrt: &TracedTensor,
104    cotangent: &TracedTensor,
105    extension_rules: &ExtensionRuleSet,
106) -> Result<TracedTensor> {
107    let wrt_input_key = leaf_input_key(wrt)?;
108    vjp_optional_impl(output, wrt, cotangent, Some(extension_rules))?
109        .ok_or_else(|| Error::Internal(format!("vjp output is inactive for {:?}", wrt_input_key)))
110}
111
112pub(crate) fn vjp_optional_with_rules(
113    output: &TracedTensor,
114    wrt: &TracedTensor,
115    cotangent: &TracedTensor,
116    extension_rules: &ExtensionRuleSet,
117) -> Result<Option<TracedTensor>> {
118    vjp_optional_impl(output, wrt, cotangent, Some(extension_rules))
119}
120
121fn grad_with_optional_rules(
122    output: &TracedTensor,
123    wrt: &TracedTensor,
124    extension_rules: Option<&ExtensionRuleSet>,
125) -> Result<TracedTensor> {
126    if output.rank != 0 {
127        return Err(Error::NonScalarGrad {
128            shape: error_shape_hint(output),
129        });
130    }
131
132    let ones = ones_tensor(output.dtype, vec![])?;
133    let seed = TracedTensor::from_tensor_concrete_shape(ones)?;
134    let wrt_input_key = leaf_input_key(wrt)?;
135    vjp_optional_impl(output, wrt, &seed, extension_rules)?
136        .ok_or_else(|| Error::Internal(format!("grad output is inactive for {:?}", wrt_input_key)))
137}
138
139/// Automatic differentiation helpers for [`TracedTensor`].
140///
141/// # Examples
142///
143/// ```rust
144/// use tenferro_ad::TracedTensorAdExt;
145/// use tenferro_runtime::TracedTensor;
146///
147/// let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
148/// let loss = x.scale_real(2.0);
149/// let maybe_dx = loss.grad_optional(&x).unwrap();
150/// assert!(maybe_dx.is_some());
151/// ```
152pub trait TracedTensorAdExt {
153    /// Gradient of a scalar output with respect to a traced input.
154    ///
155    /// For complex scalar outputs, tenferro returns the Hermitian-adjoint
156    /// cotangent. To compare seed-`1` scalar gradients with JAX's public
157    /// `grad` values, use the complex conjugate of this result. See
158    /// <https://tensor4all.org/tenferro-rs/guides/complex-ad.html>.
159    ///
160    /// # Examples
161    ///
162    /// ```rust
163    /// use tenferro_ad::TracedTensorAdExt;
164    /// use tenferro_cpu::CpuBackend;
165    /// use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
166    ///
167    /// fn eval(tensor: &TracedTensor) -> tenferro_runtime::Tensor {
168    ///     let mut compiler = GraphCompiler::new();
169    ///     let program = compiler.compile(tensor).unwrap();
170    ///     let mut executor = GraphExecutor::new(CpuBackend::new());
171    ///     executor.run(&program).unwrap()
172    /// }
173    ///
174    /// let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
175    /// let loss = (&x * &x).unwrap();
176    /// let dx = loss.grad(&x).unwrap();
177    ///
178    /// assert_eq!(eval(&dx).as_slice::<f64>().unwrap(), &[6.0]);
179    /// ```
180    fn grad(&self, wrt: &TracedTensor) -> Result<TracedTensor>;
181
182    /// Like [`grad`](Self::grad), but returns `None` when `wrt` is inactive.
183    ///
184    /// # Examples
185    ///
186    /// ```rust
187    /// use tenferro_ad::TracedTensorAdExt;
188    /// use tenferro_runtime::TracedTensor;
189    ///
190    /// let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
191    /// let y = TracedTensor::from_vec_col_major(vec![], vec![4.0_f64]).unwrap();
192    /// let loss = (&y * &y).unwrap();
193    ///
194    /// assert!(loss.grad_optional(&x).unwrap().is_none());
195    /// ```
196    fn grad_optional(&self, wrt: &TracedTensor) -> Result<Option<TracedTensor>>;
197
198    /// Evaluate this tensor and replace its graph with a concrete leaf while
199    /// preserving the previous graph for AD replay.
200    ///
201    /// # Examples
202    ///
203    /// ```rust
204    /// use tenferro_ad::TracedTensorAdExt;
205    /// use tenferro_cpu::CpuBackend;
206    /// use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
207    ///
208    /// let mut compiler = GraphCompiler::new();
209    /// let mut executor = GraphExecutor::new(CpuBackend::new());
210    /// let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
211    /// let mut y = (&x * &x).unwrap();
212    ///
213    /// y.checkpoint(&mut compiler, &mut executor).unwrap();
214    ///
215    /// let value = y.attached_data().unwrap();
216    /// assert_eq!(value.as_slice::<f64>().unwrap(), &[9.0]);
217    /// ```
218    fn checkpoint<B: TensorBackend>(
219        &mut self,
220        compiler: &mut GraphCompiler,
221        executor: &mut GraphExecutor<B>,
222    ) -> Result<()>;
223
224    /// Forward-mode Jacobian-vector product.
225    ///
226    /// # Examples
227    ///
228    /// ```rust
229    /// use tenferro_ad::TracedTensorAdExt;
230    /// use tenferro_cpu::CpuBackend;
231    /// use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
232    ///
233    /// fn eval(tensor: &TracedTensor) -> tenferro_runtime::Tensor {
234    ///     let mut compiler = GraphCompiler::new();
235    ///     let program = compiler.compile(tensor).unwrap();
236    ///     let mut executor = GraphExecutor::new(CpuBackend::new());
237    ///     executor.run(&program).unwrap()
238    /// }
239    ///
240    /// let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
241    /// let tangent = TracedTensor::from_vec_col_major(vec![], vec![2.0_f64]).unwrap();
242    /// let y = (&x * &x).unwrap();
243    /// let dy = y.jvp(&x, &tangent).unwrap();
244    ///
245    /// assert_eq!(eval(&dy).as_slice::<f64>().unwrap(), &[12.0]);
246    /// ```
247    fn jvp(&self, wrt: &TracedTensor, tangent: &TracedTensor) -> Result<TracedTensor>;
248
249    /// Like [`jvp`](Self::jvp), but returns `None` when `wrt` is inactive.
250    ///
251    /// # Examples
252    ///
253    /// ```rust
254    /// use tenferro_ad::TracedTensorAdExt;
255    /// use tenferro_runtime::TracedTensor;
256    ///
257    /// let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
258    /// let y = TracedTensor::from_vec_col_major(vec![], vec![4.0_f64]).unwrap();
259    /// let tangent = TracedTensor::from_vec_col_major(vec![], vec![1.0_f64]).unwrap();
260    /// let loss = (&y * &y).unwrap();
261    ///
262    /// assert!(loss.jvp_optional(&x, &tangent).unwrap().is_none());
263    /// ```
264    fn jvp_optional(
265        &self,
266        wrt: &TracedTensor,
267        tangent: &TracedTensor,
268    ) -> Result<Option<TracedTensor>>;
269
270    /// Reverse-mode vector-Jacobian product.
271    ///
272    /// Complex cotangents use tenferro's Hermitian real-inner-product
273    /// convention. Non-real complex cotangent seeds therefore need an explicit
274    /// seed-convention comparison when matching JAX. See
275    /// <https://tensor4all.org/tenferro-rs/guides/complex-ad.html>.
276    ///
277    /// # Examples
278    ///
279    /// ```rust
280    /// use tenferro_ad::TracedTensorAdExt;
281    /// use tenferro_cpu::CpuBackend;
282    /// use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
283    ///
284    /// fn eval(tensor: &TracedTensor) -> tenferro_runtime::Tensor {
285    ///     let mut compiler = GraphCompiler::new();
286    ///     let program = compiler.compile(tensor).unwrap();
287    ///     let mut executor = GraphExecutor::new(CpuBackend::new());
288    ///     executor.run(&program).unwrap()
289    /// }
290    ///
291    /// let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
292    /// let cotangent = TracedTensor::from_vec_col_major(vec![], vec![0.5_f64]).unwrap();
293    /// let y = (&x * &x).unwrap();
294    /// let dx = y.vjp(&x, &cotangent).unwrap();
295    ///
296    /// assert_eq!(eval(&dx).as_slice::<f64>().unwrap(), &[3.0]);
297    /// ```
298    fn vjp(&self, wrt: &TracedTensor, cotangent: &TracedTensor) -> Result<TracedTensor>;
299
300    /// Like [`vjp`](Self::vjp), but returns `None` when `wrt` is inactive.
301    ///
302    /// # Examples
303    ///
304    /// ```rust
305    /// use tenferro_ad::TracedTensorAdExt;
306    /// use tenferro_runtime::TracedTensor;
307    ///
308    /// let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
309    /// let y = TracedTensor::from_vec_col_major(vec![], vec![4.0_f64]).unwrap();
310    /// let cotangent = TracedTensor::from_vec_col_major(vec![], vec![1.0_f64]).unwrap();
311    /// let loss = (&y * &y).unwrap();
312    ///
313    /// assert!(loss.vjp_optional(&x, &cotangent).unwrap().is_none());
314    /// ```
315    fn vjp_optional(
316        &self,
317        wrt: &TracedTensor,
318        cotangent: &TracedTensor,
319    ) -> Result<Option<TracedTensor>>;
320}
321
322impl TracedTensorAdExt for TracedTensor {
323    fn grad(&self, wrt: &TracedTensor) -> Result<TracedTensor> {
324        grad_with_optional_rules(self, wrt, None)
325    }
326
327    fn grad_optional(&self, wrt: &TracedTensor) -> Result<Option<TracedTensor>> {
328        if self.rank != 0 {
329            return Err(Error::NonScalarGrad {
330                shape: error_shape_hint(self),
331            });
332        }
333
334        let ones = ones_tensor(self.dtype, vec![])?;
335        let seed = TracedTensor::from_tensor_concrete_shape(ones)?;
336        vjp_optional_impl(self, wrt, &seed, None)
337    }
338
339    fn checkpoint<B: TensorBackend>(
340        &mut self,
341        compiler: &mut GraphCompiler,
342        executor: &mut GraphExecutor<B>,
343    ) -> Result<()> {
344        let data = if let Some(data) = self.attached_data() {
345            Arc::clone(data)
346        } else {
347            let program = compiler.compile(self)?;
348            Arc::new(executor.run(&program)?)
349        };
350        checkpoint_tensor(self, data);
351        Ok(())
352    }
353
354    fn jvp(&self, wrt: &TracedTensor, tangent: &TracedTensor) -> Result<TracedTensor> {
355        let wrt_input_key = leaf_input_key(wrt)?;
356        self.jvp_optional(wrt, tangent)?.ok_or_else(|| {
357            Error::Internal(format!("jvp output is inactive for {:?}", wrt_input_key))
358        })
359    }
360
361    fn jvp_optional(
362        &self,
363        wrt: &TracedTensor,
364        tangent: &TracedTensor,
365    ) -> Result<Option<TracedTensor>> {
366        jvp_optional_impl(self, wrt, tangent, None)
367    }
368
369    fn vjp(&self, wrt: &TracedTensor, cotangent: &TracedTensor) -> Result<TracedTensor> {
370        let wrt_input_key = leaf_input_key(wrt)?;
371        self.vjp_optional(wrt, cotangent)?.ok_or_else(|| {
372            Error::Internal(format!("vjp output is inactive for {:?}", wrt_input_key))
373        })
374    }
375
376    fn vjp_optional(
377        &self,
378        wrt: &TracedTensor,
379        cotangent: &TracedTensor,
380    ) -> Result<Option<TracedTensor>> {
381        vjp_optional_impl(self, wrt, cotangent, None)
382    }
383}
384
385fn jvp_optional_impl(
386    output: &TracedTensor,
387    wrt: &TracedTensor,
388    tangent: &TracedTensor,
389    extension_rules: Option<&ExtensionRuleSet>,
390) -> Result<Option<TracedTensor>> {
391    let wrt_input_key = leaf_input_key(wrt)?;
392    let output_key = output.graph().values()[output.val].key.clone();
393    let checkpoint_chain = tensor_checkpoint_chain(output);
394    let aliases = checkpoint_chain
395        .as_ref()
396        .map(|chain| chain.collect_aliases())
397        .unwrap_or_default();
398    let checkpoint_graphs = checkpoint_chain
399        .as_ref()
400        .map(|chain| chain.collect_graphs())
401        .unwrap_or_default();
402    let mut roots = tensor_resolve_roots(output);
403    roots.extend(checkpoint_graphs.iter().cloned());
404    let view = resolve(roots);
405    let mut ad_ctx = shape_guard_context(extension_rules);
406    let linear = linearize(
407        &view,
408        std::slice::from_ref(&output_key),
409        std::slice::from_ref(&wrt_input_key),
410        next_pass_id(),
411        &mut ad_ctx,
412        &aliases,
413    )
414    .map_err(|err| ad_rule_error("jvp", err))?;
415    let Some(tangent_output) = linear.tangent_outputs()[0] else {
416        return Ok(None);
417    };
418    let tangent_input_key = linear_input_key(linear.as_graph(), linear.tangent_inputs()[0].1)?;
419    let tangent_data =
420        tangent
421            .attached_data()
422            .cloned()
423            .ok_or_else(|| Error::InvalidGraphBuild {
424                op: "jvp",
425                message: "jvp tangent must have concrete tensor data".to_string(),
426            })?;
427    let metadata_scope = register_scoped_graph_metadata(
428        linear.as_graph(),
429        vec![(
430            ValueKey::Input(tangent_input_key.clone()),
431            tensor_meta_from_tensor(tangent_data.as_ref()),
432        )],
433    )?;
434
435    let mut inputs_map = (*tensor_inputs_map(output)).clone();
436    if let Some(chain) = &checkpoint_chain {
437        inputs_map.extend(chain.collect_inputs());
438    }
439    inputs_map.insert(tangent_input_key, tangent_data);
440
441    let mut extra_roots = vec![Arc::clone(output.graph())];
442    extra_roots.extend(checkpoint_graphs);
443    extra_roots.extend(tensor_extra_roots(output));
444
445    Ok(Some(tensor_from_parts(TracedTensorParts {
446        rank: output.rank,
447        dtype: output.dtype,
448        graph: Arc::new(linear.into_graph()),
449        val: tangent_output,
450        data: None,
451        shape_hint: tensor_shape_hint(output),
452        inputs_map: Arc::new(inputs_map),
453        extra_roots,
454        checkpoint_chain,
455        metadata_scopes: metadata_scopes_with_new(
456            metadata_scope,
457            [
458                tensor_metadata_scopes(output),
459                tensor_metadata_scopes(wrt),
460                tensor_metadata_scopes(tangent),
461            ],
462        ),
463    })))
464}
465
466fn vjp_optional_impl(
467    output: &TracedTensor,
468    wrt: &TracedTensor,
469    cotangent: &TracedTensor,
470    extension_rules: Option<&ExtensionRuleSet>,
471) -> Result<Option<TracedTensor>> {
472    let wrt_input_key = leaf_input_key(wrt)?;
473    let output_key = output.graph().values()[output.val].key.clone();
474    let checkpoint_chain = tensor_checkpoint_chain(output);
475    let aliases = checkpoint_chain
476        .as_ref()
477        .map(|chain| chain.collect_aliases())
478        .unwrap_or_default();
479    let checkpoint_graphs = checkpoint_chain
480        .as_ref()
481        .map(|chain| chain.collect_graphs())
482        .unwrap_or_default();
483    let mut roots = tensor_resolve_roots(output);
484    roots.extend(checkpoint_graphs.iter().cloned());
485    let view = resolve(roots);
486    let mut ad_ctx = shape_guard_context(extension_rules);
487    let linear = linearize(
488        &view,
489        std::slice::from_ref(&output_key),
490        std::slice::from_ref(&wrt_input_key),
491        next_pass_id(),
492        &mut ad_ctx,
493        &aliases,
494    )
495    .map_err(|err| ad_rule_error("vjp", err))?;
496    if linear.tangent_outputs()[0].is_none() {
497        return Ok(None);
498    }
499    let linear_seed_key = linear_input_key(linear.as_graph(), linear.tangent_inputs()[0].1)?;
500    let linear_metadata_scope = register_scoped_graph_metadata(
501        linear.as_graph(),
502        vec![(
503            ValueKey::Input(linear_seed_key),
504            registered_meta(&wrt.graph().values()[wrt.val].key)?,
505        )],
506    )?;
507    ad_ctx.refresh_global_metadata();
508    let transposed =
509        linear_transpose(&linear, &mut ad_ctx).map_err(|err| ad_rule_error("vjp", err))?;
510    let cotangent_input_key =
511        linear_input_key(transposed.as_graph(), transposed.tangent_inputs()[0].1)?;
512    let cotangent_data =
513        cotangent
514            .attached_data()
515            .cloned()
516            .ok_or_else(|| Error::InvalidGraphBuild {
517                op: "vjp",
518                message: "vjp cotangent must have concrete tensor data".to_string(),
519            })?;
520    let transposed_metadata_scope = register_scoped_graph_metadata(
521        transposed.as_graph(),
522        vec![(
523            ValueKey::Input(cotangent_input_key.clone()),
524            tensor_meta_from_tensor(cotangent_data.as_ref()),
525        )],
526    )?;
527    let linear_graph = Arc::new(linear.into_graph());
528    let Some(cotangent_output) = transposed.tangent_outputs()[0] else {
529        return Ok(None);
530    };
531
532    let mut inputs_map = (*tensor_inputs_map(output)).clone();
533    if let Some(chain) = &checkpoint_chain {
534        inputs_map.extend(chain.collect_inputs());
535    }
536    inputs_map.insert(cotangent_input_key.clone(), cotangent_data);
537
538    let mut extra_roots = vec![Arc::clone(output.graph()), linear_graph];
539    extra_roots.extend(checkpoint_graphs);
540    extra_roots.extend(tensor_extra_roots(output));
541
542    Ok(Some(tensor_from_parts(TracedTensorParts {
543        rank: wrt.rank,
544        dtype: wrt.dtype,
545        graph: Arc::new(transposed.into_graph()),
546        val: cotangent_output,
547        data: None,
548        shape_hint: tensor_shape_hint(wrt),
549        inputs_map: Arc::new(inputs_map),
550        extra_roots,
551        checkpoint_chain,
552        metadata_scopes: {
553            let mut scopes = metadata_scopes_with_new(
554                linear_metadata_scope,
555                [
556                    tensor_metadata_scopes(output),
557                    tensor_metadata_scopes(wrt),
558                    tensor_metadata_scopes(cotangent),
559                ],
560            );
561            push_metadata_scope(&mut scopes, Arc::new(transposed_metadata_scope));
562            scopes
563        },
564    })))
565}