diff --git a/.github/workflows/ci_pipeline.yml b/.github/workflows/ci_pipeline.yml
index 4ba3557084da3aae8adf9876d31f5e1af0713690..3487393df48ca79c26e9513b707f15cf636bfccf 100644
--- a/.github/workflows/ci_pipeline.yml
+++ b/.github/workflows/ci_pipeline.yml
@@ -169,6 +169,58 @@ jobs:
           path: |
             data_libpressio
             coverage_libpressio.dat
+
+  user_firedrake_tests:
+    runs-on: ubuntu-latest
+    container:
+      image: firedrakeproject/firedrake-vanilla:latest
+      options: --user root
+      volumes:
+        - ${{ github.workspace }}:/repositories
+    defaults:
+      run:
+        shell: bash -l {0}
+    steps:
+      - name: Checkout pySDC
+        uses: actions/checkout@v4
+        with: 
+          path: ./pySDC
+      - name: Checkout gusto
+        uses: actions/checkout@v4
+        with:
+          repository: firedrakeproject/gusto
+          path: ./gusto_repo
+      - name: Install pySDC
+        run: |
+          . /home/firedrake/firedrake/bin/activate
+          python -m pip install --no-deps -e /repositories/pySDC
+          python -m pip install qmat
+      - name: Install gusto
+        run: |
+          . /home/firedrake/firedrake/bin/activate
+          python -m pip install -e /repositories/gusto_repo
+      - name: run pytest
+        run: |
+          . /home/firedrake/firedrake/bin/activate
+          firedrake-clean
+          cd ./pySDC
+          coverage run -m pytest --continue-on-collection-errors -v --durations=0 /repositories/pySDC/pySDC/tests -m firedrake
+        timeout-minutes: 120
+      - name: Make coverage report
+        run: |
+          . /home/firedrake/firedrake/bin/activate
+
+          cd ./pySDC
+          mv data ../data_firedrake
+          coverage combine
+          mv .coverage ../coverage_firedrake.dat
+      - name: Upload artifacts
+        uses: actions/upload-artifact@v4
+        with:
+          name: test-artifacts-firedrake
+          path: |
+            data_firedrake
+            coverage_firedrake.dat
          
   user_monodomain_tests_linux:
     runs-on: ubuntu-latest
diff --git a/pySDC/helpers/firedrake_ensemble_communicator.py b/pySDC/helpers/firedrake_ensemble_communicator.py
new file mode 100644
index 0000000000000000000000000000000000000000..b10e79d027b0563869381f256e54dda4d1d244bb
--- /dev/null
+++ b/pySDC/helpers/firedrake_ensemble_communicator.py
@@ -0,0 +1,58 @@
+from mpi4py import MPI
+import firedrake as fd
+import numpy as np
+
+
+class FiredrakeEnsembleCommunicator:
+    """
+    Ensemble communicator for performing multiple similar distributed simulations with Firedrake, see https://www.firedrakeproject.org/firedrake/parallelism.html
+    This is intended to do space-time parallelism in pySDC.
+    This class wraps the time communicator. All requests that are not overloaded are passed to the time communicator. For instance, `ensemble.rank` will return the rank in the time communicator.
+    Some operations are overloaded to use the interface of the MPI communicator but handles communication with the ensemble communicator instead.
+    """
+
+    def __init__(self, comm, space_size):
+        """
+        Args:
+            comm (MPI.Intracomm): MPI communicator, which will be split into time and space communicators
+            space_size (int): Size of the spatial communicators
+
+        Attributes:
+            ensemble (firedrake.Ensemble): Ensemble communicator
+        """
+        self.ensemble = fd.Ensemble(comm, space_size)
+
+    @property
+    def space_comm(self):
+        return self.ensemble.comm
+
+    @property
+    def time_comm(self):
+        return self.ensemble.ensemble_comm
+
+    def __getattr__(self, name):
+        return getattr(self.time_comm, name)
+
+    def Reduce(self, sendbuf, recvbuf, op=MPI.SUM, root=0):
+        if type(sendbuf) in [np.ndarray]:
+            self.ensemble.ensemble_comm.Reduce(sendbuf, recvbuf, op, root)
+        else:
+            assert op == MPI.SUM
+            self.ensemble.reduce(sendbuf, recvbuf, root=root)
+
+    def Allreduce(self, sendbuf, recvbuf, op=MPI.SUM):
+        if type(sendbuf) in [np.ndarray]:
+            self.ensemble.ensemble_comm.Allreduce(sendbuf, recvbuf, op)
+        else:
+            assert op == MPI.SUM
+            self.ensemble.allreduce(sendbuf, recvbuf)
+
+    def Bcast(self, buf, root=0):
+        if type(buf) in [np.ndarray]:
+            self.ensemble.ensemble_comm.Bcast(buf, root)
+        else:
+            self.ensemble.bcast(buf, root=root)
+
+
+def get_ensemble(comm, space_size):
+    return fd.Ensemble(comm, space_size)
diff --git a/pySDC/implementations/datatype_classes/fenics_mesh.py b/pySDC/implementations/datatype_classes/fenics_mesh.py
index 52b0ad7ae06c1093e4892832ba27bc8305d77eec..6e9c30cf9653d4afbd880c1a03ee616d42268f48 100644
--- a/pySDC/implementations/datatype_classes/fenics_mesh.py
+++ b/pySDC/implementations/datatype_classes/fenics_mesh.py
@@ -80,7 +80,7 @@ class fenics_mesh(object):
         Args:
             other (float): factor
         Raises:
