Skip to main content

tenferro_linalg/
eager_ext.rs

1use std::sync::Arc;
2
3use tenferro_ad::error::{Error, Result};
4use tenferro_ad::extension::apply_eager;
5use tenferro_ad::EagerTensor;
6
7use crate::extension::{LinalgExtensionOp, LinalgOp};
8use crate::register_runtime;
9
10/// Linear algebra extension methods for [`EagerTensor`].
11pub trait EagerTensorLinalgExt {
12    fn svd(&self) -> Result<(EagerTensor, EagerTensor, EagerTensor)>;
13    fn qr(&self) -> Result<(EagerTensor, EagerTensor)>;
14    fn lu(&self) -> Result<(EagerTensor, EagerTensor, EagerTensor, EagerTensor)>;
15    fn full_piv_lu(
16        &self,
17    ) -> Result<(
18        EagerTensor,
19        EagerTensor,
20        EagerTensor,
21        EagerTensor,
22        EagerTensor,
23    )>;
24    fn full_piv_lu_solve(&self, b: &EagerTensor) -> Result<EagerTensor>;
25    fn solve(&self, b: &EagerTensor) -> Result<EagerTensor>;
26    fn cholesky(&self) -> Result<EagerTensor>;
27    fn eigh(&self) -> Result<(EagerTensor, EagerTensor)>;
28    fn eig(&self) -> Result<(EagerTensor, EagerTensor)>;
29    fn triangular_solve(
30        &self,
31        b: &EagerTensor,
32        left_side: bool,
33        lower: bool,
34        transpose_a: bool,
35        unit_diagonal: bool,
36    ) -> Result<EagerTensor>;
37}
38
39impl EagerTensorLinalgExt for EagerTensor {
40    fn svd(&self) -> Result<(EagerTensor, EagerTensor, EagerTensor)> {
41        svd(self)
42    }
43
44    fn qr(&self) -> Result<(EagerTensor, EagerTensor)> {
45        qr(self)
46    }
47
48    fn lu(&self) -> Result<(EagerTensor, EagerTensor, EagerTensor, EagerTensor)> {
49        lu(self)
50    }
51
52    fn full_piv_lu(
53        &self,
54    ) -> Result<(
55        EagerTensor,
56        EagerTensor,
57        EagerTensor,
58        EagerTensor,
59        EagerTensor,
60    )> {
61        full_piv_lu(self)
62    }
63
64    fn full_piv_lu_solve(&self, b: &EagerTensor) -> Result<EagerTensor> {
65        full_piv_lu_solve(self, b)
66    }
67
68    fn solve(&self, b: &EagerTensor) -> Result<EagerTensor> {
69        solve(self, b)
70    }
71
72    fn cholesky(&self) -> Result<EagerTensor> {
73        cholesky(self)
74    }
75
76    fn eigh(&self) -> Result<(EagerTensor, EagerTensor)> {
77        eigh(self)
78    }
79
80    fn eig(&self) -> Result<(EagerTensor, EagerTensor)> {
81        eig(self)
82    }
83
84    fn triangular_solve(
85        &self,
86        b: &EagerTensor,
87        left_side: bool,
88        lower: bool,
89        transpose_a: bool,
90        unit_diagonal: bool,
91    ) -> Result<EagerTensor> {
92        triangular_solve(self, b, left_side, lower, transpose_a, unit_diagonal)
93    }
94}
95
96fn apply_linalg_eager(op: LinalgOp, inputs: &[&EagerTensor]) -> Result<Vec<EagerTensor>> {
97    if let Some(first) = inputs.first() {
98        first
99            .runtime()
100            .register_extension(register_runtime)
101            .map_err(|err| Error::Internal(err.to_string()))?;
102    }
103    apply_eager(Arc::new(LinalgExtensionOp::new(op)), inputs)
104}
105
106/// Singular value decomposition for eager tensors.
107///
108/// # Examples
109///
110/// ```rust
111/// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
112/// use tenferro_linalg::EagerTensorLinalgExt;
113///
114/// let ctx = EagerRuntime::new();
115/// let a = EagerTensor::from_tensor_in(
116///     Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 2.0]).unwrap(),
117///     ctx,
118/// ).unwrap();
119/// let (_u, s, _vt) = a.svd()?;
120/// assert_eq!(s.shape(), &[2]);
121/// # Ok::<(), tenferro_ad::Error>(())
122/// ```
123pub fn svd(a: &EagerTensor) -> Result<(EagerTensor, EagerTensor, EagerTensor)> {
124    let mut outputs = apply_linalg_eager(LinalgOp::Svd { eps: 0.0 }, &[a])?.into_iter();
125    match (
126        outputs.next(),
127        outputs.next(),
128        outputs.next(),
129        outputs.next(),
130    ) {
131        (Some(u), Some(s), Some(vt), None) => Ok((u, s, vt)),
132        _ => Err(Error::Internal(
133            "svd eager op returned an unexpected number of outputs".to_string(),
134        )),
135    }
136}
137
138/// QR decomposition for eager tensors.
139///
140/// # Examples
141///
142/// ```rust
143/// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
144/// use tenferro_linalg::EagerTensorLinalgExt;
145///
146/// let ctx = EagerRuntime::new();
147/// let a = EagerTensor::from_tensor_in(
148///     Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 1.0]).unwrap(),
149///     ctx,
150/// ).unwrap();
151/// let (q, r) = a.qr()?;
152/// assert_eq!(q.shape(), &[2, 2]);
153/// assert_eq!(r.shape(), &[2, 2]);
154/// # Ok::<(), tenferro_ad::Error>(())
155/// ```
156pub fn qr(a: &EagerTensor) -> Result<(EagerTensor, EagerTensor)> {
157    two_outputs(apply_linalg_eager(LinalgOp::Qr, &[a])?, "qr")
158}
159
160/// LU factorization for eager tensors.
161///
162/// # Examples
163///
164/// ```rust
165/// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
166/// use tenferro_linalg::EagerTensorLinalgExt;
167///
168/// let ctx = EagerRuntime::new();
169/// let a = EagerTensor::from_tensor_in(
170///     Tensor::from_vec_col_major(vec![2, 2], vec![0.0_f64, 1.0, 1.0, 0.0]).unwrap(),
171///     ctx,
172/// ).unwrap();
173/// let (_p, l, u, parity) = a.lu()?;
174/// assert_eq!(l.shape(), &[2, 2]);
175/// assert_eq!(u.shape(), &[2, 2]);
176/// assert_eq!(parity.shape(), &[] as &[usize]);
177/// # Ok::<(), tenferro_ad::Error>(())
178/// ```
179pub fn lu(a: &EagerTensor) -> Result<(EagerTensor, EagerTensor, EagerTensor, EagerTensor)> {
180    let mut outputs = apply_linalg_eager(LinalgOp::Lu, &[a])?.into_iter();
181    match (
182        outputs.next(),
183        outputs.next(),
184        outputs.next(),
185        outputs.next(),
186        outputs.next(),
187    ) {
188        (Some(p), Some(l), Some(u), Some(parity), None) => Ok((p, l, u, parity)),
189        _ => Err(Error::Internal(
190            "lu eager op returned an unexpected number of outputs".to_string(),
191        )),
192    }
193}
194
195/// Complete-pivot LU factorization for eager tensors.
196///
197/// Returns `(P, L, U, Q, parity)` with reconstruction convention
198/// `A = P^T * L * U * Q`, equivalently `P * A * Q^T = L * U`. `parity` is a
199/// scalar real tensor containing `+1` or `-1`: `F32` for `F32`/`C32` inputs and
200/// `F64` for `F64`/`C64` inputs.
201///
202/// # Examples
203///
204/// ```rust
205/// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
206/// use tenferro_linalg::EagerTensorLinalgExt;
207///
208/// let ctx = EagerRuntime::new();
209/// let a = EagerTensor::from_tensor_in(
210///     Tensor::from_vec_col_major(vec![2, 2], vec![0.0_f64, 2.0, 1.0, 3.0]).unwrap(),
211///     ctx,
212/// ).unwrap();
213/// let (p, _l, _u, q, parity) = a.full_piv_lu()?;
214/// assert_eq!(p.shape(), &[2, 2]);
215/// assert_eq!(q.shape(), &[2, 2]);
216/// assert_eq!(parity.shape(), &[] as &[usize]);
217/// # Ok::<(), tenferro_ad::Error>(())
218/// ```
219pub fn full_piv_lu(
220    a: &EagerTensor,
221) -> Result<(
222    EagerTensor,
223    EagerTensor,
224    EagerTensor,
225    EagerTensor,
226    EagerTensor,
227)> {
228    let mut outputs = apply_linalg_eager(LinalgOp::FullPivLu, &[a])?.into_iter();
229    match (
230        outputs.next(),
231        outputs.next(),
232        outputs.next(),
233        outputs.next(),
234        outputs.next(),
235        outputs.next(),
236    ) {
237        (Some(p), Some(l), Some(u), Some(q), Some(parity), None) => Ok((p, l, u, q, parity)),
238        _ => Err(Error::Internal(
239            "full_piv_lu eager op returned an unexpected number of outputs".to_string(),
240        )),
241    }
242}
243
244/// Solve a linear system using complete-pivot LU behavior.
245///
246/// # Examples
247///
248/// ```rust
249/// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
250/// use tenferro_linalg::EagerTensorLinalgExt;
251///
252/// let ctx = EagerRuntime::new();
253/// let a = EagerTensor::from_tensor_in(
254///     Tensor::from_vec_col_major(vec![2, 2], vec![0.0_f64, 2.0, 1.0, 3.0]).unwrap(),
255///     ctx.clone(),
256/// ).unwrap();
257/// let b = EagerTensor::from_tensor_in(
258///     Tensor::from_vec_col_major(vec![2, 1], vec![-1.0_f64, 5.0]).unwrap(),
259///     ctx,
260/// ).unwrap();
261/// let x = a.full_piv_lu_solve(&b)?;
262/// assert_eq!(x.shape(), &[2, 1]);
263/// # Ok::<(), tenferro_ad::Error>(())
264/// ```
265pub fn full_piv_lu_solve(a: &EagerTensor, b: &EagerTensor) -> Result<EagerTensor> {
266    one_output(
267        apply_linalg_eager(LinalgOp::FullPivLuSolve { transpose_a: false }, &[a, b])?,
268        "full_piv_lu_solve",
269    )
270}
271
272/// Solve a linear system for eager tensors.
273///
274/// # Examples
275///
276/// ```rust
277/// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
278/// use tenferro_linalg::EagerTensorLinalgExt;
279///
280/// let ctx = EagerRuntime::new();
281/// let a = EagerTensor::from_tensor_in(
282///     Tensor::from_vec_col_major(vec![2, 2], vec![2.0_f64, 0.0, 0.0, 4.0]).unwrap(),
283///     ctx.clone(),
284/// ).unwrap();
285/// let b = EagerTensor::from_tensor_in(
286///     Tensor::from_vec_col_major(vec![2, 1], vec![4.0_f64, 8.0]).unwrap(),
287///     ctx,
288/// ).unwrap();
289/// let x = a.solve(&b)?;
290/// assert_eq!(x.shape(), &[2, 1]);
291/// # Ok::<(), tenferro_ad::Error>(())
292/// ```
293pub fn solve(a: &EagerTensor, b: &EagerTensor) -> Result<EagerTensor> {
294    let mut factor_outputs = apply_linalg_eager(LinalgOp::LuFactor, &[a])?.into_iter();
295    let (packed_lu, pivots) = match (
296        factor_outputs.next(),
297        factor_outputs.next(),
298        factor_outputs.next(),
299        factor_outputs.next(),
300    ) {
301        (Some(packed_lu), Some(pivots), Some(_parity), None) => (packed_lu, pivots),
302        _ => {
303            return Err(Error::Internal(
304                "lu_factor eager op returned an unexpected number of outputs".to_string(),
305            ));
306        }
307    };
308    one_output(
309        apply_linalg_eager(
310            LinalgOp::LuSolvePrepared {
311                transpose_a: false,
312                conjugate_a: false,
313            },
314            &[a, &packed_lu, &pivots, b],
315        )?,
316        "solve",
317    )
318}
319
320/// Cholesky factorization for eager tensors.
321///
322/// # Examples
323///
324/// ```rust
325/// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
326/// use tenferro_linalg::EagerTensorLinalgExt;
327///
328/// let ctx = EagerRuntime::new();
329/// let a = EagerTensor::from_tensor_in(
330///     Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 1.0]).unwrap(),
331///     ctx,
332/// ).unwrap();
333/// let l = a.cholesky()?;
334/// assert_eq!(l.shape(), &[2, 2]);
335/// # Ok::<(), tenferro_ad::Error>(())
336/// ```
337pub fn cholesky(a: &EagerTensor) -> Result<EagerTensor> {
338    one_output(apply_linalg_eager(LinalgOp::Cholesky, &[a])?, "cholesky")
339}
340
341/// Hermitian eigenvalue decomposition for eager tensors.
342///
343/// # Examples
344///
345/// ```rust
346/// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
347/// use tenferro_linalg::EagerTensorLinalgExt;
348///
349/// let ctx = EagerRuntime::new();
350/// let a = EagerTensor::from_tensor_in(
351///     Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 3.0]).unwrap(),
352///     ctx,
353/// ).unwrap();
354/// let (values, vectors) = a.eigh()?;
355/// assert_eq!(values.shape(), &[2]);
356/// assert_eq!(vectors.shape(), &[2, 2]);
357/// # Ok::<(), tenferro_ad::Error>(())
358/// ```
359pub fn eigh(a: &EagerTensor) -> Result<(EagerTensor, EagerTensor)> {
360    two_outputs(
361        apply_linalg_eager(LinalgOp::Eigh { eps: 0.0 }, &[a])?,
362        "eigh",
363    )
364}
365
366/// General eigenvalue decomposition for eager tensors.
367///
368/// # Examples
369///
370/// ```rust
371/// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
372/// use tenferro_linalg::EagerTensorLinalgExt;
373///
374/// let ctx = EagerRuntime::new();
375/// let a = EagerTensor::from_tensor_in(
376///     Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 3.0]).unwrap(),
377///     ctx,
378/// ).unwrap();
379/// let (values, vectors) = a.eig()?;
380/// assert_eq!(values.shape(), &[2]);
381/// assert_eq!(vectors.shape(), &[2, 2]);
382/// # Ok::<(), tenferro_ad::Error>(())
383/// ```
384pub fn eig(a: &EagerTensor) -> Result<(EagerTensor, EagerTensor)> {
385    two_outputs(
386        apply_linalg_eager(
387            LinalgOp::Eig {
388                input_dtype: a.dtype(),
389            },
390            &[a],
391        )?,
392        "eig",
393    )
394}
395
396/// Triangular solve for eager tensors.
397///
398/// # Examples
399///
400/// ```rust
401/// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
402/// use tenferro_linalg::EagerTensorLinalgExt;
403///
404/// let ctx = EagerRuntime::new();
405/// let a = EagerTensor::from_tensor_in(
406///     Tensor::from_vec_col_major(vec![2, 2], vec![2.0_f64, 0.0, 1.0, 3.0]).unwrap(),
407///     ctx.clone(),
408/// ).unwrap();
409/// let b = EagerTensor::from_tensor_in(
410///     Tensor::from_vec_col_major(vec![2, 1], vec![2.0_f64, 7.0]).unwrap(),
411///     ctx,
412/// ).unwrap();
413/// let x = a.triangular_solve(&b, true, true, false, false)?;
414/// assert_eq!(x.shape(), &[2, 1]);
415/// # Ok::<(), tenferro_ad::Error>(())
416/// ```
417pub fn triangular_solve(
418    a: &EagerTensor,
419    b: &EagerTensor,
420    left_side: bool,
421    lower: bool,
422    transpose_a: bool,
423    unit_diagonal: bool,
424) -> Result<EagerTensor> {
425    one_output(
426        apply_linalg_eager(
427            LinalgOp::TriangularSolve {
428                left_side,
429                lower,
430                transpose_a,
431                unit_diagonal,
432            },
433            &[a, b],
434        )?,
435        "triangular_solve",
436    )
437}
438
439fn one_output(outputs: Vec<EagerTensor>, name: &str) -> Result<EagerTensor> {
440    let mut outputs = outputs.into_iter();
441    match (outputs.next(), outputs.next()) {
442        (Some(output), None) => Ok(output),
443        _ => Err(Error::Internal(format!(
444            "{name} eager op returned an unexpected number of outputs"
445        ))),
446    }
447}
448
449fn two_outputs(outputs: Vec<EagerTensor>, name: &str) -> Result<(EagerTensor, EagerTensor)> {
450    let mut outputs = outputs.into_iter();
451    match (outputs.next(), outputs.next(), outputs.next()) {
452        (Some(lhs), Some(rhs), None) => Ok((lhs, rhs)),
453        _ => Err(Error::Internal(format!(
454            "{name} eager op returned an unexpected number of outputs"
455        ))),
456    }
457}