Skip to main content

tensor4all_simplett/mpo/
dispatch.rs

1//! Algorithm dispatch for MPO contraction
2//!
3//! Provides a unified `contract` function that dispatches to the appropriate
4//! algorithm based on [`ContractionAlgorithm`].
5
6/// Algorithm for MPO contraction.
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
8pub enum ContractionAlgorithm {
9    /// Naive contraction (exact but memory-intensive)
10    #[default]
11    Naive,
12    /// Zip-up contraction with on-the-fly compression
13    ZipUp,
14    /// Variational fitting algorithm
15    Fit,
16}
17
18use super::contract_fit::{contract_fit, FitOptions};
19use super::contract_naive::contract_naive;
20use super::contract_zipup::contract_zipup;
21use super::contraction::ContractionOptions;
22use super::error::Result;
23use super::factorize::SVDScalar;
24use super::mpo::MPO;
25use crate::einsum_helper::EinsumScalar;
26
27/// Unified contraction function with algorithm dispatch
28///
29/// Contracts two MPOs using the specified algorithm.
30///
31/// # Arguments
32/// * `mpo_a` - First MPO
33/// * `mpo_b` - Second MPO
34/// * `algorithm` - Which algorithm to use
35/// * `options` - Contraction options (tolerance, max_bond_dim, etc.)
36///
37/// # Returns
38/// The contracted MPO
39///
40/// # Example
41///
42/// ```
43/// use tensor4all_simplett::mpo::{contract, MPO, ContractionOptions, ContractionAlgorithm};
44///
45/// let mpo_a = MPO::<f64>::identity(&[2, 2]).unwrap();
46/// let mpo_b = MPO::<f64>::identity(&[2, 2]).unwrap();
47/// let options = ContractionOptions::default();
48///
49/// // Use naive algorithm
50/// let result = contract(&mpo_a, &mpo_b, ContractionAlgorithm::Naive, &options).unwrap();
51/// assert_eq!(result.site_dims(), vec![(2, 2), (2, 2)]);
52///
53/// // Use zip-up algorithm for memory efficiency
54/// let result = contract(&mpo_a, &mpo_b, ContractionAlgorithm::ZipUp, &options).unwrap();
55/// assert_eq!(result.len(), 2);
56///
57/// // Use variational fitting for controlled bond dimension
58/// let result = contract(&mpo_a, &mpo_b, ContractionAlgorithm::Fit, &options).unwrap();
59/// assert_eq!(result.len(), 2);
60/// ```
61pub fn contract<T: SVDScalar + EinsumScalar>(
62    mpo_a: &MPO<T>,
63    mpo_b: &MPO<T>,
64    algorithm: ContractionAlgorithm,
65    options: &ContractionOptions,
66) -> Result<MPO<T>>
67where
68    <T as num_complex::ComplexFloat>::Real: Into<f64>,
69{
70    match algorithm {
71        ContractionAlgorithm::Naive => contract_naive(mpo_a, mpo_b, Some(options.clone())),
72        ContractionAlgorithm::ZipUp => contract_zipup(mpo_a, mpo_b, options),
73        ContractionAlgorithm::Fit => {
74            let fit_options = FitOptions {
75                tolerance: options.tolerance,
76                max_bond_dim: options.max_bond_dim,
77                factorize_method: options.factorize_method,
78                ..Default::default()
79            };
80            contract_fit(mpo_a, mpo_b, &fit_options, None)
81        }
82    }
83}
84
85#[cfg(test)]
86mod tests;