tenferro_linalg/frules/
lu_eigen.rs

1use super::*;
2use num_traits::{Float, One};
3use tenferro_algebra::Conjugate;
4
5/// Forward-mode AD rule for LU (JVP / pushforward).
6///
7/// The `pivot` argument must match the pivoting strategy used in the forward pass.
8///
9/// # Examples
10///
11/// ```
12/// use tenferro_linalg::{lu_frule, LuPivot};
13/// use tenferro_prims::CpuContext;
14/// use tenferro_tensor::{Tensor, MemoryOrder};
15/// use tenferro_device::LogicalMemorySpace;
16///
17/// let col = MemoryOrder::ColumnMajor;
18/// let mem = LogicalMemorySpace::MainMemory;
19/// let mut ctx = CpuContext::new(1);
20/// let a = Tensor::from_slice(&[1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0], &[3, 3], col)
21///     .unwrap();
22/// let da = Tensor::<f64>::ones(&[3, 3], mem, col).unwrap();
23/// let (result, dresult) = lu_frule(&mut ctx, &a, &da, LuPivot::Partial).unwrap();
24/// ```
25pub fn lu_frule<T, C>(
26    ctx: &mut C,
27    tensor: &Tensor<T>,
28    tangent: &Tensor<T>,
29    pivot: LuPivot,
30) -> AdResult<(LuResult<T>, LuResult<T>)>
31where
32    T: KernelLinalgScalar,
33    C: backend::TensorLinalgContextFor<T>
34        + tenferro_prims::TensorMetadataContextFor
35        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
36    C::MetadataBackend: tenferro_prims::TensorMetadataPrims<Context = C>,
37    <C as tenferro_prims::TensorScalarContextFor<
38        tenferro_algebra::Standard<T::Real>,
39    >>::ScalarBackend: tenferro_prims::TensorMetadataCastPrims<T::Real, Context = C>,
40    T: crate::primal::LiftPermutationMatrixTensor<C>,
41    C::Backend: 'static,
42{
43    let result = lu(ctx, tensor, pivot)
44        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
45    let (m, n, batch_dims) = validate_2d(tensor)
46        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
47    let k = m.min(n);
48    let bc = batch_count(batch_dims);
49
50    let (l_data, _) = extract_data(&result.l)?;
51    let (u_data, _) = extract_data(&result.u)?;
52    let p_vec = crate::forward_perm_from_permutation_matrix(&result.p, m, bc)
53        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
54    let (da_data, _) = extract_data(tangent)?;
55
56    let mut dl_data = vec![T::zero(); m * k * bc];
57    let mut du_data = vec![T::zero(); k * n * bc];
58
59    for b in 0..bc {
60        let l_b = &l_data[b * m * k..(b + 1) * m * k];
61        let u_b = &u_data[b * k * n..(b + 1) * k * n];
62        let da_b = &da_data[b * m * n..(b + 1) * m * n];
63
64        // Apply permutation: P dA (m×n)
65        let mut pda = vec![T::zero(); m * n];
66        let p_b = &p_vec[b * m..(b + 1) * m];
67        for i in 0..m {
68            for j in 0..n {
69                pda[i + j * m] = da_b[p_b[i] + j * m];
70            }
71        }
72
73        if m == n {
74            let l_sq = l_b.to_vec();
75            let u_sq = u_b.to_vec();
76            let linv_pda = backend_solve_tri(ctx, &l_sq, &pda, k, k, false)?;
77            let f_h = backend_solve_tri(
78                ctx,
79                &adjoint_transpose(&u_sq, k, k),
80                &adjoint_transpose(&linv_pda, k, k),
81                k,
82                k,
83                false,
84            )?;
85            let f = adjoint_transpose(&f_h, k, k);
86            let lower_f = tril_strict(&f, k);
87            let upper_f = triu(&f, k);
88
89            let dl_b_vec = backend_mat_mul(ctx, &l_sq, k, k, &lower_f, k)?;
90            let du_b_vec = backend_mat_mul(ctx, &upper_f, k, k, &u_sq, k)?;
91            dl_data[b * m * k..(b + 1) * m * k].copy_from_slice(&dl_b_vec);
92            du_data[b * k * n..(b + 1) * k * n].copy_from_slice(&du_b_vec);
93        } else if m < n {
94            let l_sq = l_b.to_vec();
95            let u1 = u_b[..k * k].to_vec();
96            let u2 = u_b[k * k..].to_vec();
97            let pda1 = pda[..k * k].to_vec();
98            let pda2 = pda[k * k..].to_vec();
99
100            let linv_pda1 = backend_solve_tri(ctx, &l_sq, &pda1, k, k, false)?;
101            let f_h = backend_solve_tri(
102                ctx,
103                &adjoint_transpose(&u1, k, k),
104                &adjoint_transpose(&linv_pda1, k, k),
105                k,
106                k,
107                false,
108            )?;
109            let f = adjoint_transpose(&f_h, k, k);
110            let lower_f = tril_strict(&f, k);
111            let upper_f = triu(&f, k);
112
113            let dl_b_vec = backend_mat_mul(ctx, &l_sq, k, k, &lower_f, k)?;
114            let du1 = backend_mat_mul(ctx, &upper_f, k, k, &u1, k)?;
115            let du2 = if n > k {
116                let linv_pda2 = backend_solve_tri(ctx, &l_sq, &pda2, k, n - k, false)?;
117                let correction = backend_mat_mul(ctx, &lower_f, k, k, &u2, n - k)?;
118                sub_vec(&linv_pda2, &correction)
119            } else {
120                Vec::new()
121            };
122
123            dl_data[b * m * k..(b + 1) * m * k].copy_from_slice(&dl_b_vec);
124            du_data[b * k * n..b * k * n + k * k].copy_from_slice(&du1);
125            if n > k {
126                du_data[b * k * n + k * k..(b + 1) * k * n].copy_from_slice(&du2);
127            }
128        } else {
129            let mut l1 = vec![T::zero(); k * k];
130            let mut l2 = vec![T::zero(); (m - k) * k];
131            for j in 0..k {
132                for i in 0..k {
133                    l1[i + j * k] = l_b[i + j * m];
134                }
135                for i in k..m {
136                    l2[(i - k) + j * (m - k)] = l_b[i + j * m];
137                }
138            }
139            let u_sq = u_b.to_vec();
140
141            let mut pda1 = vec![T::zero(); k * k];
142            let mut pda2 = vec![T::zero(); (m - k) * k];
143            for j in 0..k {
144                for i in 0..k {
145                    pda1[i + j * k] = pda[i + j * m];
146                }
147                for i in k..m {
148                    pda2[(i - k) + j * (m - k)] = pda[i + j * m];
149                }
150            }
151
152            let linv_pda1 = backend_solve_tri(ctx, &l1, &pda1, k, k, false)?;
153            let f_h = backend_solve_tri(
154                ctx,
155                &adjoint_transpose(&u_sq, k, k),
156                &adjoint_transpose(&linv_pda1, k, k),
157                k,
158                k,
159                false,
160            )?;
161            let f = adjoint_transpose(&f_h, k, k);
162            let lower_f = tril_strict(&f, k);
163            let upper_f = triu(&f, k);
164
165            let dl1 = backend_mat_mul(ctx, &l1, k, k, &lower_f, k)?;
166            let du_b_vec = backend_mat_mul(ctx, &upper_f, k, k, &u_sq, k)?;
167            let dl2 = if m > k {
168                let pda2_uinv_h = backend_solve_tri(
169                    ctx,
170                    &adjoint_transpose(&u_sq, k, k),
171                    &adjoint_transpose(&pda2, m - k, k),
172                    k,
173                    m - k,
174                    false,
175                )?;
176                let pda2_uinv = adjoint_transpose(&pda2_uinv_h, k, m - k);
177                let correction = backend_mat_mul(ctx, &l2, m - k, k, &upper_f, k)?;
178                sub_vec(&pda2_uinv, &correction)
179            } else {
180                Vec::new()
181            };
182
183            for j in 0..k {
184                for i in 0..k {
185                    dl_data[b * m * k + i + j * m] = dl1[i + j * k];
186                }
187                for i in k..m {
188                    dl_data[b * m * k + i + j * m] = dl2[(i - k) + j * (m - k)];
189                }
190            }
191            du_data[b * k * n..(b + 1) * k * n].copy_from_slice(&du_b_vec);
192        }
193    }
194
195    let l_dims = output_dims(&[m, k], batch_dims);
196    let u_dims = output_dims(&[k, n], batch_dims);
197    let dresult = LuResult {
198        p: Tensor::zeros(
199            result.p.dims(),
200            result.p.logical_memory_space(),
201            MemoryOrder::ColumnMajor,
202        )
203        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
204        l: tensor_from_data(dl_data, &l_dims)
205            .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
206        u: tensor_from_data(du_data, &u_dims)
207            .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
208    };
209    Ok((result, dresult))
210}
211
212/// Forward-mode AD rule for eigendecomposition (JVP / pushforward).
213///
214/// # Examples
215///
216/// ```
217/// use tenferro_linalg::eigen_frule;
218/// use tenferro_prims::CpuContext;
219/// use tenferro_tensor::{Tensor, MemoryOrder};
220/// use tenferro_device::LogicalMemorySpace;
221///
222/// let col = MemoryOrder::ColumnMajor;
223/// let mem = LogicalMemorySpace::MainMemory;
224/// let mut ctx = CpuContext::new(1);
225/// let a = Tensor::<f64>::zeros(&[3, 3], mem, col).unwrap();
226/// let da = Tensor::<f64>::ones(&[3, 3], mem, col).unwrap();
227/// let (result, dresult) = eigen_frule(&mut ctx, &a, &da).unwrap();
228/// ```
229pub fn eigen_frule<T, C>(
230    ctx: &mut C,
231    tensor: &Tensor<T>,
232    tangent: &Tensor<T>,
233) -> AdResult<(EigenResult<T, T::Real>, EigenResult<T, T::Real>)>
234where
235    T: KernelLinalgScalar + Conjugate,
236    T::Real: KernelLinalgScalar<Real = T::Real> + num_traits::Float,
237    C: backend::TensorLinalgContextFor<T>,
238    C::Backend: 'static,
239{
240    let result = eigen(ctx, tensor)
241        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
242    let (n, batch_dims) = validate_square(tensor)
243        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
244    let bc = batch_count(batch_dims);
245    // Regularization for the F-matrix: prevents division by zero when two
246    // singular values are (nearly) equal.  We use max(1e-40, T::epsilon())
247    // so that on f32 (where 1e-40 underflows to 0) we still get a safe floor.
248    let eta: T::Real = {
249        let raw: T::Real = scalar_from(1e-40).map_err(to_ad_err)?;
250        let eps = T::Real::epsilon();
251        if raw < eps {
252            eps
253        } else {
254            raw
255        }
256    };
257
258    let (v_data, _) = extract_data(&result.vectors)?;
259    let (e_data, _) = extract_data(&result.values)?;
260    let (da_data, _) = extract_data(tangent)?;
261
262    let mut de_data = vec![T::Real::zero(); n * bc];
263    let mut dv_data = vec![T::zero(); n * n * bc];
264
265    for b in 0..bc {
266        let v_b = &v_data[b * n * n..(b + 1) * n * n];
267        let e_b = &e_data[b * n..(b + 1) * n];
268        let da_b = &da_data[b * n * n..(b + 1) * n * n];
269
270        // C = V^H dA V (n×n)
271        let vh_da = backend_mat_mul(ctx, &adjoint_transpose(v_b, n, n), n, n, da_b, n)?;
272        let c = backend_mat_mul(ctx, &vh_da, n, n, v_b, n)?;
273
274        // dE = diag(C)
275        for i in 0..n {
276            de_data[b * n + i] = c[i + i * n].real_part();
277        }
278
279        // dV = V * (F ⊙ (C - diag(dE))) where F_ij = 1/(e_j - e_i) for i≠j.
280        let mut fc = vec![T::zero(); n * n];
281        for i in 0..n {
282            for j in 0..n {
283                if i != j {
284                    let denom = e_b[j] - e_b[i];
285                    let f_ij = T::Real::one()
286                        / (denom
287                            + eta
288                                * if denom >= T::Real::zero() {
289                                    T::Real::one()
290                                } else {
291                                    -T::Real::one()
292                                });
293                    fc[i + j * n] = T::from_real(f_ij) * c[i + j * n];
294                }
295            }
296        }
297        let dv_b_vec = backend_mat_mul(ctx, v_b, n, n, &fc, n)?;
298        dv_data[b * n * n..(b + 1) * n * n].copy_from_slice(&dv_b_vec);
299    }
300
301    let val_dims = output_dims(&[n], batch_dims);
302    let vec_dims = output_dims(&[n, n], batch_dims);
303    let dresult = EigenResult {
304        values: tensor_from_data(de_data, &val_dims)
305            .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
306        vectors: tensor_from_data(dv_data, &vec_dims)
307            .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
308    };
309    Ok((result, dresult))
310}