-            DataError: is other is not a float
+            DataError: if other is not a float
         Returns:
             fenics_mesh: copy of original values scaled by factor
         """
diff --git a/pySDC/implementations/datatype_classes/firedrake_mesh.py b/pySDC/implementations/datatype_classes/firedrake_mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..783b01c2e150cdad400bf87b137e622ae3e0d048
--- /dev/null
+++ b/pySDC/implementations/datatype_classes/firedrake_mesh.py
@@ -0,0 +1,114 @@
+import firedrake as fd
+
+from pySDC.core.errors import DataError
+
+
+class firedrake_mesh(object):
+    """
+    Wrapper for firedrake function data.
+
+    Attributes:
+        functionspace (firedrake.Function): firedrake data
+    """
+
+    def __init__(self, init, val=0.0):
+        if fd.functionspaceimpl.WithGeometry in type(init).__mro__:
+            self.functionspace = fd.Function(init)
+            self.functionspace.assign(val)
+        elif fd.Function in type(init).__mro__:
+            self.functionspace = fd.Function(init)
+        elif type(init) == firedrake_mesh:
+            self.functionspace = init.functionspace.copy(deepcopy=True)
+        else:
+            raise DataError('something went wrong during %s initialization' % type(init))
+
+    def __getattr__(self, key):
+        return getattr(self.functionspace, key)
+
+    @property
+    def asnumpy(self):
+        """
+        Get a numpy array of the values associated with this data
+        """
+        return self.functionspace.dat._numpy_data
+
+    def __add__(self, other):
+        if isinstance(other, type(self)):
+            me = firedrake_mesh(other)
+            me.functionspace.assign(self.functionspace + other.functionspace)
+            return me
+        else:
+            raise DataError("Type error: cannot add %s to %s" % (type(other), type(self)))
+
+    def __sub__(self, other):
+        if isinstance(other, type(self)):
+            me = firedrake_mesh(other)
+            me.functionspace.assign(self.functionspace - other.functionspace)
+            return me
+        else:
+            raise DataError("Type error: cannot add %s to %s" % (type(other), type(self)))
+
+    def __rmul__(self, other):
+        """
+        Overloading the right multiply by scalar factor
+
+        Args:
+            other (float): factor
+        Raises:
+            DataError: if other is not a float
+        Returns:
+            fenics_mesh: copy of original values scaled by factor
+        """
+
+        try:
+            me = firedrake_mesh(self)
+            me.functionspace.assign(other * self.functionspace)
+            return me
+        except TypeError as e:
+            raise DataError("Type error: cannot multiply %s to %s" % (type(other), type(self))) from e
+
+    def __abs__(self):
+        """
+        Overloading the abs operator for mesh types
+
+        Returns:
+            float: L2 norm
+        """
+
+        return fd.norm(self.functionspace, 'L2')
+
+
+class IMEX_firedrake_mesh(object):
+    """
+    Datatype for IMEX integration with firedrake data.
+
+    Attributes:
+        impl (firedrake_mesh): implicit part
+        expl (firedrake_mesh): explicit part
+    """
+
+    def __init__(self, init, val=0.0):
+        if type(init) == type(self):
+            self.impl = firedrake_mesh(init.impl)
+            self.expl = firedrake_mesh(init.expl)
+        else:
+            self.impl = firedrake_mesh(init, val=val)
+            self.expl = firedrake_mesh(init, val=val)
+
+    def __add__(self, other):
+        me = IMEX_firedrake_mesh(self)
+        me.impl = self.impl + other.impl
+        me.expl = self.expl + other.expl
+        return me
+
+    def __sub__(self, other):
+        me = IMEX_firedrake_mesh(self)
+        me.impl = self.impl - other.impl
+        me.expl = self.expl - other.expl
+        return me
+
+    def __rmul__(self, other):
+        me = IMEX_firedrake_mesh(self)
+        me.impl = other * self.impl
+        me.expl = other * self.expl
+        return me
diff --git a/pySDC/implementations/problem_classes/HeatFiredrake.py b/pySDC/implementations/problem_classes/HeatFiredrake.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2cdfdaf8af0a20c82f516837fb39ad9d9d8d361
--- /dev/null
+++ b/pySDC/implementations/problem_classes/HeatFiredrake.py
@@ -0,0 +1,211 @@
+from pySDC.core.problem import Problem, WorkCounter
+from pySDC.implementations.datatype_classes.firedrake_mesh import firedrake_mesh, IMEX_firedrake_mesh
+import firedrake as fd
+import numpy as np
+from mpi4py import MPI
+
+
+class Heat1DForcedFiredrake(Problem):
+    r"""
+    Example implementing the forced one-dimensional heat equation with Dirichlet boundary conditions
+
+    .. math::
+        \frac{d u}{d t} = \nu \frac{d^2 u}{d x^2} + f
+
+    for :math:`x \in \Omega:=[0,1]`, where the forcing term :math:`f` is defined by
+
+    .. math::
+        f(x, t) = -\sin(\pi x) (\sin(t) - \nu \pi^2 \cos(t)).
+
+    For initial conditions with constant c
+
+    .. math::
+        u(x, 0) = \sin(\pi x) + c,
+
+    the exact solution is given by
+
+    .. math::
+        u(x, t) = \sin(\pi x)\cos(t) + c.
+
+    Here, the problem is discretized with finite elements using firedrake. Hence, the problem
+    is reformulated to the *weak formulation*
+
+    .. math:
+        \int_\Omega u_t v\,dx = - \nu \int_\Omega \nabla u \nabla v\,dx + \int_\Omega f v\,dx.
+
+    We invert the Laplacian implicitly and treat the forcing term explicitly.
+    The solvers for the arising variational problems are cached for multiple collocation nodes and step sizes.
+
+    Parameters
+    ----------
+    nvars : int, optional
+        Spatial resolution, i.e., numbers of degrees of freedom in space.
+    nu : float, optional
+        Diffusion coefficient :math:`\nu`.
+    c: float, optional
+        Constant for the Dirichlet boundary condition :math: `c`
+    LHS_cache_size : int, optional
+        Cache size for variational problem solvers
+    comm : MPI communicator, optional
+        Supply an MPI communicator for spatial parallelism
+    """
+
+    dtype_u = firedrake_mesh
+    dtype_f = IMEX_firedrake_mesh
+
+    def __init__(self, n=30, nu=0.1, c=0.0, LHS_cache_size=12, comm=None):
+        """
+        Initialization
+
+        Args:
+            n (int): Number of degrees of freedom
+            nu (float): Diffusion parameter
+            c (float): Boundary condition constant
+            LHS_cache_size (int): Size of the cache for solvers
+            comm (mpi4pi.Intracomm): MPI communicator for spatial parallelism
+        """
+        comm = MPI.COMM_WORLD if comm is None else comm
+
+        # prepare Firedrake mesh and function space
+        self.mesh = fd.UnitIntervalMesh(n, comm=comm)
+        self.V = fd.FunctionSpace(self.mesh, "CG", 4)
+
+        # prepare pySDC problem class infrastructure by passing the function space to super init
+        super().__init__(self.V)
+        self._makeAttributeAndRegister('n', 'nu', 'c', 'LHS_cache_size', 'comm', localVars=locals(), readOnly=True)
+
+        # prepare caches and IO variables for solvers
+        self.solvers = {}
+        self.tmp_in = fd.Function(self.V)
+        self.tmp_out = fd.Function(self.V)
+
+        # prepare work counters
+        self.work_counters['solver_setup'] = WorkCounter()
+        self.work_counters['solves'] = WorkCounter()
+        self.work_counters['rhs'] = WorkCounter()
+
+    def eval_f(self, u, t):
+        """
+        Evaluate the right hand side.
+        The forcing term is simply interpolated to the grid.
+        The Laplacian is evaluated via a variational problem, where the mass matrix is inverted and homogeneous boundary conditions are applied.
+        Note that we cache the solver to obtain much better performance.
+
+        Parameters
+        ----------
+        u : dtype_u
+            Solution at which to evaluate
+        t : float
+            Time at which to evaluate
+
+        Returns
+        -------
+        f : dtype_f
+            The evaluated right hand side
+        """
+        # construct and cache a solver for the implicit part of the right hand side evaluation
+        if not hasattr(self, '__solv_eval_f_implicit'):
+            v = fd.TestFunction(self.V)
+            u_trial = fd.TrialFunction(self.V)
+
+            a = u_trial * v * fd.dx
+            L_impl = -fd.inner(self.nu * fd.nabla_grad(self.tmp_in), fd.nabla_grad(v)) * fd.dx
+
+            bcs = [fd.bcs.DirichletBC(self.V, fd.Constant(0), area) for area in [1, 2]]
+
+            prob = fd.LinearVariationalProblem(a, L_impl, self.tmp_out, bcs=bcs)
+            self.__solv_eval_f_implicit = fd.LinearVariationalSolver(prob)
+
+        # copy the solution we want to evaluate at into the input buffer
+        self.tmp_in.assign(u.functionspace)
+
+        # perform the solve using the cached solver
+        self.__solv_eval_f_implicit.solve()
+
+        me = self.dtype_f(self.init)
+
+        # copy the result of the solver from the output buffer to the variable this function returns
+        me.impl.assign(self.tmp_out)
+
+        # evaluate explicit part.
+        # Because it does not depend on the current solution, we can simply interpolate the expression
+        x = fd.SpatialCoordinate(self.mesh)
+        me.expl.interpolate(-(np.sin(t) - self.nu * np.pi**2 * np.cos(t)) * fd.sin(np.pi * x[0]))
+
+        self.work_counters['rhs']()
+
+        return me
+
+    def solve_system(self, rhs, factor, *args, **kwargs):
+        r"""
+        Linear solver for :math:`(M - factor nu * Lap) u = rhs`.
+
+        Parameters
+        ----------
+        rhs : dtype_f
+            Right-hand side for the nonlinear system.
+        factor : float
+            Abbrev. for the node-to-node stepsize (or any other factor required).
+        u0 : dtype_u
+            Initial guess for the iterative solver (not used here so far).
+        t : float
+            Current time.
+
+        Returns
+        -------
+        u : dtype_u
+            Solution.
+        """
+
+        # construct and cache a solver for the current factor (preconditioner entry times step size)
+        if factor not in self.solvers.keys():
+
+            # check if we need to evict something from the cache
+            if len(self.solvers) >= self.LHS_cache_size:
+                self.solvers.pop(list(self.solvers.keys())[0])
+
+            u = fd.TrialFunction(self.V)
+            v = fd.TestFunction(self.V)
+
+            a = u * v * fd.dx + fd.Constant(factor) * fd.inner(self.nu * fd.nabla_grad(u), fd.nabla_grad(v)) * fd.dx
+            L = fd.inner(self.tmp_in, v) * fd.dx
+
+            bcs = [fd.bcs.DirichletBC(self.V, fd.Constant(self.c), area) for area in [1, 2]]
+
+            prob = fd.LinearVariationalProblem(a, L, self.tmp_out, bcs=bcs)
+            self.solvers[factor] = fd.LinearVariationalSolver(prob)
+
+            self.work_counters['solver_setup']()
+
+        # copy solver rhs to the input buffer. Copying also to the output buffer uses it as initial guess
+        self.tmp_in.assign(rhs.functionspace)
+        self.tmp_out.assign(rhs.functionspace)
+
+        # call the cached solver
+        self.solvers[factor].solve()
+
+        # copy from output buffer to return variable
+        me = self.dtype_u(self.init)
+        me.assign(self.tmp_out)
+
+        self.work_counters['solves']()
+        return me
+
+    def u_exact(self, t):
+        r"""
+        Routine to compute the exact solution at time :math:`t`.
+
+        Parameters
+        ----------
+        t : float
+            Time of the exact solution.
+
+        Returns
+        -------
+        me : dtype_u
+            Exact solution.
+        """
+        me = self.u_init
+        x = fd.SpatialCoordinate(self.mesh)
+        me.interpolate(np.cos(t) * fd.sin(np.pi * x[0]) + self.c)
+        return me
diff --git a/pySDC/implementations/sweeper_classes/Runge_Kutta.py b/pySDC/implementations/sweeper_classes/Runge_Kutta.py
index f6db08086edac9b3e0891d1b55474f5c85f2b048..1ec9d58e3250e5e82d320e5c4e58af042ac5fe28 100644
--- a/pySDC/implementations/sweeper_classes/Runge_Kutta.py
+++ b/pySDC/implementations/sweeper_classes/Runge_Kutta.py
@@ -189,9 +189,9 @@ class RungeKutta(Sweeper):
         Returns:
             mesh: Full right hand side as a mesh
         """
