tenferro_linalg/frules/
svd_qr.rs

1use super::*;
2
3type SvdFruleOutput<T> = (
4    SvdResult<T, <T as LinalgScalar>::Real>,
5    SvdResult<T, <T as LinalgScalar>::Real>,
6);
7
8// ============================================================================
9// AD functions: frule (forward-mode, stateless)
10// ============================================================================
11
12/// Forward-mode AD rule for SVD (JVP / pushforward).
13///
14/// Computes the JVP of all SVD outputs given a tangent for the input.
15/// Uses batched matrix operations that broadcast over `*`.
16///
17/// # Examples
18///
19/// ```
20/// use tenferro_linalg::svd_frule;
21/// use tenferro_prims::CpuContext;
22/// use tenferro_tensor::{Tensor, MemoryOrder};
23/// use tenferro_device::LogicalMemorySpace;
24///
25/// let col = MemoryOrder::ColumnMajor;
26/// let mem = LogicalMemorySpace::MainMemory;
27/// let mut ctx = CpuContext::new(1);
28/// let a = Tensor::<f64>::zeros(&[3, 4], mem, col).unwrap();
29/// let da = Tensor::<f64>::ones(&[3, 4], mem, col).unwrap();
30/// let (result, dresult) = svd_frule(&mut ctx, &a, &da, None).unwrap();
31/// ```
32pub fn svd_frule<T, C>(
33    ctx: &mut C,
34    tensor: &Tensor<T>,
35    tangent: &Tensor<T>,
36    options: Option<&SvdOptions>,
37) -> AdResult<SvdFruleOutput<T>>
38where
39    T: KernelLinalgScalar,
40    T::Real: num_traits::Float,
41    C: backend::TensorLinalgContextFor<T>
42        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
43    C::Backend: 'static,
44    T::Real: tenferro_tensor::KeepCountScalar,
45{
46    let result = svd(ctx, tensor, options)
47        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
48    let (m, n, batch_dims) = validate_2d(tensor)
49        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
50    let k = m.min(n);
51    let bc = batch_count(batch_dims);
52    // Regularization for the F-matrix: prevents division by zero when two
53    // singular values are (nearly) equal.  We use max(1e-40, T::epsilon())
54    // so that on f32 (where 1e-40 underflows to 0) we still get a safe floor.
55    let eta: T::Real = {
56        let raw: T::Real = scalar_from(1e-40).map_err(to_ad_err)?;
57        let eps = T::real_epsilon();
58        if raw < eps {
59            eps
60        } else {
61            raw
62        }
63    };
64
65    let (u_data, _) = extract_data(&result.u)?;
66    let s_data = extract_data_scalar(&result.s)?;
67    let (vt_data, _) = extract_data(&result.vt)?;
68    let (da_data, _) = extract_data(tangent)?;
69
70    let mut du_data = vec![T::zero(); m * k * bc];
71    let mut ds_data = vec![T::Real::zero(); k * bc];
72    let mut dvt_data = vec![T::zero(); k * n * bc];
73    let half = scalar_from::<T::Real>(0.5).map_err(to_ad_err)?;
74
75    for b in 0..bc {
76        let u_b = &u_data[b * m * k..(b + 1) * m * k];
77        let s_b = &s_data[b * k..(b + 1) * k];
78        let vt_b = &vt_data[b * k * n..(b + 1) * k * n];
79        let da_b = &da_data[b * m * n..(b + 1) * m * n];
80        let v_b = adjoint_transpose(vt_b, k, n);
81
82        let uh_da = backend_mat_mul(ctx, &adjoint_transpose(u_b, m, k), k, m, da_b, n)?;
83        let c = backend_mat_mul(ctx, &uh_da, k, n, &v_b, k)?;
84
85        let mut x = c.clone();
86        for i in 0..k {
87            ds_data[b * k + i] = real_diagonal_from_scalar(c[i + i * k]);
88            x[i + i * k] = x[i + i * k] - T::from_real(ds_data[b * k + i]);
89        }
90
91        let mut du_inner = vec![T::zero(); k * k];
92        let mut dv_inner = vec![T::zero(); k * k];
93        for i in 0..k {
94            let diag_term = imag_axis_component(x[i + i * k])?
95                * T::from_real(half * stable_inverse_sigma(s_b[i], eta));
96            du_inner[i + i * k] = du_inner[i + i * k] + diag_term;
97            dv_inner[i + i * k] = dv_inner[i + i * k] - diag_term;
98            for j in 0..k {
99                if i == j {
100                    continue;
101                }
102                let gap_inv = T::from_real(stable_inverse_gap(s_b[i], s_b[j], eta));
103                du_inner[i + j * k] = gap_inv
104                    * (T::from_real(s_b[i]) * x[j + i * k].conj()
105                        + x[i + j * k] * T::from_real(s_b[j]));
106                dv_inner[i + j * k] = gap_inv
107                    * (T::from_real(s_b[i]) * x[i + j * k]
108                        + x[j + i * k].conj() * T::from_real(s_b[j]));
109            }
110        }
111        let du_core = backend_mat_mul(ctx, u_b, m, k, &du_inner, k)?;
112        let mut du_b_vec = du_core;
113
114        if m > k {
115            let da_v = backend_mat_mul(ctx, da_b, m, n, &v_b, k)?;
116            let inner = backend_mat_mul(ctx, &adjoint_transpose(u_b, m, k), k, m, &da_v, k)?;
117            let uu_h_da_v = backend_mat_mul(ctx, u_b, m, k, &inner, k)?;
118            let proj_da_v = sub_vec(&da_v, &uu_h_da_v);
119            for j in 0..k {
120                let sinv = T::from_real(stable_inverse_sigma(s_b[j], eta));
121                for i in 0..m {
122                    du_b_vec[i + j * m] = du_b_vec[i + j * m] + proj_da_v[i + j * m] * sinv;
123                }
124            }
125        }
126        du_data[b * m * k..(b + 1) * m * k].copy_from_slice(&du_b_vec);
127
128        let dv_core = backend_mat_mul(ctx, &v_b, n, k, &dv_inner, k)?;
129        let mut dv_b_vec = dv_core;
130
131        if n > k {
132            let da_h = adjoint_transpose(da_b, m, n);
133            let da_h_u = backend_mat_mul(ctx, &da_h, n, m, u_b, k)?;
134            let inner = backend_mat_mul(ctx, vt_b, k, n, &da_h_u, k)?;
135            let vv_h_da_h_u = backend_mat_mul(ctx, &v_b, n, k, &inner, k)?;
136            let proj = sub_vec(&da_h_u, &vv_h_da_h_u);
137            for j in 0..k {
138                let sinv = T::from_real(stable_inverse_sigma(s_b[j], eta));
139                for i in 0..n {
140                    dv_b_vec[i + j * n] = dv_b_vec[i + j * n] + proj[i + j * n] * sinv;
141                }
142            }
143        }
144        let dvt_b_vec = adjoint_transpose(&dv_b_vec, n, k);
145        dvt_data[b * k * n..(b + 1) * k * n].copy_from_slice(&dvt_b_vec);
146    }
147
148    let u_dims = output_dims(&[m, k], batch_dims);
149    let s_dims = output_dims(&[k], batch_dims);
150    let vt_dims = output_dims(&[k, n], batch_dims);
151
152    let dresult = SvdResult {
153        u: tensor_from_data(du_data, &u_dims)
154            .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
155        s: tensor_from_data_scalar(ds_data, &s_dims)
156            .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
157        vt: tensor_from_data(dvt_data, &vt_dims)
158            .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
159    };
160
161    Ok((result, dresult))
162}
163
164/// Forward-mode AD rule for QR (JVP / pushforward).
165///
166/// # Examples
167///
168/// ```
169/// use tenferro_linalg::qr_frule;
170/// use tenferro_prims::CpuContext;
171/// use tenferro_tensor::{Tensor, MemoryOrder};
172/// use tenferro_device::LogicalMemorySpace;
173///
174/// let col = MemoryOrder::ColumnMajor;
175/// let mem = LogicalMemorySpace::MainMemory;
176/// let mut ctx = CpuContext::new(1);
177/// let a = Tensor::from_slice(
178///     &[1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0],
179///     &[4, 3],
180///     col,
181/// ).unwrap();
182/// let da = Tensor::<f64>::ones(&[4, 3], mem, col).unwrap();
183/// let (result, dresult) = qr_frule(&mut ctx, &a, &da).unwrap();
184/// ```
185pub fn qr_frule<T, C>(
186    ctx: &mut C,
187    tensor: &Tensor<T>,
188    tangent: &Tensor<T>,
189) -> AdResult<(QrResult<T>, QrResult<T>)>
190where
191    T: KernelLinalgScalar,
192    C: backend::TensorLinalgContextFor<T>,
193    C::Backend: 'static,
194{
195    let result = qr(ctx, tensor)
196        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
197    let (m, n, batch_dims) = validate_2d(tensor)
198        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
199    let k = m.min(n);
200    let bc = batch_count(batch_dims);
201    let half: T::Real = scalar_from(0.5).map_err(to_ad_err)?;
202
203    let (q_data, _) = extract_data(&result.q)?;
204    let (r_data, _) = extract_data(&result.r)?;
205    let (da_data, _) = extract_data(tangent)?;
206
207    let mut dq_data = vec![T::zero(); m * k * bc];
208    let mut dr_data = vec![T::zero(); k * n * bc];
209
210    for b in 0..bc {
211        let q_b = &q_data[b * m * k..(b + 1) * m * k];
212        let r_b = &r_data[b * k * n..(b + 1) * k * n];
213        let da_b = &da_data[b * m * n..(b + 1) * m * n];
214
215        let (dq_b_vec, dr_b_vec) = if m >= n {
216            let r_sq = r_b[..n * n].to_vec();
217            let darinv_h = backend_solve_tri(
218                ctx,
219                &adjoint_transpose(&r_sq, n, n),
220                &adjoint_transpose(da_b, m, n),
221                n,
222                m,
223                false,
224            )?;
225            let darinv = adjoint_transpose(&darinv_h, n, m);
226            let qhdarinv = backend_mat_mul(ctx, &adjoint_transpose(q_b, m, n), n, m, &darinv, n)?;
227            let sym = add_vec(&qhdarinv, &adjoint_transpose(&qhdarinv, n, n));
228
229            let mut dr_hat = vec![T::zero(); n * n];
230            for j in 0..n {
231                for i in 0..=j {
232                    let mut val = sym[i + j * n];
233                    if i == j {
234                        val = T::from_real(val.real_part() * half);
235                    }
236                    dr_hat[i + j * n] = val;
237                }
238            }
239
240            let dq = sub_vec(&darinv, &backend_mat_mul(ctx, q_b, m, n, &dr_hat, n)?);
241            let dr = backend_mat_mul(ctx, &dr_hat, n, n, &r_sq, n)?;
242            (dq, dr)
243        } else {
244            let qhda = backend_mat_mul(ctx, &adjoint_transpose(q_b, m, k), k, m, da_b, n)?;
245            let r1 = r_b[..k * k].to_vec();
246            let qhda1 = qhda[..k * k].to_vec();
247            let qhda1_rinv_h = backend_solve_tri(
248                ctx,
249                &adjoint_transpose(&r1, k, k),
250                &adjoint_transpose(&qhda1, k, k),
251                k,
252                k,
253                false,
254            )?;
255            let qhda1_rinv = adjoint_transpose(&qhda1_rinv_h, k, k);
256            let lower = tril_im(&qhda1_rinv, k)?;
257            let dq_hat = tril_im_inv(&lower, k)?;
258
259            let dr = sub_vec(&qhda, &backend_mat_mul(ctx, &dq_hat, k, k, r_b, n)?);
260            let dq = backend_mat_mul(ctx, q_b, m, k, &dq_hat, k)?;
261            (dq, dr)
262        };
263
264        dq_data[b * m * k..(b + 1) * m * k].copy_from_slice(&dq_b_vec);
265        dr_data[b * k * n..(b + 1) * k * n].copy_from_slice(&dr_b_vec);
266    }
267
268    let q_dims = output_dims(&[m, k], batch_dims);
269    let r_dims = output_dims(&[k, n], batch_dims);
270    let dresult = QrResult {
271        q: tensor_from_data(dq_data, &q_dims)
272            .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
273        r: tensor_from_data(dr_data, &r_dims)
274            .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?,
275    };
276    Ok((result, dresult))
277}