1use std::collections::hash_map::DefaultHasher;
4use std::hash::{Hash, Hasher};
5use std::mem::size_of;
6use std::sync::Arc;
7
8use computegraph::compile::{compile, CompiledProgram, Instruction};
9use computegraph::graph::GraphBuilder;
10use computegraph::materialize::materialize_merge;
11use computegraph::resolve::resolve;
12use computegraph::types::{ValueKey, ValueRef};
13use tenferro_ad::error::{Error, Result};
14use tenferro_ad::extension::{adopt_untracked_eager_value, apply_eager};
15use tenferro_ad::{EagerRuntime, EagerTensor};
16use tenferro_ops::dim_expr::DimExpr;
17use tenferro_ops::input_key::TensorInputKey;
18use tenferro_ops::std_tensor_op::StdTensorOp;
19use tenferro_runtime::ExtensionCacheKey;
20use tenferro_tensor::TensorFusion;
21
22use crate::binary_dot::{try_build_exact_output_binary_dot_plan, BinaryDotOperandOrder};
23use crate::builder::build_einsum_graph;
24use crate::cache::{
25 saturating_sum, vec_retained_bytes, EINSUM_EAGER_EXPANDED_PROGRAMS_CACHE,
26 EINSUM_EXTENSION_FAMILY_ID,
27};
28use crate::extension::{register_runtime, EinsumExtensionOp};
29use crate::optimize::{
30 default_auto_options, hash_einsum_plan_spec, resolve_plan_spec, EinsumPlanSpec,
31};
32use crate::{parse_einsum_subscripts, EinsumSubscripts, Subscripts, TensorDotAxes};
33
34pub trait EagerEinsumExt {
36 fn einsum(&self, subscripts: &str) -> Result<EagerTensor>;
37 fn einsum_subscripts(&self, subscripts: &EinsumSubscripts) -> Result<EagerTensor>;
38}
39
40impl EagerEinsumExt for [&EagerTensor] {
41 fn einsum(&self, subscripts: &str) -> Result<EagerTensor> {
42 einsum(self, subscripts)
43 }
44
45 fn einsum_subscripts(&self, subscripts: &EinsumSubscripts) -> Result<EagerTensor> {
46 einsum_subscripts(self, subscripts)
47 }
48}
49
50impl<const N: usize> EagerEinsumExt for [&EagerTensor; N] {
51 fn einsum(&self, subscripts: &str) -> Result<EagerTensor> {
52 einsum(self.as_slice(), subscripts)
53 }
54
55 fn einsum_subscripts(&self, subscripts: &EinsumSubscripts) -> Result<EagerTensor> {
56 einsum_subscripts(self.as_slice(), subscripts)
57 }
58}
59
60pub trait EagerTensorEinsumExt {
62 fn tensordot(&self, rhs: &EagerTensor, axes: TensorDotAxes<'_>) -> Result<EagerTensor>;
63}
64
65impl EagerTensorEinsumExt for EagerTensor {
66 fn tensordot(&self, rhs: &EagerTensor, axes: TensorDotAxes<'_>) -> Result<EagerTensor> {
67 tensordot(self, rhs, axes)
68 }
69}
70
71pub fn einsum(inputs: &[&EagerTensor], subscripts: &str) -> Result<EagerTensor> {
95 let subscripts = parse_einsum_subscripts(subscripts)
96 .map_err(|err| Error::ContractionError(err.to_string()))?;
97 einsum_subscripts(inputs, &subscripts)
98}
99
100pub fn einsum_subscripts(
125 inputs: &[&EagerTensor],
126 subscripts: &EinsumSubscripts,
127) -> Result<EagerTensor> {
128 if let Some(result) = try_direct_binary_dot_general(inputs, subscripts) {
129 return result;
130 }
131
132 if let Some(result) = try_whole_program_untracked(inputs, subscripts)? {
133 return Ok(result);
134 }
135
136 let output_shape_hint = infer_eager_output_shape(subscripts, inputs)?;
137 if let Some(result) = try_expand_eager_einsum(inputs, subscripts)? {
138 return Ok(result);
139 }
140
141 if let Some(first) = inputs.first() {
142 first
143 .runtime()
144 .register_extension(register_runtime)
145 .map_err(|err| Error::Internal(err.to_string()))?;
146 }
147
148 let op = Arc::new(EinsumExtensionOp::with_output_shape_hint(
149 subscripts.clone(),
150 output_shape_hint,
151 EinsumPlanSpec::Auto(default_auto_options()),
152 ));
153 let mut outputs = apply_eager(op, inputs)?;
154 outputs
155 .pop()
156 .ok_or_else(|| Error::Internal("einsum extension produced no eager output".to_string()))
157}
158
159fn try_direct_binary_dot_general(
160 inputs: &[&EagerTensor],
161 subscripts: &EinsumSubscripts,
162) -> Option<Result<EagerTensor>> {
163 if inputs.len() != 2 || subscripts.inputs.len() != 2 {
164 return None;
165 }
166
167 let lhs_labels = &subscripts.inputs[0];
168 let rhs_labels = &subscripts.inputs[1];
169 if lhs_labels.len() != inputs[0].shape().len() || rhs_labels.len() != inputs[1].shape().len() {
170 return None;
171 }
172
173 if let Some(plan) =
174 try_build_exact_output_binary_dot_plan(lhs_labels, rhs_labels, &subscripts.output)
175 {
176 return Some(match plan.operand_order {
177 BinaryDotOperandOrder::Original => inputs[0].dot_general(inputs[1], plan.config),
178 BinaryDotOperandOrder::Swapped => inputs[1].dot_general(inputs[0], plan.config),
179 });
180 }
181 None
182}
183
184fn whole_program_untracked_enabled() -> bool {
192 std::env::var_os("TENFERRO_EAGER_WHOLE_PROGRAM").is_some()
193}
194
195fn try_whole_program_untracked(
201 inputs: &[&EagerTensor],
202 subscripts: &EinsumSubscripts,
203) -> Result<Option<EagerTensor>> {
204 if !whole_program_untracked_enabled() {
205 return Ok(None);
206 }
207 let Some(first) = inputs.first() else {
208 return Ok(None);
209 };
210 if inputs.iter().any(|tensor| tensor.tracks_grad()) {
211 return Ok(None);
212 }
213 let runtime = first.runtime();
214 if inputs
215 .iter()
216 .any(|tensor| !Arc::ptr_eq(tensor.runtime(), runtime))
217 {
218 return Ok(None);
219 }
220
221 let subs = Subscripts::from(subscripts);
222 let tensor_arcs = inputs
223 .iter()
224 .map(|tensor| tensor.materialized())
225 .collect::<Result<Vec<_>>>()?;
226 let tensors: Vec<_> = tensor_arcs.iter().map(|tensor| tensor.as_ref()).collect();
227 let result = runtime.with_backend_mut(|backend| {
228 crate::eager::eager_einsum_subscripts(backend, &tensors, &subs)
229 })??;
230 Ok(Some(EagerTensor::from_tensor_in(result, runtime.clone())?))
231}
232
233#[cfg(test)]
266fn einsum_whole_program_untracked(
267 inputs: &[&EagerTensor],
268 tree: &crate::ContractionTree,
269) -> Result<EagerTensor> {
270 let first = inputs.first().ok_or_else(|| {
271 Error::ContractionError("einsum requires at least one input tensor".into())
272 })?;
273 if inputs.iter().any(|tensor| tensor.tracks_grad()) {
274 return Err(Error::Internal(
275 "whole-program eager einsum requires untracked inputs".into(),
276 ));
277 }
278 let runtime = first.runtime();
279 if inputs
280 .iter()
281 .any(|tensor| !Arc::ptr_eq(tensor.runtime(), runtime))
282 {
283 return Err(Error::Internal(
284 "whole-program eager einsum requires inputs from one runtime".into(),
285 ));
286 }
287 let tensor_arcs = inputs
288 .iter()
289 .map(|tensor| tensor.materialized())
290 .collect::<Result<Vec<_>>>()?;
291 let tensors: Vec<_> = tensor_arcs.iter().map(|tensor| tensor.as_ref()).collect();
292 let result = runtime.with_backend_mut(|backend| {
293 crate::eager::eager_einsum_with_tree(backend, &tensors, tree)
294 })??;
295 EagerTensor::from_tensor_in(result, runtime.clone())
296}
297
298fn try_expand_eager_einsum(
299 inputs: &[&EagerTensor],
300 subscripts: &EinsumSubscripts,
301) -> Result<Option<EagerTensor>> {
302 if inputs.len() <= 1 {
303 return Ok(None);
304 }
305
306 let shapes: Vec<Vec<usize>> = inputs
307 .iter()
308 .map(|tensor| tensor.shape().to_vec())
309 .collect();
310 let shape_refs: Vec<&[usize]> = shapes.iter().map(Vec::as_slice).collect();
311 let subs = Subscripts::from(subscripts);
312 let plan_spec = EinsumPlanSpec::Auto(default_auto_options());
313
314 let program = cached_expanded_eager_program(
315 inputs[0].runtime(),
316 subscripts,
317 &subs,
318 &plan_spec,
319 &shape_refs,
320 &shapes,
321 )?;
322 execute_eager_einsum_program(inputs, &program)
323}
324
325struct ExpandedEagerProgram {
326 compiled: CompiledProgram<StdTensorOp>,
327 input_slots: Vec<(usize, usize)>,
328}
329
330fn cached_expanded_eager_program(
331 runtime: &Arc<EagerRuntime>,
332 subscripts: &EinsumSubscripts,
333 subs: &Subscripts,
334 plan_spec: &EinsumPlanSpec,
335 shape_refs: &[&[usize]],
336 shapes: &[Vec<usize>],
337) -> Result<Arc<ExpandedEagerProgram>> {
338 runtime.with_extension_caches_mut(|caches| {
339 let key = expanded_eager_program_cache_key(subscripts, plan_spec, shapes);
340 if let Some(cached) = caches.get::<Arc<ExpandedEagerProgram>>(&key) {
341 return Ok(Arc::clone(cached));
342 }
343
344 let tree = resolve_plan_spec(plan_spec, subs, shape_refs)
345 .map_err(|err| Error::ContractionError(err.to_string()))?;
346 let program = Arc::new(build_expanded_eager_program(&tree, shapes)?);
347 let retained_bytes = expanded_eager_program_retained_bytes(&program);
348 caches.put(key, Arc::clone(&program), retained_bytes);
349 Ok(program)
350 })?
351}
352
353fn expanded_eager_program_cache_key(
354 subscripts: &EinsumSubscripts,
355 plan_spec: &EinsumPlanSpec,
356 shapes: &[Vec<usize>],
357) -> ExtensionCacheKey {
358 let mut hasher = DefaultHasher::new();
359 subscripts.hash(&mut hasher);
360 shapes.hash(&mut hasher);
361 hash_einsum_plan_spec(plan_spec, &mut hasher);
362 ExtensionCacheKey::new(
363 EINSUM_EXTENSION_FAMILY_ID,
364 EINSUM_EAGER_EXPANDED_PROGRAMS_CACHE,
365 hasher.finish(),
366 )
367}
368
369fn build_expanded_eager_program(
370 tree: &crate::ContractionTree,
371 shapes: &[Vec<usize>],
372) -> Result<ExpandedEagerProgram> {
373 let mut builder = GraphBuilder::<StdTensorOp>::new();
374 let mut input_vals = Vec::with_capacity(shapes.len());
375 for input_idx in 0..shapes.len() {
376 let local = builder.add_input(TensorInputKey::User {
377 id: input_idx as u64,
378 });
379 input_vals.push(ValueRef::Local(local));
380 }
381
382 let result_ref = build_einsum_graph(&mut builder, tree, &input_vals, shapes)
383 .map_err(|err| Error::ContractionError(err.to_string()))?;
384 let ValueRef::Local(result_local) = result_ref else {
385 return Err(Error::Internal(
386 "expanded eager einsum returned an external value".into(),
387 ));
388 };
389 builder.set_outputs(vec![result_local]);
390 let graph = Arc::new(builder.build());
391 let output_key = graph.values()[result_local].key.clone();
392 let view = resolve(vec![graph]);
393 let graph = materialize_merge(&view, &[output_key]);
394 let compiled = compile(&graph);
395 let input_slots = compiled
396 .input_slots
397 .iter()
398 .zip(graph.inputs.iter())
399 .map(|(&slot, key)| {
400 let ValueKey::Input(TensorInputKey::User { id }) = key else {
401 return Err(Error::Internal(format!(
402 "expanded eager einsum saw unexpected input key: {key:?}"
403 )));
404 };
405 Ok((slot, *id as usize))
406 })
407 .collect::<Result<_>>()?;
408
409 Ok(ExpandedEagerProgram {
410 compiled,
411 input_slots,
412 })
413}
414
415fn execute_eager_einsum_program(
416 inputs: &[&EagerTensor],
417 program: &ExpandedEagerProgram,
418) -> Result<Option<EagerTensor>> {
419 let mut slots: Vec<Option<EagerTensor>> = vec![None; program.compiled.n_slots];
420 for &(slot, input_idx) in &program.input_slots {
421 let tensor = inputs.get(input_idx).ok_or_else(|| {
422 Error::Internal(format!(
423 "expanded eager einsum input {input_idx} is missing"
424 ))
425 })?;
426 slots[slot] = Some((*tensor).clone());
427 }
428
429 let mut instruction_idx = 0;
430 while instruction_idx < program.compiled.instructions.len() {
431 if let Some((output_slot, output)) = try_execute_eager_broadcast_multiply_pattern(
432 &program.compiled.instructions,
433 instruction_idx,
434 &slots,
435 &program.compiled.output_slots,
436 )? {
437 slots[output_slot] = Some(output);
438 instruction_idx += 3;
439 continue;
440 }
441
442 let instr = &program.compiled.instructions[instruction_idx];
443 if instr.outputs.len() != 1 {
444 return Err(Error::Internal(format!(
445 "expanded eager einsum expected single-output op, got {} outputs",
446 instr.outputs.len()
447 )));
448 }
449 let input_values: Vec<EagerTensor> = instr
450 .inputs
451 .iter()
452 .map(|&slot| {
453 slots
454 .get(slot)
455 .and_then(Option::as_ref)
456 .cloned()
457 .ok_or_else(|| {
458 Error::Internal(format!(
459 "expanded eager einsum missing value for slot {slot}"
460 ))
461 })
462 })
463 .collect::<Result<_>>()?;
464 let input_refs: Vec<&EagerTensor> = input_values.iter().collect();
465 let output =
466 tenferro_ad::extension::apply_standard_op(instr.operation.clone(), &input_refs)?;
467 slots[instr.outputs[0]] = Some(output);
468 instruction_idx += 1;
469 }
470
471 let [output_slot] = program.compiled.output_slots.as_slice() else {
472 return Err(Error::Internal(format!(
473 "expanded eager einsum expected one graph output, got {}",
474 program.compiled.output_slots.len()
475 )));
476 };
477 slots
478 .get_mut(*output_slot)
479 .and_then(Option::take)
480 .map(Some)
481 .ok_or_else(|| Error::Internal("expanded eager einsum output slot is missing".into()))
482}
483
484fn expanded_eager_program_retained_bytes(program: &ExpandedEagerProgram) -> usize {
485 saturating_sum([
486 size_of::<ExpandedEagerProgram>(),
487 vec_retained_bytes(&program.input_slots),
488 compiled_program_retained_bytes(&program.compiled),
489 ])
490}
491
492fn compiled_program_retained_bytes(program: &CompiledProgram<StdTensorOp>) -> usize {
493 saturating_sum([
494 size_of::<CompiledProgram<StdTensorOp>>(),
495 vec_retained_bytes(&program.instructions),
496 vec_retained_bytes(&program.input_slots),
497 vec_retained_bytes(&program.output_slots),
498 saturating_sum(program.instructions.iter().map(instruction_retained_bytes)),
499 ])
500}
501
502fn instruction_retained_bytes(instruction: &Instruction<StdTensorOp>) -> usize {
503 saturating_sum([
504 size_of::<Instruction<StdTensorOp>>(),
505 std_tensor_op_retained_bytes(&instruction.operation),
506 vec_retained_bytes(&instruction.inputs),
507 vec_retained_bytes(&instruction.outputs),
508 ])
509}
510
511fn std_tensor_op_retained_bytes(op: &StdTensorOp) -> usize {
512 match op {
513 StdTensorOp::DotGeneral { config } => saturating_sum([
514 vec_retained_bytes(&config.lhs_contracting_dims),
515 vec_retained_bytes(&config.rhs_contracting_dims),
516 vec_retained_bytes(&config.lhs_batch_dims),
517 vec_retained_bytes(&config.rhs_batch_dims),
518 ]),
519 StdTensorOp::Transpose { perm } => vec_retained_bytes(perm),
520 StdTensorOp::Reshape { to_shape } => vec_retained_bytes(to_shape),
521 StdTensorOp::BroadcastInDim { shape, dims } => {
522 saturating_sum([vec_retained_bytes(shape), vec_retained_bytes(dims)])
523 }
524 StdTensorOp::Constant { bytes, .. } => vec_retained_bytes(bytes),
525 StdTensorOp::ReduceSum { axes }
526 | StdTensorOp::ReduceProd { axes }
527 | StdTensorOp::ReduceMax { axes }
528 | StdTensorOp::ReduceMin { axes }
529 | StdTensorOp::Reverse { axes } => vec_retained_bytes(axes),
530 StdTensorOp::DynamicSlice { slice_sizes } => vec_retained_bytes(slice_sizes),
531 StdTensorOp::GatherDynamicSliceSizes {
532 offset_dims,
533 collapsed_slice_dims,
534 start_index_map,
535 slice_sizes,
536 ..
537 } => saturating_sum([
538 vec_retained_bytes(offset_dims),
539 vec_retained_bytes(collapsed_slice_dims),
540 vec_retained_bytes(start_index_map),
541 vec_retained_bytes(slice_sizes),
542 ]),
543 _ => 0,
544 }
545}
546
547fn try_execute_eager_broadcast_multiply_pattern(
548 instructions: &[Instruction<StdTensorOp>],
549 instruction_idx: usize,
550 slots: &[Option<EagerTensor>],
551 output_slots: &[usize],
552) -> Result<Option<(usize, EagerTensor)>> {
553 if instruction_idx + 2 >= instructions.len() {
554 return Ok(None);
555 }
556 let lhs_bc = &instructions[instruction_idx];
557 let rhs_bc = &instructions[instruction_idx + 1];
558 let multiply = &instructions[instruction_idx + 2];
559
560 let StdTensorOp::BroadcastInDim {
561 shape: lhs_shape_exprs,
562 dims: lhs_dims,
563 } = &lhs_bc.operation
564 else {
565 return Ok(None);
566 };
567 let StdTensorOp::BroadcastInDim {
568 shape: rhs_shape_exprs,
569 dims: rhs_dims,
570 } = &rhs_bc.operation
571 else {
572 return Ok(None);
573 };
574 if !matches!(multiply.operation, StdTensorOp::Mul)
575 || lhs_bc.outputs.len() != 1
576 || rhs_bc.outputs.len() != 1
577 || multiply.outputs.len() != 1
578 || multiply.inputs.len() != 2
579 || lhs_bc.inputs.is_empty()
580 || rhs_bc.inputs.is_empty()
581 || multiply.inputs[0] != lhs_bc.outputs[0]
582 || multiply.inputs[1] != rhs_bc.outputs[0]
583 {
584 return Ok(None);
585 }
586
587 let lhs_bc_slot = lhs_bc.outputs[0];
588 let rhs_bc_slot = rhs_bc.outputs[0];
589 if output_slots.contains(&lhs_bc_slot)
590 || output_slots.contains(&rhs_bc_slot)
591 || instructions[instruction_idx + 3..]
592 .iter()
593 .any(|instr| instr.inputs.contains(&lhs_bc_slot) || instr.inputs.contains(&rhs_bc_slot))
594 {
595 return Ok(None);
596 }
597
598 let lhs = slot_tensor(slots, lhs_bc.inputs[0])?;
599 let rhs = slot_tensor(slots, rhs_bc.inputs[0])?;
600 let lhs_shape = eval_shape_exprs(slots, &lhs_bc.inputs, lhs_shape_exprs)?;
601 let rhs_shape = eval_shape_exprs(slots, &rhs_bc.inputs, rhs_shape_exprs)?;
602 let Some(output) =
603 backend_broadcast_multiply_untracked(lhs, &lhs_shape, lhs_dims, rhs, &rhs_shape, rhs_dims)?
604 else {
605 return Ok(None);
606 };
607
608 Ok(Some((multiply.outputs[0], output)))
609}
610
611#[allow(clippy::too_many_arguments)]
612fn backend_broadcast_multiply_untracked(
613 lhs: &EagerTensor,
614 lhs_shape: &[usize],
615 lhs_dims: &[usize],
616 rhs: &EagerTensor,
617 rhs_shape: &[usize],
618 rhs_dims: &[usize],
619) -> Result<Option<EagerTensor>> {
620 if !Arc::ptr_eq(lhs.runtime(), rhs.runtime()) {
621 return Err(Error::ContextMismatch {
622 lhs: lhs.ctx_id(),
623 rhs: rhs.ctx_id(),
624 });
625 }
626 if lhs.tracks_grad() || rhs.tracks_grad() {
627 return Ok(None);
628 }
629
630 let runtime = lhs.runtime();
631 let value = runtime.with_backend_mut(|backend| {
632 backend.execute_broadcast_multiply_value(
633 lhs.tensor_read(),
634 lhs_shape,
635 lhs_dims,
636 rhs.tensor_read(),
637 rhs_shape,
638 rhs_dims,
639 )
640 })??;
641
642 Ok(value.map(|value| adopt_untracked_eager_value(runtime.clone(), value)))
643}
644
645fn eval_shape_exprs(
646 slots: &[Option<EagerTensor>],
647 input_slots: &[usize],
648 shape: &[DimExpr],
649) -> Result<Vec<usize>> {
650 let inputs = input_slots
651 .iter()
652 .map(|&slot| slot_tensor(slots, slot))
653 .collect::<Result<Vec<_>>>()?;
654 let input_shapes = inputs
655 .iter()
656 .map(|tensor| tensor.shape())
657 .collect::<Vec<_>>();
658 DimExpr::eval_all(shape, &input_shapes).map_err(|err| Error::InvalidCompiledGraph {
659 message: format!("invalid eager einsum shape expression: {err}"),
660 })
661}
662
663fn slot_tensor(slots: &[Option<EagerTensor>], slot: usize) -> Result<&EagerTensor> {
664 slots.get(slot).and_then(Option::as_ref).ok_or_else(|| {
665 Error::Internal(format!(
666 "expanded eager einsum missing value for slot {slot}"
667 ))
668 })
669}
670
671fn infer_eager_output_shape(
672 subscripts: &EinsumSubscripts,
673 inputs: &[&EagerTensor],
674) -> Result<Vec<tenferro_runtime::SymDim>> {
675 if inputs.is_empty() {
676 return Err(Error::ContractionError(
677 "einsum requires at least one input tensor".into(),
678 ));
679 }
680 if subscripts.inputs.len() != inputs.len() {
681 return Err(Error::ContractionError(format!(
682 "einsum subscripts expect {} inputs, got {}",
683 subscripts.inputs.len(),
684 inputs.len()
685 )));
686 }
687
688 let mut label_dims = std::collections::HashMap::new();
689 for (labels, tensor) in subscripts.inputs.iter().zip(inputs.iter()) {
690 let shape = tensor.shape();
691 if labels.len() != shape.len() {
692 return Err(Error::ContractionError(format!(
693 "einsum input rank mismatch: labels={}, shape={}",
694 labels.len(),
695 shape.len()
696 )));
697 }
698 for (&label, &dim) in labels.iter().zip(shape.iter()) {
699 if let Some(existing) = label_dims.insert(label, dim) {
700 if existing != dim {
701 return Err(Error::ContractionError(format!(
702 "einsum label {label} has inconsistent dimensions {existing} and {dim}"
703 )));
704 }
705 }
706 }
707 }
708
709 subscripts
710 .output
711 .iter()
712 .map(|label| {
713 label_dims
714 .get(label)
715 .copied()
716 .map(tenferro_runtime::SymDim::from)
717 .ok_or_else(|| {
718 Error::ContractionError(format!(
719 "einsum output label {label} is missing from input labels"
720 ))
721 })
722 })
723 .collect()
724}
725
726pub fn tensordot(
753 lhs: &EagerTensor,
754 rhs: &EagerTensor,
755 axes: TensorDotAxes<'_>,
756) -> Result<EagerTensor> {
757 let config = crate::tensordot::dot_general_config(axes, lhs.shape().len(), rhs.shape().len())?;
758 crate::tensordot::validate_concrete_contract_dims(lhs.shape(), rhs.shape(), &config)?;
759 lhs.dot_general(rhs, config)
760}
761
762#[cfg(test)]
763mod tests;