1use super::*;
2use num_traits::{Float, One};
3use tenferro_algebra::Conjugate;
4
5pub fn lu_frule<T, C>(
26 ctx: &mut C,
27 tensor: &Tensor<T>,
28 tangent: &Tensor<T>,
29 pivot: LuPivot,
30) -> AdResult<(LuResult<T>, LuResult<T>)>
31where
32 T: KernelLinalgScalar,
33 C: backend::TensorLinalgContextFor<T>
34 + tenferro_prims::TensorMetadataContextFor
35 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
36 C::MetadataBackend: tenferro_prims::TensorMetadataPrims<Context = C>,
37 <C as tenferro_prims::TensorScalarContextFor<
38 tenferro_algebra::Standard<T::Real>,
39 >>::ScalarBackend: tenferro_prims::TensorMetadataCastPrims<T::Real, Context = C>,
40 T: crate::primal::LiftPermutationMatrixTensor<C>,
41 C::Backend: 'static,
42{
43 let result = lu(ctx, tensor, pivot)
44 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
45 let (m, n, batch_dims) = validate_2d(tensor)
46 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
47 let k = m.min(n);
48 let bc = batch_count(batch_dims);
49
50 let (l_data, _) = extract_data(&result.l)?;
51 let (u_data, _) = extract_data(&result.u)?;
52 let p_vec = crate::forward_perm_from_permutation_matrix(&result.p, m, bc)
53 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
54 let (da_data, _) = extract_data(tangent)?;
55
56 let mut dl_data = vec![T::zero(); m * k * bc];
57 let mut du_data = vec![T::zero(); k * n * bc];
58
59 for b in 0..bc {
60 let l_b = &l_data[b * m * k..(b + 1) * m * k];
61 let u_b = &u_data[b * k * n..(b + 1) * k * n];
62 let da_b = &da_data[b * m * n..(b + 1) * m * n];
63
64 let mut pda = vec![T::zero(); m * n];
66 let p_b = &p_vec[b * m..(b + 1) * m];
67 for i in 0..m {
68 for j in 0..n {
69 pda[i + j * m] = da_b[p_b[i] + j * m];
70 }
71 }
72
73 if m == n {
74 let l_sq = l_b.to_vec();
75 let u_sq = u_b.to_vec();
76 let linv_pda = backend_solve_tri(ctx, &l_sq, &pda, k, k, false)?;
77 let f_h = backend_solve_tri(
78 ctx,
79 &adjoint_transpose(&u_sq, k, k),
80 &adjoint_transpose(&linv_pda, k, k),
81 k,
82 k,
83 false,
84 )?;
85 let f = adjoint_transpose(&f_h, k, k);
86 let lower_f = tril_strict(&f, k);
87 let upper_f = triu(&f, k);
88
89 let dl_b_vec = backend_mat_mul(ctx, &l_sq, k, k, &lower_f, k)?;
90 let du_b_vec = backend_mat_mul(ctx, &upper_f, k, k, &u_sq, k)?;
91 dl_data[b * m * k..(b + 1) * m * k].copy_from_slice(&dl_b_vec);
92 du_data[b * k * n..(b + 1) * k * n].copy_from_slice(&du_b_vec);
93 } else if m < n {
94 let l_sq = l_b.to_vec();
95 let u1 = u_b[..k * k].to_vec();
96 let u2 = u_b[k * k..].to_vec();
97 let pda1 = pda[..k * k].to_vec();
98 let pda2 = pda[k * k..].to_vec();
99
100 let linv_pda1 = backend_solve_tri(ctx, &l_sq, &pda1, k, k, false)?;
101 let f_h = backend_solve_tri(
102 ctx,
103 &adjoint_transpose(&u1, k, k),
104 &adjoint_transpose(&linv_pda1, k, k),
105 k,
106 k,
107 false,
108 )?;
109 let f = adjoint_transpose(&f_h, k, k);
110 let lower_f = tril_strict(&f, k);
111 let upper_f = triu(&f, k);
112
113 let dl_b_vec = backend_mat_mul(ctx, &l_sq, k, k, &lower_f, k)?;
114 let du1 = backend_mat_mul(ctx, &upper_f, k, k, &u1, k)?;
115 let du2 = if n > k {
116 let linv_pda2 = backend_solve_tri(ctx, &l_sq, &pda2, k, n - k, false)?;
117 let correction = backend_mat_mul(ctx, &lower_f, k, k, &u2, n - k)?;
118 sub_vec(&linv_pda2, &correction)
119 } else {
120 Vec::new()
121 };
122
123 dl_data[b * m * k..(b + 1) * m * k].copy_from_slice(&dl_b_vec);
124 du_data[b * k * n..b * k * n + k * k].copy_from_slice(&du1);
125 if n > k {
126 du_data[b * k * n + k * k..(b + 1) * k * n].copy_from_slice(&du2);
127 }
128 } else {
129 let mut l1 = vec![T::zero(); k * k];
130 let mut l2 = vec![T::zero(); (m - k) * k];
131 for j in 0..k {
132 for i in 0..k {
133 l1[i + j * k] = l_b[i + j * m];
134 }
135 for i in k..m {
136 l2[(i - k) + j * (m - k)] = l_b[i + j * m];
137 }
138 }
139 let u_sq = u_b.to_vec();
140
141 let mut pda1 = vec![T::zero(); k * k];
142 let mut pda2 = vec![T::zero(); (m - k) * k];
143 for j in 0..k {
144 for i in 0..k {
145 pda1[i + j * k] = pda[i + j * m];
146 }
147 for i in k..m {
148 pda2[(i - k) + j * (m - k)] = pda[i + j * m];
149 }
150 }
151
152 let linv_pda1 = backend_solve_tri(ctx, &l1, &pda1, k, k, false)?;
153 let f_h = backend_solve_tri(
154 ctx,
155 &adjoint_transpose(&u_sq, k, k),
156 &adjoint_transpose(&linv_pda1, k, k),
157 k,
158 k,
159 false,
160 )?;
161 let f = adjoint_transpose(&f_h, k, k);
162 let lower_f = tril_strict(&f, k);
163 let upper_f = triu(&f, k);
164
165 let dl1 = backend_mat_mul(ctx, &l1, k, k, &lower_f, k)?;
166 let du_b_vec = backend_mat_mul(ctx, &upper_f, k, k, &u_sq, k)?;
167 let dl2 = if m > k {
168 let pda2_uinv_h = backend_solve_tri(
169 ctx,
170 &adjoint_transpose(&u_sq, k, k),
171 &adjoint_transpose(&pda2, m - k, k),
172 k,
173 m - k,
174 false,
175 )?;
176 let pda2_uinv = adjoint_transpose(&pda2_uinv_h, k, m - k);
177 let correction = backend_mat_mul(ctx, &l2, m - k, k, &upper_f, k)?;
178 sub_vec(&pda2_uinv, &correction)
179 } else {
180 Vec::new()
181 };
182
183 for j in 0..k {
184 for i in 0..k {
185 dl_data[b * m * k + i + j * m] = dl1[i + j * k];
186 }
187 for i in k..m {
188 dl_data[b * m * k + i + j * m] = dl2[(i - k) + j * (m - k)];
189 }
190 }
191 du_data[b * k * n..(b + 1) * k * n].copy_from_slice(&du_b_vec);
192 }
193 }
194
195 let l_dims = output_dims(&[m, k], batch_dims);
196 let u_dims = output_dims(&[k, n], batch_dims);
197 let dresult = LuResult {
198 p: Tensor::zeros(
199 result.p.dims(),
200 result.p.logical_memory_space(),
201 MemoryOrder::ColumnMajor,
202 )
203 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
204 l: tensor_from_data(dl_data, &l_dims)
205 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
206 u: tensor_from_data(du_data, &u_dims)
207 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
208 };
209 Ok((result, dresult))
210}
211
212pub fn eigen_frule<T, C>(
230 ctx: &mut C,
231 tensor: &Tensor<T>,
232 tangent: &Tensor<T>,
233) -> AdResult<(EigenResult<T, T::Real>, EigenResult<T, T::Real>)>
234where
235 T: KernelLinalgScalar + Conjugate,
236 T::Real: KernelLinalgScalar<Real = T::Real> + num_traits::Float,
237 C: backend::TensorLinalgContextFor<T>,
238 C::Backend: 'static,
239{
240 let result = eigen(ctx, tensor)
241 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
242 let (n, batch_dims) = validate_square(tensor)
243 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
244 let bc = batch_count(batch_dims);
245 let eta: T::Real = {
249 let raw: T::Real = scalar_from(1e-40).map_err(to_ad_err)?;
250 let eps = T::Real::epsilon();
251 if raw < eps {
252 eps
253 } else {
254 raw
255 }
256 };
257
258 let (v_data, _) = extract_data(&result.vectors)?;
259 let (e_data, _) = extract_data(&result.values)?;
260 let (da_data, _) = extract_data(tangent)?;
261
262 let mut de_data = vec![T::Real::zero(); n * bc];
263 let mut dv_data = vec![T::zero(); n * n * bc];
264
265 for b in 0..bc {
266 let v_b = &v_data[b * n * n..(b + 1) * n * n];
267 let e_b = &e_data[b * n..(b + 1) * n];
268 let da_b = &da_data[b * n * n..(b + 1) * n * n];
269
270 let vh_da = backend_mat_mul(ctx, &adjoint_transpose(v_b, n, n), n, n, da_b, n)?;
272 let c = backend_mat_mul(ctx, &vh_da, n, n, v_b, n)?;
273
274 for i in 0..n {
276 de_data[b * n + i] = c[i + i * n].real_part();
277 }
278
279 let mut fc = vec![T::zero(); n * n];
281 for i in 0..n {
282 for j in 0..n {
283 if i != j {
284 let denom = e_b[j] - e_b[i];
285 let f_ij = T::Real::one()
286 / (denom
287 + eta
288 * if denom >= T::Real::zero() {
289 T::Real::one()
290 } else {
291 -T::Real::one()
292 });
293 fc[i + j * n] = T::from_real(f_ij) * c[i + j * n];
294 }
295 }
296 }
297 let dv_b_vec = backend_mat_mul(ctx, v_b, n, n, &fc, n)?;
298 dv_data[b * n * n..(b + 1) * n * n].copy_from_slice(&dv_b_vec);
299 }
300
301 let val_dims = output_dims(&[n], batch_dims);
302 let vec_dims = output_dims(&[n, n], batch_dims);
303 let dresult = EigenResult {
304 values: tensor_from_data(de_data, &val_dims)
305 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
306 vectors: tensor_from_data(dv_data, &vec_dims)
307 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
308 };
309 Ok((result, dresult))
310}