tensor4all_simplett/mpo/dispatch.rs
1//! Algorithm dispatch for MPO contraction
2//!
3//! Provides a unified `contract` function that dispatches to the appropriate
4//! algorithm based on [`ContractionAlgorithm`].
5
6/// Algorithm for MPO contraction.
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
8pub enum ContractionAlgorithm {
9 /// Naive contraction (exact but memory-intensive)
10 #[default]
11 Naive,
12 /// Zip-up contraction with on-the-fly compression
13 ZipUp,
14 /// Variational fitting algorithm
15 Fit,
16}
17
18use super::contract_fit::{contract_fit, FitOptions};
19use super::contract_naive::contract_naive;
20use super::contract_zipup::contract_zipup;
21use super::contraction::ContractionOptions;
22use super::error::Result;
23use super::factorize::SVDScalar;
24use super::mpo::MPO;
25use crate::einsum_helper::EinsumScalar;
26
27/// Unified contraction function with algorithm dispatch
28///
29/// Contracts two MPOs using the specified algorithm.
30///
31/// # Arguments
32/// * `mpo_a` - First MPO
33/// * `mpo_b` - Second MPO
34/// * `algorithm` - Which algorithm to use
35/// * `options` - Contraction options (tolerance, max_bond_dim, etc.)
36///
37/// # Returns
38/// The contracted MPO
39///
40/// # Example
41///
42/// ```
43/// use tensor4all_simplett::mpo::{contract, MPO, ContractionOptions, ContractionAlgorithm};
44///
45/// let mpo_a = MPO::<f64>::identity(&[2, 2]).unwrap();
46/// let mpo_b = MPO::<f64>::identity(&[2, 2]).unwrap();
47/// let options = ContractionOptions::default();
48///
49/// // Use naive algorithm
50/// let result = contract(&mpo_a, &mpo_b, ContractionAlgorithm::Naive, &options).unwrap();
51/// assert_eq!(result.site_dims(), vec![(2, 2), (2, 2)]);
52///
53/// // Use zip-up algorithm for memory efficiency
54/// let result = contract(&mpo_a, &mpo_b, ContractionAlgorithm::ZipUp, &options).unwrap();
55/// assert_eq!(result.len(), 2);
56///
57/// // Use variational fitting for controlled bond dimension
58/// let result = contract(&mpo_a, &mpo_b, ContractionAlgorithm::Fit, &options).unwrap();
59/// assert_eq!(result.len(), 2);
60/// ```
61pub fn contract<T: SVDScalar + EinsumScalar>(
62 mpo_a: &MPO<T>,
63 mpo_b: &MPO<T>,
64 algorithm: ContractionAlgorithm,
65 options: &ContractionOptions,
66) -> Result<MPO<T>>
67where
68 <T as num_complex::ComplexFloat>::Real: Into<f64>,
69{
70 match algorithm {
71 ContractionAlgorithm::Naive => contract_naive(mpo_a, mpo_b, Some(options.clone())),
72 ContractionAlgorithm::ZipUp => contract_zipup(mpo_a, mpo_b, options),
73 ContractionAlgorithm::Fit => {
74 let fit_options = FitOptions {
75 tolerance: options.tolerance,
76 max_bond_dim: options.max_bond_dim,
77 factorize_method: options.factorize_method,
78 ..Default::default()
79 };
80 contract_fit(mpo_a, mpo_b, &fit_options, None)
81 }
82 }
83}
84
85#[cfg(test)]
86mod tests;