tenferro_linalg/frules/
spectral.rs

1use super::*;
2
3/// Forward-mode AD rule for general eigendecomposition (JVP / pushforward).
4///
5/// Given eigendecomposition `A V = V diag(lambda)`, computes the tangents
6/// of eigenvalues and eigenvectors from a real tangent `dA` using the
7/// Mike Giles formulas.
8///
9/// Returns `(primal, tangent)` where both are [`EigResult`] with complex
10/// eigenvalues and eigenvectors.
11///
12/// # Examples
13///
14/// ```
15/// use tenferro_linalg::eig_frule;
16/// use tenferro_prims::CpuContext;
17/// use tenferro_tensor::{Tensor, MemoryOrder};
18/// use tenferro_device::LogicalMemorySpace;
19///
20/// let col = MemoryOrder::ColumnMajor;
21/// let mem = LogicalMemorySpace::MainMemory;
22/// let mut ctx = CpuContext::new(1);
23/// let a = Tensor::<f64>::zeros(&[3, 3], mem, col).unwrap();
24/// let da = Tensor::<f64>::ones(&[3, 3], mem, col).unwrap();
25/// let (result, dresult) = eig_frule(&mut ctx, &a, &da).unwrap();
26/// ```
27pub fn eig_frule<
28    T: KernelLinalgScalar<Real = T, Complex = num_complex::Complex<T>> + num_traits::Float,
29    C,
30>(
31    ctx: &mut C,
32    tensor: &Tensor<T>,
33    tangent: &Tensor<T>,
34) -> AdResult<(EigResult<T>, EigResult<T>)>
35where
36    T: KernelLinalgScalar,
37    C: backend::TensorLinalgContextFor<T>
38        + backend::TensorLinalgContextFor<num_complex::Complex<T>>
39        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
40    num_complex::Complex<T>: KernelLinalgScalar,
41    <C as backend::TensorLinalgContextFor<T>>::Backend: 'static,
42    <C as backend::TensorLinalgContextFor<num_complex::Complex<T>>>::Backend: 'static,
43    T::Real: tenferro_tensor::KeepCountScalar,
44{
45    // Forward pass
46    let eig_result = eig(ctx, tensor).map_err(to_ad_err)?;
47
48    let (n, batch_dims) = validate_square(tensor).map_err(to_ad_err)?;
49    let bc = batch_count(batch_dims);
50
51    let val_data = extract_data_scalar(&eig_result.values)?;
52    let vec_data = extract_data_scalar(&eig_result.vectors)?;
53    let (tang_data, _) = extract_data(tangent)?;
54
55    let zero_c = Cx::new(T::zero(), T::zero());
56    let one_c = Cx::new(T::one(), T::zero());
57
58    let mut dval_data = vec![zero_c; n * bc];
59    let mut dvec_data = vec![zero_c; n * n * bc];
60
61    for b in 0..bc {
62        let lambda = &val_data[b * n..(b + 1) * n];
63        let v = &vec_data[b * n * n..(b + 1) * n * n];
64        let da = &tang_data[b * n * n..(b + 1) * n * n];
65
66        // Convert real dA to complex
67        let da_complex: Vec<Cx<T>> = da.iter().map(|&x| Cx::new(x, T::zero())).collect();
68
69        // W = V^{-1} dA V = solve(V, dA_c @ V)
70        let da_v = complex_mat_mul_nn_backend(ctx, &da_complex, v, n)?;
71        let w = complex_solve_nn(ctx, v, &da_v, n)?;
72
73        // d_lambda = diag(W)
74        for i in 0..n {
75            dval_data[b * n + i] = w[i + i * n];
76        }
77
78        // F matrix: F[i,j] = 1/(lambda_j - lambda_i) for i != j, 0 on diagonal
79        let mut f_mat = vec![zero_c; n * n];
80        for i in 0..n {
81            for j in 0..n {
82                if i != j {
83                    let diff = lambda[j] - lambda[i];
84                    f_mat[i + j * n] = one_c / diff;
85                }
86            }
87        }
88
89        // Raw eigenvector tangent dV_raw = V * (F .* W)
90        let mut fw = vec![zero_c; n * n];
91        for k in 0..n * n {
92            fw[k] = f_mat[k] * w[k];
93        }
94        let dv_raw = complex_mat_mul_nn_backend(ctx, v, &fw, n)?;
95
96        // PyTorch and tenferro normalize eigenvectors to unit norm, so the raw
97        // tangent must be projected back to that gauge:
98        // dV = dV_raw - V * diag(Re(V^H dV_raw)).
99        let vh = complex_conj_transpose(v, n);
100        let vh_dv = complex_mat_mul_nn_backend(ctx, &vh, &dv_raw, n)?;
101        let mut dv = dv_raw;
102        for j in 0..n {
103            let correction = Cx::new(vh_dv[j + j * n].re, T::zero());
104            for i in 0..n {
105                let index = i + j * n;
106                dv[index] = dv[index] - v[index] * correction;
107            }
108        }
109        dvec_data[b * n * n..(b + 1) * n * n].copy_from_slice(&dv);
110    }
111
112    // Build tangent EigResult
113    let val_dims = output_dims(&[n], batch_dims);
114    let vec_dims = output_dims(&[n, n], batch_dims);
115
116    let d_result = EigResult {
117        values: tensor_from_data_scalar(dval_data, &val_dims).map_err(to_ad_err)?,
118        vectors: tensor_from_data_scalar(dvec_data, &vec_dims).map_err(to_ad_err)?,
119    };
120
121    Ok((eig_result, d_result))
122}
123
124/// Forward-mode AD rule for pseudoinverse (JVP / pushforward).
125///
126/// # Examples
127///
128/// ```
129/// use tenferro_linalg::pinv_frule;
130/// use tenferro_prims::CpuContext;
131/// use tenferro_tensor::{Tensor, MemoryOrder};
132/// use tenferro_device::LogicalMemorySpace;
133///
134/// let col = MemoryOrder::ColumnMajor;
135/// let mem = LogicalMemorySpace::MainMemory;
136/// let mut ctx = CpuContext::new(1);
137/// let a = Tensor::<f64>::zeros(&[3, 4], mem, col).unwrap();
138/// let da = Tensor::<f64>::ones(&[3, 4], mem, col).unwrap();
139/// let (pinv_a, dpinv_a) = pinv_frule(&mut ctx, &a, &da, None).unwrap();
140/// ```
141pub fn pinv_frule<T: KernelLinalgScalar + tenferro_algebra::Conjugate, C>(
142    ctx: &mut C,
143    tensor: &Tensor<T>,
144    tangent: &Tensor<T>,
145    rcond: Option<f64>,
146) -> AdResult<(Tensor<T>, Tensor<T>)>
147where
148    T: KernelLinalgScalar
149        + crate::prims_bridge::ScaleTensorByRealSameShape<C>
150        + tenferro_algebra::Conjugate,
151    C: backend::TensorLinalgContextFor<T>
152        + tenferro_prims::TensorResolveConjContextFor<T>
153        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
154    C::Backend: 'static,
155    T::Real: tenferro_tensor::KeepCountScalar,
156{
157    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Pinv, "pinv_frule")
158        .map_err(to_ad_err)?;
159
160    // dA+ = -A+ dA A+ + (I - A+A) dA^T (A+)^T A+ + A+ (A+)^T dA^T (I - AA+)
161    let ap = pinv(ctx, tensor, rcond)
162        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
163    let (m, n, batch_dims) = validate_2d(tensor)
164        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
165    let bc = batch_count(batch_dims);
166
167    let (a_data, _) = extract_data(tensor)?;
168    let (ap_data, _) = extract_data(&ap)?;
169    let (da_data, _) = extract_data(tangent)?;
170
171    let mut dap_data = vec![T::zero(); n * m * bc];
172
173    for batch in 0..bc {
174        let a_b = &a_data[batch * m * n..(batch + 1) * m * n];
175        let ap_b = &ap_data[batch * n * m..(batch + 1) * n * m];
176        let da_b = &da_data[batch * m * n..(batch + 1) * m * n];
177
178        let dat = adjoint_transpose(da_b, m, n); // n×m
179        let apt = adjoint_transpose(ap_b, n, m); // m×n
180
181        // Term 1: -A+ dA A+ (n×m × m×n × n×m = n×m)
182        let ap_da = backend_mat_mul(ctx, ap_b, n, m, da_b, n)?;
183        let ap_da_ap = backend_mat_mul(ctx, &ap_da, n, n, ap_b, m)?;
184        let t1 = scale_vec(&ap_da_ap, -T::one());
185
186        // Term 2: (I - A+A) dA^T (A+)^T A+
187        let apa = backend_mat_mul(ctx, ap_b, n, m, a_b, n)?; // n×n
188        let i_n = eye::<T>(n);
189        let i_apa = sub_vec(&i_n, &apa);
190        let dat_apt = backend_mat_mul(ctx, &dat, n, m, &apt, n)?; // n×n
191        let dat_apt_ap = backend_mat_mul(ctx, &dat_apt, n, n, ap_b, m)?; // n×m
192        let t2 = backend_mat_mul(ctx, &i_apa, n, n, &dat_apt_ap, m)?;
193
194        // Term 3: A+ (A+)^T dA^T (I - AA+)
195        let aap = backend_mat_mul(ctx, a_b, m, n, ap_b, m)?; // m×m
196        let i_m = eye::<T>(m);
197        let i_aap = sub_vec(&i_m, &aap);
198        let ap_apt = backend_mat_mul(ctx, ap_b, n, m, &apt, n)?; // n×n
199        let ap_apt_dat = backend_mat_mul(ctx, &ap_apt, n, n, &dat, m)?; // n×m
200        let t3 = backend_mat_mul(ctx, &ap_apt_dat, n, m, &i_aap, m)?;
201
202        let dap_b_vec = add_vec(&t1, &add_vec(&t2, &t3));
203        dap_data[batch * n * m..(batch + 1) * n * m].copy_from_slice(&dap_b_vec);
204    }
205
206    let dims = output_dims(&[n, m], batch_dims);
207    let dap = tensor_from_data(dap_data, &dims)
208        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
209    Ok((ap, dap))
210}