Skip to main content

tensor4all_itensorlike/
linsolve.rs

1//! Linear equation solver for tensor trains.
2//!
3//! This module provides the [`linsolve`] function for solving linear systems
4//! of the form `(a₀ + a₁ * A) * x = b` where `A` is an MPO (TensorTrain),
5//! and `x`, `b` are MPS (TensorTrain).
6//!
7//! Internally delegates to [`tensor4all_treetn::square_linsolve`].
8
9use std::collections::HashMap;
10
11use tensor4all_core::{DynIndex, IndexLike};
12use tensor4all_treetn::{square_linsolve, IndexMapping, TruncationOptions};
13
14use crate::error::{Result, TensorTrainError};
15use crate::options::{validate_svd_truncation_options, CanonicalForm, LinsolveOptions};
16use crate::tensortrain::TensorTrain;
17
18/// Solve `(a₀ + a₁ * A) * x = b` for `x`.
19///
20/// # Arguments
21/// * `operator` - The operator `A` (MPO as TensorTrain)
22/// * `rhs` - The right-hand side `b` (MPS as TensorTrain)
23/// * `init` - Initial guess for `x` (MPS as TensorTrain, consumed)
24/// * `options` - Solver options
25///
26/// # Returns
27/// The solution `x` as a TensorTrain.
28///
29/// # Errors
30/// Returns an error if:
31/// - Any tensor train is empty
32/// - `nhalfsweeps` is not a multiple of 2
33/// - The solver fails internally
34pub fn linsolve(
35    operator: &TensorTrain,
36    rhs: &TensorTrain,
37    init: TensorTrain,
38    options: &LinsolveOptions,
39) -> Result<TensorTrain> {
40    if operator.is_empty() || rhs.is_empty() || init.is_empty() {
41        return Err(TensorTrainError::InvalidStructure {
42            message: "Cannot linsolve with empty tensor trains".to_string(),
43        });
44    }
45
46    if !options.nhalfsweeps().is_multiple_of(2) {
47        return Err(TensorTrainError::OperationError {
48            message: format!(
49                "nhalfsweeps must be a multiple of 2, got {}",
50                options.nhalfsweeps()
51            ),
52        });
53    }
54
55    validate_svd_truncation_options(options.max_rank(), options.svd_policy())?;
56
57    if !options.krylov_tol().is_finite() || options.krylov_tol() <= 0.0 {
58        return Err(TensorTrainError::OperationError {
59            message: format!(
60                "krylov_tol must be finite and > 0, got {}",
61                options.krylov_tol()
62            ),
63        });
64    }
65
66    if options.krylov_maxiter() == 0 {
67        return Err(TensorTrainError::OperationError {
68            message: "krylov_maxiter must be >= 1".to_string(),
69        });
70    }
71
72    if options.krylov_dim() == 0 {
73        return Err(TensorTrainError::OperationError {
74            message: "krylov_dim must be >= 1".to_string(),
75        });
76    }
77
78    if let Some(tol) = options.convergence_tol() {
79        if !tol.is_finite() || tol < 0.0 {
80            return Err(TensorTrainError::OperationError {
81                message: format!("convergence_tol must be finite and >= 0, got {}", tol),
82            });
83        }
84    }
85
86    // Convert LinsolveOptions → treetn::LinsolveOptions
87    let nfullsweeps = options.nhalfsweeps() / 2;
88
89    let treetn_options = tensor4all_treetn::LinsolveOptions::new(nfullsweeps)
90        .with_truncation(TruncationOptions::new())
91        .with_krylov_tol(options.krylov_tol())
92        .with_krylov_maxiter(options.krylov_maxiter())
93        .with_krylov_dim(options.krylov_dim())
94        .with_coefficients(options.coefficients().0, options.coefficients().1);
95
96    let treetn_options = if let Some(policy) = options.svd_policy() {
97        treetn_options.with_svd_policy(policy)
98    } else {
99        treetn_options
100    };
101
102    let treetn_options = if let Some(max_rank) = options.max_rank() {
103        treetn_options.with_max_rank(max_rank)
104    } else {
105        treetn_options
106    };
107
108    let treetn_options = if let Some(tol) = options.convergence_tol() {
109        treetn_options.with_convergence_tol(tol)
110    } else {
111        treetn_options
112    };
113
114    // Use the last site as the sweep center
115    let center = init.len() - 1;
116
117    // Auto-infer index mappings when the operator (MPO) has distinct input/output
118    // site indices. For each site, the MPO has 2 site indices while the MPS has 1.
119    // We find the MPO index that shares an ID with init's site index (input) and
120    // the remaining one (output).
121    let (input_mapping, output_mapping) =
122        infer_index_mappings(operator, &init).map_err(|e| TensorTrainError::OperationError {
123            message: format!("Failed to infer index mappings: {}", e),
124        })?;
125
126    let result = square_linsolve(
127        operator.as_treetn(),
128        rhs.as_treetn(),
129        init.treetn,
130        &center,
131        treetn_options,
132        input_mapping,
133        output_mapping,
134    )
135    .map_err(|e| TensorTrainError::OperationError {
136        message: format!("Linsolve failed: {}", e),
137    })?;
138
139    TensorTrain::from_inner(result.solution, Some(CanonicalForm::Unitary))
140}
141
142type SiteMappings = (
143    Option<HashMap<usize, IndexMapping<DynIndex>>>,
144    Option<HashMap<usize, IndexMapping<DynIndex>>>,
145);
146
147/// Infer index mappings from the MPO/MPS structure.
148///
149/// For each site where the operator has exactly 2 site indices and init has 1,
150/// finds the operator index sharing an ID with init's index (input) and the
151/// remaining one (output). Returns `(None, None)` when no mappings are needed
152/// (operator and init share all site indices).
153fn infer_index_mappings(
154    operator: &TensorTrain,
155    init: &TensorTrain,
156) -> std::result::Result<SiteMappings, String> {
157    let op_treetn = operator.as_treetn();
158    let init_treetn = init.as_treetn();
159    let nsites = init.len();
160
161    let mut needs_mapping = false;
162
163    // First pass: check if any site needs mappings
164    for site in 0..nsites {
165        let op_site = op_treetn.site_space(&site);
166        let init_site = init_treetn.site_space(&site);
167
168        if let (Some(op_indices), Some(init_indices)) = (op_site, init_site) {
169            if op_indices.len() == 2 && init_indices.len() == 1 {
170                // MPO-like site: check if we need mappings
171                let init_idx = init_indices.iter().next().unwrap();
172                let has_shared = op_indices.iter().any(|idx| idx.same_id(init_idx));
173                if has_shared {
174                    // The shared index exists but may differ by plev → need mapping
175                    let input_idx = op_indices.iter().find(|idx| idx.same_id(init_idx));
176                    if let Some(input_idx) = input_idx {
177                        if input_idx != init_idx {
178                            needs_mapping = true;
179                        }
180                    }
181                    // Check if output index differs from init
182                    let output_idx = op_indices.iter().find(|idx| !idx.same_id(init_idx));
183                    if output_idx.is_some() {
184                        needs_mapping = true;
185                    }
186                } else {
187                    return Err(format!(
188                        "Site {}: operator has 2 site indices but none share an ID with init's \
189                         site index {:?}. Cannot auto-infer index mappings. \
190                         Use the treetn-level API with explicit IndexMapping.",
191                        site,
192                        init_idx.id()
193                    ));
194                }
195            }
196        }
197    }
198
199    if !needs_mapping {
200        return Ok((None, None));
201    }
202
203    // Second pass: build mappings
204    let mut input_mapping = HashMap::new();
205    let mut output_mapping = HashMap::new();
206
207    for site in 0..nsites {
208        let op_site = op_treetn.site_space(&site);
209        let init_site = init_treetn.site_space(&site);
210
211        if let (Some(op_indices), Some(init_indices)) = (op_site, init_site) {
212            if op_indices.len() == 2 && init_indices.len() == 1 {
213                let init_idx = init_indices.iter().next().unwrap();
214
215                let op_input = op_indices.iter().find(|idx| idx.same_id(init_idx)).unwrap();
216                let op_output = op_indices
217                    .iter()
218                    .find(|idx| !idx.same_id(init_idx))
219                    .unwrap();
220
221                input_mapping.insert(
222                    site,
223                    IndexMapping {
224                        true_index: init_idx.clone(),
225                        internal_index: op_input.clone(),
226                    },
227                );
228                output_mapping.insert(
229                    site,
230                    IndexMapping {
231                        true_index: init_idx.clone(),
232                        internal_index: op_output.clone(),
233                    },
234                );
235            }
236        }
237    }
238
239    Ok((Some(input_mapping), Some(output_mapping)))
240}
241
242impl TensorTrain {
243    /// Solve `(a₀ + a₁ * A) * x = b` for `x`.
244    ///
245    /// `self` is the operator `A`, `rhs` is `b`, `init` is the initial guess.
246    ///
247    /// See [`linsolve`] for details.
248    pub fn linsolve(&self, rhs: &Self, init: Self, options: &LinsolveOptions) -> Result<Self> {
249        linsolve(self, rhs, init, options)
250    }
251}