tenferro_linalg/rrules/
least_squares.rs

1use super::*;
2
3fn rhs_output_dims(core_rows: usize, nrhs: usize, batch_dims: &[usize]) -> Vec<usize> {
4    let core_dims = if nrhs == 1 {
5        vec![core_rows]
6    } else {
7        vec![core_rows, nrhs]
8    };
9    output_dims(&core_dims, batch_dims)
10}
11
12/// Reverse-mode AD rule for least squares (VJP / pullback).
13///
14/// Returns cotangents for both `A` and `b`.
15///
16/// # Examples
17///
18/// ```
19/// use tenferro_linalg::lstsq_rrule;
20/// use tenferro_prims::CpuContext;
21/// use tenferro_tensor::{Tensor, MemoryOrder};
22/// use tenferro_device::LogicalMemorySpace;
23///
24/// let col = MemoryOrder::ColumnMajor;
25/// let mem = LogicalMemorySpace::MainMemory;
26/// let mut ctx = CpuContext::new(1);
27/// let a = Tensor::from_slice(&[1.0, 0.0, 1.0, 0.0, 1.0, 1.0], &[3, 2], col).unwrap();
28/// let b = Tensor::from_slice(&[1.0, 2.0, 3.0], &[3], col).unwrap();
29/// let dx = Tensor::<f64>::ones(&[2], mem, col).unwrap();
30/// let grad = lstsq_rrule(&mut ctx, &a, &b, Some(&dx), None).unwrap();
31/// // grad.a: cotangent for A, grad.b: cotangent for b
32/// ```
33pub fn lstsq_rrule<
34    T: KernelLinalgScalar<Real = T>
35        + num_traits::Float
36        + tenferro_algebra::Conjugate
37        + crate::prims_bridge::ScaleTensorByRealSameShape<C>
38        + tenferro_tensor::KeepCountScalar,
39    C,
40>(
41    ctx: &mut C,
42    a: &Tensor<T>,
43    b: &Tensor<T>,
44    cotangent_solution: Option<&Tensor<T>>,
45    cotangent_residuals: Option<&Tensor<T::Real>>,
46) -> AdResult<LstsqGrad<T>>
47where
48    T: KernelLinalgScalar,
49    C: backend::TensorLinalgContextFor<T>
50        + tenferro_prims::TensorResolveConjContextFor<T>
51        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
52        + tenferro_prims::TensorSemiringContextFor<tenferro_algebra::Standard<T>>,
53    C::Backend: 'static,
54{
55    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Lstsq, "lstsq_rrule")
56        .map_err(to_ad_err)?;
57
58    let result = lstsq(ctx, a, b)
59        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
60    let (m, n, batch_dims) = validate_2d(a)
61        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
62    let bc = batch_count(batch_dims);
63    let nrhs = if b.ndim() == 1 + batch_dims.len() {
64        1
65    } else {
66        b.dims()[1]
67    };
68    let rhs_is_vector = nrhs == 1 && b.ndim() == 1 + batch_dims.len();
69
70    if let Some(cotangent_solution) = cotangent_solution {
71        if cotangent_solution.dims() != result.solution.dims() {
72            return Err(to_ad_err(Error::InvalidArgument(format!(
73                "lstsq_rrule solution cotangent shape mismatch: expected {:?}, got {:?}",
74                result.solution.dims(),
75                cotangent_solution.dims()
76            ))));
77        }
78    }
79    if let Some(cotangent_residuals) = cotangent_residuals {
80        if cotangent_residuals.dims() != result.residuals.dims() {
81            return Err(to_ad_err(Error::InvalidArgument(format!(
82                "lstsq_rrule residual cotangent shape mismatch: expected {:?}, got {:?}",
83                result.residuals.dims(),
84                cotangent_residuals.dims()
85            ))));
86        }
87    }
88    if cotangent_solution.is_none() && cotangent_residuals.is_none() {
89        let a_dims = output_dims(&[m, n], batch_dims);
90        let b_dims = rhs_output_dims(m, nrhs, batch_dims);
91        return Ok(LstsqGrad {
92            a: Tensor::<T>::zeros(
93                &a_dims,
94                a.logical_memory_space(),
95                tenferro_tensor::MemoryOrder::ColumnMajor,
96            )
97            .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
98            b: Tensor::<T>::zeros(
99                &b_dims,
100                b.logical_memory_space(),
101                tenferro_tensor::MemoryOrder::ColumnMajor,
102            )
103            .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
104        });
105    }
106
107    let (a_data, _) = extract_data(a)?;
108    let (b_data, _) = extract_data(b)?;
109    let (x_data, _) = extract_data(&result.solution)?;
110    let dx_data = cotangent_solution.map(|tensor| extract_data(tensor).map(|(data, _)| data));
111    let dx_data = dx_data.transpose()?;
112    let dresidual_data =
113        cotangent_residuals.map(|tensor| extract_data(tensor).map(|(data, _)| data));
114    let dresidual_data = dresidual_data.transpose()?;
115    let two = scalar_from::<T>(2.0)
116        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
117    let mut grad_a_data = vec![T::zero(); m * n * bc];
118    let mut grad_b_data = vec![T::zero(); m * nrhs * bc];
119
120    if let Some(dx_data) = dx_data.as_ref() {
121        let pinv_a = pinv(ctx, a, None)
122            .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
123        let (ap_data, _) = extract_data(&pinv_a)?;
124        let mut cotangent_pinv_data = vec![T::zero(); n * m * bc];
125
126        for batch in 0..bc {
127            let ap_b = &ap_data[batch * n * m..(batch + 1) * n * m];
128            let b_b = &b_data[batch * m * nrhs..(batch + 1) * m * nrhs];
129            let dx_b = &dx_data[batch * n * nrhs..(batch + 1) * n * nrhs];
130            let cotangent_pinv_b =
131                backend_mat_mul(ctx, dx_b, n, nrhs, &transpose(b_b, m, nrhs), m)?;
132            cotangent_pinv_data[batch * n * m..(batch + 1) * n * m]
133                .copy_from_slice(&cotangent_pinv_b);
134
135            let grad_b_solution =
136                backend_mat_mul(ctx, &adjoint_transpose(ap_b, n, m), m, n, dx_b, nrhs)?;
137            for i in 0..m * nrhs {
138                grad_b_data[batch * m * nrhs + i] =
139                    grad_b_data[batch * m * nrhs + i] + grad_b_solution[i];
140            }
141        }
142
143        let cotangent_pinv =
144            tensor_from_data(cotangent_pinv_data, &output_dims(&[n, m], batch_dims))
145                .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
146        let grad_a_solution = pinv_rrule(ctx, a, &cotangent_pinv, None)
147            .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
148        let (grad_a_solution_data, _) = extract_data(&grad_a_solution)?;
149        for (slot, value) in grad_a_data.iter_mut().zip(grad_a_solution_data.into_iter()) {
150            *slot = *slot + value;
151        }
152    }
153
154    for batch in 0..bc {
155        let a_b = &a_data[batch * m * n..(batch + 1) * m * n];
156        let x_b = &x_data[batch * n * nrhs..(batch + 1) * n * nrhs];
157        let b_b = &b_data[batch * m * nrhs..(batch + 1) * m * nrhs];
158
159        if let Some(dresidual_data) = dresidual_data.as_ref().filter(|data| !data.is_empty()) {
160            let ax = backend_mat_mul(ctx, a_b, m, n, x_b, nrhs)?;
161            for col in 0..nrhs {
162                let weight = if rhs_is_vector {
163                    dresidual_data[batch]
164                } else {
165                    dresidual_data[batch * nrhs + col]
166                };
167                for row in 0..m {
168                    let rhs_idx = row + col * m;
169                    let residual = ax[rhs_idx] - b_b[rhs_idx];
170                    grad_b_data[batch * m * nrhs + rhs_idx] =
171                        grad_b_data[batch * m * nrhs + rhs_idx] - two * weight * residual;
172                    for k in 0..n {
173                        let a_idx = batch * m * n + row + k * m;
174                        let x_idx = if nrhs == 1 { k } else { k + col * n };
175                        grad_a_data[a_idx] =
176                            grad_a_data[a_idx] + two * weight * residual * x_b[x_idx];
177                    }
178                }
179            }
180        }
181    }
182
183    let a_dims = output_dims(&[m, n], batch_dims);
184    let b_dims = rhs_output_dims(m, nrhs, batch_dims);
185    Ok(LstsqGrad {
186        a: tensor_from_data(grad_a_data, &a_dims)
187            .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
188        b: tensor_from_data(grad_b_data, &b_dims)
189            .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
190    })
191}
192
193/// Reverse-mode AD rule for Cholesky (VJP / pullback).
194///
195/// Given `A = L L†` and cotangent `L̄`, computes `Ā`.
196///
197/// # Examples
198///
199/// ```no_run
200/// use tenferro_linalg::cholesky_rrule;
201/// use tenferro_prims::CpuContext;
202/// use tenferro_tensor::{Tensor, MemoryOrder};
203/// use tenferro_device::LogicalMemorySpace;
204///
205/// let col = MemoryOrder::ColumnMajor;
206/// let mem = LogicalMemorySpace::MainMemory;
207/// let mut ctx = CpuContext::new(1);
208/// let a = Tensor::<f64>::zeros(&[3, 3], mem, col).unwrap();
209/// let cotangent = Tensor::<f64>::ones(&[3, 3], mem, col).unwrap();
210/// let grad_a = cholesky_rrule(&mut ctx, &a, &cotangent).unwrap();
211/// ```
212pub fn cholesky_rrule<T: KernelLinalgScalar + tenferro_algebra::Conjugate, C>(
213    ctx: &mut C,
214    tensor: &Tensor<T>,
215    cotangent: &Tensor<T>,
216) -> AdResult<Tensor<T>>
217where
218    T: KernelLinalgScalar + tenferro_algebra::Conjugate,
219    C: backend::TensorLinalgContextFor<T>,
220    C::Backend: 'static,
221{
222    // A = L L^H, dA = L^{-H} phi*(tril(L^H dL)) L^{-1}
223    let l = cholesky(ctx, tensor)
224        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
225    let (n, batch_dims) = validate_square(tensor)
226        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
227    let bc = batch_count(batch_dims);
228
229    let (l_data, _) = extract_data(&l)?;
230    let (dl_data, _) = extract_data(cotangent)?;
231
232    let mut grad_a = vec![T::zero(); n * n * bc];
233
234    for b in 0..bc {
235        let l_b = &l_data[b * n * n..(b + 1) * n * n];
236        let dl_b = &dl_data[b * n * n..(b + 1) * n * n];
237
238        // S = tril(L^H dL)
239        let lt_dl = backend_mat_mul(ctx, &adjoint_transpose(l_b, n, n), n, n, dl_b, n)?;
240        let s = tril(&lt_dl, n);
241
242        // Apply phi*: symmetrize S → (S + S^H - diag(S)) / 2
243        let s_sym = phi_star(&s, n)?;
244
245        // Solve L^H x = S_sym → x = L^{-H} S_sym
246        let x = backend_solve_tri(ctx, &adjoint_transpose(l_b, n, n), &s_sym, n, n, true)?;
247
248        // Solve x L = result → result^H = L^{-H} x^H
249        let xh = adjoint_transpose(&x, n, n);
250        let result_h = backend_solve_tri(ctx, &adjoint_transpose(l_b, n, n), &xh, n, n, true)?;
251        let da_b = adjoint_transpose(&result_h, n, n);
252
253        grad_a[b * n * n..(b + 1) * n * n].copy_from_slice(&da_b);
254    }
255
256    let dims = output_dims(&[n, n], batch_dims);
257    tensor_from_data(grad_a, &dims)
258        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use tenferro_prims::CpuContext;
265    use tenferro_tensor::MemoryOrder;
266
267    fn tensor_data<T: tenferro_algebra::Scalar + Copy>(tensor: &Tensor<T>) -> Vec<T> {
268        let contiguous = tensor.contiguous(MemoryOrder::ColumnMajor);
269        let offset = contiguous.offset() as usize;
270        let len = contiguous.dims().iter().product::<usize>().max(1);
271        contiguous.buffer().as_slice().unwrap()[offset..offset + len].to_vec()
272    }
273
274    #[test]
275    fn lstsq_rrule_returns_zero_grads_when_both_cotangents_are_none() {
276        let mut ctx = CpuContext::new(1);
277        let a = Tensor::from_slice(&[1.0_f64, 0.0, 0.0, 1.0], &[2, 2], MemoryOrder::ColumnMajor)
278            .unwrap();
279        let b = Tensor::from_slice(&[2.0_f64, 3.0], &[2], MemoryOrder::ColumnMajor).unwrap();
280
281        let grad = lstsq_rrule(&mut ctx, &a, &b, None, None).unwrap();
282        assert_eq!(tensor_data(&grad.a), vec![0.0, 0.0, 0.0, 0.0]);
283        assert_eq!(tensor_data(&grad.b), vec![0.0, 0.0]);
284    }
285
286    #[test]
287    fn lstsq_rrule_accepts_residual_summary_cotangent_for_multi_rhs() {
288        let mut ctx = CpuContext::new(1);
289        let a = Tensor::from_slice(
290            &[2.0_f64, 0.0, 0.0, 1.0, 1.0, 1.0],
291            &[3, 2],
292            MemoryOrder::ColumnMajor,
293        )
294        .unwrap();
295        let b = Tensor::from_slice(
296            &[1.0_f64, 3.0, 2.0, 0.0, 1.0, 4.0],
297            &[3, 2],
298            MemoryOrder::ColumnMajor,
299        )
300        .unwrap();
301        let dresiduals =
302            Tensor::from_slice(&[0.5_f64, -0.25], &[2], MemoryOrder::ColumnMajor).unwrap();
303
304        let grad = lstsq_rrule(&mut ctx, &a, &b, None, Some(&dresiduals)).unwrap();
305        assert_eq!(grad.a.dims(), &[3, 2]);
306        assert_eq!(grad.b.dims(), &[3, 2]);
307        assert!(tensor_data(&grad.a).iter().any(|value| value.abs() > 0.0));
308        assert!(tensor_data(&grad.b).iter().any(|value| value.abs() > 0.0));
309    }
310}