tenferro_linalg/primal/decompositions.rs
1use super::linear_systems_sign::{lu_permutation_matrix_tensor, LiftPermutationMatrixTensor};
2use super::*;
3
4/// Compute the SVD of a batched matrix.
5///
6/// Input shape: `(m, n, *)`.
7///
8/// The function internally normalizes input to column-major contiguous layout.
9/// If the input is not already contiguous, an internal copy is performed.
10///
11/// # Arguments
12///
13/// * `tensor` — Input tensor of shape `(m, n, *)`
14/// * `options` — Optional truncation parameters
15///
16/// # Examples
17///
18/// ```
19/// use tenferro_device::LogicalMemorySpace;
20/// use tenferro_linalg::{svd, SvdOptions};
21/// use tenferro_prims::CpuContext;
22/// use tenferro_tensor::{MemoryOrder, Tensor};
23///
24/// let col = MemoryOrder::ColumnMajor;
25/// let mut ctx = CpuContext::new(1);
26/// let a = Tensor::<f64>::zeros(&[3, 4], LogicalMemorySpace::MainMemory, col).unwrap();
27///
28/// let _full = svd(&mut ctx, &a, None).unwrap();
29/// let opts = SvdOptions {
30/// max_rank: Some(2),
31/// cutoff: None,
32/// };
33/// let _truncated = svd(&mut ctx, &a, Some(&opts)).unwrap();
34/// ```
35///
36/// # Errors
37///
38/// Returns an error if the input has fewer than 2 dimensions.
39pub fn svd<T: KernelLinalgScalar, C>(
40 ctx: &mut C,
41 tensor: &Tensor<T>,
42 options: Option<&SvdOptions>,
43) -> Result<SvdResult<T, T::Real>>
44where
45 C: backend::TensorLinalgContextFor<T>
46 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
47 C::Backend: 'static,
48 T::Real: tenferro_tensor::KeepCountScalar,
49{
50 let result = <C::Backend as backend::TensorLinalgBackend<T>>::thin_svd(ctx, tensor)?;
51
52 let Some(opts) = options else {
53 return Ok(SvdResult {
54 u: result.u,
55 s: result.s,
56 vt: result.vt,
57 });
58 };
59
60 let k = result.s.dims()[0];
61 let max_k = opts.max_rank.map_or(k, |r| r.min(k));
62 let u = result.u.narrow(1, 0, max_k)?;
63 let s = result.s.narrow(0, 0, max_k)?;
64 let vt = result.vt.narrow(0, 0, max_k)?;
65
66 if max_k == 0 {
67 return Ok(SvdResult { u, s, vt });
68 }
69
70 if opts.cutoff.is_none() {
71 return Ok(SvdResult { u, s, vt });
72 }
73
74 let cutoff_r: T::Real = scalar_from(opts.cutoff.unwrap())?;
75 let cutoff_tensor =
76 crate::prims_bridge::full_like_constant(cutoff_r, s.dims(), s.logical_memory_space())?;
77 let mask = crate::prims_bridge::scalar_binary_same_shape::<T::Real, C>(
78 ctx,
79 &s,
80 &cutoff_tensor,
81 tenferro_prims::ScalarBinaryOp::GreaterEqual,
82 )?;
83 let kept_axes: Vec<usize> = (1..s.ndim()).collect();
84 let keep_counts =
85 crate::prims_bridge::scalar_sum_keep_axes::<T::Real, C>(ctx, &mask, &kept_axes)?;
86
87 Ok(SvdResult {
88 u: backend::tensor_helpers::zero_trailing_by_counts(&u, &keep_counts, 1, 2)?,
89 s: backend::tensor_helpers::zero_trailing_by_counts(&s, &keep_counts, 0, 1)?,
90 vt: backend::tensor_helpers::zero_trailing_by_counts(&vt, &keep_counts, 0, 2)?,
91 })
92}
93
94/// Compute singular values only for a batched matrix.
95///
96/// Input shape: `(m, n, *)`.
97///
98/// # Examples
99///
100/// ```
101/// use tenferro_device::LogicalMemorySpace;
102/// use tenferro_linalg::svdvals;
103/// use tenferro_prims::CpuContext;
104/// use tenferro_tensor::{MemoryOrder, Tensor};
105///
106/// let mut ctx = CpuContext::new(1);
107/// let a = Tensor::<f64>::zeros(
108/// &[3, 4],
109/// LogicalMemorySpace::MainMemory,
110/// MemoryOrder::ColumnMajor,
111/// ).unwrap();
112/// let _values = svdvals(&mut ctx, &a).unwrap();
113/// ```
114pub fn svdvals<T: KernelLinalgScalar, C>(ctx: &mut C, tensor: &Tensor<T>) -> Result<Tensor<T::Real>>
115where
116 C: backend::TensorLinalgContextFor<T>,
117 C::Backend: 'static,
118{
119 <C::Backend as backend::TensorLinalgBackend<T>>::svdvals(ctx, tensor)
120}
121
122/// Compute the QR decomposition of a batched matrix.
123///
124/// Input shape: `(m, n, *)`.
125///
126/// The function internally normalizes input to column-major contiguous layout.
127/// If the input is not already contiguous, an internal copy is performed.
128///
129/// # Examples
130///
131/// ```
132/// use tenferro_device::LogicalMemorySpace;
133/// use tenferro_linalg::qr;
134/// use tenferro_prims::CpuContext;
135/// use tenferro_tensor::{MemoryOrder, Tensor};
136///
137/// let mut ctx = CpuContext::new(1);
138/// let a = Tensor::<f64>::zeros(
139/// &[4, 3],
140/// LogicalMemorySpace::MainMemory,
141/// MemoryOrder::ColumnMajor,
142/// ).unwrap();
143/// let _result = qr(&mut ctx, &a).unwrap();
144/// ```
145///
146/// # Errors
147///
148/// Returns an error if the input has fewer than 2 dimensions.
149pub fn qr<T: KernelLinalgScalar, C>(ctx: &mut C, tensor: &Tensor<T>) -> Result<QrResult<T>>
150where
151 C: backend::TensorLinalgContextFor<T>,
152 C::Backend: 'static,
153{
154 let result = <C::Backend as backend::TensorLinalgBackend<T>>::qr(ctx, tensor)?;
155 let (m, n, batch_dims) = validate_2d(tensor)?;
156 let k = m.min(n);
157 let bc = batch_count(batch_dims);
158 let (mut q_data, _) =
159 extract_data(&result.q).map_err(|e| Error::InvalidArgument(e.to_string()))?;
160 let (mut r_data, _) =
161 extract_data(&result.r).map_err(|e| Error::InvalidArgument(e.to_string()))?;
162
163 for batch in 0..bc {
164 let q_base = batch * m * k;
165 let r_base = batch * k * n;
166 for i in 0..k {
167 let diag = r_data[r_base + i + i * k];
168 if diag.imag_part() == T::Real::zero() {
169 continue;
170 }
171 let mag = diag.abs_real();
172 if mag == T::Real::zero() {
173 continue;
174 }
175 let phase = diag / T::from_real(mag);
176 for row in 0..m {
177 q_data[q_base + row + i * m] = q_data[q_base + row + i * m] * phase;
178 }
179 let phase_inv = phase.conj();
180 for col in 0..n {
181 r_data[r_base + i + col * k] = r_data[r_base + i + col * k] * phase_inv;
182 }
183 }
184 }
185
186 Ok(QrResult {
187 q: tensor_from_data(q_data, result.q.dims())?,
188 r: tensor_from_data(r_data, result.r.dims())?,
189 })
190}
191
192/// Compute the LU decomposition of a batched matrix.
193///
194/// Input shape: `(m, n, *)`.
195///
196/// The function internally normalizes input to column-major contiguous layout.
197/// If the input is not already contiguous, an internal copy is performed.
198///
199/// # Arguments
200///
201/// * `tensor` — Input tensor of shape `(m, n, *)`
202/// * `pivot` — Pivoting strategy
203///
204/// # Examples
205///
206/// ```
207/// use tenferro_linalg::{lu, LuPivot};
208/// use tenferro_prims::CpuContext;
209/// use tenferro_tensor::{MemoryOrder, Tensor};
210///
211/// let mut ctx = CpuContext::new(1);
212/// let a = Tensor::<f64>::from_slice(&[1.0, 0.0, 0.0, 1.0], &[2, 2], MemoryOrder::ColumnMajor)
213/// .unwrap();
214/// let _partial = lu(&mut ctx, &a, LuPivot::Partial).unwrap();
215/// let no_pivot = lu(&mut ctx, &a, LuPivot::NoPivot).unwrap();
216/// assert_eq!(no_pivot.p.dims(), &[0]);
217/// ```
218pub fn lu<T: KernelLinalgScalar, C>(
219 ctx: &mut C,
220 tensor: &Tensor<T>,
221 pivot: LuPivot,
222) -> Result<LuResult<T>>
223where
224 C: backend::TensorLinalgContextFor<T>
225 + tenferro_prims::TensorMetadataContextFor
226 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
227 C::MetadataBackend: tenferro_prims::TensorMetadataPrims<Context = C>,
228 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>>::ScalarBackend:
229 tenferro_prims::TensorMetadataCastPrims<T::Real, Context = C>,
230 T: LiftPermutationMatrixTensor<C>,
231 C::Backend: 'static,
232{
233 let (m, _n, _batch_dims) = validate_2d(tensor)?;
234 if pivot == LuPivot::NoPivot {
235 let result =
236 <C::Backend as backend::TensorLinalgBackend<T>>::lu_factor_no_pivot(ctx, tensor)?;
237 return Ok(LuResult {
238 p: Tensor::empty(
239 &[0],
240 tensor.logical_memory_space(),
241 MemoryOrder::ColumnMajor,
242 )?,
243 l: result.l,
244 u: result.u,
245 });
246 }
247
248 let result = <C::Backend as backend::TensorLinalgBackend<T>>::lu_factor(ctx, tensor)?;
249 let p = lu_permutation_matrix_tensor(ctx, &result.pivots, m)?;
250
251 Ok(LuResult {
252 p,
253 l: result.l,
254 u: result.u,
255 })
256}
257
258/// Compute the packed LU factorization of a batched matrix.
259///
260/// # Examples
261///
262/// ```
263/// use tenferro_linalg::lu_factor;
264/// use tenferro_prims::CpuContext;
265/// use tenferro_tensor::{MemoryOrder, Tensor};
266///
267/// let mut ctx = CpuContext::new(1);
268/// let col = MemoryOrder::ColumnMajor;
269/// let a = Tensor::<f64>::from_slice(&[2.0, 1.0, 1.0, 3.0], &[2, 2], col).unwrap();
270/// let result = lu_factor(&mut ctx, &a).unwrap();
271/// assert_eq!(result.factors.dims(), &[2, 2]);
272/// assert_eq!(result.pivots.len(), 2);
273/// ```
274pub fn lu_factor<T: KernelLinalgScalar, C>(
275 ctx: &mut C,
276 tensor: &Tensor<T>,
277) -> Result<LuFactorResult<T>>
278where
279 C: backend::TensorLinalgContextFor<T>,
280 C::Backend: 'static,
281{
282 let _ = validate_2d(tensor)?;
283 let result = <C::Backend as backend::TensorLinalgBackend<T>>::lu_factor(ctx, tensor)?;
284 Ok(LuFactorResult {
285 factors: pack_lu_factors(&result.l, &result.u)?,
286 pivots: result.pivots,
287 })
288}
289
290/// Compute the packed LU factorization with numerical status information.
291///
292/// # Examples
293///
294/// ```
295/// use tenferro_linalg::lu_factor_ex;
296/// use tenferro_prims::CpuContext;
297/// use tenferro_tensor::{MemoryOrder, Tensor};
298///
299/// let mut ctx = CpuContext::new(1);
300/// let col = MemoryOrder::ColumnMajor;
301/// let a = Tensor::<f64>::from_slice(&[2.0, 1.0, 1.0, 3.0], &[2, 2], col).unwrap();
302/// let result = lu_factor_ex(&mut ctx, &a).unwrap();
303/// assert_eq!(result.factors.dims(), &[2, 2]);
304/// assert_eq!(result.info.len(), 1);
305/// ```
306pub fn lu_factor_ex<T: KernelLinalgScalar, C>(
307 ctx: &mut C,
308 tensor: &Tensor<T>,
309) -> Result<LuFactorExResult<T>>
310where
311 C: backend::TensorLinalgContextFor<T>,
312 C::Backend: 'static,
313{
314 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::LuFactorEx, "lu_factor_ex")?;
315
316 let _ = validate_2d(tensor)?;
317 let result = <C::Backend as backend::TensorLinalgBackend<T>>::lu_factor_ex(ctx, tensor)?;
318 let factors = pack_lu_factors(&result.l, &result.u)?;
319
320 Ok(LuFactorExResult {
321 factors,
322 pivots: result.pivots,
323 info: result.info,
324 })
325}
326
327/// Solve `A x = b` from a packed LU factorization.
328///
329/// # Examples
330///
331/// ```
332/// use tenferro_linalg::{lu_factor, lu_solve};
333/// use tenferro_prims::CpuContext;
334/// use tenferro_tensor::{MemoryOrder, Tensor};
335///
336/// let mut ctx = CpuContext::new(1);
337/// let col = MemoryOrder::ColumnMajor;
338/// let a = Tensor::<f64>::from_slice(&[2.0, 1.0, 1.0, 3.0], &[2, 2], col).unwrap();
339/// let lu = lu_factor(&mut ctx, &a).unwrap();
340/// let b = Tensor::<f64>::from_slice(&[5.0, 7.0], &[2], col).unwrap();
341/// let x = lu_solve(&mut ctx, &lu.factors, &lu.pivots, &b).unwrap();
342/// assert_eq!(x.dims(), &[2]);
343/// ```
344pub fn lu_solve<T: KernelLinalgScalar, C>(
345 ctx: &mut C,
346 factors: &Tensor<T>,
347 pivots: &Tensor<i32>,
348 b: &Tensor<T>,
349) -> Result<Tensor<T>>
350where
351 T: KernelLinalgScalar,
352 C: backend::TensorLinalgContextFor<T>,
353 C::Backend: 'static,
354{
355 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::LuSolve, "lu_solve")?;
356 <C::Backend as backend::TensorLinalgBackend<T>>::lu_solve(ctx, factors, pivots, b)
357}
358
359/// Compute the eigendecomposition of a batched square matrix.
360///
361/// Input shape: `(n, n, *)`.
362/// The lower triangle is treated as canonical input, matching the default
363/// `UPLO='L'` behavior used by PyTorch's `linalg.eigh`.
364///
365/// # Examples
366///
367/// ```
368/// use tenferro_device::LogicalMemorySpace;
369/// use tenferro_linalg::eigen;
370/// use tenferro_prims::CpuContext;
371/// use tenferro_tensor::{MemoryOrder, Tensor};
372///
373/// let mut ctx = CpuContext::new(1);
374/// let a = Tensor::<f64>::from_slice(
375/// &[2.0, 1.0, 1.0, 2.0],
376/// &[2, 2],
377/// MemoryOrder::ColumnMajor,
378/// ).unwrap();
379/// let result = eigen(&mut ctx, &a).unwrap();
380/// assert_eq!(result.values.dims(), &[2]);
381/// ```
382pub fn eigen<T: KernelLinalgScalar, C>(
383 ctx: &mut C,
384 tensor: &Tensor<T>,
385) -> Result<EigenResult<T, T::Real>>
386where
387 C: backend::TensorLinalgContextFor<T>,
388 C::Backend: 'static,
389{
390 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::EigenSym, "eigen")?;
391 let _ = validate_square(tensor)?;
392 let result = <C::Backend as backend::TensorLinalgBackend<T>>::eigen_sym(ctx, tensor)?;
393
394 Ok(EigenResult {
395 values: result.values,
396 vectors: result.vectors,
397 })
398}