diff --git a/pySDC/helpers/spectral_helper.py b/pySDC/helpers/spectral_helper.py
index 7381b33e6e0e9c47aa934d0b6b5f2e8810c51a0b..cd175bb738722db435eaf0649bec825ed073b5d0 100644
--- a/pySDC/helpers/spectral_helper.py
+++ b/pySDC/helpers/spectral_helper.py
@@ -882,6 +882,7 @@ class SpectralHelper:
         self.BCs = None
 
         self.fft_cache = {}
+        self.fft_dealias_shape_cache = {}
 
     @property
     def u_init(self):
@@ -1470,7 +1471,9 @@ class SpectralHelper:
 
             if padding is not None:
                 shape = list(v.shape)
-                if self.comm:
+                if ('forward', *padding) in self.fft_dealias_shape_cache.keys():
+                    shape[0] = self.fft_dealias_shape_cache[('forward', *padding)]
+                elif self.comm:
                     send_buf = np.array(v.shape[0])
                     recv_buf = np.array(v.shape[0])
                     self.comm.Allreduce(send_buf, recv_buf)
@@ -1645,7 +1648,9 @@ class SpectralHelper:
             if padding is not None:
                 if padding[axis] != 1:
                     shape = list(v.shape)
-                    if self.comm:
+                    if ('backward', *padding) in self.fft_dealias_shape_cache.keys():
+                        shape[0] = self.fft_dealias_shape_cache[('backward', *padding)]
+                    elif self.comm:
                         send_buf = np.array(v.shape[0])
                         recv_buf = np.array(v.shape[0])
                         self.comm.Allreduce(send_buf, recv_buf)
@@ -1754,8 +1759,6 @@ class SpectralHelper:
         if self.comm.size == 1:
             return u.copy()
 
-        fft = self.get_fft(**kwargs) if fft is None else fft
-
         global_fft = self.get_fft(**kwargs)
         axisA = [me.axisA for me in global_fft.transfer]
         axisB = [me.axisB for me in global_fft.transfer]
@@ -1793,6 +1796,8 @@ class SpectralHelper:
         else:  # go the potentially slower route of not reusing transfer classes
             from mpi4py_fft import newDistArray
 
+            fft = self.get_fft(**kwargs) if fft is None else fft
+
             _in = newDistArray(fft, forward).redistribute(axis_in)
             _in[...] = u
 
diff --git a/pySDC/implementations/problem_classes/RayleighBenard.py b/pySDC/implementations/problem_classes/RayleighBenard.py
index f75d22ed7d90b901d1113c8e0ddf199352891591..69a93a3d0fd046d184dc207563d332f00e8dd122 100644
--- a/pySDC/implementations/problem_classes/RayleighBenard.py
+++ b/pySDC/implementations/problem_classes/RayleighBenard.py
@@ -6,6 +6,7 @@ from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh
 from pySDC.core.convergence_controller import ConvergenceController
 from pySDC.core.hooks import Hooks
 from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence
+from pySDC.core.problem import WorkCounter
 
 
 class RayleighBenard(GenericSpectralLinear):
@@ -20,7 +21,7 @@ class RayleighBenard(GenericSpectralLinear):
         v_t - nu (v_xx + v_zz) + p_z - T = -uv_x - vv_z
 
     with u the horizontal velocity, v the vertical velocity (in z-direction), T the temperature, p the pressure, indices
-    denoting derivatives, kappa=(Rayleigh * Prandl)**(-1/2) and nu = (Rayleigh / Prandl)**(-1/2). Everything on the left
+    denoting derivatives, kappa=(Rayleigh * Prandtl)**(-1/2) and nu = (Rayleigh / Prandtl)**(-1/2). Everything on the left
     hand side, that is the viscous part, the pressure gradient and the buoyancy due to temperature are treated
     implicitly, while the non-linear convection part on the right hand side is integrated explicitly.
 
