Complex AD
tidu follows the JAX convention for complex automatic differentiation.
Forward mode computes the full real-linear derivative:
df = (df/dz) * dz + (df/dconj(z)) * conj(dz)
Reverse mode transposes linear maps with respect to the real inner product:
<a, b> = Re(conj(a) * b)
For a general function f: C -> C, a cotangent is:
ct_z = ct_y * conj(df/dz) + conj(ct_y) * (df/dconj(z))
For real losses, this gives ct_z = 2 * dL/dconj(z) when the output cotangent seed is 1. That differs from frameworks that directly return dL/dconj(z), but the steepest-descent direction is the same.