Complex Autodiff
tenferro supports complex JVP, VJP, and scalar-loss gradient workflows, but complex reverse-mode values need one convention choice that is easy to miss when comparing against JAX.
tenferro reverse mode represents cotangents under the Hermitian real inner product:
<a, b> = Re(conj(a) * b)
For a holomorphic scalar map y = f(z), this means:
JVP: dy = f'(z) * dz
VJP: dz_bar = y_bar * conj(f'(z))
So grad with the default scalar seed 1 returns the Hermitian-adjoint cotangent. For scalar seed-1 cases, JAX’s public grad presentation is the complex conjugate of tenferro’s value:
jax_like_grad = conj(tenferro_grad)
This is a representation difference for reverse-mode cotangents, not a local AD-rule failure. JVPs are directional derivatives and do not have this public cotangent-encoding ambiguity.
Examples
The example below demonstrates:
grad(sum(exp(z))), where tenferro returnsconj(exp(z)),grad(sum(abs(z)^2)), where tenferro returns2z,- a VJP with a non-real complex cotangent seed.
use num_complex::{Complex64, ComplexFloat};
use tenferro_ad::TracedTensorAdExt;
use tenferro_cpu::CpuBackend;
use tenferro_runtime::{GraphCompiler, GraphExecutor, Tensor, TracedTensor};
const TOL: f64 = 1.0e-12;
fn assert_complex_close(actual: Complex64, expected: Complex64) {
let error = (actual - expected).norm();
assert!(
error < TOL,
"actual={actual}, expected={expected}, error={error}"
);
}
fn assert_complex_slice_close(actual: &[Complex64], expected: &[Complex64]) {
assert_eq!(actual.len(), expected.len());
for (&actual, &expected) in actual.iter().zip(expected) {
assert_complex_close(actual, expected);
}
}
fn run(tensor: &TracedTensor) -> Result<Tensor, tenferro_runtime::Error> {
let mut compiler = GraphCompiler::new();
let program = compiler.compile(tensor)?;
let mut executor = GraphExecutor::new(CpuBackend::new());
executor.run(&program)
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let z_data = vec![
Complex64::new(0.0, std::f64::consts::FRAC_PI_2),
Complex64::new(std::f64::consts::LN_2, 0.0),
];
let z = TracedTensor::from_vec_col_major(vec![2], z_data.clone())?;
let holomorphic_loss = z.exp().reduce_sum(&[0]);
let tenferro_grad = holomorphic_loss?.grad(&z)?;
let tenferro_grad_value = run(&tenferro_grad)?;
let expected_grad: Vec<_> = z_data.iter().map(|z| z.exp().conj()).collect();
assert_complex_slice_close(
tenferro_grad_value.as_slice::<Complex64>().unwrap(),
&expected_grad,
);
let magnitude = z.abs();
let real_loss = (&magnitude * &magnitude)?.reduce_sum(&[0]);
let real_loss_grad = real_loss?.grad(&z)?;
let real_loss_grad_value = run(&real_loss_grad)?;
let expected_real_loss_grad: Vec<_> = z_data.iter().map(|z| *z * 2.0).collect();
assert_complex_slice_close(
real_loss_grad_value.as_slice::<Complex64>().unwrap(),
&expected_real_loss_grad,
);
let scalar_z = TracedTensor::from_vec_col_major(
vec![1],
vec![Complex64::new(0.0, std::f64::consts::FRAC_PI_2)],
)?;
let y = scalar_z.exp();
let seed = Complex64::new(2.0, -3.0);
let cotangent = TracedTensor::from_vec_col_major(vec![1], vec![seed])?;
let vjp = y.vjp(&scalar_z, &cotangent)?;
let vjp_value = run(&vjp)?;
let expected_vjp = seed
* Complex64::new(0.0, std::f64::consts::FRAC_PI_2)
.exp()
.conj();
assert_complex_close(vjp_value.as_slice::<Complex64>().unwrap()[0], expected_vjp);
Ok(())
}For non-real complex VJP cotangent seeds, the seed is part of the convention comparison. Calling tenferro with seed y_bar computes y_bar * conj(f'(z)). If you want to compare to a JAX-style public seed c for the same holomorphic scalar map, compare against conj(tenferro_vjp(conj(c))).
Where This Matters
Use this convention note when comparing:
TracedTensorAdExt::gradandAdContext::gradon complex scalar outputs,vjpwith complex cotangent seeds,- eager
EagerTensor::backward()and storedgrad()values for complex scalar losses.
For real-valued code, or for JVPs viewed as directional derivatives, there is no extra conjugation step to apply when comparing values.