1use std::sync::Arc;
2
3use crate::compiler::compile_std_to_exec;
4use crate::engine::NaryEinsumCache;
5use crate::error::{Error, Result};
6use computegraph::compile::compile;
7use computegraph::fragment::FragmentBuilder;
8use computegraph::materialize::materialize_merge;
9use computegraph::resolve::resolve;
10use computegraph::types::{GlobalValKey, ValRef};
11use num_complex::{Complex32, Complex64};
12use tenferro_algebra::Semiring;
13use tenferro_ops::dim_expr::DimExpr;
14use tenferro_ops::input_key::TensorInputKey;
15use tenferro_ops::std_tensor_op::StdTensorOp;
16use tenferro_tensor::cpu::structural::{
17 typed_broadcast_in_dim, typed_embed_diagonal, typed_extract_diagonal, typed_reshape,
18 typed_transpose,
19};
20use tenferro_tensor::validate::validate_nonsingular_u;
21use tenferro_tensor::Error as TensorError;
22use tenferro_tensor::{
23 CompareDir, DType, DotGeneralConfig, GatherConfig, PadConfig, ScatterConfig, SemiringBackend,
24 SliceConfig, Tensor, TensorBackend, TensorExec, TypedTensor,
25};
26
27#[derive(Clone, Debug)]
28pub enum ExecOp {
29 Transpose {
30 perm: Vec<usize>,
31 },
32 Reshape {
33 shape: Vec<DimExpr>,
34 },
35 BroadcastInDim {
36 shape: Vec<DimExpr>,
37 dims: Vec<usize>,
38 },
39 Convert {
40 to: DType,
41 },
42 Constant {
43 dtype: DType,
44 bytes: Vec<u8>,
45 },
46 DotGeneral(DotGeneralConfig),
47 NaryEinsum {
48 subscripts: String,
49 },
50 ReduceSum {
51 axes: Vec<usize>,
52 },
53 ExtractDiag {
54 axis_a: usize,
55 axis_b: usize,
56 },
57 EmbedDiag {
58 axis_a: usize,
59 axis_b: usize,
60 },
61 Tril {
62 k: i64,
63 },
64 Triu {
65 k: i64,
66 },
67 Add,
68 Multiply,
69 Negate,
70 Conj,
71 Divide,
72 Abs,
73 Sign,
74 Maximum,
75 Minimum,
76 Compare(CompareDir),
77 Select,
78 Clamp,
79 Exp,
80 Log,
81 Sin,
82 Cos,
83 Tanh,
84 Sqrt,
85 Rsqrt,
86 Pow,
87 Expm1,
88 Log1p,
89 Gather(GatherConfig),
90 Scatter(ScatterConfig),
91 Slice(SliceConfig),
92 DynamicSlice {
93 slice_sizes: Vec<usize>,
94 },
95 Pad(PadConfig),
96 Concatenate {
97 axis: usize,
98 },
99 Reverse {
100 axes: Vec<usize>,
101 },
102 ShapeOf {
103 axis: usize,
104 },
105 DynamicTruncate {
106 axis: usize,
107 },
108 PadToMatch {
109 axis: usize,
110 },
111 ReduceProd {
112 axes: Vec<usize>,
113 },
114 ReduceMax {
115 axes: Vec<usize>,
116 },
117 ReduceMin {
118 axes: Vec<usize>,
119 },
120 Cholesky,
121 Svd {
122 eps: f64,
123 },
124 Qr,
125 Lu,
126 Eigh {
127 eps: f64,
128 },
129 Eig,
130 ValidateNonsingular,
131 TriangularSolve {
132 left_side: bool,
133 lower: bool,
134 transpose_a: bool,
135 unit_diagonal: bool,
136 },
137}
138
139#[derive(Clone, Debug)]
140pub struct ExecInstruction {
141 pub op: ExecOp,
142 pub input_slots: Vec<usize>,
143 pub output_slots: Vec<usize>,
144 pub dtype: tenferro_tensor::DType,
145 pub output_shapes: Vec<Vec<tenferro_ops::dim_expr::DimExpr>>,
146 pub last_use: Vec<bool>,
147}
148
149#[derive(Clone, Debug)]
150pub struct ExecProgram {
151 pub instructions: Vec<ExecInstruction>,
152 pub input_slots: Vec<usize>,
153 pub output_slots: Vec<usize>,
154 pub n_slots: usize,
155}
156
157#[derive(Clone, Copy, Debug, PartialEq, Eq)]
158pub(crate) enum DispatchMode {
159 Unsegmented,
160 Segmented,
161}
162
163pub(crate) fn get<'a, T>(
164 slots: &'a [Option<T>],
165 input_slots: &[usize],
166 idx: usize,
167) -> Result<&'a T> {
168 let slot = input_slots[idx];
169 slots[slot]
170 .as_ref()
171 .ok_or(TensorError::MissingValue { slot }.into())
172}
173
174pub(crate) fn initialize_slots(program: &ExecProgram, inputs: Vec<Tensor>) -> Vec<Option<Tensor>> {
175 let mut slots: Vec<Option<Tensor>> = vec![None; program.n_slots];
176 for (i, tensor) in inputs.into_iter().enumerate() {
177 slots[program.input_slots[i]] = Some(tensor);
178 }
179 slots
180}
181
182pub(crate) fn collect_outputs(
183 program: &ExecProgram,
184 mut slots: Vec<Option<Tensor>>,
185) -> Result<Vec<Tensor>> {
186 program
187 .output_slots
188 .iter()
189 .map(|&slot| {
190 slots[slot]
191 .take()
192 .ok_or(TensorError::MissingValue { slot }.into())
193 })
194 .collect()
195}
196
197pub(crate) fn is_host_instruction(inst: &ExecInstruction) -> bool {
198 matches!(
199 &inst.op,
200 ExecOp::ShapeOf { .. }
201 | ExecOp::DynamicTruncate { .. }
202 | ExecOp::PadToMatch { .. }
203 | ExecOp::Constant { .. }
204 | ExecOp::ValidateNonsingular
205 )
206}
207
208pub(crate) fn is_ffi_instruction(inst: &ExecInstruction) -> bool {
209 matches!(
210 &inst.op,
211 ExecOp::DotGeneral(_)
212 | ExecOp::NaryEinsum { .. }
213 | ExecOp::Cholesky
214 | ExecOp::Svd { .. }
215 | ExecOp::Qr
216 | ExecOp::Lu
217 | ExecOp::Eigh { .. }
218 | ExecOp::Eig
219 | ExecOp::TriangularSolve { .. }
220 )
221}
222
223pub(crate) fn resolve_tensor_shape_exprs(
224 slots: &[Option<Tensor>],
225 input_slots: &[usize],
226 exprs: &[DimExpr],
227) -> Result<Vec<usize>> {
228 let mut input_shapes = Vec::with_capacity(input_slots.len());
229 for &slot in input_slots {
230 input_shapes.push(
231 slots[slot]
232 .as_ref()
233 .ok_or(TensorError::MissingValue { slot })?
234 .shape(),
235 );
236 }
237 Ok(DimExpr::eval_all(exprs, &input_shapes))
238}
239
240fn resolve_semiring_shape_exprs<Alg: Semiring>(
241 slots: &[Option<TypedTensor<Alg::Scalar>>],
242 input_slots: &[usize],
243 exprs: &[DimExpr],
244) -> Result<Vec<usize>> {
245 let mut input_shapes = Vec::with_capacity(input_slots.len());
246 for &slot in input_slots {
247 input_shapes.push(
248 slots[slot]
249 .as_ref()
250 .ok_or(TensorError::MissingValue { slot })?
251 .shape
252 .as_slice(),
253 );
254 }
255 Ok(DimExpr::eval_all(exprs, &input_shapes))
256}
257
258pub fn eval_exec_ir<B: TensorBackend>(
272 backend: &mut B,
273 program: &ExecProgram,
274 inputs: Vec<Tensor>,
275) -> Result<Vec<Tensor>> {
276 crate::segment::eval_exec_segmented(backend, program, inputs)
277}
278
279pub fn eval_exec_ir_unsegmented<B: TensorBackend>(
293 backend: &mut B,
294 program: &ExecProgram,
295 inputs: Vec<Tensor>,
296) -> Result<Vec<Tensor>> {
297 let mut cache = NaryEinsumCache::new(
298 std::num::NonZeroUsize::new(crate::engine::DEFAULT_EINSUM_CACHE_CAPACITY)
299 .expect("DEFAULT_EINSUM_CACHE_CAPACITY must be non-zero"),
300 );
301 eval_exec_ir_unsegmented_with_cache(backend, program, inputs, &mut cache)
302}
303
304pub(crate) fn eval_exec_ir_unsegmented_with_cache<B: TensorBackend>(
305 backend: &mut B,
306 program: &ExecProgram,
307 inputs: Vec<Tensor>,
308 cache: &mut NaryEinsumCache,
309) -> Result<Vec<Tensor>> {
310 let mut slots = initialize_slots(program, inputs);
311
312 for inst in &program.instructions {
313 if is_host_instruction(inst) {
314 execute_host_instruction(backend, &mut slots, inst)?;
315 } else if is_ffi_instruction(inst) {
316 execute_ffi_instruction(backend, &mut slots, inst, DispatchMode::Unsegmented, cache)?;
317 } else {
318 let result =
319 backend.with_exec_session(|exec| execute_backend_op(exec, &slots, inst))?;
320 slots[inst.output_slots[0]] = Some(result);
321 }
322 reclaim_last_use_inputs_backend(&mut slots, inst, backend);
323 }
324
325 collect_outputs(program, slots)
326}
327
328pub(crate) fn execute_backend_op(
329 exec: &mut dyn TensorExec,
330 slots: &[Option<Tensor>],
331 inst: &ExecInstruction,
332) -> Result<Tensor> {
333 let result = match &inst.op {
334 ExecOp::Transpose { perm } => exec.transpose(get(slots, &inst.input_slots, 0)?, perm)?,
335 ExecOp::Reshape { shape } => {
336 let shape = resolve_tensor_shape_exprs(slots, &inst.input_slots, shape)?;
337 exec.reshape(get(slots, &inst.input_slots, 0)?, &shape)?
338 }
339 ExecOp::BroadcastInDim { shape, dims } => {
340 let shape = resolve_tensor_shape_exprs(slots, &inst.input_slots, shape)?;
341 exec.broadcast_in_dim(get(slots, &inst.input_slots, 0)?, &shape, dims)?
342 }
343 ExecOp::Convert { to } => exec.convert(get(slots, &inst.input_slots, 0)?, *to)?,
344 ExecOp::ReduceSum { axes } => exec.reduce_sum(get(slots, &inst.input_slots, 0)?, axes)?,
345 ExecOp::ExtractDiag { axis_a, axis_b } => {
346 exec.extract_diagonal(get(slots, &inst.input_slots, 0)?, *axis_a, *axis_b)?
347 }
348 ExecOp::EmbedDiag { axis_a, axis_b } => {
349 exec.embed_diagonal(get(slots, &inst.input_slots, 0)?, *axis_a, *axis_b)?
350 }
351 ExecOp::Tril { k } => exec.tril(get(slots, &inst.input_slots, 0)?, *k)?,
352 ExecOp::Triu { k } => exec.triu(get(slots, &inst.input_slots, 0)?, *k)?,
353 ExecOp::Add => exec.add(
354 get(slots, &inst.input_slots, 0)?,
355 get(slots, &inst.input_slots, 1)?,
356 )?,
357 ExecOp::Multiply => exec.mul(
358 get(slots, &inst.input_slots, 0)?,
359 get(slots, &inst.input_slots, 1)?,
360 )?,
361 ExecOp::Negate => exec.neg(get(slots, &inst.input_slots, 0)?)?,
362 ExecOp::Conj => exec.conj(get(slots, &inst.input_slots, 0)?)?,
363 ExecOp::Divide => exec.div(
364 get(slots, &inst.input_slots, 0)?,
365 get(slots, &inst.input_slots, 1)?,
366 )?,
367 ExecOp::Abs => exec.abs(get(slots, &inst.input_slots, 0)?)?,
368 ExecOp::Sign => exec.sign(get(slots, &inst.input_slots, 0)?)?,
369 ExecOp::Maximum => exec.maximum(
370 get(slots, &inst.input_slots, 0)?,
371 get(slots, &inst.input_slots, 1)?,
372 )?,
373 ExecOp::Minimum => exec.minimum(
374 get(slots, &inst.input_slots, 0)?,
375 get(slots, &inst.input_slots, 1)?,
376 )?,
377 ExecOp::Compare(dir) => exec.compare(
378 get(slots, &inst.input_slots, 0)?,
379 get(slots, &inst.input_slots, 1)?,
380 dir,
381 )?,
382 ExecOp::Select => exec.select(
383 get(slots, &inst.input_slots, 0)?,
384 get(slots, &inst.input_slots, 1)?,
385 get(slots, &inst.input_slots, 2)?,
386 )?,
387 ExecOp::Clamp => exec.clamp(
388 get(slots, &inst.input_slots, 0)?,
389 get(slots, &inst.input_slots, 1)?,
390 get(slots, &inst.input_slots, 2)?,
391 )?,
392 ExecOp::Exp => exec.exp(get(slots, &inst.input_slots, 0)?)?,
393 ExecOp::Log => exec.log(get(slots, &inst.input_slots, 0)?)?,
394 ExecOp::Sin => exec.sin(get(slots, &inst.input_slots, 0)?)?,
395 ExecOp::Cos => exec.cos(get(slots, &inst.input_slots, 0)?)?,
396 ExecOp::Tanh => exec.tanh(get(slots, &inst.input_slots, 0)?)?,
397 ExecOp::Sqrt => exec.sqrt(get(slots, &inst.input_slots, 0)?)?,
398 ExecOp::Rsqrt => exec.rsqrt(get(slots, &inst.input_slots, 0)?)?,
399 ExecOp::Pow => exec.pow(
400 get(slots, &inst.input_slots, 0)?,
401 get(slots, &inst.input_slots, 1)?,
402 )?,
403 ExecOp::Expm1 => exec.expm1(get(slots, &inst.input_slots, 0)?)?,
404 ExecOp::Log1p => exec.log1p(get(slots, &inst.input_slots, 0)?)?,
405 ExecOp::Gather(config) => exec.gather(
406 get(slots, &inst.input_slots, 0)?,
407 get(slots, &inst.input_slots, 1)?,
408 config,
409 )?,
410 ExecOp::Scatter(config) => exec.scatter(
411 get(slots, &inst.input_slots, 0)?,
412 get(slots, &inst.input_slots, 1)?,
413 get(slots, &inst.input_slots, 2)?,
414 config,
415 )?,
416 ExecOp::Slice(config) => exec.slice(get(slots, &inst.input_slots, 0)?, config)?,
417 ExecOp::DynamicSlice { slice_sizes } => exec.dynamic_slice(
418 get(slots, &inst.input_slots, 0)?,
419 get(slots, &inst.input_slots, 1)?,
420 slice_sizes,
421 )?,
422 ExecOp::Pad(config) => exec.pad(get(slots, &inst.input_slots, 0)?, config)?,
423 ExecOp::Concatenate { axis } => {
424 let inputs = collect_tensor_refs(slots, &inst.input_slots)?;
425 exec.concatenate(&inputs, *axis)?
426 }
427 ExecOp::Reverse { axes } => exec.reverse(get(slots, &inst.input_slots, 0)?, axes)?,
428 ExecOp::ReduceProd { axes } => exec.reduce_prod(get(slots, &inst.input_slots, 0)?, axes)?,
429 ExecOp::ReduceMax { axes } => exec.reduce_max(get(slots, &inst.input_slots, 0)?, axes)?,
430 ExecOp::ReduceMin { axes } => exec.reduce_min(get(slots, &inst.input_slots, 0)?, axes)?,
431 other => {
432 return Err(Error::Internal(format!(
433 "host or FFI op reached backend executor: {other:?}"
434 )))
435 }
436 };
437 Ok(result)
438}
439
440pub(crate) fn execute_host_instruction<B: TensorBackend>(
441 backend: &mut B,
442 slots: &mut [Option<Tensor>],
443 inst: &ExecInstruction,
444) -> Result<()> {
445 match &inst.op {
446 ExecOp::ShapeOf { axis } => {
447 let input = get(slots, &inst.input_slots, 0)?;
448 if *axis >= input.shape().len() {
449 return Err(Error::Internal(format!(
450 "ShapeOf: axis {} out of bounds for rank {}",
451 axis,
452 input.shape().len()
453 )));
454 }
455 let host = Tensor::F64(TypedTensor::from_vec(
456 vec![],
457 vec![input.shape()[*axis] as f64],
458 ));
459 slots[inst.output_slots[0]] = Some(backend.upload_host_tensor(&host)?);
460 }
461 ExecOp::DynamicTruncate { axis } => {
462 let input = get(slots, &inst.input_slots, 0)?;
463 if *axis >= input.shape().len() {
464 return Err(Error::Internal(format!(
465 "DynamicTruncate: axis {} out of bounds for rank {}",
466 axis,
467 input.shape().len()
468 )));
469 }
470 let size_tensor = backend.download_to_host(get(slots, &inst.input_slots, 1)?)?;
471 let axis_extent = input.shape()[*axis];
472 let size_f64 = scalar_size_value(&size_tensor)?;
473 let rounded_size = if size_f64.is_finite() {
474 size_f64.round()
475 } else {
476 0.0
477 };
478 let size = rounded_size.max(0.0).min(axis_extent as f64) as usize;
479 let rank = input.shape().len();
480 let mut limits = input.shape().to_vec();
481 limits[*axis] = size;
482 let config = SliceConfig {
483 starts: vec![0; rank],
484 limits,
485 strides: vec![1; rank],
486 };
487 slots[inst.output_slots[0]] = Some(backend.slice(input, &config)?);
488 }
489 ExecOp::PadToMatch { axis } => {
490 let input = get(slots, &inst.input_slots, 0)?;
491 let reference = get(slots, &inst.input_slots, 1)?;
492 if *axis >= input.shape().len() {
493 return Err(Error::Internal(format!(
494 "PadToMatch: axis {} out of bounds for rank {}",
495 axis,
496 input.shape().len()
497 )));
498 }
499 let target_size = reference.shape()[*axis];
500 let current_size = input.shape()[*axis];
501 if current_size >= target_size {
502 slots[inst.output_slots[0]] = Some(input.clone());
503 } else {
504 let rank = input.shape().len();
505 let mut high = vec![0i64; rank];
506 high[*axis] = (target_size - current_size) as i64;
507 let config = PadConfig {
508 edge_padding_low: vec![0i64; rank],
509 edge_padding_high: high,
510 interior_padding: vec![0i64; rank],
511 };
512 slots[inst.output_slots[0]] = Some(backend.pad(input, &config)?);
513 }
514 }
515 ExecOp::Constant { dtype, bytes } => {
516 let host = constant_tensor(*dtype, bytes);
517 slots[inst.output_slots[0]] = Some(backend.upload_host_tensor(&host)?);
518 }
519 ExecOp::ValidateNonsingular => {
520 let input = get(slots, &inst.input_slots, 0)?;
521 let host_input = backend.download_to_host(input)?;
522 validate_nonsingular_u(&host_input)?;
523 slots[inst.output_slots[0]] = Some(input.clone());
524 }
525 other => {
526 return Err(Error::Internal(format!(
527 "non-host op reached host executor: {other:?}"
528 )))
529 }
530 }
531 Ok(())
532}
533
534pub(crate) fn execute_ffi_instruction<B: TensorBackend>(
535 backend: &mut B,
536 slots: &mut [Option<Tensor>],
537 inst: &ExecInstruction,
538 mode: DispatchMode,
539 cache: &mut NaryEinsumCache,
540) -> Result<()> {
541 match &inst.op {
542 ExecOp::DotGeneral(config) => {
543 let result = backend.dot_general(
544 get(slots, &inst.input_slots, 0)?,
545 get(slots, &inst.input_slots, 1)?,
546 config,
547 )?;
548 slots[inst.output_slots[0]] = Some(result);
549 }
550 ExecOp::NaryEinsum { subscripts } => {
551 let inputs = collect_tensor_refs(slots, &inst.input_slots)?;
552 let result = execute_nary_einsum(backend, &inputs, subscripts, mode, cache)?;
553 slots[inst.output_slots[0]] = Some(result);
554 }
555 ExecOp::Cholesky => {
556 let result = backend.cholesky(get(slots, &inst.input_slots, 0)?)?;
557 slots[inst.output_slots[0]] = Some(result);
558 }
559 ExecOp::Svd { .. } => {
560 let results = backend.svd(get(slots, &inst.input_slots, 0)?)?;
561 assign_multi_output(slots, inst, results, "svd")?;
562 }
563 ExecOp::Qr => {
564 let results = backend.qr(get(slots, &inst.input_slots, 0)?)?;
565 assign_multi_output(slots, inst, results, "qr")?;
566 }
567 ExecOp::Lu => {
568 let results = backend.lu(get(slots, &inst.input_slots, 0)?)?;
569 assign_multi_output(slots, inst, results, "lu")?;
570 }
571 ExecOp::Eigh { .. } => {
572 let results = backend.eigh(get(slots, &inst.input_slots, 0)?)?;
573 assign_multi_output(slots, inst, results, "eigh")?;
574 }
575 ExecOp::Eig => {
576 let results = backend.eig(get(slots, &inst.input_slots, 0)?)?;
577 assign_multi_output(slots, inst, results, "eig")?;
578 }
579 ExecOp::TriangularSolve {
580 left_side,
581 lower,
582 transpose_a,
583 unit_diagonal,
584 } => {
585 let result = backend.triangular_solve(
586 get(slots, &inst.input_slots, 0)?,
587 get(slots, &inst.input_slots, 1)?,
588 *left_side,
589 *lower,
590 *transpose_a,
591 *unit_diagonal,
592 )?;
593 slots[inst.output_slots[0]] = Some(result);
594 }
595 other => {
596 return Err(Error::Internal(format!(
597 "non-ffi op reached ffi executor: {other:?}"
598 )))
599 }
600 }
601 Ok(())
602}
603
604fn assign_multi_output(
605 slots: &mut [Option<Tensor>],
606 inst: &ExecInstruction,
607 results: Vec<Tensor>,
608 op_name: &str,
609) -> Result<()> {
610 if results.len() != inst.output_slots.len() {
611 return Err(Error::Internal(format!(
612 "{op_name} produced {} outputs for {} slots",
613 results.len(),
614 inst.output_slots.len()
615 )));
616 }
617 for (slot, tensor) in inst.output_slots.iter().copied().zip(results.into_iter()) {
618 slots[slot] = Some(tensor);
619 }
620 Ok(())
621}
622
623fn collect_tensor_refs<'a>(
624 slots: &'a [Option<Tensor>],
625 input_slots: &[usize],
626) -> Result<Vec<&'a Tensor>> {
627 let mut inputs = Vec::with_capacity(input_slots.len());
628 for &slot in input_slots {
629 inputs.push(
630 slots[slot]
631 .as_ref()
632 .ok_or(TensorError::MissingValue { slot })?,
633 );
634 }
635 Ok(inputs)
636}
637
638fn execute_nary_einsum<B: TensorBackend>(
639 backend: &mut B,
640 inputs: &[&Tensor],
641 subscripts: &str,
642 mode: DispatchMode,
643 cache: &mut NaryEinsumCache,
644) -> Result<Tensor> {
645 use tenferro_einsum::{
646 build_einsum_fragment, ContractionOptimizerOptions, ContractionTree, Subscripts,
647 };
648
649 if inputs.is_empty() {
650 return Err(Error::ContractionError(
651 "nary einsum requires at least one input tensor".into(),
652 ));
653 }
654
655 let subs =
656 Subscripts::parse(subscripts).map_err(|e| Error::InvalidSubscripts(format!("{e}")))?;
657 let shapes: Vec<Vec<usize>> = inputs
658 .iter()
659 .map(|tensor| tensor.shape().to_vec())
660 .collect();
661 let shape_refs: Vec<&[usize]> = shapes.iter().map(Vec::as_slice).collect();
662 let cache_key = (subscripts.to_string(), shapes.clone());
663 let tree_arc = if let Some(cached) = cache.get(&cache_key) {
664 cached.clone()
665 } else {
666 let tree = Arc::new(
667 ContractionTree::optimize_with_options(
668 &subs,
669 &shape_refs,
670 &ContractionOptimizerOptions::default(),
671 )
672 .map_err(|e| Error::ContractionError(format!("{e}")))?,
673 );
674 cache.put(cache_key, tree.clone());
675 tree
676 };
677 let tree = tree_arc.as_ref();
678
679 let mut builder = FragmentBuilder::<StdTensorOp>::new();
680 let mut input_vals = Vec::with_capacity(inputs.len());
681 for input_idx in 0..inputs.len() {
682 let local = builder.add_input(TensorInputKey::User {
683 id: input_idx as u64,
684 });
685 input_vals.push(ValRef::Local(local));
686 }
687
688 let result_ref = build_einsum_fragment(&mut builder, &tree, &input_vals, &shapes);
689 let result_local = match result_ref {
690 ValRef::Local(local) => local,
691 ValRef::External(_) => {
692 return Err(Error::Internal(
693 "runtime nary einsum builder returned an external value".into(),
694 ))
695 }
696 };
697 builder.set_outputs(vec![result_local]);
698 let fragment = Arc::new(builder.build());
699 let output_key = fragment.vals()[result_local].key.clone();
700
701 let view = resolve(vec![fragment]);
702 let graph = materialize_merge(&view, &[output_key]);
703 let compiled = compile(&graph);
704
705 let mut program_inputs = Vec::with_capacity(graph.inputs.len());
706 let mut input_dtypes = Vec::with_capacity(graph.inputs.len());
707 let mut input_shapes = Vec::with_capacity(graph.inputs.len());
708 for key in &graph.inputs {
709 match key {
710 GlobalValKey::Input(TensorInputKey::User { id }) => {
711 let input_idx = *id as usize;
712 let tensor = inputs.get(input_idx).ok_or_else(|| {
713 Error::Internal(format!(
714 "runtime nary einsum input {input_idx} missing for subscripts {subscripts}"
715 ))
716 })?;
717 program_inputs.push((*tensor).clone());
718 input_dtypes.push(tensor.dtype());
719 input_shapes.push(DimExpr::from_concrete(tensor.shape()));
720 }
721 other => {
722 return Err(Error::Internal(format!(
723 "unexpected runtime nary einsum input key: {other:?}"
724 )))
725 }
726 }
727 }
728 let program = compile_std_to_exec(&compiled, &input_dtypes, &input_shapes);
729
730 let mut outputs = match mode {
731 DispatchMode::Unsegmented => {
732 eval_exec_ir_unsegmented_with_cache(backend, &program, program_inputs, cache)?
733 }
734 DispatchMode::Segmented => crate::segment::eval_exec_segmented_with_cache(
735 backend,
736 &program,
737 program_inputs,
738 cache,
739 )?,
740 };
741 if outputs.len() != 1 {
742 return Err(Error::Internal(format!(
743 "runtime nary einsum expected 1 output, got {}",
744 outputs.len()
745 )));
746 }
747 Ok(outputs.remove(0))
748}
749
750pub(crate) fn constant_tensor(dtype: DType, bytes: &[u8]) -> Tensor {
751 match dtype {
752 DType::F64 => Tensor::F64(TypedTensor::from_vec(
753 vec![],
754 vec![f64::from_le_bytes(exact_bytes::<8>(dtype, bytes))],
755 )),
756 DType::F32 => Tensor::F32(TypedTensor::from_vec(
757 vec![],
758 vec![f32::from_le_bytes(exact_bytes::<4>(dtype, bytes))],
759 )),
760 DType::C64 => {
761 let data = exact_bytes::<16>(dtype, bytes);
762 let mut re_bytes = [0u8; 8];
763 let mut im_bytes = [0u8; 8];
764 re_bytes.copy_from_slice(&data[..8]);
765 im_bytes.copy_from_slice(&data[8..]);
766 let re = f64::from_le_bytes(re_bytes);
767 let im = f64::from_le_bytes(im_bytes);
768 Tensor::C64(TypedTensor::from_vec(vec![], vec![Complex64::new(re, im)]))
769 }
770 DType::C32 => {
771 let data = exact_bytes::<8>(dtype, bytes);
772 let mut re_bytes = [0u8; 4];
773 let mut im_bytes = [0u8; 4];
774 re_bytes.copy_from_slice(&data[..4]);
775 im_bytes.copy_from_slice(&data[4..]);
776 let re = f32::from_le_bytes(re_bytes);
777 let im = f32::from_le_bytes(im_bytes);
778 Tensor::C32(TypedTensor::from_vec(vec![], vec![Complex32::new(re, im)]))
779 }
780 }
781}
782
783fn exact_bytes<const N: usize>(dtype: DType, bytes: &[u8]) -> [u8; N] {
784 if bytes.len() != N {
785 panic!(
786 "constant {:?} expected {} bytes, got {}",
787 dtype,
788 N,
789 bytes.len()
790 );
791 }
792 let mut out = [0u8; N];
793 out.copy_from_slice(bytes);
794 out
795}
796
797fn scalar_size_value(size_tensor: &Tensor) -> Result<f64> {
798 match size_tensor {
799 Tensor::F64(inner) => Ok(inner.host_data()[0]),
800 Tensor::F32(inner) => Ok(inner.host_data()[0] as f64),
801 _ => Err(Error::Internal(
802 "DynamicTruncate size must be an f32 or f64 scalar".into(),
803 )),
804 }
805}
806
807pub(crate) fn reclaim_last_use_inputs_exec(
808 slots: &mut [Option<Tensor>],
809 inst: &ExecInstruction,
810 exec: &mut dyn TensorExec,
811) {
812 for (i, &is_last) in inst.last_use.iter().enumerate() {
813 if is_last {
814 if let Some(tensor) = slots[inst.input_slots[i]].take() {
815 exec.reclaim_buffer(tensor);
816 }
817 }
818 }
819}
820
821pub(crate) fn reclaim_last_use_inputs_backend<B: TensorBackend>(
822 slots: &mut [Option<Tensor>],
823 inst: &ExecInstruction,
824 backend: &mut B,
825) {
826 for (i, &is_last) in inst.last_use.iter().enumerate() {
827 if is_last {
828 if let Some(tensor) = slots[inst.input_slots[i]].take() {
829 backend.reclaim_buffer(tensor);
830 }
831 }
832 }
833}
834
835pub fn eval_semiring_ir<B, Alg>(
836 backend: &mut B,
837 program: &ExecProgram,
838 inputs: Vec<TypedTensor<Alg::Scalar>>,
839) -> Result<Vec<TypedTensor<Alg::Scalar>>>
840where
841 Alg: Semiring,
842 B: SemiringBackend<Alg>,
843{
844 let mut slots: Vec<Option<TypedTensor<Alg::Scalar>>> = vec![None; program.n_slots];
845 for (i, tensor) in inputs.into_iter().enumerate() {
846 slots[program.input_slots[i]] = Some(tensor);
847 }
848
849 for inst in &program.instructions {
850 let result = match &inst.op {
851 ExecOp::Transpose { perm } => typed_transpose(get(&slots, &inst.input_slots, 0)?, perm),
852 ExecOp::Reshape { shape } => {
853 let shape = resolve_semiring_shape_exprs::<Alg>(&slots, &inst.input_slots, shape)?;
854 typed_reshape(get(&slots, &inst.input_slots, 0)?, &shape)
855 }
856 ExecOp::BroadcastInDim { shape, dims } => {
857 let shape = resolve_semiring_shape_exprs::<Alg>(&slots, &inst.input_slots, shape)?;
858 typed_broadcast_in_dim(get(&slots, &inst.input_slots, 0)?, &shape, dims)
859 }
860 ExecOp::DotGeneral(config) => backend.batched_gemm(
861 get(&slots, &inst.input_slots, 0)?,
862 get(&slots, &inst.input_slots, 1)?,
863 config,
864 ),
865 ExecOp::ReduceSum { axes } => {
866 backend.reduce_sum(get(&slots, &inst.input_slots, 0)?, axes)
867 }
868 ExecOp::ExtractDiag { axis_a, axis_b } => {
869 typed_extract_diagonal(get(&slots, &inst.input_slots, 0)?, *axis_a, *axis_b)
870 }
871 ExecOp::EmbedDiag { axis_a, axis_b } => {
872 typed_embed_diagonal(get(&slots, &inst.input_slots, 0)?, *axis_a, *axis_b)
873 }
874 ExecOp::Add => backend.add(
875 get(&slots, &inst.input_slots, 0)?,
876 get(&slots, &inst.input_slots, 1)?,
877 ),
878 ExecOp::Multiply => backend.mul(
879 get(&slots, &inst.input_slots, 0)?,
880 get(&slots, &inst.input_slots, 1)?,
881 ),
882 _ => panic!("non-semiring op in semiring program: {:?}", inst.op),
883 };
884 slots[inst.output_slots[0]] = Some(result?);
885 }
886
887 program
888 .output_slots
889 .iter()
890 .map(|&slot| {
891 slots[slot]
892 .take()
893 .ok_or(TensorError::MissingValue { slot }.into())
894 })
895 .collect()
896}