PyTorch and JAX Mapping
This page is for readers who already know either torch or jax.numpy and want to find the tenferro equivalent quickly.
Concept mapping
| Concept | PyTorch | JAX | tenferro |
|---|---|---|---|
| Typed concrete tensor | torch.Tensor with fixed dtype |
jax.Array with fixed dtype |
TypedTensor<T> |
| Dynamic concrete tensor | torch.Tensor |
jax.Array / jnp.ndarray |
Tensor + a backend |
| Graph-building tensor handle | torch.Tensor under compiled/tracing tools |
traced jax.Array values |
TracedTensor |
| Concrete result | torch.Tensor |
jax.Array |
Tensor returned by GraphExecutor::run |
| Execution | Eager by default | Eager arrays, often staged with jit |
Eager (Tensor / EagerTensor) or lazy traced (TracedTensor + GraphCompiler + GraphExecutor) |
| Eager forward and gradients | eager ops plus loss.backward() |
— | EagerTensor forward ops, with backward() for tracked scalar losses |
| Transform AD | torch.autograd.grad(...) |
jax.grad, jax.vjp, jax.jvp, hvp via composition |
loss.grad(&x), .vjp()?, .jvp()?; HVP via composition |
| Device/runtime | Device is attached to tensors | Device is attached to arrays | Backend lives in direct tensor calls, EagerRuntime, or GraphExecutor |
| CUDA execution | x.to("cuda") |
jax.device_put(x) |
tenferro_gpu::upload_tensor(...) and download_tensor(...) |
| Matrix contraction | torch.einsum |
jnp.einsum |
compiler.einsum(...) via GraphCompilerEinsumExt |
Function mapping
| Task | PyTorch | JAX | tenferro (eager) | tenferro (lazy/AD) |
|---|---|---|---|---|
| Create typed tensor | torch.tensor(data, dtype=...) |
jnp.array(data, dtype=...) |
TypedTensor::<T>::from_vec_col_major(shape, data) |
— |
| Create dynamic tensor | torch.tensor(data) |
jnp.array(data) |
Tensor::from_vec_col_major(shape, data) |
TracedTensor::from_vec_col_major(shape, data) |
| Matrix multiply | torch.matmul(a, b) |
jnp.matmul(a, b) |
a.matmul(&b, &mut ctx) via TensorOpsExt |
a.matmul(&b) |
| Reshape | x.reshape(shape) |
jnp.reshape(x, shape) |
x.reshape(&shape, &mut ctx) via TensorOpsExt |
x.reshape(&shape) |
| Transpose | x.transpose(0, 1) |
jnp.transpose(x, axes) |
x.transpose(&perm, &mut ctx) via TensorOpsExt |
x.transpose(&perm) |
| Broadcast | x.expand(...) / implicit broadcast |
implicit broadcast in many ops | backend-level op | x.broadcast_in_dim(&shape, &dims) |
| Reduce sum | x.sum(dim=...) |
jnp.sum(x, axis=...) |
x.reduce_sum(&axes, &mut ctx) via TensorOpsExt |
x.reduce_sum(&axes) |
| Einsum | torch.einsum(spec, ...) |
jnp.einsum(spec, ...) |
[&a, &b].einsum(...) via EagerEinsumExt |
compiler.einsum(...) via GraphCompilerEinsumExt plus register_runtime |
| SVD | torch.linalg.svd(x) |
jnp.linalg.svd(x) |
tenferro_linalg::LinalgBackend::svd(&mut ctx, &x)? |
x.svd()? via TracedTensorLinalgExt |
| QR | torch.linalg.qr(x) |
jnp.linalg.qr(x) |
tenferro_linalg::LinalgBackend::qr(&mut ctx, &x)? |
x.qr()? via TracedTensorLinalgExt |
| Cholesky | torch.linalg.cholesky(x) |
jnp.linalg.cholesky(x) |
tenferro_linalg::LinalgBackend::cholesky(&mut ctx, &x)? |
x.cholesky()? via TracedTensorLinalgExt |
| Solve | torch.linalg.solve(a, b) |
jnp.linalg.solve(a, b) |
tenferro_linalg::LinalgBackend::solve(&mut ctx, &a, &b)? |
a.solve(&b)? via TracedTensorLinalgExt |
| Scalar-loss backward | loss.backward() |
— | loss.backward() on EagerTensor |
— |
| Reverse-mode grad | torch.autograd.grad(loss, x) |
jax.grad(f)(x) |
— | loss.grad(&x) |
| VJP | torch.autograd.grad(..., grad_outputs=...) |
jax.vjp |
— | y.vjp(&x, &cotangent)? |
| JVP | torch.func.jvp |
jax.jvp |
— | y.jvp(&x, &tangent)? |
Key differences
Column-major storage
tenferro stores dense tensors in column-major order. If you write:
use tenferro_runtime::TracedTensor;
let a = TracedTensor::from_vec_col_major(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);then the columns are [1, 2], [3, 4], and [5, 6].
Use from_vec_col_major only with buffers already in tenferro’s physical column-major order. Flat buffers copied from PyTorch, NumPy, or JAX row-major arrays must be reordered explicitly before tensor construction.
Explicit CUDA transfer
tenferro follows the PyTorch convention that CPU and CUDA tensors do not move between devices implicitly. Upload CPU tensors with tenferro_gpu::upload_tensor before CUDA backend operations, and download with tenferro_gpu::download_tensor before inspecting values on the host.
For eager CUDA execution, operation calls submit work and return CUDA tensor handles. The host synchronizes at download/read boundaries or at operations that must inspect device-side status; there is no user-visible ready flag.
CUDA support targets NVIDIA CUDA. See Devices and GPU for the current coverage and setup commands.
Dynamic shapes and static-shape specialization
JAX/XLA specializes compiled programs to known shapes. That is a strong fit for large static-shape workloads where compiler optimization dominates. tenferro’s native traced runtime is built for the complementary case: programs whose output sizes may depend on runtime values, such as thresholded or truncated linear algebra.
When shapes are static, the proposed optional XLA backend (#984) may eventually route traced programs to XLA. When ranks or extents are value-dependent, tenferro keeps dynamic shape metadata in the traced program and resolves the concrete sizes during execution. See Dynamic and Symbolic Shape Metadata.
Lazy traced execution
PyTorch users usually expect every operation to execute immediately. JAX users often switch between eager execution and jit. tenferro’s traced API stays lazy until you lower a TracedTensor graph with GraphCompiler and run the resulting GraphProgram with GraphExecutor.
Autodiff split
Eager tenferro matches PyTorch-style eager forward execution. When tensors are tracked, it also matches the scalar loss loss.backward() workflow with accumulation semantics. Traced tenferro is the API for torch.autograd.grad, jax.grad, jax.vjp, jax.jvp, and higher-order compositions such as HVPs.
For complex reverse-mode AD, tenferro uses a Hermitian real-inner-product cotangent representation. When comparing scalar grad values to JAX, the JAX-like value is conj(tenferro_grad) for the same scalar seed-1 calculation. See Complex Autodiff for the full convention and VJP seed comparison.
Compiler and executor ownership
In tenferro, graph lowering state and backend runtime state are separate. That means most traced user code follows this pattern:
- Create
TracedTensorvalues. - Build tensor expressions.
- Reuse one
GraphCompilerfor graph lowering and static planning caches. - Reuse one
GraphExecutor<B>for backend execution and runtime caches.