1use std::any::Any;
2use std::collections::HashMap;
3#[cfg(feature = "autodiff")]
4use std::collections::HashSet;
5use std::hash::{Hash, Hasher};
6use std::sync::Arc;
7
8use computegraph::compile::compile;
9use computegraph::graph::GraphBuilder;
10use computegraph::materialize::materialize_merge;
11use computegraph::resolve::resolve;
12#[cfg(feature = "autodiff")]
13use computegraph::types::{LocalValueId, OperationRole};
14use computegraph::types::{ValueKey, ValueRef};
15use smallvec::SmallVec;
16use tenferro_extension_macros::define_extension_runtime;
17#[cfg(feature = "autodiff")]
18use tenferro_ops::ad::context::ShapeGuardContext;
19#[cfg(feature = "autodiff")]
20use tenferro_ops::ad::PrimitiveRuleBuilder;
21#[cfg(feature = "autodiff")]
22use tenferro_ops::dim_expr::DimExpr;
23#[cfg(feature = "autodiff")]
24use tenferro_ops::ext_op::ExtensionAdRule;
25use tenferro_ops::ext_op::{ExtensionLoweringError, ExtensionLoweringResult, ExtensionOp};
26use tenferro_ops::input_key::TensorInputKey;
27use tenferro_ops::std_tensor_op::StdTensorOp;
28use tenferro_ops::sym_dim::SymDim;
29#[cfg(feature = "autodiff")]
30use tenferro_ops::{ExtensionRegistryError, ExtensionRuleSet};
31use tenferro_runtime::extension::{
32 ExecInstruction, ExecOp, ExecProgram, ExtensionCacheKey, ExtensionExecutionContext,
33};
34use tenferro_tensor::{DType, RuntimeCacheControl, Tensor, TensorBackend, TensorRead};
35#[cfg(feature = "autodiff")]
36use tidu::{ADRuleError, ADRuleKind, ADRuleResult};
37
38use crate::builder::build_einsum_graph;
39use crate::cache::{
40 einsum_subscripts_retained_bytes, saturating_sum, vec_of_vec_retained_bytes,
41 vec_retained_bytes, EINSUM_EXTENSION_FAMILY_ID, EINSUM_RUNTIME_EXEC_PROGRAMS_CACHE,
42 EINSUM_RUNTIME_PLANS_CACHE,
43};
44#[cfg(test)]
45use crate::optimize::default_auto_options;
46#[cfg(feature = "autodiff")]
47use crate::optimize::jax_path_to_v1_pairs;
48use crate::optimize::{hash_einsum_plan_spec, plan_specs_equal, resolve_plan_spec, EinsumPlanSpec};
49use crate::{
50 ContractionTree, EinsumSubscripts, Error as EinsumError, Result as EinsumResult, Subscripts,
51};
52
53type InputIndexVec = SmallVec<[usize; 8]>;
54
55#[derive(Clone)]
62pub(crate) struct EinsumExtensionOp {
63 subscripts: EinsumSubscripts,
64 plan_spec: EinsumPlanSpec,
65 static_tree: Option<Arc<ContractionTree>>,
69 output_shape_hint: Option<Vec<SymDim>>,
70}
71
72impl std::fmt::Debug for EinsumExtensionOp {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 f.debug_struct("EinsumExtensionOp")
75 .field("subscripts", &self.subscripts)
76 .field("plan_spec", &self.plan_spec)
77 .field("has_static_tree", &self.static_tree.is_some())
78 .field("output_shape_hint", &self.output_shape_hint)
79 .finish()
80 }
81}
82
83impl EinsumExtensionOp {
84 #[must_use]
86 #[cfg(test)]
87 pub(crate) fn new(subscripts: EinsumSubscripts) -> Self {
88 Self::with_plan_spec(subscripts, EinsumPlanSpec::Auto(default_auto_options()))
89 }
90
91 #[must_use]
92 pub(crate) fn with_plan_spec(subscripts: EinsumSubscripts, plan_spec: EinsumPlanSpec) -> Self {
93 Self {
94 subscripts,
95 plan_spec,
96 static_tree: None,
97 output_shape_hint: None,
98 }
99 }
100
101 #[must_use]
103 #[cfg(test)]
104 pub(crate) fn with_static_tree(
105 subscripts: EinsumSubscripts,
106 tree: Arc<ContractionTree>,
107 ) -> Self {
108 Self::new(subscripts).with_static_tree_hint(tree)
109 }
110
111 #[must_use]
113 pub(crate) fn with_output_shape_hint(
114 subscripts: EinsumSubscripts,
115 output_shape_hint: Vec<SymDim>,
116 plan_spec: EinsumPlanSpec,
117 ) -> Self {
118 let mut op = Self::with_plan_spec(subscripts, plan_spec);
119 op.output_shape_hint = Some(output_shape_hint);
120 op
121 }
122
123 #[must_use]
125 #[cfg(any(test, feature = "autodiff"))]
126 pub(crate) fn with_static_tree_hint(mut self, tree: Arc<ContractionTree>) -> Self {
127 self.static_tree = Some(tree);
128 self
129 }
130
131 #[must_use]
133 pub(crate) fn subscripts(&self) -> &EinsumSubscripts {
134 &self.subscripts
135 }
136
137 #[must_use]
139 pub(crate) fn plan_spec(&self) -> &EinsumPlanSpec {
140 &self.plan_spec
141 }
142
143 #[must_use]
145 pub(crate) fn static_tree(&self) -> Option<&Arc<ContractionTree>> {
146 self.static_tree.as_ref()
147 }
148}
149
150impl ExtensionOp for EinsumExtensionOp {
151 fn family_id(&self) -> &'static str {
152 EINSUM_EXTENSION_FAMILY_ID
153 }
154
155 fn payload_hash(&self, hasher: &mut dyn Hasher) {
156 hasher.write_usize(self.subscripts.inputs.len());
157 for input in &self.subscripts.inputs {
158 hasher.write_usize(input.len());
159 for label in input {
160 hasher.write_u32(*label);
161 }
162 }
163 hasher.write_usize(self.subscripts.output.len());
164 for label in &self.subscripts.output {
165 hasher.write_u32(*label);
166 }
167 hash_einsum_plan_spec(self.plan_spec(), hasher);
168 if let Some(shape) = &self.output_shape_hint {
169 hasher.write_usize(shape.len());
170 for dim in shape {
171 match dim.constant_value() {
172 Some(value) => {
173 hasher.write_u8(1);
174 hasher.write_usize(value);
175 }
176 None => hasher.write_u8(0),
177 }
178 }
179 } else {
180 hasher.write_usize(usize::MAX);
181 }
182 }
183
184 fn payload_eq(&self, other: &dyn ExtensionOp) -> bool {
185 other.as_any().downcast_ref::<Self>().is_some_and(|that| {
186 self.subscripts == that.subscripts
187 && plan_specs_equal(self.plan_spec(), that.plan_spec())
188 && self.output_shape_hint == that.output_shape_hint
189 })
190 }
191
192 fn clone_arc(&self) -> Arc<dyn ExtensionOp> {
193 Arc::new(self.clone())
194 }
195
196 fn as_any(&self) -> &dyn Any {
197 self
198 }
199
200 fn input_count(&self) -> usize {
201 self.subscripts.inputs.len()
202 }
203
204 fn output_count(&self) -> usize {
205 1
206 }
207
208 fn infer_output_meta(
209 &self,
210 input_dtypes: &[DType],
211 input_shapes: &[&[SymDim]],
212 ) -> Vec<(DType, Vec<SymDim>)> {
213 if input_shapes.len() != self.subscripts.inputs.len()
214 || input_dtypes.len() != input_shapes.len()
215 {
216 return Vec::new();
217 }
218
219 let mut label_dims: HashMap<u32, SymDim> = HashMap::new();
220 for (labels, shape) in self.subscripts.inputs.iter().zip(input_shapes.iter()) {
221 if labels.len() != shape.len() {
222 return Vec::new();
223 }
224 for (&label, dim) in labels.iter().zip(shape.iter()) {
225 if let Some(existing) = label_dims.get(&label) {
226 if let (Some(lhs), Some(rhs)) =
227 (existing.constant_value(), dim.constant_value())
228 {
229 if lhs != rhs {
230 return Vec::new();
231 }
232 }
233 } else {
234 label_dims.insert(label, dim.clone());
235 }
236 }
237 }
238
239 let output_shape = match &self.output_shape_hint {
240 Some(shape) if shape.iter().all(|dim| dim.constant_value().is_some()) => shape.clone(),
241 _ => self
242 .subscripts
243 .output
244 .iter()
245 .map(|label| label_dims.get(label).cloned())
246 .collect::<Option<Vec<_>>>()
247 .unwrap_or_default(),
248 };
249 if output_shape.len() != self.subscripts.output.len() {
250 return Vec::new();
251 }
252 vec![(promote_dtypes(input_dtypes.iter().copied()), output_shape)]
253 }
254
255 fn eager_execute(&self, inputs: &[&Tensor]) -> tenferro_tensor::Result<Vec<Tensor>> {
256 let mut backend = tenferro_cpu::CpuBackend::new();
257 let subscripts = Subscripts::from(&self.subscripts);
258 crate::eager::eager_einsum_subscripts(&mut backend, inputs, &subscripts)
259 .map(|output| vec![output])
260 }
261
262 fn lower_to_standard_ops(
263 &self,
264 builder: &mut GraphBuilder<StdTensorOp>,
265 inputs: &[ValueRef<StdTensorOp>],
266 input_dtypes: &[DType],
267 input_shapes: &[&[SymDim]],
268 ) -> ExtensionLoweringResult {
269 if inputs.len() != self.input_count()
270 || input_dtypes.len() != self.input_count()
271 || input_shapes.len() != self.input_count()
272 {
273 return Err(ExtensionLoweringError::new(format!(
274 "einsum extension expects {} inputs, got values={}, dtypes={}, shapes={}",
275 self.input_count(),
276 inputs.len(),
277 input_dtypes.len(),
278 input_shapes.len()
279 )));
280 }
281
282 let Some(shapes) = concrete_sym_shape_slices(input_shapes) else {
283 return Ok(None);
284 };
285 let shape_refs: Vec<&[usize]> = shapes.iter().map(Vec::as_slice).collect();
286 let subs = Subscripts::from(&self.subscripts);
287 let tree = resolve_plan_spec(self.plan_spec(), &subs, &shape_refs)
288 .map_err(|err| ExtensionLoweringError::new(err.to_string()))?;
289 let output = build_einsum_graph(builder, &tree, inputs, &shapes)
290 .map_err(|err| ExtensionLoweringError::new(err.to_string()))?;
291 Ok(Some(vec![output]))
292 }
293}
294
295fn concrete_sym_shape_slices(input_shapes: &[&[SymDim]]) -> Option<Vec<Vec<usize>>> {
296 input_shapes
297 .iter()
298 .map(|shape| {
299 shape
300 .iter()
301 .map(SymDim::constant_value)
302 .collect::<Option<Vec<_>>>()
303 })
304 .collect()
305}
306
307#[cfg(feature = "autodiff")]
309pub fn ad_rules() -> Result<ExtensionRuleSet, ExtensionRegistryError> {
310 ExtensionRuleSet::new().with_rule(Arc::new(EinsumAdRule))
311}
312
313#[derive(Debug)]
314#[cfg(feature = "autodiff")]
315struct EinsumAdRule;
316
317#[cfg(feature = "autodiff")]
318impl ExtensionAdRule for EinsumAdRule {
319 fn family_id(&self) -> &'static str {
320 EINSUM_EXTENSION_FAMILY_ID
321 }
322
323 fn linearize(
324 &self,
325 op: &dyn ExtensionOp,
326 builder: &mut dyn PrimitiveRuleBuilder,
327 primal_in: &[ValueKey<StdTensorOp>],
328 _primal_out: &[ValueKey<StdTensorOp>],
329 tangent_in: &[Option<LocalValueId>],
330 _ctx: &mut ShapeGuardContext,
331 ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
332 let op = downcast_ad_op(op, ADRuleKind::Jvp)?;
333 let mut terms = Vec::new();
334
335 for (active_idx, tangent) in tangent_in.iter().enumerate() {
336 let Some(dt) = tangent else {
337 continue;
338 };
339
340 let mut inputs = Vec::with_capacity(primal_in.len());
341 for (input_idx, key) in primal_in.iter().enumerate() {
342 if input_idx == active_idx {
343 inputs.push(ValueRef::Local(*dt));
344 } else {
345 inputs.push(ValueRef::External(key.clone()));
346 }
347 }
348
349 let out = builder.add_operation(
350 StdTensorOp::Extension(Arc::new(op.clone())),
351 inputs,
352 OperationRole::Linearized {
353 active_mask: (0..primal_in.len()).map(|idx| idx == active_idx).collect(),
354 },
355 );
356 terms.push(out[0]);
357 }
358
359 Ok(vec![sum_terms(builder, terms)])
360 }
361
362 fn transpose_rule(
363 &self,
364 op: &dyn ExtensionOp,
365 builder: &mut dyn PrimitiveRuleBuilder,
366 cotangent_out: &[Option<LocalValueId>],
367 inputs: &[ValueRef<StdTensorOp>],
368 mode: &OperationRole,
369 ctx: &mut ShapeGuardContext,
370 ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
371 let op = downcast_ad_op(op, ADRuleKind::Transpose)?;
372 let input_labels = &op.subscripts.inputs;
373 let output_labels = &op.subscripts.output;
374 let input_count = input_labels.len();
375
376 let Some(ct) = cotangent_out.first().copied().flatten() else {
377 return Ok(vec![None; input_count]);
378 };
379 let active_mask = match mode {
380 OperationRole::Linearized { active_mask } => active_mask,
381 OperationRole::Primary => return Ok(vec![None; input_count]),
382 };
383 let primal_input_shapes: Vec<Vec<SymDim>> = inputs
384 .iter()
385 .map(|input| ctx.shape_of(input).map(|shape| shape.to_vec()))
386 .collect::<Result<_, _>>()?;
387 let cotangent_shape = op.output_shape_hint.clone().ok_or_else(|| {
388 ADRuleError::unsupported(
389 "einsum VJP requires an output shape hint for cotangent planning",
390 ADRuleKind::Transpose,
391 )
392 })?;
393
394 let mut result = Vec::with_capacity(input_count);
395 for active_idx in 0..input_count {
396 if !active_mask.get(active_idx).copied().unwrap_or(false) {
397 result.push(None);
398 continue;
399 }
400
401 let mut available_labels: HashSet<u32> = output_labels.iter().copied().collect();
402 for (input_idx, labels) in input_labels.iter().enumerate() {
403 if input_idx != active_idx {
404 available_labels.extend(labels.iter().copied());
405 }
406 }
407 let vjp_output_labels: Vec<u32> = input_labels[active_idx]
408 .iter()
409 .copied()
410 .filter(|label| available_labels.contains(label))
411 .collect();
412 let mut vjp_input_labels = Vec::with_capacity(input_count);
413 let mut vjp_inputs = Vec::with_capacity(input_count);
414 let mut vjp_input_shapes = Vec::with_capacity(input_count);
415 vjp_input_labels.push(output_labels.clone());
416 vjp_inputs.push(ValueRef::Local(ct));
417 vjp_input_shapes.push(cotangent_shape.clone());
418
419 for input_idx in 0..input_count {
420 if input_idx == active_idx {
421 continue;
422 }
423 vjp_input_labels.push(input_labels[input_idx].clone());
424 vjp_input_shapes.push(primal_input_shapes[input_idx].clone());
425 vjp_inputs.push(conjugate_primal_if_complex(
426 builder,
427 inputs[input_idx].clone(),
428 ctx,
429 )?);
430 }
431
432 let output_shape_hint = primal_input_shapes[active_idx].clone();
433 let vjp_op = vjp_einsum_op_with_inherited_plan(
434 op,
435 active_idx,
436 EinsumSubscripts {
437 inputs: vjp_input_labels,
438 output: vjp_output_labels.clone(),
439 },
440 output_shape_hint.clone(),
441 &vjp_input_shapes,
442 )?;
443 let out = builder.add_operation(
444 StdTensorOp::Extension(Arc::new(vjp_op)),
445 vjp_inputs,
446 OperationRole::Linearized {
447 active_mask: std::iter::once(true)
448 .chain(std::iter::repeat_n(false, input_count.saturating_sub(1)))
449 .collect(),
450 },
451 );
452 let mut cotangent = out[0];
453 if vjp_output_labels != input_labels[active_idx] {
454 let remapped = broadcast_einsum_vjp_to_input_shape(
455 builder,
456 cotangent,
457 &vjp_output_labels,
458 &input_labels[active_idx],
459 inputs[active_idx].clone(),
460 &output_shape_hint,
461 )?;
462 cotangent = remapped;
463 }
464 result.push(Some(cotangent));
465 }
466
467 Ok(result)
468 }
469}
470
471#[cfg(feature = "autodiff")]
472fn vjp_einsum_op_with_inherited_plan(
473 primal_op: &EinsumExtensionOp,
474 active_idx: usize,
475 subscripts: EinsumSubscripts,
476 output_shape_hint: Vec<SymDim>,
477 input_shapes: &[Vec<SymDim>],
478) -> ADRuleResult<EinsumExtensionOp> {
479 let plan_spec =
480 vjp_plan_spec_for_active(primal_op.plan_spec(), primal_op.input_count(), active_idx)?;
481 let mut op = EinsumExtensionOp::with_output_shape_hint(
482 subscripts.clone(),
483 output_shape_hint,
484 plan_spec.clone(),
485 );
486 if let Some(concrete_shapes) = concrete_sym_shapes(input_shapes) {
487 let shape_refs: Vec<&[usize]> = concrete_shapes.iter().map(Vec::as_slice).collect();
488 let raw_subscripts = Subscripts::from(&subscripts);
489 let tree =
490 resolve_plan_spec(&plan_spec, &raw_subscripts, &shape_refs).map_err(|err| {
491 ADRuleError::unsupported(
492 format!(
493 "failed to resolve inherited einsum VJP plan for active input {active_idx}: {err}"
494 ),
495 ADRuleKind::Transpose,
496 )
497 })?;
498 op = op.with_static_tree_hint(Arc::new(tree));
499 }
500 Ok(op)
501}
502
503#[cfg(feature = "autodiff")]
504fn vjp_plan_spec_for_active(
505 primal_plan: &EinsumPlanSpec,
506 input_count: usize,
507 active_idx: usize,
508) -> ADRuleResult<EinsumPlanSpec> {
509 if active_idx >= input_count {
510 return Err(ADRuleError::unsupported(
511 format!("einsum VJP active input {active_idx} is outside {input_count} inputs"),
512 ADRuleKind::Transpose,
513 ));
514 }
515
516 match primal_plan {
517 EinsumPlanSpec::Auto(options) => Ok(EinsumPlanSpec::Auto(options.clone())),
518 EinsumPlanSpec::LeftToRight => Ok(EinsumPlanSpec::LeftToRight),
519 EinsumPlanSpec::Path(path) => {
520 let pairs = jax_path_to_v1_pairs(path, input_count).map_err(|err| {
521 ADRuleError::unsupported(
522 format!(
523 "failed to inherit einsum Path plan for VJP active input {active_idx}: {err}"
524 ),
525 ADRuleKind::Transpose,
526 )
527 })?;
528 derive_vjp_fixed_pairs(&pairs, input_count, active_idx).map(EinsumPlanSpec::FixedPairs)
529 }
530 EinsumPlanSpec::FixedPairs(pairs) => {
531 derive_vjp_fixed_pairs(pairs, input_count, active_idx).map(EinsumPlanSpec::FixedPairs)
532 }
533 }
534}
535
536#[cfg(feature = "autodiff")]
537fn derive_vjp_fixed_pairs(
538 primal_pairs: &[(usize, usize)],
539 input_count: usize,
540 active_idx: usize,
541) -> ADRuleResult<Vec<(usize, usize)>> {
542 if input_count == 0 {
543 return Err(ADRuleError::unsupported(
544 "einsum VJP cannot derive a plan for zero primal inputs",
545 ADRuleKind::Transpose,
546 ));
547 }
548 if active_idx >= input_count {
549 return Err(ADRuleError::unsupported(
550 format!("einsum VJP active input {active_idx} is outside {input_count} inputs"),
551 ADRuleKind::Transpose,
552 ));
553 }
554 let required_steps = input_count.saturating_sub(1);
555 if primal_pairs.len() != required_steps {
556 return Err(ADRuleError::unsupported(
557 format!(
558 "einsum VJP cannot inherit explicit plan for active input {active_idx}: \
559 expected {required_steps} primal steps for {input_count} inputs, got {}",
560 primal_pairs.len()
561 ),
562 ADRuleKind::Transpose,
563 ));
564 }
565 if input_count == 1 {
566 return Ok(Vec::new());
567 }
568
569 let children = fixed_pair_children(primal_pairs, input_count, active_idx)?;
570 let mut primal_to_vjp = vec![None; input_count];
571 let mut next_vjp_input = 1;
572 for (input_idx, slot) in primal_to_vjp.iter_mut().enumerate() {
573 if input_idx != active_idx {
574 *slot = Some(next_vjp_input);
575 next_vjp_input += 1;
576 }
577 }
578
579 let root = input_count + primal_pairs.len() - 1;
580 let mut pairs = Vec::with_capacity(required_steps);
581 let final_id = emit_vjp_adjoint(
582 root,
583 0,
584 &children,
585 input_count,
586 active_idx,
587 &primal_to_vjp,
588 &mut pairs,
589 )?;
590 let expected_final = input_count + pairs.len() - 1;
591 if final_id != expected_final || pairs.len() != required_steps {
592 return Err(ADRuleError::unsupported(
593 format!(
594 "einsum VJP plan derivation for active input {active_idx} produced an invalid \
595 tree: final id {final_id}, expected {expected_final}, steps {}",
596 pairs.len()
597 ),
598 ADRuleKind::Transpose,
599 ));
600 }
601 Ok(pairs)
602}
603
604#[cfg(feature = "autodiff")]
605fn fixed_pair_children(
606 pairs: &[(usize, usize)],
607 input_count: usize,
608 active_idx: usize,
609) -> ADRuleResult<Vec<Option<(usize, usize)>>> {
610 let mut live = vec![false; input_count + pairs.len()];
611 for slot in live.iter_mut().take(input_count) {
612 *slot = true;
613 }
614 let mut children = vec![None; input_count + pairs.len()];
615
616 for (step_idx, &(left, right)) in pairs.iter().enumerate() {
617 let next_idx = input_count + step_idx;
618 if left == right {
619 return Err(invalid_vjp_plan_error(
620 active_idx,
621 format!("pair ({left}, {right}) references the same operand"),
622 ));
623 }
624 if left >= next_idx || right >= next_idx {
625 return Err(invalid_vjp_plan_error(
626 active_idx,
627 format!("pair ({left}, {right}) references a non-existent operand"),
628 ));
629 }
630 if !live[left] || !live[right] {
631 return Err(invalid_vjp_plan_error(
632 active_idx,
633 format!("pair ({left}, {right}) references an operand that is no longer live"),
634 ));
635 }
636
637 live[left] = false;
638 live[right] = false;
639 live[next_idx] = true;
640 children[next_idx] = Some((left, right));
641 }
642
643 let live_count = live.iter().filter(|&&is_live| is_live).count();
644 if live_count != 1 {
645 return Err(invalid_vjp_plan_error(
646 active_idx,
647 format!("explicit plan leaves {live_count} live operands"),
648 ));
649 }
650
651 Ok(children)
652}
653
654#[cfg(feature = "autodiff")]
655fn emit_vjp_adjoint(
656 node: usize,
657 cotangent_id: usize,
658 children: &[Option<(usize, usize)>],
659 input_count: usize,
660 active_idx: usize,
661 primal_to_vjp: &[Option<usize>],
662 pairs: &mut Vec<(usize, usize)>,
663) -> ADRuleResult<usize> {
664 if node < input_count {
665 return if node == active_idx {
666 Ok(cotangent_id)
667 } else {
668 Err(invalid_vjp_plan_error(
669 active_idx,
670 format!("adjoint walk reached inactive leaf {node}"),
671 ))
672 };
673 }
674
675 let (left, right) = children.get(node).and_then(|child| *child).ok_or_else(|| {
676 invalid_vjp_plan_error(active_idx, format!("missing children for node {node}"))
677 })?;
678 let left_has_active = subtree_contains_active(left, children, input_count, active_idx)?;
679 let right_has_active = subtree_contains_active(right, children, input_count, active_idx)?;
680 match (left_has_active, right_has_active) {
681 (true, false) => {
682 let sibling_id = emit_vjp_subtree(
683 right,
684 children,
685 input_count,
686 active_idx,
687 primal_to_vjp,
688 pairs,
689 )?;
690 let next = push_vjp_pair(cotangent_id, sibling_id, input_count, pairs);
691 emit_vjp_adjoint(
692 left,
693 next,
694 children,
695 input_count,
696 active_idx,
697 primal_to_vjp,
698 pairs,
699 )
700 }
701 (false, true) => {
702 let sibling_id = emit_vjp_subtree(
703 left,
704 children,
705 input_count,
706 active_idx,
707 primal_to_vjp,
708 pairs,
709 )?;
710 let next = push_vjp_pair(cotangent_id, sibling_id, input_count, pairs);
711 emit_vjp_adjoint(
712 right,
713 next,
714 children,
715 input_count,
716 active_idx,
717 primal_to_vjp,
718 pairs,
719 )
720 }
721 (true, true) => Err(invalid_vjp_plan_error(
722 active_idx,
723 format!("both children of node {node} contain the active input"),
724 )),
725 (false, false) => Err(invalid_vjp_plan_error(
726 active_idx,
727 format!("neither child of node {node} contains the active input"),
728 )),
729 }
730}
731
732#[cfg(feature = "autodiff")]
733fn emit_vjp_subtree(
734 node: usize,
735 children: &[Option<(usize, usize)>],
736 input_count: usize,
737 active_idx: usize,
738 primal_to_vjp: &[Option<usize>],
739 pairs: &mut Vec<(usize, usize)>,
740) -> ADRuleResult<usize> {
741 if node < input_count {
742 return primal_to_vjp[node].ok_or_else(|| {
743 invalid_vjp_plan_error(
744 active_idx,
745 format!("sibling subtree unexpectedly reached active leaf {node}"),
746 )
747 });
748 }
749
750 let (left, right) = children.get(node).and_then(|child| *child).ok_or_else(|| {
751 invalid_vjp_plan_error(active_idx, format!("missing children for node {node}"))
752 })?;
753 let left_id = emit_vjp_subtree(
754 left,
755 children,
756 input_count,
757 active_idx,
758 primal_to_vjp,
759 pairs,
760 )?;
761 let right_id = emit_vjp_subtree(
762 right,
763 children,
764 input_count,
765 active_idx,
766 primal_to_vjp,
767 pairs,
768 )?;
769 Ok(push_vjp_pair(left_id, right_id, input_count, pairs))
770}
771
772#[cfg(feature = "autodiff")]
773fn push_vjp_pair(
774 left: usize,
775 right: usize,
776 n_vjp_inputs: usize,
777 pairs: &mut Vec<(usize, usize)>,
778) -> usize {
779 pairs.push((left, right));
780 n_vjp_inputs + pairs.len() - 1
781}
782
783#[cfg(feature = "autodiff")]
784fn subtree_contains_active(
785 node: usize,
786 children: &[Option<(usize, usize)>],
787 input_count: usize,
788 active_idx: usize,
789) -> ADRuleResult<bool> {
790 if node < input_count {
791 return Ok(node == active_idx);
792 }
793 let (left, right) = children.get(node).and_then(|child| *child).ok_or_else(|| {
794 invalid_vjp_plan_error(active_idx, format!("missing children for node {node}"))
795 })?;
796 Ok(
797 subtree_contains_active(left, children, input_count, active_idx)?
798 || subtree_contains_active(right, children, input_count, active_idx)?,
799 )
800}
801
802#[cfg(feature = "autodiff")]
803fn invalid_vjp_plan_error(active_idx: usize, reason: String) -> ADRuleError {
804 ADRuleError::unsupported(
805 format!("einsum VJP cannot inherit explicit plan for active input {active_idx}: {reason}"),
806 ADRuleKind::Transpose,
807 )
808}
809
810#[cfg(feature = "autodiff")]
811fn concrete_sym_shapes(shapes: &[Vec<SymDim>]) -> Option<Vec<Vec<usize>>> {
812 shapes
813 .iter()
814 .map(|shape| shape.iter().map(SymDim::constant_value).collect())
815 .collect()
816}
817
818#[cfg(feature = "autodiff")]
819fn broadcast_einsum_vjp_to_input_shape(
820 builder: &mut dyn PrimitiveRuleBuilder,
821 cotangent: LocalValueId,
822 cotangent_labels: &[u32],
823 input_labels: &[u32],
824 shape_source: ValueRef<StdTensorOp>,
825 input_shape: &[SymDim],
826) -> ADRuleResult<LocalValueId> {
827 let shape: Vec<DimExpr> = input_shape
828 .iter()
829 .enumerate()
830 .map(|(axis, _)| DimExpr::InputDim { input_idx: 1, axis })
831 .collect();
832 let dims = map_label_occurrences(cotangent_labels, input_labels).ok_or_else(|| {
833 ADRuleError::unsupported(
834 format!(
835 "einsum VJP broadcast remap failed for cotangent labels {cotangent_labels:?} \
836 into active input labels {input_labels:?}"
837 ),
838 ADRuleKind::Transpose,
839 )
840 })?;
841 let mut inputs = vec![ValueRef::Local(cotangent)];
842 if !shape.is_empty() {
843 inputs.push(shape_source);
844 }
845 let broadcast = builder.add_operation(
846 StdTensorOp::BroadcastInDim { shape, dims },
847 inputs,
848 OperationRole::Linearized {
849 active_mask: vec![true, false],
850 },
851 )[0];
852 Ok(project_repeated_labels_to_diagonal(
853 builder,
854 broadcast,
855 input_labels,
856 ))
857}
858
859#[cfg(feature = "autodiff")]
860fn map_label_occurrences(source_labels: &[u32], target_labels: &[u32]) -> Option<Vec<usize>> {
861 let mut used = vec![false; target_labels.len()];
862 source_labels
863 .iter()
864 .map(|label| {
865 let axis = target_labels
866 .iter()
867 .enumerate()
868 .find_map(|(axis, target)| (!used[axis] && target == label).then_some(axis))?;
869 used[axis] = true;
870 Some(axis)
871 })
872 .collect()
873}
874
875#[cfg(feature = "autodiff")]
876fn project_repeated_labels_to_diagonal(
877 builder: &mut dyn PrimitiveRuleBuilder,
878 cotangent: LocalValueId,
879 labels: &[u32],
880) -> LocalValueId {
881 let mut result = cotangent;
882 let mut first_axis_by_label = HashMap::new();
883 for (axis_b, label) in labels.iter().copied().enumerate() {
884 let Some(&axis_a) = first_axis_by_label.get(&label) else {
885 first_axis_by_label.insert(label, axis_b);
886 continue;
887 };
888 let extracted = builder.add_operation(
889 StdTensorOp::ExtractDiag { axis_a, axis_b },
890 vec![ValueRef::Local(result)],
891 OperationRole::Linearized {
892 active_mask: vec![true],
893 },
894 )[0];
895 result = builder.add_operation(
896 StdTensorOp::EmbedDiag { axis_a, axis_b },
897 vec![ValueRef::Local(extracted)],
898 OperationRole::Linearized {
899 active_mask: vec![true],
900 },
901 )[0];
902 }
903 result
904}
905
906define_extension_runtime! {
907 runtime = EinsumRuntime,
908 family_id = EINSUM_EXTENSION_FAMILY_ID,
909 op_type = EinsumExtensionOp,
910 execute = execute_einsum_extension,
911 execute_reads = execute_einsum_extension_reads,
912 register_fn = register_runtime,
913}
914
915fn execute_einsum_extension<B: TensorBackend + 'static>(
916 op: &EinsumExtensionOp,
917 inputs: &[&Tensor],
918 ctx: &mut ExtensionExecutionContext<'_, B>,
919) -> tenferro_tensor::Result<Vec<Tensor>> {
920 if inputs.is_empty() {
921 return Err(tenferro_tensor::Error::InvalidConfig {
922 op: "einsum_extension",
923 message: "einsum requires at least one input tensor".into(),
924 });
925 }
926
927 let shapes: Vec<Vec<usize>> = inputs
928 .iter()
929 .map(|tensor| tensor.shape().to_vec())
930 .collect();
931 let shape_refs: Vec<&[usize]> = shapes.iter().map(Vec::as_slice).collect();
932 let subs = Subscripts::from(op.subscripts());
933 let tree = if let Some(tree) = op.static_tree() {
934 Arc::clone(tree)
935 } else {
936 cached_runtime_tree(ctx, op.subscripts(), op.plan_spec(), &shapes, || {
937 resolve_plan_spec(op.plan_spec(), &subs, &shape_refs)
938 })?
939 };
940
941 if is_binary_non_contracting(&subs) {
942 let output = ctx
943 .backend_mut()
944 .with_backend_session(|exec| crate::eager::eager_einsum_exec(exec, inputs, &tree))?;
945 return Ok(vec![output]);
946 }
947
948 let (backend, caches) = ctx.parts_mut();
949 let compiler_options = tenferro_runtime::extension::CompilerOptions::default();
950 let optimizer_fingerprint = compiler_options.optimizer.fingerprint();
951 let key = runtime_exec_program_cache_key(op, inputs, &shapes, optimizer_fingerprint);
952 if caches
953 .get_mut::<CachedRuntimeExecProgram<B::RuntimeCache>>(&key)
954 .is_none()
955 {
956 let cached =
957 build_runtime_exec_program::<B>(tree.as_ref(), inputs, &shapes, compiler_options)?;
958 let key_retained_bytes = runtime_exec_program_key_retained_bytes(op, inputs, &shapes);
959 caches.put_with_retained_bytes(key, cached, move |cached| {
960 saturating_sum([
961 key_retained_bytes,
962 cached_runtime_exec_program_retained_bytes(cached),
963 ])
964 });
965 }
966 let cached = caches
967 .get_mut::<CachedRuntimeExecProgram<B::RuntimeCache>>(&key)
968 .ok_or_else(|| {
969 tenferro_tensor::Error::backend_failure(
970 "einsum_extension",
971 "runtime exec program cache entry missing after insertion",
972 )
973 })?;
974 let program_inputs = runtime_program_inputs(inputs, cached.input_indices.as_slice())?;
975 let mut outputs = tenferro_runtime::extension::execute_lowered_program_with_backend_cache(
976 backend,
977 &cached.program,
978 program_inputs,
979 &mut cached.backend_cache,
980 )
981 .map_err(|err| tenferro_tensor::Error::backend_failure("einsum_extension", err.to_string()))?;
982 if outputs.len() != 1 {
983 return Err(tenferro_tensor::Error::backend_failure(
984 "einsum_extension",
985 format!("expected 1 output, got {}", outputs.len()),
986 ));
987 }
988 Ok(vec![outputs.remove(0)])
989}
990
991fn execute_einsum_extension_reads<B: TensorBackend + 'static>(
992 op: &EinsumExtensionOp,
993 inputs: &[TensorRead<'_>],
994 ctx: &mut ExtensionExecutionContext<'_, B>,
995) -> tenferro_tensor::Result<Vec<Tensor>> {
996 if inputs
997 .iter()
998 .all(|input| matches!(input, TensorRead::Tensor(_)))
999 {
1000 let input_refs: Vec<&Tensor> = inputs
1001 .iter()
1002 .map(|input| match input {
1003 TensorRead::Tensor(tensor) => *tensor,
1004 TensorRead::View(_) => unreachable!("view input filtered above"),
1005 })
1006 .collect();
1007 return execute_einsum_extension(op, &input_refs, ctx);
1008 }
1009
1010 if inputs.is_empty() {
1011 return Err(tenferro_tensor::Error::InvalidConfig {
1012 op: "einsum_extension",
1013 message: "einsum requires at least one input tensor".into(),
1014 });
1015 }
1016
1017 let shapes: Vec<Vec<usize>> = inputs.iter().map(|input| input.shape().to_vec()).collect();
1018 let shape_refs: Vec<&[usize]> = shapes.iter().map(Vec::as_slice).collect();
1019 let subs = Subscripts::from(op.subscripts());
1020 let tree = if let Some(tree) = op.static_tree() {
1021 Arc::clone(tree)
1022 } else {
1023 cached_runtime_tree(ctx, op.subscripts(), op.plan_spec(), &shapes, || {
1024 resolve_plan_spec(op.plan_spec(), &subs, &shape_refs)
1025 })?
1026 };
1027 let output = ctx
1028 .backend_mut()
1029 .with_backend_session(|exec| crate::eager::eager_einsum_exec_read(exec, inputs, &tree))?;
1030 Ok(vec![output])
1031}
1032
1033fn is_binary_non_contracting(subs: &Subscripts) -> bool {
1034 if subs.inputs.len() != 2 {
1035 return false;
1036 }
1037
1038 let lhs = &subs.inputs[0];
1039 let rhs = &subs.inputs[1];
1040 let output = &subs.output;
1041 !lhs.iter()
1042 .any(|label| rhs.contains(label) && !output.contains(label))
1043}
1044
1045struct CachedRuntimeExecProgram<C> {
1046 program: ExecProgram,
1047 input_indices: InputIndexVec,
1048 backend_cache: C,
1049 optimizer_fingerprint: u64,
1050}
1051
1052fn runtime_exec_program_cache_key(
1053 op: &EinsumExtensionOp,
1054 inputs: &[&Tensor],
1055 shapes: &[Vec<usize>],
1056 optimizer_fingerprint: u64,
1057) -> ExtensionCacheKey {
1058 let input_dtypes: Vec<DType> = inputs.iter().map(|tensor| tensor.dtype()).collect();
1059 let mut plan_hasher = std::collections::hash_map::DefaultHasher::new();
1060 hash_einsum_plan_spec(op.plan_spec(), &mut plan_hasher);
1061 let key_data = (
1062 op.subscripts().clone(),
1063 shapes.to_vec(),
1064 input_dtypes.clone(),
1065 plan_hasher.finish(),
1066 optimizer_fingerprint,
1067 );
1068 ExtensionCacheKey::new(
1069 EINSUM_EXTENSION_FAMILY_ID,
1070 EINSUM_RUNTIME_EXEC_PROGRAMS_CACHE,
1071 hash_value(&key_data),
1072 )
1073}
1074
1075fn runtime_exec_program_key_retained_bytes(
1076 op: &EinsumExtensionOp,
1077 inputs: &[&Tensor],
1078 shapes: &[Vec<usize>],
1079) -> usize {
1080 saturating_sum([
1081 einsum_subscripts_retained_bytes(op.subscripts()),
1082 saturating_sum(shapes.iter().map(vec_retained_bytes)),
1083 inputs.len().saturating_mul(std::mem::size_of::<DType>()),
1084 std::mem::size_of::<u64>(),
1085 std::mem::size_of::<u64>(),
1086 ])
1087}
1088
1089fn build_runtime_exec_program<B: TensorBackend>(
1090 tree: &ContractionTree,
1091 inputs: &[&Tensor],
1092 shapes: &[Vec<usize>],
1093 compiler_options: tenferro_runtime::extension::CompilerOptions,
1094) -> tenferro_tensor::Result<CachedRuntimeExecProgram<B::RuntimeCache>> {
1095 let mut builder = GraphBuilder::<StdTensorOp>::new();
1096 let mut input_vals = Vec::with_capacity(inputs.len());
1097 for input_idx in 0..inputs.len() {
1098 let local = builder.add_input(TensorInputKey::User {
1099 id: input_idx as u64,
1100 });
1101 input_vals.push(ValueRef::Local(local));
1102 }
1103
1104 let result_ref = build_einsum_graph(&mut builder, tree, &input_vals, shapes)
1105 .map_err(einsum_runtime_error)?;
1106 let result_local = match result_ref {
1107 ValueRef::Local(local) => local,
1108 ValueRef::External(_) => {
1109 return Err(tenferro_tensor::Error::backend_failure(
1110 "einsum_extension",
1111 "einsum builder returned an external value at runtime",
1112 ))
1113 }
1114 };
1115 builder.set_outputs(vec![result_local]);
1116 let graph = Arc::new(builder.build());
1117 let output_key = graph.values()[result_local].key.clone();
1118
1119 let view = resolve(vec![graph]);
1120 let graph = materialize_merge(&view, &[output_key]);
1121 let compiled = compile(&graph);
1122
1123 let mut input_indices = InputIndexVec::new();
1124 let mut input_dtypes = Vec::with_capacity(graph.inputs.len());
1125 let mut input_shapes = Vec::with_capacity(graph.inputs.len());
1126 for key in &graph.inputs {
1127 match key {
1128 ValueKey::Input(TensorInputKey::User { id }) => {
1129 let input_idx = *id as usize;
1130 let tensor = inputs.get(input_idx).ok_or_else(|| {
1131 tenferro_tensor::Error::backend_failure(
1132 "einsum_extension",
1133 format!("runtime input {input_idx} missing"),
1134 )
1135 })?;
1136 input_indices.push(input_idx);
1137 input_dtypes.push(tensor.dtype());
1138 input_shapes.push(tenferro_ops::dim_expr::DimExpr::from_concrete(
1139 tensor.shape(),
1140 ));
1141 }
1142 other => {
1143 return Err(tenferro_tensor::Error::backend_failure(
1144 "einsum_extension",
1145 format!("unexpected runtime input key: {other:?}"),
1146 ))
1147 }
1148 }
1149 }
1150
1151 let program = tenferro_runtime::extension::compile_std_to_exec_with_options(
1152 &compiled,
1153 &input_dtypes,
1154 &input_shapes,
1155 compiler_options,
1156 )
1157 .map_err(|err| tenferro_tensor::Error::backend_failure("einsum_extension", err.to_string()))?;
1158 Ok(CachedRuntimeExecProgram {
1159 program,
1160 input_indices,
1161 backend_cache: B::RuntimeCache::default(),
1162 optimizer_fingerprint: compiler_options.optimizer.fingerprint(),
1163 })
1164}
1165
1166fn runtime_program_inputs(
1167 inputs: &[&Tensor],
1168 input_indices: &[usize],
1169) -> tenferro_tensor::Result<Vec<Tensor>> {
1170 let mut program_inputs = Vec::with_capacity(input_indices.len());
1171 for &input_idx in input_indices {
1172 let tensor = inputs.get(input_idx).ok_or_else(|| {
1173 tenferro_tensor::Error::backend_failure(
1174 "einsum_extension",
1175 format!("runtime input {input_idx} missing"),
1176 )
1177 })?;
1178 program_inputs.push((*tensor).clone());
1179 }
1180 Ok(program_inputs)
1181}
1182
1183fn cached_runtime_exec_program_retained_bytes<C: RuntimeCacheControl>(
1184 cached: &CachedRuntimeExecProgram<C>,
1185) -> usize {
1186 saturating_sum([
1187 std::mem::size_of::<CachedRuntimeExecProgram<C>>(),
1188 exec_program_retained_bytes(&cached.program),
1189 smallvec_retained_bytes(&cached.input_indices),
1190 cached.backend_cache.stats().retained_bytes,
1191 std::mem::size_of_val(&cached.optimizer_fingerprint),
1192 ])
1193}
1194
1195fn smallvec_retained_bytes<A: smallvec::Array>(values: &SmallVec<A>) -> usize {
1196 if values.spilled() {
1197 values
1198 .capacity()
1199 .saturating_mul(std::mem::size_of::<A::Item>())
1200 } else {
1201 0
1202 }
1203}
1204
1205fn exec_program_retained_bytes(program: &ExecProgram) -> usize {
1206 saturating_sum([
1207 std::mem::size_of::<ExecProgram>(),
1208 vec_retained_bytes(&program.instructions),
1209 saturating_sum(
1210 program
1211 .instructions
1212 .iter()
1213 .map(exec_instruction_retained_bytes),
1214 ),
1215 vec_retained_bytes(&program.input_slots),
1216 vec_retained_bytes(&program.output_slots),
1217 ])
1218}
1219
1220fn exec_instruction_retained_bytes(inst: &ExecInstruction) -> usize {
1221 saturating_sum([
1222 std::mem::size_of::<ExecInstruction>(),
1223 exec_op_retained_bytes(&inst.op),
1224 vec_retained_bytes(&inst.input_slots),
1225 vec_retained_bytes(&inst.output_slots),
1226 vec_of_vec_retained_bytes(&inst.output_shapes),
1227 vec_of_vec_retained_bytes(&inst.output_extents),
1228 vec_retained_bytes(&inst.last_use),
1229 ])
1230}
1231
1232fn exec_op_retained_bytes(op: &ExecOp) -> usize {
1233 match op {
1234 ExecOp::Constant { bytes, .. } => vec_retained_bytes(bytes),
1235 ExecOp::Extension(extension) => std::mem::size_of_val(extension),
1236 _ => 0,
1237 }
1238}
1239
1240fn cached_runtime_tree<B: TensorBackend>(
1241 ctx: &mut ExtensionExecutionContext<'_, B>,
1242 subscripts: &EinsumSubscripts,
1243 plan_spec: &EinsumPlanSpec,
1244 shapes: &[Vec<usize>],
1245 build: impl FnOnce() -> EinsumResult<ContractionTree>,
1246) -> tenferro_tensor::Result<Arc<ContractionTree>> {
1247 let mut plan_hasher = std::collections::hash_map::DefaultHasher::new();
1248 hash_einsum_plan_spec(plan_spec, &mut plan_hasher);
1249 let key_data = (subscripts.clone(), shapes.to_vec(), plan_hasher.finish());
1250 let key = ExtensionCacheKey::new(
1251 EINSUM_EXTENSION_FAMILY_ID,
1252 EINSUM_RUNTIME_PLANS_CACHE,
1253 hash_value(&key_data),
1254 );
1255 if let Some(cached) = ctx.caches_mut().get::<Arc<ContractionTree>>(&key) {
1256 return Ok(Arc::clone(cached));
1257 }
1258
1259 let tree = Arc::new(build().map_err(einsum_runtime_error)?);
1260 let retained_bytes = saturating_sum([
1261 einsum_subscripts_retained_bytes(subscripts),
1262 saturating_sum(shapes.iter().map(vec_retained_bytes)),
1263 std::mem::size_of::<u64>(),
1264 tree.retained_bytes_for_cache_stats(),
1265 ]);
1266 ctx.caches_mut().put(key, Arc::clone(&tree), retained_bytes);
1267 Ok(tree)
1268}
1269
1270fn einsum_runtime_error(error: EinsumError) -> tenferro_tensor::Error {
1271 error.to_tensor_error("einsum_extension")
1272}
1273
1274fn hash_value<T: Hash>(value: &T) -> u64 {
1275 let mut hasher = std::collections::hash_map::DefaultHasher::new();
1276 value.hash(&mut hasher);
1277 hasher.finish()
1278}
1279
1280#[cfg(feature = "autodiff")]
1281fn downcast_ad_op(op: &dyn ExtensionOp, kind: ADRuleKind) -> ADRuleResult<&EinsumExtensionOp> {
1282 op.as_any()
1283 .downcast_ref::<EinsumExtensionOp>()
1284 .ok_or_else(|| ADRuleError::unsupported("tenferro.einsum.v1 payload type mismatch", kind))
1285}
1286
1287#[cfg(feature = "autodiff")]
1288fn sum_terms(
1289 builder: &mut dyn PrimitiveRuleBuilder,
1290 terms: Vec<LocalValueId>,
1291) -> Option<LocalValueId> {
1292 match terms.as_slice() {
1293 [] => None,
1294 [only] => Some(*only),
1295 [head, tail @ ..] => {
1296 let mut result = *head;
1297 for &term in tail {
1298 let sum = builder.add_operation(
1299 StdTensorOp::Add,
1300 vec![ValueRef::Local(result), ValueRef::Local(term)],
1301 OperationRole::Linearized {
1302 active_mask: vec![true, true],
1303 },
1304 );
1305 result = sum[0];
1306 }
1307 Some(result)
1308 }
1309 }
1310}
1311
1312#[cfg(feature = "autodiff")]
1313fn conjugate_primal_if_complex(
1314 builder: &mut dyn PrimitiveRuleBuilder,
1315 input: ValueRef<StdTensorOp>,
1316 ctx: &mut ShapeGuardContext,
1317) -> ADRuleResult<ValueRef<StdTensorOp>> {
1318 Ok(match ctx.dtype_of(&input)? {
1319 DType::F32 | DType::F64 | DType::I32 | DType::I64 | DType::Bool => input,
1320 DType::C32 | DType::C64 => ValueRef::Local(
1321 builder.add_operation(StdTensorOp::Conj, vec![input], OperationRole::Primary)[0],
1322 ),
1323 })
1324}
1325
1326fn promote_dtypes(dtypes: impl IntoIterator<Item = DType>) -> DType {
1327 dtypes
1328 .into_iter()
1329 .reduce(promote_dtype)
1330 .unwrap_or(DType::F64)
1331}
1332
1333fn promote_dtype(lhs: DType, rhs: DType) -> DType {
1334 use DType::*;
1335 match (lhs, rhs) {
1336 (Bool, Bool) => Bool,
1337 (Bool, other) | (other, Bool) => other,
1338 (I32, I32) => I32,
1339 (I32, I64) | (I64, I32) | (I64, I64) => I64,
1340 (I32 | I64, F32 | F64) | (F32 | F64, I32 | I64) => F64,
1341 (I32 | I64, C32 | C64) | (C32 | C64, I32 | I64) => C64,
1342 (F32, F32) => F32,
1343 (F32, F64) | (F64, F32) | (F64, F64) => F64,
1344 (F32, C32) | (C32, F32) | (C32, C32) => C32,
1345 (F32, C64) | (C64, F32) => C64,
1346 (F64, C32 | C64) | (C32 | C64, F64) => C64,
1347 (C32, C64) | (C64, C32) | (C64, C64) => C64,
1348 }
1349}
1350
1351#[cfg(test)]
1352mod tests;