1use super::*;
2
3fn rhs_output_dims(core_rows: usize, nrhs: usize, batch_dims: &[usize]) -> Vec<usize> {
4 let core_dims = if nrhs == 1 {
5 vec![core_rows]
6 } else {
7 vec![core_rows, nrhs]
8 };
9 output_dims(&core_dims, batch_dims)
10}
11
12pub fn lstsq_rrule<
34 T: KernelLinalgScalar<Real = T>
35 + num_traits::Float
36 + tenferro_algebra::Conjugate
37 + crate::prims_bridge::ScaleTensorByRealSameShape<C>
38 + tenferro_tensor::KeepCountScalar,
39 C,
40>(
41 ctx: &mut C,
42 a: &Tensor<T>,
43 b: &Tensor<T>,
44 cotangent_solution: Option<&Tensor<T>>,
45 cotangent_residuals: Option<&Tensor<T::Real>>,
46) -> AdResult<LstsqGrad<T>>
47where
48 T: KernelLinalgScalar,
49 C: backend::TensorLinalgContextFor<T>
50 + tenferro_prims::TensorResolveConjContextFor<T>
51 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
52 + tenferro_prims::TensorSemiringContextFor<tenferro_algebra::Standard<T>>,
53 C::Backend: 'static,
54{
55 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Lstsq, "lstsq_rrule")
56 .map_err(to_ad_err)?;
57
58 let result = lstsq(ctx, a, b)
59 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
60 let (m, n, batch_dims) = validate_2d(a)
61 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
62 let bc = batch_count(batch_dims);
63 let nrhs = if b.ndim() == 1 + batch_dims.len() {
64 1
65 } else {
66 b.dims()[1]
67 };
68 let rhs_is_vector = nrhs == 1 && b.ndim() == 1 + batch_dims.len();
69
70 if let Some(cotangent_solution) = cotangent_solution {
71 if cotangent_solution.dims() != result.solution.dims() {
72 return Err(to_ad_err(Error::InvalidArgument(format!(
73 "lstsq_rrule solution cotangent shape mismatch: expected {:?}, got {:?}",
74 result.solution.dims(),
75 cotangent_solution.dims()
76 ))));
77 }
78 }
79 if let Some(cotangent_residuals) = cotangent_residuals {
80 if cotangent_residuals.dims() != result.residuals.dims() {
81 return Err(to_ad_err(Error::InvalidArgument(format!(
82 "lstsq_rrule residual cotangent shape mismatch: expected {:?}, got {:?}",
83 result.residuals.dims(),
84 cotangent_residuals.dims()
85 ))));
86 }
87 }
88 if cotangent_solution.is_none() && cotangent_residuals.is_none() {
89 let a_dims = output_dims(&[m, n], batch_dims);
90 let b_dims = rhs_output_dims(m, nrhs, batch_dims);
91 return Ok(LstsqGrad {
92 a: Tensor::<T>::zeros(
93 &a_dims,
94 a.logical_memory_space(),
95 tenferro_tensor::MemoryOrder::ColumnMajor,
96 )
97 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
98 b: Tensor::<T>::zeros(
99 &b_dims,
100 b.logical_memory_space(),
101 tenferro_tensor::MemoryOrder::ColumnMajor,
102 )
103 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
104 });
105 }
106
107 let (a_data, _) = extract_data(a)?;
108 let (b_data, _) = extract_data(b)?;
109 let (x_data, _) = extract_data(&result.solution)?;
110 let dx_data = cotangent_solution.map(|tensor| extract_data(tensor).map(|(data, _)| data));
111 let dx_data = dx_data.transpose()?;
112 let dresidual_data =
113 cotangent_residuals.map(|tensor| extract_data(tensor).map(|(data, _)| data));
114 let dresidual_data = dresidual_data.transpose()?;
115 let two = scalar_from::<T>(2.0)
116 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
117 let mut grad_a_data = vec![T::zero(); m * n * bc];
118 let mut grad_b_data = vec![T::zero(); m * nrhs * bc];
119
120 if let Some(dx_data) = dx_data.as_ref() {
121 let pinv_a = pinv(ctx, a, None)
122 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
123 let (ap_data, _) = extract_data(&pinv_a)?;
124 let mut cotangent_pinv_data = vec![T::zero(); n * m * bc];
125
126 for batch in 0..bc {
127 let ap_b = &ap_data[batch * n * m..(batch + 1) * n * m];
128 let b_b = &b_data[batch * m * nrhs..(batch + 1) * m * nrhs];
129 let dx_b = &dx_data[batch * n * nrhs..(batch + 1) * n * nrhs];
130 let cotangent_pinv_b =
131 backend_mat_mul(ctx, dx_b, n, nrhs, &transpose(b_b, m, nrhs), m)?;
132 cotangent_pinv_data[batch * n * m..(batch + 1) * n * m]
133 .copy_from_slice(&cotangent_pinv_b);
134
135 let grad_b_solution =
136 backend_mat_mul(ctx, &adjoint_transpose(ap_b, n, m), m, n, dx_b, nrhs)?;
137 for i in 0..m * nrhs {
138 grad_b_data[batch * m * nrhs + i] =
139 grad_b_data[batch * m * nrhs + i] + grad_b_solution[i];
140 }
141 }
142
143 let cotangent_pinv =
144 tensor_from_data(cotangent_pinv_data, &output_dims(&[n, m], batch_dims))
145 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
146 let grad_a_solution = pinv_rrule(ctx, a, &cotangent_pinv, None)
147 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
148 let (grad_a_solution_data, _) = extract_data(&grad_a_solution)?;
149 for (slot, value) in grad_a_data.iter_mut().zip(grad_a_solution_data.into_iter()) {
150 *slot = *slot + value;
151 }
152 }
153
154 for batch in 0..bc {
155 let a_b = &a_data[batch * m * n..(batch + 1) * m * n];
156 let x_b = &x_data[batch * n * nrhs..(batch + 1) * n * nrhs];
157 let b_b = &b_data[batch * m * nrhs..(batch + 1) * m * nrhs];
158
159 if let Some(dresidual_data) = dresidual_data.as_ref().filter(|data| !data.is_empty()) {
160 let ax = backend_mat_mul(ctx, a_b, m, n, x_b, nrhs)?;
161 for col in 0..nrhs {
162 let weight = if rhs_is_vector {
163 dresidual_data[batch]
164 } else {
165 dresidual_data[batch * nrhs + col]
166 };
167 for row in 0..m {
168 let rhs_idx = row + col * m;
169 let residual = ax[rhs_idx] - b_b[rhs_idx];
170 grad_b_data[batch * m * nrhs + rhs_idx] =
171 grad_b_data[batch * m * nrhs + rhs_idx] - two * weight * residual;
172 for k in 0..n {
173 let a_idx = batch * m * n + row + k * m;
174 let x_idx = if nrhs == 1 { k } else { k + col * n };
175 grad_a_data[a_idx] =
176 grad_a_data[a_idx] + two * weight * residual * x_b[x_idx];
177 }
178 }
179 }
180 }
181 }
182
183 let a_dims = output_dims(&[m, n], batch_dims);
184 let b_dims = rhs_output_dims(m, nrhs, batch_dims);
185 Ok(LstsqGrad {
186 a: tensor_from_data(grad_a_data, &a_dims)
187 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
188 b: tensor_from_data(grad_b_data, &b_dims)
189 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
190 })
191}
192
193pub fn cholesky_rrule<T: KernelLinalgScalar + tenferro_algebra::Conjugate, C>(
213 ctx: &mut C,
214 tensor: &Tensor<T>,
215 cotangent: &Tensor<T>,
216) -> AdResult<Tensor<T>>
217where
218 T: KernelLinalgScalar + tenferro_algebra::Conjugate,
219 C: backend::TensorLinalgContextFor<T>,
220 C::Backend: 'static,
221{
222 let l = cholesky(ctx, tensor)
224 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
225 let (n, batch_dims) = validate_square(tensor)
226 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
227 let bc = batch_count(batch_dims);
228
229 let (l_data, _) = extract_data(&l)?;
230 let (dl_data, _) = extract_data(cotangent)?;
231
232 let mut grad_a = vec![T::zero(); n * n * bc];
233
234 for b in 0..bc {
235 let l_b = &l_data[b * n * n..(b + 1) * n * n];
236 let dl_b = &dl_data[b * n * n..(b + 1) * n * n];
237
238 let lt_dl = backend_mat_mul(ctx, &adjoint_transpose(l_b, n, n), n, n, dl_b, n)?;
240 let s = tril(<_dl, n);
241
242 let s_sym = phi_star(&s, n)?;
244
245 let x = backend_solve_tri(ctx, &adjoint_transpose(l_b, n, n), &s_sym, n, n, true)?;
247
248 let xh = adjoint_transpose(&x, n, n);
250 let result_h = backend_solve_tri(ctx, &adjoint_transpose(l_b, n, n), &xh, n, n, true)?;
251 let da_b = adjoint_transpose(&result_h, n, n);
252
253 grad_a[b * n * n..(b + 1) * n * n].copy_from_slice(&da_b);
254 }
255
256 let dims = output_dims(&[n, n], batch_dims);
257 tensor_from_data(grad_a, &dims)
258 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use tenferro_prims::CpuContext;
265 use tenferro_tensor::MemoryOrder;
266
267 fn tensor_data<T: tenferro_algebra::Scalar + Copy>(tensor: &Tensor<T>) -> Vec<T> {
268 let contiguous = tensor.contiguous(MemoryOrder::ColumnMajor);
269 let offset = contiguous.offset() as usize;
270 let len = contiguous.dims().iter().product::<usize>().max(1);
271 contiguous.buffer().as_slice().unwrap()[offset..offset + len].to_vec()
272 }
273
274 #[test]
275 fn lstsq_rrule_returns_zero_grads_when_both_cotangents_are_none() {
276 let mut ctx = CpuContext::new(1);
277 let a = Tensor::from_slice(&[1.0_f64, 0.0, 0.0, 1.0], &[2, 2], MemoryOrder::ColumnMajor)
278 .unwrap();
279 let b = Tensor::from_slice(&[2.0_f64, 3.0], &[2], MemoryOrder::ColumnMajor).unwrap();
280
281 let grad = lstsq_rrule(&mut ctx, &a, &b, None, None).unwrap();
282 assert_eq!(tensor_data(&grad.a), vec![0.0, 0.0, 0.0, 0.0]);
283 assert_eq!(tensor_data(&grad.b), vec![0.0, 0.0]);
284 }
285
286 #[test]
287 fn lstsq_rrule_accepts_residual_summary_cotangent_for_multi_rhs() {
288 let mut ctx = CpuContext::new(1);
289 let a = Tensor::from_slice(
290 &[2.0_f64, 0.0, 0.0, 1.0, 1.0, 1.0],
291 &[3, 2],
292 MemoryOrder::ColumnMajor,
293 )
294 .unwrap();
295 let b = Tensor::from_slice(
296 &[1.0_f64, 3.0, 2.0, 0.0, 1.0, 4.0],
297 &[3, 2],
298 MemoryOrder::ColumnMajor,
299 )
300 .unwrap();
301 let dresiduals =
302 Tensor::from_slice(&[0.5_f64, -0.25], &[2], MemoryOrder::ColumnMajor).unwrap();
303
304 let grad = lstsq_rrule(&mut ctx, &a, &b, None, Some(&dresiduals)).unwrap();
305 assert_eq!(grad.a.dims(), &[3, 2]);
306 assert_eq!(grad.b.dims(), &[3, 2]);
307 assert!(tensor_data(&grad.a).iter().any(|value| value.abs() > 0.0));
308 assert!(tensor_data(&grad.b).iter().any(|value| value.abs() > 0.0));
309 }
310}