From bf940bd4a5578192fa1c6e0da83e1dbf749aaea8 Mon Sep 17 00:00:00 2001
From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com>
Date: Thu, 14 Nov 2024 14:07:56 +0100
Subject: [PATCH] Added `increment` to level status (#504)

---
 .../convergence_controller_classes/adaptivity.py      | 11 ++++++++++-
 .../check_convergence.py                              |  4 +---
 .../estimate_embedded_error.py                        |  4 +++-
 .../estimate_polynomial_error.py                      |  2 +-
 4 files changed, 15 insertions(+), 6 deletions(-)

diff --git a/pySDC/implementations/convergence_controller_classes/adaptivity.py b/pySDC/implementations/convergence_controller_classes/adaptivity.py
index b2ec90e5c..209225d20 100644
--- a/pySDC/implementations/convergence_controller_classes/adaptivity.py
+++ b/pySDC/implementations/convergence_controller_classes/adaptivity.py
@@ -229,7 +229,16 @@ class AdaptivityForConvergedCollocationProblems(AdaptivityBase):
         if self.get_convergence(controller, S, **kwargs):
             self.res_last_iter = np.inf
 
-            if self.params.restart_at_maxiter and S.levels[0].status.residual > S.levels[0].params.restol:
+            L = S.levels[0]
+            e_tol_converged = (
+                L.status.increment < L.params.e_tol if (L.params.get('e_tol') and L.status.get('increment')) else False
+            )
+
+            if (
+                self.params.restart_at_maxiter
+                and S.levels[0].status.residual > S.levels[0].params.restol
+                and not e_tol_converged
+            ):
                 self.trigger_restart_upon_nonconvergence(S)
             elif self.get_local_error_estimate(controller, S, **kwargs) > self.params.e_tol:
                 S.status.restart = True
diff --git a/pySDC/implementations/convergence_controller_classes/check_convergence.py b/pySDC/implementations/convergence_controller_classes/check_convergence.py
index 9cbe85e25..36cb8e4a2 100644
--- a/pySDC/implementations/convergence_controller_classes/check_convergence.py
+++ b/pySDC/implementations/convergence_controller_classes/check_convergence.py
@@ -75,9 +75,7 @@ class CheckConvergence(ConvergenceController):
         iter_converged = S.status.iter >= S.params.maxiter
         res_converged = L.status.residual <= L.params.restol
         e_tol_converged = (
-            L.status.error_embedded_estimate < L.params.e_tol
-            if (L.params.get('e_tol') and L.status.get('error_embedded_estimate'))
-            else False
+            L.status.increment < L.params.e_tol if (L.params.get('e_tol') and L.status.get('increment')) else False
         )
         converged = (
             iter_converged or res_converged or e_tol_converged or S.status.force_done
diff --git a/pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py b/pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py
index 545c2f6ef..08aa14730 100644
--- a/pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py
+++ b/pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py
@@ -109,12 +109,13 @@ class EstimateEmbeddedError(ConvergenceController):
 
     def setup_status_variables(self, controller, **kwargs):
         """
-        Add the embedded error variable to the error function.
+        Add the embedded error to the level status
 
         Args:
             controller (pySDC.Controller): The controller
         """
         self.add_status_variable_to_level('error_embedded_estimate')
+        self.add_status_variable_to_level('increment')
 
     def post_iteration_processing(self, controller, S, **kwargs):
         """
@@ -134,6 +135,7 @@ class EstimateEmbeddedError(ConvergenceController):
         if S.status.iter > 0 or self.params.sweeper_type == "RK":
             for L in S.levels:
                 L.status.error_embedded_estimate = max([self.estimate_embedded_error_serial(L), np.finfo(float).eps])
+                L.status.increment = L.status.error_embedded_estimate * 1
                 self.debug(f'L.status.error_embedded_estimate={L.status.error_embedded_estimate:.5e}', S)
 
         return None
diff --git a/pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py b/pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py
index f083651e3..cce409df6 100644
--- a/pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py
+++ b/pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py
@@ -150,7 +150,7 @@ class EstimatePolynomialError(ConvergenceController):
             if self.comm:
                 buf = np.array(abs(u_inter - high_order_sol) if self.comm.rank == rank else 0.0)
                 self.comm.Bcast(buf, root=rank)
-                L.status.error_embedded_estimate = buf
+                L.status.error_embedded_estimate = float(buf)
             else:
                 L.status.error_embedded_estimate = abs(u_inter - high_order_sol)
 
-- 
GitLab