diff --git a/pySDC/implementations/problem_classes/Battery.py b/pySDC/implementations/problem_classes/Battery.py
index 673c22233ed0a48f92541ccedfbe9237c8ef679c..b1629811f8b103960896cb0d26d429ef4469e512 100644
--- a/pySDC/implementations/problem_classes/Battery.py
+++ b/pySDC/implementations/problem_classes/Battery.py
@@ -5,47 +5,63 @@ from pySDC.core.Problem import ptype
 from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh
 
 
-class battery(ptype):
+class battery_n_capacitors(ptype):
     """
-    Example implementing the battery drain model as in the description in the PinTSimE project
+    Example implementing the battery drain model with N capacitors, where N is an arbitrary integer greater than 0.
     Attributes:
-        A: system matrix, representing the 2 ODEs
+        nswitches: number of switches
     """
 
     def __init__(self, problem_params, dtype_u=mesh, dtype_f=imex_mesh):
         """
         Initialization routine
+
         Args:
             problem_params (dict): custom parameters for the example
             dtype_u: mesh data type for solution
             dtype_f: mesh data type for RHS
         """
 
-        problem_params['nvars'] = 2
-
         # these parameters will be used later, so assert their existence
-        essential_keys = ['Vs', 'Rs', 'C', 'R', 'L', 'alpha', 'V_ref', 'set_switch', 't_switch']
+        essential_keys = ['ncapacitors', 'Vs', 'Rs', 'C', 'R', 'L', 'alpha', 'V_ref']
         for key in essential_keys:
             if key not in problem_params:
                 msg = 'need %s to instantiate problem, only got %s' % (key, str(problem_params.keys()))
                 raise ParameterError(msg)
 
+        n = problem_params['ncapacitors']
+        problem_params['nvars'] = n + 1
+
         # invoke super init, passing number of dofs, dtype_u and dtype_f
