Skip to main content

tenferro_linalg/
eager_backend.rs

1use crate::backend::LinalgBackend;
2use tenferro_ad::EagerBackend;
3use tenferro_tensor::{Tensor, TensorView};
4
5macro_rules! dispatch_linalg {
6    ($backend:expr, $method:ident($($arg:expr),* $(,)?)) => {
7        match $backend {
8            EagerBackend::Cpu(backend) => backend.$method($($arg),*),
9            #[cfg(feature = "cuda")]
10            EagerBackend::Cuda(backend) => backend.$method($($arg),*),
11        }
12    };
13}
14
15impl LinalgBackend for EagerBackend {
16    fn cholesky(&mut self, input: &Tensor) -> tenferro_tensor::Result<Tensor> {
17        dispatch_linalg!(self, cholesky(input))
18    }
19
20    fn triangular_solve(
21        &mut self,
22        a: &Tensor,
23        b: &Tensor,
24        left_side: bool,
25        lower: bool,
26        transpose_a: bool,
27        unit_diagonal: bool,
28    ) -> tenferro_tensor::Result<Tensor> {
29        dispatch_linalg!(
30            self,
31            triangular_solve(a, b, left_side, lower, transpose_a, unit_diagonal)
32        )
33    }
34
35    fn lu(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
36        dispatch_linalg!(self, lu(input))
37    }
38
39    fn lu_factor(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
40        dispatch_linalg!(self, lu_factor(input))
41    }
42
43    fn full_piv_lu(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
44        dispatch_linalg!(self, full_piv_lu(input))
45    }
46
47    fn full_piv_lu_solve(
48        &mut self,
49        a: &Tensor,
50        b: &Tensor,
51        transpose_a: bool,
52    ) -> tenferro_tensor::Result<Tensor> {
53        dispatch_linalg!(self, full_piv_lu_solve(a, b, transpose_a))
54    }
55
56    fn svd(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
57        dispatch_linalg!(self, svd(input))
58    }
59
60    fn svd_values(&mut self, input: &Tensor) -> tenferro_tensor::Result<Tensor> {
61        dispatch_linalg!(self, svd_values(input))
62    }
63
64    fn svd_read(&mut self, input: TensorView<'_>) -> tenferro_tensor::Result<Vec<Tensor>> {
65        dispatch_linalg!(self, svd_read(input))
66    }
67
68    fn qr(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
69        dispatch_linalg!(self, qr(input))
70    }
71
72    fn eigh(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
73        dispatch_linalg!(self, eigh(input))
74    }
75
76    fn eigh_values(&mut self, input: &Tensor) -> tenferro_tensor::Result<Tensor> {
77        dispatch_linalg!(self, eigh_values(input))
78    }
79
80    fn eig(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
81        dispatch_linalg!(self, eig(input))
82    }
83
84    fn solve(&mut self, a: &Tensor, b: &Tensor) -> tenferro_tensor::Result<Tensor> {
85        dispatch_linalg!(self, solve(a, b))
86    }
87
88    fn lu_solve_prepared(
89        &mut self,
90        a: &Tensor,
91        packed_lu: &Tensor,
92        pivots: &Tensor,
93        b: &Tensor,
94        transpose_a: bool,
95        conjugate_a: bool,
96    ) -> tenferro_tensor::Result<Tensor> {
97        dispatch_linalg!(
98            self,
99            lu_solve_prepared(a, packed_lu, pivots, b, transpose_a, conjugate_a)
100        )
101    }
102}