Structural AD Notes
Conventions
These notes cover selection and shape-rearrangement primitives whose integer, boolean, and axis arguments are metadata. Tensor inputs are differentiated under the real Frobenius inner product used by the rest of the corpus. Unless noted otherwise, Linearization and Transpose are written in raw-output-space before any DB observable projection.
Forward
The covered raw operators are:
where(condition, x, y), where condition is nondifferentiable metadatacat(inputs, dim), where the axis is metadatanarrow(a, dim, start, length), where start indices and sizes are nondifferentiable metadataclamp(x, lo, hi), wherex,lo, andhiare differentiable tensors or scalar tensors
Linearization
For where, with a fixed boolean mask m,
\dot{z} = m \odot \dot{x} + (1-m) \odot \dot{y},
using the same broadcasting as the primal operation. The condition is nondifferentiable.
For cat, if z=\operatorname{cat}(x_1,\ldots,x_n;\operatorname{dim}=d), then
\dot{z} = \operatorname{cat}(\dot{x}_1,\ldots,\dot{x}_n; d).
The axis is metadata.
For narrow, the tangent is the same slice of the source tangent:
\dot{z} = \operatorname{narrow}(\dot{a}, d, s, \ell).
The dimension, start indices and sizes are nondifferentiable metadata.
For clamp, away from the switching surfaces x=lo and x=hi,
\dot{z}_i = \begin{cases} \dot{lo}_i, & x_i < lo_i, \\ \dot{x}_i, & lo_i < x_i < hi_i, \\ \dot{hi}_i, & x_i > hi_i. \end{cases}
Broadcasted bounds use the framework’s broadcast reduction in reverse mode.
JVP
The JVP is the linearization above evaluated at the chosen tangent inputs while holding masks, axes, starts, and lengths fixed.
Transpose
The transpose of where routes the output cotangent to x where the mask is true and to y where the mask is false, summing over broadcasted dimensions.
The transpose of cat splits the cotangent along the concatenation axis and returns one slice per differentiable input.
The transpose of narrow scatters the cotangent back into the selected source region and writes zero outside that region.
The transpose of clamp routes cotangents to lo, x, or hi according to the active region of each element, again summing over broadcasted dimensions.
VJP (JAX convention)
JAX-style VJPs are the same transposes of the fixed-metadata linear maps. The metadata arguments are not part of the differentiable input tuple.
VJP (PyTorch convention)
PyTorch autograd exposes the same tensor-input VJPs. Boolean masks, axes, start indices, and slice sizes are treated as nondifferentiable arguments.
DB Families
The DB publishes direct torch.where(condition, x, y) outputs. Cases include mixed masks, tensor/tensor branches, scalar x broadcast against tensor y, and tensor x broadcast against scalar y.
The DB publishes direct torch.cat(inputs, dim) outputs. Cases cover concatenation on leading, interior, and negative axes; dim is metadata.
The DB publishes direct dynamic-slice outputs using Tensor.narrow. Cases cover positive and negative dimensions. The dimension, start index, and slice length are metadata.
The DB publishes direct ternary clamp outputs for real tensors. Cases cover scalar bounds, tensor bounds, and broadcasted tensor bounds. Boundary points where x=lo or x=hi are nonsmooth and are not used as success probes.