Skip to main content

tensor4all_core/
krylov.rs

1//! Krylov subspace methods for solving linear equations with abstract tensors.
2//!
3//! This module provides iterative solvers that work with any type implementing [`TensorLike`],
4//! enabling their use in tensor network algorithms without requiring dense vector representations.
5//!
6//! # Solvers
7//!
8//! - [`gmres`]: Generalized Minimal Residual Method (GMRES) for non-symmetric systems
9//!
10//! # Future Extensions
11//!
12//! - CG (Conjugate Gradient) for symmetric positive definite systems
13//! - BiCGSTAB for non-symmetric systems with better convergence properties
14//!
15//! # Example
16//!
17//! ```
18//! use tensor4all_core::{
19//!     krylov::{gmres, GmresOptions},
20//!     DynIndex, TensorDynLen, TensorLike,
21//! };
22//!
23//! # fn main() -> anyhow::Result<()> {
24//! let i = DynIndex::new_dyn(2);
25//! let rhs = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, -1.0])?;
26//! let initial_guess = TensorDynLen::from_dense(vec![i.clone()], vec![0.0, 0.0])?;
27//!
28//! let apply_operator = |x: &TensorDynLen| Ok(x.clone());
29//! let result = gmres(apply_operator, &rhs, &initial_guess, &GmresOptions::default())?;
30//!
31//! assert!(result.converged);
32//! assert!(result.solution.sub(&rhs)?.maxabs() < 1e-12);
33//! # Ok(())
34//! # }
35//! ```
36
37use crate::any_scalar::AnyScalar;
38use crate::TensorLike;
39use anyhow::Result;
40
41/// Options for GMRES solver.
42///
43/// # Examples
44///
45/// ```
46/// use tensor4all_core::krylov::GmresOptions;
47///
48/// let opts = GmresOptions {
49///     max_iter: 50,
50///     rtol: 1e-8,
51///     max_restarts: 5,
52///     verbose: false,
53///     check_true_residual: true,
54/// };
55/// assert_eq!(opts.max_iter, 50);
56/// assert_eq!(opts.rtol, 1e-8);
57/// ```
58#[derive(Debug, Clone)]
59pub struct GmresOptions {
60    /// Maximum number of iterations (restart cycle length).
61    /// Default: 100
62    pub max_iter: usize,
63
64    /// Convergence tolerance for relative residual norm.
65    /// The solver stops when `||r|| / ||b|| < rtol`.
66    /// Default: 1e-10
67    pub rtol: f64,
68
69    /// Maximum number of restarts.
70    /// Total iterations = max_iter * max_restarts.
71    /// Default: 10
72    pub max_restarts: usize,
73
74    /// Whether to print convergence information.
75    /// Default: false
76    pub verbose: bool,
77
78    /// When true, verify convergence by computing the true residual `||b - A*x|| / ||b||`
79    /// before declaring convergence. This prevents false convergence caused by
80    /// truncation corrupting the Krylov basis orthogonality (see Issue #207).
81    /// Costs one additional `apply_a` call when convergence is detected.
82    /// Default: false
83    pub check_true_residual: bool,
84}
85
86impl Default for GmresOptions {
87    fn default() -> Self {
88        Self {
89            max_iter: 100,
90            rtol: 1e-10,
91            max_restarts: 10,
92            verbose: false,
93            check_true_residual: false,
94        }
95    }
96}
97
98/// Result of GMRES solver.
99///
100/// Contains the solution, iteration count, final residual norm, and
101/// convergence status.
102///
103/// # Examples
104///
105/// ```
106/// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike};
107/// use tensor4all_core::krylov::{gmres, GmresOptions};
108///
109/// let i = DynIndex::new_dyn(2);
110/// let b = TensorDynLen::from_dense(vec![i.clone()], vec![3.0, 7.0]).unwrap();
111/// let x0 = TensorDynLen::from_dense(vec![i.clone()], vec![0.0, 0.0]).unwrap();
112///
113/// let result = gmres(|x: &TensorDynLen| Ok(x.clone()), &b, &x0, &GmresOptions::default()).unwrap();
114/// assert!(result.converged);
115/// assert!(result.residual_norm < 1e-10);
116/// ```
117#[derive(Debug, Clone)]
118pub struct GmresResult<T> {
119    /// The solution vector.
120    pub solution: T,
121
122    /// Number of iterations performed.
123    pub iterations: usize,
124
125    /// Final relative residual norm.
126    pub residual_norm: f64,
127
128    /// Whether the solver converged.
129    pub converged: bool,
130}
131
132/// Solve `A x = b` using GMRES (Generalized Minimal Residual Method).
133///
134/// This implements the restarted GMRES algorithm that works with abstract tensor types
135/// through the [`TensorLike`] trait's vector space operations.
136///
137/// # Algorithm
138///
139/// GMRES builds an orthonormal basis for the Krylov subspace
140/// `K_m = span{r_0, A r_0, A^2 r_0, ..., A^{m-1} r_0}` and finds the
141/// solution that minimizes `||b - A x||` over this subspace.
142///
143/// # Type Parameters
144///
145/// * `T` - A tensor type implementing `TensorLike`
146/// * `F` - A function that applies the linear operator: `F(x) = A x`
147///
148/// # Arguments
149///
150/// * `apply_a` - Function that applies the linear operator A to a tensor
151/// * `b` - Right-hand side tensor
152/// * `x0` - Initial guess
153/// * `options` - Solver options
154///
155/// # Returns
156///
157/// A `GmresResult` containing the solution and convergence information.
158///
159/// # Errors
160///
161/// Returns an error if:
162/// - Vector space operations (add, sub, scale, inner_product) fail
163/// - The linear operator application fails
164pub fn gmres<T, F>(apply_a: F, b: &T, x0: &T, options: &GmresOptions) -> Result<GmresResult<T>>
165where
166    T: TensorLike,
167    F: Fn(&T) -> Result<T>,
168{
169    // Validate structural consistency of inputs
170    b.validate()?;
171    x0.validate()?;
172
173    let b_norm = b.norm();
174    if b_norm < 1e-15 {
175        // b is effectively zero, return x0
176        return Ok(GmresResult {
177            solution: x0.clone(),
178            iterations: 0,
179            residual_norm: 0.0,
180            converged: true,
181        });
182    }
183
184    let mut x = x0.clone();
185    let mut total_iters = 0;
186
187    for _restart in 0..options.max_restarts {
188        // Compute initial residual: r = b - A*x
189        let ax = apply_a(&x)?;
190        // Validate operator output on first restart
191        if _restart == 0 {
192            ax.validate()?;
193        }
194        // r = 1.0 * b + (-1.0) * ax
195        let r = b.axpby(AnyScalar::new_real(1.0), &ax, AnyScalar::new_real(-1.0))?;
196        let r_norm = r.norm();
197        let rel_res = r_norm / b_norm;
198
199        if options.verbose {
200            eprintln!(
201                "GMRES restart {}: initial residual = {:.6e}",
202                _restart, rel_res
203            );
204        }
205
206        if rel_res < options.rtol {
207            return Ok(GmresResult {
208                solution: x,
209                iterations: total_iters,
210                residual_norm: rel_res,
211                converged: true,
212            });
213        }
214
215        // Arnoldi process with modified Gram-Schmidt
216        let mut v_basis: Vec<T> = Vec::with_capacity(options.max_iter + 1);
217        let mut h_matrix: Vec<Vec<AnyScalar>> = Vec::with_capacity(options.max_iter);
218
219        // v_0 = r / ||r||
220        let v0 = r.scale(AnyScalar::new_real(1.0 / r_norm))?;
221        v_basis.push(v0);
222
223        // Initialize Givens rotation storage
224        let mut cs: Vec<AnyScalar> = Vec::with_capacity(options.max_iter);
225        let mut sn: Vec<AnyScalar> = Vec::with_capacity(options.max_iter);
226        let mut g: Vec<AnyScalar> = vec![AnyScalar::new_real(r_norm)]; // residual in upper Hessenberg space
227
228        for j in 0..options.max_iter {
229            total_iters += 1;
230
231            // w = A * v_j
232            let w = apply_a(&v_basis[j])?;
233
234            // Modified Gram-Schmidt orthogonalization
235            let mut h_col: Vec<AnyScalar> = Vec::with_capacity(j + 2);
236            let mut w_orth = w;
237
238            for v_i in v_basis.iter().take(j + 1) {
239                let h_ij = v_i.inner_product(&w_orth)?;
240                h_col.push(h_ij.clone());
241                // w_orth = w_orth - h_ij * v_i = 1.0 * w_orth + (-h_ij) * v_i
242                let neg_h_ij = AnyScalar::new_real(0.0) - h_ij;
243                w_orth = w_orth.axpby(AnyScalar::new_real(1.0), v_i, neg_h_ij)?;
244            }
245
246            let h_jp1_j_real = w_orth.norm();
247            let h_jp1_j = AnyScalar::new_real(h_jp1_j_real);
248            h_col.push(h_jp1_j);
249
250            // Apply previous Givens rotations to new column
251            #[allow(clippy::needless_range_loop)]
252            for i in 0..j {
253                let h_i = h_col[i].clone();
254                let h_ip1 = h_col[i + 1].clone();
255                let (new_hi, new_hip1) = apply_givens_rotation(&cs[i], &sn[i], &h_i, &h_ip1);
256                h_col[i] = new_hi;
257                h_col[i + 1] = new_hip1;
258            }
259
260            // Compute new Givens rotation for h_col[j] and h_col[j+1]
261            let (c_j, s_j) = compute_givens_rotation(&h_col[j], &h_col[j + 1]);
262            cs.push(c_j.clone());
263            sn.push(s_j.clone());
264
265            // Apply new rotation to eliminate h_col[j+1]
266            let (new_hj, _) = apply_givens_rotation(&c_j, &s_j, &h_col[j], &h_col[j + 1]);
267            h_col[j] = new_hj;
268            h_col[j + 1] = AnyScalar::new_real(0.0);
269
270            // Apply rotation to g
271            let g_j = g[j].clone();
272            let g_jp1 = AnyScalar::new_real(0.0);
273            let (new_gj, new_gjp1) = apply_givens_rotation(&c_j, &s_j, &g_j, &g_jp1);
274            g[j] = new_gj;
275            let res_norm = new_gjp1.abs();
276            g.push(new_gjp1);
277
278            h_matrix.push(h_col);
279
280            // Check convergence
281            let rel_res = res_norm / b_norm;
282
283            if options.verbose {
284                eprintln!("GMRES iter {}: residual = {:.6e}", j + 1, rel_res);
285            }
286
287            if rel_res < options.rtol {
288                // Solve upper triangular system and update x
289                let y = solve_upper_triangular(&h_matrix, &g[..=j])?;
290                x = update_solution(&x, &v_basis[..=j], &y)?;
291                return Ok(GmresResult {
292                    solution: x,
293                    iterations: total_iters,
294                    residual_norm: rel_res,
295                    converged: true,
296                });
297            }
298
299            // Add new basis vector (if not converged and h_jp1_j is not too small)
300            if h_jp1_j_real > 1e-14 {
301                let v_jp1 = w_orth.scale(AnyScalar::new_real(1.0 / h_jp1_j_real))?;
302                v_basis.push(v_jp1);
303            } else {
304                // Lucky breakdown - we've found the exact solution in the Krylov subspace
305                let y = solve_upper_triangular(&h_matrix, &g[..=j])?;
306                x = update_solution(&x, &v_basis[..=j], &y)?;
307                let ax_final = apply_a(&x)?;
308                let r_final = b.axpby(
309                    AnyScalar::new_real(1.0),
310                    &ax_final,
311                    AnyScalar::new_real(-1.0),
312                )?;
313                let final_res = r_final.norm() / b_norm;
314                return Ok(GmresResult {
315                    solution: x,
316                    iterations: total_iters,
317                    residual_norm: final_res,
318                    converged: final_res < options.rtol,
319                });
320            }
321        }
322
323        // End of restart cycle - update x with current solution
324        let y = solve_upper_triangular(&h_matrix, &g[..options.max_iter])?;
325        x = update_solution(&x, &v_basis[..options.max_iter], &y)?;
326    }
327
328    // Compute final residual
329    let ax_final = apply_a(&x)?;
330    let r_final = b.axpby(
331        AnyScalar::new_real(1.0),
332        &ax_final,
333        AnyScalar::new_real(-1.0),
334    )?;
335    let final_res = r_final.norm() / b_norm;
336
337    Ok(GmresResult {
338        solution: x,
339        iterations: total_iters,
340        residual_norm: final_res,
341        converged: final_res < options.rtol,
342    })
343}
344
345/// Solve `A x = b` using GMRES with optional truncation after each iteration.
346///
347/// This is an extension of [`gmres`] that allows truncating Krylov basis vectors
348/// to control bond dimension growth in tensor network representations.
349///
350/// # Type Parameters
351///
352/// * `T` - A tensor type implementing `TensorLike`
353/// * `F` - A function that applies the linear operator: `F(x) = A x`
354/// * `Tr` - A function that truncates a tensor in-place: `Tr(&mut x)`
355///
356/// # Arguments
357///
358/// * `apply_a` - Function that applies the linear operator A to a tensor
359/// * `b` - Right-hand side tensor
360/// * `x0` - Initial guess
361/// * `options` - Solver options
362/// * `truncate` - Function that truncates a tensor to control bond dimension
363///
364/// # Note
365///
366/// Truncation is applied after each Gram-Schmidt orthogonalization step
367/// and after the final solution update. This helps control the bond dimension
368/// growth that would otherwise occur in MPS/MPO representations.
369///
370/// # Examples
371///
372/// Solve `2x = b` with a no-op truncation function:
373///
374/// ```
375/// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike, AnyScalar};
376/// use tensor4all_core::krylov::{gmres_with_truncation, GmresOptions};
377///
378/// let i = DynIndex::new_dyn(2);
379/// let b = TensorDynLen::from_dense(vec![i.clone()], vec![4.0, 6.0]).unwrap();
380/// let x0 = TensorDynLen::from_dense(vec![i.clone()], vec![0.0, 0.0]).unwrap();
381///
382/// // Operator A = 2*I (scales input by 2)
383/// let apply_a = |x: &TensorDynLen| x.scale(AnyScalar::new_real(2.0));
384/// // No-op truncation
385/// let truncate = |_x: &mut TensorDynLen| Ok(());
386///
387/// let result = gmres_with_truncation(apply_a, &b, &x0, &GmresOptions::default(), truncate).unwrap();
388/// assert!(result.converged);
389/// // Solution should be [2.0, 3.0]
390/// let expected = TensorDynLen::from_dense(vec![i], vec![2.0, 3.0]).unwrap();
391/// assert!(result.solution.sub(&expected).unwrap().maxabs() < 1e-8);
392/// ```
393pub fn gmres_with_truncation<T, F, Tr>(
394    apply_a: F,
395    b: &T,
396    x0: &T,
397    options: &GmresOptions,
398    truncate: Tr,
399) -> Result<GmresResult<T>>
400where
401    T: TensorLike,
402    F: Fn(&T) -> Result<T>,
403    Tr: Fn(&mut T) -> Result<()>,
404{
405    // Validate structural consistency of inputs
406    b.validate()?;
407    x0.validate()?;
408
409    let b_norm = b.norm();
410    if b_norm < 1e-15 {
411        return Ok(GmresResult {
412            solution: x0.clone(),
413            iterations: 0,
414            residual_norm: 0.0,
415            converged: true,
416        });
417    }
418
419    let mut x = x0.clone();
420    let mut total_iters = 0;
421
422    for _restart in 0..options.max_restarts {
423        let ax = apply_a(&x)?;
424        // Validate operator output on first restart
425        if _restart == 0 {
426            ax.validate()?;
427        }
428        let mut r = b.axpby(AnyScalar::new_real(1.0), &ax, AnyScalar::new_real(-1.0))?;
429        truncate(&mut r)?;
430        let r_norm = r.norm();
431        let rel_res = r_norm / b_norm;
432
433        if options.verbose {
434            eprintln!(
435                "GMRES restart {}: initial residual = {:.6e}",
436                _restart, rel_res
437            );
438        }
439
440        if rel_res < options.rtol {
441            return Ok(GmresResult {
442                solution: x,
443                iterations: total_iters,
444                residual_norm: rel_res,
445                converged: true,
446            });
447        }
448
449        let mut v_basis: Vec<T> = Vec::with_capacity(options.max_iter + 1);
450        let mut h_matrix: Vec<Vec<AnyScalar>> = Vec::with_capacity(options.max_iter);
451
452        let mut v0 = r.scale(AnyScalar::new_real(1.0 / r_norm))?;
453        truncate(&mut v0)?;
454        // After truncation, v0 might not be unit norm and might point in a different direction.
455        // We need to:
456        // 1. Renormalize v0 to unit norm for numerical stability
457        // 2. Recompute g[0] = <r, v0> to maintain the correct relationship
458        let v0_norm = v0.norm();
459        let effective_g0 = if v0_norm > 1e-15 {
460            v0 = v0.scale(AnyScalar::new_real(1.0 / v0_norm))?;
461            // g[0] should be the component of r in the direction of v0
462            // Since r was truncated and v0 = truncate(r/||r||)/||truncate(r/||r||)||,
463            // g[0] = <r, v0> ≈ ||r|| * ||truncate(r/||r||)|| = r_norm * v0_norm
464            r_norm * v0_norm
465        } else {
466            r_norm
467        };
468        v_basis.push(v0);
469
470        let mut cs: Vec<AnyScalar> = Vec::with_capacity(options.max_iter);
471        let mut sn: Vec<AnyScalar> = Vec::with_capacity(options.max_iter);
472        let mut g: Vec<AnyScalar> = vec![AnyScalar::new_real(effective_g0)];
473        let mut solution_already_updated = false;
474
475        for j in 0..options.max_iter {
476            total_iters += 1;
477
478            let w = apply_a(&v_basis[j])?;
479
480            let mut h_col: Vec<AnyScalar> = Vec::with_capacity(j + 2);
481            let mut w_orth = w;
482
483            for v_i in v_basis.iter().take(j + 1) {
484                let h_ij = v_i.inner_product(&w_orth)?;
485                h_col.push(h_ij.clone());
486                let neg_h_ij = AnyScalar::new_real(0.0) - h_ij;
487                w_orth = w_orth.axpby(AnyScalar::new_real(1.0), v_i, neg_h_ij)?;
488            }
489
490            // Iterative reorthogonalization with truncation
491            // Truncation can change the direction of w_orth, breaking orthogonality.
492            // We iterate until all corrections are below a threshold to ensure
493            // the Krylov basis remains orthogonal despite truncation.
494            const REORTH_THRESHOLD: f64 = 1e-12;
495            const MAX_REORTH_ITERS: usize = 10;
496
497            let mut reorth_iter_count = 0;
498            for reorth_iter in 0..MAX_REORTH_ITERS {
499                reorth_iter_count = reorth_iter + 1;
500                let norm_before_truncate = w_orth.norm();
501                truncate(&mut w_orth)?;
502                let norm_after_truncate = w_orth.norm();
503
504                let mut max_correction = 0.0;
505                for (i, v_i) in v_basis.iter().enumerate() {
506                    let correction = v_i.inner_product(&w_orth)?;
507                    let correction_abs = correction.abs();
508                    if correction_abs > max_correction {
509                        max_correction = correction_abs;
510                    }
511                    if correction_abs > REORTH_THRESHOLD {
512                        let neg_correction = AnyScalar::new_real(0.0) - correction.clone();
513                        w_orth = w_orth.axpby(AnyScalar::new_real(1.0), v_i, neg_correction)?;
514                        // Update Hessenberg matrix entry to include correction
515                        h_col[i] = h_col[i].clone() + correction;
516                    }
517                }
518
519                if options.verbose {
520                    eprintln!(
521                        "  reorth iter {}: norm {:.6e} -> {:.6e}, max_correction = {:.6e}",
522                        reorth_iter, norm_before_truncate, norm_after_truncate, max_correction
523                    );
524                }
525
526                // If all corrections are small enough, we're done
527                if max_correction < REORTH_THRESHOLD {
528                    break;
529                }
530            }
531
532            if options.verbose && reorth_iter_count > 1 {
533                eprintln!("  (needed {} reorth iterations)", reorth_iter_count);
534            }
535
536            let h_jp1_j_real = w_orth.norm();
537            let h_jp1_j = AnyScalar::new_real(h_jp1_j_real);
538            h_col.push(h_jp1_j);
539
540            #[allow(clippy::needless_range_loop)]
541            for i in 0..j {
542                let h_i = h_col[i].clone();
543                let h_ip1 = h_col[i + 1].clone();
544                let (new_hi, new_hip1) = apply_givens_rotation(&cs[i], &sn[i], &h_i, &h_ip1);
545                h_col[i] = new_hi;
546                h_col[i + 1] = new_hip1;
547            }
548
549            let (c_j, s_j) = compute_givens_rotation(&h_col[j], &h_col[j + 1]);
550            cs.push(c_j.clone());
551            sn.push(s_j.clone());
552
553            let (new_hj, _) = apply_givens_rotation(&c_j, &s_j, &h_col[j], &h_col[j + 1]);
554            h_col[j] = new_hj;
555            h_col[j + 1] = AnyScalar::new_real(0.0);
556
557            let g_j = g[j].clone();
558            let g_jp1 = AnyScalar::new_real(0.0);
559            let (new_gj, new_gjp1) = apply_givens_rotation(&c_j, &s_j, &g_j, &g_jp1);
560            g[j] = new_gj;
561            let res_norm = new_gjp1.abs();
562            g.push(new_gjp1);
563
564            h_matrix.push(h_col);
565
566            let rel_res = res_norm / b_norm;
567
568            if options.verbose {
569                eprintln!("GMRES iter {}: residual = {:.6e}", j + 1, rel_res);
570            }
571
572            if rel_res < options.rtol {
573                let y = solve_upper_triangular(&h_matrix, &g[..=j])?;
574                x = update_solution_truncated(&x, &v_basis[..=j], &y, &truncate)?;
575
576                if options.check_true_residual {
577                    // Verify with true residual to prevent false convergence
578                    let ax_check = apply_a(&x)?;
579                    let mut r_check = b.axpby(
580                        AnyScalar::new_real(1.0),
581                        &ax_check,
582                        AnyScalar::new_real(-1.0),
583                    )?;
584                    truncate(&mut r_check)?;
585                    let true_rel_res = r_check.norm() / b_norm;
586
587                    if options.verbose {
588                        eprintln!(
589                            "GMRES true residual check: hessenberg={:.6e}, checked={:.6e}",
590                            rel_res, true_rel_res
591                        );
592                    }
593
594                    if true_rel_res < options.rtol {
595                        return Ok(GmresResult {
596                            solution: x,
597                            iterations: total_iters,
598                            residual_norm: true_rel_res,
599                            converged: true,
600                        });
601                    }
602                    // False convergence detected: x is already updated above,
603                    // so skip the end-of-cycle update and go to next restart
604                    solution_already_updated = true;
605                    break;
606                } else {
607                    return Ok(GmresResult {
608                        solution: x,
609                        iterations: total_iters,
610                        residual_norm: rel_res,
611                        converged: true,
612                    });
613                }
614            }
615
616            if h_jp1_j_real > 1e-14 {
617                // Create v_{j+1} = w_orth / ||w_orth||
618                // w_orth has already been truncated twice (after orthogonalization and after reorthogonalization)
619                // so we don't need to truncate again. Scale doesn't increase bond dimensions.
620                let v_jp1 = w_orth.scale(AnyScalar::new_real(1.0 / h_jp1_j_real))?;
621                // v_jp1 should have norm ~1.0 by construction
622                // The Arnoldi relation h_{j+1,j} * v_{j+1} = w_orth is maintained exactly
623                v_basis.push(v_jp1);
624            } else {
625                let y = solve_upper_triangular(&h_matrix, &g[..=j])?;
626                x = update_solution_truncated(&x, &v_basis[..=j], &y, &truncate)?;
627                let ax_final = apply_a(&x)?;
628                let r_final = b.axpby(
629                    AnyScalar::new_real(1.0),
630                    &ax_final,
631                    AnyScalar::new_real(-1.0),
632                )?;
633                let final_res = r_final.norm() / b_norm;
634                return Ok(GmresResult {
635                    solution: x,
636                    iterations: total_iters,
637                    residual_norm: final_res,
638                    converged: final_res < options.rtol,
639                });
640            }
641        }
642
643        if !solution_already_updated {
644            let actual_iters = v_basis.len().min(options.max_iter);
645            let y = solve_upper_triangular(&h_matrix, &g[..actual_iters])?;
646            x = update_solution_truncated(&x, &v_basis[..actual_iters], &y, &truncate)?;
647        }
648    }
649
650    let ax_final = apply_a(&x)?;
651    let r_final = b.axpby(
652        AnyScalar::new_real(1.0),
653        &ax_final,
654        AnyScalar::new_real(-1.0),
655    )?;
656    let final_res = r_final.norm() / b_norm;
657
658    Ok(GmresResult {
659        solution: x,
660        iterations: total_iters,
661        residual_norm: final_res,
662        converged: final_res < options.rtol,
663    })
664}
665
666/// Options for restarted GMRES with truncation.
667///
668/// This is used by [`restart_gmres_with_truncation`] which wraps the standard GMRES
669/// with an outer loop that recomputes the true residual at each restart.
670///
671/// # Examples
672///
673/// ```
674/// use tensor4all_core::krylov::RestartGmresOptions;
675///
676/// let opts = RestartGmresOptions::new()
677///     .with_max_outer_iters(10)
678///     .with_rtol(1e-6)
679///     .with_inner_max_iter(20)
680///     .with_inner_max_restarts(2)
681///     .with_min_reduction(0.99)
682///     .with_inner_rtol(0.01)
683///     .with_verbose(false);
684///
685/// assert_eq!(opts.max_outer_iters, 10);
686/// assert_eq!(opts.rtol, 1e-6);
687/// assert_eq!(opts.inner_max_iter, 20);
688/// assert_eq!(opts.inner_max_restarts, 2);
689/// assert_eq!(opts.min_reduction, Some(0.99));
690/// assert_eq!(opts.inner_rtol, Some(0.01));
691/// ```
692#[derive(Debug, Clone)]
693pub struct RestartGmresOptions {
694    /// Maximum number of outer restart iterations.
695    /// Default: 20
696    pub max_outer_iters: usize,
697
698    /// Convergence tolerance for relative residual norm (based on true residual).
699    /// The solver stops when `||b - A*x|| / ||b|| < rtol`.
700    /// Default: 1e-10
701    pub rtol: f64,
702
703    /// Maximum iterations per inner GMRES cycle.
704    /// Default: 10
705    pub inner_max_iter: usize,
706
707    /// Number of restarts within each inner GMRES (usually 0).
708    /// Default: 0
709    pub inner_max_restarts: usize,
710
711    /// Stagnation detection threshold.
712    /// If the residual reduction ratio exceeds this value (i.e., residual doesn't decrease enough),
713    /// the solver considers it stagnated.
714    /// For example, 0.99 means stagnation is detected when residual decreases by less than 1%.
715    /// Default: None (no stagnation detection)
716    pub min_reduction: Option<f64>,
717
718    /// Inner GMRES relative tolerance.
719    /// If None, uses 0.1 (solve inner problem loosely).
720    /// Default: None
721    pub inner_rtol: Option<f64>,
722
723    /// Whether to print convergence information.
724    /// Default: false
725    pub verbose: bool,
726}
727
728impl Default for RestartGmresOptions {
729    fn default() -> Self {
730        Self {
731            max_outer_iters: 20,
732            rtol: 1e-10,
733            inner_max_iter: 10,
734            inner_max_restarts: 0,
735            min_reduction: None,
736            inner_rtol: None,
737            verbose: false,
738        }
739    }
740}
741
742impl RestartGmresOptions {
743    /// Create new options with default values.
744    pub fn new() -> Self {
745        Self::default()
746    }
747
748    /// Set maximum number of outer iterations.
749    pub fn with_max_outer_iters(mut self, max_outer_iters: usize) -> Self {
750        self.max_outer_iters = max_outer_iters;
751        self
752    }
753
754    /// Set convergence tolerance.
755    pub fn with_rtol(mut self, rtol: f64) -> Self {
756        self.rtol = rtol;
757        self
758    }
759
760    /// Set maximum iterations per inner GMRES cycle.
761    pub fn with_inner_max_iter(mut self, inner_max_iter: usize) -> Self {
762        self.inner_max_iter = inner_max_iter;
763        self
764    }
765
766    /// Set number of restarts within each inner GMRES.
767    pub fn with_inner_max_restarts(mut self, inner_max_restarts: usize) -> Self {
768        self.inner_max_restarts = inner_max_restarts;
769        self
770    }
771
772    /// Set stagnation detection threshold.
773    pub fn with_min_reduction(mut self, min_reduction: f64) -> Self {
774        self.min_reduction = Some(min_reduction);
775        self
776    }
777
778    /// Set inner GMRES relative tolerance.
779    pub fn with_inner_rtol(mut self, inner_rtol: f64) -> Self {
780        self.inner_rtol = Some(inner_rtol);
781        self
782    }
783
784    /// Enable verbose output.
785    pub fn with_verbose(mut self, verbose: bool) -> Self {
786        self.verbose = verbose;
787        self
788    }
789}
790
791/// Result of restarted GMRES solver.
792///
793/// # Examples
794///
795/// ```
796/// use tensor4all_core::{DynIndex, TensorDynLen, AnyScalar};
797/// use tensor4all_core::krylov::{restart_gmres_with_truncation, RestartGmresOptions};
798///
799/// let i = DynIndex::new_dyn(2);
800/// let b = TensorDynLen::from_dense(vec![i.clone()], vec![3.0, 5.0]).unwrap();
801///
802/// let apply_a = |x: &TensorDynLen| x.scale(AnyScalar::new_real(3.0));
803/// let truncate = |_x: &mut TensorDynLen| Ok(());
804///
805/// let result = restart_gmres_with_truncation(
806///     apply_a, &b, None, &RestartGmresOptions::default(), truncate,
807/// ).unwrap();
808///
809/// assert!(result.converged);
810/// assert!(result.residual_norm < 1e-10);
811/// assert!(result.outer_iterations <= 20);
812/// ```
813#[derive(Debug, Clone)]
814pub struct RestartGmresResult<T> {
815    /// The solution vector.
816    pub solution: T,
817
818    /// Total number of inner GMRES iterations performed.
819    pub iterations: usize,
820
821    /// Number of outer restart iterations performed.
822    pub outer_iterations: usize,
823
824    /// Final relative residual norm (true residual).
825    pub residual_norm: f64,
826
827    /// Whether the solver converged.
828    pub converged: bool,
829}
830
831/// Solve `A x = b` using restarted GMRES with truncation.
832///
833/// This wraps [`gmres_with_truncation`] with an outer loop that recomputes the true residual
834/// at each restart. This is particularly useful for MPS/MPO computations where truncation
835/// can cause the inner GMRES residual to be inaccurate.
836///
837/// # Algorithm
838///
839/// ```text
840/// for outer_iter in 0..max_outer_iters:
841///     r = b - A*x0          // Compute true residual
842///     r = truncate(r)
843///     if ||r|| / ||b|| < rtol:
844///         return x0         // Converged
845///     x' = gmres_with_truncation(A, r, 0, inner_options, truncate)
846///     x0 = truncate(x0 + x')
847/// ```
848///
849/// # Type Parameters
850///
851/// * `T` - A tensor type implementing `TensorLike`
852/// * `F` - A function that applies the linear operator: `F(x) = A x`
853/// * `Tr` - A function that truncates a tensor in-place: `Tr(&mut x)`
854///
855/// # Arguments
856///
857/// * `apply_a` - Function that applies the linear operator A to a tensor
858/// * `b` - Right-hand side tensor
859/// * `x0` - Initial guess (if None, starts from zero)
860/// * `options` - Solver options
861/// * `truncate` - Function that truncates a tensor to control bond dimension
862///
863/// # Returns
864///
865/// A `RestartGmresResult` containing the solution and convergence information.
866///
867/// # Examples
868///
869/// Solve `5x = b` with no truncation:
870///
871/// ```
872/// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike, AnyScalar};
873/// use tensor4all_core::krylov::{restart_gmres_with_truncation, RestartGmresOptions};
874///
875/// let i = DynIndex::new_dyn(3);
876/// let b = TensorDynLen::from_dense(vec![i.clone()], vec![5.0, 10.0, 15.0]).unwrap();
877///
878/// let apply_a = |x: &TensorDynLen| x.scale(AnyScalar::new_real(5.0));
879/// let truncate = |_x: &mut TensorDynLen| Ok(());
880///
881/// let result = restart_gmres_with_truncation(
882///     apply_a, &b, None, &RestartGmresOptions::default(), truncate,
883/// ).unwrap();
884///
885/// assert!(result.converged);
886/// let expected = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0, 3.0]).unwrap();
887/// assert!(result.solution.sub(&expected).unwrap().maxabs() < 1e-8);
888/// ```
889pub fn restart_gmres_with_truncation<T, F, Tr>(
890    apply_a: F,
891    b: &T,
892    x0: Option<&T>,
893    options: &RestartGmresOptions,
894    truncate: Tr,
895) -> Result<RestartGmresResult<T>>
896where
897    T: TensorLike,
898    F: Fn(&T) -> Result<T>,
899    Tr: Fn(&mut T) -> Result<()>,
900{
901    // Validate structural consistency of inputs
902    b.validate()?;
903    if let Some(x) = x0 {
904        x.validate()?;
905    }
906
907    let b_norm = b.norm();
908    if b_norm < 1e-15 {
909        // b is effectively zero, return x0 or zero
910        let solution = match x0 {
911            Some(x) => x.clone(),
912            None => b.scale(AnyScalar::new_real(0.0))?,
913        };
914        return Ok(RestartGmresResult {
915            solution,
916            iterations: 0,
917            outer_iterations: 0,
918            residual_norm: 0.0,
919            converged: true,
920        });
921    }
922
923    // Initialize x: use x0 if provided, otherwise start from zero.
924    // Track whether x is zero to avoid unnecessary bond dimension doubling
925    // when adding the first correction via axpby.
926    let mut x_is_zero = x0.is_none();
927    let mut x = match x0 {
928        Some(x) => x.clone(),
929        None => b.scale(AnyScalar::new_real(0.0))?,
930    };
931
932    let mut total_inner_iters = 0;
933    let mut prev_residual_norm = f64::INFINITY;
934
935    // Inner GMRES options
936    let inner_options = GmresOptions {
937        max_iter: options.inner_max_iter,
938        rtol: options.inner_rtol.unwrap_or(0.1), // Solve loosely by default
939        max_restarts: options.inner_max_restarts + 1, // +1 because max_restarts=0 means 1 cycle
940        verbose: options.verbose,
941        check_true_residual: true, // Always check in restart context to avoid false convergence
942    };
943
944    for outer_iter in 0..options.max_outer_iters {
945        // Compute true residual: r = b - A*x
946        let ax = apply_a(&x)?;
947        // Validate operator output on first outer iteration
948        if outer_iter == 0 {
949            ax.validate()?;
950        }
951        let mut r = b.axpby(AnyScalar::new_real(1.0), &ax, AnyScalar::new_real(-1.0))?;
952        truncate(&mut r)?;
953
954        let r_norm = r.norm();
955        let rel_res = r_norm / b_norm;
956
957        if options.verbose {
958            eprintln!(
959                "Restart GMRES outer iter {}: true residual = {:.6e}",
960                outer_iter, rel_res
961            );
962        }
963
964        // Check convergence
965        if rel_res < options.rtol {
966            return Ok(RestartGmresResult {
967                solution: x,
968                iterations: total_inner_iters,
969                outer_iterations: outer_iter,
970                residual_norm: rel_res,
971                converged: true,
972            });
973        }
974
975        // Check stagnation
976        if let Some(min_reduction) = options.min_reduction {
977            if outer_iter > 0 && rel_res > prev_residual_norm * min_reduction {
978                if options.verbose {
979                    eprintln!(
980                        "Restart GMRES stagnated: residual ratio = {:.6e} > {:.6e}",
981                        rel_res / prev_residual_norm,
982                        min_reduction
983                    );
984                }
985                return Ok(RestartGmresResult {
986                    solution: x,
987                    iterations: total_inner_iters,
988                    outer_iterations: outer_iter,
989                    residual_norm: rel_res,
990                    converged: false,
991                });
992            }
993        }
994        prev_residual_norm = rel_res;
995
996        // Solve A*x' = r using inner GMRES with zero initial guess
997        // The zero initial guess is created by scaling r by 0
998        let zero = r.scale(AnyScalar::new_real(0.0))?;
999        let inner_result = gmres_with_truncation(&apply_a, &r, &zero, &inner_options, &truncate)?;
1000
1001        total_inner_iters += inner_result.iterations;
1002
1003        if options.verbose {
1004            eprintln!(
1005                "  Inner GMRES: {} iterations, residual = {:.6e}, converged = {}",
1006                inner_result.iterations, inner_result.residual_norm, inner_result.converged
1007            );
1008        }
1009
1010        // Update solution: x = x + x'
1011        // When x is zero (first iteration with no initial guess), use x' directly
1012        // to avoid bond dimension doubling from axpby with a zero tensor.
1013        if x_is_zero {
1014            x = inner_result.solution;
1015            x_is_zero = false;
1016        } else {
1017            x = x.axpby(
1018                AnyScalar::new_real(1.0),
1019                &inner_result.solution,
1020                AnyScalar::new_real(1.0),
1021            )?;
1022        }
1023        truncate(&mut x)?;
1024    }
1025
1026    // Did not converge within max_outer_iters
1027    // Compute final residual
1028    let ax = apply_a(&x)?;
1029    let mut r = b.axpby(AnyScalar::new_real(1.0), &ax, AnyScalar::new_real(-1.0))?;
1030    truncate(&mut r)?;
1031    let final_rel_res = r.norm() / b_norm;
1032
1033    Ok(RestartGmresResult {
1034        solution: x,
1035        iterations: total_inner_iters,
1036        outer_iterations: options.max_outer_iters,
1037        residual_norm: final_rel_res,
1038        converged: false,
1039    })
1040}
1041
1042/// Compute Givens rotation coefficients to eliminate b in (a, b).
1043///
1044/// This function keeps computation in `AnyScalar` space to preserve AD metadata
1045/// as much as possible.
1046fn compute_givens_rotation(a: &AnyScalar, b: &AnyScalar) -> (AnyScalar, AnyScalar) {
1047    // r^2 = conj(a)*a + conj(b)*b (works for both real and complex)
1048    let norm2 = a.clone().conj() * a.clone() + b.clone().conj() * b.clone();
1049    let r = norm2.sqrt();
1050    if r.abs() < 1e-15 {
1051        (AnyScalar::new_real(1.0), AnyScalar::new_real(0.0))
1052    } else {
1053        (a.clone() / r.clone(), b.clone() / r)
1054    }
1055}
1056
1057/// Apply Givens rotation: (c, s) @ (x, y) -> (c*x + s*y, -conj(s)*x + c*y) for complex
1058/// or (c*x + s*y, -s*x + c*y) for real.
1059///
1060/// This function keeps computation in `AnyScalar` space to preserve AD metadata
1061/// as much as possible.
1062fn apply_givens_rotation(
1063    c: &AnyScalar,
1064    s: &AnyScalar,
1065    x: &AnyScalar,
1066    y: &AnyScalar,
1067) -> (AnyScalar, AnyScalar) {
1068    let new_x = c.clone() * x.clone() + s.clone() * y.clone();
1069    let new_y = -(s.clone().conj() * x.clone()) + c.clone() * y.clone();
1070    (new_x, new_y)
1071}
1072
1073/// Solve upper triangular system R y = g using back substitution.
1074fn solve_upper_triangular(h: &[Vec<AnyScalar>], g: &[AnyScalar]) -> Result<Vec<AnyScalar>> {
1075    let n = g.len();
1076    if n == 0 {
1077        return Ok(vec![]);
1078    }
1079
1080    let mut y = vec![AnyScalar::new_real(0.0); n];
1081
1082    for i in (0..n).rev() {
1083        let mut sum = g[i].clone();
1084
1085        for j in (i + 1)..n {
1086            // sum = sum - h[j][i] * y[j]
1087            let prod = h[j][i].clone() * y[j].clone();
1088            sum = sum - prod;
1089        }
1090
1091        let h_ii = &h[i][i];
1092        if h_ii.abs() < 1e-15 {
1093            return Err(anyhow::anyhow!(
1094                "Near-singular upper triangular matrix in GMRES"
1095            ));
1096        }
1097
1098        y[i] = sum / h_ii.clone();
1099    }
1100
1101    Ok(y)
1102}
1103
1104/// Update solution: x_new = x + sum_i y_i * v_i
1105fn update_solution<T: TensorLike>(x: &T, v_basis: &[T], y: &[AnyScalar]) -> Result<T> {
1106    let mut result = x.clone();
1107
1108    for (vi, yi) in v_basis.iter().zip(y.iter()) {
1109        let scaled_vi = vi.scale(yi.clone())?;
1110        // result = result + scaled_vi = 1.0 * result + 1.0 * scaled_vi
1111        result = result.axpby(
1112            AnyScalar::new_real(1.0),
1113            &scaled_vi,
1114            AnyScalar::new_real(1.0),
1115        )?;
1116    }
1117
1118    Ok(result)
1119}
1120
1121/// Update solution with truncation: x_new = truncate(x + sum_i y_i * v_i)
1122fn update_solution_truncated<T, Tr>(
1123    x: &T,
1124    v_basis: &[T],
1125    y: &[AnyScalar],
1126    truncate: &Tr,
1127) -> Result<T>
1128where
1129    T: TensorLike,
1130    Tr: Fn(&mut T) -> Result<()>,
1131{
1132    let mut result = x.clone();
1133    // Detect if x is effectively zero.
1134    // When x is created via scale(0.0), it preserves the original bond structure
1135    // (e.g., bond dim 4), causing axpby to double bond dimensions unnecessarily.
1136    // By detecting zero, we can use scaled_vi directly, avoiding the doubling.
1137    let mut result_is_zero = x.norm() == 0.0;
1138
1139    for (vi, yi) in v_basis.iter().zip(y.iter()) {
1140        let scaled_vi = vi.scale(yi.clone())?;
1141        if result_is_zero {
1142            result = scaled_vi;
1143            result_is_zero = false;
1144        } else {
1145            result = result.axpby(
1146                AnyScalar::new_real(1.0),
1147                &scaled_vi,
1148                AnyScalar::new_real(1.0),
1149            )?;
1150        }
1151        // Truncate after each addition to control bond dimension growth
1152        truncate(&mut result)?;
1153    }
1154
1155    Ok(result)
1156}
1157
1158#[cfg(test)]
1159mod tests;