Scalar And Tensor Wrapper AD Notes
Conventions
Unless noted otherwise, Linearization and Transpose are written for the raw-output-space operator before any DB observable projection. For complex tensors, Transpose means the adjoint under the real Frobenius inner product
\langle X, Y \rangle_{\mathbb{R}} = \operatorname{Re}\operatorname{tr}(X^\dagger Y).
Forward
This note groups raw scalar and tensor-wrapper operators of three common forms:
y = f(x), \qquad y = h(x_1, x_2), \qquad y = r(x).
Here f is a unary wrapper, h is a binary wrapper, and r is a reduction or small tensor composite.
Linearization
Representative raw-output-space linearizations are:
- unary analytic wrapper:
\dot{y} = f'(x)\dot{x}
- addition and subtraction:
\dot{y} = \dot{x}_1 \pm \dot{x}_2
- multiplication:
\dot{y} = \dot{x}_1 x_2 + x_1 \dot{x}_2
- quotient:
\dot{y} = \frac{\dot{x}_1 x_2 - x_1 \dot{x}_2}{x_2^2}
- reductions:
\dot{y}_{\mathrm{sum}} = \sum_i \dot{x}_i, \qquad \dot{y}_{\mathrm{mean}} = \frac{1}{n}\sum_i \dot{x}_i
The same pattern extends to var, std, and tensor composites through the scalar basis plus broadcast or reduction structure.
JVP
The JVP is the linearization evaluated at the chosen tangent. Representative examples are
\operatorname{jvp}(\exp)(x;\dot{x}) = \exp(x)\dot{x}, \qquad \operatorname{jvp}(\mathrm{mul})(x_1, x_2;\dot{x}_1,\dot{x}_2) = \dot{x}_1 x_2 + x_1 \dot{x}_2.
For reductions, JVP just applies the corresponding sum, averaging, or centered residual rule to the tangent tensor.
Transpose
For a raw output cotangent \bar{y}, representative transpose rules are:
- unary analytic wrapper:
\bar{x} = \overline{f'(x)} \odot \bar{y}
- addition and subtraction:
(\bar{x}_1, \bar{x}_2) = (\bar{y}, \pm \bar{y})
- multiplication:
(\bar{x}_1, \bar{x}_2) = (\bar{y}\,\overline{x_2}, \bar{y}\,\overline{x_1})
- reductions:
\bar{x}_{\mathrm{sum}} = \operatorname{broadcast}(\bar{y}), \qquad \bar{x}_{\mathrm{mean}} = \frac{1}{n}\operatorname{broadcast}(\bar{y})
var and std add the centered-residual correction recorded later in this note.
VJP (JAX convention)
JAX presents the same raw transpose map as the VJP or linear_transpose result, interpreted under the real Frobenius inner product. Complex families therefore follow the raw adjoint formulas directly.
VJP (PyTorch convention)
PyTorch uses the same raw formulas but packages them through its conjugate-Wirtinger convention. When a real input is embedded into complex intermediates, the final cotangent is projected back to the real domain via handle_r_to_c.
Scope
This note records shared scalar AD formulas together with the tensor-level wrappers built from them.
Complex Gradient Convention
For real-valued losses:
- gradients follow the conjugate-Wirtinger convention
- VJP formulas include complex conjugation where required
- real inputs project complex intermediates back to the real domain
Current Complex Support Boundary
This note groups both complex-capable wrappers and wrappers that remain float-only in the pinned PyTorch upstream AD coverage.
For this repository phase, families that are still float-only in the pinned PyTorch upstream AD coverage are tracked as explicitly unsupported for complex in docs/math/complex-support.json rather than being promoted to repo-specific complex extensions.
Scalar Basis Rules
Let g be the output cotangent, x the primal input, and y = f(x) the primal output.
Core arithmetic
add: forx_1 + \alpha x_2, (dx_1, dx_2) = (g, \overline{\alpha}\, g)sub: forx_1 - \alpha x_2, (dx_1, dx_2) = (g, -\overline{\alpha}\, g)mul: (dx_1, dx_2) = (g \cdot \overline{x_2}, g \cdot \overline{x_1})div:- numerator path: dx_1 = g / \overline{x_2}
- denominator path: dx_2 = -g \cdot \overline{x_1 / x_2^2}
- integer-style rounding modes are treated as nondifferentiable branches
Analytic unary wrappers
conj: dx = \overline{g}sqrt: dx = g / (2 \overline{\sqrt{x}})exp: dx = g \cdot \overline{y}log: dx = g / \overline{x}expm1: dx = g \cdot \overline{\exp(x)}log1p: dx = g / \overline{(1 + x)}sin: dx = g \cdot \overline{\cos(x)}cos: dx = -g \cdot \overline{\sin(x)}tanh: dx = g \cdot \overline{(1 - y^2)}
Parameterized wrappers
atan2: for real inputs (a, b), da = g \, b / (a^2 + b^2) and db = -g \, a / (a^2 + b^2), with the zero-denominator singularity masked by the implementation conventionpowf: for fixed exponent p, dx = g \cdot \overline{(p x^{p-1})}powi: integer-exponent specialization ofpowfpow:- base path: dx = g \cdot \overline{a x^{a-1}}
- exponent path: da = g \cdot \overline{x^a \log(x)}
Tensor-Composite Rules
Tensor-level wrappers built on top of the scalar basis include:
- pointwise unary analytic families
- broadcasted binary analytic families
- small tensor wrappers such as
cross,diagonal,matrix_power,multi_dot,vander,vecdot, andhouseholder_product
Tensor Reduction Wrappers
sum_ad
For a reduction over index set \mathcal{I},
y = \sum_{i \in \mathcal{I}} x_i \quad \Longrightarrow \quad \bar{x}_i = \bar{y}
for every reduced element, with the cotangent broadcast back to the input shape.
mean_ad
If n entries are reduced,
y = \frac{1}{n} \sum_{i \in \mathcal{I}} x_i \quad \Longrightarrow \quad \bar{x}_i = \frac{\bar{y}}{n}.
var_ad
Let \mu = \operatorname{mean}(x) over the reduced axes and let correction denote the Bessel-style offset used by the variance operator. Then
\operatorname{var}(x) = \frac{1}{n - \mathrm{correction}} \sum_i |x_i - \mu|^2,
so away from the singular degrees-of-freedom boundary,
\bar{x} = \frac{2}{n - \mathrm{correction}} \, \bar{v} \, (x - \mu).
At n - \mathrm{correction} \le 0, the operator is singular and the derivative inherits the same NaN / infinity boundary behavior as the primal convention.
std_ad
For \sigma = \sqrt{v} with v = \operatorname{var}(x),
\bar{v} = \frac{\bar{\sigma}}{2 \sigma},
masked at \sigma = 0, and then the variance rule is applied to propagate back to x.
Published DB Families Using This Note
Reflected and arithmetic wrappers
Unary analytic, sign, rounding, and casts
absacosacoshangleasinasinhatanatan2atanhcdoubleceilclamp_maxclamp_mincomplexconjconj_physicalcopysigncoscoshdeg2raddigammadoubleerferfcerfinvexpexp2expm1fillfloorfmaxfminfracfrexpi0imagldexplgammaloglog10log1plog2logaddexplogitnan_to_numnegpositivepolarrad2degrealreciprocalroundround_decimals_0round_decimals_3round_decimals_neg_3rsqrtsgnsigmoidsignsinsincsinhspecial_entrspecial_erfcxspecial_i0especial_i1special_i1especial_log_ndtrspecial_ndtrspecial_ndtrispecial_polygamma_special_polygamma_n_0special_xlog1pysqrtsquaretantanhtrunc
Reductions and statistics
Neural-network functional wrappers
nn_functional_celunn_functional_elunn_functional_hardshrinknn_functional_hardsigmoidnn_functional_hardtanhnn_functional_logsigmoidnn_functional_mishnn_functional_prelunn_functional_relunn_functional_relu6nn_functional_rrelunn_functional_selunn_functional_silunn_functional_softplusnn_functional_softshrinknn_functional_softsignnn_functional_tanhshrinknn_functional_threshold
Special-function parameter families
Small tensor wrappers currently grouped here
Notes On Future Splits
This shared note is intentionally broad in the first migration pass. Operations that later grow heavier derivation detail can be split into dedicated note files without changing the DB schema; only the central registry needs to move.