tenferro_linalg/eager_ext.rs
1use std::sync::Arc;
2
3use tenferro_ad::error::{Error, Result};
4use tenferro_ad::extension::apply_eager;
5use tenferro_ad::EagerTensor;
6
7use crate::extension::{LinalgExtensionOp, LinalgOp};
8use crate::register_runtime;
9
10/// Linear algebra extension methods for [`EagerTensor`].
11pub trait EagerTensorLinalgExt {
12 fn svd(&self) -> Result<(EagerTensor, EagerTensor, EagerTensor)>;
13 fn qr(&self) -> Result<(EagerTensor, EagerTensor)>;
14 fn lu(&self) -> Result<(EagerTensor, EagerTensor, EagerTensor, EagerTensor)>;
15 fn full_piv_lu(
16 &self,
17 ) -> Result<(
18 EagerTensor,
19 EagerTensor,
20 EagerTensor,
21 EagerTensor,
22 EagerTensor,
23 )>;
24 fn full_piv_lu_solve(&self, b: &EagerTensor) -> Result<EagerTensor>;
25 fn solve(&self, b: &EagerTensor) -> Result<EagerTensor>;
26 fn cholesky(&self) -> Result<EagerTensor>;
27 fn eigh(&self) -> Result<(EagerTensor, EagerTensor)>;
28 fn eig(&self) -> Result<(EagerTensor, EagerTensor)>;
29 fn triangular_solve(
30 &self,
31 b: &EagerTensor,
32 left_side: bool,
33 lower: bool,
34 transpose_a: bool,
35 unit_diagonal: bool,
36 ) -> Result<EagerTensor>;
37}
38
39impl EagerTensorLinalgExt for EagerTensor {
40 fn svd(&self) -> Result<(EagerTensor, EagerTensor, EagerTensor)> {
41 svd(self)
42 }
43
44 fn qr(&self) -> Result<(EagerTensor, EagerTensor)> {
45 qr(self)
46 }
47
48 fn lu(&self) -> Result<(EagerTensor, EagerTensor, EagerTensor, EagerTensor)> {
49 lu(self)
50 }
51
52 fn full_piv_lu(
53 &self,
54 ) -> Result<(
55 EagerTensor,
56 EagerTensor,
57 EagerTensor,
58 EagerTensor,
59 EagerTensor,
60 )> {
61 full_piv_lu(self)
62 }
63
64 fn full_piv_lu_solve(&self, b: &EagerTensor) -> Result<EagerTensor> {
65 full_piv_lu_solve(self, b)
66 }
67
68 fn solve(&self, b: &EagerTensor) -> Result<EagerTensor> {
69 solve(self, b)
70 }
71
72 fn cholesky(&self) -> Result<EagerTensor> {
73 cholesky(self)
74 }
75
76 fn eigh(&self) -> Result<(EagerTensor, EagerTensor)> {
77 eigh(self)
78 }
79
80 fn eig(&self) -> Result<(EagerTensor, EagerTensor)> {
81 eig(self)
82 }
83
84 fn triangular_solve(
85 &self,
86 b: &EagerTensor,
87 left_side: bool,
88 lower: bool,
89 transpose_a: bool,
90 unit_diagonal: bool,
91 ) -> Result<EagerTensor> {
92 triangular_solve(self, b, left_side, lower, transpose_a, unit_diagonal)
93 }
94}
95
96fn apply_linalg_eager(op: LinalgOp, inputs: &[&EagerTensor]) -> Result<Vec<EagerTensor>> {
97 if let Some(first) = inputs.first() {
98 first
99 .runtime()
100 .register_extension(register_runtime)
101 .map_err(|err| Error::Internal(err.to_string()))?;
102 }
103 apply_eager(Arc::new(LinalgExtensionOp::new(op)), inputs)
104}
105
106/// Singular value decomposition for eager tensors.
107///
108/// # Examples
109///
110/// ```rust
111/// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
112/// use tenferro_linalg::EagerTensorLinalgExt;
113///
114/// let ctx = EagerRuntime::new();
115/// let a = EagerTensor::from_tensor_in(
116/// Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 2.0]).unwrap(),
117/// ctx,
118/// ).unwrap();
119/// let (_u, s, _vt) = a.svd()?;
120/// assert_eq!(s.shape(), &[2]);
121/// # Ok::<(), tenferro_ad::Error>(())
122/// ```
123pub fn svd(a: &EagerTensor) -> Result<(EagerTensor, EagerTensor, EagerTensor)> {
124 let mut outputs = apply_linalg_eager(LinalgOp::Svd { eps: 0.0 }, &[a])?.into_iter();
125 match (
126 outputs.next(),
127 outputs.next(),
128 outputs.next(),
129 outputs.next(),
130 ) {
131 (Some(u), Some(s), Some(vt), None) => Ok((u, s, vt)),
132 _ => Err(Error::Internal(
133 "svd eager op returned an unexpected number of outputs".to_string(),
134 )),
135 }
136}
137
138/// QR decomposition for eager tensors.
139///
140/// # Examples
141///
142/// ```rust
143/// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
144/// use tenferro_linalg::EagerTensorLinalgExt;
145///
146/// let ctx = EagerRuntime::new();
147/// let a = EagerTensor::from_tensor_in(
148/// Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 1.0]).unwrap(),
149/// ctx,
150/// ).unwrap();
151/// let (q, r) = a.qr()?;
152/// assert_eq!(q.shape(), &[2, 2]);
153/// assert_eq!(r.shape(), &[2, 2]);
154/// # Ok::<(), tenferro_ad::Error>(())
155/// ```
156pub fn qr(a: &EagerTensor) -> Result<(EagerTensor, EagerTensor)> {
157 two_outputs(apply_linalg_eager(LinalgOp::Qr, &[a])?, "qr")
158}
159
160/// LU factorization for eager tensors.
161///
162/// # Examples
163///
164/// ```rust
165/// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
166/// use tenferro_linalg::EagerTensorLinalgExt;
167///
168/// let ctx = EagerRuntime::new();
169/// let a = EagerTensor::from_tensor_in(
170/// Tensor::from_vec_col_major(vec![2, 2], vec![0.0_f64, 1.0, 1.0, 0.0]).unwrap(),
171/// ctx,
172/// ).unwrap();
173/// let (_p, l, u, parity) = a.lu()?;
174/// assert_eq!(l.shape(), &[2, 2]);
175/// assert_eq!(u.shape(), &[2, 2]);
176/// assert_eq!(parity.shape(), &[] as &[usize]);
177/// # Ok::<(), tenferro_ad::Error>(())
178/// ```
179pub fn lu(a: &EagerTensor) -> Result<(EagerTensor, EagerTensor, EagerTensor, EagerTensor)> {
180 let mut outputs = apply_linalg_eager(LinalgOp::Lu, &[a])?.into_iter();
181 match (
182 outputs.next(),
183 outputs.next(),
184 outputs.next(),
185 outputs.next(),
186 outputs.next(),
187 ) {
188 (Some(p), Some(l), Some(u), Some(parity), None) => Ok((p, l, u, parity)),
189 _ => Err(Error::Internal(
190 "lu eager op returned an unexpected number of outputs".to_string(),
191 )),
192 }
193}
194
195/// Complete-pivot LU factorization for eager tensors.
196///
197/// Returns `(P, L, U, Q, parity)` with reconstruction convention
198/// `A = P^T * L * U * Q`, equivalently `P * A * Q^T = L * U`. `parity` is a
199/// scalar real tensor containing `+1` or `-1`: `F32` for `F32`/`C32` inputs and
200/// `F64` for `F64`/`C64` inputs.
201///
202/// # Examples
203///
204/// ```rust
205/// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
206/// use tenferro_linalg::EagerTensorLinalgExt;
207///
208/// let ctx = EagerRuntime::new();
209/// let a = EagerTensor::from_tensor_in(
210/// Tensor::from_vec_col_major(vec![2, 2], vec![0.0_f64, 2.0, 1.0, 3.0]).unwrap(),
211/// ctx,
212/// ).unwrap();
213/// let (p, _l, _u, q, parity) = a.full_piv_lu()?;
214/// assert_eq!(p.shape(), &[2, 2]);
215/// assert_eq!(q.shape(), &[2, 2]);
216/// assert_eq!(parity.shape(), &[] as &[usize]);
217/// # Ok::<(), tenferro_ad::Error>(())
218/// ```
219pub fn full_piv_lu(
220 a: &EagerTensor,
221) -> Result<(
222 EagerTensor,
223 EagerTensor,
224 EagerTensor,
225 EagerTensor,
226 EagerTensor,
227)> {
228 let mut outputs = apply_linalg_eager(LinalgOp::FullPivLu, &[a])?.into_iter();
229 match (
230 outputs.next(),
231 outputs.next(),
232 outputs.next(),
233 outputs.next(),
234 outputs.next(),
235 outputs.next(),
236 ) {
237 (Some(p), Some(l), Some(u), Some(q), Some(parity), None) => Ok((p, l, u, q, parity)),
238 _ => Err(Error::Internal(
239 "full_piv_lu eager op returned an unexpected number of outputs".to_string(),
240 )),
241 }
242}
243
244/// Solve a linear system using complete-pivot LU behavior.
245///
246/// # Examples
247///
248/// ```rust
249/// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
250/// use tenferro_linalg::EagerTensorLinalgExt;
251///
252/// let ctx = EagerRuntime::new();
253/// let a = EagerTensor::from_tensor_in(
254/// Tensor::from_vec_col_major(vec![2, 2], vec![0.0_f64, 2.0, 1.0, 3.0]).unwrap(),
255/// ctx.clone(),
256/// ).unwrap();
257/// let b = EagerTensor::from_tensor_in(
258/// Tensor::from_vec_col_major(vec![2, 1], vec![-1.0_f64, 5.0]).unwrap(),
259/// ctx,
260/// ).unwrap();
261/// let x = a.full_piv_lu_solve(&b)?;
262/// assert_eq!(x.shape(), &[2, 1]);
263/// # Ok::<(), tenferro_ad::Error>(())
264/// ```
265pub fn full_piv_lu_solve(a: &EagerTensor, b: &EagerTensor) -> Result<EagerTensor> {
266 one_output(
267 apply_linalg_eager(LinalgOp::FullPivLuSolve { transpose_a: false }, &[a, b])?,
268 "full_piv_lu_solve",
269 )
270}
271
272/// Solve a linear system for eager tensors.
273///
274/// # Examples
275///
276/// ```rust
277/// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
278/// use tenferro_linalg::EagerTensorLinalgExt;
279///
280/// let ctx = EagerRuntime::new();
281/// let a = EagerTensor::from_tensor_in(
282/// Tensor::from_vec_col_major(vec![2, 2], vec![2.0_f64, 0.0, 0.0, 4.0]).unwrap(),
283/// ctx.clone(),
284/// ).unwrap();
285/// let b = EagerTensor::from_tensor_in(
286/// Tensor::from_vec_col_major(vec![2, 1], vec![4.0_f64, 8.0]).unwrap(),
287/// ctx,
288/// ).unwrap();
289/// let x = a.solve(&b)?;
290/// assert_eq!(x.shape(), &[2, 1]);
291/// # Ok::<(), tenferro_ad::Error>(())
292/// ```
293pub fn solve(a: &EagerTensor, b: &EagerTensor) -> Result<EagerTensor> {
294 let mut factor_outputs = apply_linalg_eager(LinalgOp::LuFactor, &[a])?.into_iter();
295 let (packed_lu, pivots) = match (
296 factor_outputs.next(),
297 factor_outputs.next(),
298 factor_outputs.next(),
299 factor_outputs.next(),
300 ) {
301 (Some(packed_lu), Some(pivots), Some(_parity), None) => (packed_lu, pivots),
302 _ => {
303 return Err(Error::Internal(
304 "lu_factor eager op returned an unexpected number of outputs".to_string(),
305 ));
306 }
307 };
308 one_output(
309 apply_linalg_eager(
310 LinalgOp::LuSolvePrepared {
311 transpose_a: false,
312 conjugate_a: false,
313 },
314 &[a, &packed_lu, &pivots, b],
315 )?,
316 "solve",
317 )
318}
319
320/// Cholesky factorization for eager tensors.
321///
322/// # Examples
323///
324/// ```rust
325/// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
326/// use tenferro_linalg::EagerTensorLinalgExt;
327///
328/// let ctx = EagerRuntime::new();
329/// let a = EagerTensor::from_tensor_in(
330/// Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 1.0]).unwrap(),
331/// ctx,
332/// ).unwrap();
333/// let l = a.cholesky()?;
334/// assert_eq!(l.shape(), &[2, 2]);
335/// # Ok::<(), tenferro_ad::Error>(())
336/// ```
337pub fn cholesky(a: &EagerTensor) -> Result<EagerTensor> {
338 one_output(apply_linalg_eager(LinalgOp::Cholesky, &[a])?, "cholesky")
339}
340
341/// Hermitian eigenvalue decomposition for eager tensors.
342///
343/// # Examples
344///
345/// ```rust
346/// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
347/// use tenferro_linalg::EagerTensorLinalgExt;
348///
349/// let ctx = EagerRuntime::new();
350/// let a = EagerTensor::from_tensor_in(
351/// Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 3.0]).unwrap(),
352/// ctx,
353/// ).unwrap();
354/// let (values, vectors) = a.eigh()?;
355/// assert_eq!(values.shape(), &[2]);
356/// assert_eq!(vectors.shape(), &[2, 2]);
357/// # Ok::<(), tenferro_ad::Error>(())
358/// ```
359pub fn eigh(a: &EagerTensor) -> Result<(EagerTensor, EagerTensor)> {
360 two_outputs(
361 apply_linalg_eager(LinalgOp::Eigh { eps: 0.0 }, &[a])?,
362 "eigh",
363 )
364}
365
366/// General eigenvalue decomposition for eager tensors.
367///
368/// # Examples
369///
370/// ```rust
371/// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
372/// use tenferro_linalg::EagerTensorLinalgExt;
373///
374/// let ctx = EagerRuntime::new();
375/// let a = EagerTensor::from_tensor_in(
376/// Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 3.0]).unwrap(),
377/// ctx,
378/// ).unwrap();
379/// let (values, vectors) = a.eig()?;
380/// assert_eq!(values.shape(), &[2]);
381/// assert_eq!(vectors.shape(), &[2, 2]);
382/// # Ok::<(), tenferro_ad::Error>(())
383/// ```
384pub fn eig(a: &EagerTensor) -> Result<(EagerTensor, EagerTensor)> {
385 two_outputs(
386 apply_linalg_eager(
387 LinalgOp::Eig {
388 input_dtype: a.dtype(),
389 },
390 &[a],
391 )?,
392 "eig",
393 )
394}
395
396/// Triangular solve for eager tensors.
397///
398/// # Examples
399///
400/// ```rust
401/// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
402/// use tenferro_linalg::EagerTensorLinalgExt;
403///
404/// let ctx = EagerRuntime::new();
405/// let a = EagerTensor::from_tensor_in(
406/// Tensor::from_vec_col_major(vec![2, 2], vec![2.0_f64, 0.0, 1.0, 3.0]).unwrap(),
407/// ctx.clone(),
408/// ).unwrap();
409/// let b = EagerTensor::from_tensor_in(
410/// Tensor::from_vec_col_major(vec![2, 1], vec![2.0_f64, 7.0]).unwrap(),
411/// ctx,
412/// ).unwrap();
413/// let x = a.triangular_solve(&b, true, true, false, false)?;
414/// assert_eq!(x.shape(), &[2, 1]);
415/// # Ok::<(), tenferro_ad::Error>(())
416/// ```
417pub fn triangular_solve(
418 a: &EagerTensor,
419 b: &EagerTensor,
420 left_side: bool,
421 lower: bool,
422 transpose_a: bool,
423 unit_diagonal: bool,
424) -> Result<EagerTensor> {
425 one_output(
426 apply_linalg_eager(
427 LinalgOp::TriangularSolve {
428 left_side,
429 lower,
430 transpose_a,
431 unit_diagonal,
432 },
433 &[a, b],
434 )?,
435 "triangular_solve",
436 )
437}
438
439fn one_output(outputs: Vec<EagerTensor>, name: &str) -> Result<EagerTensor> {
440 let mut outputs = outputs.into_iter();
441 match (outputs.next(), outputs.next()) {
442 (Some(output), None) => Ok(output),
443 _ => Err(Error::Internal(format!(
444 "{name} eager op returned an unexpected number of outputs"
445 ))),
446 }
447}
448
449fn two_outputs(outputs: Vec<EagerTensor>, name: &str) -> Result<(EagerTensor, EagerTensor)> {
450 let mut outputs = outputs.into_iter();
451 match (outputs.next(), outputs.next(), outputs.next()) {
452 (Some(lhs), Some(rhs), None) => Ok((lhs, rhs)),
453 _ => Err(Error::Internal(format!(
454 "{name} eager op returned an unexpected number of outputs"
455 ))),
456 }
457}