tenferro_tensor/typed_linalg.rs
1use crate::{Error, Result, Tensor, TensorBackend, TensorScalar, TypedTensor};
2
3fn try_into_typed_result<T: TensorScalar>(
4 op: &'static str,
5 tensor: Tensor,
6) -> Result<TypedTensor<T>> {
7 let actual = tensor.dtype();
8 T::try_into_typed(tensor).ok_or_else(|| Error::DTypeMismatch {
9 op,
10 lhs: actual,
11 rhs: T::dtype(),
12 })
13}
14
15impl<T: TensorScalar> TypedTensor<T> {
16 /// Singular value decomposition: `A = U diag(S) Vt`.
17 ///
18 /// Returns `(U, S, Vt)` using the thin/economy SVD.
19 ///
20 /// For complex inputs, this wrapper expects the backend to return the
21 /// singular values as `TypedTensor<T::Real>`. If the backend still returns
22 /// a complex tensor for `S`, this method returns [`Error::DTypeMismatch`].
23 ///
24 /// # Examples
25 ///
26 /// ```
27 /// use tenferro_tensor::{cpu::CpuBackend, TensorBackend, TypedTensor};
28 ///
29 /// let mut ctx = CpuBackend::new();
30 /// let a = TypedTensor::<f64>::from_vec(vec![2, 2], vec![1.0, 0.0, 0.0, 2.0]);
31 /// let (u, s, vt) = a.svd(&mut ctx).unwrap();
32 ///
33 /// assert_eq!(u.shape, vec![2, 2]);
34 /// assert_eq!(s.shape, vec![2]);
35 /// assert_eq!(vt.shape, vec![2, 2]);
36 /// ```
37 pub fn svd(&self, ctx: &mut impl TensorBackend) -> Result<(Self, TypedTensor<T::Real>, Self)> {
38 let tensor = T::into_tensor(self.shape.clone(), self.host_data().to_vec());
39 let (u, s, vt) = tensor.svd(ctx)?;
40 Ok((
41 try_into_typed_result("svd", u)?,
42 try_into_typed_result("svd", s)?,
43 try_into_typed_result("svd", vt)?,
44 ))
45 }
46
47 /// QR decomposition: `A = Q R`.
48 ///
49 /// Returns `(Q, R)` using the thin/economy QR decomposition.
50 ///
51 /// # Examples
52 ///
53 /// ```
54 /// use tenferro_tensor::{cpu::CpuBackend, TensorBackend, TypedTensor};
55 ///
56 /// let mut ctx = CpuBackend::new();
57 /// let a = TypedTensor::<f64>::from_vec(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]);
58 /// let (q, r) = a.qr(&mut ctx).unwrap();
59 ///
60 /// assert_eq!(q.shape, vec![2, 2]);
61 /// assert_eq!(r.shape, vec![2, 2]);
62 /// ```
63 pub fn qr(&self, ctx: &mut impl TensorBackend) -> Result<(Self, Self)> {
64 let tensor = T::into_tensor(self.shape.clone(), self.host_data().to_vec());
65 let (q, r) = tensor.qr(ctx)?;
66 Ok((
67 try_into_typed_result("qr", q)?,
68 try_into_typed_result("qr", r)?,
69 ))
70 }
71
72 /// Cholesky factorization: `A = L L^T` or `A = L L^H`.
73 ///
74 /// Returns the lower-triangular factor `L`.
75 ///
76 /// # Examples
77 ///
78 /// ```
79 /// use tenferro_tensor::{cpu::CpuBackend, TensorBackend, TypedTensor};
80 ///
81 /// let mut ctx = CpuBackend::new();
82 /// let a = TypedTensor::<f64>::from_vec(vec![2, 2], vec![4.0, 1.0, 1.0, 3.0]);
83 /// let l = a.cholesky(&mut ctx).unwrap();
84 ///
85 /// assert_eq!(l.shape, vec![2, 2]);
86 /// ```
87 pub fn cholesky(&self, ctx: &mut impl TensorBackend) -> Result<Self> {
88 let tensor = T::into_tensor(self.shape.clone(), self.host_data().to_vec());
89 let factor = tensor.cholesky(ctx)?;
90 try_into_typed_result("cholesky", factor)
91 }
92
93 /// Symmetric or Hermitian eigendecomposition: `A = V diag(W) V^T`.
94 ///
95 /// Returns `(eigenvalues, eigenvectors)`.
96 ///
97 /// For complex inputs, this wrapper expects the backend to return the
98 /// eigenvalues as `TypedTensor<T::Real>`. If the backend still returns a
99 /// complex tensor for `W`, this method returns [`Error::DTypeMismatch`].
100 ///
101 /// # Examples
102 ///
103 /// ```
104 /// use tenferro_tensor::{cpu::CpuBackend, TensorBackend, TypedTensor};
105 ///
106 /// let mut ctx = CpuBackend::new();
107 /// let a = TypedTensor::<f64>::from_vec(vec![2, 2], vec![4.0, 1.0, 1.0, 3.0]);
108 /// let (w, v) = a.eigh(&mut ctx).unwrap();
109 ///
110 /// assert_eq!(w.shape, vec![2]);
111 /// assert_eq!(v.shape, vec![2, 2]);
112 /// ```
113 pub fn eigh(&self, ctx: &mut impl TensorBackend) -> Result<(TypedTensor<T::Real>, Self)> {
114 let tensor = T::into_tensor(self.shape.clone(), self.host_data().to_vec());
115 let (w, v) = tensor.eigh(ctx)?;
116 Ok((
117 try_into_typed_result("eigh", w)?,
118 try_into_typed_result("eigh", v)?,
119 ))
120 }
121}