tenferro_linalg_prims/backend/
mod.rs1mod context;
8#[cfg(any(feature = "cuda", test))]
9pub(crate) mod linalg_utils;
10mod tensor_api;
11mod tensor_helpers;
12
13#[cfg(feature = "linalg-lapack")]
14#[path = "../../../tenferro-linalg/src/backend/blas_lapack_backend/mod.rs"]
15mod blas_lapack_backend;
16mod cpu;
17#[cfg(feature = "linalg-faer")]
18#[path = "../../../tenferro-linalg/src/backend/cpu_faer.rs"]
19mod cpu_faer;
20#[cfg(feature = "linalg-lapack")]
21#[path = "../../../tenferro-linalg/src/backend/cpu_lapack.rs"]
22mod cpu_lapack;
23#[path = "../../../tenferro-linalg/src/backend/cpu_tensor_impl.rs"]
24mod cpu_tensor_impl;
25mod cuda;
26#[cfg(feature = "linalg-faer")]
27mod faer_backend;
28mod hip;
29
30use tenferro_device::Result;
31
32#[cfg(feature = "linalg-lapack")]
33pub use blas_lapack_backend::BlasLapackBackend;
34pub use context::TensorLinalgContextFor;
35pub use cpu::CpuTensorLinalgBackend;
36pub use cuda::{CudaDataType, CudaLinalgScalar, CudaTensorLinalgBackend};
37#[cfg(feature = "linalg-faer")]
38pub use faer_backend::FaerBackend;
39pub use hip::HipTensorLinalgBackend;
40
41#[doc(hidden)]
42pub use tensor_helpers::{
43 batch_count, broadcast_batch_dims, ensure_col_major, extract_contiguous_slice,
44 materialize_broadcasted_batches, materialize_broadcasted_pivot_batches,
45 validate_lu_pivot_shape, validate_matrix_shape, validate_solve_rhs_shape, validate_square,
46 zero_trailing_by_counts, BroadcastBatchIndexer, SolveRhsLayout,
47};
48
49#[allow(dead_code)]
51pub(crate) trait LinalgBackend<T: Copy + 'static> {
52 type Real: Copy + 'static;
53
54 fn thin_svd(
55 &mut self,
56 a: &[T],
57 m: usize,
58 n: usize,
59 u: &mut [T],
60 s: &mut [Self::Real],
61 vt: &mut [T],
62 ) -> Result<()>;
63
64 fn qr(&mut self, a: &[T], m: usize, n: usize, q: &mut [T], r: &mut [T]) -> Result<()>;
65
66 fn lu(
67 &mut self,
68 a: &[T],
69 m: usize,
70 n: usize,
71 perm: &mut [usize],
72 l: &mut [T],
73 u_out: &mut [T],
74 ) -> Result<()>;
75
76 fn cholesky(&mut self, a: &[T], n: usize, l: &mut [T]) -> Result<()>;
77
78 fn eigen_sym(
79 &mut self,
80 a: &[T],
81 n: usize,
82 values: &mut [Self::Real],
83 vectors: &mut [T],
84 ) -> Result<()>;
85
86 fn mat_mul(
87 &mut self,
88 a: &[T],
89 m: usize,
90 k: usize,
91 b: &[T],
92 n: usize,
93 c: &mut [T],
94 ) -> Result<()>;
95
96 fn solve(&mut self, a: &[T], b: &[T], n: usize, nrhs: usize, x: &mut [T]) -> Result<()>;
97
98 fn solve_triangular(
99 &mut self,
100 a: &[T],
101 b: &[T],
102 n: usize,
103 nrhs: usize,
104 upper: bool,
105 x: &mut [T],
106 ) -> Result<()>;
107
108 fn eig_general(
109 &mut self,
110 a: &[T],
111 n: usize,
112 values_ri: &mut [T],
113 vectors_ri: &mut [T],
114 ) -> Result<()>;
115}
116
117#[doc(hidden)]
119pub fn col_major_strides(dims: &[usize]) -> Vec<isize> {
120 let mut strides = vec![0isize; dims.len()];
121 if dims.is_empty() {
122 return strides;
123 }
124 strides[0] = 1;
125 for i in 1..dims.len() {
126 strides[i] = strides[i - 1] * dims[i - 1] as isize;
127 }
128 strides
129}
130
131#[cfg(test)]
132mod tests;