Crate tenferro_linalg

Crate tenferro_linalg 

Source
Expand description

Batched matrix linear algebra decompositions with AD rules.

CPU decompositions and solvers are fully implemented via the faer backend. CUDA/HIP linalg contracts are already part of the public surface, but backend coverage there remains partial and capability-gated.

This crate provides SVD, QR, LU, eigendecomposition, Cholesky, least squares, linear solve, matrix inverse, determinant, pseudoinverse, matrix exponential, triangular solve, and norms for tensors with shape (m, n, *), adapted from PyTorch’s torch.linalg for column-major layout:

  • First 2 dimensions are the matrix (m × n).
  • All following dimensions (*) are independent batch dimensions.
  • Inputs are internally normalized to column-major contiguous layout. If an input is not already contiguous, an internal copy is performed. Calling .contiguous(ColumnMajor) explicitly is optional but useful when you want to control exactly where copies happen.

This convention mirrors PyTorch’s (*, m, n) but is flipped for col-major: in col-major the first dimensions are contiguous, so placing the matrix there ensures LAPACK can operate directly without transposition.

This module is context-agnostic: it does not know about tensor networks, MPS, or any specific application. If you need to decompose a tensor along arbitrary legs, permute + reshape before calling these functions.

§AD rules

Each decomposition has stateless _rrule (reverse-mode / VJP) and _frule (forward-mode / JVP) functions. These implement matrix-level AD formulas (Mathieu 2019 et al.) using batched operations that naturally broadcast over batch dimensions *.

There are no tracked_* / dual_* functions — the chainrules tape engine composes permute_backward + reshape_backward + svd_rrule via the standard chain rule automatically.

§Examples

§SVD of a matrix

use tenferro_linalg::{svd, SvdOptions};
use tenferro_prims::CpuContext;
use tenferro_tensor::{Tensor, MemoryOrder};
use tenferro_device::LogicalMemorySpace;

let col = MemoryOrder::ColumnMajor;
let mem = LogicalMemorySpace::MainMemory;
let mut ctx = CpuContext::new(1);

// 2D matrix: shape [3, 4]
let a = Tensor::<f64>::zeros(&[3, 4], mem, col).unwrap();
let result = svd(&mut ctx, &a, None).unwrap();
// result.u:  shape [3, 3]  (m × k, k = min(m,n) = 3)
// result.s:  shape [3]     (singular values)
// result.vt: shape [3, 4]  (k × n)

§Batched SVD

use tenferro_linalg::svd;
use tenferro_prims::CpuContext;
use tenferro_tensor::{Tensor, MemoryOrder};
use tenferro_device::LogicalMemorySpace;

let col = MemoryOrder::ColumnMajor;
let mem = LogicalMemorySpace::MainMemory;
let mut ctx = CpuContext::new(1);

// Batched: shape [m, n, batch] = [3, 4, 10]
let a = Tensor::<f64>::zeros(&[3, 4, 10], mem, col).unwrap();
let result = svd(&mut ctx, &a, None).unwrap();
// result.u:  shape [3, 3, 10]
// result.s:  shape [3, 10]
// result.vt: shape [3, 4, 10]

§Decomposing a 4D tensor along specific legs

use tenferro_linalg::svd;
use tenferro_prims::CpuContext;
use tenferro_tensor::{Tensor, MemoryOrder};
use tenferro_device::LogicalMemorySpace;

let col = MemoryOrder::ColumnMajor;
let mem = LogicalMemorySpace::MainMemory;
let mut ctx = CpuContext::new(1);

// 4D tensor [2, 3, 4, 5] — want SVD with left=[0,1], right=[2,3]
let t = Tensor::<f64>::zeros(&[2, 3, 4, 5], mem, col).unwrap();

// permute + reshape (contiguous is handled internally, but can be called explicitly)
let mat = t.permute(&[0, 1, 2, 3]).unwrap()  // already in order
           .reshape(&[6, 20]).unwrap();        // m = 2*3 = 6, n = 4*5 = 20
