1use super::*;
2
3pub fn eig_rrule<
32 T: KernelLinalgScalar<Real = T, Complex = num_complex::Complex<T>> + num_traits::Float,
33 C,
34>(
35 ctx: &mut C,
36 tensor: &Tensor<T>,
37 cotangent: &EigCotangent<T>,
38) -> AdResult<Tensor<T>>
39where
40 T: KernelLinalgScalar,
41 C: backend::TensorLinalgContextFor<T>
42 + backend::TensorLinalgContextFor<num_complex::Complex<T>>
43 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
44 num_complex::Complex<T>: KernelLinalgScalar,
45 <C as backend::TensorLinalgContextFor<T>>::Backend: 'static,
46 <C as backend::TensorLinalgContextFor<num_complex::Complex<T>>>::Backend: 'static,
47 T::Real: tenferro_tensor::KeepCountScalar,
48{
49 let (n, batch_dims) = validate_square(tensor).map_err(to_ad_err)?;
50 let bc = batch_count(batch_dims);
51
52 let eig_result = eig(ctx, tensor).map_err(to_ad_err)?;
54 let val_data = extract_data_scalar(&eig_result.values)?;
55 let vec_data = extract_data_scalar(&eig_result.vectors)?;
56
57 let zero_c = Cx::new(T::zero(), T::zero());
58 let one_c = Cx::new(T::one(), T::zero());
59
60 let mut grad_data = vec![T::zero(); n * n * bc];
61
62 for b in 0..bc {
63 let lambda = &val_data[b * n..(b + 1) * n];
64 let v = &vec_data[b * n * n..(b + 1) * n * n];
65
66 let mut f_mat = vec![zero_c; n * n];
68 for i in 0..n {
69 for j in 0..n {
70 if i != j {
71 let diff = (lambda[j] - lambda[i]).conj();
72 f_mat[i + j * n] = one_c / diff;
73 }
74 }
75 }
76
77 let vh = complex_conj_transpose(v, n);
79
80 let mut m_bar = vec![zero_c; n * n];
82
83 if let Some(ref dv_bar) = cotangent.vectors {
84 let dv_bar_data = extract_data_scalar(dv_bar)?;
85 let dv_bar_b = &dv_bar_data[b * n * n..(b + 1) * n * n];
86 let vh_dv = complex_mat_mul_nn_backend(ctx, &vh, dv_bar_b, n)?;
87 let mut dv_adj = dv_bar_b.to_vec();
88 for j in 0..n {
89 let correction = Cx::new(vh_dv[j + j * n].re, T::zero());
90 for i in 0..n {
91 let index = i + j * n;
92 dv_adj[index] = dv_adj[index] - v[index] * correction;
93 }
94 }
95 let vh_dv = complex_mat_mul_nn_backend(ctx, &vh, &dv_adj, n)?;
96 for k in 0..n * n {
97 m_bar[k] = f_mat[k] * vh_dv[k];
98 }
99 }
100
101 if let Some(ref dlam) = cotangent.values {
102 let dlam_data = extract_data_scalar(dlam)?;
103 for i in 0..n {
104 m_bar[i + i * n] = m_bar[i + i * n] + dlam_data[b * n + i];
105 }
106 }
107
108 let m_vh = complex_mat_mul_nn_backend(ctx, &m_bar, &vh, n)?;
110 let da_complex = complex_solve_nn(ctx, &vh, &m_vh, n)?;
111
112 for k in 0..n * n {
114 grad_data[b * n * n + k] = da_complex[k].re;
115 }
116 }
117
118 let dims = output_dims(&[n, n], batch_dims);
119 tensor_from_data(grad_data, &dims).map_err(to_ad_err)
120}
121
122pub fn pinv_rrule<T: KernelLinalgScalar + tenferro_algebra::Conjugate, C>(
140 ctx: &mut C,
141 tensor: &Tensor<T>,
142 cotangent: &Tensor<T>,
143 rcond: Option<f64>,
144) -> AdResult<Tensor<T>>
145where
146 T: KernelLinalgScalar
147 + crate::prims_bridge::ScaleTensorByRealSameShape<C>
148 + tenferro_algebra::Conjugate,
149 C: backend::TensorLinalgContextFor<T>
150 + tenferro_prims::TensorResolveConjContextFor<T>
151 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
152 C::Backend: 'static,
153 T::Real: tenferro_tensor::KeepCountScalar,
154{
155 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Pinv, "pinv_rrule")
156 .map_err(to_ad_err)?;
157
158 let ap = pinv(ctx, tensor, rcond)
160 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
161 let (m, n, batch_dims) = validate_2d(tensor)
162 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
163 let bc = batch_count(batch_dims);
164
165 let (a_data, _) = extract_data(tensor)?;
166 let (ap_data, _) = extract_data(&ap)?;
167 let (dap_data, _) = extract_data(cotangent)?;
168
169 let mut grad_a = vec![T::zero(); m * n * bc];
170
171 for batch in 0..bc {
172 let a_b = &a_data[batch * m * n..(batch + 1) * m * n];
173 let ap_b = &ap_data[batch * n * m..(batch + 1) * n * m];
174 let dap_b = &dap_data[batch * n * m..(batch + 1) * n * m];
175
176 let apt = adjoint_transpose(ap_b, n, m); let dapt = adjoint_transpose(dap_b, n, m); let t1 = backend_mat_mul(ctx, &apt, m, n, dap_b, m)?;
182 let t1 = backend_mat_mul(ctx, &t1, m, m, &apt, n)?;
183 let t1 = scale_vec(&t1, -T::one());
184
185 let aap = backend_mat_mul(ctx, a_b, m, n, ap_b, m)?;
188 let i_m = eye::<T>(m);
189 let i_aap = sub_vec(&i_m, &aap);
190 let dapt_ap = backend_mat_mul(ctx, &dapt, m, n, ap_b, m)?;
192 let dapt_ap_apt = backend_mat_mul(ctx, &dapt_ap, m, m, &apt, n)?;
194 let t2 = backend_mat_mul(ctx, &i_aap, m, m, &dapt_ap_apt, n)?;
195
196 let apa = backend_mat_mul(ctx, ap_b, n, m, a_b, n)?;
199 let i_n = eye::<T>(n);
200 let i_apa = sub_vec(&i_n, &apa);
201 let apt_ap = backend_mat_mul(ctx, &apt, m, n, ap_b, m)?;
203 let apt_ap_dapt = backend_mat_mul(ctx, &apt_ap, m, m, &dapt, n)?;
205 let t3 = backend_mat_mul(ctx, &apt_ap_dapt, m, n, &i_apa, n)?;
206
207 let da_b = add_vec(&t1, &add_vec(&t2, &t3));
208 grad_a[batch * m * n..(batch + 1) * m * n].copy_from_slice(&da_b);
209 }
210
211 let dims = output_dims(&[m, n], batch_dims);
212 tensor_from_data(grad_a, &dims)
213 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))
214}