1use std::cell::RefCell;
26use std::collections::HashMap;
27use std::env;
28use std::time::{Duration, Instant};
29
30use anyhow::Result;
31use petgraph::algo::connected_components;
32use petgraph::prelude::*;
33use tensor4all_tensorbackend::einsum_native_tensors;
34
35use crate::defaults::{DynId, DynIndex, TensorDynLen};
36
37use crate::index_like::IndexLike;
38use crate::tensor_like::AllowedPairs;
39
40#[derive(Debug, Clone, Hash, PartialEq, Eq)]
41struct ContractOperandSignature {
42 dims: Vec<usize>,
43 ids: Vec<usize>,
44 is_diag: bool,
45}
46
47#[derive(Debug, Clone, Hash, PartialEq, Eq)]
48struct ContractSignature {
49 operands: Vec<ContractOperandSignature>,
50 output_ids: Vec<usize>,
51 output_dims: Vec<usize>,
52}
53
54#[derive(Debug, Default, Clone)]
55struct ContractProfileEntry {
56 calls: usize,
57 total_time: Duration,
58}
59
60thread_local! {
61 static CONTRACT_PROFILE_STATE: RefCell<HashMap<ContractSignature, ContractProfileEntry>> =
62 RefCell::new(HashMap::new());
63}
64
65fn contract_profile_enabled() -> bool {
66 env::var("T4A_PROFILE_CONTRACT").is_ok()
67}
68
69fn record_contract_profile(signature: ContractSignature, elapsed: Duration) {
70 if !contract_profile_enabled() {
71 return;
72 }
73 CONTRACT_PROFILE_STATE.with(|state| {
74 let mut state = state.borrow_mut();
75 let entry = state.entry(signature).or_default();
76 entry.calls += 1;
77 entry.total_time += elapsed;
78 });
79}
80
81pub fn reset_contract_profile() {
83 CONTRACT_PROFILE_STATE.with(|state| state.borrow_mut().clear());
84}
85
86pub fn print_and_reset_contract_profile() {
88 if !contract_profile_enabled() {
89 return;
90 }
91 CONTRACT_PROFILE_STATE.with(|state| {
92 let mut entries: Vec<_> = state
93 .borrow()
94 .iter()
95 .map(|(k, v)| (k.clone(), v.clone()))
96 .collect();
97 state.borrow_mut().clear();
98 entries.sort_by(|(_, lhs), (_, rhs)| rhs.total_time.cmp(&lhs.total_time));
99
100 eprintln!("=== contract_multi Profile ===");
101 for (idx, (signature, entry)) in entries.into_iter().take(20).enumerate() {
102 let operands = signature
103 .operands
104 .iter()
105 .map(|operand| {
106 format!(
107 "dims={:?} ids={:?}{}",
108 operand.dims,
109 operand.ids,
110 if operand.is_diag { " diag" } else { "" }
111 )
112 })
113 .collect::<Vec<_>>()
114 .join(" ; ");
115 eprintln!(
116 "#{idx:02} calls={} total={:.3}s per_call={:.3}us output_dims={:?} output_ids={:?}",
117 entry.calls,
118 entry.total_time.as_secs_f64(),
119 entry.total_time.as_secs_f64() * 1e6 / entry.calls as f64,
120 signature.output_dims,
121 signature.output_ids,
122 );
123 eprintln!(" {operands}");
124 }
125 });
126}
127
128pub fn contract_multi(
178 tensors: &[&TensorDynLen],
179 allowed: AllowedPairs<'_>,
180) -> Result<TensorDynLen> {
181 match tensors.len() {
182 0 => Err(anyhow::anyhow!("No tensors to contract")),
183 1 => Ok((*tensors[0]).clone()),
184 _ => {
185 if let AllowedPairs::Specified(pairs) = allowed {
187 for &(i, j) in pairs {
188 if !has_contractable_indices(tensors[i], tensors[j]) {
189 return Err(anyhow::anyhow!(
190 "Specified pair ({}, {}) has no contractable indices",
191 i,
192 j
193 ));
194 }
195 }
196 }
197
198 let components = find_tensor_connected_components(tensors, allowed);
200
201 if components.len() == 1 {
202 contract_multi_impl(tensors, allowed, true)
204 } else {
205 let mut results: Vec<TensorDynLen> = Vec::new();
207 for component in &components {
208 let component_tensors: Vec<&TensorDynLen> =
209 component.iter().map(|&i| tensors[i]).collect();
210
211 let remapped_allowed = remap_allowed_pairs(allowed, component);
213 let contracted =
214 contract_multi_impl(&component_tensors, remapped_allowed.as_ref(), true)?;
215 results.push(contracted);
216 }
217
218 let mut results_iter = results.into_iter();
220 let mut result = results_iter.next().unwrap();
221 for other in results_iter {
222 result = result.outer_product(&other)?;
223 }
224 Ok(result)
225 }
226 }
227 }
228}
229
230pub fn contract_connected(
277 tensors: &[&TensorDynLen],
278 allowed: AllowedPairs<'_>,
279) -> Result<TensorDynLen> {
280 match tensors.len() {
281 0 => Err(anyhow::anyhow!("No tensors to contract")),
282 1 => Ok((*tensors[0]).clone()),
283 _ => {
284 let components = find_tensor_connected_components(tensors, allowed);
286 if components.len() > 1 {
287 return Err(anyhow::anyhow!(
288 "Disconnected tensor network: {} components found",
289 components.len()
290 ));
291 }
292 contract_multi_impl(tensors, allowed, true)
294 }
295 }
296}
297
298#[derive(Debug, Clone)]
307pub struct AxisUnionFind {
308 parent: HashMap<DynId, DynId>,
310 rank: HashMap<DynId, usize>,
312}
313
314impl AxisUnionFind {
315 pub fn new() -> Self {
317 Self {
318 parent: HashMap::new(),
319 rank: HashMap::new(),
320 }
321 }
322
323 pub fn make_set(&mut self, id: DynId) {
325 use std::collections::hash_map::Entry;
326 if let Entry::Vacant(e) = self.parent.entry(id) {
327 e.insert(id);
328 self.rank.insert(id, 0);
329 }
330 }
331
332 pub fn find(&mut self, id: DynId) -> DynId {
335 self.make_set(id);
336 if self.parent[&id] != id {
337 let root = self.find(self.parent[&id]);
338 self.parent.insert(id, root);
339 }
340 self.parent[&id]
341 }
342
343 pub fn union(&mut self, a: DynId, b: DynId) {
346 let root_a = self.find(a);
347 let root_b = self.find(b);
348
349 if root_a == root_b {
350 return;
351 }
352
353 let rank_a = self.rank[&root_a];
354 let rank_b = self.rank[&root_b];
355
356 if rank_a < rank_b {
357 self.parent.insert(root_a, root_b);
358 } else if rank_a > rank_b {
359 self.parent.insert(root_b, root_a);
360 } else {
361 self.parent.insert(root_b, root_a);
362 *self.rank.get_mut(&root_a).unwrap() += 1;
363 }
364 }
365
366 pub fn remap(&mut self, id: DynId) -> DynId {
368 self.find(id)
369 }
370
371 pub fn remap_ids(&mut self, ids: &[DynId]) -> Vec<DynId> {
373 ids.iter().map(|id| self.find(*id)).collect()
374 }
375}
376
377impl Default for AxisUnionFind {
378 fn default() -> Self {
379 Self::new()
380 }
381}
382
383pub fn build_diag_union(tensors: &[&TensorDynLen]) -> AxisUnionFind {
393 let mut uf = AxisUnionFind::new();
394
395 for tensor in tensors {
396 for idx in tensor.indices() {
397 uf.make_set(*idx.id());
398 }
399
400 if tensor.is_diag() && tensor.indices().len() >= 2 {
401 let first_id = *tensor.indices()[0].id();
402 for idx in tensor.indices().iter().skip(1) {
403 uf.union(first_id, *idx.id());
404 }
405 }
406 }
407
408 uf
409}
410
411pub fn remap_tensor_ids(tensors: &[&TensorDynLen], uf: &mut AxisUnionFind) -> Vec<Vec<DynId>> {
416 tensors
417 .iter()
418 .map(|t| t.indices.iter().map(|idx| uf.find(*idx.id())).collect())
419 .collect()
420}
421
422pub fn remap_output_ids(output: &[DynIndex], uf: &mut AxisUnionFind) -> Vec<DynId> {
424 output.iter().map(|idx| uf.find(*idx.id())).collect()
425}
426
427pub fn collect_sizes(tensors: &[&TensorDynLen], uf: &mut AxisUnionFind) -> HashMap<DynId, usize> {
432 let mut sizes = HashMap::new();
433
434 for tensor in tensors {
435 let dims = tensor.dims();
436 for (idx, &dim) in tensor.indices.iter().zip(dims.iter()) {
437 let rep = uf.find(*idx.id());
438 sizes.entry(rep).or_insert(dim);
439 }
440 }
441
442 sizes
443}
444
445fn contract_multi_impl(
460 tensors: &[&TensorDynLen],
461 allowed: AllowedPairs<'_>,
462 _skip_connectivity_check: bool,
463) -> Result<TensorDynLen> {
464 let mut diag_uf = build_diag_union(tensors);
466
467 let (ixs, internal_id_to_original) = build_internal_ids(tensors, allowed, &mut diag_uf)?;
469
470 let mut idx_count: HashMap<usize, usize> = HashMap::new();
472 for ix in &ixs {
473 for &i in ix {
474 *idx_count.entry(i).or_insert(0) += 1;
475 }
476 }
477 let mut output: Vec<usize> = idx_count
478 .iter()
479 .filter(|(_, &count)| count == 1)
480 .map(|(&idx, _)| idx)
481 .collect();
482 output.sort(); let mut sizes: HashMap<usize, usize> = HashMap::new();
489 for (tensor_idx, tensor) in tensors.iter().enumerate() {
490 let dims = tensor.dims();
491 for (pos, &dim) in dims.iter().enumerate() {
492 let internal_id = ixs[tensor_idx][pos];
493 sizes.entry(internal_id).or_insert(dim);
494 }
495 }
496
497 let profile_signature = contract_profile_enabled().then(|| ContractSignature {
498 operands: tensors
499 .iter()
500 .enumerate()
501 .map(|(tensor_idx, tensor)| ContractOperandSignature {
502 dims: tensor.dims().to_vec(),
503 ids: ixs[tensor_idx].clone(),
504 is_diag: tensor.is_diag(),
505 })
506 .collect(),
507 output_ids: output.clone(),
508 output_dims: output.iter().map(|id| sizes[id]).collect(),
509 });
510 let profile_started = contract_profile_enabled().then(Instant::now);
511
512 let native_operands: Vec<_> = tensors
513 .iter()
514 .enumerate()
515 .map(|(tensor_idx, tensor)| (tensor.as_native(), ixs[tensor_idx].as_slice()))
516 .collect();
517
518 let result_native = einsum_native_tensors(&native_operands, &output)?;
519 if let (Some(signature), Some(started)) = (profile_signature, profile_started) {
520 record_contract_profile(signature, started.elapsed());
521 }
522 let final_indices = if output.is_empty() {
523 vec![]
524 } else {
525 output
526 .iter()
527 .map(|&internal_id| {
528 let (tensor_idx, pos) = internal_id_to_original[&internal_id];
529 tensors[tensor_idx].indices[pos].clone()
530 })
531 .collect()
532 };
533 TensorDynLen::from_native(final_indices, result_native)
534}
535
536#[allow(clippy::type_complexity)]
542fn build_internal_ids(
543 tensors: &[&TensorDynLen],
544 allowed: AllowedPairs<'_>,
545 diag_uf: &mut AxisUnionFind,
546) -> Result<(Vec<Vec<usize>>, HashMap<usize, (usize, usize)>)> {
547 let mut next_id = 0usize;
548 let mut dynid_to_internal: HashMap<DynId, usize> = HashMap::new();
549 let mut assigned: HashMap<(usize, usize), usize> = HashMap::new();
550 let mut internal_id_to_original: HashMap<usize, (usize, usize)> = HashMap::new();
551
552 let pairs_to_process: Vec<(usize, usize)> = match allowed {
554 AllowedPairs::All => {
555 let mut pairs = Vec::new();
556 for ti in 0..tensors.len() {
557 for tj in (ti + 1)..tensors.len() {
558 pairs.push((ti, tj));
559 }
560 }
561 pairs
562 }
563 AllowedPairs::Specified(pairs) => pairs.to_vec(),
564 };
565
566 for (ti, tj) in pairs_to_process {
567 for (pi, idx_i) in tensors[ti].indices.iter().enumerate() {
568 for (pj, idx_j) in tensors[tj].indices.iter().enumerate() {
569 if idx_i.is_contractable(idx_j) {
570 let key_i = (ti, pi);
571 let key_j = (tj, pj);
572
573 let remapped_i = diag_uf.find(*idx_i.id());
574 let remapped_j = diag_uf.find(*idx_j.id());
575
576 match (assigned.get(&key_i).copied(), assigned.get(&key_j).copied()) {
577 (None, None) => {
578 let internal_id = if let Some(&id) = dynid_to_internal.get(&remapped_i)
579 {
580 id
581 } else {
582 let id = next_id;
583 next_id += 1;
584 dynid_to_internal.insert(remapped_i, id);
585 internal_id_to_original.insert(id, key_i);
586 id
587 };
588 assigned.insert(key_i, internal_id);
589 assigned.insert(key_j, internal_id);
590 if remapped_i != remapped_j {
591 dynid_to_internal.insert(remapped_j, internal_id);
592 }
593 }
594 (Some(id), None) => {
595 assigned.insert(key_j, id);
596 dynid_to_internal.insert(remapped_j, id);
597 }
598 (None, Some(id)) => {
599 assigned.insert(key_i, id);
600 dynid_to_internal.insert(remapped_i, id);
601 }
602 (Some(_id_i), Some(_id_j)) => {
603 }
605 }
606 }
607 }
608 }
609 }
610
611 for (tensor_idx, tensor) in tensors.iter().enumerate() {
613 for (pos, idx) in tensor.indices.iter().enumerate() {
614 let key = (tensor_idx, pos);
615 if let std::collections::hash_map::Entry::Vacant(e) = assigned.entry(key) {
616 let remapped_id = diag_uf.find(*idx.id());
617
618 let internal_id = if let Some(&id) = dynid_to_internal.get(&remapped_id) {
619 id
620 } else {
621 let id = next_id;
622 next_id += 1;
623 dynid_to_internal.insert(remapped_id, id);
624 internal_id_to_original.insert(id, key);
625 id
626 };
627 e.insert(internal_id);
628 }
629 }
630 }
631
632 let ixs: Vec<Vec<usize>> = tensors
634 .iter()
635 .enumerate()
636 .map(|(tensor_idx, tensor)| {
637 (0..tensor.indices.len())
638 .map(|pos| assigned[&(tensor_idx, pos)])
639 .collect()
640 })
641 .collect();
642
643 Ok((ixs, internal_id_to_original))
644}
645
646fn has_contractable_indices(a: &TensorDynLen, b: &TensorDynLen) -> bool {
652 a.indices
653 .iter()
654 .any(|idx_a| b.indices.iter().any(|idx_b| idx_a.is_contractable(idx_b)))
655}
656
657fn find_tensor_connected_components(
661 tensors: &[&TensorDynLen],
662 allowed: AllowedPairs<'_>,
663) -> Vec<Vec<usize>> {
664 let n = tensors.len();
665 if n == 0 {
666 return vec![];
667 }
668 if n == 1 {
669 return vec![vec![0]];
670 }
671
672 let mut graph = UnGraph::<(), ()>::new_undirected();
674 let nodes: Vec<_> = (0..n).map(|_| graph.add_node(())).collect();
675
676 match allowed {
678 AllowedPairs::All => {
679 for i in 0..n {
680 for j in (i + 1)..n {
681 if has_contractable_indices(tensors[i], tensors[j]) {
682 graph.add_edge(nodes[i], nodes[j], ());
683 }
684 }
685 }
686 }
687 AllowedPairs::Specified(pairs) => {
688 for &(i, j) in pairs {
689 if has_contractable_indices(tensors[i], tensors[j]) {
690 graph.add_edge(nodes[i], nodes[j], ());
691 }
692 }
693 }
694 }
695
696 let num_components = connected_components(&graph);
698
699 if num_components == 1 {
700 return vec![(0..n).collect()];
701 }
702
703 use petgraph::visit::Dfs;
705 let mut visited = vec![false; n];
706 let mut components = Vec::new();
707
708 for start in 0..n {
709 if !visited[start] {
710 let mut component = Vec::new();
711 let mut dfs = Dfs::new(&graph, nodes[start]);
712 while let Some(node) = dfs.next(&graph) {
713 let idx = node.index();
714 if !visited[idx] {
715 visited[idx] = true;
716 component.push(idx);
717 }
718 }
719 component.sort();
720 components.push(component);
721 }
722 }
723
724 components.sort_by_key(|c| c[0]);
725 components
726}
727
728fn remap_allowed_pairs(allowed: AllowedPairs<'_>, component: &[usize]) -> RemappedAllowedPairs {
730 match allowed {
731 AllowedPairs::All => RemappedAllowedPairs::All,
732 AllowedPairs::Specified(pairs) => {
733 let orig_to_local: HashMap<usize, usize> = component
734 .iter()
735 .enumerate()
736 .map(|(local, &orig)| (orig, local))
737 .collect();
738
739 let remapped: Vec<(usize, usize)> = pairs
740 .iter()
741 .filter_map(
742 |&(i, j)| match (orig_to_local.get(&i), orig_to_local.get(&j)) {
743 (Some(&li), Some(&lj)) => Some((li, lj)),
744 _ => None,
745 },
746 )
747 .collect();
748
749 RemappedAllowedPairs::Specified(remapped)
750 }
751 }
752}
753
754enum RemappedAllowedPairs {
756 All,
757 Specified(Vec<(usize, usize)>),
758}
759
760impl RemappedAllowedPairs {
761 fn as_ref(&self) -> AllowedPairs<'_> {
762 match self {
763 RemappedAllowedPairs::All => AllowedPairs::All,
764 RemappedAllowedPairs::Specified(pairs) => AllowedPairs::Specified(pairs),
765 }
766 }
767}
768
769#[cfg(test)]
770mod tests;