tenferro_tensor/tensor/
autodiff.rs1use tenferro_algebra::Scalar;
2#[cfg(feature = "cuda")]
3use tenferro_device::LogicalMemorySpace;
4
5use super::Tensor;
6use crate::layout::{add_strided, compute_contiguous_strides, StridedInput};
7use crate::MemoryOrder;
8
9impl<T: Scalar> chainrules_core::Differentiable for Tensor<T> {
10 type Tangent = Tensor<T>;
11
12 fn zero_tangent(&self) -> Tensor<T> {
13 self.zeros_like()
14 .unwrap_or_else(|err| panic!("zero_tangent: failed to create zero tensor: {err}"))
15 }
16
17 fn num_elements(&self) -> usize {
18 self.len()
19 }
20
21 fn seed_cotangent(&self) -> Tensor<T> {
22 self.ones_like()
23 .unwrap_or_else(|err| panic!("seed_cotangent: failed to create ones tensor: {err}"))
24 }
25
26 fn accumulate_tangent(a: Tensor<T>, b: &Tensor<T>) -> Tensor<T> {
27 assert_eq!(
28 a.dims, b.dims,
29 "tangent shape mismatch in accumulate_tangent"
30 );
31
32 let a_fw = a.fw_grad().cloned();
33 let b_fw = b.fw_grad().cloned();
34
35 #[cfg(feature = "cuda")]
36 let result_primal =
37 if matches!(a.logical_memory_space, LogicalMemorySpace::GpuMemory { .. }) {
38 crate::cuda_runtime::add_strided_tensor(&a, b)
39 .unwrap_or_else(|err| panic!("accumulate_tangent: GPU addition failed: {err}"))
40 } else {
41 accumulate_tangent_cpu(&a, b)
42 };
43
44 #[cfg(not(feature = "cuda"))]
45 let result_primal = accumulate_tangent_cpu(&a, b);
46
47 let fw_grad = match (a_fw, b_fw) {
48 (Some(fa), Some(fb)) => Some(Self::accumulate_tangent(fa, &fb)),
49 (Some(fa), None) => Some(fa),
50 (None, Some(fb)) => Some(fb.clone()),
51 (None, None) => None,
52 };
53
54 let mut result = result_primal;
55 if let Some(fw_grad) = fw_grad {
56 result.set_fw_grad(fw_grad);
57 }
58 result
59 }
60}
61
62fn accumulate_tangent_cpu<T: Scalar>(a: &Tensor<T>, b: &Tensor<T>) -> Tensor<T> {
64 let dst_strides = compute_contiguous_strides(&a.dims, MemoryOrder::ColumnMajor);
65 let mut data = vec![T::zero(); a.len()];
66 if !data.is_empty() {
67 add_strided(
68 &a.dims,
69 StridedInput {
70 data: a.cpu_backed_slice_or_panic("accumulate_tangent"),
71 strides: &a.strides,
72 offset: a.offset,
73 },
74 StridedInput {
75 data: b.cpu_backed_slice_or_panic("accumulate_tangent"),
76 strides: &b.strides,
77 offset: b.offset,
78 },
79 &mut data,
80 &dst_strides,
81 );
82 }
83 Tensor::from_owned_contiguous_data(
84 data,
85 a.dims.clone(),
86 MemoryOrder::ColumnMajor,
87 a.logical_memory_space,
88 a.preferred_compute_device,
89 false,
90 )
91}