tenferro_linalg/rrules/linear_systems/
solve.rs

1use super::*;
2
3/// Reverse-mode AD rule for linear solve (VJP / pullback).
4///
5/// Given `Ax = b` and cotangent `x̄`, computes `(Ā, b̄)`.
6///
7/// # Examples
8///
9/// ```
10/// use tenferro_linalg::solve_rrule;
11/// use tenferro_prims::CpuContext;
12/// use tenferro_tensor::{Tensor, MemoryOrder};
13/// use tenferro_device::LogicalMemorySpace;
14///
15/// let col = MemoryOrder::ColumnMajor;
16/// let mem = LogicalMemorySpace::MainMemory;
17/// let mut ctx = CpuContext::new(1);
18/// let a = Tensor::<f64>::eye(3, mem, col).unwrap();
19/// let b = Tensor::<f64>::ones(&[3], mem, col).unwrap();
20/// let cotangent = Tensor::<f64>::ones(&[3], mem, col).unwrap();
21/// let grad = solve_rrule(&mut ctx, &a, &b, &cotangent).unwrap();
22/// ```
23pub fn solve_rrule<T, C>(
24    ctx: &mut C,
25    a: &Tensor<T>,
26    b: &Tensor<T>,
27    cotangent: &Tensor<T>,
28) -> AdResult<SolveGrad<T>>
29where
30    T: KernelLinalgScalar + tenferro_algebra::Conjugate,
31    C: backend::TensorLinalgContextFor<T> + tenferro_prims::TensorResolveConjContextFor<T>,
32    C::Backend: 'static,
33{
34    // Ax = b → G = A^{-H} dx̄, dB = G, dA = -G x^H
35    let x = solve(ctx, a, b).map_err(to_ad_err)?;
36    let (n, batch_dims) = validate_square(a).map_err(to_ad_err)?;
37    let rhs = backend::tensor_helpers::validate_solve_rhs_shape(b, n, batch_dims, "solve_rrule")
38        .map_err(to_ad_err)?;
39    let nrhs = rhs.nrhs;
40    let sr = rhs.structural_rank;
41
42    // Promote cotangent and x to matrix form [n, nrhs, batch...]
43    let dx_mat = rhs_to_matrix(cotangent, sr).map_err(to_ad_err)?;
44    let x_mat = rhs_to_matrix(&x, sr).map_err(to_ad_err)?;
45
46    // G = solve(A^H, cotangent)
47    let a_h = matrix_adjoint_eager(ctx, a).map_err(to_ad_err)?;
48    let g = solve(ctx, &a_h, &dx_mat).map_err(to_ad_err)?;
49
50    // dB = G (convert back to original RHS shape)
51    let grad_b = matrix_to_rhs(g.clone(), sr).map_err(to_ad_err)?;
52
53    // dA = -G x^H  (use alpha=-1 in GEMM)
54    let x_h = matrix_adjoint_eager(ctx, &x_mat).map_err(to_ad_err)?;
55    let grad_a = prims_bridge::batched_gemm_alpha_tensors(ctx, &g, &x_h, n, nrhs, n, -T::one())
56        .map_err(to_ad_err)?;
57
58    Ok(SolveGrad {
59        a: grad_a,
60        b: grad_b,
61    })
62}
63
64/// Reverse-mode AD rule for triangular solve (VJP / pullback).
65///
66/// Given `A x = b` with triangular `A` and cotangent `x̄`, computes `(Ā, b̄)`.
67///
68/// - `G = A^{-H} x̄` solved with conjugate-transposed triangular structure
69/// - `b̄ = G`
70/// - `Ā = proj(-G x^H)` where `proj = triu` for upper, `tril` for lower
71pub fn solve_triangular_rrule<T, C>(
72    ctx: &mut C,
73    a: &Tensor<T>,
74    b: &Tensor<T>,
75    cotangent: &Tensor<T>,
76    upper: bool,
77) -> AdResult<SolveGrad<T>>
78where
79    T: KernelLinalgScalar + tenferro_algebra::Conjugate,
80    C: backend::TensorLinalgContextFor<T>
81        + tenferro_prims::TensorResolveConjContextFor<T>
82        + tenferro_prims::TensorMetadataContextFor,
83    C::Backend: 'static,
84{
85    let x = solve_triangular(ctx, a, b, upper).map_err(to_ad_err)?;
86    let (n, _) = validate_square(a).map_err(to_ad_err)?;
87    let rhs = backend::tensor_helpers::validate_solve_rhs_shape(
88        b,
89        n,
90        &a.dims()[2..],
91        "solve_triangular_rrule",
92    )
93    .map_err(to_ad_err)?;
94    let nrhs = rhs.nrhs;
95    let sr = rhs.structural_rank;
96
97    // Promote cotangent and x to matrix form
98    let dx_mat = rhs_to_matrix(cotangent, sr).map_err(to_ad_err)?;
99    let x_mat = rhs_to_matrix(&x, sr).map_err(to_ad_err)?;
100
101    // G = solve_triangular(A^H, dx, !upper)  (A^H flips upper/lower)
102    let a_h = matrix_adjoint_eager(ctx, a).map_err(to_ad_err)?;
103    let g = solve_triangular(ctx, &a_h, &dx_mat, !upper).map_err(to_ad_err)?;
104
105    // dB = G
106    let grad_b = matrix_to_rhs(g.clone(), sr).map_err(to_ad_err)?;
107
108    // dA = proj(-G x^H)
109    let x_h = matrix_adjoint_eager(ctx, &x_mat).map_err(to_ad_err)?;
110    let neg_g_xh = prims_bridge::batched_gemm_alpha_tensors(ctx, &g, &x_h, n, nrhs, n, -T::one())
111        .map_err(to_ad_err)?;
112    let grad_a = if upper {
113        tenferro_prims::tensor_ops::triu(ctx, &neg_g_xh, 0).map_err(to_ad_err)?
114    } else {
115        tenferro_prims::tensor_ops::tril(ctx, &neg_g_xh, 0).map_err(to_ad_err)?
116    };
117
118    Ok(SolveGrad {
119        a: grad_a,
120        b: grad_b,
121    })
122}