PyTorch and JAX Mapping
This page is for readers who already know either torch or jax.numpy and want to find the tenferro equivalent quickly.
Concept mapping
| Concept | PyTorch | JAX | tenferro |
|---|---|---|---|
| Eager tensor | numpy.ndarray |
— | Tensor + a backend |
| Tensor handle | torch.Tensor |
jax.Array / jnp.ndarray |
TracedTensor |
| Concrete result | torch.Tensor |
jax.Array |
Tensor returned by .eval(&mut engine) |
| Execution | Eager by default | Eager arrays, often staged with jit |
Eager (Tensor / EagerTensor) or lazy traced (TracedTensor + .eval()) |
| Eager gradients | loss.backward() |
— | EagerTensor::backward() with accumulation |
| Transform AD | torch.autograd.grad(...) |
jax.grad, jax.vjp, jax.jvp, hvp via composition |
loss.grad(&x), .vjp(), .jvp() |
| Device/runtime | Device is attached to tensors | Device is attached to arrays | Backend lives in direct tensor calls, EagerContext, or Engine |
| CUDA execution | x.to("cuda") |
jax.device_put(x) |
tenferro::cuda::upload_tensor(...) and download_tensor(...) |
| Matrix contraction | torch.einsum |
jnp.einsum |
tenferro::traced_tensor::einsum |
Function mapping
| Task | PyTorch | JAX | tenferro (eager) | tenferro (lazy/AD) |
|---|---|---|---|---|
| Create tensor | torch.tensor(data) |
jnp.array(data) |
Tensor::from_vec(shape, data) |
TracedTensor::from_vec(shape, data) |
| Matrix multiply | torch.matmul(a, b) |
jnp.matmul(a, b) |
a.matmul(&b, &mut ctx) |
tenferro::traced_tensor::matmul(&a, &b) |
| Reshape | x.reshape(shape) |
jnp.reshape(x, shape) |
x.reshape(&shape, &mut ctx) |
x.reshape(&shape) |
| Transpose | x.transpose(0, 1) |
jnp.transpose(x, axes) |
x.transpose(&perm, &mut ctx) |
x.transpose(&perm) |
| Broadcast | x.expand(...) / implicit broadcast |
implicit broadcast in many ops | backend-level op | x.broadcast(&shape, &dims) |
| Reduce sum | x.sum(dim=...) |
jnp.sum(x, axis=...) |
x.reduce_sum(&axes, &mut ctx) |
x.reduce_sum(&axes) |
| Einsum | torch.einsum(spec, ...) |
jnp.einsum(spec, ...) |
tenferro::tensor::einsum(&mut ctx, ...) |
tenferro::traced_tensor::einsum(&mut engine, ...) |
| SVD | torch.linalg.svd(x) |
jnp.linalg.svd(x) |
x.svd(&mut ctx) |
tenferro::traced_tensor::svd(&x) |
| QR | torch.linalg.qr(x) |
jnp.linalg.qr(x) |
x.qr(&mut ctx) |
tenferro::traced_tensor::qr(&x) |
| Cholesky | torch.linalg.cholesky(x) |
jnp.linalg.cholesky(x) |
x.cholesky(&mut ctx) |
tenferro::traced_tensor::cholesky(&x) |
| Solve | torch.linalg.solve(a, b) |
jnp.linalg.solve(a, b) |
a.solve(&b, &mut ctx) |
tenferro::traced_tensor::solve(&a, &b) |
| Scalar-loss backward | loss.backward() |
— | loss.backward() on EagerTensor |
— |
| Reverse-mode grad | torch.autograd.grad(loss, x) |
jax.grad(f)(x) |
— | loss.grad(&x) |
| VJP | torch.autograd.grad(..., grad_outputs=...) |
jax.vjp |
— | y.vjp(&x, &cotangent) |
| JVP | torch.func.jvp |
jax.jvp |
— | y.jvp(&x, &tangent) |
Key differences
Column-major storage
tenferro stores dense tensors in column-major order. If you write:
use tenferro::TracedTensor;
let a = TracedTensor::from_vec(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);then the columns are [1, 2], [3, 4], and [5, 6].
Convert flat input data first when it is already in PyTorch, NumPy, or JAX row-major order, then construct the tensor with Tensor::from_vec.
Explicit CUDA transfer
tenferro follows the PyTorch convention that CPU and CUDA tensors do not move between devices implicitly. Upload CPU tensors with tenferro::cuda::upload_tensor before CUDA backend operations, and download with tenferro::cuda::download_tensor before inspecting values on the host.
CUDA support targets NVIDIA CUDA through the CubeCL backend. See Devices and GPU for the current coverage and setup commands.
Lazy evaluation
PyTorch users usually expect every operation to execute immediately. JAX users often switch between eager execution and jit. tenferro stays lazy until you call .eval(&mut engine).
Autodiff split
Eager tenferro matches PyTorch’s scalar-loss loss.backward() workflow with accumulation semantics. Traced tenferro is the transform surface for torch.autograd.grad, jax.grad, jax.vjp, jax.jvp, and higher-order compositions such as HVPs.
Engine ownership
In tenferro, the backend and reusable execution state live in Engine, not on each tensor. That means most user code follows this pattern:
- Create
TracedTensorvalues. - Build tensor expressions.
- Reuse one
Engineto evaluate them.