tenferro_linalg/primal/
linear_systems.rs

1use super::*;
2use crate::primal::linear_systems_sign::*;
3use num_complex::{Complex32, Complex64, ComplexFloat};
4use num_traits::{NumCast, One};
5use tenferro_algebra::Conjugate;
6use tenferro_prims::{TensorMetadataCastPrims, TensorMetadataContextFor, TensorMetadataPrims};
7fn inverse_rhs<T: KernelLinalgScalar>(
8    n: usize,
9    batch_dims: &[usize],
10    memory_space: tenferro_device::LogicalMemorySpace,
11) -> Result<Tensor<T>> {
12    let mut rhs = crate::prims_bridge::identity_matrix(n, memory_space)?;
13    for _ in batch_dims {
14        rhs = rhs.unsqueeze(-1)?;
15    }
16    rhs.broadcast(&output_dims(&[n, n], batch_dims))
17}
18
19/// Solve a square linear system `A x = b`.
20///
21/// # Examples
22///
23/// ```
24/// use tenferro_linalg::solve;
25/// use tenferro_prims::CpuContext;
26/// use tenferro_tensor::{MemoryOrder, Tensor};
27///
28/// let mut ctx = CpuContext::new(1);
29/// let col = MemoryOrder::ColumnMajor;
30/// let a = Tensor::<f64>::from_slice(&[2.0, 1.0, 1.0, 3.0], &[2, 2], col).unwrap();
31/// let b = Tensor::<f64>::from_slice(&[5.0, 7.0], &[2], col).unwrap();
32/// let x = solve(&mut ctx, &a, &b).unwrap();
33/// assert_eq!(x.dims(), &[2]);
34/// ```
35pub fn solve<T: KernelLinalgScalar, C>(
36    ctx: &mut C,
37    a: &Tensor<T>,
38    b: &Tensor<T>,
39) -> Result<Tensor<T>>
40where
41    C: backend::TensorLinalgContextFor<T>,
42    C::Backend: 'static,
43{
44    <C::Backend as backend::TensorLinalgBackend<T>>::solve(ctx, a, b)
45}
46
47/// Solve a square linear system with numerical status information.
48///
49/// # Examples
50///
51/// ```
52/// use tenferro_linalg::solve_ex;
53/// use tenferro_prims::CpuContext;
54/// use tenferro_tensor::{MemoryOrder, Tensor};
55///
56/// let mut ctx = CpuContext::new(1);
57/// let col = MemoryOrder::ColumnMajor;
58/// let a = Tensor::<f64>::from_slice(&[1.0, 0.0, 0.0, 1.0], &[2, 2], col).unwrap();
59/// let b = Tensor::<f64>::from_slice(&[3.0, 4.0], &[2], col).unwrap();
60/// let result = solve_ex(&mut ctx, &a, &b).unwrap();
61/// assert_eq!(result.solution.dims(), &[2]);
62/// assert_eq!(result.info.len(), 1);
63/// ```
64pub fn solve_ex<T: KernelLinalgScalar, C>(
65    ctx: &mut C,
66    a: &Tensor<T>,
67    b: &Tensor<T>,
68) -> Result<SolveExResult<T>>
69where
70    T: KernelLinalgScalar,
71    C: backend::TensorLinalgContextFor<T>,
72    C::Backend: 'static,
73{
74    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::SolveEx, "solve_ex")?;
75    let result = <C::Backend as backend::TensorLinalgBackend<T>>::solve_ex(ctx, a, b)?;
76    Ok(SolveExResult {
77        solution: result.solution,
78        info: result.info,
79    })
80}
81
82/// Compute the inverse of a square matrix.
83///
84/// # Examples
85///
86/// ```
87/// use tenferro_linalg::inv;
88/// use tenferro_prims::CpuContext;
89/// use tenferro_tensor::{MemoryOrder, Tensor};
90///
91/// let mut ctx = CpuContext::new(1);
92/// let col = MemoryOrder::ColumnMajor;
93/// let a = Tensor::<f64>::from_slice(&[1.0, 0.0, 0.0, 1.0], &[2, 2], col).unwrap();
94/// let inv_a = inv(&mut ctx, &a).unwrap();
95/// assert_eq!(inv_a.dims(), &[2, 2]);
96/// ```
97pub fn inv<T: KernelLinalgScalar, C>(ctx: &mut C, tensor: &Tensor<T>) -> Result<Tensor<T>>
98where
99    T: KernelLinalgScalar,
100    C: backend::TensorLinalgContextFor<T>,
101    C::Backend: 'static,
102{
103    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Inv, "inv")?;
104
105    let (n, batch_dims) = validate_square(tensor)?;
106    let rhs = inverse_rhs::<T>(n, batch_dims, tensor.logical_memory_space())?;
107    if n == 0 {
108        return Ok(rhs);
109    }
110    solve(ctx, tensor, &rhs)
111}
112
113/// Compute the inverse with numerical status information.
114///
115/// # Examples
116///
117/// ```
118/// use tenferro_linalg::inv_ex;
119/// use tenferro_prims::CpuContext;
120/// use tenferro_tensor::{MemoryOrder, Tensor};
121///
122/// let mut ctx = CpuContext::new(1);
123/// let col = MemoryOrder::ColumnMajor;
124/// let a = Tensor::<f64>::from_slice(&[1.0, 0.0, 0.0, 1.0], &[2, 2], col).unwrap();
125/// let result = inv_ex(&mut ctx, &a).unwrap();
126/// assert_eq!(result.inverse.dims(), &[2, 2]);
127/// assert_eq!(result.info.len(), 1);
128/// ```
129pub fn inv_ex<T: KernelLinalgScalar, C>(ctx: &mut C, tensor: &Tensor<T>) -> Result<InvExResult<T>>
130where
131    T: KernelLinalgScalar,
132    C: backend::TensorLinalgContextFor<T>,
133    C::Backend: 'static,
134{
135    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Inv, "inv_ex")?;
136    let (n, batch_dims) = validate_square(tensor)?;
137    let rhs = inverse_rhs::<T>(n, batch_dims, tensor.logical_memory_space())?;
138    if n == 0 {
139        return Ok(InvExResult {
140            inverse: rhs,
141            info: crate::backend::tensor_helpers::info_tensor_from_vec_on_space(
142                vec![0; batch_count(batch_dims)],
143                batch_dims,
144                tensor.logical_memory_space(),
145            )?,
146        });
147    }
148    let result = solve_ex(ctx, tensor, &rhs)?;
149    Ok(InvExResult {
150        inverse: result.solution,
151        info: result.info,
152    })
153}
154
155/// Compute the determinant of a square matrix.
156///
157/// # Examples
158///
159/// ```
160/// use tenferro_linalg::det;
161/// use tenferro_prims::CpuContext;
162/// use tenferro_tensor::{MemoryOrder, Tensor};
163///
164/// let mut ctx = CpuContext::new(1);
165/// let col = MemoryOrder::ColumnMajor;
166/// let a = Tensor::<f64>::from_slice(&[1.0, 0.0, 0.0, 2.0], &[2, 2], col).unwrap();
167/// let d = det(&mut ctx, &a).unwrap();
168/// assert_eq!(d.ndim(), 0);
169/// ```
170pub fn det<T: KernelLinalgScalar, C>(ctx: &mut C, tensor: &Tensor<T>) -> Result<Tensor<T>>
171where
172    C: backend::TensorLinalgContextFor<T>
173        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
174        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>
175        + TensorMetadataContextFor,
176    T: crate::prims_bridge::ScaleTensorByRealSameShape<C>,
177    T::Real: Scalar + NumCast + One + 'static,
178    C::MetadataBackend: TensorMetadataPrims<Context = C>,
179    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>>::ScalarBackend:
180        TensorMetadataCastPrims<T::Real, Context = C>,
181    C::Backend: 'static,
182{
183    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Det, "det")?;
184
185    let (_n, batch_dims) = validate_square(tensor)?;
186    let lu = <C::Backend as backend::TensorLinalgBackend<T>>::lu_factor(ctx, tensor)?;
187    let diagonal = lu.u.diagonal(&[(0, 1)])?;
188    let kept_axes: Vec<usize> = (0..batch_dims.len()).collect();
189    let diagonal_prod = crate::prims_bridge::scalar_reduce_keep_axes(
190        ctx,
191        &diagonal,
192        &kept_axes,
193        tenferro_prims::ScalarReductionOp::Prod,
194    )?;
195    let sign_tensor = lu_permutation_sign_tensor::<T, C>(ctx, &lu.pivots)?;
196
197    <T as crate::prims_bridge::ScaleTensorByRealSameShape<C>>::scale_tensor_by_real_same_shape(
198        ctx,
199        &diagonal_prod,
200        &sign_tensor,
201    )
202}
203
204/// Dispatch trait for [`slogdet`] — selects real or complex implementation.
205///
206/// This trait is `#[doc(hidden)]` and not intended for external use.
207/// Use [`slogdet`] directly instead.
208///
209/// # Examples
210///
211/// ```ignore
212/// // Internal dispatch: users should call `slogdet` instead.
213/// use tenferro_linalg::SlogdetDispatch;
214/// ```
215#[doc(hidden)]
216pub trait SlogdetDispatch<C>: KernelLinalgScalar {
217    fn slogdet_dispatch(
218        ctx: &mut C,
219        tensor: &Tensor<Self>,
220    ) -> Result<SlogdetResult<Self, Self::Real>>;
221}
222
223fn slogdet_real_impl<T, C>(ctx: &mut C, tensor: &Tensor<T>) -> Result<SlogdetResult<T, T::Real>>
224where
225    T: KernelLinalgScalar<Real = T> + num_traits::Float,
226    C: backend::TensorLinalgContextFor<T>
227        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
228        + TensorMetadataContextFor,
229    C::Backend: 'static,
230    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>>::ScalarBackend:
231        'static
232            + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<T>, Context = C>
233            + TensorMetadataCastPrims<T, Context = C>,
234    C::MetadataBackend: TensorMetadataPrims<Context = C>,
235{
236    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Slogdet, "slogdet")?;
237
238    let (_n, batch_dims) = validate_square(tensor)?;
239    let lu = <C::Backend as backend::TensorLinalgBackend<T>>::lu_factor(ctx, tensor)?;
240    let diagonal = lu.u.diagonal(&[(0, 1)])?;
241    let abs_diagonal = crate::prims_bridge::scalar_unary_same_shape(
242        ctx,
243        &diagonal,
244        tenferro_prims::ScalarUnaryOp::Abs,
245    )?;
246    let logabsdet_factor = crate::prims_bridge::analytic_unary_same_shape(
247        ctx,
248        &abs_diagonal,
249        tenferro_prims::AnalyticUnaryOp::Log,
250    )?;
251    let kept_axes: Vec<usize> = (0..batch_dims.len()).collect();
252    let logabsdet = crate::prims_bridge::scalar_reduce_keep_axes(
253        ctx,
254        &logabsdet_factor,
255        &kept_axes,
256        tenferro_prims::ScalarReductionOp::Sum,
257    )?;
258    let sign_perm = lu_permutation_sign_tensor::<T, C>(ctx, &lu.pivots)?;
259
260    let zero_diagonal = crate::prims_bridge::full_like_constant(
261        T::zero(),
262        diagonal.dims(),
263        tensor.logical_memory_space(),
264    )?;
265    let negative_mask = crate::prims_bridge::scalar_binary_same_shape(
266        ctx,
267        &zero_diagonal,
268        &diagonal,
269        tenferro_prims::ScalarBinaryOp::Greater,
270    )?;
271    let double_negative = crate::prims_bridge::scalar_binary_same_shape(
272        ctx,
273        &negative_mask,
274        &negative_mask,
275        tenferro_prims::ScalarBinaryOp::Add,
276    )?;
277    let one = crate::prims_bridge::full_like_constant(
278        T::one(),
279        diagonal.dims(),
280        tensor.logical_memory_space(),
281    )?;
282    let sign_factors = crate::prims_bridge::scalar_binary_same_shape(
283        ctx,
284        &one,
285        &double_negative,
286        tenferro_prims::ScalarBinaryOp::Sub,
287    )?;
288    let sign_from_diag = crate::prims_bridge::scalar_reduce_keep_axes(
289        ctx,
290        &sign_factors,
291        &kept_axes,
292        tenferro_prims::ScalarReductionOp::Prod,
293    )?;
294    let sign = crate::prims_bridge::scalar_binary_same_shape(
295        ctx,
296        &sign_perm,
297        &sign_from_diag,
298        tenferro_prims::ScalarBinaryOp::Mul,
299    )?;
300
301    Ok(SlogdetResult { sign, logabsdet })
302}
303
304fn slogdet_complex_impl<T, R, C>(ctx: &mut C, tensor: &Tensor<T>) -> Result<SlogdetResult<T, R>>
305where
306    T: KernelLinalgScalar<Real = R> + Conjugate + ComplexFloat<Real = R>,
307    T: crate::prims_bridge::ScaleTensorByRealSameShape<C>,
308    C: backend::TensorLinalgContextFor<T>
309        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<R>>
310        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
311        + tenferro_prims::TensorComplexRealContextFor<T>,
312    C: TensorMetadataContextFor,
313    C::Backend: 'static,
314    R: KernelLinalgScalar<Real = R> + num_traits::Float,
315    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<R>>>::ScalarBackend:
316        'static
317            + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<R>, Context = C>
318            + TensorMetadataCastPrims<R, Context = C>,
319    <C as tenferro_prims::TensorComplexRealContextFor<T>>::ComplexRealBackend:
320        tenferro_prims::TensorComplexRealPrims<T, Context = C, Real = R>,
321    C::MetadataBackend: TensorMetadataPrims<Context = C>,
322{
323    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Slogdet, "slogdet")?;
324
325    let (_n, batch_dims) = validate_square(tensor)?;
326    let lu = <C::Backend as backend::TensorLinalgBackend<T>>::lu_factor(ctx, tensor)?;
327    let diagonal = lu.u.diagonal(&[(0, 1)])?;
328    let abs_diagonal = crate::prims_bridge::complex_real_unary_same_shape(
329        ctx,
330        &diagonal,
331        tenferro_prims::ComplexRealUnaryOp::Abs,
332    )?;
333    let logabsdet_factor = crate::prims_bridge::analytic_unary_same_shape(
334        ctx,
335        &abs_diagonal,
336        tenferro_prims::AnalyticUnaryOp::Log,
337    )?;
338    let kept_axes: Vec<usize> = (0..batch_dims.len()).collect();
339    let logabsdet = crate::prims_bridge::scalar_reduce_keep_axes(
340        ctx,
341        &logabsdet_factor,
342        &kept_axes,
343        tenferro_prims::ScalarReductionOp::Sum,
344    )?;
345    let sign_perm = lu_permutation_sign_tensor::<T, C>(ctx, &lu.pivots)?;
346
347    let zero_real = crate::prims_bridge::full_like_constant(
348        R::zero(),
349        abs_diagonal.dims(),
350        tensor.logical_memory_space(),
351    )?;
352    let positive_mask = crate::prims_bridge::scalar_binary_same_shape(
353        ctx,
354        &abs_diagonal,
355        &zero_real,
356        tenferro_prims::ScalarBinaryOp::Greater,
357    )?;
358    let reciprocal_abs = crate::prims_bridge::scalar_unary_same_shape(
359        ctx,
360        &abs_diagonal,
361        tenferro_prims::ScalarUnaryOp::Reciprocal,
362    )?;
363    let zero_recip = crate::prims_bridge::full_like_constant(
364        R::zero(),
365        abs_diagonal.dims(),
366        tensor.logical_memory_space(),
367    )?;
368    let safe_recip = crate::prims_bridge::scalar_where_same_shape(
369        ctx,
370        &positive_mask,
371        &reciprocal_abs,
372        &zero_recip,
373    )?;
374    let phase_factors = crate::prims_bridge::complex_scale_same_shape(ctx, &diagonal, &safe_recip)?;
375    let sign_from_diag = crate::prims_bridge::scalar_reduce_keep_axes(
376        ctx,
377        &phase_factors,
378        &kept_axes,
379        tenferro_prims::ScalarReductionOp::Prod,
380    )?;
381    let sign = crate::prims_bridge::complex_scale_same_shape(ctx, &sign_from_diag, &sign_perm)?;
382
383    Ok(SlogdetResult { sign, logabsdet })
384}
385
386impl<C> SlogdetDispatch<C> for f32
387where
388    C: backend::TensorLinalgContextFor<f32>
389        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f32>>
390        + TensorMetadataContextFor,
391    C::Backend: 'static,
392    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f32>>>::ScalarBackend:
393        'static
394            + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<f32>, Context = C>
395            + TensorMetadataCastPrims<f32, Context = C>,
396    C::MetadataBackend: TensorMetadataPrims<Context = C>,
397{
398    fn slogdet_dispatch(
399        ctx: &mut C,
400        tensor: &Tensor<Self>,
401    ) -> Result<SlogdetResult<Self, Self::Real>> {
402        slogdet_real_impl(ctx, tensor)
403    }
404}
405
406impl<C> SlogdetDispatch<C> for f64
407where
408    C: backend::TensorLinalgContextFor<f64>
409        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f64>>
410        + TensorMetadataContextFor,
411    C::Backend: 'static,
412    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f64>>>::ScalarBackend:
413        'static
414            + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<f64>, Context = C>
415            + TensorMetadataCastPrims<f64, Context = C>,
416    C::MetadataBackend: TensorMetadataPrims<Context = C>,
417{
418    fn slogdet_dispatch(
419        ctx: &mut C,
420        tensor: &Tensor<Self>,
421    ) -> Result<SlogdetResult<Self, Self::Real>> {
422        slogdet_real_impl(ctx, tensor)
423    }
424}
425
426impl<C> SlogdetDispatch<C> for Complex32
427where
428    Complex32: crate::prims_bridge::ScaleTensorByRealSameShape<C>,
429    C: backend::TensorLinalgContextFor<Complex32>
430        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f32>>
431        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<Complex32>>
432        + tenferro_prims::TensorComplexRealContextFor<Complex32>
433        + TensorMetadataContextFor,
434    C::Backend: 'static,
435    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f32>>>::ScalarBackend:
436        'static
437            + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<f32>, Context = C>
438            + TensorMetadataCastPrims<f32, Context = C>,
439    <C as tenferro_prims::TensorComplexRealContextFor<Complex32>>::ComplexRealBackend:
440        tenferro_prims::TensorComplexRealPrims<Complex32, Context = C, Real = f32>,
441    C::MetadataBackend: TensorMetadataPrims<Context = C>,
442{
443    fn slogdet_dispatch(
444        ctx: &mut C,
445        tensor: &Tensor<Self>,
446    ) -> Result<SlogdetResult<Self, Self::Real>> {
447        slogdet_complex_impl::<Complex32, f32, C>(ctx, tensor)
448    }
449}
450
451impl<C> SlogdetDispatch<C> for Complex64
452where
453    Complex64: crate::prims_bridge::ScaleTensorByRealSameShape<C>,
454    C: backend::TensorLinalgContextFor<Complex64>
455        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f64>>
456        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<Complex64>>
457        + tenferro_prims::TensorComplexRealContextFor<Complex64>
458        + TensorMetadataContextFor,
459    C::Backend: 'static,
460    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f64>>>::ScalarBackend:
461        'static
462            + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<f64>, Context = C>
463            + TensorMetadataCastPrims<f64, Context = C>,
464    <C as tenferro_prims::TensorComplexRealContextFor<Complex64>>::ComplexRealBackend:
465        tenferro_prims::TensorComplexRealPrims<Complex64, Context = C, Real = f64>,
466    C::MetadataBackend: TensorMetadataPrims<Context = C>,
467{
468    fn slogdet_dispatch(
469        ctx: &mut C,
470        tensor: &Tensor<Self>,
471    ) -> Result<SlogdetResult<Self, Self::Real>> {
472        slogdet_complex_impl::<Complex64, f64, C>(ctx, tensor)
473    }
474}
475
476/// Compute sign and log-absolute-determinant of a square matrix.
477///
478/// # Examples
479///
480/// ```
481/// use tenferro_linalg::slogdet;
482/// use tenferro_prims::CpuContext;
483/// use tenferro_tensor::{MemoryOrder, Tensor};
484///
485/// let mut ctx = CpuContext::new(1);
486/// let col = MemoryOrder::ColumnMajor;
487/// let a = Tensor::<f64>::from_slice(&[2.0, 0.0, 0.0, 3.0], &[2, 2], col).unwrap();
488/// let result = slogdet(&mut ctx, &a).unwrap();
489/// assert_eq!(result.sign.ndim(), 0);
490/// assert_eq!(result.logabsdet.ndim(), 0);
491/// ```
492pub fn slogdet<T: KernelLinalgScalar, C>(
493    ctx: &mut C,
494    tensor: &Tensor<T>,
495) -> Result<SlogdetResult<T, T::Real>>
496where
497    T: SlogdetDispatch<C>,
498{
499    T::slogdet_dispatch(ctx, tensor)
500}