tenferro_linalg/result_types/decomposition.rs
1use super::*;
2
3/// SVD result: `A = U * diag(S) * Vt`.
4///
5/// For an input of shape `(m, n, *)` with `k = min(m, n)`:
6///
7/// - `u`: shape `(m, k, *)`
8/// - `s`: shape `(k, *)` in descending order
9/// - `vt`: shape `(k, n, *)`
10///
11/// # Examples
12///
13/// ```
14/// use tenferro_device::LogicalMemorySpace;
15/// use tenferro_linalg::svd;
16/// use tenferro_prims::CpuContext;
17/// use tenferro_tensor::{MemoryOrder, Tensor};
18///
19/// let mut ctx = CpuContext::new(1);
20/// let a = Tensor::<f64>::zeros(
21/// &[3, 4],
22/// LogicalMemorySpace::MainMemory,
23/// MemoryOrder::ColumnMajor,
24/// ).unwrap();
25/// let result = svd(&mut ctx, &a, None).unwrap();
26/// assert_eq!(result.s.ndim(), 1);
27/// ```
28#[derive(Debug)]
29pub struct SvdResult<T: Scalar, R: Scalar = T> {
30 /// Left singular vectors. Shape: `(m, k, *)`.
31 pub u: Tensor<T>,
32 /// Singular values. Shape: `(k, *)`.
33 pub s: Tensor<R>,
34 /// Right singular vectors (conjugate-transposed). Shape: `(k, n, *)`.
35 pub vt: Tensor<T>,
36}
37
38/// Options for truncated SVD.
39///
40/// When both fields are specified, the more restrictive condition applies.
41///
42/// # Examples
43///
44/// ```
45/// use tenferro_linalg::SvdOptions;
46///
47/// let opts = SvdOptions {
48/// max_rank: Some(8),
49/// cutoff: Some(1e-12),
50/// };
51/// assert_eq!(opts.max_rank, Some(8));
52/// ```
53#[derive(Debug, Clone, Default)]
54pub struct SvdOptions {
55 /// Maximum number of singular values to keep.
56 pub max_rank: Option<usize>,
57 /// Discard singular values below this threshold.
58 pub cutoff: Option<f64>,
59}
60
61/// QR decomposition result: `A = Q * R`.
62///
63/// For an input of shape `(m, n, *)` with `k = min(m, n)`:
64///
65/// - `q`: shape `(m, k, *)`
66/// - `r`: shape `(k, n, *)`
67///
68/// # Examples
69///
70/// ```
71/// use tenferro_device::LogicalMemorySpace;
72/// use tenferro_linalg::qr;
73/// use tenferro_prims::CpuContext;
74/// use tenferro_tensor::{MemoryOrder, Tensor};
75///
76/// let mut ctx = CpuContext::new(1);
77/// let a = Tensor::<f64>::zeros(
78/// &[4, 3],
79/// LogicalMemorySpace::MainMemory,
80/// MemoryOrder::ColumnMajor,
81/// ).unwrap();
82/// let result = qr(&mut ctx, &a).unwrap();
83/// assert_eq!(result.q.dims(), &[4, 3]);
84/// assert_eq!(result.r.dims(), &[3, 3]);
85/// ```
86#[derive(Debug)]
87pub struct QrResult<T: Scalar> {
88 /// Orthonormal factor. Shape: `(m, k, *)`.
89 pub q: Tensor<T>,
90 /// Upper-triangular factor. Shape: `(k, n, *)`.
91 pub r: Tensor<T>,
92}
93
94/// Pivoting strategy for LU decomposition.
95///
96/// # Examples
97///
98/// ```
99/// use tenferro_linalg::LuPivot;
100///
101/// assert_eq!(LuPivot::default(), LuPivot::Partial);
102/// assert_eq!(LuPivot::NoPivot, LuPivot::NoPivot);
103/// ```
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
105pub enum LuPivot {
106 /// Partial row pivoting.
107 #[default]
108 Partial,
109 /// No pivoting.
110 NoPivot,
111}
112
113/// LU decomposition result: `A = P * L * U`.
114///
115/// For an input of shape `(m, n, *)` with `k = min(m, n)`:
116///
117/// - `p`: permutation matrix tensor of shape `(m, m, *)`
118/// or an empty tensor of shape `[0]` when pivoting is disabled
119/// - `l`: shape `(m, k, *)`
120/// - `u`: shape `(k, n, *)`
121///
122/// # Examples
123///
124/// ```
125/// use tenferro_linalg::{lu, LuPivot};
126/// use tenferro_prims::CpuContext;
127/// use tenferro_tensor::{MemoryOrder, Tensor};
128///
129/// let mut ctx = CpuContext::new(1);
130/// let a = Tensor::<f64>::from_slice(&[1.0, 0.0, 0.0, 1.0], &[2, 2], MemoryOrder::ColumnMajor)
131/// .unwrap();
132/// let result = lu(&mut ctx, &a, LuPivot::Partial).unwrap();
133/// assert_eq!(result.p.dims(), &[2, 2]);
134/// ```
135#[derive(Debug)]
136pub struct LuResult<T: Scalar> {
137 /// Row permutation matrix tensor.
138 pub p: Tensor<T>,
139 /// Unit lower-triangular factor.
140 pub l: Tensor<T>,
141 /// Upper-triangular factor.
142 pub u: Tensor<T>,
143}
144
145/// Gradient result for `solve_rrule`: cotangents for both `A` and `b`.
146///
147/// # Examples
148///
149/// ```
150/// use tenferro_device::LogicalMemorySpace;
151/// use tenferro_linalg::solve_rrule;
152/// use tenferro_prims::CpuContext;
153/// use tenferro_tensor::{MemoryOrder, Tensor};
154///
155/// let mem = LogicalMemorySpace::MainMemory;
156/// let col = MemoryOrder::ColumnMajor;
157/// let mut ctx = CpuContext::new(1);
158/// let a = Tensor::<f64>::eye(3, mem, col).unwrap();
159/// let b = Tensor::<f64>::ones(&[3], mem, col).unwrap();
160/// let cotangent = Tensor::<f64>::ones(&[3], mem, col).unwrap();
161/// let grad = solve_rrule(&mut ctx, &a, &b, &cotangent).unwrap();
162/// assert_eq!(grad.a.dims(), &[3, 3]);
163/// assert_eq!(grad.b.dims(), &[3]);
164/// ```
165#[derive(Debug)]
166pub struct SolveGrad<T: Scalar> {
167 /// Cotangent for the coefficient matrix `A`.
168 pub a: Tensor<T>,
169 /// Cotangent for the right-hand side `b`.
170 pub b: Tensor<T>,
171}
172
173/// Least-squares result.
174///
175/// # Examples
176///
177/// ```
178/// use tenferro_linalg::lstsq;
179/// use tenferro_prims::CpuContext;
180/// use tenferro_tensor::{MemoryOrder, Tensor};
181///
182/// let mut ctx = CpuContext::new(1);
183/// let a = Tensor::from_slice(&[1.0_f64, 0.0, 1.0, 1.0], &[2, 2], MemoryOrder::ColumnMajor)
184/// .unwrap();
185/// let b = Tensor::from_slice(&[1.0_f64, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
186/// let result = lstsq(&mut ctx, &a, &b).unwrap();
187/// assert_eq!(result.solution.dims(), &[2]);
188/// ```
189#[derive(Debug)]
190pub struct LstsqResult<S: Scalar, R: Scalar> {
191 /// Least-squares solution.
192 pub solution: Tensor<S>,
193 /// Squared residual summaries `||AX - B||_F^2` per right-hand side, or an empty tensor.
194 pub residuals: Tensor<R>,
195}
196
197/// Auxiliary metadata for least-squares solves.
198#[derive(Debug)]
199pub struct LstsqAuxResult<R: Scalar> {
200 /// Numerical rank per batch item, stored as a batch-shaped real-valued count tensor.
201 pub rank: Tensor<R>,
202 /// Singular values used for the rank estimate.
203 pub singular_values: Tensor<R>,
204}
205
206/// Gradient result for `lstsq_rrule`: cotangents for both `A` and `b`.
207///
208/// # Examples
209///
210/// ```
211/// use tenferro_linalg::LstsqGrad;
212/// use tenferro_tensor::{MemoryOrder, Tensor};
213///
214/// let grad = LstsqGrad {
215/// a: Tensor::<f64>::zeros(&[2, 2], tenferro_device::LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap(),
216/// b: Tensor::<f64>::zeros(&[2], tenferro_device::LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap(),
217/// };
218/// assert_eq!(grad.a.ndim(), 2);
219/// assert_eq!(grad.b.ndim(), 1);
220/// ```
221#[derive(Debug)]
222pub struct LstsqGrad<T: Scalar> {
223 /// Cotangent for the system matrix `A`.
224 pub a: Tensor<T>,
225 /// Cotangent for the right-hand side `b`.
226 pub b: Tensor<T>,
227}