Skip to main content

tenferro/
linalg_api.rs

1use num_complex::{Complex32, Complex64};
2use tenferro_ops::dim_expr::DimExpr;
3use tenferro_ops::std_tensor_op::StdTensorOp;
4use tenferro_tensor::{CompareDir, DType, DotGeneralConfig};
5
6use crate::sym_dim::SymDim;
7use crate::traced::{
8    apply_binary, apply_multi_output, apply_nullary, apply_unary, concrete_shape, TracedTensor,
9};
10
11/// Convert a traced tensor to a different dtype.
12///
13/// # Examples
14///
15/// ```rust,ignore
16/// use tenferro::DType;
17///
18/// let y = tenferro::convert(&x, DType::C64);
19/// ```
20pub fn convert(input: &TracedTensor, to: DType) -> TracedTensor {
21    input.convert(to)
22}
23
24fn sym_shape(shape: &[usize]) -> Vec<SymDim> {
25    shape.iter().copied().map(SymDim::from).collect()
26}
27
28fn input_shape_expr(tensor: &TracedTensor) -> Vec<DimExpr> {
29    tensor
30        .shape_hint
31        .as_ref()
32        .and_then(|shape| {
33            shape
34                .iter()
35                .map(SymDim::constant_value)
36                .collect::<Option<Vec<_>>>()
37        })
38        .map(|shape| DimExpr::from_concrete(&shape))
39        .unwrap_or_else(|| DimExpr::input_shape(0, tensor.rank))
40}
41
42/// Singular value decomposition with a default numerical epsilon.
43///
44/// # Examples
45///
46/// ```rust,ignore
47/// let (u, s, vt) = tenferro::svd(&a);
48/// ```
49pub fn svd(a: &TracedTensor) -> (TracedTensor, TracedTensor, TracedTensor) {
50    svd_with_eps(a, 1e-12)
51}
52
53/// Singular value decomposition with an explicit numerical epsilon.
54///
55/// # Examples
56///
57/// ```rust,ignore
58/// let (u, s, vt) = tenferro::svd_with_eps(&a, 1e-10);
59/// ```
60pub fn svd_with_eps(a: &TracedTensor, eps: f64) -> (TracedTensor, TracedTensor, TracedTensor) {
61    let shape = concrete_shape(a);
62    let m = shape[0];
63    let n = shape[1];
64    let k = m.min(n);
65    let batch = &shape[2..];
66    let op = StdTensorOp::Svd {
67        eps,
68        input_shape: input_shape_expr(a),
69    };
70    let mut u_shape = vec![m, k];
71    u_shape.extend_from_slice(batch);
72    let mut s_shape = vec![k];
73    s_shape.extend_from_slice(batch);
74    let mut vt_shape = vec![k, n];
75    vt_shape.extend_from_slice(batch);
76    let mut results = apply_multi_output(
77        op,
78        a,
79        vec![
80            sym_shape(&u_shape),
81            sym_shape(&s_shape),
82            sym_shape(&vt_shape),
83        ],
84    )
85    .into_iter();
86    match (
87        results.next(),
88        results.next(),
89        results.next(),
90        results.next(),
91    ) {
92        (Some(u), Some(s), Some(vt), None) => (u, s, vt),
93        _ => unreachable!("svd must produce exactly three outputs"),
94    }
95}
96
97/// QR decomposition.
98///
99/// # Examples
100///
101/// ```rust,ignore
102/// let (q, r) = tenferro::qr(&a);
103/// ```
104pub fn qr(a: &TracedTensor) -> (TracedTensor, TracedTensor) {
105    let shape = concrete_shape(a);
106    let m = shape[0];
107    let n = shape[1];
108    let k = m.min(n);
109    let batch = &shape[2..];
110    let mut q_shape = vec![m, k];
111    q_shape.extend_from_slice(batch);
112    let mut r_shape = vec![k, n];
113    r_shape.extend_from_slice(batch);
114    let mut results = apply_multi_output(
115        StdTensorOp::Qr {
116            input_shape: input_shape_expr(a),
117        },
118        a,
119        vec![sym_shape(&q_shape), sym_shape(&r_shape)],
120    )
121    .into_iter();
122    match (results.next(), results.next(), results.next()) {
123        (Some(q), Some(r), None) => (q, r),
124        _ => unreachable!("qr must produce exactly two outputs"),
125    }
126}
127
128/// Hermitian eigenvalue decomposition with a default numerical epsilon.
129///
130/// # Examples
131///
132/// ```rust,ignore
133/// let (values, vectors) = tenferro::eigh(&a);
134/// ```
135pub fn eigh(a: &TracedTensor) -> (TracedTensor, TracedTensor) {
136    eigh_with_eps(a, 1e-12)
137}
138
139/// Hermitian eigenvalue decomposition with an explicit numerical epsilon.
140///
141/// # Examples
142///
143/// ```rust,ignore
144/// let (values, vectors) = tenferro::eigh_with_eps(&a, 1e-10);
145/// ```
146pub fn eigh_with_eps(a: &TracedTensor, eps: f64) -> (TracedTensor, TracedTensor) {
147    let shape = concrete_shape(a);
148    let n = shape[0];
149    let batch = &shape[2..];
150    let op = StdTensorOp::Eigh {
151        eps,
152        input_shape: input_shape_expr(a),
153    };
154    let mut vals_shape = vec![n];
155    vals_shape.extend_from_slice(batch);
156    let mut vecs_shape = vec![n, n];
157    vecs_shape.extend_from_slice(batch);
158    let mut results =
159        apply_multi_output(op, a, vec![sym_shape(&vals_shape), sym_shape(&vecs_shape)]).into_iter();
160    match (results.next(), results.next(), results.next()) {
161        (Some(values), Some(vectors), None) => (values, vectors),
162        _ => unreachable!("eigh must produce exactly two outputs"),
163    }
164}
165
166/// Cholesky factorization.
167///
168/// # Examples
169///
170/// ```rust,ignore
171/// let l = tenferro::cholesky(&a);
172/// ```
173pub fn cholesky(a: &TracedTensor) -> TracedTensor {
174    let shape = concrete_shape(a);
175    apply_unary(
176        StdTensorOp::Cholesky {
177            input_shape: input_shape_expr(a),
178        },
179        a,
180        a.rank,
181        Some(sym_shape(&shape)),
182    )
183}
184
185/// LU decomposition with partial pivoting.
186///
187/// Returns `(P, L, U, parity)` where `P @ A = L @ U`.
188///
189/// # Examples
190///
191/// ```rust,ignore
192/// let (p, l, u, parity) = tenferro::lu(&a);
193/// ```
194pub fn lu(a: &TracedTensor) -> (TracedTensor, TracedTensor, TracedTensor, TracedTensor) {
195    let shape = concrete_shape(a);
196    let m = shape[0];
197    let n = shape[1];
198    let k = m.min(n);
199    let batch = &shape[2..];
200    let mut p_shape = vec![m, m];
201    p_shape.extend_from_slice(batch);
202    let mut l_shape = vec![m, k];
203    l_shape.extend_from_slice(batch);
204    let mut u_shape = vec![k, n];
205    u_shape.extend_from_slice(batch);
206    let parity_shape = batch.to_vec();
207    let mut results = apply_multi_output(
208        StdTensorOp::Lu {
209            input_shape: input_shape_expr(a),
210        },
211        a,
212        vec![
213            sym_shape(&p_shape),
214            sym_shape(&l_shape),
215            sym_shape(&u_shape),
216            sym_shape(&parity_shape),
217        ],
218    )
219    .into_iter();
220    match (
221        results.next(),
222        results.next(),
223        results.next(),
224        results.next(),
225        results.next(),
226    ) {
227        (Some(p), Some(l), Some(u), Some(parity), None) => (p, l, u, parity),
228        _ => unreachable!("lu must produce exactly four outputs"),
229    }
230}
231
232/// Non-symmetric eigendecomposition.
233///
234/// For real `f64` input, both outputs are `Complex64`.
235///
236/// # Examples
237///
238/// ```rust,ignore
239/// let (values, vectors) = tenferro::eig(&a);
240/// ```
241pub fn eig(a: &TracedTensor) -> (TracedTensor, TracedTensor) {
242    let shape = concrete_shape(a);
243    let n = shape[0];
244    let batch = &shape[2..];
245    let mut vals_shape = vec![n];
246    vals_shape.extend_from_slice(batch);
247    let mut vecs_shape = vec![n, n];
248    vecs_shape.extend_from_slice(batch);
249    let eig_dtype = eig_output_dtype(a.dtype);
250    let mut results = apply_multi_output(
251        StdTensorOp::Eig {
252            input_dtype: a.dtype,
253            input_shape: input_shape_expr(a),
254        },
255        a,
256        vec![sym_shape(&vals_shape), sym_shape(&vecs_shape)],
257    )
258    .into_iter();
259    match (results.next(), results.next(), results.next()) {
260        (Some(mut values), Some(mut vectors), None) => {
261            values.dtype = eig_dtype;
262            vectors.dtype = eig_dtype;
263            (values, vectors)
264        }
265        _ => unreachable!("eig must produce exactly two outputs"),
266    }
267}
268
269fn validate_nonsingular(u: &TracedTensor) -> TracedTensor {
270    apply_unary(
271        StdTensorOp::ValidateNonsingular {
272            input_shape: input_shape_expr(u),
273        },
274        u,
275        u.rank,
276        u.shape_hint.clone(),
277    )
278}
279
280/// Solve a linear system using LU decomposition and triangular solves.
281///
282/// # Examples
283///
284/// ```rust,ignore
285/// let x = tenferro::solve(&a, &b);
286/// ```
287pub fn solve(a: &TracedTensor, b: &TracedTensor) -> TracedTensor {
288    let a_shape = concrete_shape(a);
289    let b_shape = concrete_shape(b);
290    if has_zero_dim(&a_shape) || has_zero_dim(&b_shape) {
291        return zeros_like(b);
292    }
293
294    let do_solve = |a: &TracedTensor, b: &TracedTensor| -> TracedTensor {
295        let (p, l, u, _) = lu(a);
296        let u = validate_nonsingular(&u);
297        let pb = matmul_preserve_trailing_batch(&p, b);
298        let z = triangular_solve(&l, &pb, true, true, false, true);
299        triangular_solve(&u, &z, true, false, false, false)
300    };
301
302    if let Some(matrix_rhs_shape) = batched_vector_rhs_shape(a, b) {
303        let b2d = b.reshape(&matrix_rhs_shape);
304        let x2d = do_solve(a, &b2d);
305        x2d.reshape(&b_shape)
306    } else {
307        do_solve(a, b)
308    }
309}
310
311/// Solve a triangular linear system.
312///
313/// # Examples
314///
315/// ```rust,ignore
316/// let x = tenferro::triangular_solve(&a, &b, true, true, false, false);
317/// ```
318pub fn triangular_solve(
319    a: &TracedTensor,
320    b: &TracedTensor,
321    left_side: bool,
322    lower: bool,
323    transpose_a: bool,
324    unit_diagonal: bool,
325) -> TracedTensor {
326    let a_shape = concrete_shape(a);
327    let b_shape = concrete_shape(b);
328    if has_zero_dim(&a_shape) || has_zero_dim(&b_shape) {
329        return zeros_like(b);
330    }
331    let op = StdTensorOp::TriangularSolve {
332        left_side,
333        lower,
334        transpose_a,
335        unit_diagonal,
336        lhs_shape: input_shape_expr(a),
337        rhs_shape: input_shape_expr(b),
338    };
339    if let Some(matrix_rhs_shape) = batched_vector_rhs_shape(a, b) {
340        let b2d = b.reshape(&matrix_rhs_shape);
341        let x2d = apply_binary(
342            StdTensorOp::TriangularSolve {
343                left_side,
344                lower,
345                transpose_a,
346                unit_diagonal,
347                lhs_shape: input_shape_expr(a),
348                rhs_shape: DimExpr::from_concrete(&matrix_rhs_shape),
349            },
350            a,
351            &b2d,
352            matrix_rhs_shape.len(),
353            Some(sym_shape(&matrix_rhs_shape)),
354        );
355        x2d.reshape(&b_shape)
356    } else {
357        apply_binary(op, a, b, b.rank, b.shape_hint.clone())
358    }
359}
360
361/// Sign and log-absolute-determinant from the LU factorization.
362///
363/// # Examples
364///
365/// ```rust,ignore
366/// let (sign, logabsdet) = tenferro::slogdet(&a);
367/// ```
368pub fn slogdet(a: &TracedTensor) -> (TracedTensor, TracedTensor) {
369    let shape = concrete_shape(a);
370    let batch_shape = &shape[2..];
371    if has_zero_dim(&shape) {
372        let sign = broadcast_scalar(one_scalar(a.dtype), batch_shape);
373        let logabsdet = broadcast_scalar(zero_scalar(a.dtype), batch_shape);
374        return (sign, logabsdet);
375    }
376
377    let (_, _, u, parity) = lu(a);
378    let diag_u = u.extract_diag(0, 1);
379    let sign_u = reduce_prod(&diag_u.sign(), &[0]);
380    let sign = &parity * &sign_u;
381    let logabsdet = diag_u.abs().log().reduce_sum(&[0]);
382    (sign, logabsdet)
383}
384
385/// Determinant from `slogdet`.
386///
387/// # Examples
388///
389/// ```rust,ignore
390/// let value = tenferro::det(&a);
391/// ```
392pub fn det(a: &TracedTensor) -> TracedTensor {
393    let (sign, logabsdet) = slogdet(a);
394    &sign * &logabsdet.exp()
395}
396
397/// Matrix inverse via `solve(a, eye)`.
398///
399/// # Examples
400///
401/// ```rust,ignore
402/// let value = tenferro::inv(&a);
403/// ```
404pub fn inv(a: &TracedTensor) -> TracedTensor {
405    let shape = concrete_shape(a);
406    let eye = eye_like(a, shape[0]);
407    solve(a, &eye)
408}
409
410/// Hermitian eigenvalues only.
411///
412/// # Examples
413///
414/// ```rust,ignore
415/// let values = tenferro::eigvalsh(&a);
416/// ```
417pub fn eigvalsh(a: &TracedTensor) -> TracedTensor {
418    eigh(a).0
419}
420
421/// General eigenvalues only.
422///
423/// # Examples
424///
425/// ```rust,ignore
426/// let values = tenferro::eigvals(&a);
427/// ```
428pub fn eigvals(a: &TracedTensor) -> TracedTensor {
429    eig(a).0
430}
431
432/// Moore-Penrose pseudoinverse via the SVD.
433///
434/// # Examples
435///
436/// ```rust,ignore
437/// let value = tenferro::pinv(&a);
438/// ```
439pub fn pinv(a: &TracedTensor) -> TracedTensor {
440    let shape = concrete_shape(a);
441    let max_dim = match (shape.first(), shape.get(1)) {
442        (Some(&m), Some(&n)) => m.max(n),
443        (Some(&m), None) => m,
444        _ => 0,
445    };
446    pinv_with_rtol(a, default_pinv_rtol(a.dtype, max_dim))
447}
448
449/// Moore-Penrose pseudoinverse via the SVD with an explicit relative cutoff.
450///
451/// Singular values `<= rtol * max(s)` are discarded.
452///
453/// # Examples
454///
455/// ```rust,ignore
456/// let value = tenferro::pinv_with_rtol(&a, 1.0e-8);
457/// ```
458pub fn pinv_with_rtol(a: &TracedTensor, rtol: f64) -> TracedTensor {
459    let shape = concrete_shape(a);
460    if has_zero_dim(&shape) {
461        let mut out_shape = vec![shape[1], shape[0]];
462        out_shape.extend_from_slice(&shape[2..]);
463        return zeros_of_shape(a.dtype, out_shape);
464    }
465
466    let (u, s, vt) = svd(a);
467    let abs_s = s.abs();
468    let s_max = reduce_max(&abs_s, &[0]);
469    let s_max_shape = concrete_shape(&s_max);
470    let threshold = &s_max * &broadcast_scalar(scalar_real(s.dtype, rtol.max(0.0)), &s_max_shape);
471    let s_shape = concrete_shape(&s);
472    let threshold = broadcast_batch_scalar_to_leading_axis(&threshold, &s_shape);
473    let mask = compare_dir(&abs_s, &threshold, CompareDir::Gt);
474    let ones = ones_like(&s);
475    let denom = &s + &(&ones + &(-&mask));
476    let s_inv = &mask / &denom;
477
478    let v = vt.conj().transpose(&matrix_transpose_perm(vt.rank));
479    let uh = u.conj().transpose(&matrix_transpose_perm(u.rank));
480    let s_inv_diag = s_inv.embed_diag(0, 1);
481    let vs = matmul_preserve_trailing_batch(&v, &s_inv_diag);
482    matmul_preserve_trailing_batch(&vs, &uh)
483}
484
485/// Vector or matrix norm.
486///
487/// This currently covers Frobenius norms, p-norms, and `±inf` reductions.
488///
489/// # Examples
490///
491/// ```rust,ignore
492/// let value = tenferro::norm(&a, None, None, false);
493/// ```
494pub fn norm(
495    a: &TracedTensor,
496    ord: Option<f64>,
497    dim: Option<&[usize]>,
498    keepdim: bool,
499) -> TracedTensor {
500    let axes = dim.map_or_else(|| (0..a.rank).collect::<Vec<_>>(), |dims| dims.to_vec());
501    if axes.is_empty() {
502        return a.clone();
503    }
504
505    let out = match axes.len() {
506        1 => vector_norm(a, axes[0], ord),
507        2 => matrix_norm(a, &axes, ord),
508        _ => {
509            let abs = a.abs();
510            match ord {
511                None => frobenius_norm(&abs, &axes),
512                Some(p) if p == f64::INFINITY => reduce_max(&abs, &axes),
513                Some(p) if p == f64::NEG_INFINITY => reduce_min(&abs, &axes),
514                Some(p) if p == 0.0 => count_nonzero(&abs, &axes),
515                Some(p) => p_norm(&abs, &axes, p),
516            }
517        }
518    };
519    let shape = concrete_shape(a);
520    restore_keepdim(out, &shape, &axes, keepdim)
521}
522
523fn eig_output_dtype(dtype: DType) -> DType {
524    match dtype {
525        DType::F64 | DType::C64 => DType::C64,
526        DType::F32 | DType::C32 => DType::C32,
527    }
528}
529
530fn scalar_real(dtype: DType, value: f64) -> TracedTensor {
531    match dtype {
532        DType::F64 => apply_nullary(
533            StdTensorOp::constant_f64(value),
534            0,
535            DType::F64,
536            Some(vec![]),
537        ),
538        DType::F32 => apply_nullary(
539            StdTensorOp::constant_f32(value as f32),
540            0,
541            DType::F32,
542            Some(vec![]),
543        ),
544        DType::C64 => apply_nullary(
545            StdTensorOp::constant_c64(Complex64::new(value, 0.0)),
546            0,
547            DType::C64,
548            Some(vec![]),
549        ),
550        DType::C32 => apply_nullary(
551            StdTensorOp::constant_c32(Complex32::new(value as f32, 0.0)),
552            0,
553            DType::C32,
554            Some(vec![]),
555        ),
556    }
557}
558
559fn zero_scalar(dtype: DType) -> TracedTensor {
560    scalar_real(dtype, 0.0)
561}
562
563fn one_scalar(dtype: DType) -> TracedTensor {
564    scalar_real(dtype, 1.0)
565}
566
567fn zeros_like(input: &TracedTensor) -> TracedTensor {
568    zeros_of_shape(input.dtype, concrete_shape(input))
569}
570
571fn zeros_of_shape(dtype: DType, shape: Vec<usize>) -> TracedTensor {
572    broadcast_scalar(zero_scalar(dtype), &shape)
573}
574
575fn ones_like(input: &TracedTensor) -> TracedTensor {
576    let shape = concrete_shape(input);
577    broadcast_scalar(one_scalar(input.dtype), &shape)
578}
579
580fn eye_like(anchor: &TracedTensor, size: usize) -> TracedTensor {
581    let mut vector_shape = vec![size];
582    let anchor_shape = concrete_shape(anchor);
583    vector_shape.extend_from_slice(&anchor_shape[2..]);
584    let diagonal = broadcast_scalar(one_scalar(anchor.dtype), &vector_shape);
585    diagonal.embed_diag(0, 1)
586}
587
588fn frobenius_norm(abs: &TracedTensor, axes: &[usize]) -> TracedTensor {
589    let squared = abs.pow(&scalar_real(abs.dtype, 2.0));
590    squared.reduce_sum(axes).sqrt()
591}
592
593fn p_norm(abs: &TracedTensor, axes: &[usize], p: f64) -> TracedTensor {
594    let power = abs.pow(&scalar_real(abs.dtype, p));
595    let inv_p = scalar_real(abs.dtype, 1.0 / p);
596    power.reduce_sum(axes).pow(&inv_p)
597}
598
599fn default_pinv_rtol(dtype: DType, max_dim: usize) -> f64 {
600    let eps = match dtype {
601        DType::F32 | DType::C32 => f32::EPSILON as f64,
602        DType::F64 | DType::C64 => f64::EPSILON,
603    };
604    eps * max_dim as f64
605}
606
607fn vector_norm(a: &TracedTensor, axis: usize, ord: Option<f64>) -> TracedTensor {
608    let abs = a.abs();
609    match ord {
610        None => frobenius_norm(&abs, &[axis]),
611        Some(p) if p == 0.0 => count_nonzero(&abs, &[axis]),
612        Some(p) if p == f64::INFINITY => reduce_max(&abs, &[axis]),
613        Some(p) if p == f64::NEG_INFINITY => reduce_min(&abs, &[axis]),
614        Some(p) => p_norm(&abs, &[axis], p),
615    }
616}
617
618fn matrix_norm(a: &TracedTensor, axes: &[usize], ord: Option<f64>) -> TracedTensor {
619    let matrix = move_axes_to_front(a, axes);
620    let abs = matrix.abs();
621    match ord {
622        None => frobenius_norm(&abs, &[0, 1]),
623        Some(p) if p == f64::INFINITY => matrix_row_sum_norm(&abs, true),
624        Some(p) if p == f64::NEG_INFINITY => matrix_row_sum_norm(&abs, false),
625        Some(p) if p == 1.0 => matrix_col_sum_norm(&abs, true),
626        Some(p) if p == -1.0 => matrix_col_sum_norm(&abs, false),
627        Some(p) if p == 2.0 => {
628            let singular_values = svd(&matrix).1.abs();
629            reduce_max(&singular_values, &[0])
630        }
631        Some(p) if p == -2.0 => {
632            let singular_values = svd(&matrix).1.abs();
633            reduce_min(&singular_values, &[0])
634        }
635        Some(p) if p == 0.0 => count_nonzero(&abs, &[0, 1]),
636        Some(p) => p_norm(&abs, &[0, 1], p),
637    }
638}
639
640fn count_nonzero(abs: &TracedTensor, axes: &[usize]) -> TracedTensor {
641    let mask = compare_dir(abs, &zero_scalar(abs.dtype), CompareDir::Gt);
642    mask.reduce_sum(axes)
643}
644
645fn matrix_row_sum_norm(abs: &TracedTensor, take_max: bool) -> TracedTensor {
646    let row_sums = abs.reduce_sum(&[1]);
647    if take_max {
648        reduce_max(&row_sums, &[0])
649    } else {
650        reduce_min(&row_sums, &[0])
651    }
652}
653
654fn matrix_col_sum_norm(abs: &TracedTensor, take_max: bool) -> TracedTensor {
655    let col_sums = abs.reduce_sum(&[0]);
656    if take_max {
657        reduce_max(&col_sums, &[0])
658    } else {
659        reduce_min(&col_sums, &[0])
660    }
661}
662
663fn move_axes_to_front(tensor: &TracedTensor, axes: &[usize]) -> TracedTensor {
664    if axes.iter().enumerate().all(|(index, &axis)| index == axis) {
665        return tensor.clone();
666    }
667
668    let mut selected = vec![false; tensor.rank];
669    for &axis in axes {
670        selected[axis] = true;
671    }
672
673    let mut perm = Vec::with_capacity(tensor.rank);
674    perm.extend_from_slice(axes);
675    for axis in 0..tensor.rank {
676        if !selected[axis] {
677            perm.push(axis);
678        }
679    }
680    tensor.transpose(&perm)
681}
682
683fn restore_keepdim(
684    reduced: TracedTensor,
685    original_shape: &[usize],
686    axes: &[usize],
687    keepdim: bool,
688) -> TracedTensor {
689    if !keepdim {
690        return reduced;
691    }
692    let mut kept_shape = original_shape.to_vec();
693    for &axis in axes {
694        kept_shape[axis] = 1;
695    }
696    reduced.reshape(&kept_shape)
697}
698
699fn reduce_prod(input: &TracedTensor, axes: &[usize]) -> TracedTensor {
700    let input_shape = concrete_shape(input);
701    let out_shape = reduced_shape(&input_shape, axes);
702    apply_unary(
703        StdTensorOp::ReduceProd {
704            axes: axes.to_vec(),
705            input_shape: DimExpr::input_shape(0, input.rank),
706        },
707        input,
708        input.rank - axes.len(),
709        Some(sym_shape(&out_shape)),
710    )
711}
712
713fn reduce_max(input: &TracedTensor, axes: &[usize]) -> TracedTensor {
714    let input_shape = concrete_shape(input);
715    let out_shape = reduced_shape(&input_shape, axes);
716    apply_unary(
717        StdTensorOp::ReduceMax {
718            axes: axes.to_vec(),
719            input_shape: DimExpr::input_shape(0, input.rank),
720        },
721        input,
722        input.rank - axes.len(),
723        Some(sym_shape(&out_shape)),
724    )
725}
726
727fn reduce_min(input: &TracedTensor, axes: &[usize]) -> TracedTensor {
728    let input_shape = concrete_shape(input);
729    let out_shape = reduced_shape(&input_shape, axes);
730    apply_unary(
731        StdTensorOp::ReduceMin {
732            axes: axes.to_vec(),
733            input_shape: DimExpr::input_shape(0, input.rank),
734        },
735        input,
736        input.rank - axes.len(),
737        Some(sym_shape(&out_shape)),
738    )
739}
740
741fn compare_dir(lhs: &TracedTensor, rhs: &TracedTensor, dir: CompareDir) -> TracedTensor {
742    let (lhs, rhs) = broadcast_binary(lhs, rhs);
743    apply_binary(
744        StdTensorOp::Compare(dir),
745        &lhs,
746        &rhs,
747        lhs.rank,
748        lhs.shape_hint.clone(),
749    )
750}
751
752fn broadcast_scalar(input: TracedTensor, shape: &[usize]) -> TracedTensor {
753    let input_shape = concrete_shape(&input);
754    if input_shape == shape {
755        return input;
756    }
757    input.broadcast_in_dim(shape, &[])
758}
759
760fn broadcast_batch_scalar_to_leading_axis(input: &TracedTensor, shape: &[usize]) -> TracedTensor {
761    let input_shape = concrete_shape(input);
762    if input_shape == shape {
763        return input.clone();
764    }
765    let dims: Vec<usize> = (1..shape.len()).collect();
766    input.broadcast_in_dim(shape, &dims)
767}
768
769fn matmul_preserve_trailing_batch(lhs: &TracedTensor, rhs: &TracedTensor) -> TracedTensor {
770    let rank = lhs.rank;
771    let batch_dims: Vec<usize> = (2..rank).collect();
772    lhs.dot_general(
773        rhs,
774        DotGeneralConfig {
775            lhs_contracting_dims: vec![1],
776            rhs_contracting_dims: vec![0],
777            lhs_batch_dims: batch_dims.clone(),
778            rhs_batch_dims: batch_dims,
779            lhs_rank: rank,
780            rhs_rank: rank,
781        },
782    )
783}
784
785fn matrix_transpose_perm(rank: usize) -> Vec<usize> {
786    let mut perm: Vec<usize> = (0..rank).collect();
787    perm.swap(0, 1);
788    perm
789}
790
791fn reduced_shape(shape: &[usize], axes: &[usize]) -> Vec<usize> {
792    (0..shape.len())
793        .filter(|axis| !axes.contains(axis))
794        .map(|axis| shape[axis])
795        .collect()
796}
797
798fn batched_vector_rhs_shape(a: &TracedTensor, b: &TracedTensor) -> Option<Vec<usize>> {
799    let a_shape = concrete_shape(a);
800    let b_shape = concrete_shape(b);
801
802    if b_shape.len() == 1 {
803        return Some(vec![b_shape[0], 1]);
804    }
805
806    let is_batched_vector_rhs = a_shape.len() == b_shape.len() + 1
807        && !b_shape.is_empty()
808        && b_shape[0] == a_shape[0]
809        && b_shape[1..] == a_shape[2..];
810    if !is_batched_vector_rhs {
811        return None;
812    }
813
814    let mut rhs_shape = vec![b_shape[0], 1];
815    rhs_shape.extend_from_slice(&b_shape[1..]);
816    Some(rhs_shape)
817}
818
819fn has_zero_dim(shape: &[usize]) -> bool {
820    shape.contains(&0)
821}
822
823fn broadcast_shape(a: &[usize], b: &[usize]) -> Option<Vec<usize>> {
824    let rank = a.len().max(b.len());
825    let mut result = Vec::with_capacity(rank);
826    for index in 0..rank {
827        let a_dim = if index < rank - a.len() {
828            1
829        } else {
830            a[index - (rank - a.len())]
831        };
832        let b_dim = if index < rank - b.len() {
833            1
834        } else {
835            b[index - (rank - b.len())]
836        };
837        if a_dim == b_dim {
838            result.push(a_dim);
839        } else if a_dim == 1 {
840            result.push(b_dim);
841        } else if b_dim == 1 {
842            result.push(a_dim);
843        } else {
844            return None;
845        }
846    }
847    Some(result)
848}
849
850fn broadcast_to(tensor: &TracedTensor, target_shape: &[usize]) -> TracedTensor {
851    let tensor_shape = concrete_shape(tensor);
852    if tensor_shape == target_shape {
853        return tensor.clone();
854    }
855
856    assert!(
857        tensor.rank <= target_shape.len(),
858        "cannot broadcast higher-rank shape {:?} to {:?}",
859        tensor_shape,
860        target_shape
861    );
862
863    let rank_diff = target_shape.len() - tensor.rank;
864    let mut source_shape = Vec::with_capacity(tensor.rank);
865    let mut dims = Vec::with_capacity(tensor.rank);
866    for (src_axis, &src_dim) in tensor_shape.iter().enumerate() {
867        let dst_axis = src_axis + rank_diff;
868        let dst_dim = target_shape[dst_axis];
869        assert!(
870            src_dim == dst_dim || src_dim == 1,
871            "cannot broadcast shape {:?} to {:?}",
872            tensor_shape,
873            target_shape
874        );
875        if src_dim == 1 && dst_dim != 1 {
876            continue;
877        }
878        source_shape.push(src_dim);
879        dims.push(dst_axis);
880    }
881
882    let source = if source_shape == tensor_shape {
883        tensor.clone()
884    } else {
885        tensor.reshape(&source_shape)
886    };
887    source.broadcast_in_dim(target_shape, &dims)
888}
889
890fn broadcast_binary(a: &TracedTensor, b: &TracedTensor) -> (TracedTensor, TracedTensor) {
891    if a.shape_hint == b.shape_hint && a.rank == b.rank {
892        return (a.clone(), b.clone());
893    }
894    let a_shape = concrete_shape(a);
895    let b_shape = concrete_shape(b);
896    let target = broadcast_shape(&a_shape, &b_shape).unwrap_or_else(|| {
897        panic!(
898            "incompatible shapes for broadcast: {:?} and {:?}",
899            a_shape, b_shape
900        )
901    });
902    (broadcast_to(a, &target), broadcast_to(b, &target))
903}