tenferro_linalg/rrules/
svd_qr.rs

1use super::*;
2
3/// Reverse-mode AD rule for SVD (VJP / pullback).
4///
5/// Computes the gradient of the input given cotangents for the SVD outputs.
6/// Uses the F-matrix approach (Mathieu 2019).
7///
8/// # Examples
9///
10/// ```
11/// use tenferro_linalg::{svd, svd_rrule, SvdCotangent};
12/// use tenferro_prims::CpuContext;
13/// use tenferro_tensor::{Tensor, MemoryOrder};
14/// use tenferro_device::LogicalMemorySpace;
15///
16/// let col = MemoryOrder::ColumnMajor;
17/// let mem = LogicalMemorySpace::MainMemory;
18/// let mut ctx = CpuContext::new(1);
19/// let a = Tensor::<f64>::zeros(&[3, 4], mem, col).unwrap();
20///
21/// let cotangent = SvdCotangent {
22///     u: None,
23///     s: Some(Tensor::ones(&[3], mem, col).unwrap()),
24///     vt: None,
25/// };
26/// let grad_a = svd_rrule(&mut ctx, &a, &cotangent, None).unwrap();
27/// ```
28pub fn svd_rrule<T, C>(
29    ctx: &mut C,
30    tensor: &Tensor<T>,
31    cotangent: &SvdCotangent<T, T::Real>,
32    options: Option<&SvdOptions>,
33) -> AdResult<Tensor<T>>
34where
35    T: KernelLinalgScalar,
36    T::Real: num_traits::Float,
37    C: backend::TensorLinalgContextFor<T>
38        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
39    C::Backend: 'static,
40    T::Real: tenferro_tensor::KeepCountScalar,
41{
42    let result = svd(ctx, tensor, options)
43        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
44    let (m, n, batch_dims) = validate_2d(tensor)
45        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
46    let k = m.min(n);
47    let bc = batch_count(batch_dims);
48    // Regularization for the F-matrix: prevents division by zero when two
49    // singular values are (nearly) equal.  We use max(1e-40, T::epsilon())
50    // so that on f32 (where 1e-40 underflows to 0) we still get a safe floor.
51    let eta: T::Real = {
52        let raw: T::Real = scalar_from(1e-40).map_err(to_ad_err)?;
53        let eps = T::real_epsilon();
54        if raw < eps {
55            eps
56        } else {
57            raw
58        }
59    };
60
61    let (u_data, _) = extract_data(&result.u)?;
62    let s_data = extract_data_scalar(&result.s)?;
63    let (vt_data, _) = extract_data(&result.vt)?;
64    let du_data = cotangent
65        .u
66        .as_ref()
67        .map(extract_data)
68        .transpose()?
69        .map(|(data, _)| data);
70    let ds_data = cotangent.s.as_ref().map(extract_data_scalar).transpose()?;
71    let dvt_data = cotangent
72        .vt
73        .as_ref()
74        .map(extract_data)
75        .transpose()?
76        .map(|(data, _)| data);
77
78    let mut grad_a = vec![T::zero(); m * n * bc];
79
80    for b in 0..bc {
81        let u_b = &u_data[b * m * k..(b + 1) * m * k];
82        let s_b = &s_data[b * k..(b + 1) * k];
83        let vt_b = &vt_data[b * k * n..(b + 1) * k * n];
84        let v_b = adjoint_transpose(vt_b, k, n);
85
86        let mut gamma = vec![T::zero(); k * k];
87
88        if let Some(ds_data) = ds_data.as_ref() {
89            let ds_b = &ds_data[b * k..(b + 1) * k];
90            for i in 0..k {
91                gamma[i + i * k] = gamma[i + i * k] + T::from_real(ds_b[i]);
92            }
93        }
94
95        if let Some(du_data) = du_data.as_ref() {
96            let du_b = &du_data[b * m * k..(b + 1) * m * k];
97            let uh_du = backend_mat_mul(ctx, &adjoint_transpose(u_b, m, k), k, m, du_b, k)?;
98            for i in 0..k {
99                let s_inv = stable_inverse_sigma(s_b[i], eta);
100                gamma[i + i * k] =
101                    gamma[i + i * k] + imag_axis_component(uh_du[i + i * k])? * T::from_real(s_inv);
102                for j in 0..k {
103                    if i == j {
104                        continue;
105                    }
106                    let gap_inv = T::from_real(stable_inverse_gap(s_b[i], s_b[j], eta));
107                    let skew = uh_du[i + j * k] - uh_du[j + i * k].conj();
108                    gamma[i + j * k] = gamma[i + j * k] + gap_inv * skew * T::from_real(s_b[j]);
109                }
110            }
111        }
112
113        if let Some(dvt_data) = dvt_data.as_ref() {
114            let dvt_b = &dvt_data[b * k * n..(b + 1) * k * n];
115            let dv_b = adjoint_transpose(dvt_b, k, n);
116            let vh_dv = backend_mat_mul(ctx, vt_b, k, n, &dv_b, k)?;
117            for i in 0..k {
118                for j in 0..k {
119                    if i == j {
120                        continue;
121                    }
122                    let gap_inv = T::from_real(stable_inverse_gap(s_b[i], s_b[j], eta));
123                    let skew = vh_dv[i + j * k] - vh_dv[j + i * k].conj();
124                    gamma[i + j * k] = gamma[i + j * k] + T::from_real(s_b[i]) * gap_inv * skew;
125                }
126            }
127        }
128
129        let u_gamma = backend_mat_mul(ctx, u_b, m, k, &gamma, k)?;
130        let da_core = backend_mat_mul(ctx, &u_gamma, m, k, vt_b, n)?;
131
132        for i in 0..m * n {
133            grad_a[b * m * n + i] = da_core[i];
134        }
135
136        if m > k {
137            if let Some(du_data) = du_data.as_ref() {
138                let du_b = &du_data[b * m * k..(b + 1) * m * k];
139                let mut du_sinv = vec![T::zero(); m * k];
140                for j in 0..k {
141                    let sinv = T::from_real(stable_inverse_sigma(s_b[j], eta));
142                    for i in 0..m {
143                        du_sinv[i + j * m] = du_b[i + j * m] * sinv;
144                    }
145                }
146                let inner = backend_mat_mul(ctx, &adjoint_transpose(u_b, m, k), k, m, &du_sinv, k)?;
147                let uut_du = backend_mat_mul(ctx, u_b, m, k, &inner, k)?;
148                let proj = sub_vec(&du_sinv, &uut_du);
149                let correction = backend_mat_mul(ctx, &proj, m, k, vt_b, n)?;
150                for i in 0..m * n {
151                    grad_a[b * m * n + i] = grad_a[b * m * n + i] + correction[i];
152                }
153            }
154        }
155
156        if n > k {
157            if let Some(dvt_data) = dvt_data.as_ref() {
158                let dvt_b = &dvt_data[b * k * n..(b + 1) * k * n];
159                let dv_b = adjoint_transpose(dvt_b, k, n);
160                let inner = backend_mat_mul(ctx, vt_b, k, n, &dv_b, k)?;
161                let vvt_dv = backend_mat_mul(ctx, &v_b, n, k, &inner, k)?;
162                let proj_dv = sub_vec(&dv_b, &vvt_dv);
163                let mut u_sinv = vec![T::zero(); m * k];
164                for j in 0..k {
165                    let sinv = T::from_real(stable_inverse_sigma(s_b[j], eta));
166                    for i in 0..m {
167                        u_sinv[i + j * m] = u_b[i + j * m] * sinv;
168                    }
169                }
170                let correction =
171                    backend_mat_mul(ctx, &u_sinv, m, k, &adjoint_transpose(&proj_dv, n, k), n)?;
172                for i in 0..m * n {
173                    grad_a[b * m * n + i] = grad_a[b * m * n + i] + correction[i];
174                }
175            }
176        }
177    }
178
179    let dims = output_dims(&[m, n], batch_dims);
180    tensor_from_data(grad_a, &dims)
181        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))
182}
183
184/// Reverse-mode AD rule for QR (VJP / pullback).
185///
186/// # Examples
187///
188/// ```
189/// use tenferro_linalg::{qr_rrule, QrCotangent};
190/// use tenferro_prims::CpuContext;
191/// use tenferro_tensor::{Tensor, MemoryOrder};
192/// use tenferro_device::LogicalMemorySpace;
193///
194/// let col = MemoryOrder::ColumnMajor;
195/// let mem = LogicalMemorySpace::MainMemory;
196/// let mut ctx = CpuContext::new(1);
197/// let a = Tensor::from_slice(
198///     &[1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0],
199///     &[4, 3],
200///     col,
201/// ).unwrap();
202/// let cotangent = QrCotangent {
203///     q: Some(Tensor::ones(&[4, 3], mem, col).unwrap()),
204///     r: None,
205/// };
206/// let grad_a = qr_rrule(&mut ctx, &a, &cotangent).unwrap();
207/// ```
208pub fn qr_rrule<T, C>(
209    ctx: &mut C,
210    tensor: &Tensor<T>,
211    cotangent: &QrCotangent<T>,
212) -> AdResult<Tensor<T>>
213where
214    T: KernelLinalgScalar,
215    C: backend::TensorLinalgContextFor<T>,
216    C::Backend: 'static,
217{
218    let result = qr(ctx, tensor)
219        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
220    let (m, n, batch_dims) = validate_2d(tensor)
221        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
222    let k = m.min(n);
223    let bc = batch_count(batch_dims);
224
225    let (q_data, _) = extract_data(&result.q)?;
226    let (r_data, _) = extract_data(&result.r)?;
227
228    let mut grad_a = vec![T::zero(); m * n * bc];
229
230    for b in 0..bc {
231        let q_b = &q_data[b * m * k..(b + 1) * m * k];
232        let r_b = &r_data[b * k * n..(b + 1) * k * n];
233
234        // Initialize dQ and dR from cotangents (zero if not provided)
235        let dq_b: Vec<T> = if let Some(ref dq) = cotangent.q {
236            let (dq_data, _) = extract_data(dq)?;
237            dq_data[b * m * k..(b + 1) * m * k].to_vec()
238        } else {
239            vec![T::zero(); m * k]
240        };
241        let dr_b: Vec<T> = if let Some(ref dr) = cotangent.r {
242            let (dr_data, _) = extract_data(dr)?;
243            dr_data[b * k * n..(b + 1) * k * n].to_vec()
244        } else {
245            vec![T::zero(); k * n]
246        };
247
248        if m >= n {
249            let r_drh = backend_mat_mul(ctx, r_b, k, n, &adjoint_transpose(&dr_b, k, n), k)?;
250            let dqh_q = backend_mat_mul(ctx, &adjoint_transpose(&dq_b, m, k), k, m, q_b, k)?;
251            let w = sub_vec(&r_drh, &dqh_q);
252
253            let h = copyltu(&w, k);
254            let qh = backend_mat_mul(ctx, q_b, m, k, &h, k)?;
255            let rhs = add_vec(&dq_b, &qh);
256
257            let r_square = r_b[..k * n].to_vec();
258            let rhs_h = adjoint_transpose(&rhs, m, k);
259            let da_h = backend_solve_tri(ctx, &r_square, &rhs_h, k, m, true)?;
260            let da_first_k = adjoint_transpose(&da_h, k, m);
261
262            for j in 0..k.min(n) {
263                for i in 0..m {
264                    grad_a[b * m * n + i + j * m] = da_first_k[i + j * m];
265                }
266            }
267        } else {
268            let qhgq = backend_mat_mul(ctx, &adjoint_transpose(q_b, m, k), k, m, &dq_b, k)?;
269            let gr_rh = backend_mat_mul(ctx, &dr_b, k, n, &adjoint_transpose(r_b, k, n), k)?;
270            let wide_inner = sub_vec(&qhgq, &gr_rh);
271            let lower_skew = tril_im_inv_adj_skew(&wide_inner, k)?;
272
273            let q_lower = backend_mat_mul(ctx, q_b, m, k, &lower_skew, k)?;
274            let mut r1 = vec![T::zero(); k * k];
275            for j in 0..k {
276                for i in 0..k {
277                    r1[i + j * k] = r_b[i + j * k];
278                }
279            }
280            let q_lower_h = adjoint_transpose(&q_lower, m, k);
281            let leading_h = backend_solve_tri(ctx, &r1, &q_lower_h, k, m, true)?;
282            let leading = adjoint_transpose(&leading_h, k, m);
283
284            for j in 0..k {
285                for i in 0..m {
286                    grad_a[b * m * n + i + j * m] = leading[i + j * m];
287                }
288            }
289
290            let qgr = backend_mat_mul(ctx, q_b, m, k, &dr_b, n)?;
291            for j in 0..n {
292                for i in 0..m {
293                    grad_a[b * m * n + i + j * m] = grad_a[b * m * n + i + j * m] + qgr[i + j * m];
294                }
295            }
296        }
297    }
298
299    let dims = output_dims(&[m, n], batch_dims);
300    tensor_from_data(grad_a, &dims)
301        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))
302}