Skip to main content

tenferro_linalg/
traced.rs

1use std::sync::Arc;
2
3use num_complex::{Complex32, Complex64};
4use tenferro_runtime::extension::apply;
5use tenferro_runtime::{CompareDir, DType, DotGeneralConfig, Error, Result, TracedTensor};
6
7use crate::extension::{LinalgExtensionOp, LinalgOp};
8
9/// Linear algebra extension methods for [`TracedTensor`].
10pub trait TracedTensorLinalgExt {
11    fn svd(&self) -> Result<(TracedTensor, TracedTensor, TracedTensor)>;
12    fn svd_with_eps(&self, eps: f64) -> Result<(TracedTensor, TracedTensor, TracedTensor)>;
13    fn qr(&self) -> Result<(TracedTensor, TracedTensor)>;
14    fn eigh(&self) -> Result<(TracedTensor, TracedTensor)>;
15    fn eigh_with_eps(&self, eps: f64) -> Result<(TracedTensor, TracedTensor)>;
16    fn cholesky(&self) -> Result<TracedTensor>;
17    fn lu(&self) -> Result<(TracedTensor, TracedTensor, TracedTensor, TracedTensor)>;
18    fn full_piv_lu(
19        &self,
20    ) -> Result<(
21        TracedTensor,
22        TracedTensor,
23        TracedTensor,
24        TracedTensor,
25        TracedTensor,
26    )>;
27    fn eig(&self) -> Result<(TracedTensor, TracedTensor)>;
28    fn solve(&self, b: &TracedTensor) -> Result<TracedTensor>;
29    fn full_piv_lu_solve(&self, b: &TracedTensor) -> Result<TracedTensor>;
30    fn triangular_solve(
31        &self,
32        b: &TracedTensor,
33        left_side: bool,
34        lower: bool,
35        transpose_a: bool,
36        unit_diagonal: bool,
37    ) -> Result<TracedTensor>;
38    fn slogdet(&self) -> Result<(TracedTensor, TracedTensor)>;
39    fn det(&self) -> Result<TracedTensor>;
40    fn inv(&self) -> Result<TracedTensor>;
41    fn eigvalsh(&self) -> Result<TracedTensor>;
42    fn eigvals(&self) -> Result<TracedTensor>;
43    fn pinv(&self) -> Result<TracedTensor>;
44    fn pinv_with_rtol(&self, rtol: f64) -> Result<TracedTensor>;
45    fn norm(&self, ord: Option<f64>, dim: Option<&[usize]>, keepdim: bool) -> Result<TracedTensor>;
46}
47
48impl TracedTensorLinalgExt for TracedTensor {
49    fn svd(&self) -> Result<(TracedTensor, TracedTensor, TracedTensor)> {
50        svd(self)
51    }
52
53    fn svd_with_eps(&self, eps: f64) -> Result<(TracedTensor, TracedTensor, TracedTensor)> {
54        svd_with_eps(self, eps)
55    }
56
57    fn qr(&self) -> Result<(TracedTensor, TracedTensor)> {
58        qr(self)
59    }
60
61    fn eigh(&self) -> Result<(TracedTensor, TracedTensor)> {
62        eigh(self)
63    }
64
65    fn eigh_with_eps(&self, eps: f64) -> Result<(TracedTensor, TracedTensor)> {
66        eigh_with_eps(self, eps)
67    }
68
69    fn cholesky(&self) -> Result<TracedTensor> {
70        cholesky(self)
71    }
72
73    fn lu(&self) -> Result<(TracedTensor, TracedTensor, TracedTensor, TracedTensor)> {
74        lu(self)
75    }
76
77    fn full_piv_lu(
78        &self,
79    ) -> Result<(
80        TracedTensor,
81        TracedTensor,
82        TracedTensor,
83        TracedTensor,
84        TracedTensor,
85    )> {
86        full_piv_lu(self)
87    }
88
89    fn eig(&self) -> Result<(TracedTensor, TracedTensor)> {
90        eig(self)
91    }
92
93    fn solve(&self, b: &TracedTensor) -> Result<TracedTensor> {
94        solve(self, b)
95    }
96
97    fn full_piv_lu_solve(&self, b: &TracedTensor) -> Result<TracedTensor> {
98        full_piv_lu_solve(self, b)
99    }
100
101    fn triangular_solve(
102        &self,
103        b: &TracedTensor,
104        left_side: bool,
105        lower: bool,
106        transpose_a: bool,
107        unit_diagonal: bool,
108    ) -> Result<TracedTensor> {
109        triangular_solve(self, b, left_side, lower, transpose_a, unit_diagonal)
110    }
111
112    fn slogdet(&self) -> Result<(TracedTensor, TracedTensor)> {
113        slogdet(self)
114    }
115
116    fn det(&self) -> Result<TracedTensor> {
117        det(self)
118    }
119
120    fn inv(&self) -> Result<TracedTensor> {
121        inv(self)
122    }
123
124    fn eigvalsh(&self) -> Result<TracedTensor> {
125        eigvalsh(self)
126    }
127
128    fn eigvals(&self) -> Result<TracedTensor> {
129        eigvals(self)
130    }
131
132    fn pinv(&self) -> Result<TracedTensor> {
133        pinv(self)
134    }
135
136    fn pinv_with_rtol(&self, rtol: f64) -> Result<TracedTensor> {
137        pinv_with_rtol(self, rtol)
138    }
139
140    fn norm(&self, ord: Option<f64>, dim: Option<&[usize]>, keepdim: bool) -> Result<TracedTensor> {
141        norm(self, ord, dim, keepdim)
142    }
143}
144
145/// Build a traced singular value decomposition op using the default epsilon.
146///
147/// # Examples
148///
149/// ```
150/// use tenferro_linalg::TracedTensorLinalgExt;
151/// use tenferro_runtime::TracedTensor;
152///
153/// let a = TracedTensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 1.0]).unwrap();
154/// let (u, s, vt) = a.svd().unwrap();
155/// assert_eq!(u.rank, 2);
156/// assert_eq!(s.rank, 1);
157/// assert_eq!(vt.rank, 2);
158/// ```
159pub fn svd(a: &TracedTensor) -> Result<(TracedTensor, TracedTensor, TracedTensor)> {
160    svd_with_eps(a, 1e-12)
161}
162
163/// Build a traced singular value decomposition op with an explicit epsilon.
164///
165/// # Examples
166///
167/// ```
168/// use tenferro_linalg::TracedTensorLinalgExt;
169/// use tenferro_runtime::TracedTensor;
170///
171/// let a = TracedTensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 1.0]).unwrap();
172/// let (_u, s, _vt) = a.svd_with_eps(1e-10).unwrap();
173/// assert_eq!(s.rank, 1);
174/// ```
175pub fn svd_with_eps(
176    a: &TracedTensor,
177    eps: f64,
178) -> Result<(TracedTensor, TracedTensor, TracedTensor)> {
179    three_outputs(
180        apply(
181            Arc::new(LinalgExtensionOp::new(LinalgOp::Svd { eps })),
182            &[a],
183        )?,
184        "svd",
185    )
186}
187
188/// Build a traced QR decomposition op.
189///
190/// # Examples
191///
192/// ```
193/// use tenferro_linalg::TracedTensorLinalgExt;
194/// use tenferro_runtime::TracedTensor;
195///
196/// let a = TracedTensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 1.0]).unwrap();
197/// let (q, r) = a.qr().unwrap();
198/// assert_eq!(q.rank, 2);
199/// assert_eq!(r.rank, 2);
200/// ```
201pub fn qr(a: &TracedTensor) -> Result<(TracedTensor, TracedTensor)> {
202    two_outputs(
203        apply(Arc::new(LinalgExtensionOp::new(LinalgOp::Qr)), &[a])?,
204        "qr",
205    )
206}
207
208/// Build a traced Hermitian eigenvalue decomposition op using the default epsilon.
209///
210/// # Examples
211///
212/// ```
213/// use tenferro_linalg::TracedTensorLinalgExt;
214/// use tenferro_runtime::TracedTensor;
215///
216/// let a = TracedTensor::from_vec_col_major(vec![2, 2], vec![2.0_f64, 0.0, 0.0, 3.0]).unwrap();
217/// let (values, vectors) = a.eigh().unwrap();
218/// assert_eq!(values.rank, 1);
219/// assert_eq!(vectors.rank, 2);
220/// ```
221pub fn eigh(a: &TracedTensor) -> Result<(TracedTensor, TracedTensor)> {
222    eigh_with_eps(a, 1e-12)
223}
224
225/// Build a traced Hermitian eigenvalue decomposition op with an explicit epsilon.
226///
227/// # Examples
228///
229/// ```
230/// use tenferro_linalg::TracedTensorLinalgExt;
231/// use tenferro_runtime::TracedTensor;
232///
233/// let a = TracedTensor::from_vec_col_major(vec![2, 2], vec![2.0_f64, 0.0, 0.0, 3.0]).unwrap();
234/// let (values, _vectors) = a.eigh_with_eps(1e-10).unwrap();
235/// assert_eq!(values.rank, 1);
236/// ```
237pub fn eigh_with_eps(a: &TracedTensor, eps: f64) -> Result<(TracedTensor, TracedTensor)> {
238    two_outputs(
239        apply(
240            Arc::new(LinalgExtensionOp::new(LinalgOp::Eigh { eps })),
241            &[a],
242        )?,
243        "eigh",
244    )
245}
246
247/// Build a traced Cholesky decomposition op.
248///
249/// # Examples
250///
251/// ```
252/// use tenferro_linalg::TracedTensorLinalgExt;
253/// use tenferro_runtime::TracedTensor;
254///
255/// let a = TracedTensor::from_vec_col_major(vec![2, 2], vec![4.0_f64, 2.0, 2.0, 3.0]).unwrap();
256/// let factor = a.cholesky().unwrap();
257/// assert_eq!(factor.rank, 2);
258/// ```
259pub fn cholesky(a: &TracedTensor) -> Result<TracedTensor> {
260    one_output(
261        apply(Arc::new(LinalgExtensionOp::new(LinalgOp::Cholesky)), &[a])?,
262        "cholesky",
263    )
264}
265
266/// Build a traced LU decomposition op.
267///
268/// # Examples
269///
270/// ```
271/// use tenferro_linalg::TracedTensorLinalgExt;
272/// use tenferro_runtime::TracedTensor;
273///
274/// let a = TracedTensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 3.0, 2.0, 4.0]).unwrap();
275/// let (p, l, u, parity) = a.lu().unwrap();
276/// assert_eq!(p.rank, 2);
277/// assert_eq!(l.rank, 2);
278/// assert_eq!(u.rank, 2);
279/// assert_eq!(parity.rank, 0);
280/// ```
281pub fn lu(a: &TracedTensor) -> Result<(TracedTensor, TracedTensor, TracedTensor, TracedTensor)> {
282    four_outputs(
283        apply(Arc::new(LinalgExtensionOp::new(LinalgOp::Lu)), &[a])?,
284        "lu",
285    )
286}
287
288/// Build a traced full-pivot LU decomposition op.
289///
290/// Returns `(P, L, U, Q, parity)` with reconstruction convention
291/// `A = P^T * L * U * Q`, equivalently `P * A * Q^T = L * U`. `parity` is a
292/// scalar real tensor containing `+1` or `-1`: `F32` for `F32`/`C32` inputs and
293/// `F64` for `F64`/`C64` inputs.
294///
295/// # Examples
296///
297/// ```
298/// use tenferro_linalg::TracedTensorLinalgExt;
299/// use tenferro_runtime::TracedTensor;
300///
301/// let a = TracedTensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 3.0, 2.0, 4.0]).unwrap();
302/// let (p, l, u, q, parity) = a.full_piv_lu().unwrap();
303/// assert_eq!(p.rank, 2);
304/// assert_eq!(l.rank, 2);
305/// assert_eq!(u.rank, 2);
306/// assert_eq!(q.rank, 2);
307/// assert_eq!(parity.rank, 0);
308/// ```
309pub fn full_piv_lu(
310    a: &TracedTensor,
311) -> Result<(
312    TracedTensor,
313    TracedTensor,
314    TracedTensor,
315    TracedTensor,
316    TracedTensor,
317)> {
318    five_outputs(
319        apply(Arc::new(LinalgExtensionOp::new(LinalgOp::FullPivLu)), &[a])?,
320        "full_piv_lu",
321    )
322}
323
324/// Build a traced general eigendecomposition op.
325///
326/// # Examples
327///
328/// ```
329/// use tenferro_linalg::TracedTensorLinalgExt;
330/// use tenferro_runtime::TracedTensor;
331///
332/// let a = TracedTensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 2.0]).unwrap();
333/// let (values, vectors) = a.eig().unwrap();
334/// assert_eq!(values.rank, 1);
335/// assert_eq!(vectors.rank, 2);
336/// ```
337pub fn eig(a: &TracedTensor) -> Result<(TracedTensor, TracedTensor)> {
338    two_outputs(
339        apply(
340            Arc::new(LinalgExtensionOp::new(LinalgOp::Eig {
341                input_dtype: a.dtype,
342            })),
343            &[a],
344        )?,
345        "eig",
346    )
347}
348
349/// Build a traced linear solve op.
350///
351/// # Examples
352///
353/// ```
354/// use tenferro_linalg::TracedTensorLinalgExt;
355/// use tenferro_runtime::TracedTensor;
356///
357/// let a = TracedTensor::from_vec_col_major(vec![2, 2], vec![2.0_f64, 0.0, 0.0, 3.0]).unwrap();
358/// let b = TracedTensor::from_vec_col_major(vec![2, 1], vec![4.0_f64, 9.0]).unwrap();
359/// let x = a.solve(&b).unwrap();
360/// assert_eq!(x.rank, 2);
361/// ```
362pub fn solve(a: &TracedTensor, b: &TracedTensor) -> Result<TracedTensor> {
363    let mut factor_outputs =
364        apply(Arc::new(LinalgExtensionOp::new(LinalgOp::LuFactor)), &[a])?.into_iter();
365    let (packed_lu, pivots) = match (
366        factor_outputs.next(),
367        factor_outputs.next(),
368        factor_outputs.next(),
369        factor_outputs.next(),
370    ) {
371        (Some(packed_lu), Some(pivots), Some(_parity), None) => (packed_lu, pivots),
372        _ => return Err(unexpected_output_count("lu_factor", 3)),
373    };
374    one_output(
375        apply(
376            Arc::new(LinalgExtensionOp::new(LinalgOp::LuSolvePrepared {
377                transpose_a: false,
378                conjugate_a: false,
379            })),
380            &[a, &packed_lu, &pivots, b],
381        )?,
382        "solve",
383    )
384}
385
386/// Build a traced full-pivot LU solve op.
387///
388/// # Examples
389///
390/// ```
391/// use tenferro_linalg::TracedTensorLinalgExt;
392/// use tenferro_runtime::TracedTensor;
393///
394/// let a = TracedTensor::from_vec_col_major(vec![2, 2], vec![2.0_f64, 0.0, 0.0, 3.0]).unwrap();
395/// let b = TracedTensor::from_vec_col_major(vec![2, 1], vec![4.0_f64, 9.0]).unwrap();
396/// let x = a.full_piv_lu_solve(&b).unwrap();
397/// assert_eq!(x.rank, 2);
398/// ```
399pub fn full_piv_lu_solve(a: &TracedTensor, b: &TracedTensor) -> Result<TracedTensor> {
400    one_output(
401        apply(
402            Arc::new(LinalgExtensionOp::new(LinalgOp::FullPivLuSolve {
403                transpose_a: false,
404            })),
405            &[a, b],
406        )?,
407        "full_piv_lu_solve",
408    )
409}
410
411/// Build a traced triangular solve op.
412///
413/// # Examples
414///
415/// ```
416/// use tenferro_linalg::TracedTensorLinalgExt;
417/// use tenferro_runtime::TracedTensor;
418///
419/// let a = TracedTensor::from_vec_col_major(vec![2, 2], vec![2.0_f64, 0.0, 1.0, 3.0]).unwrap();
420/// let b = TracedTensor::from_vec_col_major(vec![2, 1], vec![4.0_f64, 9.0]).unwrap();
421/// let x = a.triangular_solve(&b, true, true, false, false).unwrap();
422/// assert_eq!(x.rank, 2);
423/// ```
424pub fn triangular_solve(
425    a: &TracedTensor,
426    b: &TracedTensor,
427    left_side: bool,
428    lower: bool,
429    transpose_a: bool,
430    unit_diagonal: bool,
431) -> Result<TracedTensor> {
432    one_output(
433        apply(
434            Arc::new(LinalgExtensionOp::new(LinalgOp::TriangularSolve {
435                left_side,
436                lower,
437                transpose_a,
438                unit_diagonal,
439            })),
440            &[a, b],
441        )?,
442        "triangular_solve",
443    )
444}
445
446/// Build traced sign and log-absolute-determinant ops.
447///
448/// # Examples
449///
450/// ```
451/// use tenferro_linalg::TracedTensorLinalgExt;
452/// use tenferro_runtime::TracedTensor;
453///
454/// let a = TracedTensor::from_vec_col_major(vec![2, 2], vec![2.0_f64, 0.0, 0.0, 3.0]).unwrap();
455/// let (sign, logabsdet) = a.slogdet().unwrap();
456/// assert_eq!(sign.rank, 0);
457/// assert_eq!(logabsdet.rank, 0);
458/// ```
459pub fn slogdet(a: &TracedTensor) -> Result<(TracedTensor, TracedTensor)> {
460    let mut factor_outputs =
461        apply(Arc::new(LinalgExtensionOp::new(LinalgOp::LuFactor)), &[a])?.into_iter();
462    let (packed_lu, parity) = match (
463        factor_outputs.next(),
464        factor_outputs.next(),
465        factor_outputs.next(),
466        factor_outputs.next(),
467    ) {
468        (Some(packed_lu), Some(_pivots), Some(parity), None) => (packed_lu, parity),
469        _ => return Err(unexpected_output_count("lu_factor", 3)),
470    };
471    let diag_u = packed_lu.extract_diag(0, 1)?;
472    let sign_u = diag_u.sign().reduce_prod(&[0])?;
473    let sign = (&parity * &sign_u)?;
474    let logabsdet = diag_u.abs().log().reduce_sum(&[0])?;
475    Ok((sign, logabsdet))
476}
477
478/// Build a traced determinant op.
479///
480/// # Examples
481///
482/// ```
483/// use tenferro_linalg::TracedTensorLinalgExt;
484/// use tenferro_runtime::TracedTensor;
485///
486/// let a = TracedTensor::from_vec_col_major(vec![2, 2], vec![2.0_f64, 0.0, 0.0, 3.0]).unwrap();
487/// let determinant = a.det().unwrap();
488/// assert_eq!(determinant.rank, 0);
489/// ```
490pub fn det(a: &TracedTensor) -> Result<TracedTensor> {
491    let (sign, logabsdet) = slogdet(a)?;
492    &sign * &logabsdet.exp()
493}
494
495/// Build a traced matrix inverse op.
496///
497/// # Examples
498///
499/// ```
500/// use tenferro_linalg::TracedTensorLinalgExt;
501/// use tenferro_runtime::TracedTensor;
502///
503/// let a = TracedTensor::from_vec_col_major(vec![2, 2], vec![2.0_f64, 0.0, 0.0, 3.0]).unwrap();
504/// let inverse = a.inv().unwrap();
505/// assert_eq!(inverse.rank, 2);
506/// ```
507pub fn inv(a: &TracedTensor) -> Result<TracedTensor> {
508    ensure_min_rank("inv", a.rank, 2)?;
509    let shape = require_concrete_shape("inv", a)?;
510    let eye = eye_like(a, shape[0])?;
511    solve(a, &eye)
512}
513
514/// Build a traced Hermitian eigenvalue-only op.
515///
516/// # Examples
517///
518/// ```
519/// use tenferro_linalg::TracedTensorLinalgExt;
520/// use tenferro_runtime::TracedTensor;
521///
522/// let a = TracedTensor::from_vec_col_major(vec![2, 2], vec![2.0_f64, 0.0, 0.0, 3.0]).unwrap();
523/// let values = a.eigvalsh().unwrap();
524/// assert_eq!(values.rank, 1);
525/// ```
526pub fn eigvalsh(a: &TracedTensor) -> Result<TracedTensor> {
527    eigh_values(a)
528}
529
530/// Build a traced general eigenvalue-only op.
531///
532/// # Examples
533///
534/// ```
535/// use tenferro_linalg::TracedTensorLinalgExt;
536/// use tenferro_runtime::TracedTensor;
537///
538/// let a = TracedTensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 2.0]).unwrap();
539/// let values = a.eigvals().unwrap();
540/// assert_eq!(values.rank, 1);
541/// ```
542pub fn eigvals(a: &TracedTensor) -> Result<TracedTensor> {
543    eig_values(a)
544}
545
546/// Build a traced Moore-Penrose pseudoinverse op.
547///
548/// Floating-point and complex inputs are supported. Integer and boolean inputs
549/// return an unsupported-dtype error.
550///
551/// # Examples
552///
553/// ```
554/// use tenferro_linalg::TracedTensorLinalgExt;
555/// use tenferro_runtime::TracedTensor;
556///
557/// let a = TracedTensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 2.0]).unwrap();
558/// let inverse = a.pinv().unwrap();
559/// assert_eq!(inverse.rank, 2);
560/// ```
561pub fn pinv(a: &TracedTensor) -> Result<TracedTensor> {
562    ensure_float_or_complex("pinv", a.dtype)?;
563    let shape = require_concrete_shape("pinv", a)?;
564    let max_dim = match (shape.first(), shape.get(1)) {
565        (Some(&m), Some(&n)) => m.max(n),
566        (Some(&m), None) => m,
567        _ => 0,
568    };
569    pinv_with_rtol(a, default_pinv_rtol(a.dtype, max_dim))
570}
571
572/// Build a traced Moore-Penrose pseudoinverse op with an explicit relative tolerance.
573///
574/// Floating-point and complex inputs are supported. Integer and boolean inputs
575/// return an unsupported-dtype error.
576///
577/// # Examples
578///
579/// ```
580/// use tenferro_linalg::TracedTensorLinalgExt;
581/// use tenferro_runtime::TracedTensor;
582///
583/// let a = TracedTensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 2.0]).unwrap();
584/// let inverse = a.pinv_with_rtol(1e-12).unwrap();
585/// assert_eq!(inverse.rank, 2);
586/// ```
587pub fn pinv_with_rtol(a: &TracedTensor, rtol: f64) -> Result<TracedTensor> {
588    ensure_float_or_complex("pinv_with_rtol", a.dtype)?;
589    require_concrete_shape("pinv_with_rtol", a)?;
590    let (u, s, vt) = svd(a)?;
591    let abs_s = s.abs();
592    let s_max = abs_s.reduce_max(&[0])?;
593    let s_max_shape = s_max.concrete_shape()?;
594    let threshold_scalar = broadcast_scalar(scalar_real(s.dtype, rtol.max(0.0))?, &s_max_shape)?;
595    let threshold = (&s_max * &threshold_scalar)?;
596    let s_shape = s.concrete_shape()?;
597    let threshold = broadcast_batch_scalar_to_leading_axis(&threshold, &s_shape)?;
598    let mask = abs_s.compare(&threshold, CompareDir::Gt)?;
599    let mask = mask.convert(s.dtype)?;
600    let ones = ones_like(&s)?;
601    let denom = (&s + &(&ones + &(-&mask))?)?;
602    let s_inv = (&mask / &denom)?;
603
604    let v = vt.conj().transpose(&matrix_transpose_perm(vt.rank))?;
605    let uh = u.conj().transpose(&matrix_transpose_perm(u.rank))?;
606    let vs = scale_matrix_columns(&v, &s_inv)?;
607    matmul_preserve_trailing_batch(&vs, &uh)
608}
609
610/// Build a traced vector, matrix, or tensor norm op.
611///
612/// Floating-point and complex inputs are supported. Integer and boolean inputs
613/// return an unsupported-dtype error.
614///
615/// # Examples
616///
617/// ```
618/// use tenferro_linalg::TracedTensorLinalgExt;
619/// use tenferro_runtime::TracedTensor;
620///
621/// let x = TracedTensor::from_vec_col_major(vec![3], vec![1.0_f64, 2.0, 3.0]).unwrap();
622/// let length = x.norm(Some(2.0), Some(&[0]), false).unwrap();
623/// assert_eq!(length.rank, 0);
624/// ```
625pub fn norm(
626    a: &TracedTensor,
627    ord: Option<f64>,
628    dim: Option<&[usize]>,
629    keepdim: bool,
630) -> Result<TracedTensor> {
631    ensure_float_or_complex("norm", a.dtype)?;
632    let shape = require_concrete_shape("norm", a)?;
633    let axes = dim.map_or_else(|| (0..a.rank).collect::<Vec<_>>(), |dims| dims.to_vec());
634    if axes.is_empty() {
635        return Ok(a.clone());
636    }
637    validate_axes("norm", a.rank, &axes)?;
638
639    let out = match axes.len() {
640        1 => vector_norm(a, axes[0], ord)?,
641        2 => matrix_norm(a, &axes, ord)?,
642        _ => {
643            let abs = a.abs();
644            match ord {
645                None => frobenius_norm(&abs, &axes)?,
646                Some(p) if p == f64::INFINITY => abs.reduce_max(&axes)?,
647                Some(p) if p == f64::NEG_INFINITY => abs.reduce_min(&axes)?,
648                Some(0.0) => count_nonzero(&abs, &axes)?,
649                Some(p) => p_norm(&abs, &axes, p)?,
650            }
651        }
652    };
653    Ok(restore_keepdim(out, &shape, &axes, keepdim))
654}
655
656fn unexpected_output_count(name: &str, expected: usize) -> Error {
657    Error::Internal(format!("{name} must produce exactly {expected} outputs"))
658}
659
660fn one_output(outputs: Vec<TracedTensor>, name: &str) -> Result<TracedTensor> {
661    let mut outputs = outputs.into_iter();
662    match (outputs.next(), outputs.next()) {
663        (Some(output), None) => Ok(output),
664        _ => Err(unexpected_output_count(name, 1)),
665    }
666}
667
668fn two_outputs(outputs: Vec<TracedTensor>, name: &str) -> Result<(TracedTensor, TracedTensor)> {
669    let mut outputs = outputs.into_iter();
670    match (outputs.next(), outputs.next(), outputs.next()) {
671        (Some(lhs), Some(rhs), None) => Ok((lhs, rhs)),
672        _ => Err(unexpected_output_count(name, 2)),
673    }
674}
675
676fn three_outputs(
677    outputs: Vec<TracedTensor>,
678    name: &str,
679) -> Result<(TracedTensor, TracedTensor, TracedTensor)> {
680    let mut outputs = outputs.into_iter();
681    match (
682        outputs.next(),
683        outputs.next(),
684        outputs.next(),
685        outputs.next(),
686    ) {
687        (Some(first), Some(second), Some(third), None) => Ok((first, second, third)),
688        _ => Err(unexpected_output_count(name, 3)),
689    }
690}
691
692fn four_outputs(
693    outputs: Vec<TracedTensor>,
694    name: &str,
695) -> Result<(TracedTensor, TracedTensor, TracedTensor, TracedTensor)> {
696    let mut outputs = outputs.into_iter();
697    match (
698        outputs.next(),
699        outputs.next(),
700        outputs.next(),
701        outputs.next(),
702        outputs.next(),
703    ) {
704        (Some(first), Some(second), Some(third), Some(fourth), None) => {
705            Ok((first, second, third, fourth))
706        }
707        _ => Err(unexpected_output_count(name, 4)),
708    }
709}
710
711fn five_outputs(
712    outputs: Vec<TracedTensor>,
713    name: &str,
714) -> Result<(
715    TracedTensor,
716    TracedTensor,
717    TracedTensor,
718    TracedTensor,
719    TracedTensor,
720)> {
721    let mut outputs = outputs.into_iter();
722    match (
723        outputs.next(),
724        outputs.next(),
725        outputs.next(),
726        outputs.next(),
727        outputs.next(),
728        outputs.next(),
729    ) {
730        (Some(first), Some(second), Some(third), Some(fourth), Some(fifth), None) => {
731            Ok((first, second, third, fourth, fifth))
732        }
733        _ => Err(unexpected_output_count(name, 5)),
734    }
735}
736
737fn scalar_real(dtype: DType, value: f64) -> Result<TracedTensor> {
738    match dtype {
739        DType::F64 => TracedTensor::from_vec_col_major(vec![], vec![value]),
740        DType::F32 => TracedTensor::from_vec_col_major(vec![], vec![value as f32]),
741        DType::I32 => TracedTensor::from_vec_col_major(vec![], vec![value.round() as i32]),
742        DType::I64 => TracedTensor::from_vec_col_major(vec![], vec![value.round() as i64]),
743        DType::Bool => TracedTensor::from_vec_col_major(vec![], vec![value != 0.0]),
744        DType::C64 => TracedTensor::from_vec_col_major(vec![], vec![Complex64::new(value, 0.0)]),
745        DType::C32 => {
746            TracedTensor::from_vec_col_major(vec![], vec![Complex32::new(value as f32, 0.0)])
747        }
748    }
749}
750
751fn ensure_float_or_complex(op: &'static str, dtype: DType) -> Result<()> {
752    match dtype {
753        DType::F32 | DType::F64 | DType::C32 | DType::C64 => Ok(()),
754        DType::I32 | DType::I64 | DType::Bool => Err(Error::TensorRuntime(
755            tenferro_tensor::Error::backend_failure(op, format!("unsupported dtype {dtype:?}")),
756        )),
757    }
758}
759
760fn ensure_min_rank(op: &'static str, actual: usize, expected: usize) -> Result<()> {
761    if actual < expected {
762        return Err(Error::TensorRuntime(tenferro_tensor::Error::RankMismatch {
763            op,
764            expected,
765            actual,
766        }));
767    }
768    Ok(())
769}
770
771fn validate_axes(op: &'static str, rank: usize, axes: &[usize]) -> Result<()> {
772    for &axis in axes {
773        if axis >= rank {
774            return Err(Error::TensorRuntime(
775                tenferro_tensor::Error::AxisOutOfBounds { op, axis, rank },
776            ));
777        }
778    }
779    Ok(())
780}
781
782fn require_concrete_shape(op: &'static str, input: &TracedTensor) -> Result<Vec<usize>> {
783    input.try_concrete_shape().ok_or_else(|| {
784        Error::TensorRuntime(tenferro_tensor::Error::backend_failure(
785            op,
786            "symbolic shape is not supported by this traced linalg helper",
787        ))
788    })
789}
790
791fn zero_scalar(dtype: DType) -> Result<TracedTensor> {
792    scalar_real(dtype, 0.0)
793}
794
795fn one_scalar(dtype: DType) -> Result<TracedTensor> {
796    scalar_real(dtype, 1.0)
797}
798
799fn ones_like(input: &TracedTensor) -> Result<TracedTensor> {
800    let shape = input.concrete_shape()?;
801    broadcast_scalar(one_scalar(input.dtype)?, &shape)
802}
803
804fn eye_like(anchor: &TracedTensor, size: usize) -> Result<TracedTensor> {
805    let mut vector_shape = vec![size];
806    let anchor_shape = anchor.concrete_shape()?;
807    vector_shape.extend_from_slice(&anchor_shape[2..]);
808    let diagonal = broadcast_scalar(one_scalar(anchor.dtype)?, &vector_shape)?;
809    diagonal.embed_diag(0, 1)
810}
811
812fn broadcast_scalar(input: TracedTensor, shape: &[usize]) -> Result<TracedTensor> {
813    let input_shape = input.concrete_shape()?;
814    if input_shape == shape {
815        return Ok(input);
816    }
817    input.broadcast_in_dim(shape, &[])
818}
819
820fn broadcast_batch_scalar_to_leading_axis(
821    input: &TracedTensor,
822    shape: &[usize],
823) -> Result<TracedTensor> {
824    let input_shape = input.concrete_shape()?;
825    if input_shape == shape {
826        return Ok(input.clone());
827    }
828    let dims: Vec<usize> = (1..shape.len()).collect();
829    input.broadcast_in_dim(shape, &dims)
830}
831
832fn matmul_preserve_trailing_batch(lhs: &TracedTensor, rhs: &TracedTensor) -> Result<TracedTensor> {
833    let rank = lhs.rank;
834    let batch_dims: Vec<usize> = (2..rank).collect();
835    lhs.dot_general(
836        rhs,
837        DotGeneralConfig {
838            lhs_contracting_dims: vec![1],
839            rhs_contracting_dims: vec![0],
840            lhs_batch_dims: batch_dims.clone(),
841            rhs_batch_dims: batch_dims,
842        },
843    )
844}
845
846fn matrix_transpose_perm(rank: usize) -> Vec<usize> {
847    let mut perm: Vec<usize> = (0..rank).collect();
848    perm.swap(0, 1);
849    perm
850}
851
852fn frobenius_norm(abs: &TracedTensor, axes: &[usize]) -> Result<TracedTensor> {
853    let squared = abs.pow(&scalar_real(abs.dtype, 2.0)?)?;
854    Ok(squared.reduce_sum(axes)?.sqrt())
855}
856
857fn p_norm(abs: &TracedTensor, axes: &[usize], p: f64) -> Result<TracedTensor> {
858    let power = abs.pow(&scalar_real(abs.dtype, p)?)?;
859    let inv_p = scalar_real(abs.dtype, 1.0 / p)?;
860    power.reduce_sum(axes)?.pow(&inv_p)
861}
862
863fn default_pinv_rtol(dtype: DType, max_dim: usize) -> f64 {
864    let eps = match dtype {
865        DType::F32 | DType::C32 => f32::EPSILON as f64,
866        DType::F64 | DType::C64 => f64::EPSILON,
867        DType::I32 | DType::I64 | DType::Bool => 0.0,
868    };
869    eps * max_dim as f64
870}
871
872fn vector_norm(a: &TracedTensor, axis: usize, ord: Option<f64>) -> Result<TracedTensor> {
873    let abs = a.abs();
874    match ord {
875        None => frobenius_norm(&abs, &[axis]),
876        Some(0.0) => count_nonzero(&abs, &[axis]),
877        Some(p) if p == f64::INFINITY => abs.reduce_max(&[axis]),
878        Some(p) if p == f64::NEG_INFINITY => abs.reduce_min(&[axis]),
879        Some(p) => p_norm(&abs, &[axis], p),
880    }
881}
882
883fn matrix_norm(a: &TracedTensor, axes: &[usize], ord: Option<f64>) -> Result<TracedTensor> {
884    let matrix = move_axes_to_front(a, axes)?;
885    let abs = matrix.abs();
886    match ord {
887        None => frobenius_norm(&abs, &[0, 1]),
888        Some(p) if p == f64::INFINITY => matrix_row_sum_norm(&abs, true),
889        Some(p) if p == f64::NEG_INFINITY => matrix_row_sum_norm(&abs, false),
890        Some(1.0) => matrix_col_sum_norm(&abs, true),
891        Some(-1.0) => matrix_col_sum_norm(&abs, false),
892        Some(2.0) => {
893            let singular_values = svd_values(&matrix)?.abs();
894            singular_values.reduce_max(&[0])
895        }
896        Some(-2.0) => {
897            let singular_values = svd_values(&matrix)?.abs();
898            singular_values.reduce_min(&[0])
899        }
900        Some(0.0) => count_nonzero(&abs, &[0, 1]),
901        Some(p) => p_norm(&abs, &[0, 1], p),
902    }
903}
904
905fn svd_values(a: &TracedTensor) -> Result<TracedTensor> {
906    one_output(
907        apply(
908            Arc::new(LinalgExtensionOp::new(LinalgOp::SvdVals { eps: 1e-12 })),
909            &[a],
910        )?,
911        "svd_values",
912    )
913}
914
915fn eigh_values(a: &TracedTensor) -> Result<TracedTensor> {
916    one_output(
917        apply(
918            Arc::new(LinalgExtensionOp::new(LinalgOp::EighVals { eps: 1e-12 })),
919            &[a],
920        )?,
921        "eigh_values",
922    )
923}
924
925fn eig_values(a: &TracedTensor) -> Result<TracedTensor> {
926    one_output(
927        apply(
928            Arc::new(LinalgExtensionOp::new(LinalgOp::EigVals {
929                input_dtype: a.dtype,
930            })),
931            &[a],
932        )?,
933        "eig_values",
934    )
935}
936
937fn scale_matrix_columns(matrix: &TracedTensor, scale: &TracedTensor) -> Result<TracedTensor> {
938    let matrix_shape = matrix.concrete_shape()?;
939    let scale_shape_input = scale.concrete_shape()?;
940    let mut scale_shape = vec![1, scale_shape_input[0]];
941    scale_shape.extend_from_slice(&matrix_shape[2..]);
942    let dims: Vec<usize> = (0..matrix_shape.len()).collect();
943    let scale = scale
944        .reshape(&scale_shape)
945        .broadcast_in_dim(&matrix_shape, &dims)?;
946    matrix * &scale
947}
948
949fn count_nonzero(abs: &TracedTensor, axes: &[usize]) -> Result<TracedTensor> {
950    let mask = abs.compare(&zero_scalar(abs.dtype)?, CompareDir::Gt)?;
951    mask.convert(abs.dtype)?.reduce_sum(axes)
952}
953
954fn matrix_row_sum_norm(abs: &TracedTensor, take_max: bool) -> Result<TracedTensor> {
955    let row_sums = abs.reduce_sum(&[1])?;
956    if take_max {
957        row_sums.reduce_max(&[0])
958    } else {
959        row_sums.reduce_min(&[0])
960    }
961}
962
963fn matrix_col_sum_norm(abs: &TracedTensor, take_max: bool) -> Result<TracedTensor> {
964    let col_sums = abs.reduce_sum(&[0])?;
965    if take_max {
966        col_sums.reduce_max(&[0])
967    } else {
968        col_sums.reduce_min(&[0])
969    }
970}
971
972fn move_axes_to_front(tensor: &TracedTensor, axes: &[usize]) -> Result<TracedTensor> {
973    if axes.iter().enumerate().all(|(index, &axis)| index == axis) {
974        return Ok(tensor.clone());
975    }
976
977    let mut selected = vec![false; tensor.rank];
978    for &axis in axes {
979        selected[axis] = true;
980    }
981
982    let mut perm = Vec::with_capacity(tensor.rank);
983    perm.extend_from_slice(axes);
984    for (axis, is_selected) in selected.iter().enumerate().take(tensor.rank) {
985        if !*is_selected {
986            perm.push(axis);
987        }
988    }
989    tensor.transpose(&perm)
990}
991
992fn restore_keepdim(
993    reduced: TracedTensor,
994    original_shape: &[usize],
995    axes: &[usize],
996    keepdim: bool,
997) -> TracedTensor {
998    if !keepdim {
999        return reduced;
1000    }
1001    let mut kept_shape = original_shape.to_vec();
1002    for &axis in axes {
1003        kept_shape[axis] = 1;
1004    }
1005    reduced.reshape(&kept_shape)
1006}