solve_triangular_rrule

Function solve_triangular_rrule 

Source
pub fn solve_triangular_rrule<T, C>(
    ctx: &mut C,
    a: &Tensor<T>,
    b: &Tensor<T>,
    cotangent: &Tensor<T>,
    upper: bool,
) -> AdResult<SolveGrad<T>>
Expand description

Reverse-mode AD rule for triangular solve (VJP / pullback).

Given A x = b with triangular A and cotangent , computes (Ā, b̄).

  • G = A^{-H} x̄ solved with conjugate-transposed triangular structure
  • b̄ = G
  • Ā = proj(-G x^H) where proj = triu for upper, tril for lower