diff --git a/pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py b/pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py
index d735283a2918d64cd20f1098473c9fe99ccdbc6a..f083651e3aa10155e6170185a4906563036acc4e 100644
--- a/pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py
+++ b/pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py
@@ -71,7 +71,7 @@ class EstimatePolynomialError(ConvergenceController):
         self.add_status_variable_to_level('error_embedded_estimate')
         self.add_status_variable_to_level('order_embedded_estimate')
 
-    def matmul(self, A, b):
+    def matmul(self, A, b, xp=np):
         """
         Matrix vector multiplication, possibly MPI parallel.
         The parallel implementation performs a reduce operation in every row of the matrix. While communicating the
@@ -81,10 +81,12 @@ class EstimatePolynomialError(ConvergenceController):
         Args:
             A (2d np.ndarray): Matrix
             b (list): Vector
+            xp: Either numpy or cupy
 
         Returns:
             List: Axb
         """
+
         if self.comm:
             res = [A[i, 0] * b[0] if b[i] is not None else None for i in range(A.shape[0])]
             buf = b[0] * 0.0
@@ -93,13 +95,13 @@ class EstimatePolynomialError(ConvergenceController):
                 send_buf = (
                     (A[i, index] * b[index])
                     if self.comm.rank != self.params.estimate_on_node - 1
-                    else np.zeros_like(res[0])
+                    else xp.zeros_like(res[0])
                 )
                 self.comm.Allreduce(send_buf, buf, op=self.MPI_SUM)
                 res[i] += buf
             return res
         else:
-            return A @ np.asarray(b)
+            return A @ xp.asarray(b)
 
     def post_iteration_processing(self, controller, S, **kwargs):
         """
@@ -118,19 +120,20 @@ class EstimatePolynomialError(ConvergenceController):
             coll = L.sweep.coll
             nodes = np.append(np.append(0, coll.nodes), 1.0)
             estimate_on_node = self.params.estimate_on_node
+            xp = L.u[0].xp
 
             if self.interpolation_matrix is None:
                 interpolator = LagrangeApproximation(
                     points=[nodes[i] for i in range(coll.num_nodes + 1) if i != estimate_on_node]
                 )
-                self.interpolation_matrix = interpolator.getInterpolationMatrix([nodes[estimate_on_node]])
+                self.interpolation_matrix = xp.array(interpolator.getInterpolationMatrix([nodes[estimate_on_node]]))
 
             u = [
                 L.u[i].flatten() if L.u[i] is not None else L.u[i]
                 for i in range(coll.num_nodes + 1)
                 if i != estimate_on_node
             ]
-            u_inter = self.matmul(self.interpolation_matrix, u)[0].reshape(L.prob.init[0])
+            u_inter = self.matmul(self.interpolation_matrix, u, xp=xp)[0].reshape(L.prob.init[0])
 
             # compute end point if needed
             if estimate_on_node == len(nodes) - 1:
diff --git a/pySDC/implementations/datatype_classes/cupy_mesh.py b/pySDC/implementations/datatype_classes/cupy_mesh.py
index 1a6d79fe877c5ba0566e731fe65f928dd22071f2..01b9dde1e8bc6d155bd58b04e23113065ba8efd3 100644
--- a/pySDC/implementations/datatype_classes/cupy_mesh.py
+++ b/pySDC/implementations/datatype_classes/cupy_mesh.py
@@ -12,7 +12,10 @@ class cupy_mesh(cp.ndarray):
     CuPy-based datatype for serial or parallel meshes.
     """
 
-    def __new__(cls, init, val=0.0, offset=0, buffer=None, strides=None, order=None):
+    comm = None
+    xp = cp
+
+    def __new__(cls, init, val=0.0, **kwargs):
         """
         Instantiates new datatype. This ensures that even when manipulating data, the result is still a mesh.
 
@@ -25,51 +28,30 @@ class cupy_mesh(cp.ndarray):
 
         """
         if isinstance(init, cupy_mesh):
-            obj = cp.ndarray.__new__(cls, shape=init.shape, dtype=init.dtype, strides=strides, order=order)
+            obj = cp.ndarray.__new__(cls, shape=init.shape, dtype=init.dtype, **kwargs)
             obj[:] = init[:]
-            obj._comm = init._comm
         elif (
             isinstance(init, tuple)
             and (init[1] is None or isinstance(init[1], MPI.Intracomm))
             and isinstance(init[2], cp.dtype)
         ):
-            obj = cp.ndarray.__new__(cls, init[0], dtype=init[2], strides=strides, order=order)
+            obj = cp.ndarray.__new__(cls, init[0], dtype=init[2], **kwargs)
             obj.fill(val)
-            obj._comm = init[1]
         else:
             raise NotImplementedError(type(init))
         return obj
 
-    @property
-    def comm(self):
-        """
-        Getter for the communicator
-        """
-        return self._comm
-
-    def __array_finalize__(self, obj):
-        """
-        Finalizing the datatype. Without this, new datatypes do not 'inherit' the communicator.
-        """
-        if obj is None:
-            return
-        self._comm = getattr(obj, '_comm', None)
-
     def __array_ufunc__(self, ufunc, method, *inputs, out=None, **kwargs):
         """
         Overriding default ufunc, cf. https://numpy.org/doc/stable/user/basics.subclassing.html#array-ufunc-for-ufuncs
         """
         args = []
-        comm = None
         for _, input_ in enumerate(inputs):
             if isinstance(input_, cupy_mesh):
                 args.append(input_.view(cp.ndarray))
