tenferro_linalg/rrules/
lu_eigen.rs

1use super::*;
2use num_traits::Float;
3use tenferro_algebra::Conjugate;
4
5/// Reverse-mode AD rule for LU (VJP / pullback).
6///
7/// The `pivot` argument must match the pivoting strategy used in the forward pass.
8///
9/// # Examples
10///
11/// ```
12/// use tenferro_linalg::{lu_rrule, LuCotangent, LuPivot};
13/// use tenferro_prims::CpuContext;
14/// use tenferro_tensor::{Tensor, MemoryOrder};
15/// use tenferro_device::LogicalMemorySpace;
16///
17/// let col = MemoryOrder::ColumnMajor;
18/// let mem = LogicalMemorySpace::MainMemory;
19/// let mut ctx = CpuContext::new(1);
20/// let a = Tensor::from_slice(&[1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0], &[3, 3], col)
21///     .unwrap();
22/// let cotangent = LuCotangent {
23///     l: Some(Tensor::ones(&[3, 3], mem, col).unwrap()),
24///     u: None,
25/// };
26/// let grad_a = lu_rrule(&mut ctx, &a, &cotangent, LuPivot::Partial).unwrap();
27/// ```
28pub fn lu_rrule<T, C>(
29    ctx: &mut C,
30    tensor: &Tensor<T>,
31    cotangent: &LuCotangent<T>,
32    pivot: LuPivot,
33) -> AdResult<Tensor<T>>
34where
35    T: KernelLinalgScalar,
36    C: backend::TensorLinalgContextFor<T>
37        + tenferro_prims::TensorMetadataContextFor
38        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
39    C::MetadataBackend: tenferro_prims::TensorMetadataPrims<Context = C>,
40    <C as tenferro_prims::TensorScalarContextFor<
41        tenferro_algebra::Standard<T::Real>,
42    >>::ScalarBackend: tenferro_prims::TensorMetadataCastPrims<T::Real, Context = C>,
43    T: crate::primal::LiftPermutationMatrixTensor<C>,
44    C::Backend: 'static,
45{
46    let result = lu(ctx, tensor, pivot)
47        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
48    let (m, n, batch_dims) = validate_2d(tensor)
49        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
50    let k = m.min(n);
51    let bc = batch_count(batch_dims);
52
53    if let Some(ref dl) = cotangent.l {
54        if dl.dims() != result.l.dims() {
55            return Err(to_ad_err(Error::InvalidArgument(format!(
56                "lu_rrule L cotangent shape mismatch: expected {:?}, got {:?}",
57                result.l.dims(),
58                dl.dims()
59            ))));
60        }
61    }
62    if let Some(ref du) = cotangent.u {
63        if du.dims() != result.u.dims() {
64            return Err(to_ad_err(Error::InvalidArgument(format!(
65                "lu_rrule U cotangent shape mismatch: expected {:?}, got {:?}",
66                result.u.dims(),
67                du.dims()
68            ))));
69        }
70    }
71
72    let (l_data, _) = extract_data(&result.l)?;
73    let (u_data, _) = extract_data(&result.u)?;
74    let dl_data = if let Some(ref dl) = cotangent.l {
75        Some(extract_data(dl)?.0)
76    } else {
77        None
78    };
79    let du_data = if let Some(ref du) = cotangent.u {
80        Some(extract_data(du)?.0)
81    } else {
82        None
83    };
84    let p_vec = crate::forward_perm_from_permutation_matrix(&result.p, m, bc)
85        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
86
87    let mut grad_a = vec![T::zero(); m * n * bc];
88
89    for b in 0..bc {
90        let l_b = &l_data[b * m * k..(b + 1) * m * k];
91        let u_b = &u_data[b * k * n..(b + 1) * k * n];
92        let dl_b = dl_data
93            .as_ref()
94            .map(|data| &data[b * m * k..(b + 1) * m * k]);
95        let du_b = du_data
96            .as_ref()
97            .map(|data| &data[b * k * n..(b + 1) * k * n]);
98
99        let batch_grad = if m == n {
100            let l_h = adjoint_transpose(l_b, k, k);
101            let mut inner = vec![T::zero(); k * k];
102
103            if let Some(dl_b) = dl_b {
104                let lt_dl = backend_mat_mul(ctx, &l_h, k, k, dl_b, k)?;
105                inner = add_vec(&inner, &tril_strict(&lt_dl, k));
106            }
107            if let Some(du_b) = du_b {
108                let du_ut = backend_mat_mul(ctx, du_b, k, k, &adjoint_transpose(u_b, k, k), k)?;
109                inner = add_vec(&inner, &triu(&du_ut, k));
110            }
111
112            let left = backend_solve_tri(ctx, &l_h, &inner, k, k, true)?;
113            let grad_h = backend_solve_tri(ctx, u_b, &adjoint_transpose(&left, k, k), k, k, true)?;
114            adjoint_transpose(&grad_h, k, k)
115        } else if m < n {
116            let l_h = adjoint_transpose(l_b, k, k);
117            let u1 = u_b[..k * k].to_vec();
118            let u2 = u_b[k * k..].to_vec();
119            let mut lower_source = vec![T::zero(); k * k];
120            if let Some(dl_b) = dl_b {
121                let lt_dl = backend_mat_mul(ctx, &l_h, k, k, dl_b, k)?;
122                lower_source = add_vec(&lower_source, &lt_dl);
123            }
124            if let Some(du_b) = du_b.filter(|_| n > k) {
125                let du2 = &du_b[k * k..];
126                let du2_u2h =
127                    backend_mat_mul(ctx, du2, k, n - k, &adjoint_transpose(&u2, k, n - k), k)?;
128                lower_source = sub_vec(&lower_source, &du2_u2h);
129            }
130
131            let mut inner = tril_strict(&lower_source, k);
132            if let Some(du_b) = du_b {
133                let du1 = &du_b[..k * k];
134                let du1_u1h = backend_mat_mul(ctx, du1, k, k, &adjoint_transpose(&u1, k, k), k)?;
135                inner = add_vec(&inner, &triu(&du1_u1h, k));
136            }
137
138            let leading_h = backend_solve_tri(
139                ctx,
140                u1.as_slice(),
141                &adjoint_transpose(&inner, k, k),
142                k,
143                k,
144                true,
145            )?;
146            let leading = adjoint_transpose(&leading_h, k, k);
147
148            let mut pre_left = vec![T::zero(); k * n];
149            pre_left[..k * k].copy_from_slice(&leading);
150            if let Some(du_b) = du_b.filter(|_| n > k) {
151                pre_left[k * k..].copy_from_slice(&du_b[k * k..]);
152            }
153
154            backend_solve_tri(ctx, &l_h, &pre_left, k, n, true)?
155        } else {
156            let mut l1 = vec![T::zero(); k * k];
157            let mut l2 = vec![T::zero(); (m - k) * k];
158            for j in 0..k {
159                for i in 0..k {
160                    l1[i + j * k] = l_b[i + j * m];
161                }
162                for i in k..m {
163                    l2[(i - k) + j * (m - k)] = l_b[i + j * m];
164                }
165            }
166            let l1_h = adjoint_transpose(&l1, k, k);
167
168            let mut inner = vec![T::zero(); k * k];
169            if let Some(dl_b) = dl_b {
170                let mut dl1 = vec![T::zero(); k * k];
171                let mut dl2 = vec![T::zero(); (m - k) * k];
172                for j in 0..k {
173                    for i in 0..k {
174                        dl1[i + j * k] = dl_b[i + j * m];
175                    }
176                    for i in k..m {
177                        dl2[(i - k) + j * (m - k)] = dl_b[i + j * m];
178                    }
179                }
180                let l1h_dl1 = backend_mat_mul(ctx, &l1_h, k, k, &dl1, k)?;
181                inner = add_vec(&inner, &tril_strict(&l1h_dl1, k));
182                if m > k {
183                    let l2h_dl2 =
184                        backend_mat_mul(ctx, &adjoint_transpose(&l2, m - k, k), k, m - k, &dl2, k)?;
185                    inner = sub_vec(&inner, &triu(&l2h_dl2, k));
186                }
187            }
188            if let Some(du_b) = du_b {
189                let du_term = backend_mat_mul(ctx, du_b, k, k, &adjoint_transpose(u_b, k, k), k)?;
190                inner = add_vec(&inner, &triu(&du_term, k));
191            }
192
193            let leading = backend_solve_tri(ctx, &l1_h, &inner, k, k, true)?;
194
195            let mut pre_right = vec![T::zero(); m * k];
196            for j in 0..k {
197                for i in 0..k {
198                    pre_right[i + j * m] = leading[i + j * k];
199                }
200            }
201            if let Some(dl_b) = dl_b {
202                for j in 0..k {
203                    for i in k..m {
204                        pre_right[i + j * m] = dl_b[i + j * m];
205                    }
206                }
207            }
208
209            let batch_grad_h =
210                backend_solve_tri(ctx, u_b, &adjoint_transpose(&pre_right, m, k), k, m, true)?;
211            adjoint_transpose(&batch_grad_h, k, m)
212        };
213
214        let out = &mut grad_a[b * m * n..(b + 1) * m * n];
215        let p_b = &p_vec[b * m..(b + 1) * m];
216        for j in 0..n {
217            for i in 0..m {
218                out[p_b[i] + j * m] = batch_grad[i + j * m];
219            }
220        }
221    }
222
223    let dims = output_dims(&[m, n], batch_dims);
224    tensor_from_data(grad_a, &dims)
225        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))
226}
227
228/// Reverse-mode AD rule for eigendecomposition (VJP / pullback).
229///
230/// # Examples
231///
232/// ```
233/// use tenferro_linalg::{eigen_rrule, EigenCotangent};
234/// use tenferro_prims::CpuContext;
235/// use tenferro_tensor::{Tensor, MemoryOrder};
236/// use tenferro_device::LogicalMemorySpace;
237///
238/// let col = MemoryOrder::ColumnMajor;
239/// let mem = LogicalMemorySpace::MainMemory;
240/// let mut ctx = CpuContext::new(1);
241/// let a = Tensor::<f64>::zeros(&[3, 3], mem, col).unwrap();
242/// let cotangent = EigenCotangent {
243///     values: Some(Tensor::ones(&[3], mem, col).unwrap()),
244///     vectors: None,
245/// };
246/// let grad_a = eigen_rrule(&mut ctx, &a, &cotangent).unwrap();
247/// ```
248pub fn eigen_rrule<T, C>(
249    ctx: &mut C,
250    tensor: &Tensor<T>,
251    cotangent: &EigenCotangent<T, T::Real>,
252) -> AdResult<Tensor<T>>
253where
254    T: KernelLinalgScalar + Conjugate,
255    T::Real: KernelLinalgScalar<Real = T::Real> + num_traits::Float,
256    C: backend::TensorLinalgContextFor<T>,
257    C::Backend: 'static,
258{
259    // Hermitian eigendecomposition: A = V diag(E) V^H
260    let result = eigen(ctx, tensor)
261        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
262    let (n, batch_dims) = validate_square(tensor)
263        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
264    let bc = batch_count(batch_dims);
265    // Regularization for the F-matrix: prevents division by zero when two
266    // singular values are (nearly) equal.  We use max(1e-40, T::epsilon())
267    // so that on f32 (where 1e-40 underflows to 0) we still get a safe floor.
268    let eta: T::Real = {
269        let raw: T::Real = scalar_from(1e-40).map_err(to_ad_err)?;
270        let eps = T::Real::epsilon();
271        if raw < eps {
272            eps
273        } else {
274            raw
275        }
276    };
277
278    let (v_data, _) = extract_data(&result.vectors)?;
279    let (e_data, _) = extract_data(&result.values)?;
280
281    let mut grad_a = vec![T::zero(); n * n * bc];
282
283    for b in 0..bc {
284        let v_b = &v_data[b * n * n..(b + 1) * n * n];
285        let e_b = &e_data[b * n..(b + 1) * n];
286
287        // Build F-matrix (n×n): F_ij = (e_i - e_j)/((e_i - e_j)^2 + eta), 0 diagonal.
288        let mut f_mat = vec![T::Real::zero(); n * n];
289        for i in 0..n {
290            for j in 0..n {
291                if i != j {
292                    let gap = e_b[i] - e_b[j];
293                    f_mat[i + j * n] = gap / (gap * gap + eta);
294                }
295            }
296        }
297
298        // Inner matrix D = diag(dE) + 1/2 * (H + H^H),
299        // where H = F ⊙ (V^H dV).
300        let mut d_mat = vec![T::zero(); n * n];
301
302        if let Some(ref de) = cotangent.values {
303            let (de_data, _) = extract_data(de)?;
304            let de_b = &de_data[b * n..(b + 1) * n];
305            for i in 0..n {
306                d_mat[i + i * n] = T::from_real(de_b[i]);
307            }
308        }
309
310        if let Some(ref dv) = cotangent.vectors {
311            let (dv_data, _) = extract_data(dv)?;
312            let dv_b = &dv_data[b * n * n..(b + 1) * n * n];
313            let dv_h_v = backend_mat_mul(ctx, &adjoint_transpose(dv_b, n, n), n, n, v_b, n)?;
314            let half: T::Real = scalar_from(0.5).map_err(to_ad_err)?;
315            for i in 0..n {
316                for j in 0..n {
317                    let h_ij = T::from_real(f_mat[i + j * n]) * dv_h_v[i + j * n];
318                    let h_h_ij = (T::from_real(f_mat[j + i * n]) * dv_h_v[j + i * n]).conj();
319                    d_mat[i + j * n] = d_mat[i + j * n] + (h_ij + h_h_ij) * T::from_real(half);
320                }
321            }
322        }
323
324        // dA = V D V^H
325        let vd = backend_mat_mul(ctx, v_b, n, n, &d_mat, n)?;
326        let da_b = backend_mat_mul(ctx, &vd, n, n, &adjoint_transpose(v_b, n, n), n)?;
327
328        grad_a[b * n * n..(b + 1) * n * n].copy_from_slice(&da_b);
329    }
330
331    let dims = output_dims(&[n, n], batch_dims);
332    tensor_from_data(grad_a, &dims)
333        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))
334}