1use std::cell::RefCell;
22use std::cmp::Reverse;
23use std::collections::{HashMap, HashSet};
24use std::env;
25use std::time::{Duration, Instant};
26
27use anyhow::Result;
28use petgraph::algo::connected_components;
29use petgraph::prelude::*;
30use tenferro::eager_tensor::einsum_subscripts as eager_einsum_ad;
31use tenferro::EinsumSubscripts;
32use tensor4all_tensorbackend::{einsum_native_tensors, einsum_native_tensors_owned};
33
34use crate::defaults::{DynId, DynIndex, TensorDynLen};
35
36use crate::index_like::IndexLike;
37#[derive(Debug, Clone, Hash, PartialEq, Eq)]
38struct ContractOperandSignature {
39 dims: Vec<usize>,
40 ids: Vec<usize>,
41 is_diag: bool,
42}
43
44#[derive(Debug, Clone, Hash, PartialEq, Eq)]
45struct ContractSignature {
46 operands: Vec<ContractOperandSignature>,
47 output_ids: Vec<usize>,
48 output_dims: Vec<usize>,
49}
50
51#[derive(Debug, Default, Clone)]
52struct ContractProfileEntry {
53 calls: usize,
54 total_time: Duration,
55}
56
57thread_local! {
58 static CONTRACT_PROFILE_STATE: RefCell<HashMap<ContractSignature, ContractProfileEntry>> =
59 RefCell::new(HashMap::new());
60}
61
62fn contract_profile_enabled() -> bool {
63 env::var("T4A_PROFILE_CONTRACT").is_ok()
64}
65
66fn record_contract_profile(signature: ContractSignature, elapsed: Duration) {
67 if !contract_profile_enabled() {
68 return;
69 }
70 CONTRACT_PROFILE_STATE.with(|state| {
71 let mut state = state.borrow_mut();
72 let entry = state.entry(signature).or_default();
73 entry.calls += 1;
74 entry.total_time += elapsed;
75 });
76}
77
78pub fn reset_contract_profile() {
80 CONTRACT_PROFILE_STATE.with(|state| state.borrow_mut().clear());
81}
82
83pub fn print_and_reset_contract_profile() {
85 if !contract_profile_enabled() {
86 return;
87 }
88 CONTRACT_PROFILE_STATE.with(|state| {
89 let mut entries: Vec<_> = state
90 .borrow()
91 .iter()
92 .map(|(k, v)| (k.clone(), v.clone()))
93 .collect();
94 state.borrow_mut().clear();
95 entries.sort_by_key(|(_, entry)| Reverse(entry.total_time));
96
97 eprintln!("=== contract Profile ===");
98 for (idx, (signature, entry)) in entries.into_iter().take(20).enumerate() {
99 let operands = signature
100 .operands
101 .iter()
102 .map(|operand| {
103 format!(
104 "dims={:?} ids={:?}{}",
105 operand.dims,
106 operand.ids,
107 if operand.is_diag { " diag" } else { "" }
108 )
109 })
110 .collect::<Vec<_>>()
111 .join(" ; ");
112 eprintln!(
113 "#{idx:02} calls={} total={:.3}s per_call={:.3}us output_dims={:?} output_ids={:?}",
114 entry.calls,
115 entry.total_time.as_secs_f64(),
116 entry.total_time.as_secs_f64() * 1e6 / entry.calls as f64,
117 signature.output_dims,
118 signature.output_ids,
119 );
120 eprintln!(" {operands}");
121 }
122 });
123}
124
125#[derive(Clone, Copy, Debug)]
146pub struct ContractionOptions<'a> {
147 pub retain_indices: &'a [DynIndex],
149}
150
151impl<'a> ContractionOptions<'a> {
152 pub fn new() -> Self {
154 Self {
155 retain_indices: &[],
156 }
157 }
158
159 pub fn with_retain_indices(mut self, retain_indices: &'a [DynIndex]) -> Self {
161 self.retain_indices = retain_indices;
162 self
163 }
164}
165
166impl Default for ContractionOptions<'_> {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
204pub struct PairwiseContractionOptions {
205 pub lhs_conj: bool,
207 pub rhs_conj: bool,
209}
210
211impl PairwiseContractionOptions {
212 pub fn new() -> Self {
224 Self::default()
225 }
226
227 pub fn with_lhs_conj(mut self, lhs_conj: bool) -> Self {
239 self.lhs_conj = lhs_conj;
240 self
241 }
242
243 pub fn with_rhs_conj(mut self, rhs_conj: bool) -> Self {
255 self.rhs_conj = rhs_conj;
256 self
257 }
258
259 pub(crate) fn has_conj(self) -> bool {
260 self.lhs_conj || self.rhs_conj
261 }
262}
263
264pub fn contract(tensors: &[&TensorDynLen]) -> Result<TensorDynLen> {
274 contract_with_options(tensors, ContractionOptions::new())
275}
276
277pub fn contract_with_options(
279 tensors: &[&TensorDynLen],
280 options: ContractionOptions<'_>,
281) -> Result<TensorDynLen> {
282 contract_with_options_impl(tensors, options)
283}
284
285pub fn contract_owned(tensors: Vec<TensorDynLen>) -> Result<TensorDynLen> {
287 contract_owned_with_options(tensors, ContractionOptions::new())
288}
289
290pub fn contract_owned_with_options(
292 tensors: Vec<TensorDynLen>,
293 options: ContractionOptions<'_>,
294) -> Result<TensorDynLen> {
295 let tensor_refs = tensors.iter().collect::<Vec<_>>();
296 let components =
297 find_tensor_connected_components_with_retained(&tensor_refs, options.retain_indices);
298 if components.len() > 1 {
299 return Err(anyhow::anyhow!(
300 "Tensors form disconnected components; use explicit outer_product operations for an intentional disconnected product"
301 ));
302 }
303 drop(tensor_refs);
304 contract_owned_with_options_impl(tensors, options)
305}
306
307pub fn contract_pair(lhs: &TensorDynLen, rhs: &TensorDynLen) -> Result<TensorDynLen> {
313 lhs.try_contract_pairwise_default_with_options(rhs, PairwiseContractionOptions::new())
314}
315
316pub fn contract_pair_with_operand_options(
352 lhs: &TensorDynLen,
353 rhs: &TensorDynLen,
354 options: PairwiseContractionOptions,
355) -> Result<TensorDynLen> {
356 lhs.try_contract_pairwise_default_with_options(rhs, options)
357}
358
359pub fn contract_pair_with_options(
361 lhs: &TensorDynLen,
362 rhs: &TensorDynLen,
363 options: ContractionOptions<'_>,
364) -> Result<TensorDynLen> {
365 contract_with_options(&[lhs, rhs], options)
366}
367
368pub fn tensordot(
370 lhs: &TensorDynLen,
371 rhs: &TensorDynLen,
372 pairs: &[(DynIndex, DynIndex)],
373) -> Result<TensorDynLen> {
374 lhs.try_tensordot_pairwise_explicit(rhs, pairs)
375}
376
377pub fn outer_product(lhs: &TensorDynLen, rhs: &TensorDynLen) -> Result<TensorDynLen> {
382 lhs.try_outer_product_pairwise(rhs)
383}
384
385fn contract_owned_with_options_impl(
394 tensors: Vec<TensorDynLen>,
395 options: ContractionOptions<'_>,
396) -> Result<TensorDynLen> {
397 match tensors.len() {
398 0 => Err(anyhow::anyhow!("No tensors to contract")),
399 _ => {
400 let tensor_refs = tensors.iter().collect::<Vec<_>>();
401 validate_retained_indices_exist(&tensor_refs, options.retain_indices)?;
402
403 if tensors.len() == 1 {
404 drop(tensor_refs);
405 let Some(tensor) = tensors.into_iter().next() else {
406 return Err(anyhow::anyhow!("No tensors to contract"));
407 };
408 return Ok(tensor);
409 }
410
411 let requires_borrowed_path = tensor_refs.iter().any(|tensor| tensor.tracks_grad())
412 || tensor_refs
413 .iter()
414 .any(|tensor| !has_dense_axis_classes(tensor));
415 if requires_borrowed_path {
416 return contract_with_options(&tensor_refs, options);
417 }
418
419 let components = find_tensor_connected_components_with_retained(
420 &tensor_refs,
421 options.retain_indices,
422 );
423 if components.len() > 1 {
424 return Err(anyhow::anyhow!(
425 "Tensors form disconnected components; use explicit outer_product operations for an intentional disconnected product"
426 ));
427 }
428
429 let mut diag_uf = AxisUnionFind::new();
430 let plan = build_contraction_plan(&tensor_refs, options, &mut diag_uf)?;
431 drop(tensor_refs);
432 let native_operands = tensors
433 .into_iter()
434 .enumerate()
435 .map(|(tensor_idx, tensor)| {
436 Ok((
437 tensor.as_native()?.clone(),
438 plan.input_ids[tensor_idx].clone(),
439 ))
440 })
441 .collect::<Result<Vec<_>>>()?;
442 let result_native = einsum_native_tensors_owned(native_operands, &plan.output_ids)?;
443 TensorDynLen::from_native_with_axis_classes(
444 plan.result_indices,
445 result_native,
446 plan.result_axis_classes,
447 )
448 }
449 }
450}
451
452fn has_dense_axis_classes(tensor: &TensorDynLen) -> bool {
453 let storage = tensor.storage();
454 storage
455 .axis_classes()
456 .iter()
457 .copied()
458 .eq(0..tensor.indices().len())
459}
460
461fn contract_with_options_impl(
462 tensors: &[&TensorDynLen],
463 options: ContractionOptions<'_>,
464) -> Result<TensorDynLen> {
465 match tensors.len() {
466 0 => Err(anyhow::anyhow!("No tensors to contract")),
467 _ => {
468 validate_retained_indices_exist(tensors, options.retain_indices)?;
469 if tensors.len() == 1 {
470 return Ok((*tensors[0]).clone());
471 }
472
473 let components =
475 find_tensor_connected_components_with_retained(tensors, options.retain_indices);
476 if components.len() > 1 {
477 return Err(anyhow::anyhow!(
478 "Disconnected tensor network: {} components found",
479 components.len()
480 ));
481 }
482 contract_impl(tensors, options)
484 }
485 }
486}
487
488#[derive(Debug, Clone)]
497pub struct AxisUnionFind {
498 parent: HashMap<DynId, DynId>,
500 rank: HashMap<DynId, usize>,
502}
503
504impl AxisUnionFind {
505 pub fn new() -> Self {
507 Self {
508 parent: HashMap::new(),
509 rank: HashMap::new(),
510 }
511 }
512
513 pub fn make_set(&mut self, id: DynId) {
515 use std::collections::hash_map::Entry;
516 if let Entry::Vacant(e) = self.parent.entry(id) {
517 e.insert(id);
518 self.rank.insert(id, 0);
519 }
520 }
521
522 pub fn find(&mut self, id: DynId) -> DynId {
525 self.make_set(id);
526 if self.parent[&id] != id {
527 let root = self.find(self.parent[&id]);
528 self.parent.insert(id, root);
529 }
530 self.parent[&id]
531 }
532
533 pub fn union(&mut self, a: DynId, b: DynId) {
536 let root_a = self.find(a);
537 let root_b = self.find(b);
538
539 if root_a == root_b {
540 return;
541 }
542
543 let rank_a = self.rank[&root_a];
544 let rank_b = self.rank[&root_b];
545
546 if rank_a < rank_b {
547 self.parent.insert(root_a, root_b);
548 } else if rank_a > rank_b {
549 self.parent.insert(root_b, root_a);
550 } else {
551 self.parent.insert(root_b, root_a);
552 if let Some(rank) = self.rank.get_mut(&root_a) {
553 *rank += 1;
554 }
555 }
556 }
557
558 pub fn remap(&mut self, id: DynId) -> DynId {
560 self.find(id)
561 }
562
563 pub fn remap_ids(&mut self, ids: &[DynId]) -> Vec<DynId> {
565 ids.iter().map(|id| self.find(*id)).collect()
566 }
567}
568
569impl Default for AxisUnionFind {
570 fn default() -> Self {
571 Self::new()
572 }
573}
574
575pub fn build_diag_union(tensors: &[&TensorDynLen]) -> AxisUnionFind {
585 let mut uf = AxisUnionFind::new();
586
587 for tensor in tensors {
588 for idx in tensor.indices() {
589 uf.make_set(*idx.id());
590 }
591
592 if tensor.is_diag() && tensor.indices().len() >= 2 {
593 let first_id = *tensor.indices()[0].id();
594 for idx in tensor.indices().iter().skip(1) {
595 uf.union(first_id, *idx.id());
596 }
597 }
598 }
599
600 uf
601}
602
603pub fn remap_tensor_ids(tensors: &[&TensorDynLen], uf: &mut AxisUnionFind) -> Vec<Vec<DynId>> {
608 tensors
609 .iter()
610 .map(|t| t.indices.iter().map(|idx| uf.find(*idx.id())).collect())
611 .collect()
612}
613
614pub fn remap_output_ids(output: &[DynIndex], uf: &mut AxisUnionFind) -> Vec<DynId> {
616 output.iter().map(|idx| uf.find(*idx.id())).collect()
617}
618
619pub fn collect_sizes(tensors: &[&TensorDynLen], uf: &mut AxisUnionFind) -> HashMap<DynId, usize> {
624 let mut sizes = HashMap::new();
625
626 for tensor in tensors {
627 let dims = tensor.dims();
628 for (idx, &dim) in tensor.indices.iter().zip(dims.iter()) {
629 let rep = uf.find(*idx.id());
630 sizes.entry(rep).or_insert(dim);
631 }
632 }
633
634 sizes
635}
636
637fn contract_impl(
649 tensors: &[&TensorDynLen],
650 options: ContractionOptions<'_>,
651) -> Result<TensorDynLen> {
652 let mut diag_uf = AxisUnionFind::new();
656
657 let plan = build_contraction_plan(tensors, options, &mut diag_uf)?;
659
660 let mut sizes: HashMap<usize, usize> = HashMap::new();
665 for (tensor_idx, tensor) in tensors.iter().enumerate() {
666 let dims = tensor.dims();
667 for (pos, &dim) in dims.iter().enumerate() {
668 let internal_id = plan.input_ids[tensor_idx][pos];
669 match sizes.entry(internal_id) {
670 std::collections::hash_map::Entry::Vacant(entry) => {
671 entry.insert(dim);
672 }
673 std::collections::hash_map::Entry::Occupied(entry) => {
674 if *entry.get() != dim {
675 return Err(anyhow::anyhow!(
676 "Internal label dimension mismatch: label {} has dimensions {} and {}",
677 internal_id,
678 entry.get(),
679 dim
680 ));
681 }
682 }
683 }
684 }
685 }
686
687 let profile_signature = contract_profile_enabled().then(|| ContractSignature {
688 operands: tensors
689 .iter()
690 .enumerate()
691 .map(|(tensor_idx, tensor)| ContractOperandSignature {
692 dims: tensor.dims().to_vec(),
693 ids: plan.input_ids[tensor_idx].clone(),
694 is_diag: tensor.is_diag(),
695 })
696 .collect(),
697 output_ids: plan.output_ids.clone(),
698 output_dims: plan.output_ids.iter().map(|id| sizes[id]).collect(),
699 });
700 let profile_started = contract_profile_enabled().then(Instant::now);
701
702 let result = execute_contraction_plan(tensors, &plan, !options.retain_indices.is_empty())?;
703 if let (Some(signature), Some(started)) = (profile_signature, profile_started) {
704 record_contract_profile(signature, started.elapsed());
705 }
706 Ok(result)
707}
708
709fn execute_contraction_plan(
710 tensors: &[&TensorDynLen],
711 plan: &ContractionPlan,
712 has_retained_indices: bool,
713) -> Result<TensorDynLen> {
714 let any_grad = tensors.iter().any(|tensor| tensor.tracks_grad());
715 let first_dtype = tensors[0].as_native()?.dtype();
716 let same_dtype = tensors
717 .iter()
718 .map(|tensor| {
719 tensor
720 .as_native()
721 .map(|native| native.dtype() == first_dtype)
722 })
723 .collect::<Result<Vec<_>>>()?
724 .into_iter()
725 .all(|same| same);
726 let has_non_dense_axis_classes = tensors.iter().any(|tensor| {
727 tensor
728 .storage()
729 .axis_classes()
730 .iter()
731 .copied()
732 .enumerate()
733 .any(|(axis, class)| axis != class)
734 });
735
736 if any_grad && same_dtype && has_non_dense_axis_classes {
737 if has_retained_indices {
738 return Err(anyhow::anyhow!(
739 "Retained AD contraction with structured storage is not yet supported"
740 ));
741 }
742
743 let mut iter = tensors.iter();
746 let Some(first) = iter.next() else {
747 return Err(anyhow::anyhow!("No tensors to contract"));
748 };
749 let mut result = (*first).clone();
750 for tensor in iter {
751 result = contract_pair(&result, tensor)?;
752 }
753 return Ok(result);
754 }
755
756 if any_grad && same_dtype {
757 let operands = tensors
758 .iter()
759 .map(|tensor| tensor.as_inner())
760 .collect::<Result<Vec<_>>>()?;
761 let subscripts = build_einsum_subscripts_from_usize_ids(&plan.input_ids, &plan.output_ids)?;
762 let result = eager_einsum_ad(&operands, &subscripts)?;
763 return TensorDynLen::from_inner_with_axis_classes(
764 plan.result_indices.clone(),
765 result,
766 plan.result_axis_classes.clone(),
767 );
768 }
769
770 let native_operands: Vec<_> = tensors
771 .iter()
772 .enumerate()
773 .map(|(tensor_idx, tensor)| {
774 Ok((tensor.as_native()?, plan.input_ids[tensor_idx].as_slice()))
775 })
776 .collect::<Result<Vec<_>>>()?;
777 let result_native = einsum_native_tensors(&native_operands, &plan.output_ids)?;
778 TensorDynLen::from_native_with_axis_classes(
779 plan.result_indices.clone(),
780 result_native,
781 plan.result_axis_classes.clone(),
782 )
783}
784
785fn build_einsum_subscripts_from_usize_ids(
786 input_ids: &[Vec<usize>],
787 output_ids: &[usize],
788) -> Result<EinsumSubscripts> {
789 let inputs = input_ids
790 .iter()
791 .map(|ids| {
792 ids.iter()
793 .map(|&id| {
794 u32::try_from(id)
795 .map_err(|_| anyhow::anyhow!("einsum label {id} exceeds u32 range"))
796 })
797 .collect::<Result<Vec<_>>>()
798 })
799 .collect::<Result<Vec<_>>>()?;
800 let output = output_ids
801 .iter()
802 .map(|&id| {
803 u32::try_from(id).map_err(|_| anyhow::anyhow!("einsum label {id} exceeds u32 range"))
804 })
805 .collect::<Result<Vec<_>>>()?;
806 let input_refs = inputs.iter().map(Vec::as_slice).collect::<Vec<_>>();
807 Ok(EinsumSubscripts::new(&input_refs, &output))
808}
809
810#[derive(Debug, Clone)]
812struct ContractionPlan {
813 input_ids: Vec<Vec<usize>>,
814 output_ids: Vec<usize>,
815 result_indices: Vec<DynIndex>,
816 result_axis_classes: Vec<usize>,
817}
818
819fn build_contraction_plan(
820 tensors: &[&TensorDynLen],
821 options: ContractionOptions<'_>,
822 diag_uf: &mut AxisUnionFind,
823) -> Result<ContractionPlan> {
824 let retained_indices: HashSet<DynIndex> = options.retain_indices.iter().cloned().collect();
825 let (input_ids, internal_id_to_original) =
826 build_internal_ids(tensors, diag_uf, &retained_indices)?;
827
828 let mut counts: HashMap<usize, usize> = HashMap::new();
829 for ids in &input_ids {
830 for &internal_id in ids {
831 *counts.entry(internal_id).or_insert(0) += 1;
832 }
833 }
834 let mut output_ids = Vec::new();
835 let mut seen_output = HashSet::new();
836 let mut found_retained = HashSet::new();
837
838 for (tensor_idx, tensor) in tensors.iter().enumerate() {
839 for (axis, idx) in tensor.indices.iter().enumerate() {
840 let internal_id = input_ids[tensor_idx][axis];
841 let should_output = counts[&internal_id] == 1 || retained_indices.contains(idx);
842 if should_output && seen_output.insert(internal_id) {
843 output_ids.push(internal_id);
844 }
845 if retained_indices.contains(idx) {
846 found_retained.insert(idx.clone());
847 }
848 }
849 }
850
851 for retained in retained_indices {
852 if !found_retained.contains(&retained) {
853 return Err(anyhow::anyhow!(
854 "Retained index {:?} does not appear in the input tensors",
855 retained
856 ));
857 }
858 }
859
860 let result_indices: Vec<DynIndex> = output_ids
861 .iter()
862 .map(|&internal_id| {
863 let (tensor_idx, pos) = internal_id_to_original[&internal_id];
864 tensors[tensor_idx].indices[pos].clone()
865 })
866 .collect();
867 validate_unique_output_indices(&result_indices)?;
868 let result_axis_classes =
869 output_axis_classes(tensors, &input_ids, &output_ids, &internal_id_to_original);
870
871 Ok(ContractionPlan {
872 input_ids,
873 output_ids,
874 result_indices,
875 result_axis_classes,
876 })
877}
878
879fn validate_retained_indices_exist(
880 tensors: &[&TensorDynLen],
881 retain_indices: &[DynIndex],
882) -> Result<()> {
883 for retain in retain_indices {
884 let found = tensors
885 .iter()
886 .any(|tensor| tensor.indices().iter().any(|idx| idx == retain));
887 if !found {
888 return Err(anyhow::anyhow!(
889 "Retained index {:?} does not appear in the input tensors",
890 retain
891 ));
892 }
893 }
894 Ok(())
895}
896
897fn validate_unique_output_indices(indices: &[DynIndex]) -> Result<()> {
898 let mut seen = HashSet::new();
899 for idx in indices {
900 if !seen.insert(idx.clone()) {
901 return Err(anyhow::anyhow!(
902 "Contraction result would contain duplicate output indices"
903 ));
904 }
905 }
906 Ok(())
907}
908
909fn output_axis_classes(
910 tensors: &[&TensorDynLen],
911 ixs: &[Vec<usize>],
912 output: &[usize],
913 internal_id_to_original: &HashMap<usize, (usize, usize)>,
914) -> Vec<usize> {
915 fn find(parent: &mut [usize], value: usize) -> usize {
916 if parent[value] != value {
917 parent[value] = find(parent, parent[value]);
918 }
919 parent[value]
920 }
921
922 fn union(parent: &mut [usize], lhs: usize, rhs: usize) {
923 let lhs_root = find(parent, lhs);
924 let rhs_root = find(parent, rhs);
925 if lhs_root != rhs_root {
926 parent[rhs_root] = lhs_root;
927 }
928 }
929
930 let mut class_offsets = Vec::with_capacity(tensors.len());
931 let mut next_node = 0usize;
932 for tensor in tensors {
933 class_offsets.push(next_node);
934 let payload_rank = tensor
935 .storage()
936 .axis_classes()
937 .iter()
938 .copied()
939 .max()
940 .map(|value| value + 1)
941 .unwrap_or(0);
942 next_node += payload_rank;
943 }
944 let mut parent: Vec<usize> = (0..next_node).collect();
945 let mut axes_by_internal_id: HashMap<usize, Vec<usize>> = HashMap::new();
946
947 for (tensor_idx, tensor) in tensors.iter().enumerate() {
948 for (axis, &internal_id) in ixs[tensor_idx].iter().enumerate() {
949 let class_id = tensor.storage().axis_classes()[axis];
950 let node = class_offsets[tensor_idx] + class_id;
951 axes_by_internal_id
952 .entry(internal_id)
953 .or_default()
954 .push(node);
955 }
956 }
957
958 for nodes in axes_by_internal_id.values() {
959 if let Some((&first, rest)) = nodes.split_first() {
960 for &node in rest {
961 union(&mut parent, first, node);
962 }
963 }
964 }
965
966 let mut root_to_class = HashMap::new();
967 let mut next_class = 0usize;
968 output
969 .iter()
970 .map(|internal_id| {
971 let (tensor_idx, axis) = internal_id_to_original[internal_id];
972 let class_id = tensors[tensor_idx].storage().axis_classes()[axis];
973 let node = class_offsets[tensor_idx] + class_id;
974 let root = find(&mut parent, node);
975 *root_to_class.entry(root).or_insert_with(|| {
976 let class = next_class;
977 next_class += 1;
978 class
979 })
980 })
981 .collect()
982}
983
984#[allow(clippy::type_complexity)]
992fn build_internal_ids(
993 tensors: &[&TensorDynLen],
994 _diag_uf: &mut AxisUnionFind,
995 retained_indices: &HashSet<DynIndex>,
996) -> Result<(Vec<Vec<usize>>, HashMap<usize, (usize, usize)>)> {
997 let mut next_id = 0usize;
998 let mut index_to_internal: HashMap<DynIndex, usize> = HashMap::new();
999 let mut retained_index_to_internal: HashMap<DynIndex, usize> = HashMap::new();
1000 let mut assigned: HashMap<(usize, usize), usize> = HashMap::new();
1001 let mut internal_id_to_original: HashMap<usize, (usize, usize)> = HashMap::new();
1002
1003 for ti in 0..tensors.len() {
1004 for tj in (ti + 1)..tensors.len() {
1005 for (pi, idx_i) in tensors[ti].indices.iter().enumerate() {
1006 for (pj, idx_j) in tensors[tj].indices.iter().enumerate() {
1007 if idx_i.is_contractable(idx_j) {
1008 let key_i = (ti, pi);
1009 let key_j = (tj, pj);
1010
1011 match (assigned.get(&key_i).copied(), assigned.get(&key_j).copied()) {
1012 (None, None) => {
1013 let internal_id = if let Some(&id) = index_to_internal.get(idx_i) {
1014 id
1015 } else {
1016 let id = next_id;
1017 next_id += 1;
1018 index_to_internal.insert(idx_i.clone(), id);
1019 internal_id_to_original.insert(id, key_i);
1020 id
1021 };
1022 assigned.insert(key_i, internal_id);
1023 assigned.insert(key_j, internal_id);
1024 if idx_i != idx_j {
1025 index_to_internal.insert(idx_j.clone(), internal_id);
1026 }
1027 }
1028 (Some(id), None) => {
1029 assigned.insert(key_j, id);
1030 index_to_internal.insert(idx_j.clone(), id);
1031 }
1032 (None, Some(id)) => {
1033 assigned.insert(key_i, id);
1034 index_to_internal.insert(idx_i.clone(), id);
1035 }
1036 (Some(_id_i), Some(_id_j)) => {
1037 }
1039 }
1040 }
1041 }
1042 }
1043 }
1044 }
1045
1046 for (tensor_idx, tensor) in tensors.iter().enumerate() {
1048 for (pos, idx) in tensor.indices.iter().enumerate() {
1049 let key = (tensor_idx, pos);
1050 if let std::collections::hash_map::Entry::Vacant(e) = assigned.entry(key) {
1051 let internal_id = if retained_indices.contains(idx) {
1052 if let Some(&id) = retained_index_to_internal.get(idx) {
1053 id
1054 } else {
1055 let id = next_id;
1056 next_id += 1;
1057 retained_index_to_internal.insert(idx.clone(), id);
1058 internal_id_to_original.insert(id, key);
1059 id
1060 }
1061 } else {
1062 let id = next_id;
1063 next_id += 1;
1064 internal_id_to_original.insert(id, key);
1065 id
1066 };
1067 e.insert(internal_id);
1068 }
1069 }
1070 }
1071
1072 let ixs: Vec<Vec<usize>> = tensors
1074 .iter()
1075 .enumerate()
1076 .map(|(tensor_idx, tensor)| {
1077 (0..tensor.indices.len())
1078 .map(|pos| assigned[&(tensor_idx, pos)])
1079 .collect()
1080 })
1081 .collect();
1082
1083 Ok((ixs, internal_id_to_original))
1084}
1085
1086fn has_contractable_indices(a: &TensorDynLen, b: &TensorDynLen) -> bool {
1092 a.indices
1093 .iter()
1094 .any(|idx_a| b.indices.iter().any(|idx_b| idx_a.is_contractable(idx_b)))
1095}
1096
1097#[allow(dead_code)]
1101fn find_tensor_connected_components(tensors: &[&TensorDynLen]) -> Vec<Vec<usize>> {
1102 find_tensor_connected_components_with_retained(tensors, &[])
1103}
1104
1105fn find_tensor_connected_components_with_retained(
1106 tensors: &[&TensorDynLen],
1107 retain_indices: &[DynIndex],
1108) -> Vec<Vec<usize>> {
1109 let n = tensors.len();
1110 if n == 0 {
1111 return vec![];
1112 }
1113 if n == 1 {
1114 return vec![vec![0]];
1115 }
1116
1117 let mut graph = UnGraph::<(), ()>::new_undirected();
1119 let nodes: Vec<_> = (0..n).map(|_| graph.add_node(())).collect();
1120
1121 for i in 0..n {
1122 for j in (i + 1)..n {
1123 if has_contractable_indices(tensors[i], tensors[j]) {
1124 graph.add_edge(nodes[i], nodes[j], ());
1125 }
1126 }
1127 }
1128
1129 if !retain_indices.is_empty() {
1130 for i in 0..n {
1131 for j in (i + 1)..n {
1132 if shares_retained_index(tensors[i], tensors[j], retain_indices) {
1133 graph.add_edge(nodes[i], nodes[j], ());
1134 }
1135 }
1136 }
1137 }
1138
1139 let num_components = connected_components(&graph);
1141
1142 if num_components == 1 {
1143 return vec![(0..n).collect()];
1144 }
1145
1146 use petgraph::visit::Dfs;
1148 let mut visited = vec![false; n];
1149 let mut components = Vec::new();
1150
1151 for start in 0..n {
1152 if !visited[start] {
1153 let mut component = Vec::new();
1154 let mut dfs = Dfs::new(&graph, nodes[start]);
1155 while let Some(node) = dfs.next(&graph) {
1156 let idx = node.index();
1157 if !visited[idx] {
1158 visited[idx] = true;
1159 component.push(idx);
1160 }
1161 }
1162 component.sort();
1163 components.push(component);
1164 }
1165 }
1166
1167 components.sort_by_key(|c| c[0]);
1168 components
1169}
1170
1171fn shares_retained_index(a: &TensorDynLen, b: &TensorDynLen, retain_indices: &[DynIndex]) -> bool {
1172 retain_indices.iter().any(|retain| {
1173 a.indices().iter().any(|idx_a| idx_a == retain)
1174 && b.indices().iter().any(|idx_b| idx_b == retain)
1175 })
1176}
1177
1178#[cfg(test)]
1179mod tests;