Skip to main content

tenferro/
eager_ops_linalg.rs

1use tenferro_ops::dim_expr::DimExpr;
2use tenferro_ops::std_tensor_op::StdTensorOp;
3use tenferro_tensor::{DType, TensorBackend};
4
5use crate::eager::EagerTensor;
6use crate::error::{Error, Result};
7
8impl<B: TensorBackend> EagerTensor<B> {
9    /// Singular value decomposition: `A = U diag(S) Vh`.
10    ///
11    /// # Examples
12    ///
13    /// ```
14    /// use tenferro::{EagerTensor, Tensor};
15    ///
16    /// let a = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 2.0]));
17    /// let (u, s, vh) = a.svd().unwrap();
18    ///
19    /// assert_eq!(u.data().shape(), &[2, 2]);
20    /// assert_eq!(s.data().shape(), &[2]);
21    /// assert_eq!(vh.data().shape(), &[2, 2]);
22    /// ```
23    pub fn svd(&self) -> Result<(Self, Self, Self)> {
24        let mut outputs = self
25            .multi_output_unary_op(
26                StdTensorOp::Svd {
27                    eps: 0.0,
28                    input_shape: DimExpr::from_concrete(self.data.shape()),
29                },
30                3,
31            )?
32            .into_iter();
33        match (
34            outputs.next(),
35            outputs.next(),
36            outputs.next(),
37            outputs.next(),
38        ) {
39            (Some(u), Some(s), Some(vh), None) => Ok((u, s, vh)),
40            _ => Err(Error::Internal(
41                "svd eager op returned an unexpected number of outputs".to_string(),
42            )),
43        }
44    }
45
46    /// QR decomposition: `A = Q R`.
47    ///
48    /// # Examples
49    ///
50    /// ```
51    /// use tenferro::{EagerTensor, Tensor};
52    ///
53    /// let a = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 1.0]));
54    /// let (q, r) = a.qr().unwrap();
55    ///
56    /// assert_eq!(q.data().shape(), &[2, 2]);
57    /// assert_eq!(r.data().shape(), &[2, 2]);
58    /// ```
59    pub fn qr(&self) -> Result<(Self, Self)> {
60        let mut outputs = self
61            .multi_output_unary_op(
62                StdTensorOp::Qr {
63                    input_shape: DimExpr::from_concrete(self.data.shape()),
64                },
65                2,
66            )?
67            .into_iter();
68        match (outputs.next(), outputs.next(), outputs.next()) {
69            (Some(q), Some(r), None) => Ok((q, r)),
70            _ => Err(Error::Internal(
71                "qr eager op returned an unexpected number of outputs".to_string(),
72            )),
73        }
74    }
75
76    /// LU decomposition with partial pivoting: `P A = L U`.
77    ///
78    /// # Examples
79    ///
80    /// ```
81    /// use tenferro::{EagerTensor, Tensor};
82    ///
83    /// let a = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![0.0_f64, 1.0, 1.0, 0.0]));
84    /// let (p, l, u, parity) = a.lu().unwrap();
85    ///
86    /// assert_eq!(p.data().shape(), &[2, 2]);
87    /// assert_eq!(l.data().shape(), &[2, 2]);
88    /// assert_eq!(u.data().shape(), &[2, 2]);
89    /// assert_eq!(parity.data().shape(), &[] as &[usize]);
90    /// ```
91    pub fn lu(&self) -> Result<(Self, Self, Self, Self)> {
92        let mut outputs = self
93            .multi_output_unary_op(
94                StdTensorOp::Lu {
95                    input_shape: DimExpr::from_concrete(self.data.shape()),
96                },
97                4,
98            )?
99            .into_iter();
100        match (
101            outputs.next(),
102            outputs.next(),
103            outputs.next(),
104            outputs.next(),
105            outputs.next(),
106        ) {
107            (Some(p), Some(l), Some(u), Some(parity), None) => Ok((p, l, u, parity)),
108            _ => Err(Error::Internal(
109                "lu eager op returned an unexpected number of outputs".to_string(),
110            )),
111        }
112    }
113
114    /// Cholesky factorization: `A = L L^T` for real inputs.
115    ///
116    /// # Examples
117    ///
118    /// ```
119    /// use tenferro::{EagerTensor, Tensor};
120    ///
121    /// let a = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 1.0]));
122    /// let l = a.cholesky().unwrap();
123    ///
124    /// assert_eq!(l.data().shape(), &[2, 2]);
125    /// assert_eq!(l.data().as_slice::<f64>().unwrap(), &[1.0, 0.0, 0.0, 1.0]);
126    /// ```
127    pub fn cholesky(&self) -> Result<Self> {
128        self.unary_op(StdTensorOp::Cholesky {
129            input_shape: DimExpr::from_concrete(self.data.shape()),
130        })
131    }
132
133    /// Symmetric or Hermitian eigendecomposition: `A = V diag(W) V^T`.
134    ///
135    /// # Examples
136    ///
137    /// ```
138    /// use tenferro::{EagerTensor, Tensor};
139    ///
140    /// let a = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 3.0]));
141    /// let (values, vectors) = a.eigh().unwrap();
142    ///
143    /// assert_eq!(values.data().shape(), &[2]);
144    /// assert_eq!(vectors.data().shape(), &[2, 2]);
145    /// ```
146    pub fn eigh(&self) -> Result<(Self, Self)> {
147        let mut outputs = self
148            .multi_output_unary_op(
149                StdTensorOp::Eigh {
150                    eps: 0.0,
151                    input_shape: DimExpr::from_concrete(self.data.shape()),
152                },
153                2,
154            )?
155            .into_iter();
156        match (outputs.next(), outputs.next(), outputs.next()) {
157            (Some(values), Some(vectors), None) => Ok((values, vectors)),
158            _ => Err(Error::Internal(
159                "eigh eager op returned an unexpected number of outputs".to_string(),
160            )),
161        }
162    }
163
164    /// General eigendecomposition.
165    ///
166    /// # Examples
167    ///
168    /// ```
169    /// use tenferro::{EagerTensor, Tensor};
170    ///
171    /// let a = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 3.0]));
172    /// let (values, vectors) = a.eig().unwrap();
173    ///
174    /// assert_eq!(values.data().shape(), &[2]);
175    /// assert_eq!(vectors.data().shape(), &[2, 2]);
176    /// ```
177    pub fn eig(&self) -> Result<(Self, Self)> {
178        let input_dtype: DType = self.data.dtype();
179        let mut outputs = self
180            .multi_output_unary_op(
181                StdTensorOp::Eig {
182                    input_dtype,
183                    input_shape: DimExpr::from_concrete(self.data.shape()),
184                },
185                2,
186            )?
187            .into_iter();
188        match (outputs.next(), outputs.next(), outputs.next()) {
189            (Some(values), Some(vectors), None) => Ok((values, vectors)),
190            _ => Err(Error::Internal(
191                "eig eager op returned an unexpected number of outputs".to_string(),
192            )),
193        }
194    }
195
196    /// Solve a triangular linear system.
197    ///
198    /// # Examples
199    ///
200    /// ```
201    /// use tenferro::{EagerTensor, Tensor};
202    ///
203    /// let a = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![2.0_f64, 0.0, 0.0, 4.0]));
204    /// let b = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 1], vec![4.0_f64, 8.0]));
205    /// let x = a
206    ///     .triangular_solve(&b, true, true, false, false)
207    ///     .unwrap();
208    ///
209    /// assert_eq!(x.data().shape(), &[2, 1]);
210    /// assert_eq!(x.data().as_slice::<f64>().unwrap(), &[2.0, 2.0]);
211    /// ```
212    pub fn triangular_solve(
213        &self,
214        b: &Self,
215        left_side: bool,
216        lower: bool,
217        transpose_a: bool,
218        unit_diagonal: bool,
219    ) -> Result<Self> {
220        self.binary_op(
221            b,
222            StdTensorOp::TriangularSolve {
223                left_side,
224                lower,
225                transpose_a,
226                unit_diagonal,
227                lhs_shape: DimExpr::from_concrete(self.data.shape()),
228                rhs_shape: DimExpr::from_concrete(b.data.shape()),
229            },
230        )
231    }
232}