tenferro_tensor/tensor/
autodiff.rs

1use tenferro_algebra::Scalar;
2#[cfg(feature = "cuda")]
3use tenferro_device::LogicalMemorySpace;
4
5use super::Tensor;
6use crate::layout::{add_strided, compute_contiguous_strides, StridedInput};
7use crate::MemoryOrder;
8
9impl<T: Scalar> chainrules_core::Differentiable for Tensor<T> {
10    type Tangent = Tensor<T>;
11
12    fn zero_tangent(&self) -> Tensor<T> {
13        self.zeros_like()
14            .unwrap_or_else(|err| panic!("zero_tangent: failed to create zero tensor: {err}"))
15    }
16
17    fn num_elements(&self) -> usize {
18        self.len()
19    }
20
21    fn seed_cotangent(&self) -> Tensor<T> {
22        self.ones_like()
23            .unwrap_or_else(|err| panic!("seed_cotangent: failed to create ones tensor: {err}"))
24    }
25
26    fn accumulate_tangent(a: Tensor<T>, b: &Tensor<T>) -> Tensor<T> {
27        assert_eq!(
28            a.dims, b.dims,
29            "tangent shape mismatch in accumulate_tangent"
30        );
31
32        let a_fw = a.fw_grad().cloned();
33        let b_fw = b.fw_grad().cloned();
34
35        #[cfg(feature = "cuda")]
36        let result_primal =
37            if matches!(a.logical_memory_space, LogicalMemorySpace::GpuMemory { .. }) {
38                crate::cuda_runtime::add_strided_tensor(&a, b)
39                    .unwrap_or_else(|err| panic!("accumulate_tangent: GPU addition failed: {err}"))
40            } else {
41                accumulate_tangent_cpu(&a, b)
42            };
43
44        #[cfg(not(feature = "cuda"))]
45        let result_primal = accumulate_tangent_cpu(&a, b);
46
47        let fw_grad = match (a_fw, b_fw) {
48            (Some(fa), Some(fb)) => Some(Self::accumulate_tangent(fa, &fb)),
49            (Some(fa), None) => Some(fa),
50            (None, Some(fb)) => Some(fb.clone()),
51            (None, None) => None,
52        };
53
54        let mut result = result_primal;
55        if let Some(fw_grad) = fw_grad {
56            result.set_fw_grad(fw_grad);
57        }
58        result
59    }
60}
61
62/// CPU path for element-wise tangent accumulation.
63fn accumulate_tangent_cpu<T: Scalar>(a: &Tensor<T>, b: &Tensor<T>) -> Tensor<T> {
64    let dst_strides = compute_contiguous_strides(&a.dims, MemoryOrder::ColumnMajor);
65    let mut data = vec![T::zero(); a.len()];
66    if !data.is_empty() {
67        add_strided(
68            &a.dims,
69            StridedInput {
70                data: a.cpu_backed_slice_or_panic("accumulate_tangent"),
71                strides: &a.strides,
72                offset: a.offset,
73            },
74            StridedInput {
75                data: b.cpu_backed_slice_or_panic("accumulate_tangent"),
76                strides: &b.strides,
77                offset: b.offset,
78            },
79            &mut data,
80            &dst_strides,
81        );
82    }
83    Tensor::from_owned_contiguous_data(
84        data,
85        a.dims.clone(),
86        MemoryOrder::ColumnMajor,
87        a.logical_memory_space,
88        a.preferred_compute_device,
89        false,
90    )
91}