Skip to main content

tenferro_einsum/
extension.rs

1use std::any::Any;
2use std::collections::HashMap;
3#[cfg(feature = "autodiff")]
4use std::collections::HashSet;
5use std::hash::{Hash, Hasher};
6use std::sync::Arc;
7
8use computegraph::compile::compile;
9use computegraph::graph::GraphBuilder;
10use computegraph::materialize::materialize_merge;
11use computegraph::resolve::resolve;
12#[cfg(feature = "autodiff")]
13use computegraph::types::{LocalValueId, OperationRole};
14use computegraph::types::{ValueKey, ValueRef};
15use smallvec::SmallVec;
16use tenferro_extension_macros::define_extension_runtime;
17#[cfg(feature = "autodiff")]
18use tenferro_ops::ad::context::ShapeGuardContext;
19#[cfg(feature = "autodiff")]
20use tenferro_ops::ad::PrimitiveRuleBuilder;
21#[cfg(feature = "autodiff")]
22use tenferro_ops::dim_expr::DimExpr;
23#[cfg(feature = "autodiff")]
24use tenferro_ops::ext_op::ExtensionAdRule;
25use tenferro_ops::ext_op::{ExtensionLoweringError, ExtensionLoweringResult, ExtensionOp};
26use tenferro_ops::input_key::TensorInputKey;
27use tenferro_ops::std_tensor_op::StdTensorOp;
28use tenferro_ops::sym_dim::SymDim;
29#[cfg(feature = "autodiff")]
30use tenferro_ops::{ExtensionRegistryError, ExtensionRuleSet};
31use tenferro_runtime::extension::{
32    ExecInstruction, ExecOp, ExecProgram, ExtensionCacheKey, ExtensionExecutionContext,
33};
34use tenferro_tensor::{DType, RuntimeCacheControl, Tensor, TensorBackend, TensorRead};
35#[cfg(feature = "autodiff")]
36use tidu::{ADRuleError, ADRuleKind, ADRuleResult};
37
38use crate::builder::build_einsum_graph;
39use crate::cache::{
40    einsum_subscripts_retained_bytes, saturating_sum, vec_of_vec_retained_bytes,
41    vec_retained_bytes, EINSUM_EXTENSION_FAMILY_ID, EINSUM_RUNTIME_EXEC_PROGRAMS_CACHE,
42    EINSUM_RUNTIME_PLANS_CACHE,
43};
44#[cfg(test)]
45use crate::optimize::default_auto_options;
46#[cfg(feature = "autodiff")]
47use crate::optimize::jax_path_to_v1_pairs;
48use crate::optimize::{hash_einsum_plan_spec, plan_specs_equal, resolve_plan_spec, EinsumPlanSpec};
49use crate::{
50    ContractionTree, EinsumSubscripts, Error as EinsumError, Result as EinsumResult, Subscripts,
51};
52
53type InputIndexVec = SmallVec<[usize; 8]>;
54
55/// Standard einsum extension payload.
56///
57/// This mirrors the current `tenferro.einsum.v1` payload shape. Runtime-owned
58/// execution goes through [`EinsumRuntime`]; [`ExtensionOp::eager_execute`]
59/// remains only as a host reference implementation for direct context-free
60/// extension calls.
61#[derive(Clone)]
62pub(crate) struct EinsumExtensionOp {
63    subscripts: EinsumSubscripts,
64    plan_spec: EinsumPlanSpec,
65    /// Optional execution hint. This is intentionally excluded from
66    /// `ExtensionOp` identity: the shape-independent `plan_spec` carries
67    /// user planning policy, while this tree is a resolved cacheable hint.
68    static_tree: Option<Arc<ContractionTree>>,
69    output_shape_hint: Option<Vec<SymDim>>,
70}
71
72impl std::fmt::Debug for EinsumExtensionOp {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        f.debug_struct("EinsumExtensionOp")
75            .field("subscripts", &self.subscripts)
76            .field("plan_spec", &self.plan_spec)
77            .field("has_static_tree", &self.static_tree.is_some())
78            .field("output_shape_hint", &self.output_shape_hint)
79            .finish()
80    }
81}
82
83impl EinsumExtensionOp {
84    /// Create an einsum extension payload without a precomputed plan.
85    #[must_use]
86    #[cfg(test)]
87    pub(crate) fn new(subscripts: EinsumSubscripts) -> Self {
88        Self::with_plan_spec(subscripts, EinsumPlanSpec::Auto(default_auto_options()))
89    }
90
91    #[must_use]
92    pub(crate) fn with_plan_spec(subscripts: EinsumSubscripts, plan_spec: EinsumPlanSpec) -> Self {
93        Self {
94            subscripts,
95            plan_spec,
96            static_tree: None,
97            output_shape_hint: None,
98        }
99    }
100
101    /// Create an einsum extension payload with a precomputed plan.
102    #[must_use]
103    #[cfg(test)]
104    pub(crate) fn with_static_tree(
105        subscripts: EinsumSubscripts,
106        tree: Arc<ContractionTree>,
107    ) -> Self {
108        Self::new(subscripts).with_static_tree_hint(tree)
109    }
110
111    /// Create an einsum extension payload with an explicit output shape hint.
112    #[must_use]
113    pub(crate) fn with_output_shape_hint(
114        subscripts: EinsumSubscripts,
115        output_shape_hint: Vec<SymDim>,
116        plan_spec: EinsumPlanSpec,
117    ) -> Self {
118        let mut op = Self::with_plan_spec(subscripts, plan_spec);
119        op.output_shape_hint = Some(output_shape_hint);
120        op
121    }
122
123    /// Attach a precomputed contraction tree as an execution hint.
124    #[must_use]
125    #[cfg(any(test, feature = "autodiff"))]
126    pub(crate) fn with_static_tree_hint(mut self, tree: Arc<ContractionTree>) -> Self {
127        self.static_tree = Some(tree);
128        self
129    }
130
131    /// Return the canonical subscripts.
132    #[must_use]
133    pub(crate) fn subscripts(&self) -> &EinsumSubscripts {
134        &self.subscripts
135    }
136
137    /// Return the shape-independent planning policy.
138    #[must_use]
139    pub(crate) fn plan_spec(&self) -> &EinsumPlanSpec {
140        &self.plan_spec
141    }
142
143    /// Return the precomputed contraction tree, if present.
144    #[must_use]
145    pub(crate) fn static_tree(&self) -> Option<&Arc<ContractionTree>> {
146        self.static_tree.as_ref()
147    }
148}
149
150impl ExtensionOp for EinsumExtensionOp {
151    fn family_id(&self) -> &'static str {
152        EINSUM_EXTENSION_FAMILY_ID
153    }
154
155    fn payload_hash(&self, hasher: &mut dyn Hasher) {
156        hasher.write_usize(self.subscripts.inputs.len());
157        for input in &self.subscripts.inputs {
158            hasher.write_usize(input.len());
159            for label in input {
160                hasher.write_u32(*label);
161            }
162        }
163        hasher.write_usize(self.subscripts.output.len());
164        for label in &self.subscripts.output {
165            hasher.write_u32(*label);
166        }
167        hash_einsum_plan_spec(self.plan_spec(), hasher);
168        if let Some(shape) = &self.output_shape_hint {
169            hasher.write_usize(shape.len());
170            for dim in shape {
171                match dim.constant_value() {
172                    Some(value) => {
173                        hasher.write_u8(1);
174                        hasher.write_usize(value);
175                    }
176                    None => hasher.write_u8(0),
177                }
178            }
179        } else {
180            hasher.write_usize(usize::MAX);
181        }
182    }
183
184    fn payload_eq(&self, other: &dyn ExtensionOp) -> bool {
185        other.as_any().downcast_ref::<Self>().is_some_and(|that| {
186            self.subscripts == that.subscripts
187                && plan_specs_equal(self.plan_spec(), that.plan_spec())
188                && self.output_shape_hint == that.output_shape_hint
189        })
190    }
191
192    fn clone_arc(&self) -> Arc<dyn ExtensionOp> {
193        Arc::new(self.clone())
194    }
195
196    fn as_any(&self) -> &dyn Any {
197        self
198    }
199
200    fn input_count(&self) -> usize {
201        self.subscripts.inputs.len()
202    }
203
204    fn output_count(&self) -> usize {
205        1
206    }
207
208    fn infer_output_meta(
209        &self,
210        input_dtypes: &[DType],
211        input_shapes: &[&[SymDim]],
212    ) -> Vec<(DType, Vec<SymDim>)> {
213        if input_shapes.len() != self.subscripts.inputs.len()
214            || input_dtypes.len() != input_shapes.len()
215        {
216            return Vec::new();
217        }
218
219        let mut label_dims: HashMap<u32, SymDim> = HashMap::new();
220        for (labels, shape) in self.subscripts.inputs.iter().zip(input_shapes.iter()) {
221            if labels.len() != shape.len() {
222                return Vec::new();
223            }
224            for (&label, dim) in labels.iter().zip(shape.iter()) {
225                if let Some(existing) = label_dims.get(&label) {
226                    if let (Some(lhs), Some(rhs)) =
227                        (existing.constant_value(), dim.constant_value())
228                    {
229                        if lhs != rhs {
230                            return Vec::new();
231                        }
232                    }
233                } else {
234                    label_dims.insert(label, dim.clone());
235                }
236            }
237        }
238
239        let output_shape = match &self.output_shape_hint {
240            Some(shape) if shape.iter().all(|dim| dim.constant_value().is_some()) => shape.clone(),
241            _ => self
242                .subscripts
243                .output
244                .iter()
245                .map(|label| label_dims.get(label).cloned())
246                .collect::<Option<Vec<_>>>()
247                .unwrap_or_default(),
248        };
249        if output_shape.len() != self.subscripts.output.len() {
250            return Vec::new();
251        }
252        vec![(promote_dtypes(input_dtypes.iter().copied()), output_shape)]
253    }
254
255    fn eager_execute(&self, inputs: &[&Tensor]) -> tenferro_tensor::Result<Vec<Tensor>> {
256        let mut backend = tenferro_cpu::CpuBackend::new();
257        let subscripts = Subscripts::from(&self.subscripts);
258        crate::eager::eager_einsum_subscripts(&mut backend, inputs, &subscripts)
259            .map(|output| vec![output])
260    }
261
262    fn lower_to_standard_ops(
263        &self,
264        builder: &mut GraphBuilder<StdTensorOp>,
265        inputs: &[ValueRef<StdTensorOp>],
266        input_dtypes: &[DType],
267        input_shapes: &[&[SymDim]],
268    ) -> ExtensionLoweringResult {
269        if inputs.len() != self.input_count()
270            || input_dtypes.len() != self.input_count()
271            || input_shapes.len() != self.input_count()
272        {
273            return Err(ExtensionLoweringError::new(format!(
274                "einsum extension expects {} inputs, got values={}, dtypes={}, shapes={}",
275                self.input_count(),
276                inputs.len(),
277                input_dtypes.len(),
278                input_shapes.len()
279            )));
280        }
281
282        let Some(shapes) = concrete_sym_shape_slices(input_shapes) else {
283            return Ok(None);
284        };
285        let shape_refs: Vec<&[usize]> = shapes.iter().map(Vec::as_slice).collect();
286        let subs = Subscripts::from(&self.subscripts);
287        let tree = resolve_plan_spec(self.plan_spec(), &subs, &shape_refs)
288            .map_err(|err| ExtensionLoweringError::new(err.to_string()))?;
289        let output = build_einsum_graph(builder, &tree, inputs, &shapes)
290            .map_err(|err| ExtensionLoweringError::new(err.to_string()))?;
291        Ok(Some(vec![output]))
292    }
293}
294
295fn concrete_sym_shape_slices(input_shapes: &[&[SymDim]]) -> Option<Vec<Vec<usize>>> {
296    input_shapes
297        .iter()
298        .map(|shape| {
299            shape
300                .iter()
301                .map(SymDim::constant_value)
302                .collect::<Option<Vec<_>>>()
303        })
304        .collect()
305}
306
307/// Return the explicit einsum extension AD rule set.
308#[cfg(feature = "autodiff")]
309pub fn ad_rules() -> Result<ExtensionRuleSet, ExtensionRegistryError> {
310    ExtensionRuleSet::new().with_rule(Arc::new(EinsumAdRule))
311}
312
313#[derive(Debug)]
314#[cfg(feature = "autodiff")]
315struct EinsumAdRule;
316
317#[cfg(feature = "autodiff")]
318impl ExtensionAdRule for EinsumAdRule {
319    fn family_id(&self) -> &'static str {
320        EINSUM_EXTENSION_FAMILY_ID
321    }
322
323    fn linearize(
324        &self,
325        op: &dyn ExtensionOp,
326        builder: &mut dyn PrimitiveRuleBuilder,
327        primal_in: &[ValueKey<StdTensorOp>],
328        _primal_out: &[ValueKey<StdTensorOp>],
329        tangent_in: &[Option<LocalValueId>],
330        _ctx: &mut ShapeGuardContext,
331    ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
332        let op = downcast_ad_op(op, ADRuleKind::Jvp)?;
333        let mut terms = Vec::new();
334
335        for (active_idx, tangent) in tangent_in.iter().enumerate() {
336            let Some(dt) = tangent else {
337                continue;
338            };
339
340            let mut inputs = Vec::with_capacity(primal_in.len());
341            for (input_idx, key) in primal_in.iter().enumerate() {
342                if input_idx == active_idx {
343                    inputs.push(ValueRef::Local(*dt));
344                } else {
345                    inputs.push(ValueRef::External(key.clone()));
346                }
347            }
348
349            let out = builder.add_operation(
350                StdTensorOp::Extension(Arc::new(op.clone())),
351                inputs,
352                OperationRole::Linearized {
353                    active_mask: (0..primal_in.len()).map(|idx| idx == active_idx).collect(),
354                },
355            );
356            terms.push(out[0]);
357        }
358
359        Ok(vec![sum_terms(builder, terms)])
360    }
361
362    fn transpose_rule(
363        &self,
364        op: &dyn ExtensionOp,
365        builder: &mut dyn PrimitiveRuleBuilder,
366        cotangent_out: &[Option<LocalValueId>],
367        inputs: &[ValueRef<StdTensorOp>],
368        mode: &OperationRole,
369        ctx: &mut ShapeGuardContext,
370    ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
371        let op = downcast_ad_op(op, ADRuleKind::Transpose)?;
372        let input_labels = &op.subscripts.inputs;
373        let output_labels = &op.subscripts.output;
374        let input_count = input_labels.len();
375
376        let Some(ct) = cotangent_out.first().copied().flatten() else {
377            return Ok(vec![None; input_count]);
378        };
379        let active_mask = match mode {
380            OperationRole::Linearized { active_mask } => active_mask,
381            OperationRole::Primary => return Ok(vec![None; input_count]),
382        };
383        let primal_input_shapes: Vec<Vec<SymDim>> = inputs
384            .iter()
385            .map(|input| ctx.shape_of(input).map(|shape| shape.to_vec()))
386            .collect::<Result<_, _>>()?;
387        let cotangent_shape = op.output_shape_hint.clone().ok_or_else(|| {
388            ADRuleError::unsupported(
389                "einsum VJP requires an output shape hint for cotangent planning",
390                ADRuleKind::Transpose,
391            )
392        })?;
393
394        let mut result = Vec::with_capacity(input_count);
395        for active_idx in 0..input_count {
396            if !active_mask.get(active_idx).copied().unwrap_or(false) {
397                result.push(None);
398                continue;
399            }
400
401            let mut available_labels: HashSet<u32> = output_labels.iter().copied().collect();
402            for (input_idx, labels) in input_labels.iter().enumerate() {
403                if input_idx != active_idx {
404                    available_labels.extend(labels.iter().copied());
405                }
406            }
407            let vjp_output_labels: Vec<u32> = input_labels[active_idx]
408                .iter()
409                .copied()
410                .filter(|label| available_labels.contains(label))
411                .collect();
412            let mut vjp_input_labels = Vec::with_capacity(input_count);
413            let mut vjp_inputs = Vec::with_capacity(input_count);
414            let mut vjp_input_shapes = Vec::with_capacity(input_count);
415            vjp_input_labels.push(output_labels.clone());
416            vjp_inputs.push(ValueRef::Local(ct));
417            vjp_input_shapes.push(cotangent_shape.clone());
418
419            for input_idx in 0..input_count {
420                if input_idx == active_idx {
421                    continue;
422                }
423                vjp_input_labels.push(input_labels[input_idx].clone());
424                vjp_input_shapes.push(primal_input_shapes[input_idx].clone());
425                vjp_inputs.push(conjugate_primal_if_complex(
426                    builder,
427                    inputs[input_idx].clone(),
428                    ctx,
429                )?);
430            }
431
432            let output_shape_hint = primal_input_shapes[active_idx].clone();
433            let vjp_op = vjp_einsum_op_with_inherited_plan(
434                op,
435                active_idx,
436                EinsumSubscripts {
437                    inputs: vjp_input_labels,
438                    output: vjp_output_labels.clone(),
439                },
440                output_shape_hint.clone(),
441                &vjp_input_shapes,
442            )?;
443            let out = builder.add_operation(
444                StdTensorOp::Extension(Arc::new(vjp_op)),
445                vjp_inputs,
446                OperationRole::Linearized {
447                    active_mask: std::iter::once(true)
448                        .chain(std::iter::repeat_n(false, input_count.saturating_sub(1)))
449                        .collect(),
450                },
451            );
452            let mut cotangent = out[0];
453            if vjp_output_labels != input_labels[active_idx] {
454                let remapped = broadcast_einsum_vjp_to_input_shape(
455                    builder,
456                    cotangent,
457                    &vjp_output_labels,
458                    &input_labels[active_idx],
459                    inputs[active_idx].clone(),
460                    &output_shape_hint,
461                )?;
462                cotangent = remapped;
463            }
464            result.push(Some(cotangent));
465        }
466
467        Ok(result)
468    }
469}
470
471#[cfg(feature = "autodiff")]
472fn vjp_einsum_op_with_inherited_plan(
473    primal_op: &EinsumExtensionOp,
474    active_idx: usize,
475    subscripts: EinsumSubscripts,
476    output_shape_hint: Vec<SymDim>,
477    input_shapes: &[Vec<SymDim>],
478) -> ADRuleResult<EinsumExtensionOp> {
479    let plan_spec =
480        vjp_plan_spec_for_active(primal_op.plan_spec(), primal_op.input_count(), active_idx)?;
481    let mut op = EinsumExtensionOp::with_output_shape_hint(
482        subscripts.clone(),
483        output_shape_hint,
484        plan_spec.clone(),
485    );
486    if let Some(concrete_shapes) = concrete_sym_shapes(input_shapes) {
487        let shape_refs: Vec<&[usize]> = concrete_shapes.iter().map(Vec::as_slice).collect();
488        let raw_subscripts = Subscripts::from(&subscripts);
489        let tree =
490            resolve_plan_spec(&plan_spec, &raw_subscripts, &shape_refs).map_err(|err| {
491                ADRuleError::unsupported(
492                    format!(
493                        "failed to resolve inherited einsum VJP plan for active input {active_idx}: {err}"
494                    ),
495                    ADRuleKind::Transpose,
496                )
497            })?;
498        op = op.with_static_tree_hint(Arc::new(tree));
499    }
500    Ok(op)
501}
502
503#[cfg(feature = "autodiff")]
504fn vjp_plan_spec_for_active(
505    primal_plan: &EinsumPlanSpec,
506    input_count: usize,
507    active_idx: usize,
508) -> ADRuleResult<EinsumPlanSpec> {
509    if active_idx >= input_count {
510        return Err(ADRuleError::unsupported(
511            format!("einsum VJP active input {active_idx} is outside {input_count} inputs"),
512            ADRuleKind::Transpose,
513        ));
514    }
515
516    match primal_plan {
517        EinsumPlanSpec::Auto(options) => Ok(EinsumPlanSpec::Auto(options.clone())),
518        EinsumPlanSpec::LeftToRight => Ok(EinsumPlanSpec::LeftToRight),
519        EinsumPlanSpec::Path(path) => {
520            let pairs = jax_path_to_v1_pairs(path, input_count).map_err(|err| {
521                ADRuleError::unsupported(
522                    format!(
523                        "failed to inherit einsum Path plan for VJP active input {active_idx}: {err}"
524                    ),
525                    ADRuleKind::Transpose,
526                )
527            })?;
528            derive_vjp_fixed_pairs(&pairs, input_count, active_idx).map(EinsumPlanSpec::FixedPairs)
529        }
530        EinsumPlanSpec::FixedPairs(pairs) => {
531            derive_vjp_fixed_pairs(pairs, input_count, active_idx).map(EinsumPlanSpec::FixedPairs)
532        }
533    }
534}
535
536#[cfg(feature = "autodiff")]
537fn derive_vjp_fixed_pairs(
538    primal_pairs: &[(usize, usize)],
539    input_count: usize,
540    active_idx: usize,
541) -> ADRuleResult<Vec<(usize, usize)>> {
542    if input_count == 0 {
543        return Err(ADRuleError::unsupported(
544            "einsum VJP cannot derive a plan for zero primal inputs",
545            ADRuleKind::Transpose,
546        ));
547    }
548    if active_idx >= input_count {
549        return Err(ADRuleError::unsupported(
550            format!("einsum VJP active input {active_idx} is outside {input_count} inputs"),
551            ADRuleKind::Transpose,
552        ));
553    }
554    let required_steps = input_count.saturating_sub(1);
555    if primal_pairs.len() != required_steps {
556        return Err(ADRuleError::unsupported(
557            format!(
558                "einsum VJP cannot inherit explicit plan for active input {active_idx}: \
559                 expected {required_steps} primal steps for {input_count} inputs, got {}",
560                primal_pairs.len()
561            ),
562            ADRuleKind::Transpose,
563        ));
564    }
565    if input_count == 1 {
566        return Ok(Vec::new());
567    }
568
569    let children = fixed_pair_children(primal_pairs, input_count, active_idx)?;
570    let mut primal_to_vjp = vec![None; input_count];
571    let mut next_vjp_input = 1;
572    for (input_idx, slot) in primal_to_vjp.iter_mut().enumerate() {
573        if input_idx != active_idx {
574            *slot = Some(next_vjp_input);
575            next_vjp_input += 1;
576        }
577    }
578
579    let root = input_count + primal_pairs.len() - 1;
580    let mut pairs = Vec::with_capacity(required_steps);
581    let final_id = emit_vjp_adjoint(
582        root,
583        0,
584        &children,
585        input_count,
586        active_idx,
587        &primal_to_vjp,
588        &mut pairs,
589    )?;
590    let expected_final = input_count + pairs.len() - 1;
591    if final_id != expected_final || pairs.len() != required_steps {
592        return Err(ADRuleError::unsupported(
593            format!(
594                "einsum VJP plan derivation for active input {active_idx} produced an invalid \
595                 tree: final id {final_id}, expected {expected_final}, steps {}",
596                pairs.len()
597            ),
598            ADRuleKind::Transpose,
599        ));
600    }
601    Ok(pairs)
602}
603
604#[cfg(feature = "autodiff")]
605fn fixed_pair_children(
606    pairs: &[(usize, usize)],
607    input_count: usize,
608    active_idx: usize,
609) -> ADRuleResult<Vec<Option<(usize, usize)>>> {
610    let mut live = vec![false; input_count + pairs.len()];
611    for slot in live.iter_mut().take(input_count) {
612        *slot = true;
613    }
614    let mut children = vec![None; input_count + pairs.len()];
615
616    for (step_idx, &(left, right)) in pairs.iter().enumerate() {
617        let next_idx = input_count + step_idx;
618        if left == right {
619            return Err(invalid_vjp_plan_error(
620                active_idx,
621                format!("pair ({left}, {right}) references the same operand"),
622            ));
623        }
624        if left >= next_idx || right >= next_idx {
625            return Err(invalid_vjp_plan_error(
626                active_idx,
627                format!("pair ({left}, {right}) references a non-existent operand"),
628            ));
629        }
630        if !live[left] || !live[right] {
631            return Err(invalid_vjp_plan_error(
632                active_idx,
633                format!("pair ({left}, {right}) references an operand that is no longer live"),
634            ));
635        }
636
637        live[left] = false;
638        live[right] = false;
639        live[next_idx] = true;
640        children[next_idx] = Some((left, right));
641    }
642
643    let live_count = live.iter().filter(|&&is_live| is_live).count();
644    if live_count != 1 {
645        return Err(invalid_vjp_plan_error(
646            active_idx,
647            format!("explicit plan leaves {live_count} live operands"),
648        ));
649    }
650
651    Ok(children)
652}
653
654#[cfg(feature = "autodiff")]
655fn emit_vjp_adjoint(
656    node: usize,
657    cotangent_id: usize,
658    children: &[Option<(usize, usize)>],
659    input_count: usize,
660    active_idx: usize,
661    primal_to_vjp: &[Option<usize>],
662    pairs: &mut Vec<(usize, usize)>,
663) -> ADRuleResult<usize> {
664    if node < input_count {
665        return if node == active_idx {
666            Ok(cotangent_id)
667        } else {
668            Err(invalid_vjp_plan_error(
669                active_idx,
670                format!("adjoint walk reached inactive leaf {node}"),
671            ))
672        };
673    }
674
675    let (left, right) = children.get(node).and_then(|child| *child).ok_or_else(|| {
676        invalid_vjp_plan_error(active_idx, format!("missing children for node {node}"))
677    })?;
678    let left_has_active = subtree_contains_active(left, children, input_count, active_idx)?;
679    let right_has_active = subtree_contains_active(right, children, input_count, active_idx)?;
680    match (left_has_active, right_has_active) {
681        (true, false) => {
682            let sibling_id = emit_vjp_subtree(
683                right,
684                children,
685                input_count,
686                active_idx,
687                primal_to_vjp,
688                pairs,
689            )?;
690            let next = push_vjp_pair(cotangent_id, sibling_id, input_count, pairs);
691            emit_vjp_adjoint(
692                left,
693                next,
694                children,
695                input_count,
696                active_idx,
697                primal_to_vjp,
698                pairs,
699            )
700        }
701        (false, true) => {
702            let sibling_id = emit_vjp_subtree(
703                left,
704                children,
705                input_count,
706                active_idx,
707                primal_to_vjp,
708                pairs,
709            )?;
710            let next = push_vjp_pair(cotangent_id, sibling_id, input_count, pairs);
711            emit_vjp_adjoint(
712                right,
713                next,
714                children,
715                input_count,
716                active_idx,
717                primal_to_vjp,
718                pairs,
719            )
720        }
721        (true, true) => Err(invalid_vjp_plan_error(
722            active_idx,
723            format!("both children of node {node} contain the active input"),
724        )),
725        (false, false) => Err(invalid_vjp_plan_error(
726            active_idx,
727            format!("neither child of node {node} contains the active input"),
728        )),
729    }
730}
731
732#[cfg(feature = "autodiff")]
733fn emit_vjp_subtree(
734    node: usize,
735    children: &[Option<(usize, usize)>],
736    input_count: usize,
737    active_idx: usize,
738    primal_to_vjp: &[Option<usize>],
739    pairs: &mut Vec<(usize, usize)>,
740) -> ADRuleResult<usize> {
741    if node < input_count {
742        return primal_to_vjp[node].ok_or_else(|| {
743            invalid_vjp_plan_error(
744                active_idx,
745                format!("sibling subtree unexpectedly reached active leaf {node}"),
746            )
747        });
748    }
749
750    let (left, right) = children.get(node).and_then(|child| *child).ok_or_else(|| {
751        invalid_vjp_plan_error(active_idx, format!("missing children for node {node}"))
752    })?;
753    let left_id = emit_vjp_subtree(
754        left,
755        children,
756        input_count,
757        active_idx,
758        primal_to_vjp,
759        pairs,
760    )?;
761    let right_id = emit_vjp_subtree(
762        right,
763        children,
764        input_count,
765        active_idx,
766        primal_to_vjp,
767        pairs,
768    )?;
769    Ok(push_vjp_pair(left_id, right_id, input_count, pairs))
770}
771
772#[cfg(feature = "autodiff")]
773fn push_vjp_pair(
774    left: usize,
775    right: usize,
776    n_vjp_inputs: usize,
777    pairs: &mut Vec<(usize, usize)>,
778) -> usize {
779    pairs.push((left, right));
780    n_vjp_inputs + pairs.len() - 1
781}
782
783#[cfg(feature = "autodiff")]
784fn subtree_contains_active(
785    node: usize,
786    children: &[Option<(usize, usize)>],
787    input_count: usize,
788    active_idx: usize,
789) -> ADRuleResult<bool> {
790    if node < input_count {
791        return Ok(node == active_idx);
792    }
793    let (left, right) = children.get(node).and_then(|child| *child).ok_or_else(|| {
794        invalid_vjp_plan_error(active_idx, format!("missing children for node {node}"))
795    })?;
796    Ok(
797        subtree_contains_active(left, children, input_count, active_idx)?
798            || subtree_contains_active(right, children, input_count, active_idx)?,
799    )
800}
801
802#[cfg(feature = "autodiff")]
803fn invalid_vjp_plan_error(active_idx: usize, reason: String) -> ADRuleError {
804    ADRuleError::unsupported(
805        format!("einsum VJP cannot inherit explicit plan for active input {active_idx}: {reason}"),
806        ADRuleKind::Transpose,
807    )
808}
809
810#[cfg(feature = "autodiff")]
811fn concrete_sym_shapes(shapes: &[Vec<SymDim>]) -> Option<Vec<Vec<usize>>> {
812    shapes
813        .iter()
814        .map(|shape| shape.iter().map(SymDim::constant_value).collect())
815        .collect()
816}
817
818#[cfg(feature = "autodiff")]
819fn broadcast_einsum_vjp_to_input_shape(
820    builder: &mut dyn PrimitiveRuleBuilder,
821    cotangent: LocalValueId,
822    cotangent_labels: &[u32],
823    input_labels: &[u32],
824    shape_source: ValueRef<StdTensorOp>,
825    input_shape: &[SymDim],
826) -> ADRuleResult<LocalValueId> {
827    let shape: Vec<DimExpr> = input_shape
828        .iter()
829        .enumerate()
830        .map(|(axis, _)| DimExpr::InputDim { input_idx: 1, axis })
831        .collect();
832    let dims = map_label_occurrences(cotangent_labels, input_labels).ok_or_else(|| {
833        ADRuleError::unsupported(
834            format!(
835                "einsum VJP broadcast remap failed for cotangent labels {cotangent_labels:?} \
836                 into active input labels {input_labels:?}"
837            ),
838            ADRuleKind::Transpose,
839        )
840    })?;
841    let mut inputs = vec![ValueRef::Local(cotangent)];
842    if !shape.is_empty() {
843        inputs.push(shape_source);
844    }
845    let broadcast = builder.add_operation(
846        StdTensorOp::BroadcastInDim { shape, dims },
847        inputs,
848        OperationRole::Linearized {
849            active_mask: vec![true, false],
850        },
851    )[0];
852    Ok(project_repeated_labels_to_diagonal(
853        builder,
854        broadcast,
855        input_labels,
856    ))
857}
858
859#[cfg(feature = "autodiff")]
860fn map_label_occurrences(source_labels: &[u32], target_labels: &[u32]) -> Option<Vec<usize>> {
861    let mut used = vec![false; target_labels.len()];
862    source_labels
863        .iter()
864        .map(|label| {
865            let axis = target_labels
866                .iter()
867                .enumerate()
868                .find_map(|(axis, target)| (!used[axis] && target == label).then_some(axis))?;
869            used[axis] = true;
870            Some(axis)
871        })
872        .collect()
873}
874
875#[cfg(feature = "autodiff")]
876fn project_repeated_labels_to_diagonal(
877    builder: &mut dyn PrimitiveRuleBuilder,
878    cotangent: LocalValueId,
879    labels: &[u32],
880) -> LocalValueId {
881    let mut result = cotangent;
882    let mut first_axis_by_label = HashMap::new();
883    for (axis_b, label) in labels.iter().copied().enumerate() {
884        let Some(&axis_a) = first_axis_by_label.get(&label) else {
885            first_axis_by_label.insert(label, axis_b);
886            continue;
887        };
888        let extracted = builder.add_operation(
889            StdTensorOp::ExtractDiag { axis_a, axis_b },
890            vec![ValueRef::Local(result)],
891            OperationRole::Linearized {
892                active_mask: vec![true],
893            },
894        )[0];
895        result = builder.add_operation(
896            StdTensorOp::EmbedDiag { axis_a, axis_b },
897            vec![ValueRef::Local(extracted)],
898            OperationRole::Linearized {
899                active_mask: vec![true],
900            },
901        )[0];
902    }
903    result
904}
905
906define_extension_runtime! {
907    runtime = EinsumRuntime,
908    family_id = EINSUM_EXTENSION_FAMILY_ID,
909    op_type = EinsumExtensionOp,
910    execute = execute_einsum_extension,
911    execute_reads = execute_einsum_extension_reads,
912    register_fn = register_runtime,
913}
914
915fn execute_einsum_extension<B: TensorBackend + 'static>(
916    op: &EinsumExtensionOp,
917    inputs: &[&Tensor],
918    ctx: &mut ExtensionExecutionContext<'_, B>,
919) -> tenferro_tensor::Result<Vec<Tensor>> {
920    if inputs.is_empty() {
921        return Err(tenferro_tensor::Error::InvalidConfig {
922            op: "einsum_extension",
923            message: "einsum requires at least one input tensor".into(),
924        });
925    }
926
927    let shapes: Vec<Vec<usize>> = inputs
928        .iter()
929        .map(|tensor| tensor.shape().to_vec())
930        .collect();
931    let shape_refs: Vec<&[usize]> = shapes.iter().map(Vec::as_slice).collect();
932    let subs = Subscripts::from(op.subscripts());
933    let tree = if let Some(tree) = op.static_tree() {
934        Arc::clone(tree)
935    } else {
936        cached_runtime_tree(ctx, op.subscripts(), op.plan_spec(), &shapes, || {
937            resolve_plan_spec(op.plan_spec(), &subs, &shape_refs)
938        })?
939    };
940
941    if is_binary_non_contracting(&subs) {
942        let output = ctx
943            .backend_mut()
944            .with_backend_session(|exec| crate::eager::eager_einsum_exec(exec, inputs, &tree))?;
945        return Ok(vec![output]);
946    }
947
948    let (backend, caches) = ctx.parts_mut();
949    let compiler_options = tenferro_runtime::extension::CompilerOptions::default();
950    let optimizer_fingerprint = compiler_options.optimizer.fingerprint();
951    let key = runtime_exec_program_cache_key(op, inputs, &shapes, optimizer_fingerprint);
952    if caches
953        .get_mut::<CachedRuntimeExecProgram<B::RuntimeCache>>(&key)
954        .is_none()
955    {
956        let cached =
957            build_runtime_exec_program::<B>(tree.as_ref(), inputs, &shapes, compiler_options)?;
958        let key_retained_bytes = runtime_exec_program_key_retained_bytes(op, inputs, &shapes);
959        caches.put_with_retained_bytes(key, cached, move |cached| {
960            saturating_sum([
961                key_retained_bytes,
962                cached_runtime_exec_program_retained_bytes(cached),
963            ])
964        });
965    }
966    let cached = caches
967        .get_mut::<CachedRuntimeExecProgram<B::RuntimeCache>>(&key)
968        .ok_or_else(|| {
969            tenferro_tensor::Error::backend_failure(
970                "einsum_extension",
971                "runtime exec program cache entry missing after insertion",
972            )
973        })?;
974    let program_inputs = runtime_program_inputs(inputs, cached.input_indices.as_slice())?;
975    let mut outputs = tenferro_runtime::extension::execute_lowered_program_with_backend_cache(
976        backend,
977        &cached.program,
978        program_inputs,
979        &mut cached.backend_cache,
980    )
981    .map_err(|err| tenferro_tensor::Error::backend_failure("einsum_extension", err.to_string()))?;
982    if outputs.len() != 1 {
983        return Err(tenferro_tensor::Error::backend_failure(
984            "einsum_extension",
985            format!("expected 1 output, got {}", outputs.len()),
986        ));
987    }
988    Ok(vec![outputs.remove(0)])
989}
990
991fn execute_einsum_extension_reads<B: TensorBackend + 'static>(
992    op: &EinsumExtensionOp,
993    inputs: &[TensorRead<'_>],
994    ctx: &mut ExtensionExecutionContext<'_, B>,
995) -> tenferro_tensor::Result<Vec<Tensor>> {
996    if inputs
997        .iter()
998        .all(|input| matches!(input, TensorRead::Tensor(_)))
999    {
1000        let input_refs: Vec<&Tensor> = inputs
1001            .iter()
1002            .map(|input| match input {
1003                TensorRead::Tensor(tensor) => *tensor,
1004                TensorRead::View(_) => unreachable!("view input filtered above"),
1005            })
1006            .collect();
1007        return execute_einsum_extension(op, &input_refs, ctx);
1008    }
1009
1010    if inputs.is_empty() {
1011        return Err(tenferro_tensor::Error::InvalidConfig {
1012            op: "einsum_extension",
1013            message: "einsum requires at least one input tensor".into(),
1014        });
1015    }
1016
1017    let shapes: Vec<Vec<usize>> = inputs.iter().map(|input| input.shape().to_vec()).collect();
1018    let shape_refs: Vec<&[usize]> = shapes.iter().map(Vec::as_slice).collect();
1019    let subs = Subscripts::from(op.subscripts());
1020    let tree = if let Some(tree) = op.static_tree() {
1021        Arc::clone(tree)
1022    } else {
1023        cached_runtime_tree(ctx, op.subscripts(), op.plan_spec(), &shapes, || {
1024            resolve_plan_spec(op.plan_spec(), &subs, &shape_refs)
1025        })?
1026    };
1027    let output = ctx
1028        .backend_mut()
1029        .with_backend_session(|exec| crate::eager::eager_einsum_exec_read(exec, inputs, &tree))?;
1030    Ok(vec![output])
1031}
1032
1033fn is_binary_non_contracting(subs: &Subscripts) -> bool {
1034    if subs.inputs.len() != 2 {
1035        return false;
1036    }
1037
1038    let lhs = &subs.inputs[0];
1039    let rhs = &subs.inputs[1];
1040    let output = &subs.output;
1041    !lhs.iter()
1042        .any(|label| rhs.contains(label) && !output.contains(label))
1043}
1044
1045struct CachedRuntimeExecProgram<C> {
1046    program: ExecProgram,
1047    input_indices: InputIndexVec,
1048    backend_cache: C,
1049    optimizer_fingerprint: u64,
1050}
1051
1052fn runtime_exec_program_cache_key(
1053    op: &EinsumExtensionOp,
1054    inputs: &[&Tensor],
1055    shapes: &[Vec<usize>],
1056    optimizer_fingerprint: u64,
1057) -> ExtensionCacheKey {
1058    let input_dtypes: Vec<DType> = inputs.iter().map(|tensor| tensor.dtype()).collect();
1059    let mut plan_hasher = std::collections::hash_map::DefaultHasher::new();
1060    hash_einsum_plan_spec(op.plan_spec(), &mut plan_hasher);
1061    let key_data = (
1062        op.subscripts().clone(),
1063        shapes.to_vec(),
1064        input_dtypes.clone(),
1065        plan_hasher.finish(),
1066        optimizer_fingerprint,
1067    );
1068    ExtensionCacheKey::new(
1069        EINSUM_EXTENSION_FAMILY_ID,
1070        EINSUM_RUNTIME_EXEC_PROGRAMS_CACHE,
1071        hash_value(&key_data),
1072    )
1073}
1074
1075fn runtime_exec_program_key_retained_bytes(
1076    op: &EinsumExtensionOp,
1077    inputs: &[&Tensor],
1078    shapes: &[Vec<usize>],
1079) -> usize {
1080    saturating_sum([
1081        einsum_subscripts_retained_bytes(op.subscripts()),
1082        saturating_sum(shapes.iter().map(vec_retained_bytes)),
1083        inputs.len().saturating_mul(std::mem::size_of::<DType>()),
1084        std::mem::size_of::<u64>(),
1085        std::mem::size_of::<u64>(),
1086    ])
1087}
1088
1089fn build_runtime_exec_program<B: TensorBackend>(
1090    tree: &ContractionTree,
1091    inputs: &[&Tensor],
1092    shapes: &[Vec<usize>],
1093    compiler_options: tenferro_runtime::extension::CompilerOptions,
1094) -> tenferro_tensor::Result<CachedRuntimeExecProgram<B::RuntimeCache>> {
1095    let mut builder = GraphBuilder::<StdTensorOp>::new();
1096    let mut input_vals = Vec::with_capacity(inputs.len());
1097    for input_idx in 0..inputs.len() {
1098        let local = builder.add_input(TensorInputKey::User {
1099            id: input_idx as u64,
1100        });
1101        input_vals.push(ValueRef::Local(local));
1102    }
1103
1104    let result_ref = build_einsum_graph(&mut builder, tree, &input_vals, shapes)
1105        .map_err(einsum_runtime_error)?;
1106    let result_local = match result_ref {
1107        ValueRef::Local(local) => local,
1108        ValueRef::External(_) => {
1109            return Err(tenferro_tensor::Error::backend_failure(
1110                "einsum_extension",
1111                "einsum builder returned an external value at runtime",
1112            ))
1113        }
1114    };
1115    builder.set_outputs(vec![result_local]);
1116    let graph = Arc::new(builder.build());
1117    let output_key = graph.values()[result_local].key.clone();
1118
1119    let view = resolve(vec![graph]);
1120    let graph = materialize_merge(&view, &[output_key]);
1121    let compiled = compile(&graph);
1122
1123    let mut input_indices = InputIndexVec::new();
1124    let mut input_dtypes = Vec::with_capacity(graph.inputs.len());
1125    let mut input_shapes = Vec::with_capacity(graph.inputs.len());
1126    for key in &graph.inputs {
1127        match key {
1128            ValueKey::Input(TensorInputKey::User { id }) => {
1129                let input_idx = *id as usize;
1130                let tensor = inputs.get(input_idx).ok_or_else(|| {
1131                    tenferro_tensor::Error::backend_failure(
1132                        "einsum_extension",
1133                        format!("runtime input {input_idx} missing"),
1134                    )
1135                })?;
1136                input_indices.push(input_idx);
1137                input_dtypes.push(tensor.dtype());
1138                input_shapes.push(tenferro_ops::dim_expr::DimExpr::from_concrete(
1139                    tensor.shape(),
1140                ));
1141            }
1142            other => {
1143                return Err(tenferro_tensor::Error::backend_failure(
1144                    "einsum_extension",
1145                    format!("unexpected runtime input key: {other:?}"),
1146                ))
1147            }
1148        }
1149    }
1150
1151    let program = tenferro_runtime::extension::compile_std_to_exec_with_options(
1152        &compiled,
1153        &input_dtypes,
1154        &input_shapes,
1155        compiler_options,
1156    )
1157    .map_err(|err| tenferro_tensor::Error::backend_failure("einsum_extension", err.to_string()))?;
1158    Ok(CachedRuntimeExecProgram {
1159        program,
1160        input_indices,
1161        backend_cache: B::RuntimeCache::default(),
1162        optimizer_fingerprint: compiler_options.optimizer.fingerprint(),
1163    })
1164}
1165
1166fn runtime_program_inputs(
1167    inputs: &[&Tensor],
1168    input_indices: &[usize],
1169) -> tenferro_tensor::Result<Vec<Tensor>> {
1170    let mut program_inputs = Vec::with_capacity(input_indices.len());
1171    for &input_idx in input_indices {
1172        let tensor = inputs.get(input_idx).ok_or_else(|| {
1173            tenferro_tensor::Error::backend_failure(
1174                "einsum_extension",
1175                format!("runtime input {input_idx} missing"),
1176            )
1177        })?;
1178        program_inputs.push((*tensor).clone());
1179    }
1180    Ok(program_inputs)
1181}
1182
1183fn cached_runtime_exec_program_retained_bytes<C: RuntimeCacheControl>(
1184    cached: &CachedRuntimeExecProgram<C>,
1185) -> usize {
1186    saturating_sum([
1187        std::mem::size_of::<CachedRuntimeExecProgram<C>>(),
1188        exec_program_retained_bytes(&cached.program),
1189        smallvec_retained_bytes(&cached.input_indices),
1190        cached.backend_cache.stats().retained_bytes,
1191        std::mem::size_of_val(&cached.optimizer_fingerprint),
1192    ])
1193}
1194
1195fn smallvec_retained_bytes<A: smallvec::Array>(values: &SmallVec<A>) -> usize {
1196    if values.spilled() {
1197        values
1198            .capacity()
1199            .saturating_mul(std::mem::size_of::<A::Item>())
1200    } else {
1201        0
1202    }
1203}
1204
1205fn exec_program_retained_bytes(program: &ExecProgram) -> usize {
1206    saturating_sum([
1207        std::mem::size_of::<ExecProgram>(),
1208        vec_retained_bytes(&program.instructions),
1209        saturating_sum(
1210            program
1211                .instructions
1212                .iter()
1213                .map(exec_instruction_retained_bytes),
1214        ),
1215        vec_retained_bytes(&program.input_slots),
1216        vec_retained_bytes(&program.output_slots),
1217    ])
1218}
1219
1220fn exec_instruction_retained_bytes(inst: &ExecInstruction) -> usize {
1221    saturating_sum([
1222        std::mem::size_of::<ExecInstruction>(),
1223        exec_op_retained_bytes(&inst.op),
1224        vec_retained_bytes(&inst.input_slots),
1225        vec_retained_bytes(&inst.output_slots),
1226        vec_of_vec_retained_bytes(&inst.output_shapes),
1227        vec_of_vec_retained_bytes(&inst.output_extents),
1228        vec_retained_bytes(&inst.last_use),
1229    ])
1230}
1231
1232fn exec_op_retained_bytes(op: &ExecOp) -> usize {
1233    match op {
1234        ExecOp::Constant { bytes, .. } => vec_retained_bytes(bytes),
1235        ExecOp::Extension(extension) => std::mem::size_of_val(extension),
1236        _ => 0,
1237    }
1238}
1239
1240fn cached_runtime_tree<B: TensorBackend>(
1241    ctx: &mut ExtensionExecutionContext<'_, B>,
1242    subscripts: &EinsumSubscripts,
1243    plan_spec: &EinsumPlanSpec,
1244    shapes: &[Vec<usize>],
1245    build: impl FnOnce() -> EinsumResult<ContractionTree>,
1246) -> tenferro_tensor::Result<Arc<ContractionTree>> {
1247    let mut plan_hasher = std::collections::hash_map::DefaultHasher::new();
1248    hash_einsum_plan_spec(plan_spec, &mut plan_hasher);
1249    let key_data = (subscripts.clone(), shapes.to_vec(), plan_hasher.finish());
1250    let key = ExtensionCacheKey::new(
1251        EINSUM_EXTENSION_FAMILY_ID,
1252        EINSUM_RUNTIME_PLANS_CACHE,
1253        hash_value(&key_data),
1254    );
1255    if let Some(cached) = ctx.caches_mut().get::<Arc<ContractionTree>>(&key) {
1256        return Ok(Arc::clone(cached));
1257    }
1258
1259    let tree = Arc::new(build().map_err(einsum_runtime_error)?);
1260    let retained_bytes = saturating_sum([
1261        einsum_subscripts_retained_bytes(subscripts),
1262        saturating_sum(shapes.iter().map(vec_retained_bytes)),
1263        std::mem::size_of::<u64>(),
1264        tree.retained_bytes_for_cache_stats(),
1265    ]);
1266    ctx.caches_mut().put(key, Arc::clone(&tree), retained_bytes);
1267    Ok(tree)
1268}
1269
1270fn einsum_runtime_error(error: EinsumError) -> tenferro_tensor::Error {
1271    error.to_tensor_error("einsum_extension")
1272}
1273
1274fn hash_value<T: Hash>(value: &T) -> u64 {
1275    let mut hasher = std::collections::hash_map::DefaultHasher::new();
1276    value.hash(&mut hasher);
1277    hasher.finish()
1278}
1279
1280#[cfg(feature = "autodiff")]
1281fn downcast_ad_op(op: &dyn ExtensionOp, kind: ADRuleKind) -> ADRuleResult<&EinsumExtensionOp> {
1282    op.as_any()
1283        .downcast_ref::<EinsumExtensionOp>()
1284        .ok_or_else(|| ADRuleError::unsupported("tenferro.einsum.v1 payload type mismatch", kind))
1285}
1286
1287#[cfg(feature = "autodiff")]
1288fn sum_terms(
1289    builder: &mut dyn PrimitiveRuleBuilder,
1290    terms: Vec<LocalValueId>,
1291) -> Option<LocalValueId> {
1292    match terms.as_slice() {
1293        [] => None,
1294        [only] => Some(*only),
1295        [head, tail @ ..] => {
1296            let mut result = *head;
1297            for &term in tail {
1298                let sum = builder.add_operation(
1299                    StdTensorOp::Add,
1300                    vec![ValueRef::Local(result), ValueRef::Local(term)],
1301                    OperationRole::Linearized {
1302                        active_mask: vec![true, true],
1303                    },
1304                );
1305                result = sum[0];
1306            }
1307            Some(result)
1308        }
1309    }
1310}
1311
1312#[cfg(feature = "autodiff")]
1313fn conjugate_primal_if_complex(
1314    builder: &mut dyn PrimitiveRuleBuilder,
1315    input: ValueRef<StdTensorOp>,
1316    ctx: &mut ShapeGuardContext,
1317) -> ADRuleResult<ValueRef<StdTensorOp>> {
1318    Ok(match ctx.dtype_of(&input)? {
1319        DType::F32 | DType::F64 | DType::I32 | DType::I64 | DType::Bool => input,
1320        DType::C32 | DType::C64 => ValueRef::Local(
1321            builder.add_operation(StdTensorOp::Conj, vec![input], OperationRole::Primary)[0],
1322        ),
1323    })
1324}
1325
1326fn promote_dtypes(dtypes: impl IntoIterator<Item = DType>) -> DType {
1327    dtypes
1328        .into_iter()
1329        .reduce(promote_dtype)
1330        .unwrap_or(DType::F64)
1331}
1332
1333fn promote_dtype(lhs: DType, rhs: DType) -> DType {
1334    use DType::*;
1335    match (lhs, rhs) {
1336        (Bool, Bool) => Bool,
1337        (Bool, other) | (other, Bool) => other,
1338        (I32, I32) => I32,
1339        (I32, I64) | (I64, I32) | (I64, I64) => I64,
1340        (I32 | I64, F32 | F64) | (F32 | F64, I32 | I64) => F64,
1341        (I32 | I64, C32 | C64) | (C32 | C64, I32 | I64) => C64,
1342        (F32, F32) => F32,
1343        (F32, F64) | (F64, F32) | (F64, F64) => F64,
1344        (F32, C32) | (C32, F32) | (C32, C32) => C32,
1345        (F32, C64) | (C64, F32) => C64,
1346        (F64, C32 | C64) | (C32 | C64, F64) => C64,
1347        (C32, C64) | (C64, C32) | (C64, C64) => C64,
1348    }
1349}
1350
1351#[cfg(test)]
1352mod tests;