1use super::*;
2
3fn rhs_output_dims(core_rows: usize, nrhs: usize, batch_dims: &[usize]) -> Vec<usize> {
4 let core_dims = if nrhs == 1 {
5 vec![core_rows]
6 } else {
7 vec![core_rows, nrhs]
8 };
9 output_dims(&core_dims, batch_dims)
10}
11
12pub fn lstsq_frule<
32 T: KernelLinalgScalar<Real = T>
33 + num_traits::Float
34 + tenferro_algebra::Conjugate
35 + crate::prims_bridge::ScaleTensorByRealSameShape<C>,
36 C,
37>(
38 ctx: &mut C,
39 a: &Tensor<T>,
40 b: &Tensor<T>,
41 tangent_a: &Tensor<T>,
42 tangent_b: &Tensor<T>,
43) -> AdResult<(LstsqResult<T, T::Real>, LstsqResult<T, T::Real>)>
44where
45 T: KernelLinalgScalar,
46 T::Real: LinalgScalar<Real = T::Real> + num_traits::Float + tenferro_tensor::KeepCountScalar,
47 C: backend::TensorLinalgContextFor<T>
48 + tenferro_prims::TensorResolveConjContextFor<T>
49 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
50 + tenferro_prims::TensorSemiringContextFor<tenferro_algebra::Standard<T>>
51 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
52 C::Backend: 'static,
53{
54 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Lstsq, "lstsq_frule")
55 .map_err(to_ad_err)?;
56
57 let result = lstsq(ctx, a, b)
60 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
61 let (pinv_a, dpinv_a) = pinv_frule(ctx, a, tangent_a, None)
62 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
63 let (m, n, batch_dims) = validate_2d(a)
64 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
65 let bc = batch_count(batch_dims);
66
67 let (x_data, _) = extract_data(&result.solution)?;
68 let (da_data, _) = extract_data(tangent_a)?;
69 let (ap_data, _) = extract_data(&pinv_a)?;
70 let (dap_data, _) = extract_data(&dpinv_a)?;
71 let (b_data, _) = extract_data(b)?;
72 let (db_data, _) = extract_data(tangent_b)?;
73 let nrhs = if b.ndim() == 1 + batch_dims.len() {
74 1
75 } else {
76 b.dims()[1]
77 };
78 let rhs_is_vector = nrhs == 1 && b.ndim() == 1 + batch_dims.len();
79 let aux = lstsq_aux(ctx, a)
80 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
81 let summarize_residuals = m > n
82 && crate::primal::lstsq_has_full_rank(&aux.rank, n)
83 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
84 let two = scalar_from::<T::Real>(2.0)
85 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
86
87 let mut dx_data = vec![T::zero(); n * nrhs * bc];
88 let mut dresidual_data = vec![T::Real::zero(); bc * nrhs];
89 let (a_data, _) = extract_data(a)?;
90
91 for batch in 0..bc {
92 let x_b = &x_data[batch * n * nrhs..(batch + 1) * n * nrhs];
93 let a_b = &a_data[batch * m * n..(batch + 1) * m * n];
94 let ap_b = &ap_data[batch * n * m..(batch + 1) * n * m];
95 let dap_b = &dap_data[batch * n * m..(batch + 1) * n * m];
96 let b_b = &b_data[batch * m * nrhs..(batch + 1) * m * nrhs];
97 let da_b = &da_data[batch * m * n..(batch + 1) * m * n];
98 let db_b = &db_data[batch * m * nrhs..(batch + 1) * m * nrhs];
99
100 let dpinv_b = backend_mat_mul(ctx, dap_b, n, m, b_b, nrhs)?;
101 let pinv_db = backend_mat_mul(ctx, ap_b, n, m, db_b, nrhs)?;
102 let dx_b_vec = add_vec(&dpinv_b, &pinv_db);
103 dx_data[batch * n * nrhs..(batch + 1) * n * nrhs].copy_from_slice(&dx_b_vec);
104
105 if summarize_residuals {
106 let ax = backend_mat_mul(ctx, a_b, m, n, x_b, nrhs)?;
107 let da_x = backend_mat_mul(ctx, da_b, m, n, x_b, nrhs)?;
108 for col in 0..nrhs {
109 let mut acc = T::Real::zero();
110 for row in 0..m {
111 let idx = row + col * m;
112 let residual = ax[idx] - b_b[idx];
113 let dresidual = da_x[idx] - db_b[idx];
114 acc = acc + residual * dresidual;
115 }
116 dresidual_data[batch * nrhs + col] = two * acc;
117 }
118 }
119 }
120
121 let x_dims = rhs_output_dims(n, nrhs, batch_dims);
122 let dx = tensor_from_data(dx_data, &x_dims)
123 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
124 let dresiduals = if summarize_residuals {
125 let dims = crate::primal::residual_summary_output_dims(batch_dims, nrhs, rhs_is_vector);
126 tensor_from_data(dresidual_data, &dims)
127 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?
128 } else {
129 crate::primal::empty_residual_summary::<T::Real>(a.logical_memory_space())
130 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?
131 };
132 let dresult = LstsqResult {
133 solution: dx,
134 residuals: dresiduals,
135 };
136 Ok((result, dresult))
137}
138
139pub fn cholesky_frule<T: KernelLinalgScalar + tenferro_algebra::Conjugate, C>(
157 ctx: &mut C,
158 tensor: &Tensor<T>,
159 tangent: &Tensor<T>,
160) -> AdResult<(Tensor<T>, Tensor<T>)>
161where
162 T: KernelLinalgScalar + tenferro_algebra::Conjugate,
163 C: backend::TensorLinalgContextFor<T>,
164 C::Backend: 'static,
165{
166 let l = cholesky(ctx, tensor)
168 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
169 let (n, batch_dims) = validate_square(tensor)
170 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
171 let bc = batch_count(batch_dims);
172
173 let (l_data, _) = extract_data(&l)?;
174 let (da_data, _) = extract_data(tangent)?;
175
176 let mut dl_data = vec![T::zero(); n * n * bc];
177
178 for b in 0..bc {
179 let l_b = &l_data[b * n * n..(b + 1) * n * n];
180 let da_b = &da_data[b * n * n..(b + 1) * n * n];
181
182 let linv_da = backend_solve_tri(ctx, l_b, da_b, n, n, false)?;
184 let linv_da_linvh_h =
186 backend_solve_tri(ctx, l_b, &adjoint_transpose(&linv_da, n, n), n, n, false)?;
187 let inner = adjoint_transpose(&linv_da_linvh_h, n, n);
188
189 let phi_inner = phi(&inner, n)?;
191
192 let dl_b_vec = backend_mat_mul(ctx, l_b, n, n, &phi_inner, n)?;
194 dl_data[b * n * n..(b + 1) * n * n].copy_from_slice(&dl_b_vec);
195 }
196
197 let dims = output_dims(&[n, n], batch_dims);
198 let dl = tensor_from_data(dl_data, &dims)
199 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
200 Ok((l, dl))
201}