Skip to main content

tenferro_einsum/
eager_ad.rs

1//! EagerTensor einsum extension API.
2
3use std::collections::hash_map::DefaultHasher;
4use std::hash::{Hash, Hasher};
5use std::mem::size_of;
6use std::sync::Arc;
7
8use computegraph::compile::{compile, CompiledProgram, Instruction};
9use computegraph::graph::GraphBuilder;
10use computegraph::materialize::materialize_merge;
11use computegraph::resolve::resolve;
12use computegraph::types::{ValueKey, ValueRef};
13use tenferro_ad::error::{Error, Result};
14use tenferro_ad::extension::{adopt_untracked_eager_value, apply_eager};
15use tenferro_ad::{EagerRuntime, EagerTensor};
16use tenferro_ops::dim_expr::DimExpr;
17use tenferro_ops::input_key::TensorInputKey;
18use tenferro_ops::std_tensor_op::StdTensorOp;
19use tenferro_runtime::ExtensionCacheKey;
20use tenferro_tensor::TensorFusion;
21
22use crate::binary_dot::{try_build_exact_output_binary_dot_plan, BinaryDotOperandOrder};
23use crate::builder::build_einsum_graph;
24use crate::cache::{
25    saturating_sum, vec_retained_bytes, EINSUM_EAGER_EXPANDED_PROGRAMS_CACHE,
26    EINSUM_EXTENSION_FAMILY_ID,
27};
28use crate::extension::{register_runtime, EinsumExtensionOp};
29use crate::optimize::{
30    default_auto_options, hash_einsum_plan_spec, resolve_plan_spec, EinsumPlanSpec,
31};
32use crate::{parse_einsum_subscripts, EinsumSubscripts, Subscripts, TensorDotAxes};
33
34/// Eager einsum extension methods for slices or arrays of [`EagerTensor`] refs.
35pub trait EagerEinsumExt {
36    fn einsum(&self, subscripts: &str) -> Result<EagerTensor>;
37    fn einsum_subscripts(&self, subscripts: &EinsumSubscripts) -> Result<EagerTensor>;
38}
39
40impl EagerEinsumExt for [&EagerTensor] {
41    fn einsum(&self, subscripts: &str) -> Result<EagerTensor> {
42        einsum(self, subscripts)
43    }
44
45    fn einsum_subscripts(&self, subscripts: &EinsumSubscripts) -> Result<EagerTensor> {
46        einsum_subscripts(self, subscripts)
47    }
48}
49
50impl<const N: usize> EagerEinsumExt for [&EagerTensor; N] {
51    fn einsum(&self, subscripts: &str) -> Result<EagerTensor> {
52        einsum(self.as_slice(), subscripts)
53    }
54
55    fn einsum_subscripts(&self, subscripts: &EinsumSubscripts) -> Result<EagerTensor> {
56        einsum_subscripts(self.as_slice(), subscripts)
57    }
58}
59
60/// Eager tensor contraction-sugar methods.
61pub trait EagerTensorEinsumExt {
62    fn tensordot(&self, rhs: &EagerTensor, axes: TensorDotAxes<'_>) -> Result<EagerTensor>;
63}
64
65impl EagerTensorEinsumExt for EagerTensor {
66    fn tensordot(&self, rhs: &EagerTensor, axes: TensorDotAxes<'_>) -> Result<EagerTensor> {
67        tensordot(self, rhs, axes)
68    }
69}
70
71/// Execute an einsum eagerly on [`EagerTensor`] values.
72///
73/// # Examples
74///
75/// ```
76/// use tenferro_ad::{EagerRuntime, EagerTensor};
77/// use tenferro_cpu::CpuBackend;
78/// use tenferro_einsum::EagerEinsumExt;
79/// use tenferro_tensor::Tensor;
80///
81/// let runtime = EagerRuntime::with_cpu_backend(CpuBackend::new());
82/// let a = EagerTensor::from_tensor_in(
83///     Tensor::from_vec_col_major(vec![2, 3], vec![1.0_f64; 6]).unwrap(),
84///     runtime.clone(),
85/// ).unwrap();
86/// let b = EagerTensor::from_tensor_in(
87///     Tensor::from_vec_col_major(vec![3, 4], vec![1.0_f64; 12]).unwrap(),
88///     runtime,
89/// ).unwrap();
90/// let out = [&a, &b].einsum("ij,jk->ik")?;
91/// assert_eq!(out.shape(), &[2, 4]);
92/// # Ok::<(), tenferro_ad::error::Error>(())
93/// ```
94pub fn einsum(inputs: &[&EagerTensor], subscripts: &str) -> Result<EagerTensor> {
95    let subscripts = parse_einsum_subscripts(subscripts)
96        .map_err(|err| Error::ContractionError(err.to_string()))?;
97    einsum_subscripts(inputs, &subscripts)
98}
99
100/// Execute an einsum eagerly from integer labels.
101///
102/// # Examples
103///
104/// ```
105/// use tenferro_ad::{EagerRuntime, EagerTensor};
106/// use tenferro_cpu::CpuBackend;
107/// use tenferro_einsum::{EagerEinsumExt, parse_einsum_subscripts};
108/// use tenferro_tensor::Tensor;
109///
110/// let runtime = EagerRuntime::with_cpu_backend(CpuBackend::new());
111/// let a = EagerTensor::from_tensor_in(
112///     Tensor::from_vec_col_major(vec![2, 3], vec![1.0_f64; 6]).unwrap(),
113///     runtime.clone(),
114/// ).unwrap();
115/// let b = EagerTensor::from_tensor_in(
116///     Tensor::from_vec_col_major(vec![3, 4], vec![1.0_f64; 12]).unwrap(),
117///     runtime,
118/// ).unwrap();
119/// let subscripts = parse_einsum_subscripts("ij,jk->ik").unwrap();
120/// let out = [&a, &b].einsum_subscripts(&subscripts)?;
121/// assert_eq!(out.shape(), &[2, 4]);
122/// # Ok::<(), tenferro_ad::error::Error>(())
123/// ```
124pub fn einsum_subscripts(
125    inputs: &[&EagerTensor],
126    subscripts: &EinsumSubscripts,
127) -> Result<EagerTensor> {
128    if let Some(result) = try_direct_binary_dot_general(inputs, subscripts) {
129        return result;
130    }
131
132    if let Some(result) = try_whole_program_untracked(inputs, subscripts)? {
133        return Ok(result);
134    }
135
136    let output_shape_hint = infer_eager_output_shape(subscripts, inputs)?;
137    if let Some(result) = try_expand_eager_einsum(inputs, subscripts)? {
138        return Ok(result);
139    }
140
141    if let Some(first) = inputs.first() {
142        first
143            .runtime()
144            .register_extension(register_runtime)
145            .map_err(|err| Error::Internal(err.to_string()))?;
146    }
147
148    let op = Arc::new(EinsumExtensionOp::with_output_shape_hint(
149        subscripts.clone(),
150        output_shape_hint,
151        EinsumPlanSpec::Auto(default_auto_options()),
152    ));
153    let mut outputs = apply_eager(op, inputs)?;
154    outputs
155        .pop()
156        .ok_or_else(|| Error::Internal("einsum extension produced no eager output".to_string()))
157}
158
159fn try_direct_binary_dot_general(
160    inputs: &[&EagerTensor],
161    subscripts: &EinsumSubscripts,
162) -> Option<Result<EagerTensor>> {
163    if inputs.len() != 2 || subscripts.inputs.len() != 2 {
164        return None;
165    }
166
167    let lhs_labels = &subscripts.inputs[0];
168    let rhs_labels = &subscripts.inputs[1];
169    if lhs_labels.len() != inputs[0].shape().len() || rhs_labels.len() != inputs[1].shape().len() {
170        return None;
171    }
172
173    if let Some(plan) =
174        try_build_exact_output_binary_dot_plan(lhs_labels, rhs_labels, &subscripts.output)
175    {
176        return Some(match plan.operand_order {
177            BinaryDotOperandOrder::Original => inputs[0].dot_general(inputs[1], plan.config),
178            BinaryDotOperandOrder::Swapped => inputs[1].dot_general(inputs[0], plan.config),
179        });
180    }
181    None
182}
183
184/// Whether the untracked whole-program eager einsum executor is enabled.
185///
186/// Prototype gate (issue #1060 follow-up): when set, untracked N-ary eager
187/// einsum runs the whole contraction in one backend session via
188/// [`crate::eager::eager_einsum_subscripts`] instead of executing the expanded
189/// program one standard op at a time. Tracked (`requires_grad`) inputs keep the
190/// existing per-op path so eager AD recording semantics are unchanged.
191fn whole_program_untracked_enabled() -> bool {
192    std::env::var_os("TENFERRO_EAGER_WHOLE_PROGRAM").is_some()
193}
194
195/// Run an untracked eager einsum as a single backend-session program.
196///
197/// Returns `None` (so the caller falls back to the per-op expanded path) when
198/// the gate is off, there are no inputs, any input tracks gradients, or the
199/// inputs do not all share one runtime.
200fn try_whole_program_untracked(
201    inputs: &[&EagerTensor],
202    subscripts: &EinsumSubscripts,
203) -> Result<Option<EagerTensor>> {
204    if !whole_program_untracked_enabled() {
205        return Ok(None);
206    }
207    let Some(first) = inputs.first() else {
208        return Ok(None);
209    };
210    if inputs.iter().any(|tensor| tensor.tracks_grad()) {
211        return Ok(None);
212    }
213    let runtime = first.runtime();
214    if inputs
215        .iter()
216        .any(|tensor| !Arc::ptr_eq(tensor.runtime(), runtime))
217    {
218        return Ok(None);
219    }
220
221    let subs = Subscripts::from(subscripts);
222    let tensor_arcs = inputs
223        .iter()
224        .map(|tensor| tensor.materialized())
225        .collect::<Result<Vec<_>>>()?;
226    let tensors: Vec<_> = tensor_arcs.iter().map(|tensor| tensor.as_ref()).collect();
227    let result = runtime.with_backend_mut(|backend| {
228        crate::eager::eager_einsum_subscripts(backend, &tensors, &subs)
229    })??;
230    Ok(Some(EagerTensor::from_tensor_in(result, runtime.clone())?))
231}
232
233/// Run an untracked whole-program eager einsum on an explicit contraction tree.
234///
235/// Prototype/benchmark entry (issue #1060 follow-up). Executes the whole
236/// contraction in one backend session on the caller-provided path (e.g. an
237/// externally optimized `opt_flops` order via [`crate::ContractionTree::from_pairs`]),
238/// instead of one eager op per expanded step. All inputs must be untracked and
239/// share one runtime; tracked inputs should use the per-op path to keep eager
240/// AD semantics.
241///
242/// # Examples
243///
244/// ```
245/// use tenferro_ad::{EagerRuntime, EagerTensor};
246/// use tenferro_cpu::CpuBackend;
247/// use tenferro_einsum::{ContractionTree, Subscripts};
248/// use tenferro_tensor::Tensor;
249///
250/// let runtime = EagerRuntime::with_cpu_backend(CpuBackend::new());
251/// let a = EagerTensor::from_tensor_in(
252///     Tensor::from_vec_col_major(vec![2, 3], vec![1.0_f64; 6]).unwrap(),
253///     runtime.clone(),
254/// ).unwrap();
255/// let b = EagerTensor::from_tensor_in(
256///     Tensor::from_vec_col_major(vec![3, 4], vec![1.0_f64; 12]).unwrap(),
257///     runtime,
258/// ).unwrap();
259/// let subs = Subscripts::parse("ij,jk->ik").unwrap();
260/// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
261/// let out = einsum_whole_program_untracked(&[&a, &b], &tree)?;
262/// assert_eq!(out.shape(), &[2, 4]);
263/// # Ok::<(), tenferro_ad::error::Error>(())
264/// ```
265#[cfg(test)]
266fn einsum_whole_program_untracked(
267    inputs: &[&EagerTensor],
268    tree: &crate::ContractionTree,
269) -> Result<EagerTensor> {
270    let first = inputs.first().ok_or_else(|| {
271        Error::ContractionError("einsum requires at least one input tensor".into())
272    })?;
273    if inputs.iter().any(|tensor| tensor.tracks_grad()) {
274        return Err(Error::Internal(
275            "whole-program eager einsum requires untracked inputs".into(),
276        ));
277    }
278    let runtime = first.runtime();
279    if inputs
280        .iter()
281        .any(|tensor| !Arc::ptr_eq(tensor.runtime(), runtime))
282    {
283        return Err(Error::Internal(
284            "whole-program eager einsum requires inputs from one runtime".into(),
285        ));
286    }
287    let tensor_arcs = inputs
288        .iter()
289        .map(|tensor| tensor.materialized())
290        .collect::<Result<Vec<_>>>()?;
291    let tensors: Vec<_> = tensor_arcs.iter().map(|tensor| tensor.as_ref()).collect();
292    let result = runtime.with_backend_mut(|backend| {
293        crate::eager::eager_einsum_with_tree(backend, &tensors, tree)
294    })??;
295    EagerTensor::from_tensor_in(result, runtime.clone())
296}
297
298fn try_expand_eager_einsum(
299    inputs: &[&EagerTensor],
300    subscripts: &EinsumSubscripts,
301) -> Result<Option<EagerTensor>> {
302    if inputs.len() <= 1 {
303        return Ok(None);
304    }
305
306    let shapes: Vec<Vec<usize>> = inputs
307        .iter()
308        .map(|tensor| tensor.shape().to_vec())
309        .collect();
310    let shape_refs: Vec<&[usize]> = shapes.iter().map(Vec::as_slice).collect();
311    let subs = Subscripts::from(subscripts);
312    let plan_spec = EinsumPlanSpec::Auto(default_auto_options());
313
314    let program = cached_expanded_eager_program(
315        inputs[0].runtime(),
316        subscripts,
317        &subs,
318        &plan_spec,
319        &shape_refs,
320        &shapes,
321    )?;
322    execute_eager_einsum_program(inputs, &program)
323}
324
325struct ExpandedEagerProgram {
326    compiled: CompiledProgram<StdTensorOp>,
327    input_slots: Vec<(usize, usize)>,
328}
329
330fn cached_expanded_eager_program(
331    runtime: &Arc<EagerRuntime>,
332    subscripts: &EinsumSubscripts,
333    subs: &Subscripts,
334    plan_spec: &EinsumPlanSpec,
335    shape_refs: &[&[usize]],
336    shapes: &[Vec<usize>],
337) -> Result<Arc<ExpandedEagerProgram>> {
338    runtime.with_extension_caches_mut(|caches| {
339        let key = expanded_eager_program_cache_key(subscripts, plan_spec, shapes);
340        if let Some(cached) = caches.get::<Arc<ExpandedEagerProgram>>(&key) {
341            return Ok(Arc::clone(cached));
342        }
343
344        let tree = resolve_plan_spec(plan_spec, subs, shape_refs)
345            .map_err(|err| Error::ContractionError(err.to_string()))?;
346        let program = Arc::new(build_expanded_eager_program(&tree, shapes)?);
347        let retained_bytes = expanded_eager_program_retained_bytes(&program);
348        caches.put(key, Arc::clone(&program), retained_bytes);
349        Ok(program)
350    })?
351}
352
353fn expanded_eager_program_cache_key(
354    subscripts: &EinsumSubscripts,
355    plan_spec: &EinsumPlanSpec,
356    shapes: &[Vec<usize>],
357) -> ExtensionCacheKey {
358    let mut hasher = DefaultHasher::new();
359    subscripts.hash(&mut hasher);
360    shapes.hash(&mut hasher);
361    hash_einsum_plan_spec(plan_spec, &mut hasher);
362    ExtensionCacheKey::new(
363        EINSUM_EXTENSION_FAMILY_ID,
364        EINSUM_EAGER_EXPANDED_PROGRAMS_CACHE,
365        hasher.finish(),
366    )
367}
368
369fn build_expanded_eager_program(
370    tree: &crate::ContractionTree,
371    shapes: &[Vec<usize>],
372) -> Result<ExpandedEagerProgram> {
373    let mut builder = GraphBuilder::<StdTensorOp>::new();
374    let mut input_vals = Vec::with_capacity(shapes.len());
375    for input_idx in 0..shapes.len() {
376        let local = builder.add_input(TensorInputKey::User {
377            id: input_idx as u64,
378        });
379        input_vals.push(ValueRef::Local(local));
380    }
381
382    let result_ref = build_einsum_graph(&mut builder, tree, &input_vals, shapes)
383        .map_err(|err| Error::ContractionError(err.to_string()))?;
384    let ValueRef::Local(result_local) = result_ref else {
385        return Err(Error::Internal(
386            "expanded eager einsum returned an external value".into(),
387        ));
388    };
389    builder.set_outputs(vec![result_local]);
390    let graph = Arc::new(builder.build());
391    let output_key = graph.values()[result_local].key.clone();
392    let view = resolve(vec![graph]);
393    let graph = materialize_merge(&view, &[output_key]);
394    let compiled = compile(&graph);
395    let input_slots = compiled
396        .input_slots
397        .iter()
398        .zip(graph.inputs.iter())
399        .map(|(&slot, key)| {
400            let ValueKey::Input(TensorInputKey::User { id }) = key else {
401                return Err(Error::Internal(format!(
402                    "expanded eager einsum saw unexpected input key: {key:?}"
403                )));
404            };
405            Ok((slot, *id as usize))
406        })
407        .collect::<Result<_>>()?;
408
409    Ok(ExpandedEagerProgram {
410        compiled,
411        input_slots,
412    })
413}
414
415fn execute_eager_einsum_program(
416    inputs: &[&EagerTensor],
417    program: &ExpandedEagerProgram,
418) -> Result<Option<EagerTensor>> {
419    let mut slots: Vec<Option<EagerTensor>> = vec![None; program.compiled.n_slots];
420    for &(slot, input_idx) in &program.input_slots {
421        let tensor = inputs.get(input_idx).ok_or_else(|| {
422            Error::Internal(format!(
423                "expanded eager einsum input {input_idx} is missing"
424            ))
425        })?;
426        slots[slot] = Some((*tensor).clone());
427    }
428
429    let mut instruction_idx = 0;
430    while instruction_idx < program.compiled.instructions.len() {
431        if let Some((output_slot, output)) = try_execute_eager_broadcast_multiply_pattern(
432            &program.compiled.instructions,
433            instruction_idx,
434            &slots,
435            &program.compiled.output_slots,
436        )? {
437            slots[output_slot] = Some(output);
438            instruction_idx += 3;
439            continue;
440        }
441
442        let instr = &program.compiled.instructions[instruction_idx];
443        if instr.outputs.len() != 1 {
444            return Err(Error::Internal(format!(
445                "expanded eager einsum expected single-output op, got {} outputs",
446                instr.outputs.len()
447            )));
448        }
449        let input_values: Vec<EagerTensor> = instr
450            .inputs
451            .iter()
452            .map(|&slot| {
453                slots
454                    .get(slot)
455                    .and_then(Option::as_ref)
456                    .cloned()
457                    .ok_or_else(|| {
458                        Error::Internal(format!(
459                            "expanded eager einsum missing value for slot {slot}"
460                        ))
461                    })
462            })
463            .collect::<Result<_>>()?;
464        let input_refs: Vec<&EagerTensor> = input_values.iter().collect();
465        let output =
466            tenferro_ad::extension::apply_standard_op(instr.operation.clone(), &input_refs)?;
467        slots[instr.outputs[0]] = Some(output);
468        instruction_idx += 1;
469    }
470
471    let [output_slot] = program.compiled.output_slots.as_slice() else {
472        return Err(Error::Internal(format!(
473            "expanded eager einsum expected one graph output, got {}",
474            program.compiled.output_slots.len()
475        )));
476    };
477    slots
478        .get_mut(*output_slot)
479        .and_then(Option::take)
480        .map(Some)
481        .ok_or_else(|| Error::Internal("expanded eager einsum output slot is missing".into()))
482}
483
484fn expanded_eager_program_retained_bytes(program: &ExpandedEagerProgram) -> usize {
485    saturating_sum([
486        size_of::<ExpandedEagerProgram>(),
487        vec_retained_bytes(&program.input_slots),
488        compiled_program_retained_bytes(&program.compiled),
489    ])
490}
491
492fn compiled_program_retained_bytes(program: &CompiledProgram<StdTensorOp>) -> usize {
493    saturating_sum([
494        size_of::<CompiledProgram<StdTensorOp>>(),
495        vec_retained_bytes(&program.instructions),
496        vec_retained_bytes(&program.input_slots),
497        vec_retained_bytes(&program.output_slots),
498        saturating_sum(program.instructions.iter().map(instruction_retained_bytes)),
499    ])
500}
501
502fn instruction_retained_bytes(instruction: &Instruction<StdTensorOp>) -> usize {
503    saturating_sum([
504        size_of::<Instruction<StdTensorOp>>(),
505        std_tensor_op_retained_bytes(&instruction.operation),
506        vec_retained_bytes(&instruction.inputs),
507        vec_retained_bytes(&instruction.outputs),
508    ])
509}
510
511fn std_tensor_op_retained_bytes(op: &StdTensorOp) -> usize {
512    match op {
513        StdTensorOp::DotGeneral { config } => saturating_sum([
514            vec_retained_bytes(&config.lhs_contracting_dims),
515            vec_retained_bytes(&config.rhs_contracting_dims),
516            vec_retained_bytes(&config.lhs_batch_dims),
517            vec_retained_bytes(&config.rhs_batch_dims),
518        ]),
519        StdTensorOp::Transpose { perm } => vec_retained_bytes(perm),
520        StdTensorOp::Reshape { to_shape } => vec_retained_bytes(to_shape),
521        StdTensorOp::BroadcastInDim { shape, dims } => {
522            saturating_sum([vec_retained_bytes(shape), vec_retained_bytes(dims)])
523        }
524        StdTensorOp::Constant { bytes, .. } => vec_retained_bytes(bytes),
525        StdTensorOp::ReduceSum { axes }
526        | StdTensorOp::ReduceProd { axes }
527        | StdTensorOp::ReduceMax { axes }
528        | StdTensorOp::ReduceMin { axes }
529        | StdTensorOp::Reverse { axes } => vec_retained_bytes(axes),
530        StdTensorOp::DynamicSlice { slice_sizes } => vec_retained_bytes(slice_sizes),
531        StdTensorOp::GatherDynamicSliceSizes {
532            offset_dims,
533            collapsed_slice_dims,
534            start_index_map,
535            slice_sizes,
536            ..
537        } => saturating_sum([
538            vec_retained_bytes(offset_dims),
539            vec_retained_bytes(collapsed_slice_dims),
540            vec_retained_bytes(start_index_map),
541            vec_retained_bytes(slice_sizes),
542        ]),
543        _ => 0,
544    }
545}
546
547fn try_execute_eager_broadcast_multiply_pattern(
548    instructions: &[Instruction<StdTensorOp>],
549    instruction_idx: usize,
550    slots: &[Option<EagerTensor>],
551    output_slots: &[usize],
552) -> Result<Option<(usize, EagerTensor)>> {
553    if instruction_idx + 2 >= instructions.len() {
554        return Ok(None);
555    }
556    let lhs_bc = &instructions[instruction_idx];
557    let rhs_bc = &instructions[instruction_idx + 1];
558    let multiply = &instructions[instruction_idx + 2];
559
560    let StdTensorOp::BroadcastInDim {
561        shape: lhs_shape_exprs,
562        dims: lhs_dims,
563    } = &lhs_bc.operation
564    else {
565        return Ok(None);
566    };
567    let StdTensorOp::BroadcastInDim {
568        shape: rhs_shape_exprs,
569        dims: rhs_dims,
570    } = &rhs_bc.operation
571    else {
572        return Ok(None);
573    };
574    if !matches!(multiply.operation, StdTensorOp::Mul)
575        || lhs_bc.outputs.len() != 1
576        || rhs_bc.outputs.len() != 1
577        || multiply.outputs.len() != 1
578        || multiply.inputs.len() != 2
579        || lhs_bc.inputs.is_empty()
580        || rhs_bc.inputs.is_empty()
581        || multiply.inputs[0] != lhs_bc.outputs[0]
582        || multiply.inputs[1] != rhs_bc.outputs[0]
583    {
584        return Ok(None);
585    }
586
587    let lhs_bc_slot = lhs_bc.outputs[0];
588    let rhs_bc_slot = rhs_bc.outputs[0];
589    if output_slots.contains(&lhs_bc_slot)
590        || output_slots.contains(&rhs_bc_slot)
591        || instructions[instruction_idx + 3..]
592            .iter()
593            .any(|instr| instr.inputs.contains(&lhs_bc_slot) || instr.inputs.contains(&rhs_bc_slot))
594    {
595        return Ok(None);
596    }
597
598    let lhs = slot_tensor(slots, lhs_bc.inputs[0])?;
599    let rhs = slot_tensor(slots, rhs_bc.inputs[0])?;
600    let lhs_shape = eval_shape_exprs(slots, &lhs_bc.inputs, lhs_shape_exprs)?;
601    let rhs_shape = eval_shape_exprs(slots, &rhs_bc.inputs, rhs_shape_exprs)?;
602    let Some(output) =
603        backend_broadcast_multiply_untracked(lhs, &lhs_shape, lhs_dims, rhs, &rhs_shape, rhs_dims)?
604    else {
605        return Ok(None);
606    };
607
608    Ok(Some((multiply.outputs[0], output)))
609}
610
611#[allow(clippy::too_many_arguments)]
612fn backend_broadcast_multiply_untracked(
613    lhs: &EagerTensor,
614    lhs_shape: &[usize],
615    lhs_dims: &[usize],
616    rhs: &EagerTensor,
617    rhs_shape: &[usize],
618    rhs_dims: &[usize],
619) -> Result<Option<EagerTensor>> {
620    if !Arc::ptr_eq(lhs.runtime(), rhs.runtime()) {
621        return Err(Error::ContextMismatch {
622            lhs: lhs.ctx_id(),
623            rhs: rhs.ctx_id(),
624        });
625    }
626    if lhs.tracks_grad() || rhs.tracks_grad() {
627        return Ok(None);
628    }
629
630    let runtime = lhs.runtime();
631    let value = runtime.with_backend_mut(|backend| {
632        backend.execute_broadcast_multiply_value(
633            lhs.tensor_read(),
634            lhs_shape,
635            lhs_dims,
636            rhs.tensor_read(),
637            rhs_shape,
638            rhs_dims,
639        )
640    })??;
641
642    Ok(value.map(|value| adopt_untracked_eager_value(runtime.clone(), value)))
643}
644
645fn eval_shape_exprs(
646    slots: &[Option<EagerTensor>],
647    input_slots: &[usize],
648    shape: &[DimExpr],
649) -> Result<Vec<usize>> {
650    let inputs = input_slots
651        .iter()
652        .map(|&slot| slot_tensor(slots, slot))
653        .collect::<Result<Vec<_>>>()?;
654    let input_shapes = inputs
655        .iter()
656        .map(|tensor| tensor.shape())
657        .collect::<Vec<_>>();
658    DimExpr::eval_all(shape, &input_shapes).map_err(|err| Error::InvalidCompiledGraph {
659        message: format!("invalid eager einsum shape expression: {err}"),
660    })
661}
662
663fn slot_tensor(slots: &[Option<EagerTensor>], slot: usize) -> Result<&EagerTensor> {
664    slots.get(slot).and_then(Option::as_ref).ok_or_else(|| {
665        Error::Internal(format!(
666            "expanded eager einsum missing value for slot {slot}"
667        ))
668    })
669}
670
671fn infer_eager_output_shape(
672    subscripts: &EinsumSubscripts,
673    inputs: &[&EagerTensor],
674) -> Result<Vec<tenferro_runtime::SymDim>> {
675    if inputs.is_empty() {
676        return Err(Error::ContractionError(
677            "einsum requires at least one input tensor".into(),
678        ));
679    }
680    if subscripts.inputs.len() != inputs.len() {
681        return Err(Error::ContractionError(format!(
682            "einsum subscripts expect {} inputs, got {}",
683            subscripts.inputs.len(),
684            inputs.len()
685        )));
686    }
687
688    let mut label_dims = std::collections::HashMap::new();
689    for (labels, tensor) in subscripts.inputs.iter().zip(inputs.iter()) {
690        let shape = tensor.shape();
691        if labels.len() != shape.len() {
692            return Err(Error::ContractionError(format!(
693                "einsum input rank mismatch: labels={}, shape={}",
694                labels.len(),
695                shape.len()
696            )));
697        }
698        for (&label, &dim) in labels.iter().zip(shape.iter()) {
699            if let Some(existing) = label_dims.insert(label, dim) {
700                if existing != dim {
701                    return Err(Error::ContractionError(format!(
702                        "einsum label {label} has inconsistent dimensions {existing} and {dim}"
703                    )));
704                }
705            }
706        }
707    }
708
709    subscripts
710        .output
711        .iter()
712        .map(|label| {
713            label_dims
714                .get(label)
715                .copied()
716                .map(tenferro_runtime::SymDim::from)
717                .ok_or_else(|| {
718                    Error::ContractionError(format!(
719                        "einsum output label {label} is missing from input labels"
720                    ))
721                })
722        })
723        .collect()
724}
725
726/// Execute a NumPy-style tensor contraction on [`EagerTensor`] values.
727///
728/// This helper lives in the einsum extension trait surface because it is
729/// contraction sugar over `dot_general`, not a linear algebra facade.
730///
731/// # Examples
732///
733/// ```
734/// use tenferro_tensor::Tensor;
735/// use tenferro_cpu::CpuBackend;
736/// use tenferro_ad::{EagerRuntime, EagerTensor};
737/// use tenferro_einsum::{EagerTensorEinsumExt, TensorDotAxes};
738///
739/// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
740/// let lhs = EagerTensor::from_tensor_in(
741///     Tensor::from_vec_col_major(vec![2, 3], vec![1.0_f64; 6]).unwrap(),
742///     ctx.clone(),
743/// ).unwrap();
744/// let rhs = EagerTensor::from_tensor_in(
745///     Tensor::from_vec_col_major(vec![3, 4], vec![1.0_f64; 12]).unwrap(),
746///     ctx,
747/// ).unwrap();
748/// let out = lhs.tensordot(&rhs, TensorDotAxes::Count(1)).unwrap();
749///
750/// assert_eq!(out.shape(), &[2, 4]);
751/// ```
752pub fn tensordot(
753    lhs: &EagerTensor,
754    rhs: &EagerTensor,
755    axes: TensorDotAxes<'_>,
756) -> Result<EagerTensor> {
757    let config = crate::tensordot::dot_general_config(axes, lhs.shape().len(), rhs.shape().len())?;
758    crate::tensordot::validate_concrete_contract_dims(lhs.shape(), rhs.shape(), &config)?;
759    lhs.dot_general(rhs, config)
760}
761
762#[cfg(test)]
763mod tests;