let result = svd(&mut ctx, &mat, None).unwrap();
// Then reshape result.u, result.vt back to desired tensor shape

§Reverse-mode AD (stateless rrule)

use tenferro_linalg::{svd, svd_rrule, SvdCotangent};
use tenferro_prims::CpuContext;
use tenferro_tensor::{Tensor, MemoryOrder};
use tenferro_device::LogicalMemorySpace;

let col = MemoryOrder::ColumnMajor;
let mem = LogicalMemorySpace::MainMemory;
let mut ctx = CpuContext::new(1);

let a = Tensor::<f64>::zeros(&[3, 4], mem, col).unwrap();
let result = svd(&mut ctx, &a, None).unwrap();

// Full cotangent: gradient through U, S, and Vt
let cotangent = SvdCotangent {
    u: Some(Tensor::ones(&[3, 3], mem, col).unwrap()),
    s: Some(Tensor::ones(&[3], mem, col).unwrap()),
    vt: Some(Tensor::ones(&[3, 4], mem, col).unwrap()),
};
let grad_a = svd_rrule(&mut ctx, &a, &cotangent, None).unwrap();
// grad_a has same shape as a: [3, 4]

// Partial cotangent: gradient only through singular values (always stable)
let cotangent_s_only = SvdCotangent {
    u: None,
    s: Some(Tensor::ones(&[3], mem, col).unwrap()),
    vt: None,
};
let grad_a2 = svd_rrule(&mut ctx, &a, &cotangent_s_only, None).unwrap();

Modules§

backend
Backend abstraction for linear algebra operations.

Structs§

CholeskyExResult
Structured Cholesky result with numerical status information.
EigCotangent
Cotangent (adjoint) for general eigendecomposition outputs.
EigResult
Result of general eigendecomposition (always complex-valued).
EigenCotangent
Cotangent (adjoint) for eigendecomposition outputs.
EigenResult
Eigendecomposition result: A * V = V * diag(values).
InvExResult
Structured inverse result with numerical status information.
LstsqAuxResult
Auxiliary metadata for least-squares solves.
LstsqGrad
Gradient result for lstsq_rrule: cotangents for both A and b.
LstsqResult
Least-squares result.
LuCotangent
Cotangent (adjoint) for LU outputs.
LuFactorExResult
Packed LU factorization result with numerical status information.
LuFactorResult
Packed LU factorization result.
LuResult
LU decomposition result: A = P * L * U.
QrCotangent
Cotangent (adjoint) for QR outputs.
QrResult
QR decomposition result: A = Q * R.
SlogdetCotangent
Cotangent (adjoint) for slogdet outputs.
SlogdetResult
Sign-and-log-determinant result: det(A) = sign * exp(logabsdet).
SolveExResult
Structured solve result with numerical status information.
SolveGrad
Gradient result for solve_rrule: cotangents for both A and b.
SvdCotangent
Cotangent (adjoint) for SVD outputs.
SvdOptions
Options for truncated SVD.
SvdResult
SVD result: A = U * diag(S) * Vt.

Enums§

LinalgCapabilityOp
Backend-facing tensor linalg protocol.
LuPivot
Pivoting strategy for LU decomposition.
MatrixNormOrd
Matrix norm order for [crate::matrix_norm]-style public surfaces.
NormKind
Norm kind for crate::norm.
VectorNormOrd
Vector norm order for [crate::vector_norm]-style public surfaces.

Traits§

KernelLinalgScalar
Scalar types with concrete backend kernel support in the current workspace.
LinalgScalar
Scalar types supported by linalg kernel contracts.

Functions§

