tenferro/eager_ops_linalg.rs
1use tenferro_ops::dim_expr::DimExpr;
2use tenferro_ops::std_tensor_op::StdTensorOp;
3use tenferro_tensor::{DType, TensorBackend};
4
5use crate::eager::EagerTensor;
6use crate::error::{Error, Result};
7
8impl<B: TensorBackend> EagerTensor<B> {
9 /// Singular value decomposition: `A = U diag(S) Vh`.
10 ///
11 /// # Examples
12 ///
13 /// ```
14 /// use tenferro::{EagerTensor, Tensor};
15 ///
16 /// let a = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 2.0]));
17 /// let (u, s, vh) = a.svd().unwrap();
18 ///
19 /// assert_eq!(u.data().shape(), &[2, 2]);
20 /// assert_eq!(s.data().shape(), &[2]);
21 /// assert_eq!(vh.data().shape(), &[2, 2]);
22 /// ```
23 pub fn svd(&self) -> Result<(Self, Self, Self)> {
24 let mut outputs = self
25 .multi_output_unary_op(
26 StdTensorOp::Svd {
27 eps: 0.0,
28 input_shape: DimExpr::from_concrete(self.data.shape()),
29 },
30 3,
31 )?
32 .into_iter();
33 match (
34 outputs.next(),
35 outputs.next(),
36 outputs.next(),
37 outputs.next(),
38 ) {
39 (Some(u), Some(s), Some(vh), None) => Ok((u, s, vh)),
40 _ => Err(Error::Internal(
41 "svd eager op returned an unexpected number of outputs".to_string(),
42 )),
43 }
44 }
45
46 /// QR decomposition: `A = Q R`.
47 ///
48 /// # Examples
49 ///
50 /// ```
51 /// use tenferro::{EagerTensor, Tensor};
52 ///
53 /// let a = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 1.0]));
54 /// let (q, r) = a.qr().unwrap();
55 ///
56 /// assert_eq!(q.data().shape(), &[2, 2]);
57 /// assert_eq!(r.data().shape(), &[2, 2]);
58 /// ```
59 pub fn qr(&self) -> Result<(Self, Self)> {
60 let mut outputs = self
61 .multi_output_unary_op(
62 StdTensorOp::Qr {
63 input_shape: DimExpr::from_concrete(self.data.shape()),
64 },
65 2,
66 )?
67 .into_iter();
68 match (outputs.next(), outputs.next(), outputs.next()) {
69 (Some(q), Some(r), None) => Ok((q, r)),
70 _ => Err(Error::Internal(
71 "qr eager op returned an unexpected number of outputs".to_string(),
72 )),
73 }
74 }
75
76 /// LU decomposition with partial pivoting: `P A = L U`.
77 ///
78 /// # Examples
79 ///
80 /// ```
81 /// use tenferro::{EagerTensor, Tensor};
82 ///
83 /// let a = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![0.0_f64, 1.0, 1.0, 0.0]));
84 /// let (p, l, u, parity) = a.lu().unwrap();
85 ///
86 /// assert_eq!(p.data().shape(), &[2, 2]);
87 /// assert_eq!(l.data().shape(), &[2, 2]);
88 /// assert_eq!(u.data().shape(), &[2, 2]);
89 /// assert_eq!(parity.data().shape(), &[] as &[usize]);
90 /// ```
91 pub fn lu(&self) -> Result<(Self, Self, Self, Self)> {
92 let mut outputs = self
93 .multi_output_unary_op(
94 StdTensorOp::Lu {
95 input_shape: DimExpr::from_concrete(self.data.shape()),
96 },
97 4,
98 )?
99 .into_iter();
100 match (
101 outputs.next(),
102 outputs.next(),
103 outputs.next(),
104 outputs.next(),
105 outputs.next(),
106 ) {
107 (Some(p), Some(l), Some(u), Some(parity), None) => Ok((p, l, u, parity)),
108 _ => Err(Error::Internal(
109 "lu eager op returned an unexpected number of outputs".to_string(),
110 )),
111 }
112 }
113
114 /// Cholesky factorization: `A = L L^T` for real inputs.
115 ///
116 /// # Examples
117 ///
118 /// ```
119 /// use tenferro::{EagerTensor, Tensor};
120 ///
121 /// let a = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 1.0]));
122 /// let l = a.cholesky().unwrap();
123 ///
124 /// assert_eq!(l.data().shape(), &[2, 2]);
125 /// assert_eq!(l.data().as_slice::<f64>().unwrap(), &[1.0, 0.0, 0.0, 1.0]);
126 /// ```
127 pub fn cholesky(&self) -> Result<Self> {
128 self.unary_op(StdTensorOp::Cholesky {
129 input_shape: DimExpr::from_concrete(self.data.shape()),
130 })
131 }
132
133 /// Symmetric or Hermitian eigendecomposition: `A = V diag(W) V^T`.
134 ///
135 /// # Examples
136 ///
137 /// ```
138 /// use tenferro::{EagerTensor, Tensor};
139 ///
140 /// let a = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 3.0]));
141 /// let (values, vectors) = a.eigh().unwrap();
142 ///
143 /// assert_eq!(values.data().shape(), &[2]);
144 /// assert_eq!(vectors.data().shape(), &[2, 2]);
145 /// ```
146 pub fn eigh(&self) -> Result<(Self, Self)> {
147 let mut outputs = self
148 .multi_output_unary_op(
149 StdTensorOp::Eigh {
150 eps: 0.0,
151 input_shape: DimExpr::from_concrete(self.data.shape()),
152 },
153 2,
154 )?
155 .into_iter();
156 match (outputs.next(), outputs.next(), outputs.next()) {
157 (Some(values), Some(vectors), None) => Ok((values, vectors)),
158 _ => Err(Error::Internal(
159 "eigh eager op returned an unexpected number of outputs".to_string(),
160 )),
161 }
162 }
163
164 /// General eigendecomposition.
165 ///
166 /// # Examples
167 ///
168 /// ```
169 /// use tenferro::{EagerTensor, Tensor};
170 ///
171 /// let a = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 3.0]));
172 /// let (values, vectors) = a.eig().unwrap();
173 ///
174 /// assert_eq!(values.data().shape(), &[2]);
175 /// assert_eq!(vectors.data().shape(), &[2, 2]);
176 /// ```
177 pub fn eig(&self) -> Result<(Self, Self)> {
178 let input_dtype: DType = self.data.dtype();
179 let mut outputs = self
180 .multi_output_unary_op(
181 StdTensorOp::Eig {
182 input_dtype,
183 input_shape: DimExpr::from_concrete(self.data.shape()),
184 },
185 2,
186 )?
187 .into_iter();
188 match (outputs.next(), outputs.next(), outputs.next()) {
189 (Some(values), Some(vectors), None) => Ok((values, vectors)),
190 _ => Err(Error::Internal(
191 "eig eager op returned an unexpected number of outputs".to_string(),
192 )),
193 }
194 }
195
196 /// Solve a triangular linear system.
197 ///
198 /// # Examples
199 ///
200 /// ```
201 /// use tenferro::{EagerTensor, Tensor};
202 ///
203 /// let a = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![2.0_f64, 0.0, 0.0, 4.0]));
204 /// let b = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 1], vec![4.0_f64, 8.0]));
205 /// let x = a
206 /// .triangular_solve(&b, true, true, false, false)
207 /// .unwrap();
208 ///
209 /// assert_eq!(x.data().shape(), &[2, 1]);
210 /// assert_eq!(x.data().as_slice::<f64>().unwrap(), &[2.0, 2.0]);
211 /// ```
212 pub fn triangular_solve(
213 &self,
214 b: &Self,
215 left_side: bool,
216 lower: bool,
217 transpose_a: bool,
218 unit_diagonal: bool,
219 ) -> Result<Self> {
220 self.binary_op(
221 b,
222 StdTensorOp::TriangularSolve {
223 left_side,
224 lower,
225 transpose_a,
226 unit_diagonal,
227 lhs_shape: DimExpr::from_concrete(self.data.shape()),
228 rhs_shape: DimExpr::from_concrete(b.data.shape()),
229 },
230 )
231 }
232}