Skip to main content

tenferro/
segment.rs

1//! ExecProgram segmentation and segment-based dispatch.
2
3use std::collections::{HashMap, HashSet};
4
5use crate::engine::{NaryEinsumCache, DEFAULT_EINSUM_CACHE_CAPACITY};
6use crate::error::Result;
7use crate::exec::{
8    collect_outputs, execute_backend_op, execute_ffi_instruction, execute_host_instruction,
9    initialize_slots, is_ffi_instruction, is_host_instruction, reclaim_last_use_inputs_backend,
10    reclaim_last_use_inputs_exec, DispatchMode, ExecInstruction, ExecOp, ExecProgram,
11};
12use tenferro_tensor::{
13    ElementwiseFusionInst, ElementwiseFusionOp, ElementwiseFusionPlan, Tensor, TensorBackend,
14};
15
16/// A compiled execution segment.
17///
18/// Fused segments group consecutive non-host, non-FFI instructions that can
19/// share one backend execution session. FFI and host segments remain
20/// single-instruction boundaries in Phase 4.
21///
22/// # Examples
23///
24/// ```
25/// use tenferro::segment::{segment_exec_program, Segment};
26/// use tenferro::exec::{ExecInstruction, ExecOp, ExecProgram};
27/// use tenferro::DType;
28///
29/// let program = ExecProgram {
30///     instructions: vec![
31///         ExecInstruction {
32///             op: ExecOp::Add,
33///             input_slots: vec![0, 1],
34///             output_slots: vec![2],
35///             dtype: DType::F64,
36///             output_shapes: vec![vec![]],
37///             last_use: vec![false, true],
38///         },
39///         ExecInstruction {
40///             op: ExecOp::Negate,
41///             input_slots: vec![2],
42///             output_slots: vec![3],
43///             dtype: DType::F64,
44///             output_shapes: vec![vec![]],
45///             last_use: vec![true],
46///         },
47///     ],
48///     input_slots: vec![0, 1],
49///     output_slots: vec![3],
50///     n_slots: 4,
51/// };
52///
53/// let segments = segment_exec_program(&program);
54/// assert!(matches!(&segments[0], Segment::Fused { instructions, .. } if instructions.len() == 2));
55/// ```
56#[derive(Clone, Debug)]
57pub enum Segment {
58    Fused {
59        instructions: Vec<ExecInstruction>,
60        input_slots: Vec<usize>,
61        output_slots: Vec<usize>,
62        last_use: Vec<bool>,
63    },
64    Ffi(ExecInstruction),
65    Host(ExecInstruction),
66}
67
68/// Compile an [`ExecProgram`] into execution segments.
69///
70/// # Examples
71///
72/// ```
73/// use tenferro::segment::{segment_exec_program, Segment};
74/// use tenferro::exec::{ExecInstruction, ExecOp, ExecProgram};
75/// use tenferro::DType;
76///
77/// let program = ExecProgram {
78///     instructions: vec![
79///         ExecInstruction {
80///             op: ExecOp::Add,
81///             input_slots: vec![0, 1],
82///             output_slots: vec![2],
83///             dtype: DType::F64,
84///             output_shapes: vec![vec![]],
85///             last_use: vec![false, true],
86///         },
87///         ExecInstruction {
88///             op: ExecOp::ShapeOf { axis: 0 },
89///             input_slots: vec![2],
90///             output_slots: vec![3],
91///             dtype: DType::F64,
92///             output_shapes: vec![vec![]],
93///             last_use: vec![true],
94///         },
95///     ],
96///     input_slots: vec![0, 1],
97///     output_slots: vec![2, 3],
98///     n_slots: 4,
99/// };
100///
101/// let segments = segment_exec_program(&program);
102/// assert!(matches!(&segments[0], Segment::Fused { .. }));
103/// assert!(matches!(&segments[1], Segment::Host(_)));
104/// ```
105pub fn segment_exec_program(program: &ExecProgram) -> Vec<Segment> {
106    let mut segments = Vec::new();
107    let mut fused_start: Option<usize> = None;
108
109    for (idx, inst) in program.instructions.iter().enumerate() {
110        if is_host_instruction(inst) {
111            flush_fused_segment(program, &mut segments, fused_start.take(), idx);
112            segments.push(Segment::Host(inst.clone()));
113        } else if is_ffi_instruction(inst) {
114            flush_fused_segment(program, &mut segments, fused_start.take(), idx);
115            segments.push(Segment::Ffi(inst.clone()));
116        } else if fused_start.is_none() {
117            fused_start = Some(idx);
118        }
119    }
120
121    flush_fused_segment(
122        program,
123        &mut segments,
124        fused_start.take(),
125        program.instructions.len(),
126    );
127    segments
128}
129
130/// Evaluate an [`ExecProgram`] via segment-based dispatch.
131///
132/// # Examples
133///
134/// ```
135/// use tenferro::segment::eval_exec_segmented;
136/// use tenferro::exec::ExecProgram;
137/// use tenferro::CpuBackend;
138///
139/// let _eval: fn(&mut CpuBackend, &ExecProgram, Vec<tenferro::Tensor>) -> tenferro::error::Result<Vec<tenferro::Tensor>> =
140///     eval_exec_segmented::<CpuBackend>;
141/// ```
142pub fn eval_exec_segmented<B: TensorBackend>(
143    backend: &mut B,
144    program: &ExecProgram,
145    inputs: Vec<Tensor>,
146) -> Result<Vec<Tensor>> {
147    let mut cache = NaryEinsumCache::new(
148        std::num::NonZeroUsize::new(DEFAULT_EINSUM_CACHE_CAPACITY)
149            .expect("DEFAULT_EINSUM_CACHE_CAPACITY must be non-zero"),
150    );
151    eval_exec_segmented_with_cache(backend, program, inputs, &mut cache)
152}
153
154pub(crate) fn eval_exec_segmented_with_cache<B: TensorBackend>(
155    backend: &mut B,
156    program: &ExecProgram,
157    inputs: Vec<Tensor>,
158    cache: &mut crate::engine::NaryEinsumCache,
159) -> Result<Vec<Tensor>> {
160    let segments = segment_exec_program(program);
161    let mut slots = initialize_slots(program, inputs);
162
163    for segment in &segments {
164        match segment {
165            Segment::Fused {
166                instructions,
167                input_slots,
168                output_slots,
169                last_use,
170            } => {
171                backend.with_exec_session(|exec| -> Result<()> {
172                    if let Some(plan) =
173                        build_elementwise_fusion_plan(instructions, input_slots, output_slots)
174                    {
175                        let inputs = collect_segment_inputs(&slots, input_slots)?;
176                        if let Some(outputs) = exec.execute_elementwise_fusion(&inputs, &plan)? {
177                            if outputs.len() != output_slots.len() {
178                                return Err(crate::error::Error::Internal(format!(
179                                    "fused elementwise kernel produced {} outputs for {} slots",
180                                    outputs.len(),
181                                    output_slots.len()
182                                )));
183                            }
184                            for (slot, tensor) in
185                                output_slots.iter().copied().zip(outputs.into_iter())
186                            {
187                                slots[slot] = Some(tensor);
188                            }
189                            reclaim_segment_inputs_exec(&mut slots, input_slots, last_use, exec);
190                            return Ok(());
191                        }
192                    }
193
194                    for inst in instructions {
195                        let result = execute_backend_op(exec, &slots, inst)?;
196                        slots[inst.output_slots[0]] = Some(result);
197                        reclaim_last_use_inputs_exec(&mut slots, inst, exec);
198                    }
199                    Ok(())
200                })?;
201            }
202            Segment::Ffi(inst) => {
203                execute_ffi_instruction(backend, &mut slots, inst, DispatchMode::Segmented, cache)?;
204                reclaim_last_use_inputs_backend(&mut slots, inst, backend);
205            }
206            Segment::Host(inst) => {
207                execute_host_instruction(backend, &mut slots, inst)?;
208                reclaim_last_use_inputs_backend(&mut slots, inst, backend);
209            }
210        }
211    }
212
213    collect_outputs(program, slots)
214}
215
216fn flush_fused_segment(
217    program: &ExecProgram,
218    segments: &mut Vec<Segment>,
219    start: Option<usize>,
220    end: usize,
221) {
222    let Some(start) = start else {
223        return;
224    };
225    if start == end {
226        return;
227    }
228    segments.push(build_fused_segment(program, start, end));
229}
230
231fn build_fused_segment(program: &ExecProgram, start: usize, end: usize) -> Segment {
232    let instructions = program.instructions[start..end].to_vec();
233    let mut produced = HashSet::new();
234    let mut seen_inputs = HashSet::new();
235    let mut input_slots = Vec::new();
236    let mut produced_order = Vec::new();
237
238    for inst in &instructions {
239        for &slot in &inst.input_slots {
240            if !produced.contains(&slot) && seen_inputs.insert(slot) {
241                input_slots.push(slot);
242            }
243        }
244        for &slot in &inst.output_slots {
245            if produced.insert(slot) {
246                produced_order.push(slot);
247            }
248        }
249    }
250
251    let later_instructions = &program.instructions[end..];
252    let output_slots: Vec<usize> = produced_order
253        .into_iter()
254        .filter(|slot| {
255            program.output_slots.contains(slot)
256                || later_instructions
257                    .iter()
258                    .any(|later| later.input_slots.contains(slot))
259        })
260        .collect();
261
262    let last_use = input_slots
263        .iter()
264        .map(|slot| {
265            !program.output_slots.contains(slot)
266                && !later_instructions
267                    .iter()
268                    .any(|later| later.input_slots.contains(slot))
269        })
270        .collect();
271
272    Segment::Fused {
273        instructions,
274        input_slots,
275        output_slots,
276        last_use,
277    }
278}
279
280fn collect_segment_inputs<'a>(
281    slots: &'a [Option<Tensor>],
282    input_slots: &[usize],
283) -> Result<Vec<&'a Tensor>> {
284    input_slots
285        .iter()
286        .map(|&slot| {
287            slots[slot]
288                .as_ref()
289                .ok_or(tenferro_tensor::Error::MissingValue { slot }.into())
290        })
291        .collect()
292}
293
294fn reclaim_segment_inputs_exec(
295    slots: &mut [Option<Tensor>],
296    input_slots: &[usize],
297    last_use: &[bool],
298    exec: &mut dyn tenferro_tensor::TensorExec,
299) {
300    for (&slot, &is_last_use) in input_slots.iter().zip(last_use.iter()) {
301        if is_last_use {
302            if let Some(tensor) = slots[slot].take() {
303                exec.reclaim_buffer(tensor);
304            }
305        }
306    }
307}
308
309fn build_elementwise_fusion_plan(
310    instructions: &[ExecInstruction],
311    input_slots: &[usize],
312    output_slots: &[usize],
313) -> Option<ElementwiseFusionPlan> {
314    let first = instructions.first()?;
315    let dtype = first.dtype;
316    let mut slot_to_value = HashMap::with_capacity(input_slots.len() + instructions.len());
317    for (value, &slot) in input_slots.iter().enumerate() {
318        slot_to_value.insert(slot, value);
319    }
320
321    let mut ops = Vec::with_capacity(instructions.len());
322    let mut next_value = input_slots.len();
323    for inst in instructions {
324        if inst.dtype != dtype || inst.output_slots.len() != 1 {
325            return None;
326        }
327        let op = map_exec_op_to_elementwise_fusion(&inst.op)?;
328        let inputs = inst
329            .input_slots
330            .iter()
331            .map(|slot| slot_to_value.get(slot).copied())
332            .collect::<Option<Vec<_>>>()?;
333        ops.push(ElementwiseFusionInst { op, inputs });
334        slot_to_value.insert(inst.output_slots[0], next_value);
335        next_value += 1;
336    }
337
338    let outputs = output_slots
339        .iter()
340        .map(|slot| slot_to_value.get(slot).copied())
341        .collect::<Option<Vec<_>>>()?;
342
343    Some(ElementwiseFusionPlan {
344        dtype,
345        n_inputs: input_slots.len(),
346        outputs,
347        ops,
348    })
349}
350
351fn map_exec_op_to_elementwise_fusion(op: &ExecOp) -> Option<ElementwiseFusionOp> {
352    match op {
353        ExecOp::Add => Some(ElementwiseFusionOp::Add),
354        ExecOp::Multiply => Some(ElementwiseFusionOp::Multiply),
355        ExecOp::Negate => Some(ElementwiseFusionOp::Negate),
356        ExecOp::Conj => Some(ElementwiseFusionOp::Conj),
357        ExecOp::Divide => Some(ElementwiseFusionOp::Divide),
358        ExecOp::Abs => Some(ElementwiseFusionOp::Abs),
359        ExecOp::Maximum => Some(ElementwiseFusionOp::Maximum),
360        ExecOp::Minimum => Some(ElementwiseFusionOp::Minimum),
361        ExecOp::Compare(dir) => Some(ElementwiseFusionOp::Compare(dir.clone())),
362        ExecOp::Select => Some(ElementwiseFusionOp::Select),
363        ExecOp::Clamp => Some(ElementwiseFusionOp::Clamp),
364        ExecOp::Exp => Some(ElementwiseFusionOp::Exp),
365        ExecOp::Log => Some(ElementwiseFusionOp::Log),
366        ExecOp::Sin => Some(ElementwiseFusionOp::Sin),
367        ExecOp::Cos => Some(ElementwiseFusionOp::Cos),
368        ExecOp::Tanh => Some(ElementwiseFusionOp::Tanh),
369        ExecOp::Sqrt => Some(ElementwiseFusionOp::Sqrt),
370        ExecOp::Rsqrt => Some(ElementwiseFusionOp::Rsqrt),
371        ExecOp::Pow => Some(ElementwiseFusionOp::Pow),
372        ExecOp::Expm1 => Some(ElementwiseFusionOp::Expm1),
373        ExecOp::Log1p => Some(ElementwiseFusionOp::Log1p),
374        _ => None,
375    }
376}