From 51aa701f15cd910da155b85bf336740f674352ac Mon Sep 17 00:00:00 2001
From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com>
Date: Sat, 15 Feb 2025 11:56:12 +0000
Subject: [PATCH] Added IMEX Dahlquist equation (#529)

---
 .../problem_classes/TestEquation_0D.py        | 117 +++++++++++++++++-
 .../test_problems/test_Dahlquist_IMEX.py      |  32 +++++
 2 files changed, 148 insertions(+), 1 deletion(-)
 create mode 100644 pySDC/tests/test_problems/test_Dahlquist_IMEX.py

diff --git a/pySDC/implementations/problem_classes/TestEquation_0D.py b/pySDC/implementations/problem_classes/TestEquation_0D.py
index 811dcf60c..4276e96e9 100644
--- a/pySDC/implementations/problem_classes/TestEquation_0D.py
+++ b/pySDC/implementations/problem_classes/TestEquation_0D.py
@@ -2,7 +2,7 @@ import numpy as np
 import scipy.sparse as nsp
 
 from pySDC.core.problem import Problem, WorkCounter
-from pySDC.implementations.datatype_classes.mesh import mesh
+from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh
 
 
 class testequation0d(Problem):
@@ -145,3 +145,118 @@ class testequation0d(Problem):
         me = self.dtype_u(self.init)
         me[:] = u_init * self.xp.exp((t - t_init) * self.lambdas)
         return me
+
+
+class test_equation_IMEX(Problem):
+    dtype_f = imex_mesh
+    dtype_u = mesh
+    xp = np
+    xsp = nsp
+
+    def __init__(self, lambdas_implicit=None, lambdas_explicit=None, u0=0.0):
+        """Initialization routine"""
+
+        if lambdas_implicit is None:
+            re = self.xp.linspace(-30, 19, 50)
+            im = self.xp.linspace(-50, 49, 50)
+            lambdas_implicit = self.xp.array(
+                [[complex(re[i], im[j]) for i in range(len(re))] for j in range(len(im))]
+            ).reshape((len(re) * len(im)))
+        if lambdas_explicit is None:
+            re = self.xp.linspace(-30, 19, 50)
+            im = self.xp.linspace(-50, 49, 50)
+            lambdas_implicit = self.xp.array(
+                [[complex(re[i], im[j]) for i in range(len(re))] for j in range(len(im))]
+            ).reshape((len(re) * len(im)))
+        lambdas_implicit = self.xp.asarray(lambdas_implicit)
+        lambdas_explicit = self.xp.asarray(lambdas_explicit)
+
+        assert lambdas_implicit.ndim == 1, f'expect flat list here, got {lambdas_implicit}'
+        assert lambdas_explicit.shape == lambdas_implicit.shape
+        nvars = lambdas_implicit.size
+        assert nvars > 0, 'expect at least one lambda parameter here'
+
+        # invoke super init, passing number of dofs, dtype_u and dtype_f
+        super().__init__(init=(nvars, None, self.xp.dtype('complex128')))
+
+        self.A = self.xsp.diags(lambdas_implicit)
+        self._makeAttributeAndRegister(
+            'nvars', 'lambdas_implicit', 'lambdas_explicit', 'u0', localVars=locals(), readOnly=True
+        )
+        self.work_counters['rhs'] = WorkCounter()
+
+    def eval_f(self, u, t):
+        """
+        Routine to evaluate the right-hand side of the problem.
+
+        Parameters
+        ----------
+        u : dtype_u
+            Current values of the numerical solution.
+        t : float
+            Current time of the numerical solution is computed.
+
+        Returns
+        -------
+        f : dtype_f
+            The right-hand side of the problem.
+        """
+
+        f = self.dtype_f(self.init)
+        f.impl[:] = u * self.lambdas_implicit
+        f.expl[:] = u * self.lambdas_explicit
+        self.work_counters['rhs']()
+        return f
+
+    def solve_system(self, rhs, factor, u0, t):
+        r"""
+        Simple linear solver for :math:`(I-factor\cdot A)\vec{u}=\vec{rhs}`.
+
+        Parameters
+        ----------
+        rhs : dtype_f
+            Right-hand side for the linear system.
+        factor : float
+            Abbrev. for the local stepsize (or any other factor required).
+        u0 : dtype_u
+            Initial guess for the iterative solver.
+        t : float
+            Current time (e.g. for time-dependent BCs).
+
+        Returns
+        -------
+        me : dtype_u
+            The solution as mesh.
+        """
+        me = self.dtype_u(self.init)
+        L = 1 - factor * self.lambdas_implicit
+        L[L == 0] = 1  # to avoid potential divisions by zeros
+        me[:] = rhs
+        me /= L
+        return me
+
+    def u_exact(self, t, u_init=None, t_init=None):
+        """
+        Routine to compute the exact solution at time t.
+
+        Parameters
+        ----------
+        t : float
+            Time of the exact solution.
+        u_init : pySDC.problem.testequation0d.dtype_u
+            Initial solution.
+        t_init : float
+            The initial time.
+
+        Returns
+        -------
+        me : dtype_u
+            The exact solution.
+        """
+
+        u_init = (self.u0 if u_init is None else u_init) * 1.0
+        t_init = 0.0 if t_init is None else t_init * 1.0
+
+        me = self.dtype_u(self.init)
+        me[:] = u_init * self.xp.exp((t - t_init) * (self.lambdas_implicit + self.lambdas_explicit))
+        return me
diff --git a/pySDC/tests/test_problems/test_Dahlquist_IMEX.py b/pySDC/tests/test_problems/test_Dahlquist_IMEX.py
new file mode 100644
index 000000000..ebc2fb73c
--- /dev/null
+++ b/pySDC/tests/test_problems/test_Dahlquist_IMEX.py
@@ -0,0 +1,32 @@
+def test_Dahlquist_IMEX():
+    from pySDC.implementations.problem_classes.TestEquation_0D import test_equation_IMEX
+    import numpy as np
+
+    N = 1
+    dt = 1e-2
+
+    lambdas_implicit = np.ones(N) * -10
+    lambdas_explicit = np.ones(N) * -1e-3
+
+    prob = test_equation_IMEX(lambdas_explicit=lambdas_explicit, lambdas_implicit=lambdas_implicit, u0=1)
+
+    u0 = prob.u_exact(0)
+
+    # do IMEX Euler step forward
+    f0 = prob.eval_f(u0, 0)
+    u1 = prob.solve_system(u0 + dt * f0.expl, dt, u0, 0)
+
+    exact = prob.u_exact(dt)
+    error = abs(u1 - exact)
+    error0 = abs(u0 - exact)
+    assert error < error0 * 1e-1
+
+    # do explicit Euler step backwards
+    f = prob.eval_f(u1, dt)
+    u02 = u1 - dt * (f.impl + f0.expl)
+
+    assert np.allclose(u0, u02)
+
+
+if __name__ == '__main__':
+    test_Dahlquist_IMEX()
-- 
GitLab