Skip to main content

tenferro_tensor/
typed_linalg.rs

1use crate::{Error, Result, Tensor, TensorBackend, TensorScalar, TypedTensor};
2
3fn try_into_typed_result<T: TensorScalar>(
4    op: &'static str,
5    tensor: Tensor,
6) -> Result<TypedTensor<T>> {
7    let actual = tensor.dtype();
8    T::try_into_typed(tensor).ok_or_else(|| Error::DTypeMismatch {
9        op,
10        lhs: actual,
11        rhs: T::dtype(),
12    })
13}
14
15impl<T: TensorScalar> TypedTensor<T> {
16    /// Singular value decomposition: `A = U diag(S) Vt`.
17    ///
18    /// Returns `(U, S, Vt)` using the thin/economy SVD.
19    ///
20    /// For complex inputs, this wrapper expects the backend to return the
21    /// singular values as `TypedTensor<T::Real>`. If the backend still returns
22    /// a complex tensor for `S`, this method returns [`Error::DTypeMismatch`].
23    ///
24    /// # Examples
25    ///
26    /// ```
27    /// use tenferro_tensor::{cpu::CpuBackend, TensorBackend, TypedTensor};
28    ///
29    /// let mut ctx = CpuBackend::new();
30    /// let a = TypedTensor::<f64>::from_vec(vec![2, 2], vec![1.0, 0.0, 0.0, 2.0]);
31    /// let (u, s, vt) = a.svd(&mut ctx).unwrap();
32    ///
33    /// assert_eq!(u.shape, vec![2, 2]);
34    /// assert_eq!(s.shape, vec![2]);
35    /// assert_eq!(vt.shape, vec![2, 2]);
36    /// ```
37    pub fn svd(&self, ctx: &mut impl TensorBackend) -> Result<(Self, TypedTensor<T::Real>, Self)> {
38        let tensor = T::into_tensor(self.shape.clone(), self.host_data().to_vec());
39        let (u, s, vt) = tensor.svd(ctx)?;
40        Ok((
41            try_into_typed_result("svd", u)?,
42            try_into_typed_result("svd", s)?,
43            try_into_typed_result("svd", vt)?,
44        ))
45    }
46
47    /// QR decomposition: `A = Q R`.
48    ///
49    /// Returns `(Q, R)` using the thin/economy QR decomposition.
50    ///
51    /// # Examples
52    ///
53    /// ```
54    /// use tenferro_tensor::{cpu::CpuBackend, TensorBackend, TypedTensor};
55    ///
56    /// let mut ctx = CpuBackend::new();
57    /// let a = TypedTensor::<f64>::from_vec(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]);
58    /// let (q, r) = a.qr(&mut ctx).unwrap();
59    ///
60    /// assert_eq!(q.shape, vec![2, 2]);
61    /// assert_eq!(r.shape, vec![2, 2]);
62    /// ```
63    pub fn qr(&self, ctx: &mut impl TensorBackend) -> Result<(Self, Self)> {
64        let tensor = T::into_tensor(self.shape.clone(), self.host_data().to_vec());
65        let (q, r) = tensor.qr(ctx)?;
66        Ok((
67            try_into_typed_result("qr", q)?,
68            try_into_typed_result("qr", r)?,
69        ))
70    }
71
72    /// Cholesky factorization: `A = L L^T` or `A = L L^H`.
73    ///
74    /// Returns the lower-triangular factor `L`.
75    ///
76    /// # Examples
77    ///
78    /// ```
79    /// use tenferro_tensor::{cpu::CpuBackend, TensorBackend, TypedTensor};
80    ///
81    /// let mut ctx = CpuBackend::new();
82    /// let a = TypedTensor::<f64>::from_vec(vec![2, 2], vec![4.0, 1.0, 1.0, 3.0]);
83    /// let l = a.cholesky(&mut ctx).unwrap();
84    ///
85    /// assert_eq!(l.shape, vec![2, 2]);
86    /// ```
87    pub fn cholesky(&self, ctx: &mut impl TensorBackend) -> Result<Self> {
88        let tensor = T::into_tensor(self.shape.clone(), self.host_data().to_vec());
89        let factor = tensor.cholesky(ctx)?;
90        try_into_typed_result("cholesky", factor)
91    }
92
93    /// Symmetric or Hermitian eigendecomposition: `A = V diag(W) V^T`.
94    ///
95    /// Returns `(eigenvalues, eigenvectors)`.
96    ///
97    /// For complex inputs, this wrapper expects the backend to return the
98    /// eigenvalues as `TypedTensor<T::Real>`. If the backend still returns a
99    /// complex tensor for `W`, this method returns [`Error::DTypeMismatch`].
100    ///
101    /// # Examples
102    ///
103    /// ```
104    /// use tenferro_tensor::{cpu::CpuBackend, TensorBackend, TypedTensor};
105    ///
106    /// let mut ctx = CpuBackend::new();
107    /// let a = TypedTensor::<f64>::from_vec(vec![2, 2], vec![4.0, 1.0, 1.0, 3.0]);
108    /// let (w, v) = a.eigh(&mut ctx).unwrap();
109    ///
110    /// assert_eq!(w.shape, vec![2]);
111    /// assert_eq!(v.shape, vec![2, 2]);
112    /// ```
113    pub fn eigh(&self, ctx: &mut impl TensorBackend) -> Result<(TypedTensor<T::Real>, Self)> {
114        let tensor = T::into_tensor(self.shape.clone(), self.host_data().to_vec());
115        let (w, v) = tensor.eigh(ctx)?;
116        Ok((
117            try_into_typed_result("eigh", w)?,
118            try_into_typed_result("eigh", v)?,
119        ))
120    }
121}