tenferro_tropical/
prims.rs

1//! [`TensorPrims`] implementations for tropical algebras on [`CpuBackend`].
2//!
3//! Each tropical algebra gets its own `impl TensorPrims<XxxAlgebra> for CpuBackend`.
4//! The orphan rule is satisfied because `XxxAlgebra` is defined in this crate.
5//!
6//! Extended operations (Contract, ElementwiseMul) are not supported for
7//! tropical algebras — `has_extension_for` always returns `false`.
8
9use std::marker::PhantomData;
10
11use strided_traits::ScalarBase;
12use strided_view::{StridedView, StridedViewMut};
13use tenferro_device::Result;
14use tenferro_prims::{CpuBackend, Extension, PrimDescriptor, ReduceOp, TensorPrims};
15
16use crate::algebra::{MaxMulAlgebra, MaxPlusAlgebra, MinPlusAlgebra};
17
18/// Execution plan for tropical primitive operations on CPU.
19///
20/// Analogous to [`CpuPlan`](tenferro_prims::CpuPlan) but for tropical
21/// algebras. The plan captures pre-computed kernel selection information.
22///
23/// # Examples
24///
25/// ```ignore
26/// use tenferro_prims::{CpuBackend, TensorPrims, PrimDescriptor, ReduceOp};
27/// use tenferro_tropical::{MaxPlusAlgebra, TropicalPlan};
28///
29/// let desc = PrimDescriptor::Reduce {
30///     modes_a: vec![0, 1],
31///     modes_c: vec![0],
32///     op: ReduceOp::Max,
33/// };
34/// let plan: TropicalPlan<f64> = CpuBackend::plan::<f64>(&desc, &[&[3, 4], &[3]]).unwrap();
35/// ```
36pub enum TropicalPlan<T: ScalarBase> {
37    /// Plan for batched GEMM under tropical algebra.
38    BatchedGemm {
39        /// Number of rows.
40        m: usize,
41        /// Number of columns.
42        n: usize,
43        /// Contraction dimension.
44        k: usize,
45        _marker: PhantomData<T>,
46    },
47    /// Plan for reduction under tropical algebra.
48    Reduce {
49        /// Axis to reduce over.
50        axis: usize,
51        /// Reduction operation.
52        op: ReduceOp,
53        _marker: PhantomData<T>,
54    },
55    /// Plan for trace under tropical algebra.
56    Trace {
57        /// Paired modes.
58        paired: Vec<(u32, u32)>,
59        _marker: PhantomData<T>,
60    },
61    /// Plan for permutation.
62    Permute {
63        /// Permutation mapping.
64        perm: Vec<usize>,
65        _marker: PhantomData<T>,
66    },
67    /// Plan for anti-trace (AD backward).
68    AntiTrace {
69        /// Paired modes.
70        paired: Vec<(u32, u32)>,
71        _marker: PhantomData<T>,
72    },
73    /// Plan for anti-diag (AD backward).
74    AntiDiag {
75        /// Paired modes.
76        paired: Vec<(u32, u32)>,
77        _marker: PhantomData<T>,
78    },
79}
80
81// ---------------------------------------------------------------------------
82// impl TensorPrims<MaxPlusAlgebra> for CpuBackend
83// ---------------------------------------------------------------------------
84
85impl TensorPrims<MaxPlusAlgebra> for CpuBackend {
86    type Plan<T: ScalarBase> = TropicalPlan<T>;
87
88    fn plan<T: ScalarBase>(
89        _desc: &PrimDescriptor,
90        _shapes: &[&[usize]],
91    ) -> Result<TropicalPlan<T>> {
92        todo!()
93    }
94
95    fn execute<T: ScalarBase>(
96        _plan: &TropicalPlan<T>,
97        _alpha: T,
98        _inputs: &[&StridedView<T>],
99        _beta: T,
100        _output: &mut StridedViewMut<T>,
101    ) -> Result<()> {
102        todo!()
103    }
104
105    /// Tropical backends do not support extended operations.
106    fn has_extension_for<T: ScalarBase>(_ext: Extension) -> bool {
107        false
108    }
109}
110
111// ---------------------------------------------------------------------------
112// impl TensorPrims<MinPlusAlgebra> for CpuBackend
113// ---------------------------------------------------------------------------
114
115impl TensorPrims<MinPlusAlgebra> for CpuBackend {
116    type Plan<T: ScalarBase> = TropicalPlan<T>;
117
118    fn plan<T: ScalarBase>(
119        _desc: &PrimDescriptor,
120        _shapes: &[&[usize]],
121    ) -> Result<TropicalPlan<T>> {
122        todo!()
123    }
124
125    fn execute<T: ScalarBase>(
126        _plan: &TropicalPlan<T>,
127        _alpha: T,
128        _inputs: &[&StridedView<T>],
129        _beta: T,
130        _output: &mut StridedViewMut<T>,
131    ) -> Result<()> {
132        todo!()
133    }
134
135    /// Tropical backends do not support extended operations.
136    fn has_extension_for<T: ScalarBase>(_ext: Extension) -> bool {
137        false
138    }
139}
140
141// ---------------------------------------------------------------------------
142// impl TensorPrims<MaxMulAlgebra> for CpuBackend
143// ---------------------------------------------------------------------------
144
145impl TensorPrims<MaxMulAlgebra> for CpuBackend {
146    type Plan<T: ScalarBase> = TropicalPlan<T>;
147
148    fn plan<T: ScalarBase>(
149        _desc: &PrimDescriptor,
150        _shapes: &[&[usize]],
151    ) -> Result<TropicalPlan<T>> {
152        todo!()
153    }
154
155    fn execute<T: ScalarBase>(
156        _plan: &TropicalPlan<T>,
157        _alpha: T,
158        _inputs: &[&StridedView<T>],
159        _beta: T,
160        _output: &mut StridedViewMut<T>,
161    ) -> Result<()> {
162        todo!()
163    }
164
165    /// Tropical backends do not support extended operations.
166    fn has_extension_for<T: ScalarBase>(_ext: Extension) -> bool {
167        false
168    }
169}