diff --git a/pySDC/implementations/problem_classes/Auzinger_implicit.py b/pySDC/implementations/problem_classes/Auzinger_implicit.py index 708ee4b3b7c543acc025df0ec4076083730c0348..93e749064c36f2fd1e483ef436e2f8e1c72ecd4e 100644 --- a/pySDC/implementations/problem_classes/Auzinger_implicit.py +++ b/pySDC/implementations/problem_classes/Auzinger_implicit.py @@ -6,7 +6,7 @@ from pySDC.implementations.datatype_classes.mesh import mesh # noinspection PyUnusedLocal class auzinger(ptype): - """ + r""" This class implements the Auzinger equation as initial value problem. It can be found in doi.org/10.2140/camcos.2015.10.1. The system of two ordinary differential equations (ODEs) is given by diff --git a/pySDC/implementations/problem_classes/Battery.py b/pySDC/implementations/problem_classes/Battery.py index 1c2fc59281617581fa78338418a4dc71cb5645c0..361533e973ef2ca28fcf10a2f03c6b811a44bfd3 100644 --- a/pySDC/implementations/problem_classes/Battery.py +++ b/pySDC/implementations/problem_classes/Battery.py @@ -77,8 +77,9 @@ class battery_n_capacitors(ptype): 'nvars', 'ncapacitors', 'Vs', 'Rs', 'C', 'R', 'L', 'alpha', 'V_ref', localVars=locals(), readOnly=True ) - self.A = np.zeros((n + 1, n + 1)) self.switch_A, self.switch_f = self.get_problem_dict() + self.A = self.switch_A[0] + self.t_switch = None self.nswitches = 0 @@ -222,17 +223,18 @@ class battery_n_capacitors(ptype): Indicates if a switch is found or not. m_guess : int Index of collocation node inside one subinterval of where the discrete event was found. - vC_switch : list - Contains function values of switching condition (for interpolation). + state_function : list + Contains values of the state function (for interpolation). """ switch_detected = False m_guess = -100 break_flag = False - for m in range(1, len(u)): for k in range(1, self.nvars): - if u[m][k] - self.V_ref[k - 1] <= 0: + h_prev_node = u[m - 1][k] - self.V_ref[k - 1] + h_curr_node = u[m][k] - self.V_ref[k - 1] + if h_prev_node > 0 and h_curr_node <= 0: switch_detected = True m_guess = m - 1 k_detected = k @@ -242,9 +244,11 @@ class battery_n_capacitors(ptype): if break_flag: break - vC_switch = [u[m][k_detected] - self.V_ref[k_detected - 1] for m in range(1, len(u))] if switch_detected else [] + state_function = ( + [u[m][k_detected] - self.V_ref[k_detected - 1] for m in range(len(u))] if switch_detected else [] + ) - return switch_detected, m_guess, vC_switch + return switch_detected, m_guess, state_function def count_switches(self): """ @@ -262,7 +266,6 @@ class battery_n_capacitors(ptype): n = self.ncapacitors v = np.zeros(n + 1) v[0] = 1 - A, f = dict(), dict() A = {k: np.diag(-1 / (self.C[k] * self.R) * np.roll(v, k + 1)) for k in range(n)} A.update({n: np.diag(-(self.Rs + self.R) / self.L * v)}) @@ -273,8 +276,12 @@ class battery_n_capacitors(ptype): class battery(battery_n_capacitors): r""" - Example implementing the battery drain model with :math:`N=1` capacitor, inherits from battery_n_capacitors. The ODE system - of this model is given by the following equations. If :math:`v_1 > V_{ref, 0}:` + Example implementing the battery drain model with :math:`N=1` capacitor, inherits from battery_n_capacitors. This model is an example + of a discontinuous problem. The state function :math:`decides` which differential equation is solved. When the state function has + a sign change the dynamics of the solution changes by changing the differential equation. The ODE system of this model is given by + the following equations: + + If :math:`h(v_1) := v_1 - V_{ref, 0} > 0:` .. math:: \frac{d i_L (t)}{dt} = 0, @@ -282,14 +289,15 @@ class battery(battery_n_capacitors): .. math:: \frac{d v_1 (t)}{dt} = -\frac{1}{CR}v_1 (t), - where :math:`i_L` denotes the function of the current over time :math:`t`. - If :math:`v_1 \leq V_{ref, 0}:` + else: .. math:: \frac{d i_L(t)}{dt} = -\frac{R_s + R}{L}i_L (t) + \frac{1}{L} V_s, .. math:: - \frac{d v_1(t)}{dt} = 0. + \frac{d v_1(t)}{dt} = 0, + + where :math:`i_L` denotes the function of the current over time :math:`t`. Note ---- @@ -320,7 +328,7 @@ class battery(battery_n_capacitors): t_switch = np.inf if self.t_switch is None else self.t_switch - if u[1] <= self.V_ref[0] or t >= t_switch: + if u[1] - self.V_ref[0] <= 0 or t >= t_switch: f.expl[0] = self.Vs / self.L else: @@ -352,7 +360,7 @@ class battery(battery_n_capacitors): t_switch = np.inf if self.t_switch is None else self.t_switch - if rhs[1] <= self.V_ref[0] or t >= t_switch: + if rhs[1] - self.V_ref[0] <= 0 or t >= t_switch: self.A[0, 0] = -(self.Rs + self.R) / self.L else: @@ -432,8 +440,8 @@ class battery_implicit(battery): L=1.0, alpha=1.2, V_ref=None, - newton_maxiter=200, - newton_tol=1e-8, + newton_maxiter=100, + newton_tol=1e-11, ): if C is None: C = np.array([1.0]) @@ -469,7 +477,7 @@ class battery_implicit(battery): t_switch = np.inf if self.t_switch is None else self.t_switch - if u[1] <= self.V_ref[0] or t >= t_switch: + if u[1] - self.V_ref[0] <= 0 or t >= t_switch: self.A[0, 0] = -(self.Rs + self.R) / self.L non_f[0] = self.Vs @@ -507,7 +515,7 @@ class battery_implicit(battery): t_switch = np.inf if self.t_switch is None else self.t_switch - if rhs[1] <= self.V_ref[0] or t >= t_switch: + if rhs[1] - self.V_ref[0] <= 0 or t >= t_switch: self.A[0, 0] = -(self.Rs + self.R) / self.L non_f[0] = self.Vs diff --git a/pySDC/implementations/problem_classes/DiscontinuousTestODE.py b/pySDC/implementations/problem_classes/DiscontinuousTestODE.py new file mode 100644 index 0000000000000000000000000000000000000000..fee8527839c3836f3cb31ee5826afa0401b7fc86 --- /dev/null +++ b/pySDC/implementations/problem_classes/DiscontinuousTestODE.py @@ -0,0 +1,226 @@ +import numpy as np + +from pySDC.core.Errors import ParameterError, ProblemError +from pySDC.core.Problem import ptype +from pySDC.implementations.datatype_classes.mesh import mesh + + +class DiscontinuousTestODE(ptype): + r""" + This class implements a very simple test case of a ordinary differential equation consisting of one discrete event. The dynamics of + the solution changes when the state function :math:`h(u) := u - 5` changes the sign. The problem is defined by: + + if :math:`u - 5 < 0:` + + .. math:: + \fra{d u}{dt} = u + + else: + + .. math:: + \frac{d u}{dt} = \frac{4}{t^*}, + + where :math:`t^* = \log(5) \approx 1.6094379`. For :math:`h(u) < 0`, i.e., :math:`t \leq t^*` the exact solution is + :math:`u(t) = exp(t)`; for :math:`h(u) \geq 0`, i.e., :math:`t \geq t^*` the exact solution is :math:`u(t) = \frac{4 t}{t^*} + 1`. + + Attributes + ---------- + t_switch_exact : float + Exact event time with :math:`t^* = \log(5)`. + t_switch: float + Time point of the discrete event found by switch estimation. + nswitches: int + Number of switches found by switch estimation. + newton_itercount: int + Counts the number of Newton iterations. + newton_ncalls: int + Counts the number of how often Newton is called in the simulation of the problem. + """ + + dtype_u = mesh + dtype_f = mesh + + def __init__(self, newton_maxiter=100, newton_tol=1e-8): + """Initialization routine""" + nvars = 1 + super().__init__(init=(nvars, None, np.dtype('float64'))) + self._makeAttributeAndRegister('nvars', localVars=locals(), readOnly=True) + self._makeAttributeAndRegister('newton_maxiter', 'newton_tol', localVars=locals()) + + if self.nvars != 1: + raise ParameterError('nvars has to be equal to 1!') + + self.t_switch_exact = np.log(5) + self.t_switch = None + self.nswitches = 0 + self.newton_itercount = 0 + self.newton_ncalls = 0 + + def eval_f(self, u, t): + """ + Routine to evaluate the right-hand side of the problem. + + Parameters + ---------- + u : dtype_u + Current values of the numerical solution. + t : float + Current time of the numerical solution is computed. + + Returns + ------- + f : dtype_f + The right-hand side of the problem. + """ + + t_switch = np.inf if self.t_switch is None else self.t_switch + + f = self.dtype_f(self.init, val=0.0) + h = u[0] - 5 + if h >= 0 or t >= t_switch: + f[:] = 4 / self.t_switch_exact + else: + f[:] = u + return f + + def solve_system(self, rhs, dt, u0, t): + r""" + Simple Newton solver for :math:`(I-factor\cdot A)\vec{u}=\vec{rhs}`. + + Parameters + ---------- + rhs : dtype_f + Right-hand side for the linear system. + dt : float + Abbrev. for the local stepsize (or any other factor required). + u0 : dtype_u + Initial guess for the iterative solver. + t : float + Current time (e.g. for time-dependent BCs). + + Returns + ------- + me : dtype_u + The solution as mesh. + """ + + t_switch = np.inf if self.t_switch is None else self.t_switch + + h = rhs[0] - 5 + u = self.dtype_u(u0) + + n = 0 + res = 99 + while n < self.newton_maxiter: + # form function g with g(u) = 0 + if h >= 0 or t >= t_switch: + g = u - dt * (4 / self.t_switch_exact) - rhs + else: + g = u - dt * u - rhs + + # if g is close to 0, then we are done + res = np.linalg.norm(g, np.inf) + + if res < self.newton_tol: + break + + if h >= 0 or t >= t_switch: + dg = 1 + else: + dg = 1 - dt + + # newton update + u -= 1.0 / dg * g + + n += 1 + + if np.isnan(res) and self.stop_at_nan: + raise ProblemError('Newton got nan after %i iterations, aborting...' % n) + elif np.isnan(res): + self.logger.warning('Newton got nan after %i iterations...' % n) + + if n == self.newton_maxiter: + self.logger.warning('Newton did not converge after %i iterations, error is %s' % (n, res)) + + self.newton_ncalls += 1 + self.newton_itercount += n + + me = self.dtype_u(self.init) + me[:] = u[:] + + return me + + def u_exact(self, t, u_init=None, t_init=None): + """ + Routine to compute the exact solution at time t. + + Parameters + ---------- + t : float + Time of the exact solution. + u_init : pySDC.problem.DiscontinuousTestODE.dtype_u + Initial conditions for getting the exact solution. + t_init : float + The starting time. + + Returns + ------- + me : dtype_u + The exact solution. + """ + + if t_init is not None and u_init is not None: + self.logger.warning( + f'{type(self).__name__} uses an analytic exact solution from t=0. If you try to compute the local error, you will get the global error instead!' + ) + + me = self.dtype_u(self.init) + if t <= self.t_switch_exact: + me[:] = np.exp(t) + else: + me[:] = (4 * t) / self.t_switch_exact + 1 + return me + + def get_switching_info(self, u, t): + """ + Provides information about the state function of the problem. When the state function changes its sign, + typically an event occurs. So the check for an event should be done in the way that the state function + is checked for a sign change. If this is the case, the intermediate value theorem states a root in this + step. + + Parameters + ---------- + u : dtype_u + Current values of the numerical solution at time t. + t : float + Current time of the numerical solution. + + Returns + ------- + switch_detected : bool + Indicates whether a discrete event is found or not. + m_guess : int + The index before the sign changes. + state_function : list + Defines the values of the state function at collocation nodes where it changes the sign. + """ + + switch_detected = False + m_guess = -100 + + for m in range(1, len(u)): + h_prev_node = u[m - 1][0] - 5 + h_curr_node = u[m][0] - 5 + if h_prev_node < 0 and h_curr_node >= 0: + switch_detected = True + m_guess = m - 1 + break + + state_function = [u[m][0] - 5 for m in range(len(u))] if switch_detected else [] + return switch_detected, m_guess, state_function + + def count_switches(self): + """ + Setter to update the number of switches if one is found. + """ + self.nswitches += 1 diff --git a/pySDC/implementations/problem_classes/Lorenz.py b/pySDC/implementations/problem_classes/Lorenz.py index 98cf5c0a677ff88dbf14f45ea89ef44297f20304..b4223baa66888e0d58b6cd1571a5a35261cb9ccf 100644 --- a/pySDC/implementations/problem_classes/Lorenz.py +++ b/pySDC/implementations/problem_classes/Lorenz.py @@ -192,6 +192,7 @@ class LorenzAttractor(ptype): me : dtype_u The approximated exact solution. """ + me = self.dtype_u(self.init) if t > 0: diff --git a/pySDC/implementations/problem_classes/TestEquation_0D.py b/pySDC/implementations/problem_classes/TestEquation_0D.py index cfb2b832b81eeabd95f86e26c881bc8e8cb4ea77..b4a09a97ef80d668a27d7e7a94371aa8c9deddd8 100644 --- a/pySDC/implementations/problem_classes/TestEquation_0D.py +++ b/pySDC/implementations/problem_classes/TestEquation_0D.py @@ -125,6 +125,7 @@ class testequation0d(ptype): me : dtype_u The exact solution. """ + u_init = (self.u0 if u_init is None else u_init) * 1.0 t_init = 0.0 if t_init is None else t_init * 1.0 diff --git a/pySDC/projects/PinTSimE/battery_2capacitors_model.py b/pySDC/projects/PinTSimE/battery_2capacitors_model.py index f89b7cadc838d3eb180c06fb886bfe488e037692..c1769e232bf0367e45fae12efa2c905c9da4ecca 100644 --- a/pySDC/projects/PinTSimE/battery_2capacitors_model.py +++ b/pySDC/projects/PinTSimE/battery_2capacitors_model.py @@ -3,7 +3,6 @@ import dill from pathlib import Path from pySDC.helpers.stats_helper import get_sorted -from pySDC.core.Collocation import CollBase as Collocation from pySDC.implementations.problem_classes.Battery import battery_n_capacitors from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI @@ -11,16 +10,51 @@ from pySDC.projects.PinTSimE.battery_model import ( controller_run, generate_description, get_recomputed, - log_data, proof_assertions_description, ) from pySDC.projects.PinTSimE.piline_model import setup_mpl import pySDC.helpers.plot_helper as plt_helper + from pySDC.core.Hooks import hooks +from pySDC.implementations.hooks.log_solution import LogSolution +from pySDC.implementations.hooks.default_hook import DefaultHooks from pySDC.projects.PinTSimE.switch_estimator import SwitchEstimator +class LogEvent(hooks): + """ + Logs the problem dependent state function of the battery drain model. + """ + + def post_step(self, step, level_number): + super(LogEvent, self).post_step(step, level_number) + + L = step.levels[level_number] + P = L.prob + + L.sweep.compute_end_point() + + self.add_to_stats( + process=step.status.slot, + time=L.time + L.dt, + level=L.level_index, + iter=0, + sweep=L.status.sweep, + type='state_function_1', + value=L.uend[1] - P.V_ref[0], + ) + self.add_to_stats( + process=step.status.slot, + time=L.time + L.dt, + level=L.level_index, + iter=0, + sweep=L.status.sweep, + type='state_function_2', + value=L.uend[2] - P.V_ref[1], + ) + + def run(): """ Executes the simulation for the battery model using the IMEX sweeper and plot the results @@ -33,19 +67,43 @@ def run(): problem_classes = [battery_n_capacitors] sweeper_classes = [imex_1st_order] + num_nodes = 4 + restol = -1 + maxiter = 8 ncapacitors = 2 alpha = 5.0 V_ref = np.array([1.0, 1.0]) C = np.array([1.0, 1.0]) + problem_params = dict() + problem_params['ncapacitors'] = ncapacitors + problem_params['C'] = C + problem_params['alpha'] = alpha + problem_params['V_ref'] = V_ref + recomputed = False use_switch_estimator = [True] + max_restarts = 1 + tol_event = 1e-8 + + hook_class = [DefaultHooks, LogSolution, LogEvent] for problem, sweeper in zip(problem_classes, sweeper_classes): for use_SE in use_switch_estimator: description, controller_params = generate_description( - dt, problem, sweeper, log_data, False, use_SE, ncapacitors, alpha, V_ref, C + dt, + problem, + sweeper, + num_nodes, + hook_class, + False, + use_SE, + problem_params, + restol, + maxiter, + max_restarts, + tol_event, ) # Assertions @@ -62,16 +120,24 @@ def run(): def plot_voltages(description, problem, sweeper, recomputed, use_switch_estimator, use_adaptivity, cwd='./'): """ - Routine to plot the numerical solution of the model - - Args: - description(dict): contains all information for a controller run - problem (pySDC.core.Problem.ptype): problem class that wants to be simulated - sweeper (pySDC.core.Sweeper.sweeper): sweeper class for solving the problem class numerically - recomputed (bool): flag if the values after a restart are used or before - use_switch_estimator (bool): flag if the switch estimator wants to be used or not - use_adaptivity (bool): flag if adaptivity wants to be used or not - cwd (str): current working directory + Routine to plot the numerical solution of the model. + + Parameters + ---------- + description : dict + Contains all information for a controller run. + problem : pySDC.core.Problem.ptype + Problem class that wants to be simulated. + sweeper : pySDC.core.Sweeper.sweeper + Sweeper class for solving the problem class numerically. + recomputed : bool + Flag if the values after a restart are used or before. + use_switch_estimator : bool + Flag if the switch estimator wants to be used or not. + use_adaptivity : bool + Flag if adaptivity wants to be used or not. + cwd : str + Current working directory. """ f = open(cwd + 'data/{}_{}_USE{}_USA{}.dat'.format(problem, sweeper, use_switch_estimator, use_adaptivity), 'rb') @@ -111,12 +177,17 @@ def plot_voltages(description, problem, sweeper, recomputed, use_switch_estimato def check_solution(stats, dt, use_switch_estimator): """ - Function that checks the solution based on a hardcoded reference solution. Based on check_solution function from @brownbaerchen. - - Args: - stats (dict): Raw statistics from a controller run - dt (float): initial time step - use_switch_estimator (bool): flag if the switch estimator wants to be used or not + Function that checks the solution based on a hardcoded reference solution. + Based on check_solution function from @brownbaerchen. + + Parameters + ---------- + stats : dict + Raw statistics from a controller run. + dt : float + Initial time step. + use_switch_estimator : bool + Flag if the switch estimator wants to be used or not. """ data = get_data_dict(stats, use_switch_estimator) @@ -125,51 +196,41 @@ def check_solution(stats, dt, use_switch_estimator): msg = f'Error when using the switch estimator for battery_2capacitors for dt={dt:.1e}:' if dt == 1e-2: expected = { - 'cL': 1.207906161238752, - 'vC1': 1.0094825899806945, - 'vC2': 1.00000000000412, - 'switch1': 1.6094379124373626, - 'switch2': 3.209437912437337, - 'restarts': 1.0, - 'sum_niters': 1412.0, + 'cL': 1.1783297877614183, + 'vC1': 0.9999999999967468, + 'vC2': 0.999999999996747, + 'state_function_1': -3.2531755067566337e-12, + 'state_function_2': -3.2529534621517087e-12, + 'restarts': 2.0, + 'sum_niters': 2824.0, } elif dt == 4e-1: expected = { - 'cL': 1.5090409300896785, - 'vC1': 1.0094891393319418, - 'vC2': 1.0018593331860708, - 'switch1': 1.6075867934844466, - 'switch2': 3.2094445842818007, - 'restarts': 2.0, - 'sum_niters': 52.0, + 'cL': 1.5039617338098907, + 'vC1': 0.9999999968387812, + 'vC2': 0.9999999968387812, + 'state_function_1': -3.161218842251401e-09, + 'state_function_2': -3.161218842251401e-09, + 'restarts': 10.0, + 'sum_niters': 200.0, } elif dt == 4e-2: expected = { - 'cL': 1.2708164018400792, - 'vC1': 1.0094825917376264, - 'vC2': 1.000030506091851, - 'switch1': 1.6094074085553605, - 'switch2': 3.209437914186951, - 'restarts': 2.0, - 'sum_niters': 368.0, - } - elif dt == 4e-3: - expected = { - 'cL': 1.1564912472685411, - 'vC1': 1.001438946726028, - 'vC2': 1.0000650435224532, - 'switch1': 1.6093728710270467, - 'switch2': 3.217437912434931, - 'restarts': 2.0, - 'sum_niters': 3516.0, + 'cL': 1.2707220273133215, + 'vC1': 1.0000000041344774, + 'vC2': 0.999999999632751, + 'state_function_1': 4.134477427086836e-09, + 'state_function_2': -3.672490089812186e-10, + 'restarts': 6.0, + 'sum_niters': 792.0, } got = { 'cL': data['cL'][-1], 'vC1': data['vC1'][-1], 'vC2': data['vC2'][-1], - 'switch1': data['switch1'], - 'switch2': data['switch2'], + 'state_function_1': data['state_function_1'][-1], + 'state_function_2': data['state_function_2'][-1], 'restarts': data['restarts'], 'sum_niters': data['sum_niters'], } @@ -185,21 +246,31 @@ def get_data_dict(stats, use_switch_estimator, recomputed=False): Converts the statistics in a useful data dictionary so that it can be easily checked in the check_solution function. Based on @brownbaerchen's get_data function. - Args: - stats (dict): Raw statistics from a controller run - use_switch_estimator (bool): flag if the switch estimator wants to be used or not - recomputed (bool): flag if the values after a restart are used or before - - Return: - data (dict): contains all information as the statistics dict + Parameters + ---------- + stats : dict + Raw statistics from a controller run. + use_switch_estimator : bool + Flag if the switch estimator wants to be used or not. + recomputed : bool + Flag if the values after a restart are used or before. + + Returns + ------- + data : dict + Contains all information as the statistics dict. """ data = dict() data['cL'] = np.array([me[1][0] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')]) data['vC1'] = np.array([me[1][1] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')]) data['vC2'] = np.array([me[1][2] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')]) - data['switch1'] = np.array(get_recomputed(stats, type='switch', sortby='time'))[0, 1] - data['switch2'] = np.array(get_recomputed(stats, type='switch', sortby='time'))[-1, 1] + data['state_function_1'] = np.array(get_sorted(stats, type='state_function_1', sortby='time', recomputed=False))[ + :, 1 + ] + data['state_function_2'] = np.array(get_sorted(stats, type='state_function_2', sortby='time', recomputed=False))[ + :, 1 + ] data['restarts'] = np.sum(np.array(get_sorted(stats, type='restart', recomputed=None, sortby='time'))[:, 1]) data['sum_niters'] = np.sum(np.array(get_sorted(stats, type='niter', recomputed=None, sortby='time'))[:, 1]) @@ -208,13 +279,18 @@ def get_data_dict(stats, use_switch_estimator, recomputed=False): def proof_assertions_time(dt, Tend, V_ref, alpha): """ - Function to proof the assertions regarding the time domain (in combination with the specific problem): - - Args: - dt (float): time step for computation - Tend (float): end time - V_ref (np.ndarray): Reference values (problem parameter) - alpha (np.float): Multiple used for initial conditions (problem_parameter) + Function to proof the assertions regarding the time domain (in combination with the specific problem). + + Parameters + ---------- + dt : float + Time step for computation. + Tend : float + End time. + V_ref : np.ndarray + Reference values (problem parameter). + alpha : np.float + Multiple used for initial conditions (problem_parameter). """ assert ( @@ -222,7 +298,7 @@ def proof_assertions_time(dt, Tend, V_ref, alpha): ), "Error! Do not use other parameters for V_ref[:] != 1.0, alpha != 1.2, Tend != 0.3 due to hardcoded reference!" assert ( - dt == 1e-2 or dt == 4e-1 or dt == 4e-2 or dt == 4e-3 + dt == 1e-2 or dt == 4e-1 or dt == 4e-2 ), "Error! Do not use other time steps dt != 4e-1 or dt != 4e-2 or dt != 4e-3 due to hardcoded references!" diff --git a/pySDC/projects/PinTSimE/battery_model.py b/pySDC/projects/PinTSimE/battery_model.py index a9974c8280d67a12546e8d6623725811817d1bd5..7de772ba4ff6e096b75ab0cf926e013d7a326cec 100644 --- a/pySDC/projects/PinTSimE/battery_model.py +++ b/pySDC/projects/PinTSimE/battery_model.py @@ -7,21 +7,30 @@ from pySDC.implementations.problem_classes.Battery import battery, battery_impli from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI -from pySDC.implementations.convergence_controller_classes.basic_restarting import BasicRestartingNonMPI + from pySDC.projects.PinTSimE.piline_model import setup_mpl import pySDC.helpers.plot_helper as plt_helper + from pySDC.core.Hooks import hooks +from pySDC.implementations.hooks.log_solution import LogSolution +from pySDC.implementations.hooks.log_step_size import LogStepSize +from pySDC.implementations.hooks.log_embedded_error_estimate import LogEmbeddedErrorEstimate from pySDC.projects.PinTSimE.switch_estimator import SwitchEstimator from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity +from pySDC.implementations.convergence_controller_classes.basic_restarting import BasicRestartingNonMPI -class log_data(hooks): +class LogEvent(hooks): + """ + Logs the problem dependent state function of the battery drain model. + """ + def post_step(self, step, level_number): - super(log_data, self).post_step(step, level_number) + super(LogEvent, self).post_step(step, level_number) - # some abbreviations L = step.levels[level_number] + P = L.prob L.sweep.compute_end_point() @@ -31,26 +40,8 @@ class log_data(hooks): level=L.level_index, iter=0, sweep=L.status.sweep, - type='u', - value=L.uend, - ) - self.add_to_stats( - process=step.status.slot, - time=L.time + L.dt, - level=L.level_index, - iter=0, - sweep=L.status.sweep, - type='dt', - value=L.dt, - ) - self.add_to_stats( - process=step.status.slot, - time=L.time + L.dt, - level=L.level_index, - iter=0, - sweep=L.status.sweep, - type='e_embedded', - value=L.status.get('error_embedded_estimate'), + type='state_function', + value=L.uend[1] - P.V_ref[0], ) @@ -58,56 +49,73 @@ def generate_description( dt, problem, sweeper, + num_nodes, hook_class, use_adaptivity, use_switch_estimator, - ncapacitors, - alpha, - V_ref, - C, + problem_params, + restol, + maxiter, max_restarts=None, + tol_event=1e-10, ): """ Generate a description for the battery models for a controller run. - Args: - dt (float): time step for computation - problem (pySDC.core.Problem.ptype): problem class that wants to be simulated - sweeper (pySDC.core.Sweeper.sweeper): sweeper class for solving the problem class numerically - hook_class (pySDC.core.Hooks): logged data for a problem - use_adaptivity (bool): flag if the adaptivity wants to be used or not - use_switch_estimator (bool): flag if the switch estimator wants to be used or not - ncapacitors (np.int): number of capacitors used for the battery_model - alpha (np.float): Multiple used for the initial conditions (problem_parameter) - V_ref (np.ndarray): Reference values for the capacitors (problem_parameter) - C (np.ndarray): Capacitances (problem_parameter - - Returns: - description (dict): contains all information for a controller run - controller_params (dict): Parameters needed for a controller run + + Parameters + ---------- + dt : float + Time step for computation. + problem : pySDC.core.Problem.ptype + Problem class that wants to be simulated. + sweeper : pySDC.core.Sweeper.sweeper + Sweeper class for solving the problem class numerically. + num_nodes : int + Number of collocation nodes. + hook_class : pySDC.core.Hooks + Logged data for a problem. + use_adaptivity : bool + Flag if the adaptivity wants to be used or not. + use_switch_estimator : bool + Flag if the switch estimator wants to be used or not. + ncapacitors : int + Number of capacitors used for the battery_model. + alpha : float + Multiple used for the initial conditions (problem_parameter). + problem_params : dict + Dictionary containing the problem parameters. + restol : float + Residual tolerance to terminate. + maxiter : int + Maximum number of iterations to be done. + max_restarts : int, optional + Maximum number of restarts per step. + tol_event : float, optional + Tolerance for switch estimation to terminate. + + Returns + ------- + description : dict + Contains all information for a controller run. + controller_params : dict + Parameters needed for a controller run. """ # initialize level parameters level_params = dict() - level_params['restol'] = -1 + level_params['restol'] = -1 if use_adaptivity else restol level_params['dt'] = dt # initialize sweeper parameters sweeper_params = dict() sweeper_params['quad_type'] = 'LOBATTO' - sweeper_params['num_nodes'] = 5 + sweeper_params['num_nodes'] = num_nodes sweeper_params['QI'] = 'IE' sweeper_params['initial_guess'] = 'spread' - # initialize problem parameters - problem_params = dict() - problem_params['ncapacitors'] = ncapacitors # number of capacitors - problem_params['C'] = C - problem_params['alpha'] = alpha - problem_params['V_ref'] = V_ref - # initialize step parameters step_params = dict() - step_params['maxiter'] = 4 + step_params['maxiter'] = maxiter # initialize controller parameters controller_params = dict() @@ -119,6 +127,7 @@ def generate_description( convergence_controllers = dict() if use_switch_estimator: switch_estimator_params = {} + switch_estimator_params['tol'] = tol_event convergence_controllers.update({SwitchEstimator: switch_estimator_params}) if use_adaptivity: @@ -147,16 +156,23 @@ def generate_description( def controller_run(description, controller_params, use_adaptivity, use_switch_estimator, t0, Tend): """ - Executes a controller run for a problem defined in the description - - Args: - description (dict): contains all information for a controller run - controller_params (dict): Parameters needed for a controller run - use_adaptivity (bool): flag if the adaptivity wants to be used or not - use_switch_estimator (bool): flag if the switch estimator wants to be used or not - - Returns: - stats (dict): Raw statistics from a controller run + Executes a controller run for a problem defined in the description. + + Parameters + ---------- + description : dict + Contains all information for a controller run. + controller_params : dict + Parameters needed for a controller run. + use_adaptivity : bool + Flag if the adaptivity wants to be used or not. + use_switch_estimator : bool + Flag if the switch estimator wants to be used or not. + + Returns + ------- + stats : dict + Raw statistics from a controller run. """ # instantiate controller @@ -195,22 +211,46 @@ def run(): problem_classes = [battery, battery_implicit] sweeper_classes = [imex_1st_order, generic_implicit] + num_nodes = 4 + restol = -1 + maxiter = 8 ncapacitors = 1 alpha = 1.2 V_ref = np.array([1.0]) C = np.array([1.0]) + problem_params = dict() + problem_params['ncapacitors'] = ncapacitors + problem_params['C'] = C + problem_params['alpha'] = alpha + problem_params['V_ref'] = V_ref + max_restarts = 1 recomputed = False use_switch_estimator = [True] use_adaptivity = [True] + hook_class = [LogSolution, LogEvent, LogEmbeddedErrorEstimate, LogStepSize] + for problem, sweeper in zip(problem_classes, sweeper_classes): for use_SE in use_switch_estimator: for use_A in use_adaptivity: + tol_event = 1e-10 if sweeper.__name__ == 'generic_implicit' else 1e-17 + description, controller_params = generate_description( - dt, problem, sweeper, log_data, use_A, use_SE, ncapacitors, alpha, V_ref, C, max_restarts + dt, + problem, + sweeper, + num_nodes, + hook_class, + use_A, + use_SE, + problem_params, + restol, + maxiter, + max_restarts, + tol_event, ) # Assertions @@ -227,16 +267,24 @@ def run(): def plot_voltages(description, problem, sweeper, recomputed, use_switch_estimator, use_adaptivity, cwd='./'): """ - Routine to plot the numerical solution of the model - - Args: - description(dict): contains all information for a controller run - problem (pySDC.core.Problem.ptype): problem class that wants to be simulated - sweeper (pySDC.core.Sweeper.sweeper): sweeper class for solving the problem class numerically - recomputed (bool): flag if the values after a restart are used or before - use_switch_estimator (bool): flag if the switch estimator wants to be used or not - use_adaptivity (bool): flag if adaptivity wants to be used or not - cwd (str): current working directory + Routine to plot the numerical solution of the model. + + Parameters + ---------- + description : dict + Contains all information for a controller run. + problem : pySDC.core.Problem.ptype + Problem class that wants to be simulated. + sweeper : pySDC.core.Sweeper.sweeper + Sweeper class for solving the problem class numerically. + recomputed : bool + Flag if the values after a restart are used or before. + use_switch_estimator : bool + Flag if the switch estimator wants to be used or not. + use_adaptivity : bool + Flag if adaptivity wants to be used or not. + cwd : str + Current working directory. """ f = open(cwd + 'data/{}_{}_USE{}_USA{}.dat'.format(problem, sweeper, use_switch_estimator, use_adaptivity), 'rb') @@ -284,14 +332,21 @@ def plot_voltages(description, problem, sweeper, recomputed, use_switch_estimato def check_solution(stats, dt, problem, use_adaptivity, use_switch_estimator): """ - Function that checks the solution based on a hardcoded reference solution. Based on check_solution function from @brownbaerchen. - - Args: - stats (dict): Raw statistics from a controller run - dt (float): initial time step - problem (problem_class.__name__): the problem_class that is numerically solved - use_switch_estimator (bool): - use_adaptivity (bool): + Function that checks the solution based on a hardcoded reference solution. + Based on check_solution function from @brownbaerchen. + + Parameters + ---------- + stats : dict + Raw statistics from a controller run. + dt : float + Initial time step. + problem : problem_class.__name__ + The problem_class that is numerically solved + use_switch_estimator : bool + Indicates if switch detection is used or not. + use_adaptivity : bool + Indicate if adaptivity is used or not. """ data = get_data_dict(stats, use_adaptivity, use_switch_estimator) @@ -301,33 +356,23 @@ def check_solution(stats, dt, problem, use_adaptivity, use_switch_estimator): msg = f'Error when using switch estimator and adaptivity for battery for dt={dt:.1e}:' if dt == 1e-2: expected = { - 'cL': 0.5474500710994862, - 'vC': 1.0019332967173764, - 'dt': 0.011761752270047832, - 'e_em': 8.001793672107738e-10, - 'switches': 0.18232155791181945, + 'cL': 0.5446532674094873, + 'vC': 0.9999999999883544, + 'dt': 0.01, + 'e_em': 2.220446049250313e-16, + 'state_function': -1.1645573394503117e-11, 'restarts': 3.0, - 'sum_niters': 44.0, + 'sum_niters': 136.0, } - elif dt == 4e-2: + elif dt == 1e-3: expected = { - 'cL': 0.5525783945667581, - 'vC': 1.00001743462299, - 'dt': 0.03550610373897258, - 'e_em': 6.21240694442804e-08, - 'switches': 0.18231603298272345, - 'restarts': 4.0, - 'sum_niters': 56.0, - } - elif dt == 4e-3: - expected = { - 'cL': 0.5395601429161445, - 'vC': 1.0000413761942089, - 'dt': 0.028281271825675414, - 'e_em': 2.5628611677319668e-08, - 'switches': 0.18230920573953438, - 'restarts': 3.0, - 'sum_niters': 48.0, + 'cL': 0.539386744746365, + 'vC': 0.9999999710472945, + 'dt': 0.005520873635314061, + 'e_em': 2.220446049250313e-16, + 'state_function': -2.8952705455331795e-08, + 'restarts': 11.0, + 'sum_niters': 264.0, } got = { @@ -335,7 +380,7 @@ def check_solution(stats, dt, problem, use_adaptivity, use_switch_estimator): 'vC': data['vC'][-1], 'dt': data['dt'][-1], 'e_em': data['e_em'][-1], - 'switches': data['switches'][-1], + 'state_function': data['state_function'][-1], 'restarts': data['restarts'], 'sum_niters': data['sum_niters'], } @@ -343,33 +388,25 @@ def check_solution(stats, dt, problem, use_adaptivity, use_switch_estimator): msg = f'Error when using switch estimator for battery for dt={dt:.1e}:' if dt == 1e-2: expected = { - 'cL': 0.5495834172613568, - 'vC': 1.000118710428906, - 'switches': 0.1823188001399631, - 'restarts': 1.0, - 'sum_niters': 128.0, - } - elif dt == 4e-2: - expected = { - 'cL': 0.553775247309617, - 'vC': 1.0010140038721593, - 'switches': 0.1824302065533169, - 'restarts': 1.0, - 'sum_niters': 36.0, + 'cL': 0.5456190026495924, + 'vC': 0.999166666941434, + 'state_function': -0.0008333330585660326, + 'restarts': 4.0, + 'sum_niters': 296.0, } - elif dt == 4e-3: + elif dt == 1e-3: expected = { - 'cL': 0.5495840499078819, - 'vC': 1.0001158309787614, - 'switches': 0.18232183080236553, - 'restarts': 1.0, - 'sum_niters': 308.0, + 'cL': 0.5403849766797957, + 'vC': 0.9999166666752302, + 'state_function': -8.33333247698409e-05, + 'restarts': 2.0, + 'sum_niters': 2424.0, } got = { 'cL': data['cL'][-1], 'vC': data['vC'][-1], - 'switches': data['switches'][-1], + 'state_function': data['state_function'][-1], 'restarts': data['restarts'], 'sum_niters': data['sum_niters'], } @@ -378,30 +415,21 @@ def check_solution(stats, dt, problem, use_adaptivity, use_switch_estimator): msg = f'Error when using adaptivity for battery for dt={dt:.1e}:' if dt == 1e-2: expected = { - 'cL': 0.5401449976237487, - 'vC': 0.9944656165121677, - 'dt': 0.013143356036619536, - 'e_em': 1.2462494369813726e-09, - 'restarts': 3.0, - 'sum_niters': 52.0, - } - elif dt == 4e-2: - expected = { - 'cL': 0.5966289599915113, - 'vC': 0.9923148791604984, - 'dt': 0.03564958366355817, - 'e_em': 6.210964231812e-08, - 'restarts': 1.0, - 'sum_niters': 36.0, + 'cL': 0.4433805288639916, + 'vC': 0.90262388393713, + 'dt': 0.18137307612335937, + 'e_em': 2.7177844974524135e-09, + 'restarts': 0.0, + 'sum_niters': 24.0, } - elif dt == 4e-3: + elif dt == 1e-3: expected = { - 'cL': 0.5431613774808756, - 'vC': 0.9934307674636834, - 'dt': 0.022880524075396924, - 'e_em': 1.1130212751453428e-08, - 'restarts': 3.0, - 'sum_niters': 52.0, + 'cL': 0.3994744179584864, + 'vC': 0.9679037468770668, + 'dt': 0.1701392217033212, + 'e_em': 2.0992988458701234e-09, + 'restarts': 0.0, + 'sum_niters': 32.0, } got = { @@ -418,33 +446,23 @@ def check_solution(stats, dt, problem, use_adaptivity, use_switch_estimator): msg = f'Error when using switch estimator and adaptivity for battery_implicit for dt={dt:.1e}:' if dt == 1e-2: expected = { - 'cL': 0.5395401085152521, - 'vC': 1.00003663985255, - 'dt': 0.011465727118881608, + 'cL': 0.5446675396652545, + 'vC': 0.9999999999883541, + 'dt': 0.01, 'e_em': 2.220446049250313e-16, - 'switches': 0.18231044486762837, - 'restarts': 4.0, - 'sum_niters': 44.0, - } - elif dt == 4e-2: - expected = { - 'cL': 0.6717104472882885, - 'vC': 1.0071670698947914, - 'dt': 0.035896059229296486, - 'e_em': 6.208836400567463e-08, - 'switches': 0.18232158833761175, + 'state_function': -1.1645906461410505e-11, 'restarts': 3.0, - 'sum_niters': 36.0, + 'sum_niters': 136.0, } - elif dt == 4e-3: + elif dt == 1e-3: expected = { - 'cL': 0.5396216192241711, - 'vC': 1.0000561014463172, - 'dt': 0.009904645972832471, + 'cL': 0.5393867447463223, + 'vC': 0.9999999710472952, + 'dt': 0.005520876908755634, 'e_em': 2.220446049250313e-16, - 'switches': 0.18230549652342606, - 'restarts': 4.0, - 'sum_niters': 44.0, + 'state_function': -2.895270478919798e-08, + 'restarts': 11.0, + 'sum_niters': 264.0, } got = { @@ -452,7 +470,7 @@ def check_solution(stats, dt, problem, use_adaptivity, use_switch_estimator): 'vC': data['vC'][-1], 'dt': data['dt'][-1], 'e_em': data['e_em'][-1], - 'switches': data['switches'][-1], + 'state_function': data['state_function'][-1], 'restarts': data['restarts'], 'sum_niters': data['sum_niters'], } @@ -460,33 +478,25 @@ def check_solution(stats, dt, problem, use_adaptivity, use_switch_estimator): msg = f'Error when using switch estimator for battery_implicit for dt={dt:.1e}:' if dt == 1e-2: expected = { - 'cL': 0.5495834122430945, - 'vC': 1.000118715162845, - 'switches': 0.18231880065636324, - 'restarts': 1.0, - 'sum_niters': 128.0, - } - elif dt == 4e-2: - expected = { - 'cL': 0.5537752525450169, - 'vC': 1.0010140112484431, - 'switches': 0.18243023230469263, - 'restarts': 1.0, - 'sum_niters': 36.0, + 'cL': 0.5456190026495138, + 'vC': 0.9991666669414431, + 'state_function': -0.0008333330585569287, + 'restarts': 4.0, + 'sum_niters': 296.0, } - elif dt == 4e-3: + elif dt == 1e-3: expected = { - 'cL': 0.5495840604357269, - 'vC': 1.0001158454740509, - 'switches': 0.1823218812753008, - 'restarts': 1.0, - 'sum_niters': 308.0, + 'cL': 0.5403849766797896, + 'vC': 0.9999166666752302, + 'state_function': -8.33333247698409e-05, + 'restarts': 2.0, + 'sum_niters': 2424.0, } got = { 'cL': data['cL'][-1], 'vC': data['vC'][-1], - 'switches': data['switches'][-1], + 'state_function': data['state_function'][-1], 'restarts': data['restarts'], 'sum_niters': data['sum_niters'], } @@ -495,30 +505,21 @@ def check_solution(stats, dt, problem, use_adaptivity, use_switch_estimator): msg = f'Error when using adaptivity for battery_implicit for dt={dt:.1e}:' if dt == 1e-2: expected = { - 'cL': 0.5569818284195267, - 'vC': 0.9846733115433628, - 'dt': 0.01, - 'e_em': 2.220446049250313e-16, - 'restarts': 9.0, - 'sum_niters': 88.0, + 'cL': 0.4694087102919169, + 'vC': 0.9026238839371302, + 'dt': 0.18137307612335937, + 'e_em': 2.3469713394952407e-09, + 'restarts': 0.0, + 'sum_niters': 24.0, } - elif dt == 4e-2: + elif dt == 1e-3: expected = { - 'cL': 0.5556563012729733, - 'vC': 0.9930947318467772, - 'dt': 0.035507110551631804, - 'e_em': 6.2098696185231e-08, - 'restarts': 6.0, - 'sum_niters': 64.0, - } - elif dt == 4e-3: - expected = { - 'cL': 0.5401117929618637, - 'vC': 0.9933888475391347, - 'dt': 0.03176025170463925, - 'e_em': 4.0386798239033794e-08, - 'restarts': 8.0, - 'sum_niters': 80.0, + 'cL': 0.39947441811958956, + 'vC': 0.9679037468770735, + 'dt': 0.1701392217033212, + 'e_em': 1.147640815712947e-09, + 'restarts': 0.0, + 'sum_niters': 32.0, } got = { @@ -531,9 +532,11 @@ def check_solution(stats, dt, problem, use_adaptivity, use_switch_estimator): } for key in expected.keys(): - assert np.isclose( - expected[key], got[key], rtol=1e-4 - ), f'{msg} Expected {key}={expected[key]:.4e}, got {key}={got[key]:.4e}' + err_msg = f'{msg} Expected {key}={expected[key]:.4e}, got {key}={got[key]:.4e}' + if key == 'cL': + assert abs(expected[key] - got[key]) <= 1e-2, err_msg + else: + assert np.isclose(expected[key], got[key], rtol=1e-3), err_msg def get_data_dict(stats, use_adaptivity, use_switch_estimator, recomputed=False): @@ -541,27 +544,35 @@ def get_data_dict(stats, use_adaptivity, use_switch_estimator, recomputed=False) Converts the statistics in a useful data dictionary so that it can be easily checked in the check_solution function. Based on @brownbaerchen's get_data function. - Args: - stats (dict): Raw statistics from a controller run - use_adaptivity (bool): flag if adaptivity wants to be used or not - use_switch_estimator (bool): flag if the switch estimator wants to be used or not - recomputed (bool): flag if the values after a restart are used or before - - Return: - data (dict): contains all information as the statistics dict + Parameters + ---------- + stats : dict + Raw statistics from a controller run. + use_adaptivity : bool + Flag if adaptivity wants to be used or not. + use_switch_estimator : bool + Flag if the switch estimator wants to be used or not. + recomputed : bool + Flag if the values after a restart are used or before. + + Returns + ------- + data : dict + Contains all information as the statistics dict. """ data = dict() - - data['cL'] = np.array([me[1][0] for me in get_sorted(stats, type='u', recomputed=recomputed, sortby='time')]) - data['vC'] = np.array([me[1][1] for me in get_sorted(stats, type='u', recomputed=recomputed, sortby='time')]) + data['cL'] = np.array([me[1][0] for me in get_sorted(stats, type='u', sortby='time', recomputed=recomputed)]) + data['vC'] = np.array([me[1][1] for me in get_sorted(stats, type='u', sortby='time', recomputed=recomputed)]) if use_adaptivity: - data['dt'] = np.array(get_sorted(stats, type='dt', recomputed=recomputed, sortby='time'))[:, 1] + data['dt'] = np.array(get_sorted(stats, type='dt', sortby='time', recomputed=recomputed))[:, 1] data['e_em'] = np.array( - get_sorted(stats, type='error_embedded_estimate', recomputed=recomputed, sortby='time') + get_sorted(stats, type='error_embedded_estimate', sortby='time', recomputed=recomputed) )[:, 1] if use_switch_estimator: - data['switches'] = np.array(get_recomputed(stats, type='switch', sortby='time'))[:, 1] + data['state_function'] = np.array( + get_sorted(stats, type='state_function', sortby='time', recomputed=recomputed) + )[:, 1] if use_adaptivity or use_switch_estimator: data['restarts'] = np.sum(np.array(get_sorted(stats, type='restart', recomputed=None, sortby='time'))[:, 1]) data['sum_niters'] = np.sum(np.array(get_sorted(stats, type='niter', recomputed=None, sortby='time'))[:, 1]) @@ -574,13 +585,19 @@ def get_recomputed(stats, type, sortby): Function that filters statistics after a recomputation. It stores all value of a type before restart. If there are multiple values with same time point, it only stores the elements with unique times. - Args: - stats (dict): Raw statistics from a controller run - type (str): the type the be filtered - sortby (str): string to specify which key to use for sorting - - Returns: - sorted_list (list): list of filtered statistics + Parameters + ---------- + stats : dict + Raw statistics from a controller run. + type : str + The type the be filtered. + sortby : str + String to specify which key to use for sorting. + + Returns + ------- + sorted_list : list + List of filtered statistics. """ sorted_nested_list = [] @@ -604,10 +621,14 @@ def proof_assertions_description(description, use_adaptivity, use_switch_estimat """ Function to proof the assertions in the description. - Args: - description(dict): contains all information for a controller run - use_adaptivity (bool): flag if adaptivity wants to be used or not - use_switch_estimator (bool): flag if the switch estimator wants to be used or not + Parameters + ---------- + description : dict + Contains all information for a controller run. + use_adaptivity : bool + Flag if adaptivity wants to be used or not. + use_switch_estimator : bool + Flag if the switch estimator wants to be used or not. """ n = description['problem_params']['ncapacitors'] @@ -631,19 +652,24 @@ def proof_assertions_description(description, use_adaptivity, use_switch_estimat assert 'alpha' in description['problem_params'].keys(), 'Please supply "alpha" in the problem parameters' assert 'V_ref' in description['problem_params'].keys(), 'Please supply "V_ref" in the problem parameters' - if use_switch_estimator or use_adaptivity: + if use_adaptivity: assert description['level_params']['restol'] == -1, "Please set restol to -1 or omit it" def proof_assertions_time(dt, Tend, V_ref, alpha): """ - Function to proof the assertions regarding the time domain (in combination with the specific problem): - - Args: - dt (float): time step for computation - Tend (float): end time - V_ref (np.ndarray): Reference values (problem parameter) - alpha (np.float): Multiple used for initial conditions (problem_parameter) + Function to proof the assertions regarding the time domain (in combination with the specific problem). + + Parameters + ---------- + dt : float + Time step for computation. + Tend : float + End time. + V_ref : np.ndarray + Reference values (problem parameter). + alpha : float + Multiple used for initial conditions (problem_parameter). """ assert dt < Tend, "Time step is too large for the time domain!" diff --git a/pySDC/projects/PinTSimE/discontinuous_test_ODE.py b/pySDC/projects/PinTSimE/discontinuous_test_ODE.py new file mode 100644 index 0000000000000000000000000000000000000000..a8b343bf926f915b6032be694bf4e9bbed28da48 --- /dev/null +++ b/pySDC/projects/PinTSimE/discontinuous_test_ODE.py @@ -0,0 +1,278 @@ +import numpy as np +import dill +from pathlib import Path + +from pySDC.helpers.stats_helper import get_sorted +from pySDC.implementations.problem_classes.DiscontinuousTestODE import DiscontinuousTestODE +from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit +from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI +from pySDC.implementations.convergence_controller_classes.basic_restarting import BasicRestartingNonMPI +from pySDC.projects.PinTSimE.switch_estimator import SwitchEstimator +from pySDC.projects.PinTSimE.battery_model import get_recomputed, generate_description +import pySDC.helpers.plot_helper as plt_helper +from pySDC.core.Hooks import hooks +from pySDC.implementations.hooks.log_errors import LogGlobalErrorPostStep +from pySDC.implementations.hooks.log_solution import LogSolution + + +class LogEvent(hooks): + """ + Logs the problem dependent state function of the discontinuous test ODE. + """ + + def post_step(self, step, level_number): + super(LogEvent, self).post_step(step, level_number) + + L = step.levels[level_number] + + L.sweep.compute_end_point() + + self.add_to_stats( + process=step.status.slot, + time=L.time + L.dt, + level=L.level_index, + iter=0, + sweep=L.status.sweep, + type='state_function', + value=L.uend[0] - 5, + ) + + +def main(): + """ + Executes the main stuff of the file. + """ + + Path("data").mkdir(parents=True, exist_ok=True) + + hookclass = [LogEvent, LogSolution, LogGlobalErrorPostStep] + + problem_class = DiscontinuousTestODE + + sweeper = generic_implicit + nnodes = [2, 3, 4] + maxiter = 8 + newton_tol = 1e-11 + + problem_params = dict() + problem_params['newton_maxiter'] = 50 + problem_params['newton_tol'] = newton_tol + + use_detection = [True, False] + + t0 = 1.4 + Tend = 1.7 + dt_list = [1e-2, 1e-3] + + for dt in dt_list: + for num_nodes in nnodes: + for use_SE in use_detection: + print(f'Controller run -- Simulation for step size: {dt}') + + restol = 1e-14 + recomputed = False if use_SE else None + + description, controller_params = generate_description( + dt, + problem_class, + sweeper, + num_nodes, + hookclass, + False, + use_SE, + problem_params, + restol, + maxiter, + max_restarts=None, + tol_event=1e-10, + ) + + proof_assertions(description, t0, Tend, recomputed, use_SE) + + stats, t_switch_exact = controller_run(t0, Tend, controller_params, description) + + if use_SE: + switches = get_recomputed(stats, type='switch', sortby='time') + assert len(switches) >= 1, 'No events found!' + test_event_error(stats, dt, t_switch_exact, num_nodes) + + test_error(stats, dt, num_nodes, use_SE, recomputed) + + +def controller_run(t0, Tend, controller_params, description): + """ + Executes a controller run for time interval to be specified in the arguments. + + Parameters + ---------- + t0 : float + Initial time of simulation. + Tend : float + End time of simulation. + controller_params : dict + Parameters needed for the controller. + description : dict + Contains all information for a controller run. + + Returns + ------- + stats : dict + Raw statistics from a controller run. + """ + + # instantiate the controller + controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description) + + # get initial values on finest level + P = controller.MS[0].levels[0].prob + uinit = P.u_exact(t0) + t_switch_exact = P.t_switch_exact + + # call main function to get things done... + uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) + + return stats, t_switch_exact + + +def test_event_error(stats, dt, t_switch_exact, num_nodes): + """ + Tests the error between the exact event time and the event time founded by switch estimation. + + The errors to the exact event time are very small. The higher the number of collocation nodes + is the smaller the error to the exact event time is. + + Parameter + --------- + stats : dict + Raw statistics from a controller run. + dt : float + Current step size. + t_switch_exact : float + Exact event time of the problem. + num_nodes : int + Number of collocation nodes used. + """ + + switches = get_recomputed(stats, type='switch', sortby='time') + assert len(switches) >= 1, 'No switches found!' + t_switches = [v[1] for v in switches] + + # dict with hardcoded solution for event time error + t_event_err = { + 1e-2: { + 2: 1.7020665390443668e-06, + 3: 5.532252433937401e-10, + 4: 6.2776006615195e-11, + }, + 1e-3: { + 2: 1.7060500789867206e-08, + 3: 5.890081755666188e-10, + 4: 0.00057634035536136, + }, + } + + t_event_err_got = abs(t_switch_exact - t_switches[-1]) + t_event_err_expected = t_event_err[dt][num_nodes] + + msg = f'Expected event time error {t_event_err_expected:.5f}, got {t_event_err_got:.5f}' + assert np.isclose(t_event_err_got, t_event_err_expected, atol=5e-3), msg + + +def test_error(stats, dt, num_nodes, use_SE, recomputed): + """ + Tests the error between the exact event solution and the numerical solution founded. + + In the dictionary containing the errors it can be clearly seen that errors are inherently reduced + using the switch estimator to predict the event and adapt the time step to resolve the event in + an more accurate way! + + Parameter + --------- + stats : dict + Raw statistics from a controller run. + dt : float + Current step size. + num_nodes : int + Number of collocation nodes used. + use_SE : bool + Indicates whether switch detection is used or not. + recomputed : bool + Indicates whether the values after a restart will be used. + """ + + err = get_sorted(stats, type='e_global_post_step', sortby='time', recomputed=recomputed) + err_norm = max([me[1] for me in err]) + + u_err = { + True: { + 1e-2: { + 2: 8.513409107457903e-06, + 3: 4.046930790480019e-09, + 4: 3.8459546658486943e-10, + }, + 1e-3: { + 2: 8.9185828500149e-08, + 3: 4.459276503609999e-09, + 4: 0.00015611434317808204, + }, + }, + False: { + 1e-2: { + 2: 0.014137551021780936, + 3: 0.009855041165877765, + 4: 0.006289698596543047, + }, + 1e-3: { + 2: 0.002674332734426521, + 3: 0.00057634035536136, + 4: 0.00015611434317808204, + }, + }, + } + + u_err_expected = u_err[use_SE][dt][num_nodes] + u_err_got = err_norm + + msg = f'Expected event time error {u_err_expected:.7f}, got {u_err_got:.7f}' + assert np.isclose(u_err_got - u_err_expected, 0, atol=1e-11), msg + + +def proof_assertions(description, t0, Tend, recomputed, use_detection): + """ + Tests the parameters if they would not change. + + Parameters + ---------- + description : dict + Contains all information for a controller run. + t0 : float + Starting time. + Tend : float + End time. + recomputed : bool + Indicates whether the values after a restart are considered. + use_detection : bool + Indicates whether switch estimation is used. + """ + + newton_tol = description['problem_params']['newton_tol'] + msg = 'Newton tolerance should be set as small as possible to get best possible resolution of solution' + assert newton_tol <= 1e-8, msg + + assert t0 >= 1.0, 'Problem is only defined for t >= 1!' + assert Tend >= np.log(5), f'To investigate event, please set Tend larger than {np.log(5):.5f}' + + num_nodes = description['sweeper_params']['num_nodes'] + for M in [2, 3, 4]: + if num_nodes not in [2, 3, 4]: + assert num_nodes == M, f'Hardcoded solutions are only for M={M}!' + + sweeper = description['sweeper_class'].__name__ + assert sweeper == 'generic_implicit', 'Only generic_implicit sweeper is tested!' + + if use_detection: + assert recomputed == False, 'Be aware that recomputed is set to False by using switch detection!' + + +if __name__ == "__main__": + main() diff --git a/pySDC/projects/PinTSimE/estimation_check.py b/pySDC/projects/PinTSimE/estimation_check.py index 5fa61f7b3788d71479004d702cd905f3e7d398a4..6fef419cb0a90dc4a44918c2e7e31f289fb58c2f 100644 --- a/pySDC/projects/PinTSimE/estimation_check.py +++ b/pySDC/projects/PinTSimE/estimation_check.py @@ -3,22 +3,26 @@ import dill from pathlib import Path from pySDC.helpers.stats_helper import get_sorted -from pySDC.core.Collocation import CollBase as Collocation from pySDC.implementations.problem_classes.Battery import battery, battery_implicit from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI -from pySDC.projects.PinTSimE.piline_model import setup_mpl from pySDC.projects.PinTSimE.battery_model import ( controller_run, check_solution, generate_description, get_recomputed, - log_data, + LogEvent, proof_assertions_description, ) + +from pySDC.projects.PinTSimE.piline_model import setup_mpl import pySDC.helpers.plot_helper as plt_helper + from pySDC.core.Hooks import hooks +from pySDC.implementations.hooks.log_solution import LogSolution +from pySDC.implementations.hooks.log_step_size import LogStepSize +from pySDC.implementations.hooks.log_embedded_error_estimate import LogEmbeddedErrorEstimate from pySDC.projects.PinTSimE.switch_estimator import SwitchEstimator from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity @@ -28,22 +32,35 @@ def run(cwd='./'): """ Routine to check the differences between using a switch estimator or not - Args: - cwd (str): current working directory + Parameters + ---------- + cwd : str + Current working directory. """ - dt_list = [4e-2, 4e-3] + dt_list = [1e-2, 1e-3] t0 = 0.0 Tend = 0.3 problem_classes = [battery, battery_implicit] sweeper_classes = [imex_1st_order, generic_implicit] + num_nodes = 4 + restol = -1 + maxiter = 8 ncapacitors = 1 alpha = 1.2 V_ref = np.array([1.0]) C = np.array([1.0]) + problem_params = dict() + problem_params['ncapacitors'] = ncapacitors + problem_params['C'] = C + problem_params['alpha'] = alpha + problem_params['V_ref'] = V_ref + + hook_class = [LogSolution, LogEvent, LogStepSize, LogEmbeddedErrorEstimate] + max_restarts = 1 use_switch_estimator = [True, False] use_adaptivity = [True, False] @@ -55,18 +72,20 @@ def run(cwd='./'): for dt_item in dt_list: for use_SE in use_switch_estimator: for use_A in use_adaptivity: + tol_event = 1e-10 if sweeper.__name__ == 'generic_implicit' else 1e-17 description, controller_params = generate_description( dt_item, problem, sweeper, - log_data, + num_nodes, + hook_class, use_A, use_SE, - ncapacitors, - alpha, - V_ref, - C, + problem_params, + restol, + maxiter, max_restarts, + tol_event, ) # Assertions @@ -77,11 +96,6 @@ def run(cwd='./'): if use_A or use_SE: check_solution(stats, dt_item, problem.__name__, use_A, use_SE) - if use_SE: - assert ( - len(get_recomputed(stats, type='switch', sortby='time')) >= 1 - ), 'No switches found for dt={}!'.format(dt_item) - fname = 'data/battery_dt{}_USE{}_USA{}_{}.dat'.format(dt_item, use_SE, use_A, sweeper.__name__) f = open(fname, 'wb') dill.dump(stats, f) @@ -121,14 +135,20 @@ def run(cwd='./'): def accuracy_check(dt_list, problem, sweeper, V_ref, cwd='./'): """ - Routine to check accuracy for different step sizes in case of using adaptivity - - Args: - dt_list (list): list of considered (initial) step sizes - problem (pySDC.core.Problem.ptype): Problem class used to consider (the class name) - sweeper (pySDC.core.Sweeper.sweeper): Sweeper used to solve (the class name) - V_ref (np.float): reference value for the switch - cwd (str): current working directory + Routine to check accuracy for different step sizes in case of using adaptivity. + + Parameters + ---------- + dt_list : list + List of considered (initial) step sizes. + problem : pySDC.core.Problem.ptype + Problem class used to consider (the class name). + sweeper : pySDC.core.Sweeper.sweeper + Sweeper used to solve (the class name). + V_ref : float + Reference value for the switch. + cwd : str + Current working directory. """ if len(dt_list) > 1: @@ -236,17 +256,27 @@ def differences_around_switch( dt_list, problem, restarts_SE, restarts_adapt, restarts_SE_adapt, sweeper, V_ref, cwd='./' ): """ - Routine to plot the differences before, at, and after the switch. Produces the diffs_estimation_<sweeper_class>.png file - - Args: - dt_list (list): list of considered (initial) step sizes - problem (pySDC.core.Problem.ptype): Problem class used to consider (the class name) - restarts_SE (list): Restarts for the solve only using the switch estimator - restarts_adapt (list): Restarts for the solve of only using adaptivity - restarts_SE_adapt (list): Restarts for the solve of using both, switch estimator and adaptivity - sweeper (pySDC.core.Sweeper.sweeper): Sweeper used to solve (the class name) - V_ref (np.float): reference value for the switch - cwd (str): current working directory + Routine to plot the differences before, at, and after the switch. Produces the + diffs_estimation_<sweeper_class>.png file + + Parameters + ---------- + dt_list : list + List of considered (initial) step sizes. + problem : pySDC.core.Problem.ptype + Problem class used to consider (the class name). + restarts_SE : list + Restarts for the solve only using the switch estimator. + restarts_adapt : list + Restarts for the solve of only using adaptivity. + restarts_SE_adapt : list + Restarts for the solve of using both, switch estimator and adaptivity. + sweeper : pySDC.core.Sweeper.sweeper + Sweeper used to solve (the class name). + V_ref float + Reference value for the switch. + cwd : str + Current working directory. """ diffs_true_at = [] @@ -297,9 +327,7 @@ def differences_around_switch( times_adapt = [me[0] for me in get_sorted(stats_adapt, type='u', recomputed=False)] times_SE_adapt = [me[0] for me in get_sorted(stats_SE_adapt, type='u', recomputed=False)] - diffs_true_at.append( - [diff_SE[m] for m in range(len(times_SE)) if np.isclose(times_SE[m], t_switch, atol=1e-15)][0] - ) + diffs_true_at.append([diff_SE[m] for m in range(len(times_SE)) if abs(times_SE[m] - t_switch) <= 1e-7][0]) diffs_false_before.append( [diff[m - 1] for m in range(1, len(times)) if times[m - 1] <= t_switch <= times[m]][0] @@ -307,7 +335,7 @@ def differences_around_switch( diffs_false_after.append([diff[m] for m in range(1, len(times)) if times[m - 1] <= t_switch <= times[m]][0]) for m in range(len(times_SE_adapt)): - if np.isclose(times_SE_adapt[m], t_switch_SE_adapt, atol=1e-10): + if abs(times_SE_adapt[m] - t_switch_SE_adapt) <= 1e-10: diffs_true_at_adapt.append(diff_SE_adapt[m]) diffs_true_before_adapt.append(diff_SE_adapt[m - 1]) diffs_true_after_adapt.append(diff_SE_adapt[m + 1]) @@ -387,14 +415,21 @@ def differences_around_switch( def differences_over_time(dt_list, problem, sweeper, V_ref, cwd='./'): """ - Routine to plot the differences in time using the switch estimator or not. Produces the difference_estimation_<sweeper_class>.png file - - Args: - dt_list (list): list of considered (initial) step sizes - problem (pySDC.core.Problem.ptype): Problem class used to consider (the class name) - sweeper (pySDC.core.Sweeper.sweeper): Sweeper used to solve (the class name) - V_ref (np.float): reference value for the switch - cwd (str): current working directory + Routine to plot the differences in time using the switch estimator or not. Produces the + difference_estimation_<sweeper_class>.png file. + + Parameters + ---------- + dt_list : list + List of considered (initial) step sizes. + problem : pySDC.core.Problem.ptype + Problem class used to consider (the class name). + sweeper : pySDC.core.Sweeper.sweeper + Sweeper used to solve (the class name). + V_ref : float + Reference value for the switch. + cwd : str + Current working directory. """ if len(dt_list) > 1: @@ -535,14 +570,21 @@ def differences_over_time(dt_list, problem, sweeper, V_ref, cwd='./'): def iterations_over_time(dt_list, maxiter, problem, sweeper, cwd='./'): """ - Routine to plot the number of iterations over time using switch estimator or not. Produces the iters_<sweeper_class>.png file - - Args: - dt_list (list): list of considered (initial) step sizes - maxiter (np.int): maximum number of iterations - problem (pySDC.core.Problem.ptype): Problem class used to consider (the class name) - sweeper (pySDC.core.Sweeper.sweeper): Sweeper used to solve (the class name) - cwd (str): current working directory + Routine to plot the number of iterations over time using switch estimator or not. Produces the + iters_<sweeper_class>.png file. + + Parameters + ---------- + dt_list : list + List of considered (initial) step sizes. + maxiter : int + Maximum number of iterations. + problem : pySDC.core.Problem.ptype + Problem class used to consider (the class name). + sweeper : pySDC.core.Sweeper.sweeper + Sweeper used to solve (the class name). + cwd : str + Current working directory. """ iters_time_SE = [] diff --git a/pySDC/projects/PinTSimE/estimation_check_2capacitors.py b/pySDC/projects/PinTSimE/estimation_check_2capacitors.py index 259b45ea9a2367cb3a9fc0a4c907c16c84732e0f..ac269d744c3e22cf18d5fa042585233a1bf6be17 100644 --- a/pySDC/projects/PinTSimE/estimation_check_2capacitors.py +++ b/pySDC/projects/PinTSimE/estimation_check_2capacitors.py @@ -3,42 +3,59 @@ import dill from pathlib import Path from pySDC.helpers.stats_helper import get_sorted -from pySDC.core.Collocation import CollBase as Collocation from pySDC.implementations.problem_classes.Battery import battery_n_capacitors from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI -from pySDC.projects.PinTSimE.battery_model import controller_run, generate_description, get_recomputed, log_data -from pySDC.projects.PinTSimE.piline_model import setup_mpl +from pySDC.projects.PinTSimE.battery_model import controller_run, generate_description, get_recomputed + from pySDC.projects.PinTSimE.battery_2capacitors_model import ( + LogEvent, check_solution, proof_assertions_description, proof_assertions_time, ) + +from pySDC.projects.PinTSimE.piline_model import setup_mpl import pySDC.helpers.plot_helper as plt_helper +from pySDC.implementations.hooks.log_solution import LogSolution + from pySDC.projects.PinTSimE.switch_estimator import SwitchEstimator def run(cwd='./'): """ - Routine to check the differences between using a switch estimator or not + Routine to check the differences between using a switch estimator or not. - Args: - cwd (str): current working directory + Parameters + ---------- + cwd : str + Current working directory. """ - dt_list = [4e-1, 4e-2, 4e-3] + dt_list = [4e-1, 4e-2] t0 = 0.0 Tend = 3.5 problem_classes = [battery_n_capacitors] sweeper_classes = [imex_1st_order] + num_nodes = 4 + restol = -1 + maxiter = 8 ncapacitors = 2 alpha = 5.0 V_ref = np.array([1.0, 1.0]) C = np.array([1.0, 1.0]) + problem_params = dict() + problem_params['ncapacitors'] = ncapacitors + problem_params['C'] = C + problem_params['alpha'] = alpha + problem_params['V_ref'] = V_ref + + hook_class = [LogSolution, LogEvent] + use_switch_estimator = [True, False] restarts_all = [] restarts_dict = dict() @@ -49,13 +66,15 @@ def run(cwd='./'): dt_item, problem, sweeper, - log_data, + num_nodes, + hook_class, False, use_SE, - ncapacitors, - alpha, - V_ref, - C, + problem_params, + restol, + maxiter, + 1, + 1e-8, ) # Assertions @@ -104,7 +123,6 @@ def run(cwd='./'): switches = get_recomputed(stats_true, type='switch', sortby='time') t_switch = [v[1] for v in switches] - val_switch_all.append([t_switch[0], t_switch[1]]) vC1_true = [me[1][1] for me in get_sorted(stats_true, type='u', recomputed=False)] @@ -120,9 +138,7 @@ def run(cwd='./'): t_true = [me[0] for me in get_sorted(stats_true, type='u', recomputed=False)] t_false = [me[0] for me in get_sorted(stats_false, type='u', recomputed=False)] - diff_true_all1.append( - [diff_true1[m] for m in range(len(t_true)) if np.isclose(t_true[m], t_switch[0], atol=1e-15)] - ) + diff_true_all1.append([diff_true1[m] for m in range(len(t_true)) if abs(t_true[m] - t_switch[0]) <= 1e-17]) diff_true_all2.append([diff_true2[np.argmin([abs(me - t_switch[1]) for me in t_true])]]) diff_false_all_before1.append( diff --git a/pySDC/projects/PinTSimE/switch_estimator.py b/pySDC/projects/PinTSimE/switch_estimator.py index 7cb4e4674d7bc61e8e6bf3abfff8f697c5feeded..d9059b91c115877e193fd1d39ed6c9349e4dfb5a 100644 --- a/pySDC/projects/PinTSimE/switch_estimator.py +++ b/pySDC/projects/PinTSimE/switch_estimator.py @@ -3,26 +3,32 @@ import scipy as sp from pySDC.core.Collocation import CollBase from pySDC.core.ConvergenceController import ConvergenceController, Status +from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence class SwitchEstimator(ConvergenceController): """ - Class to predict the time point of the switch and setting a new step size - - For the first time, this is a nonMPI version, because a MPI version is not yet developed. + Class to predict the time point of the event and setting a new step size. For the first time, this is a nonMPI version, + because a MPI version is not yet developed. """ def setup(self, controller, params, description): """ Function sets default variables to handle with the switch at the beginning. - Args: - controller (pySDC.Controller): The controller - params (dict): The params passed for this specific convergence controller - description (dict): The description object used to instantiate the controller - - Returns: - (dict): The updated params dictionary + Parameters + ---------- + controller : pySDC.Controller + The controller doing all the stuff in a computation. + params : dict + The parameters passed for this specific convergence controller. + description : dict + The description object used to instantiate the controller. + + Returns + ------- + convergence_controller_params : dict + The updated params dictionary. """ # for RK4 sweeper, sweep.coll.nodes now consists of values of ButcherTableau @@ -34,7 +40,6 @@ class SwitchEstimator(ConvergenceController): defaults = { 'control_order': 100, - 'tol': description['level_params']['dt'], 'nodes': coll.nodes, } return {**defaults, **params} @@ -43,97 +48,121 @@ class SwitchEstimator(ConvergenceController): """ Adds switching specific variables to status variables. - Args: - controller (pySDC.Controller): The controller + Parameters + ---------- + controller : pySDC.Controller + The controller doing all the stuff in a computation. """ - self.status = Status(['switch_detected', 't_switch']) + self.status = Status(['is_zero', 'switch_detected', 't_switch']) def reset_status_variables(self, controller, **kwargs): """ Resets status variables. - Args: - controller (pySDC.Controller): The controller + Parameters + ---------- + controller : pySDC.Controller + The controller doing all the stuff in a computation. """ self.setup_status_variables(controller, **kwargs) def get_new_step_size(self, controller, S, **kwargs): """ - Determine a new step size when a switch is found such that the switch happens at the time step. - - Args: - controller (pySDC.Controller): The controller - S (pySDC.Step): The current step - - Returns: - None + Determine a new step size when an event is found such that the event occurs at the time step. + + Parameters + ---------- + controller : pySDC.Controller + The controller doing all the stuff in a computation. + S : pySDC.Step + The current step. """ L = S.levels[0] - if S.status.iter == S.params.maxiter: - self.status.switch_detected, m_guess, vC_switch = L.prob.get_switching_info(L.u, L.time) + if CheckConvergence.check_convergence(S): + self.status.switch_detected, m_guess, state_function = L.prob.get_switching_info(L.u, L.time) if self.status.switch_detected: t_interp = [L.time + L.dt * self.params.nodes[m] for m in range(len(self.params.nodes))] - - # only find root if vc_switch[0], vC_switch[-1] have opposite signs (intermediate value theorem) - if vC_switch[0] * vC_switch[-1] < 0: - self.status.t_switch = self.get_switch(t_interp, vC_switch, m_guess) - - if L.time <= self.status.t_switch <= L.time + L.dt: + t_interp, state_function = self.adapt_interpolation_info( + L.time, L.sweep.coll.left_is_node, t_interp, state_function + ) + + # when the state function is already close to zero the event is already resolved well + if abs(state_function[-1]) <= self.params.tol: + self.log("Is already close enough to one of the end point!", S) + self.log_event_time( + controller.hooks[0], S.status.slot, L.time, L.level_index, L.status.sweep, t_interp[-1] + ) + L.prob.count_switches() + self.status.is_zero = True + + # intermediate value theorem states that a root is contained in current step + if state_function[0] * state_function[-1] < 0 and self.status.is_zero is None: + self.status.t_switch = self.get_switch(t_interp, state_function, m_guess) + + if L.time < self.status.t_switch < L.time + L.dt: dt_switch = self.status.t_switch - L.time - if not np.isclose(self.status.t_switch - L.time, L.dt, atol=self.params.tol): - self.log( - f"Located Switch at time {self.status.t_switch:.6f} is outside the range of tol={self.params.tol:.4e}", - S, - ) - - else: - self.log( - f"Switch located at time {self.status.t_switch:.6f} inside tol={self.params.tol:.4e}", S - ) + if ( + abs(self.status.t_switch - L.time) <= self.params.tol + or abs((L.time + L.dt) - self.status.t_switch) <= self.params.tol + ): + self.log(f"Switch located at time {self.status.t_switch:.12f}", S) L.prob.t_switch = self.status.t_switch - controller.hooks[0].add_to_stats( - process=S.status.slot, - time=L.time, - level=L.level_index, - iter=0, - sweep=L.status.sweep, - type='switch', - value=self.status.t_switch, + self.log_event_time( + controller.hooks[0], + S.status.slot, + L.time, + L.level_index, + L.status.sweep, + self.status.t_switch, ) L.prob.count_switches() - dt_planned = L.status.dt_new if L.status.dt_new is not None else L.params.dt + else: + self.log(f"Located Switch at time {self.status.t_switch:.12f} is outside the range", S) - # when a switch is found, time step to match with switch should be preferred + # when an event is found, step size matching with this event should be preferred + dt_planned = L.status.dt_new if L.status.dt_new is not None else L.params.dt if self.status.switch_detected: L.status.dt_new = dt_switch - else: L.status.dt_new = min([dt_planned, dt_switch]) else: + # event occurs on L.time or L.time + L.dt; no restart necessary + boundary = 'left boundary' if self.status.t_switch == L.time else 'right boundary' + self.log(f"Estimated switch {self.status.t_switch:.12f} occurs at {boundary}", S) + + self.log_event_time( + controller.hooks[0], + S.status.slot, + L.time, + L.level_index, + L.status.sweep, + self.status.t_switch, + ) + L.prob.count_switches() self.status.switch_detected = False - else: + else: # intermediate value theorem is not satisfied self.status.switch_detected = False def determine_restart(self, controller, S, **kwargs): """ Check if the step needs to be restarted due to a predicting switch. - Args: - controller (pySDC.Controller): The controller - S (pySDC.Step): The current step - - Returns: - None + Parameters + ---------- + controller : pySDC.Controller + The controller doing all the stuff in a computation. + S : pySDC.Step + The current step. """ if self.status.switch_detected: @@ -147,12 +176,12 @@ class SwitchEstimator(ConvergenceController): After a step is done, some variables will be prepared for predicting a possibly new switch. If no Adaptivity is used, the next time step will be set as the default one from the front end. - Args: - controller (pySDC.Controller): The controller - S (pySDC.Step): The current step - - Returns: - None + Parameters + ---------- + controller : pySDC.Controller + The controller doing all the stuff in a computation. + S : pySDC.Step + The current step. """ L = S.levels[0] @@ -163,28 +192,168 @@ class SwitchEstimator(ConvergenceController): super().post_step_processing(controller, S, **kwargs) @staticmethod - def get_switch(t_interp, vC_switch, m_guess): + def log_event_time(controller_hooks, process, time, level, sweep, t_switch): + """ + Logs the event time of an event satisfying an appropriate criterion, e.g., event is already resolved well, + event time satisfies tolerance. + + Parameters + ---------- + controller_hooks : pySDC.Controller.hooks + Controller with access to the hooks. + process : int + Process for logging. + time : float + Time at which the event time is logged (denotes the current step). + level : int + Level at which event is found. + sweep : int + Denotes the number of sweep. + t_switch : float + Event time founded by switch estimation. """ - Routine to do the interpolation and root finding stuff. - Args: - t_interp (list): collocation nodes in a step - vC_switch (list): differences vC - V_ref at these collocation nodes - m_guess (np.float): Index at which the difference drops below zero + controller_hooks.add_to_stats( + process=process, + time=time, + level=level, + iter=0, + sweep=sweep, + type='switch', + value=t_switch, + ) - Returns: - t_switch (np.float): time point of th switch + @staticmethod + def get_switch(t_interp, state_function, m_guess): """ + Routine to do the interpolation and root finding stuff. - p = sp.interpolate.interp1d(t_interp, vC_switch, 'cubic', bounds_error=False) + Parameters + ---------- + t_interp : list + Collocation nodes in a step. + state_function : list + Contains values of state function at these collocation nodes. + m_guess : float + Index at which the difference drops below zero. + + Returns + ------- + t_switch : float + Time point of the founded switch. + """ - SwitchResults = sp.optimize.root_scalar( - p, - method='brentq', - bracket=[t_interp[0], t_interp[m_guess]], - x0=t_interp[m_guess], - xtol=1e-10, - ) - t_switch = SwitchResults.root + Interpolator = sp.interpolate.BarycentricInterpolator(t_interp, state_function) + + def p(t): + """ + Simplifies the call of the interpolant. + + Parameters + ---------- + t : float + Time t at which the interpolant is called. + + Returns + ------- + p(t) : float + The value of the interpolated function at time t. + """ + return Interpolator.__call__(t) + + def fprime(t): + """ + Computes the derivative of the scalar interpolant using finite differences. + + Parameters + ---------- + t : float + Time where the derivatives is computed. + + Returns + ------- + dp : float + Derivative of interpolation p at time t. + """ + dt = 1e-8 + dp = (p(t + dt) - p(t)) / dt + return dp + + newton_tol, newton_maxiter = 1e-8, 50 + t_switch = newton(t_interp[m_guess], p, fprime, newton_tol, newton_maxiter) return t_switch + + @staticmethod + def adapt_interpolation_info(t, left_is_node, t_interp, state_function): + """ + Adapts the x- and y-axis for interpolation. For SDC, it is proven whether the left boundary is a + collocation node or not. In case it is, the first entry of the state function has to be removed, + because it would otherwise contain double values on starting time and the first node. Otherwise, + starting time L.time has to be added to t_interp to also take this value in the interpolation + into account. + + Parameters + ---------- + t : float + Starting time of the step. + left_is_node : bool + Indicates whether the left boundary is a collocation node or not. + t_interp : list + x-values for interpolation containing collocation nodes. + state_function : list + y-values for interpolation containing values of state function. + + Returns + ------- + t_interp : list + Adapted x-values for interpolation containing collocation nodes. + state_function : list + Adapted y-values for interpolation containing values of state function. + """ + + if not left_is_node: + t_interp.insert(0, t) + else: + del state_function[0] + + return t_interp, state_function + + +def newton(x0, p, fprime, newton_tol, newton_maxiter): + """ + Newton's method fo find the root of interpolant p. + + Parameters + ---------- + x0 : float + Initial guess. + p : callable + Interpolated function where Newton's method is applied at. + fprime : callable + Approximated erivative of p using finite differences. + newton_tol : float + Tolerance for termination. + newton_maxiter : int + Maximum of iterations the method should execute. + + Returns + ------- + root : float + Root of function p. + """ + + n = 0 + while n < newton_maxiter: + if abs(p(x0)) < newton_tol or np.isnan(p(x0)) and np.isnan(fprime(x0)): + break + + x0 -= 1.0 / fprime(x0) * p(x0) + + n += 1 + + root = x0 + msg = "Newton's method took {} iterations".format(n) + print(msg) + + return root diff --git a/pySDC/tests/test_projects/test_pintsime/test_discontinuous_test_ODE.py b/pySDC/tests/test_projects/test_pintsime/test_discontinuous_test_ODE.py new file mode 100644 index 0000000000000000000000000000000000000000..73828487ffc407b1c918d068e0ba6aea58abf2a0 --- /dev/null +++ b/pySDC/tests/test_projects/test_pintsime/test_discontinuous_test_ODE.py @@ -0,0 +1,8 @@ +import pytest + + +@pytest.mark.base +def test_main(): + from pySDC.projects.PinTSimE.discontinuous_test_ODE import main + + main()