tensor4all_core/
krylov.rs

1//! Krylov subspace methods for solving linear equations with abstract tensors.
2//!
3//! This module provides iterative solvers that work with any type implementing [`TensorLike`],
4//! enabling their use in tensor network algorithms without requiring dense vector representations.
5//!
6//! # Solvers
7//!
8//! - [`gmres`]: Generalized Minimal Residual Method (GMRES) for non-symmetric systems
9//!
10//! # Future Extensions
11//!
12//! - CG (Conjugate Gradient) for symmetric positive definite systems
13//! - BiCGSTAB for non-symmetric systems with better convergence properties
14//!
15//! # Example
16//!
17//! ```ignore
18//! use tensor4all_core::krylov::{gmres, GmresOptions};
19//!
20//! // Define a linear operator as a closure
21//! let apply_operator = |x: &T| -> Result<T> {
22//!     // Apply your linear operator to x
23//!     operator.apply(x)
24//! };
25//!
26//! let result = gmres(&apply_operator, &rhs, &initial_guess, &GmresOptions::default())?;
27//! ```
28
29use crate::any_scalar::AnyScalar;
30use crate::TensorLike;
31use anyhow::Result;
32
33/// Options for GMRES solver.
34#[derive(Debug, Clone)]
35pub struct GmresOptions {
36    /// Maximum number of iterations (restart cycle length).
37    /// Default: 100
38    pub max_iter: usize,
39
40    /// Convergence tolerance for relative residual norm.
41    /// The solver stops when `||r|| / ||b|| < rtol`.
42    /// Default: 1e-10
43    pub rtol: f64,
44
45    /// Maximum number of restarts.
46    /// Total iterations = max_iter * max_restarts.
47    /// Default: 10
48    pub max_restarts: usize,
49
50    /// Whether to print convergence information.
51    /// Default: false
52    pub verbose: bool,
53
54    /// When true, verify convergence by computing the true residual `||b - A*x|| / ||b||`
55    /// before declaring convergence. This prevents false convergence caused by
56    /// truncation corrupting the Krylov basis orthogonality (see Issue #207).
57    /// Costs one additional `apply_a` call when convergence is detected.
58    /// Default: false
59    pub check_true_residual: bool,
60}
61
62impl Default for GmresOptions {
63    fn default() -> Self {
64        Self {
65            max_iter: 100,
66            rtol: 1e-10,
67            max_restarts: 10,
68            verbose: false,
69            check_true_residual: false,
70        }
71    }
72}
73
74/// Result of GMRES solver.
75#[derive(Debug, Clone)]
76pub struct GmresResult<T> {
77    /// The solution vector.
78    pub solution: T,
79
80    /// Number of iterations performed.
81    pub iterations: usize,
82
83    /// Final relative residual norm.
84    pub residual_norm: f64,
85
86    /// Whether the solver converged.
87    pub converged: bool,
88}
89
90/// Solve `A x = b` using GMRES (Generalized Minimal Residual Method).
91///
92/// This implements the restarted GMRES algorithm that works with abstract tensor types
93/// through the [`TensorLike`] trait's vector space operations.
94///
95/// # Algorithm
96///
97/// GMRES builds an orthonormal basis for the Krylov subspace
98/// `K_m = span{r_0, A r_0, A^2 r_0, ..., A^{m-1} r_0}` and finds the
99/// solution that minimizes `||b - A x||` over this subspace.
100///
101/// # Type Parameters
102///
103/// * `T` - A tensor type implementing `TensorLike`
104/// * `F` - A function that applies the linear operator: `F(x) = A x`
105///
106/// # Arguments
107///
108/// * `apply_a` - Function that applies the linear operator A to a tensor
109/// * `b` - Right-hand side tensor
110/// * `x0` - Initial guess
111/// * `options` - Solver options
112///
113/// # Returns
114///
115/// A `GmresResult` containing the solution and convergence information.
116///
117/// # Errors
118///
119/// Returns an error if:
120/// - Vector space operations (add, sub, scale, inner_product) fail
121/// - The linear operator application fails
122pub fn gmres<T, F>(apply_a: F, b: &T, x0: &T, options: &GmresOptions) -> Result<GmresResult<T>>
123where
124    T: TensorLike,
125    F: Fn(&T) -> Result<T>,
126{
127    // Validate structural consistency of inputs
128    b.validate()?;
129    x0.validate()?;
130
131    let b_norm = b.norm();
132    if b_norm < 1e-15 {
133        // b is effectively zero, return x0
134        return Ok(GmresResult {
135            solution: x0.clone(),
136            iterations: 0,
137            residual_norm: 0.0,
138            converged: true,
139        });
140    }
141
142    let mut x = x0.clone();
143    let mut total_iters = 0;
144
145    for _restart in 0..options.max_restarts {
146        // Compute initial residual: r = b - A*x
147        let ax = apply_a(&x)?;
148        // Validate operator output on first restart
149        if _restart == 0 {
150            ax.validate()?;
151        }
152        // r = 1.0 * b + (-1.0) * ax
153        let r = b.axpby(AnyScalar::new_real(1.0), &ax, AnyScalar::new_real(-1.0))?;
154        let r_norm = r.norm();
155        let rel_res = r_norm / b_norm;
156
157        if options.verbose {
158            eprintln!(
159                "GMRES restart {}: initial residual = {:.6e}",
160                _restart, rel_res
161            );
162        }
163
164        if rel_res < options.rtol {
165            return Ok(GmresResult {
166                solution: x,
167                iterations: total_iters,
168                residual_norm: rel_res,
169                converged: true,
170            });
171        }
172
173        // Arnoldi process with modified Gram-Schmidt
174        let mut v_basis: Vec<T> = Vec::with_capacity(options.max_iter + 1);
175        let mut h_matrix: Vec<Vec<AnyScalar>> = Vec::with_capacity(options.max_iter);
176
177        // v_0 = r / ||r||
178        let v0 = r.scale(AnyScalar::new_real(1.0 / r_norm))?;
179        v_basis.push(v0);
180
181        // Initialize Givens rotation storage
182        let mut cs: Vec<AnyScalar> = Vec::with_capacity(options.max_iter);
183        let mut sn: Vec<AnyScalar> = Vec::with_capacity(options.max_iter);
184        let mut g: Vec<AnyScalar> = vec![AnyScalar::new_real(r_norm)]; // residual in upper Hessenberg space
185
186        for j in 0..options.max_iter {
187            total_iters += 1;
188
189            // w = A * v_j
190            let w = apply_a(&v_basis[j])?;
191
192            // Modified Gram-Schmidt orthogonalization
193            let mut h_col: Vec<AnyScalar> = Vec::with_capacity(j + 2);
194            let mut w_orth = w;
195
196            for v_i in v_basis.iter().take(j + 1) {
197                let h_ij = v_i.inner_product(&w_orth)?;
198                h_col.push(h_ij.clone());
199                // w_orth = w_orth - h_ij * v_i = 1.0 * w_orth + (-h_ij) * v_i
200                let neg_h_ij = AnyScalar::new_real(0.0) - h_ij;
201                w_orth = w_orth.axpby(AnyScalar::new_real(1.0), v_i, neg_h_ij)?;
202            }
203
204            let h_jp1_j_real = w_orth.norm();
205            let h_jp1_j = AnyScalar::new_real(h_jp1_j_real);
206            h_col.push(h_jp1_j);
207
208            // Apply previous Givens rotations to new column
209            #[allow(clippy::needless_range_loop)]
210            for i in 0..j {
211                let h_i = h_col[i].clone();
212                let h_ip1 = h_col[i + 1].clone();
213                let (new_hi, new_hip1) = apply_givens_rotation(&cs[i], &sn[i], &h_i, &h_ip1);
214                h_col[i] = new_hi;
215                h_col[i + 1] = new_hip1;
216            }
217
218            // Compute new Givens rotation for h_col[j] and h_col[j+1]
219            let (c_j, s_j) = compute_givens_rotation(&h_col[j], &h_col[j + 1]);
220            cs.push(c_j.clone());
221            sn.push(s_j.clone());
222
223            // Apply new rotation to eliminate h_col[j+1]
224            let (new_hj, _) = apply_givens_rotation(&c_j, &s_j, &h_col[j], &h_col[j + 1]);
225            h_col[j] = new_hj;
226            h_col[j + 1] = AnyScalar::new_real(0.0);
227
228            // Apply rotation to g
229            let g_j = g[j].clone();
230            let g_jp1 = AnyScalar::new_real(0.0);
231            let (new_gj, new_gjp1) = apply_givens_rotation(&c_j, &s_j, &g_j, &g_jp1);
232            g[j] = new_gj;
233            let res_norm = new_gjp1.abs();
234            g.push(new_gjp1);
235
236            h_matrix.push(h_col);
237
238            // Check convergence
239            let rel_res = res_norm / b_norm;
240
241            if options.verbose {
242                eprintln!("GMRES iter {}: residual = {:.6e}", j + 1, rel_res);
243            }
244
245            if rel_res < options.rtol {
246                // Solve upper triangular system and update x
247                let y = solve_upper_triangular(&h_matrix, &g[..=j])?;
248                x = update_solution(&x, &v_basis[..=j], &y)?;
249                return Ok(GmresResult {
250                    solution: x,
251                    iterations: total_iters,
252                    residual_norm: rel_res,
253                    converged: true,
254                });
255            }
256
257            // Add new basis vector (if not converged and h_jp1_j is not too small)
258            if h_jp1_j_real > 1e-14 {
259                let v_jp1 = w_orth.scale(AnyScalar::new_real(1.0 / h_jp1_j_real))?;
260                v_basis.push(v_jp1);
261            } else {
262                // Lucky breakdown - we've found the exact solution in the Krylov subspace
263                let y = solve_upper_triangular(&h_matrix, &g[..=j])?;
264                x = update_solution(&x, &v_basis[..=j], &y)?;
265                let ax_final = apply_a(&x)?;
266                let r_final = b.axpby(
267                    AnyScalar::new_real(1.0),
268                    &ax_final,
269                    AnyScalar::new_real(-1.0),
270                )?;
271                let final_res = r_final.norm() / b_norm;
272                return Ok(GmresResult {
273                    solution: x,
274                    iterations: total_iters,
275                    residual_norm: final_res,
276                    converged: final_res < options.rtol,
277                });
278            }
279        }
280
281        // End of restart cycle - update x with current solution
282        let y = solve_upper_triangular(&h_matrix, &g[..options.max_iter])?;
283        x = update_solution(&x, &v_basis[..options.max_iter], &y)?;
284    }
285
286    // Compute final residual
287    let ax_final = apply_a(&x)?;
288    let r_final = b.axpby(
289        AnyScalar::new_real(1.0),
290        &ax_final,
291        AnyScalar::new_real(-1.0),
292    )?;
293    let final_res = r_final.norm() / b_norm;
294
295    Ok(GmresResult {
296        solution: x,
297        iterations: total_iters,
298        residual_norm: final_res,
299        converged: final_res < options.rtol,
300    })
301}
302
303/// Solve `A x = b` using GMRES with optional truncation after each iteration.
304///
305/// This is an extension of [`gmres`] that allows truncating Krylov basis vectors
306/// to control bond dimension growth in tensor network representations.
307///
308/// # Type Parameters
309///
310/// * `T` - A tensor type implementing `TensorLike`
311/// * `F` - A function that applies the linear operator: `F(x) = A x`
312/// * `Tr` - A function that truncates a tensor in-place: `Tr(&mut x)`
313///
314/// # Arguments
315///
316/// * `apply_a` - Function that applies the linear operator A to a tensor
317/// * `b` - Right-hand side tensor
318/// * `x0` - Initial guess
319/// * `options` - Solver options
320/// * `truncate` - Function that truncates a tensor to control bond dimension
321///
322/// # Note
323///
324/// Truncation is applied after each Gram-Schmidt orthogonalization step
325/// and after the final solution update. This helps control the bond dimension
326/// growth that would otherwise occur in MPS/MPO representations.
327pub fn gmres_with_truncation<T, F, Tr>(
328    apply_a: F,
329    b: &T,
330    x0: &T,
331    options: &GmresOptions,
332    truncate: Tr,
333) -> Result<GmresResult<T>>
334where
335    T: TensorLike,
336    F: Fn(&T) -> Result<T>,
337    Tr: Fn(&mut T) -> Result<()>,
338{
339    // Validate structural consistency of inputs
340    b.validate()?;
341    x0.validate()?;
342
343    let b_norm = b.norm();
344    if b_norm < 1e-15 {
345        return Ok(GmresResult {
346            solution: x0.clone(),
347            iterations: 0,
348            residual_norm: 0.0,
349            converged: true,
350        });
351    }
352
353    let mut x = x0.clone();
354    let mut total_iters = 0;
355
356    for _restart in 0..options.max_restarts {
357        let ax = apply_a(&x)?;
358        // Validate operator output on first restart
359        if _restart == 0 {
360            ax.validate()?;
361        }
362        let mut r = b.axpby(AnyScalar::new_real(1.0), &ax, AnyScalar::new_real(-1.0))?;
363        truncate(&mut r)?;
364        let r_norm = r.norm();
365        let rel_res = r_norm / b_norm;
366
367        if options.verbose {
368            eprintln!(
369                "GMRES restart {}: initial residual = {:.6e}",
370                _restart, rel_res
371            );
372        }
373
374        if rel_res < options.rtol {
375            return Ok(GmresResult {
376                solution: x,
377                iterations: total_iters,
378                residual_norm: rel_res,
379                converged: true,
380            });
381        }
382
383        let mut v_basis: Vec<T> = Vec::with_capacity(options.max_iter + 1);
384        let mut h_matrix: Vec<Vec<AnyScalar>> = Vec::with_capacity(options.max_iter);
385
386        let mut v0 = r.scale(AnyScalar::new_real(1.0 / r_norm))?;
387        truncate(&mut v0)?;
388        // After truncation, v0 might not be unit norm and might point in a different direction.
389        // We need to:
390        // 1. Renormalize v0 to unit norm for numerical stability
391        // 2. Recompute g[0] = <r, v0> to maintain the correct relationship
392        let v0_norm = v0.norm();
393        let effective_g0 = if v0_norm > 1e-15 {
394            v0 = v0.scale(AnyScalar::new_real(1.0 / v0_norm))?;
395            // g[0] should be the component of r in the direction of v0
396            // Since r was truncated and v0 = truncate(r/||r||)/||truncate(r/||r||)||,
397            // g[0] = <r, v0> ≈ ||r|| * ||truncate(r/||r||)|| = r_norm * v0_norm
398            r_norm * v0_norm
399        } else {
400            r_norm
401        };
402        v_basis.push(v0);
403
404        let mut cs: Vec<AnyScalar> = Vec::with_capacity(options.max_iter);
405        let mut sn: Vec<AnyScalar> = Vec::with_capacity(options.max_iter);
406        let mut g: Vec<AnyScalar> = vec![AnyScalar::new_real(effective_g0)];
407        let mut solution_already_updated = false;
408
409        for j in 0..options.max_iter {
410            total_iters += 1;
411
412            let w = apply_a(&v_basis[j])?;
413
414            let mut h_col: Vec<AnyScalar> = Vec::with_capacity(j + 2);
415            let mut w_orth = w;
416
417            for v_i in v_basis.iter().take(j + 1) {
418                let h_ij = v_i.inner_product(&w_orth)?;
419                h_col.push(h_ij.clone());
420                let neg_h_ij = AnyScalar::new_real(0.0) - h_ij;
421                w_orth = w_orth.axpby(AnyScalar::new_real(1.0), v_i, neg_h_ij)?;
422            }
423
424            // Iterative reorthogonalization with truncation
425            // Truncation can change the direction of w_orth, breaking orthogonality.
426            // We iterate until all corrections are below a threshold to ensure
427            // the Krylov basis remains orthogonal despite truncation.
428            const REORTH_THRESHOLD: f64 = 1e-12;
429            const MAX_REORTH_ITERS: usize = 10;
430
431            let mut reorth_iter_count = 0;
432            for reorth_iter in 0..MAX_REORTH_ITERS {
433                reorth_iter_count = reorth_iter + 1;
434                let norm_before_truncate = w_orth.norm();
435                truncate(&mut w_orth)?;
436                let norm_after_truncate = w_orth.norm();
437
438                let mut max_correction = 0.0;
439                for (i, v_i) in v_basis.iter().enumerate() {
440                    let correction = v_i.inner_product(&w_orth)?;
441                    let correction_abs = correction.abs();
442                    if correction_abs > max_correction {
443                        max_correction = correction_abs;
444                    }
445                    if correction_abs > REORTH_THRESHOLD {
446                        let neg_correction = AnyScalar::new_real(0.0) - correction.clone();
447                        w_orth = w_orth.axpby(AnyScalar::new_real(1.0), v_i, neg_correction)?;
448                        // Update Hessenberg matrix entry to include correction
449                        h_col[i] = h_col[i].clone() + correction;
450                    }
451                }
452
453                if options.verbose {
454                    eprintln!(
455                        "  reorth iter {}: norm {:.6e} -> {:.6e}, max_correction = {:.6e}",
456                        reorth_iter, norm_before_truncate, norm_after_truncate, max_correction
457                    );
458                }
459
460                // If all corrections are small enough, we're done
461                if max_correction < REORTH_THRESHOLD {
462                    break;
463                }
464            }
465
466            if options.verbose && reorth_iter_count > 1 {
467                eprintln!("  (needed {} reorth iterations)", reorth_iter_count);
468            }
469
470            let h_jp1_j_real = w_orth.norm();
471            let h_jp1_j = AnyScalar::new_real(h_jp1_j_real);
472            h_col.push(h_jp1_j);
473
474            #[allow(clippy::needless_range_loop)]
475            for i in 0..j {
476                let h_i = h_col[i].clone();
477                let h_ip1 = h_col[i + 1].clone();
478                let (new_hi, new_hip1) = apply_givens_rotation(&cs[i], &sn[i], &h_i, &h_ip1);
479                h_col[i] = new_hi;
480                h_col[i + 1] = new_hip1;
481            }
482
483            let (c_j, s_j) = compute_givens_rotation(&h_col[j], &h_col[j + 1]);
484            cs.push(c_j.clone());
485            sn.push(s_j.clone());
486
487            let (new_hj, _) = apply_givens_rotation(&c_j, &s_j, &h_col[j], &h_col[j + 1]);
488            h_col[j] = new_hj;
489            h_col[j + 1] = AnyScalar::new_real(0.0);
490
491            let g_j = g[j].clone();
492            let g_jp1 = AnyScalar::new_real(0.0);
493            let (new_gj, new_gjp1) = apply_givens_rotation(&c_j, &s_j, &g_j, &g_jp1);
494            g[j] = new_gj;
495            let res_norm = new_gjp1.abs();
496            g.push(new_gjp1);
497
498            h_matrix.push(h_col);
499
500            let rel_res = res_norm / b_norm;
501
502            if options.verbose {
503                eprintln!("GMRES iter {}: residual = {:.6e}", j + 1, rel_res);
504            }
505
506            if rel_res < options.rtol {
507                let y = solve_upper_triangular(&h_matrix, &g[..=j])?;
508                x = update_solution_truncated(&x, &v_basis[..=j], &y, &truncate)?;
509
510                if options.check_true_residual {
511                    // Verify with true residual to prevent false convergence
512                    let ax_check = apply_a(&x)?;
513                    let mut r_check = b.axpby(
514                        AnyScalar::new_real(1.0),
515                        &ax_check,
516                        AnyScalar::new_real(-1.0),
517                    )?;
518                    truncate(&mut r_check)?;
519                    let true_rel_res = r_check.norm() / b_norm;
520
521                    if options.verbose {
522                        eprintln!(
523                            "GMRES true residual check: hessenberg={:.6e}, checked={:.6e}",
524                            rel_res, true_rel_res
525                        );
526                    }
527
528                    if true_rel_res < options.rtol {
529                        return Ok(GmresResult {
530                            solution: x,
531                            iterations: total_iters,
532                            residual_norm: true_rel_res,
533                            converged: true,
534                        });
535                    }
536                    // False convergence detected: x is already updated above,
537                    // so skip the end-of-cycle update and go to next restart
538                    solution_already_updated = true;
539                    break;
540                } else {
541                    return Ok(GmresResult {
542                        solution: x,
543                        iterations: total_iters,
544                        residual_norm: rel_res,
545                        converged: true,
546                    });
547                }
548            }
549
550            if h_jp1_j_real > 1e-14 {
551                // Create v_{j+1} = w_orth / ||w_orth||
552                // w_orth has already been truncated twice (after orthogonalization and after reorthogonalization)
553                // so we don't need to truncate again. Scale doesn't increase bond dimensions.
554                let v_jp1 = w_orth.scale(AnyScalar::new_real(1.0 / h_jp1_j_real))?;
555                // v_jp1 should have norm ~1.0 by construction
556                // The Arnoldi relation h_{j+1,j} * v_{j+1} = w_orth is maintained exactly
557                v_basis.push(v_jp1);
558            } else {
559                let y = solve_upper_triangular(&h_matrix, &g[..=j])?;
560                x = update_solution_truncated(&x, &v_basis[..=j], &y, &truncate)?;
561                let ax_final = apply_a(&x)?;
562                let r_final = b.axpby(
563                    AnyScalar::new_real(1.0),
564                    &ax_final,
565                    AnyScalar::new_real(-1.0),
566                )?;
567                let final_res = r_final.norm() / b_norm;
568                return Ok(GmresResult {
569                    solution: x,
570                    iterations: total_iters,
571                    residual_norm: final_res,
572                    converged: final_res < options.rtol,
573                });
574            }
575        }
576
577        if !solution_already_updated {
578            let actual_iters = v_basis.len().min(options.max_iter);
579            let y = solve_upper_triangular(&h_matrix, &g[..actual_iters])?;
580            x = update_solution_truncated(&x, &v_basis[..actual_iters], &y, &truncate)?;
581        }
582    }
583
584    let ax_final = apply_a(&x)?;
585    let r_final = b.axpby(
586        AnyScalar::new_real(1.0),
587        &ax_final,
588        AnyScalar::new_real(-1.0),
589    )?;
590    let final_res = r_final.norm() / b_norm;
591
592    Ok(GmresResult {
593        solution: x,
594        iterations: total_iters,
595        residual_norm: final_res,
596        converged: final_res < options.rtol,
597    })
598}
599
600/// Options for restarted GMRES with truncation.
601///
602/// This is used by [`restart_gmres_with_truncation`] which wraps the standard GMRES
603/// with an outer loop that recomputes the true residual at each restart.
604#[derive(Debug, Clone)]
605pub struct RestartGmresOptions {
606    /// Maximum number of outer restart iterations.
607    /// Default: 20
608    pub max_outer_iters: usize,
609
610    /// Convergence tolerance for relative residual norm (based on true residual).
611    /// The solver stops when `||b - A*x|| / ||b|| < rtol`.
612    /// Default: 1e-10
613    pub rtol: f64,
614
615    /// Maximum iterations per inner GMRES cycle.
616    /// Default: 10
617    pub inner_max_iter: usize,
618
619    /// Number of restarts within each inner GMRES (usually 0).
620    /// Default: 0
621    pub inner_max_restarts: usize,
622
623    /// Stagnation detection threshold.
624    /// If the residual reduction ratio exceeds this value (i.e., residual doesn't decrease enough),
625    /// the solver considers it stagnated.
626    /// For example, 0.99 means stagnation is detected when residual decreases by less than 1%.
627    /// Default: None (no stagnation detection)
628    pub min_reduction: Option<f64>,
629
630    /// Inner GMRES relative tolerance.
631    /// If None, uses 0.1 (solve inner problem loosely).
632    /// Default: None
633    pub inner_rtol: Option<f64>,
634
635    /// Whether to print convergence information.
636    /// Default: false
637    pub verbose: bool,
638}
639
640impl Default for RestartGmresOptions {
641    fn default() -> Self {
642        Self {
643            max_outer_iters: 20,
644            rtol: 1e-10,
645            inner_max_iter: 10,
646            inner_max_restarts: 0,
647            min_reduction: None,
648            inner_rtol: None,
649            verbose: false,
650        }
651    }
652}
653
654impl RestartGmresOptions {
655    /// Create new options with default values.
656    pub fn new() -> Self {
657        Self::default()
658    }
659
660    /// Set maximum number of outer iterations.
661    pub fn with_max_outer_iters(mut self, max_outer_iters: usize) -> Self {
662        self.max_outer_iters = max_outer_iters;
663        self
664    }
665
666    /// Set convergence tolerance.
667    pub fn with_rtol(mut self, rtol: f64) -> Self {
668        self.rtol = rtol;
669        self
670    }
671
672    /// Set maximum iterations per inner GMRES cycle.
673    pub fn with_inner_max_iter(mut self, inner_max_iter: usize) -> Self {
674        self.inner_max_iter = inner_max_iter;
675        self
676    }
677
678    /// Set number of restarts within each inner GMRES.
679    pub fn with_inner_max_restarts(mut self, inner_max_restarts: usize) -> Self {
680        self.inner_max_restarts = inner_max_restarts;
681        self
682    }
683
684    /// Set stagnation detection threshold.
685    pub fn with_min_reduction(mut self, min_reduction: f64) -> Self {
686        self.min_reduction = Some(min_reduction);
687        self
688    }
689
690    /// Set inner GMRES relative tolerance.
691    pub fn with_inner_rtol(mut self, inner_rtol: f64) -> Self {
692        self.inner_rtol = Some(inner_rtol);
693        self
694    }
695
696    /// Enable verbose output.
697    pub fn with_verbose(mut self, verbose: bool) -> Self {
698        self.verbose = verbose;
699        self
700    }
701}
702
703/// Result of restarted GMRES solver.
704#[derive(Debug, Clone)]
705pub struct RestartGmresResult<T> {
706    /// The solution vector.
707    pub solution: T,
708
709    /// Total number of inner GMRES iterations performed.
710    pub iterations: usize,
711
712    /// Number of outer restart iterations performed.
713    pub outer_iterations: usize,
714
715    /// Final relative residual norm (true residual).
716    pub residual_norm: f64,
717
718    /// Whether the solver converged.
719    pub converged: bool,
720}
721
722/// Solve `A x = b` using restarted GMRES with truncation.
723///
724/// This wraps [`gmres_with_truncation`] with an outer loop that recomputes the true residual
725/// at each restart. This is particularly useful for MPS/MPO computations where truncation
726/// can cause the inner GMRES residual to be inaccurate.
727///
728/// # Algorithm
729///
730/// ```text
731/// for outer_iter in 0..max_outer_iters:
732///     r = b - A*x0          // Compute true residual
733///     r = truncate(r)
734///     if ||r|| / ||b|| < rtol:
735///         return x0         // Converged
736///     x' = gmres_with_truncation(A, r, 0, inner_options, truncate)
737///     x0 = truncate(x0 + x')
738/// ```
739///
740/// # Type Parameters
741///
742/// * `T` - A tensor type implementing `TensorLike`
743/// * `F` - A function that applies the linear operator: `F(x) = A x`
744/// * `Tr` - A function that truncates a tensor in-place: `Tr(&mut x)`
745///
746/// # Arguments
747///
748/// * `apply_a` - Function that applies the linear operator A to a tensor
749/// * `b` - Right-hand side tensor
750/// * `x0` - Initial guess (if None, starts from zero)
751/// * `options` - Solver options
752/// * `truncate` - Function that truncates a tensor to control bond dimension
753///
754/// # Returns
755///
756/// A `RestartGmresResult` containing the solution and convergence information.
757pub fn restart_gmres_with_truncation<T, F, Tr>(
758    apply_a: F,
759    b: &T,
760    x0: Option<&T>,
761    options: &RestartGmresOptions,
762    truncate: Tr,
763) -> Result<RestartGmresResult<T>>
764where
765    T: TensorLike,
766    F: Fn(&T) -> Result<T>,
767    Tr: Fn(&mut T) -> Result<()>,
768{
769    // Validate structural consistency of inputs
770    b.validate()?;
771    if let Some(x) = x0 {
772        x.validate()?;
773    }
774
775    let b_norm = b.norm();
776    if b_norm < 1e-15 {
777        // b is effectively zero, return x0 or zero
778        let solution = match x0 {
779            Some(x) => x.clone(),
780            None => b.scale(AnyScalar::new_real(0.0))?,
781        };
782        return Ok(RestartGmresResult {
783            solution,
784            iterations: 0,
785            outer_iterations: 0,
786            residual_norm: 0.0,
787            converged: true,
788        });
789    }
790
791    // Initialize x: use x0 if provided, otherwise start from zero.
792    // Track whether x is zero to avoid unnecessary bond dimension doubling
793    // when adding the first correction via axpby.
794    let mut x_is_zero = x0.is_none();
795    let mut x = match x0 {
796        Some(x) => x.clone(),
797        None => b.scale(AnyScalar::new_real(0.0))?,
798    };
799
800    let mut total_inner_iters = 0;
801    let mut prev_residual_norm = f64::INFINITY;
802
803    // Inner GMRES options
804    let inner_options = GmresOptions {
805        max_iter: options.inner_max_iter,
806        rtol: options.inner_rtol.unwrap_or(0.1), // Solve loosely by default
807        max_restarts: options.inner_max_restarts + 1, // +1 because max_restarts=0 means 1 cycle
808        verbose: options.verbose,
809        check_true_residual: true, // Always check in restart context to avoid false convergence
810    };
811
812    for outer_iter in 0..options.max_outer_iters {
813        // Compute true residual: r = b - A*x
814        let ax = apply_a(&x)?;
815        // Validate operator output on first outer iteration
816        if outer_iter == 0 {
817            ax.validate()?;
818        }
819        let mut r = b.axpby(AnyScalar::new_real(1.0), &ax, AnyScalar::new_real(-1.0))?;
820        truncate(&mut r)?;
821
822        let r_norm = r.norm();
823        let rel_res = r_norm / b_norm;
824
825        if options.verbose {
826            eprintln!(
827                "Restart GMRES outer iter {}: true residual = {:.6e}",
828                outer_iter, rel_res
829            );
830        }
831
832        // Check convergence
833        if rel_res < options.rtol {
834            return Ok(RestartGmresResult {
835                solution: x,
836                iterations: total_inner_iters,
837                outer_iterations: outer_iter,
838                residual_norm: rel_res,
839                converged: true,
840            });
841        }
842
843        // Check stagnation
844        if let Some(min_reduction) = options.min_reduction {
845            if outer_iter > 0 && rel_res > prev_residual_norm * min_reduction {
846                if options.verbose {
847                    eprintln!(
848                        "Restart GMRES stagnated: residual ratio = {:.6e} > {:.6e}",
849                        rel_res / prev_residual_norm,
850                        min_reduction
851                    );
852                }
853                return Ok(RestartGmresResult {
854                    solution: x,
855                    iterations: total_inner_iters,
856                    outer_iterations: outer_iter,
857                    residual_norm: rel_res,
858                    converged: false,
859                });
860            }
861        }
862        prev_residual_norm = rel_res;
863
864        // Solve A*x' = r using inner GMRES with zero initial guess
865        // The zero initial guess is created by scaling r by 0
866        let zero = r.scale(AnyScalar::new_real(0.0))?;
867        let inner_result = gmres_with_truncation(&apply_a, &r, &zero, &inner_options, &truncate)?;
868
869        total_inner_iters += inner_result.iterations;
870
871        if options.verbose {
872            eprintln!(
873                "  Inner GMRES: {} iterations, residual = {:.6e}, converged = {}",
874                inner_result.iterations, inner_result.residual_norm, inner_result.converged
875            );
876        }
877
878        // Update solution: x = x + x'
879        // When x is zero (first iteration with no initial guess), use x' directly
880        // to avoid bond dimension doubling from axpby with a zero tensor.
881        if x_is_zero {
882            x = inner_result.solution;
883            x_is_zero = false;
884        } else {
885            x = x.axpby(
886                AnyScalar::new_real(1.0),
887                &inner_result.solution,
888                AnyScalar::new_real(1.0),
889            )?;
890        }
891        truncate(&mut x)?;
892    }
893
894    // Did not converge within max_outer_iters
895    // Compute final residual
896    let ax = apply_a(&x)?;
897    let mut r = b.axpby(AnyScalar::new_real(1.0), &ax, AnyScalar::new_real(-1.0))?;
898    truncate(&mut r)?;
899    let final_rel_res = r.norm() / b_norm;
900
901    Ok(RestartGmresResult {
902        solution: x,
903        iterations: total_inner_iters,
904        outer_iterations: options.max_outer_iters,
905        residual_norm: final_rel_res,
906        converged: false,
907    })
908}
909
910/// Compute Givens rotation coefficients to eliminate b in (a, b).
911///
912/// This function keeps computation in `AnyScalar` space to preserve AD metadata
913/// as much as possible.
914fn compute_givens_rotation(a: &AnyScalar, b: &AnyScalar) -> (AnyScalar, AnyScalar) {
915    // r^2 = conj(a)*a + conj(b)*b (works for both real and complex)
916    let norm2 = a.clone().conj() * a.clone() + b.clone().conj() * b.clone();
917    let r = norm2.sqrt();
918    if r.abs() < 1e-15 {
919        (AnyScalar::new_real(1.0), AnyScalar::new_real(0.0))
920    } else {
921        (a.clone() / r.clone(), b.clone() / r)
922    }
923}
924
925/// Apply Givens rotation: (c, s) @ (x, y) -> (c*x + s*y, -conj(s)*x + c*y) for complex
926/// or (c*x + s*y, -s*x + c*y) for real.
927///
928/// This function keeps computation in `AnyScalar` space to preserve AD metadata
929/// as much as possible.
930fn apply_givens_rotation(
931    c: &AnyScalar,
932    s: &AnyScalar,
933    x: &AnyScalar,
934    y: &AnyScalar,
935) -> (AnyScalar, AnyScalar) {
936    let new_x = c.clone() * x.clone() + s.clone() * y.clone();
937    let new_y = -(s.clone().conj() * x.clone()) + c.clone() * y.clone();
938    (new_x, new_y)
939}
940
941/// Solve upper triangular system R y = g using back substitution.
942fn solve_upper_triangular(h: &[Vec<AnyScalar>], g: &[AnyScalar]) -> Result<Vec<AnyScalar>> {
943    let n = g.len();
944    if n == 0 {
945        return Ok(vec![]);
946    }
947
948    let mut y = vec![AnyScalar::new_real(0.0); n];
949
950    for i in (0..n).rev() {
951        let mut sum = g[i].clone();
952
953        for j in (i + 1)..n {
954            // sum = sum - h[j][i] * y[j]
955            let prod = h[j][i].clone() * y[j].clone();
956            sum = sum - prod;
957        }
958
959        let h_ii = &h[i][i];
960        if h_ii.abs() < 1e-15 {
961            return Err(anyhow::anyhow!(
962                "Near-singular upper triangular matrix in GMRES"
963            ));
964        }
965
966        y[i] = sum / h_ii.clone();
967    }
968
969    Ok(y)
970}
971
972/// Update solution: x_new = x + sum_i y_i * v_i
973fn update_solution<T: TensorLike>(x: &T, v_basis: &[T], y: &[AnyScalar]) -> Result<T> {
974    let mut result = x.clone();
975
976    for (vi, yi) in v_basis.iter().zip(y.iter()) {
977        let scaled_vi = vi.scale(yi.clone())?;
978        // result = result + scaled_vi = 1.0 * result + 1.0 * scaled_vi
979        result = result.axpby(
980            AnyScalar::new_real(1.0),
981            &scaled_vi,
982            AnyScalar::new_real(1.0),
983        )?;
984    }
985
986    Ok(result)
987}
988
989/// Update solution with truncation: x_new = truncate(x + sum_i y_i * v_i)
990fn update_solution_truncated<T, Tr>(
991    x: &T,
992    v_basis: &[T],
993    y: &[AnyScalar],
994    truncate: &Tr,
995) -> Result<T>
996where
997    T: TensorLike,
998    Tr: Fn(&mut T) -> Result<()>,
999{
1000    let mut result = x.clone();
1001    // Detect if x is effectively zero.
1002    // When x is created via scale(0.0), it preserves the original bond structure
1003    // (e.g., bond dim 4), causing axpby to double bond dimensions unnecessarily.
1004    // By detecting zero, we can use scaled_vi directly, avoiding the doubling.
1005    let mut result_is_zero = x.norm() == 0.0;
1006
1007    for (vi, yi) in v_basis.iter().zip(y.iter()) {
1008        let scaled_vi = vi.scale(yi.clone())?;
1009        if result_is_zero {
1010            result = scaled_vi;
1011            result_is_zero = false;
1012        } else {
1013            result = result.axpby(
1014                AnyScalar::new_real(1.0),
1015                &scaled_vi,
1016                AnyScalar::new_real(1.0),
1017            )?;
1018        }
1019        // Truncate after each addition to control bond dimension growth
1020        truncate(&mut result)?;
1021    }
1022
1023    Ok(result)
1024}
1025
1026#[cfg(test)]
1027mod tests;