1use super::*;
2use num_traits::Float;
3use tenferro_algebra::Conjugate;
4
5pub fn lu_rrule<T, C>(
29 ctx: &mut C,
30 tensor: &Tensor<T>,
31 cotangent: &LuCotangent<T>,
32 pivot: LuPivot,
33) -> AdResult<Tensor<T>>
34where
35 T: KernelLinalgScalar,
36 C: backend::TensorLinalgContextFor<T>
37 + tenferro_prims::TensorMetadataContextFor
38 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
39 C::MetadataBackend: tenferro_prims::TensorMetadataPrims<Context = C>,
40 <C as tenferro_prims::TensorScalarContextFor<
41 tenferro_algebra::Standard<T::Real>,
42 >>::ScalarBackend: tenferro_prims::TensorMetadataCastPrims<T::Real, Context = C>,
43 T: crate::primal::LiftPermutationMatrixTensor<C>,
44 C::Backend: 'static,
45{
46 let result = lu(ctx, tensor, pivot)
47 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
48 let (m, n, batch_dims) = validate_2d(tensor)
49 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
50 let k = m.min(n);
51 let bc = batch_count(batch_dims);
52
53 if let Some(ref dl) = cotangent.l {
54 if dl.dims() != result.l.dims() {
55 return Err(to_ad_err(Error::InvalidArgument(format!(
56 "lu_rrule L cotangent shape mismatch: expected {:?}, got {:?}",
57 result.l.dims(),
58 dl.dims()
59 ))));
60 }
61 }
62 if let Some(ref du) = cotangent.u {
63 if du.dims() != result.u.dims() {
64 return Err(to_ad_err(Error::InvalidArgument(format!(
65 "lu_rrule U cotangent shape mismatch: expected {:?}, got {:?}",
66 result.u.dims(),
67 du.dims()
68 ))));
69 }
70 }
71
72 let (l_data, _) = extract_data(&result.l)?;
73 let (u_data, _) = extract_data(&result.u)?;
74 let dl_data = if let Some(ref dl) = cotangent.l {
75 Some(extract_data(dl)?.0)
76 } else {
77 None
78 };
79 let du_data = if let Some(ref du) = cotangent.u {
80 Some(extract_data(du)?.0)
81 } else {
82 None
83 };
84 let p_vec = crate::forward_perm_from_permutation_matrix(&result.p, m, bc)
85 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
86
87 let mut grad_a = vec![T::zero(); m * n * bc];
88
89 for b in 0..bc {
90 let l_b = &l_data[b * m * k..(b + 1) * m * k];
91 let u_b = &u_data[b * k * n..(b + 1) * k * n];
92 let dl_b = dl_data
93 .as_ref()
94 .map(|data| &data[b * m * k..(b + 1) * m * k]);
95 let du_b = du_data
96 .as_ref()
97 .map(|data| &data[b * k * n..(b + 1) * k * n]);
98
99 let batch_grad = if m == n {
100 let l_h = adjoint_transpose(l_b, k, k);
101 let mut inner = vec![T::zero(); k * k];
102
103 if let Some(dl_b) = dl_b {
104 let lt_dl = backend_mat_mul(ctx, &l_h, k, k, dl_b, k)?;
105 inner = add_vec(&inner, &tril_strict(<_dl, k));
106 }
107 if let Some(du_b) = du_b {
108 let du_ut = backend_mat_mul(ctx, du_b, k, k, &adjoint_transpose(u_b, k, k), k)?;
109 inner = add_vec(&inner, &triu(&du_ut, k));
110 }
111
112 let left = backend_solve_tri(ctx, &l_h, &inner, k, k, true)?;
113 let grad_h = backend_solve_tri(ctx, u_b, &adjoint_transpose(&left, k, k), k, k, true)?;
114 adjoint_transpose(&grad_h, k, k)
115 } else if m < n {
116 let l_h = adjoint_transpose(l_b, k, k);
117 let u1 = u_b[..k * k].to_vec();
118 let u2 = u_b[k * k..].to_vec();
119 let mut lower_source = vec![T::zero(); k * k];
120 if let Some(dl_b) = dl_b {
121 let lt_dl = backend_mat_mul(ctx, &l_h, k, k, dl_b, k)?;
122 lower_source = add_vec(&lower_source, <_dl);
123 }
124 if let Some(du_b) = du_b.filter(|_| n > k) {
125 let du2 = &du_b[k * k..];
126 let du2_u2h =
127 backend_mat_mul(ctx, du2, k, n - k, &adjoint_transpose(&u2, k, n - k), k)?;
128 lower_source = sub_vec(&lower_source, &du2_u2h);
129 }
130
131 let mut inner = tril_strict(&lower_source, k);
132 if let Some(du_b) = du_b {
133 let du1 = &du_b[..k * k];
134 let du1_u1h = backend_mat_mul(ctx, du1, k, k, &adjoint_transpose(&u1, k, k), k)?;
135 inner = add_vec(&inner, &triu(&du1_u1h, k));
136 }
137
138 let leading_h = backend_solve_tri(
139 ctx,
140 u1.as_slice(),
141 &adjoint_transpose(&inner, k, k),
142 k,
143 k,
144 true,
145 )?;
146 let leading = adjoint_transpose(&leading_h, k, k);
147
148 let mut pre_left = vec![T::zero(); k * n];
149 pre_left[..k * k].copy_from_slice(&leading);
150 if let Some(du_b) = du_b.filter(|_| n > k) {
151 pre_left[k * k..].copy_from_slice(&du_b[k * k..]);
152 }
153
154 backend_solve_tri(ctx, &l_h, &pre_left, k, n, true)?
155 } else {
156 let mut l1 = vec![T::zero(); k * k];
157 let mut l2 = vec![T::zero(); (m - k) * k];
158 for j in 0..k {
159 for i in 0..k {
160 l1[i + j * k] = l_b[i + j * m];
161 }
162 for i in k..m {
163 l2[(i - k) + j * (m - k)] = l_b[i + j * m];
164 }
165 }
166 let l1_h = adjoint_transpose(&l1, k, k);
167
168 let mut inner = vec![T::zero(); k * k];
169 if let Some(dl_b) = dl_b {
170 let mut dl1 = vec![T::zero(); k * k];
171 let mut dl2 = vec![T::zero(); (m - k) * k];
172 for j in 0..k {
173 for i in 0..k {
174 dl1[i + j * k] = dl_b[i + j * m];
175 }
176 for i in k..m {
177 dl2[(i - k) + j * (m - k)] = dl_b[i + j * m];
178 }
179 }
180 let l1h_dl1 = backend_mat_mul(ctx, &l1_h, k, k, &dl1, k)?;
181 inner = add_vec(&inner, &tril_strict(&l1h_dl1, k));
182 if m > k {
183 let l2h_dl2 =
184 backend_mat_mul(ctx, &adjoint_transpose(&l2, m - k, k), k, m - k, &dl2, k)?;
185 inner = sub_vec(&inner, &triu(&l2h_dl2, k));
186 }
187 }
188 if let Some(du_b) = du_b {
189 let du_term = backend_mat_mul(ctx, du_b, k, k, &adjoint_transpose(u_b, k, k), k)?;
190 inner = add_vec(&inner, &triu(&du_term, k));
191 }
192
193 let leading = backend_solve_tri(ctx, &l1_h, &inner, k, k, true)?;
194
195 let mut pre_right = vec![T::zero(); m * k];
196 for j in 0..k {
197 for i in 0..k {
198 pre_right[i + j * m] = leading[i + j * k];
199 }
200 }
201 if let Some(dl_b) = dl_b {
202 for j in 0..k {
203 for i in k..m {
204 pre_right[i + j * m] = dl_b[i + j * m];
205 }
206 }
207 }
208
209 let batch_grad_h =
210 backend_solve_tri(ctx, u_b, &adjoint_transpose(&pre_right, m, k), k, m, true)?;
211 adjoint_transpose(&batch_grad_h, k, m)
212 };
213
214 let out = &mut grad_a[b * m * n..(b + 1) * m * n];
215 let p_b = &p_vec[b * m..(b + 1) * m];
216 for j in 0..n {
217 for i in 0..m {
218 out[p_b[i] + j * m] = batch_grad[i + j * m];
219 }
220 }
221 }
222
223 let dims = output_dims(&[m, n], batch_dims);
224 tensor_from_data(grad_a, &dims)
225 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))
226}
227
228pub fn eigen_rrule<T, C>(
249 ctx: &mut C,
250 tensor: &Tensor<T>,
251 cotangent: &EigenCotangent<T, T::Real>,
252) -> AdResult<Tensor<T>>
253where
254 T: KernelLinalgScalar + Conjugate,
255 T::Real: KernelLinalgScalar<Real = T::Real> + num_traits::Float,
256 C: backend::TensorLinalgContextFor<T>,
257 C::Backend: 'static,
258{
259 let result = eigen(ctx, tensor)
261 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
262 let (n, batch_dims) = validate_square(tensor)
263 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
264 let bc = batch_count(batch_dims);
265 let eta: T::Real = {
269 let raw: T::Real = scalar_from(1e-40).map_err(to_ad_err)?;
270 let eps = T::Real::epsilon();
271 if raw < eps {
272 eps
273 } else {
274 raw
275 }
276 };
277
278 let (v_data, _) = extract_data(&result.vectors)?;
279 let (e_data, _) = extract_data(&result.values)?;
280
281 let mut grad_a = vec![T::zero(); n * n * bc];
282
283 for b in 0..bc {
284 let v_b = &v_data[b * n * n..(b + 1) * n * n];
285 let e_b = &e_data[b * n..(b + 1) * n];
286
287 let mut f_mat = vec![T::Real::zero(); n * n];
289 for i in 0..n {
290 for j in 0..n {
291 if i != j {
292 let gap = e_b[i] - e_b[j];
293 f_mat[i + j * n] = gap / (gap * gap + eta);
294 }
295 }
296 }
297
298 let mut d_mat = vec![T::zero(); n * n];
301
302 if let Some(ref de) = cotangent.values {
303 let (de_data, _) = extract_data(de)?;
304 let de_b = &de_data[b * n..(b + 1) * n];
305 for i in 0..n {
306 d_mat[i + i * n] = T::from_real(de_b[i]);
307 }
308 }
309
310 if let Some(ref dv) = cotangent.vectors {
311 let (dv_data, _) = extract_data(dv)?;
312 let dv_b = &dv_data[b * n * n..(b + 1) * n * n];
313 let dv_h_v = backend_mat_mul(ctx, &adjoint_transpose(dv_b, n, n), n, n, v_b, n)?;
314 let half: T::Real = scalar_from(0.5).map_err(to_ad_err)?;
315 for i in 0..n {
316 for j in 0..n {
317 let h_ij = T::from_real(f_mat[i + j * n]) * dv_h_v[i + j * n];
318 let h_h_ij = (T::from_real(f_mat[j + i * n]) * dv_h_v[j + i * n]).conj();
319 d_mat[i + j * n] = d_mat[i + j * n] + (h_ij + h_h_ij) * T::from_real(half);
320 }
321 }
322 }
323
324 let vd = backend_mat_mul(ctx, v_b, n, n, &d_mat, n)?;
326 let da_b = backend_mat_mul(ctx, &vd, n, n, &adjoint_transpose(v_b, n, n), n)?;
327
328 grad_a[b * n * n..(b + 1) * n * n].copy_from_slice(&da_b);
329 }
330
331 let dims = output_dims(&[n, n], batch_dims);
332 tensor_from_data(grad_a, &dims)
333 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))
334}