Scalar And Tensor Wrapper AD Notes

Conventions

Unless noted otherwise, Linearization and Transpose are written for the raw-output-space operator before any DB observable projection. For complex tensors, Transpose means the adjoint under the real Frobenius inner product

\langle X, Y \rangle_{\mathbb{R}} = \operatorname{Re}\operatorname{tr}(X^\dagger Y).

Forward

This note groups raw scalar and tensor-wrapper operators of three common forms:

y = f(x), \qquad y = h(x_1, x_2), \qquad y = r(x).

Here f is a unary wrapper, h is a binary wrapper, and r is a reduction or small tensor composite.

Linearization

Representative raw-output-space linearizations are:

  • unary analytic wrapper:

\dot{y} = f'(x)\dot{x}

  • addition and subtraction:

\dot{y} = \dot{x}_1 \pm \dot{x}_2

  • multiplication:

\dot{y} = \dot{x}_1 x_2 + x_1 \dot{x}_2

  • quotient:

\dot{y} = \frac{\dot{x}_1 x_2 - x_1 \dot{x}_2}{x_2^2}

  • reductions:

\dot{y}_{\mathrm{sum}} = \sum_i \dot{x}_i, \qquad \dot{y}_{\mathrm{mean}} = \frac{1}{n}\sum_i \dot{x}_i

The same pattern extends to var, std, and tensor composites through the scalar basis plus broadcast or reduction structure.

JVP

The JVP is the linearization evaluated at the chosen tangent. Representative examples are

\operatorname{jvp}(\exp)(x;\dot{x}) = \exp(x)\dot{x}, \qquad \operatorname{jvp}(\mathrm{mul})(x_1, x_2;\dot{x}_1,\dot{x}_2) = \dot{x}_1 x_2 + x_1 \dot{x}_2.

For reductions, JVP just applies the corresponding sum, averaging, or centered residual rule to the tangent tensor.

Transpose

For a raw output cotangent \bar{y}, representative transpose rules are:

  • unary analytic wrapper:

