tenferro_linalg/
eager_backend.rs1use crate::backend::LinalgBackend;
2use tenferro_ad::EagerBackend;
3use tenferro_tensor::{Tensor, TensorView};
4
5macro_rules! dispatch_linalg {
6 ($backend:expr, $method:ident($($arg:expr),* $(,)?)) => {
7 match $backend {
8 EagerBackend::Cpu(backend) => backend.$method($($arg),*),
9 #[cfg(feature = "cuda")]
10 EagerBackend::Cuda(backend) => backend.$method($($arg),*),
11 }
12 };
13}
14
15impl LinalgBackend for EagerBackend {
16 fn cholesky(&mut self, input: &Tensor) -> tenferro_tensor::Result<Tensor> {
17 dispatch_linalg!(self, cholesky(input))
18 }
19
20 fn triangular_solve(
21 &mut self,
22 a: &Tensor,
23 b: &Tensor,
24 left_side: bool,
25 lower: bool,
26 transpose_a: bool,
27 unit_diagonal: bool,
28 ) -> tenferro_tensor::Result<Tensor> {
29 dispatch_linalg!(
30 self,
31 triangular_solve(a, b, left_side, lower, transpose_a, unit_diagonal)
32 )
33 }
34
35 fn lu(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
36 dispatch_linalg!(self, lu(input))
37 }
38
39 fn lu_factor(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
40 dispatch_linalg!(self, lu_factor(input))
41 }
42
43 fn full_piv_lu(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
44 dispatch_linalg!(self, full_piv_lu(input))
45 }
46
47 fn full_piv_lu_solve(
48 &mut self,
49 a: &Tensor,
50 b: &Tensor,
51 transpose_a: bool,
52 ) -> tenferro_tensor::Result<Tensor> {
53 dispatch_linalg!(self, full_piv_lu_solve(a, b, transpose_a))
54 }
55
56 fn svd(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
57 dispatch_linalg!(self, svd(input))
58 }
59
60 fn svd_values(&mut self, input: &Tensor) -> tenferro_tensor::Result<Tensor> {
61 dispatch_linalg!(self, svd_values(input))
62 }
63
64 fn svd_read(&mut self, input: TensorView<'_>) -> tenferro_tensor::Result<Vec<Tensor>> {
65 dispatch_linalg!(self, svd_read(input))
66 }
67
68 fn qr(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
69 dispatch_linalg!(self, qr(input))
70 }
71
72 fn eigh(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
73 dispatch_linalg!(self, eigh(input))
74 }
75
76 fn eigh_values(&mut self, input: &Tensor) -> tenferro_tensor::Result<Tensor> {
77 dispatch_linalg!(self, eigh_values(input))
78 }
79
80 fn eig(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
81 dispatch_linalg!(self, eig(input))
82 }
83
84 fn solve(&mut self, a: &Tensor, b: &Tensor) -> tenferro_tensor::Result<Tensor> {
85 dispatch_linalg!(self, solve(a, b))
86 }
87
88 fn lu_solve_prepared(
89 &mut self,
90 a: &Tensor,
91 packed_lu: &Tensor,
92 pivots: &Tensor,
93 b: &Tensor,
94 transpose_a: bool,
95 conjugate_a: bool,
96 ) -> tenferro_tensor::Result<Tensor> {
97 dispatch_linalg!(
98 self,
99 lu_solve_prepared(a, packed_lu, pivots, b, transpose_a, conjugate_a)
100 )
101 }
102}