@@ -36,7 +37,7 @@ class RayleighBenard(GenericSpectralLinear):
     facilitate the Dirichlet BCs.
 
     Parameters:
-        Prandl (float): Prandl number
+        Prandtl (float): Prandtl number
         Rayleigh (float): Rayleigh number
         nx (int): Horizontal resolution
         nz (int): Vertical resolution
@@ -50,26 +51,28 @@ class RayleighBenard(GenericSpectralLinear):
 
     def __init__(
         self,
-        Prandl=1,
+        Prandtl=1,
         Rayleigh=2e6,
         nx=256,
         nz=64,
         BCs=None,
         dealiasing=3 / 2,
         comm=None,
+        Lx=8,
         **kwargs,
     ):
         """
         Constructor. `kwargs` are forwarded to parent class constructor.
 
         Args:
-            Prandl (float): Prandtl number
+            Prandtl (float): Prandtl number
             Rayleigh (float): Rayleigh number
             nx (int): Resolution in x-direction
             nz (int): Resolution in z direction
             BCs (dict): Vertical boundary conditions
             dealiasing (float): Dealiasing for evaluating the non-linear part in real space
             comm (mpi4py.Intracomm): Space communicator
+            Lx (float): Horizontal length of the domain
         """
         BCs = {} if BCs is None else BCs
         BCs = {
@@ -90,18 +93,19 @@ class RayleighBenard(GenericSpectralLinear):
             except ModuleNotFoundError:
                 pass
         self._makeAttributeAndRegister(
-            'Prandl',
+            'Prandtl',
             'Rayleigh',
             'nx',
             'nz',
             'BCs',
             'dealiasing',
             'comm',
+            'Lx',
             localVars=locals(),
             readOnly=True,
         )
 
-        bases = [{'base': 'fft', 'N': nx, 'x0': 0, 'x1': 8}, {'base': 'ultraspherical', 'N': nz}]
+        bases = [{'base': 'fft', 'N': nx, 'x0': 0, 'x1': self.Lx}, {'base': 'ultraspherical', 'N': nz}]
         components = ['u', 'v', 'T', 'p']
         super().__init__(bases, components, comm=comm, **kwargs)
 
@@ -127,15 +131,17 @@ class RayleighBenard(GenericSpectralLinear):
         self.Dz = S1 @ Dz
         self.Dzz = S2 @ Dzz
 
-        kappa = (Rayleigh * Prandl) ** (-1 / 2.0)
-        nu = (Rayleigh / Prandl) ** (-1 / 2.0)
+        # compute rescaled Rayleigh number to extract viscosity and thermal diffusivity
+        Ra = Rayleigh / (max([abs(BCs['T_top'] - BCs['T_bottom']), np.finfo(float).eps]) * self.axes[1].L ** 3)
+        self.kappa = (Ra * Prandtl) ** (-1 / 2.0)
+        self.nu = (Ra / Prandtl) ** (-1 / 2.0)
 
         # construct operators
         L_lhs = {
             'p': {'u': U01 @ Dx, 'v': Dz},  # divergence free constraint
-            'u': {'p': U02 @ Dx, 'u': -nu * (U02 @ Dxx + Dzz)},
-            'v': {'p': U12 @ Dz, 'v': -nu * (U02 @ Dxx + Dzz), 'T': -U02 @ Id},
-            'T': {'T': -kappa * (U02 @ Dxx + Dzz)},
+            'u': {'p': U02 @ Dx, 'u': -self.nu * (U02 @ Dxx + Dzz)},
+            'v': {'p': U12 @ Dz, 'v': -self.nu * (U02 @ Dxx + Dzz), 'T': -U02 @ Id},
+            'T': {'T': -self.kappa * (U02 @ Dxx + Dzz)},
         }
         self.setup_L(L_lhs)
 
@@ -175,6 +181,8 @@ class RayleighBenard(GenericSpectralLinear):
                 )
         self.setup_BCs()
 
+        self.work_counters['rhs'] = WorkCounter()
+
     def eval_f(self, u, *args, **kwargs):
         f = self.f_init
 