cholesky
Compute the Cholesky decomposition of a Hermitian positive-definite matrix.
cholesky_ex
Compute the Cholesky decomposition with numerical status information.
cholesky_frule
Forward-mode AD rule for Cholesky (JVP / pushforward).
cholesky_rrule
Reverse-mode AD rule for Cholesky (VJP / pullback).
cond
Compute the matrix condition number with a selected norm convention.
cross
Compute the cross product along the leading vector axis.
det
Compute the determinant of a square matrix.
det_frule
Forward-mode AD rule for determinant (JVP / pushforward).
det_rrule
Reverse-mode AD rule for determinant (VJP / pullback).
eig
Compute the eigendecomposition of a general (non-symmetric) square matrix.
eig_frule
Forward-mode AD rule for general eigendecomposition (JVP / pushforward).
eig_rrule
Reverse-mode AD rule for general eigendecomposition (VJP / pullback).
eigen
Compute the eigendecomposition of a batched square matrix.
eigen_frule
Forward-mode AD rule for eigendecomposition (JVP / pushforward).
eigen_rrule
Reverse-mode AD rule for eigendecomposition (VJP / pullback).
householder_product
Form the explicit product of Householder reflectors.
inv
Compute the inverse of a square matrix.
inv_ex
Compute the inverse with numerical status information.
inv_frule
Forward-mode AD rule for matrix inverse (JVP / pushforward).
inv_rrule
Reverse-mode AD rule for matrix inverse (VJP / pullback).
lstsq
Solve the least squares problem: x = argmin ||Ax - b||².
lstsq_aux
Compute least-squares auxiliary metadata.
lstsq_frule
Forward-mode AD rule for least squares (JVP / pushforward).
lstsq_rrule
Reverse-mode AD rule for least squares (VJP / pullback).
lu
Compute the LU decomposition of a batched matrix.
lu_factor
Compute the packed LU factorization of a batched matrix.
lu_factor_ex
Compute the packed LU factorization with numerical status information.
lu_frule
Forward-mode AD rule for LU (JVP / pushforward).
lu_rrule
Reverse-mode AD rule for LU (VJP / pullback).
lu_solve
Solve A x = b from a packed LU factorization.
matrix_exp
Compute the matrix exponential exp(A) of a square matrix.
matrix_exp_frule
Forward-mode AD rule for matrix exponential (JVP / pushforward).
matrix_exp_rrule
Reverse-mode AD rule for matrix exponential (VJP / pullback).
matrix_power
Raise a square matrix to an integer power.
norm
Compute a norm.
norm_frule
Forward-mode AD rule for norm (JVP / pushforward).
norm_rrule
Reverse-mode AD rule for norm (VJP / pullback).
pinv
Compute the Moore-Penrose pseudoinverse of a matrix.
pinv_frule
Forward-mode AD rule for pseudoinverse (JVP / pushforward).
pinv_rrule
Reverse-mode AD rule for pseudoinverse (VJP / pullback).
qr
Compute the QR decomposition of a batched matrix.
qr_frule
Forward-mode AD rule for QR (JVP / pushforward).
qr_rrule
Reverse-mode AD rule for QR (VJP / pullback).
slogdet
Compute sign and log-absolute-determinant of a square matrix.
slogdet_frule
Forward-mode AD rule for slogdet (JVP / pushforward).
slogdet_rrule
Reverse-mode AD rule for slogdet (VJP / pullback).
solve
Solve a square linear system A x = b.
solve_ex
Solve a square linear system with numerical status information.
solve_frule
Forward-mode AD rule for linear solve (JVP / pushforward).
solve_rrule
Reverse-mode AD rule for linear solve (VJP / pullback).
solve_triangular
Solve a triangular linear system A x = b.
solve_triangular_frule
Forward-mode AD rule for triangular solve (JVP / pushforward).
solve_triangular_rrule
Reverse-mode AD rule for triangular solve (VJP / pullback).
svd
Compute the SVD of a batched matrix.
svd_frule
Forward-mode AD rule for SVD (JVP / pushforward).
svd_rrule
Reverse-mode AD rule for SVD (VJP / pullback).
svdvals
Compute singular values only for a batched matrix.
tensorinv
Invert a tensorized square operator.
tensorsolve
Solve a tensorized linear system.
vander
Build a Vandermonde matrix from leading-dimension vectors.