Skip to main content

tenferro/
exec.rs

1use std::sync::Arc;
2
3use crate::compiler::compile_std_to_exec;
4use crate::engine::NaryEinsumCache;
5use crate::error::{Error, Result};
6use computegraph::compile::compile;
7use computegraph::fragment::FragmentBuilder;
8use computegraph::materialize::materialize_merge;
9use computegraph::resolve::resolve;
10use computegraph::types::{GlobalValKey, ValRef};
11use num_complex::{Complex32, Complex64};
12use tenferro_algebra::Semiring;
13use tenferro_ops::dim_expr::DimExpr;
14use tenferro_ops::input_key::TensorInputKey;
15use tenferro_ops::std_tensor_op::StdTensorOp;
16use tenferro_tensor::cpu::structural::{
17    typed_broadcast_in_dim, typed_embed_diagonal, typed_extract_diagonal, typed_reshape,
18    typed_transpose,
19};
20use tenferro_tensor::validate::validate_nonsingular_u;
21use tenferro_tensor::Error as TensorError;
22use tenferro_tensor::{
23    CompareDir, DType, DotGeneralConfig, GatherConfig, PadConfig, ScatterConfig, SemiringBackend,
24    SliceConfig, Tensor, TensorBackend, TensorExec, TypedTensor,
25};
26
27#[derive(Clone, Debug)]
28pub enum ExecOp {
29    Transpose {
30        perm: Vec<usize>,
31    },
32    Reshape {
33        shape: Vec<DimExpr>,
34    },
35    BroadcastInDim {
36        shape: Vec<DimExpr>,
37        dims: Vec<usize>,
38    },
39    Convert {
40        to: DType,
41    },
42    Constant {
43        dtype: DType,
44        bytes: Vec<u8>,
45    },
46    DotGeneral(DotGeneralConfig),
47    NaryEinsum {
48        subscripts: String,
49    },
50    ReduceSum {
51        axes: Vec<usize>,
52    },
53    ExtractDiag {
54        axis_a: usize,
55        axis_b: usize,
56    },
57    EmbedDiag {
58        axis_a: usize,
59        axis_b: usize,
60    },
61    Tril {
62        k: i64,
63    },
64    Triu {
65        k: i64,
66    },
67    Add,
68    Multiply,
69    Negate,
70    Conj,
71    Divide,
72    Abs,
73    Sign,
74    Maximum,
75    Minimum,
76    Compare(CompareDir),
77    Select,
78    Clamp,
79    Exp,
80    Log,
81    Sin,
82    Cos,
83    Tanh,
84    Sqrt,
85    Rsqrt,
86    Pow,
87    Expm1,
88    Log1p,
89    Gather(GatherConfig),
90    Scatter(ScatterConfig),
91    Slice(SliceConfig),
92    DynamicSlice {
93        slice_sizes: Vec<usize>,
94    },
95    Pad(PadConfig),
96    Concatenate {
97        axis: usize,
98    },
99    Reverse {
100        axes: Vec<usize>,
101    },
102    ShapeOf {
103        axis: usize,
104    },
105    DynamicTruncate {
106        axis: usize,
107    },
108    PadToMatch {
109        axis: usize,
110    },
111    ReduceProd {
112        axes: Vec<usize>,
113    },
114    ReduceMax {
115        axes: Vec<usize>,
116    },
117    ReduceMin {
118        axes: Vec<usize>,
119    },
120    Cholesky,
121    Svd {
122        eps: f64,
123    },
124    Qr,
125    Lu,
126    Eigh {
127        eps: f64,
128    },
129    Eig,
130    ValidateNonsingular,
131    TriangularSolve {
132        left_side: bool,
133        lower: bool,
134        transpose_a: bool,
135        unit_diagonal: bool,
136    },
137}
138
139#[derive(Clone, Debug)]
140pub struct ExecInstruction {
141    pub op: ExecOp,
142    pub input_slots: Vec<usize>,
143    pub output_slots: Vec<usize>,
144    pub dtype: tenferro_tensor::DType,
145    pub output_shapes: Vec<Vec<tenferro_ops::dim_expr::DimExpr>>,
146    pub last_use: Vec<bool>,
147}
148
149#[derive(Clone, Debug)]
150pub struct ExecProgram {
151    pub instructions: Vec<ExecInstruction>,
152    pub input_slots: Vec<usize>,
153    pub output_slots: Vec<usize>,
154    pub n_slots: usize,
155}
156
157#[derive(Clone, Copy, Debug, PartialEq, Eq)]
158pub(crate) enum DispatchMode {
159    Unsegmented,
160    Segmented,
161}
162
163pub(crate) fn get<'a, T>(
164    slots: &'a [Option<T>],
165    input_slots: &[usize],
166    idx: usize,
167) -> Result<&'a T> {
168    let slot = input_slots[idx];
169    slots[slot]
170        .as_ref()
171        .ok_or(TensorError::MissingValue { slot }.into())
172}
173
174pub(crate) fn initialize_slots(program: &ExecProgram, inputs: Vec<Tensor>) -> Vec<Option<Tensor>> {
175    let mut slots: Vec<Option<Tensor>> = vec![None; program.n_slots];
176    for (i, tensor) in inputs.into_iter().enumerate() {
177        slots[program.input_slots[i]] = Some(tensor);
178    }
179    slots
180}
181
182pub(crate) fn collect_outputs(
183    program: &ExecProgram,
184    mut slots: Vec<Option<Tensor>>,
185) -> Result<Vec<Tensor>> {
186    program
187        .output_slots
188        .iter()
189        .map(|&slot| {
190            slots[slot]
191                .take()
192                .ok_or(TensorError::MissingValue { slot }.into())
193        })
194        .collect()
195}
196
197pub(crate) fn is_host_instruction(inst: &ExecInstruction) -> bool {
198    matches!(
199        &inst.op,
200        ExecOp::ShapeOf { .. }
201            | ExecOp::DynamicTruncate { .. }
202            | ExecOp::PadToMatch { .. }
203            | ExecOp::Constant { .. }
204            | ExecOp::ValidateNonsingular
205    )
206}
207
208pub(crate) fn is_ffi_instruction(inst: &ExecInstruction) -> bool {
209    matches!(
210        &inst.op,
211        ExecOp::DotGeneral(_)
212            | ExecOp::NaryEinsum { .. }
213            | ExecOp::Cholesky
214            | ExecOp::Svd { .. }
215            | ExecOp::Qr
216            | ExecOp::Lu
217            | ExecOp::Eigh { .. }
218            | ExecOp::Eig
219            | ExecOp::TriangularSolve { .. }
220    )
221}
222
223pub(crate) fn resolve_tensor_shape_exprs(
224    slots: &[Option<Tensor>],
225    input_slots: &[usize],
226    exprs: &[DimExpr],
227) -> Result<Vec<usize>> {
228    let mut input_shapes = Vec::with_capacity(input_slots.len());
229    for &slot in input_slots {
230        input_shapes.push(
231            slots[slot]
232                .as_ref()
233                .ok_or(TensorError::MissingValue { slot })?
234                .shape(),
235        );
236    }
237    Ok(DimExpr::eval_all(exprs, &input_shapes))
238}
239
240fn resolve_semiring_shape_exprs<Alg: Semiring>(
241    slots: &[Option<TypedTensor<Alg::Scalar>>],
242    input_slots: &[usize],
243    exprs: &[DimExpr],
244) -> Result<Vec<usize>> {
245    let mut input_shapes = Vec::with_capacity(input_slots.len());
246    for &slot in input_slots {
247        input_shapes.push(
248            slots[slot]
249                .as_ref()
250                .ok_or(TensorError::MissingValue { slot })?
251                .shape
252                .as_slice(),
253        );
254    }
255    Ok(DimExpr::eval_all(exprs, &input_shapes))
256}
257
258/// Evaluate an [`ExecProgram`] using segmented dispatch.
259///
260/// Consecutive fusible ops are executed within one backend execution session.
261///
262/// # Examples
263///
264/// ```
265/// use tenferro::exec::{eval_exec_ir, ExecProgram};
266/// use tenferro::CpuBackend;
267///
268/// let _eval: fn(&mut CpuBackend, &ExecProgram, Vec<tenferro::Tensor>) -> tenferro::error::Result<Vec<tenferro::Tensor>> =
269///     eval_exec_ir::<CpuBackend>;
270/// ```
271pub fn eval_exec_ir<B: TensorBackend>(
272    backend: &mut B,
273    program: &ExecProgram,
274    inputs: Vec<Tensor>,
275) -> Result<Vec<Tensor>> {
276    crate::segment::eval_exec_segmented(backend, program, inputs)
277}
278
279/// Evaluate an [`ExecProgram`] one instruction at a time.
280///
281/// This is retained for parity tests against segmented dispatch.
282///
283/// # Examples
284///
285/// ```
286/// use tenferro::exec::{eval_exec_ir_unsegmented, ExecProgram};
287/// use tenferro::CpuBackend;
288///
289/// let _eval: fn(&mut CpuBackend, &ExecProgram, Vec<tenferro::Tensor>) -> tenferro::error::Result<Vec<tenferro::Tensor>> =
290///     eval_exec_ir_unsegmented::<CpuBackend>;
291/// ```
292pub fn eval_exec_ir_unsegmented<B: TensorBackend>(
293    backend: &mut B,
294    program: &ExecProgram,
295    inputs: Vec<Tensor>,
296) -> Result<Vec<Tensor>> {
297    let mut cache = NaryEinsumCache::new(
298        std::num::NonZeroUsize::new(crate::engine::DEFAULT_EINSUM_CACHE_CAPACITY)
299            .expect("DEFAULT_EINSUM_CACHE_CAPACITY must be non-zero"),
300    );
301    eval_exec_ir_unsegmented_with_cache(backend, program, inputs, &mut cache)
302}
303
304pub(crate) fn eval_exec_ir_unsegmented_with_cache<B: TensorBackend>(
305    backend: &mut B,
306    program: &ExecProgram,
307    inputs: Vec<Tensor>,
308    cache: &mut NaryEinsumCache,
309) -> Result<Vec<Tensor>> {
310    let mut slots = initialize_slots(program, inputs);
311
312    for inst in &program.instructions {
313        if is_host_instruction(inst) {
314            execute_host_instruction(backend, &mut slots, inst)?;
315        } else if is_ffi_instruction(inst) {
316            execute_ffi_instruction(backend, &mut slots, inst, DispatchMode::Unsegmented, cache)?;
317        } else {
318            let result =
319                backend.with_exec_session(|exec| execute_backend_op(exec, &slots, inst))?;
320            slots[inst.output_slots[0]] = Some(result);
321        }
322        reclaim_last_use_inputs_backend(&mut slots, inst, backend);
323    }
324
325    collect_outputs(program, slots)
326}
327
328pub(crate) fn execute_backend_op(
329    exec: &mut dyn TensorExec,
330    slots: &[Option<Tensor>],
331    inst: &ExecInstruction,
332) -> Result<Tensor> {
333    let result = match &inst.op {
334        ExecOp::Transpose { perm } => exec.transpose(get(slots, &inst.input_slots, 0)?, perm)?,
335        ExecOp::Reshape { shape } => {
336            let shape = resolve_tensor_shape_exprs(slots, &inst.input_slots, shape)?;
337            exec.reshape(get(slots, &inst.input_slots, 0)?, &shape)?
338        }
339        ExecOp::BroadcastInDim { shape, dims } => {
340            let shape = resolve_tensor_shape_exprs(slots, &inst.input_slots, shape)?;
341            exec.broadcast_in_dim(get(slots, &inst.input_slots, 0)?, &shape, dims)?
342        }
343        ExecOp::Convert { to } => exec.convert(get(slots, &inst.input_slots, 0)?, *to)?,
344        ExecOp::ReduceSum { axes } => exec.reduce_sum(get(slots, &inst.input_slots, 0)?, axes)?,
345        ExecOp::ExtractDiag { axis_a, axis_b } => {
346            exec.extract_diagonal(get(slots, &inst.input_slots, 0)?, *axis_a, *axis_b)?
347        }
348        ExecOp::EmbedDiag { axis_a, axis_b } => {
349            exec.embed_diagonal(get(slots, &inst.input_slots, 0)?, *axis_a, *axis_b)?
350        }
351        ExecOp::Tril { k } => exec.tril(get(slots, &inst.input_slots, 0)?, *k)?,
352        ExecOp::Triu { k } => exec.triu(get(slots, &inst.input_slots, 0)?, *k)?,
353        ExecOp::Add => exec.add(
354            get(slots, &inst.input_slots, 0)?,
355            get(slots, &inst.input_slots, 1)?,
356        )?,
357        ExecOp::Multiply => exec.mul(
358            get(slots, &inst.input_slots, 0)?,
359            get(slots, &inst.input_slots, 1)?,
360        )?,
361        ExecOp::Negate => exec.neg(get(slots, &inst.input_slots, 0)?)?,
362        ExecOp::Conj => exec.conj(get(slots, &inst.input_slots, 0)?)?,
363        ExecOp::Divide => exec.div(
364            get(slots, &inst.input_slots, 0)?,
365            get(slots, &inst.input_slots, 1)?,
366        )?,
367        ExecOp::Abs => exec.abs(get(slots, &inst.input_slots, 0)?)?,
368        ExecOp::Sign => exec.sign(get(slots, &inst.input_slots, 0)?)?,
369        ExecOp::Maximum => exec.maximum(
370            get(slots, &inst.input_slots, 0)?,
371            get(slots, &inst.input_slots, 1)?,
372        )?,
373        ExecOp::Minimum => exec.minimum(
374            get(slots, &inst.input_slots, 0)?,
375            get(slots, &inst.input_slots, 1)?,
376        )?,
377        ExecOp::Compare(dir) => exec.compare(
378            get(slots, &inst.input_slots, 0)?,
379            get(slots, &inst.input_slots, 1)?,
380            dir,
381        )?,
382        ExecOp::Select => exec.select(
383            get(slots, &inst.input_slots, 0)?,
384            get(slots, &inst.input_slots, 1)?,
385            get(slots, &inst.input_slots, 2)?,
386        )?,
387        ExecOp::Clamp => exec.clamp(
388            get(slots, &inst.input_slots, 0)?,
389            get(slots, &inst.input_slots, 1)?,
390            get(slots, &inst.input_slots, 2)?,
391        )?,
392        ExecOp::Exp => exec.exp(get(slots, &inst.input_slots, 0)?)?,
393        ExecOp::Log => exec.log(get(slots, &inst.input_slots, 0)?)?,
394        ExecOp::Sin => exec.sin(get(slots, &inst.input_slots, 0)?)?,
395        ExecOp::Cos => exec.cos(get(slots, &inst.input_slots, 0)?)?,
396        ExecOp::Tanh => exec.tanh(get(slots, &inst.input_slots, 0)?)?,
397        ExecOp::Sqrt => exec.sqrt(get(slots, &inst.input_slots, 0)?)?,
398        ExecOp::Rsqrt => exec.rsqrt(get(slots, &inst.input_slots, 0)?)?,
399        ExecOp::Pow => exec.pow(
400            get(slots, &inst.input_slots, 0)?,
401            get(slots, &inst.input_slots, 1)?,
402        )?,
403        ExecOp::Expm1 => exec.expm1(get(slots, &inst.input_slots, 0)?)?,
404        ExecOp::Log1p => exec.log1p(get(slots, &inst.input_slots, 0)?)?,
405        ExecOp::Gather(config) => exec.gather(
406            get(slots, &inst.input_slots, 0)?,
407            get(slots, &inst.input_slots, 1)?,
408            config,
409        )?,
410        ExecOp::Scatter(config) => exec.scatter(
411            get(slots, &inst.input_slots, 0)?,
412            get(slots, &inst.input_slots, 1)?,
413            get(slots, &inst.input_slots, 2)?,
414            config,
415        )?,
416        ExecOp::Slice(config) => exec.slice(get(slots, &inst.input_slots, 0)?, config)?,
417        ExecOp::DynamicSlice { slice_sizes } => exec.dynamic_slice(
418            get(slots, &inst.input_slots, 0)?,
419            get(slots, &inst.input_slots, 1)?,
420            slice_sizes,
421        )?,
422        ExecOp::Pad(config) => exec.pad(get(slots, &inst.input_slots, 0)?, config)?,
423        ExecOp::Concatenate { axis } => {
424            let inputs = collect_tensor_refs(slots, &inst.input_slots)?;
425            exec.concatenate(&inputs, *axis)?
426        }
427        ExecOp::Reverse { axes } => exec.reverse(get(slots, &inst.input_slots, 0)?, axes)?,
428        ExecOp::ReduceProd { axes } => exec.reduce_prod(get(slots, &inst.input_slots, 0)?, axes)?,
429        ExecOp::ReduceMax { axes } => exec.reduce_max(get(slots, &inst.input_slots, 0)?, axes)?,
430        ExecOp::ReduceMin { axes } => exec.reduce_min(get(slots, &inst.input_slots, 0)?, axes)?,
431        other => {
432            return Err(Error::Internal(format!(
433                "host or FFI op reached backend executor: {other:?}"
434            )))
435        }
436    };
437    Ok(result)
438}
439
440pub(crate) fn execute_host_instruction<B: TensorBackend>(
441    backend: &mut B,
442    slots: &mut [Option<Tensor>],
443    inst: &ExecInstruction,
444) -> Result<()> {
445    match &inst.op {
446        ExecOp::ShapeOf { axis } => {
447            let input = get(slots, &inst.input_slots, 0)?;
448            if *axis >= input.shape().len() {
449                return Err(Error::Internal(format!(
450                    "ShapeOf: axis {} out of bounds for rank {}",
451                    axis,
452                    input.shape().len()
453                )));
454            }
455            let host = Tensor::F64(TypedTensor::from_vec(
456                vec![],
457                vec![input.shape()[*axis] as f64],
458            ));
459            slots[inst.output_slots[0]] = Some(backend.upload_host_tensor(&host)?);
460        }
461        ExecOp::DynamicTruncate { axis } => {
462            let input = get(slots, &inst.input_slots, 0)?;
463            if *axis >= input.shape().len() {
464                return Err(Error::Internal(format!(
465                    "DynamicTruncate: axis {} out of bounds for rank {}",
466                    axis,
467                    input.shape().len()
468                )));
469            }
470            let size_tensor = backend.download_to_host(get(slots, &inst.input_slots, 1)?)?;
471            let axis_extent = input.shape()[*axis];
472            let size_f64 = scalar_size_value(&size_tensor)?;
473            let rounded_size = if size_f64.is_finite() {
474                size_f64.round()
475            } else {
476                0.0
477            };
478            let size = rounded_size.max(0.0).min(axis_extent as f64) as usize;
479            let rank = input.shape().len();
480            let mut limits = input.shape().to_vec();
481            limits[*axis] = size;
482            let config = SliceConfig {
483                starts: vec![0; rank],
484                limits,
485                strides: vec![1; rank],
486            };
487            slots[inst.output_slots[0]] = Some(backend.slice(input, &config)?);
488        }
489        ExecOp::PadToMatch { axis } => {
490            let input = get(slots, &inst.input_slots, 0)?;
491            let reference = get(slots, &inst.input_slots, 1)?;
492            if *axis >= input.shape().len() {
493                return Err(Error::Internal(format!(
494                    "PadToMatch: axis {} out of bounds for rank {}",
495                    axis,
496                    input.shape().len()
497                )));
498            }
499            let target_size = reference.shape()[*axis];
500            let current_size = input.shape()[*axis];
501            if current_size >= target_size {
502                slots[inst.output_slots[0]] = Some(input.clone());
503            } else {
504                let rank = input.shape().len();
505                let mut high = vec![0i64; rank];
506                high[*axis] = (target_size - current_size) as i64;
507                let config = PadConfig {
508                    edge_padding_low: vec![0i64; rank],
509                    edge_padding_high: high,
510                    interior_padding: vec![0i64; rank],
511                };
512                slots[inst.output_slots[0]] = Some(backend.pad(input, &config)?);
513            }
514        }
515        ExecOp::Constant { dtype, bytes } => {
516            let host = constant_tensor(*dtype, bytes);
517            slots[inst.output_slots[0]] = Some(backend.upload_host_tensor(&host)?);
518        }
519        ExecOp::ValidateNonsingular => {
520            let input = get(slots, &inst.input_slots, 0)?;
521            let host_input = backend.download_to_host(input)?;
522            validate_nonsingular_u(&host_input)?;
523            slots[inst.output_slots[0]] = Some(input.clone());
524        }
525        other => {
526            return Err(Error::Internal(format!(
527                "non-host op reached host executor: {other:?}"
528            )))
529        }
530    }
531    Ok(())
532}
533
534pub(crate) fn execute_ffi_instruction<B: TensorBackend>(
535    backend: &mut B,
536    slots: &mut [Option<Tensor>],
537    inst: &ExecInstruction,
538    mode: DispatchMode,
539    cache: &mut NaryEinsumCache,
540) -> Result<()> {
541    match &inst.op {
542        ExecOp::DotGeneral(config) => {
543            let result = backend.dot_general(
544                get(slots, &inst.input_slots, 0)?,
545                get(slots, &inst.input_slots, 1)?,
546                config,
547            )?;
548            slots[inst.output_slots[0]] = Some(result);
549        }
550        ExecOp::NaryEinsum { subscripts } => {
551            let inputs = collect_tensor_refs(slots, &inst.input_slots)?;
552            let result = execute_nary_einsum(backend, &inputs, subscripts, mode, cache)?;
553            slots[inst.output_slots[0]] = Some(result);
554        }
555        ExecOp::Cholesky => {
556            let result = backend.cholesky(get(slots, &inst.input_slots, 0)?)?;
557            slots[inst.output_slots[0]] = Some(result);
558        }
559        ExecOp::Svd { .. } => {
560            let results = backend.svd(get(slots, &inst.input_slots, 0)?)?;
561            assign_multi_output(slots, inst, results, "svd")?;
562        }
563        ExecOp::Qr => {
564            let results = backend.qr(get(slots, &inst.input_slots, 0)?)?;
565            assign_multi_output(slots, inst, results, "qr")?;
566        }
567        ExecOp::Lu => {
568            let results = backend.lu(get(slots, &inst.input_slots, 0)?)?;
569            assign_multi_output(slots, inst, results, "lu")?;
570        }
571        ExecOp::Eigh { .. } => {
572            let results = backend.eigh(get(slots, &inst.input_slots, 0)?)?;
573            assign_multi_output(slots, inst, results, "eigh")?;
574        }
575        ExecOp::Eig => {
576            let results = backend.eig(get(slots, &inst.input_slots, 0)?)?;
577            assign_multi_output(slots, inst, results, "eig")?;
578        }
579        ExecOp::TriangularSolve {
580            left_side,
581            lower,
582            transpose_a,
583            unit_diagonal,
584        } => {
585            let result = backend.triangular_solve(
586                get(slots, &inst.input_slots, 0)?,
587                get(slots, &inst.input_slots, 1)?,
588                *left_side,
589                *lower,
590                *transpose_a,
591                *unit_diagonal,
592            )?;
593            slots[inst.output_slots[0]] = Some(result);
594        }
595        other => {
596            return Err(Error::Internal(format!(
597                "non-ffi op reached ffi executor: {other:?}"
598            )))
599        }
600    }
601    Ok(())
602}
603
604fn assign_multi_output(
605    slots: &mut [Option<Tensor>],
606    inst: &ExecInstruction,
607    results: Vec<Tensor>,
608    op_name: &str,
609) -> Result<()> {
610    if results.len() != inst.output_slots.len() {
611        return Err(Error::Internal(format!(
612            "{op_name} produced {} outputs for {} slots",
613            results.len(),
614            inst.output_slots.len()
615        )));
616    }
617    for (slot, tensor) in inst.output_slots.iter().copied().zip(results.into_iter()) {
618        slots[slot] = Some(tensor);
619    }
620    Ok(())
621}
622
623fn collect_tensor_refs<'a>(
624    slots: &'a [Option<Tensor>],
625    input_slots: &[usize],
626) -> Result<Vec<&'a Tensor>> {
627    let mut inputs = Vec::with_capacity(input_slots.len());
628    for &slot in input_slots {
629        inputs.push(
630            slots[slot]
631                .as_ref()
632                .ok_or(TensorError::MissingValue { slot })?,
633        );
634    }
635    Ok(inputs)
636}
637
638fn execute_nary_einsum<B: TensorBackend>(
639    backend: &mut B,
640    inputs: &[&Tensor],
641    subscripts: &str,
642    mode: DispatchMode,
643    cache: &mut NaryEinsumCache,
644) -> Result<Tensor> {
645    use tenferro_einsum::{
646        build_einsum_fragment, ContractionOptimizerOptions, ContractionTree, Subscripts,
647    };
648
649    if inputs.is_empty() {
650        return Err(Error::ContractionError(
651            "nary einsum requires at least one input tensor".into(),
652        ));
653    }
654
655    let subs =
656        Subscripts::parse(subscripts).map_err(|e| Error::InvalidSubscripts(format!("{e}")))?;
657    let shapes: Vec<Vec<usize>> = inputs
658        .iter()
659        .map(|tensor| tensor.shape().to_vec())
660        .collect();
661    let shape_refs: Vec<&[usize]> = shapes.iter().map(Vec::as_slice).collect();
662    let cache_key = (subscripts.to_string(), shapes.clone());
663    let tree_arc = if let Some(cached) = cache.get(&cache_key) {
664        cached.clone()
665    } else {
666        let tree = Arc::new(
667            ContractionTree::optimize_with_options(
668                &subs,
669                &shape_refs,
670                &ContractionOptimizerOptions::default(),
671            )
672            .map_err(|e| Error::ContractionError(format!("{e}")))?,
673        );
674        cache.put(cache_key, tree.clone());
675        tree
676    };
677    let tree = tree_arc.as_ref();
678
679    let mut builder = FragmentBuilder::<StdTensorOp>::new();
680    let mut input_vals = Vec::with_capacity(inputs.len());
681    for input_idx in 0..inputs.len() {
682        let local = builder.add_input(TensorInputKey::User {
683            id: input_idx as u64,
684        });
685        input_vals.push(ValRef::Local(local));
686    }
687
688    let result_ref = build_einsum_fragment(&mut builder, &tree, &input_vals, &shapes);
689    let result_local = match result_ref {
690        ValRef::Local(local) => local,
691        ValRef::External(_) => {
692            return Err(Error::Internal(
693                "runtime nary einsum builder returned an external value".into(),
694            ))
695        }
696    };
697    builder.set_outputs(vec![result_local]);
698    let fragment = Arc::new(builder.build());
699    let output_key = fragment.vals()[result_local].key.clone();
700
701    let view = resolve(vec![fragment]);
702    let graph = materialize_merge(&view, &[output_key]);
703    let compiled = compile(&graph);
704
705    let mut program_inputs = Vec::with_capacity(graph.inputs.len());
706    let mut input_dtypes = Vec::with_capacity(graph.inputs.len());
707    let mut input_shapes = Vec::with_capacity(graph.inputs.len());
708    for key in &graph.inputs {
709        match key {
710            GlobalValKey::Input(TensorInputKey::User { id }) => {
711                let input_idx = *id as usize;
712                let tensor = inputs.get(input_idx).ok_or_else(|| {
713                    Error::Internal(format!(
714                        "runtime nary einsum input {input_idx} missing for subscripts {subscripts}"
715                    ))
716                })?;
717                program_inputs.push((*tensor).clone());
718                input_dtypes.push(tensor.dtype());
719                input_shapes.push(DimExpr::from_concrete(tensor.shape()));
720            }
721            other => {
722                return Err(Error::Internal(format!(
723                    "unexpected runtime nary einsum input key: {other:?}"
724                )))
725            }
726        }
727    }
728    let program = compile_std_to_exec(&compiled, &input_dtypes, &input_shapes);
729
730    let mut outputs = match mode {
731        DispatchMode::Unsegmented => {
732            eval_exec_ir_unsegmented_with_cache(backend, &program, program_inputs, cache)?
733        }
734        DispatchMode::Segmented => crate::segment::eval_exec_segmented_with_cache(
735            backend,
736            &program,
737            program_inputs,
738            cache,
739        )?,
740    };
741    if outputs.len() != 1 {
742        return Err(Error::Internal(format!(
743            "runtime nary einsum expected 1 output, got {}",
744            outputs.len()
745        )));
746    }
747    Ok(outputs.remove(0))
748}
749
750pub(crate) fn constant_tensor(dtype: DType, bytes: &[u8]) -> Tensor {
751    match dtype {
752        DType::F64 => Tensor::F64(TypedTensor::from_vec(
753            vec![],
754            vec![f64::from_le_bytes(exact_bytes::<8>(dtype, bytes))],
755        )),
756        DType::F32 => Tensor::F32(TypedTensor::from_vec(
757            vec![],
758            vec![f32::from_le_bytes(exact_bytes::<4>(dtype, bytes))],
759        )),
760        DType::C64 => {
761            let data = exact_bytes::<16>(dtype, bytes);
762            let mut re_bytes = [0u8; 8];
763            let mut im_bytes = [0u8; 8];
764            re_bytes.copy_from_slice(&data[..8]);
765            im_bytes.copy_from_slice(&data[8..]);
766            let re = f64::from_le_bytes(re_bytes);
767            let im = f64::from_le_bytes(im_bytes);
768            Tensor::C64(TypedTensor::from_vec(vec![], vec![Complex64::new(re, im)]))
769        }
770        DType::C32 => {
771            let data = exact_bytes::<8>(dtype, bytes);
772            let mut re_bytes = [0u8; 4];
773            let mut im_bytes = [0u8; 4];
774            re_bytes.copy_from_slice(&data[..4]);
775            im_bytes.copy_from_slice(&data[4..]);
776            let re = f32::from_le_bytes(re_bytes);
777            let im = f32::from_le_bytes(im_bytes);
778            Tensor::C32(TypedTensor::from_vec(vec![], vec![Complex32::new(re, im)]))
779        }
780    }
781}
782
783fn exact_bytes<const N: usize>(dtype: DType, bytes: &[u8]) -> [u8; N] {
784    if bytes.len() != N {
785        panic!(
786            "constant {:?} expected {} bytes, got {}",
787            dtype,
788            N,
789            bytes.len()
790        );
791    }
792    let mut out = [0u8; N];
793    out.copy_from_slice(bytes);
794    out
795}
796
797fn scalar_size_value(size_tensor: &Tensor) -> Result<f64> {
798    match size_tensor {
799        Tensor::F64(inner) => Ok(inner.host_data()[0]),
800        Tensor::F32(inner) => Ok(inner.host_data()[0] as f64),
801        _ => Err(Error::Internal(
802            "DynamicTruncate size must be an f32 or f64 scalar".into(),
803        )),
804    }
805}
806
807pub(crate) fn reclaim_last_use_inputs_exec(
808    slots: &mut [Option<Tensor>],
809    inst: &ExecInstruction,
810    exec: &mut dyn TensorExec,
811) {
812    for (i, &is_last) in inst.last_use.iter().enumerate() {
813        if is_last {
814            if let Some(tensor) = slots[inst.input_slots[i]].take() {
815                exec.reclaim_buffer(tensor);
816            }
817        }
818    }
819}
820
821pub(crate) fn reclaim_last_use_inputs_backend<B: TensorBackend>(
822    slots: &mut [Option<Tensor>],
823    inst: &ExecInstruction,
824    backend: &mut B,
825) {
826    for (i, &is_last) in inst.last_use.iter().enumerate() {
827        if is_last {
828            if let Some(tensor) = slots[inst.input_slots[i]].take() {
829                backend.reclaim_buffer(tensor);
830            }
831        }
832    }
833}
834
835pub fn eval_semiring_ir<B, Alg>(
836    backend: &mut B,
837    program: &ExecProgram,
838    inputs: Vec<TypedTensor<Alg::Scalar>>,
839) -> Result<Vec<TypedTensor<Alg::Scalar>>>
840where
841    Alg: Semiring,
842    B: SemiringBackend<Alg>,
843{
844    let mut slots: Vec<Option<TypedTensor<Alg::Scalar>>> = vec![None; program.n_slots];
845    for (i, tensor) in inputs.into_iter().enumerate() {
846        slots[program.input_slots[i]] = Some(tensor);
847    }
848
849    for inst in &program.instructions {
850        let result = match &inst.op {
851            ExecOp::Transpose { perm } => typed_transpose(get(&slots, &inst.input_slots, 0)?, perm),
852            ExecOp::Reshape { shape } => {
853                let shape = resolve_semiring_shape_exprs::<Alg>(&slots, &inst.input_slots, shape)?;
854                typed_reshape(get(&slots, &inst.input_slots, 0)?, &shape)
855            }
856            ExecOp::BroadcastInDim { shape, dims } => {
857                let shape = resolve_semiring_shape_exprs::<Alg>(&slots, &inst.input_slots, shape)?;
858                typed_broadcast_in_dim(get(&slots, &inst.input_slots, 0)?, &shape, dims)
859            }
860            ExecOp::DotGeneral(config) => backend.batched_gemm(
861                get(&slots, &inst.input_slots, 0)?,
862                get(&slots, &inst.input_slots, 1)?,
863                config,
864            ),
865            ExecOp::ReduceSum { axes } => {
866                backend.reduce_sum(get(&slots, &inst.input_slots, 0)?, axes)
867            }
868            ExecOp::ExtractDiag { axis_a, axis_b } => {
869                typed_extract_diagonal(get(&slots, &inst.input_slots, 0)?, *axis_a, *axis_b)
870            }
871            ExecOp::EmbedDiag { axis_a, axis_b } => {
872                typed_embed_diagonal(get(&slots, &inst.input_slots, 0)?, *axis_a, *axis_b)
873            }
874            ExecOp::Add => backend.add(
875                get(&slots, &inst.input_slots, 0)?,
876                get(&slots, &inst.input_slots, 1)?,
877            ),
878            ExecOp::Multiply => backend.mul(
879                get(&slots, &inst.input_slots, 0)?,
880                get(&slots, &inst.input_slots, 1)?,
881            ),
882            _ => panic!("non-semiring op in semiring program: {:?}", inst.op),
883        };
884        slots[inst.output_slots[0]] = Some(result?);
885    }
886
887    program
888        .output_slots
889        .iter()
890        .map(|&slot| {
891            slots[slot]
892                .take()
893                .ok_or(TensorError::MissingValue { slot }.into())
894        })
895        .collect()
896}