tenferro_linalg/rrules/
spectral.rs

1use super::*;
2
3/// Reverse-mode AD rule for general eigendecomposition (VJP / pullback).
4///
5/// Given eigendecomposition `A V = V diag(lambda)`, computes the gradient
6/// of the input `A` from complex-valued cotangents for eigenvalues and
7/// eigenvectors using the Mike Giles formulas.
8///
9/// The cotangent uses [`EigCotangent`] with complex-valued tensors
10/// because `eig()` returns complex output even for real inputs.
11///
12/// # Examples
13///
14/// ```
15/// use tenferro_linalg::{eig_rrule, EigCotangent};
16/// use tenferro_prims::CpuContext;
17/// use tenferro_tensor::{Tensor, MemoryOrder};
18/// use tenferro_device::LogicalMemorySpace;
19/// use num_complex::Complex64;
20///
21/// let col = MemoryOrder::ColumnMajor;
22/// let mem = LogicalMemorySpace::MainMemory;
23/// let mut ctx = CpuContext::new(1);
24/// let a = Tensor::<f64>::zeros(&[3, 3], mem, col).unwrap();
25/// let cotangent = EigCotangent::<f64> {
26///     values: None,
27///     vectors: None,
28/// };
29/// let grad_a = eig_rrule(&mut ctx, &a, &cotangent).unwrap();
30/// ```
31pub fn eig_rrule<
32    T: KernelLinalgScalar<Real = T, Complex = num_complex::Complex<T>> + num_traits::Float,
33    C,
34>(
35    ctx: &mut C,
36    tensor: &Tensor<T>,
37    cotangent: &EigCotangent<T>,
38) -> AdResult<Tensor<T>>
39where
40    T: KernelLinalgScalar,
41    C: backend::TensorLinalgContextFor<T>
42        + backend::TensorLinalgContextFor<num_complex::Complex<T>>
43        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
44    num_complex::Complex<T>: KernelLinalgScalar,
45    <C as backend::TensorLinalgContextFor<T>>::Backend: 'static,
46    <C as backend::TensorLinalgContextFor<num_complex::Complex<T>>>::Backend: 'static,
47    T::Real: tenferro_tensor::KeepCountScalar,
48{
49    let (n, batch_dims) = validate_square(tensor).map_err(to_ad_err)?;
50    let bc = batch_count(batch_dims);
51
52    // Compute eigendecomposition
53    let eig_result = eig(ctx, tensor).map_err(to_ad_err)?;
54    let val_data = extract_data_scalar(&eig_result.values)?;
55    let vec_data = extract_data_scalar(&eig_result.vectors)?;
56
57    let zero_c = Cx::new(T::zero(), T::zero());
58    let one_c = Cx::new(T::one(), T::zero());
59
60    let mut grad_data = vec![T::zero(); n * n * bc];
61
62    for b in 0..bc {
63        let lambda = &val_data[b * n..(b + 1) * n];
64        let v = &vec_data[b * n * n..(b + 1) * n * n];
65
66        // Compute F matrix: F[i,j] = 1/conj(lambda_j - lambda_i) for i != j, 0 on diagonal
67        let mut f_mat = vec![zero_c; n * n];
68        for i in 0..n {
69            for j in 0..n {
70                if i != j {
71                    let diff = (lambda[j] - lambda[i]).conj();
72                    f_mat[i + j * n] = one_c / diff;
73                }
74            }
75        }
76
77        // V^H (conjugate transpose of V)
78        let vh = complex_conj_transpose(v, n);
79
80        // Build M_bar = diag(d_bar_lambda) + F .* (V^H d_bar_V)
81        let mut m_bar = vec![zero_c; n * n];
82
83        if let Some(ref dv_bar) = cotangent.vectors {
84            let dv_bar_data = extract_data_scalar(dv_bar)?;
85            let dv_bar_b = &dv_bar_data[b * n * n..(b + 1) * n * n];
86            let vh_dv = complex_mat_mul_nn_backend(ctx, &vh, dv_bar_b, n)?;
87            let mut dv_adj = dv_bar_b.to_vec();
88            for j in 0..n {
89                let correction = Cx::new(vh_dv[j + j * n].re, T::zero());
90                for i in 0..n {
91                    let index = i + j * n;
92                    dv_adj[index] = dv_adj[index] - v[index] * correction;
93                }
94            }
95            let vh_dv = complex_mat_mul_nn_backend(ctx, &vh, &dv_adj, n)?;
96            for k in 0..n * n {
97                m_bar[k] = f_mat[k] * vh_dv[k];
98            }
99        }
100
101        if let Some(ref dlam) = cotangent.values {
102            let dlam_data = extract_data_scalar(dlam)?;
103            for i in 0..n {
104                m_bar[i + i * n] = m_bar[i + i * n] + dlam_data[b * n + i];
105            }
106        }
107
108        // d_bar_A = V^{-H} M_bar V^H = solve(V^H, M_bar @ V^H)
109        let m_vh = complex_mat_mul_nn_backend(ctx, &m_bar, &vh, n)?;
110        let da_complex = complex_solve_nn(ctx, &vh, &m_vh, n)?;
111
112        // Take real part (since input A was real)
113        for k in 0..n * n {
114            grad_data[b * n * n + k] = da_complex[k].re;
115        }
116    }
117
118    let dims = output_dims(&[n, n], batch_dims);
119    tensor_from_data(grad_data, &dims).map_err(to_ad_err)
120}
121
122/// Reverse-mode AD rule for pseudoinverse (VJP / pullback).
123///
124/// # Examples
125///
126/// ```
127/// use tenferro_linalg::pinv_rrule;
128/// use tenferro_prims::CpuContext;
129/// use tenferro_tensor::{Tensor, MemoryOrder};
130/// use tenferro_device::LogicalMemorySpace;
131///
132/// let col = MemoryOrder::ColumnMajor;
133/// let mem = LogicalMemorySpace::MainMemory;
134/// let mut ctx = CpuContext::new(1);
135/// let a = Tensor::<f64>::zeros(&[3, 4], mem, col).unwrap();
136/// let cotangent = Tensor::<f64>::ones(&[4, 3], mem, col).unwrap();
137/// let grad_a = pinv_rrule(&mut ctx, &a, &cotangent, None).unwrap();
138/// ```
139pub fn pinv_rrule<T: KernelLinalgScalar + tenferro_algebra::Conjugate, C>(
140    ctx: &mut C,
141    tensor: &Tensor<T>,
142    cotangent: &Tensor<T>,
143    rcond: Option<f64>,
144) -> AdResult<Tensor<T>>
145where
146    T: KernelLinalgScalar
147        + crate::prims_bridge::ScaleTensorByRealSameShape<C>
148        + tenferro_algebra::Conjugate,
149    C: backend::TensorLinalgContextFor<T>
150        + tenferro_prims::TensorResolveConjContextFor<T>
151        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
152    C::Backend: 'static,
153    T::Real: tenferro_tensor::KeepCountScalar,
154{
155    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Pinv, "pinv_rrule")
156        .map_err(to_ad_err)?;
157
158    // dA = -(A+)^T dA+ (A+)^T + (I - AA+)(dA+)^T A+(A+)^T + (A+)^T A+ (dA+)^T (I - A+A)
159    let ap = pinv(ctx, tensor, rcond)
160        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
161    let (m, n, batch_dims) = validate_2d(tensor)
162        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
163    let bc = batch_count(batch_dims);
164
165    let (a_data, _) = extract_data(tensor)?;
166    let (ap_data, _) = extract_data(&ap)?;
167    let (dap_data, _) = extract_data(cotangent)?;
168
169    let mut grad_a = vec![T::zero(); m * n * bc];
170
171    for batch in 0..bc {
172        let a_b = &a_data[batch * m * n..(batch + 1) * m * n];
173        let ap_b = &ap_data[batch * n * m..(batch + 1) * n * m];
174        let dap_b = &dap_data[batch * n * m..(batch + 1) * n * m];
175
176        let apt = adjoint_transpose(ap_b, n, m); // m×n
177        let dapt = adjoint_transpose(dap_b, n, m); // m×n
178
179        // Term 1: -(A+)^T dA+ (A+)^T = -apt * dap * apt^T
180        // apt: m×n, dap: n×m, apt: m×n → m×n * n×m * m×n = m×n
181        let t1 = backend_mat_mul(ctx, &apt, m, n, dap_b, m)?;
182        let t1 = backend_mat_mul(ctx, &t1, m, m, &apt, n)?;
183        let t1 = scale_vec(&t1, -T::one());
184
185        // Term 2: (I - AA+)(dA+)^T A+ (A+)^T
186        // AA+ (m×m)
187        let aap = backend_mat_mul(ctx, a_b, m, n, ap_b, m)?;
188        let i_m = eye::<T>(m);
189        let i_aap = sub_vec(&i_m, &aap);
190        // (dA+)^T A+ = dapt * ap (m×n * n×m = m×m)
191        let dapt_ap = backend_mat_mul(ctx, &dapt, m, n, ap_b, m)?;
192        // * (A+)^T = * apt (m×m * m×n = m×n)
193        let dapt_ap_apt = backend_mat_mul(ctx, &dapt_ap, m, m, &apt, n)?;
194        let t2 = backend_mat_mul(ctx, &i_aap, m, m, &dapt_ap_apt, n)?;
195
196        // Term 3: (A+)^T A+ (dA+)^T (I - A+A)
197        // A+A (n×n)
198        let apa = backend_mat_mul(ctx, ap_b, n, m, a_b, n)?;
199        let i_n = eye::<T>(n);
200        let i_apa = sub_vec(&i_n, &apa);
201        // (A+)^T A+ = apt * ap (m×n * n×m = m×m)
202        let apt_ap = backend_mat_mul(ctx, &apt, m, n, ap_b, m)?;
203        // * (dA+)^T = * dapt (m×m * m×n = m×n)
204        let apt_ap_dapt = backend_mat_mul(ctx, &apt_ap, m, m, &dapt, n)?;
205        let t3 = backend_mat_mul(ctx, &apt_ap_dapt, m, n, &i_apa, n)?;
206
207        let da_b = add_vec(&t1, &add_vec(&t2, &t3));
208        grad_a[batch * m * n..(batch + 1) * m * n].copy_from_slice(&da_b);
209    }
210
211    let dims = output_dims(&[m, n], batch_dims);
212    tensor_from_data(grad_a, &dims)
213        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))
214}