tenferro_prims/cpu/
plan.rs

1use std::marker::PhantomData;
2
3use tenferro_algebra::Scalar;
4
5use crate::SemiringBinaryOp;
6
7/// Compute connected components from a list of paired axis positions using union-find.
8///
9/// Returns `(components, comp_dims)` where:
10/// - `components[i]` = sorted list of all axis positions in the i-th component
11/// - `comp_dims[i]` = the shared dimension of the i-th component (looked up from `shape`)
12pub(super) fn compute_paired_components(
13    paired_axes: &[(usize, usize)],
14    shape: &[usize],
15) -> (Vec<Vec<usize>>, Vec<usize>) {
16    use std::collections::HashMap;
17
18    if paired_axes.is_empty() {
19        return (vec![], vec![]);
20    }
21
22    let mut all_axes: Vec<usize> = Vec::new();
23    for &(ax1, ax2) in paired_axes {
24        all_axes.push(ax1);
25        all_axes.push(ax2);
26    }
27    all_axes.sort();
28    all_axes.dedup();
29
30    let mut parent: HashMap<usize, usize> = all_axes.iter().map(|&ax| (ax, ax)).collect();
31
32    fn find(parent: &mut HashMap<usize, usize>, x: usize) -> usize {
33        let p = parent[&x];
34        if p != x {
35            let root = find(parent, p);
36            parent.insert(x, root);
37            root
38        } else {
39            x
40        }
41    }
42
43    for &(ax1, ax2) in paired_axes {
44        let r1 = find(&mut parent, ax1);
45        let r2 = find(&mut parent, ax2);
46        if r1 != r2 {
47            let (lo, hi) = if r1 < r2 { (r1, r2) } else { (r2, r1) };
48            parent.insert(hi, lo);
49        }
50    }
51
52    let mut groups: HashMap<usize, Vec<usize>> = HashMap::new();
53    for &ax in &all_axes {
54        let root = find(&mut parent, ax);
55        groups.entry(root).or_default().push(ax);
56    }
57
58    let mut components: Vec<Vec<usize>> = groups.into_values().collect();
59    components.sort_by_key(|c| c[0]);
60
61    let comp_dims: Vec<usize> = components.iter().map(|c| shape[c[0]]).collect();
62
63    (components, comp_dims)
64}
65
66/// Pre-computed mode analysis for Contract GEMM fast path.
67#[derive(Debug, Clone)]
68pub(super) struct ContractGemmSpec {
69    /// Target mode order for A: [batch, m, k]
70    pub(super) a_target: Vec<u32>,
71    /// Target mode order for B: [batch, k, n]
72    pub(super) b_target: Vec<u32>,
73    /// Target mode order for C: [batch, m, n]
74    pub(super) c_target: Vec<u32>,
75    pub(super) batch_modes: Vec<u32>,
76    pub(super) m_modes: Vec<u32>,
77    pub(super) n_modes: Vec<u32>,
78    pub(super) k_modes: Vec<u32>,
79}
80
81/// Build a [`ContractGemmSpec`] from mode labels, or `None` if the
82/// contraction is not a valid batched-GEMM pattern.
83pub(super) fn build_contract_gemm_spec(
84    modes_a: &[u32],
85    modes_b: &[u32],
86    modes_c: &[u32],
87) -> Option<ContractGemmSpec> {
88    let batch_modes: Vec<u32> = modes_c
89        .iter()
90        .copied()
91        .filter(|m| modes_a.contains(m) && modes_b.contains(m))
92        .collect();
93    let m_modes: Vec<u32> = modes_c
94        .iter()
95        .copied()
96        .filter(|m| modes_a.contains(m) && !modes_b.contains(m))
97        .collect();
98    let n_modes: Vec<u32> = modes_c
99        .iter()
100        .copied()
101        .filter(|m| modes_b.contains(m) && !modes_a.contains(m))
102        .collect();
103    let k_modes: Vec<u32> = modes_a
104        .iter()
105        .copied()
106        .filter(|m| modes_b.contains(m) && !modes_c.contains(m))
107        .collect();
108
109    let expected_a = batch_modes.len() + m_modes.len() + k_modes.len();
110    let expected_b = batch_modes.len() + k_modes.len() + n_modes.len();
111    if expected_a != modes_a.len() || expected_b != modes_b.len() {
112        return None;
113    }
114    if batch_modes.len() + m_modes.len() + n_modes.len() != modes_c.len() {
115        return None;
116    }
117
118    let a_target: Vec<u32> = batch_modes
119        .iter()
120        .chain(m_modes.iter())
121        .chain(k_modes.iter())
122        .copied()
123        .collect();
124    let b_target: Vec<u32> = batch_modes
125        .iter()
126        .chain(k_modes.iter())
127        .chain(n_modes.iter())
128        .copied()
129        .collect();
130    let c_target: Vec<u32> = batch_modes
131        .iter()
132        .chain(m_modes.iter())
133        .chain(n_modes.iter())
134        .copied()
135        .collect();
136
137    Some(ContractGemmSpec {
138        a_target,
139        b_target,
140        c_target,
141        batch_modes,
142        m_modes,
143        n_modes,
144        k_modes,
145    })
146}
147
148/// CPU plan — concrete enum, no type erasure.
149///
150/// Created by the family planners on [`crate::CpuBackend`] and consumed by the
151/// semiring family executors.
152#[derive(Debug, Clone)]
153pub enum CpuPlan<T: Scalar> {
154    /// Plan for batched GEMM.
155    BatchedGemm {
156        /// Batch dimension sizes.
157        batch_dims: Vec<usize>,
158        /// Number of rows.
159        m: usize,
160        /// Number of columns.
161        n: usize,
162        /// Contraction dimension.
163        k: usize,
164        _marker: PhantomData<T>,
165    },
166    /// Plan for semiring-add reduction.
167    ReduceAdd {
168        /// Axes to reduce over (positions in input tensor).
169        reduced_axes: Vec<usize>,
170        _marker: PhantomData<T>,
171    },
172    /// Plan for trace.
173    Trace {
174        /// Output axis positions mapping.
175        free_axes: Vec<usize>,
176        /// Connected components of paired axes (union-find groups).
177        components: Vec<Vec<usize>>,
178        /// Dimension of each component (all axes in a component share the same dim).
179        comp_dims: Vec<usize>,
180        _marker: PhantomData<T>,
181    },
182    /// Plan for anti-trace (AD backward).
183    AntiTrace {
184        /// Paired axis positions in output tensor.
185        paired_axes: Vec<(usize, usize)>,
186        /// Input axis positions mapping.
187        free_axes: Vec<usize>,
188        /// Connected components of paired axes (union-find groups).
189        components: Vec<Vec<usize>>,
190        /// Dimension of each component.
191        comp_dims: Vec<usize>,
192        _marker: PhantomData<T>,
193    },
194    /// Plan for anti-diag (AD backward).
195    AntiDiag {
196        /// Paired axis positions in output tensor.
197        paired_axes: Vec<(usize, usize)>,
198        /// Input axis positions mapping.
199        free_axes: Vec<usize>,
200        /// Connected components of paired axes (union-find groups).
201        components: Vec<Vec<usize>>,
202        /// Dimension of each component.
203        comp_dims: Vec<usize>,
204        /// Indices of generative components (no overlap with free axes).
205        generative_comps: Vec<usize>,
206        _marker: PhantomData<T>,
207    },
208    /// Plan for fused contraction.
209    Contract {
210        /// Mode labels for input A.
211        modes_a: Vec<u32>,
212        /// Mode labels for input B.
213        modes_b: Vec<u32>,
214        /// Mode labels for output C.
215        modes_c: Vec<u32>,
216        /// Cached GEMM mode analysis (None if not a valid GEMM pattern).
217        #[allow(private_interfaces)]
218        gemm_spec: Option<ContractGemmSpec>,
219        _marker: PhantomData<T>,
220    },
221    /// Plan for optional semiring binary fast paths.
222    ElementwiseBinary {
223        /// The semiring binary operation to apply.
224        op: SemiringBinaryOp,
225        _marker: PhantomData<T>,
226    },
227    /// Plan for making a tensor contiguous.
228    MakeContiguous { _marker: PhantomData<T> },
229}