Skip to main content

tenferro/
compiler.rs

1use computegraph::compile::CompiledProgram;
2use tenferro_algebra::Algebra;
3use tenferro_ops::dim_expr::DimExpr;
4use tenferro_ops::semiring_op::SemiringOp;
5use tenferro_ops::semiring_op_kind::SemiringOpKind;
6use tenferro_ops::std_tensor_op::StdTensorOp;
7use tenferro_tensor::{DType, DotGeneralConfig, TensorScalar};
8
9use crate::shape_infer::{infer_output_dtype, infer_output_shapes};
10
11use super::exec::{ExecInstruction, ExecOp, ExecProgram};
12
13pub fn compile_std_to_exec(
14    prog: &CompiledProgram<StdTensorOp>,
15    input_dtypes: &[DType],
16    input_shapes: &[Vec<DimExpr>],
17) -> ExecProgram {
18    assert_eq!(
19        prog.input_slots.len(),
20        input_dtypes.len(),
21        "compile_std_to_exec: input dtype count must match input slot count"
22    );
23    assert_eq!(
24        prog.input_slots.len(),
25        input_shapes.len(),
26        "compile_std_to_exec: input shape count must match input slot count"
27    );
28
29    let mut slot_dtypes: Vec<Option<DType>> = vec![None; prog.n_slots];
30    let mut slot_shapes: Vec<Option<Vec<DimExpr>>> = vec![None; prog.n_slots];
31
32    for (index, &slot) in prog.input_slots.iter().enumerate() {
33        slot_dtypes[slot] = Some(input_dtypes[index]);
34        slot_shapes[slot] = Some(input_shapes[index].clone());
35    }
36
37    let instructions = prog
38        .instructions
39        .iter()
40        .map(|instr| {
41            let input_dtypes: Vec<DType> = instr
42                .inputs
43                .iter()
44                .map(|&slot| {
45                    slot_dtypes[slot].unwrap_or_else(|| {
46                        panic!("compile_std_to_exec: missing dtype for slot {slot}")
47                    })
48                })
49                .collect();
50            let input_shapes_owned: Vec<Vec<DimExpr>> = instr
51                .inputs
52                .iter()
53                .map(|&slot| {
54                    slot_shapes[slot].clone().unwrap_or_else(|| {
55                        panic!("compile_std_to_exec: missing shape for slot {slot}")
56                    })
57                })
58                .collect();
59            let input_shapes_refs: Vec<&[DimExpr]> =
60                input_shapes_owned.iter().map(Vec::as_slice).collect();
61
62            let output_dtype = infer_output_dtype(&instr.op, &input_dtypes);
63            let output_shapes = infer_output_shapes(&instr.op, &input_shapes_refs);
64            assert_eq!(
65                output_shapes.len(),
66                instr.outputs.len(),
67                "compile_std_to_exec: {:?} inferred {} output shapes for {} output slots",
68                instr.op,
69                output_shapes.len(),
70                instr.outputs.len()
71            );
72
73            for (slot, shape) in instr.outputs.iter().zip(output_shapes.iter()) {
74                slot_dtypes[*slot] = Some(output_dtype);
75                slot_shapes[*slot] = Some(shape.clone());
76            }
77
78            ExecInstruction {
79                op: std_to_exec_op(&instr.op),
80                input_slots: instr.inputs.clone(),
81                output_slots: instr.outputs.clone(),
82                dtype: output_dtype,
83                output_shapes,
84                last_use: Vec::new(),
85            }
86        })
87        .collect();
88
89    let mut program = ExecProgram {
90        instructions,
91        input_slots: prog.input_slots.clone(),
92        output_slots: prog.output_slots.clone(),
93        n_slots: prog.n_slots,
94    };
95    dot_dimension_sorter(&mut program);
96    transpose_folding(&mut program);
97    populate_last_use(&mut program);
98    program
99}
100
101pub fn compile_semiring_to_exec<Alg>(
102    prog: &CompiledProgram<SemiringOp<Alg>>,
103    input_shapes: &[Vec<DimExpr>],
104) -> ExecProgram
105where
106    Alg: Algebra + Send + Sync + 'static,
107    Alg::Scalar: TensorScalar,
108{
109    assert_eq!(
110        prog.input_slots.len(),
111        input_shapes.len(),
112        "compile_semiring_to_exec: input shape count must match input slot count"
113    );
114
115    let dtype = <Alg::Scalar as TensorScalar>::dtype();
116    let mut slot_shapes: Vec<Option<Vec<DimExpr>>> = vec![None; prog.n_slots];
117    for (index, &slot) in prog.input_slots.iter().enumerate() {
118        slot_shapes[slot] = Some(input_shapes[index].clone());
119    }
120
121    let instructions = prog
122        .instructions
123        .iter()
124        .map(|instr| {
125            let input_shapes_owned: Vec<Vec<DimExpr>> = instr
126                .inputs
127                .iter()
128                .map(|&slot| {
129                    slot_shapes[slot].clone().unwrap_or_else(|| {
130                        panic!("compile_semiring_to_exec: missing shape for slot {slot}")
131                    })
132                })
133                .collect();
134            let input_shapes_refs: Vec<&[DimExpr]> =
135                input_shapes_owned.iter().map(Vec::as_slice).collect();
136            let output_shapes = infer_semiring_output_shapes(&instr.op.kind, &input_shapes_refs);
137            assert_eq!(
138                output_shapes.len(),
139                instr.outputs.len(),
140                "compile_semiring_to_exec: {:?} inferred {} output shapes for {} output slots",
141                instr.op.kind,
142                output_shapes.len(),
143                instr.outputs.len()
144            );
145
146            for (slot, shape) in instr.outputs.iter().zip(output_shapes.iter()) {
147                slot_shapes[*slot] = Some(shape.clone());
148            }
149
150            ExecInstruction {
151                op: semiring_to_exec_op(&instr.op.kind),
152                input_slots: instr.inputs.clone(),
153                output_slots: instr.outputs.clone(),
154                dtype,
155                output_shapes,
156                last_use: Vec::new(),
157            }
158        })
159        .collect();
160
161    let mut program = ExecProgram {
162        instructions,
163        input_slots: prog.input_slots.clone(),
164        output_slots: prog.output_slots.clone(),
165        n_slots: prog.n_slots,
166    };
167    dot_dimension_sorter(&mut program);
168    transpose_folding(&mut program);
169    populate_last_use(&mut program);
170    program
171}
172
173fn std_to_exec_op(op: &StdTensorOp) -> ExecOp {
174    match op {
175        StdTensorOp::Add => ExecOp::Add,
176        StdTensorOp::Mul => ExecOp::Multiply,
177        StdTensorOp::Neg => ExecOp::Negate,
178        StdTensorOp::Conj => ExecOp::Conj,
179        StdTensorOp::Div => ExecOp::Divide,
180        StdTensorOp::Abs => ExecOp::Abs,
181        StdTensorOp::Sign => ExecOp::Sign,
182        StdTensorOp::Maximum => ExecOp::Maximum,
183        StdTensorOp::Minimum => ExecOp::Minimum,
184        StdTensorOp::Compare(dir) => ExecOp::Compare(dir.clone()),
185        StdTensorOp::Select => ExecOp::Select,
186        StdTensorOp::Clamp => ExecOp::Clamp,
187        StdTensorOp::Exp => ExecOp::Exp,
188        StdTensorOp::Log => ExecOp::Log,
189        StdTensorOp::Sin => ExecOp::Sin,
190        StdTensorOp::Cos => ExecOp::Cos,
191        StdTensorOp::Tanh => ExecOp::Tanh,
192        StdTensorOp::Sqrt => ExecOp::Sqrt,
193        StdTensorOp::Rsqrt => ExecOp::Rsqrt,
194        StdTensorOp::Pow => ExecOp::Pow,
195        StdTensorOp::Expm1 => ExecOp::Expm1,
196        StdTensorOp::Log1p => ExecOp::Log1p,
197        StdTensorOp::Transpose { perm } => ExecOp::Transpose { perm: perm.clone() },
198        StdTensorOp::Reshape { to_shape, .. } => ExecOp::Reshape {
199            shape: to_shape.clone(),
200        },
201        StdTensorOp::BroadcastInDim { shape, dims } => ExecOp::BroadcastInDim {
202            shape: shape.clone(),
203            dims: dims.clone(),
204        },
205        StdTensorOp::Convert { to, .. } => ExecOp::Convert { to: *to },
206        StdTensorOp::Constant { dtype, bytes } => ExecOp::Constant {
207            dtype: *dtype,
208            bytes: bytes.clone(),
209        },
210        StdTensorOp::DotGeneral(config) => ExecOp::DotGeneral(config.clone()),
211        StdTensorOp::NaryEinsum { subscripts, .. } => ExecOp::NaryEinsum {
212            subscripts: subscripts.clone(),
213        },
214        StdTensorOp::ReduceSum { axes, .. } => ExecOp::ReduceSum { axes: axes.clone() },
215        StdTensorOp::ReduceProd { axes, .. } => ExecOp::ReduceProd { axes: axes.clone() },
216        StdTensorOp::ReduceMax { axes, .. } => ExecOp::ReduceMax { axes: axes.clone() },
217        StdTensorOp::ReduceMin { axes, .. } => ExecOp::ReduceMin { axes: axes.clone() },
218        StdTensorOp::ExtractDiag { axis_a, axis_b } => ExecOp::ExtractDiag {
219            axis_a: *axis_a,
220            axis_b: *axis_b,
221        },
222        StdTensorOp::EmbedDiag { axis_a, axis_b } => ExecOp::EmbedDiag {
223            axis_a: *axis_a,
224            axis_b: *axis_b,
225        },
226        StdTensorOp::Tril { k } => ExecOp::Tril { k: *k },
227        StdTensorOp::Triu { k } => ExecOp::Triu { k: *k },
228        StdTensorOp::Gather(config) => ExecOp::Gather(config.clone()),
229        StdTensorOp::Scatter(config) => ExecOp::Scatter(config.clone()),
230        StdTensorOp::Slice(config) => ExecOp::Slice(config.clone()),
231        StdTensorOp::DynamicSlice { slice_sizes } => ExecOp::DynamicSlice {
232            slice_sizes: slice_sizes.clone(),
233        },
234        StdTensorOp::Pad(config) => ExecOp::Pad(config.clone()),
235        StdTensorOp::Concatenate { axis } => ExecOp::Concatenate { axis: *axis },
236        StdTensorOp::Reverse { axes } => ExecOp::Reverse { axes: axes.clone() },
237        StdTensorOp::ShapeOf { axis } => ExecOp::ShapeOf { axis: *axis },
238        StdTensorOp::DynamicTruncate { axis } => ExecOp::DynamicTruncate { axis: *axis },
239        StdTensorOp::PadToMatch { axis } => ExecOp::PadToMatch { axis: *axis },
240        StdTensorOp::Cholesky { .. } => ExecOp::Cholesky,
241        StdTensorOp::Lu { .. } => ExecOp::Lu,
242        StdTensorOp::Svd { eps, .. } => ExecOp::Svd { eps: *eps },
243        StdTensorOp::Qr { .. } => ExecOp::Qr,
244        StdTensorOp::Eigh { eps, .. } => ExecOp::Eigh { eps: *eps },
245        StdTensorOp::Eig { .. } => ExecOp::Eig,
246        StdTensorOp::TriangularSolve {
247            left_side,
248            lower,
249            transpose_a,
250            unit_diagonal,
251            ..
252        } => ExecOp::TriangularSolve {
253            left_side: *left_side,
254            lower: *lower,
255            transpose_a: *transpose_a,
256            unit_diagonal: *unit_diagonal,
257        },
258        StdTensorOp::ValidateNonsingular { .. } => ExecOp::ValidateNonsingular,
259    }
260}
261
262fn semiring_to_exec_op(kind: &SemiringOpKind) -> ExecOp {
263    match kind {
264        SemiringOpKind::Add => ExecOp::Add,
265        SemiringOpKind::Mul => ExecOp::Multiply,
266        SemiringOpKind::DotGeneral(config) => ExecOp::DotGeneral(config.clone()),
267        SemiringOpKind::ReduceSum { axes } => ExecOp::ReduceSum { axes: axes.clone() },
268        SemiringOpKind::Transpose { perm } => ExecOp::Transpose { perm: perm.clone() },
269        SemiringOpKind::Reshape { shape } => ExecOp::Reshape {
270            shape: DimExpr::from_concrete(shape),
271        },
272        SemiringOpKind::BroadcastInDim { shape, dims } => ExecOp::BroadcastInDim {
273            shape: DimExpr::from_concrete(shape),
274            dims: dims.clone(),
275        },
276        SemiringOpKind::ExtractDiag { axis_a, axis_b } => ExecOp::ExtractDiag {
277            axis_a: *axis_a,
278            axis_b: *axis_b,
279        },
280        SemiringOpKind::EmbedDiag { axis_a, axis_b } => ExecOp::EmbedDiag {
281            axis_a: *axis_a,
282            axis_b: *axis_b,
283        },
284    }
285}
286
287fn infer_semiring_output_shapes(
288    kind: &SemiringOpKind,
289    input_shapes: &[&[DimExpr]],
290) -> Vec<Vec<DimExpr>> {
291    let op = match kind {
292        SemiringOpKind::Add => StdTensorOp::Add,
293        SemiringOpKind::Mul => StdTensorOp::Mul,
294        SemiringOpKind::DotGeneral(config) => StdTensorOp::DotGeneral(config.clone()),
295        SemiringOpKind::ReduceSum { axes } => StdTensorOp::ReduceSum {
296            axes: axes.clone(),
297            input_shape: require_input_shape(kind, input_shapes, 0).to_vec(),
298        },
299        SemiringOpKind::Transpose { perm } => StdTensorOp::Transpose { perm: perm.clone() },
300        SemiringOpKind::Reshape { shape } => StdTensorOp::Reshape {
301            from_shape: require_input_shape(kind, input_shapes, 0).to_vec(),
302            to_shape: DimExpr::from_concrete(shape),
303        },
304        SemiringOpKind::BroadcastInDim { shape, dims } => StdTensorOp::BroadcastInDim {
305            shape: DimExpr::from_concrete(shape),
306            dims: dims.clone(),
307        },
308        SemiringOpKind::ExtractDiag { axis_a, axis_b } => StdTensorOp::ExtractDiag {
309            axis_a: *axis_a,
310            axis_b: *axis_b,
311        },
312        SemiringOpKind::EmbedDiag { axis_a, axis_b } => StdTensorOp::EmbedDiag {
313            axis_a: *axis_a,
314            axis_b: *axis_b,
315        },
316    };
317    infer_output_shapes(&op, input_shapes)
318}
319
320fn require_input_shape<'a>(
321    kind: &SemiringOpKind,
322    input_shapes: &'a [&[DimExpr]],
323    index: usize,
324) -> &'a [DimExpr] {
325    input_shapes.get(index).copied().unwrap_or_else(|| {
326        panic!(
327            "semiring shape inference for {kind:?} requires input index {index}, got {}",
328            input_shapes.len()
329        )
330    })
331}
332
333fn populate_last_use(program: &mut ExecProgram) {
334    let all_input_slots: Vec<Vec<usize>> = program
335        .instructions
336        .iter()
337        .map(|instr| instr.input_slots.clone())
338        .collect();
339    for (idx, instr) in program.instructions.iter_mut().enumerate() {
340        instr.last_use = compute_last_use(
341            &instr.input_slots,
342            idx,
343            &all_input_slots,
344            &program.output_slots,
345        );
346    }
347}
348
349fn compute_last_use(
350    input_slots: &[usize],
351    current_idx: usize,
352    all_input_slots: &[Vec<usize>],
353    output_slots: &[usize],
354) -> Vec<bool> {
355    input_slots
356        .iter()
357        .map(|&slot| {
358            if output_slots.contains(&slot) {
359                return false;
360            }
361            for later_inputs in &all_input_slots[current_idx + 1..] {
362                if later_inputs.contains(&slot) {
363                    return false;
364                }
365            }
366            true
367        })
368        .collect()
369}
370
371// ============================================================================
372// Pass 1: DotDimensionSorter
373// ============================================================================
374//
375// Sort contracting dimensions of DotGeneral so that downstream execution sees
376// a stable canonical ordering.
377
378/// Sort contracting dimensions of all DotGeneral instructions in place.
379pub fn dot_dimension_sorter(program: &mut ExecProgram) {
380    for instr in &mut program.instructions {
381        if let ExecOp::DotGeneral(config) = &mut instr.op {
382            sort_contracting_dims(config);
383        }
384    }
385}
386
387fn sort_contracting_dims(config: &mut DotGeneralConfig) {
388    let lhs = &config.lhs_contracting_dims;
389    let rhs = &config.rhs_contracting_dims;
390
391    if lhs.is_empty() {
392        return;
393    }
394
395    if consecutive_if_sorted(lhs) && !is_sorted(lhs) {
396        let perm = argsort(lhs);
397        config.lhs_contracting_dims = apply_perm(lhs, &perm);
398        config.rhs_contracting_dims = apply_perm(rhs, &perm);
399    } else if consecutive_if_sorted(rhs) && !is_sorted(rhs) {
400        let perm = argsort(rhs);
401        config.lhs_contracting_dims = apply_perm(lhs, &perm);
402        config.rhs_contracting_dims = apply_perm(rhs, &perm);
403    }
404}
405
406fn consecutive_if_sorted(dims: &[usize]) -> bool {
407    if dims.is_empty() {
408        return true;
409    }
410    let min_val = *dims.iter().min().expect("non-empty");
411    let max_val = *dims.iter().max().expect("non-empty");
412    max_val - min_val == dims.len() - 1
413}
414
415fn is_sorted(dims: &[usize]) -> bool {
416    dims.windows(2).all(|w| w[0] <= w[1])
417}
418
419fn argsort(dims: &[usize]) -> Vec<usize> {
420    let mut indices: Vec<usize> = (0..dims.len()).collect();
421    indices.sort_by_key(|&i| dims[i]);
422    indices
423}
424
425fn apply_perm(source: &[usize], perm: &[usize]) -> Vec<usize> {
426    perm.iter().map(|&p| source[p]).collect()
427}
428
429// ============================================================================
430// Pass 2: TransposeFolding
431// ============================================================================
432//
433// Absorb Transpose instructions that feed directly into DotGeneral by
434// adjusting the DotGeneral dimension numbers and bypassing the Transpose.
435
436/// Fold Transpose instructions into DotGeneral dimension numbers.
437pub fn transpose_folding(program: &mut ExecProgram) {
438    loop {
439        let changed = transpose_fold_one_pass(program);
440        if !changed {
441            break;
442        }
443    }
444}
445
446fn transpose_fold_one_pass(program: &mut ExecProgram) -> bool {
447    let mut changed = false;
448
449    for index in 0..program.instructions.len() {
450        if !matches!(program.instructions[index].op, ExecOp::DotGeneral(_)) {
451            continue;
452        }
453
454        if try_fold_operand(program, index, 0) {
455            changed = true;
456        }
457        if program.instructions[index].input_slots.len() > 1 && try_fold_operand(program, index, 1)
458        {
459            changed = true;
460        }
461    }
462
463    changed
464}
465
466fn try_fold_operand(program: &mut ExecProgram, dot_idx: usize, operand_idx: usize) -> bool {
467    let input_slot = program.instructions[dot_idx].input_slots[operand_idx];
468    let Some(producer_idx) = find_producer(program, input_slot) else {
469        return false;
470    };
471
472    let perm = match &program.instructions[producer_idx].op {
473        ExecOp::Transpose { perm } => perm.clone(),
474        _ => return false,
475    };
476
477    let config = match &program.instructions[dot_idx].op {
478        ExecOp::DotGeneral(config) => config.clone(),
479        _ => return false,
480    };
481
482    if !is_transpose_foldable(&config, operand_idx, &perm) {
483        return false;
484    }
485
486    let new_config = fold_transpose_into_dot(&config, operand_idx, &perm);
487    let original_input = program.instructions[producer_idx].input_slots[0];
488    program.instructions[dot_idx].op = ExecOp::DotGeneral(new_config);
489    program.instructions[dot_idx].input_slots[operand_idx] = original_input;
490    true
491}
492
493fn find_producer(program: &ExecProgram, slot: usize) -> Option<usize> {
494    program
495        .instructions
496        .iter()
497        .position(|inst| inst.output_slots.contains(&slot))
498}
499
500fn is_transpose_foldable(config: &DotGeneralConfig, operand_idx: usize, perm: &[usize]) -> bool {
501    let (rank, contracting_dims, batch_dims) = if operand_idx == 0 {
502        (
503            config.lhs_rank,
504            config.lhs_contracting_dims.as_slice(),
505            config.lhs_batch_dims.as_slice(),
506        )
507    } else {
508        (
509            config.rhs_rank,
510            config.rhs_contracting_dims.as_slice(),
511            config.rhs_batch_dims.as_slice(),
512        )
513    };
514
515    if perm.len() != rank || !is_valid_permutation(perm, rank) {
516        return false;
517    }
518
519    let Some(free_dims) = free_axes(rank, contracting_dims, batch_dims) else {
520        return false;
521    };
522
523    is_role_group_order_preserved(&free_dims, perm)
524        && is_role_group_order_preserved(contracting_dims, perm)
525        && is_role_group_order_preserved(batch_dims, perm)
526}
527
528fn free_axes(rank: usize, contracting_dims: &[usize], batch_dims: &[usize]) -> Option<Vec<usize>> {
529    let mut used = vec![false; rank];
530    for &axis in contracting_dims.iter().chain(batch_dims.iter()) {
531        if axis >= rank || used[axis] {
532            return None;
533        }
534        used[axis] = true;
535    }
536
537    Some((0..rank).filter(|&axis| !used[axis]).collect())
538}
539
540fn is_valid_permutation(perm: &[usize], rank: usize) -> bool {
541    let mut seen = vec![false; rank];
542    for &axis in perm {
543        if axis >= rank || seen[axis] {
544            return false;
545        }
546        seen[axis] = true;
547    }
548    true
549}
550
551fn map_axes(axes: &[usize], perm: &[usize]) -> Option<Vec<usize>> {
552    axes.iter().map(|&axis| perm.get(axis).copied()).collect()
553}
554
555fn is_role_group_order_preserved(axes: &[usize], perm: &[usize]) -> bool {
556    let Some(mapped_axes) = map_axes(axes, perm) else {
557        return false;
558    };
559    is_strictly_increasing(&mapped_axes)
560}
561
562fn is_strictly_increasing(values: &[usize]) -> bool {
563    values.windows(2).all(|pair| pair[0] < pair[1])
564}
565
566fn fold_transpose_into_dot(
567    config: &DotGeneralConfig,
568    operand_idx: usize,
569    perm: &[usize],
570) -> DotGeneralConfig {
571    let mut new_config = config.clone();
572    if operand_idx == 0 {
573        new_config.lhs_contracting_dims = config
574            .lhs_contracting_dims
575            .iter()
576            .map(|&dim| perm[dim])
577            .collect();
578        new_config.lhs_batch_dims = config.lhs_batch_dims.iter().map(|&dim| perm[dim]).collect();
579    } else {
580        new_config.rhs_contracting_dims = config
581            .rhs_contracting_dims
582            .iter()
583            .map(|&dim| perm[dim])
584            .collect();
585        new_config.rhs_batch_dims = config.rhs_batch_dims.iter().map(|&dim| perm[dim]).collect();
586    }
587    new_config
588}