1use crate::any_scalar::AnyScalar;
38use crate::TensorVectorSpace;
39use anyhow::Result;
40use std::sync::atomic::{AtomicUsize, Ordering};
41use std::time::{Duration, Instant};
42
43static GMRES_OP_PROFILE_COUNTER: AtomicUsize = AtomicUsize::new(0);
44
45#[derive(Debug, Clone)]
46struct GmresOpProfile {
47 started: Instant,
48 b_norm: Duration,
49 apply: Duration,
50 inner_product: Duration,
51 axpby: Duration,
52 norm: Duration,
53 scale: Duration,
54 triangular_solve: Duration,
55 solution_update: Duration,
56 apply_calls: usize,
57 inner_product_calls: usize,
58 axpby_calls: usize,
59 norm_calls: usize,
60 scale_calls: usize,
61 triangular_solve_calls: usize,
62 solution_update_calls: usize,
63}
64
65impl Default for GmresOpProfile {
66 fn default() -> Self {
67 Self {
68 started: Instant::now(),
69 b_norm: Duration::ZERO,
70 apply: Duration::ZERO,
71 inner_product: Duration::ZERO,
72 axpby: Duration::ZERO,
73 norm: Duration::ZERO,
74 scale: Duration::ZERO,
75 triangular_solve: Duration::ZERO,
76 solution_update: Duration::ZERO,
77 apply_calls: 0,
78 inner_product_calls: 0,
79 axpby_calls: 0,
80 norm_calls: 0,
81 scale_calls: 0,
82 triangular_solve_calls: 0,
83 solution_update_calls: 0,
84 }
85 }
86}
87
88impl GmresOpProfile {
89 fn measured(&self) -> Duration {
90 self.b_norm
91 + self.apply
92 + self.inner_product
93 + self.axpby
94 + self.norm
95 + self.scale
96 + self.triangular_solve
97 + self.solution_update
98 }
99
100 fn print(&self, id: usize, iterations: usize, residual_norm: f64, converged: bool) {
101 let total = self.started.elapsed();
102 let other = total.saturating_sub(self.measured());
103 eprintln!(
104 "T4A gmres_op_profile #{id}: iterations={iterations} residual={residual_norm:.6e} converged={converged} total_ms={:.3} apply_ms={:.3} apply_calls={} inner_ms={:.3} inner_calls={} axpby_ms={:.3} axpby_calls={} norm_ms={:.3} norm_calls={} scale_ms={:.3} scale_calls={} update_ms={:.3} update_calls={} triangular_ms={:.3} triangular_calls={} b_norm_ms={:.3} other_ms={:.3}",
105 total.as_secs_f64() * 1000.0,
106 self.apply.as_secs_f64() * 1000.0,
107 self.apply_calls,
108 self.inner_product.as_secs_f64() * 1000.0,
109 self.inner_product_calls,
110 self.axpby.as_secs_f64() * 1000.0,
111 self.axpby_calls,
112 self.norm.as_secs_f64() * 1000.0,
113 self.norm_calls,
114 self.scale.as_secs_f64() * 1000.0,
115 self.scale_calls,
116 self.solution_update.as_secs_f64() * 1000.0,
117 self.solution_update_calls,
118 self.triangular_solve.as_secs_f64() * 1000.0,
119 self.triangular_solve_calls,
120 self.b_norm.as_secs_f64() * 1000.0,
121 other.as_secs_f64() * 1000.0,
122 );
123 }
124}
125
126#[derive(Debug, Clone)]
144pub struct GmresOptions {
145 pub max_iter: usize,
148
149 pub rtol: f64,
153
154 pub max_restarts: usize,
158
159 pub verbose: bool,
162
163 pub check_true_residual: bool,
169}
170
171impl Default for GmresOptions {
172 fn default() -> Self {
173 Self {
174 max_iter: 100,
175 rtol: 1e-10,
176 max_restarts: 10,
177 verbose: false,
178 check_true_residual: false,
179 }
180 }
181}
182
183#[derive(Debug, Clone, Copy, PartialEq)]
184enum GmresTolerance {
185 Relative(f64),
186 Absolute(f64),
187}
188
189impl GmresTolerance {
190 fn residual_value(self, residual_norm: f64, b_norm: f64) -> f64 {
191 match self {
192 Self::Relative(_) => residual_norm / b_norm,
193 Self::Absolute(_) => residual_norm,
194 }
195 }
196
197 fn is_converged(self, residual_norm: f64, b_norm: f64) -> bool {
198 match self {
199 Self::Relative(rtol) => residual_norm / b_norm < rtol,
200 Self::Absolute(atol) => residual_norm < atol,
201 }
202 }
203}
204
205#[derive(Debug, Clone)]
225pub struct GmresResult<T> {
226 pub solution: T,
228
229 pub iterations: usize,
231
232 pub residual_norm: f64,
234
235 pub converged: bool,
237}
238
239pub fn gmres<T, F>(apply_a: F, b: &T, x0: &T, options: &GmresOptions) -> Result<GmresResult<T>>
272where
273 T: TensorVectorSpace,
274 F: Fn(&T) -> Result<T>,
275{
276 gmres_impl(
277 apply_a,
278 b,
279 x0,
280 options,
281 GmresTolerance::Relative(options.rtol),
282 None,
283 )
284}
285
286pub fn gmres_with_absolute_tolerance<T, F>(
291 apply_a: F,
292 b: &T,
293 x0: &T,
294 options: &GmresOptions,
295 atol: f64,
296) -> Result<GmresResult<T>>
297where
298 T: TensorVectorSpace,
299 F: Fn(&T) -> Result<T>,
300{
301 gmres_impl(
302 apply_a,
303 b,
304 x0,
305 options,
306 GmresTolerance::Absolute(atol),
307 None,
308 )
309}
310
311pub fn gmres_affine<T, F>(
317 apply_a: F,
318 b: &T,
319 x0: &T,
320 a0: AnyScalar,
321 a1: AnyScalar,
322 options: &GmresOptions,
323) -> Result<GmresResult<T>>
324where
325 T: TensorVectorSpace,
326 F: Fn(&T) -> Result<T>,
327{
328 gmres_affine_impl(
329 apply_a,
330 b,
331 x0,
332 a0,
333 a1,
334 options,
335 GmresTolerance::Relative(options.rtol),
336 )
337}
338
339pub fn gmres_affine_with_absolute_tolerance<T, F>(
346 apply_a: F,
347 b: &T,
348 x0: &T,
349 a0: AnyScalar,
350 a1: AnyScalar,
351 options: &GmresOptions,
352 atol: f64,
353) -> Result<GmresResult<T>>
354where
355 T: TensorVectorSpace,
356 F: Fn(&T) -> Result<T>,
357{
358 gmres_affine_impl(
359 apply_a,
360 b,
361 x0,
362 a0,
363 a1,
364 options,
365 GmresTolerance::Absolute(atol),
366 )
367}
368
369fn gmres_affine_impl<T, F>(
370 apply_a: F,
371 b: &T,
372 x0: &T,
373 a0: AnyScalar,
374 a1: AnyScalar,
375 options: &GmresOptions,
376 tolerance: GmresTolerance,
377) -> Result<GmresResult<T>>
378where
379 T: TensorVectorSpace,
380 F: Fn(&T) -> Result<T>,
381{
382 b.validate()?;
383 x0.validate()?;
384
385 let profile_enabled = std::env::var_os("T4A_GMRES_OP_PROFILE").is_some();
386 let profile_id = if profile_enabled {
387 GMRES_OP_PROFILE_COUNTER.fetch_add(1, Ordering::Relaxed)
388 } else {
389 0
390 };
391 let mut profile = GmresOpProfile::default();
392 let mut total_iters = 0usize;
393
394 macro_rules! finish {
395 ($result:expr) => {{
396 let result = $result;
397 if profile_enabled {
398 profile.print(
399 profile_id,
400 result.iterations,
401 result.residual_norm,
402 result.converged,
403 );
404 }
405 return Ok(result);
406 }};
407 }
408
409 let started = Instant::now();
410 let b_norm = b.norm();
411 if profile_enabled {
412 profile.b_norm += started.elapsed();
413 }
414 if b_norm < 1e-15 {
415 finish!(GmresResult {
416 solution: x0.clone(),
417 iterations: 0,
418 residual_norm: 0.0,
419 converged: true,
420 });
421 }
422 if a0.is_zero() && a1.is_zero() {
423 anyhow::bail!("gmres_affine: at least one affine coefficient must be nonzero");
424 }
425 if a1.is_zero() {
426 let started = Instant::now();
427 let solution = b.scale(AnyScalar::new_real(1.0) / a0)?;
428 if profile_enabled {
429 profile.scale += started.elapsed();
430 profile.scale_calls += 1;
431 }
432 finish!(GmresResult {
433 solution,
434 iterations: 0,
435 residual_norm: 0.0,
436 converged: true,
437 });
438 }
439
440 let mut x = x0.clone();
441
442 for restart in 0..options.max_restarts {
443 let started = Instant::now();
444 let ax = apply_a(&x)?;
445 if profile_enabled {
446 profile.apply += started.elapsed();
447 profile.apply_calls += 1;
448 }
449 if restart == 0 {
450 ax.validate()?;
451 }
452 let started = Instant::now();
453 let affine_x = x.axpby(a0.clone(), &ax, a1.clone())?;
454 if profile_enabled {
455 profile.axpby += started.elapsed();
456 profile.axpby_calls += 1;
457 }
458 let started = Instant::now();
459 let r = b.axpby(
460 AnyScalar::new_real(1.0),
461 &affine_x,
462 AnyScalar::new_real(-1.0),
463 )?;
464 if profile_enabled {
465 profile.axpby += started.elapsed();
466 profile.axpby_calls += 1;
467 }
468 let started = Instant::now();
469 let r_norm = r.norm();
470 if profile_enabled {
471 profile.norm += started.elapsed();
472 profile.norm_calls += 1;
473 }
474 let residual_value = tolerance.residual_value(r_norm, b_norm);
475 if options.verbose {
476 eprintln!(
477 "GMRES restart {}: initial residual = {:.6e}",
478 restart, residual_value
479 );
480 }
481 if tolerance.is_converged(r_norm, b_norm) {
482 finish!(GmresResult {
483 solution: x,
484 iterations: total_iters,
485 residual_norm: residual_value,
486 converged: true,
487 });
488 }
489
490 let cycle_max_iter = options.max_iter;
491 let mut v_basis: Vec<T> = Vec::with_capacity(cycle_max_iter + 1);
492 let started = Instant::now();
493 v_basis.push(r.scale(AnyScalar::new_real(1.0 / r_norm))?);
494 if profile_enabled {
495 profile.scale += started.elapsed();
496 profile.scale_calls += 1;
497 }
498
499 let mut h_matrix: Vec<Vec<AnyScalar>> = Vec::with_capacity(cycle_max_iter);
500 let mut cs: Vec<AnyScalar> = Vec::with_capacity(cycle_max_iter);
501 let mut sn: Vec<AnyScalar> = Vec::with_capacity(cycle_max_iter);
502 let mut g: Vec<AnyScalar> = vec![AnyScalar::new_real(r_norm)];
503 let mut solution_already_updated = false;
504
505 for j in 0..cycle_max_iter {
506 total_iters += 1;
507
508 let started = Instant::now();
509 let w = apply_a(&v_basis[j])?;
510 if profile_enabled {
511 profile.apply += started.elapsed();
512 profile.apply_calls += 1;
513 }
514 let mut h_a_col: Vec<AnyScalar> = Vec::with_capacity(j + 2);
515 let mut w_orth = w;
516
517 for v_i in v_basis.iter().take(j + 1) {
518 let started = Instant::now();
519 let h_ij = v_i.inner_product(&w_orth)?;
520 if profile_enabled {
521 profile.inner_product += started.elapsed();
522 profile.inner_product_calls += 1;
523 }
524 h_a_col.push(h_ij.clone());
525 let neg_h_ij = AnyScalar::new_real(0.0) - h_ij;
526 let started = Instant::now();
527 w_orth = w_orth.axpby(AnyScalar::new_real(1.0), v_i, neg_h_ij)?;
528 if profile_enabled {
529 profile.axpby += started.elapsed();
530 profile.axpby_calls += 1;
531 }
532 }
533 for (i, v_i) in v_basis.iter().take(j + 1).enumerate() {
534 let started = Instant::now();
535 let correction = v_i.inner_product(&w_orth)?;
536 if profile_enabled {
537 profile.inner_product += started.elapsed();
538 profile.inner_product_calls += 1;
539 }
540 h_a_col[i] = h_a_col[i].clone() + correction.clone();
541 let neg_correction = AnyScalar::new_real(0.0) - correction;
542 let started = Instant::now();
543 w_orth = w_orth.axpby(AnyScalar::new_real(1.0), v_i, neg_correction)?;
544 if profile_enabled {
545 profile.axpby += started.elapsed();
546 profile.axpby_calls += 1;
547 }
548 }
549
550 let started = Instant::now();
551 let h_jp1_j_real = w_orth.norm();
552 if profile_enabled {
553 profile.norm += started.elapsed();
554 profile.norm_calls += 1;
555 }
556 h_a_col.push(AnyScalar::new_real(h_jp1_j_real));
557
558 let mut h_col: Vec<AnyScalar> = Vec::with_capacity(j + 2);
559 for h in h_a_col.iter().take(j) {
560 h_col.push(a1.clone() * h.clone());
561 }
562 h_col.push(a0.clone() + a1.clone() * h_a_col[j].clone());
563 h_col.push(a1.clone() * h_a_col[j + 1].clone());
564
565 #[allow(clippy::needless_range_loop)]
566 for i in 0..j {
567 let h_i = h_col[i].clone();
568 let h_ip1 = h_col[i + 1].clone();
569 let (new_hi, new_hip1) = apply_givens_rotation(&cs[i], &sn[i], &h_i, &h_ip1);
570 h_col[i] = new_hi;
571 h_col[i + 1] = new_hip1;
572 }
573
574 let (c_j, s_j) = compute_givens_rotation(&h_col[j], &h_col[j + 1]);
575 cs.push(c_j.clone());
576 sn.push(s_j.clone());
577
578 let (new_hj, _) = apply_givens_rotation(&c_j, &s_j, &h_col[j], &h_col[j + 1]);
579 h_col[j] = new_hj;
580 h_col[j + 1] = AnyScalar::new_real(0.0);
581
582 let g_j = g[j].clone();
583 let g_jp1 = AnyScalar::new_real(0.0);
584 let (new_gj, new_gjp1) = apply_givens_rotation(&c_j, &s_j, &g_j, &g_jp1);
585 g[j] = new_gj;
586 let res_norm = new_gjp1.abs();
587 g.push(new_gjp1);
588
589 h_matrix.push(h_col);
590 let residual_value = tolerance.residual_value(res_norm, b_norm);
591 if options.verbose {
592 eprintln!("GMRES iter {}: residual = {:.6e}", j + 1, residual_value);
593 }
594
595 if tolerance.is_converged(res_norm, b_norm) {
596 let started = Instant::now();
597 let y = solve_upper_triangular(&h_matrix, &g[..=j])?;
598 if profile_enabled {
599 profile.triangular_solve += started.elapsed();
600 profile.triangular_solve_calls += 1;
601 }
602 let started = Instant::now();
603 x = update_solution(&x, &v_basis[..=j], &y)?;
604 if profile_enabled {
605 profile.solution_update += started.elapsed();
606 profile.solution_update_calls += 1;
607 }
608 if options.check_true_residual {
609 let started = Instant::now();
610 let ax_check = apply_a(&x)?;
611 if profile_enabled {
612 profile.apply += started.elapsed();
613 profile.apply_calls += 1;
614 }
615 let started = Instant::now();
616 let affine_check = x.axpby(a0.clone(), &ax_check, a1.clone())?;
617 if profile_enabled {
618 profile.axpby += started.elapsed();
619 profile.axpby_calls += 1;
620 }
621 let started = Instant::now();
622 let r_check = b.axpby(
623 AnyScalar::new_real(1.0),
624 &affine_check,
625 AnyScalar::new_real(-1.0),
626 )?;
627 if profile_enabled {
628 profile.axpby += started.elapsed();
629 profile.axpby_calls += 1;
630 }
631 let started = Instant::now();
632 let true_abs_res = r_check.norm();
633 if profile_enabled {
634 profile.norm += started.elapsed();
635 profile.norm_calls += 1;
636 }
637 let true_residual_value = tolerance.residual_value(true_abs_res, b_norm);
638 if options.verbose {
639 eprintln!(
640 "GMRES true residual check: hessenberg={:.6e}, checked={:.6e}",
641 residual_value, true_residual_value
642 );
643 }
644 if tolerance.is_converged(true_abs_res, b_norm) {
645 finish!(GmresResult {
646 solution: x,
647 iterations: total_iters,
648 residual_norm: true_residual_value,
649 converged: true,
650 });
651 }
652 solution_already_updated = true;
653 break;
654 } else {
655 finish!(GmresResult {
656 solution: x,
657 iterations: total_iters,
658 residual_norm: residual_value,
659 converged: true,
660 });
661 }
662 }
663
664 if h_jp1_j_real > 1e-14 {
665 let started = Instant::now();
666 v_basis.push(w_orth.scale(AnyScalar::new_real(1.0 / h_jp1_j_real))?);
667 if profile_enabled {
668 profile.scale += started.elapsed();
669 profile.scale_calls += 1;
670 }
671 } else {
672 let started = Instant::now();
673 let y = solve_upper_triangular(&h_matrix, &g[..=j])?;
674 if profile_enabled {
675 profile.triangular_solve += started.elapsed();
676 profile.triangular_solve_calls += 1;
677 }
678 let started = Instant::now();
679 x = update_solution(&x, &v_basis[..=j], &y)?;
680 if profile_enabled {
681 profile.solution_update += started.elapsed();
682 profile.solution_update_calls += 1;
683 }
684 let started = Instant::now();
685 let ax_final = apply_a(&x)?;
686 if profile_enabled {
687 profile.apply += started.elapsed();
688 profile.apply_calls += 1;
689 }
690 let started = Instant::now();
691 let affine_final = x.axpby(a0.clone(), &ax_final, a1.clone())?;
692 if profile_enabled {
693 profile.axpby += started.elapsed();
694 profile.axpby_calls += 1;
695 }
696 let started = Instant::now();
697 let r_final = b.axpby(
698 AnyScalar::new_real(1.0),
699 &affine_final,
700 AnyScalar::new_real(-1.0),
701 )?;
702 if profile_enabled {
703 profile.axpby += started.elapsed();
704 profile.axpby_calls += 1;
705 }
706 let started = Instant::now();
707 let final_abs_res = r_final.norm();
708 if profile_enabled {
709 profile.norm += started.elapsed();
710 profile.norm_calls += 1;
711 }
712 let final_res = tolerance.residual_value(final_abs_res, b_norm);
713 finish!(GmresResult {
714 solution: x,
715 iterations: total_iters,
716 residual_norm: final_res,
717 converged: tolerance.is_converged(final_abs_res, b_norm),
718 });
719 }
720 }
721
722 if !solution_already_updated {
723 let actual_iters = h_matrix.len();
724 let started = Instant::now();
725 let y = solve_upper_triangular(&h_matrix, &g[..actual_iters])?;
726 if profile_enabled {
727 profile.triangular_solve += started.elapsed();
728 profile.triangular_solve_calls += 1;
729 }
730 let started = Instant::now();
731 x = update_solution(&x, &v_basis[..actual_iters], &y)?;
732 if profile_enabled {
733 profile.solution_update += started.elapsed();
734 profile.solution_update_calls += 1;
735 }
736 }
737 }
738
739 let started = Instant::now();
740 let ax_final = apply_a(&x)?;
741 if profile_enabled {
742 profile.apply += started.elapsed();
743 profile.apply_calls += 1;
744 }
745 let started = Instant::now();
746 let affine_final = x.axpby(a0, &ax_final, a1)?;
747 if profile_enabled {
748 profile.axpby += started.elapsed();
749 profile.axpby_calls += 1;
750 }
751 let started = Instant::now();
752 let r_final = b.axpby(
753 AnyScalar::new_real(1.0),
754 &affine_final,
755 AnyScalar::new_real(-1.0),
756 )?;
757 if profile_enabled {
758 profile.axpby += started.elapsed();
759 profile.axpby_calls += 1;
760 }
761 let started = Instant::now();
762 let final_abs_res = r_final.norm();
763 if profile_enabled {
764 profile.norm += started.elapsed();
765 profile.norm_calls += 1;
766 }
767 let final_res = tolerance.residual_value(final_abs_res, b_norm);
768
769 finish!(GmresResult {
770 solution: x,
771 iterations: total_iters,
772 residual_norm: final_res,
773 converged: tolerance.is_converged(final_abs_res, b_norm),
774 })
775}
776
777pub fn gmres_with_total_iteration_limit<T, F>(
784 apply_a: F,
785 b: &T,
786 x0: &T,
787 options: &GmresOptions,
788 max_total_iter: usize,
789) -> Result<GmresResult<T>>
790where
791 T: TensorVectorSpace,
792 F: Fn(&T) -> Result<T>,
793{
794 gmres_impl(
795 apply_a,
796 b,
797 x0,
798 options,
799 GmresTolerance::Relative(options.rtol),
800 Some(max_total_iter),
801 )
802}
803
804fn gmres_impl<T, F>(
805 apply_a: F,
806 b: &T,
807 x0: &T,
808 options: &GmresOptions,
809 tolerance: GmresTolerance,
810 max_total_iter: Option<usize>,
811) -> Result<GmresResult<T>>
812where
813 T: TensorVectorSpace,
814 F: Fn(&T) -> Result<T>,
815{
816 b.validate()?;
818 x0.validate()?;
819
820 let b_norm = b.norm();
821 if b_norm < 1e-15 {
822 return Ok(GmresResult {
824 solution: x0.clone(),
825 iterations: 0,
826 residual_norm: 0.0,
827 converged: true,
828 });
829 }
830
831 let mut x = x0.clone();
832 let mut total_iters = 0;
833
834 for _restart in 0..options.max_restarts {
835 let cycle_max_iter = match max_total_iter {
836 Some(limit) => {
837 let remaining = limit.saturating_sub(total_iters);
838 if remaining == 0 {
839 break;
840 }
841 options.max_iter.min(remaining)
842 }
843 None => options.max_iter,
844 };
845 if cycle_max_iter == 0 {
846 break;
847 }
848
849 let ax = apply_a(&x)?;
851 if _restart == 0 {
853 ax.validate()?;
854 }
855 let r = b.axpby(AnyScalar::new_real(1.0), &ax, AnyScalar::new_real(-1.0))?;
857 let r_norm = r.norm();
858 let residual_value = tolerance.residual_value(r_norm, b_norm);
859
860 if options.verbose {
861 eprintln!(
862 "GMRES restart {}: initial residual = {:.6e}",
863 _restart, residual_value
864 );
865 }
866
867 if tolerance.is_converged(r_norm, b_norm) {
868 return Ok(GmresResult {
869 solution: x,
870 iterations: total_iters,
871 residual_norm: residual_value,
872 converged: true,
873 });
874 }
875
876 let mut v_basis: Vec<T> = Vec::with_capacity(cycle_max_iter + 1);
878 let mut h_matrix: Vec<Vec<AnyScalar>> = Vec::with_capacity(cycle_max_iter);
879
880 let v0 = r.scale(AnyScalar::new_real(1.0 / r_norm))?;
882 v_basis.push(v0);
883
884 let mut cs: Vec<AnyScalar> = Vec::with_capacity(cycle_max_iter);
886 let mut sn: Vec<AnyScalar> = Vec::with_capacity(cycle_max_iter);
887 let mut g: Vec<AnyScalar> = vec![AnyScalar::new_real(r_norm)]; let mut solution_already_updated = false;
889
890 for j in 0..cycle_max_iter {
891 total_iters += 1;
892
893 let w = apply_a(&v_basis[j])?;
895
896 let mut h_col: Vec<AnyScalar> = Vec::with_capacity(j + 2);
898 let mut w_orth = w;
899
900 for v_i in v_basis.iter().take(j + 1) {
901 let h_ij = v_i.inner_product(&w_orth)?;
902 h_col.push(h_ij.clone());
903 let neg_h_ij = AnyScalar::new_real(0.0) - h_ij;
905 w_orth = w_orth.axpby(AnyScalar::new_real(1.0), v_i, neg_h_ij)?;
906 }
907
908 for (i, v_i) in v_basis.iter().take(j + 1).enumerate() {
912 let correction = v_i.inner_product(&w_orth)?;
913 h_col[i] = h_col[i].clone() + correction.clone();
914 let neg_correction = AnyScalar::new_real(0.0) - correction;
915 w_orth = w_orth.axpby(AnyScalar::new_real(1.0), v_i, neg_correction)?;
916 }
917
918 let h_jp1_j_real = w_orth.norm();
919 let h_jp1_j = AnyScalar::new_real(h_jp1_j_real);
920 h_col.push(h_jp1_j);
921
922 #[allow(clippy::needless_range_loop)]
924 for i in 0..j {
925 let h_i = h_col[i].clone();
926 let h_ip1 = h_col[i + 1].clone();
927 let (new_hi, new_hip1) = apply_givens_rotation(&cs[i], &sn[i], &h_i, &h_ip1);
928 h_col[i] = new_hi;
929 h_col[i + 1] = new_hip1;
930 }
931
932 let (c_j, s_j) = compute_givens_rotation(&h_col[j], &h_col[j + 1]);
934 cs.push(c_j.clone());
935 sn.push(s_j.clone());
936
937 let (new_hj, _) = apply_givens_rotation(&c_j, &s_j, &h_col[j], &h_col[j + 1]);
939 h_col[j] = new_hj;
940 h_col[j + 1] = AnyScalar::new_real(0.0);
941
942 let g_j = g[j].clone();
944 let g_jp1 = AnyScalar::new_real(0.0);
945 let (new_gj, new_gjp1) = apply_givens_rotation(&c_j, &s_j, &g_j, &g_jp1);
946 g[j] = new_gj;
947 let res_norm = new_gjp1.abs();
948 g.push(new_gjp1);
949
950 h_matrix.push(h_col);
951
952 let residual_value = tolerance.residual_value(res_norm, b_norm);
954
955 if options.verbose {
956 eprintln!("GMRES iter {}: residual = {:.6e}", j + 1, residual_value);
957 }
958
959 if tolerance.is_converged(res_norm, b_norm) {
960 let y = solve_upper_triangular(&h_matrix, &g[..=j])?;
962 x = update_solution(&x, &v_basis[..=j], &y)?;
963 if options.check_true_residual {
964 let ax_check = apply_a(&x)?;
965 let r_check = b.axpby(
966 AnyScalar::new_real(1.0),
967 &ax_check,
968 AnyScalar::new_real(-1.0),
969 )?;
970 let true_abs_res = r_check.norm();
971 let true_residual_value = tolerance.residual_value(true_abs_res, b_norm);
972
973 if options.verbose {
974 eprintln!(
975 "GMRES true residual check: hessenberg={:.6e}, checked={:.6e}",
976 residual_value, true_residual_value
977 );
978 }
979
980 if tolerance.is_converged(true_abs_res, b_norm) {
981 return Ok(GmresResult {
982 solution: x,
983 iterations: total_iters,
984 residual_norm: true_residual_value,
985 converged: true,
986 });
987 }
988 solution_already_updated = true;
989 break;
990 } else {
991 return Ok(GmresResult {
992 solution: x,
993 iterations: total_iters,
994 residual_norm: residual_value,
995 converged: true,
996 });
997 }
998 }
999
1000 if h_jp1_j_real > 1e-14 {
1002 let v_jp1 = w_orth.scale(AnyScalar::new_real(1.0 / h_jp1_j_real))?;
1003 v_basis.push(v_jp1);
1004 } else {
1005 let y = solve_upper_triangular(&h_matrix, &g[..=j])?;
1007 x = update_solution(&x, &v_basis[..=j], &y)?;
1008 let ax_final = apply_a(&x)?;
1009 let r_final = b.axpby(
1010 AnyScalar::new_real(1.0),
1011 &ax_final,
1012 AnyScalar::new_real(-1.0),
1013 )?;
1014 let final_abs_res = r_final.norm();
1015 let final_res = tolerance.residual_value(final_abs_res, b_norm);
1016 return Ok(GmresResult {
1017 solution: x,
1018 iterations: total_iters,
1019 residual_norm: final_res,
1020 converged: tolerance.is_converged(final_abs_res, b_norm),
1021 });
1022 }
1023 }
1024
1025 if !solution_already_updated {
1027 let actual_iters = h_matrix.len();
1028 let y = solve_upper_triangular(&h_matrix, &g[..actual_iters])?;
1029 x = update_solution(&x, &v_basis[..actual_iters], &y)?;
1030 }
1031 }
1032
1033 let ax_final = apply_a(&x)?;
1035 let r_final = b.axpby(
1036 AnyScalar::new_real(1.0),
1037 &ax_final,
1038 AnyScalar::new_real(-1.0),
1039 )?;
1040 let final_abs_res = r_final.norm();
1041 let final_res = tolerance.residual_value(final_abs_res, b_norm);
1042
1043 Ok(GmresResult {
1044 solution: x,
1045 iterations: total_iters,
1046 residual_norm: final_res,
1047 converged: tolerance.is_converged(final_abs_res, b_norm),
1048 })
1049}
1050
1051pub fn gmres_with_truncation<T, F, Tr>(
1100 apply_a: F,
1101 b: &T,
1102 x0: &T,
1103 options: &GmresOptions,
1104 truncate: Tr,
1105) -> Result<GmresResult<T>>
1106where
1107 T: TensorVectorSpace,
1108 F: Fn(&T) -> Result<T>,
1109 Tr: Fn(&mut T) -> Result<()>,
1110{
1111 b.validate()?;
1113 x0.validate()?;
1114
1115 let b_norm = b.norm();
1116 if b_norm < 1e-15 {
1117 return Ok(GmresResult {
1118 solution: x0.clone(),
1119 iterations: 0,
1120 residual_norm: 0.0,
1121 converged: true,
1122 });
1123 }
1124
1125 let mut x = x0.clone();
1126 let mut total_iters = 0;
1127
1128 for _restart in 0..options.max_restarts {
1129 let ax = apply_a(&x)?;
1130 if _restart == 0 {
1132 ax.validate()?;
1133 }
1134 let mut r = b.axpby(AnyScalar::new_real(1.0), &ax, AnyScalar::new_real(-1.0))?;
1135 truncate(&mut r)?;
1136 let r_norm = r.norm();
1137 let rel_res = r_norm / b_norm;
1138
1139 if options.verbose {
1140 eprintln!(
1141 "GMRES restart {}: initial residual = {:.6e}",
1142 _restart, rel_res
1143 );
1144 }
1145
1146 if rel_res < options.rtol {
1147 return Ok(GmresResult {
1148 solution: x,
1149 iterations: total_iters,
1150 residual_norm: rel_res,
1151 converged: true,
1152 });
1153 }
1154
1155 let mut v_basis: Vec<T> = Vec::with_capacity(options.max_iter + 1);
1156 let mut h_matrix: Vec<Vec<AnyScalar>> = Vec::with_capacity(options.max_iter);
1157
1158 let mut v0 = r.scale(AnyScalar::new_real(1.0 / r_norm))?;
1159 truncate(&mut v0)?;
1160 let v0_norm = v0.norm();
1165 let effective_g0 = if v0_norm > 1e-15 {
1166 v0 = v0.scale(AnyScalar::new_real(1.0 / v0_norm))?;
1167 r_norm * v0_norm
1171 } else {
1172 r_norm
1173 };
1174 v_basis.push(v0);
1175
1176 let mut cs: Vec<AnyScalar> = Vec::with_capacity(options.max_iter);
1177 let mut sn: Vec<AnyScalar> = Vec::with_capacity(options.max_iter);
1178 let mut g: Vec<AnyScalar> = vec![AnyScalar::new_real(effective_g0)];
1179 let mut solution_already_updated = false;
1180
1181 for j in 0..options.max_iter {
1182 total_iters += 1;
1183
1184 let w = apply_a(&v_basis[j])?;
1185
1186 let mut h_col: Vec<AnyScalar> = Vec::with_capacity(j + 2);
1187 let mut w_orth = w;
1188
1189 for v_i in v_basis.iter().take(j + 1) {
1190 let h_ij = v_i.inner_product(&w_orth)?;
1191 h_col.push(h_ij.clone());
1192 let neg_h_ij = AnyScalar::new_real(0.0) - h_ij;
1193 w_orth = w_orth.axpby(AnyScalar::new_real(1.0), v_i, neg_h_ij)?;
1194 }
1195
1196 const REORTH_THRESHOLD: f64 = 1e-12;
1201 const MAX_REORTH_ITERS: usize = 10;
1202
1203 let mut reorth_iter_count = 0;
1204 for reorth_iter in 0..MAX_REORTH_ITERS {
1205 reorth_iter_count = reorth_iter + 1;
1206 let norm_before_truncate = w_orth.norm();
1207 truncate(&mut w_orth)?;
1208 let norm_after_truncate = w_orth.norm();
1209
1210 let mut max_correction = 0.0;
1211 for (i, v_i) in v_basis.iter().enumerate() {
1212 let correction = v_i.inner_product(&w_orth)?;
1213 let correction_abs = correction.abs();
1214 if correction_abs > max_correction {
1215 max_correction = correction_abs;
1216 }
1217 if correction_abs > REORTH_THRESHOLD {
1218 let neg_correction = AnyScalar::new_real(0.0) - correction.clone();
1219 w_orth = w_orth.axpby(AnyScalar::new_real(1.0), v_i, neg_correction)?;
1220 h_col[i] = h_col[i].clone() + correction;
1222 }
1223 }
1224
1225 if options.verbose {
1226 eprintln!(
1227 " reorth iter {}: norm {:.6e} -> {:.6e}, max_correction = {:.6e}",
1228 reorth_iter, norm_before_truncate, norm_after_truncate, max_correction
1229 );
1230 }
1231
1232 if max_correction < REORTH_THRESHOLD {
1234 break;
1235 }
1236 }
1237
1238 if options.verbose && reorth_iter_count > 1 {
1239 eprintln!(" (needed {} reorth iterations)", reorth_iter_count);
1240 }
1241
1242 let h_jp1_j_real = w_orth.norm();
1243 let h_jp1_j = AnyScalar::new_real(h_jp1_j_real);
1244 h_col.push(h_jp1_j);
1245
1246 #[allow(clippy::needless_range_loop)]
1247 for i in 0..j {
1248 let h_i = h_col[i].clone();
1249 let h_ip1 = h_col[i + 1].clone();
1250 let (new_hi, new_hip1) = apply_givens_rotation(&cs[i], &sn[i], &h_i, &h_ip1);
1251 h_col[i] = new_hi;
1252 h_col[i + 1] = new_hip1;
1253 }
1254
1255 let (c_j, s_j) = compute_givens_rotation(&h_col[j], &h_col[j + 1]);
1256 cs.push(c_j.clone());
1257 sn.push(s_j.clone());
1258
1259 let (new_hj, _) = apply_givens_rotation(&c_j, &s_j, &h_col[j], &h_col[j + 1]);
1260 h_col[j] = new_hj;
1261 h_col[j + 1] = AnyScalar::new_real(0.0);
1262
1263 let g_j = g[j].clone();
1264 let g_jp1 = AnyScalar::new_real(0.0);
1265 let (new_gj, new_gjp1) = apply_givens_rotation(&c_j, &s_j, &g_j, &g_jp1);
1266 g[j] = new_gj;
1267 let res_norm = new_gjp1.abs();
1268 g.push(new_gjp1);
1269
1270 h_matrix.push(h_col);
1271
1272 let rel_res = res_norm / b_norm;
1273
1274 if options.verbose {
1275 eprintln!("GMRES iter {}: residual = {:.6e}", j + 1, rel_res);
1276 }
1277
1278 if rel_res < options.rtol {
1279 let y = solve_upper_triangular(&h_matrix, &g[..=j])?;
1280 x = update_solution_truncated(&x, &v_basis[..=j], &y, &truncate)?;
1281
1282 if options.check_true_residual {
1283 let ax_check = apply_a(&x)?;
1285 let mut r_check = b.axpby(
1286 AnyScalar::new_real(1.0),
1287 &ax_check,
1288 AnyScalar::new_real(-1.0),
1289 )?;
1290 truncate(&mut r_check)?;
1291 let true_rel_res = r_check.norm() / b_norm;
1292
1293 if options.verbose {
1294 eprintln!(
1295 "GMRES true residual check: hessenberg={:.6e}, checked={:.6e}",
1296 rel_res, true_rel_res
1297 );
1298 }
1299
1300 if true_rel_res < options.rtol {
1301 return Ok(GmresResult {
1302 solution: x,
1303 iterations: total_iters,
1304 residual_norm: true_rel_res,
1305 converged: true,
1306 });
1307 }
1308 solution_already_updated = true;
1311 break;
1312 } else {
1313 return Ok(GmresResult {
1314 solution: x,
1315 iterations: total_iters,
1316 residual_norm: rel_res,
1317 converged: true,
1318 });
1319 }
1320 }
1321
1322 if h_jp1_j_real > 1e-14 {
1323 let v_jp1 = w_orth.scale(AnyScalar::new_real(1.0 / h_jp1_j_real))?;
1327 v_basis.push(v_jp1);
1330 } else {
1331 let y = solve_upper_triangular(&h_matrix, &g[..=j])?;
1332 x = update_solution_truncated(&x, &v_basis[..=j], &y, &truncate)?;
1333 let ax_final = apply_a(&x)?;
1334 let r_final = b.axpby(
1335 AnyScalar::new_real(1.0),
1336 &ax_final,
1337 AnyScalar::new_real(-1.0),
1338 )?;
1339 let final_res = r_final.norm() / b_norm;
1340 return Ok(GmresResult {
1341 solution: x,
1342 iterations: total_iters,
1343 residual_norm: final_res,
1344 converged: final_res < options.rtol,
1345 });
1346 }
1347 }
1348
1349 if !solution_already_updated {
1350 let actual_iters = v_basis.len().min(options.max_iter);
1351 let y = solve_upper_triangular(&h_matrix, &g[..actual_iters])?;
1352 x = update_solution_truncated(&x, &v_basis[..actual_iters], &y, &truncate)?;
1353 }
1354 }
1355
1356 let ax_final = apply_a(&x)?;
1357 let r_final = b.axpby(
1358 AnyScalar::new_real(1.0),
1359 &ax_final,
1360 AnyScalar::new_real(-1.0),
1361 )?;
1362 let final_res = r_final.norm() / b_norm;
1363
1364 Ok(GmresResult {
1365 solution: x,
1366 iterations: total_iters,
1367 residual_norm: final_res,
1368 converged: final_res < options.rtol,
1369 })
1370}
1371
1372#[derive(Debug, Clone)]
1399pub struct RestartGmresOptions {
1400 pub max_outer_iters: usize,
1403
1404 pub rtol: f64,
1408
1409 pub inner_max_iter: usize,
1412
1413 pub inner_max_restarts: usize,
1416
1417 pub min_reduction: Option<f64>,
1423
1424 pub inner_rtol: Option<f64>,
1428
1429 pub verbose: bool,
1432}
1433
1434impl Default for RestartGmresOptions {
1435 fn default() -> Self {
1436 Self {
1437 max_outer_iters: 20,
1438 rtol: 1e-10,
1439 inner_max_iter: 10,
1440 inner_max_restarts: 0,
1441 min_reduction: None,
1442 inner_rtol: None,
1443 verbose: false,
1444 }
1445 }
1446}
1447
1448impl RestartGmresOptions {
1449 pub fn new() -> Self {
1451 Self::default()
1452 }
1453
1454 pub fn with_max_outer_iters(mut self, max_outer_iters: usize) -> Self {
1456 self.max_outer_iters = max_outer_iters;
1457 self
1458 }
1459
1460 pub fn with_rtol(mut self, rtol: f64) -> Self {
1462 self.rtol = rtol;
1463 self
1464 }
1465
1466 pub fn with_inner_max_iter(mut self, inner_max_iter: usize) -> Self {
1468 self.inner_max_iter = inner_max_iter;
1469 self
1470 }
1471
1472 pub fn with_inner_max_restarts(mut self, inner_max_restarts: usize) -> Self {
1474 self.inner_max_restarts = inner_max_restarts;
1475 self
1476 }
1477
1478 pub fn with_min_reduction(mut self, min_reduction: f64) -> Self {
1480 self.min_reduction = Some(min_reduction);
1481 self
1482 }
1483
1484 pub fn with_inner_rtol(mut self, inner_rtol: f64) -> Self {
1486 self.inner_rtol = Some(inner_rtol);
1487 self
1488 }
1489
1490 pub fn with_verbose(mut self, verbose: bool) -> Self {
1492 self.verbose = verbose;
1493 self
1494 }
1495}
1496
1497#[derive(Debug, Clone)]
1520pub struct RestartGmresResult<T> {
1521 pub solution: T,
1523
1524 pub iterations: usize,
1526
1527 pub outer_iterations: usize,
1529
1530 pub residual_norm: f64,
1532
1533 pub converged: bool,
1535}
1536
1537pub fn restart_gmres_with_truncation<T, F, Tr>(
1596 apply_a: F,
1597 b: &T,
1598 x0: Option<&T>,
1599 options: &RestartGmresOptions,
1600 truncate: Tr,
1601) -> Result<RestartGmresResult<T>>
1602where
1603 T: TensorVectorSpace,
1604 F: Fn(&T) -> Result<T>,
1605 Tr: Fn(&mut T) -> Result<()>,
1606{
1607 b.validate()?;
1609 if let Some(x) = x0 {
1610 x.validate()?;
1611 }
1612
1613 let b_norm = b.norm();
1614 if b_norm < 1e-15 {
1615 let solution = match x0 {
1617 Some(x) => x.clone(),
1618 None => b.scale(AnyScalar::new_real(0.0))?,
1619 };
1620 return Ok(RestartGmresResult {
1621 solution,
1622 iterations: 0,
1623 outer_iterations: 0,
1624 residual_norm: 0.0,
1625 converged: true,
1626 });
1627 }
1628
1629 let mut x_is_zero = x0.is_none();
1633 let mut x = match x0 {
1634 Some(x) => x.clone(),
1635 None => b.scale(AnyScalar::new_real(0.0))?,
1636 };
1637
1638 let mut total_inner_iters = 0;
1639 let mut prev_residual_norm = f64::INFINITY;
1640
1641 let inner_options = GmresOptions {
1643 max_iter: options.inner_max_iter,
1644 rtol: options.inner_rtol.unwrap_or(0.1), max_restarts: options.inner_max_restarts + 1, verbose: options.verbose,
1647 check_true_residual: true, };
1649
1650 for outer_iter in 0..options.max_outer_iters {
1651 let ax = apply_a(&x)?;
1653 if outer_iter == 0 {
1655 ax.validate()?;
1656 }
1657 let mut r = b.axpby(AnyScalar::new_real(1.0), &ax, AnyScalar::new_real(-1.0))?;
1658 truncate(&mut r)?;
1659
1660 let r_norm = r.norm();
1661 let rel_res = r_norm / b_norm;
1662
1663 if options.verbose {
1664 eprintln!(
1665 "Restart GMRES outer iter {}: true residual = {:.6e}",
1666 outer_iter, rel_res
1667 );
1668 }
1669
1670 if rel_res < options.rtol {
1672 return Ok(RestartGmresResult {
1673 solution: x,
1674 iterations: total_inner_iters,
1675 outer_iterations: outer_iter,
1676 residual_norm: rel_res,
1677 converged: true,
1678 });
1679 }
1680
1681 if let Some(min_reduction) = options.min_reduction {
1683 if outer_iter > 0 && rel_res > prev_residual_norm * min_reduction {
1684 if options.verbose {
1685 eprintln!(
1686 "Restart GMRES stagnated: residual ratio = {:.6e} > {:.6e}",
1687 rel_res / prev_residual_norm,
1688 min_reduction
1689 );
1690 }
1691 return Ok(RestartGmresResult {
1692 solution: x,
1693 iterations: total_inner_iters,
1694 outer_iterations: outer_iter,
1695 residual_norm: rel_res,
1696 converged: false,
1697 });
1698 }
1699 }
1700 prev_residual_norm = rel_res;
1701
1702 let zero = r.scale(AnyScalar::new_real(0.0))?;
1705 let inner_result = gmres_with_truncation(&apply_a, &r, &zero, &inner_options, &truncate)?;
1706
1707 total_inner_iters += inner_result.iterations;
1708
1709 if options.verbose {
1710 eprintln!(
1711 " Inner GMRES: {} iterations, residual = {:.6e}, converged = {}",
1712 inner_result.iterations, inner_result.residual_norm, inner_result.converged
1713 );
1714 }
1715
1716 if x_is_zero {
1720 x = inner_result.solution;
1721 x_is_zero = false;
1722 } else {
1723 x = x.axpby(
1724 AnyScalar::new_real(1.0),
1725 &inner_result.solution,
1726 AnyScalar::new_real(1.0),
1727 )?;
1728 }
1729 truncate(&mut x)?;
1730 }
1731
1732 let ax = apply_a(&x)?;
1735 let mut r = b.axpby(AnyScalar::new_real(1.0), &ax, AnyScalar::new_real(-1.0))?;
1736 truncate(&mut r)?;
1737 let final_rel_res = r.norm() / b_norm;
1738
1739 Ok(RestartGmresResult {
1740 solution: x,
1741 iterations: total_inner_iters,
1742 outer_iterations: options.max_outer_iters,
1743 residual_norm: final_rel_res,
1744 converged: false,
1745 })
1746}
1747
1748fn compute_givens_rotation(a: &AnyScalar, b: &AnyScalar) -> (AnyScalar, AnyScalar) {
1753 let a_abs = a.abs();
1754 let b_abs = b.abs();
1755 let r = (a_abs * a_abs + b_abs * b_abs).sqrt();
1756 if r < 1e-15 {
1757 (AnyScalar::new_real(1.0), AnyScalar::new_real(0.0))
1758 } else if a_abs < 1e-15 {
1759 (
1760 AnyScalar::new_real(0.0),
1761 b.clone().conj() / AnyScalar::new_real(r),
1762 )
1763 } else {
1764 let phase = a.clone() / AnyScalar::new_real(a_abs);
1765 (
1766 AnyScalar::new_real(a_abs / r),
1767 phase * b.clone().conj() / AnyScalar::new_real(r),
1768 )
1769 }
1770}
1771
1772fn apply_givens_rotation(
1778 c: &AnyScalar,
1779 s: &AnyScalar,
1780 x: &AnyScalar,
1781 y: &AnyScalar,
1782) -> (AnyScalar, AnyScalar) {
1783 let new_x = c.clone() * x.clone() + s.clone() * y.clone();
1784 let new_y = -(s.clone().conj() * x.clone()) + c.clone() * y.clone();
1785 (new_x, new_y)
1786}
1787
1788fn solve_upper_triangular(h: &[Vec<AnyScalar>], g: &[AnyScalar]) -> Result<Vec<AnyScalar>> {
1790 let n = g.len();
1791 if n == 0 {
1792 return Ok(vec![]);
1793 }
1794
1795 let mut y = vec![AnyScalar::new_real(0.0); n];
1796
1797 for i in (0..n).rev() {
1798 let mut sum = g[i].clone();
1799
1800 for j in (i + 1)..n {
1801 let prod = h[j][i].clone() * y[j].clone();
1803 sum = sum - prod;
1804 }
1805
1806 let h_ii = &h[i][i];
1807 if h_ii.abs() < 1e-15 {
1808 return Err(anyhow::anyhow!(
1809 "Near-singular upper triangular matrix in GMRES"
1810 ));
1811 }
1812
1813 y[i] = sum / h_ii.clone();
1814 }
1815
1816 Ok(y)
1817}
1818
1819fn update_solution<T: TensorVectorSpace>(x: &T, v_basis: &[T], y: &[AnyScalar]) -> Result<T> {
1821 let mut result = x.clone();
1822
1823 for (vi, yi) in v_basis.iter().zip(y.iter()) {
1824 let scaled_vi = vi.scale(yi.clone())?;
1825 result = result.axpby(
1827 AnyScalar::new_real(1.0),
1828 &scaled_vi,
1829 AnyScalar::new_real(1.0),
1830 )?;
1831 }
1832
1833 Ok(result)
1834}
1835
1836fn update_solution_truncated<T, Tr>(
1838 x: &T,
1839 v_basis: &[T],
1840 y: &[AnyScalar],
1841 truncate: &Tr,
1842) -> Result<T>
1843where
1844 T: TensorVectorSpace,
1845 Tr: Fn(&mut T) -> Result<()>,
1846{
1847 let mut result = x.clone();
1848 let mut result_is_zero = x.norm() == 0.0;
1853
1854 for (vi, yi) in v_basis.iter().zip(y.iter()) {
1855 let scaled_vi = vi.scale(yi.clone())?;
1856 if result_is_zero {
1857 result = scaled_vi;
1858 result_is_zero = false;
1859 } else {
1860 result = result.axpby(
1861 AnyScalar::new_real(1.0),
1862 &scaled_vi,
1863 AnyScalar::new_real(1.0),
1864 )?;
1865 }
1866 truncate(&mut result)?;
1868 }
1869
1870 Ok(result)
1871}
1872
1873#[cfg(test)]
1874mod tests;