tenferro_ext_tropical/prims/
plan.rs

1use std::collections::HashSet;
2use std::marker::PhantomData;
3
4use tenferro_algebra::Scalar;
5use tenferro_device::{Error, Result};
6use tenferro_prims::SemiringCoreDescriptor;
7
8use super::view::mode_position;
9
10/// Execution plan for tropical primitive operations on CPU.
11///
12/// Analogous to [`CpuPlan`](tenferro_prims::CpuPlan) but for tropical
13/// algebras. The plan captures pre-computed kernel selection information.
14///
15/// # Examples
16///
17/// ```ignore
18/// use tenferro_device::LogicalMemorySpace;
19/// use tenferro_prims::{CpuBackend, CpuContext, SemiringCoreDescriptor, TensorSemiringCore};
20/// use tenferro_tensor::{MemoryOrder, Tensor};
21/// use tenferro_ext_tropical::{MaxPlus, MaxPlusAlgebra, TropicalPlan};
22///
23/// let mut ctx = CpuContext::new(1);
24/// let col = MemoryOrder::ColumnMajor;
25/// let mem = LogicalMemorySpace::MainMemory;
26/// let a = Tensor::<MaxPlus<f64>>::zeros(&[3, 4], mem, col);
27/// let mut c = Tensor::<MaxPlus<f64>>::zeros(&[3], mem, col);
28/// let desc = SemiringCoreDescriptor::ReduceAdd {
29///     modes_a: vec![0, 1],
30///     modes_c: vec![0],
31/// };
32/// let plan =
33///     <CpuBackend as TensorSemiringCore<MaxPlusAlgebra<f64>>>::plan(
34///         &mut ctx,
35///         &desc,
36///         &[&[3, 4], &[3]],
37///     )
38///         .unwrap();
39/// <CpuBackend as TensorSemiringCore<MaxPlusAlgebra<f64>>>::execute(
40///     &mut ctx,
41///     &plan,
42///     MaxPlus::one(),
43///     &[&a],
44///     MaxPlus::zero(),
45///     &mut c,
46/// )
47/// .unwrap();
48/// ```
49#[derive(Debug)]
50pub enum TropicalPlan<T: Scalar> {
51    /// Plan for batched GEMM under tropical algebra.
52    BatchedGemm {
53        /// Batch dimension sizes.
54        batch_dims: Vec<usize>,
55        /// Number of rows.
56        m: usize,
57        /// Number of columns.
58        n: usize,
59        /// Contraction dimension.
60        k: usize,
61        _marker: PhantomData<T>,
62    },
63    /// Plan for reduction under tropical algebra.
64    Reduce {
65        /// Axes to reduce over (positions in input).
66        reduced_axes: Vec<usize>,
67        _marker: PhantomData<T>,
68    },
69    /// Plan for trace under tropical algebra.
70    Trace {
71        /// Paired axis positions in input.
72        paired_axes: Vec<(usize, usize)>,
73        /// Free axis positions in input (corresponding to output modes).
74        free_axes: Vec<usize>,
75        _marker: PhantomData<T>,
76    },
77    /// Plan for anti-trace (AD backward).
78    AntiTrace {
79        /// Paired axis positions in output.
80        paired_axes: Vec<(usize, usize)>,
81        /// Free axis positions in output (corresponding to input modes).
82        free_axes: Vec<usize>,
83        _marker: PhantomData<T>,
84    },
85    /// Plan for anti-diag (AD backward).
86    AntiDiag {
87        /// Paired axis positions in output.
88        paired_axes: Vec<(usize, usize)>,
89        /// Free axis positions in output (corresponding to input modes).
90        free_axes: Vec<usize>,
91        _marker: PhantomData<T>,
92    },
93    /// Plan for making a tensor contiguous.
94    MakeContiguous { _marker: PhantomData<T> },
95}
96
97fn ensure_shape_count(shapes: &[&[usize]], expected: usize, op: &str) -> Result<()> {
98    if shapes.len() != expected {
99        return Err(Error::InvalidArgument(format!(
100            "{op} expects {expected} shapes, got {}",
101            shapes.len()
102        )));
103    }
104    Ok(())
105}
106
107fn ensure_unique_modes(modes: &[u32], name: &str) -> Result<()> {
108    let mut seen = HashSet::new();
109    for &m in modes {
110        if !seen.insert(m) {
111            return Err(Error::InvalidArgument(format!(
112                "{name} contains duplicate mode label {m}"
113            )));
114        }
115    }
116    Ok(())
117}
118
119fn ensure_pair_labels_unique(paired: &[(u32, u32)], name: &str) -> Result<()> {
120    let mut seen = HashSet::new();
121    for &(m1, m2) in paired {
122        if m1 == m2 {
123            return Err(Error::InvalidArgument(format!(
124                "{name} contains invalid pair ({m1},{m2})"
125            )));
126        }
127        if !seen.insert(m1) || !seen.insert(m2) {
128            return Err(Error::InvalidArgument(format!(
129                "{name} contains duplicated paired label"
130            )));
131        }
132    }
133    Ok(())
134}
135
136pub(crate) fn tropical_plan<T: Scalar>(
137    desc: &SemiringCoreDescriptor,
138    shapes: &[&[usize]],
139) -> Result<TropicalPlan<T>> {
140    match desc {
141        SemiringCoreDescriptor::BatchedGemm {
142            batch_dims,
143            m,
144            n,
145            k,
146        } => {
147            ensure_shape_count(shapes, 3, "BatchedGemm")?;
148            let a_shape = shapes[0];
149            let b_shape = shapes[1];
150            let c_shape = shapes[2];
151            let expected_rank = batch_dims.len() + 2;
152            if a_shape.len() != expected_rank
153                || b_shape.len() != expected_rank
154                || c_shape.len() != expected_rank
155            {
156                return Err(Error::InvalidArgument(
157                    "BatchedGemm rank mismatch between descriptor and shapes".into(),
158                ));
159            }
160            if a_shape[0] != *m || a_shape[1] != *k {
161                return Err(Error::InvalidArgument(
162                    "BatchedGemm A shape mismatch".into(),
163                ));
164            }
165            if b_shape[0] != *k || b_shape[1] != *n {
166                return Err(Error::InvalidArgument(
167                    "BatchedGemm B shape mismatch".into(),
168                ));
169            }
170            if c_shape[0] != *m || c_shape[1] != *n {
171                return Err(Error::InvalidArgument(
172                    "BatchedGemm C shape mismatch".into(),
173                ));
174            }
175            for (i, &bd) in batch_dims.iter().enumerate() {
176                if a_shape[2 + i] != bd || b_shape[2 + i] != bd || c_shape[2 + i] != bd {
177                    return Err(Error::InvalidArgument(
178                        "BatchedGemm batch dimensions do not match shapes".into(),
179                    ));
180                }
181            }
182
183            Ok(TropicalPlan::BatchedGemm {
184                batch_dims: batch_dims.clone(),
185                m: *m,
186                n: *n,
187                k: *k,
188                _marker: PhantomData,
189            })
190        }
191        SemiringCoreDescriptor::ReduceAdd { modes_a, modes_c } => {
192            ensure_shape_count(shapes, 2, "ReduceAdd")?;
193            ensure_unique_modes(modes_a, "modes_a")?;
194            ensure_unique_modes(modes_c, "modes_c")?;
195            let a_shape = shapes[0];
196            let c_shape = shapes[1];
197            if modes_a.len() != a_shape.len() || modes_c.len() != c_shape.len() {
198                return Err(Error::InvalidArgument(
199                    "Reduce mode rank does not match shape rank".into(),
200                ));
201            }
202            for &m in modes_c {
203                if !modes_a.contains(&m) {
204                    return Err(Error::InvalidArgument(
205                        "Reduce modes_c must be a subset of modes_a".into(),
206                    ));
207                }
208            }
209            for (out_ax, &m) in modes_c.iter().enumerate() {
210                let in_ax = mode_position(modes_a, m)?;
211                if a_shape[in_ax] != c_shape[out_ax] {
212                    return Err(Error::InvalidArgument(
213                        "Reduce output shape does not match input modes".into(),
214                    ));
215                }
216            }
217
218            let reduced_axes: Vec<usize> = modes_a
219                .iter()
220                .enumerate()
221                .filter(|(_, m)| !modes_c.contains(m))
222                .map(|(i, _)| i)
223                .collect();
224            Ok(TropicalPlan::Reduce {
225                reduced_axes,
226                _marker: PhantomData,
227            })
228        }
229        SemiringCoreDescriptor::Trace {
230            modes_a,
231            modes_c,
232            paired,
233        } => {
234            ensure_shape_count(shapes, 2, "Trace")?;
235            ensure_unique_modes(modes_a, "modes_a")?;
236            ensure_unique_modes(modes_c, "modes_c")?;
237            if paired.is_empty() {
238                return Err(Error::InvalidArgument(
239                    "Trace requires non-empty paired axes".into(),
240                ));
241            }
242            ensure_pair_labels_unique(paired, "Trace paired")?;
243            let a_shape = shapes[0];
244            let c_shape = shapes[1];
245            if modes_a.len() != a_shape.len() || modes_c.len() != c_shape.len() {
246                return Err(Error::InvalidArgument(
247                    "Trace mode rank does not match shape rank".into(),
248                ));
249            }
250
251            let paired_labels: HashSet<u32> =
252                paired.iter().flat_map(|(m1, m2)| [*m1, *m2]).collect();
253            for &(m1, m2) in paired {
254                if !modes_a.contains(&m1) || !modes_a.contains(&m2) {
255                    return Err(Error::InvalidArgument(
256                        "Trace paired labels must exist in modes_a".into(),
257                    ));
258                }
259                if modes_c.contains(&m1) || modes_c.contains(&m2) {
260                    return Err(Error::InvalidArgument(
261                        "Trace paired labels must be reduced (not present in modes_c)".into(),
262                    ));
263                }
264                let ax1 = mode_position(modes_a, m1)?;
265                let ax2 = mode_position(modes_a, m2)?;
266                if a_shape[ax1] != a_shape[ax2] {
267                    return Err(Error::InvalidArgument(
268                        "Trace paired dimensions must be equal".into(),
269                    ));
270                }
271            }
272            for &m in modes_a {
273                if !modes_c.contains(&m) && !paired_labels.contains(&m) {
274                    return Err(Error::InvalidArgument(
275                        "Trace modes_a contains labels neither free nor paired".into(),
276                    ));
277                }
278            }
279            for (out_ax, &m) in modes_c.iter().enumerate() {
280                if paired_labels.contains(&m) {
281                    return Err(Error::InvalidArgument(
282                        "Trace free labels must not be in paired set".into(),
283                    ));
284                }
285                let in_ax = mode_position(modes_a, m)?;
286                if a_shape[in_ax] != c_shape[out_ax] {
287                    return Err(Error::InvalidArgument(
288                        "Trace output shape does not match free modes".into(),
289                    ));
290                }
291            }
292
293            let paired_axes: Vec<(usize, usize)> = paired
294                .iter()
295                .map(|(m1, m2)| Ok((mode_position(modes_a, *m1)?, mode_position(modes_a, *m2)?)))
296                .collect::<Result<_>>()?;
297            let free_axes: Vec<usize> = modes_c
298                .iter()
299                .map(|m| mode_position(modes_a, *m))
300                .collect::<Result<_>>()?;
301            Ok(TropicalPlan::Trace {
302                paired_axes,
303                free_axes,
304                _marker: PhantomData,
305            })
306        }
307        SemiringCoreDescriptor::AntiTrace {
308            modes_a,
309            modes_c,
310            paired,
311        } => {
312            ensure_shape_count(shapes, 2, "AntiTrace")?;
313            ensure_unique_modes(modes_a, "modes_a")?;
314            ensure_unique_modes(modes_c, "modes_c")?;
315            if paired.is_empty() {
316                return Err(Error::InvalidArgument(
317                    "AntiTrace requires non-empty paired axes".into(),
318                ));
319            }
320            ensure_pair_labels_unique(paired, "AntiTrace paired")?;
321            let a_shape = shapes[0];
322            let c_shape = shapes[1];
323            if modes_a.len() != a_shape.len() || modes_c.len() != c_shape.len() {
324                return Err(Error::InvalidArgument(
325                    "AntiTrace mode rank does not match shape rank".into(),
326                ));
327            }
328
329            let paired_labels: HashSet<u32> =
330                paired.iter().flat_map(|(m1, m2)| [*m1, *m2]).collect();
331            for &(m1, m2) in paired {
332                if !modes_c.contains(&m1) || !modes_c.contains(&m2) {
333                    return Err(Error::InvalidArgument(
334                        "AntiTrace paired labels must exist in modes_c".into(),
335                    ));
336                }
337                if modes_a.contains(&m1) || modes_a.contains(&m2) {
338                    return Err(Error::InvalidArgument(
339                        "AntiTrace paired labels must not be in modes_a".into(),
340                    ));
341                }
342                let ax1 = mode_position(modes_c, m1)?;
343                let ax2 = mode_position(modes_c, m2)?;
344                if c_shape[ax1] != c_shape[ax2] {
345                    return Err(Error::InvalidArgument(
346                        "AntiTrace paired dimensions must be equal".into(),
347                    ));
348                }
349            }
350            for &m in modes_c {
351                if !modes_a.contains(&m) && !paired_labels.contains(&m) {
352                    return Err(Error::InvalidArgument(
353                        "AntiTrace modes_c contains labels neither free nor paired".into(),
354                    ));
355                }
356            }
357            for (in_ax, &m) in modes_a.iter().enumerate() {
358                if paired_labels.contains(&m) {
359                    return Err(Error::InvalidArgument(
360                        "AntiTrace free labels must not be in paired set".into(),
361                    ));
362                }
363                let out_ax = mode_position(modes_c, m)?;
364                if a_shape[in_ax] != c_shape[out_ax] {
365                    return Err(Error::InvalidArgument(
366                        "AntiTrace input shape does not match output free modes".into(),
367                    ));
368                }
369            }
370
371            let paired_axes: Vec<(usize, usize)> = paired
372                .iter()
373                .map(|(m1, m2)| Ok((mode_position(modes_c, *m1)?, mode_position(modes_c, *m2)?)))
374                .collect::<Result<_>>()?;
375            let free_axes: Vec<usize> = modes_a
376                .iter()
377                .map(|m| mode_position(modes_c, *m))
378                .collect::<Result<_>>()?;
379            Ok(TropicalPlan::AntiTrace {
380                paired_axes,
381                free_axes,
382                _marker: PhantomData,
383            })
384        }
385        SemiringCoreDescriptor::AntiDiag {
386            modes_a,
387            modes_c,
388            paired,
389        } => {
390            ensure_shape_count(shapes, 2, "AntiDiag")?;
391            ensure_unique_modes(modes_a, "modes_a")?;
392            ensure_unique_modes(modes_c, "modes_c")?;
393            if paired.is_empty() {
394                return Err(Error::InvalidArgument(
395                    "AntiDiag requires non-empty paired axes".into(),
396                ));
397            }
398            ensure_pair_labels_unique(paired, "AntiDiag paired")?;
399            let a_shape = shapes[0];
400            let c_shape = shapes[1];
401            if modes_a.len() != a_shape.len() || modes_c.len() != c_shape.len() {
402                return Err(Error::InvalidArgument(
403                    "AntiDiag mode rank does not match shape rank".into(),
404                ));
405            }
406
407            let paired_labels: HashSet<u32> =
408                paired.iter().flat_map(|(m1, m2)| [*m1, *m2]).collect();
409            let free_labels: HashSet<u32> = modes_a.iter().copied().collect();
410            for &(m1, m2) in paired {
411                if !modes_c.contains(&m1) || !modes_c.contains(&m2) {
412                    return Err(Error::InvalidArgument(
413                        "AntiDiag paired labels must exist in modes_c".into(),
414                    ));
415                }
416                if !free_labels.contains(&m1) {
417                    return Err(Error::InvalidArgument(
418                        "AntiDiag first paired label must exist in modes_a".into(),
419                    ));
420                }
421                if free_labels.contains(&m2) {
422                    return Err(Error::InvalidArgument(
423                        "AntiDiag second paired label must not exist in modes_a".into(),
424                    ));
425                }
426                let ax1 = mode_position(modes_c, m1)?;
427                let ax2 = mode_position(modes_c, m2)?;
428                if c_shape[ax1] != c_shape[ax2] {
429                    return Err(Error::InvalidArgument(
430                        "AntiDiag paired dimensions must be equal".into(),
431                    ));
432                }
433            }
434            for &m in modes_c {
435                if !free_labels.contains(&m) && !paired_labels.contains(&m) {
436                    return Err(Error::InvalidArgument(
437                        "AntiDiag modes_c contains labels neither free nor paired".into(),
438                    ));
439                }
440            }
441            for (in_ax, &m) in modes_a.iter().enumerate() {
442                let out_ax = mode_position(modes_c, m)?;
443                if a_shape[in_ax] != c_shape[out_ax] {
444                    return Err(Error::InvalidArgument(
445                        "AntiDiag input shape does not match output free modes".into(),
446                    ));
447                }
448            }
449
450            let paired_axes: Vec<(usize, usize)> = paired
451                .iter()
452                .map(|(m1, m2)| Ok((mode_position(modes_c, *m1)?, mode_position(modes_c, *m2)?)))
453                .collect::<Result<_>>()?;
454            let free_axes: Vec<usize> = modes_a
455                .iter()
456                .map(|m| mode_position(modes_c, *m))
457                .collect::<Result<_>>()?;
458            Ok(TropicalPlan::AntiDiag {
459                paired_axes,
460                free_axes,
461                _marker: PhantomData,
462            })
463        }
464        SemiringCoreDescriptor::MakeContiguous => {
465            ensure_shape_count(shapes, 2, "MakeContiguous")?;
466            if shapes[0] != shapes[1] {
467                return Err(Error::InvalidArgument(
468                    "MakeContiguous input and output shapes must match".into(),
469                ));
470            }
471            Ok(TropicalPlan::MakeContiguous {
472                _marker: PhantomData,
473            })
474        }
475    }
476}