tenferro_internal_ad_surface/
autograd_api.rs1use chainrules_core::Differentiable as _;
2
3use crate::{Error, Result, Tensor};
4
5#[derive(Debug, Clone, Default)]
6pub struct BackwardOptions {
7 pub retain_graph: bool,
8}
9
10#[derive(Debug, Clone, Default)]
11pub struct GradOptions {
12 pub retain_graph: bool,
13}
14
15fn invalid_argument(message: impl Into<String>) -> Error {
16 chainrules_core::AutodiffError::InvalidArgument(message.into()).into()
17}
18
19fn default_seed(output: &Tensor) -> crate::DynTensor {
20 output.primal().seed_cotangent()
21}
22
23fn accumulate_optional_grad(slot: &mut Option<crate::DynTensor>, grad: Option<crate::DynTensor>) {
24 match (slot.take(), grad) {
25 (None, None) => *slot = None,
26 (Some(existing), None) => *slot = Some(existing),
27 (None, Some(new_grad)) => *slot = Some(new_grad),
28 (Some(existing), Some(new_grad)) => {
29 *slot = Some(
30 <crate::DynTensor as chainrules_core::Differentiable>::accumulate_tangent(
31 existing, &new_grad,
32 ),
33 );
34 }
35 }
36}
37
38pub fn grad(
39 outputs: &[&Tensor],
40 inputs: &[&Tensor],
41 grad_outputs: Option<&[Tensor]>,
42 options: GradOptions,
43) -> Result<Vec<Option<Tensor>>> {
44 if let Some(grad_outputs) = grad_outputs {
45 if grad_outputs.len() != outputs.len() {
46 return Err(invalid_argument(format!(
47 "grad_outputs length mismatch: expected {}, found {}",
48 outputs.len(),
49 grad_outputs.len()
50 )));
51 }
52 }
53
54 let mut accum = vec![None; inputs.len()];
55 let wrt = inputs.iter().map(|input| input.value()).collect::<Vec<_>>();
56
57 for (index, output) in outputs.iter().enumerate() {
58 let seed = grad_outputs
59 .map(|grads| grads[index].primal().clone())
60 .unwrap_or_else(|| default_seed(output));
61 let grads = output.value().grad_wrt_with_seed(seed, &wrt)?;
62 for (slot, grad) in accum.iter_mut().zip(grads) {
63 accumulate_optional_grad(slot, grad);
64 }
65 }
66
67 let _ = options.retain_graph;
68 Ok(accum
69 .into_iter()
70 .map(|grad| grad.map(Tensor::from))
71 .collect())
72}
73
74pub fn backward(
75 outputs: &[&Tensor],
76 grad_outputs: Option<&[Tensor]>,
77 options: BackwardOptions,
78) -> Result<()> {
79 if let Some(grad_outputs) = grad_outputs {
80 if grad_outputs.len() != outputs.len() {
81 return Err(invalid_argument(format!(
82 "grad_outputs length mismatch: expected {}, found {}",
83 outputs.len(),
84 grad_outputs.len()
85 )));
86 }
87 }
88
89 for (index, output) in outputs.iter().enumerate() {
90 let seed = grad_outputs
91 .map(|grads| grads[index].primal().clone())
92 .unwrap_or_else(|| default_seed(output));
93 output.value().backward_with_seed(seed)?;
94 }
95
96 let _ = options.retain_graph;
97 Ok(())
98}