tenferro_tropical/prims.rs
1//! [`TensorPrims`] implementations for tropical algebras on [`CpuBackend`].
2//!
3//! Each tropical algebra gets its own `impl TensorPrims<XxxAlgebra> for CpuBackend`.
4//! The orphan rule is satisfied because `XxxAlgebra` is defined in this crate.
5//!
6//! Extended operations (Contract, ElementwiseMul) are not supported for
7//! tropical algebras — `has_extension_for` always returns `false`.
8
9use std::marker::PhantomData;
10
11use strided_traits::ScalarBase;
12use strided_view::{StridedView, StridedViewMut};
13use tenferro_device::Result;
14use tenferro_prims::{CpuBackend, Extension, PrimDescriptor, ReduceOp, TensorPrims};
15
16use crate::algebra::{MaxMulAlgebra, MaxPlusAlgebra, MinPlusAlgebra};
17
18/// Execution plan for tropical primitive operations on CPU.
19///
20/// Analogous to [`CpuPlan`](tenferro_prims::CpuPlan) but for tropical
21/// algebras. The plan captures pre-computed kernel selection information.
22///
23/// # Examples
24///
25/// ```ignore
26/// use tenferro_prims::{CpuBackend, TensorPrims, PrimDescriptor, ReduceOp};
27/// use tenferro_tropical::{MaxPlusAlgebra, TropicalPlan};
28///
29/// let desc = PrimDescriptor::Reduce {
30/// modes_a: vec![0, 1],
31/// modes_c: vec![0],
32/// op: ReduceOp::Max,
33/// };
34/// let plan: TropicalPlan<f64> = CpuBackend::plan::<f64>(&desc, &[&[3, 4], &[3]]).unwrap();
35/// ```
36pub enum TropicalPlan<T: ScalarBase> {
37 /// Plan for batched GEMM under tropical algebra.
38 BatchedGemm {
39 /// Number of rows.
40 m: usize,
41 /// Number of columns.
42 n: usize,
43 /// Contraction dimension.
44 k: usize,
45 _marker: PhantomData<T>,
46 },
47 /// Plan for reduction under tropical algebra.
48 Reduce {
49 /// Axis to reduce over.
50 axis: usize,
51 /// Reduction operation.
52 op: ReduceOp,
53 _marker: PhantomData<T>,
54 },
55 /// Plan for trace under tropical algebra.
56 Trace {
57 /// Paired modes.
58 paired: Vec<(u32, u32)>,
59 _marker: PhantomData<T>,
60 },
61 /// Plan for permutation.
62 Permute {
63 /// Permutation mapping.
64 perm: Vec<usize>,
65 _marker: PhantomData<T>,
66 },
67 /// Plan for anti-trace (AD backward).
68 AntiTrace {
69 /// Paired modes.
70 paired: Vec<(u32, u32)>,
71 _marker: PhantomData<T>,
72 },
73 /// Plan for anti-diag (AD backward).
74 AntiDiag {
75 /// Paired modes.
76 paired: Vec<(u32, u32)>,
77 _marker: PhantomData<T>,
78 },
79}
80
81// ---------------------------------------------------------------------------
82// impl TensorPrims<MaxPlusAlgebra> for CpuBackend
83// ---------------------------------------------------------------------------
84
85impl TensorPrims<MaxPlusAlgebra> for CpuBackend {
86 type Plan<T: ScalarBase> = TropicalPlan<T>;
87
88 fn plan<T: ScalarBase>(
89 _desc: &PrimDescriptor,
90 _shapes: &[&[usize]],
91 ) -> Result<TropicalPlan<T>> {
92 todo!()
93 }
94
95 fn execute<T: ScalarBase>(
96 _plan: &TropicalPlan<T>,
97 _alpha: T,
98 _inputs: &[&StridedView<T>],
99 _beta: T,
100 _output: &mut StridedViewMut<T>,
101 ) -> Result<()> {
102 todo!()
103 }
104
105 /// Tropical backends do not support extended operations.
106 fn has_extension_for<T: ScalarBase>(_ext: Extension) -> bool {
107 false
108 }
109}
110
111// ---------------------------------------------------------------------------
112// impl TensorPrims<MinPlusAlgebra> for CpuBackend
113// ---------------------------------------------------------------------------
114
115impl TensorPrims<MinPlusAlgebra> for CpuBackend {
116 type Plan<T: ScalarBase> = TropicalPlan<T>;
117
118 fn plan<T: ScalarBase>(
119 _desc: &PrimDescriptor,
120 _shapes: &[&[usize]],
121 ) -> Result<TropicalPlan<T>> {
122 todo!()
123 }
124
125 fn execute<T: ScalarBase>(
126 _plan: &TropicalPlan<T>,
127 _alpha: T,
128 _inputs: &[&StridedView<T>],
129 _beta: T,
130 _output: &mut StridedViewMut<T>,
131 ) -> Result<()> {
132 todo!()
133 }
134
135 /// Tropical backends do not support extended operations.
136 fn has_extension_for<T: ScalarBase>(_ext: Extension) -> bool {
137 false
138 }
139}
140
141// ---------------------------------------------------------------------------
142// impl TensorPrims<MaxMulAlgebra> for CpuBackend
143// ---------------------------------------------------------------------------
144
145impl TensorPrims<MaxMulAlgebra> for CpuBackend {
146 type Plan<T: ScalarBase> = TropicalPlan<T>;
147
148 fn plan<T: ScalarBase>(
149 _desc: &PrimDescriptor,
150 _shapes: &[&[usize]],
151 ) -> Result<TropicalPlan<T>> {
152 todo!()
153 }
154
155 fn execute<T: ScalarBase>(
156 _plan: &TropicalPlan<T>,
157 _alpha: T,
158 _inputs: &[&StridedView<T>],
159 _beta: T,
160 _output: &mut StridedViewMut<T>,
161 ) -> Result<()> {
162 todo!()
163 }
164
165 /// Tropical backends do not support extended operations.
166 fn has_extension_for<T: ScalarBase>(_ext: Extension) -> bool {
167 false
168 }
169}