tenferro_linalg/primal/
decompositions.rs

1use super::linear_systems_sign::{lu_permutation_matrix_tensor, LiftPermutationMatrixTensor};
2use super::*;
3
4/// Compute the SVD of a batched matrix.
5///
6/// Input shape: `(m, n, *)`.
7///
8/// The function internally normalizes input to column-major contiguous layout.
9/// If the input is not already contiguous, an internal copy is performed.
10///
11/// # Arguments
12///
13/// * `tensor` — Input tensor of shape `(m, n, *)`
14/// * `options` — Optional truncation parameters
15///
16/// # Examples
17///
18/// ```
19/// use tenferro_device::LogicalMemorySpace;
20/// use tenferro_linalg::{svd, SvdOptions};
21/// use tenferro_prims::CpuContext;
22/// use tenferro_tensor::{MemoryOrder, Tensor};
23///
24/// let col = MemoryOrder::ColumnMajor;
25/// let mut ctx = CpuContext::new(1);
26/// let a = Tensor::<f64>::zeros(&[3, 4], LogicalMemorySpace::MainMemory, col).unwrap();
27///
28/// let _full = svd(&mut ctx, &a, None).unwrap();
29/// let opts = SvdOptions {
30///     max_rank: Some(2),
31///     cutoff: None,
32/// };
33/// let _truncated = svd(&mut ctx, &a, Some(&opts)).unwrap();
34/// ```
35///
36/// # Errors
37///
38/// Returns an error if the input has fewer than 2 dimensions.
39pub fn svd<T: KernelLinalgScalar, C>(
40    ctx: &mut C,
41    tensor: &Tensor<T>,
42    options: Option<&SvdOptions>,
43) -> Result<SvdResult<T, T::Real>>
44where
45    C: backend::TensorLinalgContextFor<T>
46        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
47    C::Backend: 'static,
48    T::Real: tenferro_tensor::KeepCountScalar,
49{
50    let result = <C::Backend as backend::TensorLinalgBackend<T>>::thin_svd(ctx, tensor)?;
51
52    let Some(opts) = options else {
53        return Ok(SvdResult {
54            u: result.u,
55            s: result.s,
56            vt: result.vt,
57        });
58    };
59
60    let k = result.s.dims()[0];
61    let max_k = opts.max_rank.map_or(k, |r| r.min(k));
62    let u = result.u.narrow(1, 0, max_k)?;
63    let s = result.s.narrow(0, 0, max_k)?;
64    let vt = result.vt.narrow(0, 0, max_k)?;
65
66    if max_k == 0 {
67        return Ok(SvdResult { u, s, vt });
68    }
69
70    if opts.cutoff.is_none() {
71        return Ok(SvdResult { u, s, vt });
72    }
73
74    let cutoff_r: T::Real = scalar_from(opts.cutoff.unwrap())?;
75    let cutoff_tensor =
76        crate::prims_bridge::full_like_constant(cutoff_r, s.dims(), s.logical_memory_space())?;
77    let mask = crate::prims_bridge::scalar_binary_same_shape::<T::Real, C>(
78        ctx,
79        &s,
80        &cutoff_tensor,
81        tenferro_prims::ScalarBinaryOp::GreaterEqual,
82    )?;
83    let kept_axes: Vec<usize> = (1..s.ndim()).collect();
84    let keep_counts =
85        crate::prims_bridge::scalar_sum_keep_axes::<T::Real, C>(ctx, &mask, &kept_axes)?;
86
87    Ok(SvdResult {
88        u: backend::tensor_helpers::zero_trailing_by_counts(&u, &keep_counts, 1, 2)?,
89        s: backend::tensor_helpers::zero_trailing_by_counts(&s, &keep_counts, 0, 1)?,
90        vt: backend::tensor_helpers::zero_trailing_by_counts(&vt, &keep_counts, 0, 2)?,
91    })
92}
93
94/// Compute singular values only for a batched matrix.
95///
96/// Input shape: `(m, n, *)`.
97///
98/// # Examples
99///
100/// ```
101/// use tenferro_device::LogicalMemorySpace;
102/// use tenferro_linalg::svdvals;
103/// use tenferro_prims::CpuContext;
104/// use tenferro_tensor::{MemoryOrder, Tensor};
105///
106/// let mut ctx = CpuContext::new(1);
107/// let a = Tensor::<f64>::zeros(
108///     &[3, 4],
109///     LogicalMemorySpace::MainMemory,
110///     MemoryOrder::ColumnMajor,
111/// ).unwrap();
112/// let _values = svdvals(&mut ctx, &a).unwrap();
113/// ```
114pub fn svdvals<T: KernelLinalgScalar, C>(ctx: &mut C, tensor: &Tensor<T>) -> Result<Tensor<T::Real>>
115where
116    C: backend::TensorLinalgContextFor<T>,
117    C::Backend: 'static,
118{
119    <C::Backend as backend::TensorLinalgBackend<T>>::svdvals(ctx, tensor)
120}
121
122/// Compute the QR decomposition of a batched matrix.
123///
124/// Input shape: `(m, n, *)`.
125///
126/// The function internally normalizes input to column-major contiguous layout.
127/// If the input is not already contiguous, an internal copy is performed.
128///
129/// # Examples
130///
131/// ```
132/// use tenferro_device::LogicalMemorySpace;
133/// use tenferro_linalg::qr;
134/// use tenferro_prims::CpuContext;
135/// use tenferro_tensor::{MemoryOrder, Tensor};
136///
137/// let mut ctx = CpuContext::new(1);
138/// let a = Tensor::<f64>::zeros(
139///     &[4, 3],
140///     LogicalMemorySpace::MainMemory,
141///     MemoryOrder::ColumnMajor,
142/// ).unwrap();
143/// let _result = qr(&mut ctx, &a).unwrap();
144/// ```
145///
146/// # Errors
147///
148/// Returns an error if the input has fewer than 2 dimensions.
149pub fn qr<T: KernelLinalgScalar, C>(ctx: &mut C, tensor: &Tensor<T>) -> Result<QrResult<T>>
150where
151    C: backend::TensorLinalgContextFor<T>,
152    C::Backend: 'static,
153{
154    let result = <C::Backend as backend::TensorLinalgBackend<T>>::qr(ctx, tensor)?;
155    let (m, n, batch_dims) = validate_2d(tensor)?;
156    let k = m.min(n);
157    let bc = batch_count(batch_dims);
158    let (mut q_data, _) =
159        extract_data(&result.q).map_err(|e| Error::InvalidArgument(e.to_string()))?;
160    let (mut r_data, _) =
161        extract_data(&result.r).map_err(|e| Error::InvalidArgument(e.to_string()))?;
162
163    for batch in 0..bc {
164        let q_base = batch * m * k;
165        let r_base = batch * k * n;
166        for i in 0..k {
167            let diag = r_data[r_base + i + i * k];
168            if diag.imag_part() == T::Real::zero() {
169                continue;
170            }
171            let mag = diag.abs_real();
172            if mag == T::Real::zero() {
173                continue;
174            }
175            let phase = diag / T::from_real(mag);
176            for row in 0..m {
177                q_data[q_base + row + i * m] = q_data[q_base + row + i * m] * phase;
178            }
179            let phase_inv = phase.conj();
180            for col in 0..n {
181                r_data[r_base + i + col * k] = r_data[r_base + i + col * k] * phase_inv;
182            }
183        }
184    }
185
186    Ok(QrResult {
187        q: tensor_from_data(q_data, result.q.dims())?,
188        r: tensor_from_data(r_data, result.r.dims())?,
189    })
190}
191
192/// Compute the LU decomposition of a batched matrix.
193///
194/// Input shape: `(m, n, *)`.
195///
196/// The function internally normalizes input to column-major contiguous layout.
197/// If the input is not already contiguous, an internal copy is performed.
198///
199/// # Arguments
200///
201/// * `tensor` — Input tensor of shape `(m, n, *)`
202/// * `pivot` — Pivoting strategy
203///
204/// # Examples
205///
206/// ```
207/// use tenferro_linalg::{lu, LuPivot};
208/// use tenferro_prims::CpuContext;
209/// use tenferro_tensor::{MemoryOrder, Tensor};
210///
211/// let mut ctx = CpuContext::new(1);
212/// let a = Tensor::<f64>::from_slice(&[1.0, 0.0, 0.0, 1.0], &[2, 2], MemoryOrder::ColumnMajor)
213///     .unwrap();
214/// let _partial = lu(&mut ctx, &a, LuPivot::Partial).unwrap();
215/// let no_pivot = lu(&mut ctx, &a, LuPivot::NoPivot).unwrap();
216/// assert_eq!(no_pivot.p.dims(), &[0]);
217/// ```
218pub fn lu<T: KernelLinalgScalar, C>(
219    ctx: &mut C,
220    tensor: &Tensor<T>,
221    pivot: LuPivot,
222) -> Result<LuResult<T>>
223where
224    C: backend::TensorLinalgContextFor<T>
225        + tenferro_prims::TensorMetadataContextFor
226        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
227    C::MetadataBackend: tenferro_prims::TensorMetadataPrims<Context = C>,
228    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>>::ScalarBackend:
229        tenferro_prims::TensorMetadataCastPrims<T::Real, Context = C>,
230    T: LiftPermutationMatrixTensor<C>,
231    C::Backend: 'static,
232{
233    let (m, _n, _batch_dims) = validate_2d(tensor)?;
234    if pivot == LuPivot::NoPivot {
235        let result =
236            <C::Backend as backend::TensorLinalgBackend<T>>::lu_factor_no_pivot(ctx, tensor)?;
237        return Ok(LuResult {
238            p: Tensor::empty(
239                &[0],
240                tensor.logical_memory_space(),
241                MemoryOrder::ColumnMajor,
242            )?,
243            l: result.l,
244            u: result.u,
245        });
246    }
247
248    let result = <C::Backend as backend::TensorLinalgBackend<T>>::lu_factor(ctx, tensor)?;
249    let p = lu_permutation_matrix_tensor(ctx, &result.pivots, m)?;
250
251    Ok(LuResult {
252        p,
253        l: result.l,
254        u: result.u,
255    })
256}
257
258/// Compute the packed LU factorization of a batched matrix.
259///
260/// # Examples
261///
262/// ```
263/// use tenferro_linalg::lu_factor;
264/// use tenferro_prims::CpuContext;
265/// use tenferro_tensor::{MemoryOrder, Tensor};
266///
267/// let mut ctx = CpuContext::new(1);
268/// let col = MemoryOrder::ColumnMajor;
269/// let a = Tensor::<f64>::from_slice(&[2.0, 1.0, 1.0, 3.0], &[2, 2], col).unwrap();
270/// let result = lu_factor(&mut ctx, &a).unwrap();
271/// assert_eq!(result.factors.dims(), &[2, 2]);
272/// assert_eq!(result.pivots.len(), 2);
273/// ```
274pub fn lu_factor<T: KernelLinalgScalar, C>(
275    ctx: &mut C,
276    tensor: &Tensor<T>,
277) -> Result<LuFactorResult<T>>
278where
279    C: backend::TensorLinalgContextFor<T>,
280    C::Backend: 'static,
281{
282    let _ = validate_2d(tensor)?;
283    let result = <C::Backend as backend::TensorLinalgBackend<T>>::lu_factor(ctx, tensor)?;
284    Ok(LuFactorResult {
285        factors: pack_lu_factors(&result.l, &result.u)?,
286        pivots: result.pivots,
287    })
288}
289
290/// Compute the packed LU factorization with numerical status information.
291///
292/// # Examples
293///
294/// ```
295/// use tenferro_linalg::lu_factor_ex;
296/// use tenferro_prims::CpuContext;
297/// use tenferro_tensor::{MemoryOrder, Tensor};
298///
299/// let mut ctx = CpuContext::new(1);
300/// let col = MemoryOrder::ColumnMajor;
301/// let a = Tensor::<f64>::from_slice(&[2.0, 1.0, 1.0, 3.0], &[2, 2], col).unwrap();
302/// let result = lu_factor_ex(&mut ctx, &a).unwrap();
303/// assert_eq!(result.factors.dims(), &[2, 2]);
304/// assert_eq!(result.info.len(), 1);
305/// ```
306pub fn lu_factor_ex<T: KernelLinalgScalar, C>(
307    ctx: &mut C,
308    tensor: &Tensor<T>,
309) -> Result<LuFactorExResult<T>>
310where
311    C: backend::TensorLinalgContextFor<T>,
312    C::Backend: 'static,
313{
314    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::LuFactorEx, "lu_factor_ex")?;
315
316    let _ = validate_2d(tensor)?;
317    let result = <C::Backend as backend::TensorLinalgBackend<T>>::lu_factor_ex(ctx, tensor)?;
318    let factors = pack_lu_factors(&result.l, &result.u)?;
319
320    Ok(LuFactorExResult {
321        factors,
322        pivots: result.pivots,
323        info: result.info,
324    })
325}
326
327/// Solve `A x = b` from a packed LU factorization.
328///
329/// # Examples
330///
331/// ```
332/// use tenferro_linalg::{lu_factor, lu_solve};
333/// use tenferro_prims::CpuContext;
334/// use tenferro_tensor::{MemoryOrder, Tensor};
335///
336/// let mut ctx = CpuContext::new(1);
337/// let col = MemoryOrder::ColumnMajor;
338/// let a = Tensor::<f64>::from_slice(&[2.0, 1.0, 1.0, 3.0], &[2, 2], col).unwrap();
339/// let lu = lu_factor(&mut ctx, &a).unwrap();
340/// let b = Tensor::<f64>::from_slice(&[5.0, 7.0], &[2], col).unwrap();
341/// let x = lu_solve(&mut ctx, &lu.factors, &lu.pivots, &b).unwrap();
342/// assert_eq!(x.dims(), &[2]);
343/// ```
344pub fn lu_solve<T: KernelLinalgScalar, C>(
345    ctx: &mut C,
346    factors: &Tensor<T>,
347    pivots: &Tensor<i32>,
348    b: &Tensor<T>,
349) -> Result<Tensor<T>>
350where
351    T: KernelLinalgScalar,
352    C: backend::TensorLinalgContextFor<T>,
353    C::Backend: 'static,
354{
355    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::LuSolve, "lu_solve")?;
356    <C::Backend as backend::TensorLinalgBackend<T>>::lu_solve(ctx, factors, pivots, b)
357}
358
359/// Compute the eigendecomposition of a batched square matrix.
360///
361/// Input shape: `(n, n, *)`.
362/// The lower triangle is treated as canonical input, matching the default
363/// `UPLO='L'` behavior used by PyTorch's `linalg.eigh`.
364///
365/// # Examples
366///
367/// ```
368/// use tenferro_device::LogicalMemorySpace;
369/// use tenferro_linalg::eigen;
370/// use tenferro_prims::CpuContext;
371/// use tenferro_tensor::{MemoryOrder, Tensor};
372///
373/// let mut ctx = CpuContext::new(1);
374/// let a = Tensor::<f64>::from_slice(
375///     &[2.0, 1.0, 1.0, 2.0],
376///     &[2, 2],
377///     MemoryOrder::ColumnMajor,
378/// ).unwrap();
379/// let result = eigen(&mut ctx, &a).unwrap();
380/// assert_eq!(result.values.dims(), &[2]);
381/// ```
382pub fn eigen<T: KernelLinalgScalar, C>(
383    ctx: &mut C,
384    tensor: &Tensor<T>,
385) -> Result<EigenResult<T, T::Real>>
386where
387    C: backend::TensorLinalgContextFor<T>,
388    C::Backend: 'static,
389{
390    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::EigenSym, "eigen")?;
391    let _ = validate_square(tensor)?;
392    let result = <C::Backend as backend::TensorLinalgBackend<T>>::eigen_sym(ctx, tensor)?;
393
394    Ok(EigenResult {
395        values: result.values,
396        vectors: result.vectors,
397    })
398}