diff --git a/pySDC/implementations/problem_classes/GrayScott_MPIFFT.py b/pySDC/implementations/problem_classes/GrayScott_MPIFFT.py
index ba14ea762dc37bc2fe59e3015f174db518abd46e..796a359b142101e0c26b440e695dacf88b297d10 100644
--- a/pySDC/implementations/problem_classes/GrayScott_MPIFFT.py
+++ b/pySDC/implementations/problem_classes/GrayScott_MPIFFT.py
@@ -200,7 +200,6 @@ class grayscott_imex_diffusion(IMEX_Laplacian_MPIFFT):
             Exact solution.
         """
         assert t == 0.0, 'Exact solution only valid as initial condition'
-        assert self.ndim == 2, 'The initial conditions are 2D for now..'
 
         xp = self.xp
 
@@ -222,12 +221,13 @@ class grayscott_imex_diffusion(IMEX_Laplacian_MPIFFT):
             _v[...] = xp.sqrt(F) * (A + xp.sqrt(A**2 - 4)) / 2
 
             for _ in range(-self.num_blobs):
-                x0, y0 = rng.random(size=2) * self.L[0] - self.L[0] / 2
-                lx, ly = rng.random(size=2) * self.L[0] / self.nvars[0] * 30
+                x0 = rng.random(size=self.ndim) * self.L[0] - self.L[0] / 2
+                l = rng.random(size=self.ndim) * self.L[0] / self.nvars[0] * 30
 
-                mask_x = xp.logical_and(self.X[0] > x0, self.X[0] < x0 + lx)
-                mask_y = xp.logical_and(self.X[1] > y0, self.X[1] < y0 + ly)
-                mask = xp.logical_and(mask_x, mask_y)
+                masks = [xp.logical_and(self.X[i] > x0[i], self.X[i] < x0[i] + l[i]) for i in range(self.ndim)]
+                mask = masks[0]
+                for m in masks[1:]:
+                    mask = xp.logical_and(mask, m)
 
                 _u[mask] = rng.random()
                 _v[mask] = rng.random()
@@ -236,6 +236,7 @@ class grayscott_imex_diffusion(IMEX_Laplacian_MPIFFT):
             """
             Blobs as in https://www.chebfun.org/examples/pde/GrayScott.html
             """
+            assert self.ndim == 2, 'The initial conditions are 2D for now..'
 
             inc = self.L[0] / (self.num_blobs + 1)
 
diff --git a/pySDC/implementations/problem_classes/generic_MPIFFT_Laplacian.py b/pySDC/implementations/problem_classes/generic_MPIFFT_Laplacian.py
index e1fc8f6812c846479cfc4dad8abf03f5b4fcf7ef..8f4ee38efb759abe0608dc8ed6d4ca7ac077d3c2 100644
--- a/pySDC/implementations/problem_classes/generic_MPIFFT_Laplacian.py
+++ b/pySDC/implementations/problem_classes/generic_MPIFFT_Laplacian.py
@@ -1,18 +1,17 @@
 import numpy as np
 from mpi4py import MPI
-from mpi4py_fft import PFFT
+from mpi4py_fft import PFFT, newDistArray
 
 from pySDC.core.errors import ProblemError
 from pySDC.core.problem import Problem, WorkCounter
 from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh
 
-from mpi4py_fft import newDistArray
-
 
 class IMEX_Laplacian_MPIFFT(Problem):
     r"""
     Generic base class for IMEX problems using a spectral method to solve the Laplacian implicitly and a possible rest
     explicitly. The FFTs are done with``mpi4py-fft`` [1]_.
+    Works in two and three dimensions.
 
     Parameters
     ----------
@@ -99,14 +98,24 @@ class IMEX_Laplacian_MPIFFT(Problem):
             'nvars', 'spectral', 'L', 'alpha', 'comm', 'x0', 'useGPU', localVars=locals(), readOnly=True
         )
 
-        # get local mesh
+        self.getLocalGrid()
+        self.getLaplacian()
+
+        # Need this for diagnostics
+        self.dx = self.L[0] / nvars[0]
+        self.dy = self.L[1] / nvars[1]
+
+        # work counters
+        self.work_counters['rhs'] = WorkCounter()
+
+    def getLocalGrid(self):
         X = list(self.xp.ogrid[self.fft.local_slice(False)])
         N = self.fft.global_shape()
         for i in range(len(N)):
