tensor4all_core/krylov.rs
1//! Krylov subspace methods for solving linear equations with abstract tensors.
2//!
3//! This module provides iterative solvers that work with any type implementing [`TensorLike`],
4//! enabling their use in tensor network algorithms without requiring dense vector representations.
5//!
6//! # Solvers
7//!
8//! - [`gmres`]: Generalized Minimal Residual Method (GMRES) for non-symmetric systems
9//!
10//! # Future Extensions
11//!
12//! - CG (Conjugate Gradient) for symmetric positive definite systems
13//! - BiCGSTAB for non-symmetric systems with better convergence properties
14//!
15//! # Example
16//!
17//! ```
18//! use tensor4all_core::{
19//! krylov::{gmres, GmresOptions},
20//! DynIndex, TensorDynLen, TensorLike,
21//! };
22//!
23//! # fn main() -> anyhow::Result<()> {
24//! let i = DynIndex::new_dyn(2);
25//! let rhs = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, -1.0])?;
26//! let initial_guess = TensorDynLen::from_dense(vec![i.clone()], vec![0.0, 0.0])?;
27//!
28//! let apply_operator = |x: &TensorDynLen| Ok(x.clone());
29//! let result = gmres(apply_operator, &rhs, &initial_guess, &GmresOptions::default())?;
30//!
31//! assert!(result.converged);
32//! assert!(result.solution.sub(&rhs)?.maxabs() < 1e-12);
33//! # Ok(())
34//! # }
35//! ```
36
37use crate::any_scalar::AnyScalar;
38use crate::TensorLike;
39use anyhow::Result;
40
41/// Options for GMRES solver.
42///
43/// # Examples
44///
45/// ```
46/// use tensor4all_core::krylov::GmresOptions;
47///
48/// let opts = GmresOptions {
49/// max_iter: 50,
50/// rtol: 1e-8,
51/// max_restarts: 5,
52/// verbose: false,
53/// check_true_residual: true,
54/// };
55/// assert_eq!(opts.max_iter, 50);
56/// assert_eq!(opts.rtol, 1e-8);
57/// ```
58#[derive(Debug, Clone)]
59pub struct GmresOptions {
60 /// Maximum number of iterations (restart cycle length).
61 /// Default: 100
62 pub max_iter: usize,
63
64 /// Convergence tolerance for relative residual norm.
65 /// The solver stops when `||r|| / ||b|| < rtol`.
66 /// Default: 1e-10
67 pub rtol: f64,
68
69 /// Maximum number of restarts.
70 /// Total iterations = max_iter * max_restarts.
71 /// Default: 10
72 pub max_restarts: usize,
73
74 /// Whether to print convergence information.
75 /// Default: false
76 pub verbose: bool,
77
78 /// When true, verify convergence by computing the true residual `||b - A*x|| / ||b||`
79 /// before declaring convergence. This prevents false convergence caused by
80 /// truncation corrupting the Krylov basis orthogonality (see Issue #207).
81 /// Costs one additional `apply_a` call when convergence is detected.
82 /// Default: false
83 pub check_true_residual: bool,
84}
85
86impl Default for GmresOptions {
87 fn default() -> Self {
88 Self {
89 max_iter: 100,
90 rtol: 1e-10,
91 max_restarts: 10,
92 verbose: false,
93 check_true_residual: false,
94 }
95 }
96}
97
98/// Result of GMRES solver.
99///
100/// Contains the solution, iteration count, final residual norm, and
101/// convergence status.
102///
103/// # Examples
104///
105/// ```
106/// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike};
107/// use tensor4all_core::krylov::{gmres, GmresOptions};
108///
109/// let i = DynIndex::new_dyn(2);
110/// let b = TensorDynLen::from_dense(vec![i.clone()], vec![3.0, 7.0]).unwrap();
111/// let x0 = TensorDynLen::from_dense(vec![i.clone()], vec![0.0, 0.0]).unwrap();
112///
113/// let result = gmres(|x: &TensorDynLen| Ok(x.clone()), &b, &x0, &GmresOptions::default()).unwrap();
114/// assert!(result.converged);
115/// assert!(result.residual_norm < 1e-10);
116/// ```
117#[derive(Debug, Clone)]
118pub struct GmresResult<T> {
119 /// The solution vector.
120 pub solution: T,
121
122 /// Number of iterations performed.
123 pub iterations: usize,
124
125 /// Final relative residual norm.
126 pub residual_norm: f64,
127
128 /// Whether the solver converged.
129 pub converged: bool,
130}
131
132/// Solve `A x = b` using GMRES (Generalized Minimal Residual Method).
133///
134/// This implements the restarted GMRES algorithm that works with abstract tensor types
135/// through the [`TensorLike`] trait's vector space operations.
136///
137/// # Algorithm
138///
139/// GMRES builds an orthonormal basis for the Krylov subspace
140/// `K_m = span{r_0, A r_0, A^2 r_0, ..., A^{m-1} r_0}` and finds the
141/// solution that minimizes `||b - A x||` over this subspace.
142///
143/// # Type Parameters
144///
145/// * `T` - A tensor type implementing `TensorLike`
146/// * `F` - A function that applies the linear operator: `F(x) = A x`
147///
148/// # Arguments
149///
150/// * `apply_a` - Function that applies the linear operator A to a tensor
151/// * `b` - Right-hand side tensor
152/// * `x0` - Initial guess
153/// * `options` - Solver options
154///
155/// # Returns
156///
157/// A `GmresResult` containing the solution and convergence information.
158///
159/// # Errors
160///
161/// Returns an error if:
162/// - Vector space operations (add, sub, scale, inner_product) fail
163/// - The linear operator application fails
164pub fn gmres<T, F>(apply_a: F, b: &T, x0: &T, options: &GmresOptions) -> Result<GmresResult<T>>
165where
166 T: TensorLike,
167 F: Fn(&T) -> Result<T>,
168{
169 // Validate structural consistency of inputs
170 b.validate()?;
171 x0.validate()?;
172
173 let b_norm = b.norm();
174 if b_norm < 1e-15 {
175 // b is effectively zero, return x0
176 return Ok(GmresResult {
177 solution: x0.clone(),
178 iterations: 0,
179 residual_norm: 0.0,
180 converged: true,
181 });
182 }
183
184 let mut x = x0.clone();
185 let mut total_iters = 0;
186
187 for _restart in 0..options.max_restarts {
188 // Compute initial residual: r = b - A*x
189 let ax = apply_a(&x)?;
190 // Validate operator output on first restart
191 if _restart == 0 {
192 ax.validate()?;
193 }
194 // r = 1.0 * b + (-1.0) * ax
195 let r = b.axpby(AnyScalar::new_real(1.0), &ax, AnyScalar::new_real(-1.0))?;
196 let r_norm = r.norm();
197 let rel_res = r_norm / b_norm;
198
199 if options.verbose {
200 eprintln!(
201 "GMRES restart {}: initial residual = {:.6e}",
202 _restart, rel_res
203 );
204 }
205
206 if rel_res < options.rtol {
207 return Ok(GmresResult {
208 solution: x,
209 iterations: total_iters,
210 residual_norm: rel_res,
211 converged: true,
212 });
213 }
214
215 // Arnoldi process with modified Gram-Schmidt
216 let mut v_basis: Vec<T> = Vec::with_capacity(options.max_iter + 1);
217 let mut h_matrix: Vec<Vec<AnyScalar>> = Vec::with_capacity(options.max_iter);
218
219 // v_0 = r / ||r||
220 let v0 = r.scale(AnyScalar::new_real(1.0 / r_norm))?;
221 v_basis.push(v0);
222
223 // Initialize Givens rotation storage
224 let mut cs: Vec<AnyScalar> = Vec::with_capacity(options.max_iter);
225 let mut sn: Vec<AnyScalar> = Vec::with_capacity(options.max_iter);
226 let mut g: Vec<AnyScalar> = vec![AnyScalar::new_real(r_norm)]; // residual in upper Hessenberg space
227
228 for j in 0..options.max_iter {
229 total_iters += 1;
230
231 // w = A * v_j
232 let w = apply_a(&v_basis[j])?;
233
234 // Modified Gram-Schmidt orthogonalization
235 let mut h_col: Vec<AnyScalar> = Vec::with_capacity(j + 2);
236 let mut w_orth = w;
237
238 for v_i in v_basis.iter().take(j + 1) {
239 let h_ij = v_i.inner_product(&w_orth)?;
240 h_col.push(h_ij.clone());
241 // w_orth = w_orth - h_ij * v_i = 1.0 * w_orth + (-h_ij) * v_i
242 let neg_h_ij = AnyScalar::new_real(0.0) - h_ij;
243 w_orth = w_orth.axpby(AnyScalar::new_real(1.0), v_i, neg_h_ij)?;
244 }
245
246 let h_jp1_j_real = w_orth.norm();
247 let h_jp1_j = AnyScalar::new_real(h_jp1_j_real);
248 h_col.push(h_jp1_j);
249
250 // Apply previous Givens rotations to new column
251 #[allow(clippy::needless_range_loop)]
252 for i in 0..j {
253 let h_i = h_col[i].clone();
254 let h_ip1 = h_col[i + 1].clone();
255 let (new_hi, new_hip1) = apply_givens_rotation(&cs[i], &sn[i], &h_i, &h_ip1);
256 h_col[i] = new_hi;
257 h_col[i + 1] = new_hip1;
258 }
259
260 // Compute new Givens rotation for h_col[j] and h_col[j+1]
261 let (c_j, s_j) = compute_givens_rotation(&h_col[j], &h_col[j + 1]);
262 cs.push(c_j.clone());
263 sn.push(s_j.clone());
264
265 // Apply new rotation to eliminate h_col[j+1]
266 let (new_hj, _) = apply_givens_rotation(&c_j, &s_j, &h_col[j], &h_col[j + 1]);
267 h_col[j] = new_hj;
268 h_col[j + 1] = AnyScalar::new_real(0.0);
269
270 // Apply rotation to g
271 let g_j = g[j].clone();
272 let g_jp1 = AnyScalar::new_real(0.0);
273 let (new_gj, new_gjp1) = apply_givens_rotation(&c_j, &s_j, &g_j, &g_jp1);
274 g[j] = new_gj;
275 let res_norm = new_gjp1.abs();
276 g.push(new_gjp1);
277
278 h_matrix.push(h_col);
279
280 // Check convergence
281 let rel_res = res_norm / b_norm;
282
283 if options.verbose {
284 eprintln!("GMRES iter {}: residual = {:.6e}", j + 1, rel_res);
285 }
286
287 if rel_res < options.rtol {
288 // Solve upper triangular system and update x
289 let y = solve_upper_triangular(&h_matrix, &g[..=j])?;
290 x = update_solution(&x, &v_basis[..=j], &y)?;
291 return Ok(GmresResult {
292 solution: x,
293 iterations: total_iters,
294 residual_norm: rel_res,
295 converged: true,
296 });
297 }
298
299 // Add new basis vector (if not converged and h_jp1_j is not too small)
300 if h_jp1_j_real > 1e-14 {
301 let v_jp1 = w_orth.scale(AnyScalar::new_real(1.0 / h_jp1_j_real))?;
302 v_basis.push(v_jp1);
303 } else {
304 // Lucky breakdown - we've found the exact solution in the Krylov subspace
305 let y = solve_upper_triangular(&h_matrix, &g[..=j])?;
306 x = update_solution(&x, &v_basis[..=j], &y)?;
307 let ax_final = apply_a(&x)?;
308 let r_final = b.axpby(
309 AnyScalar::new_real(1.0),
310 &ax_final,
311 AnyScalar::new_real(-1.0),
312 )?;
313 let final_res = r_final.norm() / b_norm;
314 return Ok(GmresResult {
315 solution: x,
316 iterations: total_iters,
317 residual_norm: final_res,
318 converged: final_res < options.rtol,
319 });
320 }
321 }
322
323 // End of restart cycle - update x with current solution
324 let y = solve_upper_triangular(&h_matrix, &g[..options.max_iter])?;
325 x = update_solution(&x, &v_basis[..options.max_iter], &y)?;
326 }
327
328 // Compute final residual
329 let ax_final = apply_a(&x)?;
330 let r_final = b.axpby(
331 AnyScalar::new_real(1.0),
332 &ax_final,
333 AnyScalar::new_real(-1.0),
334 )?;
335 let final_res = r_final.norm() / b_norm;
336
337 Ok(GmresResult {
338 solution: x,
339 iterations: total_iters,
340 residual_norm: final_res,
341 converged: final_res < options.rtol,
342 })
343}
344
345/// Solve `A x = b` using GMRES with optional truncation after each iteration.
346///
347/// This is an extension of [`gmres`] that allows truncating Krylov basis vectors
348/// to control bond dimension growth in tensor network representations.
349///
350/// # Type Parameters
351///
352/// * `T` - A tensor type implementing `TensorLike`
353/// * `F` - A function that applies the linear operator: `F(x) = A x`
354/// * `Tr` - A function that truncates a tensor in-place: `Tr(&mut x)`
355///
356/// # Arguments
357///
358/// * `apply_a` - Function that applies the linear operator A to a tensor
359/// * `b` - Right-hand side tensor
360/// * `x0` - Initial guess
361/// * `options` - Solver options
362/// * `truncate` - Function that truncates a tensor to control bond dimension
363///
364/// # Note
365///
366/// Truncation is applied after each Gram-Schmidt orthogonalization step
367/// and after the final solution update. This helps control the bond dimension
368/// growth that would otherwise occur in MPS/MPO representations.
369///
370/// # Examples
371///
372/// Solve `2x = b` with a no-op truncation function:
373///
374/// ```
375/// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike, AnyScalar};
376/// use tensor4all_core::krylov::{gmres_with_truncation, GmresOptions};
377///
378/// let i = DynIndex::new_dyn(2);
379/// let b = TensorDynLen::from_dense(vec![i.clone()], vec![4.0, 6.0]).unwrap();
380/// let x0 = TensorDynLen::from_dense(vec![i.clone()], vec![0.0, 0.0]).unwrap();
381///
382/// // Operator A = 2*I (scales input by 2)
383/// let apply_a = |x: &TensorDynLen| x.scale(AnyScalar::new_real(2.0));
384/// // No-op truncation
385/// let truncate = |_x: &mut TensorDynLen| Ok(());
386///
387/// let result = gmres_with_truncation(apply_a, &b, &x0, &GmresOptions::default(), truncate).unwrap();
388/// assert!(result.converged);
389/// // Solution should be [2.0, 3.0]
390/// let expected = TensorDynLen::from_dense(vec![i], vec![2.0, 3.0]).unwrap();
391/// assert!(result.solution.sub(&expected).unwrap().maxabs() < 1e-8);
392/// ```
393pub fn gmres_with_truncation<T, F, Tr>(
394 apply_a: F,
395 b: &T,
396 x0: &T,
397 options: &GmresOptions,
398 truncate: Tr,
399) -> Result<GmresResult<T>>
400where
401 T: TensorLike,
402 F: Fn(&T) -> Result<T>,
403 Tr: Fn(&mut T) -> Result<()>,
404{
405 // Validate structural consistency of inputs
406 b.validate()?;
407 x0.validate()?;
408
409 let b_norm = b.norm();
410 if b_norm < 1e-15 {
411 return Ok(GmresResult {
412 solution: x0.clone(),
413 iterations: 0,
414 residual_norm: 0.0,
415 converged: true,
416 });
417 }
418
419 let mut x = x0.clone();
420 let mut total_iters = 0;
421
422 for _restart in 0..options.max_restarts {
423 let ax = apply_a(&x)?;
424 // Validate operator output on first restart
425 if _restart == 0 {
426 ax.validate()?;
427 }
428 let mut r = b.axpby(AnyScalar::new_real(1.0), &ax, AnyScalar::new_real(-1.0))?;
429 truncate(&mut r)?;
430 let r_norm = r.norm();
431 let rel_res = r_norm / b_norm;
432
433 if options.verbose {
434 eprintln!(
435 "GMRES restart {}: initial residual = {:.6e}",
436 _restart, rel_res
437 );
438 }
439
440 if rel_res < options.rtol {
441 return Ok(GmresResult {
442 solution: x,
443 iterations: total_iters,
444 residual_norm: rel_res,
445 converged: true,
446 });
447 }
448
449 let mut v_basis: Vec<T> = Vec::with_capacity(options.max_iter + 1);
450 let mut h_matrix: Vec<Vec<AnyScalar>> = Vec::with_capacity(options.max_iter);
451
452 let mut v0 = r.scale(AnyScalar::new_real(1.0 / r_norm))?;
453 truncate(&mut v0)?;
454 // After truncation, v0 might not be unit norm and might point in a different direction.
455 // We need to:
456 // 1. Renormalize v0 to unit norm for numerical stability
457 // 2. Recompute g[0] = <r, v0> to maintain the correct relationship
458 let v0_norm = v0.norm();
459 let effective_g0 = if v0_norm > 1e-15 {
460 v0 = v0.scale(AnyScalar::new_real(1.0 / v0_norm))?;
461 // g[0] should be the component of r in the direction of v0
462 // Since r was truncated and v0 = truncate(r/||r||)/||truncate(r/||r||)||,
463 // g[0] = <r, v0> ≈ ||r|| * ||truncate(r/||r||)|| = r_norm * v0_norm
464 r_norm * v0_norm
465 } else {
466 r_norm
467 };
468 v_basis.push(v0);
469
470 let mut cs: Vec<AnyScalar> = Vec::with_capacity(options.max_iter);
471 let mut sn: Vec<AnyScalar> = Vec::with_capacity(options.max_iter);
472 let mut g: Vec<AnyScalar> = vec![AnyScalar::new_real(effective_g0)];
473 let mut solution_already_updated = false;
474
475 for j in 0..options.max_iter {
476 total_iters += 1;
477
478 let w = apply_a(&v_basis[j])?;
479
480 let mut h_col: Vec<AnyScalar> = Vec::with_capacity(j + 2);
481 let mut w_orth = w;
482
483 for v_i in v_basis.iter().take(j + 1) {
484 let h_ij = v_i.inner_product(&w_orth)?;
485 h_col.push(h_ij.clone());
486 let neg_h_ij = AnyScalar::new_real(0.0) - h_ij;
487 w_orth = w_orth.axpby(AnyScalar::new_real(1.0), v_i, neg_h_ij)?;
488 }
489
490 // Iterative reorthogonalization with truncation
491 // Truncation can change the direction of w_orth, breaking orthogonality.
492 // We iterate until all corrections are below a threshold to ensure
493 // the Krylov basis remains orthogonal despite truncation.
494 const REORTH_THRESHOLD: f64 = 1e-12;
495 const MAX_REORTH_ITERS: usize = 10;
496
497 let mut reorth_iter_count = 0;
498 for reorth_iter in 0..MAX_REORTH_ITERS {
499 reorth_iter_count = reorth_iter + 1;
500 let norm_before_truncate = w_orth.norm();
501 truncate(&mut w_orth)?;
502 let norm_after_truncate = w_orth.norm();
503
504 let mut max_correction = 0.0;
505 for (i, v_i) in v_basis.iter().enumerate() {
506 let correction = v_i.inner_product(&w_orth)?;
507 let correction_abs = correction.abs();
508 if correction_abs > max_correction {
509 max_correction = correction_abs;
510 }
511 if correction_abs > REORTH_THRESHOLD {
512 let neg_correction = AnyScalar::new_real(0.0) - correction.clone();
513 w_orth = w_orth.axpby(AnyScalar::new_real(1.0), v_i, neg_correction)?;
514 // Update Hessenberg matrix entry to include correction
515 h_col[i] = h_col[i].clone() + correction;
516 }
517 }
518
519 if options.verbose {
520 eprintln!(
521 " reorth iter {}: norm {:.6e} -> {:.6e}, max_correction = {:.6e}",
522 reorth_iter, norm_before_truncate, norm_after_truncate, max_correction
523 );
524 }
525
526 // If all corrections are small enough, we're done
527 if max_correction < REORTH_THRESHOLD {
528 break;
529 }
530 }
531
532 if options.verbose && reorth_iter_count > 1 {
533 eprintln!(" (needed {} reorth iterations)", reorth_iter_count);
534 }
535
536 let h_jp1_j_real = w_orth.norm();
537 let h_jp1_j = AnyScalar::new_real(h_jp1_j_real);
538 h_col.push(h_jp1_j);
539
540 #[allow(clippy::needless_range_loop)]
541 for i in 0..j {
542 let h_i = h_col[i].clone();
543 let h_ip1 = h_col[i + 1].clone();
544 let (new_hi, new_hip1) = apply_givens_rotation(&cs[i], &sn[i], &h_i, &h_ip1);
545 h_col[i] = new_hi;
546 h_col[i + 1] = new_hip1;
547 }
548
549 let (c_j, s_j) = compute_givens_rotation(&h_col[j], &h_col[j + 1]);
550 cs.push(c_j.clone());
551 sn.push(s_j.clone());
552
553 let (new_hj, _) = apply_givens_rotation(&c_j, &s_j, &h_col[j], &h_col[j + 1]);
554 h_col[j] = new_hj;
555 h_col[j + 1] = AnyScalar::new_real(0.0);
556
557 let g_j = g[j].clone();
558 let g_jp1 = AnyScalar::new_real(0.0);
559 let (new_gj, new_gjp1) = apply_givens_rotation(&c_j, &s_j, &g_j, &g_jp1);
560 g[j] = new_gj;
561 let res_norm = new_gjp1.abs();
562 g.push(new_gjp1);
563
564 h_matrix.push(h_col);
565
566 let rel_res = res_norm / b_norm;
567
568 if options.verbose {
569 eprintln!("GMRES iter {}: residual = {:.6e}", j + 1, rel_res);
570 }
571
572 if rel_res < options.rtol {
573 let y = solve_upper_triangular(&h_matrix, &g[..=j])?;
574 x = update_solution_truncated(&x, &v_basis[..=j], &y, &truncate)?;
575
576 if options.check_true_residual {
577 // Verify with true residual to prevent false convergence
578 let ax_check = apply_a(&x)?;
579 let mut r_check = b.axpby(
580 AnyScalar::new_real(1.0),
581 &ax_check,
582 AnyScalar::new_real(-1.0),
583 )?;
584 truncate(&mut r_check)?;
585 let true_rel_res = r_check.norm() / b_norm;
586
587 if options.verbose {
588 eprintln!(
589 "GMRES true residual check: hessenberg={:.6e}, checked={:.6e}",
590 rel_res, true_rel_res
591 );
592 }
593
594 if true_rel_res < options.rtol {
595 return Ok(GmresResult {
596 solution: x,
597 iterations: total_iters,
598 residual_norm: true_rel_res,
599 converged: true,
600 });
601 }
602 // False convergence detected: x is already updated above,
603 // so skip the end-of-cycle update and go to next restart
604 solution_already_updated = true;
605 break;
606 } else {
607 return Ok(GmresResult {
608 solution: x,
609 iterations: total_iters,
610 residual_norm: rel_res,
611 converged: true,
612 });
613 }
614 }
615
616 if h_jp1_j_real > 1e-14 {
617 // Create v_{j+1} = w_orth / ||w_orth||
618 // w_orth has already been truncated twice (after orthogonalization and after reorthogonalization)
619 // so we don't need to truncate again. Scale doesn't increase bond dimensions.
620 let v_jp1 = w_orth.scale(AnyScalar::new_real(1.0 / h_jp1_j_real))?;
621 // v_jp1 should have norm ~1.0 by construction
622 // The Arnoldi relation h_{j+1,j} * v_{j+1} = w_orth is maintained exactly
623 v_basis.push(v_jp1);
624 } else {
625 let y = solve_upper_triangular(&h_matrix, &g[..=j])?;
626 x = update_solution_truncated(&x, &v_basis[..=j], &y, &truncate)?;
627 let ax_final = apply_a(&x)?;
628 let r_final = b.axpby(
629 AnyScalar::new_real(1.0),
630 &ax_final,
631 AnyScalar::new_real(-1.0),
632 )?;
633 let final_res = r_final.norm() / b_norm;
634 return Ok(GmresResult {
635 solution: x,
636 iterations: total_iters,
637 residual_norm: final_res,
638 converged: final_res < options.rtol,
639 });
640 }
641 }
642
643 if !solution_already_updated {
644 let actual_iters = v_basis.len().min(options.max_iter);
645 let y = solve_upper_triangular(&h_matrix, &g[..actual_iters])?;
646 x = update_solution_truncated(&x, &v_basis[..actual_iters], &y, &truncate)?;
647 }
648 }
649
650 let ax_final = apply_a(&x)?;
651 let r_final = b.axpby(
652 AnyScalar::new_real(1.0),
653 &ax_final,
654 AnyScalar::new_real(-1.0),
655 )?;
656 let final_res = r_final.norm() / b_norm;
657
658 Ok(GmresResult {
659 solution: x,
660 iterations: total_iters,
661 residual_norm: final_res,
662 converged: final_res < options.rtol,
663 })
664}
665
666/// Options for restarted GMRES with truncation.
667///
668/// This is used by [`restart_gmres_with_truncation`] which wraps the standard GMRES
669/// with an outer loop that recomputes the true residual at each restart.
670///
671/// # Examples
672///
673/// ```
674/// use tensor4all_core::krylov::RestartGmresOptions;
675///
676/// let opts = RestartGmresOptions::new()
677/// .with_max_outer_iters(10)
678/// .with_rtol(1e-6)
679/// .with_inner_max_iter(20)
680/// .with_inner_max_restarts(2)
681/// .with_min_reduction(0.99)
682/// .with_inner_rtol(0.01)
683/// .with_verbose(false);
684///
685/// assert_eq!(opts.max_outer_iters, 10);
686/// assert_eq!(opts.rtol, 1e-6);
687/// assert_eq!(opts.inner_max_iter, 20);
688/// assert_eq!(opts.inner_max_restarts, 2);
689/// assert_eq!(opts.min_reduction, Some(0.99));
690/// assert_eq!(opts.inner_rtol, Some(0.01));
691/// ```
692#[derive(Debug, Clone)]
693pub struct RestartGmresOptions {
694 /// Maximum number of outer restart iterations.
695 /// Default: 20
696 pub max_outer_iters: usize,
697
698 /// Convergence tolerance for relative residual norm (based on true residual).
699 /// The solver stops when `||b - A*x|| / ||b|| < rtol`.
700 /// Default: 1e-10
701 pub rtol: f64,
702
703 /// Maximum iterations per inner GMRES cycle.
704 /// Default: 10
705 pub inner_max_iter: usize,
706
707 /// Number of restarts within each inner GMRES (usually 0).
708 /// Default: 0
709 pub inner_max_restarts: usize,
710
711 /// Stagnation detection threshold.
712 /// If the residual reduction ratio exceeds this value (i.e., residual doesn't decrease enough),
713 /// the solver considers it stagnated.
714 /// For example, 0.99 means stagnation is detected when residual decreases by less than 1%.
715 /// Default: None (no stagnation detection)
716 pub min_reduction: Option<f64>,
717
718 /// Inner GMRES relative tolerance.
719 /// If None, uses 0.1 (solve inner problem loosely).
720 /// Default: None
721 pub inner_rtol: Option<f64>,
722
723 /// Whether to print convergence information.
724 /// Default: false
725 pub verbose: bool,
726}
727
728impl Default for RestartGmresOptions {
729 fn default() -> Self {
730 Self {
731 max_outer_iters: 20,
732 rtol: 1e-10,
733 inner_max_iter: 10,
734 inner_max_restarts: 0,
735 min_reduction: None,
736 inner_rtol: None,
737 verbose: false,
738 }
739 }
740}
741
742impl RestartGmresOptions {
743 /// Create new options with default values.
744 pub fn new() -> Self {
745 Self::default()
746 }
747
748 /// Set maximum number of outer iterations.
749 pub fn with_max_outer_iters(mut self, max_outer_iters: usize) -> Self {
750 self.max_outer_iters = max_outer_iters;
751 self
752 }
753
754 /// Set convergence tolerance.
755 pub fn with_rtol(mut self, rtol: f64) -> Self {
756 self.rtol = rtol;
757 self
758 }
759
760 /// Set maximum iterations per inner GMRES cycle.
761 pub fn with_inner_max_iter(mut self, inner_max_iter: usize) -> Self {
762 self.inner_max_iter = inner_max_iter;
763 self
764 }
765
766 /// Set number of restarts within each inner GMRES.
767 pub fn with_inner_max_restarts(mut self, inner_max_restarts: usize) -> Self {
768 self.inner_max_restarts = inner_max_restarts;
769 self
770 }
771
772 /// Set stagnation detection threshold.
773 pub fn with_min_reduction(mut self, min_reduction: f64) -> Self {
774 self.min_reduction = Some(min_reduction);
775 self
776 }
777
778 /// Set inner GMRES relative tolerance.
779 pub fn with_inner_rtol(mut self, inner_rtol: f64) -> Self {
780 self.inner_rtol = Some(inner_rtol);
781 self
782 }
783
784 /// Enable verbose output.
785 pub fn with_verbose(mut self, verbose: bool) -> Self {
786 self.verbose = verbose;
787 self
788 }
789}
790
791/// Result of restarted GMRES solver.
792///
793/// # Examples
794///
795/// ```
796/// use tensor4all_core::{DynIndex, TensorDynLen, AnyScalar};
797/// use tensor4all_core::krylov::{restart_gmres_with_truncation, RestartGmresOptions};
798///
799/// let i = DynIndex::new_dyn(2);
800/// let b = TensorDynLen::from_dense(vec![i.clone()], vec![3.0, 5.0]).unwrap();
801///
802/// let apply_a = |x: &TensorDynLen| x.scale(AnyScalar::new_real(3.0));
803/// let truncate = |_x: &mut TensorDynLen| Ok(());
804///
805/// let result = restart_gmres_with_truncation(
806/// apply_a, &b, None, &RestartGmresOptions::default(), truncate,
807/// ).unwrap();
808///
809/// assert!(result.converged);
810/// assert!(result.residual_norm < 1e-10);
811/// assert!(result.outer_iterations <= 20);
812/// ```
813#[derive(Debug, Clone)]
814pub struct RestartGmresResult<T> {
815 /// The solution vector.
816 pub solution: T,
817
818 /// Total number of inner GMRES iterations performed.
819 pub iterations: usize,
820
821 /// Number of outer restart iterations performed.
822 pub outer_iterations: usize,
823
824 /// Final relative residual norm (true residual).
825 pub residual_norm: f64,
826
827 /// Whether the solver converged.
828 pub converged: bool,
829}
830
831/// Solve `A x = b` using restarted GMRES with truncation.
832///
833/// This wraps [`gmres_with_truncation`] with an outer loop that recomputes the true residual
834/// at each restart. This is particularly useful for MPS/MPO computations where truncation
835/// can cause the inner GMRES residual to be inaccurate.
836///
837/// # Algorithm
838///
839/// ```text
840/// for outer_iter in 0..max_outer_iters:
841/// r = b - A*x0 // Compute true residual
842/// r = truncate(r)
843/// if ||r|| / ||b|| < rtol:
844/// return x0 // Converged
845/// x' = gmres_with_truncation(A, r, 0, inner_options, truncate)
846/// x0 = truncate(x0 + x')
847/// ```
848///
849/// # Type Parameters
850///
851/// * `T` - A tensor type implementing `TensorLike`
852/// * `F` - A function that applies the linear operator: `F(x) = A x`
853/// * `Tr` - A function that truncates a tensor in-place: `Tr(&mut x)`
854///
855/// # Arguments
856///
857/// * `apply_a` - Function that applies the linear operator A to a tensor
858/// * `b` - Right-hand side tensor
859/// * `x0` - Initial guess (if None, starts from zero)
860/// * `options` - Solver options
861/// * `truncate` - Function that truncates a tensor to control bond dimension
862///
863/// # Returns
864///
865/// A `RestartGmresResult` containing the solution and convergence information.
866///
867/// # Examples
868///
869/// Solve `5x = b` with no truncation:
870///
871/// ```
872/// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike, AnyScalar};
873/// use tensor4all_core::krylov::{restart_gmres_with_truncation, RestartGmresOptions};
874///
875/// let i = DynIndex::new_dyn(3);
876/// let b = TensorDynLen::from_dense(vec![i.clone()], vec![5.0, 10.0, 15.0]).unwrap();
877///
878/// let apply_a = |x: &TensorDynLen| x.scale(AnyScalar::new_real(5.0));
879/// let truncate = |_x: &mut TensorDynLen| Ok(());
880///
881/// let result = restart_gmres_with_truncation(
882/// apply_a, &b, None, &RestartGmresOptions::default(), truncate,
883/// ).unwrap();
884///
885/// assert!(result.converged);
886/// let expected = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0, 3.0]).unwrap();
887/// assert!(result.solution.sub(&expected).unwrap().maxabs() < 1e-8);
888/// ```
889pub fn restart_gmres_with_truncation<T, F, Tr>(
890 apply_a: F,
891 b: &T,
892 x0: Option<&T>,
893 options: &RestartGmresOptions,
894 truncate: Tr,
895) -> Result<RestartGmresResult<T>>
896where
897 T: TensorLike,
898 F: Fn(&T) -> Result<T>,
899 Tr: Fn(&mut T) -> Result<()>,
900{
901 // Validate structural consistency of inputs
902 b.validate()?;
903 if let Some(x) = x0 {
904 x.validate()?;
905 }
906
907 let b_norm = b.norm();
908 if b_norm < 1e-15 {
909 // b is effectively zero, return x0 or zero
910 let solution = match x0 {
911 Some(x) => x.clone(),
912 None => b.scale(AnyScalar::new_real(0.0))?,
913 };
914 return Ok(RestartGmresResult {
915 solution,
916 iterations: 0,
917 outer_iterations: 0,
918 residual_norm: 0.0,
919 converged: true,
920 });
921 }
922
923 // Initialize x: use x0 if provided, otherwise start from zero.
924 // Track whether x is zero to avoid unnecessary bond dimension doubling
925 // when adding the first correction via axpby.
926 let mut x_is_zero = x0.is_none();
927 let mut x = match x0 {
928 Some(x) => x.clone(),
929 None => b.scale(AnyScalar::new_real(0.0))?,
930 };
931
932 let mut total_inner_iters = 0;
933 let mut prev_residual_norm = f64::INFINITY;
934
935 // Inner GMRES options
936 let inner_options = GmresOptions {
937 max_iter: options.inner_max_iter,
938 rtol: options.inner_rtol.unwrap_or(0.1), // Solve loosely by default
939 max_restarts: options.inner_max_restarts + 1, // +1 because max_restarts=0 means 1 cycle
940 verbose: options.verbose,
941 check_true_residual: true, // Always check in restart context to avoid false convergence
942 };
943
944 for outer_iter in 0..options.max_outer_iters {
945 // Compute true residual: r = b - A*x
946 let ax = apply_a(&x)?;
947 // Validate operator output on first outer iteration
948 if outer_iter == 0 {
949 ax.validate()?;
950 }
951 let mut r = b.axpby(AnyScalar::new_real(1.0), &ax, AnyScalar::new_real(-1.0))?;
952 truncate(&mut r)?;
953
954 let r_norm = r.norm();
955 let rel_res = r_norm / b_norm;
956
957 if options.verbose {
958 eprintln!(
959 "Restart GMRES outer iter {}: true residual = {:.6e}",
960 outer_iter, rel_res
961 );
962 }
963
964 // Check convergence
965 if rel_res < options.rtol {
966 return Ok(RestartGmresResult {
967 solution: x,
968 iterations: total_inner_iters,
969 outer_iterations: outer_iter,
970 residual_norm: rel_res,
971 converged: true,
972 });
973 }
974
975 // Check stagnation
976 if let Some(min_reduction) = options.min_reduction {
977 if outer_iter > 0 && rel_res > prev_residual_norm * min_reduction {
978 if options.verbose {
979 eprintln!(
980 "Restart GMRES stagnated: residual ratio = {:.6e} > {:.6e}",
981 rel_res / prev_residual_norm,
982 min_reduction
983 );
984 }
985 return Ok(RestartGmresResult {
986 solution: x,
987 iterations: total_inner_iters,
988 outer_iterations: outer_iter,
989 residual_norm: rel_res,
990 converged: false,
991 });
992 }
993 }
994 prev_residual_norm = rel_res;
995
996 // Solve A*x' = r using inner GMRES with zero initial guess
997 // The zero initial guess is created by scaling r by 0
998 let zero = r.scale(AnyScalar::new_real(0.0))?;
999 let inner_result = gmres_with_truncation(&apply_a, &r, &zero, &inner_options, &truncate)?;
1000
1001 total_inner_iters += inner_result.iterations;
1002
1003 if options.verbose {
1004 eprintln!(
1005 " Inner GMRES: {} iterations, residual = {:.6e}, converged = {}",
1006 inner_result.iterations, inner_result.residual_norm, inner_result.converged
1007 );
1008 }
1009
1010 // Update solution: x = x + x'
1011 // When x is zero (first iteration with no initial guess), use x' directly
1012 // to avoid bond dimension doubling from axpby with a zero tensor.
1013 if x_is_zero {
1014 x = inner_result.solution;
1015 x_is_zero = false;
1016 } else {
1017 x = x.axpby(
1018 AnyScalar::new_real(1.0),
1019 &inner_result.solution,
1020 AnyScalar::new_real(1.0),
1021 )?;
1022 }
1023 truncate(&mut x)?;
1024 }
1025
1026 // Did not converge within max_outer_iters
1027 // Compute final residual
1028 let ax = apply_a(&x)?;
1029 let mut r = b.axpby(AnyScalar::new_real(1.0), &ax, AnyScalar::new_real(-1.0))?;
1030 truncate(&mut r)?;
1031 let final_rel_res = r.norm() / b_norm;
1032
1033 Ok(RestartGmresResult {
1034 solution: x,
1035 iterations: total_inner_iters,
1036 outer_iterations: options.max_outer_iters,
1037 residual_norm: final_rel_res,
1038 converged: false,
1039 })
1040}
1041
1042/// Compute Givens rotation coefficients to eliminate b in (a, b).
1043///
1044/// This function keeps computation in `AnyScalar` space to preserve AD metadata
1045/// as much as possible.
1046fn compute_givens_rotation(a: &AnyScalar, b: &AnyScalar) -> (AnyScalar, AnyScalar) {
1047 // r^2 = conj(a)*a + conj(b)*b (works for both real and complex)
1048 let norm2 = a.clone().conj() * a.clone() + b.clone().conj() * b.clone();
1049 let r = norm2.sqrt();
1050 if r.abs() < 1e-15 {
1051 (AnyScalar::new_real(1.0), AnyScalar::new_real(0.0))
1052 } else {
1053 (a.clone() / r.clone(), b.clone() / r)
1054 }
1055}
1056
1057/// Apply Givens rotation: (c, s) @ (x, y) -> (c*x + s*y, -conj(s)*x + c*y) for complex
1058/// or (c*x + s*y, -s*x + c*y) for real.
1059///
1060/// This function keeps computation in `AnyScalar` space to preserve AD metadata
1061/// as much as possible.
1062fn apply_givens_rotation(
1063 c: &AnyScalar,
1064 s: &AnyScalar,
1065 x: &AnyScalar,
1066 y: &AnyScalar,
1067) -> (AnyScalar, AnyScalar) {
1068 let new_x = c.clone() * x.clone() + s.clone() * y.clone();
1069 let new_y = -(s.clone().conj() * x.clone()) + c.clone() * y.clone();
1070 (new_x, new_y)
1071}
1072
1073/// Solve upper triangular system R y = g using back substitution.
1074fn solve_upper_triangular(h: &[Vec<AnyScalar>], g: &[AnyScalar]) -> Result<Vec<AnyScalar>> {
1075 let n = g.len();
1076 if n == 0 {
1077 return Ok(vec![]);
1078 }
1079
1080 let mut y = vec![AnyScalar::new_real(0.0); n];
1081
1082 for i in (0..n).rev() {
1083 let mut sum = g[i].clone();
1084
1085 for j in (i + 1)..n {
1086 // sum = sum - h[j][i] * y[j]
1087 let prod = h[j][i].clone() * y[j].clone();
1088 sum = sum - prod;
1089 }
1090
1091 let h_ii = &h[i][i];
1092 if h_ii.abs() < 1e-15 {
1093 return Err(anyhow::anyhow!(
1094 "Near-singular upper triangular matrix in GMRES"
1095 ));
1096 }
1097
1098 y[i] = sum / h_ii.clone();
1099 }
1100
1101 Ok(y)
1102}
1103
1104/// Update solution: x_new = x + sum_i y_i * v_i
1105fn update_solution<T: TensorLike>(x: &T, v_basis: &[T], y: &[AnyScalar]) -> Result<T> {
1106 let mut result = x.clone();
1107
1108 for (vi, yi) in v_basis.iter().zip(y.iter()) {
1109 let scaled_vi = vi.scale(yi.clone())?;
1110 // result = result + scaled_vi = 1.0 * result + 1.0 * scaled_vi
1111 result = result.axpby(
1112 AnyScalar::new_real(1.0),
1113 &scaled_vi,
1114 AnyScalar::new_real(1.0),
1115 )?;
1116 }
1117
1118 Ok(result)
1119}
1120
1121/// Update solution with truncation: x_new = truncate(x + sum_i y_i * v_i)
1122fn update_solution_truncated<T, Tr>(
1123 x: &T,
1124 v_basis: &[T],
1125 y: &[AnyScalar],
1126 truncate: &Tr,
1127) -> Result<T>
1128where
1129 T: TensorLike,
1130 Tr: Fn(&mut T) -> Result<()>,
1131{
1132 let mut result = x.clone();
1133 // Detect if x is effectively zero.
1134 // When x is created via scale(0.0), it preserves the original bond structure
1135 // (e.g., bond dim 4), causing axpby to double bond dimensions unnecessarily.
1136 // By detecting zero, we can use scaled_vi directly, avoiding the doubling.
1137 let mut result_is_zero = x.norm() == 0.0;
1138
1139 for (vi, yi) in v_basis.iter().zip(y.iter()) {
1140 let scaled_vi = vi.scale(yi.clone())?;
1141 if result_is_zero {
1142 result = scaled_vi;
1143 result_is_zero = false;
1144 } else {
1145 result = result.axpby(
1146 AnyScalar::new_real(1.0),
1147 &scaled_vi,
1148 AnyScalar::new_real(1.0),
1149 )?;
1150 }
1151 // Truncate after each addition to control bond dimension growth
1152 truncate(&mut result)?;
1153 }
1154
1155 Ok(result)
1156}
1157
1158#[cfg(test)]
1159mod tests;