tenferro_tropical/argmax.rs
1//! Argmax tracking for tropical backward pass (automatic differentiation).
2//!
3//! In tropical semirings, the "gradient" of a contraction is determined by
4//! which elements "won" the max (or min) comparisons. [`ArgmaxTracker`]
5//! records the winner indices during the forward pass so that the backward
6//! pass can route gradients to the correct elements.
7//!
8//! This is the tropical analogue of storing intermediate activations for
9//! backpropagation in standard neural network training.
10
11/// Tracks winner indices from tropical forward-pass operations.
12///
13/// During a tropical contraction `C[i,j] = max_k (A[i,k] + B[k,j])`,
14/// the tracker records which `k` achieved the maximum for each `(i,j)`.
15/// The backward pass uses these indices to route gradients.
16///
17/// # Examples
18///
19/// ```ignore
20/// use tenferro_tropical::ArgmaxTracker;
21///
22/// // Create a tracker for a 3×5 output
23/// let tracker = ArgmaxTracker::new(&[3, 5]);
24///
25/// // After forward pass, query the winner index for output element (1, 2)
26/// let k_winner = tracker.winner_index(&[1, 2]);
27/// ```
28pub struct ArgmaxTracker {
29 /// Shape of the output tensor.
30 output_shape: Vec<usize>,
31 /// Winner indices (flat storage, one per output element).
32 /// Each entry records the contraction index that achieved the optimum.
33 indices: Vec<usize>,
34}
35
36impl ArgmaxTracker {
37 /// Create a new tracker for an output of the given shape.
38 ///
39 /// All winner indices are initialized to 0.
40 ///
41 /// # Examples
42 ///
43 /// ```ignore
44 /// use tenferro_tropical::ArgmaxTracker;
45 ///
46 /// let tracker = ArgmaxTracker::new(&[3, 5]);
47 /// assert_eq!(tracker.output_shape(), &[3, 5]);
48 /// ```
49 pub fn new(_output_shape: &[usize]) -> Self {
50 todo!()
51 }
52
53 /// Return the output shape.
54 pub fn output_shape(&self) -> &[usize] {
55 &self.output_shape
56 }
57
58 /// Return the winner indices as a flat slice.
59 pub fn indices(&self) -> &[usize] {
60 &self.indices
61 }
62
63 /// Return a mutable reference to the winner indices.
64 pub fn indices_mut(&mut self) -> &mut [usize] {
65 &mut self.indices
66 }
67
68 /// Look up the winner index for a given multi-dimensional output position.
69 ///
70 /// # Examples
71 ///
72 /// ```ignore
73 /// use tenferro_tropical::ArgmaxTracker;
74 ///
75 /// let tracker = ArgmaxTracker::new(&[3, 5]);
76 /// let k = tracker.winner_index(&[1, 2]);
77 /// ```
78 pub fn winner_index(&self, _position: &[usize]) -> usize {
79 todo!()
80 }
81}