tensor4all_tensorbackend/
backend.rs1use anyhow::{anyhow, Result};
7use num_complex::{Complex32, Complex64};
8use tenferro::{DType, Tensor, TensorBackend, TensorScalar, TypedTensor};
9
10use crate::context::with_default_backend;
11use crate::matrix::Matrix;
12
13#[derive(Debug, Clone)]
32pub struct SvdResult<T: TensorScalar> {
33 pub u: TypedTensor<T>,
35 pub s: TypedTensor<T::Real>,
37 pub vt: TypedTensor<T>,
39}
40
41#[derive(Debug, Clone)]
47pub struct FullPivLuResult<T: TensorScalar> {
48 pub p: TypedTensor<T>,
50 pub l: TypedTensor<T>,
52 pub u: TypedTensor<T>,
54 pub q: TypedTensor<T>,
56}
57
58#[derive(Debug, Clone)]
75pub struct FullPivLuMatrixResult<T> {
76 pub p: Matrix<T>,
78 pub l: Matrix<T>,
80 pub u: Matrix<T>,
82 pub q: Matrix<T>,
84}
85
86pub trait BackendLinalgScalar: TensorScalar {}
88
89impl<T: TensorScalar> BackendLinalgScalar for T {}
90
91pub trait MatrixSolveScalar: BackendLinalgScalar + crate::matrix::MatrixScalar {
109 #[doc(hidden)]
110 fn solve_matrix_impl(a: &Matrix<Self>, b: &Matrix<Self>) -> Result<Matrix<Self>>;
111}
112
113fn solve_matrix_direct<T>(a: &Matrix<T>, b: &Matrix<T>) -> Result<Matrix<T>>
114where
115 T: BackendLinalgScalar + Copy,
116{
117 let a_tensor = matrix_to_typed_tensor(a);
118 let b_tensor = matrix_to_typed_tensor(b);
119 let x = solve_backend(&a_tensor, &b_tensor)?;
120 typed_tensor_to_matrix("solve", x)
121}
122
123impl MatrixSolveScalar for f64 {
124 fn solve_matrix_impl(a: &Matrix<Self>, b: &Matrix<Self>) -> Result<Matrix<Self>> {
125 solve_matrix_direct(a, b)
126 }
127}
128
129impl MatrixSolveScalar for Complex64 {
130 fn solve_matrix_impl(a: &Matrix<Self>, b: &Matrix<Self>) -> Result<Matrix<Self>> {
131 solve_matrix_direct(a, b)
132 }
133}
134
135impl MatrixSolveScalar for f32 {
136 fn solve_matrix_impl(a: &Matrix<Self>, b: &Matrix<Self>) -> Result<Matrix<Self>> {
137 let a64 = Matrix::from_col_major_vec(
138 a.nrows(),
139 a.ncols(),
140 a.as_col_major_slice()
141 .iter()
142 .map(|&value| value as f64)
143 .collect(),
144 );
145 let b64 = Matrix::from_col_major_vec(
146 b.nrows(),
147 b.ncols(),
148 b.as_col_major_slice()
149 .iter()
150 .map(|&value| value as f64)
151 .collect(),
152 );
153 let x64 = solve_matrix_direct(&a64, &b64)?;
154 Ok(Matrix::from_col_major_vec(
155 x64.nrows(),
156 x64.ncols(),
157 x64.as_col_major_slice()
158 .iter()
159 .map(|&value| value as f32)
160 .collect(),
161 ))
162 }
163}
164
165impl MatrixSolveScalar for Complex32 {
166 fn solve_matrix_impl(a: &Matrix<Self>, b: &Matrix<Self>) -> Result<Matrix<Self>> {
167 let a64 = Matrix::from_col_major_vec(
168 a.nrows(),
169 a.ncols(),
170 a.as_col_major_slice()
171 .iter()
172 .map(|&value| Complex64::new(value.re as f64, value.im as f64))
173 .collect(),
174 );
175 let b64 = Matrix::from_col_major_vec(
176 b.nrows(),
177 b.ncols(),
178 b.as_col_major_slice()
179 .iter()
180 .map(|&value| Complex64::new(value.re as f64, value.im as f64))
181 .collect(),
182 );
183 let x64 = solve_matrix_direct(&a64, &b64)?;
184 Ok(Matrix::from_col_major_vec(
185 x64.nrows(),
186 x64.ncols(),
187 x64.as_col_major_slice()
188 .iter()
189 .map(|&value| Complex32::new(value.re as f32, value.im as f32))
190 .collect(),
191 ))
192 }
193}
194
195fn tensor_scalar_dtype<T: TensorScalar>() -> DType {
196 T::into_tensor(vec![0], Vec::new()).dtype()
197}
198
199fn try_into_typed_result<T: TensorScalar>(
200 op: &'static str,
201 tensor: Tensor,
202) -> Result<TypedTensor<T>> {
203 let actual = tensor.dtype();
204 T::try_into_typed(tensor).ok_or_else(|| {
205 anyhow!(
206 "{op}: dtype mismatch lhs={actual:?} rhs={:?}",
207 tensor_scalar_dtype::<T>()
208 )
209 })
210}
211
212fn convert_for_typed<T: TensorScalar>(op: &'static str, tensor: Tensor) -> Result<TypedTensor<T>> {
213 let expected = tensor_scalar_dtype::<T>();
214 let tensor = if tensor.dtype() == expected {
215 tensor
216 } else {
217 with_default_backend(|backend| {
218 backend.with_exec_session(|exec| exec.convert(&tensor, expected))
219 })
220 .map_err(|e| anyhow!("{op}: dtype conversion to {expected:?} failed: {e}"))?
221 };
222 try_into_typed_result::<T>(op, tensor)
223}
224
225fn matrix_to_typed_tensor<T>(matrix: &Matrix<T>) -> TypedTensor<T>
226where
227 T: TensorScalar + Copy,
228{
229 TypedTensor::from_vec(
230 vec![matrix.nrows(), matrix.ncols()],
231 matrix.as_col_major_slice().to_vec(),
232 )
233}
234
235fn typed_tensor_to_matrix<T>(op: &'static str, tensor: TypedTensor<T>) -> Result<Matrix<T>>
236where
237 T: TensorScalar + Copy,
238{
239 if tensor.shape.len() != 2 {
240 return Err(anyhow!(
241 "{op}: expected a rank-2 tensor, got shape {:?}",
242 tensor.shape
243 ));
244 }
245 Ok(Matrix::from_col_major_vec(
246 tensor.shape[0],
247 tensor.shape[1],
248 tensor.as_slice().to_vec(),
249 ))
250}
251
252pub fn svd_backend<T>(a: &TypedTensor<T>) -> Result<SvdResult<T>>
259where
260 T: BackendLinalgScalar,
261{
262 let tensor = T::into_tensor(a.shape.clone(), a.host_data().to_vec());
263 let (u, s, vt) = with_default_backend(|backend| tensor.svd(backend))
264 .map_err(|e| anyhow!("SVD computation failed via tenferro-tensor: {e}"))?;
265 Ok(SvdResult {
266 u: convert_for_typed::<T>("svd", u)?,
267 s: convert_for_typed::<T::Real>("svd", s)?,
268 vt: convert_for_typed::<T>("svd", vt)?,
269 })
270}
271
272pub fn qr_backend<T>(a: &TypedTensor<T>) -> Result<(TypedTensor<T>, TypedTensor<T>)>
279where
280 T: BackendLinalgScalar,
281{
282 with_default_backend(|backend| a.qr(backend))
283 .map_err(|e| anyhow!("QR computation failed via tenferro-tensor: {e}"))
284}
285
286pub fn solve_backend<T>(a: &TypedTensor<T>, b: &TypedTensor<T>) -> Result<TypedTensor<T>>
293where
294 T: BackendLinalgScalar,
295{
296 let a_tensor = T::into_tensor(a.shape.clone(), a.host_data().to_vec());
297 let b_tensor = T::into_tensor(b.shape.clone(), b.host_data().to_vec());
298 let result = with_default_backend(|backend| a_tensor.solve(&b_tensor, backend))
299 .map_err(|e| anyhow!("linear solve failed via tenferro-tensor: {e}"))?;
300 try_into_typed_result::<T>("solve", result)
301}
302
303pub fn solve_matrix<T>(a: &Matrix<T>, b: &Matrix<T>) -> Result<Matrix<T>>
325where
326 T: MatrixSolveScalar,
327{
328 T::solve_matrix_impl(a, b)
329}
330
331pub fn full_piv_lu_backend<T>(a: &TypedTensor<T>) -> Result<FullPivLuResult<T>>
338where
339 T: BackendLinalgScalar,
340{
341 let tensor = T::into_tensor(a.shape.clone(), a.host_data().to_vec());
342 let (p, l, u, q, _parity) = with_default_backend(|backend| tensor.full_piv_lu(backend))
343 .map_err(|e| anyhow!("complete-pivoting LU failed via tenferro-tensor: {e}"))?;
344 Ok(FullPivLuResult {
345 p: convert_for_typed::<T>("full_piv_lu", p)?,
346 l: convert_for_typed::<T>("full_piv_lu", l)?,
347 u: convert_for_typed::<T>("full_piv_lu", u)?,
348 q: convert_for_typed::<T>("full_piv_lu", q)?,
349 })
350}
351
352pub fn full_piv_lu_matrix<T>(a: &Matrix<T>) -> Result<FullPivLuMatrixResult<T>>
373where
374 T: BackendLinalgScalar + Copy,
375{
376 let tensor = matrix_to_typed_tensor(a);
377 let decomp = full_piv_lu_backend(&tensor)?;
378 Ok(FullPivLuMatrixResult {
379 p: typed_tensor_to_matrix("full_piv_lu", decomp.p)?,
380 l: typed_tensor_to_matrix("full_piv_lu", decomp.l)?,
381 u: typed_tensor_to_matrix("full_piv_lu", decomp.u)?,
382 q: typed_tensor_to_matrix("full_piv_lu", decomp.q)?,
383 })
384}
385
386#[cfg(test)]
387mod tests;