Skip to main content

tensor4all_tensorbackend/
backend.rs

1//! Backend dispatch helpers for linear algebra operations.
2//!
3//! This module keeps tensor4all's typed factorization entry points thin while
4//! routing the actual work through the shared tenferro CPU backend.
5
6use anyhow::{anyhow, Result};
7use num_complex::{Complex32, Complex64};
8use tenferro::{DType, Tensor, TensorBackend, TensorScalar, TypedTensor};
9
10use crate::context::with_default_backend;
11use crate::matrix::Matrix;
12
13/// Result of SVD decomposition `A = U * diag(S) * Vt`.
14///
15/// The singular values are stored in a real-valued typed tensor, even when the
16/// input matrix is complex.
17///
18/// # Examples
19///
20/// ```
21/// use tensor4all_tensorbackend::svd_backend;
22/// use tenferro::TypedTensor;
23///
24/// let a = TypedTensor::<f64>::from_vec(vec![2, 2], vec![1.0, 0.0, 0.0, 2.0]);
25/// let result = svd_backend(&a).unwrap();
26///
27/// assert_eq!(result.u.shape, vec![2, 2]);
28/// assert_eq!(result.s.shape, vec![2]);
29/// assert_eq!(result.vt.shape, vec![2, 2]);
30/// ```
31#[derive(Debug, Clone)]
32pub struct SvdResult<T: TensorScalar> {
33    /// Left singular vectors.
34    pub u: TypedTensor<T>,
35    /// Singular values.
36    pub s: TypedTensor<T::Real>,
37    /// Right singular vectors transposed.
38    pub vt: TypedTensor<T>,
39}
40
41/// Result of complete-pivoting LU decomposition `P A Q^T = L U`.
42///
43/// The parity output from tenferro is intentionally omitted because current
44/// tensor4all callers only need the permutation matrices and the upper
45/// triangular factor for pivot selection.
46#[derive(Debug, Clone)]
47pub struct FullPivLuResult<T: TensorScalar> {
48    /// Left permutation matrix.
49    pub p: TypedTensor<T>,
50    /// Lower triangular factor.
51    pub l: TypedTensor<T>,
52    /// Upper triangular factor.
53    pub u: TypedTensor<T>,
54    /// Right permutation matrix.
55    pub q: TypedTensor<T>,
56}
57
58/// Result of complete-pivoting LU decomposition on [`Matrix`] values.
59///
60/// This is the matrix-shaped counterpart of [`FullPivLuResult`]. It exists so
61/// downstream crates can use backend linalg without hand-writing
62/// `TypedTensor` conversion code.
63///
64/// # Examples
65///
66/// ```
67/// use tensor4all_tensorbackend::{from_vec2d, full_piv_lu_matrix};
68///
69/// let matrix = from_vec2d(vec![vec![0.0_f64, 1.0], vec![2.0, 3.0]]);
70/// let factors = full_piv_lu_matrix(&matrix).unwrap();
71/// assert_eq!(factors.u.nrows(), 2);
72/// assert_eq!(factors.u.ncols(), 2);
73/// ```
74#[derive(Debug, Clone)]
75pub struct FullPivLuMatrixResult<T> {
76    /// Left permutation matrix.
77    pub p: Matrix<T>,
78    /// Lower triangular factor.
79    pub l: Matrix<T>,
80    /// Upper triangular factor.
81    pub u: Matrix<T>,
82    /// Right permutation matrix.
83    pub q: Matrix<T>,
84}
85
86/// Scalar bound accepted by tensor4all's typed linalg wrappers.
87pub trait BackendLinalgScalar: TensorScalar {}
88
89impl<T: TensorScalar> BackendLinalgScalar for T {}
90
91/// Scalar types supported by [`solve_matrix`].
92///
93/// `f64` and `Complex64` are solved directly. `f32` and `Complex32` are
94/// promoted to the corresponding 64-bit dtype for the backend solve and then
95/// converted back, because the current tenferro CPU LU solve is double
96/// precision only.
97///
98/// # Examples
99///
100/// ```
101/// use tensor4all_tensorbackend::{from_vec2d, solve_matrix};
102///
103/// let a = from_vec2d(vec![vec![2.0_f32, 1.0], vec![1.0, 2.0]]);
104/// let b = from_vec2d(vec![vec![1.0_f32], vec![0.0]]);
105/// let x = solve_matrix(&a, &b).unwrap();
106/// assert!((x[[0, 0]] - 2.0 / 3.0).abs() < 1.0e-6);
107/// ```
108pub trait MatrixSolveScalar: BackendLinalgScalar + crate::matrix::MatrixScalar {
109    #[doc(hidden)]
110    fn solve_matrix_impl(a: &Matrix<Self>, b: &Matrix<Self>) -> Result<Matrix<Self>>;
111}
112
113fn solve_matrix_direct<T>(a: &Matrix<T>, b: &Matrix<T>) -> Result<Matrix<T>>
114where
115    T: BackendLinalgScalar + Copy,
116{
117    let a_tensor = matrix_to_typed_tensor(a);
118    let b_tensor = matrix_to_typed_tensor(b);
119    let x = solve_backend(&a_tensor, &b_tensor)?;
120    typed_tensor_to_matrix("solve", x)
121}
122
123impl MatrixSolveScalar for f64 {
124    fn solve_matrix_impl(a: &Matrix<Self>, b: &Matrix<Self>) -> Result<Matrix<Self>> {
125        solve_matrix_direct(a, b)
126    }
127}
128
129impl MatrixSolveScalar for Complex64 {
130    fn solve_matrix_impl(a: &Matrix<Self>, b: &Matrix<Self>) -> Result<Matrix<Self>> {
131        solve_matrix_direct(a, b)
132    }
133}
134
135impl MatrixSolveScalar for f32 {
136    fn solve_matrix_impl(a: &Matrix<Self>, b: &Matrix<Self>) -> Result<Matrix<Self>> {
137        let a64 = Matrix::from_col_major_vec(
138            a.nrows(),
139            a.ncols(),
140            a.as_col_major_slice()
141                .iter()
142                .map(|&value| value as f64)
143                .collect(),
144        );
145        let b64 = Matrix::from_col_major_vec(
146            b.nrows(),
147            b.ncols(),
148            b.as_col_major_slice()
149                .iter()
150                .map(|&value| value as f64)
151                .collect(),
152        );
153        let x64 = solve_matrix_direct(&a64, &b64)?;
154        Ok(Matrix::from_col_major_vec(
155            x64.nrows(),
156            x64.ncols(),
157            x64.as_col_major_slice()
158                .iter()
159                .map(|&value| value as f32)
160                .collect(),
161        ))
162    }
163}
164
165impl MatrixSolveScalar for Complex32 {
166    fn solve_matrix_impl(a: &Matrix<Self>, b: &Matrix<Self>) -> Result<Matrix<Self>> {
167        let a64 = Matrix::from_col_major_vec(
168            a.nrows(),
169            a.ncols(),
170            a.as_col_major_slice()
171                .iter()
172                .map(|&value| Complex64::new(value.re as f64, value.im as f64))
173                .collect(),
174        );
175        let b64 = Matrix::from_col_major_vec(
176            b.nrows(),
177            b.ncols(),
178            b.as_col_major_slice()
179                .iter()
180                .map(|&value| Complex64::new(value.re as f64, value.im as f64))
181                .collect(),
182        );
183        let x64 = solve_matrix_direct(&a64, &b64)?;
184        Ok(Matrix::from_col_major_vec(
185            x64.nrows(),
186            x64.ncols(),
187            x64.as_col_major_slice()
188                .iter()
189                .map(|&value| Complex32::new(value.re as f32, value.im as f32))
190                .collect(),
191        ))
192    }
193}
194
195fn tensor_scalar_dtype<T: TensorScalar>() -> DType {
196    T::into_tensor(vec![0], Vec::new()).dtype()
197}
198
199fn try_into_typed_result<T: TensorScalar>(
200    op: &'static str,
201    tensor: Tensor,
202) -> Result<TypedTensor<T>> {
203    let actual = tensor.dtype();
204    T::try_into_typed(tensor).ok_or_else(|| {
205        anyhow!(
206            "{op}: dtype mismatch lhs={actual:?} rhs={:?}",
207            tensor_scalar_dtype::<T>()
208        )
209    })
210}
211
212fn convert_for_typed<T: TensorScalar>(op: &'static str, tensor: Tensor) -> Result<TypedTensor<T>> {
213    let expected = tensor_scalar_dtype::<T>();
214    let tensor = if tensor.dtype() == expected {
215        tensor
216    } else {
217        with_default_backend(|backend| {
218            backend.with_exec_session(|exec| exec.convert(&tensor, expected))
219        })
220        .map_err(|e| anyhow!("{op}: dtype conversion to {expected:?} failed: {e}"))?
221    };
222    try_into_typed_result::<T>(op, tensor)
223}
224
225fn matrix_to_typed_tensor<T>(matrix: &Matrix<T>) -> TypedTensor<T>
226where
227    T: TensorScalar + Copy,
228{
229    TypedTensor::from_vec(
230        vec![matrix.nrows(), matrix.ncols()],
231        matrix.as_col_major_slice().to_vec(),
232    )
233}
234
235fn typed_tensor_to_matrix<T>(op: &'static str, tensor: TypedTensor<T>) -> Result<Matrix<T>>
236where
237    T: TensorScalar + Copy,
238{
239    if tensor.shape.len() != 2 {
240        return Err(anyhow!(
241            "{op}: expected a rank-2 tensor, got shape {:?}",
242            tensor.shape
243        ));
244    }
245    Ok(Matrix::from_col_major_vec(
246        tensor.shape[0],
247        tensor.shape[1],
248        tensor.as_slice().to_vec(),
249    ))
250}
251
252/// Compute a thin/economy SVD on a typed tensor.
253///
254/// # Errors
255///
256/// Returns an error if the backend rejects the input or the decomposition
257/// fails to converge.
258pub fn svd_backend<T>(a: &TypedTensor<T>) -> Result<SvdResult<T>>
259where
260    T: BackendLinalgScalar,
261{
262    let tensor = T::into_tensor(a.shape.clone(), a.host_data().to_vec());
263    let (u, s, vt) = with_default_backend(|backend| tensor.svd(backend))
264        .map_err(|e| anyhow!("SVD computation failed via tenferro-tensor: {e}"))?;
265    Ok(SvdResult {
266        u: convert_for_typed::<T>("svd", u)?,
267        s: convert_for_typed::<T::Real>("svd", s)?,
268        vt: convert_for_typed::<T>("svd", vt)?,
269    })
270}
271
272/// Compute a thin/economy QR decomposition on a typed tensor.
273///
274/// # Errors
275///
276/// Returns an error if the backend rejects the input or the decomposition
277/// fails.
278pub fn qr_backend<T>(a: &TypedTensor<T>) -> Result<(TypedTensor<T>, TypedTensor<T>)>
279where
280    T: BackendLinalgScalar,
281{
282    with_default_backend(|backend| a.qr(backend))
283        .map_err(|e| anyhow!("QR computation failed via tenferro-tensor: {e}"))
284}
285
286/// Solve `A X = B` with the configured tenferro backend.
287///
288/// # Errors
289///
290/// Returns an error if the backend rejects the input shapes, the scalar dtype,
291/// or the coefficient matrix is singular.
292pub fn solve_backend<T>(a: &TypedTensor<T>, b: &TypedTensor<T>) -> Result<TypedTensor<T>>
293where
294    T: BackendLinalgScalar,
295{
296    let a_tensor = T::into_tensor(a.shape.clone(), a.host_data().to_vec());
297    let b_tensor = T::into_tensor(b.shape.clone(), b.host_data().to_vec());
298    let result = with_default_backend(|backend| a_tensor.solve(&b_tensor, backend))
299        .map_err(|e| anyhow!("linear solve failed via tenferro-tensor: {e}"))?;
300    try_into_typed_result::<T>("solve", result)
301}
302
303/// Solve `A X = B` for column-major [`Matrix`] values.
304///
305/// This routes the operation through the configured tenferro backend and keeps
306/// matrix-to-tensor conversion centralized in `tensor4all-tensorbackend`.
307///
308/// # Errors
309///
310/// Returns an error if the backend rejects the input shapes, scalar dtype, or
311/// coefficient matrix.
312///
313/// # Examples
314///
315/// ```
316/// use tensor4all_tensorbackend::{from_vec2d, solve_matrix};
317///
318/// let a = from_vec2d(vec![vec![2.0_f64, 1.0], vec![1.0, 2.0]]);
319/// let b = from_vec2d(vec![vec![1.0_f64], vec![0.0]]);
320/// let x = solve_matrix(&a, &b).unwrap();
321/// assert!((x[[0, 0]] - 2.0 / 3.0).abs() < 1.0e-12);
322/// assert!((x[[1, 0]] + 1.0 / 3.0).abs() < 1.0e-12);
323/// ```
324pub fn solve_matrix<T>(a: &Matrix<T>, b: &Matrix<T>) -> Result<Matrix<T>>
325where
326    T: MatrixSolveScalar,
327{
328    T::solve_matrix_impl(a, b)
329}
330
331/// Compute complete-pivoting LU with the configured tenferro backend.
332///
333/// # Errors
334///
335/// Returns an error if the backend does not support the input dtype or if the
336/// factorization fails.
337pub fn full_piv_lu_backend<T>(a: &TypedTensor<T>) -> Result<FullPivLuResult<T>>
338where
339    T: BackendLinalgScalar,
340{
341    let tensor = T::into_tensor(a.shape.clone(), a.host_data().to_vec());
342    let (p, l, u, q, _parity) = with_default_backend(|backend| tensor.full_piv_lu(backend))
343        .map_err(|e| anyhow!("complete-pivoting LU failed via tenferro-tensor: {e}"))?;
344    Ok(FullPivLuResult {
345        p: convert_for_typed::<T>("full_piv_lu", p)?,
346        l: convert_for_typed::<T>("full_piv_lu", l)?,
347        u: convert_for_typed::<T>("full_piv_lu", u)?,
348        q: convert_for_typed::<T>("full_piv_lu", q)?,
349    })
350}
351
352/// Compute complete-pivoting LU for a column-major [`Matrix`].
353///
354/// This is a convenience wrapper over [`full_piv_lu_backend`] for callers that
355/// use [`Matrix`] as their dense boundary type.
356///
357/// # Errors
358///
359/// Returns an error if the backend does not support the input dtype or if the
360/// factorization fails.
361///
362/// # Examples
363///
364/// ```
365/// use tensor4all_tensorbackend::{from_vec2d, full_piv_lu_matrix};
366///
367/// let matrix = from_vec2d(vec![vec![0.0_f64, 1.0], vec![2.0, 3.0]]);
368/// let factors = full_piv_lu_matrix(&matrix).unwrap();
369/// assert_eq!(factors.p.nrows(), 2);
370/// assert_eq!(factors.q.ncols(), 2);
371/// ```
372pub fn full_piv_lu_matrix<T>(a: &Matrix<T>) -> Result<FullPivLuMatrixResult<T>>
373where
374    T: BackendLinalgScalar + Copy,
375{
376    let tensor = matrix_to_typed_tensor(a);
377    let decomp = full_piv_lu_backend(&tensor)?;
378    Ok(FullPivLuMatrixResult {
379        p: typed_tensor_to_matrix("full_piv_lu", decomp.p)?,
380        l: typed_tensor_to_matrix("full_piv_lu", decomp.l)?,
381        u: typed_tensor_to_matrix("full_piv_lu", decomp.u)?,
382        q: typed_tensor_to_matrix("full_piv_lu", decomp.q)?,
383    })
384}
385
386#[cfg(test)]
387mod tests;