tenferro_linalg_prims/backend/
mod.rs

1//! Backend-owned linalg context bindings and marker types.
2//!
3//! This module is the long-term home of the tensor-level linalg ownership
4//! boundary. High-level crates should depend on these bindings instead of
5//! owning backend markers themselves.
6
7mod context;
8#[cfg(any(feature = "cuda", test))]
9pub(crate) mod linalg_utils;
10mod tensor_api;
11mod tensor_helpers;
12
13#[cfg(feature = "linalg-lapack")]
14#[path = "../../../tenferro-linalg/src/backend/blas_lapack_backend/mod.rs"]
15mod blas_lapack_backend;
16mod cpu;
17#[cfg(feature = "linalg-faer")]
18#[path = "../../../tenferro-linalg/src/backend/cpu_faer.rs"]
19mod cpu_faer;
20#[cfg(feature = "linalg-lapack")]
21#[path = "../../../tenferro-linalg/src/backend/cpu_lapack.rs"]
22mod cpu_lapack;
23#[path = "../../../tenferro-linalg/src/backend/cpu_tensor_impl.rs"]
24mod cpu_tensor_impl;
25mod cuda;
26#[cfg(feature = "linalg-faer")]
27mod faer_backend;
28mod hip;
29
30use tenferro_device::Result;
31
32#[cfg(feature = "linalg-lapack")]
33pub use blas_lapack_backend::BlasLapackBackend;
34pub use context::TensorLinalgContextFor;
35pub use cpu::CpuTensorLinalgBackend;
36pub use cuda::{CudaDataType, CudaLinalgScalar, CudaTensorLinalgBackend};
37#[cfg(feature = "linalg-faer")]
38pub use faer_backend::FaerBackend;
39pub use hip::HipTensorLinalgBackend;
40
41#[doc(hidden)]
42pub use tensor_helpers::{
43    batch_count, broadcast_batch_dims, ensure_col_major, extract_contiguous_slice,
44    materialize_broadcasted_batches, materialize_broadcasted_pivot_batches,
45    validate_lu_pivot_shape, validate_matrix_shape, validate_solve_rhs_shape, validate_square,
46    zero_trailing_by_counts, BroadcastBatchIndexer, SolveRhsLayout,
47};
48
49/// Slice-level backend interface for matrix linear algebra operations.
50#[allow(dead_code)]
51pub(crate) trait LinalgBackend<T: Copy + 'static> {
52    type Real: Copy + 'static;
53
54    fn thin_svd(
55        &mut self,
56        a: &[T],
57        m: usize,
58        n: usize,
59        u: &mut [T],
60        s: &mut [Self::Real],
61        vt: &mut [T],
62    ) -> Result<()>;
63
64    fn qr(&mut self, a: &[T], m: usize, n: usize, q: &mut [T], r: &mut [T]) -> Result<()>;
65
66    fn lu(
67        &mut self,
68        a: &[T],
69        m: usize,
70        n: usize,
71        perm: &mut [usize],
72        l: &mut [T],
73        u_out: &mut [T],
74    ) -> Result<()>;
75
76    fn cholesky(&mut self, a: &[T], n: usize, l: &mut [T]) -> Result<()>;
77
78    fn eigen_sym(
79        &mut self,
80        a: &[T],
81        n: usize,
82        values: &mut [Self::Real],
83        vectors: &mut [T],
84    ) -> Result<()>;
85
86    fn mat_mul(
87        &mut self,
88        a: &[T],
89        m: usize,
90        k: usize,
91        b: &[T],
92        n: usize,
93        c: &mut [T],
94    ) -> Result<()>;
95
96    fn solve(&mut self, a: &[T], b: &[T], n: usize, nrhs: usize, x: &mut [T]) -> Result<()>;
97
98    fn solve_triangular(
99        &mut self,
100        a: &[T],
101        b: &[T],
102        n: usize,
103        nrhs: usize,
104        upper: bool,
105        x: &mut [T],
106    ) -> Result<()>;
107
108    fn eig_general(
109        &mut self,
110        a: &[T],
111        n: usize,
112        values_ri: &mut [T],
113        vectors_ri: &mut [T],
114    ) -> Result<()>;
115}
116
117/// Compute column-major strides for the provided dimensions.
118#[doc(hidden)]
119pub fn col_major_strides(dims: &[usize]) -> Vec<isize> {
120    let mut strides = vec![0isize; dims.len()];
121    if dims.is_empty() {
122        return strides;
123    }
124    strides[0] = 1;
125    for i in 1..dims.len() {
126        strides[i] = strides[i - 1] * dims[i - 1] as isize;
127    }
128    strides
129}
130
131#[cfg(test)]
132mod tests;