strided_einsum2/
trace.rs

1//! Trace-axis reduction for einsum operands.
2//!
3//! Trace axes are axes that appear in only one operand and not in the output.
4//! They must be summed out (reduced) before the main contraction.
5
6use strided_kernel::reduce_axis;
7use strided_view::{ElementOp, StridedArray, StridedView};
8
9/// Reduce all trace axes from a view by summing them out.
10///
11/// `trace_axes` are indices into the original view's dimensions, given in
12/// ascending order. Each axis is reduced (summed) in sequence from back to
13/// front so that axis indices remain valid after each reduction.
14///
15/// Returns a new `StridedArray` with the trace axes removed.
16pub fn reduce_trace_axes<T, Op>(
17    src: &StridedView<T, Op>,
18    trace_axes: &[usize],
19) -> strided_view::Result<StridedArray<T>>
20where
21    T: Copy + Send + Sync + std::ops::Add<Output = T> + num_traits::Zero,
22    Op: ElementOp<T>,
23{
24    if trace_axes.is_empty() {
25        // No trace axes — this shouldn't happen in practice since the caller checks.
26        panic!("reduce_trace_axes called with empty trace_axes");
27    }
28
29    // Sort in descending order so we can reduce from back to front
30    // without invalidating earlier axis indices.
31    let mut axes: Vec<usize> = trace_axes.to_vec();
32    axes.sort_unstable();
33    axes.reverse();
34
35    let first_reduced = reduce_axis(src, axes[0], |x| x, |a, b| a + b, T::zero())?;
36
37    let mut current = first_reduced;
38    for &ax in &axes[1..] {
39        current = reduce_axis(&current.view(), ax, |x| x, |a, b| a + b, T::zero())?;
40    }
41
42    Ok(current)
43}
44
45#[cfg(test)]
46mod tests {
47    use super::*;
48    use strided_view::Identity;
49
50    #[test]
51    #[should_panic(expected = "reduce_trace_axes called with empty trace_axes")]
52    fn test_reduce_no_trace() {
53        let a =
54            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
55        let _ = reduce_trace_axes::<f64, Identity>(&a.view(), &[]);
56    }
57
58    #[test]
59    fn test_reduce_single_trace() {
60        // A: 2x3, reduce axis 1 => [2] with sums [6, 15]
61        let a =
62            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
63        // A = [[1,2,3],[4,5,6]]
64        let result = reduce_trace_axes::<f64, Identity>(&a.view(), &[1]).unwrap();
65        assert_eq!(result.dims(), &[2]);
66        assert_eq!(result.get(&[0]), 6.0); // 1+2+3
67        assert_eq!(result.get(&[1]), 15.0); // 4+5+6
68    }
69
70    #[test]
71    fn test_reduce_two_traces() {
72        // A: 2x3x4, reduce axes [0, 2] => [3]
73        let a = StridedArray::<f64>::from_fn_row_major(&[2, 3, 4], |idx| {
74            (idx[0] * 12 + idx[1] * 4 + idx[2]) as f64
75        });
76        let result = reduce_trace_axes::<f64, Identity>(&a.view(), &[0, 2]).unwrap();
77        assert_eq!(result.dims(), &[3]);
78        // After reducing axis 0: [3, 4] where result[j,k] = a[0,j,k] + a[1,j,k]
79        //   = (j*4+k) + (12+j*4+k) = 2*j*4 + 2*k + 12
80        // After reducing axis 1 (was axis 2, now adjusted to 1): [3]
81        //   result[j] = sum_k (2*j*4 + 2*k + 12) for k=0..3
82        //             = 4*(8*j + 12) + 2*(0+1+2+3) = 32*j + 48 + 12 = 32*j + 60
83        // Hmm, let me recalculate...
84        // a[i,j,k] = i*12 + j*4 + k
85        // After reducing axis 0 (sum over i=0,1): b[j,k] = (0*12+j*4+k) + (1*12+j*4+k) = 12 + 2*j*4 + 2*k
86        // After reducing axis 1 (originally axis 2, adjusted to 1 -> sum over k=0..3):
87        //   c[j] = sum_{k=0}^{3} (12 + 8*j + 2*k) = 4*12 + 4*8*j + 2*(0+1+2+3) = 48 + 32*j + 12 = 60 + 32*j
88        assert_eq!(result.get(&[0]), 60.0); // 60 + 32*0
89        assert_eq!(result.get(&[1]), 92.0); // 60 + 32*1
90        assert_eq!(result.get(&[2]), 124.0); // 60 + 32*2
91    }
92}