Expand description
Batched matrix linear algebra decompositions with AD rules.
CPU decompositions and solvers are fully implemented via the
faer backend. CUDA/HIP linalg contracts
are already part of the public surface, but backend coverage there remains
partial and capability-gated.
This crate provides SVD, QR, LU, eigendecomposition, Cholesky, least squares,
linear solve, matrix inverse, determinant, pseudoinverse, matrix exponential,
triangular solve, and norms for tensors
with shape (m, n, *), adapted from PyTorch’s torch.linalg for
column-major layout:
- First 2 dimensions are the matrix (
m × n). - All following dimensions (
*) are independent batch dimensions. - Inputs are internally normalized to column-major contiguous layout.
If an input is not already contiguous, an internal copy is performed.
Calling
.contiguous(ColumnMajor)explicitly is optional but useful when you want to control exactly where copies happen.
This convention mirrors PyTorch’s (*, m, n) but is flipped for
col-major: in col-major the first dimensions are contiguous, so
placing the matrix there ensures LAPACK can operate directly without
transposition.
This module is context-agnostic: it does not know about tensor
networks, MPS, or any specific application. If you need to decompose
a tensor along arbitrary legs, permute + reshape before calling
these functions.
§AD rules
Each decomposition has stateless _rrule (reverse-mode / VJP) and
_frule (forward-mode / JVP) functions. These implement matrix-level
AD formulas (Mathieu 2019 et al.) using batched operations that
naturally broadcast over batch dimensions *.
There are no tracked_* / dual_* functions — the chainrules tape
engine composes permute_backward + reshape_backward + svd_rrule
via the standard chain rule automatically.
§Examples
§SVD of a matrix
use tenferro_linalg::{svd, SvdOptions};
use tenferro_prims::CpuContext;
use tenferro_tensor::{Tensor, MemoryOrder};
use tenferro_device::LogicalMemorySpace;
let col = MemoryOrder::ColumnMajor;
let mem = LogicalMemorySpace::MainMemory;
let mut ctx = CpuContext::new(1);
// 2D matrix: shape [3, 4]
let a = Tensor::<f64>::zeros(&[3, 4], mem, col).unwrap();
let result = svd(&mut ctx, &a, None).unwrap();
// result.u: shape [3, 3] (m × k, k = min(m,n) = 3)
// result.s: shape [3] (singular values)
// result.vt: shape [3, 4] (k × n)§Batched SVD
use tenferro_linalg::svd;
use tenferro_prims::CpuContext;
use tenferro_tensor::{Tensor, MemoryOrder};
use tenferro_device::LogicalMemorySpace;
let col = MemoryOrder::ColumnMajor;
let mem = LogicalMemorySpace::MainMemory;
let mut ctx = CpuContext::new(1);
// Batched: shape [m, n, batch] = [3, 4, 10]
let a = Tensor::<f64>::zeros(&[3, 4, 10], mem, col).unwrap();
let result = svd(&mut ctx, &a, None).unwrap();
// result.u: shape [3, 3, 10]
// result.s: shape [3, 10]
// result.vt: shape [3, 4, 10]§Decomposing a 4D tensor along specific legs
use tenferro_linalg::svd;
use tenferro_prims::CpuContext;
use tenferro_tensor::{Tensor, MemoryOrder};
use tenferro_device::LogicalMemorySpace;
let col = MemoryOrder::ColumnMajor;
let mem = LogicalMemorySpace::MainMemory;
let mut ctx = CpuContext::new(1);
// 4D tensor [2, 3, 4, 5] — want SVD with left=[0,1], right=[2,3]
let t = Tensor::<f64>::zeros(&[2, 3, 4, 5], mem, col).unwrap();
// permute + reshape (contiguous is handled internally, but can be called explicitly)
let mat = t.permute(&[0, 1, 2, 3]).unwrap() // already in order
.reshape(&[6, 20]).unwrap(); // m = 2*3 = 6, n = 4*5 = 20
let result = svd(&mut ctx, &mat, None).unwrap();
// Then reshape result.u, result.vt back to desired tensor shape§Reverse-mode AD (stateless rrule)
use tenferro_linalg::{svd, svd_rrule, SvdCotangent};
use tenferro_prims::CpuContext;
use tenferro_tensor::{Tensor, MemoryOrder};
use tenferro_device::LogicalMemorySpace;
let col = MemoryOrder::ColumnMajor;
let mem = LogicalMemorySpace::MainMemory;
let mut ctx = CpuContext::new(1);
let a = Tensor::<f64>::zeros(&[3, 4], mem, col).unwrap();
let result = svd(&mut ctx, &a, None).unwrap();
// Full cotangent: gradient through U, S, and Vt
let cotangent = SvdCotangent {
u: Some(Tensor::ones(&[3, 3], mem, col).unwrap()),
s: Some(Tensor::ones(&[3], mem, col).unwrap()),
vt: Some(Tensor::ones(&[3, 4], mem, col).unwrap()),
};
let grad_a = svd_rrule(&mut ctx, &a, &cotangent, None).unwrap();
// grad_a has same shape as a: [3, 4]
// Partial cotangent: gradient only through singular values (always stable)
let cotangent_s_only = SvdCotangent {
u: None,
s: Some(Tensor::ones(&[3], mem, col).unwrap()),
vt: None,
};
let grad_a2 = svd_rrule(&mut ctx, &a, &cotangent_s_only, None).unwrap();Modules§
- backend
- Backend abstraction for linear algebra operations.
Structs§
- Cholesky
ExResult - Structured Cholesky result with numerical status information.
- EigCotangent
- Cotangent (adjoint) for general eigendecomposition outputs.
- EigResult
- Result of general eigendecomposition (always complex-valued).
- Eigen
Cotangent - Cotangent (adjoint) for eigendecomposition outputs.
- Eigen
Result - Eigendecomposition result:
A * V = V * diag(values). - InvEx
Result - Structured inverse result with numerical status information.
- Lstsq
AuxResult - Auxiliary metadata for least-squares solves.
- Lstsq
Grad - Gradient result for
lstsq_rrule: cotangents for bothAandb. - Lstsq
Result - Least-squares result.
- LuCotangent
- Cotangent (adjoint) for LU outputs.
- LuFactor
ExResult - Packed LU factorization result with numerical status information.
- LuFactor
Result - Packed LU factorization result.
- LuResult
- LU decomposition result:
A = P * L * U. - QrCotangent
- Cotangent (adjoint) for QR outputs.
- QrResult
- QR decomposition result:
A = Q * R. - Slogdet
Cotangent - Cotangent (adjoint) for slogdet outputs.
- Slogdet
Result - Sign-and-log-determinant result:
det(A) = sign * exp(logabsdet). - Solve
ExResult - Structured solve result with numerical status information.
- Solve
Grad - Gradient result for
solve_rrule: cotangents for bothAandb. - SvdCotangent
- Cotangent (adjoint) for SVD outputs.
- SvdOptions
- Options for truncated SVD.
- SvdResult
- SVD result:
A = U * diag(S) * Vt.
Enums§
- Linalg
Capability Op - Backend-facing tensor linalg protocol.
- LuPivot
- Pivoting strategy for LU decomposition.
- Matrix
Norm Ord - Matrix norm order for [
crate::matrix_norm]-style public surfaces. - Norm
Kind - Norm kind for
crate::norm. - Vector
Norm Ord - Vector norm order for [
crate::vector_norm]-style public surfaces.
Traits§
- Kernel
Linalg Scalar - Scalar types with concrete backend kernel support in the current workspace.
- Linalg
Scalar - Scalar types supported by linalg kernel contracts.
Functions§
- cholesky
- Compute the Cholesky decomposition of a Hermitian positive-definite matrix.
- cholesky_
ex - Compute the Cholesky decomposition with numerical status information.
- cholesky_
frule - Forward-mode AD rule for Cholesky (JVP / pushforward).
- cholesky_
rrule - Reverse-mode AD rule for Cholesky (VJP / pullback).
- cond
- Compute the matrix condition number with a selected norm convention.
- cross
- Compute the cross product along the leading vector axis.
- det
- Compute the determinant of a square matrix.
- det_
frule - Forward-mode AD rule for determinant (JVP / pushforward).
- det_
rrule - Reverse-mode AD rule for determinant (VJP / pullback).
- eig
- Compute the eigendecomposition of a general (non-symmetric) square matrix.
- eig_
frule - Forward-mode AD rule for general eigendecomposition (JVP / pushforward).
- eig_
rrule - Reverse-mode AD rule for general eigendecomposition (VJP / pullback).
- eigen
- Compute the eigendecomposition of a batched square matrix.
- eigen_
frule - Forward-mode AD rule for eigendecomposition (JVP / pushforward).
- eigen_
rrule - Reverse-mode AD rule for eigendecomposition (VJP / pullback).
- householder_
product - Form the explicit product of Householder reflectors.
- inv
- Compute the inverse of a square matrix.
- inv_ex
- Compute the inverse with numerical status information.
- inv_
frule - Forward-mode AD rule for matrix inverse (JVP / pushforward).
- inv_
rrule - Reverse-mode AD rule for matrix inverse (VJP / pullback).
- lstsq
- Solve the least squares problem:
x = argmin ||Ax - b||². - lstsq_
aux - Compute least-squares auxiliary metadata.
- lstsq_
frule - Forward-mode AD rule for least squares (JVP / pushforward).
- lstsq_
rrule - Reverse-mode AD rule for least squares (VJP / pullback).
- lu
- Compute the LU decomposition of a batched matrix.
- lu_
factor - Compute the packed LU factorization of a batched matrix.
- lu_
factor_ ex - Compute the packed LU factorization with numerical status information.
- lu_
frule - Forward-mode AD rule for LU (JVP / pushforward).
- lu_
rrule - Reverse-mode AD rule for LU (VJP / pullback).
- lu_
solve - Solve
A x = bfrom a packed LU factorization. - matrix_
exp - Compute the matrix exponential
exp(A)of a square matrix. - matrix_
exp_ frule - Forward-mode AD rule for matrix exponential (JVP / pushforward).
- matrix_
exp_ rrule - Reverse-mode AD rule for matrix exponential (VJP / pullback).
- matrix_
power - Raise a square matrix to an integer power.
- norm
- Compute a norm.
- norm_
frule - Forward-mode AD rule for norm (JVP / pushforward).
- norm_
rrule - Reverse-mode AD rule for norm (VJP / pullback).
- pinv
- Compute the Moore-Penrose pseudoinverse of a matrix.
- pinv_
frule - Forward-mode AD rule for pseudoinverse (JVP / pushforward).
- pinv_
rrule - Reverse-mode AD rule for pseudoinverse (VJP / pullback).
- qr
- Compute the QR decomposition of a batched matrix.
- qr_
frule - Forward-mode AD rule for QR (JVP / pushforward).
- qr_
rrule - Reverse-mode AD rule for QR (VJP / pullback).
- slogdet
- Compute sign and log-absolute-determinant of a square matrix.
- slogdet_
frule - Forward-mode AD rule for slogdet (JVP / pushforward).
- slogdet_
rrule - Reverse-mode AD rule for slogdet (VJP / pullback).
- solve
- Solve a square linear system
A x = b. - solve_
ex - Solve a square linear system with numerical status information.
- solve_
frule - Forward-mode AD rule for linear solve (JVP / pushforward).
- solve_
rrule - Reverse-mode AD rule for linear solve (VJP / pullback).
- solve_
triangular - Solve a triangular linear system
A x = b. - solve_
triangular_ frule - Forward-mode AD rule for triangular solve (JVP / pushforward).
- solve_
triangular_ rrule - Reverse-mode AD rule for triangular solve (VJP / pullback).
- svd
- Compute the SVD of a batched matrix.
- svd_
frule - Forward-mode AD rule for SVD (JVP / pushforward).
- svd_
rrule - Reverse-mode AD rule for SVD (VJP / pullback).
- svdvals
- Compute singular values only for a batched matrix.
- tensorinv
- Invert a tensorized square operator.
- tensorsolve
- Solve a tensorized linear system.
- vander
- Build a Vandermonde matrix from leading-dimension vectors.