1use std::collections::{HashMap, HashSet};
4
5use crate::engine::{NaryEinsumCache, DEFAULT_EINSUM_CACHE_CAPACITY};
6use crate::error::Result;
7use crate::exec::{
8 collect_outputs, execute_backend_op, execute_ffi_instruction, execute_host_instruction,
9 initialize_slots, is_ffi_instruction, is_host_instruction, reclaim_last_use_inputs_backend,
10 reclaim_last_use_inputs_exec, DispatchMode, ExecInstruction, ExecOp, ExecProgram,
11};
12use tenferro_tensor::{
13 ElementwiseFusionInst, ElementwiseFusionOp, ElementwiseFusionPlan, Tensor, TensorBackend,
14};
15
16#[derive(Clone, Debug)]
57pub enum Segment {
58 Fused {
59 instructions: Vec<ExecInstruction>,
60 input_slots: Vec<usize>,
61 output_slots: Vec<usize>,
62 last_use: Vec<bool>,
63 },
64 Ffi(ExecInstruction),
65 Host(ExecInstruction),
66}
67
68pub fn segment_exec_program(program: &ExecProgram) -> Vec<Segment> {
106 let mut segments = Vec::new();
107 let mut fused_start: Option<usize> = None;
108
109 for (idx, inst) in program.instructions.iter().enumerate() {
110 if is_host_instruction(inst) {
111 flush_fused_segment(program, &mut segments, fused_start.take(), idx);
112 segments.push(Segment::Host(inst.clone()));
113 } else if is_ffi_instruction(inst) {
114 flush_fused_segment(program, &mut segments, fused_start.take(), idx);
115 segments.push(Segment::Ffi(inst.clone()));
116 } else if fused_start.is_none() {
117 fused_start = Some(idx);
118 }
119 }
120
121 flush_fused_segment(
122 program,
123 &mut segments,
124 fused_start.take(),
125 program.instructions.len(),
126 );
127 segments
128}
129
130pub fn eval_exec_segmented<B: TensorBackend>(
143 backend: &mut B,
144 program: &ExecProgram,
145 inputs: Vec<Tensor>,
146) -> Result<Vec<Tensor>> {
147 let mut cache = NaryEinsumCache::new(
148 std::num::NonZeroUsize::new(DEFAULT_EINSUM_CACHE_CAPACITY)
149 .expect("DEFAULT_EINSUM_CACHE_CAPACITY must be non-zero"),
150 );
151 eval_exec_segmented_with_cache(backend, program, inputs, &mut cache)
152}
153
154pub(crate) fn eval_exec_segmented_with_cache<B: TensorBackend>(
155 backend: &mut B,
156 program: &ExecProgram,
157 inputs: Vec<Tensor>,
158 cache: &mut crate::engine::NaryEinsumCache,
159) -> Result<Vec<Tensor>> {
160 let segments = segment_exec_program(program);
161 let mut slots = initialize_slots(program, inputs);
162
163 for segment in &segments {
164 match segment {
165 Segment::Fused {
166 instructions,
167 input_slots,
168 output_slots,
169 last_use,
170 } => {
171 backend.with_exec_session(|exec| -> Result<()> {
172 if let Some(plan) =
173 build_elementwise_fusion_plan(instructions, input_slots, output_slots)
174 {
175 let inputs = collect_segment_inputs(&slots, input_slots)?;
176 if let Some(outputs) = exec.execute_elementwise_fusion(&inputs, &plan)? {
177 if outputs.len() != output_slots.len() {
178 return Err(crate::error::Error::Internal(format!(
179 "fused elementwise kernel produced {} outputs for {} slots",
180 outputs.len(),
181 output_slots.len()
182 )));
183 }
184 for (slot, tensor) in
185 output_slots.iter().copied().zip(outputs.into_iter())
186 {
187 slots[slot] = Some(tensor);
188 }
189 reclaim_segment_inputs_exec(&mut slots, input_slots, last_use, exec);
190 return Ok(());
191 }
192 }
193
194 for inst in instructions {
195 let result = execute_backend_op(exec, &slots, inst)?;
196 slots[inst.output_slots[0]] = Some(result);
197 reclaim_last_use_inputs_exec(&mut slots, inst, exec);
198 }
199 Ok(())
200 })?;
201 }
202 Segment::Ffi(inst) => {
203 execute_ffi_instruction(backend, &mut slots, inst, DispatchMode::Segmented, cache)?;
204 reclaim_last_use_inputs_backend(&mut slots, inst, backend);
205 }
206 Segment::Host(inst) => {
207 execute_host_instruction(backend, &mut slots, inst)?;
208 reclaim_last_use_inputs_backend(&mut slots, inst, backend);
209 }
210 }
211 }
212
213 collect_outputs(program, slots)
214}
215
216fn flush_fused_segment(
217 program: &ExecProgram,
218 segments: &mut Vec<Segment>,
219 start: Option<usize>,
220 end: usize,
221) {
222 let Some(start) = start else {
223 return;
224 };
225 if start == end {
226 return;
227 }
228 segments.push(build_fused_segment(program, start, end));
229}
230
231fn build_fused_segment(program: &ExecProgram, start: usize, end: usize) -> Segment {
232 let instructions = program.instructions[start..end].to_vec();
233 let mut produced = HashSet::new();
234 let mut seen_inputs = HashSet::new();
235 let mut input_slots = Vec::new();
236 let mut produced_order = Vec::new();
237
238 for inst in &instructions {
239 for &slot in &inst.input_slots {
240 if !produced.contains(&slot) && seen_inputs.insert(slot) {
241 input_slots.push(slot);
242 }
243 }
244 for &slot in &inst.output_slots {
245 if produced.insert(slot) {
246 produced_order.push(slot);
247 }
248 }
249 }
250
251 let later_instructions = &program.instructions[end..];
252 let output_slots: Vec<usize> = produced_order
253 .into_iter()
254 .filter(|slot| {
255 program.output_slots.contains(slot)
256 || later_instructions
257 .iter()
258 .any(|later| later.input_slots.contains(slot))
259 })
260 .collect();
261
262 let last_use = input_slots
263 .iter()
264 .map(|slot| {
265 !program.output_slots.contains(slot)
266 && !later_instructions
267 .iter()
268 .any(|later| later.input_slots.contains(slot))
269 })
270 .collect();
271
272 Segment::Fused {
273 instructions,
274 input_slots,
275 output_slots,
276 last_use,
277 }
278}
279
280fn collect_segment_inputs<'a>(
281 slots: &'a [Option<Tensor>],
282 input_slots: &[usize],
283) -> Result<Vec<&'a Tensor>> {
284 input_slots
285 .iter()
286 .map(|&slot| {
287 slots[slot]
288 .as_ref()
289 .ok_or(tenferro_tensor::Error::MissingValue { slot }.into())
290 })
291 .collect()
292}
293
294fn reclaim_segment_inputs_exec(
295 slots: &mut [Option<Tensor>],
296 input_slots: &[usize],
297 last_use: &[bool],
298 exec: &mut dyn tenferro_tensor::TensorExec,
299) {
300 for (&slot, &is_last_use) in input_slots.iter().zip(last_use.iter()) {
301 if is_last_use {
302 if let Some(tensor) = slots[slot].take() {
303 exec.reclaim_buffer(tensor);
304 }
305 }
306 }
307}
308
309fn build_elementwise_fusion_plan(
310 instructions: &[ExecInstruction],
311 input_slots: &[usize],
312 output_slots: &[usize],
313) -> Option<ElementwiseFusionPlan> {
314 let first = instructions.first()?;
315 let dtype = first.dtype;
316 let mut slot_to_value = HashMap::with_capacity(input_slots.len() + instructions.len());
317 for (value, &slot) in input_slots.iter().enumerate() {
318 slot_to_value.insert(slot, value);
319 }
320
321 let mut ops = Vec::with_capacity(instructions.len());
322 let mut next_value = input_slots.len();
323 for inst in instructions {
324 if inst.dtype != dtype || inst.output_slots.len() != 1 {
325 return None;
326 }
327 let op = map_exec_op_to_elementwise_fusion(&inst.op)?;
328 let inputs = inst
329 .input_slots
330 .iter()
331 .map(|slot| slot_to_value.get(slot).copied())
332 .collect::<Option<Vec<_>>>()?;
333 ops.push(ElementwiseFusionInst { op, inputs });
334 slot_to_value.insert(inst.output_slots[0], next_value);
335 next_value += 1;
336 }
337
338 let outputs = output_slots
339 .iter()
340 .map(|slot| slot_to_value.get(slot).copied())
341 .collect::<Option<Vec<_>>>()?;
342
343 Some(ElementwiseFusionPlan {
344 dtype,
345 n_inputs: input_slots.len(),
346 outputs,
347 ops,
348 })
349}
350
351fn map_exec_op_to_elementwise_fusion(op: &ExecOp) -> Option<ElementwiseFusionOp> {
352 match op {
353 ExecOp::Add => Some(ElementwiseFusionOp::Add),
354 ExecOp::Multiply => Some(ElementwiseFusionOp::Multiply),
355 ExecOp::Negate => Some(ElementwiseFusionOp::Negate),
356 ExecOp::Conj => Some(ElementwiseFusionOp::Conj),
357 ExecOp::Divide => Some(ElementwiseFusionOp::Divide),
358 ExecOp::Abs => Some(ElementwiseFusionOp::Abs),
359 ExecOp::Maximum => Some(ElementwiseFusionOp::Maximum),
360 ExecOp::Minimum => Some(ElementwiseFusionOp::Minimum),
361 ExecOp::Compare(dir) => Some(ElementwiseFusionOp::Compare(dir.clone())),
362 ExecOp::Select => Some(ElementwiseFusionOp::Select),
363 ExecOp::Clamp => Some(ElementwiseFusionOp::Clamp),
364 ExecOp::Exp => Some(ElementwiseFusionOp::Exp),
365 ExecOp::Log => Some(ElementwiseFusionOp::Log),
366 ExecOp::Sin => Some(ElementwiseFusionOp::Sin),
367 ExecOp::Cos => Some(ElementwiseFusionOp::Cos),
368 ExecOp::Tanh => Some(ElementwiseFusionOp::Tanh),
369 ExecOp::Sqrt => Some(ElementwiseFusionOp::Sqrt),
370 ExecOp::Rsqrt => Some(ElementwiseFusionOp::Rsqrt),
371 ExecOp::Pow => Some(ElementwiseFusionOp::Pow),
372 ExecOp::Expm1 => Some(ElementwiseFusionOp::Expm1),
373 ExecOp::Log1p => Some(ElementwiseFusionOp::Log1p),
374 _ => None,
375 }
376}