Skip to main content

tensor4all_core/
index_ops.rs

1use crate::IndexLike;
2use smallvec::SmallVec;
3
4const SMALL_CONTRACTION_INLINE: usize = 8;
5const LINEAR_CONTRACTION_SCAN_LIMIT: usize = 64;
6
7/// Small axis list used by contraction preparation.
8pub(crate) type AxisVec = SmallVec<[usize; SMALL_CONTRACTION_INLINE]>;
9/// Small index list used by contraction preparation.
10pub(crate) type IndexVec<I> = SmallVec<[I; SMALL_CONTRACTION_INLINE]>;
11
12type AxisPairVec = SmallVec<[(usize, usize); SMALL_CONTRACTION_INLINE]>;
13type BoolVec = SmallVec<[bool; SMALL_CONTRACTION_INLINE]>;
14
15/// Error type for index replacement operations.
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum ReplaceIndsError {
18    /// The symmetry space of the replacement index does not match the original.
19    SpaceMismatch {
20        /// The dimension/size of the original index
21        from_dim: usize,
22        /// The dimension/size of the replacement index
23        to_dim: usize,
24    },
25    /// Duplicate indices found in the collection.
26    DuplicateIndices {
27        /// The position of the first duplicate index
28        first_pos: usize,
29        /// The position of the duplicate index
30        duplicate_pos: usize,
31    },
32}
33
34impl std::fmt::Display for ReplaceIndsError {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        match self {
37            ReplaceIndsError::SpaceMismatch { from_dim, to_dim } => {
38                write!(
39                    f,
40                    "Index space mismatch: cannot replace index with dimension {} with index of dimension {}",
41                    from_dim, to_dim
42                )
43            }
44            ReplaceIndsError::DuplicateIndices {
45                first_pos,
46                duplicate_pos,
47            } => {
48                write!(
49                    f,
50                    "Duplicate indices found: index at position {} equals index at position {}",
51                    duplicate_pos, first_pos
52                )
53            }
54        }
55    }
56}
57
58impl std::error::Error for ReplaceIndsError {}
59
60/// Check if a collection of indices contains any duplicate full indices.
61///
62/// # Arguments
63/// * `indices` - Collection of indices to check
64///
65/// # Returns
66/// `Ok(())` if all indices are unique, or `Err(ReplaceIndsError::DuplicateIndices)` if duplicates are found.
67///
68/// # Example
69/// ```
70/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
71/// use tensor4all_core::index_ops::check_unique_indices;
72///
73/// let i = Index::new_dyn(2);
74/// let j = Index::new_dyn(3);
75/// let indices = vec![i.clone(), j.clone()];
76/// assert!(check_unique_indices(&indices).is_ok());
77///
78/// let duplicate = vec![i.clone(), i.clone()];
79/// assert!(check_unique_indices(&duplicate).is_err());
80/// ```
81pub fn check_unique_indices<I: IndexLike>(indices: &[I]) -> Result<(), ReplaceIndsError> {
82    use std::collections::HashMap;
83    let mut seen: HashMap<&I, usize> = HashMap::with_capacity(indices.len());
84    for (pos, idx) in indices.iter().enumerate() {
85        if let Some(&first_pos) = seen.get(idx) {
86            return Err(ReplaceIndsError::DuplicateIndices {
87                first_pos,
88                duplicate_pos: pos,
89            });
90        }
91        seen.insert(idx, pos);
92    }
93    Ok(())
94}
95
96/// Replace indices in a collection based on full-index matching.
97///
98/// This corresponds to ITensors.jl's `replaceinds` function. It replaces indices
99/// in `indices` that equal any of the `(old, new)` pairs in `replacements`.
100/// The replacement index must have the same dimension as the original.
101///
102/// # Arguments
103/// * `indices` - Collection of indices to modify
104/// * `replacements` - Pairs of `(old_index, new_index)` where indices equal to `old_index` are replaced with `new_index`
105///
106/// # Returns
107/// A new vector with replacements applied, or an error if any replacement has a dimension mismatch.
108///
109/// # Errors
110/// Returns `ReplaceIndsError::SpaceMismatch` if any replacement index has a different dimension than the original.
111///
112/// # Example
113/// ```
114/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
115/// use tensor4all_core::index_ops::replaceinds;
116///
117/// let i = Index::new_dyn(2);
118/// let j = Index::new_dyn(3);
119/// let k = Index::new_dyn(4);
120/// let new_j = Index::new_dyn(3);  // Same size as j
121///
122/// let indices = vec![i.clone(), j.clone(), k.clone()];
123/// let replacements = vec![(j.clone(), new_j.clone())];
124///
125/// let replaced = replaceinds(indices, &replacements).unwrap();
126/// assert_eq!(replaced.len(), 3);
127/// assert_eq!(replaced[1].id, new_j.id);
128/// ```
129pub fn replaceinds<I: IndexLike>(
130    indices: Vec<I>,
131    replacements: &[(I, I)],
132) -> Result<Vec<I>, ReplaceIndsError> {
133    // Check for duplicates in input indices
134    check_unique_indices(&indices)?;
135
136    // Build a map from old index to new index for fast lookup.
137    let mut replacement_map = std::collections::HashMap::with_capacity(replacements.len());
138    for (old, new) in replacements {
139        // Validate dimension match
140        if old.dim() != new.dim() {
141            return Err(ReplaceIndsError::SpaceMismatch {
142                from_dim: old.dim(),
143                to_dim: new.dim(),
144            });
145        }
146        replacement_map.insert(old.clone(), new);
147    }
148
149    // Apply replacements
150    let mut result = Vec::with_capacity(indices.len());
151    for idx in indices {
152        if let Some(new_idx) = replacement_map.get(&idx) {
153            result.push((*new_idx).clone());
154        } else {
155            result.push(idx);
156        }
157    }
158
159    // Check for duplicates in result indices
160    check_unique_indices(&result)?;
161    Ok(result)
162}
163
164/// Replace indices in-place based on full-index matching.
165///
166/// This is an in-place variant of `replaceinds` that modifies the input slice directly.
167/// Useful for performance-critical code where you want to avoid allocations.
168///
169/// # Arguments
170/// * `indices` - Mutable slice of indices to modify
171/// * `replacements` - Pairs of `(old_index, new_index)` where indices equal to `old_index` are replaced with `new_index`
172///
173/// # Returns
174/// `Ok(())` on success, or an error if any replacement has a dimension mismatch.
175///
176/// # Errors
177/// Returns `ReplaceIndsError::SpaceMismatch` if any replacement index has a different dimension than the original.
178///
179/// # Example
180/// ```
181/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
182/// use tensor4all_core::index_ops::replaceinds_in_place;
183///
184/// let i = Index::new_dyn(2);
185/// let j = Index::new_dyn(3);
186/// let k = Index::new_dyn(4);
187/// let new_j = Index::new_dyn(3);
188///
189/// let mut indices = vec![i.clone(), j.clone(), k.clone()];
190/// let replacements = vec![(j.clone(), new_j.clone())];
191///
192/// replaceinds_in_place(&mut indices, &replacements).unwrap();
193/// assert_eq!(indices[1].id, new_j.id);
194/// ```
195pub fn replaceinds_in_place<I: IndexLike>(
196    indices: &mut [I],
197    replacements: &[(I, I)],
198) -> Result<(), ReplaceIndsError> {
199    // Check for duplicates in input indices
200    check_unique_indices(indices)?;
201
202    // Build a map from old index to new index for fast lookup.
203    let mut replacement_map = std::collections::HashMap::with_capacity(replacements.len());
204    for (old, new) in replacements {
205        // Validate dimension match
206        if old.dim() != new.dim() {
207            return Err(ReplaceIndsError::SpaceMismatch {
208                from_dim: old.dim(),
209                to_dim: new.dim(),
210            });
211        }
212        replacement_map.insert(old.clone(), new);
213    }
214
215    // Apply replacements in-place
216    for idx in indices.iter_mut() {
217        if let Some(new_idx) = replacement_map.get(idx) {
218            *idx = (*new_idx).clone();
219        }
220    }
221
222    // Check for duplicates in result indices
223    check_unique_indices(indices)?;
224    Ok(())
225}
226
227/// Find indices that are unique to the first collection (set difference A \ B).
228///
229/// Returns indices that appear in `indices_a` but not in `indices_b` (matched by full index).
230/// This corresponds to ITensors.jl's `uniqueinds` function.
231///
232/// # Arguments
233/// * `indices_a` - First collection of indices
234/// * `indices_b` - Second collection of indices
235///
236/// # Returns
237/// A vector containing indices from `indices_a` that are not in `indices_b`.
238///
239/// # Example
240/// ```
241/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
242/// use tensor4all_core::index_ops::unique_inds;
243///
244/// let i = Index::new_dyn(2);
245/// let j = Index::new_dyn(3);
246/// let k = Index::new_dyn(4);
247///
248/// let indices_a = vec![i.clone(), j.clone()];
249/// let indices_b = vec![j.clone(), k.clone()];
250///
251/// let unique = unique_inds(&indices_a, &indices_b);
252/// assert_eq!(unique.len(), 1);
253/// assert_eq!(unique[0].id, i.id);
254/// ```
255pub fn unique_inds<I: IndexLike>(indices_a: &[I], indices_b: &[I]) -> Vec<I> {
256    indices_a
257        .iter()
258        .filter(|idx| !indices_b.iter().any(|other| other == *idx))
259        .cloned()
260        .collect()
261}
262
263/// Find indices that are not common between two collections (symmetric difference).
264///
265/// Returns indices that appear in either `indices_a` or `indices_b` but not in both
266/// (matched by full index). This corresponds to ITensors.jl's `noncommoninds` function.
267///
268/// Time complexity: O(n + m) where n = len(indices_a), m = len(indices_b).
269///
270/// # Arguments
271/// * `indices_a` - First collection of indices
272/// * `indices_b` - Second collection of indices
273///
274/// # Returns
275/// A vector containing indices from both collections that are not common to both.
276/// Order: indices from A first (in original order), then indices from B (in original order).
277///
278/// # Example
279/// ```
280/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
281/// use tensor4all_core::index_ops::noncommon_inds;
282///
283/// let i = Index::new_dyn(2);
284/// let j = Index::new_dyn(3);
285/// let k = Index::new_dyn(4);
286///
287/// let indices_a = vec![i.clone(), j.clone()];
288/// let indices_b = vec![j.clone(), k.clone()];
289///
290/// let noncommon = noncommon_inds(&indices_a, &indices_b);
291/// assert_eq!(noncommon.len(), 2);  // i and k
292/// ```
293pub fn noncommon_inds<I: IndexLike>(indices_a: &[I], indices_b: &[I]) -> Vec<I> {
294    // Pre-allocate with estimated capacity (worst case: no common indices)
295    let mut result = Vec::with_capacity(indices_a.len() + indices_b.len());
296
297    // Add indices from A that are not in B
298    result.extend(
299        indices_a
300            .iter()
301            .filter(|idx| !indices_b.iter().any(|other| other == *idx))
302            .cloned(),
303    );
304    // Add indices from B that are not in A
305    result.extend(
306        indices_b
307            .iter()
308            .filter(|idx| !indices_a.iter().any(|other| other == *idx))
309            .cloned(),
310    );
311    result
312}
313
314/// Find the union of two index collections.
315///
316/// Returns all unique indices from both collections (matched by full index).
317/// This corresponds to ITensors.jl's `unioninds` function.
318///
319/// Time complexity: O(n + m) where n = len(indices_a), m = len(indices_b).
320///
321/// # Arguments
322/// * `indices_a` - First collection of indices
323/// * `indices_b` - Second collection of indices
324///
325/// # Returns
326/// A vector containing all unique indices from both collections.
327///
328/// # Example
329/// ```
330/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
331/// use tensor4all_core::index_ops::union_inds;
332///
333/// let i = Index::new_dyn(2);
334/// let j = Index::new_dyn(3);
335/// let k = Index::new_dyn(4);
336///
337/// let indices_a = vec![i.clone(), j.clone()];
338/// let indices_b = vec![j.clone(), k.clone()];
339///
340/// let union = union_inds(&indices_a, &indices_b);
341/// assert_eq!(union.len(), 3);  // i, j, k
342/// ```
343pub fn union_inds<I: IndexLike>(indices_a: &[I], indices_b: &[I]) -> Vec<I> {
344    let mut seen: std::collections::HashSet<&I> =
345        std::collections::HashSet::with_capacity(indices_a.len() + indices_b.len());
346    let mut result = Vec::with_capacity(indices_a.len() + indices_b.len());
347
348    for idx in indices_a {
349        if seen.insert(idx) {
350            result.push(idx.clone());
351        }
352    }
353    for idx in indices_b {
354        if seen.insert(idx) {
355            result.push(idx.clone());
356        }
357    }
358    result
359}
360
361/// Check if a collection contains a specific full index.
362///
363/// This corresponds to ITensors.jl's `hasind` function.
364///
365/// # Arguments
366/// * `indices` - Collection of indices to search
367/// * `index` - The index to look for
368///
369/// # Returns
370/// `true` if an index with matching ID is found, `false` otherwise.
371///
372/// # Example
373/// ```
374/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
375/// use tensor4all_core::index_ops::hasind;
376///
377/// let i = Index::new_dyn(2);
378/// let j = Index::new_dyn(3);
379/// let indices = vec![i.clone(), j.clone()];
380///
381/// assert!(hasind(&indices, &i));
382/// assert!(!hasind(&indices, &Index::new_dyn(4)));
383/// ```
384pub fn hasind<I: IndexLike>(indices: &[I], index: &I) -> bool {
385    indices.iter().any(|idx| idx == index)
386}
387
388/// Check if a collection contains all of the specified full indices.
389///
390/// This corresponds to ITensors.jl's `hasinds` function.
391///
392/// # Arguments
393/// * `indices` - Collection of indices to search
394/// * `targets` - The indices to look for
395///
396/// # Returns
397/// `true` if all target indices are found, `false` otherwise.
398///
399/// # Example
400/// ```
401/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
402/// use tensor4all_core::index_ops::hasinds;
403///
404/// let i = Index::new_dyn(2);
405/// let j = Index::new_dyn(3);
406/// let k = Index::new_dyn(4);
407/// let indices = vec![i.clone(), j.clone(), k.clone()];
408///
409/// assert!(hasinds(&indices, &[i.clone(), j.clone()]));
410/// assert!(!hasinds(&indices, &[i.clone(), Index::new_dyn(5)]));
411/// ```
412pub fn hasinds<I: IndexLike>(indices: &[I], targets: &[I]) -> bool {
413    targets
414        .iter()
415        .all(|target| indices.iter().any(|idx| idx == target))
416}
417
418/// Check if two collections have any common full indices.
419///
420/// This corresponds to ITensors.jl's `hascommoninds` function.
421///
422/// # Arguments
423/// * `indices_a` - First collection of indices
424/// * `indices_b` - Second collection of indices
425///
426/// # Returns
427/// `true` if there is at least one common index, `false` otherwise.
428///
429/// # Example
430/// ```
431/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
432/// use tensor4all_core::index_ops::hascommoninds;
433///
434/// let i = Index::new_dyn(2);
435/// let j = Index::new_dyn(3);
436/// let k = Index::new_dyn(4);
437///
438/// let indices_a = vec![i.clone(), j.clone()];
439/// let indices_b = vec![j.clone(), k.clone()];
440///
441/// assert!(hascommoninds(&indices_a, &indices_b));
442/// assert!(!hascommoninds(&[i.clone()], &[k.clone()]));
443/// ```
444pub fn hascommoninds<I: IndexLike>(indices_a: &[I], indices_b: &[I]) -> bool {
445    indices_a
446        .iter()
447        .any(|idx| indices_b.iter().any(|other| other == idx))
448}
449
450/// Find common indices between two index collections.
451///
452/// Returns a vector of indices that appear in both `indices_a` and `indices_b`
453/// (set intersection). This is similar to ITensors.jl's `commoninds` function.
454///
455/// Time complexity: O(n + m) where n = len(indices_a), m = len(indices_b).
456///
457/// # Arguments
458/// * `indices_a` - First collection of indices
459/// * `indices_b` - Second collection of indices
460///
461/// # Returns
462/// A vector containing indices that are common to both collections (matched by full index).
463///
464/// # Example
465/// ```
466/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
467/// use tensor4all_core::index_ops::common_inds;
468///
469/// let i = Index::new_dyn(2);
470/// let j = Index::new_dyn(3);
471/// let k = Index::new_dyn(4);
472///
473/// let indices_a = vec![i.clone(), j.clone()];
474/// let indices_b = vec![j.clone(), k.clone()];
475///
476/// let common = common_inds(&indices_a, &indices_b);
477/// assert_eq!(common.len(), 1);
478/// assert_eq!(common[0].id, j.id);
479/// ```
480pub fn common_inds<I: IndexLike>(indices_a: &[I], indices_b: &[I]) -> Vec<I> {
481    indices_a
482        .iter()
483        .filter(|idx| indices_b.iter().any(|other| other == *idx))
484        .cloned()
485        .collect()
486}
487
488/// Find contractable indices between two slices and return their positions.
489///
490/// Returns a vector of `(pos_a, pos_b)` tuples where each tuple indicates
491/// that `indices_a[pos_a]` and `indices_b[pos_b]` are contractable
492/// (same ID, same dimension, and compatible ConjState).
493///
494/// # Example
495/// ```
496/// use tensor4all_core::index::DefaultIndex as Index;
497/// use tensor4all_core::index_ops::common_ind_positions;
498///
499/// let i = Index::new_dyn(2);
500/// let j = Index::new_dyn(3);
501/// let k = Index::new_dyn(4);
502///
503/// let indices_a = vec![i.clone(), j.clone()];
504/// let indices_b = vec![j.clone(), k.clone()];
505///
506/// let positions = common_ind_positions(&indices_a, &indices_b);
507/// assert_eq!(positions, vec![(1, 0)]); // j is at position 1 in a, position 0 in b
508/// ```
509pub fn common_ind_positions<I: IndexLike>(indices_a: &[I], indices_b: &[I]) -> Vec<(usize, usize)> {
510    common_ind_positions_small(indices_a, indices_b).into_vec()
511}
512
513fn common_ind_positions_small<I: IndexLike>(indices_a: &[I], indices_b: &[I]) -> AxisPairVec {
514    let scan_work = indices_a.len().saturating_mul(indices_b.len());
515    if scan_work <= LINEAR_CONTRACTION_SCAN_LIMIT {
516        return common_ind_positions_linear(indices_a, indices_b);
517    }
518    common_ind_positions_hashed(indices_a, indices_b)
519}
520
521fn common_ind_positions_linear<I: IndexLike>(indices_a: &[I], indices_b: &[I]) -> AxisPairVec {
522    let mut positions = AxisPairVec::new();
523    for (pos_a, idx_a) in indices_a.iter().enumerate() {
524        for (pos_b, idx_b) in indices_b.iter().enumerate() {
525            if idx_a.is_contractable(idx_b) {
526                positions.push((pos_a, pos_b));
527                break; // Each index in a can match at most one in b
528            }
529        }
530    }
531    positions
532}
533
534fn common_ind_positions_hashed<I: IndexLike>(indices_a: &[I], indices_b: &[I]) -> AxisPairVec {
535    use std::collections::HashMap;
536
537    let mut positions_by_id: HashMap<&I::Id, SmallVec<[usize; 2]>> =
538        HashMap::with_capacity(indices_b.len());
539    for (pos_b, idx_b) in indices_b.iter().enumerate() {
540        positions_by_id.entry(idx_b.id()).or_default().push(pos_b);
541    }
542
543    let mut positions = AxisPairVec::new();
544    for (pos_a, idx_a) in indices_a.iter().enumerate() {
545        let Some(candidate_positions) = positions_by_id.get(idx_a.id()) else {
546            continue;
547        };
548        for &pos_b in candidate_positions {
549            if idx_a.is_contractable(&indices_b[pos_b]) {
550                positions.push((pos_a, pos_b));
551                break; // Each index in a can match at most one in b
552            }
553        }
554    }
555    positions
556}
557
558/// Result of preparing a tensor contraction.
559///
560/// Contains all the information needed to perform the contraction:
561/// - Which axes to contract from each tensor
562/// - The resulting indices after contraction
563#[derive(Debug, Clone)]
564pub(crate) struct ContractionSpec<I: IndexLike> {
565    /// Axes to contract from the first tensor (positions in `indices_a`).
566    pub axes_a: AxisVec,
567    /// Axes to contract from the second tensor (positions in `indices_b`).
568    pub axes_b: AxisVec,
569    /// Indices of the result tensor (non-contracted indices from both tensors).
570    pub result_indices: IndexVec<I>,
571}
572
573/// Error type for contraction preparation.
574#[derive(Debug, Clone, PartialEq, Eq)]
575pub(crate) enum ContractionError {
576    /// No common indices found for contraction.
577    NoCommonIndices,
578    /// Dimension mismatch for a common index.
579    DimensionMismatch {
580        /// Position in the first tensor.
581        pos_a: usize,
582        /// Position in the second tensor.
583        pos_b: usize,
584        /// Dimension in the first tensor.
585        dim_a: usize,
586        /// Dimension in the second tensor.
587        dim_b: usize,
588    },
589    /// Duplicate axis specified in contraction.
590    DuplicateAxis {
591        /// Which tensor has the duplicate ("self" or "other").
592        tensor: &'static str,
593        /// Position of the duplicate axis.
594        pos: usize,
595    },
596    /// Index not found in tensor.
597    IndexNotFound {
598        /// Which tensor the index was not found in.
599        tensor: &'static str,
600    },
601    /// Batch contraction not yet implemented.
602    BatchContractionNotImplemented,
603}
604
605impl std::fmt::Display for ContractionError {
606    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
607        match self {
608            ContractionError::NoCommonIndices => {
609                write!(f, "No common indices found for contraction")
610            }
611            ContractionError::DimensionMismatch {
612                pos_a,
613                pos_b,
614                dim_a,
615                dim_b,
616            } => {
617                write!(
618                    f,
619                    "Dimension mismatch: tensor_a[{}]={} != tensor_b[{}]={}",
620                    pos_a, dim_a, pos_b, dim_b
621                )
622            }
623            ContractionError::DuplicateAxis { tensor, pos } => {
624                write!(f, "Duplicate axis {} in {} tensor", pos, tensor)
625            }
626            ContractionError::IndexNotFound { tensor } => {
627                write!(f, "Index not found in {} tensor", tensor)
628            }
629            ContractionError::BatchContractionNotImplemented => {
630                write!(f, "Batch contraction not yet implemented")
631            }
632        }
633    }
634}
635
636impl std::error::Error for ContractionError {}
637
638/// Prepare contraction data for two tensors that share common indices.
639///
640/// This internal helper finds common indices and computes the axes to contract
641/// together with the resulting non-contracted indices.
642pub(crate) fn prepare_contraction<I: IndexLike>(
643    indices_a: &[I],
644    dims_a: &[usize],
645    indices_b: &[I],
646    dims_b: &[usize],
647) -> Result<ContractionSpec<I>, ContractionError> {
648    // Find common indices and their positions.
649    // If no common indices exist, this becomes an outer product (empty axes).
650    let positions = common_ind_positions_small(indices_a, indices_b);
651
652    let mut axes_a = AxisVec::with_capacity(positions.len());
653    let mut axes_b = AxisVec::with_capacity(positions.len());
654    for &(pos_a, pos_b) in &positions {
655        axes_a.push(pos_a);
656        axes_b.push(pos_b);
657    }
658
659    // Verify dimensions match
660    for &(pos_a, pos_b) in &positions {
661        if dims_a[pos_a] != dims_b[pos_b] {
662            return Err(ContractionError::DimensionMismatch {
663                pos_a,
664                pos_b,
665                dim_a: dims_a[pos_a],
666                dim_b: dims_b[pos_b],
667            });
668        }
669    }
670
671    let result_indices = build_contraction_result_indices(indices_a, &axes_a, indices_b, &axes_b);
672
673    Ok(ContractionSpec {
674        axes_a,
675        axes_b,
676        result_indices,
677    })
678}
679
680/// Prepare contraction data for explicit index pairs (like tensordot).
681///
682/// Unlike `prepare_contraction`, this internal helper takes explicit pairs of
683/// indices to contract, allowing contraction of indices with different IDs.
684pub(crate) fn prepare_contraction_pairs<I: IndexLike>(
685    indices_a: &[I],
686    dims_a: &[usize],
687    indices_b: &[I],
688    dims_b: &[usize],
689    pairs: &[(I, I)],
690) -> Result<ContractionSpec<I>, ContractionError> {
691    if pairs.is_empty() {
692        return Err(ContractionError::NoCommonIndices);
693    }
694
695    // Check for batch contraction (common indices not in pairs). The explicit
696    // pair list identifies axes by full index metadata, not by ID alone.
697    let common_positions = common_ind_positions_small(indices_a, indices_b);
698    for (pos_a, pos_b) in &common_positions {
699        let idx_a = &indices_a[*pos_a];
700        let idx_b = &indices_b[*pos_b];
701        if !pairs
702            .iter()
703            .any(|(contracted_idx, _)| contracted_idx == idx_a)
704            || !pairs
705                .iter()
706                .any(|(_, contracted_idx)| contracted_idx == idx_b)
707        {
708            return Err(ContractionError::BatchContractionNotImplemented);
709        }
710    }
711
712    // Find positions and validate
713    let mut axes_a = AxisVec::with_capacity(pairs.len());
714    let mut axes_b = AxisVec::with_capacity(pairs.len());
715    let mut contracted_a = bool_flags(indices_a.len());
716    let mut contracted_b = bool_flags(indices_b.len());
717
718    for (idx_a, idx_b) in pairs {
719        let pos_a = indices_a
720            .iter()
721            .position(|idx| idx == idx_a)
722            .ok_or(ContractionError::IndexNotFound { tensor: "self" })?;
723
724        let pos_b = indices_b
725            .iter()
726            .position(|idx| idx == idx_b)
727            .ok_or(ContractionError::IndexNotFound { tensor: "other" })?;
728
729        // Verify dimensions match
730        if dims_a[pos_a] != dims_b[pos_b] {
731            return Err(ContractionError::DimensionMismatch {
732                pos_a,
733                pos_b,
734                dim_a: dims_a[pos_a],
735                dim_b: dims_b[pos_b],
736            });
737        }
738
739        // Check for duplicate axes
740        if contracted_a[pos_a] {
741            return Err(ContractionError::DuplicateAxis {
742                tensor: "self",
743                pos: pos_a,
744            });
745        }
746        if contracted_b[pos_b] {
747            return Err(ContractionError::DuplicateAxis {
748                tensor: "other",
749                pos: pos_b,
750            });
751        }
752
753        contracted_a[pos_a] = true;
754        contracted_b[pos_b] = true;
755        axes_a.push(pos_a);
756        axes_b.push(pos_b);
757    }
758
759    let result_indices = build_contraction_result_indices(indices_a, &axes_a, indices_b, &axes_b);
760
761    Ok(ContractionSpec {
762        axes_a,
763        axes_b,
764        result_indices,
765    })
766}
767
768fn bool_flags(len: usize) -> BoolVec {
769    let mut flags = BoolVec::with_capacity(len);
770    flags.resize(len, false);
771    flags
772}
773
774fn build_contraction_result_indices<I: IndexLike>(
775    indices_a: &[I],
776    axes_a: &[usize],
777    indices_b: &[I],
778    axes_b: &[usize],
779) -> IndexVec<I> {
780    let mut contracted_a = bool_flags(indices_a.len());
781    let mut contracted_b = bool_flags(indices_b.len());
782    for &axis in axes_a {
783        contracted_a[axis] = true;
784    }
785    for &axis in axes_b {
786        contracted_b[axis] = true;
787    }
788
789    let result_len = indices_a.len() + indices_b.len() - axes_a.len() - axes_b.len();
790    let mut result_indices = IndexVec::with_capacity(result_len);
791
792    for (i, idx) in indices_a.iter().enumerate() {
793        if !contracted_a[i] {
794            result_indices.push(idx.clone());
795        }
796    }
797
798    for (i, idx) in indices_b.iter().enumerate() {
799        if !contracted_b[i] {
800            result_indices.push(idx.clone());
801        }
802    }
803
804    result_indices
805}
806
807#[cfg(test)]
808mod tests {
809    use super::{prepare_contraction, prepare_contraction_pairs};
810    use crate::index::DefaultIndex as Index;
811
812    #[test]
813    fn prepare_contraction_pairs_selects_exact_same_id_prime_index() {
814        let i = Index::new_dyn(2);
815        let i_prime = i.prime();
816        let spec = prepare_contraction_pairs(
817            &[i.clone(), i_prime.clone()],
818            &[2, 2],
819            std::slice::from_ref(&i_prime),
820            &[2],
821            &[(i_prime.clone(), i_prime.clone())],
822        )
823        .unwrap();
824
825        assert_eq!(spec.axes_a.as_slice(), &[1]);
826        assert_eq!(spec.axes_b.as_slice(), &[0]);
827        assert_eq!(spec.result_indices.as_slice(), &[i]);
828    }
829
830    #[test]
831    fn prepare_contraction_large_rank_uses_hash_fallback_semantics() {
832        let mut lhs: Vec<_> = (0..9).map(|_| Index::new_dyn(2)).collect();
833        let shared = lhs[7].clone();
834        let mut rhs: Vec<_> = (0..9).map(|_| Index::new_dyn(2)).collect();
835        rhs[5] = shared;
836
837        let lhs_dims = vec![2; lhs.len()];
838        let rhs_dims = vec![2; rhs.len()];
839        let spec = prepare_contraction(&lhs, &lhs_dims, &rhs, &rhs_dims).unwrap();
840
841        assert_eq!(spec.axes_a.as_slice(), &[7]);
842        assert_eq!(spec.axes_b.as_slice(), &[5]);
843        assert_eq!(spec.result_indices.len(), lhs.len() + rhs.len() - 2);
844
845        lhs.remove(7);
846        rhs.remove(5);
847        let mut expected = lhs;
848        expected.extend(rhs);
849        assert_eq!(spec.result_indices.as_slice(), expected.as_slice());
850    }
851}