1use std::collections::{HashMap, HashSet};
8use std::fmt::Debug;
9use std::hash::Hash;
10
11use anyhow::{anyhow, bail, Context, Result};
12
13use super::contraction::{contract, ContractionOptions};
14use super::decompose::{factorize_tensor_to_treetn_with, TreeTopology};
15use super::TreeTN;
16use tensor4all_core::{
17 AllowedPairs, AnyScalar, DynIndex, FactorizeAlg, FactorizeOptions, IndexLike, TensorDynLen,
18 TensorIndex, TensorLike,
19};
20
21type DiagonalPairApplication<V> = (
22 TreeTN<TensorDynLen, V>,
23 TreeTN<TensorDynLen, V>,
24 Vec<DynIndex>,
25 Vec<DynIndex>,
26);
27
28#[derive(Debug, Clone)]
59pub struct PartialContractionSpec<I: IndexLike> {
60 pub contract_pairs: Vec<(I, I)>,
62 pub diagonal_pairs: Vec<(I, I)>,
65 pub output_order: Option<Vec<I>>,
74}
75
76fn validate_partial_contraction_spec<T, V>(
77 a: &TreeTN<T, V>,
78 b: &TreeTN<T, V>,
79 spec: &PartialContractionSpec<T::Index>,
80) -> Result<()>
81where
82 T: TensorLike,
83 V: Clone + Hash + Eq + Send + Sync + Debug + Ord,
84 <T::Index as IndexLike>::Id: Clone + Hash + Eq + Debug + Send + Sync + Ord,
85{
86 let a_external_ids: HashSet<_> = a
87 .external_indices()
88 .into_iter()
89 .map(|idx| idx.id().clone())
90 .collect();
91 let b_external_ids: HashSet<_> = b
92 .external_indices()
93 .into_iter()
94 .map(|idx| idx.id().clone())
95 .collect();
96
97 let mut seen_a_ids = HashSet::new();
98 let mut seen_b_ids = HashSet::new();
99
100 for (kind, pairs) in [
101 ("contract_pairs", &spec.contract_pairs),
102 ("diagonal_pairs", &spec.diagonal_pairs),
103 ] {
104 for (idx_a, idx_b) in pairs {
105 if idx_a.dim() != idx_b.dim() {
106 bail!(
107 "partial_contract: {} index dimension mismatch: {} != {}",
108 kind,
109 idx_a.dim(),
110 idx_b.dim()
111 );
112 }
113
114 if !a_external_ids.contains(idx_a.id()) {
115 bail!(
116 "partial_contract: {:?} from {} not found in first TreeTN external indices",
117 idx_a.id(),
118 kind
119 );
120 }
121 if !b_external_ids.contains(idx_b.id()) {
122 bail!(
123 "partial_contract: {:?} from {} not found in second TreeTN external indices",
124 idx_b.id(),
125 kind
126 );
127 }
128
129 if !seen_a_ids.insert(idx_a.id().clone()) {
130 bail!(
131 "partial_contract: first TreeTN index {:?} appears in multiple pairs",
132 idx_a.id()
133 );
134 }
135 if !seen_b_ids.insert(idx_b.id().clone()) {
136 bail!(
137 "partial_contract: second TreeTN index {:?} appears in multiple pairs",
138 idx_b.id()
139 );
140 }
141 }
142 }
143
144 Ok(())
145}
146
147fn canonical_edge<V>(left: &V, right: &V) -> (V, V)
148where
149 V: Clone + Ord,
150{
151 if left <= right {
152 (left.clone(), right.clone())
153 } else {
154 (right.clone(), left.clone())
155 }
156}
157
158fn sorted_edge_set<V>(tn: &TreeTN<TensorDynLen, V>) -> Vec<(V, V)>
159where
160 V: Clone + Hash + Eq + Send + Sync + Debug + Ord,
161{
162 let mut edges: Vec<_> = tn
163 .site_index_network()
164 .edges()
165 .map(|(u, v)| canonical_edge(&u, &v))
166 .collect();
167 edges.sort();
168 edges.dedup();
169 edges
170}
171
172fn compatible_union_node_names<V>(
173 a: &TreeTN<TensorDynLen, V>,
174 b: &TreeTN<TensorDynLen, V>,
175) -> Vec<V>
176where
177 V: Clone + Hash + Eq + Send + Sync + Debug + Ord,
178{
179 let mut names: Vec<_> = a.node_names();
180 names.extend(b.node_names());
181 names.sort();
182 names.dedup();
183 names
184}
185
186fn validate_union_topology<V>(node_names: &[V], edges: &[(V, V)]) -> Result<()>
187where
188 V: Clone + Hash + Eq + Send + Sync + Debug + Ord,
189{
190 if node_names.is_empty() {
191 bail!("partial_contract: networks must contain at least one node");
192 }
193
194 if edges.len() + 1 != node_names.len() {
195 bail!("partial_contract: networks have incompatible topologies");
196 }
197
198 let mut adjacency: HashMap<V, Vec<V>> = node_names
199 .iter()
200 .cloned()
201 .map(|name| (name, Vec::new()))
202 .collect();
203 for (u, v) in edges {
204 let Some(neighbors_u) = adjacency.get_mut(u) else {
205 bail!("partial_contract: union topology references unknown node");
206 };
207 neighbors_u.push(v.clone());
208 let Some(neighbors_v) = adjacency.get_mut(v) else {
209 bail!("partial_contract: union topology references unknown node");
210 };
211 neighbors_v.push(u.clone());
212 }
213
214 let mut seen = HashSet::new();
215 let mut stack = vec![node_names[0].clone()];
216 while let Some(node) = stack.pop() {
217 if !seen.insert(node.clone()) {
218 continue;
219 }
220 if let Some(neighbors) = adjacency.get(&node) {
221 stack.extend(neighbors.iter().cloned());
222 }
223 }
224
225 if seen.len() != node_names.len() {
226 bail!("partial_contract: networks have incompatible topologies");
227 }
228
229 Ok(())
230}
231
232fn factorize_options_from_contraction_options(
233 options: &ContractionOptions,
234) -> Result<FactorizeOptions> {
235 let mut factorize_options = match options.factorize_alg {
236 FactorizeAlg::SVD => FactorizeOptions::svd(),
237 FactorizeAlg::QR => FactorizeOptions::qr(),
238 FactorizeAlg::LU => FactorizeOptions::lu(),
239 FactorizeAlg::CI => FactorizeOptions::ci(),
240 };
241 if let Some(policy) = options.svd_policy {
242 factorize_options = factorize_options.with_svd_policy(policy);
243 }
244 if let Some(rtol) = options.qr_rtol {
245 factorize_options = factorize_options.with_qr_rtol(rtol);
246 }
247 if let Some(max_rank) = options.max_rank {
248 factorize_options = factorize_options.with_max_rank(max_rank);
249 }
250 factorize_options.validate().map_err(|err| {
251 anyhow!("partial_contract: invalid contraction factorization options: {err}")
252 })?;
253 Ok(factorize_options)
254}
255
256fn union_result_topology<V>(
257 a: &TreeTN<TensorDynLen, V>,
258 b: &TreeTN<TensorDynLen, V>,
259 contracted_tensor: &TensorDynLen,
260) -> Result<TreeTopology<V, <DynIndex as IndexLike>::Id>>
261where
262 V: Clone + Hash + Eq + Send + Sync + Debug + Ord,
263 <DynIndex as IndexLike>::Id: Clone + Hash + Eq + Ord + Debug + Send + Sync,
264{
265 let node_names = compatible_union_node_names(a, b);
266 let mut union_edges = sorted_edge_set(a);
267 union_edges.extend(sorted_edge_set(b));
268 union_edges.sort();
269 union_edges.dedup();
270 validate_union_topology(&node_names, &union_edges)?;
271
272 let surviving_ids: HashSet<_> = contracted_tensor
273 .external_indices()
274 .into_iter()
275 .map(|idx| *idx.id())
276 .collect();
277
278 let mut nodes = HashMap::new();
279 for node_name in &node_names {
280 let mut ids = Vec::new();
281
282 if let Some(site_space_a) = a.site_index_network().site_space(node_name) {
283 for site_idx in site_space_a {
284 if surviving_ids.contains(site_idx.id()) {
285 ids.push(*site_idx.id());
286 }
287 }
288 }
289
290 if let Some(site_space_b) = b.site_index_network().site_space(node_name) {
291 for site_idx in site_space_b {
292 if surviving_ids.contains(site_idx.id()) && !ids.contains(site_idx.id()) {
293 ids.push(*site_idx.id());
294 }
295 }
296 }
297
298 nodes.insert(node_name.clone(), ids);
299 }
300
301 Ok(TreeTopology::new(nodes, union_edges))
302}
303
304fn contract_mismatched_topologies<V>(
305 a: &TreeTN<TensorDynLen, V>,
306 b: &TreeTN<TensorDynLen, V>,
307 center: &V,
308 options: ContractionOptions,
309) -> Result<TreeTN<TensorDynLen, V>>
310where
311 V: Clone + Hash + Eq + Send + Sync + Debug + Ord,
312 <DynIndex as IndexLike>::Id: Clone + Hash + Eq + Ord + Debug + Send + Sync,
313{
314 let a_dense = a
315 .sim_internal_inds()
316 .contract_to_tensor()
317 .context("partial_contract: failed to contract first mismatched-topology TreeTN")?;
318 let b_dense = b
319 .sim_internal_inds()
320 .contract_to_tensor()
321 .context("partial_contract: failed to contract second mismatched-topology TreeTN")?;
322 let contracted_tensor =
323 <TensorDynLen as TensorLike>::contract(&[&a_dense, &b_dense], AllowedPairs::All)
324 .context("partial_contract: failed dense contraction for mismatched topologies")?;
325
326 if contracted_tensor.external_indices().is_empty() {
327 let mut result = TreeTN::<TensorDynLen, V>::new();
328 result
329 .add_tensor(center.clone(), contracted_tensor)
330 .context("partial_contract: failed to wrap scalar mismatched-topology result")?;
331 result
332 .set_canonical_region([center.clone()])
333 .context("partial_contract: failed to set canonical region for scalar result")?;
334 return Ok(result);
335 }
336
337 let topology = union_result_topology(a, b, &contracted_tensor)?;
338 let factorize_options = factorize_options_from_contraction_options(&options)?;
339 factorize_tensor_to_treetn_with(&contracted_tensor, &topology, factorize_options, center)
340 .context("partial_contract: failed to factorize mismatched-topology dense result")
341}
342
343fn apply_output_order<T, V>(result: TreeTN<T, V>, output_order: &[T::Index]) -> Result<TreeTN<T, V>>
344where
345 T: TensorLike,
346 V: Clone + Hash + Eq + Send + Sync + Debug + Ord,
347 T::Index: Clone + Hash + Eq,
348 <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + Debug + Send + Sync,
349{
350 let (current_indices, _) = result.all_site_indices()?;
351 if output_order.len() != current_indices.len() {
352 bail!(
353 "partial_contract: output_order length {} does not match surviving external index count {}",
354 output_order.len(),
355 current_indices.len()
356 );
357 }
358
359 let current_ids: HashSet<_> = current_indices.iter().map(|idx| idx.id().clone()).collect();
360 let requested_ids: HashSet<_> = output_order.iter().map(|idx| idx.id().clone()).collect();
361 if current_ids != requested_ids {
362 bail!("partial_contract: output_order must contain exactly the surviving external indices");
363 }
364
365 let mut current_nodes = Vec::with_capacity(current_indices.len());
366 for index in ¤t_indices {
367 let node = result.site_index_network().find_node_by_index(index).ok_or_else(|| {
368 anyhow!(
369 "partial_contract: current result index {:?} is not present in the site index network",
370 index.id()
371 )
372 })?;
373 current_nodes.push(node.clone());
374 }
375
376 let unique_current_nodes: HashSet<_> = current_nodes.iter().cloned().collect();
377 if unique_current_nodes.len() != current_nodes.len() {
378 bail!(
379 "partial_contract: output_order currently requires at most one surviving site index per node"
380 );
381 }
382
383 let mut seen_requested = HashSet::new();
384 let mut ordered_nodes = Vec::with_capacity(result.node_count());
385 let mut ordered_node_set = HashSet::new();
386
387 for index in output_order {
388 if !seen_requested.insert(index.id().clone()) {
389 bail!("partial_contract: output_order contains duplicate indices");
390 }
391 let current_node = result
392 .site_index_network()
393 .find_node_by_index(index)
394 .ok_or_else(|| {
395 anyhow!(
396 "partial_contract: output_order index {:?} is not present in the result",
397 index.id()
398 )
399 })?;
400 if !ordered_node_set.insert(current_node.clone()) {
401 bail!(
402 "partial_contract: output_order currently requires each requested index to occupy a distinct node"
403 );
404 }
405 ordered_nodes.push(current_node.clone());
406 }
407
408 for node_name in result.node_names() {
409 if ordered_node_set.insert(node_name.clone()) {
410 ordered_nodes.push(node_name);
411 }
412 }
413
414 let tensors = ordered_nodes
415 .iter()
416 .map(|node_name| {
417 let node_idx = result.node_index(node_name).ok_or_else(|| {
418 anyhow!(
419 "partial_contract: output_order node {:?} is not present in the result",
420 node_name
421 )
422 })?;
423 result.tensor(node_idx).cloned().ok_or_else(|| {
424 anyhow!(
425 "partial_contract: tensor for output_order node {:?} is missing",
426 node_name
427 )
428 })
429 })
430 .collect::<Result<Vec<_>>>()?;
431
432 let mut reordered = TreeTN::from_tensors(tensors, ordered_nodes)
433 .context("partial_contract: failed to rebuild result in requested output order")?;
434 reordered.canonical_region = result.canonical_region.clone();
435 reordered.canonical_form = result.canonical_form;
436 reordered.ortho_towards = result.ortho_towards.clone();
437 Ok(reordered)
438}
439
440fn diagonal_copy_value(tensor: &TensorDynLen) -> AnyScalar {
441 if tensor.is_complex() {
442 AnyScalar::new_complex(1.0, 0.0)
443 } else {
444 AnyScalar::new_real(1.0)
445 }
446}
447
448fn apply_diagonal_pairs<V>(
449 a: &TreeTN<TensorDynLen, V>,
450 b: &TreeTN<TensorDynLen, V>,
451 diagonal_pairs: &[(DynIndex, DynIndex)],
452) -> Result<DiagonalPairApplication<V>>
453where
454 V: Clone + Hash + Eq + Send + Sync + Debug + Ord,
455 <DynIndex as IndexLike>::Id: Clone + Hash + Eq + Ord + Debug + Send + Sync,
456{
457 let mut a_modified = a.clone();
458 let mut b_modified = b.clone();
459 let mut restore_from = Vec::with_capacity(diagonal_pairs.len());
460 let mut restore_to = Vec::with_capacity(diagonal_pairs.len());
461
462 for (idx_a, idx_b) in diagonal_pairs {
463 let node_name = a_modified
464 .site_index_network()
465 .find_node_by_index(idx_a)
466 .cloned()
467 .ok_or_else(|| {
468 anyhow!(
469 "partial_contract: diagonal pair left index {:?} is not a site index of the first TreeTN",
470 idx_a.id()
471 )
472 })?;
473 let node_idx = a_modified.node_index(&node_name).ok_or_else(|| {
474 anyhow!(
475 "partial_contract: node {:?} for left diagonal index {:?} not found",
476 node_name,
477 idx_a.id()
478 )
479 })?;
480 let local_tensor = a_modified.tensor(node_idx).cloned().ok_or_else(|| {
481 anyhow!(
482 "partial_contract: tensor for node {:?} not found while processing diagonal pair {:?}",
483 node_name,
484 idx_a.id()
485 )
486 })?;
487
488 let aux_index = idx_a.sim();
489 let kept_index = idx_a.sim();
490 let copy_tensor = TensorDynLen::copy_tensor(
491 vec![idx_a.clone(), aux_index.clone(), kept_index.clone()],
492 diagonal_copy_value(&local_tensor),
493 )
494 .with_context(|| {
495 format!(
496 "partial_contract: failed to build copy tensor for diagonal pair {:?} <- {:?}",
497 idx_a.id(),
498 idx_b.id()
499 )
500 })?;
501 let expanded_tensor = local_tensor
502 .tensordot(©_tensor, &[(idx_a.clone(), idx_a.clone())])
503 .with_context(|| {
504 format!(
505 "partial_contract: failed to apply diagonal structure for pair {:?} <- {:?}",
506 idx_a.id(),
507 idx_b.id()
508 )
509 })?;
510 a_modified
511 .replace_tensor(node_idx, expanded_tensor)
512 .with_context(|| {
513 format!(
514 "partial_contract: failed to replace tensor at node {:?} for diagonal pair {:?}",
515 node_name,
516 idx_a.id()
517 )
518 })?
519 .ok_or_else(|| {
520 anyhow!(
521 "partial_contract: node {:?} disappeared while processing diagonal pair {:?}",
522 node_name,
523 idx_a.id()
524 )
525 })?;
526
527 b_modified = b_modified.replaceind(idx_b, &aux_index).with_context(|| {
528 format!(
529 "partial_contract: failed to align diagonal pair {:?} <- {:?}",
530 idx_a.id(),
531 idx_b.id()
532 )
533 })?;
534
535 restore_from.push(kept_index);
536 restore_to.push(idx_a.clone());
537 }
538
539 Ok((a_modified, b_modified, restore_from, restore_to))
540}
541
542pub fn partial_contract<V>(
593 a: &TreeTN<TensorDynLen, V>,
594 b: &TreeTN<TensorDynLen, V>,
595 spec: &PartialContractionSpec<DynIndex>,
596 center: &V,
597 options: ContractionOptions,
598) -> Result<TreeTN<TensorDynLen, V>>
599where
600 V: Clone + Hash + Eq + Send + Sync + Debug + Ord,
601 <DynIndex as IndexLike>::Id: Clone + Hash + Eq + Ord + Debug + Send + Sync,
602{
603 validate_partial_contraction_spec(a, b, spec)?;
604
605 let (a_modified, mut b_modified, restore_from, restore_to) =
606 apply_diagonal_pairs(a, b, &spec.diagonal_pairs)?;
607
608 for (idx_a, idx_b) in &spec.contract_pairs {
609 b_modified = b_modified.replaceind(idx_b, idx_a).with_context(|| {
610 format!(
611 "partial_contract: failed to align contract pair {:?} <- {:?}",
612 idx_a.id(),
613 idx_b.id()
614 )
615 })?;
616 }
617
618 let mut result = if a_modified.same_topology(&b_modified) {
619 contract(&a_modified, &b_modified, center, options)
620 .context("partial_contract: contraction failed")?
621 } else {
622 contract_mismatched_topologies(&a_modified, &b_modified, center, options)?
623 };
624
625 if !restore_from.is_empty() {
626 result = result.replaceinds(&restore_from, &restore_to).context(
627 "partial_contract: failed to restore surviving left-hand indices after diagonal pairing",
628 )?;
629 }
630
631 if let Some(output_order) = &spec.output_order {
632 apply_output_order(result, output_order)
633 } else {
634 Ok(result)
635 }
636}
637
638#[cfg(test)]
639mod tests;