PyTorch and JAX Mapping

This page is for readers who already know either torch or jax.numpy and want to find the tenferro equivalent quickly.

Concept mapping

Concept PyTorch JAX tenferro
Eager tensor numpy.ndarray Tensor + CpuBackend
Tensor handle torch.Tensor jax.Array / jnp.ndarray TracedTensor
Concrete result torch.Tensor jax.Array Tensor returned by .eval(&mut engine)
Execution Eager by default Eager arrays, often staged with jit Eager (Tensor / EagerTensor) or lazy traced (TracedTensor + .eval())
Eager gradients loss.backward() EagerTensor::backward() with accumulation
Transform AD torch.autograd.grad(...) jax.grad, jax.vjp, jax.jvp, hvp via composition loss.grad(&x), .vjp(), .jvp()
Device/runtime Device is attached to tensors Device is attached to arrays Backend lives inside Engine
Matrix contraction torch.einsum jnp.einsum tenferro::einsum::einsum

Function mapping

Task PyTorch JAX tenferro (eager) tenferro (lazy/AD)
Create tensor torch.tensor(data) jnp.array(data) Tensor::from_vec(shape, data) TracedTensor::from_vec(shape, data)
Matrix multiply torch.matmul(a, b) jnp.matmul(a, b) a.matmul(&b, &mut ctx) tenferro::matmul(&a, &b)
Reshape x.reshape(shape) jnp.reshape(x, shape) x.reshape(&shape, &mut ctx) x.reshape(&shape)
Transpose x.transpose(0, 1) jnp.transpose(x, axes) x.transpose(&perm, &mut ctx) x.transpose(&perm)
Broadcast x.expand(...) / implicit broadcast implicit broadcast in many ops backend-level op x.broadcast(&shape, &dims)
Reduce sum x.sum(dim=...) jnp.sum(x, axis=...) x.reduce_sum(&axes, &mut ctx) x.reduce_sum(&axes)
Einsum torch.einsum(spec, ...) jnp.einsum(spec, ...) eager_einsum(&mut ctx, ...) einsum(&mut engine, ...)
SVD torch.linalg.svd(x) jnp.linalg.svd(x) x.svd(&mut ctx) tenferro::svd(&x)
QR torch.linalg.qr(x) jnp.linalg.qr(x) x.qr(&mut ctx) tenferro::qr(&x)
Cholesky torch.linalg.cholesky(x) jnp.linalg.cholesky(x) x.cholesky(&mut ctx) tenferro::cholesky(&x)
Solve torch.linalg.solve(a, b) jnp.linalg.solve(a, b) a.solve(&b, &mut ctx) tenferro::solve(&a, &b)
Scalar-loss backward loss.backward() loss.backward() on EagerTensor
Reverse-mode grad torch.autograd.grad(loss, x) jax.grad(f)(x) loss.grad(&x)
VJP torch.autograd.grad(..., grad_outputs=...) jax.vjp y.vjp(&x, &cotangent)
JVP torch.func.jvp jax.jvp y.jvp(&x, &tangent)

Key differences

Column-major storage

tenferro stores dense tensors in column-major order. If you write:

use tenferro::TracedTensor;
let a = TracedTensor::from_vec(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);

then the columns are [1, 2], [3, 4], and [5, 6].

Lazy evaluation

PyTorch users usually expect every operation to execute immediately. JAX users often switch between eager execution and jit. tenferro stays lazy until you call .eval(&mut engine).

Autodiff split

Eager tenferro matches PyTorch’s scalar-loss loss.backward() workflow with accumulation semantics. Traced tenferro is the transform surface for torch.autograd.grad, jax.grad, jax.vjp, jax.jvp, and higher-order compositions such as HVPs.

Engine ownership

In tenferro, the backend and reusable execution state live in Engine, not on each tensor. That means most user code follows this pattern:

  1. Create TracedTensor values.
  2. Build tensor expressions.
  3. Reuse one Engine to evaluate them.