tenferro_ext_tropical/argmax.rs
1//! Argmax tracking for tropical backward pass (automatic differentiation).
2//!
3//! In tropical semirings, the "gradient" of a contraction is determined by
4//! which elements "won" the max (or min) comparisons. [`ArgmaxTracker`]
5//! records the winner indices during the forward pass so that the backward
6//! pass can route gradients to the correct elements.
7//!
8//! This is the tropical analogue of storing intermediate activations for
9//! backpropagation in standard neural network training.
10
11/// Tracks winner indices from tropical forward-pass operations.
12///
13/// During a tropical contraction `C[i,j] = max_k (A[i,k] + B[k,j])`,
14/// the tracker records which `k` achieved the maximum for each `(i,j)`.
15/// The backward pass uses these indices to route gradients.
16///
17/// # Examples
18///
19/// ```
20/// use tenferro_ext_tropical::ArgmaxTracker;
21///
22/// // Create a tracker for a 3x5 output
23/// let tracker = ArgmaxTracker::new(&[3, 5]);
24///
25/// // After forward pass, query the winner index for output element (1, 2)
26/// let k_winner = tracker.winner_index(&[1, 2]);
27/// assert_eq!(k_winner, 0); // initialized to 0
28/// ```
29pub struct ArgmaxTracker {
30 /// Shape of the output tensor.
31 output_shape: Vec<usize>,
32 /// Winner indices (flat storage, one per output element).
33 /// Each entry records the contraction index that achieved the optimum.
34 indices: Vec<usize>,
35}
36
37impl ArgmaxTracker {
38 /// Create a new tracker for an output of the given shape.
39 ///
40 /// All winner indices are initialized to 0.
41 ///
42 /// # Examples
43 ///
44 /// ```
45 /// use tenferro_ext_tropical::ArgmaxTracker;
46 ///
47 /// let tracker = ArgmaxTracker::new(&[3, 5]);
48 /// assert_eq!(tracker.output_shape(), &[3, 5]);
49 /// ```
50 pub fn new(output_shape: &[usize]) -> Self {
51 let total: usize = output_shape.iter().product();
52 Self {
53 output_shape: output_shape.to_vec(),
54 indices: vec![0; total],
55 }
56 }
57
58 /// Return the output shape.
59 pub fn output_shape(&self) -> &[usize] {
60 &self.output_shape
61 }
62
63 /// Return the winner indices as a flat slice.
64 pub fn indices(&self) -> &[usize] {
65 &self.indices
66 }
67
68 /// Return a mutable reference to the winner indices.
69 pub fn indices_mut(&mut self) -> &mut [usize] {
70 &mut self.indices
71 }
72
73 /// Look up the winner index for a given multi-dimensional output position.
74 ///
75 /// # Examples
76 ///
77 /// ```
78 /// use tenferro_ext_tropical::ArgmaxTracker;
79 ///
80 /// let tracker = ArgmaxTracker::new(&[3, 5]);
81 /// let k = tracker.winner_index(&[1, 2]);
82 /// assert_eq!(k, 0); // initialized to 0
83 /// ```
84 ///
85 /// # Panics
86 ///
87 /// Panics if `position` has the wrong rank or any index is out of bounds
88 /// for the tracked output shape.
89 pub fn winner_index(&self, position: &[usize]) -> usize {
90 assert_eq!(
91 position.len(),
92 self.output_shape.len(),
93 "winner_index: expected {} indices, got {}",
94 self.output_shape.len(),
95 position.len()
96 );
97
98 // Column-major linear index
99 let mut linear = 0;
100 let mut stride = 1;
101 for (i, &p) in position.iter().enumerate() {
102 assert!(
103 p < self.output_shape[i],
104 "winner_index: index {} out of bounds for axis {} with size {}",
105 p,
106 i,
107 self.output_shape[i]
108 );
109 linear += p * stride;
110 stride *= self.output_shape[i];
111 }
112 self.indices[linear]
113 }
114}