Skip to main content

tenferro_linalg/
backend.rs

1use tenferro_tensor::{Tensor, TensorBackend, TensorView};
2
3/// Backend surface required by the linalg extension runtime.
4///
5/// # Examples
6///
7/// ```rust
8/// use tenferro_linalg::backend::LinalgBackend;
9/// use tenferro_cpu::CpuBackend;
10///
11/// fn accepts_linalg_backend<B: LinalgBackend>(_backend: &mut B) {}
12///
13/// let mut backend = CpuBackend::new();
14/// accepts_linalg_backend(&mut backend);
15/// ```
16pub trait LinalgBackend: TensorBackend {
17    /// Compute a Cholesky factorization.
18    fn cholesky(&mut self, input: &Tensor) -> tenferro_tensor::Result<Tensor>;
19
20    /// Solve a triangular linear system with explicit side, triangle,
21    /// transpose, and unit-diagonal flags.
22    fn triangular_solve(
23        &mut self,
24        a: &Tensor,
25        b: &Tensor,
26        left_side: bool,
27        lower: bool,
28        transpose_a: bool,
29        unit_diagonal: bool,
30    ) -> tenferro_tensor::Result<Tensor>;
31
32    /// Compute public LU outputs `(P, L, U, parity)`.
33    fn lu(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>>;
34
35    #[doc(hidden)]
36    fn lu_factor(&mut self, _input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
37        Err(tenferro_tensor::Error::backend_failure(
38            "lu_factor",
39            format!(
40                "backend {} does not implement internal packed LU factorization",
41                std::any::type_name::<Self>()
42            ),
43        ))
44    }
45
46    /// Compute complete-pivot LU outputs `(P, L, U, Q, parity)`.
47    ///
48    /// The reconstruction convention is `A = P^T * L * U * Q`, equivalently
49    /// `P * A * Q^T = L * U`. `parity` is a scalar real tensor containing
50    /// `+1` or `-1`: `F32` for `F32`/`C32` inputs and `F64` for `F64`/`C64`
51    /// inputs.
52    fn full_piv_lu(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>>;
53
54    /// Solve a linear system through the complete-pivot LU path.
55    ///
56    /// With `transpose_a = false`, this solves `A * x = b`. With
57    /// `transpose_a = true`, this solves `A^T * x = b`.
58    fn full_piv_lu_solve(
59        &mut self,
60        a: &Tensor,
61        b: &Tensor,
62        transpose_a: bool,
63    ) -> tenferro_tensor::Result<Tensor>;
64
65    /// Compute public SVD outputs `(U, S, Vt)`.
66    fn svd(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>>;
67
68    #[doc(hidden)]
69    fn svd_values(&mut self, _input: &Tensor) -> tenferro_tensor::Result<Tensor> {
70        Err(tenferro_tensor::Error::backend_failure(
71            "svd_values",
72            format!(
73                "backend {} does not implement internal singular-values-only decomposition",
74                std::any::type_name::<Self>()
75            ),
76        ))
77    }
78
79    /// Compute a singular value decomposition from a borrowed tensor view.
80    ///
81    /// Backends may canonicalize the view inside the same placement family, but
82    /// must not silently transfer between CPU and GPU memory.
83    ///
84    /// # Examples
85    ///
86    /// ```rust
87    /// use tenferro_linalg::LinalgBackend;
88    /// use tenferro_cpu::CpuBackend;
89    /// use tenferro_tensor::{TensorView, TypedTensor};
90    ///
91    /// let input = TypedTensor::<f64>::from_vec_col_major(
92    ///     vec![2, 2],
93    ///     vec![1.0, 0.0, 0.0, 2.0],
94    /// )?;
95    /// let outputs = CpuBackend::new().svd_read(TensorView::F64(input.as_view()))?;
96    /// assert_eq!(outputs[1].shape(), &[2]);
97    /// # Ok::<(), tenferro_tensor::Error>(())
98    /// ```
99    fn svd_read(&mut self, _input: TensorView<'_>) -> tenferro_tensor::Result<Vec<Tensor>> {
100        Err(tenferro_tensor::Error::backend_failure(
101            "svd",
102            "backend does not accept borrowed tensor views at this execution boundary",
103        ))
104    }
105
106    /// Compute public QR outputs `(Q, R)`.
107    ///
108    /// QR is thin: for an `m x n` input, `Q` has shape `m x min(m, n)` and
109    /// `R` has shape `min(m, n) x n`.
110    fn qr(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>>;
111
112    /// Compute public Hermitian eigendecomposition outputs `(values, vectors)`.
113    ///
114    /// The returned vector order is `[values, vectors]`, where `values` has
115    /// shape `[n]` and `vectors` has shape `[n, n]`.
116    fn eigh(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>>;
117
118    #[doc(hidden)]
119    fn eigh_values(&mut self, _input: &Tensor) -> tenferro_tensor::Result<Tensor> {
120        Err(tenferro_tensor::Error::backend_failure(
121            "eigh_values",
122            format!(
123                "backend {} does not implement internal Hermitian eigenvalues-only decomposition",
124                std::any::type_name::<Self>()
125            ),
126        ))
127    }
128
129    /// Compute public general eigendecomposition outputs `(values, vectors)`.
130    fn eig(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>>;
131
132    #[doc(hidden)]
133    fn eig_values(&mut self, _input: &Tensor) -> tenferro_tensor::Result<Tensor> {
134        Err(tenferro_tensor::Error::backend_failure(
135            "eig_values",
136            format!(
137                "backend {} does not implement internal general eigenvalues-only decomposition",
138                std::any::type_name::<Self>()
139            ),
140        ))
141    }
142
143    /// Solve a dense linear system.
144    fn solve(&mut self, a: &Tensor, b: &Tensor) -> tenferro_tensor::Result<Tensor>;
145
146    #[doc(hidden)]
147    fn lu_solve_prepared(
148        &mut self,
149        _a: &Tensor,
150        _packed_lu: &Tensor,
151        _pivots: &Tensor,
152        _b: &Tensor,
153        _transpose_a: bool,
154        _conjugate_a: bool,
155    ) -> tenferro_tensor::Result<Tensor> {
156        Err(tenferro_tensor::Error::backend_failure(
157            "lu_solve_prepared",
158            format!(
159                "backend {} does not implement internal prepared LU solve",
160                std::any::type_name::<Self>()
161            ),
162        ))
163    }
164}