Tensor Operations

This guide covers the everyday tensor operations you would usually reach for in PyTorch or JAX: creating tensors, applying elementwise math, changing shapes, and reducing over axes.

Create tensors from shape and data

This is the tenferro equivalent of torch.tensor(...) or jnp.array(...).

use tenferro::{CpuBackend, Engine, TracedTensor};

let mut a = TracedTensor::from_vec(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);

let mut engine = Engine::new(CpuBackend::new());
let result = a.eval(&mut engine).unwrap();

assert_eq!(result.shape(), &[2, 3]);
assert_eq!(result.as_slice::<f64>().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);

Elementwise arithmetic

Like PyTorch and JAX, tenferro supports familiar elementwise operators on tensors with compatible shapes.

use tenferro::{CpuBackend, Engine, TracedTensor};

let a = TracedTensor::from_vec(vec![3], vec![1.0_f64, 2.0, 3.0]);
let b = TracedTensor::from_vec(vec![3], vec![4.0_f64, 5.0, 6.0]);
let mut sum = &a + &b;
let mut product = &a * &b;

let mut engine = Engine::new(CpuBackend::new());
let sum_result = sum.eval(&mut engine).unwrap();
let product_result = product.eval(&mut engine).unwrap();

assert_eq!(sum_result.as_slice::<f64>().unwrap(), &[5.0, 7.0, 9.0]);
assert_eq!(product_result.as_slice::<f64>().unwrap(), &[4.0, 10.0, 18.0]);

Elementwise math functions

The common unary math ops are methods on TracedTensor, similar to torch.exp(x) or jnp.exp(x).

use tenferro::{CpuBackend, Engine, TracedTensor};

let x = TracedTensor::from_vec(vec![3], vec![0.0_f64, 1.0, 2.0]);
let mut y = x.exp();

let mut engine = Engine::new(CpuBackend::new());
let result = y.eval(&mut engine).unwrap();
let data = result.as_slice::<f64>().unwrap();

assert!((data[0] - 1.0).abs() < 1e-12);
assert!((data[1] - std::f64::consts::E).abs() < 1e-12);
assert!((data[2] - 7.38905609893065).abs() < 1e-12);

Reshape and transpose

These match the ideas behind reshape and transpose in PyTorch and JAX.

use tenferro::{CpuBackend, Engine, TracedTensor};

let a = TracedTensor::from_vec(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
let mut reshaped = a.reshape(&[6]);
let mut transposed = a.transpose(&[1, 0]);

let mut engine = Engine::new(CpuBackend::new());
let reshaped_result = reshaped.eval(&mut engine).unwrap();
let transposed_result = transposed.eval(&mut engine).unwrap();

assert_eq!(reshaped_result.shape(), &[6]);
assert_eq!(reshaped_result.as_slice::<f64>().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
assert_eq!(transposed_result.shape(), &[3, 2]);
assert_eq!(
    transposed_result.as_slice::<f64>().unwrap(),
    &[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]
);

Explicit broadcast

PyTorch and JAX often hide broadcasting behind arithmetic. tenferro also supports broadcasted arithmetic, but it exposes explicit broadcasting as a standalone operation when you want it.

use tenferro::{CpuBackend, Engine, TracedTensor};

let v = TracedTensor::from_vec(vec![3], vec![1.0_f64, 2.0, 3.0]);
let mut repeated = v.broadcast(&[3, 2], &[0]);

let mut engine = Engine::new(CpuBackend::new());
let result = repeated.eval(&mut engine).unwrap();

assert_eq!(result.shape(), &[3, 2]);
assert_eq!(result.as_slice::<f64>().unwrap(), &[1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);

Reduce over axes

Reduction looks closest to torch.sum(x, dim=...) or jnp.sum(x, axis=...).

use tenferro::{CpuBackend, Engine, TracedTensor};

let a = TracedTensor::from_vec(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
let mut row_sums = a.reduce_sum(&[1]);
let mut total = a.reduce_sum(&[0, 1]);

let mut engine = Engine::new(CpuBackend::new());
let rows = row_sums.eval(&mut engine).unwrap();
let total_value = total.eval(&mut engine).unwrap();

assert_eq!(rows.shape(), &[2]);
assert_eq!(rows.as_slice::<f64>().unwrap(), &[9.0, 12.0]);
assert_eq!(total_value.shape(), &[] as &[usize]);
assert_eq!(total_value.as_slice::<f64>().unwrap(), &[21.0]);