PyTorch and JAX Mapping
This page is for readers who already know either torch or jax.numpy and want to find the tenferro equivalent quickly.
Concept mapping
| Concept | PyTorch | JAX | tenferro |
|---|---|---|---|
| Eager tensor | numpy.ndarray |
— | Tensor + CpuBackend |
| Tensor handle | torch.Tensor |
jax.Array / jnp.ndarray |
TracedTensor |
| Concrete result | torch.Tensor |
jax.Array |
Tensor returned by .eval(&mut engine) |
| Execution | Eager by default | Eager arrays, often staged with jit |
Eager (Tensor / EagerTensor) or lazy traced (TracedTensor + .eval()) |
| Eager gradients | loss.backward() |
— | EagerTensor::backward() with accumulation |
| Transform AD | torch.autograd.grad(...) |
jax.grad, jax.vjp, jax.jvp, hvp via composition |
loss.grad(&x), .vjp(), .jvp() |
| Device/runtime | Device is attached to tensors | Device is attached to arrays | Backend lives inside Engine |
| Matrix contraction | torch.einsum |
jnp.einsum |
tenferro::einsum::einsum |
Function mapping
| Task | PyTorch | JAX | tenferro (eager) | tenferro (lazy/AD) |
|---|---|---|---|---|
| Create tensor | torch.tensor(data) |
jnp.array(data) |
Tensor::from_vec(shape, data) |
TracedTensor::from_vec(shape, data) |
| Matrix multiply | torch.matmul(a, b) |
jnp.matmul(a, b) |
a.matmul(&b, &mut ctx) |
tenferro::matmul(&a, &b) |
| Reshape | x.reshape(shape) |
jnp.reshape(x, shape) |
x.reshape(&shape, &mut ctx) |
x.reshape(&shape) |
| Transpose | x.transpose(0, 1) |
jnp.transpose(x, axes) |
x.transpose(&perm, &mut ctx) |
x.transpose(&perm) |
| Broadcast | x.expand(...) / implicit broadcast |
implicit broadcast in many ops | backend-level op | x.broadcast(&shape, &dims) |
| Reduce sum | x.sum(dim=...) |
jnp.sum(x, axis=...) |
x.reduce_sum(&axes, &mut ctx) |
x.reduce_sum(&axes) |
| Einsum | torch.einsum(spec, ...) |
jnp.einsum(spec, ...) |
eager_einsum(&mut ctx, ...) |
einsum(&mut engine, ...) |
| SVD | torch.linalg.svd(x) |
jnp.linalg.svd(x) |
x.svd(&mut ctx) |
tenferro::svd(&x) |
| QR | torch.linalg.qr(x) |
jnp.linalg.qr(x) |
x.qr(&mut ctx) |
tenferro::qr(&x) |
| Cholesky | torch.linalg.cholesky(x) |
jnp.linalg.cholesky(x) |
x.cholesky(&mut ctx) |
tenferro::cholesky(&x) |
| Solve | torch.linalg.solve(a, b) |
jnp.linalg.solve(a, b) |
a.solve(&b, &mut ctx) |
tenferro::solve(&a, &b) |
| Scalar-loss backward | loss.backward() |
— | loss.backward() on EagerTensor |
— |
| Reverse-mode grad | torch.autograd.grad(loss, x) |
jax.grad(f)(x) |
— | loss.grad(&x) |
| VJP | torch.autograd.grad(..., grad_outputs=...) |
jax.vjp |
— | y.vjp(&x, &cotangent) |
| JVP | torch.func.jvp |
jax.jvp |
— | y.jvp(&x, &tangent) |
Key differences
Column-major storage
tenferro stores dense tensors in column-major order. If you write:
use tenferro::TracedTensor;
let a = TracedTensor::from_vec(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);then the columns are [1, 2], [3, 4], and [5, 6].
Lazy evaluation
PyTorch users usually expect every operation to execute immediately. JAX users often switch between eager execution and jit. tenferro stays lazy until you call .eval(&mut engine).
Autodiff split
Eager tenferro matches PyTorch’s scalar-loss loss.backward() workflow with accumulation semantics. Traced tenferro is the transform surface for torch.autograd.grad, jax.grad, jax.vjp, jax.jvp, and higher-order compositions such as HVPs.
Engine ownership
In tenferro, the backend and reusable execution state live in Engine, not on each tensor. That means most user code follows this pattern:
- Create
TracedTensorvalues. - Build tensor expressions.
- Reuse one
Engineto evaluate them.