tenferro_linalg/
lib.rs

1#![allow(clippy::multiple_bound_locations)]
2
3//! Batched matrix linear algebra decompositions with AD rules.
4//!
5//! CPU decompositions and solvers are fully implemented via the
6//! [`faer`](https://crates.io/crates/faer) backend. CUDA/HIP linalg contracts
7//! are already part of the public surface, but backend coverage there remains
8//! partial and capability-gated.
9//!
10//! This crate provides SVD, QR, LU, eigendecomposition, Cholesky, least squares,
11//! linear solve, matrix inverse, determinant, pseudoinverse, matrix exponential,
12//! triangular solve, and norms for tensors
13//! with shape `(m, n, *)`, adapted from PyTorch's `torch.linalg` for
14//! column-major layout:
15//!
16//! - **First 2 dimensions** are the matrix (`m × n`).
17//! - **All following dimensions** (`*`) are independent batch dimensions.
18//! - Inputs are **internally normalized** to column-major contiguous layout.
19//!   If an input is not already contiguous, an internal copy is performed.
20//!   Calling `.contiguous(ColumnMajor)` explicitly is optional but useful
21//!   when you want to control exactly where copies happen.
22//!
23//! This convention mirrors PyTorch's `(*, m, n)` but is flipped for
24//! col-major: in col-major the first dimensions are contiguous, so
25//! placing the matrix there ensures LAPACK can operate directly without
26//! transposition.
27//!
28//! This module is **context-agnostic**: it does not know about tensor
29//! networks, MPS, or any specific application. If you need to decompose
30//! a tensor along arbitrary legs, `permute` + `reshape` before calling
31//! these functions.
32//!
33//! # AD rules
34//!
35//! Each decomposition has stateless `_rrule` (reverse-mode / VJP) and
36//! `_frule` (forward-mode / JVP) functions. These implement matrix-level
37//! AD formulas (Mathieu 2019 et al.) using batched operations that
38//! naturally broadcast over batch dimensions `*`.
39//!
40//! There are no `tracked_*` / `dual_*` functions — the chainrules tape
41//! engine composes `permute_backward` + `reshape_backward` + `svd_rrule`
42//! via the standard chain rule automatically.
43//!
44//! # Examples
45//!
46//! ## SVD of a matrix
47//!
48//! ```
49//! use tenferro_linalg::{svd, SvdOptions};
50//! use tenferro_prims::CpuContext;
51//! use tenferro_tensor::{Tensor, MemoryOrder};
52//! use tenferro_device::LogicalMemorySpace;
53//!
54//! let col = MemoryOrder::ColumnMajor;
55//! let mem = LogicalMemorySpace::MainMemory;
56//! let mut ctx = CpuContext::new(1);
57//!
58//! // 2D matrix: shape [3, 4]
59//! let a = Tensor::<f64>::zeros(&[3, 4], mem, col).unwrap();
60//! let result = svd(&mut ctx, &a, None).unwrap();
61//! // result.u:  shape [3, 3]  (m × k, k = min(m,n) = 3)
62//! // result.s:  shape [3]     (singular values)
63//! // result.vt: shape [3, 4]  (k × n)
64//! ```
65//!
66//! ## Batched SVD
67//!
68//! ```
69//! use tenferro_linalg::svd;
70//! use tenferro_prims::CpuContext;
71//! use tenferro_tensor::{Tensor, MemoryOrder};
72//! use tenferro_device::LogicalMemorySpace;
73//!
74//! let col = MemoryOrder::ColumnMajor;
75//! let mem = LogicalMemorySpace::MainMemory;
76//! let mut ctx = CpuContext::new(1);
77//!
78//! // Batched: shape [m, n, batch] = [3, 4, 10]
79//! let a = Tensor::<f64>::zeros(&[3, 4, 10], mem, col).unwrap();
80//! let result = svd(&mut ctx, &a, None).unwrap();
81//! // result.u:  shape [3, 3, 10]
82//! // result.s:  shape [3, 10]
83//! // result.vt: shape [3, 4, 10]
84//! ```
85//!
86//! ## Decomposing a 4D tensor along specific legs
87//!
88//! ```
89//! use tenferro_linalg::svd;
90//! use tenferro_prims::CpuContext;
91//! use tenferro_tensor::{Tensor, MemoryOrder};
92//! use tenferro_device::LogicalMemorySpace;
93//!
94//! let col = MemoryOrder::ColumnMajor;
95//! let mem = LogicalMemorySpace::MainMemory;
96//! let mut ctx = CpuContext::new(1);
97//!
98//! // 4D tensor [2, 3, 4, 5] — want SVD with left=[0,1], right=[2,3]
99//! let t = Tensor::<f64>::zeros(&[2, 3, 4, 5], mem, col).unwrap();
100//!
101//! // permute + reshape (contiguous is handled internally, but can be called explicitly)
102//! let mat = t.permute(&[0, 1, 2, 3]).unwrap()  // already in order
103//!            .reshape(&[6, 20]).unwrap();        // m = 2*3 = 6, n = 4*5 = 20
104//! let result = svd(&mut ctx, &mat, None).unwrap();
105//! // Then reshape result.u, result.vt back to desired tensor shape
106//! ```
107//!
108//! ## Reverse-mode AD (stateless rrule)
109//!
110//! ```
111//! use tenferro_linalg::{svd, svd_rrule, SvdCotangent};
112//! use tenferro_prims::CpuContext;
113//! use tenferro_tensor::{Tensor, MemoryOrder};
114//! use tenferro_device::LogicalMemorySpace;
115//!
116//! let col = MemoryOrder::ColumnMajor;
117//! let mem = LogicalMemorySpace::MainMemory;
118//! let mut ctx = CpuContext::new(1);
119//!
120//! let a = Tensor::<f64>::zeros(&[3, 4], mem, col).unwrap();
121//! let result = svd(&mut ctx, &a, None).unwrap();
122//!
123//! // Full cotangent: gradient through U, S, and Vt
124//! let cotangent = SvdCotangent {
125//!     u: Some(Tensor::ones(&[3, 3], mem, col).unwrap()),
126//!     s: Some(Tensor::ones(&[3], mem, col).unwrap()),
127//!     vt: Some(Tensor::ones(&[3, 4], mem, col).unwrap()),
128//! };
129//! let grad_a = svd_rrule(&mut ctx, &a, &cotangent, None).unwrap();
130//! // grad_a has same shape as a: [3, 4]
131//!
132//! // Partial cotangent: gradient only through singular values (always stable)
133//! let cotangent_s_only = SvdCotangent {
134//!     u: None,
135//!     s: Some(Tensor::ones(&[3], mem, col).unwrap()),
136//!     vt: None,
137//! };
138//! let grad_a2 = svd_rrule(&mut ctx, &a, &cotangent_s_only, None).unwrap();
139//! ```
140
141#[cfg(all(feature = "provider-src", not(feature = "linalg-lapack")))]
142compile_error!("provider-src requires linalg-lapack");
143#[cfg(all(feature = "provider-inject", not(feature = "linalg-lapack")))]
144compile_error!("provider-inject requires linalg-lapack");
145#[cfg(all(
146    any(
147        feature = "src-openblas",
148        feature = "src-netlib",
149        feature = "src-accelerate",
150        feature = "src-r",
151        feature = "src-intel-mkl-dynamic-sequential",
152        feature = "src-intel-mkl-dynamic-parallel",
153        feature = "src-intel-mkl-static-sequential",
154        feature = "src-intel-mkl-static-parallel"
155    ),
156    not(feature = "linalg-lapack")
157))]
158compile_error!("src-* features require linalg-lapack and provider-src");
159
160#[cfg(feature = "linalg-lapack")]
161const _: () = {
162    let provider_count =
163        (cfg!(feature = "provider-src") as usize) + (cfg!(feature = "provider-inject") as usize);
164    assert!(
165        provider_count == 1,
166        "linalg-lapack requires exactly one provider: provider-src or provider-inject"
167    );
168
169    let src_count = (cfg!(feature = "src-openblas") as usize)
170        + (cfg!(feature = "src-netlib") as usize)
171        + (cfg!(feature = "src-accelerate") as usize)
172        + (cfg!(feature = "src-r") as usize)
173        + (cfg!(feature = "src-intel-mkl-dynamic-sequential") as usize)
174        + (cfg!(feature = "src-intel-mkl-dynamic-parallel") as usize)
175        + (cfg!(feature = "src-intel-mkl-static-sequential") as usize)
176        + (cfg!(feature = "src-intel-mkl-static-parallel") as usize);
177
178    if cfg!(feature = "provider-src") {
179        assert!(
180            src_count == 1,
181            "provider-src requires exactly one src-* feature"
182        );
183    }
184    if cfg!(feature = "provider-inject") {
185        assert!(src_count == 0, "provider-inject forbids src-* features");
186    }
187};
188
189#[cfg(feature = "provider-src")]
190extern crate blas_src as _;
191#[cfg(feature = "provider-src")]
192extern crate cblas_src as _;
193#[cfg(feature = "provider-src")]
194extern crate lapack_src as _;
195
196#[cfg(feature = "provider-inject")]
197extern crate cblas_inject as _;
198#[cfg(feature = "provider-inject")]
199extern crate lapack_inject as _;
200
201pub mod backend;
202#[cfg(all(feature = "linalg-lapack", feature = "provider-inject"))]
203pub mod inject;
204mod prims_bridge;
205
206use chainrules_core::AdResult;
207use num_traits::Zero;
208use tenferro_device::{Error, Result};
209use tenferro_tensor::{MemoryOrder, Tensor};
210
211mod ad_helpers;
212mod frules;
213mod primal;
214mod result_types;
215mod rrules;
216
217pub(crate) use ad_helpers::*;
218
219#[doc(hidden)]
220pub use ad_helpers::MatrixExpAbsTensor;
221pub use frules::*;
222pub(crate) use primal::require_linalg_support;
223pub use primal::*;
224#[doc(hidden)]
225pub use prims_bridge::ScaleTensorByRealSameShape;
226pub use result_types::*;
227pub use rrules::*;
228#[doc(inline)]
229pub use tenferro_linalg_prims::{KernelLinalgScalar, LinalgCapabilityOp, LinalgScalar};
230
231#[cfg(test)]
232mod tests;