Skip to main content

tenferro_einsum/
traced.rs

1use std::collections::hash_map::DefaultHasher;
2use std::hash::{Hash, Hasher};
3use std::sync::Arc;
4
5use computegraph::types::ValueRef;
6use tenferro_ops::dim_expr::DimExpr;
7use tenferro_ops::ext_op::ExtensionOp;
8use tenferro_runtime::error::{Error, Result};
9use tenferro_runtime::extension::{self, ExtensionCacheKey, ExtensionCacheStore};
10use tenferro_runtime::{GraphCompiler, SymDim, TracedTensor};
11
12use crate::binary_dot::{try_build_exact_output_binary_dot_plan, BinaryDotOperandOrder};
13use crate::builder::build_einsum_graph_dim_expr;
14use crate::cache::{
15    einsum_subscripts_retained_bytes, saturating_sum, vec_retained_bytes, ParsedEinsum,
16    EINSUM_EXTENSION_FAMILY_ID, EINSUM_PARSE_CACHE, EINSUM_STATIC_PLANS_CACHE,
17};
18use crate::extension::EinsumExtensionOp;
19use crate::optimize::{
20    hash_einsum_plan_spec, plan_spec_from_optimize, resolve_einsum_strategy_with_spec,
21    resolve_plan_spec, EinsumPlanSpec,
22};
23use crate::{
24    parse_einsum_subscripts, ContractionTree, EinsumOptimize, EinsumSubscripts,
25    Error as EinsumError, Result as EinsumResult, Subscripts, TensorDotAxes,
26};
27
28/// Traced einsum extension methods for [`GraphCompiler`].
29pub trait GraphCompilerEinsumExt {
30    fn einsum(&mut self, inputs: &[&TracedTensor], subscripts: &str) -> Result<TracedTensor>;
31    fn einsum_subscripts(
32        &mut self,
33        inputs: &[&TracedTensor],
34        subscripts: &EinsumSubscripts,
35    ) -> Result<TracedTensor>;
36    fn einsum_with(
37        &mut self,
38        inputs: &[&TracedTensor],
39        subscripts: &str,
40        optimize: EinsumOptimize,
41    ) -> Result<TracedTensor>;
42    fn einsum_subscripts_with(
43        &mut self,
44        inputs: &[&TracedTensor],
45        subscripts: &EinsumSubscripts,
46        optimize: EinsumOptimize,
47    ) -> Result<TracedTensor>;
48}
49
50impl GraphCompilerEinsumExt for GraphCompiler {
51    fn einsum(&mut self, inputs: &[&TracedTensor], subscripts: &str) -> Result<TracedTensor> {
52        einsum(self, inputs, subscripts)
53    }
54
55    fn einsum_subscripts(
56        &mut self,
57        inputs: &[&TracedTensor],
58        subscripts: &EinsumSubscripts,
59    ) -> Result<TracedTensor> {
60        einsum_subscripts(self, inputs, subscripts)
61    }
62
63    fn einsum_with(
64        &mut self,
65        inputs: &[&TracedTensor],
66        subscripts: &str,
67        optimize: EinsumOptimize,
68    ) -> Result<TracedTensor> {
69        einsum_with(self, inputs, subscripts, optimize)
70    }
71
72    fn einsum_subscripts_with(
73        &mut self,
74        inputs: &[&TracedTensor],
75        subscripts: &EinsumSubscripts,
76        optimize: EinsumOptimize,
77    ) -> Result<TracedTensor> {
78        einsum_subscripts_with(self, inputs, subscripts, optimize)
79    }
80}
81
82/// Traced tensor contraction-sugar methods.
83pub trait TracedTensorEinsumExt {
84    fn tensordot(&self, rhs: &TracedTensor, axes: TensorDotAxes<'_>) -> Result<TracedTensor>;
85}
86
87impl TracedTensorEinsumExt for TracedTensor {
88    fn tensordot(&self, rhs: &TracedTensor, axes: TensorDotAxes<'_>) -> Result<TracedTensor> {
89        tensordot(self, rhs, axes)
90    }
91}
92
93/// N-ary einsum with default time-optimized automatic planning.
94///
95/// The default optimizer is resolved into a shape-independent plan
96/// specification stored in the extension payload. That payload identity
97/// participates in traced extension-op equality and in compile/runtime einsum
98/// plan caches.
99pub fn einsum(
100    compiler: &mut GraphCompiler,
101    inputs: &[&TracedTensor],
102    subscripts: &str,
103) -> Result<TracedTensor> {
104    einsum_with(compiler, inputs, subscripts, EinsumOptimize::default())
105}
106
107/// N-ary einsum using integer labels and the default contraction strategy.
108///
109/// The default optimizer is resolved into a shape-independent plan
110/// specification stored in the extension payload. That payload identity
111/// participates in traced extension-op equality and in compile/runtime einsum
112/// plan caches.
113pub fn einsum_subscripts(
114    compiler: &mut GraphCompiler,
115    inputs: &[&TracedTensor],
116    subscripts: &EinsumSubscripts,
117) -> Result<TracedTensor> {
118    einsum_subscripts_with(compiler, inputs, subscripts, EinsumOptimize::default())
119}
120
121/// N-ary einsum with explicit contraction strategy.
122///
123/// `optimize` is converted to a shape-independent plan specification carried
124/// by the extension payload. `EinsumOptimize::Path` uses JAX-style positions
125/// over the current shrinking operand list, so it works with symbolic traced
126/// inputs. `EinsumOptimize::Tree` requires concrete shapes for N-ary extension
127/// execution; binary trees that lower exactly to `dot_general` bypass the
128/// extension path and may use symbolic traced inputs.
129///
130/// Planner options, explicit paths, and fixed plan identities affect traced
131/// extension payload identity and the einsum compile/runtime plan caches.
132/// Different options or paths are therefore not treated as identical extension
133/// ops.
134pub fn einsum_with(
135    compiler: &mut GraphCompiler,
136    inputs: &[&TracedTensor],
137    subscripts: &str,
138    optimize: EinsumOptimize,
139) -> Result<TracedTensor> {
140    let parsed = cached_subscripts(compiler.extension_caches_mut(), subscripts)?;
141    einsum_subscripts_with(compiler, inputs, &parsed.subscripts, optimize)
142}
143
144/// N-ary einsum with integer labels and explicit contraction strategy.
145///
146/// `optimize` is converted to a shape-independent plan specification carried
147/// by the extension payload. `EinsumOptimize::Path` uses JAX-style positions
148/// over the current shrinking operand list, so it works with symbolic traced
149/// inputs. `EinsumOptimize::Tree` requires concrete shapes for N-ary extension
150/// execution; binary trees that lower exactly to `dot_general` bypass the
151/// extension path and may use symbolic traced inputs.
152///
153/// Planner options, explicit paths, and fixed plan identities affect traced
154/// extension payload identity and the einsum compile/runtime plan caches.
155/// Different options or paths are therefore not treated as identical extension
156/// ops.
157pub fn einsum_subscripts_with(
158    compiler: &mut GraphCompiler,
159    inputs: &[&TracedTensor],
160    subscripts: &EinsumSubscripts,
161    optimize: EinsumOptimize,
162) -> Result<TracedTensor> {
163    if inputs.is_empty() {
164        return Err(Error::ContractionError(
165            "einsum requires at least one input tensor".into(),
166        ));
167    }
168    if subscripts.inputs.len() != inputs.len() {
169        return Err(Error::ContractionError(format!(
170            "einsum subscripts expect {} inputs, got {}",
171            subscripts.inputs.len(),
172            inputs.len()
173        )));
174    }
175
176    let output_shape_hint = infer_symbolic_output_shape(subscripts, inputs)?;
177    if let Some(result) = try_direct_binary_dot_general(inputs, subscripts, &optimize)? {
178        return Ok(result);
179    }
180
181    let subs = Subscripts::from(subscripts);
182
183    let (plan_spec, static_tree) = if let Some(shapes) = concrete_shapes(inputs) {
184        let shape_refs: Vec<&[usize]> = shapes.iter().map(Vec::as_slice).collect();
185        let (plan_spec, tree) = match optimize {
186            EinsumOptimize::Tree(tree) => {
187                let (plan_spec, tree) = resolve_einsum_strategy_with_spec(
188                    EinsumOptimize::Tree(tree),
189                    &subs,
190                    &shape_refs,
191                )
192                .map_err(to_tenferro_error)?;
193                let tree = cached_static_tree(
194                    compiler.extension_caches_mut(),
195                    subscripts,
196                    &plan_spec,
197                    &shapes,
198                    || Ok(tree),
199                )?;
200                (plan_spec, tree)
201            }
202            optimize => {
203                let plan_spec =
204                    plan_spec_from_optimize(optimize, &subs).map_err(to_tenferro_error)?;
205                let tree = cached_static_tree(
206                    compiler.extension_caches_mut(),
207                    subscripts,
208                    &plan_spec,
209                    &shapes,
210                    || resolve_plan_spec(&plan_spec, &subs, &shape_refs),
211                )?;
212                (plan_spec, tree)
213            }
214        };
215        (plan_spec, Some(tree))
216    } else {
217        let plan_spec = plan_spec_from_optimize(optimize, &subs).map_err(to_tenferro_error)?;
218        let tree = symbolic_fixed_path_tree(&plan_spec, &subs, inputs)?;
219        (plan_spec, tree.map(Arc::new))
220    };
221
222    if let Some(tree) = static_tree {
223        return expand_traced_einsum_graph(inputs, subscripts, tree.as_ref(), output_shape_hint);
224    }
225
226    let op =
227        EinsumExtensionOp::with_output_shape_hint(subscripts.clone(), output_shape_hint, plan_spec);
228    let outputs = extension::apply(Arc::new(op), inputs)?;
229    outputs
230        .into_iter()
231        .next()
232        .ok_or_else(|| Error::Internal("einsum extension produced no output".into()))
233}
234
235fn tensordot(
236    lhs: &TracedTensor,
237    rhs: &TracedTensor,
238    axes: TensorDotAxes<'_>,
239) -> Result<TracedTensor> {
240    let config = crate::tensordot::dot_general_config(axes, lhs.rank, rhs.rank)?;
241    crate::tensordot::validate_traced_contract_dims(lhs, rhs, &config)?;
242    lhs.dot_general(rhs, config)
243}
244
245fn expand_traced_einsum_graph(
246    inputs: &[&TracedTensor],
247    subscripts: &EinsumSubscripts,
248    tree: &ContractionTree,
249    output_shape_hint: Vec<SymDim>,
250) -> Result<TracedTensor> {
251    let op = EinsumExtensionOp::with_output_shape_hint(
252        subscripts.clone(),
253        output_shape_hint,
254        EinsumPlanSpec::LeftToRight,
255    );
256    let input_dtypes: Vec<_> = inputs.iter().map(|tensor| tensor.dtype).collect();
257    let input_sym_shapes: Vec<Vec<SymDim>> = inputs
258        .iter()
259        .map(|tensor| match tensor.sym_shape() {
260            Some(shape) => Ok(shape.to_vec()),
261            None => (0..tensor.rank)
262                .map(|axis| tensor.axis_sym_dim(axis))
263                .collect(),
264        })
265        .collect::<Result<_>>()?;
266    let input_sym_shape_refs: Vec<_> = input_sym_shapes.iter().map(Vec::as_slice).collect();
267    let output_metas = op.infer_output_meta(&input_dtypes, &input_sym_shape_refs);
268    let input_dim_shapes = traced_dim_expr_shapes(inputs);
269
270    let outputs = extension::apply_expanded_graph(inputs, output_metas, |builder, input_refs| {
271        let result = build_einsum_graph_dim_expr(builder, tree, input_refs, &input_dim_shapes)
272            .map_err(|err| Error::ContractionError(err.to_string()))?;
273        let ValueRef::Local(local) = result else {
274            return Err(Error::Internal(
275                "expanded einsum returned an external value".into(),
276            ));
277        };
278        Ok(vec![local])
279    })?;
280
281    outputs
282        .into_iter()
283        .next()
284        .ok_or_else(|| Error::Internal("expanded einsum produced no output".into()))
285}
286
287fn traced_dim_expr_shapes(inputs: &[&TracedTensor]) -> Vec<Vec<DimExpr>> {
288    inputs
289        .iter()
290        .map(|tensor| DimExpr::input_shape(0, tensor.rank))
291        .collect()
292}
293
294fn symbolic_fixed_path_tree(
295    plan_spec: &EinsumPlanSpec,
296    subs: &Subscripts,
297    inputs: &[&TracedTensor],
298) -> Result<Option<ContractionTree>> {
299    if matches!(plan_spec, EinsumPlanSpec::Auto(_)) {
300        return Ok(None);
301    }
302    let dummy_shapes = symbolic_dummy_shapes(inputs);
303    let shape_refs: Vec<&[usize]> = dummy_shapes.iter().map(Vec::as_slice).collect();
304    resolve_plan_spec(plan_spec, subs, &shape_refs)
305        .map(Some)
306        .map_err(to_tenferro_error)
307}
308
309fn symbolic_dummy_shapes(inputs: &[&TracedTensor]) -> Vec<Vec<usize>> {
310    inputs.iter().map(|tensor| vec![1; tensor.rank]).collect()
311}
312
313fn try_direct_binary_dot_general(
314    inputs: &[&TracedTensor],
315    subscripts: &EinsumSubscripts,
316    optimize: &EinsumOptimize,
317) -> Result<Option<TracedTensor>> {
318    if inputs.len() != 2 || subscripts.inputs.len() != 2 {
319        return Ok(None);
320    }
321    if !optimize_allows_direct_binary_dot(optimize)? {
322        return Ok(None);
323    }
324
325    let lhs_labels = &subscripts.inputs[0];
326    let rhs_labels = &subscripts.inputs[1];
327    if lhs_labels.len() != inputs[0].rank || rhs_labels.len() != inputs[1].rank {
328        return Ok(None);
329    }
330    validate_direct_binary_dot_label_dims(inputs, subscripts)?;
331
332    let Some(plan) =
333        try_build_exact_output_binary_dot_plan(lhs_labels, rhs_labels, &subscripts.output)
334    else {
335        return Ok(None);
336    };
337
338    let result = match plan.operand_order {
339        BinaryDotOperandOrder::Original => inputs[0].dot_general(inputs[1], plan.config)?,
340        BinaryDotOperandOrder::Swapped => inputs[1].dot_general(inputs[0], plan.config)?,
341    };
342    Ok(Some(result))
343}
344
345fn validate_direct_binary_dot_label_dims(
346    inputs: &[&TracedTensor],
347    subscripts: &EinsumSubscripts,
348) -> Result<()> {
349    let mut label_dims = std::collections::HashMap::new();
350    for (labels, tensor) in subscripts.inputs.iter().zip(inputs.iter()) {
351        let Some(shape) = tensor.sym_shape() else {
352            continue;
353        };
354        for (&label, dim) in labels.iter().zip(shape.iter()) {
355            let Some(dim) = dim.constant_value() else {
356                continue;
357            };
358            if let Some(existing) = label_dims.insert(label, dim) {
359                if existing != dim {
360                    return Err(Error::ContractionError(format!(
361                        "einsum label {label} has inconsistent dimensions {existing} and {dim}"
362                    )));
363                }
364            }
365        }
366    }
367    Ok(())
368}
369
370fn optimize_allows_direct_binary_dot(optimize: &EinsumOptimize) -> Result<bool> {
371    match optimize {
372        EinsumOptimize::Auto(options) => {
373            options.validate().map_err(to_tenferro_error)?;
374            Ok(true)
375        }
376        EinsumOptimize::False => Ok(true),
377        EinsumOptimize::Tree(tree) => {
378            Ok(tree.step_count() == 1 && matches!(tree.step_pair(0), Some((0, 1)) | Some((1, 0))))
379        }
380        EinsumOptimize::Nested(_) | EinsumOptimize::Path(_) => Ok(false),
381    }
382}
383
384fn cached_subscripts(
385    caches: &mut ExtensionCacheStore,
386    notation: &str,
387) -> Result<Arc<ParsedEinsum>> {
388    let key = ExtensionCacheKey::new(
389        EINSUM_EXTENSION_FAMILY_ID,
390        EINSUM_PARSE_CACHE,
391        hash_value(&notation),
392    );
393    if let Some(cached) = caches.get::<Arc<ParsedEinsum>>(&key) {
394        return Ok(Arc::clone(cached));
395    }
396
397    let parsed = Arc::new(ParsedEinsum {
398        subscripts: parse_einsum_subscripts(notation).map_err(to_tenferro_error)?,
399    });
400    let retained_bytes = saturating_sum([
401        notation.len(),
402        einsum_subscripts_retained_bytes(&parsed.subscripts),
403    ]);
404    caches.put(key, Arc::clone(&parsed), retained_bytes);
405    Ok(parsed)
406}
407
408fn cached_static_tree(
409    caches: &mut ExtensionCacheStore,
410    subscripts: &EinsumSubscripts,
411    plan_spec: &EinsumPlanSpec,
412    shapes: &[Vec<usize>],
413    build: impl FnOnce() -> EinsumResult<ContractionTree>,
414) -> Result<Arc<ContractionTree>> {
415    let mut plan_hasher = DefaultHasher::new();
416    hash_einsum_plan_spec(plan_spec, &mut plan_hasher);
417    let key_data = (subscripts.clone(), shapes.to_vec(), plan_hasher.finish());
418    let key = ExtensionCacheKey::new(
419        EINSUM_EXTENSION_FAMILY_ID,
420        EINSUM_STATIC_PLANS_CACHE,
421        hash_value(&key_data),
422    );
423    if let Some(cached) = caches.get::<Arc<ContractionTree>>(&key) {
424        return Ok(Arc::clone(cached));
425    }
426
427    let tree = Arc::new(build().map_err(to_tenferro_error)?);
428    let retained_bytes = saturating_sum([
429        einsum_subscripts_retained_bytes(subscripts),
430        saturating_sum(shapes.iter().map(vec_retained_bytes)),
431        std::mem::size_of::<u64>(),
432        tree.retained_bytes_for_cache_stats(),
433    ]);
434    caches.put(key, Arc::clone(&tree), retained_bytes);
435    Ok(tree)
436}
437
438fn concrete_shapes(inputs: &[&TracedTensor]) -> Option<Vec<Vec<usize>>> {
439    inputs
440        .iter()
441        .map(|tensor| {
442            tensor
443                .sym_shape()?
444                .iter()
445                .map(|dim| dim.constant_value())
446                .collect::<Option<Vec<_>>>()
447        })
448        .collect()
449}
450
451fn infer_symbolic_output_shape(
452    subscripts: &EinsumSubscripts,
453    inputs: &[&TracedTensor],
454) -> Result<Vec<SymDim>> {
455    let mut label_dims = std::collections::HashMap::new();
456    for (labels, tensor) in subscripts.inputs.iter().zip(inputs.iter()) {
457        let shape: Vec<_> = match tensor.sym_shape() {
458            Some(shape) => shape.to_vec(),
459            None => (0..tensor.rank)
460                .map(|axis| tensor.axis_sym_dim(axis))
461                .collect::<Result<_>>()?,
462        };
463        if labels.len() != shape.len() {
464            return Err(Error::ContractionError(format!(
465                "einsum input rank mismatch: labels={}, shape={}",
466                labels.len(),
467                shape.len()
468            )));
469        }
470        for (&label, dim) in labels.iter().zip(shape) {
471            label_dims.entry(label).or_insert(dim);
472        }
473    }
474    subscripts
475        .output
476        .iter()
477        .map(|label| {
478            label_dims.get(label).cloned().ok_or_else(|| {
479                Error::ContractionError(format!(
480                    "einsum output label {label} is missing from inputs"
481                ))
482            })
483        })
484        .collect()
485}
486
487fn to_tenferro_error(error: EinsumError) -> Error {
488    Error::ContractionError(error.to_string())
489}
490
491fn hash_value<T: Hash + ?Sized>(value: &T) -> u64 {
492    let mut hasher = DefaultHasher::new();
493    value.hash(&mut hasher);
494    hasher.finish()
495}
496
497#[cfg(test)]
498mod tests;