strided_einsum2/trace.rs
1//! Trace-axis reduction for einsum operands.
2//!
3//! Trace axes are axes that appear in only one operand and not in the output.
4//! They must be summed out (reduced) before the main contraction.
5
6use strided_kernel::reduce_axis;
7use strided_view::{ElementOp, StridedArray, StridedView};
8
9/// Find indices of trace axes: axes in `labels` that don't appear in `other` or `output`.
10///
11/// Returns a (possibly empty) vector of positions in `labels`.
12pub fn find_trace_indices<ID: PartialEq>(labels: &[ID], other: &[ID], output: &[ID]) -> Vec<usize> {
13 labels
14 .iter()
15 .enumerate()
16 .filter(|(_, id)| !other.contains(id) && !output.contains(id))
17 .map(|(i, _)| i)
18 .collect()
19}
20
21/// Reduce all trace axes from a view by summing them out.
22///
23/// `trace_axes` are indices into the original view's dimensions, given in
24/// ascending order. Each axis is reduced (summed) in sequence from back to
25/// front so that axis indices remain valid after each reduction.
26///
27/// Returns a new `StridedArray` with the trace axes removed.
28pub fn reduce_trace_axes<T, Op>(
29 src: &StridedView<T, Op>,
30 trace_axes: &[usize],
31) -> strided_view::Result<StridedArray<T>>
32where
33 T: Copy + Send + Sync + std::ops::Add<Output = T> + num_traits::Zero,
34 Op: ElementOp<T>,
35{
36 if trace_axes.is_empty() {
37 // No trace axes — this shouldn't happen in practice since the caller checks.
38 panic!("reduce_trace_axes called with empty trace_axes");
39 }
40
41 // Sort in descending order so we can reduce from back to front
42 // without invalidating earlier axis indices.
43 let mut axes: Vec<usize> = trace_axes.to_vec();
44 axes.sort_unstable();
45 axes.reverse();
46
47 let first_reduced = reduce_axis(src, axes[0], |x| x, |a, b| a + b, T::zero())?;
48
49 let mut current = first_reduced;
50 for &ax in &axes[1..] {
51 current = reduce_axis(¤t.view(), ax, |x| x, |a, b| a + b, T::zero())?;
52 }
53
54 Ok(current)
55}
56
57#[cfg(test)]
58mod tests {
59 use super::*;
60 use strided_view::Identity;
61
62 #[test]
63 #[should_panic(expected = "reduce_trace_axes called with empty trace_axes")]
64 fn test_reduce_no_trace() {
65 let a =
66 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
67 let _ = reduce_trace_axes::<f64, Identity>(&a.view(), &[]);
68 }
69
70 #[test]
71 fn test_reduce_single_trace() {
72 // A: 2x3, reduce axis 1 => [2] with sums [6, 15]
73 let a =
74 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
75 // A = [[1,2,3],[4,5,6]]
76 let result = reduce_trace_axes::<f64, Identity>(&a.view(), &[1]).unwrap();
77 assert_eq!(result.dims(), &[2]);
78 assert_eq!(result.get(&[0]), 6.0); // 1+2+3
79 assert_eq!(result.get(&[1]), 15.0); // 4+5+6
80 }
81
82 #[test]
83 fn test_reduce_two_traces() {
84 // A: 2x3x4, reduce axes [0, 2] => [3]
85 let a = StridedArray::<f64>::from_fn_row_major(&[2, 3, 4], |idx| {
86 (idx[0] * 12 + idx[1] * 4 + idx[2]) as f64
87 });
88 let result = reduce_trace_axes::<f64, Identity>(&a.view(), &[0, 2]).unwrap();
89 assert_eq!(result.dims(), &[3]);
90 // After reducing axis 0: [3, 4] where result[j,k] = a[0,j,k] + a[1,j,k]
91 // = (j*4+k) + (12+j*4+k) = 2*j*4 + 2*k + 12
92 // After reducing axis 1 (was axis 2, now adjusted to 1): [3]
93 // result[j] = sum_k (2*j*4 + 2*k + 12) for k=0..3
94 // = 4*(8*j + 12) + 2*(0+1+2+3) = 32*j + 48 + 12 = 32*j + 60
95 // Hmm, let me recalculate...
96 // a[i,j,k] = i*12 + j*4 + k
97 // After reducing axis 0 (sum over i=0,1): b[j,k] = (0*12+j*4+k) + (1*12+j*4+k) = 12 + 2*j*4 + 2*k
98 // After reducing axis 1 (originally axis 2, adjusted to 1 -> sum over k=0..3):
99 // c[j] = sum_{k=0}^{3} (12 + 8*j + 2*k) = 4*12 + 4*8*j + 2*(0+1+2+3) = 48 + 32*j + 12 = 60 + 32*j
100 assert_eq!(result.get(&[0]), 60.0); // 60 + 32*0
101 assert_eq!(result.get(&[1]), 92.0); // 60 + 32*1
102 assert_eq!(result.get(&[2]), 124.0); // 60 + 32*2
103 }
104}