1use std::collections::hash_map::DefaultHasher;
2use std::hash::{Hash, Hasher};
3use std::sync::Arc;
4
5use computegraph::types::ValueRef;
6use tenferro_ops::dim_expr::DimExpr;
7use tenferro_ops::ext_op::ExtensionOp;
8use tenferro_runtime::error::{Error, Result};
9use tenferro_runtime::extension::{self, ExtensionCacheKey, ExtensionCacheStore};
10use tenferro_runtime::{GraphCompiler, SymDim, TracedTensor};
11
12use crate::binary_dot::{try_build_exact_output_binary_dot_plan, BinaryDotOperandOrder};
13use crate::builder::build_einsum_graph_dim_expr;
14use crate::cache::{
15 einsum_subscripts_retained_bytes, saturating_sum, vec_retained_bytes, ParsedEinsum,
16 EINSUM_EXTENSION_FAMILY_ID, EINSUM_PARSE_CACHE, EINSUM_STATIC_PLANS_CACHE,
17};
18use crate::extension::EinsumExtensionOp;
19use crate::optimize::{
20 hash_einsum_plan_spec, plan_spec_from_optimize, resolve_einsum_strategy_with_spec,
21 resolve_plan_spec, EinsumPlanSpec,
22};
23use crate::{
24 parse_einsum_subscripts, ContractionTree, EinsumOptimize, EinsumSubscripts,
25 Error as EinsumError, Result as EinsumResult, Subscripts, TensorDotAxes,
26};
27
28pub trait GraphCompilerEinsumExt {
30 fn einsum(&mut self, inputs: &[&TracedTensor], subscripts: &str) -> Result<TracedTensor>;
31 fn einsum_subscripts(
32 &mut self,
33 inputs: &[&TracedTensor],
34 subscripts: &EinsumSubscripts,
35 ) -> Result<TracedTensor>;
36 fn einsum_with(
37 &mut self,
38 inputs: &[&TracedTensor],
39 subscripts: &str,
40 optimize: EinsumOptimize,
41 ) -> Result<TracedTensor>;
42 fn einsum_subscripts_with(
43 &mut self,
44 inputs: &[&TracedTensor],
45 subscripts: &EinsumSubscripts,
46 optimize: EinsumOptimize,
47 ) -> Result<TracedTensor>;
48}
49
50impl GraphCompilerEinsumExt for GraphCompiler {
51 fn einsum(&mut self, inputs: &[&TracedTensor], subscripts: &str) -> Result<TracedTensor> {
52 einsum(self, inputs, subscripts)
53 }
54
55 fn einsum_subscripts(
56 &mut self,
57 inputs: &[&TracedTensor],
58 subscripts: &EinsumSubscripts,
59 ) -> Result<TracedTensor> {
60 einsum_subscripts(self, inputs, subscripts)
61 }
62
63 fn einsum_with(
64 &mut self,
65 inputs: &[&TracedTensor],
66 subscripts: &str,
67 optimize: EinsumOptimize,
68 ) -> Result<TracedTensor> {
69 einsum_with(self, inputs, subscripts, optimize)
70 }
71
72 fn einsum_subscripts_with(
73 &mut self,
74 inputs: &[&TracedTensor],
75 subscripts: &EinsumSubscripts,
76 optimize: EinsumOptimize,
77 ) -> Result<TracedTensor> {
78 einsum_subscripts_with(self, inputs, subscripts, optimize)
79 }
80}
81
82pub trait TracedTensorEinsumExt {
84 fn tensordot(&self, rhs: &TracedTensor, axes: TensorDotAxes<'_>) -> Result<TracedTensor>;
85}
86
87impl TracedTensorEinsumExt for TracedTensor {
88 fn tensordot(&self, rhs: &TracedTensor, axes: TensorDotAxes<'_>) -> Result<TracedTensor> {
89 tensordot(self, rhs, axes)
90 }
91}
92
93pub fn einsum(
100 compiler: &mut GraphCompiler,
101 inputs: &[&TracedTensor],
102 subscripts: &str,
103) -> Result<TracedTensor> {
104 einsum_with(compiler, inputs, subscripts, EinsumOptimize::default())
105}
106
107pub fn einsum_subscripts(
114 compiler: &mut GraphCompiler,
115 inputs: &[&TracedTensor],
116 subscripts: &EinsumSubscripts,
117) -> Result<TracedTensor> {
118 einsum_subscripts_with(compiler, inputs, subscripts, EinsumOptimize::default())
119}
120
121pub fn einsum_with(
135 compiler: &mut GraphCompiler,
136 inputs: &[&TracedTensor],
137 subscripts: &str,
138 optimize: EinsumOptimize,
139) -> Result<TracedTensor> {
140 let parsed = cached_subscripts(compiler.extension_caches_mut(), subscripts)?;
141 einsum_subscripts_with(compiler, inputs, &parsed.subscripts, optimize)
142}
143
144pub fn einsum_subscripts_with(
158 compiler: &mut GraphCompiler,
159 inputs: &[&TracedTensor],
160 subscripts: &EinsumSubscripts,
161 optimize: EinsumOptimize,
162) -> Result<TracedTensor> {
163 if inputs.is_empty() {
164 return Err(Error::ContractionError(
165 "einsum requires at least one input tensor".into(),
166 ));
167 }
168 if subscripts.inputs.len() != inputs.len() {
169 return Err(Error::ContractionError(format!(
170 "einsum subscripts expect {} inputs, got {}",
171 subscripts.inputs.len(),
172 inputs.len()
173 )));
174 }
175
176 let output_shape_hint = infer_symbolic_output_shape(subscripts, inputs)?;
177 if let Some(result) = try_direct_binary_dot_general(inputs, subscripts, &optimize)? {
178 return Ok(result);
179 }
180
181 let subs = Subscripts::from(subscripts);
182
183 let (plan_spec, static_tree) = if let Some(shapes) = concrete_shapes(inputs) {
184 let shape_refs: Vec<&[usize]> = shapes.iter().map(Vec::as_slice).collect();
185 let (plan_spec, tree) = match optimize {
186 EinsumOptimize::Tree(tree) => {
187 let (plan_spec, tree) = resolve_einsum_strategy_with_spec(
188 EinsumOptimize::Tree(tree),
189 &subs,
190 &shape_refs,
191 )
192 .map_err(to_tenferro_error)?;
193 let tree = cached_static_tree(
194 compiler.extension_caches_mut(),
195 subscripts,
196 &plan_spec,
197 &shapes,
198 || Ok(tree),
199 )?;
200 (plan_spec, tree)
201 }
202 optimize => {
203 let plan_spec =
204 plan_spec_from_optimize(optimize, &subs).map_err(to_tenferro_error)?;
205 let tree = cached_static_tree(
206 compiler.extension_caches_mut(),
207 subscripts,
208 &plan_spec,
209 &shapes,
210 || resolve_plan_spec(&plan_spec, &subs, &shape_refs),
211 )?;
212 (plan_spec, tree)
213 }
214 };
215 (plan_spec, Some(tree))
216 } else {
217 let plan_spec = plan_spec_from_optimize(optimize, &subs).map_err(to_tenferro_error)?;
218 let tree = symbolic_fixed_path_tree(&plan_spec, &subs, inputs)?;
219 (plan_spec, tree.map(Arc::new))
220 };
221
222 if let Some(tree) = static_tree {
223 return expand_traced_einsum_graph(inputs, subscripts, tree.as_ref(), output_shape_hint);
224 }
225
226 let op =
227 EinsumExtensionOp::with_output_shape_hint(subscripts.clone(), output_shape_hint, plan_spec);
228 let outputs = extension::apply(Arc::new(op), inputs)?;
229 outputs
230 .into_iter()
231 .next()
232 .ok_or_else(|| Error::Internal("einsum extension produced no output".into()))
233}
234
235fn tensordot(
236 lhs: &TracedTensor,
237 rhs: &TracedTensor,
238 axes: TensorDotAxes<'_>,
239) -> Result<TracedTensor> {
240 let config = crate::tensordot::dot_general_config(axes, lhs.rank, rhs.rank)?;
241 crate::tensordot::validate_traced_contract_dims(lhs, rhs, &config)?;
242 lhs.dot_general(rhs, config)
243}
244
245fn expand_traced_einsum_graph(
246 inputs: &[&TracedTensor],
247 subscripts: &EinsumSubscripts,
248 tree: &ContractionTree,
249 output_shape_hint: Vec<SymDim>,
250) -> Result<TracedTensor> {
251 let op = EinsumExtensionOp::with_output_shape_hint(
252 subscripts.clone(),
253 output_shape_hint,
254 EinsumPlanSpec::LeftToRight,
255 );
256 let input_dtypes: Vec<_> = inputs.iter().map(|tensor| tensor.dtype).collect();
257 let input_sym_shapes: Vec<Vec<SymDim>> = inputs
258 .iter()
259 .map(|tensor| match tensor.sym_shape() {
260 Some(shape) => Ok(shape.to_vec()),
261 None => (0..tensor.rank)
262 .map(|axis| tensor.axis_sym_dim(axis))
263 .collect(),
264 })
265 .collect::<Result<_>>()?;
266 let input_sym_shape_refs: Vec<_> = input_sym_shapes.iter().map(Vec::as_slice).collect();
267 let output_metas = op.infer_output_meta(&input_dtypes, &input_sym_shape_refs);
268 let input_dim_shapes = traced_dim_expr_shapes(inputs);
269
270 let outputs = extension::apply_expanded_graph(inputs, output_metas, |builder, input_refs| {
271 let result = build_einsum_graph_dim_expr(builder, tree, input_refs, &input_dim_shapes)
272 .map_err(|err| Error::ContractionError(err.to_string()))?;
273 let ValueRef::Local(local) = result else {
274 return Err(Error::Internal(
275 "expanded einsum returned an external value".into(),
276 ));
277 };
278 Ok(vec![local])
279 })?;
280
281 outputs
282 .into_iter()
283 .next()
284 .ok_or_else(|| Error::Internal("expanded einsum produced no output".into()))
285}
286
287fn traced_dim_expr_shapes(inputs: &[&TracedTensor]) -> Vec<Vec<DimExpr>> {
288 inputs
289 .iter()
290 .map(|tensor| DimExpr::input_shape(0, tensor.rank))
291 .collect()
292}
293
294fn symbolic_fixed_path_tree(
295 plan_spec: &EinsumPlanSpec,
296 subs: &Subscripts,
297 inputs: &[&TracedTensor],
298) -> Result<Option<ContractionTree>> {
299 if matches!(plan_spec, EinsumPlanSpec::Auto(_)) {
300 return Ok(None);
301 }
302 let dummy_shapes = symbolic_dummy_shapes(inputs);
303 let shape_refs: Vec<&[usize]> = dummy_shapes.iter().map(Vec::as_slice).collect();
304 resolve_plan_spec(plan_spec, subs, &shape_refs)
305 .map(Some)
306 .map_err(to_tenferro_error)
307}
308
309fn symbolic_dummy_shapes(inputs: &[&TracedTensor]) -> Vec<Vec<usize>> {
310 inputs.iter().map(|tensor| vec![1; tensor.rank]).collect()
311}
312
313fn try_direct_binary_dot_general(
314 inputs: &[&TracedTensor],
315 subscripts: &EinsumSubscripts,
316 optimize: &EinsumOptimize,
317) -> Result<Option<TracedTensor>> {
318 if inputs.len() != 2 || subscripts.inputs.len() != 2 {
319 return Ok(None);
320 }
321 if !optimize_allows_direct_binary_dot(optimize)? {
322 return Ok(None);
323 }
324
325 let lhs_labels = &subscripts.inputs[0];
326 let rhs_labels = &subscripts.inputs[1];
327 if lhs_labels.len() != inputs[0].rank || rhs_labels.len() != inputs[1].rank {
328 return Ok(None);
329 }
330 validate_direct_binary_dot_label_dims(inputs, subscripts)?;
331
332 let Some(plan) =
333 try_build_exact_output_binary_dot_plan(lhs_labels, rhs_labels, &subscripts.output)
334 else {
335 return Ok(None);
336 };
337
338 let result = match plan.operand_order {
339 BinaryDotOperandOrder::Original => inputs[0].dot_general(inputs[1], plan.config)?,
340 BinaryDotOperandOrder::Swapped => inputs[1].dot_general(inputs[0], plan.config)?,
341 };
342 Ok(Some(result))
343}
344
345fn validate_direct_binary_dot_label_dims(
346 inputs: &[&TracedTensor],
347 subscripts: &EinsumSubscripts,
348) -> Result<()> {
349 let mut label_dims = std::collections::HashMap::new();
350 for (labels, tensor) in subscripts.inputs.iter().zip(inputs.iter()) {
351 let Some(shape) = tensor.sym_shape() else {
352 continue;
353 };
354 for (&label, dim) in labels.iter().zip(shape.iter()) {
355 let Some(dim) = dim.constant_value() else {
356 continue;
357 };
358 if let Some(existing) = label_dims.insert(label, dim) {
359 if existing != dim {
360 return Err(Error::ContractionError(format!(
361 "einsum label {label} has inconsistent dimensions {existing} and {dim}"
362 )));
363 }
364 }
365 }
366 }
367 Ok(())
368}
369
370fn optimize_allows_direct_binary_dot(optimize: &EinsumOptimize) -> Result<bool> {
371 match optimize {
372 EinsumOptimize::Auto(options) => {
373 options.validate().map_err(to_tenferro_error)?;
374 Ok(true)
375 }
376 EinsumOptimize::False => Ok(true),
377 EinsumOptimize::Tree(tree) => {
378 Ok(tree.step_count() == 1 && matches!(tree.step_pair(0), Some((0, 1)) | Some((1, 0))))
379 }
380 EinsumOptimize::Nested(_) | EinsumOptimize::Path(_) => Ok(false),
381 }
382}
383
384fn cached_subscripts(
385 caches: &mut ExtensionCacheStore,
386 notation: &str,
387) -> Result<Arc<ParsedEinsum>> {
388 let key = ExtensionCacheKey::new(
389 EINSUM_EXTENSION_FAMILY_ID,
390 EINSUM_PARSE_CACHE,
391 hash_value(¬ation),
392 );
393 if let Some(cached) = caches.get::<Arc<ParsedEinsum>>(&key) {
394 return Ok(Arc::clone(cached));
395 }
396
397 let parsed = Arc::new(ParsedEinsum {
398 subscripts: parse_einsum_subscripts(notation).map_err(to_tenferro_error)?,
399 });
400 let retained_bytes = saturating_sum([
401 notation.len(),
402 einsum_subscripts_retained_bytes(&parsed.subscripts),
403 ]);
404 caches.put(key, Arc::clone(&parsed), retained_bytes);
405 Ok(parsed)
406}
407
408fn cached_static_tree(
409 caches: &mut ExtensionCacheStore,
410 subscripts: &EinsumSubscripts,
411 plan_spec: &EinsumPlanSpec,
412 shapes: &[Vec<usize>],
413 build: impl FnOnce() -> EinsumResult<ContractionTree>,
414) -> Result<Arc<ContractionTree>> {
415 let mut plan_hasher = DefaultHasher::new();
416 hash_einsum_plan_spec(plan_spec, &mut plan_hasher);
417 let key_data = (subscripts.clone(), shapes.to_vec(), plan_hasher.finish());
418 let key = ExtensionCacheKey::new(
419 EINSUM_EXTENSION_FAMILY_ID,
420 EINSUM_STATIC_PLANS_CACHE,
421 hash_value(&key_data),
422 );
423 if let Some(cached) = caches.get::<Arc<ContractionTree>>(&key) {
424 return Ok(Arc::clone(cached));
425 }
426
427 let tree = Arc::new(build().map_err(to_tenferro_error)?);
428 let retained_bytes = saturating_sum([
429 einsum_subscripts_retained_bytes(subscripts),
430 saturating_sum(shapes.iter().map(vec_retained_bytes)),
431 std::mem::size_of::<u64>(),
432 tree.retained_bytes_for_cache_stats(),
433 ]);
434 caches.put(key, Arc::clone(&tree), retained_bytes);
435 Ok(tree)
436}
437
438fn concrete_shapes(inputs: &[&TracedTensor]) -> Option<Vec<Vec<usize>>> {
439 inputs
440 .iter()
441 .map(|tensor| {
442 tensor
443 .sym_shape()?
444 .iter()
445 .map(|dim| dim.constant_value())
446 .collect::<Option<Vec<_>>>()
447 })
448 .collect()
449}
450
451fn infer_symbolic_output_shape(
452 subscripts: &EinsumSubscripts,
453 inputs: &[&TracedTensor],
454) -> Result<Vec<SymDim>> {
455 let mut label_dims = std::collections::HashMap::new();
456 for (labels, tensor) in subscripts.inputs.iter().zip(inputs.iter()) {
457 let shape: Vec<_> = match tensor.sym_shape() {
458 Some(shape) => shape.to_vec(),
459 None => (0..tensor.rank)
460 .map(|axis| tensor.axis_sym_dim(axis))
461 .collect::<Result<_>>()?,
462 };
463 if labels.len() != shape.len() {
464 return Err(Error::ContractionError(format!(
465 "einsum input rank mismatch: labels={}, shape={}",
466 labels.len(),
467 shape.len()
468 )));
469 }
470 for (&label, dim) in labels.iter().zip(shape) {
471 label_dims.entry(label).or_insert(dim);
472 }
473 }
474 subscripts
475 .output
476 .iter()
477 .map(|label| {
478 label_dims.get(label).cloned().ok_or_else(|| {
479 Error::ContractionError(format!(
480 "einsum output label {label} is missing from inputs"
481 ))
482 })
483 })
484 .collect()
485}
486
487fn to_tenferro_error(error: EinsumError) -> Error {
488 Error::ContractionError(error.to_string())
489}
490
491fn hash_value<T: Hash + ?Sized>(value: &T) -> u64 {
492 let mut hasher = DefaultHasher::new();
493 value.hash(&mut hasher);
494 hasher.finish()
495}
496
497#[cfg(test)]
498mod tests;