-        super(battery, self).__init__(
+        super(battery_n_capacitors, self).__init__(
             init=(problem_params['nvars'], None, np.dtype('float64')),
             dtype_u=dtype_u,
             dtype_f=dtype_f,
             params=problem_params,
         )
 
-        self.A = np.zeros((2, 2))
+        self.A = np.zeros((n + 1, n + 1))
+        self.switch_A, self.switch_f = self.get_problem_dict()
+        self.t_switch = None
+        self.nswitches = 0
 
     def eval_f(self, u, t):
         """
-        Routine to evaluate the RHS
+        Routine to evaluate the RHS. No Switch Estimator is used: For N = 3 there are N + 1 = 4 different states of the battery:
+            1. u[1] > V_ref[0] and u[2] > V_ref[1] and u[3] > V_ref[2]    -> C1 supplies energy
+            2. u[1] <= V_ref[0] and u[2] > V_ref[1] and u[3] > V_ref[2]   -> C2 supplies energy
+            3. u[1] <= V_ref[0] and u[2] <= V_ref[1] and u[3] > V_ref[2]  -> C3 supplies energy
+            4. u[1] <= V_ref[0] and u[2] <= V_ref[1] and u[3] <= V_ref[2] -> Vs supplies energy
+        max_index is initialized to -1. List "switch" contains a True if u[k] <= V_ref[k-1] is satisfied.
+            - Is no True there (i.e. max_index = -1), we are in the first case.
+            - max_index = k >= 0 means we are in the (k+1)-th case.
+              So, the actual RHS has key max_index-1 in the dictionary self.switch_f.
+        In case of using the Switch Estimator, we count the number of switches which illustrates in which case of voltage source we are.
+
         Args:
             u (dtype_u): current values
             t (float): current time
+
         Returns:
             dtype_f: the RHS
         """
@@ -53,50 +69,48 @@ class battery(ptype):
         f = self.dtype_f(self.init, val=0.0)
         f.impl[:] = self.A.dot(u)
 
-        if u[1] <= self.params.V_ref or self.params.set_switch:
-            # switching need to happen on exact time point
-            if self.params.set_switch:
-                if t >= self.params.t_switch:
-                    f.expl[0] = self.params.Vs / self.params.L
+        if self.t_switch is not None:
+            f.expl[:] = self.switch_f[self.nswitches]
 
-                else:
-                    f.expl[0] = 0
+        else:
+            # proof all switching conditions and find largest index where it drops below V_ref
+            switch = [True if u[k] <= self.params.V_ref[k - 1] else False for k in range(1, len(u))]
+            max_index = max([k if switch[k] == True else -1 for k in range(len(switch))])
 
-            else:
-                f.expl[0] = self.params.Vs / self.params.L
+            if max_index == -1:
+                f.expl[:] = self.switch_f[0]
 
-        else:
-            f.expl[0] = 0
+            else:
+                f.expl[:] = self.switch_f[max_index + 1]
 
         return f
 
     def solve_system(self, rhs, factor, u0, t):
         """
         Simple linear solver for (I-factor*A)u = rhs
+
         Args:
             rhs (dtype_f): right-hand side for the linear system
             factor (float): abbrev. for the local stepsize (or any other factor required)
             u0 (dtype_u): initial guess for the iterative solver
             t (float): current time (e.g. for time-dependent BCs)
+
         Returns:
             dtype_u: solution as mesh
         """
-        self.A = np.zeros((2, 2))
 
-        if rhs[1] <= self.params.V_ref or self.params.set_switch:
-            # switching need to happen on exact time point
-            if self.params.set_switch:
-                if t >= self.params.t_switch:
-                    self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
+        if self.t_switch is not None:
+            self.A = self.switch_A[self.nswitches]
 
-                else:
-                    self.A[1, 1] = -1 / (self.params.C * self.params.R)
+        else:
+            # proof all switching conditions and find largest index where it drops below V_ref
+            switch = [True if rhs[k] <= self.params.V_ref[k - 1] else False for k in range(1, len(rhs))]
+            max_index = max([k if switch[k] == True else -1 for k in range(len(switch))])
+            if max_index == -1:
+                self.A = self.switch_A[0]
 
             else:
-                self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
-
-        else:
-            self.A[1, 1] = -1 / (self.params.C * self.params.R)
+                self.A = self.switch_A[max_index + 1]
 
         me = self.dtype_u(self.init)
         me[:] = np.linalg.solve(np.eye(self.params.nvars) - factor * self.A, rhs)
@@ -105,8 +119,10 @@ class battery(ptype):
     def u_exact(self, t):
         """
         Routine to compute the exact solution at time t
+
         Args:
             t (float): current time
+
         Returns:
             dtype_u: exact solution
         """
@@ -115,57 +131,168 @@ class battery(ptype):
         me = self.dtype_u(self.init)
 
         me[0] = 0.0  # cL
-        me[1] = self.params.alpha * self.params.V_ref  # vC
-
+        me[1:] = self.params.alpha * self.params.V_ref  # vC's
         return me
 
+    def get_switching_info(self, u, t):
+        """
+        Provides information about a discrete event for one subinterval.
+
+        Args:
+            u (dtype_u): current values
+            t (float): current time
+
+        Returns:
+            switch_detected (bool): Indicates if a switch is found or not
+            m_guess (np.int): Index of collocation node inside one subinterval of where the discrete event was found
+            vC_switch (list): Contains function values of switching condition (for interpolation)
+        """
+
+        switch_detected = False
+        m_guess = -100
+        break_flag = False
+
+        for m in range(len(u)):
+            for k in range(1, self.params.nvars):
+                if u[m][k] - self.params.V_ref[k - 1] <= 0:
+                    switch_detected = True
+                    m_guess = m - 1
+                    k_detected = k
+                    break_flag = True
+                    break
+
+            if break_flag:
+                break
 
-class battery_implicit(ptype):
+        vC_switch = (
+            [u[m][k_detected] - self.params.V_ref[k_detected - 1] for m in range(1, len(u))] if switch_detected else []
+        )
+
+        return switch_detected, m_guess, vC_switch
+
+    def count_switches(self):
+        """
+        Counts the number of switches. This function is called when a switch is found inside the range of tolerance
+        (in switch_estimator.py)
+        """
+
+        self.nswitches += 1
+
+    def get_problem_dict(self):
+        """
+        Helper to create dictionaries for both the coefficent matrix of the ODE system and the nonhomogeneous part.
+        """
+
+        n = self.params.ncapacitors
+        v = np.zeros(n + 1)
+        v[0] = 1
+
+        A, f = dict(), dict()
+        A = {k: np.diag(-1 / (self.params.C[k] * self.params.R) * np.roll(v, k + 1)) for k in range(n)}
+        A.update({n: np.diag(-(self.params.Rs + self.params.R) / self.params.L * v)})
+        f = {k: np.zeros(n + 1) for k in range(n)}
+        f.update({n: self.params.Vs / self.params.L * v})
+        return A, f
+
+
+class battery(battery_n_capacitors):
     """
-    Example implementing the battery drain model as in the description in the PinTSimE project
-    Attributes:
-        A: system matrix, representing the 2 ODEs
+    Example implementing the battery drain model with one capacitor, inherits from battery_n_capacitors.
     """
 
-    def __init__(self, problem_params, dtype_u=mesh, dtype_f=mesh):
+    def __init__(self, problem_params, dtype_u=mesh, dtype_f=imex_mesh):
+        super(battery, self).__init__(problem_params, dtype_u=dtype_u, dtype_f=dtype_f)
+
+    def eval_f(self, u, t):
         """
-        Initialization routine
+        Routine to evaluate the RHS
+
         Args:
-            problem_params (dict): custom parameters for the example
-            dtype_u: mesh data type for solution
-            dtype_f: mesh data type for RHS
+            u (dtype_u): current values
+            t (float): current time
+
+        Returns:
+            dtype_f: the RHS
         """
 
-        problem_params['nvars'] = 2
+        f = self.dtype_f(self.init, val=0.0)
+        f.impl[:] = self.A.dot(u)
+
+        t_switch = np.inf if self.t_switch is None else self.t_switch
 
-        # these parameters will be used later, so assert their existence
-        essential_keys = [
-            'newton_maxiter',
-            'newton_tol',
-            'Vs',
-            'Rs',
-            'C',
-            'R',
-            'L',
-            'alpha',
-            'V_ref',
-            'set_switch',
-            't_switch',
-        ]
+        if u[1] <= self.params.V_ref[0] or t >= t_switch:
+            f.expl[0] = self.params.Vs / self.params.L
+
+        else:
+            f.expl[0] = 0
+
+        return f
+
+    def solve_system(self, rhs, factor, u0, t):
+        """
+        Simple linear solver for (I-factor*A)u = rhs
+
+        Args:
+            rhs (dtype_f): right-hand side for the linear system
+            factor (float): abbrev. for the local stepsize (or any other factor required)
+            u0 (dtype_u): initial guess for the iterative solver
+            t (float): current time (e.g. for time-dependent BCs)
+
+        Returns:
+            dtype_u: solution as mesh
+        """
+        self.A = np.zeros((2, 2))
+
+        t_switch = np.inf if self.t_switch is None else self.t_switch
+
+        if rhs[1] <= self.params.V_ref[0] or t >= t_switch:
+            self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
+
+        else:
+            self.A[1, 1] = -1 / (self.params.C[0] * self.params.R)
+
+        me = self.dtype_u(self.init)
+        me[:] = np.linalg.solve(np.eye(self.params.nvars) - factor * self.A, rhs)
+        return me
+
+    def u_exact(self, t):
+        """
+        Routine to compute the exact solution at time t
+
+        Args:
+            t (float): current time
+
+        Returns:
+            dtype_u: exact solution
+        """
+        assert t == 0, 'ERROR: u_exact only valid for t=0'
+
+        me = self.dtype_u(self.init)
+
+        me[0] = 0.0  # cL
+        me[1] = self.params.alpha * self.params.V_ref[0]  # vC
+
+        return me
+
+
+class battery_implicit(battery):
+    def __init__(self, problem_params, dtype_u=mesh, dtype_f=mesh):
+
+        essential_keys = ['newton_maxiter', 'newton_tol', 'ncapacitors', 'Vs', 'Rs', 'C', 'R', 'L', 'alpha', 'V_ref']
         for key in essential_keys:
             if key not in problem_params:
                 msg = 'need %s to instantiate problem, only got %s' % (key, str(problem_params.keys()))
                 raise ParameterError(msg)
 
+        problem_params['nvars'] = problem_params['ncapacitors'] + 1
+
         # invoke super init, passing number of dofs, dtype_u and dtype_f
         super(battery_implicit, self).__init__(
-            init=(problem_params['nvars'], None, np.dtype('float64')),
+            problem_params,
             dtype_u=dtype_u,
             dtype_f=dtype_f,
-            params=problem_params,
         )
 
-        self.A = np.zeros((2, 2))
         self.newton_itercount = 0
         self.lin_itercount = 0
         self.newton_ncalls = 0
@@ -174,9 +301,11 @@ class battery_implicit(ptype):
     def eval_f(self, u, t):
         """
         Routine to evaluate the RHS
+
         Args:
             u (dtype_u): current values
             t (float): current time
+
         Returns:
             dtype_f: the RHS
         """
@@ -184,37 +313,29 @@ class battery_implicit(ptype):
         f = self.dtype_f(self.init, val=0.0)
         non_f = np.zeros(2)
 
-        if u[1] <= self.params.V_ref or self.params.set_switch:
-            # switching need to happen on exact time point
-            if self.params.set_switch:
-                if t >= self.params.t_switch:
-                    self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
-                    non_f[0] = self.params.Vs / self.params.L
+        t_switch = np.inf if self.t_switch is None else self.t_switch
 
-                else:
-                    self.A[1, 1] = -1 / (self.params.C * self.params.R)
-                    non_f[0] = 0
-
-            else:
-                self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
-                non_f[0] = self.params.Vs / self.params.L
+        if u[1] <= self.params.V_ref[0] or t >= t_switch:
+            self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
+            non_f[0] = self.params.Vs
 
         else:
-            self.A[1, 1] = -1 / (self.params.C * self.params.R)
+            self.A[1, 1] = -1 / (self.params.C[0] * self.params.R)
             non_f[0] = 0
 
         f[:] = self.A.dot(u) + non_f
-
         return f
 
     def solve_system(self, rhs, factor, u0, t):
         """
         Simple Newton solver
+
         Args:
             rhs (dtype_f): right-hand side for the linear system
             factor (float): abbrev. for the local stepsize (or any other factor required)
             u0 (dtype_u): initial guess for the iterative solver
             t (float): current time (e.g. for time-dependent BCs)
+
         Returns:
             dtype_u: solution as mesh
         """
@@ -223,23 +344,14 @@ class battery_implicit(ptype):
         non_f = np.zeros(2)
         self.A = np.zeros((2, 2))
 
-        if rhs[1] <= self.params.V_ref or self.params.set_switch:
-            # switching need to happen on exact time point
-            if self.params.set_switch:
-                if t >= self.params.t_switch:
-                    self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
-                    non_f[0] = self.params.Vs / self.params.L
+        t_switch = np.inf if self.t_switch is None else self.t_switch
 
-                else:
-                    self.A[1, 1] = -1 / (self.params.C * self.params.R)
-                    non_f[0] = 0
-
-            else:
-                self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
-                non_f[0] = self.params.Vs / self.params.L
+        if rhs[1] <= self.params.V_ref[0] or t >= t_switch:
+            self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
+            non_f[0] = self.params.Vs
 
         else:
-            self.A[1, 1] = -1 / (self.params.C * self.params.R)
+            self.A[1, 1] = -1 / (self.params.C[0] * self.params.R)
             non_f[0] = 0
 
         # start newton iteration
@@ -279,20 +391,3 @@ class battery_implicit(ptype):
         me[:] = u[:]
 
         return me
-
-    def u_exact(self, t):
-        """
-        Routine to compute the exact solution at time t
-        Args:
-            t (float): current time
-        Returns:
-            dtype_u: exact solution
-        """
-        assert t == 0, 'ERROR: u_exact only valid for t=0'
-
-        me = self.dtype_u(self.init)
-
-        me[0] = 0.0  # cL
-        me[1] = self.params.alpha * self.params.V_ref  # vC
-
-        return me
diff --git a/pySDC/implementations/problem_classes/Battery_2Condensators.py b/pySDC/implementations/problem_classes/Battery_2Condensators.py
deleted file mode 100644
index 1c5eb55030426368d2ceebbdf2b3e1e153a92f00..0000000000000000000000000000000000000000
--- a/pySDC/implementations/problem_classes/Battery_2Condensators.py
+++ /dev/null
@@ -1,168 +0,0 @@
-import numpy as np
-
-from pySDC.core.Errors import ParameterError
-from pySDC.core.Problem import ptype
-from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh
-
-
-class battery_2condensators(ptype):
-    """
-    Example implementing the battery drain model using two capacitors as in the description in the PinTSimE
-    project
-    Attributes:
-        A: system matrix, representing the 3 ODEs
-    """
-
-    def __init__(self, problem_params, dtype_u=mesh, dtype_f=imex_mesh):
-        """
-        Initialization routine
-        Args:
-            problem_params (dict): custom parameters for the example
-            dtype_u: mesh data type for solution
-            dtype_f: mesh data type for RHS
-        """
-
-        problem_params['nvars'] = 3
-
-        # these parameters will be used later, so assert their existence
-        essential_keys = ['Vs', 'Rs', 'C1', 'C2', 'R', 'L', 'alpha', 'V_ref', 'set_switch', 't_switch']
-
-        for key in essential_keys:
-            if key not in problem_params:
-                msg = 'need %s to instantiate problem, only got %s' % (key, str(problem_params.keys()))
-                raise ParameterError(msg)
-
-        # invoke super init, passing number of dofs, dtype_u and dtype_f
-        super(battery_2condensators, self).__init__(
-            init=(problem_params['nvars'], None, np.dtype('float64')),
-            dtype_u=dtype_u,
-            dtype_f=dtype_f,
-            params=problem_params,
-        )
-
-        self.A = np.zeros((3, 3))
-
-    def eval_f(self, u, t):
-        """
-        Routine to evaluate the RHS
-        Args:
-            u (dtype_u): current values
-            t (float): current time
-        Returns:
-            dtype_f: the RHS
-        """
-
-        f = self.dtype_f(self.init, val=0.0)
-        f.impl[:] = self.A.dot(u)
-
-        # switch to C2
-        if (
-            u[1] <= self.params.V_ref[0]
-            and u[2] > self.params.V_ref[1]
-            or self.params.set_switch[0]
-            and not self.params.set_switch[1]
-        ):
-            if self.params.set_switch[0]:
-                if t >= self.params.t_switch[0]:
-                    f.expl[0] = 0
-
-                else:
-                    f.expl[0] = 0
-
-            else:
-                f.expl[0] = 0
-
-        # switch to Vs
-        elif u[2] <= self.params.V_ref[1] or (self.params.set_switch[0] and self.params.set_switch[1]):
-            # switch to Vs
-            if self.params.set_switch[1]:
-                if t >= self.params.t_switch[1]:
-                    f.expl[0] = self.params.Vs / self.params.L
-
-                else:
-                    f.expl[0] = 0
-
-            else:
-                f.expl[0] = self.params.Vs / self.params.L
-
-        elif (
-            u[1] > self.params.V_ref[0]
-            and u[2] > self.params.V_ref[1]
-            or not self.params.set_switch[0]
-            and not self.params.set_switch[1]
-        ):
-            # C1 supplies energy
-            f.expl[0] = 0
-
-        return f
-
-    def solve_system(self, rhs, factor, u0, t):
-        """
-        Simple linear solver for (I-factor*A)u = rhs
-        Args:
-            rhs (dtype_f): right-hand side for the linear system
-            factor (float): abbrev. for the local stepsize (or any other factor required)
-            u0 (dtype_u): initial guess for the iterative solver
-            t (float): current time (e.g. for time-dependent BCs)
-        Returns:
-            dtype_u: solution as mesh
-        """
-        self.A = np.zeros((3, 3))
-
-        # switch to C2
-        if (
-            rhs[1] <= self.params.V_ref[0]
-            and rhs[2] > self.params.V_ref[1]
-            or self.params.set_switch[0]
-            and not self.params.set_switch[1]
-        ):
-            if self.params.set_switch[0]:
-                if t >= self.params.t_switch[0]:
-                    self.A[2, 2] = -1 / (self.params.C2 * self.params.R)
-
-                else:
-                    self.A[1, 1] = -1 / (self.params.C1 * self.params.R)
-            else:
-                self.A[2, 2] = -1 / (self.params.C2 * self.params.R)
-
-        # switch to Vs
-        elif rhs[2] <= self.params.V_ref[1] or (self.params.set_switch[0] and self.params.set_switch[1]):
-            if self.params.set_switch[1]:
-                if t >= self.params.t_switch[1]:
-                    self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
-
-                else:
-                    self.A[2, 2] = -1 / (self.params.C2 * self.params.R)
-
-            else:
-                self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
-
-        elif (
-            rhs[1] > self.params.V_ref[0]
-            and rhs[2] > self.params.V_ref[1]
-            or not self.params.set_switch[0]
-            and not self.params.set_switch[1]
-        ):
-            # C1 supplies energy
-            self.A[1, 1] = -1 / (self.params.C1 * self.params.R)
-
-        me = self.dtype_u(self.init)
-        me[:] = np.linalg.solve(np.eye(self.params.nvars) - factor * self.A, rhs)
-        return me
-
-    def u_exact(self, t):
-        """
-        Routine to compute the exact solution at time t
-        Args:
-            t (float): current time
-        Returns:
-            dtype_u: exact solution
-        """
-        assert t == 0, 'ERROR: u_exact only valid for t=0'
-
-        me = self.dtype_u(self.init)
-
-        me[0] = 0.0  # cL
-        me[1] = self.params.alpha * self.params.V_ref[0]  # vC1
-        me[2] = self.params.alpha * self.params.V_ref[1]  # vC2
-        return me
diff --git a/pySDC/projects/PinTSimE/battery_2capacitors_model.py b/pySDC/projects/PinTSimE/battery_2capacitors_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1401c16862ed03ae8979ba32d9dad7675b2226f
--- /dev/null
+++ b/pySDC/projects/PinTSimE/battery_2capacitors_model.py
@@ -0,0 +1,230 @@
+import numpy as np
+import dill
+from pathlib import Path
+
+from pySDC.helpers.stats_helper import get_sorted
+from pySDC.core.Collocation import CollBase as Collocation
+from pySDC.implementations.problem_classes.Battery import battery_n_capacitors
+from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
+from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
+from pySDC.projects.PinTSimE.battery_model import (
+    controller_run,
+    generate_description,
+    get_recomputed,
+    log_data,
+    proof_assertions_description,
+)
+from pySDC.projects.PinTSimE.piline_model import setup_mpl
+import pySDC.helpers.plot_helper as plt_helper
+from pySDC.core.Hooks import hooks
+
+from pySDC.projects.PinTSimE.switch_estimator import SwitchEstimator
+
+
+def run():
+    """
+    Executes the simulation for the battery model using the IMEX sweeper and plot the results
+    as <problem_class>_model_solution_<sweeper_class>.png
+    """
+
+    dt = 1e-2
+    t0 = 0.0
+    Tend = 3.5
+
+    problem_classes = [battery_n_capacitors]
+    sweeper_classes = [imex_1st_order]
+
+    ncapacitors = 2
+    alpha = 5.0
+    V_ref = np.array([1.0, 1.0])
+    C = np.array([1.0, 1.0])
+
+    recomputed = False
+    use_switch_estimator = [True]
+
+    for problem, sweeper in zip(problem_classes, sweeper_classes):
+        for use_SE in use_switch_estimator:
+            description, controller_params = generate_description(
+                dt, problem, sweeper, log_data, False, use_SE, ncapacitors, alpha, V_ref, C
+            )
+
+            # Assertions
+            proof_assertions_description(description, False, use_SE)
+
+            proof_assertions_time(dt, Tend, V_ref, alpha)
+
+            stats = controller_run(description, controller_params, False, use_SE, t0, Tend)
+
+            check_solution(stats, dt, use_SE)
+
+            plot_voltages(description, problem.__name__, sweeper.__name__, recomputed, use_SE, False)
+
+
+def plot_voltages(description, problem, sweeper, recomputed, use_switch_estimator, use_adaptivity, cwd='./'):
+    """
+    Routine to plot the numerical solution of the model
+
+    Args:
+        description(dict): contains all information for a controller run
+        problem (pySDC.core.Problem.ptype): problem class that wants to be simulated
+        sweeper (pySDC.core.Sweeper.sweeper): sweeper class for solving the problem class numerically
+        recomputed (bool): flag if the values after a restart are used or before
+        use_switch_estimator (bool): flag if the switch estimator wants to be used or not
+        use_adaptivity (bool): flag if adaptivity wants to be used or not
+        cwd (str): current working directory
+    """
+
+    f = open(cwd + 'data/{}_{}_USE{}_USA{}.dat'.format(problem, sweeper, use_switch_estimator, use_adaptivity), 'rb')
+    stats = dill.load(f)
+    f.close()
+
+    # convert filtered statistics to list of iterations count, sorted by process
+    cL = np.array([me[1][0] for me in get_sorted(stats, type='u', recomputed=recomputed)])
+    vC1 = np.array([me[1][1] for me in get_sorted(stats, type='u', recomputed=recomputed)])
+    vC2 = np.array([me[1][2] for me in get_sorted(stats, type='u', recomputed=recomputed)])
+
+    t = np.array([me[0] for me in get_sorted(stats, type='u', recomputed=recomputed)])
+
+    setup_mpl()
+    fig, ax = plt_helper.plt.subplots(1, 1, figsize=(4.5, 3))
+    ax.plot(t, cL, label='$i_L$')
+    ax.plot(t, vC1, label='$v_{C_1}$')
+    ax.plot(t, vC2, label='$v_{C_2}$')
+
+    if use_switch_estimator:
+        switches = get_recomputed(stats, type='switch', sortby='time')
+        if recomputed is not None:
+            assert len(switches) >= 2, f"Expected at least 2 switches, got {len(switches)}!"
+        t_switches = [v[1] for v in switches]
+
+        for i in range(len(t_switches)):
+            ax.axvline(x=t_switches[i], linestyle='--', color='k', label='Switch {}'.format(i + 1))
+
+    ax.legend(frameon=False, fontsize=12, loc='upper right')
+
+    ax.set_xlabel('Time')
+    ax.set_ylabel('Energy')
+
+    fig.savefig('data/battery_2capacitors_model_solution.png', dpi=300, bbox_inches='tight')
+    plt_helper.plt.close(fig)
+
+
+def check_solution(stats, dt, use_switch_estimator):
+    """
+    Function that checks the solution based on a hardcoded reference solution. Based on check_solution function from @brownbaerchen.
+
+    Args:
+        stats (dict): Raw statistics from a controller run
+        dt (float): initial time step
+        use_switch_estimator (bool): flag if the switch estimator wants to be used or not
+    """
+
+    data = get_data_dict(stats, use_switch_estimator)
+
+    if use_switch_estimator:
+        msg = f'Error when using the switch estimator for battery_2condensators for dt={dt:.1e}:'
+        if dt == 1e-2:
+            expected = {
+                'cL': 1.2065280755094876,
+                'vC1': 1.0094825899806945,
+                'vC2': 1.0050052828742688,
+                'switch1': 1.6094379124373626,
+                'switch2': 3.209437912457051,
+                'restarts': 2.0,
+                'sum_niters': 1568,
+            }
+        elif dt == 4e-1:
+            expected = {
+                'cL': 1.1842780233981391,
+                'vC1': 1.0094891393319418,
+                'vC2': 1.00103823232433,
+                'switch1': 1.6075867934844466,
+                'switch2': 3.209437912436633,
+                'restarts': 2.0,
+                'sum_niters': 2000,
+            }
+        elif dt == 4e-2:
+            expected = {
+                'cL': 1.180493652021971,
+                'vC1': 1.0094825917376264,
+                'vC2': 1.0007713468084405,
+                'switch1': 1.6094074085553605,
+                'switch2': 3.209437912440314,
+                'restarts': 2.0,
+                'sum_niters': 2364,
+            }
+        elif dt == 4e-3:
+            expected = {
+                'cL': 1.1537529501025199,
+                'vC1': 1.001438946726028,
+                'vC2': 1.0004331625246141,
+                'switch1': 1.6093728710270467,
+                'switch2': 3.217437912434171,
+                'restarts': 2.0,
+                'sum_niters': 8920,
+            }
+
+    got = {
+        'cL': data['cL'][-1],
+        'vC1': data['vC1'][-1],
+        'vC2': data['vC2'][-1],
+        'switch1': data['switch1'],
+        'switch2': data['switch2'],
+        'restarts': data['restarts'],
+        'sum_niters': data['sum_niters'],
+    }
+
+    for key in expected.keys():
+        assert np.isclose(
+            expected[key], got[key], rtol=1e-4
+        ), f'{msg} Expected {key}={expected[key]:.4e}, got {key}={got[key]:.4e}'
+
+
+def get_data_dict(stats, use_switch_estimator, recomputed=False):
+    """
+    Converts the statistics in a useful data dictionary so that it can be easily checked in the check_solution function.
+    Based on @brownbaerchen's get_data function.
+
+    Args:
+        stats (dict): Raw statistics from a controller run
+        use_switch_estimator (bool): flag if the switch estimator wants to be used or not
+        recomputed (bool): flag if the values after a restart are used or before
+
+    Return:
+        data (dict): contains all information as the statistics dict
+    """
+
+    data = dict()
+    data['cL'] = np.array([me[1][0] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')])
+    data['vC1'] = np.array([me[1][1] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')])
+    data['vC2'] = np.array([me[1][2] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')])
+    data['switch1'] = np.array(get_recomputed(stats, type='switch', sortby='time'))[0, 1]
+    data['switch2'] = np.array(get_recomputed(stats, type='switch', sortby='time'))[-1, 1]
+    data['restarts'] = np.sum(np.array(get_sorted(stats, type='restart', recomputed=None, sortby='time'))[:, 1])
+    data['sum_niters'] = np.sum(np.array(get_sorted(stats, type='niter', recomputed=None, sortby='time'))[:, 1])
+
+    return data
+
+
+def proof_assertions_time(dt, Tend, V_ref, alpha):
+    """
+    Function to proof the assertions regarding the time domain (in combination with the specific problem):
+
+    Args:
+        dt (float): time step for computation
+        Tend (float): end time
+        V_ref (np.ndarray): Reference values (problem parameter)
+        alpha (np.float): Multiple used for initial conditions (problem_parameter)
+    """
+
+    assert (
+        Tend == 3.5 and V_ref[0] == 1.0 and V_ref[1] == 1.0 and alpha == 5.0
+    ), "Error! Do not use other parameters for V_ref[:] != 1.0, alpha != 1.2, Tend != 0.3 due to hardcoded reference!"
+
+    assert (
+        dt == 1e-2 or dt == 4e-1 or dt == 4e-2 or dt == 4e-3
+    ), "Error! Do not use other time steps dt != 4e-1 or dt != 4e-2 or dt != 4e-3 due to hardcoded references!"
+
+
+if __name__ == "__main__":
+    run()
diff --git a/pySDC/projects/PinTSimE/battery_2condensators_model.py b/pySDC/projects/PinTSimE/battery_2condensators_model.py
deleted file mode 100644
index c41cb8271236f26c8c7c6f827a0a2853e1154c91..0000000000000000000000000000000000000000
--- a/pySDC/projects/PinTSimE/battery_2condensators_model.py
+++ /dev/null
@@ -1,241 +0,0 @@
-import numpy as np
-import dill
-from pathlib import Path
-
-from pySDC.helpers.stats_helper import get_sorted
-from pySDC.core.Collocation import CollBase as Collocation
-from pySDC.implementations.problem_classes.Battery_2Condensators import battery_2condensators
-from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
-from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
-from pySDC.implementations.transfer_classes.TransferMesh import mesh_to_mesh
-from pySDC.projects.PinTSimE.piline_model import setup_mpl
-import pySDC.helpers.plot_helper as plt_helper
-from pySDC.core.Hooks import hooks
-
-from pySDC.projects.PinTSimE.switch_estimator import SwitchEstimator
-
-
-class log_data(hooks):
-    def post_step(self, step, level_number):
-
-        super(log_data, self).post_step(step, level_number)
-
-        # some abbreviations
-        L = step.levels[level_number]
-
-        L.sweep.compute_end_point()
-
-        self.add_to_stats(
-            process=step.status.slot,
-            time=L.time + L.dt,
-            level=L.level_index,
-            iter=0,
-            sweep=L.status.sweep,
-            type='current L',
-            value=L.uend[0],
-        )
-        self.add_to_stats(
-            process=step.status.slot,
-            time=L.time + L.dt,
-            level=L.level_index,
-            iter=0,
-            sweep=L.status.sweep,
-            type='voltage C1',
-            value=L.uend[1],
-        )
-        self.add_to_stats(
-            process=step.status.slot,
-            time=L.time + L.dt,
-            level=L.level_index,
-            iter=0,
-            sweep=L.status.sweep,
-            type='voltage C2',
-            value=L.uend[2],
-        )
-        self.increment_stats(
-            process=step.status.slot,
-            time=L.time,
-            level=L.level_index,
-            iter=0,
-            sweep=L.status.sweep,
-            type='restart',
-            value=1,
-            initialize=0,
-        )
-
-
-def main(use_switch_estimator=True):
-    """
-    A simple test program to do SDC/PFASST runs for the battery drain model using 2 condensators
-    """
-
-    # initialize level parameters
-    level_params = dict()
-    level_params['restol'] = 1e-13
-    level_params['dt'] = 1e-2
-
-    # initialize sweeper parameters
-    sweeper_params = dict()
-    sweeper_params['quad_type'] = 'LOBATTO'
-    sweeper_params['num_nodes'] = 5
-    sweeper_params['QI'] = 'LU'  # For the IMEX sweeper, the LU-trick can be activated for the implicit part
-    sweeper_params['initial_guess'] = 'zero'
-
-    # initialize problem parameters
-    problem_params = dict()
-    problem_params['Vs'] = 5.0
-    problem_params['Rs'] = 0.5
-    problem_params['C1'] = 1.0
-    problem_params['C2'] = 1.0
-    problem_params['R'] = 1.0
-    problem_params['L'] = 1.0
-    problem_params['alpha'] = 5.0
-    problem_params['V_ref'] = np.array([1.0, 1.0])  # [V_ref1, V_ref2]
-    problem_params['set_switch'] = np.array([False, False], dtype=bool)
-    problem_params['t_switch'] = np.zeros(np.shape(problem_params['V_ref'])[0])
-
-    # initialize step parameters
-    step_params = dict()
-    step_params['maxiter'] = 20
-
-    # initialize controller parameters
-    controller_params = dict()
-    controller_params['logger_level'] = 20
-    controller_params['hook_class'] = log_data
-
-    # convergence controllers
-    convergence_controllers = dict()
-    if use_switch_estimator:
-        switch_estimator_params = {}
-        convergence_controllers[SwitchEstimator] = switch_estimator_params
-
-    # fill description dictionary for easy step instantiation
-    description = dict()
-    description['problem_class'] = battery_2condensators  # pass problem class
-    description['problem_params'] = problem_params  # pass problem parameters
-    description['sweeper_class'] = imex_1st_order  # pass sweeper
-    description['sweeper_params'] = sweeper_params  # pass sweeper parameters
-    description['level_params'] = level_params  # pass level parameters
-    description['step_params'] = step_params
-    description['space_transfer_class'] = mesh_to_mesh  # pass spatial transfer class
-
-    if use_switch_estimator:
-        description['convergence_controllers'] = convergence_controllers
-
-    proof_assertions_description(description, problem_params)
-
-    # set time parameters
-    t0 = 0.0
-    Tend = 3.5
-
-    # instantiate controller
-    controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description)
-
-    # get initial values on finest level
-    P = controller.MS[0].levels[0].prob
-    uinit = P.u_exact(t0)
-
-    # call main function to get things done...
-    uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
-
-    Path("data").mkdir(parents=True, exist_ok=True)
-    fname = 'data/battery_2condensators.dat'
-    f = open(fname, 'wb')
-    dill.dump(stats, f)
-    f.close()
-
-    # filter statistics by number of iterations
-    iter_counts = get_sorted(stats, type='niter', sortby='time')
-
-    # compute and print statistics
-    min_iter = 20
-    max_iter = 0
-
-    f = open('battery_2condensators_out.txt', 'w')
-    niters = np.array([item[1] for item in iter_counts])
-    out = '   Mean number of iterations: %4.2f' % np.mean(niters)
-    f.write(out + '\n')
-    print(out)
-    for item in iter_counts:
-        out = 'Number of iterations for time %4.2f: %1i' % item
-        f.write(out + '\n')
-        # print(out)
-        min_iter = min(min_iter, item[1])
-        max_iter = max(max_iter, item[1])
-
-    restarts = np.array(get_sorted(stats, type='restart', recomputed=False))[:, 1]
-    print("Restarts for dt: ", level_params['dt'], " -- ", np.sum(restarts))
-
-    assert np.mean(niters) <= 10, "Mean number of iterations is too high, got %s" % np.mean(niters)
-    f.close()
-
-    plot_voltages(description, use_switch_estimator)
-
-    return np.mean(niters)
-
-
-def plot_voltages(description, use_switch_estimator, cwd='./'):
-    """
-    Routine to plot the numerical solution of the model
-    """
-
-    f = open(cwd + 'data/battery_2condensators.dat', 'rb')
-    stats = dill.load(f)
-    f.close()
-
-    # convert filtered statistics to list of iterations count, sorted by process
-    cL = get_sorted(stats, type='current L', sortby='time')
-    vC1 = get_sorted(stats, type='voltage C1', sortby='time')
-    vC2 = get_sorted(stats, type='voltage C2', sortby='time')
-
-    times = [v[0] for v in cL]
-
-    setup_mpl()
-    fig, ax = plt_helper.plt.subplots(1, 1, figsize=(4.5, 3))
-    ax.plot(times, [v[1] for v in cL], label='$i_L$')
-    ax.plot(times, [v[1] for v in vC1], label='$v_{C_1}$')
-    ax.plot(times, [v[1] for v in vC2], label='$v_{C_2}$')
-
-    if use_switch_estimator:
-        t_switch_plot = np.zeros(np.shape(description['problem_params']['t_switch'])[0])
-        for i in range(np.shape(description['problem_params']['t_switch'])[0]):
-            t_switch_plot[i] = description['problem_params']['t_switch'][i]
-
-            ax.axvline(x=t_switch_plot[i], linestyle='--', color='k', label='Switch {}'.format(i + 1))
-
-    ax.legend(frameon=False, fontsize=12, loc='upper right')
-
-    ax.set_xlabel('Time')
-    ax.set_ylabel('Energy')
-
-    fig.savefig('data/battery_2condensators_model_solution.png', dpi=300, bbox_inches='tight')
-    plt_helper.plt.close(fig)
-
-
-def proof_assertions_description(description, problem_params):
-    """
-    Function to proof the assertions (function to get cleaner code)
-    """
-
-    assert problem_params['alpha'] > problem_params['V_ref'][0], 'Please set "alpha" greater than "V_ref1"'
-    assert problem_params['alpha'] > problem_params['V_ref'][1], 'Please set "alpha" greater than "V_ref2"'
-
-    assert problem_params['V_ref'][0] > 0, 'Please set "V_ref1" greater than 0'
-    assert problem_params['V_ref'][1] > 0, 'Please set "V_ref2" greater than 0'
-
-    assert type(problem_params['V_ref']) == np.ndarray, '"V_ref" needs to be an array (of type float)'
-    assert not problem_params['set_switch'][0], 'First entry of "set_switch" needs to be False'
-    assert not problem_params['set_switch'][1], 'Second entry of "set_switch" needs to be False'
-
-    assert not type(problem_params['t_switch']) == float, '"t_switch" has to be an array with entry zero'
-
-    assert problem_params['t_switch'][0] == 0, 'First entry of "t_switch" needs to be zero'
-    assert problem_params['t_switch'][1] == 0, 'Second entry of "t_switch" needs to be zero'
-
-    assert 'errtol' not in description['step_params'].keys(), 'No exact solution known to compute error'
-    assert 'alpha' in description['problem_params'].keys(), 'Please supply "alpha" in the problem parameters'
-    assert 'V_ref' in description['problem_params'].keys(), 'Please supply "V_ref" in the problem parameters'
-
-
-if __name__ == "__main__":
-    main()
diff --git a/pySDC/projects/PinTSimE/battery_model.py b/pySDC/projects/PinTSimE/battery_model.py
index a127687eb5a9e04243329905ac3fac41ce215d83..428a202672c0adaad8fdd19c0ce10bd4915abb51 100644
--- a/pySDC/projects/PinTSimE/battery_model.py
+++ b/pySDC/projects/PinTSimE/battery_model.py
@@ -2,7 +2,7 @@ import numpy as np
 import dill
 from pathlib import Path
 
-from pySDC.helpers.stats_helper import get_sorted
+from pySDC.helpers.stats_helper import sort_stats, filter_stats, get_sorted
 from pySDC.core.Collocation import CollBase as Collocation
 from pySDC.implementations.problem_classes.Battery import battery, battery_implicit
 from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
@@ -32,26 +32,26 @@ class log_data(hooks):
             level=L.level_index,
             iter=0,
             sweep=L.status.sweep,
-            type='current L',
-            value=L.uend[0],
+            type='u',
+            value=L.uend,
         )
         self.add_to_stats(
             process=step.status.slot,
-            time=L.time + L.dt,
+            time=L.time,
             level=L.level_index,
             iter=0,
             sweep=L.status.sweep,
-            type='voltage C',
-            value=L.uend[1],
+            type='restart',
+            value=int(step.status.get('restart')),
         )
         self.add_to_stats(
             process=step.status.slot,
-            time=L.time,
+            time=L.time + L.dt,
             level=L.level_index,
             iter=0,
             sweep=L.status.sweep,
-            type='restart',
-            value=int(step.status.get('restart')),
+            type='dt',
+            value=L.dt,
         )
         self.add_to_stats(
             process=step.status.slot,
@@ -59,14 +59,41 @@ class log_data(hooks):
             level=L.level_index,
             iter=0,
             sweep=L.status.sweep,
-            type='dt',
-            value=L.dt,
+            type='e_embedded',
+            value=L.status.get('error_embedded_estimate'),
         )
 
 
-def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity):
+def generate_description(
+    dt,
+    problem,
+    sweeper,
+    hook_class,
+    use_adaptivity,
+    use_switch_estimator,
+    ncapacitors,
+    alpha,
+    V_ref,
+    C,
+    max_restarts=None,
+):
     """
-    A simple test program to do SDC/PFASST runs for the battery drain model
+    Generate a description for the battery models for a controller run.
+    Args:
+        dt (float): time step for computation
+        problem (pySDC.core.Problem.ptype): problem class that wants to be simulated
+        sweeper (pySDC.core.Sweeper.sweeper): sweeper class for solving the problem class numerically
+        hook_class (pySDC.core.Hooks): logged data for a problem
+        use_adaptivity (bool): flag if the adaptivity wants to be used or not
+        use_switch_estimator (bool): flag if the switch estimator wants to be used or not
+        ncapacitors (np.int): number of capacitors used for the battery_model
+        alpha (np.float): Multiple used for the initial conditions (problem_parameter)
+        V_ref (np.ndarray): Reference values for the capacitors (problem_parameter)
+        C (np.ndarray): Capacitances (problem_parameter
+
+    Returns:
+        description (dict): contains all information for a controller run
+        controller_params (dict): Parameters needed for a controller run
     """
 
     # initialize level parameters
@@ -78,22 +105,21 @@ def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity):
     sweeper_params = dict()
     sweeper_params['quad_type'] = 'LOBATTO'
     sweeper_params['num_nodes'] = 5
-    # sweeper_params['QI'] = 'LU'  # For the IMEX sweeper, the LU-trick can be activated for the implicit part
-    sweeper_params['initial_guess'] = 'zero'
+    sweeper_params['QI'] = 'IE'
+    sweeper_params['initial_guess'] = 'spread'
 
     # initialize problem parameters
     problem_params = dict()
     problem_params['newton_maxiter'] = 200
     problem_params['newton_tol'] = 1e-08
+    problem_params['ncapacitors'] = ncapacitors  # number of condensators
     problem_params['Vs'] = 5.0
     problem_params['Rs'] = 0.5
-    problem_params['C'] = 1.0
+    problem_params['C'] = C
     problem_params['R'] = 1.0
     problem_params['L'] = 1.0
-    problem_params['alpha'] = 1.2
-    problem_params['V_ref'] = 1.0
-    problem_params['set_switch'] = np.array([False], dtype=bool)
-    problem_params['t_switch'] = np.zeros(1)
+    problem_params['alpha'] = alpha
+    problem_params['V_ref'] = V_ref
 
     # initialize step parameters
     step_params = dict()
@@ -101,8 +127,8 @@ def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity):
 
     # initialize controller parameters
     controller_params = dict()
-    controller_params['logger_level'] = 20
-    controller_params['hook_class'] = log_data
+    controller_params['logger_level'] = 30
+    controller_params['hook_class'] = hook_class
     controller_params['mssdc_jac'] = False
 
     # convergence controllers
@@ -113,7 +139,7 @@ def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity):
 
     if use_adaptivity:
         adaptivity_params = dict()
-        adaptivity_params['e_tol'] = 1e-12
+        adaptivity_params['e_tol'] = 1e-7
         convergence_controllers.update({Adaptivity: adaptivity_params})
 
     # fill description dictionary for easy step instantiation
@@ -124,15 +150,28 @@ def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity):
     description['sweeper_params'] = sweeper_params  # pass sweeper parameters
     description['level_params'] = level_params  # pass level parameters
     description['step_params'] = step_params
