From 5b92e943df34fff9ff271930ee1489b59f4cdf10 Mon Sep 17 00:00:00 2001
From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com>
Date: Thu, 23 Jan 2025 12:09:03 +0000
Subject: [PATCH] Implemented relative error estimates for adaptive step size
 selection (#515)

---
 .../adaptivity.py                             | 27 +++++-
 .../estimate_embedded_error.py                | 18 +++-
 .../estimate_polynomial_error.py              | 93 +++++++++++++++++--
 .../test_polynomial_error.py                  | 83 ++++++++++++++++-
 4 files changed, 200 insertions(+), 21 deletions(-)

diff --git a/pySDC/implementations/convergence_controller_classes/adaptivity.py b/pySDC/implementations/convergence_controller_classes/adaptivity.py
index 209225d20..666755721 100644
--- a/pySDC/implementations/convergence_controller_classes/adaptivity.py
+++ b/pySDC/implementations/convergence_controller_classes/adaptivity.py
@@ -307,6 +307,7 @@ class Adaptivity(AdaptivityBase):
         """
         defaults = {
             "embedded_error_flavor": 'standard',
+            "rel_error": False,
         }
         return {**defaults, **super().setup(controller, params, description, **kwargs)}
 
@@ -328,6 +329,9 @@ class Adaptivity(AdaptivityBase):
         controller.add_convergence_controller(
             EstimateEmbeddedError.get_implementation(self.params.embedded_error_flavor, self.params.useMPI),
             description=description,
+            params={
+                'rel_error': self.params.rel_error,
+            },
         )
 
         # load contraction factor estimator if necessary
@@ -837,6 +841,8 @@ class AdaptivityPolynomialError(AdaptivityForConvergedCollocationProblems):
 
         defaults = {
             'control_order': -50,
+            'problem_mesh_type': 'numpyesque',
+            'rel_error': False,
             **super().setup(controller, params, description, **kwargs),
             **params,
         }
@@ -858,16 +864,27 @@ class AdaptivityPolynomialError(AdaptivityForConvergedCollocationProblems):
         Returns:
             None
         """
-        from pySDC.implementations.convergence_controller_classes.estimate_polynomial_error import (
-            EstimatePolynomialError,
-        )
+        if self.params.problem_mesh_type.lower() == 'numpyesque':
+            from pySDC.implementations.convergence_controller_classes.estimate_polynomial_error import (
+                EstimatePolynomialError as error_estimation_cls,
+            )
+        elif self.params.problem_mesh_type.lower() == 'firedrake':
+            from pySDC.implementations.convergence_controller_classes.estimate_polynomial_error import (
+                EstimatePolynomialErrorFiredrake as error_estimation_cls,
+            )
+        else:
+            raise NotImplementedError(
+                f'Don\'t know what error estimation class to use for problems with mesh type {self.params.problem_mesh_type}'
+            )
 
         super().dependencies(controller, description, **kwargs)
 
         controller.add_convergence_controller(
-            EstimatePolynomialError,
+            error_estimation_cls,
             description=description,
-            params={},
+            params={
+                'rel_error': self.params.rel_error,
+            },
         )
         return None
 
diff --git a/pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py b/pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py
index 08aa14730..ff5714ac2 100644
--- a/pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py
+++ b/pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py
@@ -57,6 +57,7 @@ class EstimateEmbeddedError(ConvergenceController):
         return {
             "control_order": -80,
             "sweeper_type": sweeper_type,
+            "rel_error": False,
             **super().setup(controller, params, description, **kwargs),
         }
 
@@ -94,13 +95,24 @@ class EstimateEmbeddedError(ConvergenceController):
         """
         if self.params.sweeper_type == "RK":
             L.sweep.compute_end_point()
-            return abs(L.uend - L.sweep.u_secondary)
+            if self.params.rel_error:
+                return abs(L.uend - L.sweep.u_secondary) / abs(L.uend)
+            else:
+                return abs(L.uend - L.sweep.u_secondary)
         elif self.params.sweeper_type == "SDC":
             # order rises by one between sweeps
-            return abs(L.uold[-1] - L.u[-1])
+            if self.params.rel_error:
+                return abs(L.uold[-1] - L.u[-1]) / abs(L.u[-1])
+            else:
+                return abs(L.uold[-1] - L.u[-1])
         elif self.params.sweeper_type == 'MPI':
             comm = L.sweep.comm
-            return comm.bcast(abs(L.uold[comm.rank + 1] - L.u[comm.rank + 1]), root=comm.size - 1)
+            if self.params.rel_error:
+                return comm.bcast(
+                    abs(L.uold[comm.rank + 1] - L.u[comm.rank + 1]) / abs(L.u[comm.rank + 1]), root=comm.size - 1
+                )
+            else:
+                return comm.bcast(abs(L.uold[comm.rank + 1] - L.u[comm.rank + 1]), root=comm.size - 1)
         else:
             raise NotImplementedError(
                 f"Don't know how to estimate embedded error for sweeper type \
diff --git a/pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py b/pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py
index cce409df6..748a2f41d 100644
--- a/pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py
+++ b/pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py
@@ -37,6 +37,7 @@ class EstimatePolynomialError(ConvergenceController):
         defaults = {
             'control_order': -75,
             'estimate_on_node': num_nodes + 1 if quad_type == 'GAUSS' else num_nodes - 1,
+            'rel_error': False,
             **super().setup(controller, params, description, **kwargs),
         }
         self.comm = description['sweeper_params'].get('comm', None)
@@ -103,6 +104,23 @@ class EstimatePolynomialError(ConvergenceController):
         else:
             return A @ xp.asarray(b)
 
+    def get_interpolated_solution(self, L, xp):
+        """
+        Get the interpolated solution for numpy or cupy data types
+
+        Args:
+            u_vec (array): Vector of solutions
+            prob (pySDC.problem): Problem
+        """
+        coll = L.sweep.coll
+
+        u = [
+            L.u[i].flatten() if L.u[i] is not None else L.u[i]
+            for i in range(coll.num_nodes + 1)
+            if i != self.params.estimate_on_node
+        ]
+        return self.matmul(self.interpolation_matrix, u, xp=xp)[0].reshape(L.prob.init[0])
+
     def post_iteration_processing(self, controller, S, **kwargs):
         """
         Estimate the error
@@ -120,7 +138,11 @@ class EstimatePolynomialError(ConvergenceController):
             coll = L.sweep.coll
             nodes = np.append(np.append(0, coll.nodes), 1.0)
             estimate_on_node = self.params.estimate_on_node
-            xp = L.u[0].xp
+
+            if hasattr(L.u[0], 'xp'):
+                xp = L.u[0].xp
+            else:
+                xp = np
 
             if self.interpolation_matrix is None:
                 interpolator = LagrangeApproximation(
@@ -128,12 +150,7 @@ class EstimatePolynomialError(ConvergenceController):
                 )
                 self.interpolation_matrix = xp.array(interpolator.getInterpolationMatrix([nodes[estimate_on_node]]))
 
-            u = [
-                L.u[i].flatten() if L.u[i] is not None else L.u[i]
-                for i in range(coll.num_nodes + 1)
-                if i != estimate_on_node
-            ]
-            u_inter = self.matmul(self.interpolation_matrix, u, xp=xp)[0].reshape(L.prob.init[0])
+            u_inter = self.get_interpolated_solution(L, xp)
 
             # compute end point if needed
             if estimate_on_node == len(nodes) - 1:
@@ -147,12 +164,14 @@ class EstimatePolynomialError(ConvergenceController):
                 rank = estimate_on_node - 1
                 L.status.order_embedded_estimate = coll.num_nodes * 1
 
+            rescale = float(abs(u_inter)) if self.params.rel_error else 1
+
             if self.comm:
-                buf = np.array(abs(u_inter - high_order_sol) if self.comm.rank == rank else 0.0)
+                buf = np.array(abs(u_inter - high_order_sol) / rescale if self.comm.rank == rank else 0.0)
                 self.comm.Bcast(buf, root=rank)
                 L.status.error_embedded_estimate = float(buf)
             else:
-                L.status.error_embedded_estimate = abs(u_inter - high_order_sol)
+                L.status.error_embedded_estimate = abs(u_inter - high_order_sol) / rescale
 
             self.debug(
                 f'Obtained error estimate: {L.status.error_embedded_estimate:.2e} of order {L.status.order_embedded_estimate}',
@@ -176,3 +195,59 @@ class EstimatePolynomialError(ConvergenceController):
             return False, 'Need at least two collocation nodes to interpolate to one!'
 
         return True, ""
+
+
+class EstimatePolynomialErrorFiredrake(EstimatePolynomialError):
+    def matmul(self, A, b):
+        """
+        Matrix vector multiplication, possibly MPI parallel.
+        The parallel implementation performs a reduce operation in every row of the matrix. While communicating the
+        entire vector once could reduce the number of communications, this way we never need to store the entire vector
+        on any specific rank.
+
+        Args:
+            A (2d np.ndarray): Matrix
+            b (list): Vector
+
+        Returns:
+            List: Axb
+        """
+
+        if self.comm:
+            res = [A[i, 0] * b[0] if b[i] is not None else None for i in range(A.shape[0])]
+            buf = 0 * b[0]
+            for i in range(0, A.shape[0]):
+                index = self.comm.rank + (1 if self.comm.rank < self.params.estimate_on_node - 1 else 0)
+                send_buf = (
+                    (A[i, index] * b[index]) if self.comm.rank != self.params.estimate_on_node - 1 else 0 * res[0]
+                )
+                self.comm.Allreduce(send_buf, buf, op=self.MPI_SUM)
+                res[i] += buf
+            return res
+        else:
+            res = []
+            for i in range(A.shape[0]):
+                res.append(A[i, 0] * b[0])
+                for j in range(1, A.shape[1]):
+                    res[-1] += A[i, j] * b[j]
+
+            return res
+
+    def get_interpolated_solution(self, L):
+        """
+        Get the interpolated solution for Firedrake data types
+        We are not 100% sure that you don't need to invert the mass matrix here, but should be fine.
+
+        Args:
+            u_vec (array): Vector of solutions
+            prob (pySDC.problem): Problem
+        """
+        coll = L.sweep.coll
+
+        u = [
+            L.u[i] if L.u[i] is not None else L.u[i]
+            for i in range(coll.num_nodes + 1)
+            if i != self.params.estimate_on_node
+        ]
+        return L.prob.dtype_u(self.matmul(self.interpolation_matrix, u)[0])
+        # return L.prob.invert_mass_matrix(self.matmul(self.interpolation_matrix, u)[0])
diff --git a/pySDC/tests/test_convergence_controllers/test_polynomial_error.py b/pySDC/tests/test_convergence_controllers/test_polynomial_error.py
index 103bd6380..db17e3b52 100644
--- a/pySDC/tests/test_convergence_controllers/test_polynomial_error.py
+++ b/pySDC/tests/test_convergence_controllers/test_polynomial_error.py
@@ -1,7 +1,7 @@
 import pytest
 
 
-def get_controller(dt, num_nodes, quad_type, useMPI, useGPU):
+def get_controller(dt, num_nodes, quad_type, useMPI, useGPU, rel_error):
     """
     Get a controller prepared for polynomial test equation
 
@@ -64,7 +64,7 @@ def get_controller(dt, num_nodes, quad_type, useMPI, useGPU):
     description['sweeper_params'] = sweeper_params
     description['level_params'] = level_params
     description['step_params'] = step_params
-    description['convergence_controllers'] = {EstimatePolynomialError: {}}
+    description['convergence_controllers'] = {EstimatePolynomialError: {'rel_error': rel_error}}
 
     controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description)
     return controller
@@ -177,13 +177,15 @@ def check_order(dts, **kwargs):
 @pytest.mark.base
 @pytest.mark.parametrize('num_nodes', [2, 3, 4, 5])
 @pytest.mark.parametrize('quad_type', ['RADAU-RIGHT', 'GAUSS'])
-def test_interpolation_error(num_nodes, quad_type):
+@pytest.mark.parametrize('rel_error', [True, False])
+def test_interpolation_error(num_nodes, quad_type, rel_error):
     import numpy as np
 
     kwargs = {
         'num_nodes': num_nodes,
         'quad_type': quad_type,
         'useMPI': False,
+        'rel_error': rel_error,
     }
     steps = np.logspace(-1, -4, 20)
     check_order(steps, **kwargs)
@@ -200,6 +202,7 @@ def test_interpolation_error_GPU(num_nodes, quad_type):
         'quad_type': quad_type,
         'useMPI': False,
         'useGPU': True,
+        'rel_error': False,
     }
     steps = np.logspace(-1, -4, 20)
     check_order(steps, **kwargs)
@@ -228,6 +231,77 @@ def test_interpolation_error_MPI(num_nodes, quad_type):
     )
 
 
+@pytest.mark.firedrake
+def test_polynomial_error_firedrake(dt=1.0, num_nodes=3, useMPI=False):
+    from pySDC.implementations.problem_classes.HeatFiredrake import Heat1DForcedFiredrake
+    from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
+    from pySDC.implementations.convergence_controller_classes.estimate_polynomial_error import (
+        EstimatePolynomialErrorFiredrake,
+        LagrangeApproximation,
+    )
+    import numpy as np
+
+    if useMPI:
+        from pySDC.implementations.sweeper_classes.generic_implicit_MPI import generic_implicit_MPI as sweeper_class
+        from mpi4py import MPI
+
+        comm = MPI.COMM_WORLD
+    else:
+        from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit as sweeper_class
+
+        comm = None
+
+    level_params = {}
+    level_params['dt'] = dt
+    level_params['restol'] = 1.0
+
+    sweeper_params = {}
+    sweeper_params['quad_type'] = 'RADAU-RIGHT'
+    sweeper_params['num_nodes'] = num_nodes
+    sweeper_params['comm'] = comm
+
+    problem_params = {'n': 1}
+
+    step_params = {}
+    step_params['maxiter'] = 0
+
+    controller_params = {}
+    controller_params['logger_level'] = 30
+    controller_params['mssdc_jac'] = False
+
+    description = {}
+    description['problem_class'] = Heat1DForcedFiredrake
+    description['problem_params'] = problem_params
+    description['sweeper_class'] = sweeper_class
+    description['sweeper_params'] = sweeper_params
+    description['level_params'] = level_params
+    description['step_params'] = step_params
+    description['convergence_controllers'] = {EstimatePolynomialErrorFiredrake: {}}
+
+    controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description)
+
+    L = controller.MS[0].levels[0]
+
+    cont = controller.convergence_controllers[
+        np.arange(len(controller.convergence_controllers))[
+            [type(me).__name__ == 'EstimatePolynomialErrorFiredrake' for me in controller.convergence_controllers]
+        ][0]
+    ]
+
+    nodes = np.append(np.append(0, L.sweep.coll.nodes), 1.0)
+    estimate_on_node = cont.params.estimate_on_node
+    interpolator = LagrangeApproximation(points=[nodes[i] for i in range(num_nodes + 1) if i != estimate_on_node])
+    cont.interpolation_matrix = np.array(interpolator.getInterpolationMatrix([nodes[estimate_on_node]]))
+
+    for i in range(num_nodes + 1):
+        L.u[i] = L.prob.u_init
+        L.u[i].functionspace.assign(nodes[i])
+
+    u_inter = cont.get_interpolated_solution(L)
+    error = abs(u_inter - L.u[estimate_on_node])
+    assert np.isclose(error, 0)
+
+
 if __name__ == "__main__":
     import sys
     import numpy as np
@@ -238,7 +312,8 @@ if __name__ == "__main__":
         kwargs = {
             'num_nodes': int(sys.argv[1]),
             'quad_type': sys.argv[2],
+            'rel_error': False,
         }
         check_order(steps, useMPI=True, **kwargs)
     else:
-        check_order(steps, useMPI=False, num_nodes=3, quad_type='RADAU-RIGHT')
+        check_order(steps, useMPI=False, num_nodes=3, quad_type='RADAU-RIGHT', rel_error=False)
-- 
GitLab