tenferro_linalg/frules/
least_squares.rs

1use super::*;
2
3fn rhs_output_dims(core_rows: usize, nrhs: usize, batch_dims: &[usize]) -> Vec<usize> {
4    let core_dims = if nrhs == 1 {
5        vec![core_rows]
6    } else {
7        vec![core_rows, nrhs]
8    };
9    output_dims(&core_dims, batch_dims)
10}
11
12/// Forward-mode AD rule for least squares (JVP / pushforward).
13///
14/// # Examples
15///
16/// ```
17/// use tenferro_linalg::lstsq_frule;
18/// use tenferro_prims::CpuContext;
19/// use tenferro_tensor::{Tensor, MemoryOrder};
20/// use tenferro_device::LogicalMemorySpace;
21///
22/// let col = MemoryOrder::ColumnMajor;
23/// let mem = LogicalMemorySpace::MainMemory;
24/// let mut ctx = CpuContext::new(1);
25/// let a = Tensor::from_slice(&[1.0, 0.0, 1.0, 0.0, 1.0, 1.0], &[3, 2], col).unwrap();
26/// let b = Tensor::from_slice(&[1.0, 2.0, 3.0], &[3], col).unwrap();
27/// let da = Tensor::<f64>::ones(&[3, 2], mem, col).unwrap();
28/// let db = Tensor::<f64>::ones(&[3], mem, col).unwrap();
29/// let (result, dresult) = lstsq_frule(&mut ctx, &a, &b, &da, &db).unwrap();
30/// ```
31pub fn lstsq_frule<
32    T: KernelLinalgScalar<Real = T>
33        + num_traits::Float
34        + tenferro_algebra::Conjugate
35        + crate::prims_bridge::ScaleTensorByRealSameShape<C>,
36    C,
37>(
38    ctx: &mut C,
39    a: &Tensor<T>,
40    b: &Tensor<T>,
41    tangent_a: &Tensor<T>,
42    tangent_b: &Tensor<T>,
43) -> AdResult<(LstsqResult<T, T::Real>, LstsqResult<T, T::Real>)>
44where
45    T: KernelLinalgScalar,
46    T::Real: LinalgScalar<Real = T::Real> + num_traits::Float + tenferro_tensor::KeepCountScalar,
47    C: backend::TensorLinalgContextFor<T>
48        + tenferro_prims::TensorResolveConjContextFor<T>
49        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
50        + tenferro_prims::TensorSemiringContextFor<tenferro_algebra::Standard<T>>
51        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
52    C::Backend: 'static,
53{
54    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Lstsq, "lstsq_frule")
55        .map_err(to_ad_err)?;
56
57    // dx = dA^+ * b + A^+ * db
58    // d residual_summaries = 2 * sum(real((A x - b) * conj(dA x - db))), per RHS
59    let result = lstsq(ctx, a, b)
60        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
61    let (pinv_a, dpinv_a) = pinv_frule(ctx, a, tangent_a, None)
62        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
63    let (m, n, batch_dims) = validate_2d(a)
64        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
65    let bc = batch_count(batch_dims);
66
67    let (x_data, _) = extract_data(&result.solution)?;
68    let (da_data, _) = extract_data(tangent_a)?;
69    let (ap_data, _) = extract_data(&pinv_a)?;
70    let (dap_data, _) = extract_data(&dpinv_a)?;
71    let (b_data, _) = extract_data(b)?;
72    let (db_data, _) = extract_data(tangent_b)?;
73    let nrhs = if b.ndim() == 1 + batch_dims.len() {
74        1
75    } else {
76        b.dims()[1]
77    };
78    let rhs_is_vector = nrhs == 1 && b.ndim() == 1 + batch_dims.len();
79    let aux = lstsq_aux(ctx, a)
80        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
81    let summarize_residuals = m > n
82        && crate::primal::lstsq_has_full_rank(&aux.rank, n)
83            .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
84    let two = scalar_from::<T::Real>(2.0)
85        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
86
87    let mut dx_data = vec![T::zero(); n * nrhs * bc];
88    let mut dresidual_data = vec![T::Real::zero(); bc * nrhs];
89    let (a_data, _) = extract_data(a)?;
90
91    for batch in 0..bc {
92        let x_b = &x_data[batch * n * nrhs..(batch + 1) * n * nrhs];
93        let a_b = &a_data[batch * m * n..(batch + 1) * m * n];
94        let ap_b = &ap_data[batch * n * m..(batch + 1) * n * m];
95        let dap_b = &dap_data[batch * n * m..(batch + 1) * n * m];
96        let b_b = &b_data[batch * m * nrhs..(batch + 1) * m * nrhs];
97        let da_b = &da_data[batch * m * n..(batch + 1) * m * n];
98        let db_b = &db_data[batch * m * nrhs..(batch + 1) * m * nrhs];
99
100        let dpinv_b = backend_mat_mul(ctx, dap_b, n, m, b_b, nrhs)?;
101        let pinv_db = backend_mat_mul(ctx, ap_b, n, m, db_b, nrhs)?;
102        let dx_b_vec = add_vec(&dpinv_b, &pinv_db);
103        dx_data[batch * n * nrhs..(batch + 1) * n * nrhs].copy_from_slice(&dx_b_vec);
104
105        if summarize_residuals {
106            let ax = backend_mat_mul(ctx, a_b, m, n, x_b, nrhs)?;
107            let da_x = backend_mat_mul(ctx, da_b, m, n, x_b, nrhs)?;
108            for col in 0..nrhs {
109                let mut acc = T::Real::zero();
110                for row in 0..m {
111                    let idx = row + col * m;
112                    let residual = ax[idx] - b_b[idx];
113                    let dresidual = da_x[idx] - db_b[idx];
114                    acc = acc + residual * dresidual;
115                }
116                dresidual_data[batch * nrhs + col] = two * acc;
117            }
118        }
119    }
120
121    let x_dims = rhs_output_dims(n, nrhs, batch_dims);
122    let dx = tensor_from_data(dx_data, &x_dims)
123        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
124    let dresiduals = if summarize_residuals {
125        let dims = crate::primal::residual_summary_output_dims(batch_dims, nrhs, rhs_is_vector);
126        tensor_from_data(dresidual_data, &dims)
127            .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?
128    } else {
129        crate::primal::empty_residual_summary::<T::Real>(a.logical_memory_space())
130            .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?
131    };
132    let dresult = LstsqResult {
133        solution: dx,
134        residuals: dresiduals,
135    };
136    Ok((result, dresult))
137}
138
139/// Forward-mode AD rule for Cholesky (JVP / pushforward).
140///
141/// # Examples
142///
143/// ```no_run
144/// use tenferro_linalg::cholesky_frule;
145/// use tenferro_prims::CpuContext;
146/// use tenferro_tensor::{Tensor, MemoryOrder};
147/// use tenferro_device::LogicalMemorySpace;
148///
149/// let col = MemoryOrder::ColumnMajor;
150/// let mem = LogicalMemorySpace::MainMemory;
151/// let mut ctx = CpuContext::new(1);
152/// let a = Tensor::<f64>::zeros(&[3, 3], mem, col).unwrap();
153/// let da = Tensor::<f64>::ones(&[3, 3], mem, col).unwrap();
154/// let (l, dl) = cholesky_frule(&mut ctx, &a, &da).unwrap();
155/// ```
156pub fn cholesky_frule<T: KernelLinalgScalar + tenferro_algebra::Conjugate, C>(
157    ctx: &mut C,
158    tensor: &Tensor<T>,
159    tangent: &Tensor<T>,
160) -> AdResult<(Tensor<T>, Tensor<T>)>
161where
162    T: KernelLinalgScalar + tenferro_algebra::Conjugate,
163    C: backend::TensorLinalgContextFor<T>,
164    C::Backend: 'static,
165{
166    // dL = L phi(L^{-1} dA L^{-H})
167    let l = cholesky(ctx, tensor)
168        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
169    let (n, batch_dims) = validate_square(tensor)
170        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
171    let bc = batch_count(batch_dims);
172
173    let (l_data, _) = extract_data(&l)?;
174    let (da_data, _) = extract_data(tangent)?;
175
176    let mut dl_data = vec![T::zero(); n * n * bc];
177
178    for b in 0..bc {
179        let l_b = &l_data[b * n * n..(b + 1) * n * n];
180        let da_b = &da_data[b * n * n..(b + 1) * n * n];
181
182        // L^{-1} dA: solve L x = dA
183        let linv_da = backend_solve_tri(ctx, l_b, da_b, n, n, false)?;
184        // (L^{-1} dA) L^{-H}: solve L x = (L^{-1} dA)^H, then adjoint back
185        let linv_da_linvh_h =
186            backend_solve_tri(ctx, l_b, &adjoint_transpose(&linv_da, n, n), n, n, false)?;
187        let inner = adjoint_transpose(&linv_da_linvh_h, n, n);
188
189        // phi(inner) = tril with diagonal halved
190        let phi_inner = phi(&inner, n)?;
191
192        // dL = L phi(inner)
193        let dl_b_vec = backend_mat_mul(ctx, l_b, n, n, &phi_inner, n)?;
194        dl_data[b * n * n..(b + 1) * n * n].copy_from_slice(&dl_b_vec);
195    }
196
197    let dims = output_dims(&[n, n], batch_dims);
198    let dl = tensor_from_data(dl_data, &dims)
199        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
200    Ok((l, dl))
201}