Skip to main content

tensor4all_core/defaults/
contract.rs

1//! Multi-tensor contraction with optimal contraction order.
2//!
3//! This module provides functions to contract multiple tensors efficiently
4//! using einsum optimization via the tensorbackend
5//! (tenferro-backed implementation).
6//!
7//! This module works with concrete types (`DynIndex`, `TensorDynLen`) only.
8//!
9//! # Main Functions
10//!
11//! - [`contract_multi`]: Contracts tensors, handling disconnected components via outer product
12//! - [`contract_connected`]: Contracts tensors that must form a connected graph
13//!
14//! # Diag Tensor Handling
15//!
16//! Diagonal tensors are materialized as dense native operands for contraction,
17//! so numeric einsum labels must keep uncontracted logical axes distinct.
18//! Diagonal/structured equality metadata is propagated separately onto the
19//! result when the contraction leaves equal axes behind.
20
21use std::cell::RefCell;
22use std::cmp::Reverse;
23use std::collections::{HashMap, HashSet};
24use std::env;
25use std::time::{Duration, Instant};
26
27use anyhow::Result;
28use petgraph::algo::connected_components;
29use petgraph::prelude::*;
30use tenferro::eager_einsum::eager_einsum_ad;
31use tensor4all_tensorbackend::{einsum_native_tensors, einsum_native_tensors_owned};
32
33use crate::defaults::{DynId, DynIndex, TensorDynLen};
34
35use crate::index_like::IndexLike;
36use crate::tensor_like::AllowedPairs;
37
38#[derive(Debug, Clone, Hash, PartialEq, Eq)]
39struct ContractOperandSignature {
40    dims: Vec<usize>,
41    ids: Vec<usize>,
42    is_diag: bool,
43}
44
45#[derive(Debug, Clone, Hash, PartialEq, Eq)]
46struct ContractSignature {
47    operands: Vec<ContractOperandSignature>,
48    output_ids: Vec<usize>,
49    output_dims: Vec<usize>,
50}
51
52#[derive(Debug, Default, Clone)]
53struct ContractProfileEntry {
54    calls: usize,
55    total_time: Duration,
56}
57
58thread_local! {
59    static CONTRACT_PROFILE_STATE: RefCell<HashMap<ContractSignature, ContractProfileEntry>> =
60        RefCell::new(HashMap::new());
61}
62
63fn contract_profile_enabled() -> bool {
64    env::var("T4A_PROFILE_CONTRACT").is_ok()
65}
66
67fn record_contract_profile(signature: ContractSignature, elapsed: Duration) {
68    if !contract_profile_enabled() {
69        return;
70    }
71    CONTRACT_PROFILE_STATE.with(|state| {
72        let mut state = state.borrow_mut();
73        let entry = state.entry(signature).or_default();
74        entry.calls += 1;
75        entry.total_time += elapsed;
76    });
77}
78
79/// Reset the aggregated multi-tensor contraction profile.
80pub fn reset_contract_profile() {
81    CONTRACT_PROFILE_STATE.with(|state| state.borrow_mut().clear());
82}
83
84/// Print and clear the aggregated multi-tensor contraction profile.
85pub fn print_and_reset_contract_profile() {
86    if !contract_profile_enabled() {
87        return;
88    }
89    CONTRACT_PROFILE_STATE.with(|state| {
90        let mut entries: Vec<_> = state
91            .borrow()
92            .iter()
93            .map(|(k, v)| (k.clone(), v.clone()))
94            .collect();
95        state.borrow_mut().clear();
96        entries.sort_by_key(|(_, entry)| Reverse(entry.total_time));
97
98        eprintln!("=== contract_multi Profile ===");
99        for (idx, (signature, entry)) in entries.into_iter().take(20).enumerate() {
100            let operands = signature
101                .operands
102                .iter()
103                .map(|operand| {
104                    format!(
105                        "dims={:?} ids={:?}{}",
106                        operand.dims,
107                        operand.ids,
108                        if operand.is_diag { " diag" } else { "" }
109                    )
110                })
111                .collect::<Vec<_>>()
112                .join(" ; ");
113            eprintln!(
114                "#{idx:02} calls={} total={:.3}s per_call={:.3}us output_dims={:?} output_ids={:?}",
115                entry.calls,
116                entry.total_time.as_secs_f64(),
117                entry.total_time.as_secs_f64() * 1e6 / entry.calls as f64,
118                signature.output_dims,
119                signature.output_ids,
120            );
121            eprintln!("     {operands}");
122        }
123    });
124}
125
126// ============================================================================
127// Public API
128// ============================================================================
129
130/// Options for multi-tensor contraction.
131///
132/// Use this to choose which tensor pairs may contract and which shared indices
133/// should be retained in the output instead of summed over.
134///
135/// # Examples
136///
137/// ```
138/// use tensor4all_core::{AllowedPairs, ContractionOptions, DynIndex};
139///
140/// let batch = DynIndex::new_dyn(2);
141/// let retain = [batch.clone()];
142/// let options = ContractionOptions::new(AllowedPairs::All).with_retain_indices(&retain);
143///
144/// assert!(matches!(options.allowed, AllowedPairs::All));
145/// assert_eq!(options.retain_indices, &[batch]);
146/// ```
147#[derive(Clone, Copy, Debug)]
148pub struct ContractionOptions<'a> {
149    /// Contractability policy for tensor pairs.
150    pub allowed: AllowedPairs<'a>,
151    /// Indices that should remain in the result even if they appear more than once.
152    pub retain_indices: &'a [DynIndex],
153}
154
155impl<'a> ContractionOptions<'a> {
156    /// Create contraction options with no retained indices.
157    pub fn new(allowed: AllowedPairs<'a>) -> Self {
158        Self {
159            allowed,
160            retain_indices: &[],
161        }
162    }
163
164    /// Set the indices that should be retained in the output.
165    pub fn with_retain_indices(mut self, retain_indices: &'a [DynIndex]) -> Self {
166        self.retain_indices = retain_indices;
167        self
168    }
169}
170
171/// Contract multiple tensors into a single tensor, handling disconnected components.
172///
173/// This function automatically handles disconnected tensor graphs by:
174/// 1. Finding connected components based on contractable indices
175/// 2. Contracting each connected component separately
176/// 3. Combining results using outer product
177///
178/// # Arguments
179/// * `tensors` - Slice of tensors to contract
180/// * `allowed` - Specifies which tensor pairs can have their indices contracted
181///
182/// # Returns
183/// The result of contracting all tensors over allowed contractable indices.
184/// If tensors form disconnected components, they are combined via outer product.
185///
186/// # Behavior by N
187/// - N=0: Error
188/// - N=1: Clone of input
189/// - N>=2: Contract connected components, combine with outer product
190///
191/// # Errors
192/// - `AllowedPairs::Specified` contains a pair with no contractable indices
193///
194/// # Examples
195///
196/// ```
197/// use tensor4all_core::{TensorDynLen, DynIndex, contract_multi, AllowedPairs};
198///
199/// // A[i, j] and B[j, k] share index j — contract to get C[i, k]
200/// let i = DynIndex::new_dyn(2);
201/// let j = DynIndex::new_dyn(3);
202/// let k = DynIndex::new_dyn(4);
203///
204/// let a = TensorDynLen::from_dense(
205///     vec![i.clone(), j.clone()],
206///     vec![1.0_f64; 6],
207/// ).unwrap();
208/// let b = TensorDynLen::from_dense(
209///     vec![j.clone(), k.clone()],
210///     vec![1.0_f64; 12],
211/// ).unwrap();
212///
213/// let c = contract_multi(&[&a, &b], AllowedPairs::All).unwrap();
214/// assert_eq!(c.dims(), vec![2, 4]);
215/// ```
216pub fn contract_multi(
217    tensors: &[&TensorDynLen],
218    allowed: AllowedPairs<'_>,
219) -> Result<TensorDynLen> {
220    contract_multi_with_options(tensors, ContractionOptions::new(allowed))
221}
222
223/// Contract multiple tensors into a single tensor with additional options.
224///
225/// This behaves like [`contract_multi`] but also allows selected shared indices
226/// to be retained in the output.
227///
228/// # Arguments
229/// * `tensors` - Slice of tensors to contract
230/// * `options` - Pair-selection policy and retained indices
231///
232/// # Returns
233/// The contracted tensor, possibly with retained shared indices in the result.
234///
235/// # Errors
236/// Returns an error if:
237/// - no tensors are provided
238/// - `AllowedPairs::Specified` contains a pair with no contractable indices
239/// - a retained index does not appear in the inputs
240/// - a shared internal label has inconsistent dimensions
241///
242/// # Examples
243///
244/// ```
245/// use tensor4all_core::{contract_multi_with_options, AllowedPairs, ContractionOptions, DynIndex, TensorDynLen};
246///
247/// let i = DynIndex::new_dyn(2);
248/// let j = DynIndex::new_dyn(3);
249/// let k = DynIndex::new_dyn(4);
250///
251/// let a = TensorDynLen::from_dense(vec![i.clone(), j.clone()], vec![1.0_f64; 6]).unwrap();
252/// let b = TensorDynLen::from_dense(vec![j.clone(), k.clone()], vec![1.0_f64; 12]).unwrap();
253/// let retain_indices = [j.clone()];
254/// let options = ContractionOptions::new(AllowedPairs::All).with_retain_indices(&retain_indices);
255/// let c = contract_multi_with_options(&[&a, &b], options).unwrap();
256/// assert_eq!(c.dims(), vec![2, 3, 4]);
257/// ```
258pub fn contract_multi_with_options(
259    tensors: &[&TensorDynLen],
260    options: ContractionOptions<'_>,
261) -> Result<TensorDynLen> {
262    match tensors.len() {
263        0 => Err(anyhow::anyhow!("No tensors to contract")),
264        _ => {
265            validate_retained_indices_exist(tensors, options.retain_indices)?;
266            if tensors.len() == 1 {
267                return Ok((*tensors[0]).clone());
268            }
269
270            // Validate AllowedPairs::Specified pairs have contractable indices
271            if let AllowedPairs::Specified(pairs) = options.allowed {
272                for &(i, j) in pairs {
273                    if !has_contractable_indices(tensors[i], tensors[j]) {
274                        return Err(anyhow::anyhow!(
275                            "Specified pair ({}, {}) has no contractable indices",
276                            i,
277                            j
278                        ));
279                    }
280                }
281            }
282
283            // Find connected components
284            let components = find_tensor_connected_components_with_retained(
285                tensors,
286                options.allowed,
287                options.retain_indices,
288            );
289
290            if components.len() == 1 {
291                // All tensors connected - use optimized contraction (skip connectivity check)
292                contract_multi_impl(tensors, options)
293            } else {
294                // Multiple components - contract each and combine with outer product
295                let mut results: Vec<TensorDynLen> = Vec::new();
296                for component in &components {
297                    let component_tensors: Vec<&TensorDynLen> =
298                        component.iter().map(|&i| tensors[i]).collect();
299                    let component_retain_indices =
300                        retained_indices_for_component(tensors, component, options.retain_indices);
301
302                    // Remap AllowedPairs for the component (connectivity already verified)
303                    let remapped_allowed = remap_allowed_pairs(options.allowed, component);
304                    let component_options = ContractionOptions {
305                        allowed: remapped_allowed.as_ref(),
306                        retain_indices: &component_retain_indices,
307                    };
308                    let contracted = contract_multi_impl(&component_tensors, component_options)?;
309                    results.push(contracted);
310                }
311
312                // Combine with outer product
313                let mut results_iter = results.into_iter();
314                let Some(mut result) = results_iter.next() else {
315                    return Err(anyhow::anyhow!("No contracted components produced"));
316                };
317                for other in results_iter {
318                    result = result.outer_product(&other)?;
319                }
320                Ok(result)
321            }
322        }
323    }
324}
325
326/// Contract multiple owned tensors into a single tensor.
327///
328/// This is the consuming counterpart to [`contract_multi_with_options`]. It
329/// preserves the same contraction semantics while allowing eligible non-AD
330/// dense inputs to use tenferro's owned eager einsum executor. When any input
331/// tracks gradients, or when compact structured metadata needs the borrowed
332/// path, this function falls back to the shared borrowed execution so semantics
333/// and reverse-mode AD remain intact.
334///
335/// # Arguments
336/// * `tensors` - Owned tensors to contract.
337/// * `options` - Pair-selection policy and retained indices.
338///
339/// # Returns
340/// The contracted tensor, with retained shared indices preserved in the output.
341///
342/// # Errors
343/// Returns an error for the same conditions as
344/// [`contract_multi_with_options`], including empty input, invalid retained
345/// indices, and incompatible contraction pairs.
346///
347/// # Examples
348///
349/// ```
350/// use tensor4all_core::{contract_multi_owned, contract_multi_with_options, AllowedPairs, ContractionOptions, DynIndex, TensorDynLen};
351///
352/// let i = DynIndex::new_dyn(2);
353/// let j = DynIndex::new_dyn(3);
354/// let k = DynIndex::new_dyn(4);
355/// let a = TensorDynLen::from_dense(vec![i.clone(), j.clone()], vec![1.0_f64; 6]).unwrap();
356/// let b = TensorDynLen::from_dense(vec![j.clone(), k.clone()], vec![1.0_f64; 12]).unwrap();
357/// let options = ContractionOptions::new(AllowedPairs::All);
358///
359/// let owned = contract_multi_owned(vec![a.clone(), b.clone()], options).unwrap();
360/// let borrowed = contract_multi_with_options(&[&a, &b], options).unwrap();
361/// assert_eq!(owned.indices(), borrowed.indices());
362/// assert_eq!(owned.to_vec::<f64>().unwrap(), borrowed.to_vec::<f64>().unwrap());
363/// ```
364pub fn contract_multi_owned(
365    tensors: Vec<TensorDynLen>,
366    options: ContractionOptions<'_>,
367) -> Result<TensorDynLen> {
368    match tensors.len() {
369        0 => Err(anyhow::anyhow!("No tensors to contract")),
370        _ => {
371            let tensor_refs = tensors.iter().collect::<Vec<_>>();
372            validate_retained_indices_exist(&tensor_refs, options.retain_indices)?;
373
374            if tensors.len() == 1 {
375                drop(tensor_refs);
376                let Some(tensor) = tensors.into_iter().next() else {
377                    return Err(anyhow::anyhow!("No tensors to contract"));
378                };
379                return Ok(tensor);
380            }
381
382            if let AllowedPairs::Specified(pairs) = options.allowed {
383                for &(i, j) in pairs {
384                    if !has_contractable_indices(tensor_refs[i], tensor_refs[j]) {
385                        return Err(anyhow::anyhow!(
386                            "Specified pair ({}, {}) has no contractable indices",
387                            i,
388                            j
389                        ));
390                    }
391                }
392            }
393
394            let requires_borrowed_path = tensor_refs.iter().any(|tensor| tensor.tracks_grad())
395                || tensor_refs
396                    .iter()
397                    .any(|tensor| !has_dense_axis_classes(tensor));
398            if requires_borrowed_path {
399                return contract_multi_with_options(&tensor_refs, options);
400            }
401
402            let components = find_tensor_connected_components_with_retained(
403                &tensor_refs,
404                options.allowed,
405                options.retain_indices,
406            );
407            if components.len() > 1 {
408                return contract_multi_with_options(&tensor_refs, options);
409            }
410
411            let mut diag_uf = AxisUnionFind::new();
412            let plan = build_contraction_plan(&tensor_refs, options, &mut diag_uf)?;
413            drop(tensor_refs);
414            let native_operands = tensors
415                .into_iter()
416                .enumerate()
417                .map(|(tensor_idx, tensor)| {
418                    (
419                        tensor.as_native().clone(),
420                        plan.input_ids[tensor_idx].clone(),
421                    )
422                })
423                .collect::<Vec<_>>();
424            let result_native = einsum_native_tensors_owned(native_operands, &plan.output_ids)?;
425            TensorDynLen::from_native_with_axis_classes(
426                plan.result_indices,
427                result_native,
428                plan.result_axis_classes,
429            )
430        }
431    }
432}
433
434fn has_dense_axis_classes(tensor: &TensorDynLen) -> bool {
435    let storage = tensor.storage();
436    storage
437        .axis_classes()
438        .iter()
439        .copied()
440        .eq(0..tensor.indices().len())
441}
442
443/// Contract multiple tensors that form a connected graph.
444///
445/// Uses einsum optimization via tensorbackend.
446///
447/// # Arguments
448/// * `tensors` - Slice of tensors to contract (must form a connected graph)
449/// * `allowed` - Specifies which tensor pairs can have their indices contracted
450///
451/// # Returns
452/// The result of contracting all tensors over allowed contractable indices.
453///
454/// # Connectivity Requirement
455/// All tensors must form a connected graph through contractable indices.
456/// Two tensors are connected if they share a contractable index (same ID, dual direction).
457/// If the tensors form disconnected components, this function returns an error.
458///
459/// Use [`contract_multi`] if you want automatic handling of disconnected components.
460///
461/// # Behavior by N
462/// - N=0: Error
463/// - N=1: Clone of input
464/// - N>=2: Optimized order via the tensorbackend einsum path
465///
466/// # Examples
467///
468/// ```
469/// use tensor4all_core::{TensorDynLen, DynIndex, contract_connected, AllowedPairs};
470///
471/// // A[i, j] contracted with B[j, k]
472/// let i = DynIndex::new_dyn(2);
473/// let j = DynIndex::new_dyn(3);
474/// let k = DynIndex::new_dyn(4);
475///
476/// let a = TensorDynLen::from_dense(
477///     vec![i.clone(), j.clone()],
478///     vec![1.0_f64; 6],
479/// ).unwrap();
480/// let b = TensorDynLen::from_dense(
481///     vec![j.clone(), k.clone()],
482///     vec![1.0_f64; 12],
483/// ).unwrap();
484///
485/// let c = contract_connected(&[&a, &b], AllowedPairs::All).unwrap();
486/// assert_eq!(c.dims(), vec![2, 4]);
487/// ```
488pub fn contract_connected(
489    tensors: &[&TensorDynLen],
490    allowed: AllowedPairs<'_>,
491) -> Result<TensorDynLen> {
492    contract_connected_with_options(tensors, ContractionOptions::new(allowed))
493}
494
495/// Contract a connected tensor network with additional options.
496///
497/// This behaves like [`contract_connected`] but also allows selected shared
498/// indices to be retained in the output.
499///
500/// # Arguments
501/// * `tensors` - Slice of tensors to contract
502/// * `options` - Pair-selection policy and retained indices
503///
504/// # Returns
505/// The contracted tensor.
506///
507/// # Errors
508/// Returns an error if the tensors are disconnected, no tensors are provided,
509/// or retained indices are invalid.
510///
511/// # Examples
512///
513/// ```
514/// use tensor4all_core::{
515///     contract_connected_with_options, AllowedPairs, ContractionOptions, DynIndex, TensorDynLen,
516/// };
517///
518/// let batch = DynIndex::new_dyn(2);
519/// let i = DynIndex::new_dyn(2);
520/// let k = DynIndex::new_dyn(3);
521/// let j = DynIndex::new_dyn(2);
522///
523/// let a = TensorDynLen::from_dense(
524///     vec![batch.clone(), i.clone(), k.clone()],
525///     vec![1.0_f64; 12],
526/// )
527/// .unwrap();
528/// let b = TensorDynLen::from_dense(
529///     vec![batch.clone(), k, j.clone()],
530///     vec![1.0_f64; 12],
531/// )
532/// .unwrap();
533/// let retain = [batch.clone()];
534/// let options = ContractionOptions::new(AllowedPairs::All).with_retain_indices(&retain);
535///
536/// let c = contract_connected_with_options(&[&a, &b], options).unwrap();
537/// assert_eq!(c.indices(), &[batch, i, j]);
538/// assert_eq!(c.to_vec::<f64>().unwrap(), vec![3.0; 8]);
539/// ```
540pub fn contract_connected_with_options(
541    tensors: &[&TensorDynLen],
542    options: ContractionOptions<'_>,
543) -> Result<TensorDynLen> {
544    match tensors.len() {
545        0 => Err(anyhow::anyhow!("No tensors to contract")),
546        _ => {
547            validate_retained_indices_exist(tensors, options.retain_indices)?;
548            if tensors.len() == 1 {
549                return Ok((*tensors[0]).clone());
550            }
551
552            // Check connectivity first
553            let components = find_tensor_connected_components_with_retained(
554                tensors,
555                options.allowed,
556                options.retain_indices,
557            );
558            if components.len() > 1 {
559                return Err(anyhow::anyhow!(
560                    "Disconnected tensor network: {} components found",
561                    components.len()
562                ));
563            }
564            // Connectivity verified - skip check in impl
565            contract_multi_impl(tensors, options)
566        }
567    }
568}
569
570// ============================================================================
571// Union-Find for Diag axis grouping
572// ============================================================================
573
574/// Union-Find data structure for grouping axis IDs.
575///
576/// Used to merge diagonal axes from Diag tensors so that they share
577/// the same representative ID when passed to einsum.
578#[derive(Debug, Clone)]
579pub struct AxisUnionFind {
580    /// Maps each ID to its parent. If parent[id] == id, it's a root.
581    parent: HashMap<DynId, DynId>,
582    /// Rank for union by rank optimization.
583    rank: HashMap<DynId, usize>,
584}
585
586impl AxisUnionFind {
587    /// Create a new empty union-find structure.
588    pub fn new() -> Self {
589        Self {
590            parent: HashMap::new(),
591            rank: HashMap::new(),
592        }
593    }
594
595    /// Add an ID to the structure (as its own set).
596    pub fn make_set(&mut self, id: DynId) {
597        use std::collections::hash_map::Entry;
598        if let Entry::Vacant(e) = self.parent.entry(id) {
599            e.insert(id);
600            self.rank.insert(id, 0);
601        }
602    }
603
604    /// Find the representative (root) of the set containing `id`.
605    /// Uses path compression for efficiency.
606    pub fn find(&mut self, id: DynId) -> DynId {
607        self.make_set(id);
608        if self.parent[&id] != id {
609            let root = self.find(self.parent[&id]);
610            self.parent.insert(id, root);
611        }
612        self.parent[&id]
613    }
614
615    /// Union the sets containing `a` and `b`.
616    /// Uses union by rank for efficiency.
617    pub fn union(&mut self, a: DynId, b: DynId) {
618        let root_a = self.find(a);
619        let root_b = self.find(b);
620
621        if root_a == root_b {
622            return;
623        }
624
625        let rank_a = self.rank[&root_a];
626        let rank_b = self.rank[&root_b];
627
628        if rank_a < rank_b {
629            self.parent.insert(root_a, root_b);
630        } else if rank_a > rank_b {
631            self.parent.insert(root_b, root_a);
632        } else {
633            self.parent.insert(root_b, root_a);
634            if let Some(rank) = self.rank.get_mut(&root_a) {
635                *rank += 1;
636            }
637        }
638    }
639
640    /// Remap an ID to its representative.
641    pub fn remap(&mut self, id: DynId) -> DynId {
642        self.find(id)
643    }
644
645    /// Remap a slice of IDs to their representatives.
646    pub fn remap_ids(&mut self, ids: &[DynId]) -> Vec<DynId> {
647        ids.iter().map(|id| self.find(*id)).collect()
648    }
649}
650
651impl Default for AxisUnionFind {
652    fn default() -> Self {
653        Self::new()
654    }
655}
656
657// ============================================================================
658// Axis helper builders
659// ============================================================================
660
661/// Build a union-find structure from a collection of tensors.
662///
663/// This helper is kept for callers that need to group diagonal axes by index ID.
664/// Numeric contraction currently keeps dense logical axes distinct and propagates
665/// diagonal result metadata separately.
666pub fn build_diag_union(tensors: &[&TensorDynLen]) -> AxisUnionFind {
667    let mut uf = AxisUnionFind::new();
668
669    for tensor in tensors {
670        for idx in tensor.indices() {
671            uf.make_set(*idx.id());
672        }
673
674        if tensor.is_diag() && tensor.indices().len() >= 2 {
675            let first_id = *tensor.indices()[0].id();
676            for idx in tensor.indices().iter().skip(1) {
677                uf.union(first_id, *idx.id());
678            }
679        }
680    }
681
682    uf
683}
684
685/// Remap tensor indices using the union-find structure.
686///
687/// Returns a vector of remapped IDs for each tensor, suitable for passing
688/// to einsum. The original tensors are not modified.
689pub fn remap_tensor_ids(tensors: &[&TensorDynLen], uf: &mut AxisUnionFind) -> Vec<Vec<DynId>> {
690    tensors
691        .iter()
692        .map(|t| t.indices.iter().map(|idx| uf.find(*idx.id())).collect())
693        .collect()
694}
695
696/// Remap output IDs using the union-find structure.
697pub fn remap_output_ids(output: &[DynIndex], uf: &mut AxisUnionFind) -> Vec<DynId> {
698    output.iter().map(|idx| uf.find(*idx.id())).collect()
699}
700
701/// Collect dimension sizes for remapped IDs.
702///
703/// For unified IDs (from Diag tensors), all axes must have the same dimension,
704/// so we just take the first occurrence.
705pub fn collect_sizes(tensors: &[&TensorDynLen], uf: &mut AxisUnionFind) -> HashMap<DynId, usize> {
706    let mut sizes = HashMap::new();
707
708    for tensor in tensors {
709        let dims = tensor.dims();
710        for (idx, &dim) in tensor.indices.iter().zip(dims.iter()) {
711            let rep = uf.find(*idx.id());
712            sizes.entry(rep).or_insert(dim);
713        }
714    }
715
716    sizes
717}
718
719// ============================================================================
720// Contraction implementation
721// ============================================================================
722
723/// Internal implementation of multi-tensor contraction.
724///
725/// Diagonal tensors are passed as dense native operands for numeric contraction.
726/// Their compact equality metadata is propagated separately onto the result.
727///
728/// This implementation preserves storage type: if all inputs are F64, the result
729/// is F64; if any input is C64, the result is C64.
730fn contract_multi_impl(
731    tensors: &[&TensorDynLen],
732    options: ContractionOptions<'_>,
733) -> Result<TensorDynLen> {
734    // 1. Build union-find over exact matching index IDs. Diagonal equality is
735    // encoded in the dense native values and should not collapse uncontracted
736    // logical axes in the numeric einsum.
737    let mut diag_uf = AxisUnionFind::new();
738
739    // 2. Build the contraction plan from internal labels.
740    let plan = build_contraction_plan(tensors, options, &mut diag_uf)?;
741
742    // Note: Connectivity check is done by caller (contract_multi or contract_connected)
743    // via find_tensor_connected_components before calling this function
744
745    // 3. Build sizes from unique internal IDs.
746    let mut sizes: HashMap<usize, usize> = HashMap::new();
747    for (tensor_idx, tensor) in tensors.iter().enumerate() {
748        let dims = tensor.dims();
749        for (pos, &dim) in dims.iter().enumerate() {
750            let internal_id = plan.input_ids[tensor_idx][pos];
751            match sizes.entry(internal_id) {
752                std::collections::hash_map::Entry::Vacant(entry) => {
753                    entry.insert(dim);
754                }
755                std::collections::hash_map::Entry::Occupied(entry) => {
756                    if *entry.get() != dim {
757                        return Err(anyhow::anyhow!(
758                            "Internal label dimension mismatch: label {} has dimensions {} and {}",
759                            internal_id,
760                            entry.get(),
761                            dim
762                        ));
763                    }
764                }
765            }
766        }
767    }
768
769    let profile_signature = contract_profile_enabled().then(|| ContractSignature {
770        operands: tensors
771            .iter()
772            .enumerate()
773            .map(|(tensor_idx, tensor)| ContractOperandSignature {
774                dims: tensor.dims().to_vec(),
775                ids: plan.input_ids[tensor_idx].clone(),
776                is_diag: tensor.is_diag(),
777            })
778            .collect(),
779        output_ids: plan.output_ids.clone(),
780        output_dims: plan.output_ids.iter().map(|id| sizes[id]).collect(),
781    });
782    let profile_started = contract_profile_enabled().then(Instant::now);
783
784    let result = execute_contraction_plan(tensors, &plan, !options.retain_indices.is_empty())?;
785    if let (Some(signature), Some(started)) = (profile_signature, profile_started) {
786        record_contract_profile(signature, started.elapsed());
787    }
788    Ok(result)
789}
790
791fn execute_contraction_plan(
792    tensors: &[&TensorDynLen],
793    plan: &ContractionPlan,
794    has_retained_indices: bool,
795) -> Result<TensorDynLen> {
796    let any_grad = tensors.iter().any(|tensor| tensor.tracks_grad());
797    let first_dtype = tensors[0].as_native().dtype();
798    let same_dtype = tensors
799        .iter()
800        .all(|tensor| tensor.as_native().dtype() == first_dtype);
801    let has_non_dense_axis_classes = tensors.iter().any(|tensor| {
802        tensor
803            .storage()
804            .axis_classes()
805            .iter()
806            .copied()
807            .enumerate()
808            .any(|(axis, class)| axis != class)
809    });
810
811    if any_grad && same_dtype && has_non_dense_axis_classes {
812        if has_retained_indices {
813            return Err(anyhow::anyhow!(
814                "Retained AD contraction with structured storage is not yet supported"
815            ));
816        }
817
818        // Structured payload AD still relies on the existing pairwise structured
819        // path until structured N-ary planning is implemented.
820        let mut iter = tensors.iter();
821        let Some(first) = iter.next() else {
822            return Err(anyhow::anyhow!("No tensors to contract"));
823        };
824        let mut result = (*first).clone();
825        for tensor in iter {
826            result = result.contract_pairwise_default(tensor);
827        }
828        return Ok(result);
829    }
830
831    if any_grad && same_dtype {
832        let operands = tensors
833            .iter()
834            .map(|tensor| tensor.as_inner())
835            .collect::<Vec<_>>();
836        let subscripts = build_einsum_subscripts_from_usize_ids(&plan.input_ids, &plan.output_ids)?;
837        let result = eager_einsum_ad(&operands, &subscripts)?;
838        return TensorDynLen::from_inner_with_axis_classes(
839            plan.result_indices.clone(),
840            result,
841            plan.result_axis_classes.clone(),
842        );
843    }
844
845    let native_operands: Vec<_> = tensors
846        .iter()
847        .enumerate()
848        .map(|(tensor_idx, tensor)| (tensor.as_native(), plan.input_ids[tensor_idx].as_slice()))
849        .collect();
850    let result_native = einsum_native_tensors(&native_operands, &plan.output_ids)?;
851    TensorDynLen::from_native_with_axis_classes(
852        plan.result_indices.clone(),
853        result_native,
854        plan.result_axis_classes.clone(),
855    )
856}
857
858fn build_einsum_subscripts_from_usize_ids(
859    input_ids: &[Vec<usize>],
860    output_ids: &[usize],
861) -> Result<String> {
862    fn ids_to_subscript(ids: &[usize]) -> Result<String> {
863        const LETTERS: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
864        let mut out = String::with_capacity(ids.len());
865        for &id in ids {
866            let letter = LETTERS.get(id).ok_or_else(|| {
867                anyhow::anyhow!("einsum label {id} exceeds supported label range")
868            })?;
869            out.push(char::from(*letter));
870        }
871        Ok(out)
872    }
873
874    let inputs = input_ids
875        .iter()
876        .map(|ids| ids_to_subscript(ids))
877        .collect::<Result<Vec<_>>>()?;
878    Ok(format!(
879        "{}->{}",
880        inputs.join(","),
881        ids_to_subscript(output_ids)?
882    ))
883}
884
885/// A contraction plan with internal labels and result ordering.
886#[derive(Debug, Clone)]
887struct ContractionPlan {
888    input_ids: Vec<Vec<usize>>,
889    output_ids: Vec<usize>,
890    result_indices: Vec<DynIndex>,
891    result_axis_classes: Vec<usize>,
892}
893
894fn build_contraction_plan(
895    tensors: &[&TensorDynLen],
896    options: ContractionOptions<'_>,
897    diag_uf: &mut AxisUnionFind,
898) -> Result<ContractionPlan> {
899    let retained_indices: HashSet<DynIndex> = options.retain_indices.iter().cloned().collect();
900    let (input_ids, internal_id_to_original) =
901        build_internal_ids(tensors, options.allowed, diag_uf, &retained_indices)?;
902
903    let mut counts: HashMap<usize, usize> = HashMap::new();
904    for ids in &input_ids {
905        for &internal_id in ids {
906            *counts.entry(internal_id).or_insert(0) += 1;
907        }
908    }
909    let mut output_ids = Vec::new();
910    let mut seen_output = HashSet::new();
911    let mut found_retained = HashSet::new();
912
913    for (tensor_idx, tensor) in tensors.iter().enumerate() {
914        for (axis, idx) in tensor.indices.iter().enumerate() {
915            let internal_id = input_ids[tensor_idx][axis];
916            let should_output = counts[&internal_id] == 1 || retained_indices.contains(idx);
917            if should_output && seen_output.insert(internal_id) {
918                output_ids.push(internal_id);
919            }
920            if retained_indices.contains(idx) {
921                found_retained.insert(idx.clone());
922            }
923        }
924    }
925
926    for retained in retained_indices {
927        if !found_retained.contains(&retained) {
928            return Err(anyhow::anyhow!(
929                "Retained index {:?} does not appear in the input tensors",
930                retained
931            ));
932        }
933    }
934
935    let result_indices: Vec<DynIndex> = output_ids
936        .iter()
937        .map(|&internal_id| {
938            let (tensor_idx, pos) = internal_id_to_original[&internal_id];
939            tensors[tensor_idx].indices[pos].clone()
940        })
941        .collect();
942    validate_unique_output_indices(&result_indices)?;
943    let result_axis_classes =
944        output_axis_classes(tensors, &input_ids, &output_ids, &internal_id_to_original);
945
946    Ok(ContractionPlan {
947        input_ids,
948        output_ids,
949        result_indices,
950        result_axis_classes,
951    })
952}
953
954fn validate_retained_indices_exist(
955    tensors: &[&TensorDynLen],
956    retain_indices: &[DynIndex],
957) -> Result<()> {
958    for retain in retain_indices {
959        let found = tensors
960            .iter()
961            .any(|tensor| tensor.indices().iter().any(|idx| idx == retain));
962        if !found {
963            return Err(anyhow::anyhow!(
964                "Retained index {:?} does not appear in the input tensors",
965                retain
966            ));
967        }
968    }
969    Ok(())
970}
971
972fn retained_indices_for_component(
973    tensors: &[&TensorDynLen],
974    component: &[usize],
975    retain_indices: &[DynIndex],
976) -> Vec<DynIndex> {
977    let mut seen = HashSet::new();
978    let mut retained = Vec::new();
979    for retain in retain_indices {
980        if seen.insert(retain.clone())
981            && component.iter().any(|&tensor_idx| {
982                tensors[tensor_idx]
983                    .indices()
984                    .iter()
985                    .any(|idx| idx == retain)
986            })
987        {
988            retained.push(retain.clone());
989        }
990    }
991    retained
992}
993
994fn validate_unique_output_indices(indices: &[DynIndex]) -> Result<()> {
995    let mut seen = HashSet::new();
996    for idx in indices {
997        if !seen.insert(idx.clone()) {
998            return Err(anyhow::anyhow!(
999                "Contraction result would contain duplicate output indices"
1000            ));
1001        }
1002    }
1003    Ok(())
1004}
1005
1006fn output_axis_classes(
1007    tensors: &[&TensorDynLen],
1008    ixs: &[Vec<usize>],
1009    output: &[usize],
1010    internal_id_to_original: &HashMap<usize, (usize, usize)>,
1011) -> Vec<usize> {
1012    fn find(parent: &mut [usize], value: usize) -> usize {
1013        if parent[value] != value {
1014            parent[value] = find(parent, parent[value]);
1015        }
1016        parent[value]
1017    }
1018
1019    fn union(parent: &mut [usize], lhs: usize, rhs: usize) {
1020        let lhs_root = find(parent, lhs);
1021        let rhs_root = find(parent, rhs);
1022        if lhs_root != rhs_root {
1023            parent[rhs_root] = lhs_root;
1024        }
1025    }
1026
1027    let mut class_offsets = Vec::with_capacity(tensors.len());
1028    let mut next_node = 0usize;
1029    for tensor in tensors {
1030        class_offsets.push(next_node);
1031        let payload_rank = tensor
1032            .storage()
1033            .axis_classes()
1034            .iter()
1035            .copied()
1036            .max()
1037            .map(|value| value + 1)
1038            .unwrap_or(0);
1039        next_node += payload_rank;
1040    }
1041    let mut parent: Vec<usize> = (0..next_node).collect();
1042    let mut axes_by_internal_id: HashMap<usize, Vec<usize>> = HashMap::new();
1043
1044    for (tensor_idx, tensor) in tensors.iter().enumerate() {
1045        for (axis, &internal_id) in ixs[tensor_idx].iter().enumerate() {
1046            let class_id = tensor.storage().axis_classes()[axis];
1047            let node = class_offsets[tensor_idx] + class_id;
1048            axes_by_internal_id
1049                .entry(internal_id)
1050                .or_default()
1051                .push(node);
1052        }
1053    }
1054
1055    for nodes in axes_by_internal_id.values() {
1056        if let Some((&first, rest)) = nodes.split_first() {
1057            for &node in rest {
1058                union(&mut parent, first, node);
1059            }
1060        }
1061    }
1062
1063    let mut root_to_class = HashMap::new();
1064    let mut next_class = 0usize;
1065    output
1066        .iter()
1067        .map(|internal_id| {
1068            let (tensor_idx, axis) = internal_id_to_original[internal_id];
1069            let class_id = tensors[tensor_idx].storage().axis_classes()[axis];
1070            let node = class_offsets[tensor_idx] + class_id;
1071            let root = find(&mut parent, node);
1072            *root_to_class.entry(root).or_insert_with(|| {
1073                let class = next_class;
1074                next_class += 1;
1075                class
1076            })
1077        })
1078        .collect()
1079}
1080
1081/// Build internal IDs for numeric contraction.
1082///
1083/// Uses the union-find to merge IDs that have already been proven equivalent by
1084/// the caller. Diagonal logical-axis metadata is intentionally handled outside
1085/// this numeric labeling step.
1086///
1087/// Returns: (ixs, internal_id_to_original)
1088#[allow(clippy::type_complexity)]
1089fn build_internal_ids(
1090    tensors: &[&TensorDynLen],
1091    allowed: AllowedPairs<'_>,
1092    _diag_uf: &mut AxisUnionFind,
1093    retained_indices: &HashSet<DynIndex>,
1094) -> Result<(Vec<Vec<usize>>, HashMap<usize, (usize, usize)>)> {
1095    let mut next_id = 0usize;
1096    let mut index_to_internal: HashMap<DynIndex, usize> = HashMap::new();
1097    let mut retained_index_to_internal: HashMap<DynIndex, usize> = HashMap::new();
1098    let mut assigned: HashMap<(usize, usize), usize> = HashMap::new();
1099    let mut internal_id_to_original: HashMap<usize, (usize, usize)> = HashMap::new();
1100
1101    // Process contractable pairs
1102    let pairs_to_process: Vec<(usize, usize)> = match allowed {
1103        AllowedPairs::All => {
1104            let mut pairs = Vec::new();
1105            for ti in 0..tensors.len() {
1106                for tj in (ti + 1)..tensors.len() {
1107                    pairs.push((ti, tj));
1108                }
1109            }
1110            pairs
1111        }
1112        AllowedPairs::Specified(pairs) => pairs.to_vec(),
1113    };
1114
1115    for (ti, tj) in pairs_to_process {
1116        for (pi, idx_i) in tensors[ti].indices.iter().enumerate() {
1117            for (pj, idx_j) in tensors[tj].indices.iter().enumerate() {
1118                if idx_i.is_contractable(idx_j) {
1119                    let key_i = (ti, pi);
1120                    let key_j = (tj, pj);
1121
1122                    match (assigned.get(&key_i).copied(), assigned.get(&key_j).copied()) {
1123                        (None, None) => {
1124                            let internal_id = if let Some(&id) = index_to_internal.get(idx_i) {
1125                                id
1126                            } else {
1127                                let id = next_id;
1128                                next_id += 1;
1129                                index_to_internal.insert(idx_i.clone(), id);
1130                                internal_id_to_original.insert(id, key_i);
1131                                id
1132                            };
1133                            assigned.insert(key_i, internal_id);
1134                            assigned.insert(key_j, internal_id);
1135                            if idx_i != idx_j {
1136                                index_to_internal.insert(idx_j.clone(), internal_id);
1137                            }
1138                        }
1139                        (Some(id), None) => {
1140                            assigned.insert(key_j, id);
1141                            index_to_internal.insert(idx_j.clone(), id);
1142                        }
1143                        (None, Some(id)) => {
1144                            assigned.insert(key_i, id);
1145                            index_to_internal.insert(idx_i.clone(), id);
1146                        }
1147                        (Some(_id_i), Some(_id_j)) => {
1148                            // Both already assigned
1149                        }
1150                    }
1151                }
1152            }
1153        }
1154    }
1155
1156    // Assign IDs for unassigned indices (external indices)
1157    for (tensor_idx, tensor) in tensors.iter().enumerate() {
1158        for (pos, idx) in tensor.indices.iter().enumerate() {
1159            let key = (tensor_idx, pos);
1160            if let std::collections::hash_map::Entry::Vacant(e) = assigned.entry(key) {
1161                let internal_id = if retained_indices.contains(idx) {
1162                    if let Some(&id) = retained_index_to_internal.get(idx) {
1163                        id
1164                    } else {
1165                        let id = next_id;
1166                        next_id += 1;
1167                        retained_index_to_internal.insert(idx.clone(), id);
1168                        internal_id_to_original.insert(id, key);
1169                        id
1170                    }
1171                } else {
1172                    let id = next_id;
1173                    next_id += 1;
1174                    internal_id_to_original.insert(id, key);
1175                    id
1176                };
1177                e.insert(internal_id);
1178            }
1179        }
1180    }
1181
1182    // Build ixs
1183    let ixs: Vec<Vec<usize>> = tensors
1184        .iter()
1185        .enumerate()
1186        .map(|(tensor_idx, tensor)| {
1187            (0..tensor.indices.len())
1188                .map(|pos| assigned[&(tensor_idx, pos)])
1189                .collect()
1190        })
1191        .collect();
1192
1193    Ok((ixs, internal_id_to_original))
1194}
1195
1196// ============================================================================
1197// Helper functions for connected component detection
1198// ============================================================================
1199
1200/// Check if two tensors have any contractable indices.
1201fn has_contractable_indices(a: &TensorDynLen, b: &TensorDynLen) -> bool {
1202    a.indices
1203        .iter()
1204        .any(|idx_a| b.indices.iter().any(|idx_b| idx_a.is_contractable(idx_b)))
1205}
1206
1207/// Find connected components of tensors based on contractable indices.
1208///
1209/// Uses petgraph for O(V+E) connected component detection.
1210#[allow(dead_code)]
1211fn find_tensor_connected_components(
1212    tensors: &[&TensorDynLen],
1213    allowed: AllowedPairs<'_>,
1214) -> Vec<Vec<usize>> {
1215    find_tensor_connected_components_with_retained(tensors, allowed, &[])
1216}
1217
1218fn find_tensor_connected_components_with_retained(
1219    tensors: &[&TensorDynLen],
1220    allowed: AllowedPairs<'_>,
1221    retain_indices: &[DynIndex],
1222) -> Vec<Vec<usize>> {
1223    let n = tensors.len();
1224    if n == 0 {
1225        return vec![];
1226    }
1227    if n == 1 {
1228        return vec![vec![0]];
1229    }
1230
1231    // Build undirected graph
1232    let mut graph = UnGraph::<(), ()>::new_undirected();
1233    let nodes: Vec<_> = (0..n).map(|_| graph.add_node(())).collect();
1234
1235    // Add edges based on connectivity
1236    match allowed {
1237        AllowedPairs::All => {
1238            for i in 0..n {
1239                for j in (i + 1)..n {
1240                    if has_contractable_indices(tensors[i], tensors[j]) {
1241                        graph.add_edge(nodes[i], nodes[j], ());
1242                    }
1243                }
1244            }
1245        }
1246        AllowedPairs::Specified(pairs) => {
1247            for &(i, j) in pairs {
1248                if has_contractable_indices(tensors[i], tensors[j]) {
1249                    graph.add_edge(nodes[i], nodes[j], ());
1250                }
1251            }
1252        }
1253    }
1254
1255    if !retain_indices.is_empty() {
1256        for i in 0..n {
1257            for j in (i + 1)..n {
1258                if shares_retained_index(tensors[i], tensors[j], retain_indices) {
1259                    graph.add_edge(nodes[i], nodes[j], ());
1260                }
1261            }
1262        }
1263    }
1264
1265    // Find connected components using petgraph
1266    let num_components = connected_components(&graph);
1267
1268    if num_components == 1 {
1269        return vec![(0..n).collect()];
1270    }
1271
1272    // Multiple components - group by component ID
1273    use petgraph::visit::Dfs;
1274    let mut visited = vec![false; n];
1275    let mut components = Vec::new();
1276
1277    for start in 0..n {
1278        if !visited[start] {
1279            let mut component = Vec::new();
1280            let mut dfs = Dfs::new(&graph, nodes[start]);
1281            while let Some(node) = dfs.next(&graph) {
1282                let idx = node.index();
1283                if !visited[idx] {
1284                    visited[idx] = true;
1285                    component.push(idx);
1286                }
1287            }
1288            component.sort();
1289            components.push(component);
1290        }
1291    }
1292
1293    components.sort_by_key(|c| c[0]);
1294    components
1295}
1296
1297fn shares_retained_index(a: &TensorDynLen, b: &TensorDynLen, retain_indices: &[DynIndex]) -> bool {
1298    retain_indices.iter().any(|retain| {
1299        a.indices().iter().any(|idx_a| idx_a == retain)
1300            && b.indices().iter().any(|idx_b| idx_b == retain)
1301    })
1302}
1303
1304/// Remap AllowedPairs for a subset of tensors.
1305fn remap_allowed_pairs(allowed: AllowedPairs<'_>, component: &[usize]) -> RemappedAllowedPairs {
1306    match allowed {
1307        AllowedPairs::All => RemappedAllowedPairs::All,
1308        AllowedPairs::Specified(pairs) => {
1309            let orig_to_local: HashMap<usize, usize> = component
1310                .iter()
1311                .enumerate()
1312                .map(|(local, &orig)| (orig, local))
1313                .collect();
1314
1315            let remapped: Vec<(usize, usize)> = pairs
1316                .iter()
1317                .filter_map(
1318                    |&(i, j)| match (orig_to_local.get(&i), orig_to_local.get(&j)) {
1319                        (Some(&li), Some(&lj)) => Some((li, lj)),
1320                        _ => None,
1321                    },
1322                )
1323                .collect();
1324
1325            RemappedAllowedPairs::Specified(remapped)
1326        }
1327    }
1328}
1329
1330/// Owned version of AllowedPairs for remapped components.
1331enum RemappedAllowedPairs {
1332    All,
1333    Specified(Vec<(usize, usize)>),
1334}
1335
1336impl RemappedAllowedPairs {
1337    fn as_ref(&self) -> AllowedPairs<'_> {
1338        match self {
1339            RemappedAllowedPairs::All => AllowedPairs::All,
1340            RemappedAllowedPairs::Specified(pairs) => AllowedPairs::Specified(pairs),
1341        }
1342    }
1343}
1344
1345#[cfg(test)]
1346mod tests;