+    if max_restarts is not None:
+        description['max_restarts'] = max_restarts
 
     if use_switch_estimator or use_adaptivity:
         description['convergence_controllers'] = convergence_controllers
 
-    proof_assertions_description(description, problem_params)
+    return description, controller_params
 
-    # set time parameters
-    t0 = 0.0
-    Tend = 0.3
+
+def controller_run(description, controller_params, use_adaptivity, use_switch_estimator, t0, Tend):
+    """
+    Executes a controller run for a problem defined in the description
+
+    Args:
+        description (dict): contains all information for a controller run
+        controller_params (dict): Parameters needed for a controller run
+        use_adaptivity (bool): flag if the adaptivity wants to be used or not
+        use_switch_estimator (bool): flag if the switch estimator wants to be used or not
+
+    Returns:
+        stats (dict): Raw statistics from a controller run
+    """
 
     # instantiate controller
     controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description)
@@ -144,35 +183,18 @@ def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity):
     # call main function to get things done...
     uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
 
-    # filter statistics by number of iterations
-    iter_counts = get_sorted(stats, type='niter', recomputed=False, sortby='time')
-
-    # compute and print statistics
-    min_iter = 20
-    max_iter = 0
+    problem = description['problem_class']
+    sweeper = description['sweeper_class']
 
     Path("data").mkdir(parents=True, exist_ok=True)
-    fname = 'data/battery_{}_USE{}_USA{}.dat'.format(sweeper.__name__, use_switch_estimator, use_adaptivity)
+    fname = 'data/{}_{}_USE{}_USA{}.dat'.format(
+        problem.__name__, sweeper.__name__, use_switch_estimator, use_adaptivity
+    )
     f = open(fname, 'wb')
     dill.dump(stats, f)
     f.close()
 
-    f = open('data/battery_out.txt', 'w')
-    niters = np.array([item[1] for item in iter_counts])
-    out = '   Mean number of iterations: %4.2f' % np.mean(niters)
-    f.write(out + '\n')
-    print(out)
-    for item in iter_counts:
-        out = 'Number of iterations for time %4.2f: %1i' % item
-        f.write(out + '\n')
-        print(out)
-        min_iter = min(min_iter, item[1])
-        max_iter = max(max_iter, item[1])
-
-    assert np.mean(niters) <= 5, "Mean number of iterations is too high, got %s" % np.mean(niters)
-    f.close()
-
-    return description
+    return stats
 
 
 def run():
@@ -182,53 +204,80 @@ def run():
     """
 
     dt = 1e-2
+    t0 = 0.0
+    Tend = 0.3
+
     problem_classes = [battery, battery_implicit]
     sweeper_classes = [imex_1st_order, generic_implicit]
+
+    ncapacitors = 1
+    alpha = 1.2
+    V_ref = np.array([1.0])
+    C = np.array([1.0])
+
+    recomputed = False
     use_switch_estimator = [True]
     use_adaptivity = [True]
 
     for problem, sweeper in zip(problem_classes, sweeper_classes):
         for use_SE in use_switch_estimator:
             for use_A in use_adaptivity:
-                description = main(
-                    dt=dt,
-                    problem=problem,
-                    sweeper=sweeper,
-                    use_switch_estimator=use_SE,
-                    use_adaptivity=use_A,
+                description, controller_params = generate_description(
+                    dt, problem, sweeper, log_data, use_A, use_SE, ncapacitors, alpha, V_ref, C
                 )
 
-            plot_voltages(description, problem.__name__, sweeper.__name__, use_SE, use_A)
+                # Assertions
+                proof_assertions_description(description, use_A, use_SE)
+
+                proof_assertions_time(dt, Tend, V_ref, alpha)
+
+                stats = controller_run(description, controller_params, use_A, use_SE, t0, Tend)
 
+            check_solution(stats, dt, problem.__name__, use_A, use_SE)
 
-def plot_voltages(description, problem, sweeper, use_switch_estimator, use_adaptivity, cwd='./'):
+            plot_voltages(description, problem.__name__, sweeper.__name__, recomputed, use_SE, use_A)
+
+
+def plot_voltages(description, problem, sweeper, recomputed, use_switch_estimator, use_adaptivity, cwd='./'):
     """
     Routine to plot the numerical solution of the model
+
+    Args:
+        description(dict): contains all information for a controller run
+        problem (pySDC.core.Problem.ptype): problem class that wants to be simulated
+        sweeper (pySDC.core.Sweeper.sweeper): sweeper class for solving the problem class numerically
+        recomputed (bool): flag if the values after a restart are used or before
+        use_switch_estimator (bool): flag if the switch estimator wants to be used or not
+        use_adaptivity (bool): flag if adaptivity wants to be used or not
+        cwd (str): current working directory
     """
 
-    f = open(cwd + 'data/battery_{}_USE{}_USA{}.dat'.format(sweeper, use_switch_estimator, use_adaptivity), 'rb')
+    f = open(cwd + 'data/{}_{}_USE{}_USA{}.dat'.format(problem, sweeper, use_switch_estimator, use_adaptivity), 'rb')
     stats = dill.load(f)
     f.close()
 
     # convert filtered statistics to list of iterations count, sorted by process
-    cL = get_sorted(stats, type='current L', recomputed=False, sortby='time')
-    vC = get_sorted(stats, type='voltage C', recomputed=False, sortby='time')
+    cL = np.array([me[1][0] for me in get_sorted(stats, type='u', recomputed=recomputed)])
+    vC = np.array([me[1][1] for me in get_sorted(stats, type='u', recomputed=recomputed)])
 
-    times = [v[0] for v in cL]
+    t = np.array([me[0] for me in get_sorted(stats, type='u', recomputed=recomputed)])
 
     setup_mpl()
     fig, ax = plt_helper.plt.subplots(1, 1, figsize=(3, 3))
     ax.set_title('Simulation of {} using {}'.format(problem, sweeper), fontsize=10)
-    ax.plot(times, [v[1] for v in cL], label=r'$i_L$')
-    ax.plot(times, [v[1] for v in vC], label=r'$v_C$')
+    ax.plot(t, cL, label=r'$i_L$')
+    ax.plot(t, vC, label=r'$v_C$')
 
     if use_switch_estimator:
-        val_switch = get_sorted(stats, type='switch1', sortby='time')
-        t_switch = [v[1] for v in val_switch]
+        switches = get_recomputed(stats, type='switch', sortby='time')
+
+        assert len(switches) >= 1, 'No switches found!'
+        t_switch = [v[1] for v in switches]
         ax.axvline(x=t_switch[-1], linestyle='--', linewidth=0.8, color='r', label='Switch')
 
     if use_adaptivity:
         dt = np.array(get_sorted(stats, type='dt', recomputed=False))
+
         dt_ax = ax.twinx()
         dt_ax.plot(dt[:, 0], dt[:, 1], linestyle='-', linewidth=0.8, color='k', label=r'$\Delta t$')
         dt_ax.set_ylabel(r'$\Delta t$', fontsize=8)
@@ -245,23 +294,377 @@ def plot_voltages(description, problem, sweeper, use_switch_estimator, use_adapt
     plt_helper.plt.close(fig)
 
 
-def proof_assertions_description(description, problem_params):
+def check_solution(stats, dt, problem, use_adaptivity, use_switch_estimator):
+    """
+    Function that checks the solution based on a hardcoded reference solution. Based on check_solution function from @brownbaerchen.
+
+    Args:
+        stats (dict): Raw statistics from a controller run
+        dt (float): initial time step
+        problem (problem_class.__name__): the problem_class that is numerically solved
+        use_switch_estimator (bool):
+        use_adaptivity (bool):
+    """
+
+    data = get_data_dict(stats, use_adaptivity, use_switch_estimator)
+
+    if problem == 'battery':
+        if use_switch_estimator and use_adaptivity:
+            msg = f'Error when using switch estimator and adaptivity for battery for dt={dt:.1e}:'
+            if dt == 1e-2:
+                expected = {
+                    'cL': 0.5474500710994862,
+                    'vC': 1.0019332967173764,
+                    'dt': 0.011761752270047832,
+                    'e_em': 8.001793672107738e-10,
+                    'switches': 0.18232155791181945,
+                    'restarts': 3.0,
+                    'sum_niters': 44,
+                }
+            elif dt == 4e-2:
+                expected = {
+                    'cL': 0.5525783945667581,
+                    'vC': 1.00001743462299,
+                    'dt': 0.03550610373897258,
+                    'e_em': 6.21240694442804e-08,
+                    'switches': 0.18231603298272345,
+                    'restarts': 4.0,
+                    'sum_niters': 56,
+                }
+            elif dt == 4e-3:
+                expected = {
+                    'cL': 0.5395601429161445,
+                    'vC': 1.0000413761942089,
+                    'dt': 0.028281271825675414,
+                    'e_em': 2.5628611677319668e-08,
+                    'switches': 0.18230920573953438,
+                    'restarts': 3.0,
+                    'sum_niters': 48,
+                }
+
+            got = {
+                'cL': data['cL'][-1],
+                'vC': data['vC'][-1],
+                'dt': data['dt'][-1],
+                'e_em': data['e_em'][-1],
+                'switches': data['switches'][-1],
+                'restarts': data['restarts'],
+                'sum_niters': data['sum_niters'],
+            }
+        elif use_switch_estimator and not use_adaptivity:
+            msg = f'Error when using switch estimator for battery for dt={dt:.1e}:'
+            if dt == 1e-2:
+                expected = {
+                    'cL': 0.5423033461806986,
+                    'vC': 1.000118710428906,
+                    'switches': 0.1823188001399631,
+                    'restarts': 1.0,
+                    'sum_niters': 284,
+                }
+            elif dt == 4e-2:
+                expected = {
+                    'cL': 0.6139093327509394,
+                    'vC': 1.0010140038721593,
+                    'switches': 0.1824302065533169,
+                    'restarts': 1.0,
+                    'sum_niters': 48,
+                }
+            elif dt == 4e-3:
+                expected = {
+                    'cL': 0.5429509935448258,
+                    'vC': 1.0001158309787614,
+                    'switches': 0.18232183080236553,
+                    'restarts': 1.0,
+                    'sum_niters': 392,
+                }
+
+            got = {
+                'cL': data['cL'][-1],
+                'vC': data['vC'][-1],
+                'switches': data['switches'][-1],
+                'restarts': data['restarts'],
+                'sum_niters': data['sum_niters'],
+            }
+
+        elif not use_switch_estimator and use_adaptivity:
+            msg = f'Error when using adaptivity for battery for dt={dt:.1e}:'
+            if dt == 1e-2:
+                expected = {
+                    'cL': 0.5413318777113352,
+                    'vC': 0.9963444569399663,
+                    'dt': 0.020451912195976252,
+                    'e_em': 7.157646031430431e-09,
+                    'restarts': 4.0,
+                    'sum_niters': 56,
+                }
+            elif dt == 4e-2:
+                expected = {
+                    'cL': 0.5966289599915113,
+                    'vC': 0.9923148791604984,
+                    'dt': 0.03564958366355817,
+                    'e_em': 6.210964231812e-08,
+                    'restarts': 1.0,
+                    'sum_niters': 36,
+                }
+            elif dt == 4e-3:
+                expected = {
+                    'cL': 0.5431613774808756,
+                    'vC': 0.9934307674636834,
+                    'dt': 0.022880524075396924,
+                    'e_em': 1.1130212751453428e-08,
+                    'restarts': 3.0,
+                    'sum_niters': 52,
+                }
+
+            got = {
+                'cL': data['cL'][-1],
+                'vC': data['vC'][-1],
+                'dt': data['dt'][-1],
+                'e_em': data['e_em'][-1],
+                'restarts': data['restarts'],
+                'sum_niters': data['sum_niters'],
+            }
+
+    elif problem == 'battery_implicit':
+        if use_switch_estimator and use_adaptivity:
+            msg = f'Error when using switch estimator and adaptivity for battery_implicit for dt={dt:.1e}:'
+            if dt == 1e-2:
+                expected = {
+                    'cL': 0.5424577937840791,
+                    'vC': 1.0001051105894005,
+                    'dt': 0.01,
+                    'e_em': 2.220446049250313e-16,
+                    'switches': 0.1822923488448394,
+                    'restarts': 6.0,
+                    'sum_niters': 60,
+                }
+            elif dt == 4e-2:
+                expected = {
+                    'cL': 0.6717104472882885,
+                    'vC': 1.0071670698947914,
+                    'dt': 0.035896059229296486,
+                    'e_em': 6.208836400567463e-08,
+                    'switches': 0.18232158833761175,
+                    'restarts': 3.0,
+                    'sum_niters': 36,
+                }
+            elif dt == 4e-3:
+                expected = {
+                    'cL': 0.5396216192241711,
+                    'vC': 1.0000561014463172,
+                    'dt': 0.009904645972832471,
+                    'e_em': 2.220446049250313e-16,
+                    'switches': 0.18230549652342606,
+                    'restarts': 4.0,
+                    'sum_niters': 44,
+                }
+
+            got = {
+                'cL': data['cL'][-1],
+                'vC': data['vC'][-1],
+                'dt': data['dt'][-1],
+                'e_em': data['e_em'][-1],
+                'switches': data['switches'][-1],
+                'restarts': data['restarts'],
+                'sum_niters': data['sum_niters'],
+            }
+        elif use_switch_estimator and not use_adaptivity:
+            msg = f'Error when using switch estimator for battery_implicit for dt={dt:.1e}:'
+            if dt == 1e-2:
+                expected = {
+                    'cL': 0.5423033363981951,
+                    'vC': 1.000118715162845,
+                    'switches': 0.18231880065636324,
+                    'restarts': 1.0,
+                    'sum_niters': 284,
+                }
+            elif dt == 4e-2:
+                expected = {
+                    'cL': 0.613909968362315,
+                    'vC': 1.0010140112484431,
+                    'switches': 0.18243023230469263,
+                    'restarts': 1.0,
+                    'sum_niters': 48,
+                }
+            elif dt == 4e-3:
+                expected = {
+                    'cL': 0.5429616576526073,
+                    'vC': 1.0001158454740509,
+                    'switches': 0.1823218812753008,
+                    'restarts': 1.0,
+                    'sum_niters': 392,
+                }
+
+            got = {
+                'cL': data['cL'][-1],
+                'vC': data['vC'][-1],
+                'switches': data['switches'][-1],
+                'restarts': data['restarts'],
+                'sum_niters': data['sum_niters'],
+            }
+
+        elif not use_switch_estimator and use_adaptivity:
+            msg = f'Error when using adaptivity for battery_implicit for dt={dt:.1e}:'
+            if dt == 1e-2:
+                expected = {
+                    'cL': 0.5490142863996689,
+                    'vC': 0.997253099984895,
+                    'dt': 0.024243123245133835,
+                    'e_em': 1.4052013885823555e-08,
+                    'restarts': 11.0,
+                    'sum_niters': 96,
+                }
+            elif dt == 4e-2:
+                expected = {
+                    'cL': 0.5556563012729733,
+                    'vC': 0.9930947318467772,
+                    'dt': 0.035507110551631804,
+                    'e_em': 6.2098696185231e-08,
+                    'restarts': 6.0,
+                    'sum_niters': 64,
+                }
+            elif dt == 4e-3:
+                expected = {
+                    'cL': 0.5401117929618637,
+                    'vC': 0.9933888475391347,
+                    'dt': 0.03176025170463925,
+                    'e_em': 4.0386798239033794e-08,
+                    'restarts': 8.0,
+                    'sum_niters': 80,
+                }
+
+            got = {
+                'cL': data['cL'][-1],
+                'vC': data['vC'][-1],
+                'dt': data['dt'][-1],
+                'e_em': data['e_em'][-1],
+                'restarts': data['restarts'],
+                'sum_niters': data['sum_niters'],
+            }
+
+    for key in expected.keys():
+        assert np.isclose(
+            expected[key], got[key], rtol=1e-4
+        ), f'{msg} Expected {key}={expected[key]:.4e}, got {key}={got[key]:.4e}'
+
+
+def get_data_dict(stats, use_adaptivity=True, use_switch_estimator=True, recomputed=False):
+    """
+    Converts the statistics in a useful data dictionary so that it can be easily checked in the check_solution function.
+    Based on @brownbaerchen's get_data function.
+
+    Args:
+        stats (dict): Raw statistics from a controller run
+        use_adaptivity (bool): flag if adaptivity wants to be used or not
+        use_switch_estimator (bool): flag if the switch estimator wants to be used or not
+        recomputed (bool): flag if the values after a restart are used or before
+
+    Return:
+        data (dict): contains all information as the statistics dict
+    """
+
+    data = dict()
+
+    data['cL'] = np.array([me[1][0] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')])
+    data['vC'] = np.array([me[1][1] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')])
+    if use_adaptivity:
+        data['dt'] = np.array(get_sorted(stats, type='dt', recomputed=recomputed, sortby='time'))[:, 1]
+        data['e_em'] = np.array(
+            get_sorted(stats, type='error_embedded_estimate', recomputed=recomputed, sortby='time')
+        )[:, 1]
+    if use_switch_estimator:
+        data['switches'] = np.array(get_recomputed(stats, type='switch', sortby='time'))[:, 1]
+    if use_adaptivity or use_switch_estimator:
+        data['restarts'] = np.sum(np.array(get_sorted(stats, type='restart', recomputed=None, sortby='time'))[:, 1])
+    data['sum_niters'] = np.sum(np.array(get_sorted(stats, type='niter', recomputed=None, sortby='time'))[:, 1])
+
+    return data
+
+
+def get_recomputed(stats, type, sortby):
+    """
+    Function that filters statistics after a recomputation. It stores all value of a type before restart. If there are multiple values
+    with same time point, it only stores the elements with unique times.
+
+    Args:
+        stats (dict): Raw statistics from a controller run
+        type (str): the type the be filtered
+        sortby (str): string to specify which key to use for sorting
+
+    Returns:
+        sorted_list (list): list of filtered statistics
     """
-    Function to proof the assertions (function to get cleaner code)
+
+    sorted_nested_list = []
+    times_unique = np.unique([me[0] for me in get_sorted(stats, type=type)])
+    filtered_list = [
+        filter_stats(
+            stats,
+            time=t_unique,
+            num_restarts=max([me.num_restarts for me in filter_stats(stats, type=type, time=t_unique).keys()]),
+            type=type,
+        )
+        for t_unique in times_unique
+    ]
+    for item in filtered_list:
+        sorted_nested_list.append(sort_stats(item, sortby=sortby))
+    sorted_list = [item for sub_item in sorted_nested_list for item in sub_item]
+    return sorted_list
+
+
+def proof_assertions_description(description, use_adaptivity, use_switch_estimator):
     """
