tenferro_tropical/
argmax.rs

1//! Argmax tracking for tropical backward pass (automatic differentiation).
2//!
3//! In tropical semirings, the "gradient" of a contraction is determined by
4//! which elements "won" the max (or min) comparisons. [`ArgmaxTracker`]
5//! records the winner indices during the forward pass so that the backward
6//! pass can route gradients to the correct elements.
7//!
8//! This is the tropical analogue of storing intermediate activations for
9//! backpropagation in standard neural network training.
10
11/// Tracks winner indices from tropical forward-pass operations.
12///
13/// During a tropical contraction `C[i,j] = max_k (A[i,k] + B[k,j])`,
14/// the tracker records which `k` achieved the maximum for each `(i,j)`.
15/// The backward pass uses these indices to route gradients.
16///
17/// # Examples
18///
19/// ```ignore
20/// use tenferro_tropical::ArgmaxTracker;
21///
22/// // Create a tracker for a 3×5 output
23/// let tracker = ArgmaxTracker::new(&[3, 5]);
24///
25/// // After forward pass, query the winner index for output element (1, 2)
26/// let k_winner = tracker.winner_index(&[1, 2]);
27/// ```
28pub struct ArgmaxTracker {
29    /// Shape of the output tensor.
30    output_shape: Vec<usize>,
31    /// Winner indices (flat storage, one per output element).
32    /// Each entry records the contraction index that achieved the optimum.
33    indices: Vec<usize>,
34}
35
36impl ArgmaxTracker {
37    /// Create a new tracker for an output of the given shape.
38    ///
39    /// All winner indices are initialized to 0.
40    ///
41    /// # Examples
42    ///
43    /// ```ignore
44    /// use tenferro_tropical::ArgmaxTracker;
45    ///
46    /// let tracker = ArgmaxTracker::new(&[3, 5]);
47    /// assert_eq!(tracker.output_shape(), &[3, 5]);
48    /// ```
49    pub fn new(_output_shape: &[usize]) -> Self {
50        todo!()
51    }
52
53    /// Return the output shape.
54    pub fn output_shape(&self) -> &[usize] {
55        &self.output_shape
56    }
57
58    /// Return the winner indices as a flat slice.
59    pub fn indices(&self) -> &[usize] {
60        &self.indices
61    }
62
63    /// Return a mutable reference to the winner indices.
64    pub fn indices_mut(&mut self) -> &mut [usize] {
65        &mut self.indices
66    }
67
68    /// Look up the winner index for a given multi-dimensional output position.
69    ///
70    /// # Examples
71    ///
72    /// ```ignore
73    /// use tenferro_tropical::ArgmaxTracker;
74    ///
75    /// let tracker = ArgmaxTracker::new(&[3, 5]);
76    /// let k = tracker.winner_index(&[1, 2]);
77    /// ```
78    pub fn winner_index(&self, _position: &[usize]) -> usize {
79        todo!()
80    }
81}