tensor4all_core/defaults/
contract.rs

1//! Multi-tensor contraction with optimal contraction order.
2//!
3//! This module provides functions to contract multiple tensors efficiently
4//! using hyperedge-aware einsum optimization via the tensorbackend
5//! (tenferro-backed implementation).
6//!
7//! This module works with concrete types (`DynIndex`, `TensorDynLen`) only.
8//!
9//! # Main Functions
10//!
11//! - [`contract_multi`]: Contracts tensors, handling disconnected components via outer product
12//! - [`contract_connected`]: Contracts tensors that must form a connected graph
13//!
14//! # Diag Tensor Handling
15//!
16//! When Diag tensors share indices, their diagonal axes are unified to create
17//! hyperedges in the einsum optimizer.
18//!
19//! Example: `Diag(i,j) * Diag(j,k)`:
20//! - Diag(i,j) has diagonal axes i and j (same index)
21//! - Diag(j,k) has diagonal axes j and k (same index)
22//! - After union-find: i, j, k all map to the same representative ID
23//! - This creates a hyperedge that the einsum optimizer handles correctly
24
25use std::cell::RefCell;
26use std::collections::HashMap;
27use std::env;
28use std::time::{Duration, Instant};
29
30use anyhow::Result;
31use petgraph::algo::connected_components;
32use petgraph::prelude::*;
33use tensor4all_tensorbackend::einsum_native_tensors;
34
35use crate::defaults::{DynId, DynIndex, TensorDynLen};
36
37use crate::index_like::IndexLike;
38use crate::tensor_like::AllowedPairs;
39
40#[derive(Debug, Clone, Hash, PartialEq, Eq)]
41struct ContractOperandSignature {
42    dims: Vec<usize>,
43    ids: Vec<usize>,
44    is_diag: bool,
45}
46
47#[derive(Debug, Clone, Hash, PartialEq, Eq)]
48struct ContractSignature {
49    operands: Vec<ContractOperandSignature>,
50    output_ids: Vec<usize>,
51    output_dims: Vec<usize>,
52}
53
54#[derive(Debug, Default, Clone)]
55struct ContractProfileEntry {
56    calls: usize,
57    total_time: Duration,
58}
59
60thread_local! {
61    static CONTRACT_PROFILE_STATE: RefCell<HashMap<ContractSignature, ContractProfileEntry>> =
62        RefCell::new(HashMap::new());
63}
64
65fn contract_profile_enabled() -> bool {
66    env::var("T4A_PROFILE_CONTRACT").is_ok()
67}
68
69fn record_contract_profile(signature: ContractSignature, elapsed: Duration) {
70    if !contract_profile_enabled() {
71        return;
72    }
73    CONTRACT_PROFILE_STATE.with(|state| {
74        let mut state = state.borrow_mut();
75        let entry = state.entry(signature).or_default();
76        entry.calls += 1;
77        entry.total_time += elapsed;
78    });
79}
80
81/// Reset the aggregated multi-tensor contraction profile.
82pub fn reset_contract_profile() {
83    CONTRACT_PROFILE_STATE.with(|state| state.borrow_mut().clear());
84}
85
86/// Print and clear the aggregated multi-tensor contraction profile.
87pub fn print_and_reset_contract_profile() {
88    if !contract_profile_enabled() {
89        return;
90    }
91    CONTRACT_PROFILE_STATE.with(|state| {
92        let mut entries: Vec<_> = state
93            .borrow()
94            .iter()
95            .map(|(k, v)| (k.clone(), v.clone()))
96            .collect();
97        state.borrow_mut().clear();
98        entries.sort_by(|(_, lhs), (_, rhs)| rhs.total_time.cmp(&lhs.total_time));
99
100        eprintln!("=== contract_multi Profile ===");
101        for (idx, (signature, entry)) in entries.into_iter().take(20).enumerate() {
102            let operands = signature
103                .operands
104                .iter()
105                .map(|operand| {
106                    format!(
107                        "dims={:?} ids={:?}{}",
108                        operand.dims,
109                        operand.ids,
110                        if operand.is_diag { " diag" } else { "" }
111                    )
112                })
113                .collect::<Vec<_>>()
114                .join(" ; ");
115            eprintln!(
116                "#{idx:02} calls={} total={:.3}s per_call={:.3}us output_dims={:?} output_ids={:?}",
117                entry.calls,
118                entry.total_time.as_secs_f64(),
119                entry.total_time.as_secs_f64() * 1e6 / entry.calls as f64,
120                signature.output_dims,
121                signature.output_ids,
122            );
123            eprintln!("     {operands}");
124        }
125    });
126}
127
128// ============================================================================
129// Public API
130// ============================================================================
131
132/// Contract multiple tensors into a single tensor, handling disconnected components.
133///
134/// This function automatically handles disconnected tensor graphs by:
135/// 1. Finding connected components based on contractable indices
136/// 2. Contracting each connected component separately
137/// 3. Combining results using outer product
138///
139/// # Arguments
140/// * `tensors` - Slice of tensors to contract
141/// * `allowed` - Specifies which tensor pairs can have their indices contracted
142///
143/// # Returns
144/// The result of contracting all tensors over allowed contractable indices.
145/// If tensors form disconnected components, they are combined via outer product.
146///
147/// # Behavior by N
148/// - N=0: Error
149/// - N=1: Clone of input
150/// - N>=2: Contract connected components, combine with outer product
151///
152/// # Errors
153/// - `AllowedPairs::Specified` contains a pair with no contractable indices
154///
155/// # Examples
156///
157/// ```
158/// use tensor4all_core::{TensorDynLen, DynIndex, contract_multi, AllowedPairs};
159///
160/// // A[i, j] and B[j, k] share index j — contract to get C[i, k]
161/// let i = DynIndex::new_dyn(2);
162/// let j = DynIndex::new_dyn(3);
163/// let k = DynIndex::new_dyn(4);
164///
165/// let a = TensorDynLen::from_dense(
166///     vec![i.clone(), j.clone()],
167///     vec![1.0_f64; 6],
168/// ).unwrap();
169/// let b = TensorDynLen::from_dense(
170///     vec![j.clone(), k.clone()],
171///     vec![1.0_f64; 12],
172/// ).unwrap();
173///
174/// let c = contract_multi(&[&a, &b], AllowedPairs::All).unwrap();
175/// assert_eq!(c.dims(), vec![2, 4]);
176/// ```
177pub fn contract_multi(
178    tensors: &[&TensorDynLen],
179    allowed: AllowedPairs<'_>,
180) -> Result<TensorDynLen> {
181    match tensors.len() {
182        0 => Err(anyhow::anyhow!("No tensors to contract")),
183        1 => Ok((*tensors[0]).clone()),
184        _ => {
185            // Validate AllowedPairs::Specified pairs have contractable indices
186            if let AllowedPairs::Specified(pairs) = allowed {
187                for &(i, j) in pairs {
188                    if !has_contractable_indices(tensors[i], tensors[j]) {
189                        return Err(anyhow::anyhow!(
190                            "Specified pair ({}, {}) has no contractable indices",
191                            i,
192                            j
193                        ));
194                    }
195                }
196            }
197
198            // Find connected components
199            let components = find_tensor_connected_components(tensors, allowed);
200
201            if components.len() == 1 {
202                // All tensors connected - use optimized contraction (skip connectivity check)
203                contract_multi_impl(tensors, allowed, true)
204            } else {
205                // Multiple components - contract each and combine with outer product
206                let mut results: Vec<TensorDynLen> = Vec::new();
207                for component in &components {
208                    let component_tensors: Vec<&TensorDynLen> =
209                        component.iter().map(|&i| tensors[i]).collect();
210
211                    // Remap AllowedPairs for the component (connectivity already verified)
212                    let remapped_allowed = remap_allowed_pairs(allowed, component);
213                    let contracted =
214                        contract_multi_impl(&component_tensors, remapped_allowed.as_ref(), true)?;
215                    results.push(contracted);
216                }
217
218                // Combine with outer product
219                let mut results_iter = results.into_iter();
220                let mut result = results_iter.next().unwrap();
221                for other in results_iter {
222                    result = result.outer_product(&other)?;
223                }
224                Ok(result)
225            }
226        }
227    }
228}
229
230/// Contract multiple tensors that form a connected graph.
231///
232/// Uses hyperedge-aware einsum optimization via tensorbackend.
233/// This correctly handles Diag tensors by treating their diagonal axes as hyperedges.
234///
235/// # Arguments
236/// * `tensors` - Slice of tensors to contract (must form a connected graph)
237/// * `allowed` - Specifies which tensor pairs can have their indices contracted
238///
239/// # Returns
240/// The result of contracting all tensors over allowed contractable indices.
241///
242/// # Connectivity Requirement
243/// All tensors must form a connected graph through contractable indices.
244/// Two tensors are connected if they share a contractable index (same ID, dual direction).
245/// If the tensors form disconnected components, this function returns an error.
246///
247/// Use [`contract_multi`] if you want automatic handling of disconnected components.
248///
249/// # Behavior by N
250/// - N=0: Error
251/// - N=1: Clone of input
252/// - N>=2: Optimal order via hyperedge-aware greedy optimizer
253///
254/// # Examples
255///
256/// ```
257/// use tensor4all_core::{TensorDynLen, DynIndex, contract_connected, AllowedPairs};
258///
259/// // A[i, j] contracted with B[j, k]
260/// let i = DynIndex::new_dyn(2);
261/// let j = DynIndex::new_dyn(3);
262/// let k = DynIndex::new_dyn(4);
263///
264/// let a = TensorDynLen::from_dense(
265///     vec![i.clone(), j.clone()],
266///     vec![1.0_f64; 6],
267/// ).unwrap();
268/// let b = TensorDynLen::from_dense(
269///     vec![j.clone(), k.clone()],
270///     vec![1.0_f64; 12],
271/// ).unwrap();
272///
273/// let c = contract_connected(&[&a, &b], AllowedPairs::All).unwrap();
274/// assert_eq!(c.dims(), vec![2, 4]);
275/// ```
276pub fn contract_connected(
277    tensors: &[&TensorDynLen],
278    allowed: AllowedPairs<'_>,
279) -> Result<TensorDynLen> {
280    match tensors.len() {
281        0 => Err(anyhow::anyhow!("No tensors to contract")),
282        1 => Ok((*tensors[0]).clone()),
283        _ => {
284            // Check connectivity first
285            let components = find_tensor_connected_components(tensors, allowed);
286            if components.len() > 1 {
287                return Err(anyhow::anyhow!(
288                    "Disconnected tensor network: {} components found",
289                    components.len()
290                ));
291            }
292            // Connectivity verified - skip check in impl
293            contract_multi_impl(tensors, allowed, true)
294        }
295    }
296}
297
298// ============================================================================
299// Union-Find for Diag axis grouping
300// ============================================================================
301
302/// Union-Find data structure for grouping axis IDs.
303///
304/// Used to merge diagonal axes from Diag tensors so that they share
305/// the same representative ID when passed to einsum.
306#[derive(Debug, Clone)]
307pub struct AxisUnionFind {
308    /// Maps each ID to its parent. If parent[id] == id, it's a root.
309    parent: HashMap<DynId, DynId>,
310    /// Rank for union by rank optimization.
311    rank: HashMap<DynId, usize>,
312}
313
314impl AxisUnionFind {
315    /// Create a new empty union-find structure.
316    pub fn new() -> Self {
317        Self {
318            parent: HashMap::new(),
319            rank: HashMap::new(),
320        }
321    }
322
323    /// Add an ID to the structure (as its own set).
324    pub fn make_set(&mut self, id: DynId) {
325        use std::collections::hash_map::Entry;
326        if let Entry::Vacant(e) = self.parent.entry(id) {
327            e.insert(id);
328            self.rank.insert(id, 0);
329        }
330    }
331
332    /// Find the representative (root) of the set containing `id`.
333    /// Uses path compression for efficiency.
334    pub fn find(&mut self, id: DynId) -> DynId {
335        self.make_set(id);
336        if self.parent[&id] != id {
337            let root = self.find(self.parent[&id]);
338            self.parent.insert(id, root);
339        }
340        self.parent[&id]
341    }
342
343    /// Union the sets containing `a` and `b`.
344    /// Uses union by rank for efficiency.
345    pub fn union(&mut self, a: DynId, b: DynId) {
346        let root_a = self.find(a);
347        let root_b = self.find(b);
348
349        if root_a == root_b {
350            return;
351        }
352
353        let rank_a = self.rank[&root_a];
354        let rank_b = self.rank[&root_b];
355
356        if rank_a < rank_b {
357            self.parent.insert(root_a, root_b);
358        } else if rank_a > rank_b {
359            self.parent.insert(root_b, root_a);
360        } else {
361            self.parent.insert(root_b, root_a);
362            *self.rank.get_mut(&root_a).unwrap() += 1;
363        }
364    }
365
366    /// Remap an ID to its representative.
367    pub fn remap(&mut self, id: DynId) -> DynId {
368        self.find(id)
369    }
370
371    /// Remap a slice of IDs to their representatives.
372    pub fn remap_ids(&mut self, ids: &[DynId]) -> Vec<DynId> {
373        ids.iter().map(|id| self.find(*id)).collect()
374    }
375}
376
377impl Default for AxisUnionFind {
378    fn default() -> Self {
379        Self::new()
380    }
381}
382
383// ============================================================================
384// Diag union-find builders
385// ============================================================================
386
387/// Build a union-find structure from a collection of tensors.
388///
389/// For each Diag tensor component, all its indices are unified (they share the same
390/// diagonal dimension). This creates hyperedges when multiple Diag tensors
391/// share indices.
392pub fn build_diag_union(tensors: &[&TensorDynLen]) -> AxisUnionFind {
393    let mut uf = AxisUnionFind::new();
394
395    for tensor in tensors {
396        for idx in tensor.indices() {
397            uf.make_set(*idx.id());
398        }
399
400        if tensor.is_diag() && tensor.indices().len() >= 2 {
401            let first_id = *tensor.indices()[0].id();
402            for idx in tensor.indices().iter().skip(1) {
403                uf.union(first_id, *idx.id());
404            }
405        }
406    }
407
408    uf
409}
410
411/// Remap tensor indices using the union-find structure.
412///
413/// Returns a vector of remapped IDs for each tensor, suitable for passing
414/// to einsum. The original tensors are not modified.
415pub fn remap_tensor_ids(tensors: &[&TensorDynLen], uf: &mut AxisUnionFind) -> Vec<Vec<DynId>> {
416    tensors
417        .iter()
418        .map(|t| t.indices.iter().map(|idx| uf.find(*idx.id())).collect())
419        .collect()
420}
421
422/// Remap output IDs using the union-find structure.
423pub fn remap_output_ids(output: &[DynIndex], uf: &mut AxisUnionFind) -> Vec<DynId> {
424    output.iter().map(|idx| uf.find(*idx.id())).collect()
425}
426
427/// Collect dimension sizes for remapped IDs.
428///
429/// For unified IDs (from Diag tensors), all axes must have the same dimension,
430/// so we just take the first occurrence.
431pub fn collect_sizes(tensors: &[&TensorDynLen], uf: &mut AxisUnionFind) -> HashMap<DynId, usize> {
432    let mut sizes = HashMap::new();
433
434    for tensor in tensors {
435        let dims = tensor.dims();
436        for (idx, &dim) in tensor.indices.iter().zip(dims.iter()) {
437            let rep = uf.find(*idx.id());
438            sizes.entry(rep).or_insert(dim);
439        }
440    }
441
442    sizes
443}
444
445// ============================================================================
446// Contraction implementation
447// ============================================================================
448
449/// Internal implementation of multi-tensor contraction.
450///
451/// For Diag tensors, we pass them as 1D tensors (the diagonal elements) with
452/// a single hyperedge ID. The einsum hyperedge optimizer will handle them correctly.
453///
454/// This implementation preserves storage type: if all inputs are F64, the result
455/// is F64; if any input is C64, the result is C64.
456///
457/// # Arguments
458/// * `skip_connectivity_check` - If true, assumes connectivity was already verified by caller
459fn contract_multi_impl(
460    tensors: &[&TensorDynLen],
461    allowed: AllowedPairs<'_>,
462    _skip_connectivity_check: bool,
463) -> Result<TensorDynLen> {
464    // 1. Build union-find from Diag tensors to unify diagonal axes
465    let mut diag_uf = build_diag_union(tensors);
466
467    // 2. Build internal IDs with Diag-awareness
468    let (ixs, internal_id_to_original) = build_internal_ids(tensors, allowed, &mut diag_uf)?;
469
470    // 3. Output = count == 1 internal IDs (external indices)
471    let mut idx_count: HashMap<usize, usize> = HashMap::new();
472    for ix in &ixs {
473        for &i in ix {
474            *idx_count.entry(i).or_insert(0) += 1;
475        }
476    }
477    let mut output: Vec<usize> = idx_count
478        .iter()
479        .filter(|(_, &count)| count == 1)
480        .map(|(&idx, _)| idx)
481        .collect();
482    output.sort(); // deterministic order
483
484    // Note: Connectivity check is done by caller (contract_multi or contract_connected)
485    // via find_tensor_connected_components before calling this function
486
487    // 4. Build sizes from unique internal IDs
488    let mut sizes: HashMap<usize, usize> = HashMap::new();
489    for (tensor_idx, tensor) in tensors.iter().enumerate() {
490        let dims = tensor.dims();
491        for (pos, &dim) in dims.iter().enumerate() {
492            let internal_id = ixs[tensor_idx][pos];
493            sizes.entry(internal_id).or_insert(dim);
494        }
495    }
496
497    let profile_signature = contract_profile_enabled().then(|| ContractSignature {
498        operands: tensors
499            .iter()
500            .enumerate()
501            .map(|(tensor_idx, tensor)| ContractOperandSignature {
502                dims: tensor.dims().to_vec(),
503                ids: ixs[tensor_idx].clone(),
504                is_diag: tensor.is_diag(),
505            })
506            .collect(),
507        output_ids: output.clone(),
508        output_dims: output.iter().map(|id| sizes[id]).collect(),
509    });
510    let profile_started = contract_profile_enabled().then(Instant::now);
511
512    let native_operands: Vec<_> = tensors
513        .iter()
514        .enumerate()
515        .map(|(tensor_idx, tensor)| (tensor.as_native(), ixs[tensor_idx].as_slice()))
516        .collect();
517
518    let result_native = einsum_native_tensors(&native_operands, &output)?;
519    if let (Some(signature), Some(started)) = (profile_signature, profile_started) {
520        record_contract_profile(signature, started.elapsed());
521    }
522    let final_indices = if output.is_empty() {
523        vec![]
524    } else {
525        output
526            .iter()
527            .map(|&internal_id| {
528                let (tensor_idx, pos) = internal_id_to_original[&internal_id];
529                tensors[tensor_idx].indices[pos].clone()
530            })
531            .collect()
532    };
533    TensorDynLen::from_native(final_indices, result_native)
534}
535
536/// Build internal IDs with Diag-awareness.
537///
538/// Uses the union-find to ensure diagonal axes from Diag tensors share the same internal ID.
539///
540/// Returns: (ixs, internal_id_to_original)
541#[allow(clippy::type_complexity)]
542fn build_internal_ids(
543    tensors: &[&TensorDynLen],
544    allowed: AllowedPairs<'_>,
545    diag_uf: &mut AxisUnionFind,
546) -> Result<(Vec<Vec<usize>>, HashMap<usize, (usize, usize)>)> {
547    let mut next_id = 0usize;
548    let mut dynid_to_internal: HashMap<DynId, usize> = HashMap::new();
549    let mut assigned: HashMap<(usize, usize), usize> = HashMap::new();
550    let mut internal_id_to_original: HashMap<usize, (usize, usize)> = HashMap::new();
551
552    // Process contractable pairs
553    let pairs_to_process: Vec<(usize, usize)> = match allowed {
554        AllowedPairs::All => {
555            let mut pairs = Vec::new();
556            for ti in 0..tensors.len() {
557                for tj in (ti + 1)..tensors.len() {
558                    pairs.push((ti, tj));
559                }
560            }
561            pairs
562        }
563        AllowedPairs::Specified(pairs) => pairs.to_vec(),
564    };
565
566    for (ti, tj) in pairs_to_process {
567        for (pi, idx_i) in tensors[ti].indices.iter().enumerate() {
568            for (pj, idx_j) in tensors[tj].indices.iter().enumerate() {
569                if idx_i.is_contractable(idx_j) {
570                    let key_i = (ti, pi);
571                    let key_j = (tj, pj);
572
573                    let remapped_i = diag_uf.find(*idx_i.id());
574                    let remapped_j = diag_uf.find(*idx_j.id());
575
576                    match (assigned.get(&key_i).copied(), assigned.get(&key_j).copied()) {
577                        (None, None) => {
578                            let internal_id = if let Some(&id) = dynid_to_internal.get(&remapped_i)
579                            {
580                                id
581                            } else {
582                                let id = next_id;
583                                next_id += 1;
584                                dynid_to_internal.insert(remapped_i, id);
585                                internal_id_to_original.insert(id, key_i);
586                                id
587                            };
588                            assigned.insert(key_i, internal_id);
589                            assigned.insert(key_j, internal_id);
590                            if remapped_i != remapped_j {
591                                dynid_to_internal.insert(remapped_j, internal_id);
592                            }
593                        }
594                        (Some(id), None) => {
595                            assigned.insert(key_j, id);
596                            dynid_to_internal.insert(remapped_j, id);
597                        }
598                        (None, Some(id)) => {
599                            assigned.insert(key_i, id);
600                            dynid_to_internal.insert(remapped_i, id);
601                        }
602                        (Some(_id_i), Some(_id_j)) => {
603                            // Both already assigned
604                        }
605                    }
606                }
607            }
608        }
609    }
610
611    // Assign IDs for unassigned indices (external indices)
612    for (tensor_idx, tensor) in tensors.iter().enumerate() {
613        for (pos, idx) in tensor.indices.iter().enumerate() {
614            let key = (tensor_idx, pos);
615            if let std::collections::hash_map::Entry::Vacant(e) = assigned.entry(key) {
616                let remapped_id = diag_uf.find(*idx.id());
617
618                let internal_id = if let Some(&id) = dynid_to_internal.get(&remapped_id) {
619                    id
620                } else {
621                    let id = next_id;
622                    next_id += 1;
623                    dynid_to_internal.insert(remapped_id, id);
624                    internal_id_to_original.insert(id, key);
625                    id
626                };
627                e.insert(internal_id);
628            }
629        }
630    }
631
632    // Build ixs
633    let ixs: Vec<Vec<usize>> = tensors
634        .iter()
635        .enumerate()
636        .map(|(tensor_idx, tensor)| {
637            (0..tensor.indices.len())
638                .map(|pos| assigned[&(tensor_idx, pos)])
639                .collect()
640        })
641        .collect();
642
643    Ok((ixs, internal_id_to_original))
644}
645
646// ============================================================================
647// Helper functions for connected component detection
648// ============================================================================
649
650/// Check if two tensors have any contractable indices.
651fn has_contractable_indices(a: &TensorDynLen, b: &TensorDynLen) -> bool {
652    a.indices
653        .iter()
654        .any(|idx_a| b.indices.iter().any(|idx_b| idx_a.is_contractable(idx_b)))
655}
656
657/// Find connected components of tensors based on contractable indices.
658///
659/// Uses petgraph for O(V+E) connected component detection.
660fn find_tensor_connected_components(
661    tensors: &[&TensorDynLen],
662    allowed: AllowedPairs<'_>,
663) -> Vec<Vec<usize>> {
664    let n = tensors.len();
665    if n == 0 {
666        return vec![];
667    }
668    if n == 1 {
669        return vec![vec![0]];
670    }
671
672    // Build undirected graph
673    let mut graph = UnGraph::<(), ()>::new_undirected();
674    let nodes: Vec<_> = (0..n).map(|_| graph.add_node(())).collect();
675
676    // Add edges based on connectivity
677    match allowed {
678        AllowedPairs::All => {
679            for i in 0..n {
680                for j in (i + 1)..n {
681                    if has_contractable_indices(tensors[i], tensors[j]) {
682                        graph.add_edge(nodes[i], nodes[j], ());
683                    }
684                }
685            }
686        }
687        AllowedPairs::Specified(pairs) => {
688            for &(i, j) in pairs {
689                if has_contractable_indices(tensors[i], tensors[j]) {
690                    graph.add_edge(nodes[i], nodes[j], ());
691                }
692            }
693        }
694    }
695
696    // Find connected components using petgraph
697    let num_components = connected_components(&graph);
698
699    if num_components == 1 {
700        return vec![(0..n).collect()];
701    }
702
703    // Multiple components - group by component ID
704    use petgraph::visit::Dfs;
705    let mut visited = vec![false; n];
706    let mut components = Vec::new();
707
708    for start in 0..n {
709        if !visited[start] {
710            let mut component = Vec::new();
711            let mut dfs = Dfs::new(&graph, nodes[start]);
712            while let Some(node) = dfs.next(&graph) {
713                let idx = node.index();
714                if !visited[idx] {
715                    visited[idx] = true;
716                    component.push(idx);
717                }
718            }
719            component.sort();
720            components.push(component);
721        }
722    }
723
724    components.sort_by_key(|c| c[0]);
725    components
726}
727
728/// Remap AllowedPairs for a subset of tensors.
729fn remap_allowed_pairs(allowed: AllowedPairs<'_>, component: &[usize]) -> RemappedAllowedPairs {
730    match allowed {
731        AllowedPairs::All => RemappedAllowedPairs::All,
732        AllowedPairs::Specified(pairs) => {
733            let orig_to_local: HashMap<usize, usize> = component
734                .iter()
735                .enumerate()
736                .map(|(local, &orig)| (orig, local))
737                .collect();
738
739            let remapped: Vec<(usize, usize)> = pairs
740                .iter()
741                .filter_map(
742                    |&(i, j)| match (orig_to_local.get(&i), orig_to_local.get(&j)) {
743                        (Some(&li), Some(&lj)) => Some((li, lj)),
744                        _ => None,
745                    },
746                )
747                .collect();
748
749            RemappedAllowedPairs::Specified(remapped)
750        }
751    }
752}
753
754/// Owned version of AllowedPairs for remapped components.
755enum RemappedAllowedPairs {
756    All,
757    Specified(Vec<(usize, usize)>),
758}
759
760impl RemappedAllowedPairs {
761    fn as_ref(&self) -> AllowedPairs<'_> {
762        match self {
763            RemappedAllowedPairs::All => AllowedPairs::All,
764            RemappedAllowedPairs::Specified(pairs) => AllowedPairs::Specified(pairs),
765        }
766    }
767}
768
769#[cfg(test)]
770mod tests;