+    Function to proof the assertions in the description.
 
-    assert problem_params['alpha'] > problem_params['V_ref'], 'Please set "alpha" greater than "V_ref"'
-    assert problem_params['V_ref'] > 0, 'Please set "V_ref" greater than 0'
-    assert type(problem_params['V_ref']) == float, '"V_ref" needs to be of type float'
+    Args:
+        description(dict): contains all information for a controller run
+        use_adaptivity (bool): flag if adaptivity wants to be used or not
+        use_switch_estimator (bool): flag if the switch estimator wants to be used or not
+    """
 
-    assert type(problem_params['set_switch'][0]) == np.bool_, '"set_switch" has to be an bool array'
-    assert type(problem_params['t_switch']) == np.ndarray, '"t_switch" has to be an array'
-    assert problem_params['t_switch'][0] == 0, '"t_switch" is only allowed to have entry zero'
+    n = description['problem_params']['ncapacitors']
+    assert (
+        description['problem_params']['alpha'] > description['problem_params']['V_ref'][k] for k in range(n)
+    ), 'Please set "alpha" greater than values of "V_ref"'
+    assert type(description['problem_params']['V_ref']) == np.ndarray, '"V_ref" needs to be an np.ndarray'
+    assert type(description['problem_params']['C']) == np.ndarray, '"C" needs to be an np.ndarray '
+    assert (
+        np.shape(description['problem_params']['V_ref'])[0] == n
+    ), 'Number of reference values needs to be equal to number of condensators'
+    assert (
+        np.shape(description['problem_params']['C'])[0] == n
+    ), 'Number of capacitance values needs to be equal to number of condensators'
+
+    assert (
+        description['problem_params']['V_ref'][k] > 0 for k in range(n)
+    ), 'Please set values of "V_ref" greater than 0'
 
     assert 'errtol' not in description['step_params'].keys(), 'No exact solution known to compute error'
     assert 'alpha' in description['problem_params'].keys(), 'Please supply "alpha" in the problem parameters'
     assert 'V_ref' in description['problem_params'].keys(), 'Please supply "V_ref" in the problem parameters'
 
+    if use_switch_estimator or use_adaptivity:
+        assert description['level_params']['restol'] == -1, "Please set restol to -1 or omit it"
+
+
+def proof_assertions_time(dt, Tend, V_ref, alpha):
+    """
+    Function to proof the assertions regarding the time domain (in combination with the specific problem):
+
+    Args:
+        dt (float): time step for computation
+        Tend (float): end time
+        V_ref (np.ndarray): Reference values (problem parameter)
+        alpha (np.float): Multiple used for initial conditions (problem_parameter)
+    """
+
+    assert dt < Tend, "Time step is too large for the time domain!"
+
+    assert (
+        Tend == 0.3 and V_ref[0] == 1.0 and alpha == 1.2
+    ), "Error! Do not use other parameters for V_ref != 1.0, alpha != 1.2, Tend != 0.3 due to hardcoded reference!"
+    assert dt == 1e-2, "Error! Do not use another time step dt!= 1e-2!"
+
 
 if __name__ == "__main__":
     run()
diff --git a/pySDC/projects/PinTSimE/estimation_check.py b/pySDC/projects/PinTSimE/estimation_check.py
index e565bd11d387705ee5f0df74c602077c03c133e4..da012be499e25935134b9b922654a730f7fbed37 100644
--- a/pySDC/projects/PinTSimE/estimation_check.py
+++ b/pySDC/projects/PinTSimE/estimation_check.py
@@ -9,143 +9,79 @@ from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
 from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit
 from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
 from pySDC.projects.PinTSimE.piline_model import setup_mpl
-from pySDC.projects.PinTSimE.battery_model import log_data, proof_assertions_description
+from pySDC.projects.PinTSimE.battery_model import (
+    controller_run,
+    check_solution,
+    generate_description,
+    get_recomputed,
+    log_data,
+    proof_assertions_description,
+)
 import pySDC.helpers.plot_helper as plt_helper
+from pySDC.core.Hooks import hooks
 
 from pySDC.projects.PinTSimE.switch_estimator import SwitchEstimator
 from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity
-from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import EstimateEmbeddedErrorNonMPI
 
 
-def run(dt, problem, sweeper, use_switch_estimator, use_adaptivity, V_ref):
+def run(cwd='./'):
     """
-    A simple test program to do SDC/PFASST runs for the battery drain model
+    Routine to check the differences between using a switch estimator or not
+
+    Args:
+        cwd (str): current working directory
     """
 
-    # initialize level parameters
-    level_params = dict()
-    level_params['restol'] = -1
-    level_params['dt'] = dt
-
-    # initialize sweeper parameters
-    sweeper_params = dict()
-    sweeper_params['quad_type'] = 'LOBATTO'
-    sweeper_params['num_nodes'] = 5
-    # sweeper_params['QI'] = 'LU'  # For the IMEX sweeper, the LU-trick can be activated for the implicit part
-    sweeper_params['initial_guess'] = 'zero'
-
-    # initialize problem parameters
-    problem_params = dict()
-    problem_params['newton_maxiter'] = 200
-    problem_params['newton_tol'] = 1e-08
-    problem_params['Vs'] = 5.0
-    problem_params['Rs'] = 0.5
-    problem_params['C'] = 1.0
-    problem_params['R'] = 1.0
-    problem_params['L'] = 1.0
-    problem_params['alpha'] = 1.2
-    problem_params['V_ref'] = V_ref
-    problem_params['set_switch'] = np.array([False], dtype=bool)
-    problem_params['t_switch'] = np.zeros(1)
-
-    # initialize step parameters
-    step_params = dict()
-    step_params['maxiter'] = 4
-
-    # initialize controller parameters
-    controller_params = dict()
-    controller_params['logger_level'] = 20
-    controller_params['hook_class'] = log_data
-    controller_params['mssdc_jac'] = False
-
-    # convergence controllers
-    convergence_controllers = dict()
-    if use_switch_estimator:
-        switch_estimator_params = dict()
-        convergence_controllers.update({SwitchEstimator: switch_estimator_params})
-
-    if use_adaptivity:
-        adaptivity_params = dict()
-        adaptivity_params['e_tol'] = 1e-7
-        convergence_controllers.update({Adaptivity: adaptivity_params})
-
-    # fill description dictionary for easy step instantiation
-    description = dict()
-    description['problem_class'] = problem  # pass problem class
-    description['problem_params'] = problem_params  # pass problem parameters
-    description['sweeper_class'] = sweeper  # pass sweeper
-    description['sweeper_params'] = sweeper_params  # pass sweeper parameters
-    description['level_params'] = level_params  # pass level parameters
-    description['step_params'] = step_params
-    description['max_restarts'] = 1
-
-    if use_switch_estimator or use_adaptivity:
-        description['convergence_controllers'] = convergence_controllers
-
-    proof_assertions_description(description, problem_params)
-
-    # set time parameters
+    dt_list = [4e-2, 4e-3]
     t0 = 0.0
     Tend = 0.3
 
-    # instantiate controller
-    controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description)
-
-    # get initial values on finest level
-    P = controller.MS[0].levels[0].prob
-    uinit = P.u_exact(t0)
-
-    # call main function to get things done...
-    uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
-
-    Path("data").mkdir(parents=True, exist_ok=True)
-    fname = 'data/battery.dat'
-    f = open(fname, 'wb')
-    dill.dump(stats, f)
-    f.close()
-
-    # filter statistics by number of iterations
-    iter_counts = get_sorted(stats, type='niter', recomputed=False, sortby='time')
-
-    # compute and print statistics
-    f = open('data/battery_out.txt', 'w')
-    niters = np.array([item[1] for item in iter_counts])
-
-    assert np.mean(niters) <= 11, "Mean number of iterations is too high, got %s" % np.mean(niters)
-    f.close()
-
-    return description, stats
+    problem_classes = [battery, battery_implicit]
+    sweeper_classes = [imex_1st_order, generic_implicit]
 
+    ncapacitors = 1
+    alpha = 1.2
+    V_ref = np.array([1.0])
+    C = np.array([1.0])
 
-def check(cwd='./'):
-    """
-    Routine to check the differences between using a switch estimator or not
-    """
-
-    V_ref = 1.0
-    dt_list = [1e-2, 1e-3]
+    max_restarts = 1
     use_switch_estimator = [True, False]
     use_adaptivity = [True, False]
-    restarts_true = []
-    restarts_false_adapt = []
-    restarts_true_adapt = []
-
-    problem_classes = [battery, battery_implicit]
-    sweeper_classes = [imex_1st_order, generic_implicit]
+    restarts_SE = []
+    restarts_adapt = []
+    restarts_SE_adapt = []
 
     for problem, sweeper in zip(problem_classes, sweeper_classes):
         for dt_item in dt_list:
             for use_SE in use_switch_estimator:
                 for use_A in use_adaptivity:
-                    description, stats = run(
-                        dt=dt_item,
-                        problem=problem,
-                        sweeper=sweeper,
-                        use_switch_estimator=use_SE,
-                        use_adaptivity=use_A,
-                        V_ref=V_ref,
+                    description, controller_params = generate_description(
+                        dt_item,
+                        problem,
+                        sweeper,
+                        log_data,
+                        use_A,
+                        use_SE,
+                        ncapacitors,
+                        alpha,
+                        V_ref,
+                        C,
+                        max_restarts,
                     )
 
+                    # Assertions
+                    proof_assertions_description(description, use_A, use_SE)
+
+                    stats = controller_run(description, controller_params, use_A, use_SE, t0, Tend)
+
+                    if use_A or use_SE:
+                        check_solution(stats, dt_item, problem.__name__, use_A, use_SE)
+
+                    if use_SE:
+                        assert (
+                            len(get_recomputed(stats, type='switch', sortby='time')) >= 1
+                        ), 'No switches found for dt={}!'.format(dt_item)
+
                     fname = 'data/battery_dt{}_USE{}_USA{}_{}.dat'.format(dt_item, use_SE, use_A, sweeper.__name__)
                     f = open(fname, 'wb')
                     dill.dump(stats, f)
@@ -153,24 +89,23 @@ def check(cwd='./'):
 
                     if use_SE or use_A:
                         restarts_sorted = np.array(get_sorted(stats, type='restart', recomputed=None))[:, 1]
-                        print('Restarts for dt={}: {}'.format(dt_item, np.sum(restarts_sorted)))
                         if use_SE and not use_A:
-                            restarts_true.append(np.sum(restarts_sorted))
+                            restarts_SE.append(np.sum(restarts_sorted))
 
                         elif not use_SE and use_A:
-                            restarts_false_adapt.append(np.sum(restarts_sorted))
+                            restarts_adapt.append(np.sum(restarts_sorted))
 
                         elif use_SE and use_A:
-                            restarts_true_adapt.append(np.sum(restarts_sorted))
+                            restarts_SE_adapt.append(np.sum(restarts_sorted))
 
         accuracy_check(dt_list, problem.__name__, sweeper.__name__, V_ref)
 
         differences_around_switch(
             dt_list,
             problem.__name__,
-            restarts_true,
-            restarts_false_adapt,
-            restarts_true_adapt,
+            restarts_SE,
+            restarts_adapt,
+            restarts_SE_adapt,
             sweeper.__name__,
             V_ref,
         )
@@ -179,14 +114,21 @@ def check(cwd='./'):
 
         iterations_over_time(dt_list, description['step_params']['maxiter'], problem.__name__, sweeper.__name__)
 
-        restarts_true = []
-        restarts_false_adapt = []
-        restarts_true_adapt = []
+        restarts_SE = []
+        restarts_adapt = []
+        restarts_SE_adapt = []
 
 
 def accuracy_check(dt_list, problem, sweeper, V_ref, cwd='./'):
     """
     Routine to check accuracy for different step sizes in case of using adaptivity
