strided_einsum2/
trace.rs

1//! Trace-axis reduction for einsum operands.
2//!
3//! Trace axes are axes that appear in only one operand and not in the output.
4//! They must be summed out (reduced) before the main contraction.
5
6use strided_kernel::reduce_axis;
7use strided_view::{ElementOp, StridedArray, StridedView};
8
9/// Find indices of trace axes: axes in `labels` that don't appear in `other` or `output`.
10///
11/// Returns a (possibly empty) vector of positions in `labels`.
12pub fn find_trace_indices<ID: PartialEq>(labels: &[ID], other: &[ID], output: &[ID]) -> Vec<usize> {
13    labels
14        .iter()
15        .enumerate()
16        .filter(|(_, id)| !other.contains(id) && !output.contains(id))
17        .map(|(i, _)| i)
18        .collect()
19}
20
21/// Reduce all trace axes from a view by summing them out.
22///
23/// `trace_axes` are indices into the original view's dimensions, given in
24/// ascending order. Each axis is reduced (summed) in sequence from back to
25/// front so that axis indices remain valid after each reduction.
26///
27/// Returns a new `StridedArray` with the trace axes removed.
28pub fn reduce_trace_axes<T, Op>(
29    src: &StridedView<T, Op>,
30    trace_axes: &[usize],
31) -> strided_view::Result<StridedArray<T>>
32where
33    T: Copy + Send + Sync + std::ops::Add<Output = T> + num_traits::Zero,
34    Op: ElementOp<T>,
35{
36    if trace_axes.is_empty() {
37        // No trace axes — this shouldn't happen in practice since the caller checks.
38        panic!("reduce_trace_axes called with empty trace_axes");
39    }
40
41    // Sort in descending order so we can reduce from back to front
42    // without invalidating earlier axis indices.
43    let mut axes: Vec<usize> = trace_axes.to_vec();
44    axes.sort_unstable();
45    axes.reverse();
46
47    let first_reduced = reduce_axis(src, axes[0], |x| x, |a, b| a + b, T::zero())?;
48
49    let mut current = first_reduced;
50    for &ax in &axes[1..] {
51        current = reduce_axis(&current.view(), ax, |x| x, |a, b| a + b, T::zero())?;
52    }
53
54    Ok(current)
55}
56
57#[cfg(test)]
58mod tests {
59    use super::*;
60    use strided_view::Identity;
61
62    #[test]
63    #[should_panic(expected = "reduce_trace_axes called with empty trace_axes")]
64    fn test_reduce_no_trace() {
65        let a =
66            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
67        let _ = reduce_trace_axes::<f64, Identity>(&a.view(), &[]);
68    }
69
70    #[test]
71    fn test_reduce_single_trace() {
72        // A: 2x3, reduce axis 1 => [2] with sums [6, 15]
73        let a =
74            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
75        // A = [[1,2,3],[4,5,6]]
76        let result = reduce_trace_axes::<f64, Identity>(&a.view(), &[1]).unwrap();
77        assert_eq!(result.dims(), &[2]);
78        assert_eq!(result.get(&[0]), 6.0); // 1+2+3
79        assert_eq!(result.get(&[1]), 15.0); // 4+5+6
80    }
81
82    #[test]
83    fn test_reduce_two_traces() {
84        // A: 2x3x4, reduce axes [0, 2] => [3]
85        let a = StridedArray::<f64>::from_fn_row_major(&[2, 3, 4], |idx| {
86            (idx[0] * 12 + idx[1] * 4 + idx[2]) as f64
87        });
88        let result = reduce_trace_axes::<f64, Identity>(&a.view(), &[0, 2]).unwrap();
89        assert_eq!(result.dims(), &[3]);
90        // After reducing axis 0: [3, 4] where result[j,k] = a[0,j,k] + a[1,j,k]
91        //   = (j*4+k) + (12+j*4+k) = 2*j*4 + 2*k + 12
92        // After reducing axis 1 (was axis 2, now adjusted to 1): [3]
93        //   result[j] = sum_k (2*j*4 + 2*k + 12) for k=0..3
94        //             = 4*(8*j + 12) + 2*(0+1+2+3) = 32*j + 48 + 12 = 32*j + 60
95        // Hmm, let me recalculate...
96        // a[i,j,k] = i*12 + j*4 + k
97        // After reducing axis 0 (sum over i=0,1): b[j,k] = (0*12+j*4+k) + (1*12+j*4+k) = 12 + 2*j*4 + 2*k
98        // After reducing axis 1 (originally axis 2, adjusted to 1 -> sum over k=0..3):
99        //   c[j] = sum_{k=0}^{3} (12 + 8*j + 2*k) = 4*12 + 4*8*j + 2*(0+1+2+3) = 48 + 32*j + 12 = 60 + 32*j
100        assert_eq!(result.get(&[0]), 60.0); // 60 + 32*0
101        assert_eq!(result.get(&[1]), 92.0); // 60 + 32*1
102        assert_eq!(result.get(&[2]), 124.0); // 60 + 32*2
103    }
104}