tensor4all_core/
index_ops.rs

1use crate::IndexLike;
2
3/// Error type for index replacement operations.
4#[derive(Debug, Clone, PartialEq, Eq)]
5pub enum ReplaceIndsError {
6    /// The symmetry space of the replacement index does not match the original.
7    SpaceMismatch {
8        /// The dimension/size of the original index
9        from_dim: usize,
10        /// The dimension/size of the replacement index
11        to_dim: usize,
12    },
13    /// Duplicate indices found in the collection.
14    DuplicateIndices {
15        /// The position of the first duplicate index
16        first_pos: usize,
17        /// The position of the duplicate index
18        duplicate_pos: usize,
19    },
20}
21
22impl std::fmt::Display for ReplaceIndsError {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        match self {
25            ReplaceIndsError::SpaceMismatch { from_dim, to_dim } => {
26                write!(
27                    f,
28                    "Index space mismatch: cannot replace index with dimension {} with index of dimension {}",
29                    from_dim, to_dim
30                )
31            }
32            ReplaceIndsError::DuplicateIndices {
33                first_pos,
34                duplicate_pos,
35            } => {
36                write!(
37                    f,
38                    "Duplicate indices found: index at position {} has the same ID as index at position {}",
39                    duplicate_pos, first_pos
40                )
41            }
42        }
43    }
44}
45
46impl std::error::Error for ReplaceIndsError {}
47
48/// Check if a collection of indices contains any duplicates (by ID).
49///
50/// # Arguments
51/// * `indices` - Collection of indices to check
52///
53/// # Returns
54/// `Ok(())` if all indices are unique, or `Err(ReplaceIndsError::DuplicateIndices)` if duplicates are found.
55///
56/// # Example
57/// ```
58/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
59/// use tensor4all_core::index_ops::check_unique_indices;
60///
61/// let i = Index::new_dyn(2);
62/// let j = Index::new_dyn(3);
63/// let indices = vec![i.clone(), j.clone()];
64/// assert!(check_unique_indices(&indices).is_ok());
65///
66/// let duplicate = vec![i.clone(), i.clone()];
67/// assert!(check_unique_indices(&duplicate).is_err());
68/// ```
69pub fn check_unique_indices<I: IndexLike>(indices: &[I]) -> Result<(), ReplaceIndsError> {
70    use std::collections::HashMap;
71    let mut seen: HashMap<&I::Id, usize> = HashMap::with_capacity(indices.len());
72    for (pos, idx) in indices.iter().enumerate() {
73        if let Some(&first_pos) = seen.get(idx.id()) {
74            return Err(ReplaceIndsError::DuplicateIndices {
75                first_pos,
76                duplicate_pos: pos,
77            });
78        }
79        seen.insert(idx.id(), pos);
80    }
81    Ok(())
82}
83
84/// Replace indices in a collection based on ID matching.
85///
86/// This corresponds to ITensors.jl's `replaceinds` function. It replaces indices
87/// in `indices` that match (by ID) any of the `(old, new)` pairs in `replacements`.
88/// The replacement index must have the same dimension as the original.
89///
90/// # Arguments
91/// * `indices` - Collection of indices to modify
92/// * `replacements` - Pairs of `(old_index, new_index)` where indices matching `old_index.id` are replaced with `new_index`
93///
94/// # Returns
95/// A new vector with replacements applied, or an error if any replacement has a dimension mismatch.
96///
97/// # Errors
98/// Returns `ReplaceIndsError::SpaceMismatch` if any replacement index has a different dimension than the original.
99///
100/// # Example
101/// ```
102/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
103/// use tensor4all_core::index_ops::replaceinds;
104///
105/// let i = Index::new_dyn(2);
106/// let j = Index::new_dyn(3);
107/// let k = Index::new_dyn(4);
108/// let new_j = Index::new_dyn(3);  // Same size as j
109///
110/// let indices = vec![i.clone(), j.clone(), k.clone()];
111/// let replacements = vec![(j.clone(), new_j.clone())];
112///
113/// let replaced = replaceinds(indices, &replacements).unwrap();
114/// assert_eq!(replaced.len(), 3);
115/// assert_eq!(replaced[1].id, new_j.id);
116/// ```
117pub fn replaceinds<I: IndexLike>(
118    indices: Vec<I>,
119    replacements: &[(I, I)],
120) -> Result<Vec<I>, ReplaceIndsError> {
121    // Check for duplicates in input indices
122    check_unique_indices(&indices)?;
123
124    // Build a map from old ID to new index for fast lookup
125    let mut replacement_map = std::collections::HashMap::with_capacity(replacements.len());
126    for (old, new) in replacements {
127        // Validate dimension match
128        if old.dim() != new.dim() {
129            return Err(ReplaceIndsError::SpaceMismatch {
130                from_dim: old.dim(),
131                to_dim: new.dim(),
132            });
133        }
134        replacement_map.insert(old.id(), new);
135    }
136
137    // Apply replacements
138    let mut result = Vec::with_capacity(indices.len());
139    for idx in indices {
140        if let Some(new_idx) = replacement_map.get(idx.id()) {
141            result.push((*new_idx).clone());
142        } else {
143            result.push(idx);
144        }
145    }
146
147    // Check for duplicates in result indices
148    check_unique_indices(&result)?;
149    Ok(result)
150}
151
152/// Replace indices in-place based on ID matching.
153///
154/// This is an in-place variant of `replaceinds` that modifies the input slice directly.
155/// Useful for performance-critical code where you want to avoid allocations.
156///
157/// # Arguments
158/// * `indices` - Mutable slice of indices to modify
159/// * `replacements` - Pairs of `(old_index, new_index)` where indices matching `old_index.id` are replaced with `new_index`
160///
161/// # Returns
162/// `Ok(())` on success, or an error if any replacement has a dimension mismatch.
163///
164/// # Errors
165/// Returns `ReplaceIndsError::SpaceMismatch` if any replacement index has a different dimension than the original.
166///
167/// # Example
168/// ```
169/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
170/// use tensor4all_core::index_ops::replaceinds_in_place;
171///
172/// let i = Index::new_dyn(2);
173/// let j = Index::new_dyn(3);
174/// let k = Index::new_dyn(4);
175/// let new_j = Index::new_dyn(3);
176///
177/// let mut indices = vec![i.clone(), j.clone(), k.clone()];
178/// let replacements = vec![(j.clone(), new_j.clone())];
179///
180/// replaceinds_in_place(&mut indices, &replacements).unwrap();
181/// assert_eq!(indices[1].id, new_j.id);
182/// ```
183pub fn replaceinds_in_place<I: IndexLike>(
184    indices: &mut [I],
185    replacements: &[(I, I)],
186) -> Result<(), ReplaceIndsError> {
187    // Check for duplicates in input indices
188    check_unique_indices(indices)?;
189
190    // Build a map from old ID to new index for fast lookup
191    let mut replacement_map = std::collections::HashMap::with_capacity(replacements.len());
192    for (old, new) in replacements {
193        // Validate dimension match
194        if old.dim() != new.dim() {
195            return Err(ReplaceIndsError::SpaceMismatch {
196                from_dim: old.dim(),
197                to_dim: new.dim(),
198            });
199        }
200        replacement_map.insert(old.id(), new);
201    }
202
203    // Apply replacements in-place
204    for idx in indices.iter_mut() {
205        if let Some(new_idx) = replacement_map.get(idx.id()) {
206            *idx = (*new_idx).clone();
207        }
208    }
209
210    // Check for duplicates in result indices
211    check_unique_indices(indices)?;
212    Ok(())
213}
214
215/// Find indices that are unique to the first collection (set difference A \ B).
216///
217/// Returns indices that appear in `indices_a` but not in `indices_b` (matched by ID).
218/// This corresponds to ITensors.jl's `uniqueinds` function.
219///
220/// # Arguments
221/// * `indices_a` - First collection of indices
222/// * `indices_b` - Second collection of indices
223///
224/// # Returns
225/// A vector containing indices from `indices_a` that are not in `indices_b`.
226///
227/// # Example
228/// ```
229/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
230/// use tensor4all_core::index_ops::unique_inds;
231///
232/// let i = Index::new_dyn(2);
233/// let j = Index::new_dyn(3);
234/// let k = Index::new_dyn(4);
235///
236/// let indices_a = vec![i.clone(), j.clone()];
237/// let indices_b = vec![j.clone(), k.clone()];
238///
239/// let unique = unique_inds(&indices_a, &indices_b);
240/// assert_eq!(unique.len(), 1);
241/// assert_eq!(unique[0].id, i.id);
242/// ```
243pub fn unique_inds<I: IndexLike>(indices_a: &[I], indices_b: &[I]) -> Vec<I> {
244    let b_ids: std::collections::HashSet<_> = indices_b.iter().map(|idx| idx.id()).collect();
245    indices_a
246        .iter()
247        .filter(|idx| !b_ids.contains(idx.id()))
248        .cloned()
249        .collect()
250}
251
252/// Find indices that are not common between two collections (symmetric difference).
253///
254/// Returns indices that appear in either `indices_a` or `indices_b` but not in both
255/// (matched by ID). This corresponds to ITensors.jl's `noncommoninds` function.
256///
257/// Time complexity: O(n + m) where n = len(indices_a), m = len(indices_b).
258///
259/// # Arguments
260/// * `indices_a` - First collection of indices
261/// * `indices_b` - Second collection of indices
262///
263/// # Returns
264/// A vector containing indices from both collections that are not common to both.
265/// Order: indices from A first (in original order), then indices from B (in original order).
266///
267/// # Example
268/// ```
269/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
270/// use tensor4all_core::index_ops::noncommon_inds;
271///
272/// let i = Index::new_dyn(2);
273/// let j = Index::new_dyn(3);
274/// let k = Index::new_dyn(4);
275///
276/// let indices_a = vec![i.clone(), j.clone()];
277/// let indices_b = vec![j.clone(), k.clone()];
278///
279/// let noncommon = noncommon_inds(&indices_a, &indices_b);
280/// assert_eq!(noncommon.len(), 2);  // i and k
281/// ```
282pub fn noncommon_inds<I: IndexLike>(indices_a: &[I], indices_b: &[I]) -> Vec<I> {
283    let a_ids: std::collections::HashSet<_> = indices_a.iter().map(|idx| idx.id()).collect();
284    let b_ids: std::collections::HashSet<_> = indices_b.iter().map(|idx| idx.id()).collect();
285
286    // Pre-allocate with estimated capacity (worst case: no common indices)
287    let mut result = Vec::with_capacity(indices_a.len() + indices_b.len());
288
289    // Add indices from A that are not in B
290    result.extend(
291        indices_a
292            .iter()
293            .filter(|idx| !b_ids.contains(idx.id()))
294            .cloned(),
295    );
296    // Add indices from B that are not in A
297    result.extend(
298        indices_b
299            .iter()
300            .filter(|idx| !a_ids.contains(idx.id()))
301            .cloned(),
302    );
303    result
304}
305
306/// Find the union of two index collections.
307///
308/// Returns all unique indices from both collections (matched by ID).
309/// This corresponds to ITensors.jl's `unioninds` function.
310///
311/// Time complexity: O(n + m) where n = len(indices_a), m = len(indices_b).
312///
313/// # Arguments
314/// * `indices_a` - First collection of indices
315/// * `indices_b` - Second collection of indices
316///
317/// # Returns
318/// A vector containing all unique indices from both collections.
319///
320/// # Example
321/// ```
322/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
323/// use tensor4all_core::index_ops::union_inds;
324///
325/// let i = Index::new_dyn(2);
326/// let j = Index::new_dyn(3);
327/// let k = Index::new_dyn(4);
328///
329/// let indices_a = vec![i.clone(), j.clone()];
330/// let indices_b = vec![j.clone(), k.clone()];
331///
332/// let union = union_inds(&indices_a, &indices_b);
333/// assert_eq!(union.len(), 3);  // i, j, k
334/// ```
335pub fn union_inds<I: IndexLike>(indices_a: &[I], indices_b: &[I]) -> Vec<I> {
336    let mut seen: std::collections::HashSet<&I::Id> =
337        std::collections::HashSet::with_capacity(indices_a.len() + indices_b.len());
338    let mut result = Vec::with_capacity(indices_a.len() + indices_b.len());
339
340    for idx in indices_a {
341        if seen.insert(idx.id()) {
342            result.push(idx.clone());
343        }
344    }
345    for idx in indices_b {
346        if seen.insert(idx.id()) {
347            result.push(idx.clone());
348        }
349    }
350    result
351}
352
353/// Check if a collection contains a specific index (by ID).
354///
355/// This corresponds to ITensors.jl's `hasind` function.
356///
357/// # Arguments
358/// * `indices` - Collection of indices to search
359/// * `index` - The index to look for
360///
361/// # Returns
362/// `true` if an index with matching ID is found, `false` otherwise.
363///
364/// # Example
365/// ```
366/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
367/// use tensor4all_core::index_ops::hasind;
368///
369/// let i = Index::new_dyn(2);
370/// let j = Index::new_dyn(3);
371/// let indices = vec![i.clone(), j.clone()];
372///
373/// assert!(hasind(&indices, &i));
374/// assert!(!hasind(&indices, &Index::new_dyn(4)));
375/// ```
376pub fn hasind<I: IndexLike>(indices: &[I], index: &I) -> bool {
377    indices.iter().any(|idx| idx == index)
378}
379
380/// Check if a collection contains all of the specified indices (by ID).
381///
382/// This corresponds to ITensors.jl's `hasinds` function.
383///
384/// # Arguments
385/// * `indices` - Collection of indices to search
386/// * `targets` - The indices to look for
387///
388/// # Returns
389/// `true` if all target indices (by ID) are found, `false` otherwise.
390///
391/// # Example
392/// ```
393/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
394/// use tensor4all_core::index_ops::hasinds;
395///
396/// let i = Index::new_dyn(2);
397/// let j = Index::new_dyn(3);
398/// let k = Index::new_dyn(4);
399/// let indices = vec![i.clone(), j.clone(), k.clone()];
400///
401/// assert!(hasinds(&indices, &[i.clone(), j.clone()]));
402/// assert!(!hasinds(&indices, &[i.clone(), Index::new_dyn(5)]));
403/// ```
404pub fn hasinds<I: IndexLike>(indices: &[I], targets: &[I]) -> bool {
405    let index_ids: std::collections::HashSet<_> = indices.iter().map(|idx| idx.id()).collect();
406    targets.iter().all(|target| index_ids.contains(target.id()))
407}
408
409/// Check if two collections have any common indices (by ID).
410///
411/// This corresponds to ITensors.jl's `hascommoninds` function.
412///
413/// # Arguments
414/// * `indices_a` - First collection of indices
415/// * `indices_b` - Second collection of indices
416///
417/// # Returns
418/// `true` if there is at least one common index (by ID), `false` otherwise.
419///
420/// # Example
421/// ```
422/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
423/// use tensor4all_core::index_ops::hascommoninds;
424///
425/// let i = Index::new_dyn(2);
426/// let j = Index::new_dyn(3);
427/// let k = Index::new_dyn(4);
428///
429/// let indices_a = vec![i.clone(), j.clone()];
430/// let indices_b = vec![j.clone(), k.clone()];
431///
432/// assert!(hascommoninds(&indices_a, &indices_b));
433/// assert!(!hascommoninds(&[i.clone()], &[k.clone()]));
434/// ```
435pub fn hascommoninds<I: IndexLike>(indices_a: &[I], indices_b: &[I]) -> bool {
436    let b_ids: std::collections::HashSet<_> = indices_b.iter().map(|idx| idx.id()).collect();
437    indices_a.iter().any(|idx| b_ids.contains(idx.id()))
438}
439
440/// Find common indices between two index collections.
441///
442/// Returns a vector of indices that appear in both `indices_a` and `indices_b`
443/// (set intersection). This is similar to ITensors.jl's `commoninds` function.
444///
445/// Time complexity: O(n + m) where n = len(indices_a), m = len(indices_b).
446///
447/// # Arguments
448/// * `indices_a` - First collection of indices
449/// * `indices_b` - Second collection of indices
450///
451/// # Returns
452/// A vector containing indices that are common to both collections (matched by ID).
453///
454/// # Example
455/// ```
456/// use tensor4all_core::index::{DefaultIndex as Index, DynId};
457/// use tensor4all_core::index_ops::common_inds;
458///
459/// let i = Index::new_dyn(2);
460/// let j = Index::new_dyn(3);
461/// let k = Index::new_dyn(4);
462///
463/// let indices_a = vec![i.clone(), j.clone()];
464/// let indices_b = vec![j.clone(), k.clone()];
465///
466/// let common = common_inds(&indices_a, &indices_b);
467/// assert_eq!(common.len(), 1);
468/// assert_eq!(common[0].id, j.id);
469/// ```
470pub fn common_inds<I: IndexLike>(indices_a: &[I], indices_b: &[I]) -> Vec<I> {
471    let b_ids: std::collections::HashSet<_> = indices_b.iter().map(|idx| idx.id()).collect();
472    indices_a
473        .iter()
474        .filter(|idx| b_ids.contains(idx.id()))
475        .cloned()
476        .collect()
477}
478
479/// Find contractable indices between two slices and return their positions.
480///
481/// Returns a vector of `(pos_a, pos_b)` tuples where each tuple indicates
482/// that `indices_a[pos_a]` and `indices_b[pos_b]` are contractable
483/// (same ID, same dimension, and compatible ConjState).
484///
485/// # Example
486/// ```
487/// use tensor4all_core::index::DefaultIndex as Index;
488/// use tensor4all_core::index_ops::common_ind_positions;
489///
490/// let i = Index::new_dyn(2);
491/// let j = Index::new_dyn(3);
492/// let k = Index::new_dyn(4);
493///
494/// let indices_a = vec![i.clone(), j.clone()];
495/// let indices_b = vec![j.clone(), k.clone()];
496///
497/// let positions = common_ind_positions(&indices_a, &indices_b);
498/// assert_eq!(positions, vec![(1, 0)]); // j is at position 1 in a, position 0 in b
499/// ```
500pub fn common_ind_positions<I: IndexLike>(indices_a: &[I], indices_b: &[I]) -> Vec<(usize, usize)> {
501    let mut positions = Vec::new();
502    for (pos_a, idx_a) in indices_a.iter().enumerate() {
503        for (pos_b, idx_b) in indices_b.iter().enumerate() {
504            if idx_a.is_contractable(idx_b) {
505                positions.push((pos_a, pos_b));
506                break; // Each index in a can match at most one in b
507            }
508        }
509    }
510    positions
511}
512
513/// Result of preparing a tensor contraction.
514///
515/// Contains all the information needed to perform the contraction:
516/// - Which axes to contract from each tensor
517/// - The resulting indices and dimensions after contraction
518#[derive(Debug, Clone)]
519pub struct ContractionSpec<I: IndexLike> {
520    /// Axes to contract from the first tensor (positions in `indices_a`).
521    pub axes_a: Vec<usize>,
522    /// Axes to contract from the second tensor (positions in `indices_b`).
523    pub axes_b: Vec<usize>,
524    /// Indices of the result tensor (non-contracted indices from both tensors).
525    pub result_indices: Vec<I>,
526    /// Dimensions of the result tensor.
527    pub result_dims: Vec<usize>,
528}
529
530/// Error type for contraction preparation.
531#[derive(Debug, Clone, PartialEq, Eq)]
532pub enum ContractionError {
533    /// No common indices found for contraction.
534    NoCommonIndices,
535    /// Dimension mismatch for a common index.
536    DimensionMismatch {
537        /// Position in the first tensor.
538        pos_a: usize,
539        /// Position in the second tensor.
540        pos_b: usize,
541        /// Dimension in the first tensor.
542        dim_a: usize,
543        /// Dimension in the second tensor.
544        dim_b: usize,
545    },
546    /// Duplicate axis specified in contraction.
547    DuplicateAxis {
548        /// Which tensor has the duplicate ("self" or "other").
549        tensor: &'static str,
550        /// Position of the duplicate axis.
551        pos: usize,
552    },
553    /// Index not found in tensor.
554    IndexNotFound {
555        /// Which tensor the index was not found in.
556        tensor: &'static str,
557    },
558    /// Batch contraction not yet implemented.
559    BatchContractionNotImplemented,
560}
561
562impl std::fmt::Display for ContractionError {
563    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
564        match self {
565            ContractionError::NoCommonIndices => {
566                write!(f, "No common indices found for contraction")
567            }
568            ContractionError::DimensionMismatch {
569                pos_a,
570                pos_b,
571                dim_a,
572                dim_b,
573            } => {
574                write!(
575                    f,
576                    "Dimension mismatch: tensor_a[{}]={} != tensor_b[{}]={}",
577                    pos_a, dim_a, pos_b, dim_b
578                )
579            }
580            ContractionError::DuplicateAxis { tensor, pos } => {
581                write!(f, "Duplicate axis {} in {} tensor", pos, tensor)
582            }
583            ContractionError::IndexNotFound { tensor } => {
584                write!(f, "Index not found in {} tensor", tensor)
585            }
586            ContractionError::BatchContractionNotImplemented => {
587                write!(f, "Batch contraction not yet implemented")
588            }
589        }
590    }
591}
592
593impl std::error::Error for ContractionError {}
594
595/// Prepare contraction data for two tensors that share common indices.
596///
597/// This function finds common indices and computes the axes to contract
598/// and the resulting indices/dimensions.
599///
600/// # Example
601/// ```
602/// use tensor4all_core::index::DefaultIndex as Index;
603/// use tensor4all_core::index_ops::prepare_contraction;
604///
605/// let i = Index::new_dyn(2);
606/// let j = Index::new_dyn(3);
607/// let k = Index::new_dyn(4);
608///
609/// let indices_a = vec![i.clone(), j.clone()];
610/// let dims_a = vec![2, 3];
611/// let indices_b = vec![j.clone(), k.clone()];
612/// let dims_b = vec![3, 4];
613///
614/// let spec = prepare_contraction(&indices_a, &dims_a, &indices_b, &dims_b).unwrap();
615/// assert_eq!(spec.axes_a, vec![1]);  // j is at position 1 in a
616/// assert_eq!(spec.axes_b, vec![0]);  // j is at position 0 in b
617/// assert_eq!(spec.result_dims, vec![2, 4]);  // [i, k]
618/// ```
619pub fn prepare_contraction<I: IndexLike>(
620    indices_a: &[I],
621    dims_a: &[usize],
622    indices_b: &[I],
623    dims_b: &[usize],
624) -> Result<ContractionSpec<I>, ContractionError> {
625    // Find common indices and their positions.
626    // If no common indices exist, this becomes an outer product (empty axes).
627    let positions = common_ind_positions(indices_a, indices_b);
628
629    let (axes_a, axes_b): (Vec<_>, Vec<_>) = positions.iter().copied().unzip();
630
631    // Verify dimensions match
632    for &(pos_a, pos_b) in &positions {
633        if dims_a[pos_a] != dims_b[pos_b] {
634            return Err(ContractionError::DimensionMismatch {
635                pos_a,
636                pos_b,
637                dim_a: dims_a[pos_a],
638                dim_b: dims_b[pos_b],
639            });
640        }
641    }
642
643    // Build result indices and dimensions (non-contracted indices)
644    let mut result_indices = Vec::new();
645    let mut result_dims = Vec::new();
646
647    for (i, idx) in indices_a.iter().enumerate() {
648        if !axes_a.contains(&i) {
649            result_indices.push(idx.clone());
650            result_dims.push(dims_a[i]);
651        }
652    }
653
654    for (i, idx) in indices_b.iter().enumerate() {
655        if !axes_b.contains(&i) {
656            result_indices.push(idx.clone());
657            result_dims.push(dims_b[i]);
658        }
659    }
660
661    Ok(ContractionSpec {
662        axes_a,
663        axes_b,
664        result_indices,
665        result_dims,
666    })
667}
668
669/// Prepare contraction data for explicit index pairs (like tensordot).
670///
671/// Unlike `prepare_contraction`, this function takes explicit pairs of indices
672/// to contract, allowing contraction of indices with different IDs.
673///
674/// # Example
675/// ```
676/// use tensor4all_core::index::DefaultIndex as Index;
677/// use tensor4all_core::index_ops::prepare_contraction_pairs;
678///
679/// let i = Index::new_dyn(2);
680/// let j = Index::new_dyn(3);
681/// let k = Index::new_dyn(3);  // Same dim as j but different ID
682/// let l = Index::new_dyn(4);
683///
684/// let indices_a = vec![i.clone(), j.clone()];
685/// let dims_a = vec![2, 3];
686/// let indices_b = vec![k.clone(), l.clone()];
687/// let dims_b = vec![3, 4];
688///
689/// // Contract j with k
690/// let spec = prepare_contraction_pairs(
691///     &indices_a, &dims_a,
692///     &indices_b, &dims_b,
693///     &[(j.clone(), k.clone())]
694/// ).unwrap();
695/// assert_eq!(spec.axes_a, vec![1]);
696/// assert_eq!(spec.axes_b, vec![0]);
697/// assert_eq!(spec.result_dims, vec![2, 4]);
698/// ```
699pub fn prepare_contraction_pairs<I: IndexLike>(
700    indices_a: &[I],
701    dims_a: &[usize],
702    indices_b: &[I],
703    dims_b: &[usize],
704    pairs: &[(I, I)],
705) -> Result<ContractionSpec<I>, ContractionError> {
706    use std::collections::HashSet;
707
708    if pairs.is_empty() {
709        return Err(ContractionError::NoCommonIndices);
710    }
711
712    // Check for batch contraction (common indices not in pairs)
713    let contracted_a_ids: HashSet<_> = pairs.iter().map(|(idx, _)| idx.id()).collect();
714    let contracted_b_ids: HashSet<_> = pairs.iter().map(|(_, idx)| idx.id()).collect();
715
716    let common_positions = common_ind_positions(indices_a, indices_b);
717    for (pos_a, pos_b) in &common_positions {
718        let id_a = indices_a[*pos_a].id();
719        let id_b = indices_b[*pos_b].id();
720        if !contracted_a_ids.contains(id_a) || !contracted_b_ids.contains(id_b) {
721            return Err(ContractionError::BatchContractionNotImplemented);
722        }
723    }
724
725    // Find positions and validate
726    let mut axes_a = Vec::new();
727    let mut axes_b = Vec::new();
728
729    for (idx_a, idx_b) in pairs {
730        let pos_a = indices_a
731            .iter()
732            .position(|idx| idx.id() == idx_a.id())
733            .ok_or(ContractionError::IndexNotFound { tensor: "self" })?;
734
735        let pos_b = indices_b
736            .iter()
737            .position(|idx| idx.id() == idx_b.id())
738            .ok_or(ContractionError::IndexNotFound { tensor: "other" })?;
739
740        // Verify dimensions match
741        if dims_a[pos_a] != dims_b[pos_b] {
742            return Err(ContractionError::DimensionMismatch {
743                pos_a,
744                pos_b,
745                dim_a: dims_a[pos_a],
746                dim_b: dims_b[pos_b],
747            });
748        }
749
750        // Check for duplicate axes
751        if axes_a.contains(&pos_a) {
752            return Err(ContractionError::DuplicateAxis {
753                tensor: "self",
754                pos: pos_a,
755            });
756        }
757        if axes_b.contains(&pos_b) {
758            return Err(ContractionError::DuplicateAxis {
759                tensor: "other",
760                pos: pos_b,
761            });
762        }
763
764        axes_a.push(pos_a);
765        axes_b.push(pos_b);
766    }
767
768    // Build result indices and dimensions
769    let mut result_indices = Vec::new();
770    let mut result_dims = Vec::new();
771
772    for (i, idx) in indices_a.iter().enumerate() {
773        if !axes_a.contains(&i) {
774            result_indices.push(idx.clone());
775            result_dims.push(dims_a[i]);
776        }
777    }
778
779    for (i, idx) in indices_b.iter().enumerate() {
780        if !axes_b.contains(&i) {
781            result_indices.push(idx.clone());
782            result_dims.push(dims_b[i]);
783        }
784    }
785
786    Ok(ContractionSpec {
787        axes_a,
788        axes_b,
789        result_indices,
790        result_dims,
791    })
792}