Scalar and Reduction AD Rules
This note records the scalar AD formulas implemented in chainrules and the tensor-level wrappers exposed through tenferro.
Scope
Implemented low-level scalar basis in chainrules:
add,sub,mul,divconjsqrtexplogatan2powf(scalar exponent)powi(integer exponent; a restrictedpowcase)
Additional tensor-level scalar/analytic rules are implemented in tenferro by composing runtime-generic tensor primitives on top of that basis.
The tensor wrappers now share a centralized runtime-dispatch contract layer in tenferro, so adding new scalar families does not require repeating backend-specific Cpu/Cuda/Rocm bounds at each eager or builder entrypoint.
Scalar reduction helpers also use transfer-aware rank-0 extraction and scalar broadcast utilities. That keeps reduction pullbacks on the same generic runtime path instead of introducing one-off CPU-only builder branches.
Tensor-level wrappers built on top of those formulas:
- pointwise binary:
add_ad,atan2_ad,pow_ad,hypot_ad - pointwise unary:
sqrt_ad,exp_ad,expm1_ad,log_ad,log1p_ad,sin_ad,cos_ad,tanh_ad,asin_ad,acos_ad,atan_ad,sinh_ad,cosh_ad,asinh_ad,acosh_ad,atanh_ad - reductions:
sum_ad,mean_ad,var_ad,std_ad
Target scalar domains:
f32,f64Complex32,Complex64
PyTorch Baseline (Local Snapshot)
Reference repository: ../pytorch
Commit used for this note: 8dd3b7637abd04433bafe77765de59df6388f9f9
Primary source files:
tools/autograd/derivatives.yamltorch/csrc/autograd/FunctionsManual.cppdocs/source/notes/autograd.rst
Key lines:
_conjbackward:derivatives.yaml:477-479sqrtbackward:derivatives.yaml:1622-1624expbackward:derivatives.yaml:652-654logbackward:derivatives.yaml:966-968atan2backward:derivatives.yaml:260-263pow.Tensor_Scalar/pow.Tensor_Tensor:derivatives.yaml:1385-1392pow_backward*implementation:FunctionsManual.cpp:473-557handle_r_to_c:FunctionsManual.cpp:169-183- Complex AD convention statement:
notes/autograd.rst:601-607
Complex Gradient Convention
We follow PyTorch’s convention for real-valued losses:
- gradients are conjugate-Wirtinger (
dL/dz*) style - VJP formulas include complex conjugation where required
- if input is real and an intermediate gradient is complex, project back to real (
handle_r_to_cbehavior)
Rule Summary
Let g be output cotangent, x input primal, y = f(x) output primal.
conj
- Primal:
y = conj(x) - rrule:
dx = conj(g) - frule:
dy = conj(dx)
sqrt
- Primal:
y = sqrt(x) - rrule:
dx = g / (2 * conj(y)) - frule:
dy = dx / (2 * conj(y))
add
- Primal:
y = x1 + x2 - rrule:
(dx1, dx2) = (g, g) - frule:
dy = dx1 + dx2
sub
- Primal:
y = x1 - x2 - rrule:
(dx1, dx2) = (g, -g) - frule:
dy = dx1 - dx2
mul
- Primal:
y = x1 * x2 - rrule:
(dx1, dx2) = (g * conj(x2), g * conj(x1)) - frule:
dy = dx1 * conj(x2) + dx2 * conj(x1)
div
- Primal:
y = x1 / x2 - rrule:
dx1 = g / conj(x2)dx2 = g * conj(-x1 / x2^2)
- frule:
dy = dx1 / conj(x2) + dx2 * conj(-x1 / x2^2)
powf (fixed scalar exponent a)
- Primal:
y = x^a - rrule (self gradient):
dx = g * conj(a * x^(a - 1)) - frule (self tangent):
dy = dx * conj(a * x^(a - 1))
powi (fixed integer exponent n)
This is powf with integer exponent semantics.
- Primal:
y = x^n - rrule:
dx = g * conj(n * x^(n - 1)) - frule:
dy = dx * conj(n * x^(n - 1))
exp
- Primal:
y = exp(x) - rrule:
dx = g * conj(y) - frule:
dy = dx * conj(y)
expm1
- Primal:
y = exp(x) - 1 - rrule:
dx = g * conj(exp(x)) - frule:
dy = dx * conj(exp(x))
log
- Primal:
y = log(x) - rrule:
dx = g / conj(x) - frule:
dy = dx / conj(x)
log1p
- Primal:
y = log(1 + x) - rrule:
dx = g / conj(1 + x) - frule:
dy = dx / conj(1 + x)
sin
- Primal:
y = sin(x) - rrule:
dx = g * conj(cos(x)) - frule:
dy = dx * conj(cos(x))
cos
- Primal:
y = cos(x) - rrule:
dx = g * conj(-sin(x)) - frule:
dy = dx * conj(-sin(x))
tanh
- Primal:
y = tanh(x) - rrule:
dx = g * conj(1 - y^2) - frule:
dy = dx * conj(1 - y^2)
Tensor-Composite Rules in tenferro
The following rules are implemented one layer up by composing the scalar and tensor primitive families. They are not exported as standalone chainrules::*_rrule / *_frule functions.
Unary analytic wrappers
expm1: derivative factorexp(x)log1p: derivative factor1 / (1 + x)sin: derivative factorcos(x)cos: derivative factor-sin(x)asin: derivative factor1 / sqrt(1 - x^2)acos: derivative factor-1 / sqrt(1 - x^2)atan: derivative factor1 / (1 + x^2)sinh: derivative factorcosh(x)cosh: derivative factorsinh(x)asinh: derivative factor1 / sqrt(1 + x^2)acosh: derivative factor1 / sqrt(x^2 - 1)atanh: derivative factor1 / (1 - x^2)
Binary analytic wrappers
pow(x, a):dx = g * conj(a * x^(a - 1))da = g * conj(pow(x, a) * log(x))
hypot(x, y)for real-valued inputs:dx = g * x / hypot(x, y)dy = g * y / hypot(x, y)
atan2 (real-valued inputs)
Let y = atan2(a, b) with a as numerator-like input and b as denominator-like input.
- rrule:
da = g * b / (a^2 + b^2)db = g * (-a) / (a^2 + b^2)
- frule:
dy = da * b / (a^2 + b^2) + db * (-a) / (a^2 + b^2)
Edge Cases
Aligned with PyTorch:
powwith exponent0gives zero self-gradient.- Real-input/complex-intermediate gradients are projected back to real (
handle_r_to_cequivalent).
Tensor Reduction Wrappers
The tensor-level reduction builders in tenferro reuse the scalar rules above plus runtime-generic tensor primitives.
mean_ad
For y = mean(x) over all elements (N = x.len()):
- rrule: every element receives
g / N - frule:
dy = mean(dx)
var_ad
For y = var(x) with population normalization:
- center:
c = x - mean(x) - rrule:
dx = g * 2c / N - frule:
dy = mean(2c * dx)
std_ad
For y = std(x) with y = sqrt(var(x)):
- rrule:
dx = g * c / (N * y) - frule:
dy = dvar / (2y)
sum_ad
For y = sum(x) over all elements:
- rrule: every element receives the same cotangent
g - frule:
dy = sum(dx)
Deferred Select-Style Ops
The following families are intentionally not exposed yet:
where- AD wrappers for
maximum,minimum, andclamp* xlogytensor-level AD wrappers
The blocker is not CPU math itself. It is the absence of a dedicated boolean/predicate tensor substrate. Without that layer, branch-select rules would have to smuggle predicate semantics into scalar-only contracts, which is not the design direction for #441.
API Placement
Implementation placement:
- scalar formulas and helper projection (
handle_r_to_cequivalent):tensor4all/chainrules-rs/crates/chainrules - tensor-level generic unary/binary/reduction wrappers:
tenferro::ops::scalar::ad - eager AD entrypoints:
tenferro::ops::ad