Skip to main content

tensor4all_core/
krylov.rs

1//! Krylov subspace methods for solving linear equations with abstract tensors.
2//!
3//! This module provides iterative solvers that work with any type implementing [`TensorVectorSpace`],
4//! enabling their use in tensor network algorithms without requiring dense vector representations.
5//!
6//! # Solvers
7//!
8//! - [`gmres`]: Generalized Minimal Residual Method (GMRES) for non-symmetric systems
9//!
10//! # Future Extensions
11//!
12//! - CG (Conjugate Gradient) for symmetric positive definite systems
13//! - BiCGSTAB for non-symmetric systems with better convergence properties
14//!
15//! # Example
16//!
17//! ```
18//! use tensor4all_core::{
19//!     krylov::{gmres, GmresOptions},
20//!     DynIndex, TensorDynLen, TensorVectorSpace,
21//! };
22//!
23//! # fn main() -> anyhow::Result<()> {
24//! let i = DynIndex::new_dyn(2);
25//! let rhs = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, -1.0])?;
26//! let initial_guess = TensorDynLen::from_dense(vec![i.clone()], vec![0.0, 0.0])?;
27//!
28//! let apply_operator = |x: &TensorDynLen| Ok(x.clone());
29//! let result = gmres(apply_operator, &rhs, &initial_guess, &GmresOptions::default())?;
30//!
31//! assert!(result.converged);
32//! assert!(result.solution.sub(&rhs)?.maxabs() < 1e-12);
33//! # Ok(())
34//! # }
35//! ```
36
37use crate::any_scalar::AnyScalar;
38use crate::TensorVectorSpace;
39use anyhow::Result;
40use std::sync::atomic::{AtomicUsize, Ordering};
41use std::time::{Duration, Instant};
42
43static GMRES_OP_PROFILE_COUNTER: AtomicUsize = AtomicUsize::new(0);
44
45#[derive(Debug, Clone)]
46struct GmresOpProfile {
47    started: Instant,
48    b_norm: Duration,
49    apply: Duration,
50    inner_product: Duration,
51    axpby: Duration,
52    norm: Duration,
53    scale: Duration,
54    triangular_solve: Duration,
55    solution_update: Duration,
56    apply_calls: usize,
57    inner_product_calls: usize,
58    axpby_calls: usize,
59    norm_calls: usize,
60    scale_calls: usize,
61    triangular_solve_calls: usize,
62    solution_update_calls: usize,
63}
64
65impl Default for GmresOpProfile {
66    fn default() -> Self {
67        Self {
68            started: Instant::now(),
69            b_norm: Duration::ZERO,
70            apply: Duration::ZERO,
71            inner_product: Duration::ZERO,
72            axpby: Duration::ZERO,
73            norm: Duration::ZERO,
74            scale: Duration::ZERO,
75            triangular_solve: Duration::ZERO,
76            solution_update: Duration::ZERO,
77            apply_calls: 0,
78            inner_product_calls: 0,
79            axpby_calls: 0,
80            norm_calls: 0,
81            scale_calls: 0,
82            triangular_solve_calls: 0,
83            solution_update_calls: 0,
84        }
85    }
86}
87
88impl GmresOpProfile {
89    fn measured(&self) -> Duration {
90        self.b_norm
91            + self.apply
92            + self.inner_product
93            + self.axpby
94            + self.norm
95            + self.scale
96            + self.triangular_solve
97            + self.solution_update
98    }
99
100    fn print(&self, id: usize, iterations: usize, residual_norm: f64, converged: bool) {
101        let total = self.started.elapsed();
102        let other = total.saturating_sub(self.measured());
103        eprintln!(
104            "T4A gmres_op_profile #{id}: iterations={iterations} residual={residual_norm:.6e} converged={converged} total_ms={:.3} apply_ms={:.3} apply_calls={} inner_ms={:.3} inner_calls={} axpby_ms={:.3} axpby_calls={} norm_ms={:.3} norm_calls={} scale_ms={:.3} scale_calls={} update_ms={:.3} update_calls={} triangular_ms={:.3} triangular_calls={} b_norm_ms={:.3} other_ms={:.3}",
105            total.as_secs_f64() * 1000.0,
106            self.apply.as_secs_f64() * 1000.0,
107            self.apply_calls,
108            self.inner_product.as_secs_f64() * 1000.0,
109            self.inner_product_calls,
110            self.axpby.as_secs_f64() * 1000.0,
111            self.axpby_calls,
112            self.norm.as_secs_f64() * 1000.0,
113            self.norm_calls,
114            self.scale.as_secs_f64() * 1000.0,
115            self.scale_calls,
116            self.solution_update.as_secs_f64() * 1000.0,
117            self.solution_update_calls,
118            self.triangular_solve.as_secs_f64() * 1000.0,
119            self.triangular_solve_calls,
120            self.b_norm.as_secs_f64() * 1000.0,
121            other.as_secs_f64() * 1000.0,
122        );
123    }
124}
125
126/// Options for GMRES solver.
127///
128/// # Examples
129///
130/// ```
131/// use tensor4all_core::krylov::GmresOptions;
132///
133/// let opts = GmresOptions {
134///     max_iter: 50,
135///     rtol: 1e-8,
136///     max_restarts: 5,
137///     verbose: false,
138///     check_true_residual: true,
139/// };
140/// assert_eq!(opts.max_iter, 50);
141/// assert_eq!(opts.rtol, 1e-8);
142/// ```
143#[derive(Debug, Clone)]
144pub struct GmresOptions {
145    /// Maximum number of iterations (restart cycle length).
146    /// Default: 100
147    pub max_iter: usize,
148
149    /// Convergence tolerance for relative residual norm.
150    /// The solver stops when `||r|| / ||b|| < rtol`.
151    /// Default: 1e-10
152    pub rtol: f64,
153
154    /// Maximum number of restarts.
155    /// Total iterations = max_iter * max_restarts.
156    /// Default: 10
157    pub max_restarts: usize,
158
159    /// Whether to print convergence information.
160    /// Default: false
161    pub verbose: bool,
162
163    /// When true, verify convergence by computing the true residual `||b - A*x|| / ||b||`
164    /// before declaring convergence. This prevents false convergence caused by
165    /// truncation corrupting the Krylov basis orthogonality (see Issue #207).
166    /// Costs one additional `apply_a` call when convergence is detected.
167    /// Default: false
168    pub check_true_residual: bool,
169}
170
171impl Default for GmresOptions {
172    fn default() -> Self {
173        Self {
174            max_iter: 100,
175            rtol: 1e-10,
176            max_restarts: 10,
177            verbose: false,
178            check_true_residual: false,
179        }
180    }
181}
182
183#[derive(Debug, Clone, Copy, PartialEq)]
184enum GmresTolerance {
185    Relative(f64),
186    Absolute(f64),
187}
188
189impl GmresTolerance {
190    fn residual_value(self, residual_norm: f64, b_norm: f64) -> f64 {
191        match self {
192            Self::Relative(_) => residual_norm / b_norm,
193            Self::Absolute(_) => residual_norm,
194        }
195    }
196
197    fn is_converged(self, residual_norm: f64, b_norm: f64) -> bool {
198        match self {
199            Self::Relative(rtol) => residual_norm / b_norm < rtol,
200            Self::Absolute(atol) => residual_norm < atol,
201        }
202    }
203}
204
205/// Result of GMRES solver.
206///
207/// Contains the solution, iteration count, final residual norm, and
208/// convergence status.
209///
210/// # Examples
211///
212/// ```
213/// use tensor4all_core::{DynIndex, TensorDynLen, TensorVectorSpace};
214/// use tensor4all_core::krylov::{gmres, GmresOptions};
215///
216/// let i = DynIndex::new_dyn(2);
217/// let b = TensorDynLen::from_dense(vec![i.clone()], vec![3.0, 7.0]).unwrap();
218/// let x0 = TensorDynLen::from_dense(vec![i.clone()], vec![0.0, 0.0]).unwrap();
219///
220/// let result = gmres(|x: &TensorDynLen| Ok(x.clone()), &b, &x0, &GmresOptions::default()).unwrap();
221/// assert!(result.converged);
222/// assert!(result.residual_norm < 1e-10);
223/// ```
224#[derive(Debug, Clone)]
225pub struct GmresResult<T> {
226    /// The solution vector.
227    pub solution: T,
228
229    /// Number of iterations performed.
230    pub iterations: usize,
231
232    /// Final relative residual norm.
233    pub residual_norm: f64,
234
235    /// Whether the solver converged.
236    pub converged: bool,
237}
238
239/// Solve `A x = b` using GMRES (Generalized Minimal Residual Method).
240///
241/// This implements the restarted GMRES algorithm that works with abstract tensor types
242/// through the [`TensorVectorSpace`] trait's vector space operations.
243///
244/// # Algorithm
245///
246/// GMRES builds an orthonormal basis for the Krylov subspace
247/// `K_m = span{r_0, A r_0, A^2 r_0, ..., A^{m-1} r_0}` and finds the
248/// solution that minimizes `||b - A x||` over this subspace.
249///
250/// # Type Parameters
251///
252/// * `T` - A tensor type implementing `TensorVectorSpace`
253/// * `F` - A function that applies the linear operator: `F(x) = A x`
254///
255/// # Arguments
256///
257/// * `apply_a` - Function that applies the linear operator A to a tensor
258/// * `b` - Right-hand side tensor
259/// * `x0` - Initial guess
260/// * `options` - Solver options
261///
262/// # Returns
263///
264/// A `GmresResult` containing the solution and convergence information.
265///
266/// # Errors
267///
268/// Returns an error if:
269/// - Vector space operations (add, sub, scale, inner_product) fail
270/// - The linear operator application fails
271pub fn gmres<T, F>(apply_a: F, b: &T, x0: &T, options: &GmresOptions) -> Result<GmresResult<T>>
272where
273    T: TensorVectorSpace,
274    F: Fn(&T) -> Result<T>,
275{
276    gmres_impl(
277        apply_a,
278        b,
279        x0,
280        options,
281        GmresTolerance::Relative(options.rtol),
282        None,
283    )
284}
285
286/// Solve `A x = b` using GMRES with an absolute residual tolerance.
287///
288/// This variant stops when `||b - A*x|| < atol`. The default [`gmres`] API uses
289/// relative residual tolerance and is preferred for scale-independent solves.
290pub fn gmres_with_absolute_tolerance<T, F>(
291    apply_a: F,
292    b: &T,
293    x0: &T,
294    options: &GmresOptions,
295    atol: f64,
296) -> Result<GmresResult<T>>
297where
298    T: TensorVectorSpace,
299    F: Fn(&T) -> Result<T>,
300{
301    gmres_impl(
302        apply_a,
303        b,
304        x0,
305        options,
306        GmresTolerance::Absolute(atol),
307        None,
308    )
309}
310
311/// Solve `(a0 I + a1 A) x = b` using GMRES with relative residual tolerance.
312///
313/// The Arnoldi basis is built from the unshifted `A` callback, while affine
314/// coefficients are applied in the projected Hessenberg problem, matching
315/// KrylovKit's affine linear-solve convention.
316pub fn gmres_affine<T, F>(
317    apply_a: F,
318    b: &T,
319    x0: &T,
320    a0: AnyScalar,
321    a1: AnyScalar,
322    options: &GmresOptions,
323) -> Result<GmresResult<T>>
324where
325    T: TensorVectorSpace,
326    F: Fn(&T) -> Result<T>,
327{
328    gmres_affine_impl(
329        apply_a,
330        b,
331        x0,
332        a0,
333        a1,
334        options,
335        GmresTolerance::Relative(options.rtol),
336    )
337}
338
339/// Solve `(a0 I + a1 A) x = b` using GMRES with an absolute residual tolerance.
340///
341/// The Arnoldi basis is built from the unshifted `A` callback, while the affine
342/// coefficients are applied to the small Hessenberg problem. This mirrors
343/// KrylovKit's `linsolve(operator, b, a0, a1)` algorithm and avoids changing the
344/// Krylov basis when affine coefficients are present.
345pub fn gmres_affine_with_absolute_tolerance<T, F>(
346    apply_a: F,
347    b: &T,
348    x0: &T,
349    a0: AnyScalar,
350    a1: AnyScalar,
351    options: &GmresOptions,
352    atol: f64,
353) -> Result<GmresResult<T>>
354where
355    T: TensorVectorSpace,
356    F: Fn(&T) -> Result<T>,
357{
358    gmres_affine_impl(
359        apply_a,
360        b,
361        x0,
362        a0,
363        a1,
364        options,
365        GmresTolerance::Absolute(atol),
366    )
367}
368
369fn gmres_affine_impl<T, F>(
370    apply_a: F,
371    b: &T,
372    x0: &T,
373    a0: AnyScalar,
374    a1: AnyScalar,
375    options: &GmresOptions,
376    tolerance: GmresTolerance,
377) -> Result<GmresResult<T>>
378where
379    T: TensorVectorSpace,
380    F: Fn(&T) -> Result<T>,
381{
382    b.validate()?;
383    x0.validate()?;
384
385    let profile_enabled = std::env::var_os("T4A_GMRES_OP_PROFILE").is_some();
386    let profile_id = if profile_enabled {
387        GMRES_OP_PROFILE_COUNTER.fetch_add(1, Ordering::Relaxed)
388    } else {
389        0
390    };
391    let mut profile = GmresOpProfile::default();
392    let mut total_iters = 0usize;
393
394    macro_rules! finish {
395        ($result:expr) => {{
396            let result = $result;
397            if profile_enabled {
398                profile.print(
399                    profile_id,
400                    result.iterations,
401                    result.residual_norm,
402                    result.converged,
403                );
404            }
405            return Ok(result);
406        }};
407    }
408
409    let started = Instant::now();
410    let b_norm = b.norm();
411    if profile_enabled {
412        profile.b_norm += started.elapsed();
413    }
414    if b_norm < 1e-15 {
415        finish!(GmresResult {
416            solution: x0.clone(),
417            iterations: 0,
418            residual_norm: 0.0,
419            converged: true,
420        });
421    }
422    if a0.is_zero() && a1.is_zero() {
423        anyhow::bail!("gmres_affine: at least one affine coefficient must be nonzero");
424    }
425    if a1.is_zero() {
426        let started = Instant::now();
427        let solution = b.scale(AnyScalar::new_real(1.0) / a0)?;
428        if profile_enabled {
429            profile.scale += started.elapsed();
430            profile.scale_calls += 1;
431        }
432        finish!(GmresResult {
433            solution,
434            iterations: 0,
435            residual_norm: 0.0,
436            converged: true,
437        });
438    }
439
440    let mut x = x0.clone();
441
442    for restart in 0..options.max_restarts {
443        let started = Instant::now();
444        let ax = apply_a(&x)?;
445        if profile_enabled {
446            profile.apply += started.elapsed();
447            profile.apply_calls += 1;
448        }
449        if restart == 0 {
450            ax.validate()?;
451        }
452        let started = Instant::now();
453        let affine_x = x.axpby(a0.clone(), &ax, a1.clone())?;
454        if profile_enabled {
455            profile.axpby += started.elapsed();
456            profile.axpby_calls += 1;
457        }
458        let started = Instant::now();
459        let r = b.axpby(
460            AnyScalar::new_real(1.0),
461            &affine_x,
462            AnyScalar::new_real(-1.0),
463        )?;
464        if profile_enabled {
465            profile.axpby += started.elapsed();
466            profile.axpby_calls += 1;
467        }
468        let started = Instant::now();
469        let r_norm = r.norm();
470        if profile_enabled {
471            profile.norm += started.elapsed();
472            profile.norm_calls += 1;
473        }
474        let residual_value = tolerance.residual_value(r_norm, b_norm);
475        if options.verbose {
476            eprintln!(
477                "GMRES restart {}: initial residual = {:.6e}",
478                restart, residual_value
479            );
480        }
481        if tolerance.is_converged(r_norm, b_norm) {
482            finish!(GmresResult {
483                solution: x,
484                iterations: total_iters,
485                residual_norm: residual_value,
486                converged: true,
487            });
488        }
489
490        let cycle_max_iter = options.max_iter;
491        let mut v_basis: Vec<T> = Vec::with_capacity(cycle_max_iter + 1);
492        let started = Instant::now();
493        v_basis.push(r.scale(AnyScalar::new_real(1.0 / r_norm))?);
494        if profile_enabled {
495            profile.scale += started.elapsed();
496            profile.scale_calls += 1;
497        }
498
499        let mut h_matrix: Vec<Vec<AnyScalar>> = Vec::with_capacity(cycle_max_iter);
500        let mut cs: Vec<AnyScalar> = Vec::with_capacity(cycle_max_iter);
501        let mut sn: Vec<AnyScalar> = Vec::with_capacity(cycle_max_iter);
502        let mut g: Vec<AnyScalar> = vec![AnyScalar::new_real(r_norm)];
503        let mut solution_already_updated = false;
504
505        for j in 0..cycle_max_iter {
506            total_iters += 1;
507
508            let started = Instant::now();
509            let w = apply_a(&v_basis[j])?;
510            if profile_enabled {
511                profile.apply += started.elapsed();
512                profile.apply_calls += 1;
513            }
514            let mut h_a_col: Vec<AnyScalar> = Vec::with_capacity(j + 2);
515            let mut w_orth = w;
516
517            for v_i in v_basis.iter().take(j + 1) {
518                let started = Instant::now();
519                let h_ij = v_i.inner_product(&w_orth)?;
520                if profile_enabled {
521                    profile.inner_product += started.elapsed();
522                    profile.inner_product_calls += 1;
523                }
524                h_a_col.push(h_ij.clone());
525                let neg_h_ij = AnyScalar::new_real(0.0) - h_ij;
526                let started = Instant::now();
527                w_orth = w_orth.axpby(AnyScalar::new_real(1.0), v_i, neg_h_ij)?;
528                if profile_enabled {
529                    profile.axpby += started.elapsed();
530                    profile.axpby_calls += 1;
531                }
532            }
533            for (i, v_i) in v_basis.iter().take(j + 1).enumerate() {
534                let started = Instant::now();
535                let correction = v_i.inner_product(&w_orth)?;
536                if profile_enabled {
537                    profile.inner_product += started.elapsed();
538                    profile.inner_product_calls += 1;
539                }
540                h_a_col[i] = h_a_col[i].clone() + correction.clone();
541                let neg_correction = AnyScalar::new_real(0.0) - correction;
542                let started = Instant::now();
543                w_orth = w_orth.axpby(AnyScalar::new_real(1.0), v_i, neg_correction)?;
544                if profile_enabled {
545                    profile.axpby += started.elapsed();
546                    profile.axpby_calls += 1;
547                }
548            }
549
550            let started = Instant::now();
551            let h_jp1_j_real = w_orth.norm();
552            if profile_enabled {
553                profile.norm += started.elapsed();
554                profile.norm_calls += 1;
555            }
556            h_a_col.push(AnyScalar::new_real(h_jp1_j_real));
557
558            let mut h_col: Vec<AnyScalar> = Vec::with_capacity(j + 2);
559            for h in h_a_col.iter().take(j) {
560                h_col.push(a1.clone() * h.clone());
561            }
562            h_col.push(a0.clone() + a1.clone() * h_a_col[j].clone());
563            h_col.push(a1.clone() * h_a_col[j + 1].clone());
564
565            #[allow(clippy::needless_range_loop)]
566            for i in 0..j {
567                let h_i = h_col[i].clone();
568                let h_ip1 = h_col[i + 1].clone();
569                let (new_hi, new_hip1) = apply_givens_rotation(&cs[i], &sn[i], &h_i, &h_ip1);
570                h_col[i] = new_hi;
571                h_col[i + 1] = new_hip1;
572            }
573
574            let (c_j, s_j) = compute_givens_rotation(&h_col[j], &h_col[j + 1]);
575            cs.push(c_j.clone());
576            sn.push(s_j.clone());
577
578            let (new_hj, _) = apply_givens_rotation(&c_j, &s_j, &h_col[j], &h_col[j + 1]);
579            h_col[j] = new_hj;
580            h_col[j + 1] = AnyScalar::new_real(0.0);
581
582            let g_j = g[j].clone();
583            let g_jp1 = AnyScalar::new_real(0.0);
584            let (new_gj, new_gjp1) = apply_givens_rotation(&c_j, &s_j, &g_j, &g_jp1);
585            g[j] = new_gj;
586            let res_norm = new_gjp1.abs();
587            g.push(new_gjp1);
588
589            h_matrix.push(h_col);
590            let residual_value = tolerance.residual_value(res_norm, b_norm);
591            if options.verbose {
592                eprintln!("GMRES iter {}: residual = {:.6e}", j + 1, residual_value);
593            }
594
595            if tolerance.is_converged(res_norm, b_norm) {
596                let started = Instant::now();
597                let y = solve_upper_triangular(&h_matrix, &g[..=j])?;
598                if profile_enabled {
599                    profile.triangular_solve += started.elapsed();
600                    profile.triangular_solve_calls += 1;
601                }
602                let started = Instant::now();
603                x = update_solution(&x, &v_basis[..=j], &y)?;
604                if profile_enabled {
605                    profile.solution_update += started.elapsed();
606                    profile.solution_update_calls += 1;
607                }
608                if options.check_true_residual {
609                    let started = Instant::now();
610                    let ax_check = apply_a(&x)?;
611                    if profile_enabled {
612                        profile.apply += started.elapsed();
613                        profile.apply_calls += 1;
614                    }
615                    let started = Instant::now();
616                    let affine_check = x.axpby(a0.clone(), &ax_check, a1.clone())?;
617                    if profile_enabled {
618                        profile.axpby += started.elapsed();
619                        profile.axpby_calls += 1;
620                    }
621                    let started = Instant::now();
622                    let r_check = b.axpby(
623                        AnyScalar::new_real(1.0),
624                        &affine_check,
625                        AnyScalar::new_real(-1.0),
626                    )?;
627                    if profile_enabled {
628                        profile.axpby += started.elapsed();
629                        profile.axpby_calls += 1;
630                    }
631                    let started = Instant::now();
632                    let true_abs_res = r_check.norm();
633                    if profile_enabled {
634                        profile.norm += started.elapsed();
635                        profile.norm_calls += 1;
636                    }
637                    let true_residual_value = tolerance.residual_value(true_abs_res, b_norm);
638                    if options.verbose {
639                        eprintln!(
640                            "GMRES true residual check: hessenberg={:.6e}, checked={:.6e}",
641                            residual_value, true_residual_value
642                        );
643                    }
644                    if tolerance.is_converged(true_abs_res, b_norm) {
645                        finish!(GmresResult {
646                            solution: x,
647                            iterations: total_iters,
648                            residual_norm: true_residual_value,
649                            converged: true,
650                        });
651                    }
652                    solution_already_updated = true;
653                    break;
654                } else {
655                    finish!(GmresResult {
656                        solution: x,
657                        iterations: total_iters,
658                        residual_norm: residual_value,
659                        converged: true,
660                    });
661                }
662            }
663
664            if h_jp1_j_real > 1e-14 {
665                let started = Instant::now();
666                v_basis.push(w_orth.scale(AnyScalar::new_real(1.0 / h_jp1_j_real))?);
667                if profile_enabled {
668                    profile.scale += started.elapsed();
669                    profile.scale_calls += 1;
670                }
671            } else {
672                let started = Instant::now();
673                let y = solve_upper_triangular(&h_matrix, &g[..=j])?;
674                if profile_enabled {
675                    profile.triangular_solve += started.elapsed();
676                    profile.triangular_solve_calls += 1;
677                }
678                let started = Instant::now();
679                x = update_solution(&x, &v_basis[..=j], &y)?;
680                if profile_enabled {
681                    profile.solution_update += started.elapsed();
682                    profile.solution_update_calls += 1;
683                }
684                let started = Instant::now();
685                let ax_final = apply_a(&x)?;
686                if profile_enabled {
687                    profile.apply += started.elapsed();
688                    profile.apply_calls += 1;
689                }
690                let started = Instant::now();
691                let affine_final = x.axpby(a0.clone(), &ax_final, a1.clone())?;
692                if profile_enabled {
693                    profile.axpby += started.elapsed();
694                    profile.axpby_calls += 1;
695                }
696                let started = Instant::now();
697                let r_final = b.axpby(
698                    AnyScalar::new_real(1.0),
699                    &affine_final,
700                    AnyScalar::new_real(-1.0),
701                )?;
702                if profile_enabled {
703                    profile.axpby += started.elapsed();
704                    profile.axpby_calls += 1;
705                }
706                let started = Instant::now();
707                let final_abs_res = r_final.norm();
708                if profile_enabled {
709                    profile.norm += started.elapsed();
710                    profile.norm_calls += 1;
711                }
712                let final_res = tolerance.residual_value(final_abs_res, b_norm);
713                finish!(GmresResult {
714                    solution: x,
715                    iterations: total_iters,
716                    residual_norm: final_res,
717                    converged: tolerance.is_converged(final_abs_res, b_norm),
718                });
719            }
720        }
721
722        if !solution_already_updated {
723            let actual_iters = h_matrix.len();
724            let started = Instant::now();
725            let y = solve_upper_triangular(&h_matrix, &g[..actual_iters])?;
726            if profile_enabled {
727                profile.triangular_solve += started.elapsed();
728                profile.triangular_solve_calls += 1;
729            }
730            let started = Instant::now();
731            x = update_solution(&x, &v_basis[..actual_iters], &y)?;
732            if profile_enabled {
733                profile.solution_update += started.elapsed();
734                profile.solution_update_calls += 1;
735            }
736        }
737    }
738
739    let started = Instant::now();
740    let ax_final = apply_a(&x)?;
741    if profile_enabled {
742        profile.apply += started.elapsed();
743        profile.apply_calls += 1;
744    }
745    let started = Instant::now();
746    let affine_final = x.axpby(a0, &ax_final, a1)?;
747    if profile_enabled {
748        profile.axpby += started.elapsed();
749        profile.axpby_calls += 1;
750    }
751    let started = Instant::now();
752    let r_final = b.axpby(
753        AnyScalar::new_real(1.0),
754        &affine_final,
755        AnyScalar::new_real(-1.0),
756    )?;
757    if profile_enabled {
758        profile.axpby += started.elapsed();
759        profile.axpby_calls += 1;
760    }
761    let started = Instant::now();
762    let final_abs_res = r_final.norm();
763    if profile_enabled {
764        profile.norm += started.elapsed();
765        profile.norm_calls += 1;
766    }
767    let final_res = tolerance.residual_value(final_abs_res, b_norm);
768
769    finish!(GmresResult {
770        solution: x,
771        iterations: total_iters,
772        residual_norm: final_res,
773        converged: tolerance.is_converged(final_abs_res, b_norm),
774    })
775}
776
777/// Solve `A x = b` using GMRES while enforcing a total iteration limit.
778///
779/// [`GmresOptions::max_iter`] remains the restart cycle length and
780/// [`GmresOptions::max_restarts`] remains the maximum number of restart cycles.
781/// `max_total_iter` caps the total number of Arnoldi steps across all restart
782/// cycles; the final cycle is shortened when necessary.
783pub fn gmres_with_total_iteration_limit<T, F>(
784    apply_a: F,
785    b: &T,
786    x0: &T,
787    options: &GmresOptions,
788    max_total_iter: usize,
789) -> Result<GmresResult<T>>
790where
791    T: TensorVectorSpace,
792    F: Fn(&T) -> Result<T>,
793{
794    gmres_impl(
795        apply_a,
796        b,
797        x0,
798        options,
799        GmresTolerance::Relative(options.rtol),
800        Some(max_total_iter),
801    )
802}
803
804fn gmres_impl<T, F>(
805    apply_a: F,
806    b: &T,
807    x0: &T,
808    options: &GmresOptions,
809    tolerance: GmresTolerance,
810    max_total_iter: Option<usize>,
811) -> Result<GmresResult<T>>
812where
813    T: TensorVectorSpace,
814    F: Fn(&T) -> Result<T>,
815{
816    // Validate structural consistency of inputs
817    b.validate()?;
818    x0.validate()?;
819
820    let b_norm = b.norm();
821    if b_norm < 1e-15 {
822        // b is effectively zero, return x0
823        return Ok(GmresResult {
824            solution: x0.clone(),
825            iterations: 0,
826            residual_norm: 0.0,
827            converged: true,
828        });
829    }
830
831    let mut x = x0.clone();
832    let mut total_iters = 0;
833
834    for _restart in 0..options.max_restarts {
835        let cycle_max_iter = match max_total_iter {
836            Some(limit) => {
837                let remaining = limit.saturating_sub(total_iters);
838                if remaining == 0 {
839                    break;
840                }
841                options.max_iter.min(remaining)
842            }
843            None => options.max_iter,
844        };
845        if cycle_max_iter == 0 {
846            break;
847        }
848
849        // Compute initial residual: r = b - A*x
850        let ax = apply_a(&x)?;
851        // Validate operator output on first restart
852        if _restart == 0 {
853            ax.validate()?;
854        }
855        // r = 1.0 * b + (-1.0) * ax
856        let r = b.axpby(AnyScalar::new_real(1.0), &ax, AnyScalar::new_real(-1.0))?;
857        let r_norm = r.norm();
858        let residual_value = tolerance.residual_value(r_norm, b_norm);
859
860        if options.verbose {
861            eprintln!(
862                "GMRES restart {}: initial residual = {:.6e}",
863                _restart, residual_value
864            );
865        }
866
867        if tolerance.is_converged(r_norm, b_norm) {
868            return Ok(GmresResult {
869                solution: x,
870                iterations: total_iters,
871                residual_norm: residual_value,
872                converged: true,
873            });
874        }
875
876        // Arnoldi process with modified Gram-Schmidt
877        let mut v_basis: Vec<T> = Vec::with_capacity(cycle_max_iter + 1);
878        let mut h_matrix: Vec<Vec<AnyScalar>> = Vec::with_capacity(cycle_max_iter);
879
880        // v_0 = r / ||r||
881        let v0 = r.scale(AnyScalar::new_real(1.0 / r_norm))?;
882        v_basis.push(v0);
883
884        // Initialize Givens rotation storage
885        let mut cs: Vec<AnyScalar> = Vec::with_capacity(cycle_max_iter);
886        let mut sn: Vec<AnyScalar> = Vec::with_capacity(cycle_max_iter);
887        let mut g: Vec<AnyScalar> = vec![AnyScalar::new_real(r_norm)]; // residual in upper Hessenberg space
888        let mut solution_already_updated = false;
889
890        for j in 0..cycle_max_iter {
891            total_iters += 1;
892
893            // w = A * v_j
894            let w = apply_a(&v_basis[j])?;
895
896            // Modified Gram-Schmidt orthogonalization
897            let mut h_col: Vec<AnyScalar> = Vec::with_capacity(j + 2);
898            let mut w_orth = w;
899
900            for v_i in v_basis.iter().take(j + 1) {
901                let h_ij = v_i.inner_product(&w_orth)?;
902                h_col.push(h_ij.clone());
903                // w_orth = w_orth - h_ij * v_i = 1.0 * w_orth + (-h_ij) * v_i
904                let neg_h_ij = AnyScalar::new_real(0.0) - h_ij;
905                w_orth = w_orth.axpby(AnyScalar::new_real(1.0), v_i, neg_h_ij)?;
906            }
907
908            // KrylovKit's default orthogonalizer is ModifiedGramSchmidt2.
909            // The second pass is important for long Krylov bases and complex
910            // non-Hermitian local problems.
911            for (i, v_i) in v_basis.iter().take(j + 1).enumerate() {
912                let correction = v_i.inner_product(&w_orth)?;
913                h_col[i] = h_col[i].clone() + correction.clone();
914                let neg_correction = AnyScalar::new_real(0.0) - correction;
915                w_orth = w_orth.axpby(AnyScalar::new_real(1.0), v_i, neg_correction)?;
916            }
917
918            let h_jp1_j_real = w_orth.norm();
919            let h_jp1_j = AnyScalar::new_real(h_jp1_j_real);
920            h_col.push(h_jp1_j);
921
922            // Apply previous Givens rotations to new column
923            #[allow(clippy::needless_range_loop)]
924            for i in 0..j {
925                let h_i = h_col[i].clone();
926                let h_ip1 = h_col[i + 1].clone();
927                let (new_hi, new_hip1) = apply_givens_rotation(&cs[i], &sn[i], &h_i, &h_ip1);
928                h_col[i] = new_hi;
929                h_col[i + 1] = new_hip1;
930            }
931
932            // Compute new Givens rotation for h_col[j] and h_col[j+1]
933            let (c_j, s_j) = compute_givens_rotation(&h_col[j], &h_col[j + 1]);
934            cs.push(c_j.clone());
935            sn.push(s_j.clone());
936
937            // Apply new rotation to eliminate h_col[j+1]
938            let (new_hj, _) = apply_givens_rotation(&c_j, &s_j, &h_col[j], &h_col[j + 1]);
939            h_col[j] = new_hj;
940            h_col[j + 1] = AnyScalar::new_real(0.0);
941
942            // Apply rotation to g
943            let g_j = g[j].clone();
944            let g_jp1 = AnyScalar::new_real(0.0);
945            let (new_gj, new_gjp1) = apply_givens_rotation(&c_j, &s_j, &g_j, &g_jp1);
946            g[j] = new_gj;
947            let res_norm = new_gjp1.abs();
948            g.push(new_gjp1);
949
950            h_matrix.push(h_col);
951
952            // Check convergence
953            let residual_value = tolerance.residual_value(res_norm, b_norm);
954
955            if options.verbose {
956                eprintln!("GMRES iter {}: residual = {:.6e}", j + 1, residual_value);
957            }
958
959            if tolerance.is_converged(res_norm, b_norm) {
960                // Solve upper triangular system and update x
961                let y = solve_upper_triangular(&h_matrix, &g[..=j])?;
962                x = update_solution(&x, &v_basis[..=j], &y)?;
963                if options.check_true_residual {
964                    let ax_check = apply_a(&x)?;
965                    let r_check = b.axpby(
966                        AnyScalar::new_real(1.0),
967                        &ax_check,
968                        AnyScalar::new_real(-1.0),
969                    )?;
970                    let true_abs_res = r_check.norm();
971                    let true_residual_value = tolerance.residual_value(true_abs_res, b_norm);
972
973                    if options.verbose {
974                        eprintln!(
975                            "GMRES true residual check: hessenberg={:.6e}, checked={:.6e}",
976                            residual_value, true_residual_value
977                        );
978                    }
979
980                    if tolerance.is_converged(true_abs_res, b_norm) {
981                        return Ok(GmresResult {
982                            solution: x,
983                            iterations: total_iters,
984                            residual_norm: true_residual_value,
985                            converged: true,
986                        });
987                    }
988                    solution_already_updated = true;
989                    break;
990                } else {
991                    return Ok(GmresResult {
992                        solution: x,
993                        iterations: total_iters,
994                        residual_norm: residual_value,
995                        converged: true,
996                    });
997                }
998            }
999
1000            // Add new basis vector (if not converged and h_jp1_j is not too small)
1001            if h_jp1_j_real > 1e-14 {
1002                let v_jp1 = w_orth.scale(AnyScalar::new_real(1.0 / h_jp1_j_real))?;
1003                v_basis.push(v_jp1);
1004            } else {
1005                // Lucky breakdown - we've found the exact solution in the Krylov subspace
1006                let y = solve_upper_triangular(&h_matrix, &g[..=j])?;
1007                x = update_solution(&x, &v_basis[..=j], &y)?;
1008                let ax_final = apply_a(&x)?;
1009                let r_final = b.axpby(
1010                    AnyScalar::new_real(1.0),
1011                    &ax_final,
1012                    AnyScalar::new_real(-1.0),
1013                )?;
1014                let final_abs_res = r_final.norm();
1015                let final_res = tolerance.residual_value(final_abs_res, b_norm);
1016                return Ok(GmresResult {
1017                    solution: x,
1018                    iterations: total_iters,
1019                    residual_norm: final_res,
1020                    converged: tolerance.is_converged(final_abs_res, b_norm),
1021                });
1022            }
1023        }
1024
1025        // End of restart cycle - update x with current solution
1026        if !solution_already_updated {
1027            let actual_iters = h_matrix.len();
1028            let y = solve_upper_triangular(&h_matrix, &g[..actual_iters])?;
1029            x = update_solution(&x, &v_basis[..actual_iters], &y)?;
1030        }
1031    }
1032
1033    // Compute final residual
1034    let ax_final = apply_a(&x)?;
1035    let r_final = b.axpby(
1036        AnyScalar::new_real(1.0),
1037        &ax_final,
1038        AnyScalar::new_real(-1.0),
1039    )?;
1040    let final_abs_res = r_final.norm();
1041    let final_res = tolerance.residual_value(final_abs_res, b_norm);
1042
1043    Ok(GmresResult {
1044        solution: x,
1045        iterations: total_iters,
1046        residual_norm: final_res,
1047        converged: tolerance.is_converged(final_abs_res, b_norm),
1048    })
1049}
1050
1051/// Solve `A x = b` using GMRES with optional truncation after each iteration.
1052///
1053/// This is an extension of [`gmres`] that allows truncating Krylov basis vectors
1054/// to control bond dimension growth in tensor network representations.
1055///
1056/// # Type Parameters
1057///
1058/// * `T` - A tensor type implementing `TensorVectorSpace`
1059/// * `F` - A function that applies the linear operator: `F(x) = A x`
1060/// * `Tr` - A function that truncates a tensor in-place: `Tr(&mut x)`
1061///
1062/// # Arguments
1063///
1064/// * `apply_a` - Function that applies the linear operator A to a tensor
1065/// * `b` - Right-hand side tensor
1066/// * `x0` - Initial guess
1067/// * `options` - Solver options
1068/// * `truncate` - Function that truncates a tensor to control bond dimension
1069///
1070/// # Note
1071///
1072/// Truncation is applied after each Gram-Schmidt orthogonalization step
1073/// and after the final solution update. This helps control the bond dimension
1074/// growth that would otherwise occur in MPS/MPO representations.
1075///
1076/// # Examples
1077///
1078/// Solve `2x = b` with a no-op truncation function:
1079///
1080/// ```
1081/// use tensor4all_core::{DynIndex, TensorDynLen, TensorVectorSpace, AnyScalar};
1082/// use tensor4all_core::krylov::{gmres_with_truncation, GmresOptions};
1083///
1084/// let i = DynIndex::new_dyn(2);
1085/// let b = TensorDynLen::from_dense(vec![i.clone()], vec![4.0, 6.0]).unwrap();
1086/// let x0 = TensorDynLen::from_dense(vec![i.clone()], vec![0.0, 0.0]).unwrap();
1087///
1088/// // Operator A = 2*I (scales input by 2)
1089/// let apply_a = |x: &TensorDynLen| x.scale(AnyScalar::new_real(2.0));
1090/// // No-op truncation
1091/// let truncate = |_x: &mut TensorDynLen| Ok(());
1092///
1093/// let result = gmres_with_truncation(apply_a, &b, &x0, &GmresOptions::default(), truncate).unwrap();
1094/// assert!(result.converged);
1095/// // Solution should be [2.0, 3.0]
1096/// let expected = TensorDynLen::from_dense(vec![i], vec![2.0, 3.0]).unwrap();
1097/// assert!(result.solution.sub(&expected).unwrap().maxabs() < 1e-8);
1098/// ```
1099pub fn gmres_with_truncation<T, F, Tr>(
1100    apply_a: F,
1101    b: &T,
1102    x0: &T,
1103    options: &GmresOptions,
1104    truncate: Tr,
1105) -> Result<GmresResult<T>>
1106where
1107    T: TensorVectorSpace,
1108    F: Fn(&T) -> Result<T>,
1109    Tr: Fn(&mut T) -> Result<()>,
1110{
1111    // Validate structural consistency of inputs
1112    b.validate()?;
1113    x0.validate()?;
1114
1115    let b_norm = b.norm();
1116    if b_norm < 1e-15 {
1117        return Ok(GmresResult {
1118            solution: x0.clone(),
1119            iterations: 0,
1120            residual_norm: 0.0,
1121            converged: true,
1122        });
1123    }
1124
1125    let mut x = x0.clone();
1126    let mut total_iters = 0;
1127
1128    for _restart in 0..options.max_restarts {
1129        let ax = apply_a(&x)?;
1130        // Validate operator output on first restart
1131        if _restart == 0 {
1132            ax.validate()?;
1133        }
1134        let mut r = b.axpby(AnyScalar::new_real(1.0), &ax, AnyScalar::new_real(-1.0))?;
1135        truncate(&mut r)?;
1136        let r_norm = r.norm();
1137        let rel_res = r_norm / b_norm;
1138
1139        if options.verbose {
1140            eprintln!(
1141                "GMRES restart {}: initial residual = {:.6e}",
1142                _restart, rel_res
1143            );
1144        }
1145
1146        if rel_res < options.rtol {
1147            return Ok(GmresResult {
1148                solution: x,
1149                iterations: total_iters,
1150                residual_norm: rel_res,
1151                converged: true,
1152            });
1153        }
1154
1155        let mut v_basis: Vec<T> = Vec::with_capacity(options.max_iter + 1);
1156        let mut h_matrix: Vec<Vec<AnyScalar>> = Vec::with_capacity(options.max_iter);
1157
1158        let mut v0 = r.scale(AnyScalar::new_real(1.0 / r_norm))?;
1159        truncate(&mut v0)?;
1160        // After truncation, v0 might not be unit norm and might point in a different direction.
1161        // We need to:
1162        // 1. Renormalize v0 to unit norm for numerical stability
1163        // 2. Recompute g[0] = <r, v0> to maintain the correct relationship
1164        let v0_norm = v0.norm();
1165        let effective_g0 = if v0_norm > 1e-15 {
1166            v0 = v0.scale(AnyScalar::new_real(1.0 / v0_norm))?;
1167            // g[0] should be the component of r in the direction of v0
1168            // Since r was truncated and v0 = truncate(r/||r||)/||truncate(r/||r||)||,
1169            // g[0] = <r, v0> ≈ ||r|| * ||truncate(r/||r||)|| = r_norm * v0_norm
1170            r_norm * v0_norm
1171        } else {
1172            r_norm
1173        };
1174        v_basis.push(v0);
1175
1176        let mut cs: Vec<AnyScalar> = Vec::with_capacity(options.max_iter);
1177        let mut sn: Vec<AnyScalar> = Vec::with_capacity(options.max_iter);
1178        let mut g: Vec<AnyScalar> = vec![AnyScalar::new_real(effective_g0)];
1179        let mut solution_already_updated = false;
1180
1181        for j in 0..options.max_iter {
1182            total_iters += 1;
1183
1184            let w = apply_a(&v_basis[j])?;
1185
1186            let mut h_col: Vec<AnyScalar> = Vec::with_capacity(j + 2);
1187            let mut w_orth = w;
1188
1189            for v_i in v_basis.iter().take(j + 1) {
1190                let h_ij = v_i.inner_product(&w_orth)?;
1191                h_col.push(h_ij.clone());
1192                let neg_h_ij = AnyScalar::new_real(0.0) - h_ij;
1193                w_orth = w_orth.axpby(AnyScalar::new_real(1.0), v_i, neg_h_ij)?;
1194            }
1195
1196            // Iterative reorthogonalization with truncation
1197            // Truncation can change the direction of w_orth, breaking orthogonality.
1198            // We iterate until all corrections are below a threshold to ensure
1199            // the Krylov basis remains orthogonal despite truncation.
1200            const REORTH_THRESHOLD: f64 = 1e-12;
1201            const MAX_REORTH_ITERS: usize = 10;
1202
1203            let mut reorth_iter_count = 0;
1204            for reorth_iter in 0..MAX_REORTH_ITERS {
1205                reorth_iter_count = reorth_iter + 1;
1206                let norm_before_truncate = w_orth.norm();
1207                truncate(&mut w_orth)?;
1208                let norm_after_truncate = w_orth.norm();
1209
1210                let mut max_correction = 0.0;
1211                for (i, v_i) in v_basis.iter().enumerate() {
1212                    let correction = v_i.inner_product(&w_orth)?;
1213                    let correction_abs = correction.abs();
1214                    if correction_abs > max_correction {
1215                        max_correction = correction_abs;
1216                    }
1217                    if correction_abs > REORTH_THRESHOLD {
1218                        let neg_correction = AnyScalar::new_real(0.0) - correction.clone();
1219                        w_orth = w_orth.axpby(AnyScalar::new_real(1.0), v_i, neg_correction)?;
1220                        // Update Hessenberg matrix entry to include correction
1221                        h_col[i] = h_col[i].clone() + correction;
1222                    }
1223                }
1224
1225                if options.verbose {
1226                    eprintln!(
1227                        "  reorth iter {}: norm {:.6e} -> {:.6e}, max_correction = {:.6e}",
1228                        reorth_iter, norm_before_truncate, norm_after_truncate, max_correction
1229                    );
1230                }
1231
1232                // If all corrections are small enough, we're done
1233                if max_correction < REORTH_THRESHOLD {
1234                    break;
1235                }
1236            }
1237
1238            if options.verbose && reorth_iter_count > 1 {
1239                eprintln!("  (needed {} reorth iterations)", reorth_iter_count);
1240            }
1241
1242            let h_jp1_j_real = w_orth.norm();
1243            let h_jp1_j = AnyScalar::new_real(h_jp1_j_real);
1244            h_col.push(h_jp1_j);
1245
1246            #[allow(clippy::needless_range_loop)]
1247            for i in 0..j {
1248                let h_i = h_col[i].clone();
1249                let h_ip1 = h_col[i + 1].clone();
1250                let (new_hi, new_hip1) = apply_givens_rotation(&cs[i], &sn[i], &h_i, &h_ip1);
1251                h_col[i] = new_hi;
1252                h_col[i + 1] = new_hip1;
1253            }
1254
1255            let (c_j, s_j) = compute_givens_rotation(&h_col[j], &h_col[j + 1]);
1256            cs.push(c_j.clone());
1257            sn.push(s_j.clone());
1258
1259            let (new_hj, _) = apply_givens_rotation(&c_j, &s_j, &h_col[j], &h_col[j + 1]);
1260            h_col[j] = new_hj;
1261            h_col[j + 1] = AnyScalar::new_real(0.0);
1262
1263            let g_j = g[j].clone();
1264            let g_jp1 = AnyScalar::new_real(0.0);
1265            let (new_gj, new_gjp1) = apply_givens_rotation(&c_j, &s_j, &g_j, &g_jp1);
1266            g[j] = new_gj;
1267            let res_norm = new_gjp1.abs();
1268            g.push(new_gjp1);
1269
1270            h_matrix.push(h_col);
1271
1272            let rel_res = res_norm / b_norm;
1273
1274            if options.verbose {
1275                eprintln!("GMRES iter {}: residual = {:.6e}", j + 1, rel_res);
1276            }
1277
1278            if rel_res < options.rtol {
1279                let y = solve_upper_triangular(&h_matrix, &g[..=j])?;
1280                x = update_solution_truncated(&x, &v_basis[..=j], &y, &truncate)?;
1281
1282                if options.check_true_residual {
1283                    // Verify with true residual to prevent false convergence
1284                    let ax_check = apply_a(&x)?;
1285                    let mut r_check = b.axpby(
1286                        AnyScalar::new_real(1.0),
1287                        &ax_check,
1288                        AnyScalar::new_real(-1.0),
1289                    )?;
1290                    truncate(&mut r_check)?;
1291                    let true_rel_res = r_check.norm() / b_norm;
1292
1293                    if options.verbose {
1294                        eprintln!(
1295                            "GMRES true residual check: hessenberg={:.6e}, checked={:.6e}",
1296                            rel_res, true_rel_res
1297                        );
1298                    }
1299
1300                    if true_rel_res < options.rtol {
1301                        return Ok(GmresResult {
1302                            solution: x,
1303                            iterations: total_iters,
1304                            residual_norm: true_rel_res,
1305                            converged: true,
1306                        });
1307                    }
1308                    // False convergence detected: x is already updated above,
1309                    // so skip the end-of-cycle update and go to next restart
1310                    solution_already_updated = true;
1311                    break;
1312                } else {
1313                    return Ok(GmresResult {
1314                        solution: x,
1315                        iterations: total_iters,
1316                        residual_norm: rel_res,
1317                        converged: true,
1318                    });
1319                }
1320            }
1321
1322            if h_jp1_j_real > 1e-14 {
1323                // Create v_{j+1} = w_orth / ||w_orth||
1324                // w_orth has already been truncated twice (after orthogonalization and after reorthogonalization)
1325                // so we don't need to truncate again. Scale doesn't increase bond dimensions.
1326                let v_jp1 = w_orth.scale(AnyScalar::new_real(1.0 / h_jp1_j_real))?;
1327                // v_jp1 should have norm ~1.0 by construction
1328                // The Arnoldi relation h_{j+1,j} * v_{j+1} = w_orth is maintained exactly
1329                v_basis.push(v_jp1);
1330            } else {
1331                let y = solve_upper_triangular(&h_matrix, &g[..=j])?;
1332                x = update_solution_truncated(&x, &v_basis[..=j], &y, &truncate)?;
1333                let ax_final = apply_a(&x)?;
1334                let r_final = b.axpby(
1335                    AnyScalar::new_real(1.0),
1336                    &ax_final,
1337                    AnyScalar::new_real(-1.0),
1338                )?;
1339                let final_res = r_final.norm() / b_norm;
1340                return Ok(GmresResult {
1341                    solution: x,
1342                    iterations: total_iters,
1343                    residual_norm: final_res,
1344                    converged: final_res < options.rtol,
1345                });
1346            }
1347        }
1348
1349        if !solution_already_updated {
1350            let actual_iters = v_basis.len().min(options.max_iter);
1351            let y = solve_upper_triangular(&h_matrix, &g[..actual_iters])?;
1352            x = update_solution_truncated(&x, &v_basis[..actual_iters], &y, &truncate)?;
1353        }
1354    }
1355
1356    let ax_final = apply_a(&x)?;
1357    let r_final = b.axpby(
1358        AnyScalar::new_real(1.0),
1359        &ax_final,
1360        AnyScalar::new_real(-1.0),
1361    )?;
1362    let final_res = r_final.norm() / b_norm;
1363
1364    Ok(GmresResult {
1365        solution: x,
1366        iterations: total_iters,
1367        residual_norm: final_res,
1368        converged: final_res < options.rtol,
1369    })
1370}
1371
1372/// Options for restarted GMRES with truncation.
1373///
1374/// This is used by [`restart_gmres_with_truncation`] which wraps the standard GMRES
1375/// with an outer loop that recomputes the true residual at each restart.
1376///
1377/// # Examples
1378///
1379/// ```
1380/// use tensor4all_core::krylov::RestartGmresOptions;
1381///
1382/// let opts = RestartGmresOptions::new()
1383///     .with_max_outer_iters(10)
1384///     .with_rtol(1e-6)
1385///     .with_inner_max_iter(20)
1386///     .with_inner_max_restarts(2)
1387///     .with_min_reduction(0.99)
1388///     .with_inner_rtol(0.01)
1389///     .with_verbose(false);
1390///
1391/// assert_eq!(opts.max_outer_iters, 10);
1392/// assert_eq!(opts.rtol, 1e-6);
1393/// assert_eq!(opts.inner_max_iter, 20);
1394/// assert_eq!(opts.inner_max_restarts, 2);
1395/// assert_eq!(opts.min_reduction, Some(0.99));
1396/// assert_eq!(opts.inner_rtol, Some(0.01));
1397/// ```
1398#[derive(Debug, Clone)]
1399pub struct RestartGmresOptions {
1400    /// Maximum number of outer restart iterations.
1401    /// Default: 20
1402    pub max_outer_iters: usize,
1403
1404    /// Convergence tolerance for relative residual norm (based on true residual).
1405    /// The solver stops when `||b - A*x|| / ||b|| < rtol`.
1406    /// Default: 1e-10
1407    pub rtol: f64,
1408
1409    /// Maximum iterations per inner GMRES cycle.
1410    /// Default: 10
1411    pub inner_max_iter: usize,
1412
1413    /// Number of restarts within each inner GMRES (usually 0).
1414    /// Default: 0
1415    pub inner_max_restarts: usize,
1416
1417    /// Stagnation detection threshold.
1418    /// If the residual reduction ratio exceeds this value (i.e., residual doesn't decrease enough),
1419    /// the solver considers it stagnated.
1420    /// For example, 0.99 means stagnation is detected when residual decreases by less than 1%.
1421    /// Default: None (no stagnation detection)
1422    pub min_reduction: Option<f64>,
1423
1424    /// Inner GMRES relative tolerance.
1425    /// If None, uses 0.1 (solve inner problem loosely).
1426    /// Default: None
1427    pub inner_rtol: Option<f64>,
1428
1429    /// Whether to print convergence information.
1430    /// Default: false
1431    pub verbose: bool,
1432}
1433
1434impl Default for RestartGmresOptions {
1435    fn default() -> Self {
1436        Self {
1437            max_outer_iters: 20,
1438            rtol: 1e-10,
1439            inner_max_iter: 10,
1440            inner_max_restarts: 0,
1441            min_reduction: None,
1442            inner_rtol: None,
1443            verbose: false,
1444        }
1445    }
1446}
1447
1448impl RestartGmresOptions {
1449    /// Create new options with default values.
1450    pub fn new() -> Self {
1451        Self::default()
1452    }
1453
1454    /// Set maximum number of outer iterations.
1455    pub fn with_max_outer_iters(mut self, max_outer_iters: usize) -> Self {
1456        self.max_outer_iters = max_outer_iters;
1457        self
1458    }
1459
1460    /// Set convergence tolerance.
1461    pub fn with_rtol(mut self, rtol: f64) -> Self {
1462        self.rtol = rtol;
1463        self
1464    }
1465
1466    /// Set maximum iterations per inner GMRES cycle.
1467    pub fn with_inner_max_iter(mut self, inner_max_iter: usize) -> Self {
1468        self.inner_max_iter = inner_max_iter;
1469        self
1470    }
1471
1472    /// Set number of restarts within each inner GMRES.
1473    pub fn with_inner_max_restarts(mut self, inner_max_restarts: usize) -> Self {
1474        self.inner_max_restarts = inner_max_restarts;
1475        self
1476    }
1477
1478    /// Set stagnation detection threshold.
1479    pub fn with_min_reduction(mut self, min_reduction: f64) -> Self {
1480        self.min_reduction = Some(min_reduction);
1481        self
1482    }
1483
1484    /// Set inner GMRES relative tolerance.
1485    pub fn with_inner_rtol(mut self, inner_rtol: f64) -> Self {
1486        self.inner_rtol = Some(inner_rtol);
1487        self
1488    }
1489
1490    /// Enable verbose output.
1491    pub fn with_verbose(mut self, verbose: bool) -> Self {
1492        self.verbose = verbose;
1493        self
1494    }
1495}
1496
1497/// Result of restarted GMRES solver.
1498///
1499/// # Examples
1500///
1501/// ```
1502/// use tensor4all_core::{DynIndex, TensorDynLen, AnyScalar};
1503/// use tensor4all_core::krylov::{restart_gmres_with_truncation, RestartGmresOptions};
1504///
1505/// let i = DynIndex::new_dyn(2);
1506/// let b = TensorDynLen::from_dense(vec![i.clone()], vec![3.0, 5.0]).unwrap();
1507///
1508/// let apply_a = |x: &TensorDynLen| x.scale(AnyScalar::new_real(3.0));
1509/// let truncate = |_x: &mut TensorDynLen| Ok(());
1510///
1511/// let result = restart_gmres_with_truncation(
1512///     apply_a, &b, None, &RestartGmresOptions::default(), truncate,
1513/// ).unwrap();
1514///
1515/// assert!(result.converged);
1516/// assert!(result.residual_norm < 1e-10);
1517/// assert!(result.outer_iterations <= 20);
1518/// ```
1519#[derive(Debug, Clone)]
1520pub struct RestartGmresResult<T> {
1521    /// The solution vector.
1522    pub solution: T,
1523
1524    /// Total number of inner GMRES iterations performed.
1525    pub iterations: usize,
1526
1527    /// Number of outer restart iterations performed.
1528    pub outer_iterations: usize,
1529
1530    /// Final relative residual norm (true residual).
1531    pub residual_norm: f64,
1532
1533    /// Whether the solver converged.
1534    pub converged: bool,
1535}
1536
1537/// Solve `A x = b` using restarted GMRES with truncation.
1538///
1539/// This wraps [`gmres_with_truncation`] with an outer loop that recomputes the true residual
1540/// at each restart. This is particularly useful for MPS/MPO computations where truncation
1541/// can cause the inner GMRES residual to be inaccurate.
1542///
1543/// # Algorithm
1544///
1545/// ```text
1546/// for outer_iter in 0..max_outer_iters:
1547///     r = b - A*x0          // Compute true residual
1548///     r = truncate(r)
1549///     if ||r|| / ||b|| < rtol:
1550///         return x0         // Converged
1551///     x' = gmres_with_truncation(A, r, 0, inner_options, truncate)
1552///     x0 = truncate(x0 + x')
1553/// ```
1554///
1555/// # Type Parameters
1556///
1557/// * `T` - A tensor type implementing `TensorVectorSpace`
1558/// * `F` - A function that applies the linear operator: `F(x) = A x`
1559/// * `Tr` - A function that truncates a tensor in-place: `Tr(&mut x)`
1560///
1561/// # Arguments
1562///
1563/// * `apply_a` - Function that applies the linear operator A to a tensor
1564/// * `b` - Right-hand side tensor
1565/// * `x0` - Initial guess (if None, starts from zero)
1566/// * `options` - Solver options
1567/// * `truncate` - Function that truncates a tensor to control bond dimension
1568///
1569/// # Returns
1570///
1571/// A `RestartGmresResult` containing the solution and convergence information.
1572///
1573/// # Examples
1574///
1575/// Solve `5x = b` with no truncation:
1576///
1577/// ```
1578/// use tensor4all_core::{DynIndex, TensorDynLen, TensorVectorSpace, AnyScalar};
1579/// use tensor4all_core::krylov::{restart_gmres_with_truncation, RestartGmresOptions};
1580///
1581/// let i = DynIndex::new_dyn(3);
1582/// let b = TensorDynLen::from_dense(vec![i.clone()], vec![5.0, 10.0, 15.0]).unwrap();
1583///
1584/// let apply_a = |x: &TensorDynLen| x.scale(AnyScalar::new_real(5.0));
1585/// let truncate = |_x: &mut TensorDynLen| Ok(());
1586///
1587/// let result = restart_gmres_with_truncation(
1588///     apply_a, &b, None, &RestartGmresOptions::default(), truncate,
1589/// ).unwrap();
1590///
1591/// assert!(result.converged);
1592/// let expected = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0, 3.0]).unwrap();
1593/// assert!(result.solution.sub(&expected).unwrap().maxabs() < 1e-8);
1594/// ```
1595pub fn restart_gmres_with_truncation<T, F, Tr>(
1596    apply_a: F,
1597    b: &T,
1598    x0: Option<&T>,
1599    options: &RestartGmresOptions,
1600    truncate: Tr,
1601) -> Result<RestartGmresResult<T>>
1602where
1603    T: TensorVectorSpace,
1604    F: Fn(&T) -> Result<T>,
1605    Tr: Fn(&mut T) -> Result<()>,
1606{
1607    // Validate structural consistency of inputs
1608    b.validate()?;
1609    if let Some(x) = x0 {
1610        x.validate()?;
1611    }
1612
1613    let b_norm = b.norm();
1614    if b_norm < 1e-15 {
1615        // b is effectively zero, return x0 or zero
1616        let solution = match x0 {
1617            Some(x) => x.clone(),
1618            None => b.scale(AnyScalar::new_real(0.0))?,
1619        };
1620        return Ok(RestartGmresResult {
1621            solution,
1622            iterations: 0,
1623            outer_iterations: 0,
1624            residual_norm: 0.0,
1625            converged: true,
1626        });
1627    }
1628
1629    // Initialize x: use x0 if provided, otherwise start from zero.
1630    // Track whether x is zero to avoid unnecessary bond dimension doubling
1631    // when adding the first correction via axpby.
1632    let mut x_is_zero = x0.is_none();
1633    let mut x = match x0 {
1634        Some(x) => x.clone(),
1635        None => b.scale(AnyScalar::new_real(0.0))?,
1636    };
1637
1638    let mut total_inner_iters = 0;
1639    let mut prev_residual_norm = f64::INFINITY;
1640
1641    // Inner GMRES options
1642    let inner_options = GmresOptions {
1643        max_iter: options.inner_max_iter,
1644        rtol: options.inner_rtol.unwrap_or(0.1), // Solve loosely by default
1645        max_restarts: options.inner_max_restarts + 1, // +1 because max_restarts=0 means 1 cycle
1646        verbose: options.verbose,
1647        check_true_residual: true, // Always check in restart context to avoid false convergence
1648    };
1649
1650    for outer_iter in 0..options.max_outer_iters {
1651        // Compute true residual: r = b - A*x
1652        let ax = apply_a(&x)?;
1653        // Validate operator output on first outer iteration
1654        if outer_iter == 0 {
1655            ax.validate()?;
1656        }
1657        let mut r = b.axpby(AnyScalar::new_real(1.0), &ax, AnyScalar::new_real(-1.0))?;
1658        truncate(&mut r)?;
1659
1660        let r_norm = r.norm();
1661        let rel_res = r_norm / b_norm;
1662
1663        if options.verbose {
1664            eprintln!(
1665                "Restart GMRES outer iter {}: true residual = {:.6e}",
1666                outer_iter, rel_res
1667            );
1668        }
1669
1670        // Check convergence
1671        if rel_res < options.rtol {
1672            return Ok(RestartGmresResult {
1673                solution: x,
1674                iterations: total_inner_iters,
1675                outer_iterations: outer_iter,
1676                residual_norm: rel_res,
1677                converged: true,
1678            });
1679        }
1680
1681        // Check stagnation
1682        if let Some(min_reduction) = options.min_reduction {
1683            if outer_iter > 0 && rel_res > prev_residual_norm * min_reduction {
1684                if options.verbose {
1685                    eprintln!(
1686                        "Restart GMRES stagnated: residual ratio = {:.6e} > {:.6e}",
1687                        rel_res / prev_residual_norm,
1688                        min_reduction
1689                    );
1690                }
1691                return Ok(RestartGmresResult {
1692                    solution: x,
1693                    iterations: total_inner_iters,
1694                    outer_iterations: outer_iter,
1695                    residual_norm: rel_res,
1696                    converged: false,
1697                });
1698            }
1699        }
1700        prev_residual_norm = rel_res;
1701
1702        // Solve A*x' = r using inner GMRES with zero initial guess
1703        // The zero initial guess is created by scaling r by 0
1704        let zero = r.scale(AnyScalar::new_real(0.0))?;
1705        let inner_result = gmres_with_truncation(&apply_a, &r, &zero, &inner_options, &truncate)?;
1706
1707        total_inner_iters += inner_result.iterations;
1708
1709        if options.verbose {
1710            eprintln!(
1711                "  Inner GMRES: {} iterations, residual = {:.6e}, converged = {}",
1712                inner_result.iterations, inner_result.residual_norm, inner_result.converged
1713            );
1714        }
1715
1716        // Update solution: x = x + x'
1717        // When x is zero (first iteration with no initial guess), use x' directly
1718        // to avoid bond dimension doubling from axpby with a zero tensor.
1719        if x_is_zero {
1720            x = inner_result.solution;
1721            x_is_zero = false;
1722        } else {
1723            x = x.axpby(
1724                AnyScalar::new_real(1.0),
1725                &inner_result.solution,
1726                AnyScalar::new_real(1.0),
1727            )?;
1728        }
1729        truncate(&mut x)?;
1730    }
1731
1732    // Did not converge within max_outer_iters
1733    // Compute final residual
1734    let ax = apply_a(&x)?;
1735    let mut r = b.axpby(AnyScalar::new_real(1.0), &ax, AnyScalar::new_real(-1.0))?;
1736    truncate(&mut r)?;
1737    let final_rel_res = r.norm() / b_norm;
1738
1739    Ok(RestartGmresResult {
1740        solution: x,
1741        iterations: total_inner_iters,
1742        outer_iterations: options.max_outer_iters,
1743        residual_norm: final_rel_res,
1744        converged: false,
1745    })
1746}
1747
1748/// Compute Givens rotation coefficients to eliminate b in (a, b).
1749///
1750/// This function keeps computation in `AnyScalar` space to preserve AD metadata
1751/// as much as possible.
1752fn compute_givens_rotation(a: &AnyScalar, b: &AnyScalar) -> (AnyScalar, AnyScalar) {
1753    let a_abs = a.abs();
1754    let b_abs = b.abs();
1755    let r = (a_abs * a_abs + b_abs * b_abs).sqrt();
1756    if r < 1e-15 {
1757        (AnyScalar::new_real(1.0), AnyScalar::new_real(0.0))
1758    } else if a_abs < 1e-15 {
1759        (
1760            AnyScalar::new_real(0.0),
1761            b.clone().conj() / AnyScalar::new_real(r),
1762        )
1763    } else {
1764        let phase = a.clone() / AnyScalar::new_real(a_abs);
1765        (
1766            AnyScalar::new_real(a_abs / r),
1767            phase * b.clone().conj() / AnyScalar::new_real(r),
1768        )
1769    }
1770}
1771
1772/// Apply Givens rotation: (c, s) @ (x, y) -> (c*x + s*y, -conj(s)*x + c*y) for complex
1773/// or (c*x + s*y, -s*x + c*y) for real.
1774///
1775/// This function keeps computation in `AnyScalar` space to preserve AD metadata
1776/// as much as possible.
1777fn apply_givens_rotation(
1778    c: &AnyScalar,
1779    s: &AnyScalar,
1780    x: &AnyScalar,
1781    y: &AnyScalar,
1782) -> (AnyScalar, AnyScalar) {
1783    let new_x = c.clone() * x.clone() + s.clone() * y.clone();
1784    let new_y = -(s.clone().conj() * x.clone()) + c.clone() * y.clone();
1785    (new_x, new_y)
1786}
1787
1788/// Solve upper triangular system R y = g using back substitution.
1789fn solve_upper_triangular(h: &[Vec<AnyScalar>], g: &[AnyScalar]) -> Result<Vec<AnyScalar>> {
1790    let n = g.len();
1791    if n == 0 {
1792        return Ok(vec![]);
1793    }
1794
1795    let mut y = vec![AnyScalar::new_real(0.0); n];
1796
1797    for i in (0..n).rev() {
1798        let mut sum = g[i].clone();
1799
1800        for j in (i + 1)..n {
1801            // sum = sum - h[j][i] * y[j]
1802            let prod = h[j][i].clone() * y[j].clone();
1803            sum = sum - prod;
1804        }
1805
1806        let h_ii = &h[i][i];
1807        if h_ii.abs() < 1e-15 {
1808            return Err(anyhow::anyhow!(
1809                "Near-singular upper triangular matrix in GMRES"
1810            ));
1811        }
1812
1813        y[i] = sum / h_ii.clone();
1814    }
1815
1816    Ok(y)
1817}
1818
1819/// Update solution: x_new = x + sum_i y_i * v_i
1820fn update_solution<T: TensorVectorSpace>(x: &T, v_basis: &[T], y: &[AnyScalar]) -> Result<T> {
1821    let mut result = x.clone();
1822
1823    for (vi, yi) in v_basis.iter().zip(y.iter()) {
1824        let scaled_vi = vi.scale(yi.clone())?;
1825        // result = result + scaled_vi = 1.0 * result + 1.0 * scaled_vi
1826        result = result.axpby(
1827            AnyScalar::new_real(1.0),
1828            &scaled_vi,
1829            AnyScalar::new_real(1.0),
1830        )?;
1831    }
1832
1833    Ok(result)
1834}
1835
1836/// Update solution with truncation: x_new = truncate(x + sum_i y_i * v_i)
1837fn update_solution_truncated<T, Tr>(
1838    x: &T,
1839    v_basis: &[T],
1840    y: &[AnyScalar],
1841    truncate: &Tr,
1842) -> Result<T>
1843where
1844    T: TensorVectorSpace,
1845    Tr: Fn(&mut T) -> Result<()>,
1846{
1847    let mut result = x.clone();
1848    // Detect if x is effectively zero.
1849    // When x is created via scale(0.0), it preserves the original bond structure
1850    // (e.g., bond dim 4), causing axpby to double bond dimensions unnecessarily.
1851    // By detecting zero, we can use scaled_vi directly, avoiding the doubling.
1852    let mut result_is_zero = x.norm() == 0.0;
1853
1854    for (vi, yi) in v_basis.iter().zip(y.iter()) {
1855        let scaled_vi = vi.scale(yi.clone())?;
1856        if result_is_zero {
1857            result = scaled_vi;
1858            result_is_zero = false;
1859        } else {
1860            result = result.axpby(
1861                AnyScalar::new_real(1.0),
1862                &scaled_vi,
1863                AnyScalar::new_real(1.0),
1864            )?;
1865        }
1866        // Truncate after each addition to control bond dimension growth
1867        truncate(&mut result)?;
1868    }
1869
1870    Ok(result)
1871}
1872
1873#[cfg(test)]
1874mod tests;