+
+    Args:
+        dt_list (list): list of considered (initial) step sizes
+        problem (pySDC.core.Problem.ptype): Problem class used to consider (the class name)
+        sweeper (pySDC.core.Sweeper.sweeper): Sweeper used to solve (the class name)
+        V_ref (np.float): reference value for the switch
+        cwd (str): current working directory
     """
 
     if len(dt_list) > 1:
@@ -202,45 +144,49 @@ def accuracy_check(dt_list, problem, sweeper, V_ref, cwd='./'):
     count_ax = 0
     for dt_item in dt_list:
         f3 = open(cwd + 'data/battery_dt{}_USETrue_USATrue_{}.dat'.format(dt_item, sweeper), 'rb')
-        stats_TT = dill.load(f3)
+        stats_SE_adapt = dill.load(f3)
         f3.close()
 
         f4 = open(cwd + 'data/battery_dt{}_USEFalse_USATrue_{}.dat'.format(dt_item, sweeper), 'rb')
-        stats_FT = dill.load(f4)
+        stats_adapt = dill.load(f4)
         f4.close()
 
-        val_switch_TT = get_sorted(stats_TT, type='switch1', sortby='time')
-        t_switch_adapt = [v[1] for v in val_switch_TT]
-        t_switch_adapt = t_switch_adapt[-1]
+        switches_SE_adapt = get_recomputed(stats_SE_adapt, type='switch', sortby='time')
+        t_switch_SE_adapt = [v[1] for v in switches_SE_adapt]
+        t_switch_SE_adapt = t_switch_SE_adapt[-1]
 
-        dt_TT_val = get_sorted(stats_TT, type='dt', recomputed=False)
-        dt_FT_val = get_sorted(stats_FT, type='dt', recomputed=False)
+        dt_SE_adapt_val = get_sorted(stats_SE_adapt, type='dt', recomputed=False)
+        dt_adapt_val = get_sorted(stats_adapt, type='dt', recomputed=False)
 
-        e_emb_TT_val = get_sorted(stats_TT, type='e_embedded', recomputed=False)
-        e_emb_FT_val = get_sorted(stats_FT, type='e_embedded', recomputed=False)
+        e_emb_SE_adapt_val = get_sorted(stats_SE_adapt, type='e_embedded', recomputed=False)
+        e_emb_adapt_val = get_sorted(stats_adapt, type='e_embedded', recomputed=False)
 
-        times_TT = [v[0] for v in e_emb_TT_val]
-        times_FT = [v[0] for v in e_emb_FT_val]
+        times_SE_adapt = [v[0] for v in e_emb_SE_adapt_val]
+        times_adapt = [v[0] for v in e_emb_adapt_val]
 
-        e_emb_TT = [v[1] for v in e_emb_TT_val]
-        e_emb_FT = [v[1] for v in e_emb_FT_val]
+        e_emb_SE_adapt = [v[1] for v in e_emb_SE_adapt_val]
+        e_emb_adapt = [v[1] for v in e_emb_adapt_val]
 
         if len(dt_list) > 1:
-            ax_acc[count_ax].set_title(r'$\Delta t$={}'.format(dt_item))
+            ax_acc[count_ax].set_title(r'$\Delta t_\mathrm{initial}$=%s' % dt_item)
             dt1 = ax_acc[count_ax].plot(
-                [v[0] for v in dt_TT_val], [v[1] for v in dt_TT_val], 'ko-', label=r'SE+A - $\Delta t$'
+                [v[0] for v in dt_SE_adapt_val],
+                [v[1] for v in dt_SE_adapt_val],
+                'ko-',
+                label=r'SE+A - $\Delta t_\mathrm{adapt}$',
             )
             dt2 = ax_acc[count_ax].plot(
-                [v[0] for v in dt_FT_val], [v[1] for v in dt_FT_val], 'g-', label=r'A - $\Delta t$'
+                [v[0] for v in dt_adapt_val], [v[1] for v in dt_adapt_val], 'g-', label=r'A - $\Delta t_\mathrm{adapt}$'
             )
-            ax_acc[count_ax].axvline(x=t_switch_adapt, linestyle='--', linewidth=0.5, color='r', label='Switch')
+            ax_acc[count_ax].axvline(x=t_switch_SE_adapt, linestyle='--', linewidth=0.5, color='r', label='Switch')
+            ax_acc[count_ax].tick_params(axis='both', which='major', labelsize=6)
             ax_acc[count_ax].set_xlabel('Time', fontsize=6)
             if count_ax == 0:
-                ax_acc[count_ax].set_ylabel(r'$\Delta t_{adapted}$', fontsize=6)
+                ax_acc[count_ax].set_ylabel(r'$\Delta t_\mathrm{adapt}$', fontsize=6)
 
             e_ax = ax_acc[count_ax].twinx()
-            e_plt1 = e_ax.plot(times_TT, e_emb_TT, 'k--', label=r'SE+A - $\epsilon_{emb}$')
-            e_plt2 = e_ax.plot(times_FT, e_emb_FT, 'g--', label=r'A - $\epsilon_{emb}$')
+            e_plt1 = e_ax.plot(times_SE_adapt, e_emb_SE_adapt, 'k--', label=r'SE+A - $\epsilon_{emb}$')
+            e_plt2 = e_ax.plot(times_adapt, e_emb_adapt, 'g--', label=r'A - $\epsilon_{emb}$')
             e_ax.set_yscale('log', base=10)
             e_ax.set_ylim(1e-16, 1e-7)
             e_ax.tick_params(labelsize=6)
@@ -248,26 +194,37 @@ def accuracy_check(dt_list, problem, sweeper, V_ref, cwd='./'):
             lines = dt1 + e_plt1 + dt2 + e_plt2
             labels = [l.get_label() for l in lines]
 
-            ax_acc[count_ax].legend(lines, labels, frameon=False, fontsize=6, loc='upper left')
+            ax_acc[count_ax].legend(lines, labels, frameon=False, fontsize=6, loc='upper right')
 
         else:
-            ax_acc.set_title(r'$\Delta t$={}'.format(dt_item))
-            dt1 = ax_acc.plot([v[0] for v in dt_TT_val], [v[1] for v in dt_TT_val], 'ko-', label=r'SE+A - $\Delta t$')
-            dt2 = ax_acc.plot([v[0] for v in dt_FT_val], [v[1] for v in dt_FT_val], 'go-', label=r'A - $\Delta t$')
-            ax_acc.axvline(x=t_switch_adapt, linestyle='--', linewidth=0.5, color='r', label='Switch')
+            ax_acc.set_title(r'$\Delta t_\mathrm{initial}$=%s' % dt_item)
+            dt1 = ax_acc.plot(
+                [v[0] for v in dt_SE_adapt_val],
+                [v[1] for v in dt_SE_adapt_val],
+                'ko-',
+                label=r'SE+A - $\Delta t_\mathrm{adapt}$',
+            )
+            dt2 = ax_acc.plot(
+                [v[0] for v in dt_adapt_val],
+                [v[1] for v in dt_adapt_val],
+                'go-',
+                label=r'A - $\Delta t_\mathrm{adapt}$',
+            )
+            ax_acc.axvline(x=t_switch_SE_adapt, linestyle='--', linewidth=0.5, color='r', label='Switch')
+            ax_acc.tick_params(axis='both', which='major', labelsize=6)
             ax_acc.set_xlabel('Time', fontsize=6)
-            ax_acc.set_ylabel(r'$Delta t_{adapted}$', fontsize=6)
+            ax_acc.set_ylabel(r'$Delta t_\mathrm{adapt}$', fontsize=6)
 
             e_ax = ax_acc.twinx()
-            e_plt1 = e_ax.plot(times_TT, e_emb_TT, 'k--', label=r'SE+A - $\epsilon_{emb}$')
-            e_plt2 = e_ax.plot(times_FT, e_emb_FT, 'g--', label=r'A - $\epsilon_{emb}$')
+            e_plt1 = e_ax.plot(times_SE_adapt, e_emb_SE_adapt, 'k--', label=r'SE+A - $\epsilon_{emb}$')
+            e_plt2 = e_ax.plot(times_adapt, e_emb_adapt, 'g--', label=r'A - $\epsilon_{emb}$')
             e_ax.set_yscale('log', base=10)
             e_ax.tick_params(labelsize=6)
 
             lines = dt1 + e_plt1 + dt2 + e_plt2
             labels = [l.get_label() for l in lines]
 
-            ax_acc.legend(lines, labels, frameon=False, fontsize=6, loc='upper left')
+            ax_acc.legend(lines, labels, frameon=False, fontsize=6, loc='upper right')
 
         count_ax += 1
 
@@ -276,10 +233,20 @@ def accuracy_check(dt_list, problem, sweeper, V_ref, cwd='./'):
 
 
 def differences_around_switch(
-    dt_list, problem, restarts_true, restarts_false_adapt, restarts_true_adapt, sweeper, V_ref, cwd='./'
+    dt_list, problem, restarts_SE, restarts_adapt, restarts_SE_adapt, sweeper, V_ref, cwd='./'
 ):
     """
     Routine to plot the differences before, at, and after the switch. Produces the diffs_estimation_<sweeper_class>.png file
+
+    Args:
+        dt_list (list): list of considered (initial) step sizes
+        problem (pySDC.core.Problem.ptype): Problem class used to consider (the class name)
+        restarts_SE (list): Restarts for the solve only using the switch estimator
+        restarts_adapt (list): Restarts for the solve of only using adaptivity
+        restarts_SE_adapt (list): Restarts for the solve of using both, switch estimator and adaptivity
+        sweeper (pySDC.core.Sweeper.sweeper): Sweeper used to solve (the class name)
+        V_ref (np.float): reference value for the switch
+        cwd (str): current working directory
     """
 
     diffs_true_at = []
@@ -294,59 +261,77 @@ def differences_around_switch(
     diffs_false_after_adapt = []
     for dt_item in dt_list:
         f1 = open(cwd + 'data/battery_dt{}_USETrue_USAFalse_{}.dat'.format(dt_item, sweeper), 'rb')
-        stats_TF = dill.load(f1)
+        stats_SE = dill.load(f1)
         f1.close()
 
         f2 = open(cwd + 'data/battery_dt{}_USEFalse_USAFalse_{}.dat'.format(dt_item, sweeper), 'rb')
-        stats_FF = dill.load(f2)
+        stats = dill.load(f2)
         f2.close()
 
         f3 = open(cwd + 'data/battery_dt{}_USETrue_USATrue_{}.dat'.format(dt_item, sweeper), 'rb')
-        stats_TT = dill.load(f3)
+        stats_SE_adapt = dill.load(f3)
         f3.close()
 
         f4 = open(cwd + 'data/battery_dt{}_USEFalse_USATrue_{}.dat'.format(dt_item, sweeper), 'rb')
-        stats_FT = dill.load(f4)
+        stats_adapt = dill.load(f4)
         f4.close()
 
-        val_switch_TF = get_sorted(stats_TF, type='switch1', sortby='time')
-        t_switch = [v[1] for v in val_switch_TF]
+        switches_SE = get_recomputed(stats_SE, type='switch', sortby='time')
+        t_switch = [v[1] for v in switches_SE]
         t_switch = t_switch[-1]  # battery has only one single switch
 
-        val_switch_TT = get_sorted(stats_TT, type='switch1', sortby='time')
-        t_switch_adapt = [v[1] for v in val_switch_TT]
-        t_switch_adapt = t_switch_adapt[-1]
+        switches_SE_adapt = get_recomputed(stats_SE_adapt, type='switch', sortby='time')
+        t_switch_SE_adapt = [v[1] for v in switches_SE_adapt]
+        t_switch_SE_adapt = t_switch_SE_adapt[-1]
 
-        vC_TF = get_sorted(stats_TF, type='voltage C', recomputed=False, sortby='time')
-        vC_FT = get_sorted(stats_FT, type='voltage C', recomputed=False, sortby='time')
-        vC_TT = get_sorted(stats_TT, type='voltage C', recomputed=False, sortby='time')
-        vC_FF = get_sorted(stats_FF, type='voltage C', sortby='time')
+        vC_SE = [me[1][1] for me in get_sorted(stats_SE, type='u', recomputed=False)]
+        vC_adapt = [me[1][1] for me in get_sorted(stats_adapt, type='u', recomputed=False)]
+        vC_SE_adapt = [me[1][1] for me in get_sorted(stats_SE_adapt, type='u', recomputed=False)]
+        vC = [me[1][1] for me in get_sorted(stats, type='u', recomputed=False)]
 
-        diff_TF, diff_FF = [v[1] - V_ref for v in vC_TF], [v[1] - V_ref for v in vC_FF]
-        times_TF, times_FF = [v[0] for v in vC_TF], [v[0] for v in vC_FF]
+        diff_SE, diff = vC_SE - V_ref[0], vC - V_ref[0]
+        times_SE = [me[0] for me in get_sorted(stats_SE, type='u', recomputed=False)]
+        times = [me[0] for me in get_sorted(stats, type='u', recomputed=False)]
 
-        diff_FT, diff_TT = [v[1] - V_ref for v in vC_FT], [v[1] - V_ref for v in vC_TT]
-        times_FT, times_TT = [v[0] for v in vC_FT], [v[0] for v in vC_TT]
+        diff_adapt, diff_SE_adapt = vC_adapt - V_ref[0], vC_SE_adapt - V_ref[0]
+        times_adapt = [me[0] for me in get_sorted(stats_adapt, type='u', recomputed=False)]
+        times_SE_adapt = [me[0] for me in get_sorted(stats_SE_adapt, type='u', recomputed=False)]
 
-        for m in range(len(times_TF)):
-            if np.round(times_TF[m], 15) == np.round(t_switch, 15):
-                diffs_true_at.append(diff_TF[m])
+        diffs_true_at.append(
+            [diff_SE[m] for m in range(len(times_SE)) if np.isclose(times_SE[m], t_switch, atol=1e-15)]
+        )
 
-        for m in range(1, len(times_FF)):
-            if times_FF[m - 1] <= t_switch <= times_FF[m]:
-                diffs_false_before.append(diff_FF[m - 1])
-                diffs_false_after.append(diff_FF[m])
+        diffs_false_before.append([diff[m - 1] for m in range(1, len(times)) if times[m - 1] <= t_switch <= times[m]])
+        diffs_false_after.append([diff[m] for m in range(1, len(times)) if times[m - 1] <= t_switch <= times[m]])
 
-        for m in range(len(times_TT)):
-            if np.round(times_TT[m], 13) == np.round(t_switch_adapt, 13):
-                diffs_true_at_adapt.append(diff_TT[m])
-                diffs_true_before_adapt.append(diff_TT[m - 1])
-                diffs_true_after_adapt.append(diff_TT[m + 1])
+        diffs_true_at_adapt.append(
+            [
+                diff_SE_adapt[m]
+                for m in range(len(times_SE_adapt))
+                if np.isclose(times_SE_adapt[m], t_switch_SE_adapt, atol=1e-13)
+            ]
+        )
+        diffs_true_before_adapt.append(
+            [
+                diff_SE_adapt[m - 1]
+                for m in range(len(times_SE_adapt))
+                if np.isclose(times_SE_adapt[m], t_switch_SE_adapt, atol=1e-13)
+            ]
+        )
+        diffs_true_after_adapt.append(
+            [
+                diff_SE_adapt[m + 1]
+                for m in range(len(times_SE_adapt))
+                if np.isclose(times_SE_adapt[m], t_switch_SE_adapt, atol=1e-13)
+            ]
+        )
 
-        for m in range(len(times_FT)):
-            if times_FT[m - 1] <= t_switch <= times_FT[m]:
-                diffs_false_before_adapt.append(diff_FT[m - 1])
-                diffs_false_after_adapt.append(diff_FT[m])
+        diffs_false_before_adapt.append(
+            [diff_adapt[m - 1] for m in range(len(times_adapt)) if times_adapt[m - 1] <= t_switch <= times_adapt[m]]
+        )
+        diffs_false_after_adapt.append(
+            [diff_adapt[m] for m in range(len(times_adapt)) if times_adapt[m - 1] <= t_switch <= times_adapt[m]]
+        )
 
     setup_mpl()
     fig_around, ax_around = plt_helper.plt.subplots(1, 3, figsize=(9, 3), sharex='col', sharey='row')
@@ -356,14 +341,15 @@ def differences_around_switch(
     pos13 = ax_around[0].plot(dt_list, diffs_true_at, 'ko--', label='at switch')
     ax_around[0].set_xticks(dt_list)
     ax_around[0].set_xticklabels(dt_list)
+    ax_around[0].tick_params(axis='both', which='major', labelsize=6)
     ax_around[0].set_xscale('log', base=10)
     ax_around[0].set_yscale('symlog', linthresh=1e-8)
     ax_around[0].set_ylim(-1, 1)
-    ax_around[0].set_xlabel(r'$\Delta t$', fontsize=6)
+    ax_around[0].set_xlabel(r'$\Delta t_\mathrm{initial}$', fontsize=6)
     ax_around[0].set_ylabel(r'$v_{C}-V_{ref}$', fontsize=6)
 
     restart_ax0 = ax_around[0].twinx()
-    restarts_plt0 = restart_ax0.plot(dt_list, restarts_true, 'cs--', label='Restarts')
+    restarts_plt0 = restart_ax0.plot(dt_list, restarts_SE, 'cs--', label='Restarts')
     restart_ax0.tick_params(labelsize=6)
 
     lines = pos11 + pos12 + pos13 + restarts_plt0
@@ -375,13 +361,14 @@ def differences_around_switch(
     pos22 = ax_around[1].plot(dt_list, diffs_false_after_adapt, 'bd--', label='after switch')
     ax_around[1].set_xticks(dt_list)
     ax_around[1].set_xticklabels(dt_list)
+    ax_around[1].tick_params(axis='both', which='major', labelsize=6)
     ax_around[1].set_xscale('log', base=10)
     ax_around[1].set_yscale('symlog', linthresh=1e-8)
     ax_around[1].set_ylim(-1, 1)
-    ax_around[1].set_xlabel(r'$\Delta t$', fontsize=6)
+    ax_around[1].set_xlabel(r'$\Delta t_\mathrm{initial}$', fontsize=6)
 
     restart_ax1 = ax_around[1].twinx()
-    restarts_plt1 = restart_ax1.plot(dt_list, restarts_false_adapt, 'cs--', label='Restarts')
+    restarts_plt1 = restart_ax1.plot(dt_list, restarts_adapt, 'cs--', label='Restarts')
     restart_ax1.tick_params(labelsize=6)
 
     lines = pos21 + pos22 + restarts_plt1
@@ -394,90 +381,103 @@ def differences_around_switch(
     pos33 = ax_around[2].plot(dt_list, diffs_true_at_adapt, 'ko--', label='at switch')
     ax_around[2].set_xticks(dt_list)
     ax_around[2].set_xticklabels(dt_list)
+    ax_around[2].tick_params(axis='both', which='major', labelsize=6)
     ax_around[2].set_xscale('log', base=10)
     ax_around[2].set_yscale('symlog', linthresh=1e-8)
     ax_around[2].set_ylim(-1, 1)
-    ax_around[2].set_xlabel(r'$\Delta t$', fontsize=6)
+    ax_around[2].set_xlabel(r'$\Delta t_\mathrm{initial}$', fontsize=6)
 
     restart_ax2 = ax_around[2].twinx()
-    restarts_plt2 = restart_ax2.plot(dt_list, restarts_true_adapt, 'cs--', label='Restarts')
+    restarts_plt2 = restart_ax2.plot(dt_list, restarts_SE_adapt, 'cs--', label='Restarts')
     restart_ax2.tick_params(labelsize=6)
 
     lines = pos31 + pos32 + pos33 + restarts_plt2
     labels = [l.get_label() for l in lines]
     ax_around[2].legend(frameon=False, fontsize=6, loc='lower right')
 
-    fig_around.savefig('data/diffs_estimation_{}.png'.format(sweeper), dpi=300, bbox_inches='tight')
+    fig_around.savefig('data/diffs_around_switch_{}.png'.format(sweeper), dpi=300, bbox_inches='tight')
     plt_helper.plt.close(fig_around)
 
 
 def differences_over_time(dt_list, problem, sweeper, V_ref, cwd='./'):
     """
     Routine to plot the differences in time using the switch estimator or not. Produces the difference_estimation_<sweeper_class>.png file
