1use std::collections::HashSet;
2use std::marker::PhantomData;
3
4use tenferro_algebra::Scalar;
5use tenferro_device::{Error, Result};
6use tenferro_prims::SemiringCoreDescriptor;
7
8use super::view::mode_position;
9
10#[derive(Debug)]
50pub enum TropicalPlan<T: Scalar> {
51 BatchedGemm {
53 batch_dims: Vec<usize>,
55 m: usize,
57 n: usize,
59 k: usize,
61 _marker: PhantomData<T>,
62 },
63 Reduce {
65 reduced_axes: Vec<usize>,
67 _marker: PhantomData<T>,
68 },
69 Trace {
71 paired_axes: Vec<(usize, usize)>,
73 free_axes: Vec<usize>,
75 _marker: PhantomData<T>,
76 },
77 AntiTrace {
79 paired_axes: Vec<(usize, usize)>,
81 free_axes: Vec<usize>,
83 _marker: PhantomData<T>,
84 },
85 AntiDiag {
87 paired_axes: Vec<(usize, usize)>,
89 free_axes: Vec<usize>,
91 _marker: PhantomData<T>,
92 },
93 MakeContiguous { _marker: PhantomData<T> },
95}
96
97fn ensure_shape_count(shapes: &[&[usize]], expected: usize, op: &str) -> Result<()> {
98 if shapes.len() != expected {
99 return Err(Error::InvalidArgument(format!(
100 "{op} expects {expected} shapes, got {}",
101 shapes.len()
102 )));
103 }
104 Ok(())
105}
106
107fn ensure_unique_modes(modes: &[u32], name: &str) -> Result<()> {
108 let mut seen = HashSet::new();
109 for &m in modes {
110 if !seen.insert(m) {
111 return Err(Error::InvalidArgument(format!(
112 "{name} contains duplicate mode label {m}"
113 )));
114 }
115 }
116 Ok(())
117}
118
119fn ensure_pair_labels_unique(paired: &[(u32, u32)], name: &str) -> Result<()> {
120 let mut seen = HashSet::new();
121 for &(m1, m2) in paired {
122 if m1 == m2 {
123 return Err(Error::InvalidArgument(format!(
124 "{name} contains invalid pair ({m1},{m2})"
125 )));
126 }
127 if !seen.insert(m1) || !seen.insert(m2) {
128 return Err(Error::InvalidArgument(format!(
129 "{name} contains duplicated paired label"
130 )));
131 }
132 }
133 Ok(())
134}
135
136pub(crate) fn tropical_plan<T: Scalar>(
137 desc: &SemiringCoreDescriptor,
138 shapes: &[&[usize]],
139) -> Result<TropicalPlan<T>> {
140 match desc {
141 SemiringCoreDescriptor::BatchedGemm {
142 batch_dims,
143 m,
144 n,
145 k,
146 } => {
147 ensure_shape_count(shapes, 3, "BatchedGemm")?;
148 let a_shape = shapes[0];
149 let b_shape = shapes[1];
150 let c_shape = shapes[2];
151 let expected_rank = batch_dims.len() + 2;
152 if a_shape.len() != expected_rank
153 || b_shape.len() != expected_rank
154 || c_shape.len() != expected_rank
155 {
156 return Err(Error::InvalidArgument(
157 "BatchedGemm rank mismatch between descriptor and shapes".into(),
158 ));
159 }
160 if a_shape[0] != *m || a_shape[1] != *k {
161 return Err(Error::InvalidArgument(
162 "BatchedGemm A shape mismatch".into(),
163 ));
164 }
165 if b_shape[0] != *k || b_shape[1] != *n {
166 return Err(Error::InvalidArgument(
167 "BatchedGemm B shape mismatch".into(),
168 ));
169 }
170 if c_shape[0] != *m || c_shape[1] != *n {
171 return Err(Error::InvalidArgument(
172 "BatchedGemm C shape mismatch".into(),
173 ));
174 }
175 for (i, &bd) in batch_dims.iter().enumerate() {
176 if a_shape[2 + i] != bd || b_shape[2 + i] != bd || c_shape[2 + i] != bd {
177 return Err(Error::InvalidArgument(
178 "BatchedGemm batch dimensions do not match shapes".into(),
179 ));
180 }
181 }
182
183 Ok(TropicalPlan::BatchedGemm {
184 batch_dims: batch_dims.clone(),
185 m: *m,
186 n: *n,
187 k: *k,
188 _marker: PhantomData,
189 })
190 }
191 SemiringCoreDescriptor::ReduceAdd { modes_a, modes_c } => {
192 ensure_shape_count(shapes, 2, "ReduceAdd")?;
193 ensure_unique_modes(modes_a, "modes_a")?;
194 ensure_unique_modes(modes_c, "modes_c")?;
195 let a_shape = shapes[0];
196 let c_shape = shapes[1];
197 if modes_a.len() != a_shape.len() || modes_c.len() != c_shape.len() {
198 return Err(Error::InvalidArgument(
199 "Reduce mode rank does not match shape rank".into(),
200 ));
201 }
202 for &m in modes_c {
203 if !modes_a.contains(&m) {
204 return Err(Error::InvalidArgument(
205 "Reduce modes_c must be a subset of modes_a".into(),
206 ));
207 }
208 }
209 for (out_ax, &m) in modes_c.iter().enumerate() {
210 let in_ax = mode_position(modes_a, m)?;
211 if a_shape[in_ax] != c_shape[out_ax] {
212 return Err(Error::InvalidArgument(
213 "Reduce output shape does not match input modes".into(),
214 ));
215 }
216 }
217
218 let reduced_axes: Vec<usize> = modes_a
219 .iter()
220 .enumerate()
221 .filter(|(_, m)| !modes_c.contains(m))
222 .map(|(i, _)| i)
223 .collect();
224 Ok(TropicalPlan::Reduce {
225 reduced_axes,
226 _marker: PhantomData,
227 })
228 }
229 SemiringCoreDescriptor::Trace {
230 modes_a,
231 modes_c,
232 paired,
233 } => {
234 ensure_shape_count(shapes, 2, "Trace")?;
235 ensure_unique_modes(modes_a, "modes_a")?;
236 ensure_unique_modes(modes_c, "modes_c")?;
237 if paired.is_empty() {
238 return Err(Error::InvalidArgument(
239 "Trace requires non-empty paired axes".into(),
240 ));
241 }
242 ensure_pair_labels_unique(paired, "Trace paired")?;
243 let a_shape = shapes[0];
244 let c_shape = shapes[1];
245 if modes_a.len() != a_shape.len() || modes_c.len() != c_shape.len() {
246 return Err(Error::InvalidArgument(
247 "Trace mode rank does not match shape rank".into(),
248 ));
249 }
250
251 let paired_labels: HashSet<u32> =
252 paired.iter().flat_map(|(m1, m2)| [*m1, *m2]).collect();
253 for &(m1, m2) in paired {
254 if !modes_a.contains(&m1) || !modes_a.contains(&m2) {
255 return Err(Error::InvalidArgument(
256 "Trace paired labels must exist in modes_a".into(),
257 ));
258 }
259 if modes_c.contains(&m1) || modes_c.contains(&m2) {
260 return Err(Error::InvalidArgument(
261 "Trace paired labels must be reduced (not present in modes_c)".into(),
262 ));
263 }
264 let ax1 = mode_position(modes_a, m1)?;
265 let ax2 = mode_position(modes_a, m2)?;
266 if a_shape[ax1] != a_shape[ax2] {
267 return Err(Error::InvalidArgument(
268 "Trace paired dimensions must be equal".into(),
269 ));
270 }
271 }
272 for &m in modes_a {
273 if !modes_c.contains(&m) && !paired_labels.contains(&m) {
274 return Err(Error::InvalidArgument(
275 "Trace modes_a contains labels neither free nor paired".into(),
276 ));
277 }
278 }
279 for (out_ax, &m) in modes_c.iter().enumerate() {
280 if paired_labels.contains(&m) {
281 return Err(Error::InvalidArgument(
282 "Trace free labels must not be in paired set".into(),
283 ));
284 }
285 let in_ax = mode_position(modes_a, m)?;
286 if a_shape[in_ax] != c_shape[out_ax] {
287 return Err(Error::InvalidArgument(
288 "Trace output shape does not match free modes".into(),
289 ));
290 }
291 }
292
293 let paired_axes: Vec<(usize, usize)> = paired
294 .iter()
295 .map(|(m1, m2)| Ok((mode_position(modes_a, *m1)?, mode_position(modes_a, *m2)?)))
296 .collect::<Result<_>>()?;
297 let free_axes: Vec<usize> = modes_c
298 .iter()
299 .map(|m| mode_position(modes_a, *m))
300 .collect::<Result<_>>()?;
301 Ok(TropicalPlan::Trace {
302 paired_axes,
303 free_axes,
304 _marker: PhantomData,
305 })
306 }
307 SemiringCoreDescriptor::AntiTrace {
308 modes_a,
309 modes_c,
310 paired,
311 } => {
312 ensure_shape_count(shapes, 2, "AntiTrace")?;
313 ensure_unique_modes(modes_a, "modes_a")?;
314 ensure_unique_modes(modes_c, "modes_c")?;
315 if paired.is_empty() {
316 return Err(Error::InvalidArgument(
317 "AntiTrace requires non-empty paired axes".into(),
318 ));
319 }
320 ensure_pair_labels_unique(paired, "AntiTrace paired")?;
321 let a_shape = shapes[0];
322 let c_shape = shapes[1];
323 if modes_a.len() != a_shape.len() || modes_c.len() != c_shape.len() {
324 return Err(Error::InvalidArgument(
325 "AntiTrace mode rank does not match shape rank".into(),
326 ));
327 }
328
329 let paired_labels: HashSet<u32> =
330 paired.iter().flat_map(|(m1, m2)| [*m1, *m2]).collect();
331 for &(m1, m2) in paired {
332 if !modes_c.contains(&m1) || !modes_c.contains(&m2) {
333 return Err(Error::InvalidArgument(
334 "AntiTrace paired labels must exist in modes_c".into(),
335 ));
336 }
337 if modes_a.contains(&m1) || modes_a.contains(&m2) {
338 return Err(Error::InvalidArgument(
339 "AntiTrace paired labels must not be in modes_a".into(),
340 ));
341 }
342 let ax1 = mode_position(modes_c, m1)?;
343 let ax2 = mode_position(modes_c, m2)?;
344 if c_shape[ax1] != c_shape[ax2] {
345 return Err(Error::InvalidArgument(
346 "AntiTrace paired dimensions must be equal".into(),
347 ));
348 }
349 }
350 for &m in modes_c {
351 if !modes_a.contains(&m) && !paired_labels.contains(&m) {
352 return Err(Error::InvalidArgument(
353 "AntiTrace modes_c contains labels neither free nor paired".into(),
354 ));
355 }
356 }
357 for (in_ax, &m) in modes_a.iter().enumerate() {
358 if paired_labels.contains(&m) {
359 return Err(Error::InvalidArgument(
360 "AntiTrace free labels must not be in paired set".into(),
361 ));
362 }
363 let out_ax = mode_position(modes_c, m)?;
364 if a_shape[in_ax] != c_shape[out_ax] {
365 return Err(Error::InvalidArgument(
366 "AntiTrace input shape does not match output free modes".into(),
367 ));
368 }
369 }
370
371 let paired_axes: Vec<(usize, usize)> = paired
372 .iter()
373 .map(|(m1, m2)| Ok((mode_position(modes_c, *m1)?, mode_position(modes_c, *m2)?)))
374 .collect::<Result<_>>()?;
375 let free_axes: Vec<usize> = modes_a
376 .iter()
377 .map(|m| mode_position(modes_c, *m))
378 .collect::<Result<_>>()?;
379 Ok(TropicalPlan::AntiTrace {
380 paired_axes,
381 free_axes,
382 _marker: PhantomData,
383 })
384 }
385 SemiringCoreDescriptor::AntiDiag {
386 modes_a,
387 modes_c,
388 paired,
389 } => {
390 ensure_shape_count(shapes, 2, "AntiDiag")?;
391 ensure_unique_modes(modes_a, "modes_a")?;
392 ensure_unique_modes(modes_c, "modes_c")?;
393 if paired.is_empty() {
394 return Err(Error::InvalidArgument(
395 "AntiDiag requires non-empty paired axes".into(),
396 ));
397 }
398 ensure_pair_labels_unique(paired, "AntiDiag paired")?;
399 let a_shape = shapes[0];
400 let c_shape = shapes[1];
401 if modes_a.len() != a_shape.len() || modes_c.len() != c_shape.len() {
402 return Err(Error::InvalidArgument(
403 "AntiDiag mode rank does not match shape rank".into(),
404 ));
405 }
406
407 let paired_labels: HashSet<u32> =
408 paired.iter().flat_map(|(m1, m2)| [*m1, *m2]).collect();
409 let free_labels: HashSet<u32> = modes_a.iter().copied().collect();
410 for &(m1, m2) in paired {
411 if !modes_c.contains(&m1) || !modes_c.contains(&m2) {
412 return Err(Error::InvalidArgument(
413 "AntiDiag paired labels must exist in modes_c".into(),
414 ));
415 }
416 if !free_labels.contains(&m1) {
417 return Err(Error::InvalidArgument(
418 "AntiDiag first paired label must exist in modes_a".into(),
419 ));
420 }
421 if free_labels.contains(&m2) {
422 return Err(Error::InvalidArgument(
423 "AntiDiag second paired label must not exist in modes_a".into(),
424 ));
425 }
426 let ax1 = mode_position(modes_c, m1)?;
427 let ax2 = mode_position(modes_c, m2)?;
428 if c_shape[ax1] != c_shape[ax2] {
429 return Err(Error::InvalidArgument(
430 "AntiDiag paired dimensions must be equal".into(),
431 ));
432 }
433 }
434 for &m in modes_c {
435 if !free_labels.contains(&m) && !paired_labels.contains(&m) {
436 return Err(Error::InvalidArgument(
437 "AntiDiag modes_c contains labels neither free nor paired".into(),
438 ));
439 }
440 }
441 for (in_ax, &m) in modes_a.iter().enumerate() {
442 let out_ax = mode_position(modes_c, m)?;
443 if a_shape[in_ax] != c_shape[out_ax] {
444 return Err(Error::InvalidArgument(
445 "AntiDiag input shape does not match output free modes".into(),
446 ));
447 }
448 }
449
450 let paired_axes: Vec<(usize, usize)> = paired
451 .iter()
452 .map(|(m1, m2)| Ok((mode_position(modes_c, *m1)?, mode_position(modes_c, *m2)?)))
453 .collect::<Result<_>>()?;
454 let free_axes: Vec<usize> = modes_a
455 .iter()
456 .map(|m| mode_position(modes_c, *m))
457 .collect::<Result<_>>()?;
458 Ok(TropicalPlan::AntiDiag {
459 paired_axes,
460 free_axes,
461 _marker: PhantomData,
462 })
463 }
464 SemiringCoreDescriptor::MakeContiguous => {
465 ensure_shape_count(shapes, 2, "MakeContiguous")?;
466 if shapes[0] != shapes[1] {
467 return Err(Error::InvalidArgument(
468 "MakeContiguous input and output shapes must match".into(),
469 ));
470 }
471 Ok(TropicalPlan::MakeContiguous {
472 _marker: PhantomData,
473 })
474 }
475 }
476}