-                comm = input_.comm
             else:
                 args.append(input_)
         results = super(cupy_mesh, self).__array_ufunc__(ufunc, method, *args, **kwargs).view(cupy_mesh)
-        if not method == 'reduce':
-            results._comm = comm
         return results
 
     def __abs__(self):
diff --git a/pySDC/implementations/datatype_classes/mesh.py b/pySDC/implementations/datatype_classes/mesh.py
index cda77da4a624b870bc521ddd0dff2ca13203ead1..8ad26465c08c24def442621822ff89f6a6f91f5a 100644
--- a/pySDC/implementations/datatype_classes/mesh.py
+++ b/pySDC/implementations/datatype_classes/mesh.py
@@ -19,6 +19,7 @@ class mesh(np.ndarray):
     """
 
     comm = None
+    xp = np
 
     def __new__(cls, init, val=0.0, **kwargs):
         """
diff --git a/pySDC/implementations/problem_classes/polynomial_test_problem.py b/pySDC/implementations/problem_classes/polynomial_test_problem.py
index cc7cada8a84f9d59187c734c631d8861e25b1063..8183e5f989874ca64023dc19cb8ab7eec4532a45 100644
--- a/pySDC/implementations/problem_classes/polynomial_test_problem.py
+++ b/pySDC/implementations/problem_classes/polynomial_test_problem.py
@@ -12,10 +12,19 @@ class polynomial_testequation(Problem):
 
     dtype_u = mesh
     dtype_f = mesh
+    xp = np
 
-    def __init__(self, degree=1, seed=26266):
+    def __init__(self, degree=1, seed=26266, useGPU=False):
         """Initialization routine"""
 
+        if useGPU:
+            import cupy as cp
+            from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh
+
+            type(self).xp = cp
+            type(self).dtype_u = cupy_mesh
+            type(self).dtype_f = cupy_mesh
+
         # invoke super init, passing number of dofs, dtype_u and dtype_f
         super().__init__(init=(1, None, np.dtype('float64')))
 
@@ -41,7 +50,7 @@ class polynomial_testequation(Problem):
         """
 
         f = self.dtype_f(self.init)
-        f[:] = self.poly.deriv(m=1)(t)
+        f[:] = self.xp.array(self.poly.deriv(m=1)(t))
         return f
 
     def solve_system(self, rhs, factor, u0, t):
@@ -86,7 +95,7 @@ class polynomial_testequation(Problem):
             The exact solution.
         """
         me = self.dtype_u(self.init)
-        me[:] = self.poly(t)
+        me[:] = self.xp.array(self.poly(t))
         return me
 
 
@@ -116,7 +125,7 @@ class polynomial_testequation_IMEX(polynomial_testequation):
         """
 
         f = self.dtype_f(self.init)
-        derivative = self.poly.deriv(m=1)(t)
+        derivative = self.xp.array(self.poly.deriv(m=1)(t))
         f.impl[:] = derivative / 2
         f.expl[:] = derivative / 2
         return f
diff --git a/pySDC/tests/test_convergence_controllers/test_polynomial_error.py b/pySDC/tests/test_convergence_controllers/test_polynomial_error.py
index a9244275e1b01d69fcb2b2ce71e3557e77412653..103bd63803ffcbe029c62b89e4583ec28d996f54 100644
--- a/pySDC/tests/test_convergence_controllers/test_polynomial_error.py
+++ b/pySDC/tests/test_convergence_controllers/test_polynomial_error.py
@@ -1,7 +1,7 @@
 import pytest
 
 
-def get_controller(dt, num_nodes, quad_type, useMPI):
+def get_controller(dt, num_nodes, quad_type, useMPI, useGPU):
     """
     Get a controller prepared for polynomial test equation
 
@@ -42,7 +42,10 @@ def get_controller(dt, num_nodes, quad_type, useMPI):
     sweeper_params['num_nodes'] = num_nodes
     sweeper_params['comm'] = comm
 
-    problem_params = {'degree': 12}
+    problem_params = {
+        'degree': 12,
+        'useGPU': useGPU,
+    }
 
     # initialize step parameters
     step_params = {}
@@ -82,6 +85,7 @@ def single_test(**kwargs):
         'quad_type': 'RADAU-RIGHT',
         'useMPI': False,
         'dt': 1.0,
+        'useGPU': False,
         **kwargs,
     }
 
@@ -131,8 +135,6 @@ def multiple_runs(dts, **kwargs):
         dict: Errors for multiple runs
         int: Order of the collocation problem
     """
-    from pySDC.helpers.stats_helper import get_sorted
-
     res = {}
 
     for dt in dts:
@@ -187,6 +189,22 @@ def test_interpolation_error(num_nodes, quad_type):
     check_order(steps, **kwargs)
 
 
+@pytest.mark.cupy
+@pytest.mark.parametrize('num_nodes', [2, 3, 4, 5])
+@pytest.mark.parametrize('quad_type', ['RADAU-RIGHT', 'GAUSS'])
+def test_interpolation_error_GPU(num_nodes, quad_type):
+    import numpy as np
+
+    kwargs = {
+        'num_nodes': num_nodes,
+        'quad_type': quad_type,
+        'useMPI': False,
+        'useGPU': True,
+    }
+    steps = np.logspace(-1, -4, 20)
+    check_order(steps, **kwargs)
+
+
 @pytest.mark.mpi4py
 @pytest.mark.parametrize('num_nodes', [2, 5])
 @pytest.mark.parametrize('quad_type', ['RADAU-RIGHT', 'GAUSS'])