1use computegraph::compile::CompiledProgram;
2use tenferro_algebra::Algebra;
3use tenferro_ops::dim_expr::DimExpr;
4use tenferro_ops::semiring_op::SemiringOp;
5use tenferro_ops::semiring_op_kind::SemiringOpKind;
6use tenferro_ops::std_tensor_op::StdTensorOp;
7use tenferro_tensor::{DType, DotGeneralConfig, TensorScalar};
8
9use crate::shape_infer::{infer_output_dtype, infer_output_shapes};
10
11use super::exec::{ExecInstruction, ExecOp, ExecProgram};
12
13pub fn compile_std_to_exec(
14 prog: &CompiledProgram<StdTensorOp>,
15 input_dtypes: &[DType],
16 input_shapes: &[Vec<DimExpr>],
17) -> ExecProgram {
18 assert_eq!(
19 prog.input_slots.len(),
20 input_dtypes.len(),
21 "compile_std_to_exec: input dtype count must match input slot count"
22 );
23 assert_eq!(
24 prog.input_slots.len(),
25 input_shapes.len(),
26 "compile_std_to_exec: input shape count must match input slot count"
27 );
28
29 let mut slot_dtypes: Vec<Option<DType>> = vec![None; prog.n_slots];
30 let mut slot_shapes: Vec<Option<Vec<DimExpr>>> = vec![None; prog.n_slots];
31
32 for (index, &slot) in prog.input_slots.iter().enumerate() {
33 slot_dtypes[slot] = Some(input_dtypes[index]);
34 slot_shapes[slot] = Some(input_shapes[index].clone());
35 }
36
37 let instructions = prog
38 .instructions
39 .iter()
40 .map(|instr| {
41 let input_dtypes: Vec<DType> = instr
42 .inputs
43 .iter()
44 .map(|&slot| {
45 slot_dtypes[slot].unwrap_or_else(|| {
46 panic!("compile_std_to_exec: missing dtype for slot {slot}")
47 })
48 })
49 .collect();
50 let input_shapes_owned: Vec<Vec<DimExpr>> = instr
51 .inputs
52 .iter()
53 .map(|&slot| {
54 slot_shapes[slot].clone().unwrap_or_else(|| {
55 panic!("compile_std_to_exec: missing shape for slot {slot}")
56 })
57 })
58 .collect();
59 let input_shapes_refs: Vec<&[DimExpr]> =
60 input_shapes_owned.iter().map(Vec::as_slice).collect();
61
62 let output_dtype = infer_output_dtype(&instr.op, &input_dtypes);
63 let output_shapes = infer_output_shapes(&instr.op, &input_shapes_refs);
64 assert_eq!(
65 output_shapes.len(),
66 instr.outputs.len(),
67 "compile_std_to_exec: {:?} inferred {} output shapes for {} output slots",
68 instr.op,
69 output_shapes.len(),
70 instr.outputs.len()
71 );
72
73 for (slot, shape) in instr.outputs.iter().zip(output_shapes.iter()) {
74 slot_dtypes[*slot] = Some(output_dtype);
75 slot_shapes[*slot] = Some(shape.clone());
76 }
77
78 ExecInstruction {
79 op: std_to_exec_op(&instr.op),
80 input_slots: instr.inputs.clone(),
81 output_slots: instr.outputs.clone(),
82 dtype: output_dtype,
83 output_shapes,
84 last_use: Vec::new(),
85 }
86 })
87 .collect();
88
89 let mut program = ExecProgram {
90 instructions,
91 input_slots: prog.input_slots.clone(),
92 output_slots: prog.output_slots.clone(),
93 n_slots: prog.n_slots,
94 };
95 dot_dimension_sorter(&mut program);
96 transpose_folding(&mut program);
97 populate_last_use(&mut program);
98 program
99}
100
101pub fn compile_semiring_to_exec<Alg>(
102 prog: &CompiledProgram<SemiringOp<Alg>>,
103 input_shapes: &[Vec<DimExpr>],
104) -> ExecProgram
105where
106 Alg: Algebra + Send + Sync + 'static,
107 Alg::Scalar: TensorScalar,
108{
109 assert_eq!(
110 prog.input_slots.len(),
111 input_shapes.len(),
112 "compile_semiring_to_exec: input shape count must match input slot count"
113 );
114
115 let dtype = <Alg::Scalar as TensorScalar>::dtype();
116 let mut slot_shapes: Vec<Option<Vec<DimExpr>>> = vec![None; prog.n_slots];
117 for (index, &slot) in prog.input_slots.iter().enumerate() {
118 slot_shapes[slot] = Some(input_shapes[index].clone());
119 }
120
121 let instructions = prog
122 .instructions
123 .iter()
124 .map(|instr| {
125 let input_shapes_owned: Vec<Vec<DimExpr>> = instr
126 .inputs
127 .iter()
128 .map(|&slot| {
129 slot_shapes[slot].clone().unwrap_or_else(|| {
130 panic!("compile_semiring_to_exec: missing shape for slot {slot}")
131 })
132 })
133 .collect();
134 let input_shapes_refs: Vec<&[DimExpr]> =
135 input_shapes_owned.iter().map(Vec::as_slice).collect();
136 let output_shapes = infer_semiring_output_shapes(&instr.op.kind, &input_shapes_refs);
137 assert_eq!(
138 output_shapes.len(),
139 instr.outputs.len(),
140 "compile_semiring_to_exec: {:?} inferred {} output shapes for {} output slots",
141 instr.op.kind,
142 output_shapes.len(),
143 instr.outputs.len()
144 );
145
146 for (slot, shape) in instr.outputs.iter().zip(output_shapes.iter()) {
147 slot_shapes[*slot] = Some(shape.clone());
148 }
149
150 ExecInstruction {
151 op: semiring_to_exec_op(&instr.op.kind),
152 input_slots: instr.inputs.clone(),
153 output_slots: instr.outputs.clone(),
154 dtype,
155 output_shapes,
156 last_use: Vec::new(),
157 }
158 })
159 .collect();
160
161 let mut program = ExecProgram {
162 instructions,
163 input_slots: prog.input_slots.clone(),
164 output_slots: prog.output_slots.clone(),
165 n_slots: prog.n_slots,
166 };
167 dot_dimension_sorter(&mut program);
168 transpose_folding(&mut program);
169 populate_last_use(&mut program);
170 program
171}
172
173fn std_to_exec_op(op: &StdTensorOp) -> ExecOp {
174 match op {
175 StdTensorOp::Add => ExecOp::Add,
176 StdTensorOp::Mul => ExecOp::Multiply,
177 StdTensorOp::Neg => ExecOp::Negate,
178 StdTensorOp::Conj => ExecOp::Conj,
179 StdTensorOp::Div => ExecOp::Divide,
180 StdTensorOp::Abs => ExecOp::Abs,
181 StdTensorOp::Sign => ExecOp::Sign,
182 StdTensorOp::Maximum => ExecOp::Maximum,
183 StdTensorOp::Minimum => ExecOp::Minimum,
184 StdTensorOp::Compare(dir) => ExecOp::Compare(dir.clone()),
185 StdTensorOp::Select => ExecOp::Select,
186 StdTensorOp::Clamp => ExecOp::Clamp,
187 StdTensorOp::Exp => ExecOp::Exp,
188 StdTensorOp::Log => ExecOp::Log,
189 StdTensorOp::Sin => ExecOp::Sin,
190 StdTensorOp::Cos => ExecOp::Cos,
191 StdTensorOp::Tanh => ExecOp::Tanh,
192 StdTensorOp::Sqrt => ExecOp::Sqrt,
193 StdTensorOp::Rsqrt => ExecOp::Rsqrt,
194 StdTensorOp::Pow => ExecOp::Pow,
195 StdTensorOp::Expm1 => ExecOp::Expm1,
196 StdTensorOp::Log1p => ExecOp::Log1p,
197 StdTensorOp::Transpose { perm } => ExecOp::Transpose { perm: perm.clone() },
198 StdTensorOp::Reshape { to_shape, .. } => ExecOp::Reshape {
199 shape: to_shape.clone(),
200 },
201 StdTensorOp::BroadcastInDim { shape, dims } => ExecOp::BroadcastInDim {
202 shape: shape.clone(),
203 dims: dims.clone(),
204 },
205 StdTensorOp::Convert { to, .. } => ExecOp::Convert { to: *to },
206 StdTensorOp::Constant { dtype, bytes } => ExecOp::Constant {
207 dtype: *dtype,
208 bytes: bytes.clone(),
209 },
210 StdTensorOp::DotGeneral(config) => ExecOp::DotGeneral(config.clone()),
211 StdTensorOp::NaryEinsum { subscripts, .. } => ExecOp::NaryEinsum {
212 subscripts: subscripts.clone(),
213 },
214 StdTensorOp::ReduceSum { axes, .. } => ExecOp::ReduceSum { axes: axes.clone() },
215 StdTensorOp::ReduceProd { axes, .. } => ExecOp::ReduceProd { axes: axes.clone() },
216 StdTensorOp::ReduceMax { axes, .. } => ExecOp::ReduceMax { axes: axes.clone() },
217 StdTensorOp::ReduceMin { axes, .. } => ExecOp::ReduceMin { axes: axes.clone() },
218 StdTensorOp::ExtractDiag { axis_a, axis_b } => ExecOp::ExtractDiag {
219 axis_a: *axis_a,
220 axis_b: *axis_b,
221 },
222 StdTensorOp::EmbedDiag { axis_a, axis_b } => ExecOp::EmbedDiag {
223 axis_a: *axis_a,
224 axis_b: *axis_b,
225 },
226 StdTensorOp::Tril { k } => ExecOp::Tril { k: *k },
227 StdTensorOp::Triu { k } => ExecOp::Triu { k: *k },
228 StdTensorOp::Gather(config) => ExecOp::Gather(config.clone()),
229 StdTensorOp::Scatter(config) => ExecOp::Scatter(config.clone()),
230 StdTensorOp::Slice(config) => ExecOp::Slice(config.clone()),
231 StdTensorOp::DynamicSlice { slice_sizes } => ExecOp::DynamicSlice {
232 slice_sizes: slice_sizes.clone(),
233 },
234 StdTensorOp::Pad(config) => ExecOp::Pad(config.clone()),
235 StdTensorOp::Concatenate { axis } => ExecOp::Concatenate { axis: *axis },
236 StdTensorOp::Reverse { axes } => ExecOp::Reverse { axes: axes.clone() },
237 StdTensorOp::ShapeOf { axis } => ExecOp::ShapeOf { axis: *axis },
238 StdTensorOp::DynamicTruncate { axis } => ExecOp::DynamicTruncate { axis: *axis },
239 StdTensorOp::PadToMatch { axis } => ExecOp::PadToMatch { axis: *axis },
240 StdTensorOp::Cholesky { .. } => ExecOp::Cholesky,
241 StdTensorOp::Lu { .. } => ExecOp::Lu,
242 StdTensorOp::Svd { eps, .. } => ExecOp::Svd { eps: *eps },
243 StdTensorOp::Qr { .. } => ExecOp::Qr,
244 StdTensorOp::Eigh { eps, .. } => ExecOp::Eigh { eps: *eps },
245 StdTensorOp::Eig { .. } => ExecOp::Eig,
246 StdTensorOp::TriangularSolve {
247 left_side,
248 lower,
249 transpose_a,
250 unit_diagonal,
251 ..
252 } => ExecOp::TriangularSolve {
253 left_side: *left_side,
254 lower: *lower,
255 transpose_a: *transpose_a,
256 unit_diagonal: *unit_diagonal,
257 },
258 StdTensorOp::ValidateNonsingular { .. } => ExecOp::ValidateNonsingular,
259 }
260}
261
262fn semiring_to_exec_op(kind: &SemiringOpKind) -> ExecOp {
263 match kind {
264 SemiringOpKind::Add => ExecOp::Add,
265 SemiringOpKind::Mul => ExecOp::Multiply,
266 SemiringOpKind::DotGeneral(config) => ExecOp::DotGeneral(config.clone()),
267 SemiringOpKind::ReduceSum { axes } => ExecOp::ReduceSum { axes: axes.clone() },
268 SemiringOpKind::Transpose { perm } => ExecOp::Transpose { perm: perm.clone() },
269 SemiringOpKind::Reshape { shape } => ExecOp::Reshape {
270 shape: DimExpr::from_concrete(shape),
271 },
272 SemiringOpKind::BroadcastInDim { shape, dims } => ExecOp::BroadcastInDim {
273 shape: DimExpr::from_concrete(shape),
274 dims: dims.clone(),
275 },
276 SemiringOpKind::ExtractDiag { axis_a, axis_b } => ExecOp::ExtractDiag {
277 axis_a: *axis_a,
278 axis_b: *axis_b,
279 },
280 SemiringOpKind::EmbedDiag { axis_a, axis_b } => ExecOp::EmbedDiag {
281 axis_a: *axis_a,
282 axis_b: *axis_b,
283 },
284 }
285}
286
287fn infer_semiring_output_shapes(
288 kind: &SemiringOpKind,
289 input_shapes: &[&[DimExpr]],
290) -> Vec<Vec<DimExpr>> {
291 let op = match kind {
292 SemiringOpKind::Add => StdTensorOp::Add,
293 SemiringOpKind::Mul => StdTensorOp::Mul,
294 SemiringOpKind::DotGeneral(config) => StdTensorOp::DotGeneral(config.clone()),
295 SemiringOpKind::ReduceSum { axes } => StdTensorOp::ReduceSum {
296 axes: axes.clone(),
297 input_shape: require_input_shape(kind, input_shapes, 0).to_vec(),
298 },
299 SemiringOpKind::Transpose { perm } => StdTensorOp::Transpose { perm: perm.clone() },
300 SemiringOpKind::Reshape { shape } => StdTensorOp::Reshape {
301 from_shape: require_input_shape(kind, input_shapes, 0).to_vec(),
302 to_shape: DimExpr::from_concrete(shape),
303 },
304 SemiringOpKind::BroadcastInDim { shape, dims } => StdTensorOp::BroadcastInDim {
305 shape: DimExpr::from_concrete(shape),
306 dims: dims.clone(),
307 },
308 SemiringOpKind::ExtractDiag { axis_a, axis_b } => StdTensorOp::ExtractDiag {
309 axis_a: *axis_a,
310 axis_b: *axis_b,
311 },
312 SemiringOpKind::EmbedDiag { axis_a, axis_b } => StdTensorOp::EmbedDiag {
313 axis_a: *axis_a,
314 axis_b: *axis_b,
315 },
316 };
317 infer_output_shapes(&op, input_shapes)
318}
319
320fn require_input_shape<'a>(
321 kind: &SemiringOpKind,
322 input_shapes: &'a [&[DimExpr]],
323 index: usize,
324) -> &'a [DimExpr] {
325 input_shapes.get(index).copied().unwrap_or_else(|| {
326 panic!(
327 "semiring shape inference for {kind:?} requires input index {index}, got {}",
328 input_shapes.len()
329 )
330 })
331}
332
333fn populate_last_use(program: &mut ExecProgram) {
334 let all_input_slots: Vec<Vec<usize>> = program
335 .instructions
336 .iter()
337 .map(|instr| instr.input_slots.clone())
338 .collect();
339 for (idx, instr) in program.instructions.iter_mut().enumerate() {
340 instr.last_use = compute_last_use(
341 &instr.input_slots,
342 idx,
343 &all_input_slots,
344 &program.output_slots,
345 );
346 }
347}
348
349fn compute_last_use(
350 input_slots: &[usize],
351 current_idx: usize,
352 all_input_slots: &[Vec<usize>],
353 output_slots: &[usize],
354) -> Vec<bool> {
355 input_slots
356 .iter()
357 .map(|&slot| {
358 if output_slots.contains(&slot) {
359 return false;
360 }
361 for later_inputs in &all_input_slots[current_idx + 1..] {
362 if later_inputs.contains(&slot) {
363 return false;
364 }
365 }
366 true
367 })
368 .collect()
369}
370
371pub fn dot_dimension_sorter(program: &mut ExecProgram) {
380 for instr in &mut program.instructions {
381 if let ExecOp::DotGeneral(config) = &mut instr.op {
382 sort_contracting_dims(config);
383 }
384 }
385}
386
387fn sort_contracting_dims(config: &mut DotGeneralConfig) {
388 let lhs = &config.lhs_contracting_dims;
389 let rhs = &config.rhs_contracting_dims;
390
391 if lhs.is_empty() {
392 return;
393 }
394
395 if consecutive_if_sorted(lhs) && !is_sorted(lhs) {
396 let perm = argsort(lhs);
397 config.lhs_contracting_dims = apply_perm(lhs, &perm);
398 config.rhs_contracting_dims = apply_perm(rhs, &perm);
399 } else if consecutive_if_sorted(rhs) && !is_sorted(rhs) {
400 let perm = argsort(rhs);
401 config.lhs_contracting_dims = apply_perm(lhs, &perm);
402 config.rhs_contracting_dims = apply_perm(rhs, &perm);
403 }
404}
405
406fn consecutive_if_sorted(dims: &[usize]) -> bool {
407 if dims.is_empty() {
408 return true;
409 }
410 let min_val = *dims.iter().min().expect("non-empty");
411 let max_val = *dims.iter().max().expect("non-empty");
412 max_val - min_val == dims.len() - 1
413}
414
415fn is_sorted(dims: &[usize]) -> bool {
416 dims.windows(2).all(|w| w[0] <= w[1])
417}
418
419fn argsort(dims: &[usize]) -> Vec<usize> {
420 let mut indices: Vec<usize> = (0..dims.len()).collect();
421 indices.sort_by_key(|&i| dims[i]);
422 indices
423}
424
425fn apply_perm(source: &[usize], perm: &[usize]) -> Vec<usize> {
426 perm.iter().map(|&p| source[p]).collect()
427}
428
429pub fn transpose_folding(program: &mut ExecProgram) {
438 loop {
439 let changed = transpose_fold_one_pass(program);
440 if !changed {
441 break;
442 }
443 }
444}
445
446fn transpose_fold_one_pass(program: &mut ExecProgram) -> bool {
447 let mut changed = false;
448
449 for index in 0..program.instructions.len() {
450 if !matches!(program.instructions[index].op, ExecOp::DotGeneral(_)) {
451 continue;
452 }
453
454 if try_fold_operand(program, index, 0) {
455 changed = true;
456 }
457 if program.instructions[index].input_slots.len() > 1 && try_fold_operand(program, index, 1)
458 {
459 changed = true;
460 }
461 }
462
463 changed
464}
465
466fn try_fold_operand(program: &mut ExecProgram, dot_idx: usize, operand_idx: usize) -> bool {
467 let input_slot = program.instructions[dot_idx].input_slots[operand_idx];
468 let Some(producer_idx) = find_producer(program, input_slot) else {
469 return false;
470 };
471
472 let perm = match &program.instructions[producer_idx].op {
473 ExecOp::Transpose { perm } => perm.clone(),
474 _ => return false,
475 };
476
477 let config = match &program.instructions[dot_idx].op {
478 ExecOp::DotGeneral(config) => config.clone(),
479 _ => return false,
480 };
481
482 if !is_transpose_foldable(&config, operand_idx, &perm) {
483 return false;
484 }
485
486 let new_config = fold_transpose_into_dot(&config, operand_idx, &perm);
487 let original_input = program.instructions[producer_idx].input_slots[0];
488 program.instructions[dot_idx].op = ExecOp::DotGeneral(new_config);
489 program.instructions[dot_idx].input_slots[operand_idx] = original_input;
490 true
491}
492
493fn find_producer(program: &ExecProgram, slot: usize) -> Option<usize> {
494 program
495 .instructions
496 .iter()
497 .position(|inst| inst.output_slots.contains(&slot))
498}
499
500fn is_transpose_foldable(config: &DotGeneralConfig, operand_idx: usize, perm: &[usize]) -> bool {
501 let (rank, contracting_dims, batch_dims) = if operand_idx == 0 {
502 (
503 config.lhs_rank,
504 config.lhs_contracting_dims.as_slice(),
505 config.lhs_batch_dims.as_slice(),
506 )
507 } else {
508 (
509 config.rhs_rank,
510 config.rhs_contracting_dims.as_slice(),
511 config.rhs_batch_dims.as_slice(),
512 )
513 };
514
515 if perm.len() != rank || !is_valid_permutation(perm, rank) {
516 return false;
517 }
518
519 let Some(free_dims) = free_axes(rank, contracting_dims, batch_dims) else {
520 return false;
521 };
522
523 is_role_group_order_preserved(&free_dims, perm)
524 && is_role_group_order_preserved(contracting_dims, perm)
525 && is_role_group_order_preserved(batch_dims, perm)
526}
527
528fn free_axes(rank: usize, contracting_dims: &[usize], batch_dims: &[usize]) -> Option<Vec<usize>> {
529 let mut used = vec![false; rank];
530 for &axis in contracting_dims.iter().chain(batch_dims.iter()) {
531 if axis >= rank || used[axis] {
532 return None;
533 }
534 used[axis] = true;
535 }
536
537 Some((0..rank).filter(|&axis| !used[axis]).collect())
538}
539
540fn is_valid_permutation(perm: &[usize], rank: usize) -> bool {
541 let mut seen = vec![false; rank];
542 for &axis in perm {
543 if axis >= rank || seen[axis] {
544 return false;
545 }
546 seen[axis] = true;
547 }
548 true
549}
550
551fn map_axes(axes: &[usize], perm: &[usize]) -> Option<Vec<usize>> {
552 axes.iter().map(|&axis| perm.get(axis).copied()).collect()
553}
554
555fn is_role_group_order_preserved(axes: &[usize], perm: &[usize]) -> bool {
556 let Some(mapped_axes) = map_axes(axes, perm) else {
557 return false;
558 };
559 is_strictly_increasing(&mapped_axes)
560}
561
562fn is_strictly_increasing(values: &[usize]) -> bool {
563 values.windows(2).all(|pair| pair[0] < pair[1])
564}
565
566fn fold_transpose_into_dot(
567 config: &DotGeneralConfig,
568 operand_idx: usize,
569 perm: &[usize],
570) -> DotGeneralConfig {
571 let mut new_config = config.clone();
572 if operand_idx == 0 {
573 new_config.lhs_contracting_dims = config
574 .lhs_contracting_dims
575 .iter()
576 .map(|&dim| perm[dim])
577 .collect();
578 new_config.lhs_batch_dims = config.lhs_batch_dims.iter().map(|&dim| perm[dim]).collect();
579 } else {
580 new_config.rhs_contracting_dims = config
581 .rhs_contracting_dims
582 .iter()
583 .map(|&dim| perm[dim])
584 .collect();
585 new_config.rhs_batch_dims = config.rhs_batch_dims.iter().map(|&dim| perm[dim]).collect();
586 }
587 new_config
588}