\bar{x} = \overline{f'(x)} \odot \bar{y}

  • addition and subtraction:

(\bar{x}_1, \bar{x}_2) = (\bar{y}, \pm \bar{y})

  • multiplication:

(\bar{x}_1, \bar{x}_2) = (\bar{y}\,\overline{x_2}, \bar{y}\,\overline{x_1})

  • reductions:

\bar{x}_{\mathrm{sum}} = \operatorname{broadcast}(\bar{y}), \qquad \bar{x}_{\mathrm{mean}} = \frac{1}{n}\operatorname{broadcast}(\bar{y})

var and std add the centered-residual correction recorded later in this note.

VJP (JAX convention)

JAX presents the same raw transpose map as the VJP or linear_transpose result, interpreted under the real Frobenius inner product. Complex families therefore follow the raw adjoint formulas directly.

VJP (PyTorch convention)

PyTorch uses the same raw formulas but packages them through its conjugate-Wirtinger convention. When a real input is embedded into complex intermediates, the final cotangent is projected back to the real domain via handle_r_to_c.

Scope

This note records shared scalar AD formulas together with the tensor-level wrappers built from them.

Complex Gradient Convention

For real-valued losses:

  • gradients follow the conjugate-Wirtinger convention
  • VJP formulas include complex conjugation where required
  • real inputs project complex intermediates back to the real domain

Current Complex Support Boundary

This note groups both complex-capable wrappers and wrappers that remain float-only in the pinned PyTorch upstream AD coverage.

For this repository phase, families that are still float-only in the pinned PyTorch upstream AD coverage are tracked as explicitly unsupported for complex in docs/math/complex-support.json rather than being promoted to repo-specific complex extensions.

Scalar Basis Rules

Let g be the output cotangent, x the primal input, and y = f(x) the primal output.

Core arithmetic

  • add: for x_1 + \alpha x_2, (dx_1, dx_2) = (g, \overline{\alpha}\, g)
  • sub: for x_1 - \alpha x_2, (dx_1, dx_2) = (g, -\overline{\alpha}\, g)
  • mul: (dx_1, dx_2) = (g \cdot \overline{x_2}, g \cdot \overline{x_1})
  • div:
    • numerator path: dx_1 = g / \overline{x_2}
    • denominator path: dx_2 = -g \cdot \overline{x_1 / x_2^2}
    • integer-style rounding modes are treated as nondifferentiable branches

Analytic unary wrappers

  • conj: dx = \overline{g}
  • sqrt: dx = g / (2 \overline{\sqrt{x}})
  • exp: dx = g \cdot \overline{y}
  • log: dx = g / \overline{x}
  • expm1: dx = g \cdot \overline{\exp(x)}
  • log1p: dx = g / \overline{(1 + x)}
  • sin: dx = g \cdot \overline{\cos(x)}
  • cos: dx = -g \cdot \overline{\sin(x)}
  • tanh: dx = g \cdot \overline{(1 - y^2)}

Parameterized wrappers

  • atan2: for real inputs (a, b), da = g \, b / (a^2 + b^2) and db = -g \, a / (a^2 + b^2), with the zero-denominator singularity masked by the implementation convention
  • powf: for fixed exponent p, dx = g \cdot \overline{(p x^{p-1})}
  • powi: integer-exponent specialization of powf
  • pow:
    • base path: dx = g \cdot \overline{a x^{a-1}}
    • exponent path: da = g \cdot \overline{x^a \log(x)}

Tensor-Composite Rules

Tensor-level wrappers built on top of the scalar basis include:

  • pointwise unary analytic families
  • broadcasted binary analytic families
  • small tensor wrappers such as cross, diagonal, matrix_power, multi_dot, vander, vecdot, and householder_product

Tensor Reduction Wrappers

sum_ad

For a reduction over index set \mathcal{I},

y = \sum_{i \in \mathcal{I}} x_i \quad \Longrightarrow \quad \bar{x}_i = \bar{y}

for every reduced element, with the cotangent broadcast back to the input shape.

mean_ad

If n entries are reduced,

y = \frac{1}{n} \sum_{i \in \mathcal{I}} x_i \quad \Longrightarrow \quad \bar{x}_i = \frac{\bar{y}}{n}.

var_ad

Let \mu = \operatorname{mean}(x) over the reduced axes and let correction denote the Bessel-style offset used by the variance operator. Then

\operatorname{var}(x) = \frac{1}{n - \mathrm{correction}} \sum_i |x_i - \mu|^2,

so away from the singular degrees-of-freedom boundary,

\bar{x} = \frac{2}{n - \mathrm{correction}} \, \bar{v} \, (x - \mu).

At n - \mathrm{correction} \le 0, the operator is singular and the derivative inherits the same NaN / infinity boundary behavior as the primal convention.

std_ad

For \sigma = \sqrt{v} with v = \operatorname{var}(x),

\bar{v} = \frac{\bar{\sigma}}{2 \sigma},

masked at \sigma = 0, and then the variance rule is applied to propagate back to x.

Published DB Families Using This Note

Reflected and arithmetic wrappers

  • __radd__
  • __rdiv__
  • __rmod__
  • __rmul__
  • __rpow__
  • __rsub__
  • add
  • div_no_rounding_mode
  • float_power
  • hypot
  • max_binary
  • maximum
  • min_binary
  • minimum
  • mul
  • pow
  • rsub
  • sub
  • true_divide
  • xlogy

Unary analytic, sign, rounding, and casts

  • abs
  • acos
  • acosh
  • angle
  • asin
  • asinh
  • atan
  • atan2
  • atanh
  • cdouble
  • ceil
  • clamp_max
  • clamp_min
  • complex
  • conj
  • conj_physical
  • copysign
  • cos
  • cosh
  • deg2rad
  • digamma
  • double
  • erf
  • erfc
  • erfinv
  • exp
  • exp2
  • expm1
  • fill
  • floor
  • fmax
  • fmin
  • frac
  • frexp
  • i0
  • imag
  • ldexp
  • lgamma
  • log
  • log10
  • log1p
  • log2
  • logaddexp
  • logit
  • nan_to_num
  • neg
  • positive
  • polar
  • rad2deg
  • real
  • reciprocal
  • round
  • round_decimals_0
  • round_decimals_3
  • round_decimals_neg_3
  • rsqrt
  • sgn
  • sigmoid
  • sign
  • sin
  • sinc
  • sinh
  • special_entr
  • special_erfcx
  • special_i0e
  • special_i1
  • special_i1e
  • special_log_ndtr
  • special_ndtr
  • special_ndtri
  • special_polygamma_special_polygamma_n_0
  • special_xlog1py
  • sqrt
  • square
  • tan
  • tanh
  • trunc

Reductions and statistics

  • amax
  • amin
  • mean
  • nanmean
  • nansum
  • prod
  • std
  • std_unbiased
  • sum
  • var
  • var_unbiased

Neural-network functional wrappers

  • nn_functional_celu
  • nn_functional_elu
  • nn_functional_hardshrink
  • nn_functional_hardsigmoid
  • nn_functional_hardtanh
  • nn_functional_logsigmoid
  • nn_functional_mish
  • nn_functional_prelu
  • nn_functional_relu
  • nn_functional_relu6
  • nn_functional_rrelu
  • nn_functional_selu
  • nn_functional_silu
  • nn_functional_softplus
  • nn_functional_softshrink
  • nn_functional_softsign
  • nn_functional_tanhshrink
  • nn_functional_threshold

Special-function parameter families

  • mvlgamma_mvlgamma_p_1
  • mvlgamma_mvlgamma_p_3
  • mvlgamma_mvlgamma_p_5
  • polygamma_polygamma_n_0
  • polygamma_polygamma_n_1
  • polygamma_polygamma_n_2
  • polygamma_polygamma_n_3
  • polygamma_polygamma_n_4

Small tensor wrappers currently grouped here

  • cross
  • diagonal
  • householder_product
  • matrix_power
  • multi_dot
  • vander
  • vecdot

Notes On Future Splits

This shared note is intentionally broad in the first migration pass. Operations that later grow heavier derivation detail can be split into dedicated note files without changing the DB schema; only the central registry needs to move.