Crate tenferro_linalg

Crate tenferro_linalg 

Source
Expand description

Batched matrix linear algebra decompositions with AD rules.

This crate provides SVD, QR, LU, and eigendecomposition for tensors with shape (m, n, *), adapted from PyTorch’s torch.linalg for column-major layout:

  • First 2 dimensions are the matrix (m × n).
  • All following dimensions (*) are independent batch dimensions.
  • Input must be column-major contiguous (LAPACK/cuSOLVER native).

This convention mirrors PyTorch’s (*, m, n) but is flipped for col-major: in col-major the first dimensions are contiguous, so placing the matrix there ensures LAPACK can operate directly without transposition.

This module is context-agnostic: it does not know about tensor networks, MPS, or any specific application. If you need to decompose a tensor along arbitrary legs, permute + reshape + contiguous(ColumnMajor) before calling these functions.

§AD rules

Each decomposition has stateless _rrule (reverse-mode / VJP) and _frule (forward-mode / JVP) functions. These implement matrix-level AD formulas (Mathieu 2019 et al.) using batched operations that naturally broadcast over batch dimensions *.

There are no tracked_* / dual_* functions — the chainrules tape engine composes permute_backward + reshape_backward + svd_rrule via the standard chain rule automatically.

§Examples

§SVD of a matrix

use tenferro_linalg::{svd, SvdOptions};
use tenferro_tensor::{Tensor, MemoryOrder};
use tenferro_device::LogicalMemorySpace;

let col = MemoryOrder::ColumnMajor;
let mem = LogicalMemorySpace::MainMemory;

// 2D matrix: shape [3, 4]
let a = Tensor::<f64>::zeros(&[3, 4], mem, col);
let result = svd(&a, None).unwrap();
// result.u:  shape [3, 3]  (m × k, k = min(m,n) = 3)
// result.s:  shape [3]     (singular values)
// result.vt: shape [3, 4]  (k × n)

§Batched SVD

use tenferro_linalg::svd;
use tenferro_tensor::{Tensor, MemoryOrder};
use tenferro_device::LogicalMemorySpace;

let col = MemoryOrder::ColumnMajor;
let mem = LogicalMemorySpace::MainMemory;

// Batched: shape [m, n, batch] = [3, 4, 10]
let a = Tensor::<f64>::zeros(&[3, 4, 10], mem, col);
let result = svd(&a, None).unwrap();
// result.u:  shape [3, 3, 10]
// result.s:  shape [3, 10]
// result.vt: shape [3, 4, 10]

§Decomposing a 4D tensor along specific legs

use tenferro_linalg::svd;
use tenferro_tensor::{Tensor, MemoryOrder};
use tenferro_device::LogicalMemorySpace;

let col = MemoryOrder::ColumnMajor;
let mem = LogicalMemorySpace::MainMemory;

// 4D tensor [2, 3, 4, 5] — want SVD with left=[0,1], right=[2,3]
let t = Tensor::<f64>::zeros(&[2, 3, 4, 5], mem, col);

// User's responsibility: permute + reshape + contiguous
let mat = t.permute(&[0, 1, 2, 3])   // already in order
           .reshape(&[6, 20]).unwrap() // m = 2*3 = 6, n = 4*5 = 20
           .contiguous(col);
let result = svd(&mat, None).unwrap();
// Then reshape result.u, result.vt back to desired tensor shape

§Reverse-mode AD (stateless rrule)

use tenferro_linalg::{svd, svd_rrule, SvdCotangent};
use tenferro_tensor::{Tensor, MemoryOrder};
use tenferro_device::LogicalMemorySpace;

let col = MemoryOrder::ColumnMajor;
let mem = LogicalMemorySpace::MainMemory;

let a = Tensor::<f64>::zeros(&[3, 4], mem, col);
let result = svd(&a, None).unwrap();

// Full cotangent: gradient through U, S, and Vt
let cotangent = SvdCotangent {
    u: Some(Tensor::ones(&[3, 3], mem, col)),
    s: Some(Tensor::ones(&[3], mem, col)),
    vt: Some(Tensor::ones(&[3, 4], mem, col)),
};
let grad_a = svd_rrule(&a, &cotangent, None).unwrap();
// grad_a has same shape as a: [3, 4]

// Partial cotangent: gradient only through singular values (always stable)
let cotangent_s_only = SvdCotangent {
    u: None,
    s: Some(Tensor::ones(&[3], mem, col)),
    vt: None,
};
let grad_a2 = svd_rrule(&a, &cotangent_s_only, None).unwrap();

Structs§

EigenCotangent
Cotangent (adjoint) for eigendecomposition outputs.
EigenResult
Eigendecomposition result: A * V = V * diag(values).
LuCotangent
Cotangent (adjoint) for LU outputs.
LuResult
LU decomposition result: A = P * L * U (partial pivoting).
QrCotangent
Cotangent (adjoint) for QR outputs.
QrResult
QR decomposition result: A = Q * R.
SvdCotangent
Cotangent (adjoint) for SVD outputs.
SvdOptions
Options for truncated SVD.
SvdResult
SVD result: A = U * diag(S) * Vt.

Functions§

eigen
Compute the eigendecomposition of a batched square matrix.
eigen_frule
Forward-mode AD rule for eigendecomposition (JVP / pushforward).
eigen_rrule
Reverse-mode AD rule for eigendecomposition (VJP / pullback).
lu
Compute the LU decomposition of a batched matrix (partial pivoting).
lu_frule
Forward-mode AD rule for LU (JVP / pushforward).
lu_rrule
Reverse-mode AD rule for LU (VJP / pullback).
qr
Compute the QR decomposition of a batched matrix.
qr_frule
Forward-mode AD rule for QR (JVP / pushforward).
qr_rrule
Reverse-mode AD rule for QR (VJP / pullback).
svd
Compute the SVD of a batched matrix.
svd_frule
Forward-mode AD rule for SVD (JVP / pushforward).
svd_rrule
Reverse-mode AD rule for SVD (VJP / pullback).