Skip to main content

tensor4all_treetn/linsolve/square/
mod.rs

1//! Linear equation solver for Tree Tensor Networks (V_in = V_out case).
2//!
3//! This module provides the `square_linsolve` function for solving linear systems
4//! of the form `(a₀ + a₁ * A) * x = b` where A is a TTN operator and x, b are TTN states,
5//! and the input and output spaces are the same (V_in = V_out).
6//!
7//! # Algorithm
8//!
9//! The algorithm uses alternating updates (sweeping) similar to DMRG:
10//! 1. Position environments to expose a local region
11//! 2. Solve the local linear problem using GMRES (via kryst)
12//! 3. Factorize the result and move the orthogonality center
13//! 4. Update environment caches
14//!
15//! # Inspired By
16//!
17//! This implementation is inspired by:
18//! - [ITensorMPS.jl](https://github.com/ITensor/ITensorMPS.jl) - Core algorithm structure
19//! - [KrylovKit.jl](https://github.com/Jutho/KrylovKit.jl) - Krylov solver integration pattern
20//! - [kryst](https://github.com/tmathis720/kryst) - Rust GMRES implementation
21//!
22//! # References
23//!
24//! - Phys. Rev. B 72, 180403 (2005) - Noise term technique (not implemented in initial version)
25
26mod local_linop;
27mod projected_state;
28mod updater;
29
30pub use projected_state::ProjectedState;
31pub use updater::{LinsolveVerifyReport, NodeVerifyDetail, SquareLinsolveUpdater};
32
33use std::collections::HashMap;
34use std::hash::Hash;
35
36use anyhow::Result;
37
38use tensor4all_core::{IndexLike, TensorLike};
39
40use crate::linsolve::common::LinsolveOptions;
41use crate::operator::IndexMapping;
42use crate::{apply_local_update_sweep, CanonicalizationOptions, LocalUpdateSweepPlan, TreeTN};
43
44/// Result of square_linsolve operation.
45#[derive(Debug, Clone)]
46pub struct SquareLinsolveResult<T, V>
47where
48    T: TensorLike,
49    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
50{
51    /// The solution TreeTN
52    pub solution: TreeTN<T, V>,
53    /// Number of sweeps performed
54    pub sweeps: usize,
55    /// Final residual norm (if computed)
56    pub residual: Option<f64>,
57    /// Converged flag
58    pub converged: bool,
59}
60
61/// Validate that operator, rhs, and init have compatible structures for linsolve.
62///
63/// Checks:
64/// 1. Operator can act on init (same topology)
65/// 2. Result of operator action has compatible site dimensions with rhs
66fn validate_linsolve_inputs<T, V>(
67    operator: &TreeTN<T, V>,
68    rhs: &TreeTN<T, V>,
69    init: &TreeTN<T, V>,
70) -> Result<()>
71where
72    T: TensorLike,
73    <T::Index as tensor4all_core::IndexLike>::Id:
74        Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
75    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
76{
77    let init_network = init.site_index_network();
78    let op_network = operator.site_index_network();
79    let rhs_network = rhs.site_index_network();
80
81    // Check 1: Operator can act on init
82    let result_network = init_network
83        .apply_operator_topology(op_network)
84        .map_err(|e| anyhow::anyhow!("Operator cannot act on init: {}", e))?;
85
86    // Check 2: Result has compatible dimensions with rhs
87    if !result_network.compatible_site_dimensions(rhs_network) {
88        return Err(anyhow::anyhow!(
89            "Result of operator action is not compatible with RHS"
90        ));
91    }
92
93    Ok(())
94}
95
96/// Solve the linear system `(a₀ + a₁ * H) |x⟩ = |b⟩` for TreeTN.
97///
98/// This solver is for the square case where V_in = V_out (input and output spaces are the same).
99///
100/// # Arguments
101///
102/// * `operator` - The operator H as a TreeTN (must have compatible structure with `rhs`)
103/// * `rhs` - The right-hand side |b⟩ as a TreeTN
104/// * `init` - Initial guess for |x⟩
105/// * `center` - Node to use as sweep center
106/// * `options` - Solver options
107/// * `input_mapping` - Optional per-node mapping from state site index to operator input index.
108///   Required when the operator (MPO) uses internal indices distinct from the state's site indices.
109/// * `output_mapping` - Optional per-node mapping from state site index to operator output index.
110///   Required when the operator (MPO) uses internal indices distinct from the state's site indices.
111///
112/// # Returns
113///
114/// The solution TreeTN, or an error if solving fails.
115///
116/// # Example
117///
118/// ```no_run
119/// use tensor4all_core::{DynIndex, TensorDynLen};
120/// use tensor4all_treetn::{square_linsolve, LinsolveOptions, TreeTN};
121///
122/// # fn main() -> anyhow::Result<()> {
123/// let s = DynIndex::new_dyn(2);
124/// let operator_tensor = TensorDynLen::from_dense(vec![s.clone()], vec![1.0, 1.0])?;
125/// let rhs_tensor = TensorDynLen::from_dense(vec![s.clone()], vec![1.0, 2.0])?;
126/// let init_tensor = TensorDynLen::from_dense(vec![s.clone()], vec![0.0, 0.0])?;
127///
128/// let operator = TreeTN::<TensorDynLen, usize>::from_tensors(vec![operator_tensor], vec![0])?;
129/// let rhs = TreeTN::<TensorDynLen, usize>::from_tensors(vec![rhs_tensor], vec![0])?;
130/// let init = TreeTN::<TensorDynLen, usize>::from_tensors(vec![init_tensor], vec![0])?;
131///
132/// let result = square_linsolve(&operator, &rhs, init, &0usize, LinsolveOptions::default(), None, None)?;
133/// assert_eq!(result.solution.node_count(), 1);
134/// # Ok(())
135/// # }
136/// ```
137pub fn square_linsolve<T, V>(
138    operator: &TreeTN<T, V>,
139    rhs: &TreeTN<T, V>,
140    init: TreeTN<T, V>,
141    center: &V,
142    options: LinsolveOptions,
143    input_mapping: Option<HashMap<V, IndexMapping<T::Index>>>,
144    output_mapping: Option<HashMap<V, IndexMapping<T::Index>>>,
145) -> Result<SquareLinsolveResult<T, V>>
146where
147    T: TensorLike + 'static,
148    <T::Index as IndexLike>::Id:
149        Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync + 'static,
150    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug + 'static,
151{
152    // Validate inputs before proceeding
153    validate_linsolve_inputs(operator, rhs, &init)?;
154
155    // Canonicalize initial guess towards center
156    let mut x = init.canonicalize([center.clone()], CanonicalizationOptions::default())?;
157
158    // Create SquareLinsolveUpdater with or without index mappings
159    let mut updater = match (input_mapping, output_mapping) {
160        (Some(input), Some(output)) => SquareLinsolveUpdater::with_index_mappings(
161            operator.clone(),
162            input,
163            output,
164            rhs.clone(),
165            options.clone(),
166        ),
167        (None, None) => SquareLinsolveUpdater::new(operator.clone(), rhs.clone(), options.clone()),
168        _ => {
169            return Err(anyhow::anyhow!(
170                "input_mapping and output_mapping must both be Some or both be None"
171            ));
172        }
173    };
174
175    // Create sweep plan (nsite=2 for 2-site updates)
176    let plan = LocalUpdateSweepPlan::from_treetn(&x, center, 2)
177        .ok_or_else(|| anyhow::anyhow!("Failed to create sweep plan"))?;
178
179    let mut final_sweeps = 0;
180
181    // Perform sweeps
182    for sweep in 0..options.nfullsweeps {
183        final_sweeps = sweep + 1;
184        apply_local_update_sweep(&mut x, &plan, &mut updater)?;
185    }
186
187    // Note: Residual computation (||Hx - b|| / ||b||) and convergence checking
188    // are not yet implemented. Currently, all requested sweeps are performed.
189    Ok(SquareLinsolveResult {
190        solution: x,
191        sweeps: final_sweeps,
192        residual: None,
193        converged: false,
194    })
195}
196
197#[cfg(test)]
198mod tests;