1use super::*;
2
3type SvdFruleOutput<T> = (
4 SvdResult<T, <T as LinalgScalar>::Real>,
5 SvdResult<T, <T as LinalgScalar>::Real>,
6);
7
8pub fn svd_frule<T, C>(
33 ctx: &mut C,
34 tensor: &Tensor<T>,
35 tangent: &Tensor<T>,
36 options: Option<&SvdOptions>,
37) -> AdResult<SvdFruleOutput<T>>
38where
39 T: KernelLinalgScalar,
40 T::Real: num_traits::Float,
41 C: backend::TensorLinalgContextFor<T>
42 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
43 C::Backend: 'static,
44 T::Real: tenferro_tensor::KeepCountScalar,
45{
46 let result = svd(ctx, tensor, options)
47 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
48 let (m, n, batch_dims) = validate_2d(tensor)
49 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
50 let k = m.min(n);
51 let bc = batch_count(batch_dims);
52 let eta: T::Real = {
56 let raw: T::Real = scalar_from(1e-40).map_err(to_ad_err)?;
57 let eps = T::real_epsilon();
58 if raw < eps {
59 eps
60 } else {
61 raw
62 }
63 };
64
65 let (u_data, _) = extract_data(&result.u)?;
66 let s_data = extract_data_scalar(&result.s)?;
67 let (vt_data, _) = extract_data(&result.vt)?;
68 let (da_data, _) = extract_data(tangent)?;
69
70 let mut du_data = vec![T::zero(); m * k * bc];
71 let mut ds_data = vec![T::Real::zero(); k * bc];
72 let mut dvt_data = vec![T::zero(); k * n * bc];
73 let half = scalar_from::<T::Real>(0.5).map_err(to_ad_err)?;
74
75 for b in 0..bc {
76 let u_b = &u_data[b * m * k..(b + 1) * m * k];
77 let s_b = &s_data[b * k..(b + 1) * k];
78 let vt_b = &vt_data[b * k * n..(b + 1) * k * n];
79 let da_b = &da_data[b * m * n..(b + 1) * m * n];
80 let v_b = adjoint_transpose(vt_b, k, n);
81
82 let uh_da = backend_mat_mul(ctx, &adjoint_transpose(u_b, m, k), k, m, da_b, n)?;
83 let c = backend_mat_mul(ctx, &uh_da, k, n, &v_b, k)?;
84
85 let mut x = c.clone();
86 for i in 0..k {
87 ds_data[b * k + i] = real_diagonal_from_scalar(c[i + i * k]);
88 x[i + i * k] = x[i + i * k] - T::from_real(ds_data[b * k + i]);
89 }
90
91 let mut du_inner = vec![T::zero(); k * k];
92 let mut dv_inner = vec![T::zero(); k * k];
93 for i in 0..k {
94 let diag_term = imag_axis_component(x[i + i * k])?
95 * T::from_real(half * stable_inverse_sigma(s_b[i], eta));
96 du_inner[i + i * k] = du_inner[i + i * k] + diag_term;
97 dv_inner[i + i * k] = dv_inner[i + i * k] - diag_term;
98 for j in 0..k {
99 if i == j {
100 continue;
101 }
102 let gap_inv = T::from_real(stable_inverse_gap(s_b[i], s_b[j], eta));
103 du_inner[i + j * k] = gap_inv
104 * (T::from_real(s_b[i]) * x[j + i * k].conj()
105 + x[i + j * k] * T::from_real(s_b[j]));
106 dv_inner[i + j * k] = gap_inv
107 * (T::from_real(s_b[i]) * x[i + j * k]
108 + x[j + i * k].conj() * T::from_real(s_b[j]));
109 }
110 }
111 let du_core = backend_mat_mul(ctx, u_b, m, k, &du_inner, k)?;
112 let mut du_b_vec = du_core;
113
114 if m > k {
115 let da_v = backend_mat_mul(ctx, da_b, m, n, &v_b, k)?;
116 let inner = backend_mat_mul(ctx, &adjoint_transpose(u_b, m, k), k, m, &da_v, k)?;
117 let uu_h_da_v = backend_mat_mul(ctx, u_b, m, k, &inner, k)?;
118 let proj_da_v = sub_vec(&da_v, &uu_h_da_v);
119 for j in 0..k {
120 let sinv = T::from_real(stable_inverse_sigma(s_b[j], eta));
121 for i in 0..m {
122 du_b_vec[i + j * m] = du_b_vec[i + j * m] + proj_da_v[i + j * m] * sinv;
123 }
124 }
125 }
126 du_data[b * m * k..(b + 1) * m * k].copy_from_slice(&du_b_vec);
127
128 let dv_core = backend_mat_mul(ctx, &v_b, n, k, &dv_inner, k)?;
129 let mut dv_b_vec = dv_core;
130
131 if n > k {
132 let da_h = adjoint_transpose(da_b, m, n);
133 let da_h_u = backend_mat_mul(ctx, &da_h, n, m, u_b, k)?;
134 let inner = backend_mat_mul(ctx, vt_b, k, n, &da_h_u, k)?;
135 let vv_h_da_h_u = backend_mat_mul(ctx, &v_b, n, k, &inner, k)?;
136 let proj = sub_vec(&da_h_u, &vv_h_da_h_u);
137 for j in 0..k {
138 let sinv = T::from_real(stable_inverse_sigma(s_b[j], eta));
139 for i in 0..n {
140 dv_b_vec[i + j * n] = dv_b_vec[i + j * n] + proj[i + j * n] * sinv;
141 }
142 }
143 }
144 let dvt_b_vec = adjoint_transpose(&dv_b_vec, n, k);
145 dvt_data[b * k * n..(b + 1) * k * n].copy_from_slice(&dvt_b_vec);
146 }
147
148 let u_dims = output_dims(&[m, k], batch_dims);
149 let s_dims = output_dims(&[k], batch_dims);
150 let vt_dims = output_dims(&[k, n], batch_dims);
151
152 let dresult = SvdResult {
153 u: tensor_from_data(du_data, &u_dims)
154 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
155 s: tensor_from_data_scalar(ds_data, &s_dims)
156 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
157 vt: tensor_from_data(dvt_data, &vt_dims)
158 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
159 };
160
161 Ok((result, dresult))
162}
163
164pub fn qr_frule<T, C>(
186 ctx: &mut C,
187 tensor: &Tensor<T>,
188 tangent: &Tensor<T>,
189) -> AdResult<(QrResult<T>, QrResult<T>)>
190where
191 T: KernelLinalgScalar,
192 C: backend::TensorLinalgContextFor<T>,
193 C::Backend: 'static,
194{
195 let result = qr(ctx, tensor)
196 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
197 let (m, n, batch_dims) = validate_2d(tensor)
198 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
199 let k = m.min(n);
200 let bc = batch_count(batch_dims);
201 let half: T::Real = scalar_from(0.5).map_err(to_ad_err)?;
202
203 let (q_data, _) = extract_data(&result.q)?;
204 let (r_data, _) = extract_data(&result.r)?;
205 let (da_data, _) = extract_data(tangent)?;
206
207 let mut dq_data = vec![T::zero(); m * k * bc];
208 let mut dr_data = vec![T::zero(); k * n * bc];
209
210 for b in 0..bc {
211 let q_b = &q_data[b * m * k..(b + 1) * m * k];
212 let r_b = &r_data[b * k * n..(b + 1) * k * n];
213 let da_b = &da_data[b * m * n..(b + 1) * m * n];
214
215 let (dq_b_vec, dr_b_vec) = if m >= n {
216 let r_sq = r_b[..n * n].to_vec();
217 let darinv_h = backend_solve_tri(
218 ctx,
219 &adjoint_transpose(&r_sq, n, n),
220 &adjoint_transpose(da_b, m, n),
221 n,
222 m,
223 false,
224 )?;
225 let darinv = adjoint_transpose(&darinv_h, n, m);
226 let qhdarinv = backend_mat_mul(ctx, &adjoint_transpose(q_b, m, n), n, m, &darinv, n)?;
227 let sym = add_vec(&qhdarinv, &adjoint_transpose(&qhdarinv, n, n));
228
229 let mut dr_hat = vec![T::zero(); n * n];
230 for j in 0..n {
231 for i in 0..=j {
232 let mut val = sym[i + j * n];
233 if i == j {
234 val = T::from_real(val.real_part() * half);
235 }
236 dr_hat[i + j * n] = val;
237 }
238 }
239
240 let dq = sub_vec(&darinv, &backend_mat_mul(ctx, q_b, m, n, &dr_hat, n)?);
241 let dr = backend_mat_mul(ctx, &dr_hat, n, n, &r_sq, n)?;
242 (dq, dr)
243 } else {
244 let qhda = backend_mat_mul(ctx, &adjoint_transpose(q_b, m, k), k, m, da_b, n)?;
245 let r1 = r_b[..k * k].to_vec();
246 let qhda1 = qhda[..k * k].to_vec();
247 let qhda1_rinv_h = backend_solve_tri(
248 ctx,
249 &adjoint_transpose(&r1, k, k),
250 &adjoint_transpose(&qhda1, k, k),
251 k,
252 k,
253 false,
254 )?;
255 let qhda1_rinv = adjoint_transpose(&qhda1_rinv_h, k, k);
256 let lower = tril_im(&qhda1_rinv, k)?;
257 let dq_hat = tril_im_inv(&lower, k)?;
258
259 let dr = sub_vec(&qhda, &backend_mat_mul(ctx, &dq_hat, k, k, r_b, n)?);
260 let dq = backend_mat_mul(ctx, q_b, m, k, &dq_hat, k)?;
261 (dq, dr)
262 };
263
264 dq_data[b * m * k..(b + 1) * m * k].copy_from_slice(&dq_b_vec);
265 dr_data[b * k * n..(b + 1) * k * n].copy_from_slice(&dr_b_vec);
266 }
267
268 let q_dims = output_dims(&[m, k], batch_dims);
269 let r_dims = output_dims(&[k, n], batch_dims);
270 let dresult = QrResult {
271 q: tensor_from_data(dq_data, &q_dims)
272 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
273 r: tensor_from_data(dr_data, &r_dims)
274 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
275 };
276 Ok((result, dresult))
277}