@@ -225,6 +233,7 @@ class RayleighBenard(GenericSpectralLinear):
         else:
             f.expl[:] = self.itransform(self.transform(fexpl_pad, padding=padding)).real
 
+        self.work_counters['rhs']()
         return f
 
     def u_exact(self, t=0, noise_level=1e-3, seed=99):
diff --git a/pySDC/implementations/problem_classes/generic_spectral.py b/pySDC/implementations/problem_classes/generic_spectral.py
index 556ff467606af60909f45977097da58822a319e4..4f0fa34266c8de2ad76382b875862c00a66d438f 100644
--- a/pySDC/implementations/problem_classes/generic_spectral.py
+++ b/pySDC/implementations/problem_classes/generic_spectral.py
@@ -28,21 +28,26 @@ class GenericSpectralLinear(Problem):
         Pr (sparse matrix): Right preconditioner
     """
 
-    @classmethod
-    def setup_GPU(cls):
+    def setup_GPU(self):
         """switch to GPU modules"""
         import cupy as cp
         from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh, imex_cupy_mesh
         from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh
 
-        cls.dtype_u = cupy_mesh
+        self.dtype_u = cupy_mesh
 
         GPU_versions = {
             mesh: cupy_mesh,
             imex_mesh: imex_cupy_mesh,
         }
 
-        cls.dtype_f = GPU_versions[cls.dtype_f]
+        self.dtype_f = GPU_versions[self.dtype_f]
+
+        if self.comm is not None:
+            from pySDC.helpers.NCCL_communicator import NCCLComm
+
+            if not isinstance(self.comm, NCCLComm):
+                self.__dict__['comm'] = NCCLComm(self.comm)
 
     def __init__(
         self,
@@ -94,6 +99,9 @@ class GenericSpectralLinear(Problem):
 
         if useGPU:
             self.setup_GPU()
+            if self.solver_args is not None:
+                if 'rtol' in self.solver_args.keys():
+                    self.solver_args['tol'] = self.solver_args.pop('rtol')
 
         for base in bases:
             self.spectral.add_axis(**base)
@@ -218,14 +226,21 @@ class GenericSpectralLinear(Problem):
 
         if self.spectral_space:
             rhs_hat = rhs.copy()
+            if u0 is not None:
+                u0_hat = self.Pr.T @ u0.copy().flatten()
         else:
             rhs_hat = self.spectral.transform(rhs)
+            if u0 is not None:
+                u0_hat = self.Pr.T @ self.spectral.transform(u0).flatten()
+
+        if self.useGPU:
+            self.xp.cuda.Device().synchronize()
 
         rhs_hat = (self.M @ rhs_hat.flatten()).reshape(rhs_hat.shape)
         rhs_hat = self.spectral.put_BCs_in_rhs_hat(rhs_hat)
         rhs_hat = self.Pl @ rhs_hat.flatten()
 
-        if dt not in self.cached_factorizations.keys():
+        if dt not in self.cached_factorizations.keys() or not self.solver_type.lower() == 'cached_direct':
             A = self.M + dt * self.L
             A = self.Pl @ self.spectral.put_BCs_in_matrix(A) @ self.Pr
 
@@ -255,25 +270,58 @@ class GenericSpectralLinear(Problem):
 
         elif self.solver_type.lower() == 'direct':
             _sol_hat = sp.linalg.spsolve(A, rhs_hat)
+        elif self.solver_type.lower() == 'lsqr':
+            lsqr = sp.linalg.lsqr(
+                A,
+                rhs_hat,
+                x0=u0_hat,
+                **self.solver_args,
+            )
+            _sol_hat = lsqr[0]
         elif self.solver_type.lower() == 'gmres':
             _sol_hat, _ = sp.linalg.gmres(
                 A,
                 rhs_hat,
-                x0=u0.flatten(),
+                x0=u0_hat,
                 **self.solver_args,
                 callback=self.work_counters[self.solver_type],
-                callback_type='legacy',
+                callback_type='pr_norm',
+            )
+        elif self.solver_type.lower() == 'gmres+ilu':
+            linalg = self.spectral.linalg
+
+            if dt not in self.cached_factorizations.keys():
+                if len(self.cached_factorizations) >= self.max_cached_factorizations:
+                    to_evict = list(self.cached_factorizations.keys())[0]
+                    self.cached_factorizations.pop(to_evict)
+                    self.logger.debug(f'Evicted matrix factorization for {to_evict=:.6f} from cache')
+                iLU = linalg.spilu(A, drop_tol=dt * 1e-4, fill_factor=100)
+                self.cached_factorizations[dt] = linalg.LinearOperator(A.shape, iLU.solve)
+                self.logger.debug(f'Cached matrix factorization for {dt=:.6f}')
+                self.work_counters['factorizations']()
+
+            _sol_hat, _ = linalg.gmres(
+                A,
+                rhs_hat,
+                x0=u0_hat,
+                **self.solver_args,
+                callback=self.work_counters[self.solver_type],
+                callback_type='pr_norm',
+                M=self.cached_factorizations[dt],
             )
         elif self.solver_type.lower() == 'cg':
             _sol_hat, _ = sp.linalg.cg(
-                A, rhs_hat, x0=u0.flatten(), **self.solver_args, callback=self.work_counters[self.solver_type]
+                A, rhs_hat, x0=u0_hat, **self.solver_args, callback=self.work_counters[self.solver_type]
             )
         else:
-            raise NotImplementedError(f'Solver {self.solver_type:!} not implemented in {type(self).__name__}!')
+            raise NotImplementedError(f'Solver {self.solver_type=} not implemented in {type(self).__name__}!')
 
         sol_hat = self.spectral.u_init_forward
         sol_hat[...] = (self.Pr @ _sol_hat).reshape(sol_hat.shape)
 
+        if self.useGPU:
+            self.xp.cuda.Device().synchronize()
+
         if self.spectral_space:
             return sol_hat
         else:
@@ -319,7 +367,6 @@ def compute_residual_DAE(self, stage=''):
             res[m] += L.tau[m]
         # use abs function from data type here
         res_norm.append(abs(res[m]))
-        # print(m, [abs(me) for me in res[m]], [abs(me) for me in L.u[0] - L.u[m + 1]])
 
     # find maximal residual over the nodes
     if L.params.residual_type == 'full_abs':
diff --git a/pySDC/tests/test_problems/test_RayleighBenard.py b/pySDC/tests/test_problems/test_RayleighBenard.py
index 1d1638d7bb88076abe682eabbf58b3ce61c39bd4..1cc1d85a9e38d7090d406240700fbf129a8e4aa9 100644
--- a/pySDC/tests/test_problems/test_RayleighBenard.py
+++ b/pySDC/tests/test_problems/test_RayleighBenard.py
@@ -15,8 +15,8 @@ def test_eval_f(nx, nz, direction, spectral_space):
     X, Z = P.X, P.Z
     cos, sin = np.cos, np.sin
 
-    kappa = (P.Rayleigh * P.Prandl) ** (-1 / 2)
-    nu = (P.Rayleigh / P.Prandl) ** (-1 / 2)
+    kappa = P.kappa
+    nu = P.nu
 
     if direction == 'x':
         y = sin(X * np.pi)
@@ -181,7 +181,9 @@ def test_Poisson_problems(nx, component):
         'T_top': 0,
         'T_bottom': 0,
     }
-    P = RayleighBenard(nx=nx, nz=6, BCs=BCs, Rayleigh=1.0)
+    P = RayleighBenard(
+        nx=nx, nz=6, BCs=BCs, Rayleigh=(max([abs(BCs['T_top'] - BCs['T_bottom']), np.finfo(float).eps]) * 2**3)
+    )
     rhs = P.u_init
 
     idx = P.index(f'{component}')