1use super::*;
2
3pub fn svd_rrule<T, C>(
29 ctx: &mut C,
30 tensor: &Tensor<T>,
31 cotangent: &SvdCotangent<T, T::Real>,
32 options: Option<&SvdOptions>,
33) -> AdResult<Tensor<T>>
34where
35 T: KernelLinalgScalar,
36 T::Real: num_traits::Float,
37 C: backend::TensorLinalgContextFor<T>
38 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
39 C::Backend: 'static,
40 T::Real: tenferro_tensor::KeepCountScalar,
41{
42 let result = svd(ctx, tensor, options)
43 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
44 let (m, n, batch_dims) = validate_2d(tensor)
45 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
46 let k = m.min(n);
47 let bc = batch_count(batch_dims);
48 let eta: T::Real = {
52 let raw: T::Real = scalar_from(1e-40).map_err(to_ad_err)?;
53 let eps = T::real_epsilon();
54 if raw < eps {
55 eps
56 } else {
57 raw
58 }
59 };
60
61 let (u_data, _) = extract_data(&result.u)?;
62 let s_data = extract_data_scalar(&result.s)?;
63 let (vt_data, _) = extract_data(&result.vt)?;
64 let du_data = cotangent
65 .u
66 .as_ref()
67 .map(extract_data)
68 .transpose()?
69 .map(|(data, _)| data);
70 let ds_data = cotangent.s.as_ref().map(extract_data_scalar).transpose()?;
71 let dvt_data = cotangent
72 .vt
73 .as_ref()
74 .map(extract_data)
75 .transpose()?
76 .map(|(data, _)| data);
77
78 let mut grad_a = vec![T::zero(); m * n * bc];
79
80 for b in 0..bc {
81 let u_b = &u_data[b * m * k..(b + 1) * m * k];
82 let s_b = &s_data[b * k..(b + 1) * k];
83 let vt_b = &vt_data[b * k * n..(b + 1) * k * n];
84 let v_b = adjoint_transpose(vt_b, k, n);
85
86 let mut gamma = vec![T::zero(); k * k];
87
88 if let Some(ds_data) = ds_data.as_ref() {
89 let ds_b = &ds_data[b * k..(b + 1) * k];
90 for i in 0..k {
91 gamma[i + i * k] = gamma[i + i * k] + T::from_real(ds_b[i]);
92 }
93 }
94
95 if let Some(du_data) = du_data.as_ref() {
96 let du_b = &du_data[b * m * k..(b + 1) * m * k];
97 let uh_du = backend_mat_mul(ctx, &adjoint_transpose(u_b, m, k), k, m, du_b, k)?;
98 for i in 0..k {
99 let s_inv = stable_inverse_sigma(s_b[i], eta);
100 gamma[i + i * k] =
101 gamma[i + i * k] + imag_axis_component(uh_du[i + i * k])? * T::from_real(s_inv);
102 for j in 0..k {
103 if i == j {
104 continue;
105 }
106 let gap_inv = T::from_real(stable_inverse_gap(s_b[i], s_b[j], eta));
107 let skew = uh_du[i + j * k] - uh_du[j + i * k].conj();
108 gamma[i + j * k] = gamma[i + j * k] + gap_inv * skew * T::from_real(s_b[j]);
109 }
110 }
111 }
112
113 if let Some(dvt_data) = dvt_data.as_ref() {
114 let dvt_b = &dvt_data[b * k * n..(b + 1) * k * n];
115 let dv_b = adjoint_transpose(dvt_b, k, n);
116 let vh_dv = backend_mat_mul(ctx, vt_b, k, n, &dv_b, k)?;
117 for i in 0..k {
118 for j in 0..k {
119 if i == j {
120 continue;
121 }
122 let gap_inv = T::from_real(stable_inverse_gap(s_b[i], s_b[j], eta));
123 let skew = vh_dv[i + j * k] - vh_dv[j + i * k].conj();
124 gamma[i + j * k] = gamma[i + j * k] + T::from_real(s_b[i]) * gap_inv * skew;
125 }
126 }
127 }
128
129 let u_gamma = backend_mat_mul(ctx, u_b, m, k, &gamma, k)?;
130 let da_core = backend_mat_mul(ctx, &u_gamma, m, k, vt_b, n)?;
131
132 for i in 0..m * n {
133 grad_a[b * m * n + i] = da_core[i];
134 }
135
136 if m > k {
137 if let Some(du_data) = du_data.as_ref() {
138 let du_b = &du_data[b * m * k..(b + 1) * m * k];
139 let mut du_sinv = vec![T::zero(); m * k];
140 for j in 0..k {
141 let sinv = T::from_real(stable_inverse_sigma(s_b[j], eta));
142 for i in 0..m {
143 du_sinv[i + j * m] = du_b[i + j * m] * sinv;
144 }
145 }
146 let inner = backend_mat_mul(ctx, &adjoint_transpose(u_b, m, k), k, m, &du_sinv, k)?;
147 let uut_du = backend_mat_mul(ctx, u_b, m, k, &inner, k)?;
148 let proj = sub_vec(&du_sinv, &uut_du);
149 let correction = backend_mat_mul(ctx, &proj, m, k, vt_b, n)?;
150 for i in 0..m * n {
151 grad_a[b * m * n + i] = grad_a[b * m * n + i] + correction[i];
152 }
153 }
154 }
155
156 if n > k {
157 if let Some(dvt_data) = dvt_data.as_ref() {
158 let dvt_b = &dvt_data[b * k * n..(b + 1) * k * n];
159 let dv_b = adjoint_transpose(dvt_b, k, n);
160 let inner = backend_mat_mul(ctx, vt_b, k, n, &dv_b, k)?;
161 let vvt_dv = backend_mat_mul(ctx, &v_b, n, k, &inner, k)?;
162 let proj_dv = sub_vec(&dv_b, &vvt_dv);
163 let mut u_sinv = vec![T::zero(); m * k];
164 for j in 0..k {
165 let sinv = T::from_real(stable_inverse_sigma(s_b[j], eta));
166 for i in 0..m {
167 u_sinv[i + j * m] = u_b[i + j * m] * sinv;
168 }
169 }
170 let correction =
171 backend_mat_mul(ctx, &u_sinv, m, k, &adjoint_transpose(&proj_dv, n, k), n)?;
172 for i in 0..m * n {
173 grad_a[b * m * n + i] = grad_a[b * m * n + i] + correction[i];
174 }
175 }
176 }
177 }
178
179 let dims = output_dims(&[m, n], batch_dims);
180 tensor_from_data(grad_a, &dims)
181 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))
182}
183
184pub fn qr_rrule<T, C>(
209 ctx: &mut C,
210 tensor: &Tensor<T>,
211 cotangent: &QrCotangent<T>,
212) -> AdResult<Tensor<T>>
213where
214 T: KernelLinalgScalar,
215 C: backend::TensorLinalgContextFor<T>,
216 C::Backend: 'static,
217{
218 let result = qr(ctx, tensor)
219 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
220 let (m, n, batch_dims) = validate_2d(tensor)
221 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
222 let k = m.min(n);
223 let bc = batch_count(batch_dims);
224
225 let (q_data, _) = extract_data(&result.q)?;
226 let (r_data, _) = extract_data(&result.r)?;
227
228 let mut grad_a = vec![T::zero(); m * n * bc];
229
230 for b in 0..bc {
231 let q_b = &q_data[b * m * k..(b + 1) * m * k];
232 let r_b = &r_data[b * k * n..(b + 1) * k * n];
233
234 let dq_b: Vec<T> = if let Some(ref dq) = cotangent.q {
236 let (dq_data, _) = extract_data(dq)?;
237 dq_data[b * m * k..(b + 1) * m * k].to_vec()
238 } else {
239 vec![T::zero(); m * k]
240 };
241 let dr_b: Vec<T> = if let Some(ref dr) = cotangent.r {
242 let (dr_data, _) = extract_data(dr)?;
243 dr_data[b * k * n..(b + 1) * k * n].to_vec()
244 } else {
245 vec![T::zero(); k * n]
246 };
247
248 if m >= n {
249 let r_drh = backend_mat_mul(ctx, r_b, k, n, &adjoint_transpose(&dr_b, k, n), k)?;
250 let dqh_q = backend_mat_mul(ctx, &adjoint_transpose(&dq_b, m, k), k, m, q_b, k)?;
251 let w = sub_vec(&r_drh, &dqh_q);
252
253 let h = copyltu(&w, k);
254 let qh = backend_mat_mul(ctx, q_b, m, k, &h, k)?;
255 let rhs = add_vec(&dq_b, &qh);
256
257 let r_square = r_b[..k * n].to_vec();
258 let rhs_h = adjoint_transpose(&rhs, m, k);
259 let da_h = backend_solve_tri(ctx, &r_square, &rhs_h, k, m, true)?;
260 let da_first_k = adjoint_transpose(&da_h, k, m);
261
262 for j in 0..k.min(n) {
263 for i in 0..m {
264 grad_a[b * m * n + i + j * m] = da_first_k[i + j * m];
265 }
266 }
267 } else {
268 let qhgq = backend_mat_mul(ctx, &adjoint_transpose(q_b, m, k), k, m, &dq_b, k)?;
269 let gr_rh = backend_mat_mul(ctx, &dr_b, k, n, &adjoint_transpose(r_b, k, n), k)?;
270 let wide_inner = sub_vec(&qhgq, &gr_rh);
271 let lower_skew = tril_im_inv_adj_skew(&wide_inner, k)?;
272
273 let q_lower = backend_mat_mul(ctx, q_b, m, k, &lower_skew, k)?;
274 let mut r1 = vec![T::zero(); k * k];
275 for j in 0..k {
276 for i in 0..k {
277 r1[i + j * k] = r_b[i + j * k];
278 }
279 }
280 let q_lower_h = adjoint_transpose(&q_lower, m, k);
281 let leading_h = backend_solve_tri(ctx, &r1, &q_lower_h, k, m, true)?;
282 let leading = adjoint_transpose(&leading_h, k, m);
283
284 for j in 0..k {
285 for i in 0..m {
286 grad_a[b * m * n + i + j * m] = leading[i + j * m];
287 }
288 }
289
290 let qgr = backend_mat_mul(ctx, q_b, m, k, &dr_b, n)?;
291 for j in 0..n {
292 for i in 0..m {
293 grad_a[b * m * n + i + j * m] = grad_a[b * m * n + i + j * m] + qgr[i + j * m];
294 }
295 }
296 }
297 }
298
299 let dims = output_dims(&[m, n], batch_dims);
300 tensor_from_data(grad_a, &dims)
301 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))
302}