1use super::*;
2
3pub fn eig_frule<
28 T: KernelLinalgScalar<Real = T, Complex = num_complex::Complex<T>> + num_traits::Float,
29 C,
30>(
31 ctx: &mut C,
32 tensor: &Tensor<T>,
33 tangent: &Tensor<T>,
34) -> AdResult<(EigResult<T>, EigResult<T>)>
35where
36 T: KernelLinalgScalar,
37 C: backend::TensorLinalgContextFor<T>
38 + backend::TensorLinalgContextFor<num_complex::Complex<T>>
39 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
40 num_complex::Complex<T>: KernelLinalgScalar,
41 <C as backend::TensorLinalgContextFor<T>>::Backend: 'static,
42 <C as backend::TensorLinalgContextFor<num_complex::Complex<T>>>::Backend: 'static,
43 T::Real: tenferro_tensor::KeepCountScalar,
44{
45 let eig_result = eig(ctx, tensor).map_err(to_ad_err)?;
47
48 let (n, batch_dims) = validate_square(tensor).map_err(to_ad_err)?;
49 let bc = batch_count(batch_dims);
50
51 let val_data = extract_data_scalar(&eig_result.values)?;
52 let vec_data = extract_data_scalar(&eig_result.vectors)?;
53 let (tang_data, _) = extract_data(tangent)?;
54
55 let zero_c = Cx::new(T::zero(), T::zero());
56 let one_c = Cx::new(T::one(), T::zero());
57
58 let mut dval_data = vec![zero_c; n * bc];
59 let mut dvec_data = vec![zero_c; n * n * bc];
60
61 for b in 0..bc {
62 let lambda = &val_data[b * n..(b + 1) * n];
63 let v = &vec_data[b * n * n..(b + 1) * n * n];
64 let da = &tang_data[b * n * n..(b + 1) * n * n];
65
66 let da_complex: Vec<Cx<T>> = da.iter().map(|&x| Cx::new(x, T::zero())).collect();
68
69 let da_v = complex_mat_mul_nn_backend(ctx, &da_complex, v, n)?;
71 let w = complex_solve_nn(ctx, v, &da_v, n)?;
72
73 for i in 0..n {
75 dval_data[b * n + i] = w[i + i * n];
76 }
77
78 let mut f_mat = vec![zero_c; n * n];
80 for i in 0..n {
81 for j in 0..n {
82 if i != j {
83 let diff = lambda[j] - lambda[i];
84 f_mat[i + j * n] = one_c / diff;
85 }
86 }
87 }
88
89 let mut fw = vec![zero_c; n * n];
91 for k in 0..n * n {
92 fw[k] = f_mat[k] * w[k];
93 }
94 let dv_raw = complex_mat_mul_nn_backend(ctx, v, &fw, n)?;
95
96 let vh = complex_conj_transpose(v, n);
100 let vh_dv = complex_mat_mul_nn_backend(ctx, &vh, &dv_raw, n)?;
101 let mut dv = dv_raw;
102 for j in 0..n {
103 let correction = Cx::new(vh_dv[j + j * n].re, T::zero());
104 for i in 0..n {
105 let index = i + j * n;
106 dv[index] = dv[index] - v[index] * correction;
107 }
108 }
109 dvec_data[b * n * n..(b + 1) * n * n].copy_from_slice(&dv);
110 }
111
112 let val_dims = output_dims(&[n], batch_dims);
114 let vec_dims = output_dims(&[n, n], batch_dims);
115
116 let d_result = EigResult {
117 values: tensor_from_data_scalar(dval_data, &val_dims).map_err(to_ad_err)?,
118 vectors: tensor_from_data_scalar(dvec_data, &vec_dims).map_err(to_ad_err)?,
119 };
120
121 Ok((eig_result, d_result))
122}
123
124pub fn pinv_frule<T: KernelLinalgScalar + tenferro_algebra::Conjugate, C>(
142 ctx: &mut C,
143 tensor: &Tensor<T>,
144 tangent: &Tensor<T>,
145 rcond: Option<f64>,
146) -> AdResult<(Tensor<T>, Tensor<T>)>
147where
148 T: KernelLinalgScalar
149 + crate::prims_bridge::ScaleTensorByRealSameShape<C>
150 + tenferro_algebra::Conjugate,
151 C: backend::TensorLinalgContextFor<T>
152 + tenferro_prims::TensorResolveConjContextFor<T>
153 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
154 C::Backend: 'static,
155 T::Real: tenferro_tensor::KeepCountScalar,
156{
157 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Pinv, "pinv_frule")
158 .map_err(to_ad_err)?;
159
160 let ap = pinv(ctx, tensor, rcond)
162 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
163 let (m, n, batch_dims) = validate_2d(tensor)
164 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
165 let bc = batch_count(batch_dims);
166
167 let (a_data, _) = extract_data(tensor)?;
168 let (ap_data, _) = extract_data(&ap)?;
169 let (da_data, _) = extract_data(tangent)?;
170
171 let mut dap_data = vec![T::zero(); n * m * bc];
172
173 for batch in 0..bc {
174 let a_b = &a_data[batch * m * n..(batch + 1) * m * n];
175 let ap_b = &ap_data[batch * n * m..(batch + 1) * n * m];
176 let da_b = &da_data[batch * m * n..(batch + 1) * m * n];
177
178 let dat = adjoint_transpose(da_b, m, n); let apt = adjoint_transpose(ap_b, n, m); let ap_da = backend_mat_mul(ctx, ap_b, n, m, da_b, n)?;
183 let ap_da_ap = backend_mat_mul(ctx, &ap_da, n, n, ap_b, m)?;
184 let t1 = scale_vec(&ap_da_ap, -T::one());
185
186 let apa = backend_mat_mul(ctx, ap_b, n, m, a_b, n)?; let i_n = eye::<T>(n);
189 let i_apa = sub_vec(&i_n, &apa);
190 let dat_apt = backend_mat_mul(ctx, &dat, n, m, &apt, n)?; let dat_apt_ap = backend_mat_mul(ctx, &dat_apt, n, n, ap_b, m)?; let t2 = backend_mat_mul(ctx, &i_apa, n, n, &dat_apt_ap, m)?;
193
194 let aap = backend_mat_mul(ctx, a_b, m, n, ap_b, m)?; let i_m = eye::<T>(m);
197 let i_aap = sub_vec(&i_m, &aap);
198 let ap_apt = backend_mat_mul(ctx, ap_b, n, m, &apt, n)?; let ap_apt_dat = backend_mat_mul(ctx, &ap_apt, n, n, &dat, m)?; let t3 = backend_mat_mul(ctx, &ap_apt_dat, n, m, &i_aap, m)?;
201
202 let dap_b_vec = add_vec(&t1, &add_vec(&t2, &t3));
203 dap_data[batch * n * m..(batch + 1) * n * m].copy_from_slice(&dap_b_vec);
204 }
205
206 let dims = output_dims(&[n, m], batch_dims);
207 let dap = tensor_from_data(dap_data, &dims)
208 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
209 Ok((ap, dap))
210}