+
+    Args:
+        dt_list (list): list of considered (initial) step sizes
+        problem (pySDC.core.Problem.ptype): Problem class used to consider (the class name)
+        sweeper (pySDC.core.Sweeper.sweeper): Sweeper used to solve (the class name)
+        V_ref (np.float): reference value for the switch
+        cwd (str): current working directory
     """
 
     if len(dt_list) > 1:
         setup_mpl()
         fig_diffs, ax_diffs = plt_helper.plt.subplots(
-            2, len(dt_list), figsize=(3 * len(dt_list), 4), sharex='col', sharey='row'
+            2, len(dt_list), figsize=(4 * len(dt_list), 6), sharex='col', sharey='row'
         )
 
     else:
         setup_mpl()
-        fig_diffs, ax_diffs = plt_helper.plt.subplots(2, 1, figsize=(3, 3))
+        fig_diffs, ax_diffs = plt_helper.plt.subplots(2, 1, figsize=(4, 6))
 
     count_ax = 0
     for dt_item in dt_list:
         f1 = open(cwd + 'data/battery_dt{}_USETrue_USAFalse_{}.dat'.format(dt_item, sweeper), 'rb')
-        stats_TF = dill.load(f1)
+        stats_SE = dill.load(f1)
         f1.close()
 
         f2 = open(cwd + 'data/battery_dt{}_USEFalse_USAFalse_{}.dat'.format(dt_item, sweeper), 'rb')
-        stats_FF = dill.load(f2)
+        stats = dill.load(f2)
         f2.close()
 
         f3 = open(cwd + 'data/battery_dt{}_USETrue_USATrue_{}.dat'.format(dt_item, sweeper), 'rb')
-        stats_TT = dill.load(f3)
+        stats_SE_adapt = dill.load(f3)
         f3.close()
 
         f4 = open(cwd + 'data/battery_dt{}_USEFalse_USATrue_{}.dat'.format(dt_item, sweeper), 'rb')
-        stats_FT = dill.load(f4)
+        stats_adapt = dill.load(f4)
         f4.close()
 
-        val_switch_TF = get_sorted(stats_TF, type='switch1', sortby='time')
-        t_switch_TF = [v[1] for v in val_switch_TF]
-        t_switch_TF = t_switch_TF[-1]  # battery has only one single switch
+        switches_SE = get_recomputed(stats_SE, type='switch', sortby='time')
+        t_switch_SE = [v[1] for v in switches_SE]
+        t_switch_SE = t_switch_SE[-1]  # battery has only one single switch
 
-        val_switch_TT = get_sorted(stats_TT, type='switch1', sortby='time')
-        t_switch_adapt = [v[1] for v in val_switch_TT]
-        t_switch_adapt = t_switch_adapt[-1]
+        switches_SE_adapt = get_recomputed(stats_SE_adapt, type='switch', sortby='time')
+        t_switch_SE_adapt = [v[1] for v in switches_SE_adapt]
+        t_switch_SE_adapt = t_switch_SE_adapt[-1]
 
-        dt_FT = np.array(get_sorted(stats_FT, type='dt', recomputed=False, sortby='time'))
-        dt_TT = np.array(get_sorted(stats_TT, type='dt', recomputed=False, sortby='time'))
+        dt_adapt = np.array(get_sorted(stats_adapt, type='dt', recomputed=False))
+        dt_SE_adapt = np.array(get_sorted(stats_SE_adapt, type='dt', recomputed=False))
 
-        restart_FT = np.array(get_sorted(stats_FT, type='restart', recomputed=None, sortby='time'))
-        restart_TT = np.array(get_sorted(stats_TT, type='restart', recomputed=None, sortby='time'))
+        restart_adapt = np.array(get_sorted(stats_adapt, type='restart', recomputed=None))
+        restart_SE_adapt = np.array(get_sorted(stats_SE_adapt, type='restart', recomputed=None))
 
-        vC_TF = get_sorted(stats_TF, type='voltage C', recomputed=False, sortby='time')
-        vC_FT = get_sorted(stats_FT, type='voltage C', recomputed=False, sortby='time')
-        vC_TT = get_sorted(stats_TT, type='voltage C', recomputed=False, sortby='time')
-        vC_FF = get_sorted(stats_FF, type='voltage C', sortby='time')
+        vC_SE = [me[1][1] for me in get_sorted(stats_SE, type='u', recomputed=False)]
+        vC_adapt = [me[1][1] for me in get_sorted(stats_adapt, type='u', recomputed=False)]
+        vC_SE_adapt = [me[1][1] for me in get_sorted(stats_SE_adapt, type='u', recomputed=False)]
+        vC = [me[1][1] for me in get_sorted(stats, type='u', recomputed=False)]
 
-        diff_TF, diff_FF = [v[1] - V_ref for v in vC_TF], [v[1] - V_ref for v in vC_FF]
-        times_TF, times_FF = [v[0] for v in vC_TF], [v[0] for v in vC_FF]
+        diff_SE, diff = vC_SE - V_ref[0], vC - V_ref[0]
+        times_SE = [me[0] for me in get_sorted(stats_SE, type='u', recomputed=False)]
+        times = [me[0] for me in get_sorted(stats, type='u', recomputed=False)]
 
-        diff_FT, diff_TT = [v[1] - V_ref for v in vC_FT], [v[1] - V_ref for v in vC_TT]
-        times_FT, times_TT = [v[0] for v in vC_FT], [v[0] for v in vC_TT]
+        diff_adapt, diff_SE_adapt = vC_adapt - V_ref[0], vC_SE_adapt - V_ref[0]
+        times_adapt = [me[0] for me in get_sorted(stats_adapt, type='u', recomputed=False)]
+        times_SE_adapt = [me[0] for me in get_sorted(stats_SE_adapt, type='u', recomputed=False)]
 
         if len(dt_list) > 1:
-            ax_diffs[0, count_ax].set_title(r'$\Delta t$={}'.format(dt_item))
-            ax_diffs[0, count_ax].plot(times_TF, diff_TF, label='SE=True, A=False', color='#ff7f0e')
-            ax_diffs[0, count_ax].plot(times_FF, diff_FF, label='SE=False, A=False', color='#1f77b4')
-            ax_diffs[0, count_ax].plot(times_FT, diff_FT, label='SE=False, A=True', color='red', linestyle='--')
-            ax_diffs[0, count_ax].plot(times_TT, diff_TT, label='SE=True, A=True', color='limegreen', linestyle='-.')
-            ax_diffs[0, count_ax].axvline(x=t_switch_TF, linestyle='--', linewidth=0.5, color='k', label='Switch')
+            ax_diffs[0, count_ax].set_title(r'$\Delta t$=%s' % dt_item)
+            ax_diffs[0, count_ax].plot(times_SE, diff_SE, label='SE=True, A=False', color='#ff7f0e')
+            ax_diffs[0, count_ax].plot(times, diff, label='SE=False, A=False', color='#1f77b4')
+            ax_diffs[0, count_ax].plot(times_adapt, diff_adapt, label='SE=False, A=True', color='red', linestyle='--')
+            ax_diffs[0, count_ax].plot(
+                times_SE_adapt, diff_SE_adapt, label='SE=True, A=True', color='limegreen', linestyle='-.'
+            )
+            ax_diffs[0, count_ax].axvline(x=t_switch_SE, linestyle='--', linewidth=0.5, color='k', label='Switch')
             ax_diffs[0, count_ax].legend(frameon=False, fontsize=6, loc='lower left')
             ax_diffs[0, count_ax].set_yscale('symlog', linthresh=1e-5)
+            ax_diffs[0, count_ax].tick_params(axis='both', which='major', labelsize=6)
             if count_ax == 0:
                 ax_diffs[0, count_ax].set_ylabel('Difference $v_{C}-V_{ref}$', fontsize=6)
 
@@ -488,110 +488,128 @@ def differences_over_time(dt_list, problem, sweeper, V_ref, cwd='./'):
                 ax_diffs[0, count_ax].legend(frameon=False, fontsize=6, loc='upper right')
 
             ax_diffs[1, count_ax].plot(
-                dt_FT[:, 0], dt_FT[:, 1], label=r'$\Delta t$ - SE=F, A=T', color='red', linestyle='--'
+                dt_adapt[:, 0], dt_adapt[:, 1], label=r'$\Delta t$ - SE=F, A=T', color='red', linestyle='--'
             )
             ax_diffs[1, count_ax].plot([None], [None], label='Restart - SE=F, A=T', color='grey', linestyle='-.')
 
-            for i in range(len(restart_FT)):
-                if restart_FT[i, 1] > 0:
-                    ax_diffs[1, count_ax].axvline(restart_FT[i, 0], color='grey', linestyle='-.')
+            for i in range(len(restart_adapt)):
+                if restart_adapt[i, 1] > 0:
+                    ax_diffs[1, count_ax].axvline(restart_adapt[i, 0], color='grey', linestyle='-.')
 
             ax_diffs[1, count_ax].plot(
-                dt_TT[:, 0], dt_TT[:, 1], label=r'$ \Delta t$ - SE=T, A=T', color='limegreen', linestyle='-.'
+                dt_SE_adapt[:, 0],
+                dt_SE_adapt[:, 1],
+                label=r'$ \Delta t$ - SE=T, A=T',
+                color='limegreen',
+                linestyle='-.',
             )
             ax_diffs[1, count_ax].plot([None], [None], label='Restart - SE=T, A=T', color='black', linestyle='-.')
 
-            for i in range(len(restart_TT)):
-                if restart_TT[i, 1] > 0:
-                    ax_diffs[1, count_ax].axvline(restart_TT[i, 0], color='black', linestyle='-.')
+            for i in range(len(restart_SE_adapt)):
+                if restart_SE_adapt[i, 1] > 0:
+                    ax_diffs[1, count_ax].axvline(restart_SE_adapt[i, 0], color='black', linestyle='-.')
 
             ax_diffs[1, count_ax].set_xlabel('Time', fontsize=6)
+            ax_diffs[1, count_ax].tick_params(axis='both', which='major', labelsize=6)
             if count_ax == 0:
-                ax_diffs[1, count_ax].set_ylabel(r'$\Delta t_{adapted}$', fontsize=6)
+                ax_diffs[1, count_ax].set_ylabel(r'$\Delta t_\mathrm{adapted}$', fontsize=6)
 
-            ax_diffs[1, count_ax].legend(frameon=True, fontsize=6, loc='upper left')
+            ax_diffs[1, count_ax].set_yscale('log', base=10)
+            ax_diffs[1, count_ax].legend(frameon=True, fontsize=6, loc='lower left')
 
         else:
-            ax_diffs[0].set_title(r'$\Delta t$={}'.format(dt_item))
-            ax_diffs[0].plot(times_TF, diff_TF, label='SE=True', color='#ff7f0e')
-            ax_diffs[0].plot(times_FF, diff_FF, label='SE=False', color='#1f77b4')
-            ax_diffs[0].plot(times_FT, diff_FT, label='SE=False, A=True', color='red', linestyle='--')
-            ax_diffs[0].plot(times_TT, diff_TT, label='SE=True, A=True', color='limegreen', linestyle='-.')
-            ax_diffs[0].axvline(x=t_switch_TF, linestyle='--', linewidth=0.5, color='k', label='Switch')
-            ax_diffs[0].legend(frameon=False, fontsize=6, loc='lower left')
+            ax_diffs[0].set_title(r'$\Delta t$=%s' % dt_item)
+            ax_diffs[0].plot(times_SE, diff_SE, label='SE=True', color='#ff7f0e')
+            ax_diffs[0].plot(times, diff, label='SE=False', color='#1f77b4')
+            ax_diffs[0].plot(times_adapt, diff_adapt, label='SE=False, A=True', color='red', linestyle='--')
+            ax_diffs[0].plot(times_SE_adapt, diff_SE_adapt, label='SE=True, A=True', color='limegreen', linestyle='-.')
+            ax_diffs[0].axvline(x=t_switch_SE, linestyle='--', linewidth=0.5, color='k', label='Switch')
+            ax_diffs[0].tick_params(axis='both', which='major', labelsize=6)
             ax_diffs[0].set_yscale('symlog', linthresh=1e-5)
             ax_diffs[0].set_ylabel('Difference $v_{C}-V_{ref}$', fontsize=6)
             ax_diffs[0].legend(frameon=False, fontsize=6, loc='center right')
 
-            ax_diffs[1].plot(dt_FT[:, 0], dt_FT[:, 1], label='SE=False, A=True', color='red', linestyle='--')
-            ax_diffs[1].plot(dt_TT[:, 0], dt_TT[:, 1], label='SE=True, A=True', color='limegreen', linestyle='-.')
+            ax_diffs[1].plot(dt_adapt[:, 0], dt_adapt[:, 1], label='SE=False, A=True', color='red', linestyle='--')
+            ax_diffs[1].plot(
+                dt_SE_adapt[:, 0], dt_SE_adapt[:, 1], label='SE=True, A=True', color='limegreen', linestyle='-.'
+            )
+            ax_diffs[1].tick_params(axis='both', which='major', labelsize=6)
             ax_diffs[1].set_xlabel('Time', fontsize=6)
-            ax_diffs[1].set_ylabel(r'$\Delta t_{adapted}$', fontsize=6)
+            ax_diffs[1].set_ylabel(r'$\Delta t_\mathrm{adapted}$', fontsize=6)
+            ax_diffs[1].set_yscale('log', base=10)
 
             ax_diffs[1].legend(frameon=False, fontsize=6, loc='upper right')
 
         count_ax += 1
 
     plt_helper.plt.tight_layout()
-    fig_diffs.savefig('data/difference_estimation_{}.png'.format(sweeper), dpi=300, bbox_inches='tight')
+    fig_diffs.savefig('data/diffs_over_time_{}.png'.format(sweeper), dpi=300, bbox_inches='tight')
     plt_helper.plt.close(fig_diffs)
 
 
 def iterations_over_time(dt_list, maxiter, problem, sweeper, cwd='./'):
     """
     Routine  to plot the number of iterations over time using switch estimator or not. Produces the iters_<sweeper_class>.png file
