Expand description
Batched matrix linear algebra decompositions with AD rules.
This crate provides SVD, QR, LU, and eigendecomposition for tensors
with shape (m, n, *), adapted from PyTorch’s torch.linalg for
column-major layout:
- First 2 dimensions are the matrix (
m × n). - All following dimensions (
*) are independent batch dimensions. - Input must be column-major contiguous (LAPACK/cuSOLVER native).
This convention mirrors PyTorch’s (*, m, n) but is flipped for
col-major: in col-major the first dimensions are contiguous, so
placing the matrix there ensures LAPACK can operate directly without
transposition.
This module is context-agnostic: it does not know about tensor
networks, MPS, or any specific application. If you need to decompose
a tensor along arbitrary legs, permute + reshape +
contiguous(ColumnMajor) before calling these functions.
§AD rules
Each decomposition has stateless _rrule (reverse-mode / VJP) and
_frule (forward-mode / JVP) functions. These implement matrix-level
AD formulas (Mathieu 2019 et al.) using batched operations that
naturally broadcast over batch dimensions *.
There are no tracked_* / dual_* functions — the chainrules tape
engine composes permute_backward + reshape_backward + svd_rrule
via the standard chain rule automatically.
§Examples
§SVD of a matrix
use tenferro_linalg::{svd, SvdOptions};
use tenferro_tensor::{Tensor, MemoryOrder};
use tenferro_device::LogicalMemorySpace;
let col = MemoryOrder::ColumnMajor;
let mem = LogicalMemorySpace::MainMemory;
// 2D matrix: shape [3, 4]
let a = Tensor::<f64>::zeros(&[3, 4], mem, col);
let result = svd(&a, None).unwrap();
// result.u: shape [3, 3] (m × k, k = min(m,n) = 3)
// result.s: shape [3] (singular values)
// result.vt: shape [3, 4] (k × n)§Batched SVD
use tenferro_linalg::svd;
use tenferro_tensor::{Tensor, MemoryOrder};
use tenferro_device::LogicalMemorySpace;
let col = MemoryOrder::ColumnMajor;
let mem = LogicalMemorySpace::MainMemory;
// Batched: shape [m, n, batch] = [3, 4, 10]
let a = Tensor::<f64>::zeros(&[3, 4, 10], mem, col);
let result = svd(&a, None).unwrap();
// result.u: shape [3, 3, 10]
// result.s: shape [3, 10]
// result.vt: shape [3, 4, 10]§Decomposing a 4D tensor along specific legs
use tenferro_linalg::svd;
use tenferro_tensor::{Tensor, MemoryOrder};
use tenferro_device::LogicalMemorySpace;
let col = MemoryOrder::ColumnMajor;
let mem = LogicalMemorySpace::MainMemory;
// 4D tensor [2, 3, 4, 5] — want SVD with left=[0,1], right=[2,3]
let t = Tensor::<f64>::zeros(&[2, 3, 4, 5], mem, col);
// User's responsibility: permute + reshape + contiguous
let mat = t.permute(&[0, 1, 2, 3]) // already in order
.reshape(&[6, 20]).unwrap() // m = 2*3 = 6, n = 4*5 = 20
.contiguous(col);
let result = svd(&mat, None).unwrap();
// Then reshape result.u, result.vt back to desired tensor shape§Reverse-mode AD (stateless rrule)
use tenferro_linalg::{svd, svd_rrule, SvdCotangent};
use tenferro_tensor::{Tensor, MemoryOrder};
use tenferro_device::LogicalMemorySpace;
let col = MemoryOrder::ColumnMajor;
let mem = LogicalMemorySpace::MainMemory;
let a = Tensor::<f64>::zeros(&[3, 4], mem, col);
let result = svd(&a, None).unwrap();
// Full cotangent: gradient through U, S, and Vt
let cotangent = SvdCotangent {
u: Some(Tensor::ones(&[3, 3], mem, col)),
s: Some(Tensor::ones(&[3], mem, col)),
vt: Some(Tensor::ones(&[3, 4], mem, col)),
};
let grad_a = svd_rrule(&a, &cotangent, None).unwrap();
// grad_a has same shape as a: [3, 4]
// Partial cotangent: gradient only through singular values (always stable)
let cotangent_s_only = SvdCotangent {
u: None,
s: Some(Tensor::ones(&[3], mem, col)),
vt: None,
};
let grad_a2 = svd_rrule(&a, &cotangent_s_only, None).unwrap();Structs§
- Eigen
Cotangent - Cotangent (adjoint) for eigendecomposition outputs.
- Eigen
Result - Eigendecomposition result:
A * V = V * diag(values). - LuCotangent
- Cotangent (adjoint) for LU outputs.
- LuResult
- LU decomposition result:
A = P * L * U(partial pivoting). - QrCotangent
- Cotangent (adjoint) for QR outputs.
- QrResult
- QR decomposition result:
A = Q * R. - SvdCotangent
- Cotangent (adjoint) for SVD outputs.
- SvdOptions
- Options for truncated SVD.
- SvdResult
- SVD result:
A = U * diag(S) * Vt.
Functions§
- eigen
- Compute the eigendecomposition of a batched square matrix.
- eigen_
frule - Forward-mode AD rule for eigendecomposition (JVP / pushforward).
- eigen_
rrule - Reverse-mode AD rule for eigendecomposition (VJP / pullback).
- lu
- Compute the LU decomposition of a batched matrix (partial pivoting).
- lu_
frule - Forward-mode AD rule for LU (JVP / pushforward).
- lu_
rrule - Reverse-mode AD rule for LU (VJP / pullback).
- qr
- Compute the QR decomposition of a batched matrix.
- qr_
frule - Forward-mode AD rule for QR (JVP / pushforward).
- qr_
rrule - Reverse-mode AD rule for QR (VJP / pullback).
- svd
- Compute the SVD of a batched matrix.
- svd_
frule - Forward-mode AD rule for SVD (JVP / pushforward).
- svd_
rrule - Reverse-mode AD rule for SVD (VJP / pullback).