tenferro_internal_frontend_core/
autodiff.rs

1use chainrules_core::Differentiable;
2
3use crate::{DynTensor, StructuredTensor};
4
5impl Differentiable for DynTensor {
6    type Tangent = DynTensor;
7
8    fn zero_tangent(&self) -> Self::Tangent {
9        match self {
10            Self::F32(value) => Self::F32(value.zero_tangent()),
11            Self::F64(value) => Self::F64(value.zero_tangent()),
12            Self::C32(value) => Self::C32(value.zero_tangent()),
13            Self::C64(value) => Self::C64(value.zero_tangent()),
14        }
15    }
16
17    fn accumulate_tangent(a: Self::Tangent, b: &Self::Tangent) -> Self::Tangent {
18        match (a, b) {
19            (Self::F32(lhs), Self::F32(rhs)) => {
20                Self::F32(StructuredTensor::<f32>::accumulate_tangent(lhs, rhs))
21            }
22            (Self::F64(lhs), Self::F64(rhs)) => {
23                Self::F64(StructuredTensor::<f64>::accumulate_tangent(lhs, rhs))
24            }
25            (Self::C32(lhs), Self::C32(rhs)) => {
26                Self::C32(StructuredTensor::<num_complex::Complex32>::accumulate_tangent(lhs, rhs))
27            }
28            (Self::C64(lhs), Self::C64(rhs)) => {
29                Self::C64(StructuredTensor::<num_complex::Complex64>::accumulate_tangent(lhs, rhs))
30            }
31            (lhs, rhs) => unreachable!(
32                "DynTensor::accumulate_tangent requires matching dtypes, got lhs={:?}, rhs={:?}",
33                lhs.scalar_type(),
34                rhs.scalar_type()
35            ),
36        }
37    }
38
39    fn num_elements(&self) -> usize {
40        match self {
41            Self::F32(value) => value.num_elements(),
42            Self::F64(value) => value.num_elements(),
43            Self::C32(value) => value.num_elements(),
44            Self::C64(value) => value.num_elements(),
45        }
46    }
47
48    fn seed_cotangent(&self) -> Self::Tangent {
49        match self {
50            Self::F32(value) => Self::F32(value.seed_cotangent()),
51            Self::F64(value) => Self::F64(value.seed_cotangent()),
52            Self::C32(value) => Self::C32(value.seed_cotangent()),
53            Self::C64(value) => Self::C64(value.seed_cotangent()),
54        }
55    }
56}