1use std::cell::RefCell;
22use std::cmp::Reverse;
23use std::collections::{HashMap, HashSet};
24use std::env;
25use std::time::{Duration, Instant};
26
27use anyhow::Result;
28use petgraph::algo::connected_components;
29use petgraph::prelude::*;
30use tenferro::eager_einsum::eager_einsum_ad;
31use tensor4all_tensorbackend::{einsum_native_tensors, einsum_native_tensors_owned};
32
33use crate::defaults::{DynId, DynIndex, TensorDynLen};
34
35use crate::index_like::IndexLike;
36use crate::tensor_like::AllowedPairs;
37
38#[derive(Debug, Clone, Hash, PartialEq, Eq)]
39struct ContractOperandSignature {
40 dims: Vec<usize>,
41 ids: Vec<usize>,
42 is_diag: bool,
43}
44
45#[derive(Debug, Clone, Hash, PartialEq, Eq)]
46struct ContractSignature {
47 operands: Vec<ContractOperandSignature>,
48 output_ids: Vec<usize>,
49 output_dims: Vec<usize>,
50}
51
52#[derive(Debug, Default, Clone)]
53struct ContractProfileEntry {
54 calls: usize,
55 total_time: Duration,
56}
57
58thread_local! {
59 static CONTRACT_PROFILE_STATE: RefCell<HashMap<ContractSignature, ContractProfileEntry>> =
60 RefCell::new(HashMap::new());
61}
62
63fn contract_profile_enabled() -> bool {
64 env::var("T4A_PROFILE_CONTRACT").is_ok()
65}
66
67fn record_contract_profile(signature: ContractSignature, elapsed: Duration) {
68 if !contract_profile_enabled() {
69 return;
70 }
71 CONTRACT_PROFILE_STATE.with(|state| {
72 let mut state = state.borrow_mut();
73 let entry = state.entry(signature).or_default();
74 entry.calls += 1;
75 entry.total_time += elapsed;
76 });
77}
78
79pub fn reset_contract_profile() {
81 CONTRACT_PROFILE_STATE.with(|state| state.borrow_mut().clear());
82}
83
84pub fn print_and_reset_contract_profile() {
86 if !contract_profile_enabled() {
87 return;
88 }
89 CONTRACT_PROFILE_STATE.with(|state| {
90 let mut entries: Vec<_> = state
91 .borrow()
92 .iter()
93 .map(|(k, v)| (k.clone(), v.clone()))
94 .collect();
95 state.borrow_mut().clear();
96 entries.sort_by_key(|(_, entry)| Reverse(entry.total_time));
97
98 eprintln!("=== contract_multi Profile ===");
99 for (idx, (signature, entry)) in entries.into_iter().take(20).enumerate() {
100 let operands = signature
101 .operands
102 .iter()
103 .map(|operand| {
104 format!(
105 "dims={:?} ids={:?}{}",
106 operand.dims,
107 operand.ids,
108 if operand.is_diag { " diag" } else { "" }
109 )
110 })
111 .collect::<Vec<_>>()
112 .join(" ; ");
113 eprintln!(
114 "#{idx:02} calls={} total={:.3}s per_call={:.3}us output_dims={:?} output_ids={:?}",
115 entry.calls,
116 entry.total_time.as_secs_f64(),
117 entry.total_time.as_secs_f64() * 1e6 / entry.calls as f64,
118 signature.output_dims,
119 signature.output_ids,
120 );
121 eprintln!(" {operands}");
122 }
123 });
124}
125
126#[derive(Clone, Copy, Debug)]
148pub struct ContractionOptions<'a> {
149 pub allowed: AllowedPairs<'a>,
151 pub retain_indices: &'a [DynIndex],
153}
154
155impl<'a> ContractionOptions<'a> {
156 pub fn new(allowed: AllowedPairs<'a>) -> Self {
158 Self {
159 allowed,
160 retain_indices: &[],
161 }
162 }
163
164 pub fn with_retain_indices(mut self, retain_indices: &'a [DynIndex]) -> Self {
166 self.retain_indices = retain_indices;
167 self
168 }
169}
170
171pub fn contract_multi(
217 tensors: &[&TensorDynLen],
218 allowed: AllowedPairs<'_>,
219) -> Result<TensorDynLen> {
220 contract_multi_with_options(tensors, ContractionOptions::new(allowed))
221}
222
223pub fn contract_multi_with_options(
259 tensors: &[&TensorDynLen],
260 options: ContractionOptions<'_>,
261) -> Result<TensorDynLen> {
262 match tensors.len() {
263 0 => Err(anyhow::anyhow!("No tensors to contract")),
264 _ => {
265 validate_retained_indices_exist(tensors, options.retain_indices)?;
266 if tensors.len() == 1 {
267 return Ok((*tensors[0]).clone());
268 }
269
270 if let AllowedPairs::Specified(pairs) = options.allowed {
272 for &(i, j) in pairs {
273 if !has_contractable_indices(tensors[i], tensors[j]) {
274 return Err(anyhow::anyhow!(
275 "Specified pair ({}, {}) has no contractable indices",
276 i,
277 j
278 ));
279 }
280 }
281 }
282
283 let components = find_tensor_connected_components_with_retained(
285 tensors,
286 options.allowed,
287 options.retain_indices,
288 );
289
290 if components.len() == 1 {
291 contract_multi_impl(tensors, options)
293 } else {
294 let mut results: Vec<TensorDynLen> = Vec::new();
296 for component in &components {
297 let component_tensors: Vec<&TensorDynLen> =
298 component.iter().map(|&i| tensors[i]).collect();
299 let component_retain_indices =
300 retained_indices_for_component(tensors, component, options.retain_indices);
301
302 let remapped_allowed = remap_allowed_pairs(options.allowed, component);
304 let component_options = ContractionOptions {
305 allowed: remapped_allowed.as_ref(),
306 retain_indices: &component_retain_indices,
307 };
308 let contracted = contract_multi_impl(&component_tensors, component_options)?;
309 results.push(contracted);
310 }
311
312 let mut results_iter = results.into_iter();
314 let Some(mut result) = results_iter.next() else {
315 return Err(anyhow::anyhow!("No contracted components produced"));
316 };
317 for other in results_iter {
318 result = result.outer_product(&other)?;
319 }
320 Ok(result)
321 }
322 }
323 }
324}
325
326pub fn contract_multi_owned(
365 tensors: Vec<TensorDynLen>,
366 options: ContractionOptions<'_>,
367) -> Result<TensorDynLen> {
368 match tensors.len() {
369 0 => Err(anyhow::anyhow!("No tensors to contract")),
370 _ => {
371 let tensor_refs = tensors.iter().collect::<Vec<_>>();
372 validate_retained_indices_exist(&tensor_refs, options.retain_indices)?;
373
374 if tensors.len() == 1 {
375 drop(tensor_refs);
376 let Some(tensor) = tensors.into_iter().next() else {
377 return Err(anyhow::anyhow!("No tensors to contract"));
378 };
379 return Ok(tensor);
380 }
381
382 if let AllowedPairs::Specified(pairs) = options.allowed {
383 for &(i, j) in pairs {
384 if !has_contractable_indices(tensor_refs[i], tensor_refs[j]) {
385 return Err(anyhow::anyhow!(
386 "Specified pair ({}, {}) has no contractable indices",
387 i,
388 j
389 ));
390 }
391 }
392 }
393
394 let requires_borrowed_path = tensor_refs.iter().any(|tensor| tensor.tracks_grad())
395 || tensor_refs
396 .iter()
397 .any(|tensor| !has_dense_axis_classes(tensor));
398 if requires_borrowed_path {
399 return contract_multi_with_options(&tensor_refs, options);
400 }
401
402 let components = find_tensor_connected_components_with_retained(
403 &tensor_refs,
404 options.allowed,
405 options.retain_indices,
406 );
407 if components.len() > 1 {
408 return contract_multi_with_options(&tensor_refs, options);
409 }
410
411 let mut diag_uf = AxisUnionFind::new();
412 let plan = build_contraction_plan(&tensor_refs, options, &mut diag_uf)?;
413 drop(tensor_refs);
414 let native_operands = tensors
415 .into_iter()
416 .enumerate()
417 .map(|(tensor_idx, tensor)| {
418 (
419 tensor.as_native().clone(),
420 plan.input_ids[tensor_idx].clone(),
421 )
422 })
423 .collect::<Vec<_>>();
424 let result_native = einsum_native_tensors_owned(native_operands, &plan.output_ids)?;
425 TensorDynLen::from_native_with_axis_classes(
426 plan.result_indices,
427 result_native,
428 plan.result_axis_classes,
429 )
430 }
431 }
432}
433
434fn has_dense_axis_classes(tensor: &TensorDynLen) -> bool {
435 let storage = tensor.storage();
436 storage
437 .axis_classes()
438 .iter()
439 .copied()
440 .eq(0..tensor.indices().len())
441}
442
443pub fn contract_connected(
489 tensors: &[&TensorDynLen],
490 allowed: AllowedPairs<'_>,
491) -> Result<TensorDynLen> {
492 contract_connected_with_options(tensors, ContractionOptions::new(allowed))
493}
494
495pub fn contract_connected_with_options(
541 tensors: &[&TensorDynLen],
542 options: ContractionOptions<'_>,
543) -> Result<TensorDynLen> {
544 match tensors.len() {
545 0 => Err(anyhow::anyhow!("No tensors to contract")),
546 _ => {
547 validate_retained_indices_exist(tensors, options.retain_indices)?;
548 if tensors.len() == 1 {
549 return Ok((*tensors[0]).clone());
550 }
551
552 let components = find_tensor_connected_components_with_retained(
554 tensors,
555 options.allowed,
556 options.retain_indices,
557 );
558 if components.len() > 1 {
559 return Err(anyhow::anyhow!(
560 "Disconnected tensor network: {} components found",
561 components.len()
562 ));
563 }
564 contract_multi_impl(tensors, options)
566 }
567 }
568}
569
570#[derive(Debug, Clone)]
579pub struct AxisUnionFind {
580 parent: HashMap<DynId, DynId>,
582 rank: HashMap<DynId, usize>,
584}
585
586impl AxisUnionFind {
587 pub fn new() -> Self {
589 Self {
590 parent: HashMap::new(),
591 rank: HashMap::new(),
592 }
593 }
594
595 pub fn make_set(&mut self, id: DynId) {
597 use std::collections::hash_map::Entry;
598 if let Entry::Vacant(e) = self.parent.entry(id) {
599 e.insert(id);
600 self.rank.insert(id, 0);
601 }
602 }
603
604 pub fn find(&mut self, id: DynId) -> DynId {
607 self.make_set(id);
608 if self.parent[&id] != id {
609 let root = self.find(self.parent[&id]);
610 self.parent.insert(id, root);
611 }
612 self.parent[&id]
613 }
614
615 pub fn union(&mut self, a: DynId, b: DynId) {
618 let root_a = self.find(a);
619 let root_b = self.find(b);
620
621 if root_a == root_b {
622 return;
623 }
624
625 let rank_a = self.rank[&root_a];
626 let rank_b = self.rank[&root_b];
627
628 if rank_a < rank_b {
629 self.parent.insert(root_a, root_b);
630 } else if rank_a > rank_b {
631 self.parent.insert(root_b, root_a);
632 } else {
633 self.parent.insert(root_b, root_a);
634 if let Some(rank) = self.rank.get_mut(&root_a) {
635 *rank += 1;
636 }
637 }
638 }
639
640 pub fn remap(&mut self, id: DynId) -> DynId {
642 self.find(id)
643 }
644
645 pub fn remap_ids(&mut self, ids: &[DynId]) -> Vec<DynId> {
647 ids.iter().map(|id| self.find(*id)).collect()
648 }
649}
650
651impl Default for AxisUnionFind {
652 fn default() -> Self {
653 Self::new()
654 }
655}
656
657pub fn build_diag_union(tensors: &[&TensorDynLen]) -> AxisUnionFind {
667 let mut uf = AxisUnionFind::new();
668
669 for tensor in tensors {
670 for idx in tensor.indices() {
671 uf.make_set(*idx.id());
672 }
673
674 if tensor.is_diag() && tensor.indices().len() >= 2 {
675 let first_id = *tensor.indices()[0].id();
676 for idx in tensor.indices().iter().skip(1) {
677 uf.union(first_id, *idx.id());
678 }
679 }
680 }
681
682 uf
683}
684
685pub fn remap_tensor_ids(tensors: &[&TensorDynLen], uf: &mut AxisUnionFind) -> Vec<Vec<DynId>> {
690 tensors
691 .iter()
692 .map(|t| t.indices.iter().map(|idx| uf.find(*idx.id())).collect())
693 .collect()
694}
695
696pub fn remap_output_ids(output: &[DynIndex], uf: &mut AxisUnionFind) -> Vec<DynId> {
698 output.iter().map(|idx| uf.find(*idx.id())).collect()
699}
700
701pub fn collect_sizes(tensors: &[&TensorDynLen], uf: &mut AxisUnionFind) -> HashMap<DynId, usize> {
706 let mut sizes = HashMap::new();
707
708 for tensor in tensors {
709 let dims = tensor.dims();
710 for (idx, &dim) in tensor.indices.iter().zip(dims.iter()) {
711 let rep = uf.find(*idx.id());
712 sizes.entry(rep).or_insert(dim);
713 }
714 }
715
716 sizes
717}
718
719fn contract_multi_impl(
731 tensors: &[&TensorDynLen],
732 options: ContractionOptions<'_>,
733) -> Result<TensorDynLen> {
734 let mut diag_uf = AxisUnionFind::new();
738
739 let plan = build_contraction_plan(tensors, options, &mut diag_uf)?;
741
742 let mut sizes: HashMap<usize, usize> = HashMap::new();
747 for (tensor_idx, tensor) in tensors.iter().enumerate() {
748 let dims = tensor.dims();
749 for (pos, &dim) in dims.iter().enumerate() {
750 let internal_id = plan.input_ids[tensor_idx][pos];
751 match sizes.entry(internal_id) {
752 std::collections::hash_map::Entry::Vacant(entry) => {
753 entry.insert(dim);
754 }
755 std::collections::hash_map::Entry::Occupied(entry) => {
756 if *entry.get() != dim {
757 return Err(anyhow::anyhow!(
758 "Internal label dimension mismatch: label {} has dimensions {} and {}",
759 internal_id,
760 entry.get(),
761 dim
762 ));
763 }
764 }
765 }
766 }
767 }
768
769 let profile_signature = contract_profile_enabled().then(|| ContractSignature {
770 operands: tensors
771 .iter()
772 .enumerate()
773 .map(|(tensor_idx, tensor)| ContractOperandSignature {
774 dims: tensor.dims().to_vec(),
775 ids: plan.input_ids[tensor_idx].clone(),
776 is_diag: tensor.is_diag(),
777 })
778 .collect(),
779 output_ids: plan.output_ids.clone(),
780 output_dims: plan.output_ids.iter().map(|id| sizes[id]).collect(),
781 });
782 let profile_started = contract_profile_enabled().then(Instant::now);
783
784 let result = execute_contraction_plan(tensors, &plan, !options.retain_indices.is_empty())?;
785 if let (Some(signature), Some(started)) = (profile_signature, profile_started) {
786 record_contract_profile(signature, started.elapsed());
787 }
788 Ok(result)
789}
790
791fn execute_contraction_plan(
792 tensors: &[&TensorDynLen],
793 plan: &ContractionPlan,
794 has_retained_indices: bool,
795) -> Result<TensorDynLen> {
796 let any_grad = tensors.iter().any(|tensor| tensor.tracks_grad());
797 let first_dtype = tensors[0].as_native().dtype();
798 let same_dtype = tensors
799 .iter()
800 .all(|tensor| tensor.as_native().dtype() == first_dtype);
801 let has_non_dense_axis_classes = tensors.iter().any(|tensor| {
802 tensor
803 .storage()
804 .axis_classes()
805 .iter()
806 .copied()
807 .enumerate()
808 .any(|(axis, class)| axis != class)
809 });
810
811 if any_grad && same_dtype && has_non_dense_axis_classes {
812 if has_retained_indices {
813 return Err(anyhow::anyhow!(
814 "Retained AD contraction with structured storage is not yet supported"
815 ));
816 }
817
818 let mut iter = tensors.iter();
821 let Some(first) = iter.next() else {
822 return Err(anyhow::anyhow!("No tensors to contract"));
823 };
824 let mut result = (*first).clone();
825 for tensor in iter {
826 result = result.contract_pairwise_default(tensor);
827 }
828 return Ok(result);
829 }
830
831 if any_grad && same_dtype {
832 let operands = tensors
833 .iter()
834 .map(|tensor| tensor.as_inner())
835 .collect::<Vec<_>>();
836 let subscripts = build_einsum_subscripts_from_usize_ids(&plan.input_ids, &plan.output_ids)?;
837 let result = eager_einsum_ad(&operands, &subscripts)?;
838 return TensorDynLen::from_inner_with_axis_classes(
839 plan.result_indices.clone(),
840 result,
841 plan.result_axis_classes.clone(),
842 );
843 }
844
845 let native_operands: Vec<_> = tensors
846 .iter()
847 .enumerate()
848 .map(|(tensor_idx, tensor)| (tensor.as_native(), plan.input_ids[tensor_idx].as_slice()))
849 .collect();
850 let result_native = einsum_native_tensors(&native_operands, &plan.output_ids)?;
851 TensorDynLen::from_native_with_axis_classes(
852 plan.result_indices.clone(),
853 result_native,
854 plan.result_axis_classes.clone(),
855 )
856}
857
858fn build_einsum_subscripts_from_usize_ids(
859 input_ids: &[Vec<usize>],
860 output_ids: &[usize],
861) -> Result<String> {
862 fn ids_to_subscript(ids: &[usize]) -> Result<String> {
863 const LETTERS: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
864 let mut out = String::with_capacity(ids.len());
865 for &id in ids {
866 let letter = LETTERS.get(id).ok_or_else(|| {
867 anyhow::anyhow!("einsum label {id} exceeds supported label range")
868 })?;
869 out.push(char::from(*letter));
870 }
871 Ok(out)
872 }
873
874 let inputs = input_ids
875 .iter()
876 .map(|ids| ids_to_subscript(ids))
877 .collect::<Result<Vec<_>>>()?;
878 Ok(format!(
879 "{}->{}",
880 inputs.join(","),
881 ids_to_subscript(output_ids)?
882 ))
883}
884
885#[derive(Debug, Clone)]
887struct ContractionPlan {
888 input_ids: Vec<Vec<usize>>,
889 output_ids: Vec<usize>,
890 result_indices: Vec<DynIndex>,
891 result_axis_classes: Vec<usize>,
892}
893
894fn build_contraction_plan(
895 tensors: &[&TensorDynLen],
896 options: ContractionOptions<'_>,
897 diag_uf: &mut AxisUnionFind,
898) -> Result<ContractionPlan> {
899 let retained_indices: HashSet<DynIndex> = options.retain_indices.iter().cloned().collect();
900 let (input_ids, internal_id_to_original) =
901 build_internal_ids(tensors, options.allowed, diag_uf, &retained_indices)?;
902
903 let mut counts: HashMap<usize, usize> = HashMap::new();
904 for ids in &input_ids {
905 for &internal_id in ids {
906 *counts.entry(internal_id).or_insert(0) += 1;
907 }
908 }
909 let mut output_ids = Vec::new();
910 let mut seen_output = HashSet::new();
911 let mut found_retained = HashSet::new();
912
913 for (tensor_idx, tensor) in tensors.iter().enumerate() {
914 for (axis, idx) in tensor.indices.iter().enumerate() {
915 let internal_id = input_ids[tensor_idx][axis];
916 let should_output = counts[&internal_id] == 1 || retained_indices.contains(idx);
917 if should_output && seen_output.insert(internal_id) {
918 output_ids.push(internal_id);
919 }
920 if retained_indices.contains(idx) {
921 found_retained.insert(idx.clone());
922 }
923 }
924 }
925
926 for retained in retained_indices {
927 if !found_retained.contains(&retained) {
928 return Err(anyhow::anyhow!(
929 "Retained index {:?} does not appear in the input tensors",
930 retained
931 ));
932 }
933 }
934
935 let result_indices: Vec<DynIndex> = output_ids
936 .iter()
937 .map(|&internal_id| {
938 let (tensor_idx, pos) = internal_id_to_original[&internal_id];
939 tensors[tensor_idx].indices[pos].clone()
940 })
941 .collect();
942 validate_unique_output_indices(&result_indices)?;
943 let result_axis_classes =
944 output_axis_classes(tensors, &input_ids, &output_ids, &internal_id_to_original);
945
946 Ok(ContractionPlan {
947 input_ids,
948 output_ids,
949 result_indices,
950 result_axis_classes,
951 })
952}
953
954fn validate_retained_indices_exist(
955 tensors: &[&TensorDynLen],
956 retain_indices: &[DynIndex],
957) -> Result<()> {
958 for retain in retain_indices {
959 let found = tensors
960 .iter()
961 .any(|tensor| tensor.indices().iter().any(|idx| idx == retain));
962 if !found {
963 return Err(anyhow::anyhow!(
964 "Retained index {:?} does not appear in the input tensors",
965 retain
966 ));
967 }
968 }
969 Ok(())
970}
971
972fn retained_indices_for_component(
973 tensors: &[&TensorDynLen],
974 component: &[usize],
975 retain_indices: &[DynIndex],
976) -> Vec<DynIndex> {
977 let mut seen = HashSet::new();
978 let mut retained = Vec::new();
979 for retain in retain_indices {
980 if seen.insert(retain.clone())
981 && component.iter().any(|&tensor_idx| {
982 tensors[tensor_idx]
983 .indices()
984 .iter()
985 .any(|idx| idx == retain)
986 })
987 {
988 retained.push(retain.clone());
989 }
990 }
991 retained
992}
993
994fn validate_unique_output_indices(indices: &[DynIndex]) -> Result<()> {
995 let mut seen = HashSet::new();
996 for idx in indices {
997 if !seen.insert(idx.clone()) {
998 return Err(anyhow::anyhow!(
999 "Contraction result would contain duplicate output indices"
1000 ));
1001 }
1002 }
1003 Ok(())
1004}
1005
1006fn output_axis_classes(
1007 tensors: &[&TensorDynLen],
1008 ixs: &[Vec<usize>],
1009 output: &[usize],
1010 internal_id_to_original: &HashMap<usize, (usize, usize)>,
1011) -> Vec<usize> {
1012 fn find(parent: &mut [usize], value: usize) -> usize {
1013 if parent[value] != value {
1014 parent[value] = find(parent, parent[value]);
1015 }
1016 parent[value]
1017 }
1018
1019 fn union(parent: &mut [usize], lhs: usize, rhs: usize) {
1020 let lhs_root = find(parent, lhs);
1021 let rhs_root = find(parent, rhs);
1022 if lhs_root != rhs_root {
1023 parent[rhs_root] = lhs_root;
1024 }
1025 }
1026
1027 let mut class_offsets = Vec::with_capacity(tensors.len());
1028 let mut next_node = 0usize;
1029 for tensor in tensors {
1030 class_offsets.push(next_node);
1031 let payload_rank = tensor
1032 .storage()
1033 .axis_classes()
1034 .iter()
1035 .copied()
1036 .max()
1037 .map(|value| value + 1)
1038 .unwrap_or(0);
1039 next_node += payload_rank;
1040 }
1041 let mut parent: Vec<usize> = (0..next_node).collect();
1042 let mut axes_by_internal_id: HashMap<usize, Vec<usize>> = HashMap::new();
1043
1044 for (tensor_idx, tensor) in tensors.iter().enumerate() {
1045 for (axis, &internal_id) in ixs[tensor_idx].iter().enumerate() {
1046 let class_id = tensor.storage().axis_classes()[axis];
1047 let node = class_offsets[tensor_idx] + class_id;
1048 axes_by_internal_id
1049 .entry(internal_id)
1050 .or_default()
1051 .push(node);
1052 }
1053 }
1054
1055 for nodes in axes_by_internal_id.values() {
1056 if let Some((&first, rest)) = nodes.split_first() {
1057 for &node in rest {
1058 union(&mut parent, first, node);
1059 }
1060 }
1061 }
1062
1063 let mut root_to_class = HashMap::new();
1064 let mut next_class = 0usize;
1065 output
1066 .iter()
1067 .map(|internal_id| {
1068 let (tensor_idx, axis) = internal_id_to_original[internal_id];
1069 let class_id = tensors[tensor_idx].storage().axis_classes()[axis];
1070 let node = class_offsets[tensor_idx] + class_id;
1071 let root = find(&mut parent, node);
1072 *root_to_class.entry(root).or_insert_with(|| {
1073 let class = next_class;
1074 next_class += 1;
1075 class
1076 })
1077 })
1078 .collect()
1079}
1080
1081#[allow(clippy::type_complexity)]
1089fn build_internal_ids(
1090 tensors: &[&TensorDynLen],
1091 allowed: AllowedPairs<'_>,
1092 _diag_uf: &mut AxisUnionFind,
1093 retained_indices: &HashSet<DynIndex>,
1094) -> Result<(Vec<Vec<usize>>, HashMap<usize, (usize, usize)>)> {
1095 let mut next_id = 0usize;
1096 let mut index_to_internal: HashMap<DynIndex, usize> = HashMap::new();
1097 let mut retained_index_to_internal: HashMap<DynIndex, usize> = HashMap::new();
1098 let mut assigned: HashMap<(usize, usize), usize> = HashMap::new();
1099 let mut internal_id_to_original: HashMap<usize, (usize, usize)> = HashMap::new();
1100
1101 let pairs_to_process: Vec<(usize, usize)> = match allowed {
1103 AllowedPairs::All => {
1104 let mut pairs = Vec::new();
1105 for ti in 0..tensors.len() {
1106 for tj in (ti + 1)..tensors.len() {
1107 pairs.push((ti, tj));
1108 }
1109 }
1110 pairs
1111 }
1112 AllowedPairs::Specified(pairs) => pairs.to_vec(),
1113 };
1114
1115 for (ti, tj) in pairs_to_process {
1116 for (pi, idx_i) in tensors[ti].indices.iter().enumerate() {
1117 for (pj, idx_j) in tensors[tj].indices.iter().enumerate() {
1118 if idx_i.is_contractable(idx_j) {
1119 let key_i = (ti, pi);
1120 let key_j = (tj, pj);
1121
1122 match (assigned.get(&key_i).copied(), assigned.get(&key_j).copied()) {
1123 (None, None) => {
1124 let internal_id = if let Some(&id) = index_to_internal.get(idx_i) {
1125 id
1126 } else {
1127 let id = next_id;
1128 next_id += 1;
1129 index_to_internal.insert(idx_i.clone(), id);
1130 internal_id_to_original.insert(id, key_i);
1131 id
1132 };
1133 assigned.insert(key_i, internal_id);
1134 assigned.insert(key_j, internal_id);
1135 if idx_i != idx_j {
1136 index_to_internal.insert(idx_j.clone(), internal_id);
1137 }
1138 }
1139 (Some(id), None) => {
1140 assigned.insert(key_j, id);
1141 index_to_internal.insert(idx_j.clone(), id);
1142 }
1143 (None, Some(id)) => {
1144 assigned.insert(key_i, id);
1145 index_to_internal.insert(idx_i.clone(), id);
1146 }
1147 (Some(_id_i), Some(_id_j)) => {
1148 }
1150 }
1151 }
1152 }
1153 }
1154 }
1155
1156 for (tensor_idx, tensor) in tensors.iter().enumerate() {
1158 for (pos, idx) in tensor.indices.iter().enumerate() {
1159 let key = (tensor_idx, pos);
1160 if let std::collections::hash_map::Entry::Vacant(e) = assigned.entry(key) {
1161 let internal_id = if retained_indices.contains(idx) {
1162 if let Some(&id) = retained_index_to_internal.get(idx) {
1163 id
1164 } else {
1165 let id = next_id;
1166 next_id += 1;
1167 retained_index_to_internal.insert(idx.clone(), id);
1168 internal_id_to_original.insert(id, key);
1169 id
1170 }
1171 } else {
1172 let id = next_id;
1173 next_id += 1;
1174 internal_id_to_original.insert(id, key);
1175 id
1176 };
1177 e.insert(internal_id);
1178 }
1179 }
1180 }
1181
1182 let ixs: Vec<Vec<usize>> = tensors
1184 .iter()
1185 .enumerate()
1186 .map(|(tensor_idx, tensor)| {
1187 (0..tensor.indices.len())
1188 .map(|pos| assigned[&(tensor_idx, pos)])
1189 .collect()
1190 })
1191 .collect();
1192
1193 Ok((ixs, internal_id_to_original))
1194}
1195
1196fn has_contractable_indices(a: &TensorDynLen, b: &TensorDynLen) -> bool {
1202 a.indices
1203 .iter()
1204 .any(|idx_a| b.indices.iter().any(|idx_b| idx_a.is_contractable(idx_b)))
1205}
1206
1207#[allow(dead_code)]
1211fn find_tensor_connected_components(
1212 tensors: &[&TensorDynLen],
1213 allowed: AllowedPairs<'_>,
1214) -> Vec<Vec<usize>> {
1215 find_tensor_connected_components_with_retained(tensors, allowed, &[])
1216}
1217
1218fn find_tensor_connected_components_with_retained(
1219 tensors: &[&TensorDynLen],
1220 allowed: AllowedPairs<'_>,
1221 retain_indices: &[DynIndex],
1222) -> Vec<Vec<usize>> {
1223 let n = tensors.len();
1224 if n == 0 {
1225 return vec![];
1226 }
1227 if n == 1 {
1228 return vec![vec![0]];
1229 }
1230
1231 let mut graph = UnGraph::<(), ()>::new_undirected();
1233 let nodes: Vec<_> = (0..n).map(|_| graph.add_node(())).collect();
1234
1235 match allowed {
1237 AllowedPairs::All => {
1238 for i in 0..n {
1239 for j in (i + 1)..n {
1240 if has_contractable_indices(tensors[i], tensors[j]) {
1241 graph.add_edge(nodes[i], nodes[j], ());
1242 }
1243 }
1244 }
1245 }
1246 AllowedPairs::Specified(pairs) => {
1247 for &(i, j) in pairs {
1248 if has_contractable_indices(tensors[i], tensors[j]) {
1249 graph.add_edge(nodes[i], nodes[j], ());
1250 }
1251 }
1252 }
1253 }
1254
1255 if !retain_indices.is_empty() {
1256 for i in 0..n {
1257 for j in (i + 1)..n {
1258 if shares_retained_index(tensors[i], tensors[j], retain_indices) {
1259 graph.add_edge(nodes[i], nodes[j], ());
1260 }
1261 }
1262 }
1263 }
1264
1265 let num_components = connected_components(&graph);
1267
1268 if num_components == 1 {
1269 return vec![(0..n).collect()];
1270 }
1271
1272 use petgraph::visit::Dfs;
1274 let mut visited = vec![false; n];
1275 let mut components = Vec::new();
1276
1277 for start in 0..n {
1278 if !visited[start] {
1279 let mut component = Vec::new();
1280 let mut dfs = Dfs::new(&graph, nodes[start]);
1281 while let Some(node) = dfs.next(&graph) {
1282 let idx = node.index();
1283 if !visited[idx] {
1284 visited[idx] = true;
1285 component.push(idx);
1286 }
1287 }
1288 component.sort();
1289 components.push(component);
1290 }
1291 }
1292
1293 components.sort_by_key(|c| c[0]);
1294 components
1295}
1296
1297fn shares_retained_index(a: &TensorDynLen, b: &TensorDynLen, retain_indices: &[DynIndex]) -> bool {
1298 retain_indices.iter().any(|retain| {
1299 a.indices().iter().any(|idx_a| idx_a == retain)
1300 && b.indices().iter().any(|idx_b| idx_b == retain)
1301 })
1302}
1303
1304fn remap_allowed_pairs(allowed: AllowedPairs<'_>, component: &[usize]) -> RemappedAllowedPairs {
1306 match allowed {
1307 AllowedPairs::All => RemappedAllowedPairs::All,
1308 AllowedPairs::Specified(pairs) => {
1309 let orig_to_local: HashMap<usize, usize> = component
1310 .iter()
1311 .enumerate()
1312 .map(|(local, &orig)| (orig, local))
1313 .collect();
1314
1315 let remapped: Vec<(usize, usize)> = pairs
1316 .iter()
1317 .filter_map(
1318 |&(i, j)| match (orig_to_local.get(&i), orig_to_local.get(&j)) {
1319 (Some(&li), Some(&lj)) => Some((li, lj)),
1320 _ => None,
1321 },
1322 )
1323 .collect();
1324
1325 RemappedAllowedPairs::Specified(remapped)
1326 }
1327 }
1328}
1329
1330enum RemappedAllowedPairs {
1332 All,
1333 Specified(Vec<(usize, usize)>),
1334}
1335
1336impl RemappedAllowedPairs {
1337 fn as_ref(&self) -> AllowedPairs<'_> {
1338 match self {
1339 RemappedAllowedPairs::All => AllowedPairs::All,
1340 RemappedAllowedPairs::Specified(pairs) => AllowedPairs::Specified(pairs),
1341 }
1342 }
1343}
1344
1345#[cfg(test)]
1346mod tests;