tenferro_internal_frontend_core/
autodiff.rs1use chainrules_core::Differentiable;
2
3use crate::{DynTensor, StructuredTensor};
4
5impl Differentiable for DynTensor {
6 type Tangent = DynTensor;
7
8 fn zero_tangent(&self) -> Self::Tangent {
9 match self {
10 Self::F32(value) => Self::F32(value.zero_tangent()),
11 Self::F64(value) => Self::F64(value.zero_tangent()),
12 Self::C32(value) => Self::C32(value.zero_tangent()),
13 Self::C64(value) => Self::C64(value.zero_tangent()),
14 }
15 }
16
17 fn accumulate_tangent(a: Self::Tangent, b: &Self::Tangent) -> Self::Tangent {
18 match (a, b) {
19 (Self::F32(lhs), Self::F32(rhs)) => {
20 Self::F32(StructuredTensor::<f32>::accumulate_tangent(lhs, rhs))
21 }
22 (Self::F64(lhs), Self::F64(rhs)) => {
23 Self::F64(StructuredTensor::<f64>::accumulate_tangent(lhs, rhs))
24 }
25 (Self::C32(lhs), Self::C32(rhs)) => {
26 Self::C32(StructuredTensor::<num_complex::Complex32>::accumulate_tangent(lhs, rhs))
27 }
28 (Self::C64(lhs), Self::C64(rhs)) => {
29 Self::C64(StructuredTensor::<num_complex::Complex64>::accumulate_tangent(lhs, rhs))
30 }
31 (lhs, rhs) => unreachable!(
32 "DynTensor::accumulate_tangent requires matching dtypes, got lhs={:?}, rhs={:?}",
33 lhs.scalar_type(),
34 rhs.scalar_type()
35 ),
36 }
37 }
38
39 fn num_elements(&self) -> usize {
40 match self {
41 Self::F32(value) => value.num_elements(),
42 Self::F64(value) => value.num_elements(),
43 Self::C32(value) => value.num_elements(),
44 Self::C64(value) => value.num_elements(),
45 }
46 }
47
48 fn seed_cotangent(&self) -> Self::Tangent {
49 match self {
50 Self::F32(value) => Self::F32(value.seed_cotangent()),
51 Self::F64(value) => Self::F64(value.seed_cotangent()),
52 Self::C32(value) => Self::C32(value.seed_cotangent()),
53 Self::C64(value) => Self::C64(value.seed_cotangent()),
54 }
55 }
56}