+
+    Args:
+        dt_list (list): list of considered (initial) step sizes
+        maxiter (np.int): maximum number of iterations
+        problem (pySDC.core.Problem.ptype): Problem class used to consider (the class name)
+        sweeper (pySDC.core.Sweeper.sweeper): Sweeper used to solve (the class name)
+        cwd (str): current working directory
     """
 
-    iters_time_TF = []
-    iters_time_FF = []
-    iters_time_TT = []
-    iters_time_FT = []
-    times_TF = []
-    times_FF = []
-    times_TT = []
-    times_FT = []
-    t_switches_TF = []
-    t_switches_adapt = []
+    iters_time_SE = []
+    iters_time = []
+    iters_time_SE_adapt = []
+    iters_time_adapt = []
+    times_SE = []
+    times = []
+    times_SE_adapt = []
+    times_adapt = []
+    t_switches_SE = []
+    t_switches_SE_adapt = []
 
     for dt_item in dt_list:
         f1 = open(cwd + 'data/battery_dt{}_USETrue_USAFalse_{}.dat'.format(dt_item, sweeper), 'rb')
-        stats_TF = dill.load(f1)
+        stats_SE = dill.load(f1)
         f1.close()
 
         f2 = open(cwd + 'data/battery_dt{}_USEFalse_USAFalse_{}.dat'.format(dt_item, sweeper), 'rb')
-        stats_FF = dill.load(f2)
+        stats = dill.load(f2)
         f2.close()
 
         f3 = open(cwd + 'data/battery_dt{}_USETrue_USATrue_{}.dat'.format(dt_item, sweeper), 'rb')
-        stats_TT = dill.load(f3)
+        stats_SE_adapt = dill.load(f3)
         f3.close()
 
         f4 = open(cwd + 'data/battery_dt{}_USEFalse_USATrue_{}.dat'.format(dt_item, sweeper), 'rb')
-        stats_FT = dill.load(f4)
+        stats_adapt = dill.load(f4)
         f4.close()
 
-        iter_counts_TF_val = get_sorted(stats_TF, type='niter', recomputed=False, sortby='time')
-        iter_counts_TT_val = get_sorted(stats_TT, type='niter', recomputed=False, sortby='time')
-        iter_counts_FT_val = get_sorted(stats_FT, type='niter', recomputed=False, sortby='time')
-        iter_counts_FF_val = get_sorted(stats_FF, type='niter', recomputed=False, sortby='time')
+        # consider iterations before restarts to see what happens
+        iter_counts_SE_val = get_sorted(stats_SE, type='niter')
+        iter_counts_SE_adapt_val = get_sorted(stats_SE_adapt, type='niter')
+        iter_counts_adapt_val = get_sorted(stats_adapt, type='niter')
+        iter_counts_val = get_sorted(stats, type='niter')
 
-        iters_time_TF.append([v[1] for v in iter_counts_TF_val])
-        iters_time_TT.append([v[1] for v in iter_counts_TT_val])
-        iters_time_FT.append([v[1] for v in iter_counts_FT_val])
-        iters_time_FF.append([v[1] for v in iter_counts_FF_val])
+        iters_time_SE.append([v[1] for v in iter_counts_SE_val])
+        iters_time_SE_adapt.append([v[1] for v in iter_counts_SE_adapt_val])
+        iters_time_adapt.append([v[1] for v in iter_counts_adapt_val])
+        iters_time.append([v[1] for v in iter_counts_val])
 
-        times_TF.append([v[0] for v in iter_counts_TF_val])
-        times_TT.append([v[0] for v in iter_counts_TT_val])
-        times_FT.append([v[0] for v in iter_counts_FT_val])
-        times_FF.append([v[0] for v in iter_counts_FF_val])
+        times_SE.append([v[0] for v in iter_counts_SE_val])
+        times_SE_adapt.append([v[0] for v in iter_counts_SE_adapt_val])
+        times_adapt.append([v[0] for v in iter_counts_adapt_val])
+        times.append([v[0] for v in iter_counts_val])
 
-        val_switch_TF = get_sorted(stats_TF, type='switch1', sortby='time')
-        t_switch_TF = [v[1] for v in val_switch_TF]
-        t_switches_TF.append(t_switch_TF[-1])
+        switches_SE = get_recomputed(stats_SE, type='switch', sortby='time')
+        t_switch_SE = [v[1] for v in switches_SE]
+        t_switches_SE.append(t_switch_SE[-1])
 
-        val_switch_TT = get_sorted(stats_TT, type='switch1', sortby='time')
-        t_switch_adapt = [v[1] for v in val_switch_TT]
-        t_switches_adapt.append(t_switch_adapt[-1])
+        switches_SE_adapt = get_recomputed(stats_SE_adapt, type='switch', sortby='time')
+        t_switch_SE_adapt = [v[1] for v in switches_SE_adapt]
+        t_switches_SE_adapt.append(t_switch_SE_adapt[-1])
 
     if len(dt_list) > 1:
         setup_mpl()
@@ -599,18 +617,15 @@ def iterations_over_time(dt_list, maxiter, problem, sweeper, cwd='./'):
             nrows=1, ncols=len(dt_list), figsize=(2 * len(dt_list) - 1, 3), sharex='col', sharey='row'
         )
         for col in range(len(dt_list)):
-            ax_iter_all[col].plot(times_FF[col], iters_time_FF[col], label='SE=F, A=F')
-            ax_iter_all[col].plot(times_TF[col], iters_time_TF[col], label='SE=T, A=F')
-            ax_iter_all[col].plot(times_TT[col], iters_time_TT[col], '--', label='SE=T, A=T')
-            ax_iter_all[col].plot(times_FT[col], iters_time_FT[col], '--', label='SE=F, A=T')
-            ax_iter_all[col].axvline(x=t_switches_TF[col], linestyle='--', linewidth=0.5, color='k', label='Switch')
-            if t_switches_adapt[col] != t_switches_TF[col]:
-                ax_iter_all[col].axvline(
-                    x=t_switches_adapt[col], linestyle='--', linewidth=0.5, color='k', label='Switch'
-                )
-            ax_iter_all[col].set_title('dt={}'.format(dt_list[col]))
+            ax_iter_all[col].plot(times[col], iters_time[col], label='SE=F, A=F')
+            ax_iter_all[col].plot(times_SE[col], iters_time_SE[col], label='SE=T, A=F')
+            ax_iter_all[col].plot(times_SE_adapt[col], iters_time_SE_adapt[col], '--', label='SE=T, A=T')
+            ax_iter_all[col].plot(times_adapt[col], iters_time_adapt[col], '--', label='SE=F, A=T')
+            ax_iter_all[col].axvline(x=t_switches_SE[col], linestyle='--', linewidth=0.5, color='k', label='Switch')
+            ax_iter_all[col].set_title(r'$\Delta t_\mathrm{initial}$=%s' % dt_list[col])
             ax_iter_all[col].set_ylim(0, maxiter + 2)
             ax_iter_all[col].set_xlabel('Time', fontsize=6)
+            ax_iter_all[col].tick_params(axis='both', which='major', labelsize=6)
 
             if col == 0:
                 ax_iter_all[col].set_ylabel('Number iterations', fontsize=6)
@@ -620,16 +635,15 @@ def iterations_over_time(dt_list, maxiter, problem, sweeper, cwd='./'):
         setup_mpl()
         fig_iter_all, ax_iter_all = plt_helper.plt.subplots(nrows=1, ncols=1, figsize=(3, 3))
 
-        ax_iter_all.plot(times_FF[0], iters_time_FF[0], label='SE=False')
-        ax_iter_all.plot(times_TF[0], iters_time_TF[0], label='SE=True')
-        ax_iter_all.plot(times_TT[0], iters_time_TT[0], '--', label='SE=T, A=T')
-        ax_iter_all.plot(times_FT[0], iters_time_FT[0], '--', label='SE=F, A=T')
-        ax_iter_all.axvline(x=t_switches_TF[0], linestyle='--', linewidth=0.5, color='k', label='Switch')
-        if t_switches_adapt[0] != t_switches_TF[0]:
-            ax_iter_all.axvline(x=t_switches_adapt[0], linestyle='--', linewidth=0.5, color='k', label='Switch')
-        ax_iter_all.set_title('dt={}'.format(dt_list[0]))
+        ax_iter_all.plot(times[0], iters_time[0], label='SE=False')
+        ax_iter_all.plot(times_SE[0], iters_time_SE[0], label='SE=True')
+        ax_iter_all.plot(times_SE_adapt[0], iters_time_SE_adapt[0], '--', label='SE=T, A=T')
+        ax_iter_all.plot(times_adapt[0], iters_time_adapt[0], '--', label='SE=F, A=T')
+        ax_iter_all.axvline(x=t_switches_SE[0], linestyle='--', linewidth=0.5, color='k', label='Switch')
+        ax_iter_all.set_title(r'$\Delta t_\mathrm{initial}$=%s' % dt_list[0])
         ax_iter_all.set_ylim(0, maxiter + 2)
         ax_iter_all.set_xlabel('Time', fontsize=6)
+        ax_iter_all.tick_params(axis='both', which='major', labelsize=6)
 
         ax_iter_all.set_ylabel('Number iterations', fontsize=6)
         ax_iter_all.legend(frameon=False, fontsize=6, loc='upper right')
@@ -640,4 +654,4 @@ def iterations_over_time(dt_list, maxiter, problem, sweeper, cwd='./'):
 
 
 if __name__ == "__main__":
-    check()
+    run()
diff --git a/pySDC/projects/PinTSimE/estimation_check_2capacitors.py b/pySDC/projects/PinTSimE/estimation_check_2capacitors.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cf747af665d1fb89b483e73c3b9df6c27cd4596
--- /dev/null
+++ b/pySDC/projects/PinTSimE/estimation_check_2capacitors.py
@@ -0,0 +1,236 @@
+import numpy as np
+import dill
+from pathlib import Path
+
+from pySDC.helpers.stats_helper import get_sorted
+from pySDC.core.Collocation import CollBase as Collocation
+from pySDC.implementations.problem_classes.Battery import battery_n_capacitors
+from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
+from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
+from pySDC.projects.PinTSimE.battery_model import controller_run, generate_description, get_recomputed, log_data
+from pySDC.projects.PinTSimE.piline_model import setup_mpl
+from pySDC.projects.PinTSimE.battery_2capacitors_model import (
+    check_solution,
+    proof_assertions_description,
+    proof_assertions_time,
+)
+import pySDC.helpers.plot_helper as plt_helper
+
+from pySDC.projects.PinTSimE.switch_estimator import SwitchEstimator
+
+
+def run(cwd='./'):
+    """
+    Routine to check the differences between using a switch estimator or not
+
+    Args:
+        cwd (str): current working directory
+    """
+
+    dt_list = [4e-1, 4e-2, 4e-3]
+    t0 = 0.0
+    Tend = 3.5
+
+    problem_classes = [battery_n_capacitors]
+    sweeper_classes = [imex_1st_order]
+
+    ncapacitors = 2
+    alpha = 5.0
+    V_ref = np.array([1.0, 1.0])
+    C = np.array([1.0, 1.0])
+
+    use_switch_estimator = [True, False]
+    restarts_all = []
+    restarts_dict = dict()
+    for problem, sweeper in zip(problem_classes, sweeper_classes):
+        for dt_item in dt_list:
+            for use_SE in use_switch_estimator:
+                description, controller_params = generate_description(
+                    dt_item,
+                    problem,
+                    sweeper,
+                    log_data,
+                    False,
+                    use_SE,
+                    ncapacitors,
+                    alpha,
+                    V_ref,
+                    C,
+                )
+
+                # Assertions
+                proof_assertions_description(description, False, use_SE)
+
+                proof_assertions_time(dt_item, Tend, V_ref, alpha)
+
+                stats = controller_run(description, controller_params, False, use_SE, t0, Tend)
+
+                if use_SE:
+                    switches = get_recomputed(stats, type='switch', sortby='time')
+                    assert len(switches) >= 2, f"Expected at least 2 switches for dt: {dt_item}, got {len(switches)}!"
+
+                    check_solution(stats, dt_item, use_SE)
+
+                fname = 'data/{}_dt{}_USE{}.dat'.format(problem.__name__, dt_item, use_SE)
+                f = open(fname, 'wb')
+                dill.dump(stats, f)
+                f.close()
+
+                if use_SE:
+                    restarts_dict[dt_item] = np.array(get_sorted(stats, type='restart', recomputed=None))
+                    restarts = restarts_dict[dt_item][:, 1]
+                    restarts_all.append(np.sum(restarts))
+                    print("Restarts for dt: ", dt_item, " -- ", np.sum(restarts))
+
+    V_ref = description['problem_params']['V_ref']
+
+    val_switch_all = []
+    diff_true_all1 = []
+    diff_false_all_before1 = []
+    diff_false_all_after1 = []
+    diff_true_all2 = []
+    diff_false_all_before2 = []
+    diff_false_all_after2 = []
+    restarts_dt_switch1 = []
+    restarts_dt_switch2 = []
+    for dt_item in dt_list:
+        f1 = open(cwd + 'data/{}_dt{}_USETrue.dat'.format(problem.__name__, dt_item), 'rb')
+        stats_true = dill.load(f1)
+        f1.close()
+
+        f2 = open(cwd + 'data/{}_dt{}_USEFalse.dat'.format(problem.__name__, dt_item), 'rb')
+        stats_false = dill.load(f2)
+        f2.close()
+
+        switches = get_recomputed(stats_true, type='switch', sortby='time')
+        t_switch = [v[1] for v in switches]
+
+        val_switch_all.append([t_switch[0], t_switch[1]])
+
+        vC1_true = [me[1][1] for me in get_sorted(stats_true, type='u', recomputed=False)]
+        vC2_true = [me[1][2] for me in get_sorted(stats_true, type='u', recomputed=False)]
+        vC1_false = [me[1][1] for me in get_sorted(stats_false, type='u', recomputed=False)]
+        vC2_false = [me[1][2] for me in get_sorted(stats_false, type='u', recomputed=False)]
+
+        diff_true1 = vC1_true - V_ref[0]
+        diff_true2 = vC2_true - V_ref[1]
+        diff_false1 = vC1_false - V_ref[0]
+        diff_false2 = vC2_false - V_ref[1]
+
+        t_true = [me[0] for me in get_sorted(stats_true, type='u', recomputed=False)]
+        t_false = [me[0] for me in get_sorted(stats_false, type='u', recomputed=False)]
+
+        diff_true_all1.append(
+            [diff_true1[m] for m in range(len(t_true)) if np.isclose(t_true[m], t_switch[0], atol=1e-15)]
+        )
+        diff_true_all2.append(
+            [diff_true2[m] for m in range(len(t_true)) if np.isclose(t_true[m], t_switch[1], atol=1e-15)]
+        )
+
+        diff_false_all_before1.append(
+            [diff_false1[m - 1] for m in range(1, len(t_false)) if t_false[m - 1] < t_switch[0] < t_false[m]]
+        )
+        diff_false_all_after1.append(
+            [diff_false1[m] for m in range(1, len(t_false)) if t_false[m - 1] < t_switch[0] < t_false[m]]
+        )
+
+        diff_false_all_before2.append(
+            [diff_false2[m - 1] for m in range(1, len(t_false)) if t_false[m - 1] < t_switch[1] < t_false[m]]
+        )
+        diff_false_all_after2.append(
+            [diff_false2[m] for m in range(1, len(t_false)) if t_false[m - 1] < t_switch[1] < t_false[m]]
+        )
+
+        restarts_dt = restarts_dict[dt_item]
+        restarts_dt_switch1.append(
+            [
+                np.sum(restarts_dt[0 : i - 1, 1])
+                for i in range(len(restarts_dt[:, 0]))
+                if np.isclose(restarts_dt[i, 0], t_switch[0], atol=1e-13)
+            ]
+        )
+        restarts_dt_switch2.append(
+            [
+                np.sum(restarts_dt[i - 2 :, 1])
+                for i in range(len(restarts_dt[:, 0]))
+                if np.isclose(restarts_dt[i, 0], t_switch[1], atol=1e-13)
+            ]
+        )
+
+        setup_mpl()
+        fig1, ax1 = plt_helper.plt.subplots(1, 1, figsize=(4.5, 3))
+        ax1.set_title('Time evolution of $v_{C_{1}}-V_{ref1}$')
+        ax1.plot(t_true, diff_true1, label='SE=True', color='#ff7f0e')
+        ax1.plot(t_false, diff_false1, label='SE=False', color='#1f77b4')
+        ax1.axvline(x=t_switch[0], linestyle='--', color='k', label='Switch1')
+        ax1.legend(frameon=False, fontsize=10, loc='lower left')
+        ax1.set_yscale('symlog', linthresh=1e-5)
+        ax1.set_xlabel('Time')
+
+        fig1.savefig('data/difference_estimation_vC1_dt{}.png'.format(dt_item), dpi=300, bbox_inches='tight')
+        plt_helper.plt.close(fig1)
+
+        setup_mpl()
+        fig2, ax2 = plt_helper.plt.subplots(1, 1, figsize=(4.5, 3))
+        ax2.set_title('Time evolution of $v_{C_{2}}-V_{ref2}$')
+        ax2.plot(t_true, diff_true2, label='SE=True', color='#ff7f0e')
+        ax2.plot(t_false, diff_false2, label='SE=False', color='#1f77b4')
+        ax2.axvline(x=t_switch[1], linestyle='--', color='k', label='Switch2')
+        ax2.legend(frameon=False, fontsize=10, loc='lower left')
+        ax2.set_yscale('symlog', linthresh=1e-5)
+        ax2.set_xlabel('Time')
+
+        fig2.savefig('data/difference_estimation_vC2_dt{}.png'.format(dt_item), dpi=300, bbox_inches='tight')
+        plt_helper.plt.close(fig2)
+
+    setup_mpl()
+    fig1, ax1 = plt_helper.plt.subplots(1, 1, figsize=(3, 3))
+    ax1.set_title("Difference $v_{C_{1}}-V_{ref1}$")
+    pos1 = ax1.plot(dt_list, diff_false_all_before1, 'rs-', label='SE=False - before switch1')
+    pos2 = ax1.plot(dt_list, diff_false_all_after1, 'bd-', label='SE=False - after switch1')
+    pos3 = ax1.plot(dt_list, diff_true_all1, 'kd-', label='SE=True')
+    ax1.set_xticks(dt_list)
+    ax1.set_xticklabels(dt_list)
+    ax1.set_xscale('log', base=10)
+    ax1.set_yscale('symlog', linthresh=1e-10)
+    ax1.set_ylim(-2, 2)
+    ax1.set_xlabel(r'$\Delta t$')
+
+    restart_ax = ax1.twinx()
+    restarts = restart_ax.plot(dt_list, restarts_dt_switch1, 'cs--', label='Restarts')
+    restart_ax.set_ylabel('Restarts')
+
+    lines = pos1 + pos2 + pos3 + restarts
+    labels = [l.get_label() for l in lines]
+    ax1.legend(lines, labels, frameon=False, fontsize=8, loc='center right')
+
+    fig1.savefig('data/diffs_estimation_vC1.png', dpi=300, bbox_inches='tight')
+    plt_helper.plt.close(fig1)
+
+    setup_mpl()
+    fig2, ax2 = plt_helper.plt.subplots(1, 1, figsize=(3, 3))
+    ax2.set_title("Difference $v_{C_{2}}-V_{ref2}$")
+    pos1 = ax2.plot(dt_list, diff_false_all_before2, 'rs-', label='SE=False - before switch2')
+    pos2 = ax2.plot(dt_list, diff_false_all_after2, 'bd-', label='SE=False - after switch2')
+    pos3 = ax2.plot(dt_list, diff_true_all2, 'kd-', label='SE=True')
+    ax2.set_xticks(dt_list)
+    ax2.set_xticklabels(dt_list)
+    ax2.set_xscale('log', base=10)
+    ax2.set_yscale('symlog', linthresh=1e-10)
+    ax2.set_ylim(-2, 2)
+    ax2.set_xlabel(r'$\Delta t$')
+
+    restart_ax = ax2.twinx()
+    restarts = restart_ax.plot(dt_list, restarts_dt_switch2, 'cs--', label='Restarts')
+    restart_ax.set_ylabel('Restarts')
+
+    lines = pos1 + pos2 + pos3 + restarts
+    labels = [l.get_label() for l in lines]
+    ax2.legend(lines, labels, frameon=False, fontsize=8, loc='center right')
+
+    fig2.savefig('data/diffs_estimation_vC2.png', dpi=300, bbox_inches='tight')
+    plt_helper.plt.close(fig2)
+
+
+if __name__ == "__main__":
+    run()
diff --git a/pySDC/projects/PinTSimE/estimation_check_extended.py b/pySDC/projects/PinTSimE/estimation_check_extended.py
deleted file mode 100644
index f075fe2de47cd65765f4a6c4b2b505b5bd3c0e53..0000000000000000000000000000000000000000
--- a/pySDC/projects/PinTSimE/estimation_check_extended.py
+++ /dev/null
@@ -1,291 +0,0 @@
-import numpy as np
-import dill
-from pathlib import Path
-
-from pySDC.helpers.stats_helper import get_sorted
-from pySDC.core.Collocation import CollBase as Collocation
-from pySDC.implementations.problem_classes.Battery_2Condensators import battery_2condensators
-from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
-from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
-from pySDC.implementations.transfer_classes.TransferMesh import mesh_to_mesh
-from pySDC.projects.PinTSimE.piline_model import setup_mpl
-from pySDC.projects.PinTSimE.battery_2condensators_model import log_data, proof_assertions_description
-import pySDC.helpers.plot_helper as plt_helper
-
-from pySDC.projects.PinTSimE.switch_estimator import SwitchEstimator
-
-
-def run(dt, use_switch_estimator=True):
-
-    # initialize level parameters
-    level_params = dict()
-    level_params['restol'] = 1e-13
-    level_params['dt'] = dt
-
-    # initialize sweeper parameters
-    sweeper_params = dict()
-    sweeper_params['quad_type'] = 'LOBATTO'
-    sweeper_params['num_nodes'] = 5
-    sweeper_params['QI'] = 'LU'  # For the IMEX sweeper, the LU-trick can be activated for the implicit part
-    sweeper_params['initial_guess'] = 'zero'
-
-    # initialize problem parameters
-    problem_params = dict()
-    problem_params['Vs'] = 5.0
-    problem_params['Rs'] = 0.5
-    problem_params['C1'] = 1.0
-    problem_params['C2'] = 1.0
-    problem_params['R'] = 1.0
-    problem_params['L'] = 1.0
-    problem_params['alpha'] = 5.0
-    problem_params['V_ref'] = np.array([1.0, 1.0])  # [V_ref1, V_ref2]
-    problem_params['set_switch'] = np.array([False, False], dtype=bool)
-    problem_params['t_switch'] = np.zeros(np.shape(problem_params['V_ref'])[0])
-
-    # initialize step parameters
-    step_params = dict()
-    step_params['maxiter'] = 20
-
-    # initialize controller parameters
-    controller_params = dict()
-    controller_params['logger_level'] = 20
-    controller_params['hook_class'] = log_data
-
-    # convergence controllers
-    convergence_controllers = dict()
-    if use_switch_estimator:
-        switch_estimator_params = {}
-        convergence_controllers[SwitchEstimator] = switch_estimator_params
-
-    # fill description dictionary for easy step instantiation
-    description = dict()
-    description['problem_class'] = battery_2condensators  # pass problem class
-    description['problem_params'] = problem_params  # pass problem parameters
-    description['sweeper_class'] = imex_1st_order  # pass sweeper
-    description['sweeper_params'] = sweeper_params  # pass sweeper parameters
-    description['level_params'] = level_params  # pass level parameters
-    description['step_params'] = step_params
-    description['space_transfer_class'] = mesh_to_mesh  # pass spatial transfer class
-
-    if use_switch_estimator:
-        description['convergence_controllers'] = convergence_controllers
-
-    proof_assertions_description(description, problem_params)
-
-    # set time parameters
-    t0 = 0.0
-    Tend = 3.5
-
-    # instantiate controller
-    controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description)
-
-    # get initial values on finest level
-    P = controller.MS[0].levels[0].prob
-    uinit = P.u_exact(t0)
-
-    # call main function to get things done...
-    uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
-
-    Path("data").mkdir(parents=True, exist_ok=True)
-    fname = 'data/battery_2condensators.dat'
-    f = open(fname, 'wb')
-    dill.dump(stats, f)
-    f.close()
-
-    # filter statistics by number of iterations
-    iter_counts = get_sorted(stats, type='niter', sortby='time')
-
-    # compute and print statistics
-    min_iter = 20
-    max_iter = 0
-
-    f = open('data/battery_2condensators_out.txt', 'w')
-    niters = np.array([item[1] for item in iter_counts])
-    out = '   Mean number of iterations: %4.2f' % np.mean(niters)
-    f.write(out + '\n')
-    print(out)
-    for item in iter_counts:
-        out = 'Number of iterations for time %4.2f: %1i' % item
-        f.write(out + '\n')
-        # print(out)
-        min_iter = min(min_iter, item[1])
-        max_iter = max(max_iter, item[1])
-
-    assert np.mean(niters) <= 12, "Mean number of iterations is too high, got %s" % np.mean(niters)
-    f.close()
-
-    return stats, description
-
-
-def check(cwd='./'):
-    """
-    Routine to check the differences between using a switch estimator or not
-    """
-
-    dt_list = [4e-1, 4e-2, 4e-3]
-    use_switch_estimator = [True, False]
-    restarts_all = []
-    restarts_dict = dict()
-    for dt_item in dt_list:
-        for item in use_switch_estimator:
-            stats, description = run(dt=dt_item, use_switch_estimator=item)
-
-            fname = 'data/battery_2condensators_dt{}_USE{}.dat'.format(dt_item, item)
-            f = open(fname, 'wb')
-            dill.dump(stats, f)
-            f.close()
-
-            if item:
-                restarts_dict[dt_item] = np.array(get_sorted(stats, type='restart', recomputed=None, sortby='time'))
-                restarts = restarts_dict[dt_item][:, 1]
-                restarts_all.append(np.sum(restarts))
-                print("Restarts for dt: ", dt_item, " -- ", np.sum(restarts))
-
-    V_ref = description['problem_params']['V_ref']
-
-    val_switch_all = []
-    diff_true_all1 = []
-    diff_false_all_before1 = []
-    diff_false_all_after1 = []
-    diff_true_all2 = []
-    diff_false_all_before2 = []
-    diff_false_all_after2 = []
-    restarts_dt_switch1 = []
-    restarts_dt_switch2 = []
-    for dt_item in dt_list:
-        f1 = open(cwd + 'data/battery_2condensators_dt{}_USETrue.dat'.format(dt_item), 'rb')
-        stats_true = dill.load(f1)
-        f1.close()
-
-        f2 = open(cwd + 'data/battery_2condensators_dt{}_USEFalse.dat'.format(dt_item), 'rb')
-        stats_false = dill.load(f2)
-        f2.close()
-
-        val_switch1 = get_sorted(stats_true, type='switch1', sortby='time')
-        val_switch2 = get_sorted(stats_true, type='switch2', sortby='time')
-        t_switch1 = [v[1] for v in val_switch1]
-        t_switch2 = [v[1] for v in val_switch2]
-
-        t_switch1 = t_switch1[-1]
-        t_switch2 = t_switch2[-1]
-
-        val_switch_all.append([t_switch1, t_switch2])
-
-        vC1_true = get_sorted(stats_true, type='voltage C1', recomputed=False, sortby='time')
-        vC2_true = get_sorted(stats_true, type='voltage C2', recomputed=False, sortby='time')
-        vC1_false = get_sorted(stats_false, type='voltage C1', sortby='time')
-        vC2_false = get_sorted(stats_false, type='voltage C2', sortby='time')
-
-        diff_true1 = [v[1] - V_ref[0] for v in vC1_true]
-        diff_true2 = [v[1] - V_ref[1] for v in vC2_true]
-        diff_false1 = [v[1] - V_ref[0] for v in vC1_false]
-        diff_false2 = [v[1] - V_ref[1] for v in vC2_false]
-
-        times_true1 = [v[0] for v in vC1_true]
-        times_true2 = [v[0] for v in vC2_true]
-        times_false1 = [v[0] for v in vC1_false]
-        times_false2 = [v[0] for v in vC2_false]
-
-        for m in range(len(times_true1)):
-            if np.round(times_true1[m], 15) == np.round(t_switch1, 15):
-                diff_true_all1.append(diff_true1[m])
-
-        for m in range(len(times_true2)):
-            if np.round(times_true2[m], 15) == np.round(t_switch2, 15):
-                diff_true_all2.append(diff_true2[m])
-
-        for m in range(1, len(times_false1)):
-            if times_false1[m - 1] < t_switch1 < times_false1[m]:
-                diff_false_all_before1.append(diff_false1[m - 1])
-                diff_false_all_after1.append(diff_false1[m])
-
-        for m in range(1, len(times_false2)):
-            if times_false2[m - 1] < t_switch2 < times_false2[m]:
-                diff_false_all_before2.append(diff_false2[m - 1])
-                diff_false_all_after2.append(diff_false2[m])
-
-        restarts_dt = restarts_dict[dt_item]
-        for i in range(len(restarts_dt[:, 0])):
-            if round(restarts_dt[i, 0], 13) == round(t_switch1, 13):
-                restarts_dt_switch1.append(np.sum(restarts_dt[0:i, 1]))
-
-            if round(restarts_dt[i, 0], 13) == round(t_switch2, 13):
-                restarts_dt_switch2.append(np.sum(restarts_dt[i - 1 :, 1]))
-
-        setup_mpl()
-        fig1, ax1 = plt_helper.plt.subplots(1, 1, figsize=(4.5, 3))
-        ax1.set_title('Time evolution of $v_{C_{1}}-V_{ref1}$')
-        ax1.plot(times_true1, diff_true1, label='SE=True', color='#ff7f0e')
-        ax1.plot(times_false1, diff_false1, label='SE=False', color='#1f77b4')
-        ax1.axvline(x=t_switch1, linestyle='--', color='k', label='Switch1')
-        ax1.legend(frameon=False, fontsize=10, loc='lower left')
-        ax1.set_yscale('symlog', linthresh=1e-5)
-        ax1.set_xlabel('Time')
-
-        fig1.savefig('data/difference_estimation_vC1_dt{}.png'.format(dt_item), dpi=300, bbox_inches='tight')
-        plt_helper.plt.close(fig1)
-
-        setup_mpl()
-        fig2, ax2 = plt_helper.plt.subplots(1, 1, figsize=(4.5, 3))
-        ax2.set_title('Time evolution of $v_{C_{2}}-V_{ref2}$')
-        ax2.plot(times_true2, diff_true2, label='SE=True', color='#ff7f0e')
-        ax2.plot(times_false2, diff_false2, label='SE=False', color='#1f77b4')
-        ax2.axvline(x=t_switch2, linestyle='--', color='k', label='Switch2')
-        ax2.legend(frameon=False, fontsize=10, loc='lower left')
-        ax2.set_yscale('symlog', linthresh=1e-5)
-        ax2.set_xlabel('Time')
-
-        fig2.savefig('data/difference_estimation_vC2_dt{}.png'.format(dt_item), dpi=300, bbox_inches='tight')
-        plt_helper.plt.close(fig2)
-
-    setup_mpl()
-    fig1, ax1 = plt_helper.plt.subplots(1, 1, figsize=(3, 3))
-    ax1.set_title("Difference $v_{C_{1}}-V_{ref1}$")
-    pos1 = ax1.plot(dt_list, diff_false_all_before1, 'rs-', label='SE=False - before switch1')
-    pos2 = ax1.plot(dt_list, diff_false_all_after1, 'bd-', label='SE=False - after switch1')
-    pos3 = ax1.plot(dt_list, diff_true_all1, 'kd-', label='SE=True')
-    ax1.set_xticks(dt_list)
-    ax1.set_xticklabels(dt_list)
-    ax1.set_xscale('log', base=10)
-    ax1.set_yscale('symlog', linthresh=1e-10)
-    ax1.set_ylim(-2, 2)
-    ax1.set_xlabel(r'$\Delta t$')
-
-    restart_ax = ax1.twinx()
-    restarts = restart_ax.plot(dt_list, restarts_dt_switch1, 'cs--', label='Restarts')
-    restart_ax.set_ylabel('Restarts')
-
-    lines = pos1 + pos2 + pos3 + restarts
-    labels = [l.get_label() for l in lines]
-    ax1.legend(lines, labels, frameon=False, fontsize=8, loc='center right')
-
-    fig1.savefig('data/diffs_estimation_vC1.png', dpi=300, bbox_inches='tight')
-    plt_helper.plt.close(fig1)
-
-    setup_mpl()
-    fig2, ax2 = plt_helper.plt.subplots(1, 1, figsize=(3, 3))
-    ax2.set_title("Difference $v_{C_{2}}-V_{ref2}$")
-    pos1 = ax2.plot(dt_list, diff_false_all_before2, 'rs-', label='SE=False - before switch2')
-    pos2 = ax2.plot(dt_list, diff_false_all_after2, 'bd-', label='SE=False - after switch2')
-    pos3 = ax2.plot(dt_list, diff_true_all2, 'kd-', label='SE=True')
-    ax2.set_xticks(dt_list)
-    ax2.set_xticklabels(dt_list)
-    ax2.set_xscale('log', base=10)
-    ax2.set_yscale('symlog', linthresh=1e-10)
-    ax2.set_ylim(-2, 2)
-    ax2.set_xlabel(r'$\Delta t$')
-
-    restart_ax = ax2.twinx()
-    restarts = restart_ax.plot(dt_list, restarts_dt_switch2, 'cs--', label='Restarts')
-    restart_ax.set_ylabel('Restarts')
-
-    lines = pos1 + pos2 + pos3 + restarts
-    labels = [l.get_label() for l in lines]
-    ax2.legend(lines, labels, frameon=False, fontsize=8, loc='center right')
-
-    fig2.savefig('data/diffs_estimation_vC2.png', dpi=300, bbox_inches='tight')
-    plt_helper.plt.close(fig2)
-
-
-if __name__ == "__main__":
-    check()
diff --git a/pySDC/projects/PinTSimE/switch_estimator.py b/pySDC/projects/PinTSimE/switch_estimator.py
index 67addb60c2ea5f532aa02dbcdb1b81a2fd02fd0a..f31a290ea35f6e9b44560fc4134a88abeb30cbcd 100644
--- a/pySDC/projects/PinTSimE/switch_estimator.py
+++ b/pySDC/projects/PinTSimE/switch_estimator.py
@@ -2,7 +2,7 @@ import numpy as np
 import scipy as sp
 
 from pySDC.core.Collocation import CollBase
-from pySDC.core.ConvergenceController import ConvergenceController
+from pySDC.core.ConvergenceController import ConvergenceController, Status
 
 
 class SwitchEstimator(ConvergenceController):
@@ -31,13 +31,34 @@ class SwitchEstimator(ConvergenceController):
             num_nodes=description['sweeper_params']['num_nodes'],
             quad_type=description['sweeper_params']['quad_type'],
         )
-        self.coll_nodes_local = coll.nodes
-        self.switch_detected = False
-        self.switch_detected_step = False
-        self.t_switch = None
-        self.count_switches = 0
-        self.dt_initial = description['level_params']['dt']
-        return {'control_order': 100, **params}
+
+        defaults = {
+            'control_order': 100,
+            'tol': description['level_params']['dt'],
+            'coll_nodes': coll.nodes,
+            'dt_initial': description['level_params']['dt'],
+        }
+        return {**defaults, **params}
+
+    def setup_status_variables(self, controller, **kwargs):
+        """
+        Adds switching specific variables to status variables.
+
+        Args:
+            controller (pySDC.Controller): The controller
+        """
+
+        self.status = Status(['t_switch', 'switch_detected', 'switch_detected_step'])
+
+    def reset_status_variables(self, controller, **kwargs):
+        """
+        Resets status variables.
+
+        Args:
+            controller (pySDC.Controller): The controller
+        """
+
+        self.setup_status_variables(controller, **kwargs)
 
     def get_new_step_size(self, controller, S):
         """
