Skip to main content

tensor4all_core/defaults/
contract.rs

1//! Multi-tensor contraction with optimal contraction order.
2//!
3//! This module provides functions to contract multiple tensors efficiently
4//! using einsum optimization via the tensorbackend
5//! (tenferro-backed implementation).
6//!
7//! This module works with concrete types (`DynIndex`, `TensorDynLen`) only.
8//!
9//! # Main Functions
10//!
11//! - [`contract`]: Contracts one connected tensor network
12//! - [`contract_with_options`]: Contracts one connected tensor network with retained indices
13//!
14//! # Diag Tensor Handling
15//!
16//! Diagonal tensors are materialized as dense native operands for contraction,
17//! so numeric einsum labels must keep uncontracted logical axes distinct.
18//! Diagonal/structured equality metadata is propagated separately onto the
19//! result when the contraction leaves equal axes behind.
20
21use std::cell::RefCell;
22use std::cmp::Reverse;
23use std::collections::{HashMap, HashSet};
24use std::env;
25use std::time::{Duration, Instant};
26
27use anyhow::Result;
28use petgraph::algo::connected_components;
29use petgraph::prelude::*;
30use tenferro::eager_tensor::einsum_subscripts as eager_einsum_ad;
31use tenferro::EinsumSubscripts;
32use tensor4all_tensorbackend::{einsum_native_tensors, einsum_native_tensors_owned};
33
34use crate::defaults::{DynId, DynIndex, TensorDynLen};
35
36use crate::index_like::IndexLike;
37#[derive(Debug, Clone, Hash, PartialEq, Eq)]
38struct ContractOperandSignature {
39    dims: Vec<usize>,
40    ids: Vec<usize>,
41    is_diag: bool,
42}
43
44#[derive(Debug, Clone, Hash, PartialEq, Eq)]
45struct ContractSignature {
46    operands: Vec<ContractOperandSignature>,
47    output_ids: Vec<usize>,
48    output_dims: Vec<usize>,
49}
50
51#[derive(Debug, Default, Clone)]
52struct ContractProfileEntry {
53    calls: usize,
54    total_time: Duration,
55}
56
57thread_local! {
58    static CONTRACT_PROFILE_STATE: RefCell<HashMap<ContractSignature, ContractProfileEntry>> =
59        RefCell::new(HashMap::new());
60}
61
62fn contract_profile_enabled() -> bool {
63    env::var("T4A_PROFILE_CONTRACT").is_ok()
64}
65
66fn record_contract_profile(signature: ContractSignature, elapsed: Duration) {
67    if !contract_profile_enabled() {
68        return;
69    }
70    CONTRACT_PROFILE_STATE.with(|state| {
71        let mut state = state.borrow_mut();
72        let entry = state.entry(signature).or_default();
73        entry.calls += 1;
74        entry.total_time += elapsed;
75    });
76}
77
78/// Reset the aggregated multi-tensor contraction profile.
79pub fn reset_contract_profile() {
80    CONTRACT_PROFILE_STATE.with(|state| state.borrow_mut().clear());
81}
82
83/// Print and clear the aggregated multi-tensor contraction profile.
84pub fn print_and_reset_contract_profile() {
85    if !contract_profile_enabled() {
86        return;
87    }
88    CONTRACT_PROFILE_STATE.with(|state| {
89        let mut entries: Vec<_> = state
90            .borrow()
91            .iter()
92            .map(|(k, v)| (k.clone(), v.clone()))
93            .collect();
94        state.borrow_mut().clear();
95        entries.sort_by_key(|(_, entry)| Reverse(entry.total_time));
96
97        eprintln!("=== contract Profile ===");
98        for (idx, (signature, entry)) in entries.into_iter().take(20).enumerate() {
99            let operands = signature
100                .operands
101                .iter()
102                .map(|operand| {
103                    format!(
104                        "dims={:?} ids={:?}{}",
105                        operand.dims,
106                        operand.ids,
107                        if operand.is_diag { " diag" } else { "" }
108                    )
109                })
110                .collect::<Vec<_>>()
111                .join(" ; ");
112            eprintln!(
113                "#{idx:02} calls={} total={:.3}s per_call={:.3}us output_dims={:?} output_ids={:?}",
114                entry.calls,
115                entry.total_time.as_secs_f64(),
116                entry.total_time.as_secs_f64() * 1e6 / entry.calls as f64,
117                signature.output_dims,
118                signature.output_ids,
119            );
120            eprintln!("     {operands}");
121        }
122    });
123}
124
125// ============================================================================
126// Public API
127// ============================================================================
128
129/// Options for multi-tensor contraction.
130///
131/// Use this to choose which shared indices should be retained in the output
132/// instead of summed over.
133///
134/// # Examples
135///
136/// ```
137/// use tensor4all_core::{ContractionOptions, DynIndex};
138///
139/// let batch = DynIndex::new_dyn(2);
140/// let retain = [batch.clone()];
141/// let options = ContractionOptions::new().with_retain_indices(&retain);
142///
143/// assert_eq!(options.retain_indices, &[batch]);
144/// ```
145#[derive(Clone, Copy, Debug)]
146pub struct ContractionOptions<'a> {
147    /// Indices that should remain in the result even if they appear more than once.
148    pub retain_indices: &'a [DynIndex],
149}
150
151impl<'a> ContractionOptions<'a> {
152    /// Create contraction options with no retained indices.
153    pub fn new() -> Self {
154        Self {
155            retain_indices: &[],
156        }
157    }
158
159    /// Set the indices that should be retained in the output.
160    pub fn with_retain_indices(mut self, retain_indices: &'a [DynIndex]) -> Self {
161        self.retain_indices = retain_indices;
162        self
163    }
164}
165
166impl Default for ContractionOptions<'_> {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172/// Options for pairwise tensor contraction.
173///
174/// The conjugation flags are semantically equivalent to contracting
175/// `lhs.conj()` or `rhs.conj()`, but allow implementations to pass conjugation
176/// to the backend without materializing a conjugated tensor.
177///
178/// # Examples
179///
180/// ```
181/// use num_complex::Complex64;
182/// use tensor4all_core::{
183///     contract_pair, contract_pair_with_operand_options, DynIndex,
184///     PairwiseContractionOptions, TensorDynLen,
185/// };
186///
187/// let i = DynIndex::new_dyn(2);
188/// let lhs = TensorDynLen::from_dense(
189///     vec![i.clone()],
190///     vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, -1.0)],
191/// ).unwrap();
192/// let rhs = TensorDynLen::from_dense(
193///     vec![i],
194///     vec![Complex64::new(2.0, 0.5), Complex64::new(-1.0, 4.0)],
195/// ).unwrap();
196///
197/// let options = PairwiseContractionOptions::new().with_lhs_conj(true);
198/// let flagged = contract_pair_with_operand_options(&lhs, &rhs, options).unwrap();
199/// let materialized = contract_pair(&lhs.conj(), &rhs).unwrap();
200///
201/// assert!((flagged.sum().unwrap() - materialized.sum().unwrap()).abs() < 1e-12);
202/// ```
203#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
204pub struct PairwiseContractionOptions {
205    /// Whether to conjugate the left operand before contraction.
206    pub lhs_conj: bool,
207    /// Whether to conjugate the right operand before contraction.
208    pub rhs_conj: bool,
209}
210
211impl PairwiseContractionOptions {
212    /// Create pairwise contraction options with no operand conjugation.
213    ///
214    /// # Examples
215    ///
216    /// ```
217    /// use tensor4all_core::PairwiseContractionOptions;
218    ///
219    /// let options = PairwiseContractionOptions::new();
220    /// assert!(!options.lhs_conj);
221    /// assert!(!options.rhs_conj);
222    /// ```
223    pub fn new() -> Self {
224        Self::default()
225    }
226
227    /// Set whether the left operand is conjugated during contraction.
228    ///
229    /// # Examples
230    ///
231    /// ```
232    /// use tensor4all_core::PairwiseContractionOptions;
233    ///
234    /// let options = PairwiseContractionOptions::new().with_lhs_conj(true);
235    /// assert!(options.lhs_conj);
236    /// assert!(!options.rhs_conj);
237    /// ```
238    pub fn with_lhs_conj(mut self, lhs_conj: bool) -> Self {
239        self.lhs_conj = lhs_conj;
240        self
241    }
242
243    /// Set whether the right operand is conjugated during contraction.
244    ///
245    /// # Examples
246    ///
247    /// ```
248    /// use tensor4all_core::PairwiseContractionOptions;
249    ///
250    /// let options = PairwiseContractionOptions::new().with_rhs_conj(true);
251    /// assert!(!options.lhs_conj);
252    /// assert!(options.rhs_conj);
253    /// ```
254    pub fn with_rhs_conj(mut self, rhs_conj: bool) -> Self {
255        self.rhs_conj = rhs_conj;
256        self
257    }
258
259    pub(crate) fn has_conj(self) -> bool {
260        self.lhs_conj || self.rhs_conj
261    }
262}
263
264/// Contract a connected tensor network with the default semantics.
265///
266/// This is the normal public entry point for N-ary tensor contraction. It
267/// contracts all common contractable indices and requires the input tensors to
268/// form one connected tensor graph. Disconnected inputs are rejected so missing
269/// links do not silently become outer products.
270///
271/// Use explicit [`outer_product`] calls when an outer product of disconnected
272/// components is intentional.
273pub fn contract(tensors: &[&TensorDynLen]) -> Result<TensorDynLen> {
274    contract_with_options(tensors, ContractionOptions::new())
275}
276
277/// Contract a connected tensor network with advanced options.
278pub fn contract_with_options(
279    tensors: &[&TensorDynLen],
280    options: ContractionOptions<'_>,
281) -> Result<TensorDynLen> {
282    contract_with_options_impl(tensors, options)
283}
284
285/// Contract owned tensors with the default connected-network semantics.
286pub fn contract_owned(tensors: Vec<TensorDynLen>) -> Result<TensorDynLen> {
287    contract_owned_with_options(tensors, ContractionOptions::new())
288}
289
290/// Contract owned tensors with advanced connected-network options.
291pub fn contract_owned_with_options(
292    tensors: Vec<TensorDynLen>,
293    options: ContractionOptions<'_>,
294) -> Result<TensorDynLen> {
295    let tensor_refs = tensors.iter().collect::<Vec<_>>();
296    let components =
297        find_tensor_connected_components_with_retained(&tensor_refs, options.retain_indices);
298    if components.len() > 1 {
299        return Err(anyhow::anyhow!(
300            "Tensors form disconnected components; use explicit outer_product operations for an intentional disconnected product"
301        ));
302    }
303    drop(tensor_refs);
304    contract_owned_with_options_impl(tensors, options)
305}
306
307/// Contract two tensors with the default pairwise semantics.
308///
309/// This is the concrete `TensorDynLen` entry point for binary contraction. It
310/// contracts all common indices and preserves the pairwise structured fast
311/// paths used by [`TensorContractionLike::contract_pair`].
312pub fn contract_pair(lhs: &TensorDynLen, rhs: &TensorDynLen) -> Result<TensorDynLen> {
313    lhs.try_contract_pairwise_default_with_options(rhs, PairwiseContractionOptions::new())
314}
315
316/// Contract two tensors with operand-level conjugation options.
317///
318/// This has the same index semantics as [`contract_pair`], with optional
319/// conjugation applied to either operand before matching and contracting common
320/// indices. Implementations may pass conjugation to the backend to avoid
321/// materializing conjugated payloads.
322///
323/// # Examples
324///
325/// ```
326/// use num_complex::Complex64;
327/// use tensor4all_core::{
328///     contract_pair, contract_pair_with_operand_options, DynIndex,
329///     PairwiseContractionOptions, TensorDynLen,
330/// };
331///
332/// let i = DynIndex::new_dyn(2);
333/// let lhs = TensorDynLen::from_dense(
334///     vec![i.clone()],
335///     vec![Complex64::new(1.0, 1.0), Complex64::new(0.0, 2.0)],
336/// ).unwrap();
337/// let rhs = TensorDynLen::from_dense(
338///     vec![i],
339///     vec![Complex64::new(2.0, 0.0), Complex64::new(3.0, -1.0)],
340/// ).unwrap();
341///
342/// let flagged = contract_pair_with_operand_options(
343///     &lhs,
344///     &rhs,
345///     PairwiseContractionOptions::new().with_lhs_conj(true),
346/// ).unwrap();
347/// let materialized = contract_pair(&lhs.conj(), &rhs).unwrap();
348///
349/// assert!((flagged.sum().unwrap() - materialized.sum().unwrap()).abs() < 1e-12);
350/// ```
351pub fn contract_pair_with_operand_options(
352    lhs: &TensorDynLen,
353    rhs: &TensorDynLen,
354    options: PairwiseContractionOptions,
355) -> Result<TensorDynLen> {
356    lhs.try_contract_pairwise_default_with_options(rhs, options)
357}
358
359/// Contract two tensors with explicit contraction options.
360pub fn contract_pair_with_options(
361    lhs: &TensorDynLen,
362    rhs: &TensorDynLen,
363    options: ContractionOptions<'_>,
364) -> Result<TensorDynLen> {
365    contract_with_options(&[lhs, rhs], options)
366}
367
368/// Contract two tensors along explicitly specified index pairs.
369pub fn tensordot(
370    lhs: &TensorDynLen,
371    rhs: &TensorDynLen,
372    pairs: &[(DynIndex, DynIndex)],
373) -> Result<TensorDynLen> {
374    lhs.try_tensordot_pairwise_explicit(rhs, pairs)
375}
376
377/// Compute the outer product of two tensors.
378///
379/// This is an explicit tensor product, not a dense-only operation. Compact
380/// structured storage is preserved when the operand layouts allow it.
381pub fn outer_product(lhs: &TensorDynLen, rhs: &TensorDynLen) -> Result<TensorDynLen> {
382    lhs.try_outer_product_pairwise(rhs)
383}
384
385/// Contract multiple owned tensors into a single tensor.
386///
387/// This is the consuming implementation for [`contract_owned_with_options`]. It
388/// preserves the same contraction semantics while allowing eligible non-AD
389/// dense inputs to use tenferro's owned eager einsum executor. When any input
390/// tracks gradients, or when compact structured metadata needs the borrowed
391/// path, this function falls back to the shared borrowed execution so semantics
392/// and reverse-mode AD remain intact.
393fn contract_owned_with_options_impl(
394    tensors: Vec<TensorDynLen>,
395    options: ContractionOptions<'_>,
396) -> Result<TensorDynLen> {
397    match tensors.len() {
398        0 => Err(anyhow::anyhow!("No tensors to contract")),
399        _ => {
400            let tensor_refs = tensors.iter().collect::<Vec<_>>();
401            validate_retained_indices_exist(&tensor_refs, options.retain_indices)?;
402
403            if tensors.len() == 1 {
404                drop(tensor_refs);
405                let Some(tensor) = tensors.into_iter().next() else {
406                    return Err(anyhow::anyhow!("No tensors to contract"));
407                };
408                return Ok(tensor);
409            }
410
411            let requires_borrowed_path = tensor_refs.iter().any(|tensor| tensor.tracks_grad())
412                || tensor_refs
413                    .iter()
414                    .any(|tensor| !has_dense_axis_classes(tensor));
415            if requires_borrowed_path {
416                return contract_with_options(&tensor_refs, options);
417            }
418
419            let components = find_tensor_connected_components_with_retained(
420                &tensor_refs,
421                options.retain_indices,
422            );
423            if components.len() > 1 {
424                return Err(anyhow::anyhow!(
425                    "Tensors form disconnected components; use explicit outer_product operations for an intentional disconnected product"
426                ));
427            }
428
429            let mut diag_uf = AxisUnionFind::new();
430            let plan = build_contraction_plan(&tensor_refs, options, &mut diag_uf)?;
431            drop(tensor_refs);
432            let native_operands = tensors
433                .into_iter()
434                .enumerate()
435                .map(|(tensor_idx, tensor)| {
436                    Ok((
437                        tensor.as_native()?.clone(),
438                        plan.input_ids[tensor_idx].clone(),
439                    ))
440                })
441                .collect::<Result<Vec<_>>>()?;
442            let result_native = einsum_native_tensors_owned(native_operands, &plan.output_ids)?;
443            TensorDynLen::from_native_with_axis_classes(
444                plan.result_indices,
445                result_native,
446                plan.result_axis_classes,
447            )
448        }
449    }
450}
451
452fn has_dense_axis_classes(tensor: &TensorDynLen) -> bool {
453    let storage = tensor.storage();
454    storage
455        .axis_classes()
456        .iter()
457        .copied()
458        .eq(0..tensor.indices().len())
459}
460
461fn contract_with_options_impl(
462    tensors: &[&TensorDynLen],
463    options: ContractionOptions<'_>,
464) -> Result<TensorDynLen> {
465    match tensors.len() {
466        0 => Err(anyhow::anyhow!("No tensors to contract")),
467        _ => {
468            validate_retained_indices_exist(tensors, options.retain_indices)?;
469            if tensors.len() == 1 {
470                return Ok((*tensors[0]).clone());
471            }
472
473            // Check connectivity first
474            let components =
475                find_tensor_connected_components_with_retained(tensors, options.retain_indices);
476            if components.len() > 1 {
477                return Err(anyhow::anyhow!(
478                    "Disconnected tensor network: {} components found",
479                    components.len()
480                ));
481            }
482            // Connectivity verified - skip check in impl
483            contract_impl(tensors, options)
484        }
485    }
486}
487
488// ============================================================================
489// Union-Find for Diag axis grouping
490// ============================================================================
491
492/// Union-Find data structure for grouping axis IDs.
493///
494/// Used to merge diagonal axes from Diag tensors so that they share
495/// the same representative ID when passed to einsum.
496#[derive(Debug, Clone)]
497pub struct AxisUnionFind {
498    /// Maps each ID to its parent. If parent[id] == id, it's a root.
499    parent: HashMap<DynId, DynId>,
500    /// Rank for union by rank optimization.
501    rank: HashMap<DynId, usize>,
502}
503
504impl AxisUnionFind {
505    /// Create a new empty union-find structure.
506    pub fn new() -> Self {
507        Self {
508            parent: HashMap::new(),
509            rank: HashMap::new(),
510        }
511    }
512
513    /// Add an ID to the structure (as its own set).
514    pub fn make_set(&mut self, id: DynId) {
515        use std::collections::hash_map::Entry;
516        if let Entry::Vacant(e) = self.parent.entry(id) {
517            e.insert(id);
518            self.rank.insert(id, 0);
519        }
520    }
521
522    /// Find the representative (root) of the set containing `id`.
523    /// Uses path compression for efficiency.
524    pub fn find(&mut self, id: DynId) -> DynId {
525        self.make_set(id);
526        if self.parent[&id] != id {
527            let root = self.find(self.parent[&id]);
528            self.parent.insert(id, root);
529        }
530        self.parent[&id]
531    }
532
533    /// Union the sets containing `a` and `b`.
534    /// Uses union by rank for efficiency.
535    pub fn union(&mut self, a: DynId, b: DynId) {
536        let root_a = self.find(a);
537        let root_b = self.find(b);
538
539        if root_a == root_b {
540            return;
541        }
542
543        let rank_a = self.rank[&root_a];
544        let rank_b = self.rank[&root_b];
545
546        if rank_a < rank_b {
547            self.parent.insert(root_a, root_b);
548        } else if rank_a > rank_b {
549            self.parent.insert(root_b, root_a);
550        } else {
551            self.parent.insert(root_b, root_a);
552            if let Some(rank) = self.rank.get_mut(&root_a) {
553                *rank += 1;
554            }
555        }
556    }
557
558    /// Remap an ID to its representative.
559    pub fn remap(&mut self, id: DynId) -> DynId {
560        self.find(id)
561    }
562
563    /// Remap a slice of IDs to their representatives.
564    pub fn remap_ids(&mut self, ids: &[DynId]) -> Vec<DynId> {
565        ids.iter().map(|id| self.find(*id)).collect()
566    }
567}
568
569impl Default for AxisUnionFind {
570    fn default() -> Self {
571        Self::new()
572    }
573}
574
575// ============================================================================
576// Axis helper builders
577// ============================================================================
578
579/// Build a union-find structure from a collection of tensors.
580///
581/// This helper is kept for callers that need to group diagonal axes by index ID.
582/// Numeric contraction currently keeps dense logical axes distinct and propagates
583/// diagonal result metadata separately.
584pub fn build_diag_union(tensors: &[&TensorDynLen]) -> AxisUnionFind {
585    let mut uf = AxisUnionFind::new();
586
587    for tensor in tensors {
588        for idx in tensor.indices() {
589            uf.make_set(*idx.id());
590        }
591
592        if tensor.is_diag() && tensor.indices().len() >= 2 {
593            let first_id = *tensor.indices()[0].id();
594            for idx in tensor.indices().iter().skip(1) {
595                uf.union(first_id, *idx.id());
596            }
597        }
598    }
599
600    uf
601}
602
603/// Remap tensor indices using the union-find structure.
604///
605/// Returns a vector of remapped IDs for each tensor, suitable for passing
606/// to einsum. The original tensors are not modified.
607pub fn remap_tensor_ids(tensors: &[&TensorDynLen], uf: &mut AxisUnionFind) -> Vec<Vec<DynId>> {
608    tensors
609        .iter()
610        .map(|t| t.indices.iter().map(|idx| uf.find(*idx.id())).collect())
611        .collect()
612}
613
614/// Remap output IDs using the union-find structure.
615pub fn remap_output_ids(output: &[DynIndex], uf: &mut AxisUnionFind) -> Vec<DynId> {
616    output.iter().map(|idx| uf.find(*idx.id())).collect()
617}
618
619/// Collect dimension sizes for remapped IDs.
620///
621/// For unified IDs (from Diag tensors), all axes must have the same dimension,
622/// so we just take the first occurrence.
623pub fn collect_sizes(tensors: &[&TensorDynLen], uf: &mut AxisUnionFind) -> HashMap<DynId, usize> {
624    let mut sizes = HashMap::new();
625
626    for tensor in tensors {
627        let dims = tensor.dims();
628        for (idx, &dim) in tensor.indices.iter().zip(dims.iter()) {
629            let rep = uf.find(*idx.id());
630            sizes.entry(rep).or_insert(dim);
631        }
632    }
633
634    sizes
635}
636
637// ============================================================================
638// Contraction implementation
639// ============================================================================
640
641/// Internal implementation of multi-tensor contraction.
642///
643/// Diagonal tensors are passed as dense native operands for numeric contraction.
644/// Their compact equality metadata is propagated separately onto the result.
645///
646/// This implementation preserves storage type: if all inputs are F64, the result
647/// is F64; if any input is C64, the result is C64.
648fn contract_impl(
649    tensors: &[&TensorDynLen],
650    options: ContractionOptions<'_>,
651) -> Result<TensorDynLen> {
652    // 1. Build union-find over exact matching index IDs. Diagonal equality is
653    // encoded in the dense native values and should not collapse uncontracted
654    // logical axes in the numeric einsum.
655    let mut diag_uf = AxisUnionFind::new();
656
657    // 2. Build the contraction plan from internal labels.
658    let plan = build_contraction_plan(tensors, options, &mut diag_uf)?;
659
660    // Note: Connectivity check is done by caller.
661    // via find_tensor_connected_components before calling this function
662
663    // 3. Build sizes from unique internal IDs.
664    let mut sizes: HashMap<usize, usize> = HashMap::new();
665    for (tensor_idx, tensor) in tensors.iter().enumerate() {
666        let dims = tensor.dims();
667        for (pos, &dim) in dims.iter().enumerate() {
668            let internal_id = plan.input_ids[tensor_idx][pos];
669            match sizes.entry(internal_id) {
670                std::collections::hash_map::Entry::Vacant(entry) => {
671                    entry.insert(dim);
672                }
673                std::collections::hash_map::Entry::Occupied(entry) => {
674                    if *entry.get() != dim {
675                        return Err(anyhow::anyhow!(
676                            "Internal label dimension mismatch: label {} has dimensions {} and {}",
677                            internal_id,
678                            entry.get(),
679                            dim
680                        ));
681                    }
682                }
683            }
684        }
685    }
686
687    let profile_signature = contract_profile_enabled().then(|| ContractSignature {
688        operands: tensors
689            .iter()
690            .enumerate()
691            .map(|(tensor_idx, tensor)| ContractOperandSignature {
692                dims: tensor.dims().to_vec(),
693                ids: plan.input_ids[tensor_idx].clone(),
694                is_diag: tensor.is_diag(),
695            })
696            .collect(),
697        output_ids: plan.output_ids.clone(),
698        output_dims: plan.output_ids.iter().map(|id| sizes[id]).collect(),
699    });
700    let profile_started = contract_profile_enabled().then(Instant::now);
701
702    let result = execute_contraction_plan(tensors, &plan, !options.retain_indices.is_empty())?;
703    if let (Some(signature), Some(started)) = (profile_signature, profile_started) {
704        record_contract_profile(signature, started.elapsed());
705    }
706    Ok(result)
707}
708
709fn execute_contraction_plan(
710    tensors: &[&TensorDynLen],
711    plan: &ContractionPlan,
712    has_retained_indices: bool,
713) -> Result<TensorDynLen> {
714    let any_grad = tensors.iter().any(|tensor| tensor.tracks_grad());
715    let first_dtype = tensors[0].as_native()?.dtype();
716    let same_dtype = tensors
717        .iter()
718        .map(|tensor| {
719            tensor
720                .as_native()
721                .map(|native| native.dtype() == first_dtype)
722        })
723        .collect::<Result<Vec<_>>>()?
724        .into_iter()
725        .all(|same| same);
726    let has_non_dense_axis_classes = tensors.iter().any(|tensor| {
727        tensor
728            .storage()
729            .axis_classes()
730            .iter()
731            .copied()
732            .enumerate()
733            .any(|(axis, class)| axis != class)
734    });
735
736    if any_grad && same_dtype && has_non_dense_axis_classes {
737        if has_retained_indices {
738            return Err(anyhow::anyhow!(
739                "Retained AD contraction with structured storage is not yet supported"
740            ));
741        }
742
743        // Structured payload AD still relies on the existing pairwise structured
744        // path until structured N-ary planning is implemented.
745        let mut iter = tensors.iter();
746        let Some(first) = iter.next() else {
747            return Err(anyhow::anyhow!("No tensors to contract"));
748        };
749        let mut result = (*first).clone();
750        for tensor in iter {
751            result = contract_pair(&result, tensor)?;
752        }
753        return Ok(result);
754    }
755
756    if any_grad && same_dtype {
757        let operands = tensors
758            .iter()
759            .map(|tensor| tensor.as_inner())
760            .collect::<Result<Vec<_>>>()?;
761        let subscripts = build_einsum_subscripts_from_usize_ids(&plan.input_ids, &plan.output_ids)?;
762        let result = eager_einsum_ad(&operands, &subscripts)?;
763        return TensorDynLen::from_inner_with_axis_classes(
764            plan.result_indices.clone(),
765            result,
766            plan.result_axis_classes.clone(),
767        );
768    }
769
770    let native_operands: Vec<_> = tensors
771        .iter()
772        .enumerate()
773        .map(|(tensor_idx, tensor)| {
774            Ok((tensor.as_native()?, plan.input_ids[tensor_idx].as_slice()))
775        })
776        .collect::<Result<Vec<_>>>()?;
777    let result_native = einsum_native_tensors(&native_operands, &plan.output_ids)?;
778    TensorDynLen::from_native_with_axis_classes(
779        plan.result_indices.clone(),
780        result_native,
781        plan.result_axis_classes.clone(),
782    )
783}
784
785fn build_einsum_subscripts_from_usize_ids(
786    input_ids: &[Vec<usize>],
787    output_ids: &[usize],
788) -> Result<EinsumSubscripts> {
789    let inputs = input_ids
790        .iter()
791        .map(|ids| {
792            ids.iter()
793                .map(|&id| {
794                    u32::try_from(id)
795                        .map_err(|_| anyhow::anyhow!("einsum label {id} exceeds u32 range"))
796                })
797                .collect::<Result<Vec<_>>>()
798        })
799        .collect::<Result<Vec<_>>>()?;
800    let output = output_ids
801        .iter()
802        .map(|&id| {
803            u32::try_from(id).map_err(|_| anyhow::anyhow!("einsum label {id} exceeds u32 range"))
804        })
805        .collect::<Result<Vec<_>>>()?;
806    let input_refs = inputs.iter().map(Vec::as_slice).collect::<Vec<_>>();
807    Ok(EinsumSubscripts::new(&input_refs, &output))
808}
809
810/// A contraction plan with internal labels and result ordering.
811#[derive(Debug, Clone)]
812struct ContractionPlan {
813    input_ids: Vec<Vec<usize>>,
814    output_ids: Vec<usize>,
815    result_indices: Vec<DynIndex>,
816    result_axis_classes: Vec<usize>,
817}
818
819fn build_contraction_plan(
820    tensors: &[&TensorDynLen],
821    options: ContractionOptions<'_>,
822    diag_uf: &mut AxisUnionFind,
823) -> Result<ContractionPlan> {
824    let retained_indices: HashSet<DynIndex> = options.retain_indices.iter().cloned().collect();
825    let (input_ids, internal_id_to_original) =
826        build_internal_ids(tensors, diag_uf, &retained_indices)?;
827
828    let mut counts: HashMap<usize, usize> = HashMap::new();
829    for ids in &input_ids {
830        for &internal_id in ids {
831            *counts.entry(internal_id).or_insert(0) += 1;
832        }
833    }
834    let mut output_ids = Vec::new();
835    let mut seen_output = HashSet::new();
836    let mut found_retained = HashSet::new();
837
838    for (tensor_idx, tensor) in tensors.iter().enumerate() {
839        for (axis, idx) in tensor.indices.iter().enumerate() {
840            let internal_id = input_ids[tensor_idx][axis];
841            let should_output = counts[&internal_id] == 1 || retained_indices.contains(idx);
842            if should_output && seen_output.insert(internal_id) {
843                output_ids.push(internal_id);
844            }
845            if retained_indices.contains(idx) {
846                found_retained.insert(idx.clone());
847            }
848        }
849    }
850
851    for retained in retained_indices {
852        if !found_retained.contains(&retained) {
853            return Err(anyhow::anyhow!(
854                "Retained index {:?} does not appear in the input tensors",
855                retained
856            ));
857        }
858    }
859
860    let result_indices: Vec<DynIndex> = output_ids
861        .iter()
862        .map(|&internal_id| {
863            let (tensor_idx, pos) = internal_id_to_original[&internal_id];
864            tensors[tensor_idx].indices[pos].clone()
865        })
866        .collect();
867    validate_unique_output_indices(&result_indices)?;
868    let result_axis_classes =
869        output_axis_classes(tensors, &input_ids, &output_ids, &internal_id_to_original);
870
871    Ok(ContractionPlan {
872        input_ids,
873        output_ids,
874        result_indices,
875        result_axis_classes,
876    })
877}
878
879fn validate_retained_indices_exist(
880    tensors: &[&TensorDynLen],
881    retain_indices: &[DynIndex],
882) -> Result<()> {
883    for retain in retain_indices {
884        let found = tensors
885            .iter()
886            .any(|tensor| tensor.indices().iter().any(|idx| idx == retain));
887        if !found {
888            return Err(anyhow::anyhow!(
889                "Retained index {:?} does not appear in the input tensors",
890                retain
891            ));
892        }
893    }
894    Ok(())
895}
896
897fn validate_unique_output_indices(indices: &[DynIndex]) -> Result<()> {
898    let mut seen = HashSet::new();
899    for idx in indices {
900        if !seen.insert(idx.clone()) {
901            return Err(anyhow::anyhow!(
902                "Contraction result would contain duplicate output indices"
903            ));
904        }
905    }
906    Ok(())
907}
908
909fn output_axis_classes(
910    tensors: &[&TensorDynLen],
911    ixs: &[Vec<usize>],
912    output: &[usize],
913    internal_id_to_original: &HashMap<usize, (usize, usize)>,
914) -> Vec<usize> {
915    fn find(parent: &mut [usize], value: usize) -> usize {
916        if parent[value] != value {
917            parent[value] = find(parent, parent[value]);
918        }
919        parent[value]
920    }
921
922    fn union(parent: &mut [usize], lhs: usize, rhs: usize) {
923        let lhs_root = find(parent, lhs);
924        let rhs_root = find(parent, rhs);
925        if lhs_root != rhs_root {
926            parent[rhs_root] = lhs_root;
927        }
928    }
929
930    let mut class_offsets = Vec::with_capacity(tensors.len());
931    let mut next_node = 0usize;
932    for tensor in tensors {
933        class_offsets.push(next_node);
934        let payload_rank = tensor
935            .storage()
936            .axis_classes()
937            .iter()
938            .copied()
939            .max()
940            .map(|value| value + 1)
941            .unwrap_or(0);
942        next_node += payload_rank;
943    }
944    let mut parent: Vec<usize> = (0..next_node).collect();
945    let mut axes_by_internal_id: HashMap<usize, Vec<usize>> = HashMap::new();
946
947    for (tensor_idx, tensor) in tensors.iter().enumerate() {
948        for (axis, &internal_id) in ixs[tensor_idx].iter().enumerate() {
949            let class_id = tensor.storage().axis_classes()[axis];
950            let node = class_offsets[tensor_idx] + class_id;
951            axes_by_internal_id
952                .entry(internal_id)
953                .or_default()
954                .push(node);
955        }
956    }
957
958    for nodes in axes_by_internal_id.values() {
959        if let Some((&first, rest)) = nodes.split_first() {
960            for &node in rest {
961                union(&mut parent, first, node);
962            }
963        }
964    }
965
966    let mut root_to_class = HashMap::new();
967    let mut next_class = 0usize;
968    output
969        .iter()
970        .map(|internal_id| {
971            let (tensor_idx, axis) = internal_id_to_original[internal_id];
972            let class_id = tensors[tensor_idx].storage().axis_classes()[axis];
973            let node = class_offsets[tensor_idx] + class_id;
974            let root = find(&mut parent, node);
975            *root_to_class.entry(root).or_insert_with(|| {
976                let class = next_class;
977                next_class += 1;
978                class
979            })
980        })
981        .collect()
982}
983
984/// Build internal IDs for numeric contraction.
985///
986/// Uses the union-find to merge IDs that have already been proven equivalent by
987/// the caller. Diagonal logical-axis metadata is intentionally handled outside
988/// this numeric labeling step.
989///
990/// Returns: (ixs, internal_id_to_original)
991#[allow(clippy::type_complexity)]
992fn build_internal_ids(
993    tensors: &[&TensorDynLen],
994    _diag_uf: &mut AxisUnionFind,
995    retained_indices: &HashSet<DynIndex>,
996) -> Result<(Vec<Vec<usize>>, HashMap<usize, (usize, usize)>)> {
997    let mut next_id = 0usize;
998    let mut index_to_internal: HashMap<DynIndex, usize> = HashMap::new();
999    let mut retained_index_to_internal: HashMap<DynIndex, usize> = HashMap::new();
1000    let mut assigned: HashMap<(usize, usize), usize> = HashMap::new();
1001    let mut internal_id_to_original: HashMap<usize, (usize, usize)> = HashMap::new();
1002
1003    for ti in 0..tensors.len() {
1004        for tj in (ti + 1)..tensors.len() {
1005            for (pi, idx_i) in tensors[ti].indices.iter().enumerate() {
1006                for (pj, idx_j) in tensors[tj].indices.iter().enumerate() {
1007                    if idx_i.is_contractable(idx_j) {
1008                        let key_i = (ti, pi);
1009                        let key_j = (tj, pj);
1010
1011                        match (assigned.get(&key_i).copied(), assigned.get(&key_j).copied()) {
1012                            (None, None) => {
1013                                let internal_id = if let Some(&id) = index_to_internal.get(idx_i) {
1014                                    id
1015                                } else {
1016                                    let id = next_id;
1017                                    next_id += 1;
1018                                    index_to_internal.insert(idx_i.clone(), id);
1019                                    internal_id_to_original.insert(id, key_i);
1020                                    id
1021                                };
1022                                assigned.insert(key_i, internal_id);
1023                                assigned.insert(key_j, internal_id);
1024                                if idx_i != idx_j {
1025                                    index_to_internal.insert(idx_j.clone(), internal_id);
1026                                }
1027                            }
1028                            (Some(id), None) => {
1029                                assigned.insert(key_j, id);
1030                                index_to_internal.insert(idx_j.clone(), id);
1031                            }
1032                            (None, Some(id)) => {
1033                                assigned.insert(key_i, id);
1034                                index_to_internal.insert(idx_i.clone(), id);
1035                            }
1036                            (Some(_id_i), Some(_id_j)) => {
1037                                // Both already assigned
1038                            }
1039                        }
1040                    }
1041                }
1042            }
1043        }
1044    }
1045
1046    // Assign IDs for unassigned indices (external indices)
1047    for (tensor_idx, tensor) in tensors.iter().enumerate() {
1048        for (pos, idx) in tensor.indices.iter().enumerate() {
1049            let key = (tensor_idx, pos);
1050            if let std::collections::hash_map::Entry::Vacant(e) = assigned.entry(key) {
1051                let internal_id = if retained_indices.contains(idx) {
1052                    if let Some(&id) = retained_index_to_internal.get(idx) {
1053                        id
1054                    } else {
1055                        let id = next_id;
1056                        next_id += 1;
1057                        retained_index_to_internal.insert(idx.clone(), id);
1058                        internal_id_to_original.insert(id, key);
1059                        id
1060                    }
1061                } else {
1062                    let id = next_id;
1063                    next_id += 1;
1064                    internal_id_to_original.insert(id, key);
1065                    id
1066                };
1067                e.insert(internal_id);
1068            }
1069        }
1070    }
1071
1072    // Build ixs
1073    let ixs: Vec<Vec<usize>> = tensors
1074        .iter()
1075        .enumerate()
1076        .map(|(tensor_idx, tensor)| {
1077            (0..tensor.indices.len())
1078                .map(|pos| assigned[&(tensor_idx, pos)])
1079                .collect()
1080        })
1081        .collect();
1082
1083    Ok((ixs, internal_id_to_original))
1084}
1085
1086// ============================================================================
1087// Helper functions for connected component detection
1088// ============================================================================
1089
1090/// Check if two tensors have any contractable indices.
1091fn has_contractable_indices(a: &TensorDynLen, b: &TensorDynLen) -> bool {
1092    a.indices
1093        .iter()
1094        .any(|idx_a| b.indices.iter().any(|idx_b| idx_a.is_contractable(idx_b)))
1095}
1096
1097/// Find connected components of tensors based on contractable indices.
1098///
1099/// Uses petgraph for O(V+E) connected component detection.
1100#[allow(dead_code)]
1101fn find_tensor_connected_components(tensors: &[&TensorDynLen]) -> Vec<Vec<usize>> {
1102    find_tensor_connected_components_with_retained(tensors, &[])
1103}
1104
1105fn find_tensor_connected_components_with_retained(
1106    tensors: &[&TensorDynLen],
1107    retain_indices: &[DynIndex],
1108) -> Vec<Vec<usize>> {
1109    let n = tensors.len();
1110    if n == 0 {
1111        return vec![];
1112    }
1113    if n == 1 {
1114        return vec![vec![0]];
1115    }
1116
1117    // Build undirected graph
1118    let mut graph = UnGraph::<(), ()>::new_undirected();
1119    let nodes: Vec<_> = (0..n).map(|_| graph.add_node(())).collect();
1120
1121    for i in 0..n {
1122        for j in (i + 1)..n {
1123            if has_contractable_indices(tensors[i], tensors[j]) {
1124                graph.add_edge(nodes[i], nodes[j], ());
1125            }
1126        }
1127    }
1128
1129    if !retain_indices.is_empty() {
1130        for i in 0..n {
1131            for j in (i + 1)..n {
1132                if shares_retained_index(tensors[i], tensors[j], retain_indices) {
1133                    graph.add_edge(nodes[i], nodes[j], ());
1134                }
1135            }
1136        }
1137    }
1138
1139    // Find connected components using petgraph
1140    let num_components = connected_components(&graph);
1141
1142    if num_components == 1 {
1143        return vec![(0..n).collect()];
1144    }
1145
1146    // Multiple components - group by component ID
1147    use petgraph::visit::Dfs;
1148    let mut visited = vec![false; n];
1149    let mut components = Vec::new();
1150
1151    for start in 0..n {
1152        if !visited[start] {
1153            let mut component = Vec::new();
1154            let mut dfs = Dfs::new(&graph, nodes[start]);
1155            while let Some(node) = dfs.next(&graph) {
1156                let idx = node.index();
1157                if !visited[idx] {
1158                    visited[idx] = true;
1159                    component.push(idx);
1160                }
1161            }
1162            component.sort();
1163            components.push(component);
1164        }
1165    }
1166
1167    components.sort_by_key(|c| c[0]);
1168    components
1169}
1170
1171fn shares_retained_index(a: &TensorDynLen, b: &TensorDynLen, retain_indices: &[DynIndex]) -> bool {
1172    retain_indices.iter().any(|retain| {
1173        a.indices().iter().any(|idx_a| idx_a == retain)
1174            && b.indices().iter().any(|idx_b| idx_b == retain)
1175    })
1176}
1177
1178#[cfg(test)]
1179mod tests;