Skip to main content

tensor4all_core/
index_ops.rs

1use crate::IndexLike;
2
3/// Error type for index replacement operations.
4#[derive(Debug, Clone, PartialEq, Eq)]
5pub enum ReplaceIndsError {
6    /// The symmetry space of the replacement index does not match the original.
7    SpaceMismatch {
8        /// The dimension/size of the original index
9        from_dim: usize,
10        /// The dimension/size of the replacement index
11        to_dim: usize,
12    },
13    /// Duplicate indices found in the collection.
14    DuplicateIndices {
15        /// The position of the first duplicate index
16        first_pos: usize,
17        /// The position of the duplicate index
18        duplicate_pos: usize,
19    },
20}
21
22impl std::fmt::Display for ReplaceIndsError {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        match self {
25            ReplaceIndsError::SpaceMismatch { from_dim, to_dim } => {
26                write!(
27                    f,
28                    "Index space mismatch: cannot replace index with dimension {} with index of dimension {}",
29                    from_dim, to_dim
30                )
31            }
32            ReplaceIndsError::DuplicateIndices {
33                first_pos,
34                duplicate_pos,
35            } => {
36                write!(
37                    f,
38                    "Duplicate indices found: index at position {} equals index at position {}",
39                    duplicate_pos, first_pos
40                )
41            }
42        }
43    }
44}
45
46impl std::error::Error for ReplaceIndsError {}
47
48/// Check if a collection of indices contains any duplicate full indices.
49///
50/// # Arguments
51/// * `indices` - Collection of indices to check
52///
53/// # Returns
54/// `Ok(())` if all indices are unique, or `Err(ReplaceIndsError::DuplicateIndices)` if duplicates are found.
55///
56/// # Example
57/// ```
58/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
59/// use tensor4all_core::index_ops::check_unique_indices;
60///
61/// let i = Index::new_dyn(2);
62/// let j = Index::new_dyn(3);
63/// let indices = vec![i.clone(), j.clone()];
64/// assert!(check_unique_indices(&indices).is_ok());
65///
66/// let duplicate = vec![i.clone(), i.clone()];
67/// assert!(check_unique_indices(&duplicate).is_err());
68/// ```
69pub fn check_unique_indices<I: IndexLike>(indices: &[I]) -> Result<(), ReplaceIndsError> {
70    use std::collections::HashMap;
71    let mut seen: HashMap<&I, usize> = HashMap::with_capacity(indices.len());
72    for (pos, idx) in indices.iter().enumerate() {
73        if let Some(&first_pos) = seen.get(idx) {
74            return Err(ReplaceIndsError::DuplicateIndices {
75                first_pos,
76                duplicate_pos: pos,
77            });
78        }
79        seen.insert(idx, pos);
80    }
81    Ok(())
82}
83
84/// Replace indices in a collection based on full-index matching.
85///
86/// This corresponds to ITensors.jl's `replaceinds` function. It replaces indices
87/// in `indices` that equal any of the `(old, new)` pairs in `replacements`.
88/// The replacement index must have the same dimension as the original.
89///
90/// # Arguments
91/// * `indices` - Collection of indices to modify
92/// * `replacements` - Pairs of `(old_index, new_index)` where indices equal to `old_index` are replaced with `new_index`
93///
94/// # Returns
95/// A new vector with replacements applied, or an error if any replacement has a dimension mismatch.
96///
97/// # Errors
98/// Returns `ReplaceIndsError::SpaceMismatch` if any replacement index has a different dimension than the original.
99///
100/// # Example
101/// ```
102/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
103/// use tensor4all_core::index_ops::replaceinds;
104///
105/// let i = Index::new_dyn(2);
106/// let j = Index::new_dyn(3);
107/// let k = Index::new_dyn(4);
108/// let new_j = Index::new_dyn(3);  // Same size as j
109///
110/// let indices = vec![i.clone(), j.clone(), k.clone()];
111/// let replacements = vec![(j.clone(), new_j.clone())];
112///
113/// let replaced = replaceinds(indices, &replacements).unwrap();
114/// assert_eq!(replaced.len(), 3);
115/// assert_eq!(replaced[1].id, new_j.id);
116/// ```
117pub fn replaceinds<I: IndexLike>(
118    indices: Vec<I>,
119    replacements: &[(I, I)],
120) -> Result<Vec<I>, ReplaceIndsError> {
121    // Check for duplicates in input indices
122    check_unique_indices(&indices)?;
123
124    // Build a map from old index to new index for fast lookup.
125    let mut replacement_map = std::collections::HashMap::with_capacity(replacements.len());
126    for (old, new) in replacements {
127        // Validate dimension match
128        if old.dim() != new.dim() {
129            return Err(ReplaceIndsError::SpaceMismatch {
130                from_dim: old.dim(),
131                to_dim: new.dim(),
132            });
133        }
134        replacement_map.insert(old.clone(), new);
135    }
136
137    // Apply replacements
138    let mut result = Vec::with_capacity(indices.len());
139    for idx in indices {
140        if let Some(new_idx) = replacement_map.get(&idx) {
141            result.push((*new_idx).clone());
142        } else {
143            result.push(idx);
144        }
145    }
146
147    // Check for duplicates in result indices
148    check_unique_indices(&result)?;
149    Ok(result)
150}
151
152/// Replace indices in-place based on full-index matching.
153///
154/// This is an in-place variant of `replaceinds` that modifies the input slice directly.
155/// Useful for performance-critical code where you want to avoid allocations.
156///
157/// # Arguments
158/// * `indices` - Mutable slice of indices to modify
159/// * `replacements` - Pairs of `(old_index, new_index)` where indices equal to `old_index` are replaced with `new_index`
160///
161/// # Returns
162/// `Ok(())` on success, or an error if any replacement has a dimension mismatch.
163///
164/// # Errors
165/// Returns `ReplaceIndsError::SpaceMismatch` if any replacement index has a different dimension than the original.
166///
167/// # Example
168/// ```
169/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
170/// use tensor4all_core::index_ops::replaceinds_in_place;
171///
172/// let i = Index::new_dyn(2);
173/// let j = Index::new_dyn(3);
174/// let k = Index::new_dyn(4);
175/// let new_j = Index::new_dyn(3);
176///
177/// let mut indices = vec![i.clone(), j.clone(), k.clone()];
178/// let replacements = vec![(j.clone(), new_j.clone())];
179///
180/// replaceinds_in_place(&mut indices, &replacements).unwrap();
181/// assert_eq!(indices[1].id, new_j.id);
182/// ```
183pub fn replaceinds_in_place<I: IndexLike>(
184    indices: &mut [I],
185    replacements: &[(I, I)],
186) -> Result<(), ReplaceIndsError> {
187    // Check for duplicates in input indices
188    check_unique_indices(indices)?;
189
190    // Build a map from old index to new index for fast lookup.
191    let mut replacement_map = std::collections::HashMap::with_capacity(replacements.len());
192    for (old, new) in replacements {
193        // Validate dimension match
194        if old.dim() != new.dim() {
195            return Err(ReplaceIndsError::SpaceMismatch {
196                from_dim: old.dim(),
197                to_dim: new.dim(),
198            });
199        }
200        replacement_map.insert(old.clone(), new);
201    }
202
203    // Apply replacements in-place
204    for idx in indices.iter_mut() {
205        if let Some(new_idx) = replacement_map.get(idx) {
206            *idx = (*new_idx).clone();
207        }
208    }
209
210    // Check for duplicates in result indices
211    check_unique_indices(indices)?;
212    Ok(())
213}
214
215/// Find indices that are unique to the first collection (set difference A \ B).
216///
217/// Returns indices that appear in `indices_a` but not in `indices_b` (matched by full index).
218/// This corresponds to ITensors.jl's `uniqueinds` function.
219///
220/// # Arguments
221/// * `indices_a` - First collection of indices
222/// * `indices_b` - Second collection of indices
223///
224/// # Returns
225/// A vector containing indices from `indices_a` that are not in `indices_b`.
226///
227/// # Example
228/// ```
229/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
230/// use tensor4all_core::index_ops::unique_inds;
231///
232/// let i = Index::new_dyn(2);
233/// let j = Index::new_dyn(3);
234/// let k = Index::new_dyn(4);
235///
236/// let indices_a = vec![i.clone(), j.clone()];
237/// let indices_b = vec![j.clone(), k.clone()];
238///
239/// let unique = unique_inds(&indices_a, &indices_b);
240/// assert_eq!(unique.len(), 1);
241/// assert_eq!(unique[0].id, i.id);
242/// ```
243pub fn unique_inds<I: IndexLike>(indices_a: &[I], indices_b: &[I]) -> Vec<I> {
244    indices_a
245        .iter()
246        .filter(|idx| !indices_b.iter().any(|other| other == *idx))
247        .cloned()
248        .collect()
249}
250
251/// Find indices that are not common between two collections (symmetric difference).
252///
253/// Returns indices that appear in either `indices_a` or `indices_b` but not in both
254/// (matched by full index). This corresponds to ITensors.jl's `noncommoninds` function.
255///
256/// Time complexity: O(n + m) where n = len(indices_a), m = len(indices_b).
257///
258/// # Arguments
259/// * `indices_a` - First collection of indices
260/// * `indices_b` - Second collection of indices
261///
262/// # Returns
263/// A vector containing indices from both collections that are not common to both.
264/// Order: indices from A first (in original order), then indices from B (in original order).
265///
266/// # Example
267/// ```
268/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
269/// use tensor4all_core::index_ops::noncommon_inds;
270///
271/// let i = Index::new_dyn(2);
272/// let j = Index::new_dyn(3);
273/// let k = Index::new_dyn(4);
274///
275/// let indices_a = vec![i.clone(), j.clone()];
276/// let indices_b = vec![j.clone(), k.clone()];
277///
278/// let noncommon = noncommon_inds(&indices_a, &indices_b);
279/// assert_eq!(noncommon.len(), 2);  // i and k
280/// ```
281pub fn noncommon_inds<I: IndexLike>(indices_a: &[I], indices_b: &[I]) -> Vec<I> {
282    // Pre-allocate with estimated capacity (worst case: no common indices)
283    let mut result = Vec::with_capacity(indices_a.len() + indices_b.len());
284
285    // Add indices from A that are not in B
286    result.extend(
287        indices_a
288            .iter()
289            .filter(|idx| !indices_b.iter().any(|other| other == *idx))
290            .cloned(),
291    );
292    // Add indices from B that are not in A
293    result.extend(
294        indices_b
295            .iter()
296            .filter(|idx| !indices_a.iter().any(|other| other == *idx))
297            .cloned(),
298    );
299    result
300}
301
302/// Find the union of two index collections.
303///
304/// Returns all unique indices from both collections (matched by full index).
305/// This corresponds to ITensors.jl's `unioninds` function.
306///
307/// Time complexity: O(n + m) where n = len(indices_a), m = len(indices_b).
308///
309/// # Arguments
310/// * `indices_a` - First collection of indices
311/// * `indices_b` - Second collection of indices
312///
313/// # Returns
314/// A vector containing all unique indices from both collections.
315///
316/// # Example
317/// ```
318/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
319/// use tensor4all_core::index_ops::union_inds;
320///
321/// let i = Index::new_dyn(2);
322/// let j = Index::new_dyn(3);
323/// let k = Index::new_dyn(4);
324///
325/// let indices_a = vec![i.clone(), j.clone()];
326/// let indices_b = vec![j.clone(), k.clone()];
327///
328/// let union = union_inds(&indices_a, &indices_b);
329/// assert_eq!(union.len(), 3);  // i, j, k
330/// ```
331pub fn union_inds<I: IndexLike>(indices_a: &[I], indices_b: &[I]) -> Vec<I> {
332    let mut seen: std::collections::HashSet<&I> =
333        std::collections::HashSet::with_capacity(indices_a.len() + indices_b.len());
334    let mut result = Vec::with_capacity(indices_a.len() + indices_b.len());
335
336    for idx in indices_a {
337        if seen.insert(idx) {
338            result.push(idx.clone());
339        }
340    }
341    for idx in indices_b {
342        if seen.insert(idx) {
343            result.push(idx.clone());
344        }
345    }
346    result
347}
348
349/// Check if a collection contains a specific full index.
350///
351/// This corresponds to ITensors.jl's `hasind` function.
352///
353/// # Arguments
354/// * `indices` - Collection of indices to search
355/// * `index` - The index to look for
356///
357/// # Returns
358/// `true` if an index with matching ID is found, `false` otherwise.
359///
360/// # Example
361/// ```
362/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
363/// use tensor4all_core::index_ops::hasind;
364///
365/// let i = Index::new_dyn(2);
366/// let j = Index::new_dyn(3);
367/// let indices = vec![i.clone(), j.clone()];
368///
369/// assert!(hasind(&indices, &i));
370/// assert!(!hasind(&indices, &Index::new_dyn(4)));
371/// ```
372pub fn hasind<I: IndexLike>(indices: &[I], index: &I) -> bool {
373    indices.iter().any(|idx| idx == index)
374}
375
376/// Check if a collection contains all of the specified full indices.
377///
378/// This corresponds to ITensors.jl's `hasinds` function.
379///
380/// # Arguments
381/// * `indices` - Collection of indices to search
382/// * `targets` - The indices to look for
383///
384/// # Returns
385/// `true` if all target indices are found, `false` otherwise.
386///
387/// # Example
388/// ```
389/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
390/// use tensor4all_core::index_ops::hasinds;
391///
392/// let i = Index::new_dyn(2);
393/// let j = Index::new_dyn(3);
394/// let k = Index::new_dyn(4);
395/// let indices = vec![i.clone(), j.clone(), k.clone()];
396///
397/// assert!(hasinds(&indices, &[i.clone(), j.clone()]));
398/// assert!(!hasinds(&indices, &[i.clone(), Index::new_dyn(5)]));
399/// ```
400pub fn hasinds<I: IndexLike>(indices: &[I], targets: &[I]) -> bool {
401    targets
402        .iter()
403        .all(|target| indices.iter().any(|idx| idx == target))
404}
405
406/// Check if two collections have any common full indices.
407///
408/// This corresponds to ITensors.jl's `hascommoninds` function.
409///
410/// # Arguments
411/// * `indices_a` - First collection of indices
412/// * `indices_b` - Second collection of indices
413///
414/// # Returns
415/// `true` if there is at least one common index, `false` otherwise.
416///
417/// # Example
418/// ```
419/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
420/// use tensor4all_core::index_ops::hascommoninds;
421///
422/// let i = Index::new_dyn(2);
423/// let j = Index::new_dyn(3);
424/// let k = Index::new_dyn(4);
425///
426/// let indices_a = vec![i.clone(), j.clone()];
427/// let indices_b = vec![j.clone(), k.clone()];
428///
429/// assert!(hascommoninds(&indices_a, &indices_b));
430/// assert!(!hascommoninds(&[i.clone()], &[k.clone()]));
431/// ```
432pub fn hascommoninds<I: IndexLike>(indices_a: &[I], indices_b: &[I]) -> bool {
433    indices_a
434        .iter()
435        .any(|idx| indices_b.iter().any(|other| other == idx))
436}
437
438/// Find common indices between two index collections.
439///
440/// Returns a vector of indices that appear in both `indices_a` and `indices_b`
441/// (set intersection). This is similar to ITensors.jl's `commoninds` function.
442///
443/// Time complexity: O(n + m) where n = len(indices_a), m = len(indices_b).
444///
445/// # Arguments
446/// * `indices_a` - First collection of indices
447/// * `indices_b` - Second collection of indices
448///
449/// # Returns
450/// A vector containing indices that are common to both collections (matched by full index).
451///
452/// # Example
453/// ```
454/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
455/// use tensor4all_core::index_ops::common_inds;
456///
457/// let i = Index::new_dyn(2);
458/// let j = Index::new_dyn(3);
459/// let k = Index::new_dyn(4);
460///
461/// let indices_a = vec![i.clone(), j.clone()];
462/// let indices_b = vec![j.clone(), k.clone()];
463///
464/// let common = common_inds(&indices_a, &indices_b);
465/// assert_eq!(common.len(), 1);
466/// assert_eq!(common[0].id, j.id);
467/// ```
468pub fn common_inds<I: IndexLike>(indices_a: &[I], indices_b: &[I]) -> Vec<I> {
469    indices_a
470        .iter()
471        .filter(|idx| indices_b.iter().any(|other| other == *idx))
472        .cloned()
473        .collect()
474}
475
476/// Find contractable indices between two slices and return their positions.
477///
478/// Returns a vector of `(pos_a, pos_b)` tuples where each tuple indicates
479/// that `indices_a[pos_a]` and `indices_b[pos_b]` are contractable
480/// (same ID, same dimension, and compatible ConjState).
481///
482/// # Example
483/// ```
484/// use tensor4all_core::index::DefaultIndex as Index;
485/// use tensor4all_core::index_ops::common_ind_positions;
486///
487/// let i = Index::new_dyn(2);
488/// let j = Index::new_dyn(3);
489/// let k = Index::new_dyn(4);
490///
491/// let indices_a = vec![i.clone(), j.clone()];
492/// let indices_b = vec![j.clone(), k.clone()];
493///
494/// let positions = common_ind_positions(&indices_a, &indices_b);
495/// assert_eq!(positions, vec![(1, 0)]); // j is at position 1 in a, position 0 in b
496/// ```
497pub fn common_ind_positions<I: IndexLike>(indices_a: &[I], indices_b: &[I]) -> Vec<(usize, usize)> {
498    let mut positions = Vec::new();
499    for (pos_a, idx_a) in indices_a.iter().enumerate() {
500        for (pos_b, idx_b) in indices_b.iter().enumerate() {
501            if idx_a.is_contractable(idx_b) {
502                positions.push((pos_a, pos_b));
503                break; // Each index in a can match at most one in b
504            }
505        }
506    }
507    positions
508}
509
510/// Result of preparing a tensor contraction.
511///
512/// Contains all the information needed to perform the contraction:
513/// - Which axes to contract from each tensor
514/// - The resulting indices and dimensions after contraction
515#[derive(Debug, Clone)]
516pub struct ContractionSpec<I: IndexLike> {
517    /// Axes to contract from the first tensor (positions in `indices_a`).
518    pub axes_a: Vec<usize>,
519    /// Axes to contract from the second tensor (positions in `indices_b`).
520    pub axes_b: Vec<usize>,
521    /// Indices of the result tensor (non-contracted indices from both tensors).
522    pub result_indices: Vec<I>,
523    /// Dimensions of the result tensor.
524    pub result_dims: Vec<usize>,
525}
526
527/// Error type for contraction preparation.
528#[derive(Debug, Clone, PartialEq, Eq)]
529pub enum ContractionError {
530    /// No common indices found for contraction.
531    NoCommonIndices,
532    /// Dimension mismatch for a common index.
533    DimensionMismatch {
534        /// Position in the first tensor.
535        pos_a: usize,
536        /// Position in the second tensor.
537        pos_b: usize,
538        /// Dimension in the first tensor.
539        dim_a: usize,
540        /// Dimension in the second tensor.
541        dim_b: usize,
542    },
543    /// Duplicate axis specified in contraction.
544    DuplicateAxis {
545        /// Which tensor has the duplicate ("self" or "other").
546        tensor: &'static str,
547        /// Position of the duplicate axis.
548        pos: usize,
549    },
550    /// Index not found in tensor.
551    IndexNotFound {
552        /// Which tensor the index was not found in.
553        tensor: &'static str,
554    },
555    /// Batch contraction not yet implemented.
556    BatchContractionNotImplemented,
557}
558
559impl std::fmt::Display for ContractionError {
560    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
561        match self {
562            ContractionError::NoCommonIndices => {
563                write!(f, "No common indices found for contraction")
564            }
565            ContractionError::DimensionMismatch {
566                pos_a,
567                pos_b,
568                dim_a,
569                dim_b,
570            } => {
571                write!(
572                    f,
573                    "Dimension mismatch: tensor_a[{}]={} != tensor_b[{}]={}",
574                    pos_a, dim_a, pos_b, dim_b
575                )
576            }
577            ContractionError::DuplicateAxis { tensor, pos } => {
578                write!(f, "Duplicate axis {} in {} tensor", pos, tensor)
579            }
580            ContractionError::IndexNotFound { tensor } => {
581                write!(f, "Index not found in {} tensor", tensor)
582            }
583            ContractionError::BatchContractionNotImplemented => {
584                write!(f, "Batch contraction not yet implemented")
585            }
586        }
587    }
588}
589
590impl std::error::Error for ContractionError {}
591
592/// Prepare contraction data for two tensors that share common indices.
593///
594/// This function finds common indices and computes the axes to contract
595/// and the resulting indices/dimensions.
596///
597/// # Example
598/// ```
599/// use tensor4all_core::index::DefaultIndex as Index;
600/// use tensor4all_core::index_ops::prepare_contraction;
601///
602/// let i = Index::new_dyn(2);
603/// let j = Index::new_dyn(3);
604/// let k = Index::new_dyn(4);
605///
606/// let indices_a = vec![i.clone(), j.clone()];
607/// let dims_a = vec![2, 3];
608/// let indices_b = vec![j.clone(), k.clone()];
609/// let dims_b = vec![3, 4];
610///
611/// let spec = prepare_contraction(&indices_a, &dims_a, &indices_b, &dims_b).unwrap();
612/// assert_eq!(spec.axes_a, vec![1]);  // j is at position 1 in a
613/// assert_eq!(spec.axes_b, vec![0]);  // j is at position 0 in b
614/// assert_eq!(spec.result_dims, vec![2, 4]);  // [i, k]
615/// ```
616pub fn prepare_contraction<I: IndexLike>(
617    indices_a: &[I],
618    dims_a: &[usize],
619    indices_b: &[I],
620    dims_b: &[usize],
621) -> Result<ContractionSpec<I>, ContractionError> {
622    // Find common indices and their positions.
623    // If no common indices exist, this becomes an outer product (empty axes).
624    let positions = common_ind_positions(indices_a, indices_b);
625
626    let (axes_a, axes_b): (Vec<_>, Vec<_>) = positions.iter().copied().unzip();
627
628    // Verify dimensions match
629    for &(pos_a, pos_b) in &positions {
630        if dims_a[pos_a] != dims_b[pos_b] {
631            return Err(ContractionError::DimensionMismatch {
632                pos_a,
633                pos_b,
634                dim_a: dims_a[pos_a],
635                dim_b: dims_b[pos_b],
636            });
637        }
638    }
639
640    // Build result indices and dimensions (non-contracted indices)
641    let mut result_indices = Vec::new();
642    let mut result_dims = Vec::new();
643
644    for (i, idx) in indices_a.iter().enumerate() {
645        if !axes_a.contains(&i) {
646            result_indices.push(idx.clone());
647            result_dims.push(dims_a[i]);
648        }
649    }
650
651    for (i, idx) in indices_b.iter().enumerate() {
652        if !axes_b.contains(&i) {
653            result_indices.push(idx.clone());
654            result_dims.push(dims_b[i]);
655        }
656    }
657
658    Ok(ContractionSpec {
659        axes_a,
660        axes_b,
661        result_indices,
662        result_dims,
663    })
664}
665
666/// Prepare contraction data for explicit index pairs (like tensordot).
667///
668/// Unlike `prepare_contraction`, this function takes explicit pairs of indices
669/// to contract, allowing contraction of indices with different IDs.
670///
671/// # Example
672/// ```
673/// use tensor4all_core::index::DefaultIndex as Index;
674/// use tensor4all_core::index_ops::prepare_contraction_pairs;
675///
676/// let i = Index::new_dyn(2);
677/// let j = Index::new_dyn(3);
678/// let k = Index::new_dyn(3);  // Same dim as j but different ID
679/// let l = Index::new_dyn(4);
680///
681/// let indices_a = vec![i.clone(), j.clone()];
682/// let dims_a = vec![2, 3];
683/// let indices_b = vec![k.clone(), l.clone()];
684/// let dims_b = vec![3, 4];
685///
686/// // Contract j with k
687/// let spec = prepare_contraction_pairs(
688///     &indices_a, &dims_a,
689///     &indices_b, &dims_b,
690///     &[(j.clone(), k.clone())]
691/// ).unwrap();
692/// assert_eq!(spec.axes_a, vec![1]);
693/// assert_eq!(spec.axes_b, vec![0]);
694/// assert_eq!(spec.result_dims, vec![2, 4]);
695/// ```
696pub fn prepare_contraction_pairs<I: IndexLike>(
697    indices_a: &[I],
698    dims_a: &[usize],
699    indices_b: &[I],
700    dims_b: &[usize],
701    pairs: &[(I, I)],
702) -> Result<ContractionSpec<I>, ContractionError> {
703    use std::collections::HashSet;
704
705    if pairs.is_empty() {
706        return Err(ContractionError::NoCommonIndices);
707    }
708
709    // Check for batch contraction (common indices not in pairs). The explicit
710    // pair list identifies axes by full index metadata, not by ID alone.
711    let contracted_a_indices: HashSet<_> = pairs.iter().map(|(idx, _)| idx).collect();
712    let contracted_b_indices: HashSet<_> = pairs.iter().map(|(_, idx)| idx).collect();
713
714    let common_positions = common_ind_positions(indices_a, indices_b);
715    for (pos_a, pos_b) in &common_positions {
716        let idx_a = &indices_a[*pos_a];
717        let idx_b = &indices_b[*pos_b];
718        if !contracted_a_indices.contains(idx_a) || !contracted_b_indices.contains(idx_b) {
719            return Err(ContractionError::BatchContractionNotImplemented);
720        }
721    }
722
723    // Find positions and validate
724    let mut axes_a = Vec::new();
725    let mut axes_b = Vec::new();
726
727    for (idx_a, idx_b) in pairs {
728        let pos_a = indices_a
729            .iter()
730            .position(|idx| idx == idx_a)
731            .ok_or(ContractionError::IndexNotFound { tensor: "self" })?;
732
733        let pos_b = indices_b
734            .iter()
735            .position(|idx| idx == idx_b)
736            .ok_or(ContractionError::IndexNotFound { tensor: "other" })?;
737
738        // Verify dimensions match
739        if dims_a[pos_a] != dims_b[pos_b] {
740            return Err(ContractionError::DimensionMismatch {
741                pos_a,
742                pos_b,
743                dim_a: dims_a[pos_a],
744                dim_b: dims_b[pos_b],
745            });
746        }
747
748        // Check for duplicate axes
749        if axes_a.contains(&pos_a) {
750            return Err(ContractionError::DuplicateAxis {
751                tensor: "self",
752                pos: pos_a,
753            });
754        }
755        if axes_b.contains(&pos_b) {
756            return Err(ContractionError::DuplicateAxis {
757                tensor: "other",
758                pos: pos_b,
759            });
760        }
761
762        axes_a.push(pos_a);
763        axes_b.push(pos_b);
764    }
765
766    // Build result indices and dimensions
767    let mut result_indices = Vec::new();
768    let mut result_dims = Vec::new();
769
770    for (i, idx) in indices_a.iter().enumerate() {
771        if !axes_a.contains(&i) {
772            result_indices.push(idx.clone());
773            result_dims.push(dims_a[i]);
774        }
775    }
776
777    for (i, idx) in indices_b.iter().enumerate() {
778        if !axes_b.contains(&i) {
779            result_indices.push(idx.clone());
780            result_dims.push(dims_b[i]);
781        }
782    }
783
784    Ok(ContractionSpec {
785        axes_a,
786        axes_b,
787        result_indices,
788        result_dims,
789    })
790}