@@ -51,76 +72,62 @@ class SwitchEstimator(ConvergenceController):
             None
         """
 
-        self.switch_detected = False  # reset between steps
-
         L = S.levels[0]
 
-        if not type(L.prob.params.V_ref) == int and not type(L.prob.params.V_ref) == float:
-            # if V_ref is not a scalar, but an (np.)array
-            V_ref = np.zeros(np.shape(L.prob.params.V_ref)[0], dtype=float)
-            for m in range(np.shape(L.prob.params.V_ref)[0]):
-                V_ref[m] = L.prob.params.V_ref[m]
-        else:
-            V_ref = np.array([L.prob.params.V_ref], dtype=float)
+        if S.status.iter == S.params.maxiter:
 
-        if S.status.iter > 0 and self.count_switches < np.shape(V_ref)[0]:
-            for m in range(len(L.u)):
-                if L.u[m][self.count_switches + 1] - V_ref[self.count_switches] <= 0:
-                    self.switch_detected = True
-                    m_guess = m - 1
-                    break
+            self.status.switch_detected, m_guess, vC_switch = L.prob.get_switching_info(L.u, L.time)
 
-            if self.switch_detected:
-                t_interp = [L.time + L.dt * self.coll_nodes_local[m] for m in range(len(self.coll_nodes_local))]
-
-                vC_switch = []
-                for m in range(1, len(L.u)):
-                    vC_switch.append(L.u[m][self.count_switches + 1] - V_ref[self.count_switches])
+            if self.status.switch_detected:
+                t_interp = [L.time + L.dt * self.params.coll_nodes[m] for m in range(len(self.params.coll_nodes))]
 
                 # only find root if vc_switch[0], vC_switch[-1] have opposite signs (intermediate value theorem)
                 if vC_switch[0] * vC_switch[-1] < 0:
 
-                    self.t_switch = self.get_switch(t_interp, vC_switch, m_guess)
+                    self.status.t_switch = self.get_switch(t_interp, vC_switch, m_guess)
 
-                    # if the switch is not find, we need to do ... ?
-                    if L.time < self.t_switch < L.time + L.dt:
-                        r = 1
-                        tol = self.dt_initial / r
+                    if L.time < self.status.t_switch < L.time + L.dt:
 
-                        if not np.isclose(self.t_switch - L.time, L.dt, atol=tol):
-                            dt_search = self.t_switch - L.time
+                        dt_switch = self.status.t_switch - L.time
+                        if not np.isclose(self.status.t_switch - L.time, L.dt, atol=self.params.tol):
+                            self.log(
+                                f"Located Switch at time {self.status.t_switch:.6f} is outside the range of tol={self.params.tol:.4e}",
+                                S,
+                            )
 
                         else:
-                            print('Switch located at time: {}'.format(self.t_switch))
-                            dt_search = self.t_switch - L.time
-                            L.prob.params.set_switch[self.count_switches] = self.switch_detected
-                            L.prob.params.t_switch[self.count_switches] = self.t_switch
+                            self.log(
+                                f"Switch located at time {self.status.t_switch:.6f} inside tol={self.params.tol:.4e}", S
+                            )
+
+                            L.prob.t_switch = self.status.t_switch
                             controller.hooks[0].add_to_stats(
                                 process=S.status.slot,
                                 time=L.time,
                                 level=L.level_index,
                                 iter=0,
                                 sweep=L.status.sweep,
-                                type='switch{}'.format(self.count_switches + 1),
-                                value=self.t_switch,
+                                type='switch',
+                                value=self.status.t_switch,
                             )
 
-                            self.switch_detected_step = True
+                            L.prob.count_switches()
+                            self.status.switch_detected_step = True
 
                         dt_planned = L.status.dt_new if L.status.dt_new is not None else L.params.dt
 
                         # when a switch is found, time step to match with switch should be preferred
-                        if self.switch_detected:
-                            L.status.dt_new = dt_search
+                        if self.status.switch_detected:
+                            L.status.dt_new = dt_switch
 
                         else:
-                            L.status.dt_new = min([dt_planned, dt_search])
+                            L.status.dt_new = min([dt_planned, dt_switch])
 
                     else:
-                        self.switch_detected = False
+                        self.status.switch_detected = False
 
                 else:
-                    self.switch_detected = False
+                    self.status.switch_detected = False
 
     def determine_restart(self, controller, S):
         """
@@ -134,8 +141,7 @@ class SwitchEstimator(ConvergenceController):
             None
         """
 
-        if self.switch_detected:
-            print("Restart")
+        if self.status.switch_detected:
             S.status.restart = True
             S.status.force_done = True
 
@@ -156,13 +162,9 @@ class SwitchEstimator(ConvergenceController):
 
         L = S.levels[0]
 
-        if self.switch_detected_step:
-            if L.prob.params.set_switch[self.count_switches] and L.time + L.dt >= self.t_switch:
-                self.count_switches += 1
-                self.t_switch = None
-                self.switch_detected_step = False
-
-                L.status.dt_new = self.dt_initial
+        if self.status.switch_detected_step:
+            if L.time + L.dt >= self.params.t_switch:
+                L.status.dt_new = L.status.dt_new if L.status.dt_new is not None else L.params.dt
 
         super(SwitchEstimator, self).post_step_processing(controller, S)
 
diff --git a/pySDC/tests/test_projects/test_pintsime/test_battery_2capacitors_model.py b/pySDC/tests/test_projects/test_pintsime/test_battery_2capacitors_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb7dccd0cf28d0e04893fd10242b85ba55030669
--- /dev/null
+++ b/pySDC/tests/test_projects/test_pintsime/test_battery_2capacitors_model.py
@@ -0,0 +1,8 @@
+import pytest
+
+
+@pytest.mark.base
+def test_main():
+    from pySDC.projects.PinTSimE.battery_2capacitors_model import run
+
+    run()
diff --git a/pySDC/tests/test_projects/test_pintsime/test_battery_2condensators_model.py b/pySDC/tests/test_projects/test_pintsime/test_battery_2condensators_model.py
deleted file mode 100644
index 4965514ebaaf32bae3b1049fa24f608d8b7b1b27..0000000000000000000000000000000000000000
--- a/pySDC/tests/test_projects/test_pintsime/test_battery_2condensators_model.py
+++ /dev/null
@@ -1,8 +0,0 @@
-import pytest
-
-
-@pytest.mark.base
-def test_main():
-    from pySDC.projects.PinTSimE.battery_2condensators_model import main
-
-    main()
diff --git a/pySDC/tests/test_projects/test_pintsime/test_estimation_check.py b/pySDC/tests/test_projects/test_pintsime/test_estimation_check.py
index 375c365982e0e327ba827559df6de46d8709af15..01ec3fed73b7dcb01dd284afd6d9468aad0a6939 100644
--- a/pySDC/tests/test_projects/test_pintsime/test_estimation_check.py
+++ b/pySDC/tests/test_projects/test_pintsime/test_estimation_check.py
@@ -3,6 +3,6 @@ import pytest
 
 @pytest.mark.base
 def test_main():
-    from pySDC.projects.PinTSimE.estimation_check import check
+    from pySDC.projects.PinTSimE.estimation_check import run
 
-    check()
+    run()
diff --git a/pySDC/tests/test_projects/test_pintsime/test_estimation_check_2capacitors.py b/pySDC/tests/test_projects/test_pintsime/test_estimation_check_2capacitors.py
new file mode 100644
index 0000000000000000000000000000000000000000..09751a7897487e72518bcf75ebc16aceb6fef6f3
--- /dev/null
+++ b/pySDC/tests/test_projects/test_pintsime/test_estimation_check_2capacitors.py
@@ -0,0 +1,8 @@
+import pytest
+
+
+@pytest.mark.base
+def test_main():
+    from pySDC.projects.PinTSimE.estimation_check_2capacitors import run
+
+    run()
diff --git a/pySDC/tests/test_projects/test_pintsime/test_estimation_check_extended.py b/pySDC/tests/test_projects/test_pintsime/test_estimation_check_extended.py
deleted file mode 100644
index 03f44a23e31d108cbfa2de1de1e7017204dcfdd0..0000000000000000000000000000000000000000
--- a/pySDC/tests/test_projects/test_pintsime/test_estimation_check_extended.py
+++ /dev/null
@@ -1,8 +0,0 @@
-import pytest
-
-
-@pytest.mark.base
-def test_main():
-    from pySDC.projects.PinTSimE.estimation_check_extended import check
-
-    check()