Skip to main content

tensor4all_treetn/treetn/
partial_contraction.rs

1//! Partial site contraction for TreeTN.
2//!
3//! Provides [`partial_contract`] for selecting which site indices should be
4//! contracted and which should be linked through explicit diagonal/copy
5//! structure before calling the existing TreeTN contraction pipeline.
6
7use std::collections::{HashMap, HashSet};
8use std::fmt::Debug;
9use std::hash::Hash;
10
11use anyhow::{anyhow, bail, Context, Result};
12
13use super::contraction::{contract, ContractionOptions};
14use super::decompose::{factorize_tensor_to_treetn_with, TreeTopology};
15use super::TreeTN;
16use tensor4all_core::{
17    AllowedPairs, AnyScalar, DynIndex, FactorizeAlg, FactorizeOptions, IndexLike, TensorDynLen,
18    TensorIndex, TensorLike,
19};
20
21type DiagonalPairApplication<V> = (
22    TreeTN<TensorDynLen, V>,
23    TreeTN<TensorDynLen, V>,
24    Vec<DynIndex>,
25    Vec<DynIndex>,
26);
27
28/// Specification for partial site contraction between two TreeTNs.
29///
30/// - `contract_pairs`: Site index pairs to sum over and remove from the result.
31/// - `diagonal_pairs`: Site index pairs to identify through diagonal/copy
32///   structure while keeping the left-hand site leg in the result.
33/// - Remaining (unmentioned) site indices pass through as external legs.
34///
35/// Uses `Index` objects directly (not raw IDs), following Julia ITensor-style
36/// conventions.
37///
38/// # Examples
39///
40/// ```
41/// use tensor4all_core::DynIndex;
42/// use tensor4all_treetn::PartialContractionSpec;
43///
44/// let idx_a = DynIndex::new_dyn(4);
45/// let idx_b = DynIndex::new_dyn(4);
46/// let idx_c = DynIndex::new_dyn(3);
47/// let idx_d = DynIndex::new_dyn(3);
48///
49/// let spec = PartialContractionSpec {
50///     contract_pairs: vec![(idx_a.clone(), idx_b.clone())],
51///     diagonal_pairs: vec![(idx_c.clone(), idx_d.clone())],
52///     output_order: None,
53/// };
54///
55/// assert_eq!(spec.contract_pairs.len(), 1);
56/// assert_eq!(spec.diagonal_pairs.len(), 1);
57/// ```
58#[derive(Debug, Clone)]
59pub struct PartialContractionSpec<I: IndexLike> {
60    /// Site index pairs to contract (summed over, removed from result).
61    pub contract_pairs: Vec<(I, I)>,
62    /// Site index pairs to link through diagonal/copy structure while keeping
63    /// the left-hand site index in the result.
64    pub diagonal_pairs: Vec<(I, I)>,
65    /// Optional order for the surviving external site indices in the result.
66    ///
67    /// The indices must refer to the final result indices after applying
68    /// `contract_pairs` and `diagonal_pairs`. When provided, the result is
69    /// post-processed so that these indices appear in the requested order.
70    ///
71    /// Current implementation requires that each surviving site index occupies a
72    /// distinct node in the result.
73    pub output_order: Option<Vec<I>>,
74}
75
76fn validate_partial_contraction_spec<T, V>(
77    a: &TreeTN<T, V>,
78    b: &TreeTN<T, V>,
79    spec: &PartialContractionSpec<T::Index>,
80) -> Result<()>
81where
82    T: TensorLike,
83    V: Clone + Hash + Eq + Send + Sync + Debug + Ord,
84    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Debug + Send + Sync + Ord,
85{
86    let a_external_ids: HashSet<_> = a
87        .external_indices()
88        .into_iter()
89        .map(|idx| idx.id().clone())
90        .collect();
91    let b_external_ids: HashSet<_> = b
92        .external_indices()
93        .into_iter()
94        .map(|idx| idx.id().clone())
95        .collect();
96
97    let mut seen_a_ids = HashSet::new();
98    let mut seen_b_ids = HashSet::new();
99
100    for (kind, pairs) in [
101        ("contract_pairs", &spec.contract_pairs),
102        ("diagonal_pairs", &spec.diagonal_pairs),
103    ] {
104        for (idx_a, idx_b) in pairs {
105            if idx_a.dim() != idx_b.dim() {
106                bail!(
107                    "partial_contract: {} index dimension mismatch: {} != {}",
108                    kind,
109                    idx_a.dim(),
110                    idx_b.dim()
111                );
112            }
113
114            if !a_external_ids.contains(idx_a.id()) {
115                bail!(
116                    "partial_contract: {:?} from {} not found in first TreeTN external indices",
117                    idx_a.id(),
118                    kind
119                );
120            }
121            if !b_external_ids.contains(idx_b.id()) {
122                bail!(
123                    "partial_contract: {:?} from {} not found in second TreeTN external indices",
124                    idx_b.id(),
125                    kind
126                );
127            }
128
129            if !seen_a_ids.insert(idx_a.id().clone()) {
130                bail!(
131                    "partial_contract: first TreeTN index {:?} appears in multiple pairs",
132                    idx_a.id()
133                );
134            }
135            if !seen_b_ids.insert(idx_b.id().clone()) {
136                bail!(
137                    "partial_contract: second TreeTN index {:?} appears in multiple pairs",
138                    idx_b.id()
139                );
140            }
141        }
142    }
143
144    Ok(())
145}
146
147fn canonical_edge<V>(left: &V, right: &V) -> (V, V)
148where
149    V: Clone + Ord,
150{
151    if left <= right {
152        (left.clone(), right.clone())
153    } else {
154        (right.clone(), left.clone())
155    }
156}
157
158fn sorted_edge_set<V>(tn: &TreeTN<TensorDynLen, V>) -> Vec<(V, V)>
159where
160    V: Clone + Hash + Eq + Send + Sync + Debug + Ord,
161{
162    let mut edges: Vec<_> = tn
163        .site_index_network()
164        .edges()
165        .map(|(u, v)| canonical_edge(&u, &v))
166        .collect();
167    edges.sort();
168    edges.dedup();
169    edges
170}
171
172fn compatible_union_node_names<V>(
173    a: &TreeTN<TensorDynLen, V>,
174    b: &TreeTN<TensorDynLen, V>,
175) -> Vec<V>
176where
177    V: Clone + Hash + Eq + Send + Sync + Debug + Ord,
178{
179    let mut names: Vec<_> = a.node_names();
180    names.extend(b.node_names());
181    names.sort();
182    names.dedup();
183    names
184}
185
186fn validate_union_topology<V>(node_names: &[V], edges: &[(V, V)]) -> Result<()>
187where
188    V: Clone + Hash + Eq + Send + Sync + Debug + Ord,
189{
190    if node_names.is_empty() {
191        bail!("partial_contract: networks must contain at least one node");
192    }
193
194    if edges.len() + 1 != node_names.len() {
195        bail!("partial_contract: networks have incompatible topologies");
196    }
197
198    let mut adjacency: HashMap<V, Vec<V>> = node_names
199        .iter()
200        .cloned()
201        .map(|name| (name, Vec::new()))
202        .collect();
203    for (u, v) in edges {
204        let Some(neighbors_u) = adjacency.get_mut(u) else {
205            bail!("partial_contract: union topology references unknown node");
206        };
207        neighbors_u.push(v.clone());
208        let Some(neighbors_v) = adjacency.get_mut(v) else {
209            bail!("partial_contract: union topology references unknown node");
210        };
211        neighbors_v.push(u.clone());
212    }
213
214    let mut seen = HashSet::new();
215    let mut stack = vec![node_names[0].clone()];
216    while let Some(node) = stack.pop() {
217        if !seen.insert(node.clone()) {
218            continue;
219        }
220        if let Some(neighbors) = adjacency.get(&node) {
221            stack.extend(neighbors.iter().cloned());
222        }
223    }
224
225    if seen.len() != node_names.len() {
226        bail!("partial_contract: networks have incompatible topologies");
227    }
228
229    Ok(())
230}
231
232fn factorize_options_from_contraction_options(
233    options: &ContractionOptions,
234) -> Result<FactorizeOptions> {
235    let mut factorize_options = match options.factorize_alg {
236        FactorizeAlg::SVD => FactorizeOptions::svd(),
237        FactorizeAlg::QR => FactorizeOptions::qr(),
238        FactorizeAlg::LU => FactorizeOptions::lu(),
239        FactorizeAlg::CI => FactorizeOptions::ci(),
240    };
241    if let Some(policy) = options.svd_policy {
242        factorize_options = factorize_options.with_svd_policy(policy);
243    }
244    if let Some(rtol) = options.qr_rtol {
245        factorize_options = factorize_options.with_qr_rtol(rtol);
246    }
247    if let Some(max_rank) = options.max_rank {
248        factorize_options = factorize_options.with_max_rank(max_rank);
249    }
250    factorize_options.validate().map_err(|err| {
251        anyhow!("partial_contract: invalid contraction factorization options: {err}")
252    })?;
253    Ok(factorize_options)
254}
255
256fn union_result_topology<V>(
257    a: &TreeTN<TensorDynLen, V>,
258    b: &TreeTN<TensorDynLen, V>,
259    contracted_tensor: &TensorDynLen,
260) -> Result<TreeTopology<V, <DynIndex as IndexLike>::Id>>
261where
262    V: Clone + Hash + Eq + Send + Sync + Debug + Ord,
263    <DynIndex as IndexLike>::Id: Clone + Hash + Eq + Ord + Debug + Send + Sync,
264{
265    let node_names = compatible_union_node_names(a, b);
266    let mut union_edges = sorted_edge_set(a);
267    union_edges.extend(sorted_edge_set(b));
268    union_edges.sort();
269    union_edges.dedup();
270    validate_union_topology(&node_names, &union_edges)?;
271
272    let surviving_ids: HashSet<_> = contracted_tensor
273        .external_indices()
274        .into_iter()
275        .map(|idx| *idx.id())
276        .collect();
277
278    let mut nodes = HashMap::new();
279    for node_name in &node_names {
280        let mut ids = Vec::new();
281
282        if let Some(site_space_a) = a.site_index_network().site_space(node_name) {
283            for site_idx in site_space_a {
284                if surviving_ids.contains(site_idx.id()) {
285                    ids.push(*site_idx.id());
286                }
287            }
288        }
289
290        if let Some(site_space_b) = b.site_index_network().site_space(node_name) {
291            for site_idx in site_space_b {
292                if surviving_ids.contains(site_idx.id()) && !ids.contains(site_idx.id()) {
293                    ids.push(*site_idx.id());
294                }
295            }
296        }
297
298        nodes.insert(node_name.clone(), ids);
299    }
300
301    Ok(TreeTopology::new(nodes, union_edges))
302}
303
304fn contract_mismatched_topologies<V>(
305    a: &TreeTN<TensorDynLen, V>,
306    b: &TreeTN<TensorDynLen, V>,
307    center: &V,
308    options: ContractionOptions,
309) -> Result<TreeTN<TensorDynLen, V>>
310where
311    V: Clone + Hash + Eq + Send + Sync + Debug + Ord,
312    <DynIndex as IndexLike>::Id: Clone + Hash + Eq + Ord + Debug + Send + Sync,
313{
314    let a_dense = a
315        .sim_internal_inds()
316        .contract_to_tensor()
317        .context("partial_contract: failed to contract first mismatched-topology TreeTN")?;
318    let b_dense = b
319        .sim_internal_inds()
320        .contract_to_tensor()
321        .context("partial_contract: failed to contract second mismatched-topology TreeTN")?;
322    let contracted_tensor =
323        <TensorDynLen as TensorLike>::contract(&[&a_dense, &b_dense], AllowedPairs::All)
324            .context("partial_contract: failed dense contraction for mismatched topologies")?;
325
326    if contracted_tensor.external_indices().is_empty() {
327        let mut result = TreeTN::<TensorDynLen, V>::new();
328        result
329            .add_tensor(center.clone(), contracted_tensor)
330            .context("partial_contract: failed to wrap scalar mismatched-topology result")?;
331        result
332            .set_canonical_region([center.clone()])
333            .context("partial_contract: failed to set canonical region for scalar result")?;
334        return Ok(result);
335    }
336
337    let topology = union_result_topology(a, b, &contracted_tensor)?;
338    let factorize_options = factorize_options_from_contraction_options(&options)?;
339    factorize_tensor_to_treetn_with(&contracted_tensor, &topology, factorize_options, center)
340        .context("partial_contract: failed to factorize mismatched-topology dense result")
341}
342
343fn apply_output_order<T, V>(result: TreeTN<T, V>, output_order: &[T::Index]) -> Result<TreeTN<T, V>>
344where
345    T: TensorLike,
346    V: Clone + Hash + Eq + Send + Sync + Debug + Ord,
347    T::Index: Clone + Hash + Eq,
348    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + Debug + Send + Sync,
349{
350    let (current_indices, _) = result.all_site_indices()?;
351    if output_order.len() != current_indices.len() {
352        bail!(
353            "partial_contract: output_order length {} does not match surviving external index count {}",
354            output_order.len(),
355            current_indices.len()
356        );
357    }
358
359    let current_ids: HashSet<_> = current_indices.iter().map(|idx| idx.id().clone()).collect();
360    let requested_ids: HashSet<_> = output_order.iter().map(|idx| idx.id().clone()).collect();
361    if current_ids != requested_ids {
362        bail!("partial_contract: output_order must contain exactly the surviving external indices");
363    }
364
365    let mut current_nodes = Vec::with_capacity(current_indices.len());
366    for index in &current_indices {
367        let node = result.site_index_network().find_node_by_index(index).ok_or_else(|| {
368            anyhow!(
369                "partial_contract: current result index {:?} is not present in the site index network",
370                index.id()
371            )
372        })?;
373        current_nodes.push(node.clone());
374    }
375
376    let unique_current_nodes: HashSet<_> = current_nodes.iter().cloned().collect();
377    if unique_current_nodes.len() != current_nodes.len() {
378        bail!(
379            "partial_contract: output_order currently requires at most one surviving site index per node"
380        );
381    }
382
383    let mut seen_requested = HashSet::new();
384    let mut ordered_nodes = Vec::with_capacity(result.node_count());
385    let mut ordered_node_set = HashSet::new();
386
387    for index in output_order {
388        if !seen_requested.insert(index.id().clone()) {
389            bail!("partial_contract: output_order contains duplicate indices");
390        }
391        let current_node = result
392            .site_index_network()
393            .find_node_by_index(index)
394            .ok_or_else(|| {
395                anyhow!(
396                    "partial_contract: output_order index {:?} is not present in the result",
397                    index.id()
398                )
399            })?;
400        if !ordered_node_set.insert(current_node.clone()) {
401            bail!(
402                "partial_contract: output_order currently requires each requested index to occupy a distinct node"
403            );
404        }
405        ordered_nodes.push(current_node.clone());
406    }
407
408    for node_name in result.node_names() {
409        if ordered_node_set.insert(node_name.clone()) {
410            ordered_nodes.push(node_name);
411        }
412    }
413
414    let tensors = ordered_nodes
415        .iter()
416        .map(|node_name| {
417            let node_idx = result.node_index(node_name).ok_or_else(|| {
418                anyhow!(
419                    "partial_contract: output_order node {:?} is not present in the result",
420                    node_name
421                )
422            })?;
423            result.tensor(node_idx).cloned().ok_or_else(|| {
424                anyhow!(
425                    "partial_contract: tensor for output_order node {:?} is missing",
426                    node_name
427                )
428            })
429        })
430        .collect::<Result<Vec<_>>>()?;
431
432    let mut reordered = TreeTN::from_tensors(tensors, ordered_nodes)
433        .context("partial_contract: failed to rebuild result in requested output order")?;
434    reordered.canonical_region = result.canonical_region.clone();
435    reordered.canonical_form = result.canonical_form;
436    reordered.ortho_towards = result.ortho_towards.clone();
437    Ok(reordered)
438}
439
440fn diagonal_copy_value(tensor: &TensorDynLen) -> AnyScalar {
441    if tensor.is_complex() {
442        AnyScalar::new_complex(1.0, 0.0)
443    } else {
444        AnyScalar::new_real(1.0)
445    }
446}
447
448fn apply_diagonal_pairs<V>(
449    a: &TreeTN<TensorDynLen, V>,
450    b: &TreeTN<TensorDynLen, V>,
451    diagonal_pairs: &[(DynIndex, DynIndex)],
452) -> Result<DiagonalPairApplication<V>>
453where
454    V: Clone + Hash + Eq + Send + Sync + Debug + Ord,
455    <DynIndex as IndexLike>::Id: Clone + Hash + Eq + Ord + Debug + Send + Sync,
456{
457    let mut a_modified = a.clone();
458    let mut b_modified = b.clone();
459    let mut restore_from = Vec::with_capacity(diagonal_pairs.len());
460    let mut restore_to = Vec::with_capacity(diagonal_pairs.len());
461
462    for (idx_a, idx_b) in diagonal_pairs {
463        let node_name = a_modified
464            .site_index_network()
465            .find_node_by_index(idx_a)
466            .cloned()
467            .ok_or_else(|| {
468                anyhow!(
469                    "partial_contract: diagonal pair left index {:?} is not a site index of the first TreeTN",
470                    idx_a.id()
471                )
472            })?;
473        let node_idx = a_modified.node_index(&node_name).ok_or_else(|| {
474            anyhow!(
475                "partial_contract: node {:?} for left diagonal index {:?} not found",
476                node_name,
477                idx_a.id()
478            )
479        })?;
480        let local_tensor = a_modified.tensor(node_idx).cloned().ok_or_else(|| {
481            anyhow!(
482                "partial_contract: tensor for node {:?} not found while processing diagonal pair {:?}",
483                node_name,
484                idx_a.id()
485            )
486        })?;
487
488        let aux_index = idx_a.sim();
489        let kept_index = idx_a.sim();
490        let copy_tensor = TensorDynLen::copy_tensor(
491            vec![idx_a.clone(), aux_index.clone(), kept_index.clone()],
492            diagonal_copy_value(&local_tensor),
493        )
494        .with_context(|| {
495            format!(
496                "partial_contract: failed to build copy tensor for diagonal pair {:?} <- {:?}",
497                idx_a.id(),
498                idx_b.id()
499            )
500        })?;
501        let expanded_tensor = local_tensor
502            .tensordot(&copy_tensor, &[(idx_a.clone(), idx_a.clone())])
503            .with_context(|| {
504                format!(
505                    "partial_contract: failed to apply diagonal structure for pair {:?} <- {:?}",
506                    idx_a.id(),
507                    idx_b.id()
508                )
509            })?;
510        a_modified
511            .replace_tensor(node_idx, expanded_tensor)
512            .with_context(|| {
513                format!(
514                    "partial_contract: failed to replace tensor at node {:?} for diagonal pair {:?}",
515                    node_name,
516                    idx_a.id()
517                )
518            })?
519            .ok_or_else(|| {
520                anyhow!(
521                    "partial_contract: node {:?} disappeared while processing diagonal pair {:?}",
522                    node_name,
523                    idx_a.id()
524                )
525            })?;
526
527        b_modified = b_modified.replaceind(idx_b, &aux_index).with_context(|| {
528            format!(
529                "partial_contract: failed to align diagonal pair {:?} <- {:?}",
530                idx_a.id(),
531                idx_b.id()
532            )
533        })?;
534
535        restore_from.push(kept_index);
536        restore_to.push(idx_a.clone());
537    }
538
539    Ok((a_modified, b_modified, restore_from, restore_to))
540}
541
542/// Partially contract two TreeTNs according to the given specification.
543///
544/// # Arguments
545/// * `a` - First tensor network
546/// * `b` - Second tensor network
547/// * `spec` - Which site indices to contract versus link through diagonal
548///   structure
549/// * `center` - Canonical center node for the result
550/// * `options` - Contraction algorithm options
551///
552/// # Index handling
553///
554/// - **contract_pairs**: Both indices are traced over (inner product).
555///   Neither appears in the result.
556/// - **diagonal_pairs**: The two indices are linked through explicit diagonal
557///   structure so that only matching values contribute, while the left-hand site
558///   index remains in the result.
559/// - **Unmentioned indices**: Pass through unchanged as external legs.
560///
561/// # Examples
562///
563/// ```no_run
564/// use tensor4all_core::{DynIndex, TensorDynLen};
565/// use tensor4all_treetn::{
566///     contraction::ContractionOptions,
567///     partial_contract,
568///     PartialContractionSpec,
569///     TreeTN,
570/// };
571///
572/// let idx_a = DynIndex::new_dyn(2);
573/// let idx_b = DynIndex::new_dyn(2);
574/// let a = TreeTN::<TensorDynLen, usize>::from_tensors(
575///     vec![TensorDynLen::from_dense(vec![idx_a.clone()], vec![1.0, 2.0]).unwrap()],
576///     vec![0],
577/// ).unwrap();
578/// let b = TreeTN::<TensorDynLen, usize>::from_tensors(
579///     vec![TensorDynLen::from_dense(vec![idx_b.clone()], vec![3.0, 4.0]).unwrap()],
580///     vec![0],
581/// ).unwrap();
582///
583/// let spec = PartialContractionSpec {
584///     contract_pairs: vec![(idx_a.clone(), idx_b.clone())],
585///     diagonal_pairs: vec![],
586///     output_order: None,
587/// };
588///
589/// let result = partial_contract(&a, &b, &spec, &0usize, ContractionOptions::default()).unwrap();
590/// assert_eq!(result.node_count(), 1);
591/// ```
592pub fn partial_contract<V>(
593    a: &TreeTN<TensorDynLen, V>,
594    b: &TreeTN<TensorDynLen, V>,
595    spec: &PartialContractionSpec<DynIndex>,
596    center: &V,
597    options: ContractionOptions,
598) -> Result<TreeTN<TensorDynLen, V>>
599where
600    V: Clone + Hash + Eq + Send + Sync + Debug + Ord,
601    <DynIndex as IndexLike>::Id: Clone + Hash + Eq + Ord + Debug + Send + Sync,
602{
603    validate_partial_contraction_spec(a, b, spec)?;
604
605    let (a_modified, mut b_modified, restore_from, restore_to) =
606        apply_diagonal_pairs(a, b, &spec.diagonal_pairs)?;
607
608    for (idx_a, idx_b) in &spec.contract_pairs {
609        b_modified = b_modified.replaceind(idx_b, idx_a).with_context(|| {
610            format!(
611                "partial_contract: failed to align contract pair {:?} <- {:?}",
612                idx_a.id(),
613                idx_b.id()
614            )
615        })?;
616    }
617
618    let mut result = if a_modified.same_topology(&b_modified) {
619        contract(&a_modified, &b_modified, center, options)
620            .context("partial_contract: contraction failed")?
621    } else {
622        contract_mismatched_topologies(&a_modified, &b_modified, center, options)?
623    };
624
625    if !restore_from.is_empty() {
626        result = result.replaceinds(&restore_from, &restore_to).context(
627            "partial_contract: failed to restore surviving left-hand indices after diagonal pairing",
628        )?;
629    }
630
631    if let Some(output_order) = &spec.output_order {
632        apply_output_order(result, output_order)
633    } else {
634        Ok(result)
635    }
636}
637
638#[cfg(test)]
639mod tests;