tenferro_internal_ad_surface/
autograd_api.rs

1use chainrules_core::Differentiable as _;
2
3use crate::{Error, Result, Tensor};
4
5#[derive(Debug, Clone, Default)]
6pub struct BackwardOptions {
7    pub retain_graph: bool,
8}
9
10#[derive(Debug, Clone, Default)]
11pub struct GradOptions {
12    pub retain_graph: bool,
13}
14
15fn invalid_argument(message: impl Into<String>) -> Error {
16    chainrules_core::AutodiffError::InvalidArgument(message.into()).into()
17}
18
19fn default_seed(output: &Tensor) -> crate::DynTensor {
20    output.primal().seed_cotangent()
21}
22
23fn accumulate_optional_grad(slot: &mut Option<crate::DynTensor>, grad: Option<crate::DynTensor>) {
24    match (slot.take(), grad) {
25        (None, None) => *slot = None,
26        (Some(existing), None) => *slot = Some(existing),
27        (None, Some(new_grad)) => *slot = Some(new_grad),
28        (Some(existing), Some(new_grad)) => {
29            *slot = Some(
30                <crate::DynTensor as chainrules_core::Differentiable>::accumulate_tangent(
31                    existing, &new_grad,
32                ),
33            );
34        }
35    }
36}
37
38pub fn grad(
39    outputs: &[&Tensor],
40    inputs: &[&Tensor],
41    grad_outputs: Option<&[Tensor]>,
42    options: GradOptions,
43) -> Result<Vec<Option<Tensor>>> {
44    if let Some(grad_outputs) = grad_outputs {
45        if grad_outputs.len() != outputs.len() {
46            return Err(invalid_argument(format!(
47                "grad_outputs length mismatch: expected {}, found {}",
48                outputs.len(),
49                grad_outputs.len()
50            )));
51        }
52    }
53
54    let mut accum = vec![None; inputs.len()];
55    let wrt = inputs.iter().map(|input| input.value()).collect::<Vec<_>>();
56
57    for (index, output) in outputs.iter().enumerate() {
58        let seed = grad_outputs
59            .map(|grads| grads[index].primal().clone())
60            .unwrap_or_else(|| default_seed(output));
61        let grads = output.value().grad_wrt_with_seed(seed, &wrt)?;
62        for (slot, grad) in accum.iter_mut().zip(grads) {
63            accumulate_optional_grad(slot, grad);
64        }
65    }
66
67    let _ = options.retain_graph;
68    Ok(accum
69        .into_iter()
70        .map(|grad| grad.map(Tensor::from))
71        .collect())
72}
73
74pub fn backward(
75    outputs: &[&Tensor],
76    grad_outputs: Option<&[Tensor]>,
77    options: BackwardOptions,
78) -> Result<()> {
79    if let Some(grad_outputs) = grad_outputs {
80        if grad_outputs.len() != outputs.len() {
81            return Err(invalid_argument(format!(
82                "grad_outputs length mismatch: expected {}, found {}",
83                outputs.len(),
84                grad_outputs.len()
85            )));
86        }
87    }
88
89    for (index, output) in outputs.iter().enumerate() {
90        let seed = grad_outputs
91            .map(|grads| grads[index].primal().clone())
92            .unwrap_or_else(|| default_seed(output));
93        output.value().backward_with_seed(seed)?;
94    }
95
96    let _ = options.retain_graph;
97    Ok(())
98}