-        if type(f).__name__ in ['mesh', 'cupy_mesh']:
+        if type(f).__name__ in ['mesh', 'cupy_mesh', 'firedrake_mesh']:
             return f
-        elif type(f).__name__ in ['imex_mesh', 'imex_cupy_mesh']:
+        elif type(f).__name__.lower() in ['imex_mesh', 'imex_cupy_mesh', 'imex_firedrake_mesh']:
             return f.impl + f.expl
         elif f is None:
             prob = self.level.prob
@@ -249,11 +249,11 @@ class RungeKutta(Sweeper):
 
             # implicit solve with prefactor stemming from the diagonal of Qd, use previous stage as initial guess
             if self.QI[m + 1, m + 1] != 0:
-                lvl.u[m + 1][:] = prob.solve_system(
+                lvl.u[m + 1] = prob.solve_system(
                     rhs, lvl.dt * self.QI[m + 1, m + 1], lvl.u[m], lvl.time + lvl.dt * self.coll.nodes[m + 1]
                 )
             else:
-                lvl.u[m + 1][:] = rhs[:]
+                lvl.u[m + 1] = rhs
 
             # update function values (we don't usually need to evaluate the RHS at the solution of the step)
             lvl.f[m + 1] = prob.eval_f(lvl.u[m + 1], lvl.time + lvl.dt * self.coll.nodes[m + 1])
@@ -428,11 +428,11 @@ class RungeKuttaIMEX(RungeKutta):
 
             # implicit solve with prefactor stemming from the diagonal of Qd, use previous stage as initial guess
             if self.QI[m + 1, m + 1] != 0:
-                lvl.u[m + 1][:] = prob.solve_system(
+                lvl.u[m + 1] = prob.solve_system(
                     rhs, lvl.dt * self.QI[m + 1, m + 1], lvl.u[m], lvl.time + lvl.dt * self.coll.nodes[m + 1]
                 )
             else:
-                lvl.u[m + 1][:] = rhs[:]
+                lvl.u[m + 1] = rhs[:]
 
             # update function values
             lvl.f[m + 1] = prob.eval_f(lvl.u[m + 1], lvl.time + lvl.dt * self.coll.nodes[m + 1])
diff --git a/pySDC/implementations/sweeper_classes/generic_implicit_MPI.py b/pySDC/implementations/sweeper_classes/generic_implicit_MPI.py
index 1ca6d6e894c863929b401bdc0ee84002a095b274..4df14c346d105ffeef9d7ae640f760274eb024d8 100644
--- a/pySDC/implementations/sweeper_classes/generic_implicit_MPI.py
+++ b/pySDC/implementations/sweeper_classes/generic_implicit_MPI.py
@@ -55,14 +55,15 @@ class SweeperMPI(Sweeper):
 
         L = self.level
         P = L.prob
-        L.uend = P.dtype_u(P.init, val=0.0)
 
         # check if Mth node is equal to right point and do_coll_update is false, perform a simple copy
         if self.coll.right_is_node and not self.params.do_coll_update:
             # a copy is sufficient
             root = self.comm.Get_size() - 1
             if self.comm.rank == root:
-                L.uend[:] = L.u[-1]
+                L.uend = P.dtype_u(L.u[-1])
+            else:
+                L.uend = P.dtype_u(L.u[0])
             self.comm.Bcast(L.uend, root=root)
         else:
             raise NotImplementedError('require last node to be identical with right interval boundary')
@@ -221,7 +222,7 @@ class generic_implicit_MPI(SweeperMPI, generic_implicit):
         # build rhs, consisting of the known values from above and new values from previous nodes (at k+1)
 
         # implicit solve with prefactor stemming from the diagonal of Qd
-        L.u[self.rank + 1][:] = P.solve_system(
+        L.u[self.rank + 1] = P.solve_system(
             rhs,
             L.dt * self.QI[self.rank + 1, self.rank + 1],
             L.u[self.rank + 1],
diff --git a/pySDC/tests/test_datatypes/test_firedrake_mesh.py b/pySDC/tests/test_datatypes/test_firedrake_mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..781d765a354ddcf782b55574578895e5ddc3951a
--- /dev/null
+++ b/pySDC/tests/test_datatypes/test_firedrake_mesh.py
@@ -0,0 +1,155 @@
+import pytest
+
+
+@pytest.mark.firedrake
+def test_addition(n=3, v1=1, v2=2):
+    from pySDC.implementations.datatype_classes.firedrake_mesh import firedrake_mesh
+    import numpy as np
+    import firedrake as fd
+
+    mesh = fd.UnitSquareMesh(n, n)
+    V = fd.VectorFunctionSpace(mesh, "CG", 2)
+
+    a = firedrake_mesh(V)
+    b = firedrake_mesh(a)
+
+    a.assign(v1)
+    b.assign(v2)
+
+    c = a + b
+
+    assert np.allclose(c.dat._numpy_data, v1 + v2)
+    assert np.allclose(a.dat._numpy_data, v1)
+    assert np.allclose(b.dat._numpy_data, v2)
+
+
+@pytest.mark.firedrake
+def test_subtraction(n=3, v1=1, v2=2):
+    from pySDC.implementations.datatype_classes.firedrake_mesh import firedrake_mesh
+    import numpy as np
+    import firedrake as fd
+
+    mesh = fd.UnitSquareMesh(n, n)
+    V = fd.VectorFunctionSpace(mesh, "CG", 2)
+
+    a = firedrake_mesh(V, val=v1)
+    _b = fd.Function(V)
+    _b.assign(v2)
+    b = firedrake_mesh(_b)
+
+    c = a - b
+
+    assert np.allclose(c.dat._numpy_data, v1 - v2)
+    assert np.allclose(a.dat._numpy_data, v1)
+    assert np.allclose(b.dat._numpy_data, v2)
+
+
+@pytest.mark.firedrake
+def test_right_multiplication(n=3, v1=1, v2=2):
+    from pySDC.implementations.datatype_classes.firedrake_mesh import firedrake_mesh
+    from pySDC.core.errors import DataError
+    import numpy as np
+    import firedrake as fd
+
+    mesh = fd.UnitSquareMesh(n, n)
+    V = fd.VectorFunctionSpace(mesh, "CG", 2)
+
+    a = firedrake_mesh(V)
+    b = firedrake_mesh(a)
+
+    a.assign(v1)
+
+    b = v2 * a
+
+    assert np.allclose(b.dat._numpy_data, v1 * v2)
+    assert np.allclose(a.dat._numpy_data, v1)
+
+    try:
+        'Dat kölsche Dom' * b
+    except DataError:
+        pass
+
+
+@pytest.mark.firedrake
+def test_norm(n=3, v1=-1):
+    from pySDC.implementations.datatype_classes.firedrake_mesh import firedrake_mesh
+    import numpy as np
+    import firedrake as fd
+
+    mesh = fd.UnitSquareMesh(n, n)
+    V = fd.VectorFunctionSpace(mesh, "CG", 1)
+
+    a = firedrake_mesh(V, val=v1)
+    b = firedrake_mesh(a)
+
+    b = abs(a)
+
+    assert np.isclose(b, np.sqrt(2) * abs(v1)), f'{b=}, {v1=}'
+    assert np.allclose(a.dat._numpy_data, v1)
+
+
+@pytest.mark.firedrake
+def test_addition_rhs(n=3, v1=1, v2=2):
+    from pySDC.implementations.datatype_classes.firedrake_mesh import IMEX_firedrake_mesh
+    import numpy as np
+    import firedrake as fd
+
+    mesh = fd.UnitSquareMesh(n, n)
+    V = fd.VectorFunctionSpace(mesh, "CG", 2)
+
+    a = IMEX_firedrake_mesh(V, val=v1)
+    b = IMEX_firedrake_mesh(V, val=v2)
+
+    c = a + b
+
+    assert np.allclose(c.impl.dat._numpy_data, v1 + v2)
+    assert np.allclose(c.expl.dat._numpy_data, v1 + v2)
+    assert np.allclose(a.impl.dat._numpy_data, v1)
+    assert np.allclose(b.impl.dat._numpy_data, v2)
+    assert np.allclose(a.expl.dat._numpy_data, v1)
+    assert np.allclose(b.expl.dat._numpy_data, v2)
+
+
+@pytest.mark.firedrake
+def test_subtraction_rhs(n=3, v1=1, v2=2):
+    from pySDC.implementations.datatype_classes.firedrake_mesh import IMEX_firedrake_mesh
+    import numpy as np
+    import firedrake as fd
+
+    mesh = fd.UnitSquareMesh(n, n)
+    V = fd.VectorFunctionSpace(mesh, "CG", 2)
+
+    a = IMEX_firedrake_mesh(V, val=v1)
+    b = IMEX_firedrake_mesh(V, val=v2)
+
+    c = a - b
+
+    assert np.allclose(c.impl.dat._numpy_data, v1 - v2)
+    assert np.allclose(c.expl.dat._numpy_data, v1 - v2)
+    assert np.allclose(a.impl.dat._numpy_data, v1)
+    assert np.allclose(b.impl.dat._numpy_data, v2)
+    assert np.allclose(a.expl.dat._numpy_data, v1)
+    assert np.allclose(b.expl.dat._numpy_data, v2)
+
+
+@pytest.mark.firedrake
+def test_rmul_rhs(n=3, v1=1, v2=2):
+    from pySDC.implementations.datatype_classes.firedrake_mesh import IMEX_firedrake_mesh
+    import numpy as np
+    import firedrake as fd
+
+    mesh = fd.UnitSquareMesh(n, n)
+    V = fd.VectorFunctionSpace(mesh, "CG", 2)
+
+    a = IMEX_firedrake_mesh(V, val=v1)
+
+    b = v2 * a
+
+    assert np.allclose(a.impl.dat._numpy_data, v1)
+    assert np.allclose(b.impl.dat._numpy_data, v2 * v1)
+    assert np.allclose(a.expl.dat._numpy_data, v1)
+    assert np.allclose(b.expl.dat._numpy_data, v2 * v1)
+
+
+if __name__ == '__main__':
+    test_addition()
diff --git a/pySDC/tests/test_problems/test_heat_firedrake.py b/pySDC/tests/test_problems/test_heat_firedrake.py
new file mode 100644
index 0000000000000000000000000000000000000000..9304dd35ce98a7e9b1ff9598e625d9b2b0aae958
--- /dev/null
+++ b/pySDC/tests/test_problems/test_heat_firedrake.py
@@ -0,0 +1,62 @@
+import pytest
+
+
+@pytest.mark.parametrize('c', [0, 3.14])
+@pytest.mark.firedrake
+def test_solve_system(c):
+    from pySDC.implementations.problem_classes.HeatFiredrake import Heat1DForcedFiredrake
+    import numpy as np
+    import firedrake as fd
+
+    # test we get the initial conditions back when solving with zero step size
+    P = Heat1DForcedFiredrake(n=128, c=c)
+    u0 = P.u_exact(0)
+    un = P.solve_system(u0, 0)
+    assert abs(u0 - un) < 1e-8
+
+    # test we get the expected solution to a Poisson problem by setting very large step size
+    dt = 1e6
+    x = fd.SpatialCoordinate(P.mesh)
+    u0 = P.u_init
+    u0.interpolate(fd.sin(np.pi * x[0]))
+    expect = P.u_init
+    expect.interpolate(1 / (P.nu * np.pi**2 * dt) * fd.sin(np.pi * x[0]) + P.c)
+    un = P.solve_system(u0, dt)
+    error = abs(un - expect) / abs(expect)
+    assert error < 1e-4, error
+
+    # test that we arrive back where we started when going forward with IMEX Euler and backward with explicit Euler
+    dt = 1e0
+    u0 = P.u_exact(0)
+    f = P.eval_f(u0, 0)
+    un2 = P.solve_system(u0 + dt * f.expl, dt)
+    fn2 = P.eval_f(un2, 0)
+    u02 = un2 - dt * (fn2.impl + fn2.expl)
+    error = abs(u02 - u0) / abs(u02)
+    assert error < 1e-8, error
+
+
+@pytest.mark.firedrake
+def test_eval_f():
+    from pySDC.implementations.problem_classes.HeatFiredrake import Heat1DForcedFiredrake
+    import numpy as np
+    import firedrake as fd
+
+    P = Heat1DForcedFiredrake(n=128)
+
+    me = P.u_init
+    x = fd.SpatialCoordinate(P.mesh)
+    me.interpolate(-fd.sin(np.pi * x[0]))
+
+    expect = P.u_init
+    expect.interpolate(P.nu * np.pi**2 * fd.sin(np.pi * x[0]))
+
+    get = P.eval_f(me, 0).impl
+
+    error = abs(expect - get) / abs(expect)
+    assert error < 1e-8, error
+
+
+if __name__ == '__main__':
+    test_solve_system(0)
+    # test_eval_f()
diff --git a/pySDC/tests/test_tutorials/test_step_7.py b/pySDC/tests/test_tutorials/test_step_7.py
index cd2f4ec8b37a2e5458f829ad0f29f6aafe2c160e..56aa6cb83f6c0a7a49b34d59d1f27128fb90b89c 100644
--- a/pySDC/tests/test_tutorials/test_step_7.py
+++ b/pySDC/tests/test_tutorials/test_step_7.py
@@ -127,3 +127,27 @@ def test_D():
     from pySDC.tutorial.step_7.D_pySDC_with_PyTorch import train_at_collocation_nodes
 
     train_at_collocation_nodes()
+
+
+@pytest.mark.firedrake
+def test_E():
+    from pySDC.tutorial.step_7.E_pySDC_with_Firedrake import runHeatFiredrake
+
+    runHeatFiredrake(useMPIsweeper=False)
+
+
+@pytest.mark.firedrake
+def test_E_MPI():
+    my_env = os.environ.copy()
+    my_env['COVERAGE_PROCESS_START'] = 'pyproject.toml'
+    cwd = '.'
+    num_procs = 3
+    cmd = f'mpiexec -np {num_procs} python pySDC/tutorial/step_7/E_pySDC_with_Firedrake.py'.split()
+
+    p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=my_env, cwd=cwd)
+    p.wait()
+    for line in p.stdout:
+        print(line)
+    for line in p.stderr:
+        print(line)
+    assert p.returncode == 0, 'ERROR: did not get return code 0, got %s with %2i processes' % (p.returncode, num_procs)
diff --git a/pySDC/tutorial/step_7/E_pySDC_with_Firedrake.py b/pySDC/tutorial/step_7/E_pySDC_with_Firedrake.py
new file mode 100644
index 0000000000000000000000000000000000000000..d85d9a536b9151694df0fed52d62770811a65354
--- /dev/null
+++ b/pySDC/tutorial/step_7/E_pySDC_with_Firedrake.py
@@ -0,0 +1,115 @@
+"""
+Simple example running a forced heat equation in Firedrake.
+
+The function `setup` generates the description and controller_params dictionaries needed to run SDC with diagonal preconditioner.
+This proceeds very similar to earlier tutorials. The interesting part of this tutorial is rather in the problem class.
+See `pySDC/implementations/problem_classes/HeatFiredrake` for an easy example of how to use Firedrake within pySDC.
+
+Run in serial using simply `python E_pySDC_with_Firedrake.py` or with parallel diagonal SDC with `mpiexec -np 3 python E_pySDC_with_Firedrake.py`.
+"""
+
+import numpy as np
+from mpi4py import MPI
+
+
+def setup(useMPIsweeper):
+    """
+    Helper routine to set up parameters
+
+    Returns:
+        description and controller_params parameter dictionaries
+    """
+    from pySDC.implementations.problem_classes.HeatFiredrake import Heat1DForcedFiredrake
+    from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
+    from pySDC.implementations.sweeper_classes.imex_1st_order_MPI import imex_1st_order_MPI
+    from pySDC.implementations.hooks.log_errors import LogGlobalErrorPostRun
+    from pySDC.implementations.hooks.log_work import LogWork
+    from pySDC.helpers.firedrake_ensemble_communicator import FiredrakeEnsembleCommunicator
+
+    # setup space-time parallelism via ensemble for Firedrake, see https://www.firedrakeproject.org/firedrake/parallelism.html
+    num_nodes = 3
+    ensemble = FiredrakeEnsembleCommunicator(MPI.COMM_WORLD, max([MPI.COMM_WORLD.size // num_nodes, 1]))
+
+    level_params = dict()
+    level_params['restol'] = 5e-10
+    level_params['dt'] = 0.2
+
+    step_params = dict()
+    step_params['maxiter'] = 20
+
+    sweeper_params = dict()
+    sweeper_params['quad_type'] = 'RADAU-RIGHT'
+    sweeper_params['num_nodes'] = num_nodes
+    sweeper_params['QI'] = 'MIN-SR-S'
+    sweeper_params['QE'] = 'PIC'
+    sweeper_params['comm'] = ensemble
+
+    problem_params = dict()
+    problem_params['nu'] = 0.1
+    problem_params['n'] = 128
+    problem_params['c'] = 1.0
+    problem_params['comm'] = ensemble.space_comm
+
+    controller_params = dict()
+    controller_params['logger_level'] = 15 if MPI.COMM_WORLD.rank == 0 else 30
+    controller_params['hook_class'] = [LogGlobalErrorPostRun, LogWork]
+
+    description = dict()
+    description['problem_class'] = Heat1DForcedFiredrake
+    description['problem_params'] = problem_params
+    description['sweeper_class'] = imex_1st_order_MPI if useMPIsweeper else imex_1st_order
+    description['sweeper_params'] = sweeper_params
+    description['level_params'] = level_params
+    description['step_params'] = step_params
+
+    return description, controller_params
+
+
+def runHeatFiredrake(useMPIsweeper):
+    """
+    Run the example defined by the above parameters
+    """
+    from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
+    from pySDC.helpers.stats_helper import get_sorted
+
+    Tend = 1.0
+    t0 = 0.0
+
+    description, controller_params = setup(useMPIsweeper)
+
+    controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description)
+
+    # get initial values
+    P = controller.MS[0].levels[0].prob
+    uinit = P.u_exact(0.0)
+
+    # call main function to get things done...
+    uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
+
+    # see what we get
+    error = get_sorted(stats, type='e_global_post_run')
+    work_solver_setup = get_sorted(stats, type='work_solver_setup')
+    work_solves = get_sorted(stats, type='work_solves')
+    work_rhs = get_sorted(stats, type='work_rhs')
+    niter = get_sorted(stats, type='niter')
+
+    tot_iter = np.sum([me[1] for me in niter])
+    tot_solver_setup = np.sum([me[1] for me in work_solver_setup])
+    tot_solves = np.sum([me[1] for me in work_solves])
+    tot_rhs = np.sum([me[1] for me in work_rhs])
+
+    print(
+        f'Finished with error {error[0][1]:.2e}. Used {tot_iter} SDC iterations, with {tot_solver_setup} solver setups, {tot_solves} solves and {tot_rhs} right hand side evaluations on time task {description["sweeper_params"]["comm"].rank}.'
+    )
+
+    # do tests that we got the same as last time
+    n_nodes = 1 if useMPIsweeper else description['sweeper_params']['num_nodes']
+    assert error[0][1] < 2e-8
+    assert tot_iter == 29
+    assert tot_solver_setup == n_nodes
+    assert tot_solves == n_nodes * tot_iter
+    assert tot_rhs == n_nodes * tot_iter + (n_nodes + 1) * len(niter)
+
+
+if __name__ == "__main__":
+    runHeatFiredrake(useMPIsweeper=MPI.COMM_WORLD.size > 1)
diff --git a/pyproject.toml b/pyproject.toml
index 3e6e776af1643f74ffd4662632ccb12f84a31575..5ea5109db25086878ff9e878d652b797427ffd77 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -64,7 +64,8 @@ markers = [
     'cupy: tests for cupy on GPUs',
     'libpressio: tests using the libpressio library',
     'monodomain: tests the monodomain project, which requires previous compilation of c++ code',
-    'pytorch: tests for PyTorch related things in pySDC'
+    'pytorch: tests for PyTorch related things in pySDC',
+    'firedrake: tests for firedrake',
     ]
 timeout = 300