tenferro_ext_tropical/
argmax.rs

1//! Argmax tracking for tropical backward pass (automatic differentiation).
2//!
3//! In tropical semirings, the "gradient" of a contraction is determined by
4//! which elements "won" the max (or min) comparisons. [`ArgmaxTracker`]
5//! records the winner indices during the forward pass so that the backward
6//! pass can route gradients to the correct elements.
7//!
8//! This is the tropical analogue of storing intermediate activations for
9//! backpropagation in standard neural network training.
10
11/// Tracks winner indices from tropical forward-pass operations.
12///
13/// During a tropical contraction `C[i,j] = max_k (A[i,k] + B[k,j])`,
14/// the tracker records which `k` achieved the maximum for each `(i,j)`.
15/// The backward pass uses these indices to route gradients.
16///
17/// # Examples
18///
19/// ```
20/// use tenferro_ext_tropical::ArgmaxTracker;
21///
22/// // Create a tracker for a 3x5 output
23/// let tracker = ArgmaxTracker::new(&[3, 5]);
24///
25/// // After forward pass, query the winner index for output element (1, 2)
26/// let k_winner = tracker.winner_index(&[1, 2]);
27/// assert_eq!(k_winner, 0); // initialized to 0
28/// ```
29pub struct ArgmaxTracker {
30    /// Shape of the output tensor.
31    output_shape: Vec<usize>,
32    /// Winner indices (flat storage, one per output element).
33    /// Each entry records the contraction index that achieved the optimum.
34    indices: Vec<usize>,
35}
36
37impl ArgmaxTracker {
38    /// Create a new tracker for an output of the given shape.
39    ///
40    /// All winner indices are initialized to 0.
41    ///
42    /// # Examples
43    ///
44    /// ```
45    /// use tenferro_ext_tropical::ArgmaxTracker;
46    ///
47    /// let tracker = ArgmaxTracker::new(&[3, 5]);
48    /// assert_eq!(tracker.output_shape(), &[3, 5]);
49    /// ```
50    pub fn new(output_shape: &[usize]) -> Self {
51        let total: usize = output_shape.iter().product();
52        Self {
53            output_shape: output_shape.to_vec(),
54            indices: vec![0; total],
55        }
56    }
57
58    /// Return the output shape.
59    pub fn output_shape(&self) -> &[usize] {
60        &self.output_shape
61    }
62
63    /// Return the winner indices as a flat slice.
64    pub fn indices(&self) -> &[usize] {
65        &self.indices
66    }
67
68    /// Return a mutable reference to the winner indices.
69    pub fn indices_mut(&mut self) -> &mut [usize] {
70        &mut self.indices
71    }
72
73    /// Look up the winner index for a given multi-dimensional output position.
74    ///
75    /// # Examples
76    ///
77    /// ```
78    /// use tenferro_ext_tropical::ArgmaxTracker;
79    ///
80    /// let tracker = ArgmaxTracker::new(&[3, 5]);
81    /// let k = tracker.winner_index(&[1, 2]);
82    /// assert_eq!(k, 0); // initialized to 0
83    /// ```
84    ///
85    /// # Panics
86    ///
87    /// Panics if `position` has the wrong rank or any index is out of bounds
88    /// for the tracked output shape.
89    pub fn winner_index(&self, position: &[usize]) -> usize {
90        assert_eq!(
91            position.len(),
92            self.output_shape.len(),
93            "winner_index: expected {} indices, got {}",
94            self.output_shape.len(),
95            position.len()
96        );
97
98        // Column-major linear index
99        let mut linear = 0;
100        let mut stride = 1;
101        for (i, &p) in position.iter().enumerate() {
102            assert!(
103                p < self.output_shape[i],
104                "winner_index: index {} out of bounds for axis {} with size {}",
105                p,
106                i,
107                self.output_shape[i]
108            );
109            linear += p * stride;
110            stride *= self.output_shape[i];
111        }
112        self.indices[linear]
113    }
114}