Cholesky Decomposition AD Rules
Forward
A = L L^{\mathsf{H}}, \quad A \in \mathbb{C}^{N \times N} \text{ (Hermitian positive-definite)},\; L \text{ lower triangular}
Auxiliary operator
Define \varphi(X) = \mathrm{tril}(X) - \tfrac{1}{2}\mathrm{diag}(X) (extract lower triangle, halve the diagonal).
Its adjoint is \varphi^*(X) = \tfrac{1}{2}(X + X^{\mathsf{H}} - \mathrm{diag}(X)).
Forward rule (JVP)
Given tangent \dot{A} (Hermitian):
\dot{L} = L \,\varphi\!\bigl(L^{-1}\,\dot{A}\,L^{-\mathsf{H}}\bigr)
Derivation. Differentiating A = L L^{\mathsf{H}}:
\dot{A} = \dot{L}\,L^{\mathsf{H}} + L\,\dot{L}^{\mathsf{H}}
Left-multiply by L^{-1}, right-multiply by L^{-\mathsf{H}}:
L^{-1}\dot{A}\,L^{-\mathsf{H}} = L^{-1}\dot{L} + (L^{-1}\dot{L})^{\mathsf{H}}
Since L^{-1}\dot{L} is lower triangular, inverting the symmetrization gives L^{-1}\dot{L} = \varphi(L^{-1}\dot{A}\,L^{-\mathsf{H}}), hence \dot{L} = L\,\varphi(\cdots).
Algorithm:
- Solve T \leftarrow L^{-1}\,\dot{A} (triangular solve, left)
- Solve T \leftarrow T\,L^{-\mathsf{H}} (triangular solve, right)
- T \leftarrow \mathrm{tril}(T) - \tfrac{1}{2}\mathrm{diag}(T)
- \dot{L} \leftarrow L\,T
Reverse rule (VJP)
Given cotangent \bar{L}:
\bar{A} = L^{-\mathsf{H}}\,\varphi^*\!\bigl(\mathrm{tril}(L^{\mathsf{H}}\bar{L})\bigr)\,L^{-1}
Derivation. Taking the adjoint of the JVP linear map \dot{A} \mapsto \dot{L}:
\delta\ell = \langle \bar{L},\,\dot{L}\rangle = \mathrm{tr}\!\bigl(\bar{L}^{\mathsf{H}}\,L\,\varphi(L^{-1}\dot{A}\,L^{-\mathsf{H}})\bigr)
Working through the adjoint chain:
- Adjoint of \varphi: \varphi^*
- Adjoint of left-multiply by L: left-multiply by L^{\mathsf{H}}, then project to lower triangle
- Adjoint of L^{-1}(\cdot)L^{-\mathsf{H}}: L^{-\mathsf{H}}(\cdot)L^{-1}
This yields \bar{A} = L^{-\mathsf{H}}\,\varphi^*(\mathrm{tril}(L^{\mathsf{H}}\bar{L}))\,L^{-1}.
The output \bar{A} is symmetric (Hermitian), consistent with the constraint on A.
Algorithm:
- S \leftarrow \mathrm{tril}(L^{\mathsf{H}}\bar{L})
- S \leftarrow \tfrac{1}{2}(S + S^{\mathsf{H}} - \mathrm{diag}(S)) (equivalently: S \leftarrow \tfrac{1}{2}(S + \mathrm{tril}(S,-1)^{\mathsf{H}}))
- Solve \bar{A} \leftarrow L^{-\mathsf{H}}\,S (triangular solve, left)
- Solve \bar{A} \leftarrow \bar{A}\,L^{-1} (triangular solve, right)
Implementation notes
- All operations are O(N^3), same as the forward Cholesky.
- Use blocked algorithms (Murray 2016) for large matrices to exploit Level-3 BLAS.
- Never form L^{-1} explicitly; always use triangular solves.
- For complex types, ensure diagonal of \bar{A} is real (A is Hermitian).
References
- Smith, S. P. (1995). “Differentiation of the Cholesky Algorithm.” J. Comput. Graph. Stat., 4(2), 134-147.
- Giles, M. B. (2008). “An extended collection of matrix derivative results for forward and reverse mode AD.”
- Murray, I. (2016). “Differentiation of the Cholesky decomposition.” arXiv:1602.07527.
- Seeger, M. et al. (2017). “Auto-Differentiating Linear Algebra.” arXiv:1710.08717.
- PyTorch
FunctionsManual.cpp:cholesky_jvp(L1962),cholesky_backward(L1983). - JAX
jax/_src/lax/linalg.py:_cholesky_jvp_rule. - ChainRules.jl
src/rulesets/LinearAlgebra/factorization.jl:_cholesky_pullback_shared_code.