tenferro_linalg/result_types/
decomposition.rs

1use super::*;
2
3/// SVD result: `A = U * diag(S) * Vt`.
4///
5/// For an input of shape `(m, n, *)` with `k = min(m, n)`:
6///
7/// - `u`: shape `(m, k, *)`
8/// - `s`: shape `(k, *)` in descending order
9/// - `vt`: shape `(k, n, *)`
10///
11/// # Examples
12///
13/// ```
14/// use tenferro_device::LogicalMemorySpace;
15/// use tenferro_linalg::svd;
16/// use tenferro_prims::CpuContext;
17/// use tenferro_tensor::{MemoryOrder, Tensor};
18///
19/// let mut ctx = CpuContext::new(1);
20/// let a = Tensor::<f64>::zeros(
21///     &[3, 4],
22///     LogicalMemorySpace::MainMemory,
23///     MemoryOrder::ColumnMajor,
24/// ).unwrap();
25/// let result = svd(&mut ctx, &a, None).unwrap();
26/// assert_eq!(result.s.ndim(), 1);
27/// ```
28#[derive(Debug)]
29pub struct SvdResult<T: Scalar, R: Scalar = T> {
30    /// Left singular vectors. Shape: `(m, k, *)`.
31    pub u: Tensor<T>,
32    /// Singular values. Shape: `(k, *)`.
33    pub s: Tensor<R>,
34    /// Right singular vectors (conjugate-transposed). Shape: `(k, n, *)`.
35    pub vt: Tensor<T>,
36}
37
38/// Options for truncated SVD.
39///
40/// When both fields are specified, the more restrictive condition applies.
41///
42/// # Examples
43///
44/// ```
45/// use tenferro_linalg::SvdOptions;
46///
47/// let opts = SvdOptions {
48///     max_rank: Some(8),
49///     cutoff: Some(1e-12),
50/// };
51/// assert_eq!(opts.max_rank, Some(8));
52/// ```
53#[derive(Debug, Clone, Default)]
54pub struct SvdOptions {
55    /// Maximum number of singular values to keep.
56    pub max_rank: Option<usize>,
57    /// Discard singular values below this threshold.
58    pub cutoff: Option<f64>,
59}
60
61/// QR decomposition result: `A = Q * R`.
62///
63/// For an input of shape `(m, n, *)` with `k = min(m, n)`:
64///
65/// - `q`: shape `(m, k, *)`
66/// - `r`: shape `(k, n, *)`
67///
68/// # Examples
69///
70/// ```
71/// use tenferro_device::LogicalMemorySpace;
72/// use tenferro_linalg::qr;
73/// use tenferro_prims::CpuContext;
74/// use tenferro_tensor::{MemoryOrder, Tensor};
75///
76/// let mut ctx = CpuContext::new(1);
77/// let a = Tensor::<f64>::zeros(
78///     &[4, 3],
79///     LogicalMemorySpace::MainMemory,
80///     MemoryOrder::ColumnMajor,
81/// ).unwrap();
82/// let result = qr(&mut ctx, &a).unwrap();
83/// assert_eq!(result.q.dims(), &[4, 3]);
84/// assert_eq!(result.r.dims(), &[3, 3]);
85/// ```
86#[derive(Debug)]
87pub struct QrResult<T: Scalar> {
88    /// Orthonormal factor. Shape: `(m, k, *)`.
89    pub q: Tensor<T>,
90    /// Upper-triangular factor. Shape: `(k, n, *)`.
91    pub r: Tensor<T>,
92}
93
94/// Pivoting strategy for LU decomposition.
95///
96/// # Examples
97///
98/// ```
99/// use tenferro_linalg::LuPivot;
100///
101/// assert_eq!(LuPivot::default(), LuPivot::Partial);
102/// assert_eq!(LuPivot::NoPivot, LuPivot::NoPivot);
103/// ```
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
105pub enum LuPivot {
106    /// Partial row pivoting.
107    #[default]
108    Partial,
109    /// No pivoting.
110    NoPivot,
111}
112
113/// LU decomposition result: `A = P * L * U`.
114///
115/// For an input of shape `(m, n, *)` with `k = min(m, n)`:
116///
117/// - `p`: permutation matrix tensor of shape `(m, m, *)`
118///   or an empty tensor of shape `[0]` when pivoting is disabled
119/// - `l`: shape `(m, k, *)`
120/// - `u`: shape `(k, n, *)`
121///
122/// # Examples
123///
124/// ```
125/// use tenferro_linalg::{lu, LuPivot};
126/// use tenferro_prims::CpuContext;
127/// use tenferro_tensor::{MemoryOrder, Tensor};
128///
129/// let mut ctx = CpuContext::new(1);
130/// let a = Tensor::<f64>::from_slice(&[1.0, 0.0, 0.0, 1.0], &[2, 2], MemoryOrder::ColumnMajor)
131///     .unwrap();
132/// let result = lu(&mut ctx, &a, LuPivot::Partial).unwrap();
133/// assert_eq!(result.p.dims(), &[2, 2]);
134/// ```
135#[derive(Debug)]
136pub struct LuResult<T: Scalar> {
137    /// Row permutation matrix tensor.
138    pub p: Tensor<T>,
139    /// Unit lower-triangular factor.
140    pub l: Tensor<T>,
141    /// Upper-triangular factor.
142    pub u: Tensor<T>,
143}
144
145/// Gradient result for `solve_rrule`: cotangents for both `A` and `b`.
146///
147/// # Examples
148///
149/// ```
150/// use tenferro_device::LogicalMemorySpace;
151/// use tenferro_linalg::solve_rrule;
152/// use tenferro_prims::CpuContext;
153/// use tenferro_tensor::{MemoryOrder, Tensor};
154///
155/// let mem = LogicalMemorySpace::MainMemory;
156/// let col = MemoryOrder::ColumnMajor;
157/// let mut ctx = CpuContext::new(1);
158/// let a = Tensor::<f64>::eye(3, mem, col).unwrap();
159/// let b = Tensor::<f64>::ones(&[3], mem, col).unwrap();
160/// let cotangent = Tensor::<f64>::ones(&[3], mem, col).unwrap();
161/// let grad = solve_rrule(&mut ctx, &a, &b, &cotangent).unwrap();
162/// assert_eq!(grad.a.dims(), &[3, 3]);
163/// assert_eq!(grad.b.dims(), &[3]);
164/// ```
165#[derive(Debug)]
166pub struct SolveGrad<T: Scalar> {
167    /// Cotangent for the coefficient matrix `A`.
168    pub a: Tensor<T>,
169    /// Cotangent for the right-hand side `b`.
170    pub b: Tensor<T>,
171}
172
173/// Least-squares result.
174///
175/// # Examples
176///
177/// ```
178/// use tenferro_linalg::lstsq;
179/// use tenferro_prims::CpuContext;
180/// use tenferro_tensor::{MemoryOrder, Tensor};
181///
182/// let mut ctx = CpuContext::new(1);
183/// let a = Tensor::from_slice(&[1.0_f64, 0.0, 1.0, 1.0], &[2, 2], MemoryOrder::ColumnMajor)
184///     .unwrap();
185/// let b = Tensor::from_slice(&[1.0_f64, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
186/// let result = lstsq(&mut ctx, &a, &b).unwrap();
187/// assert_eq!(result.solution.dims(), &[2]);
188/// ```
189#[derive(Debug)]
190pub struct LstsqResult<S: Scalar, R: Scalar> {
191    /// Least-squares solution.
192    pub solution: Tensor<S>,
193    /// Squared residual summaries `||AX - B||_F^2` per right-hand side, or an empty tensor.
194    pub residuals: Tensor<R>,
195}
196
197/// Auxiliary metadata for least-squares solves.
198#[derive(Debug)]
199pub struct LstsqAuxResult<R: Scalar> {
200    /// Numerical rank per batch item, stored as a batch-shaped real-valued count tensor.
201    pub rank: Tensor<R>,
202    /// Singular values used for the rank estimate.
203    pub singular_values: Tensor<R>,
204}
205
206/// Gradient result for `lstsq_rrule`: cotangents for both `A` and `b`.
207///
208/// # Examples
209///
210/// ```
211/// use tenferro_linalg::LstsqGrad;
212/// use tenferro_tensor::{MemoryOrder, Tensor};
213///
214/// let grad = LstsqGrad {
215///     a: Tensor::<f64>::zeros(&[2, 2], tenferro_device::LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap(),
216///     b: Tensor::<f64>::zeros(&[2], tenferro_device::LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap(),
217/// };
218/// assert_eq!(grad.a.ndim(), 2);
219/// assert_eq!(grad.b.ndim(), 1);
220/// ```
221#[derive(Debug)]
222pub struct LstsqGrad<T: Scalar> {
223    /// Cotangent for the system matrix `A`.
224    pub a: Tensor<T>,
225    /// Cotangent for the right-hand side `b`.
226    pub b: Tensor<T>,
227}