-            X[i] = x0 + (X[i] * L[i] / N[i])
+            X[i] = self.x0 + (X[i] * self.L[i] / N[i])
         self.X = [self.xp.broadcast_to(x, self.fft.shape(False)) for x in X]
 
-        # get local wavenumbers and Laplace operator
+    def getLaplacian(self):
         s = self.fft.local_slice()
         N = self.fft.global_shape()
         k = [self.xp.fft.fftfreq(n, 1.0 / n).astype(int) for n in N]
@@ -117,14 +126,7 @@ class IMEX_Laplacian_MPIFFT(Problem):
             Ks[i] = (Ks[i] * Lp[i]).astype(float)
         K = [self.xp.broadcast_to(k, self.fft.shape(True)) for k in Ks]
         K = self.xp.array(K).astype(float)
-        self.K2 = self.xp.sum(K * K, 0, dtype=float)  # Laplacian in spectral space
-
-        # Need this for diagnostics
-        self.dx = self.L[0] / nvars[0]
-        self.dy = self.L[1] / nvars[1]
-
-        # work counters
-        self.work_counters['rhs'] = WorkCounter()
+        self.K2 = self.xp.sum(K * K, 0, dtype=float)
 
     def eval_f(self, u, t):
         """
diff --git a/pySDC/projects/GPU/configs/GS_configs.py b/pySDC/projects/GPU/configs/GS_configs.py
index 6ee1374d5d1b0dd93354459218a6b27c5413b17b..f57cdc7ea2b6746b732a72f8182096beb85e6113 100644
--- a/pySDC/projects/GPU/configs/GS_configs.py
+++ b/pySDC/projects/GPU/configs/GS_configs.py
@@ -27,6 +27,7 @@ class GrayScott(Config):
     num_frames = 200
     sweeper_type = 'IMEX'
     res_per_blob = 2**7
+    ndim = 3
 
     def get_LogToFile(self, ranks=None):
         import numpy as np
@@ -49,16 +50,14 @@ class GrayScott(Config):
                     't': L.time + L.dt,
                     'u': uend[0].get().view(np.ndarray),
                     'v': uend[1].get().view(np.ndarray),
-                    'X': L.prob.X[0].get().view(np.ndarray),
-                    'Y': L.prob.X[1].get().view(np.ndarray),
+                    'X': L.prob.X.get().view(np.ndarray),
                 }
             else:
                 return {
                     't': L.time + L.dt,
                     'u': uend[0],
                     'v': uend[1],
-                    'X': L.prob.X[0],
-                    'Y': L.prob.X[1],
+                    'X': L.prob.X,
                 }
 
         def logging_condition(L):
@@ -75,7 +74,7 @@ class GrayScott(Config):
         LogToFile.logging_condition = logging_condition
         return LogToFile
 
-    def plot(self, P, idx, n_procs_list):  # pragma: no cover
+    def plot(self, P, idx, n_procs_list, projection='xy'):  # pragma: no cover
         import numpy as np
         from matplotlib import ticker as tkr
 
@@ -99,19 +98,49 @@ class GrayScott(Config):
             vmax['u'] = max([vmax['u'], buffer[f'u-{rank}']['u'].real.max()])
 
         for rank in range(n_procs_list[2]):
-            im = ax.pcolormesh(
-                buffer[f'u-{rank}']['X'],
-                buffer[f'u-{rank}']['Y'],
-                buffer[f'u-{rank}']['v'].real,
-                vmin=vmin['v'],
-                vmax=vmax['v'],
-                cmap='binary',
-            )
+            if len(buffer[f'u-{rank}']['X']) == 2:
+                ax.set_xlabel('$x$')
+                ax.set_ylabel('$y$')
+                im = ax.pcolormesh(
+                    buffer[f'u-{rank}']['X'][0],
+                    buffer[f'u-{rank}']['X'][1],
+                    buffer[f'u-{rank}']['v'].real,
+                    vmin=vmin['v'],
+                    vmax=vmax['v'],
+                    cmap='binary',
+                )
+            else:
+                v3d = buffer[f'u-{rank}']['v'].real
+
+                if projection == 'xy':
+                    slices = [slice(None), slice(None), v3d.shape[2] // 2]
+                    x = buffer[f'u-{rank}']['X'][0][*slices]
+                    y = buffer[f'u-{rank}']['X'][1][*slices]
+                    ax.set_xlabel('$x$')
+                    ax.set_ylabel('$y$')
+                elif projection == 'xz':
+                    slices = [slice(None), v3d.shape[1] // 2, slice(None)]
+                    x = buffer[f'u-{rank}']['X'][0][*slices]
+                    y = buffer[f'u-{rank}']['X'][2][*slices]
+                    ax.set_xlabel('$x$')
+                    ax.set_ylabel('$z$')
+                elif projection == 'yz':
+                    slices = [v3d.shape[0] // 2, slice(None), slice(None)]
+                    x = buffer[f'u-{rank}']['X'][1][*slices]
+                    y = buffer[f'u-{rank}']['X'][2][*slices]
+                    ax.set_xlabel('$y$')
+                    ax.set_ylabel('$z$')
+
+                im = ax.pcolormesh(
+                    x,
+                    y,
+                    v3d[*slices],
+                    vmin=vmin['v'],
+                    vmax=vmax['v'],
+                    cmap='binary',
+                )
             fig.colorbar(im, cax, format=tkr.FormatStrFormatter('%.1f'))
             ax.set_title(f't={buffer[f"u-{rank}"]["t"]:.2f}')
-            ax.set_xlabel('$x$')
-            ax.set_ylabel('$y$')
-            ax.set_aspect(1.0)
             ax.set_aspect(1.0)
         return fig
 
@@ -130,7 +159,7 @@ class GrayScott(Config):
         desc['sweeper_params']['QI'] = 'MIN-SR-S'
         desc['sweeper_params']['QE'] = 'PIC'
 
-        desc['problem_params']['nvars'] = (2**8 if res == -1 else res,) * 2
+        desc['problem_params']['nvars'] = (2**8 if res == -1 else res,) * self.ndim
         desc['problem_params']['Du'] = 0.00002
         desc['problem_params']['Dv'] = 0.00001
         desc['problem_params']['A'] = 0.04
diff --git a/pySDC/tests/test_problems/test_generic_MPIFFT.py b/pySDC/tests/test_problems/test_generic_MPIFFT.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd82cf41157a4fd09c6695f3c4c205c09711a816
--- /dev/null
+++ b/pySDC/tests/test_problems/test_generic_MPIFFT.py
@@ -0,0 +1,57 @@
+import pytest
+
+
+@pytest.mark.mpi4py
+@pytest.mark.parametrize('nx', [8, 16])
+@pytest.mark.parametrize('ny', [8, 16])
+@pytest.mark.parametrize('nz', [0, 8])
+@pytest.mark.parametrize('f', [1, 3])
+@pytest.mark.parametrize('spectral', [True, False])
+@pytest.mark.parametrize('direction', [0, 1, 10])
+def test_derivative(nx, ny, nz, f, spectral, direction):
+    from pySDC.implementations.problem_classes.generic_MPIFFT_Laplacian import IMEX_Laplacian_MPIFFT
+
+    nvars = (nx, ny)
+    if nz > 0:
+        nvars += (nz,)
+    prob = IMEX_Laplacian_MPIFFT(nvars=nvars, spectral=spectral)
+
+    xp = prob.xp
+
+    if direction == 0:
+        _u = xp.sin(f * prob.X[0])
+        du_expect = -(f**2) * xp.sin(f * prob.X[0])
+    elif direction == 1:
+        _u = xp.sin(f * prob.X[1])
+        du_expect = -(f**2) * xp.sin(f * prob.X[1])
+    elif direction == 10:
+        _u = xp.sin(f * prob.X[1]) + xp.cos(f * prob.X[0])
+        du_expect = -(f**2) * xp.sin(f * prob.X[1]) - f**2 * xp.cos(f * prob.X[0])
+    else:
+        raise
+
+    if spectral:
+        u = prob.fft.forward(_u)
+    else:
+        u = _u
+
+    _du = prob.eval_f(u, 0).impl
+
+    if spectral:
+        du = prob.fft.backward(_du)
+    else:
+        du = _du
+    assert xp.allclose(du, du_expect), 'Got unexpected derivative'
+
+    u2 = prob.solve_system(_du, factor=1e8, u0=du, t=0) * -1e8
+
+    if spectral:
+        _u2 = prob.fft.backward(u2)
+    else:
+        _u2 = u2
+
+    assert xp.allclose(_u2, _u, atol=1e-7), 'Got unexpected inverse derivative'
+
+
+if __name__ == '__main__':
+    test_derivative(6, 6, 6, 3, False, 1)