diff --git a/pySDC/implementations/problem_classes/Battery.py b/pySDC/implementations/problem_classes/Battery.py index 673c22233ed0a48f92541ccedfbe9237c8ef679c..b1629811f8b103960896cb0d26d429ef4469e512 100644 --- a/pySDC/implementations/problem_classes/Battery.py +++ b/pySDC/implementations/problem_classes/Battery.py @@ -5,47 +5,63 @@ from pySDC.core.Problem import ptype from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh -class battery(ptype): +class battery_n_capacitors(ptype): """ - Example implementing the battery drain model as in the description in the PinTSimE project + Example implementing the battery drain model with N capacitors, where N is an arbitrary integer greater than 0. Attributes: - A: system matrix, representing the 2 ODEs + nswitches: number of switches """ def __init__(self, problem_params, dtype_u=mesh, dtype_f=imex_mesh): """ Initialization routine + Args: problem_params (dict): custom parameters for the example dtype_u: mesh data type for solution dtype_f: mesh data type for RHS """ - problem_params['nvars'] = 2 - # these parameters will be used later, so assert their existence - essential_keys = ['Vs', 'Rs', 'C', 'R', 'L', 'alpha', 'V_ref', 'set_switch', 't_switch'] + essential_keys = ['ncapacitors', 'Vs', 'Rs', 'C', 'R', 'L', 'alpha', 'V_ref'] for key in essential_keys: if key not in problem_params: msg = 'need %s to instantiate problem, only got %s' % (key, str(problem_params.keys())) raise ParameterError(msg) + n = problem_params['ncapacitors'] + problem_params['nvars'] = n + 1 + # invoke super init, passing number of dofs, dtype_u and dtype_f - super(battery, self).__init__( + super(battery_n_capacitors, self).__init__( init=(problem_params['nvars'], None, np.dtype('float64')), dtype_u=dtype_u, dtype_f=dtype_f, params=problem_params, ) - self.A = np.zeros((2, 2)) + self.A = np.zeros((n + 1, n + 1)) + self.switch_A, self.switch_f = self.get_problem_dict() + self.t_switch = None + self.nswitches = 0 def eval_f(self, u, t): """ - Routine to evaluate the RHS + Routine to evaluate the RHS. No Switch Estimator is used: For N = 3 there are N + 1 = 4 different states of the battery: + 1. u[1] > V_ref[0] and u[2] > V_ref[1] and u[3] > V_ref[2] -> C1 supplies energy + 2. u[1] <= V_ref[0] and u[2] > V_ref[1] and u[3] > V_ref[2] -> C2 supplies energy + 3. u[1] <= V_ref[0] and u[2] <= V_ref[1] and u[3] > V_ref[2] -> C3 supplies energy + 4. u[1] <= V_ref[0] and u[2] <= V_ref[1] and u[3] <= V_ref[2] -> Vs supplies energy + max_index is initialized to -1. List "switch" contains a True if u[k] <= V_ref[k-1] is satisfied. + - Is no True there (i.e. max_index = -1), we are in the first case. + - max_index = k >= 0 means we are in the (k+1)-th case. + So, the actual RHS has key max_index-1 in the dictionary self.switch_f. + In case of using the Switch Estimator, we count the number of switches which illustrates in which case of voltage source we are. + Args: u (dtype_u): current values t (float): current time + Returns: dtype_f: the RHS """ @@ -53,50 +69,48 @@ class battery(ptype): f = self.dtype_f(self.init, val=0.0) f.impl[:] = self.A.dot(u) - if u[1] <= self.params.V_ref or self.params.set_switch: - # switching need to happen on exact time point - if self.params.set_switch: - if t >= self.params.t_switch: - f.expl[0] = self.params.Vs / self.params.L + if self.t_switch is not None: + f.expl[:] = self.switch_f[self.nswitches] - else: - f.expl[0] = 0 + else: + # proof all switching conditions and find largest index where it drops below V_ref + switch = [True if u[k] <= self.params.V_ref[k - 1] else False for k in range(1, len(u))] + max_index = max([k if switch[k] == True else -1 for k in range(len(switch))]) - else: - f.expl[0] = self.params.Vs / self.params.L + if max_index == -1: + f.expl[:] = self.switch_f[0] - else: - f.expl[0] = 0 + else: + f.expl[:] = self.switch_f[max_index + 1] return f def solve_system(self, rhs, factor, u0, t): """ Simple linear solver for (I-factor*A)u = rhs + Args: rhs (dtype_f): right-hand side for the linear system factor (float): abbrev. for the local stepsize (or any other factor required) u0 (dtype_u): initial guess for the iterative solver t (float): current time (e.g. for time-dependent BCs) + Returns: dtype_u: solution as mesh """ - self.A = np.zeros((2, 2)) - if rhs[1] <= self.params.V_ref or self.params.set_switch: - # switching need to happen on exact time point - if self.params.set_switch: - if t >= self.params.t_switch: - self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L + if self.t_switch is not None: + self.A = self.switch_A[self.nswitches] - else: - self.A[1, 1] = -1 / (self.params.C * self.params.R) + else: + # proof all switching conditions and find largest index where it drops below V_ref + switch = [True if rhs[k] <= self.params.V_ref[k - 1] else False for k in range(1, len(rhs))] + max_index = max([k if switch[k] == True else -1 for k in range(len(switch))]) + if max_index == -1: + self.A = self.switch_A[0] else: - self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L - - else: - self.A[1, 1] = -1 / (self.params.C * self.params.R) + self.A = self.switch_A[max_index + 1] me = self.dtype_u(self.init) me[:] = np.linalg.solve(np.eye(self.params.nvars) - factor * self.A, rhs) @@ -105,8 +119,10 @@ class battery(ptype): def u_exact(self, t): """ Routine to compute the exact solution at time t + Args: t (float): current time + Returns: dtype_u: exact solution """ @@ -115,57 +131,168 @@ class battery(ptype): me = self.dtype_u(self.init) me[0] = 0.0 # cL - me[1] = self.params.alpha * self.params.V_ref # vC - + me[1:] = self.params.alpha * self.params.V_ref # vC's return me + def get_switching_info(self, u, t): + """ + Provides information about a discrete event for one subinterval. + + Args: + u (dtype_u): current values + t (float): current time + + Returns: + switch_detected (bool): Indicates if a switch is found or not + m_guess (np.int): Index of collocation node inside one subinterval of where the discrete event was found + vC_switch (list): Contains function values of switching condition (for interpolation) + """ + + switch_detected = False + m_guess = -100 + break_flag = False + + for m in range(len(u)): + for k in range(1, self.params.nvars): + if u[m][k] - self.params.V_ref[k - 1] <= 0: + switch_detected = True + m_guess = m - 1 + k_detected = k + break_flag = True + break + + if break_flag: + break -class battery_implicit(ptype): + vC_switch = ( + [u[m][k_detected] - self.params.V_ref[k_detected - 1] for m in range(1, len(u))] if switch_detected else [] + ) + + return switch_detected, m_guess, vC_switch + + def count_switches(self): + """ + Counts the number of switches. This function is called when a switch is found inside the range of tolerance + (in switch_estimator.py) + """ + + self.nswitches += 1 + + def get_problem_dict(self): + """ + Helper to create dictionaries for both the coefficent matrix of the ODE system and the nonhomogeneous part. + """ + + n = self.params.ncapacitors + v = np.zeros(n + 1) + v[0] = 1 + + A, f = dict(), dict() + A = {k: np.diag(-1 / (self.params.C[k] * self.params.R) * np.roll(v, k + 1)) for k in range(n)} + A.update({n: np.diag(-(self.params.Rs + self.params.R) / self.params.L * v)}) + f = {k: np.zeros(n + 1) for k in range(n)} + f.update({n: self.params.Vs / self.params.L * v}) + return A, f + + +class battery(battery_n_capacitors): """ - Example implementing the battery drain model as in the description in the PinTSimE project - Attributes: - A: system matrix, representing the 2 ODEs + Example implementing the battery drain model with one capacitor, inherits from battery_n_capacitors. """ - def __init__(self, problem_params, dtype_u=mesh, dtype_f=mesh): + def __init__(self, problem_params, dtype_u=mesh, dtype_f=imex_mesh): + super(battery, self).__init__(problem_params, dtype_u=dtype_u, dtype_f=dtype_f) + + def eval_f(self, u, t): """ - Initialization routine + Routine to evaluate the RHS + Args: - problem_params (dict): custom parameters for the example - dtype_u: mesh data type for solution - dtype_f: mesh data type for RHS + u (dtype_u): current values + t (float): current time + + Returns: + dtype_f: the RHS """ - problem_params['nvars'] = 2 + f = self.dtype_f(self.init, val=0.0) + f.impl[:] = self.A.dot(u) + + t_switch = np.inf if self.t_switch is None else self.t_switch - # these parameters will be used later, so assert their existence - essential_keys = [ - 'newton_maxiter', - 'newton_tol', - 'Vs', - 'Rs', - 'C', - 'R', - 'L', - 'alpha', - 'V_ref', - 'set_switch', - 't_switch', - ] + if u[1] <= self.params.V_ref[0] or t >= t_switch: + f.expl[0] = self.params.Vs / self.params.L + + else: + f.expl[0] = 0 + + return f + + def solve_system(self, rhs, factor, u0, t): + """ + Simple linear solver for (I-factor*A)u = rhs + + Args: + rhs (dtype_f): right-hand side for the linear system + factor (float): abbrev. for the local stepsize (or any other factor required) + u0 (dtype_u): initial guess for the iterative solver + t (float): current time (e.g. for time-dependent BCs) + + Returns: + dtype_u: solution as mesh + """ + self.A = np.zeros((2, 2)) + + t_switch = np.inf if self.t_switch is None else self.t_switch + + if rhs[1] <= self.params.V_ref[0] or t >= t_switch: + self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L + + else: + self.A[1, 1] = -1 / (self.params.C[0] * self.params.R) + + me = self.dtype_u(self.init) + me[:] = np.linalg.solve(np.eye(self.params.nvars) - factor * self.A, rhs) + return me + + def u_exact(self, t): + """ + Routine to compute the exact solution at time t + + Args: + t (float): current time + + Returns: + dtype_u: exact solution + """ + assert t == 0, 'ERROR: u_exact only valid for t=0' + + me = self.dtype_u(self.init) + + me[0] = 0.0 # cL + me[1] = self.params.alpha * self.params.V_ref[0] # vC + + return me + + +class battery_implicit(battery): + def __init__(self, problem_params, dtype_u=mesh, dtype_f=mesh): + + essential_keys = ['newton_maxiter', 'newton_tol', 'ncapacitors', 'Vs', 'Rs', 'C', 'R', 'L', 'alpha', 'V_ref'] for key in essential_keys: if key not in problem_params: msg = 'need %s to instantiate problem, only got %s' % (key, str(problem_params.keys())) raise ParameterError(msg) + problem_params['nvars'] = problem_params['ncapacitors'] + 1 + # invoke super init, passing number of dofs, dtype_u and dtype_f super(battery_implicit, self).__init__( - init=(problem_params['nvars'], None, np.dtype('float64')), + problem_params, dtype_u=dtype_u, dtype_f=dtype_f, - params=problem_params, ) - self.A = np.zeros((2, 2)) self.newton_itercount = 0 self.lin_itercount = 0 self.newton_ncalls = 0 @@ -174,9 +301,11 @@ class battery_implicit(ptype): def eval_f(self, u, t): """ Routine to evaluate the RHS + Args: u (dtype_u): current values t (float): current time + Returns: dtype_f: the RHS """ @@ -184,37 +313,29 @@ class battery_implicit(ptype): f = self.dtype_f(self.init, val=0.0) non_f = np.zeros(2) - if u[1] <= self.params.V_ref or self.params.set_switch: - # switching need to happen on exact time point - if self.params.set_switch: - if t >= self.params.t_switch: - self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L - non_f[0] = self.params.Vs / self.params.L + t_switch = np.inf if self.t_switch is None else self.t_switch - else: - self.A[1, 1] = -1 / (self.params.C * self.params.R) - non_f[0] = 0 - - else: - self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L - non_f[0] = self.params.Vs / self.params.L + if u[1] <= self.params.V_ref[0] or t >= t_switch: + self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L + non_f[0] = self.params.Vs else: - self.A[1, 1] = -1 / (self.params.C * self.params.R) + self.A[1, 1] = -1 / (self.params.C[0] * self.params.R) non_f[0] = 0 f[:] = self.A.dot(u) + non_f - return f def solve_system(self, rhs, factor, u0, t): """ Simple Newton solver + Args: rhs (dtype_f): right-hand side for the linear system factor (float): abbrev. for the local stepsize (or any other factor required) u0 (dtype_u): initial guess for the iterative solver t (float): current time (e.g. for time-dependent BCs) + Returns: dtype_u: solution as mesh """ @@ -223,23 +344,14 @@ class battery_implicit(ptype): non_f = np.zeros(2) self.A = np.zeros((2, 2)) - if rhs[1] <= self.params.V_ref or self.params.set_switch: - # switching need to happen on exact time point - if self.params.set_switch: - if t >= self.params.t_switch: - self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L - non_f[0] = self.params.Vs / self.params.L + t_switch = np.inf if self.t_switch is None else self.t_switch - else: - self.A[1, 1] = -1 / (self.params.C * self.params.R) - non_f[0] = 0 - - else: - self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L - non_f[0] = self.params.Vs / self.params.L + if rhs[1] <= self.params.V_ref[0] or t >= t_switch: + self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L + non_f[0] = self.params.Vs else: - self.A[1, 1] = -1 / (self.params.C * self.params.R) + self.A[1, 1] = -1 / (self.params.C[0] * self.params.R) non_f[0] = 0 # start newton iteration @@ -279,20 +391,3 @@ class battery_implicit(ptype): me[:] = u[:] return me - - def u_exact(self, t): - """ - Routine to compute the exact solution at time t - Args: - t (float): current time - Returns: - dtype_u: exact solution - """ - assert t == 0, 'ERROR: u_exact only valid for t=0' - - me = self.dtype_u(self.init) - - me[0] = 0.0 # cL - me[1] = self.params.alpha * self.params.V_ref # vC - - return me diff --git a/pySDC/implementations/problem_classes/Battery_2Condensators.py b/pySDC/implementations/problem_classes/Battery_2Condensators.py deleted file mode 100644 index 1c5eb55030426368d2ceebbdf2b3e1e153a92f00..0000000000000000000000000000000000000000 --- a/pySDC/implementations/problem_classes/Battery_2Condensators.py +++ /dev/null @@ -1,168 +0,0 @@ -import numpy as np - -from pySDC.core.Errors import ParameterError -from pySDC.core.Problem import ptype -from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh - - -class battery_2condensators(ptype): - """ - Example implementing the battery drain model using two capacitors as in the description in the PinTSimE - project - Attributes: - A: system matrix, representing the 3 ODEs - """ - - def __init__(self, problem_params, dtype_u=mesh, dtype_f=imex_mesh): - """ - Initialization routine - Args: - problem_params (dict): custom parameters for the example - dtype_u: mesh data type for solution - dtype_f: mesh data type for RHS - """ - - problem_params['nvars'] = 3 - - # these parameters will be used later, so assert their existence - essential_keys = ['Vs', 'Rs', 'C1', 'C2', 'R', 'L', 'alpha', 'V_ref', 'set_switch', 't_switch'] - - for key in essential_keys: - if key not in problem_params: - msg = 'need %s to instantiate problem, only got %s' % (key, str(problem_params.keys())) - raise ParameterError(msg) - - # invoke super init, passing number of dofs, dtype_u and dtype_f - super(battery_2condensators, self).__init__( - init=(problem_params['nvars'], None, np.dtype('float64')), - dtype_u=dtype_u, - dtype_f=dtype_f, - params=problem_params, - ) - - self.A = np.zeros((3, 3)) - - def eval_f(self, u, t): - """ - Routine to evaluate the RHS - Args: - u (dtype_u): current values - t (float): current time - Returns: - dtype_f: the RHS - """ - - f = self.dtype_f(self.init, val=0.0) - f.impl[:] = self.A.dot(u) - - # switch to C2 - if ( - u[1] <= self.params.V_ref[0] - and u[2] > self.params.V_ref[1] - or self.params.set_switch[0] - and not self.params.set_switch[1] - ): - if self.params.set_switch[0]: - if t >= self.params.t_switch[0]: - f.expl[0] = 0 - - else: - f.expl[0] = 0 - - else: - f.expl[0] = 0 - - # switch to Vs - elif u[2] <= self.params.V_ref[1] or (self.params.set_switch[0] and self.params.set_switch[1]): - # switch to Vs - if self.params.set_switch[1]: - if t >= self.params.t_switch[1]: - f.expl[0] = self.params.Vs / self.params.L - - else: - f.expl[0] = 0 - - else: - f.expl[0] = self.params.Vs / self.params.L - - elif ( - u[1] > self.params.V_ref[0] - and u[2] > self.params.V_ref[1] - or not self.params.set_switch[0] - and not self.params.set_switch[1] - ): - # C1 supplies energy - f.expl[0] = 0 - - return f - - def solve_system(self, rhs, factor, u0, t): - """ - Simple linear solver for (I-factor*A)u = rhs - Args: - rhs (dtype_f): right-hand side for the linear system - factor (float): abbrev. for the local stepsize (or any other factor required) - u0 (dtype_u): initial guess for the iterative solver - t (float): current time (e.g. for time-dependent BCs) - Returns: - dtype_u: solution as mesh - """ - self.A = np.zeros((3, 3)) - - # switch to C2 - if ( - rhs[1] <= self.params.V_ref[0] - and rhs[2] > self.params.V_ref[1] - or self.params.set_switch[0] - and not self.params.set_switch[1] - ): - if self.params.set_switch[0]: - if t >= self.params.t_switch[0]: - self.A[2, 2] = -1 / (self.params.C2 * self.params.R) - - else: - self.A[1, 1] = -1 / (self.params.C1 * self.params.R) - else: - self.A[2, 2] = -1 / (self.params.C2 * self.params.R) - - # switch to Vs - elif rhs[2] <= self.params.V_ref[1] or (self.params.set_switch[0] and self.params.set_switch[1]): - if self.params.set_switch[1]: - if t >= self.params.t_switch[1]: - self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L - - else: - self.A[2, 2] = -1 / (self.params.C2 * self.params.R) - - else: - self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L - - elif ( - rhs[1] > self.params.V_ref[0] - and rhs[2] > self.params.V_ref[1] - or not self.params.set_switch[0] - and not self.params.set_switch[1] - ): - # C1 supplies energy - self.A[1, 1] = -1 / (self.params.C1 * self.params.R) - - me = self.dtype_u(self.init) - me[:] = np.linalg.solve(np.eye(self.params.nvars) - factor * self.A, rhs) - return me - - def u_exact(self, t): - """ - Routine to compute the exact solution at time t - Args: - t (float): current time - Returns: - dtype_u: exact solution - """ - assert t == 0, 'ERROR: u_exact only valid for t=0' - - me = self.dtype_u(self.init) - - me[0] = 0.0 # cL - me[1] = self.params.alpha * self.params.V_ref[0] # vC1 - me[2] = self.params.alpha * self.params.V_ref[1] # vC2 - return me diff --git a/pySDC/projects/PinTSimE/battery_2capacitors_model.py b/pySDC/projects/PinTSimE/battery_2capacitors_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d1401c16862ed03ae8979ba32d9dad7675b2226f --- /dev/null +++ b/pySDC/projects/PinTSimE/battery_2capacitors_model.py @@ -0,0 +1,230 @@ +import numpy as np +import dill +from pathlib import Path + +from pySDC.helpers.stats_helper import get_sorted +from pySDC.core.Collocation import CollBase as Collocation +from pySDC.implementations.problem_classes.Battery import battery_n_capacitors +from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order +from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI +from pySDC.projects.PinTSimE.battery_model import ( + controller_run, + generate_description, + get_recomputed, + log_data, + proof_assertions_description, +) +from pySDC.projects.PinTSimE.piline_model import setup_mpl +import pySDC.helpers.plot_helper as plt_helper +from pySDC.core.Hooks import hooks + +from pySDC.projects.PinTSimE.switch_estimator import SwitchEstimator + + +def run(): + """ + Executes the simulation for the battery model using the IMEX sweeper and plot the results + as <problem_class>_model_solution_<sweeper_class>.png + """ + + dt = 1e-2 + t0 = 0.0 + Tend = 3.5 + + problem_classes = [battery_n_capacitors] + sweeper_classes = [imex_1st_order] + + ncapacitors = 2 + alpha = 5.0 + V_ref = np.array([1.0, 1.0]) + C = np.array([1.0, 1.0]) + + recomputed = False + use_switch_estimator = [True] + + for problem, sweeper in zip(problem_classes, sweeper_classes): + for use_SE in use_switch_estimator: + description, controller_params = generate_description( + dt, problem, sweeper, log_data, False, use_SE, ncapacitors, alpha, V_ref, C + ) + + # Assertions + proof_assertions_description(description, False, use_SE) + + proof_assertions_time(dt, Tend, V_ref, alpha) + + stats = controller_run(description, controller_params, False, use_SE, t0, Tend) + + check_solution(stats, dt, use_SE) + + plot_voltages(description, problem.__name__, sweeper.__name__, recomputed, use_SE, False) + + +def plot_voltages(description, problem, sweeper, recomputed, use_switch_estimator, use_adaptivity, cwd='./'): + """ + Routine to plot the numerical solution of the model + + Args: + description(dict): contains all information for a controller run + problem (pySDC.core.Problem.ptype): problem class that wants to be simulated + sweeper (pySDC.core.Sweeper.sweeper): sweeper class for solving the problem class numerically + recomputed (bool): flag if the values after a restart are used or before + use_switch_estimator (bool): flag if the switch estimator wants to be used or not + use_adaptivity (bool): flag if adaptivity wants to be used or not + cwd (str): current working directory + """ + + f = open(cwd + 'data/{}_{}_USE{}_USA{}.dat'.format(problem, sweeper, use_switch_estimator, use_adaptivity), 'rb') + stats = dill.load(f) + f.close() + + # convert filtered statistics to list of iterations count, sorted by process + cL = np.array([me[1][0] for me in get_sorted(stats, type='u', recomputed=recomputed)]) + vC1 = np.array([me[1][1] for me in get_sorted(stats, type='u', recomputed=recomputed)]) + vC2 = np.array([me[1][2] for me in get_sorted(stats, type='u', recomputed=recomputed)]) + + t = np.array([me[0] for me in get_sorted(stats, type='u', recomputed=recomputed)]) + + setup_mpl() + fig, ax = plt_helper.plt.subplots(1, 1, figsize=(4.5, 3)) + ax.plot(t, cL, label='$i_L$') + ax.plot(t, vC1, label='$v_{C_1}$') + ax.plot(t, vC2, label='$v_{C_2}$') + + if use_switch_estimator: + switches = get_recomputed(stats, type='switch', sortby='time') + if recomputed is not None: + assert len(switches) >= 2, f"Expected at least 2 switches, got {len(switches)}!" + t_switches = [v[1] for v in switches] + + for i in range(len(t_switches)): + ax.axvline(x=t_switches[i], linestyle='--', color='k', label='Switch {}'.format(i + 1)) + + ax.legend(frameon=False, fontsize=12, loc='upper right') + + ax.set_xlabel('Time') + ax.set_ylabel('Energy') + + fig.savefig('data/battery_2capacitors_model_solution.png', dpi=300, bbox_inches='tight') + plt_helper.plt.close(fig) + + +def check_solution(stats, dt, use_switch_estimator): + """ + Function that checks the solution based on a hardcoded reference solution. Based on check_solution function from @brownbaerchen. + + Args: + stats (dict): Raw statistics from a controller run + dt (float): initial time step + use_switch_estimator (bool): flag if the switch estimator wants to be used or not + """ + + data = get_data_dict(stats, use_switch_estimator) + + if use_switch_estimator: + msg = f'Error when using the switch estimator for battery_2condensators for dt={dt:.1e}:' + if dt == 1e-2: + expected = { + 'cL': 1.2065280755094876, + 'vC1': 1.0094825899806945, + 'vC2': 1.0050052828742688, + 'switch1': 1.6094379124373626, + 'switch2': 3.209437912457051, + 'restarts': 2.0, + 'sum_niters': 1568, + } + elif dt == 4e-1: + expected = { + 'cL': 1.1842780233981391, + 'vC1': 1.0094891393319418, + 'vC2': 1.00103823232433, + 'switch1': 1.6075867934844466, + 'switch2': 3.209437912436633, + 'restarts': 2.0, + 'sum_niters': 2000, + } + elif dt == 4e-2: + expected = { + 'cL': 1.180493652021971, + 'vC1': 1.0094825917376264, + 'vC2': 1.0007713468084405, + 'switch1': 1.6094074085553605, + 'switch2': 3.209437912440314, + 'restarts': 2.0, + 'sum_niters': 2364, + } + elif dt == 4e-3: + expected = { + 'cL': 1.1537529501025199, + 'vC1': 1.001438946726028, + 'vC2': 1.0004331625246141, + 'switch1': 1.6093728710270467, + 'switch2': 3.217437912434171, + 'restarts': 2.0, + 'sum_niters': 8920, + } + + got = { + 'cL': data['cL'][-1], + 'vC1': data['vC1'][-1], + 'vC2': data['vC2'][-1], + 'switch1': data['switch1'], + 'switch2': data['switch2'], + 'restarts': data['restarts'], + 'sum_niters': data['sum_niters'], + } + + for key in expected.keys(): + assert np.isclose( + expected[key], got[key], rtol=1e-4 + ), f'{msg} Expected {key}={expected[key]:.4e}, got {key}={got[key]:.4e}' + + +def get_data_dict(stats, use_switch_estimator, recomputed=False): + """ + Converts the statistics in a useful data dictionary so that it can be easily checked in the check_solution function. + Based on @brownbaerchen's get_data function. + + Args: + stats (dict): Raw statistics from a controller run + use_switch_estimator (bool): flag if the switch estimator wants to be used or not + recomputed (bool): flag if the values after a restart are used or before + + Return: + data (dict): contains all information as the statistics dict + """ + + data = dict() + data['cL'] = np.array([me[1][0] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')]) + data['vC1'] = np.array([me[1][1] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')]) + data['vC2'] = np.array([me[1][2] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')]) + data['switch1'] = np.array(get_recomputed(stats, type='switch', sortby='time'))[0, 1] + data['switch2'] = np.array(get_recomputed(stats, type='switch', sortby='time'))[-1, 1] + data['restarts'] = np.sum(np.array(get_sorted(stats, type='restart', recomputed=None, sortby='time'))[:, 1]) + data['sum_niters'] = np.sum(np.array(get_sorted(stats, type='niter', recomputed=None, sortby='time'))[:, 1]) + + return data + + +def proof_assertions_time(dt, Tend, V_ref, alpha): + """ + Function to proof the assertions regarding the time domain (in combination with the specific problem): + + Args: + dt (float): time step for computation + Tend (float): end time + V_ref (np.ndarray): Reference values (problem parameter) + alpha (np.float): Multiple used for initial conditions (problem_parameter) + """ + + assert ( + Tend == 3.5 and V_ref[0] == 1.0 and V_ref[1] == 1.0 and alpha == 5.0 + ), "Error! Do not use other parameters for V_ref[:] != 1.0, alpha != 1.2, Tend != 0.3 due to hardcoded reference!" + + assert ( + dt == 1e-2 or dt == 4e-1 or dt == 4e-2 or dt == 4e-3 + ), "Error! Do not use other time steps dt != 4e-1 or dt != 4e-2 or dt != 4e-3 due to hardcoded references!" + + +if __name__ == "__main__": + run() diff --git a/pySDC/projects/PinTSimE/battery_2condensators_model.py b/pySDC/projects/PinTSimE/battery_2condensators_model.py deleted file mode 100644 index c41cb8271236f26c8c7c6f827a0a2853e1154c91..0000000000000000000000000000000000000000 --- a/pySDC/projects/PinTSimE/battery_2condensators_model.py +++ /dev/null @@ -1,241 +0,0 @@ -import numpy as np -import dill -from pathlib import Path - -from pySDC.helpers.stats_helper import get_sorted -from pySDC.core.Collocation import CollBase as Collocation -from pySDC.implementations.problem_classes.Battery_2Condensators import battery_2condensators -from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order -from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI -from pySDC.implementations.transfer_classes.TransferMesh import mesh_to_mesh -from pySDC.projects.PinTSimE.piline_model import setup_mpl -import pySDC.helpers.plot_helper as plt_helper -from pySDC.core.Hooks import hooks - -from pySDC.projects.PinTSimE.switch_estimator import SwitchEstimator - - -class log_data(hooks): - def post_step(self, step, level_number): - - super(log_data, self).post_step(step, level_number) - - # some abbreviations - L = step.levels[level_number] - - L.sweep.compute_end_point() - - self.add_to_stats( - process=step.status.slot, - time=L.time + L.dt, - level=L.level_index, - iter=0, - sweep=L.status.sweep, - type='current L', - value=L.uend[0], - ) - self.add_to_stats( - process=step.status.slot, - time=L.time + L.dt, - level=L.level_index, - iter=0, - sweep=L.status.sweep, - type='voltage C1', - value=L.uend[1], - ) - self.add_to_stats( - process=step.status.slot, - time=L.time + L.dt, - level=L.level_index, - iter=0, - sweep=L.status.sweep, - type='voltage C2', - value=L.uend[2], - ) - self.increment_stats( - process=step.status.slot, - time=L.time, - level=L.level_index, - iter=0, - sweep=L.status.sweep, - type='restart', - value=1, - initialize=0, - ) - - -def main(use_switch_estimator=True): - """ - A simple test program to do SDC/PFASST runs for the battery drain model using 2 condensators - """ - - # initialize level parameters - level_params = dict() - level_params['restol'] = 1e-13 - level_params['dt'] = 1e-2 - - # initialize sweeper parameters - sweeper_params = dict() - sweeper_params['quad_type'] = 'LOBATTO' - sweeper_params['num_nodes'] = 5 - sweeper_params['QI'] = 'LU' # For the IMEX sweeper, the LU-trick can be activated for the implicit part - sweeper_params['initial_guess'] = 'zero' - - # initialize problem parameters - problem_params = dict() - problem_params['Vs'] = 5.0 - problem_params['Rs'] = 0.5 - problem_params['C1'] = 1.0 - problem_params['C2'] = 1.0 - problem_params['R'] = 1.0 - problem_params['L'] = 1.0 - problem_params['alpha'] = 5.0 - problem_params['V_ref'] = np.array([1.0, 1.0]) # [V_ref1, V_ref2] - problem_params['set_switch'] = np.array([False, False], dtype=bool) - problem_params['t_switch'] = np.zeros(np.shape(problem_params['V_ref'])[0]) - - # initialize step parameters - step_params = dict() - step_params['maxiter'] = 20 - - # initialize controller parameters - controller_params = dict() - controller_params['logger_level'] = 20 - controller_params['hook_class'] = log_data - - # convergence controllers - convergence_controllers = dict() - if use_switch_estimator: - switch_estimator_params = {} - convergence_controllers[SwitchEstimator] = switch_estimator_params - - # fill description dictionary for easy step instantiation - description = dict() - description['problem_class'] = battery_2condensators # pass problem class - description['problem_params'] = problem_params # pass problem parameters - description['sweeper_class'] = imex_1st_order # pass sweeper - description['sweeper_params'] = sweeper_params # pass sweeper parameters - description['level_params'] = level_params # pass level parameters - description['step_params'] = step_params - description['space_transfer_class'] = mesh_to_mesh # pass spatial transfer class - - if use_switch_estimator: - description['convergence_controllers'] = convergence_controllers - - proof_assertions_description(description, problem_params) - - # set time parameters - t0 = 0.0 - Tend = 3.5 - - # instantiate controller - controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description) - - # get initial values on finest level - P = controller.MS[0].levels[0].prob - uinit = P.u_exact(t0) - - # call main function to get things done... - uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) - - Path("data").mkdir(parents=True, exist_ok=True) - fname = 'data/battery_2condensators.dat' - f = open(fname, 'wb') - dill.dump(stats, f) - f.close() - - # filter statistics by number of iterations - iter_counts = get_sorted(stats, type='niter', sortby='time') - - # compute and print statistics - min_iter = 20 - max_iter = 0 - - f = open('battery_2condensators_out.txt', 'w') - niters = np.array([item[1] for item in iter_counts]) - out = ' Mean number of iterations: %4.2f' % np.mean(niters) - f.write(out + '\n') - print(out) - for item in iter_counts: - out = 'Number of iterations for time %4.2f: %1i' % item - f.write(out + '\n') - # print(out) - min_iter = min(min_iter, item[1]) - max_iter = max(max_iter, item[1]) - - restarts = np.array(get_sorted(stats, type='restart', recomputed=False))[:, 1] - print("Restarts for dt: ", level_params['dt'], " -- ", np.sum(restarts)) - - assert np.mean(niters) <= 10, "Mean number of iterations is too high, got %s" % np.mean(niters) - f.close() - - plot_voltages(description, use_switch_estimator) - - return np.mean(niters) - - -def plot_voltages(description, use_switch_estimator, cwd='./'): - """ - Routine to plot the numerical solution of the model - """ - - f = open(cwd + 'data/battery_2condensators.dat', 'rb') - stats = dill.load(f) - f.close() - - # convert filtered statistics to list of iterations count, sorted by process - cL = get_sorted(stats, type='current L', sortby='time') - vC1 = get_sorted(stats, type='voltage C1', sortby='time') - vC2 = get_sorted(stats, type='voltage C2', sortby='time') - - times = [v[0] for v in cL] - - setup_mpl() - fig, ax = plt_helper.plt.subplots(1, 1, figsize=(4.5, 3)) - ax.plot(times, [v[1] for v in cL], label='$i_L$') - ax.plot(times, [v[1] for v in vC1], label='$v_{C_1}$') - ax.plot(times, [v[1] for v in vC2], label='$v_{C_2}$') - - if use_switch_estimator: - t_switch_plot = np.zeros(np.shape(description['problem_params']['t_switch'])[0]) - for i in range(np.shape(description['problem_params']['t_switch'])[0]): - t_switch_plot[i] = description['problem_params']['t_switch'][i] - - ax.axvline(x=t_switch_plot[i], linestyle='--', color='k', label='Switch {}'.format(i + 1)) - - ax.legend(frameon=False, fontsize=12, loc='upper right') - - ax.set_xlabel('Time') - ax.set_ylabel('Energy') - - fig.savefig('data/battery_2condensators_model_solution.png', dpi=300, bbox_inches='tight') - plt_helper.plt.close(fig) - - -def proof_assertions_description(description, problem_params): - """ - Function to proof the assertions (function to get cleaner code) - """ - - assert problem_params['alpha'] > problem_params['V_ref'][0], 'Please set "alpha" greater than "V_ref1"' - assert problem_params['alpha'] > problem_params['V_ref'][1], 'Please set "alpha" greater than "V_ref2"' - - assert problem_params['V_ref'][0] > 0, 'Please set "V_ref1" greater than 0' - assert problem_params['V_ref'][1] > 0, 'Please set "V_ref2" greater than 0' - - assert type(problem_params['V_ref']) == np.ndarray, '"V_ref" needs to be an array (of type float)' - assert not problem_params['set_switch'][0], 'First entry of "set_switch" needs to be False' - assert not problem_params['set_switch'][1], 'Second entry of "set_switch" needs to be False' - - assert not type(problem_params['t_switch']) == float, '"t_switch" has to be an array with entry zero' - - assert problem_params['t_switch'][0] == 0, 'First entry of "t_switch" needs to be zero' - assert problem_params['t_switch'][1] == 0, 'Second entry of "t_switch" needs to be zero' - - assert 'errtol' not in description['step_params'].keys(), 'No exact solution known to compute error' - assert 'alpha' in description['problem_params'].keys(), 'Please supply "alpha" in the problem parameters' - assert 'V_ref' in description['problem_params'].keys(), 'Please supply "V_ref" in the problem parameters' - - -if __name__ == "__main__": - main() diff --git a/pySDC/projects/PinTSimE/battery_model.py b/pySDC/projects/PinTSimE/battery_model.py index a127687eb5a9e04243329905ac3fac41ce215d83..428a202672c0adaad8fdd19c0ce10bd4915abb51 100644 --- a/pySDC/projects/PinTSimE/battery_model.py +++ b/pySDC/projects/PinTSimE/battery_model.py @@ -2,7 +2,7 @@ import numpy as np import dill from pathlib import Path -from pySDC.helpers.stats_helper import get_sorted +from pySDC.helpers.stats_helper import sort_stats, filter_stats, get_sorted from pySDC.core.Collocation import CollBase as Collocation from pySDC.implementations.problem_classes.Battery import battery, battery_implicit from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order @@ -32,26 +32,26 @@ class log_data(hooks): level=L.level_index, iter=0, sweep=L.status.sweep, - type='current L', - value=L.uend[0], + type='u', + value=L.uend, ) self.add_to_stats( process=step.status.slot, - time=L.time + L.dt, + time=L.time, level=L.level_index, iter=0, sweep=L.status.sweep, - type='voltage C', - value=L.uend[1], + type='restart', + value=int(step.status.get('restart')), ) self.add_to_stats( process=step.status.slot, - time=L.time, + time=L.time + L.dt, level=L.level_index, iter=0, sweep=L.status.sweep, - type='restart', - value=int(step.status.get('restart')), + type='dt', + value=L.dt, ) self.add_to_stats( process=step.status.slot, @@ -59,14 +59,41 @@ class log_data(hooks): level=L.level_index, iter=0, sweep=L.status.sweep, - type='dt', - value=L.dt, + type='e_embedded', + value=L.status.get('error_embedded_estimate'), ) -def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity): +def generate_description( + dt, + problem, + sweeper, + hook_class, + use_adaptivity, + use_switch_estimator, + ncapacitors, + alpha, + V_ref, + C, + max_restarts=None, +): """ - A simple test program to do SDC/PFASST runs for the battery drain model + Generate a description for the battery models for a controller run. + Args: + dt (float): time step for computation + problem (pySDC.core.Problem.ptype): problem class that wants to be simulated + sweeper (pySDC.core.Sweeper.sweeper): sweeper class for solving the problem class numerically + hook_class (pySDC.core.Hooks): logged data for a problem + use_adaptivity (bool): flag if the adaptivity wants to be used or not + use_switch_estimator (bool): flag if the switch estimator wants to be used or not + ncapacitors (np.int): number of capacitors used for the battery_model + alpha (np.float): Multiple used for the initial conditions (problem_parameter) + V_ref (np.ndarray): Reference values for the capacitors (problem_parameter) + C (np.ndarray): Capacitances (problem_parameter + + Returns: + description (dict): contains all information for a controller run + controller_params (dict): Parameters needed for a controller run """ # initialize level parameters @@ -78,22 +105,21 @@ def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity): sweeper_params = dict() sweeper_params['quad_type'] = 'LOBATTO' sweeper_params['num_nodes'] = 5 - # sweeper_params['QI'] = 'LU' # For the IMEX sweeper, the LU-trick can be activated for the implicit part - sweeper_params['initial_guess'] = 'zero' + sweeper_params['QI'] = 'IE' + sweeper_params['initial_guess'] = 'spread' # initialize problem parameters problem_params = dict() problem_params['newton_maxiter'] = 200 problem_params['newton_tol'] = 1e-08 + problem_params['ncapacitors'] = ncapacitors # number of condensators problem_params['Vs'] = 5.0 problem_params['Rs'] = 0.5 - problem_params['C'] = 1.0 + problem_params['C'] = C problem_params['R'] = 1.0 problem_params['L'] = 1.0 - problem_params['alpha'] = 1.2 - problem_params['V_ref'] = 1.0 - problem_params['set_switch'] = np.array([False], dtype=bool) - problem_params['t_switch'] = np.zeros(1) + problem_params['alpha'] = alpha + problem_params['V_ref'] = V_ref # initialize step parameters step_params = dict() @@ -101,8 +127,8 @@ def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity): # initialize controller parameters controller_params = dict() - controller_params['logger_level'] = 20 - controller_params['hook_class'] = log_data + controller_params['logger_level'] = 30 + controller_params['hook_class'] = hook_class controller_params['mssdc_jac'] = False # convergence controllers @@ -113,7 +139,7 @@ def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity): if use_adaptivity: adaptivity_params = dict() - adaptivity_params['e_tol'] = 1e-12 + adaptivity_params['e_tol'] = 1e-7 convergence_controllers.update({Adaptivity: adaptivity_params}) # fill description dictionary for easy step instantiation @@ -124,15 +150,28 @@ def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity): description['sweeper_params'] = sweeper_params # pass sweeper parameters description['level_params'] = level_params # pass level parameters description['step_params'] = step_params + if max_restarts is not None: + description['max_restarts'] = max_restarts if use_switch_estimator or use_adaptivity: description['convergence_controllers'] = convergence_controllers - proof_assertions_description(description, problem_params) + return description, controller_params - # set time parameters - t0 = 0.0 - Tend = 0.3 + +def controller_run(description, controller_params, use_adaptivity, use_switch_estimator, t0, Tend): + """ + Executes a controller run for a problem defined in the description + + Args: + description (dict): contains all information for a controller run + controller_params (dict): Parameters needed for a controller run + use_adaptivity (bool): flag if the adaptivity wants to be used or not + use_switch_estimator (bool): flag if the switch estimator wants to be used or not + + Returns: + stats (dict): Raw statistics from a controller run + """ # instantiate controller controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description) @@ -144,35 +183,18 @@ def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity): # call main function to get things done... uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) - # filter statistics by number of iterations - iter_counts = get_sorted(stats, type='niter', recomputed=False, sortby='time') - - # compute and print statistics - min_iter = 20 - max_iter = 0 + problem = description['problem_class'] + sweeper = description['sweeper_class'] Path("data").mkdir(parents=True, exist_ok=True) - fname = 'data/battery_{}_USE{}_USA{}.dat'.format(sweeper.__name__, use_switch_estimator, use_adaptivity) + fname = 'data/{}_{}_USE{}_USA{}.dat'.format( + problem.__name__, sweeper.__name__, use_switch_estimator, use_adaptivity + ) f = open(fname, 'wb') dill.dump(stats, f) f.close() - f = open('data/battery_out.txt', 'w') - niters = np.array([item[1] for item in iter_counts]) - out = ' Mean number of iterations: %4.2f' % np.mean(niters) - f.write(out + '\n') - print(out) - for item in iter_counts: - out = 'Number of iterations for time %4.2f: %1i' % item - f.write(out + '\n') - print(out) - min_iter = min(min_iter, item[1]) - max_iter = max(max_iter, item[1]) - - assert np.mean(niters) <= 5, "Mean number of iterations is too high, got %s" % np.mean(niters) - f.close() - - return description + return stats def run(): @@ -182,53 +204,80 @@ def run(): """ dt = 1e-2 + t0 = 0.0 + Tend = 0.3 + problem_classes = [battery, battery_implicit] sweeper_classes = [imex_1st_order, generic_implicit] + + ncapacitors = 1 + alpha = 1.2 + V_ref = np.array([1.0]) + C = np.array([1.0]) + + recomputed = False use_switch_estimator = [True] use_adaptivity = [True] for problem, sweeper in zip(problem_classes, sweeper_classes): for use_SE in use_switch_estimator: for use_A in use_adaptivity: - description = main( - dt=dt, - problem=problem, - sweeper=sweeper, - use_switch_estimator=use_SE, - use_adaptivity=use_A, + description, controller_params = generate_description( + dt, problem, sweeper, log_data, use_A, use_SE, ncapacitors, alpha, V_ref, C ) - plot_voltages(description, problem.__name__, sweeper.__name__, use_SE, use_A) + # Assertions + proof_assertions_description(description, use_A, use_SE) + + proof_assertions_time(dt, Tend, V_ref, alpha) + + stats = controller_run(description, controller_params, use_A, use_SE, t0, Tend) + check_solution(stats, dt, problem.__name__, use_A, use_SE) -def plot_voltages(description, problem, sweeper, use_switch_estimator, use_adaptivity, cwd='./'): + plot_voltages(description, problem.__name__, sweeper.__name__, recomputed, use_SE, use_A) + + +def plot_voltages(description, problem, sweeper, recomputed, use_switch_estimator, use_adaptivity, cwd='./'): """ Routine to plot the numerical solution of the model + + Args: + description(dict): contains all information for a controller run + problem (pySDC.core.Problem.ptype): problem class that wants to be simulated + sweeper (pySDC.core.Sweeper.sweeper): sweeper class for solving the problem class numerically + recomputed (bool): flag if the values after a restart are used or before + use_switch_estimator (bool): flag if the switch estimator wants to be used or not + use_adaptivity (bool): flag if adaptivity wants to be used or not + cwd (str): current working directory """ - f = open(cwd + 'data/battery_{}_USE{}_USA{}.dat'.format(sweeper, use_switch_estimator, use_adaptivity), 'rb') + f = open(cwd + 'data/{}_{}_USE{}_USA{}.dat'.format(problem, sweeper, use_switch_estimator, use_adaptivity), 'rb') stats = dill.load(f) f.close() # convert filtered statistics to list of iterations count, sorted by process - cL = get_sorted(stats, type='current L', recomputed=False, sortby='time') - vC = get_sorted(stats, type='voltage C', recomputed=False, sortby='time') + cL = np.array([me[1][0] for me in get_sorted(stats, type='u', recomputed=recomputed)]) + vC = np.array([me[1][1] for me in get_sorted(stats, type='u', recomputed=recomputed)]) - times = [v[0] for v in cL] + t = np.array([me[0] for me in get_sorted(stats, type='u', recomputed=recomputed)]) setup_mpl() fig, ax = plt_helper.plt.subplots(1, 1, figsize=(3, 3)) ax.set_title('Simulation of {} using {}'.format(problem, sweeper), fontsize=10) - ax.plot(times, [v[1] for v in cL], label=r'$i_L$') - ax.plot(times, [v[1] for v in vC], label=r'$v_C$') + ax.plot(t, cL, label=r'$i_L$') + ax.plot(t, vC, label=r'$v_C$') if use_switch_estimator: - val_switch = get_sorted(stats, type='switch1', sortby='time') - t_switch = [v[1] for v in val_switch] + switches = get_recomputed(stats, type='switch', sortby='time') + + assert len(switches) >= 1, 'No switches found!' + t_switch = [v[1] for v in switches] ax.axvline(x=t_switch[-1], linestyle='--', linewidth=0.8, color='r', label='Switch') if use_adaptivity: dt = np.array(get_sorted(stats, type='dt', recomputed=False)) + dt_ax = ax.twinx() dt_ax.plot(dt[:, 0], dt[:, 1], linestyle='-', linewidth=0.8, color='k', label=r'$\Delta t$') dt_ax.set_ylabel(r'$\Delta t$', fontsize=8) @@ -245,23 +294,377 @@ def plot_voltages(description, problem, sweeper, use_switch_estimator, use_adapt plt_helper.plt.close(fig) -def proof_assertions_description(description, problem_params): +def check_solution(stats, dt, problem, use_adaptivity, use_switch_estimator): + """ + Function that checks the solution based on a hardcoded reference solution. Based on check_solution function from @brownbaerchen. + + Args: + stats (dict): Raw statistics from a controller run + dt (float): initial time step + problem (problem_class.__name__): the problem_class that is numerically solved + use_switch_estimator (bool): + use_adaptivity (bool): + """ + + data = get_data_dict(stats, use_adaptivity, use_switch_estimator) + + if problem == 'battery': + if use_switch_estimator and use_adaptivity: + msg = f'Error when using switch estimator and adaptivity for battery for dt={dt:.1e}:' + if dt == 1e-2: + expected = { + 'cL': 0.5474500710994862, + 'vC': 1.0019332967173764, + 'dt': 0.011761752270047832, + 'e_em': 8.001793672107738e-10, + 'switches': 0.18232155791181945, + 'restarts': 3.0, + 'sum_niters': 44, + } + elif dt == 4e-2: + expected = { + 'cL': 0.5525783945667581, + 'vC': 1.00001743462299, + 'dt': 0.03550610373897258, + 'e_em': 6.21240694442804e-08, + 'switches': 0.18231603298272345, + 'restarts': 4.0, + 'sum_niters': 56, + } + elif dt == 4e-3: + expected = { + 'cL': 0.5395601429161445, + 'vC': 1.0000413761942089, + 'dt': 0.028281271825675414, + 'e_em': 2.5628611677319668e-08, + 'switches': 0.18230920573953438, + 'restarts': 3.0, + 'sum_niters': 48, + } + + got = { + 'cL': data['cL'][-1], + 'vC': data['vC'][-1], + 'dt': data['dt'][-1], + 'e_em': data['e_em'][-1], + 'switches': data['switches'][-1], + 'restarts': data['restarts'], + 'sum_niters': data['sum_niters'], + } + elif use_switch_estimator and not use_adaptivity: + msg = f'Error when using switch estimator for battery for dt={dt:.1e}:' + if dt == 1e-2: + expected = { + 'cL': 0.5423033461806986, + 'vC': 1.000118710428906, + 'switches': 0.1823188001399631, + 'restarts': 1.0, + 'sum_niters': 284, + } + elif dt == 4e-2: + expected = { + 'cL': 0.6139093327509394, + 'vC': 1.0010140038721593, + 'switches': 0.1824302065533169, + 'restarts': 1.0, + 'sum_niters': 48, + } + elif dt == 4e-3: + expected = { + 'cL': 0.5429509935448258, + 'vC': 1.0001158309787614, + 'switches': 0.18232183080236553, + 'restarts': 1.0, + 'sum_niters': 392, + } + + got = { + 'cL': data['cL'][-1], + 'vC': data['vC'][-1], + 'switches': data['switches'][-1], + 'restarts': data['restarts'], + 'sum_niters': data['sum_niters'], + } + + elif not use_switch_estimator and use_adaptivity: + msg = f'Error when using adaptivity for battery for dt={dt:.1e}:' + if dt == 1e-2: + expected = { + 'cL': 0.5413318777113352, + 'vC': 0.9963444569399663, + 'dt': 0.020451912195976252, + 'e_em': 7.157646031430431e-09, + 'restarts': 4.0, + 'sum_niters': 56, + } + elif dt == 4e-2: + expected = { + 'cL': 0.5966289599915113, + 'vC': 0.9923148791604984, + 'dt': 0.03564958366355817, + 'e_em': 6.210964231812e-08, + 'restarts': 1.0, + 'sum_niters': 36, + } + elif dt == 4e-3: + expected = { + 'cL': 0.5431613774808756, + 'vC': 0.9934307674636834, + 'dt': 0.022880524075396924, + 'e_em': 1.1130212751453428e-08, + 'restarts': 3.0, + 'sum_niters': 52, + } + + got = { + 'cL': data['cL'][-1], + 'vC': data['vC'][-1], + 'dt': data['dt'][-1], + 'e_em': data['e_em'][-1], + 'restarts': data['restarts'], + 'sum_niters': data['sum_niters'], + } + + elif problem == 'battery_implicit': + if use_switch_estimator and use_adaptivity: + msg = f'Error when using switch estimator and adaptivity for battery_implicit for dt={dt:.1e}:' + if dt == 1e-2: + expected = { + 'cL': 0.5424577937840791, + 'vC': 1.0001051105894005, + 'dt': 0.01, + 'e_em': 2.220446049250313e-16, + 'switches': 0.1822923488448394, + 'restarts': 6.0, + 'sum_niters': 60, + } + elif dt == 4e-2: + expected = { + 'cL': 0.6717104472882885, + 'vC': 1.0071670698947914, + 'dt': 0.035896059229296486, + 'e_em': 6.208836400567463e-08, + 'switches': 0.18232158833761175, + 'restarts': 3.0, + 'sum_niters': 36, + } + elif dt == 4e-3: + expected = { + 'cL': 0.5396216192241711, + 'vC': 1.0000561014463172, + 'dt': 0.009904645972832471, + 'e_em': 2.220446049250313e-16, + 'switches': 0.18230549652342606, + 'restarts': 4.0, + 'sum_niters': 44, + } + + got = { + 'cL': data['cL'][-1], + 'vC': data['vC'][-1], + 'dt': data['dt'][-1], + 'e_em': data['e_em'][-1], + 'switches': data['switches'][-1], + 'restarts': data['restarts'], + 'sum_niters': data['sum_niters'], + } + elif use_switch_estimator and not use_adaptivity: + msg = f'Error when using switch estimator for battery_implicit for dt={dt:.1e}:' + if dt == 1e-2: + expected = { + 'cL': 0.5423033363981951, + 'vC': 1.000118715162845, + 'switches': 0.18231880065636324, + 'restarts': 1.0, + 'sum_niters': 284, + } + elif dt == 4e-2: + expected = { + 'cL': 0.613909968362315, + 'vC': 1.0010140112484431, + 'switches': 0.18243023230469263, + 'restarts': 1.0, + 'sum_niters': 48, + } + elif dt == 4e-3: + expected = { + 'cL': 0.5429616576526073, + 'vC': 1.0001158454740509, + 'switches': 0.1823218812753008, + 'restarts': 1.0, + 'sum_niters': 392, + } + + got = { + 'cL': data['cL'][-1], + 'vC': data['vC'][-1], + 'switches': data['switches'][-1], + 'restarts': data['restarts'], + 'sum_niters': data['sum_niters'], + } + + elif not use_switch_estimator and use_adaptivity: + msg = f'Error when using adaptivity for battery_implicit for dt={dt:.1e}:' + if dt == 1e-2: + expected = { + 'cL': 0.5490142863996689, + 'vC': 0.997253099984895, + 'dt': 0.024243123245133835, + 'e_em': 1.4052013885823555e-08, + 'restarts': 11.0, + 'sum_niters': 96, + } + elif dt == 4e-2: + expected = { + 'cL': 0.5556563012729733, + 'vC': 0.9930947318467772, + 'dt': 0.035507110551631804, + 'e_em': 6.2098696185231e-08, + 'restarts': 6.0, + 'sum_niters': 64, + } + elif dt == 4e-3: + expected = { + 'cL': 0.5401117929618637, + 'vC': 0.9933888475391347, + 'dt': 0.03176025170463925, + 'e_em': 4.0386798239033794e-08, + 'restarts': 8.0, + 'sum_niters': 80, + } + + got = { + 'cL': data['cL'][-1], + 'vC': data['vC'][-1], + 'dt': data['dt'][-1], + 'e_em': data['e_em'][-1], + 'restarts': data['restarts'], + 'sum_niters': data['sum_niters'], + } + + for key in expected.keys(): + assert np.isclose( + expected[key], got[key], rtol=1e-4 + ), f'{msg} Expected {key}={expected[key]:.4e}, got {key}={got[key]:.4e}' + + +def get_data_dict(stats, use_adaptivity=True, use_switch_estimator=True, recomputed=False): + """ + Converts the statistics in a useful data dictionary so that it can be easily checked in the check_solution function. + Based on @brownbaerchen's get_data function. + + Args: + stats (dict): Raw statistics from a controller run + use_adaptivity (bool): flag if adaptivity wants to be used or not + use_switch_estimator (bool): flag if the switch estimator wants to be used or not + recomputed (bool): flag if the values after a restart are used or before + + Return: + data (dict): contains all information as the statistics dict + """ + + data = dict() + + data['cL'] = np.array([me[1][0] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')]) + data['vC'] = np.array([me[1][1] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')]) + if use_adaptivity: + data['dt'] = np.array(get_sorted(stats, type='dt', recomputed=recomputed, sortby='time'))[:, 1] + data['e_em'] = np.array( + get_sorted(stats, type='error_embedded_estimate', recomputed=recomputed, sortby='time') + )[:, 1] + if use_switch_estimator: + data['switches'] = np.array(get_recomputed(stats, type='switch', sortby='time'))[:, 1] + if use_adaptivity or use_switch_estimator: + data['restarts'] = np.sum(np.array(get_sorted(stats, type='restart', recomputed=None, sortby='time'))[:, 1]) + data['sum_niters'] = np.sum(np.array(get_sorted(stats, type='niter', recomputed=None, sortby='time'))[:, 1]) + + return data + + +def get_recomputed(stats, type, sortby): + """ + Function that filters statistics after a recomputation. It stores all value of a type before restart. If there are multiple values + with same time point, it only stores the elements with unique times. + + Args: + stats (dict): Raw statistics from a controller run + type (str): the type the be filtered + sortby (str): string to specify which key to use for sorting + + Returns: + sorted_list (list): list of filtered statistics """ - Function to proof the assertions (function to get cleaner code) + + sorted_nested_list = [] + times_unique = np.unique([me[0] for me in get_sorted(stats, type=type)]) + filtered_list = [ + filter_stats( + stats, + time=t_unique, + num_restarts=max([me.num_restarts for me in filter_stats(stats, type=type, time=t_unique).keys()]), + type=type, + ) + for t_unique in times_unique + ] + for item in filtered_list: + sorted_nested_list.append(sort_stats(item, sortby=sortby)) + sorted_list = [item for sub_item in sorted_nested_list for item in sub_item] + return sorted_list + + +def proof_assertions_description(description, use_adaptivity, use_switch_estimator): """ + Function to proof the assertions in the description. - assert problem_params['alpha'] > problem_params['V_ref'], 'Please set "alpha" greater than "V_ref"' - assert problem_params['V_ref'] > 0, 'Please set "V_ref" greater than 0' - assert type(problem_params['V_ref']) == float, '"V_ref" needs to be of type float' + Args: + description(dict): contains all information for a controller run + use_adaptivity (bool): flag if adaptivity wants to be used or not + use_switch_estimator (bool): flag if the switch estimator wants to be used or not + """ - assert type(problem_params['set_switch'][0]) == np.bool_, '"set_switch" has to be an bool array' - assert type(problem_params['t_switch']) == np.ndarray, '"t_switch" has to be an array' - assert problem_params['t_switch'][0] == 0, '"t_switch" is only allowed to have entry zero' + n = description['problem_params']['ncapacitors'] + assert ( + description['problem_params']['alpha'] > description['problem_params']['V_ref'][k] for k in range(n) + ), 'Please set "alpha" greater than values of "V_ref"' + assert type(description['problem_params']['V_ref']) == np.ndarray, '"V_ref" needs to be an np.ndarray' + assert type(description['problem_params']['C']) == np.ndarray, '"C" needs to be an np.ndarray ' + assert ( + np.shape(description['problem_params']['V_ref'])[0] == n + ), 'Number of reference values needs to be equal to number of condensators' + assert ( + np.shape(description['problem_params']['C'])[0] == n + ), 'Number of capacitance values needs to be equal to number of condensators' + + assert ( + description['problem_params']['V_ref'][k] > 0 for k in range(n) + ), 'Please set values of "V_ref" greater than 0' assert 'errtol' not in description['step_params'].keys(), 'No exact solution known to compute error' assert 'alpha' in description['problem_params'].keys(), 'Please supply "alpha" in the problem parameters' assert 'V_ref' in description['problem_params'].keys(), 'Please supply "V_ref" in the problem parameters' + if use_switch_estimator or use_adaptivity: + assert description['level_params']['restol'] == -1, "Please set restol to -1 or omit it" + + +def proof_assertions_time(dt, Tend, V_ref, alpha): + """ + Function to proof the assertions regarding the time domain (in combination with the specific problem): + + Args: + dt (float): time step for computation + Tend (float): end time + V_ref (np.ndarray): Reference values (problem parameter) + alpha (np.float): Multiple used for initial conditions (problem_parameter) + """ + + assert dt < Tend, "Time step is too large for the time domain!" + + assert ( + Tend == 0.3 and V_ref[0] == 1.0 and alpha == 1.2 + ), "Error! Do not use other parameters for V_ref != 1.0, alpha != 1.2, Tend != 0.3 due to hardcoded reference!" + assert dt == 1e-2, "Error! Do not use another time step dt!= 1e-2!" + if __name__ == "__main__": run() diff --git a/pySDC/projects/PinTSimE/estimation_check.py b/pySDC/projects/PinTSimE/estimation_check.py index e565bd11d387705ee5f0df74c602077c03c133e4..da012be499e25935134b9b922654a730f7fbed37 100644 --- a/pySDC/projects/PinTSimE/estimation_check.py +++ b/pySDC/projects/PinTSimE/estimation_check.py @@ -9,143 +9,79 @@ from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI from pySDC.projects.PinTSimE.piline_model import setup_mpl -from pySDC.projects.PinTSimE.battery_model import log_data, proof_assertions_description +from pySDC.projects.PinTSimE.battery_model import ( + controller_run, + check_solution, + generate_description, + get_recomputed, + log_data, + proof_assertions_description, +) import pySDC.helpers.plot_helper as plt_helper +from pySDC.core.Hooks import hooks from pySDC.projects.PinTSimE.switch_estimator import SwitchEstimator from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity -from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import EstimateEmbeddedErrorNonMPI -def run(dt, problem, sweeper, use_switch_estimator, use_adaptivity, V_ref): +def run(cwd='./'): """ - A simple test program to do SDC/PFASST runs for the battery drain model + Routine to check the differences between using a switch estimator or not + + Args: + cwd (str): current working directory """ - # initialize level parameters - level_params = dict() - level_params['restol'] = -1 - level_params['dt'] = dt - - # initialize sweeper parameters - sweeper_params = dict() - sweeper_params['quad_type'] = 'LOBATTO' - sweeper_params['num_nodes'] = 5 - # sweeper_params['QI'] = 'LU' # For the IMEX sweeper, the LU-trick can be activated for the implicit part - sweeper_params['initial_guess'] = 'zero' - - # initialize problem parameters - problem_params = dict() - problem_params['newton_maxiter'] = 200 - problem_params['newton_tol'] = 1e-08 - problem_params['Vs'] = 5.0 - problem_params['Rs'] = 0.5 - problem_params['C'] = 1.0 - problem_params['R'] = 1.0 - problem_params['L'] = 1.0 - problem_params['alpha'] = 1.2 - problem_params['V_ref'] = V_ref - problem_params['set_switch'] = np.array([False], dtype=bool) - problem_params['t_switch'] = np.zeros(1) - - # initialize step parameters - step_params = dict() - step_params['maxiter'] = 4 - - # initialize controller parameters - controller_params = dict() - controller_params['logger_level'] = 20 - controller_params['hook_class'] = log_data - controller_params['mssdc_jac'] = False - - # convergence controllers - convergence_controllers = dict() - if use_switch_estimator: - switch_estimator_params = dict() - convergence_controllers.update({SwitchEstimator: switch_estimator_params}) - - if use_adaptivity: - adaptivity_params = dict() - adaptivity_params['e_tol'] = 1e-7 - convergence_controllers.update({Adaptivity: adaptivity_params}) - - # fill description dictionary for easy step instantiation - description = dict() - description['problem_class'] = problem # pass problem class - description['problem_params'] = problem_params # pass problem parameters - description['sweeper_class'] = sweeper # pass sweeper - description['sweeper_params'] = sweeper_params # pass sweeper parameters - description['level_params'] = level_params # pass level parameters - description['step_params'] = step_params - description['max_restarts'] = 1 - - if use_switch_estimator or use_adaptivity: - description['convergence_controllers'] = convergence_controllers - - proof_assertions_description(description, problem_params) - - # set time parameters + dt_list = [4e-2, 4e-3] t0 = 0.0 Tend = 0.3 - # instantiate controller - controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description) - - # get initial values on finest level - P = controller.MS[0].levels[0].prob - uinit = P.u_exact(t0) - - # call main function to get things done... - uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) - - Path("data").mkdir(parents=True, exist_ok=True) - fname = 'data/battery.dat' - f = open(fname, 'wb') - dill.dump(stats, f) - f.close() - - # filter statistics by number of iterations - iter_counts = get_sorted(stats, type='niter', recomputed=False, sortby='time') - - # compute and print statistics - f = open('data/battery_out.txt', 'w') - niters = np.array([item[1] for item in iter_counts]) - - assert np.mean(niters) <= 11, "Mean number of iterations is too high, got %s" % np.mean(niters) - f.close() - - return description, stats + problem_classes = [battery, battery_implicit] + sweeper_classes = [imex_1st_order, generic_implicit] + ncapacitors = 1 + alpha = 1.2 + V_ref = np.array([1.0]) + C = np.array([1.0]) -def check(cwd='./'): - """ - Routine to check the differences between using a switch estimator or not - """ - - V_ref = 1.0 - dt_list = [1e-2, 1e-3] + max_restarts = 1 use_switch_estimator = [True, False] use_adaptivity = [True, False] - restarts_true = [] - restarts_false_adapt = [] - restarts_true_adapt = [] - - problem_classes = [battery, battery_implicit] - sweeper_classes = [imex_1st_order, generic_implicit] + restarts_SE = [] + restarts_adapt = [] + restarts_SE_adapt = [] for problem, sweeper in zip(problem_classes, sweeper_classes): for dt_item in dt_list: for use_SE in use_switch_estimator: for use_A in use_adaptivity: - description, stats = run( - dt=dt_item, - problem=problem, - sweeper=sweeper, - use_switch_estimator=use_SE, - use_adaptivity=use_A, - V_ref=V_ref, + description, controller_params = generate_description( + dt_item, + problem, + sweeper, + log_data, + use_A, + use_SE, + ncapacitors, + alpha, + V_ref, + C, + max_restarts, ) + # Assertions + proof_assertions_description(description, use_A, use_SE) + + stats = controller_run(description, controller_params, use_A, use_SE, t0, Tend) + + if use_A or use_SE: + check_solution(stats, dt_item, problem.__name__, use_A, use_SE) + + if use_SE: + assert ( + len(get_recomputed(stats, type='switch', sortby='time')) >= 1 + ), 'No switches found for dt={}!'.format(dt_item) + fname = 'data/battery_dt{}_USE{}_USA{}_{}.dat'.format(dt_item, use_SE, use_A, sweeper.__name__) f = open(fname, 'wb') dill.dump(stats, f) @@ -153,24 +89,23 @@ def check(cwd='./'): if use_SE or use_A: restarts_sorted = np.array(get_sorted(stats, type='restart', recomputed=None))[:, 1] - print('Restarts for dt={}: {}'.format(dt_item, np.sum(restarts_sorted))) if use_SE and not use_A: - restarts_true.append(np.sum(restarts_sorted)) + restarts_SE.append(np.sum(restarts_sorted)) elif not use_SE and use_A: - restarts_false_adapt.append(np.sum(restarts_sorted)) + restarts_adapt.append(np.sum(restarts_sorted)) elif use_SE and use_A: - restarts_true_adapt.append(np.sum(restarts_sorted)) + restarts_SE_adapt.append(np.sum(restarts_sorted)) accuracy_check(dt_list, problem.__name__, sweeper.__name__, V_ref) differences_around_switch( dt_list, problem.__name__, - restarts_true, - restarts_false_adapt, - restarts_true_adapt, + restarts_SE, + restarts_adapt, + restarts_SE_adapt, sweeper.__name__, V_ref, ) @@ -179,14 +114,21 @@ def check(cwd='./'): iterations_over_time(dt_list, description['step_params']['maxiter'], problem.__name__, sweeper.__name__) - restarts_true = [] - restarts_false_adapt = [] - restarts_true_adapt = [] + restarts_SE = [] + restarts_adapt = [] + restarts_SE_adapt = [] def accuracy_check(dt_list, problem, sweeper, V_ref, cwd='./'): """ Routine to check accuracy for different step sizes in case of using adaptivity + + Args: + dt_list (list): list of considered (initial) step sizes + problem (pySDC.core.Problem.ptype): Problem class used to consider (the class name) + sweeper (pySDC.core.Sweeper.sweeper): Sweeper used to solve (the class name) + V_ref (np.float): reference value for the switch + cwd (str): current working directory """ if len(dt_list) > 1: @@ -202,45 +144,49 @@ def accuracy_check(dt_list, problem, sweeper, V_ref, cwd='./'): count_ax = 0 for dt_item in dt_list: f3 = open(cwd + 'data/battery_dt{}_USETrue_USATrue_{}.dat'.format(dt_item, sweeper), 'rb') - stats_TT = dill.load(f3) + stats_SE_adapt = dill.load(f3) f3.close() f4 = open(cwd + 'data/battery_dt{}_USEFalse_USATrue_{}.dat'.format(dt_item, sweeper), 'rb') - stats_FT = dill.load(f4) + stats_adapt = dill.load(f4) f4.close() - val_switch_TT = get_sorted(stats_TT, type='switch1', sortby='time') - t_switch_adapt = [v[1] for v in val_switch_TT] - t_switch_adapt = t_switch_adapt[-1] + switches_SE_adapt = get_recomputed(stats_SE_adapt, type='switch', sortby='time') + t_switch_SE_adapt = [v[1] for v in switches_SE_adapt] + t_switch_SE_adapt = t_switch_SE_adapt[-1] - dt_TT_val = get_sorted(stats_TT, type='dt', recomputed=False) - dt_FT_val = get_sorted(stats_FT, type='dt', recomputed=False) + dt_SE_adapt_val = get_sorted(stats_SE_adapt, type='dt', recomputed=False) + dt_adapt_val = get_sorted(stats_adapt, type='dt', recomputed=False) - e_emb_TT_val = get_sorted(stats_TT, type='e_embedded', recomputed=False) - e_emb_FT_val = get_sorted(stats_FT, type='e_embedded', recomputed=False) + e_emb_SE_adapt_val = get_sorted(stats_SE_adapt, type='e_embedded', recomputed=False) + e_emb_adapt_val = get_sorted(stats_adapt, type='e_embedded', recomputed=False) - times_TT = [v[0] for v in e_emb_TT_val] - times_FT = [v[0] for v in e_emb_FT_val] + times_SE_adapt = [v[0] for v in e_emb_SE_adapt_val] + times_adapt = [v[0] for v in e_emb_adapt_val] - e_emb_TT = [v[1] for v in e_emb_TT_val] - e_emb_FT = [v[1] for v in e_emb_FT_val] + e_emb_SE_adapt = [v[1] for v in e_emb_SE_adapt_val] + e_emb_adapt = [v[1] for v in e_emb_adapt_val] if len(dt_list) > 1: - ax_acc[count_ax].set_title(r'$\Delta t$={}'.format(dt_item)) + ax_acc[count_ax].set_title(r'$\Delta t_\mathrm{initial}$=%s' % dt_item) dt1 = ax_acc[count_ax].plot( - [v[0] for v in dt_TT_val], [v[1] for v in dt_TT_val], 'ko-', label=r'SE+A - $\Delta t$' + [v[0] for v in dt_SE_adapt_val], + [v[1] for v in dt_SE_adapt_val], + 'ko-', + label=r'SE+A - $\Delta t_\mathrm{adapt}$', ) dt2 = ax_acc[count_ax].plot( - [v[0] for v in dt_FT_val], [v[1] for v in dt_FT_val], 'g-', label=r'A - $\Delta t$' + [v[0] for v in dt_adapt_val], [v[1] for v in dt_adapt_val], 'g-', label=r'A - $\Delta t_\mathrm{adapt}$' ) - ax_acc[count_ax].axvline(x=t_switch_adapt, linestyle='--', linewidth=0.5, color='r', label='Switch') + ax_acc[count_ax].axvline(x=t_switch_SE_adapt, linestyle='--', linewidth=0.5, color='r', label='Switch') + ax_acc[count_ax].tick_params(axis='both', which='major', labelsize=6) ax_acc[count_ax].set_xlabel('Time', fontsize=6) if count_ax == 0: - ax_acc[count_ax].set_ylabel(r'$\Delta t_{adapted}$', fontsize=6) + ax_acc[count_ax].set_ylabel(r'$\Delta t_\mathrm{adapt}$', fontsize=6) e_ax = ax_acc[count_ax].twinx() - e_plt1 = e_ax.plot(times_TT, e_emb_TT, 'k--', label=r'SE+A - $\epsilon_{emb}$') - e_plt2 = e_ax.plot(times_FT, e_emb_FT, 'g--', label=r'A - $\epsilon_{emb}$') + e_plt1 = e_ax.plot(times_SE_adapt, e_emb_SE_adapt, 'k--', label=r'SE+A - $\epsilon_{emb}$') + e_plt2 = e_ax.plot(times_adapt, e_emb_adapt, 'g--', label=r'A - $\epsilon_{emb}$') e_ax.set_yscale('log', base=10) e_ax.set_ylim(1e-16, 1e-7) e_ax.tick_params(labelsize=6) @@ -248,26 +194,37 @@ def accuracy_check(dt_list, problem, sweeper, V_ref, cwd='./'): lines = dt1 + e_plt1 + dt2 + e_plt2 labels = [l.get_label() for l in lines] - ax_acc[count_ax].legend(lines, labels, frameon=False, fontsize=6, loc='upper left') + ax_acc[count_ax].legend(lines, labels, frameon=False, fontsize=6, loc='upper right') else: - ax_acc.set_title(r'$\Delta t$={}'.format(dt_item)) - dt1 = ax_acc.plot([v[0] for v in dt_TT_val], [v[1] for v in dt_TT_val], 'ko-', label=r'SE+A - $\Delta t$') - dt2 = ax_acc.plot([v[0] for v in dt_FT_val], [v[1] for v in dt_FT_val], 'go-', label=r'A - $\Delta t$') - ax_acc.axvline(x=t_switch_adapt, linestyle='--', linewidth=0.5, color='r', label='Switch') + ax_acc.set_title(r'$\Delta t_\mathrm{initial}$=%s' % dt_item) + dt1 = ax_acc.plot( + [v[0] for v in dt_SE_adapt_val], + [v[1] for v in dt_SE_adapt_val], + 'ko-', + label=r'SE+A - $\Delta t_\mathrm{adapt}$', + ) + dt2 = ax_acc.plot( + [v[0] for v in dt_adapt_val], + [v[1] for v in dt_adapt_val], + 'go-', + label=r'A - $\Delta t_\mathrm{adapt}$', + ) + ax_acc.axvline(x=t_switch_SE_adapt, linestyle='--', linewidth=0.5, color='r', label='Switch') + ax_acc.tick_params(axis='both', which='major', labelsize=6) ax_acc.set_xlabel('Time', fontsize=6) - ax_acc.set_ylabel(r'$Delta t_{adapted}$', fontsize=6) + ax_acc.set_ylabel(r'$Delta t_\mathrm{adapt}$', fontsize=6) e_ax = ax_acc.twinx() - e_plt1 = e_ax.plot(times_TT, e_emb_TT, 'k--', label=r'SE+A - $\epsilon_{emb}$') - e_plt2 = e_ax.plot(times_FT, e_emb_FT, 'g--', label=r'A - $\epsilon_{emb}$') + e_plt1 = e_ax.plot(times_SE_adapt, e_emb_SE_adapt, 'k--', label=r'SE+A - $\epsilon_{emb}$') + e_plt2 = e_ax.plot(times_adapt, e_emb_adapt, 'g--', label=r'A - $\epsilon_{emb}$') e_ax.set_yscale('log', base=10) e_ax.tick_params(labelsize=6) lines = dt1 + e_plt1 + dt2 + e_plt2 labels = [l.get_label() for l in lines] - ax_acc.legend(lines, labels, frameon=False, fontsize=6, loc='upper left') + ax_acc.legend(lines, labels, frameon=False, fontsize=6, loc='upper right') count_ax += 1 @@ -276,10 +233,20 @@ def accuracy_check(dt_list, problem, sweeper, V_ref, cwd='./'): def differences_around_switch( - dt_list, problem, restarts_true, restarts_false_adapt, restarts_true_adapt, sweeper, V_ref, cwd='./' + dt_list, problem, restarts_SE, restarts_adapt, restarts_SE_adapt, sweeper, V_ref, cwd='./' ): """ Routine to plot the differences before, at, and after the switch. Produces the diffs_estimation_<sweeper_class>.png file + + Args: + dt_list (list): list of considered (initial) step sizes + problem (pySDC.core.Problem.ptype): Problem class used to consider (the class name) + restarts_SE (list): Restarts for the solve only using the switch estimator + restarts_adapt (list): Restarts for the solve of only using adaptivity + restarts_SE_adapt (list): Restarts for the solve of using both, switch estimator and adaptivity + sweeper (pySDC.core.Sweeper.sweeper): Sweeper used to solve (the class name) + V_ref (np.float): reference value for the switch + cwd (str): current working directory """ diffs_true_at = [] @@ -294,59 +261,77 @@ def differences_around_switch( diffs_false_after_adapt = [] for dt_item in dt_list: f1 = open(cwd + 'data/battery_dt{}_USETrue_USAFalse_{}.dat'.format(dt_item, sweeper), 'rb') - stats_TF = dill.load(f1) + stats_SE = dill.load(f1) f1.close() f2 = open(cwd + 'data/battery_dt{}_USEFalse_USAFalse_{}.dat'.format(dt_item, sweeper), 'rb') - stats_FF = dill.load(f2) + stats = dill.load(f2) f2.close() f3 = open(cwd + 'data/battery_dt{}_USETrue_USATrue_{}.dat'.format(dt_item, sweeper), 'rb') - stats_TT = dill.load(f3) + stats_SE_adapt = dill.load(f3) f3.close() f4 = open(cwd + 'data/battery_dt{}_USEFalse_USATrue_{}.dat'.format(dt_item, sweeper), 'rb') - stats_FT = dill.load(f4) + stats_adapt = dill.load(f4) f4.close() - val_switch_TF = get_sorted(stats_TF, type='switch1', sortby='time') - t_switch = [v[1] for v in val_switch_TF] + switches_SE = get_recomputed(stats_SE, type='switch', sortby='time') + t_switch = [v[1] for v in switches_SE] t_switch = t_switch[-1] # battery has only one single switch - val_switch_TT = get_sorted(stats_TT, type='switch1', sortby='time') - t_switch_adapt = [v[1] for v in val_switch_TT] - t_switch_adapt = t_switch_adapt[-1] + switches_SE_adapt = get_recomputed(stats_SE_adapt, type='switch', sortby='time') + t_switch_SE_adapt = [v[1] for v in switches_SE_adapt] + t_switch_SE_adapt = t_switch_SE_adapt[-1] - vC_TF = get_sorted(stats_TF, type='voltage C', recomputed=False, sortby='time') - vC_FT = get_sorted(stats_FT, type='voltage C', recomputed=False, sortby='time') - vC_TT = get_sorted(stats_TT, type='voltage C', recomputed=False, sortby='time') - vC_FF = get_sorted(stats_FF, type='voltage C', sortby='time') + vC_SE = [me[1][1] for me in get_sorted(stats_SE, type='u', recomputed=False)] + vC_adapt = [me[1][1] for me in get_sorted(stats_adapt, type='u', recomputed=False)] + vC_SE_adapt = [me[1][1] for me in get_sorted(stats_SE_adapt, type='u', recomputed=False)] + vC = [me[1][1] for me in get_sorted(stats, type='u', recomputed=False)] - diff_TF, diff_FF = [v[1] - V_ref for v in vC_TF], [v[1] - V_ref for v in vC_FF] - times_TF, times_FF = [v[0] for v in vC_TF], [v[0] for v in vC_FF] + diff_SE, diff = vC_SE - V_ref[0], vC - V_ref[0] + times_SE = [me[0] for me in get_sorted(stats_SE, type='u', recomputed=False)] + times = [me[0] for me in get_sorted(stats, type='u', recomputed=False)] - diff_FT, diff_TT = [v[1] - V_ref for v in vC_FT], [v[1] - V_ref for v in vC_TT] - times_FT, times_TT = [v[0] for v in vC_FT], [v[0] for v in vC_TT] + diff_adapt, diff_SE_adapt = vC_adapt - V_ref[0], vC_SE_adapt - V_ref[0] + times_adapt = [me[0] for me in get_sorted(stats_adapt, type='u', recomputed=False)] + times_SE_adapt = [me[0] for me in get_sorted(stats_SE_adapt, type='u', recomputed=False)] - for m in range(len(times_TF)): - if np.round(times_TF[m], 15) == np.round(t_switch, 15): - diffs_true_at.append(diff_TF[m]) + diffs_true_at.append( + [diff_SE[m] for m in range(len(times_SE)) if np.isclose(times_SE[m], t_switch, atol=1e-15)] + ) - for m in range(1, len(times_FF)): - if times_FF[m - 1] <= t_switch <= times_FF[m]: - diffs_false_before.append(diff_FF[m - 1]) - diffs_false_after.append(diff_FF[m]) + diffs_false_before.append([diff[m - 1] for m in range(1, len(times)) if times[m - 1] <= t_switch <= times[m]]) + diffs_false_after.append([diff[m] for m in range(1, len(times)) if times[m - 1] <= t_switch <= times[m]]) - for m in range(len(times_TT)): - if np.round(times_TT[m], 13) == np.round(t_switch_adapt, 13): - diffs_true_at_adapt.append(diff_TT[m]) - diffs_true_before_adapt.append(diff_TT[m - 1]) - diffs_true_after_adapt.append(diff_TT[m + 1]) + diffs_true_at_adapt.append( + [ + diff_SE_adapt[m] + for m in range(len(times_SE_adapt)) + if np.isclose(times_SE_adapt[m], t_switch_SE_adapt, atol=1e-13) + ] + ) + diffs_true_before_adapt.append( + [ + diff_SE_adapt[m - 1] + for m in range(len(times_SE_adapt)) + if np.isclose(times_SE_adapt[m], t_switch_SE_adapt, atol=1e-13) + ] + ) + diffs_true_after_adapt.append( + [ + diff_SE_adapt[m + 1] + for m in range(len(times_SE_adapt)) + if np.isclose(times_SE_adapt[m], t_switch_SE_adapt, atol=1e-13) + ] + ) - for m in range(len(times_FT)): - if times_FT[m - 1] <= t_switch <= times_FT[m]: - diffs_false_before_adapt.append(diff_FT[m - 1]) - diffs_false_after_adapt.append(diff_FT[m]) + diffs_false_before_adapt.append( + [diff_adapt[m - 1] for m in range(len(times_adapt)) if times_adapt[m - 1] <= t_switch <= times_adapt[m]] + ) + diffs_false_after_adapt.append( + [diff_adapt[m] for m in range(len(times_adapt)) if times_adapt[m - 1] <= t_switch <= times_adapt[m]] + ) setup_mpl() fig_around, ax_around = plt_helper.plt.subplots(1, 3, figsize=(9, 3), sharex='col', sharey='row') @@ -356,14 +341,15 @@ def differences_around_switch( pos13 = ax_around[0].plot(dt_list, diffs_true_at, 'ko--', label='at switch') ax_around[0].set_xticks(dt_list) ax_around[0].set_xticklabels(dt_list) + ax_around[0].tick_params(axis='both', which='major', labelsize=6) ax_around[0].set_xscale('log', base=10) ax_around[0].set_yscale('symlog', linthresh=1e-8) ax_around[0].set_ylim(-1, 1) - ax_around[0].set_xlabel(r'$\Delta t$', fontsize=6) + ax_around[0].set_xlabel(r'$\Delta t_\mathrm{initial}$', fontsize=6) ax_around[0].set_ylabel(r'$v_{C}-V_{ref}$', fontsize=6) restart_ax0 = ax_around[0].twinx() - restarts_plt0 = restart_ax0.plot(dt_list, restarts_true, 'cs--', label='Restarts') + restarts_plt0 = restart_ax0.plot(dt_list, restarts_SE, 'cs--', label='Restarts') restart_ax0.tick_params(labelsize=6) lines = pos11 + pos12 + pos13 + restarts_plt0 @@ -375,13 +361,14 @@ def differences_around_switch( pos22 = ax_around[1].plot(dt_list, diffs_false_after_adapt, 'bd--', label='after switch') ax_around[1].set_xticks(dt_list) ax_around[1].set_xticklabels(dt_list) + ax_around[1].tick_params(axis='both', which='major', labelsize=6) ax_around[1].set_xscale('log', base=10) ax_around[1].set_yscale('symlog', linthresh=1e-8) ax_around[1].set_ylim(-1, 1) - ax_around[1].set_xlabel(r'$\Delta t$', fontsize=6) + ax_around[1].set_xlabel(r'$\Delta t_\mathrm{initial}$', fontsize=6) restart_ax1 = ax_around[1].twinx() - restarts_plt1 = restart_ax1.plot(dt_list, restarts_false_adapt, 'cs--', label='Restarts') + restarts_plt1 = restart_ax1.plot(dt_list, restarts_adapt, 'cs--', label='Restarts') restart_ax1.tick_params(labelsize=6) lines = pos21 + pos22 + restarts_plt1 @@ -394,90 +381,103 @@ def differences_around_switch( pos33 = ax_around[2].plot(dt_list, diffs_true_at_adapt, 'ko--', label='at switch') ax_around[2].set_xticks(dt_list) ax_around[2].set_xticklabels(dt_list) + ax_around[2].tick_params(axis='both', which='major', labelsize=6) ax_around[2].set_xscale('log', base=10) ax_around[2].set_yscale('symlog', linthresh=1e-8) ax_around[2].set_ylim(-1, 1) - ax_around[2].set_xlabel(r'$\Delta t$', fontsize=6) + ax_around[2].set_xlabel(r'$\Delta t_\mathrm{initial}$', fontsize=6) restart_ax2 = ax_around[2].twinx() - restarts_plt2 = restart_ax2.plot(dt_list, restarts_true_adapt, 'cs--', label='Restarts') + restarts_plt2 = restart_ax2.plot(dt_list, restarts_SE_adapt, 'cs--', label='Restarts') restart_ax2.tick_params(labelsize=6) lines = pos31 + pos32 + pos33 + restarts_plt2 labels = [l.get_label() for l in lines] ax_around[2].legend(frameon=False, fontsize=6, loc='lower right') - fig_around.savefig('data/diffs_estimation_{}.png'.format(sweeper), dpi=300, bbox_inches='tight') + fig_around.savefig('data/diffs_around_switch_{}.png'.format(sweeper), dpi=300, bbox_inches='tight') plt_helper.plt.close(fig_around) def differences_over_time(dt_list, problem, sweeper, V_ref, cwd='./'): """ Routine to plot the differences in time using the switch estimator or not. Produces the difference_estimation_<sweeper_class>.png file + + Args: + dt_list (list): list of considered (initial) step sizes + problem (pySDC.core.Problem.ptype): Problem class used to consider (the class name) + sweeper (pySDC.core.Sweeper.sweeper): Sweeper used to solve (the class name) + V_ref (np.float): reference value for the switch + cwd (str): current working directory """ if len(dt_list) > 1: setup_mpl() fig_diffs, ax_diffs = plt_helper.plt.subplots( - 2, len(dt_list), figsize=(3 * len(dt_list), 4), sharex='col', sharey='row' + 2, len(dt_list), figsize=(4 * len(dt_list), 6), sharex='col', sharey='row' ) else: setup_mpl() - fig_diffs, ax_diffs = plt_helper.plt.subplots(2, 1, figsize=(3, 3)) + fig_diffs, ax_diffs = plt_helper.plt.subplots(2, 1, figsize=(4, 6)) count_ax = 0 for dt_item in dt_list: f1 = open(cwd + 'data/battery_dt{}_USETrue_USAFalse_{}.dat'.format(dt_item, sweeper), 'rb') - stats_TF = dill.load(f1) + stats_SE = dill.load(f1) f1.close() f2 = open(cwd + 'data/battery_dt{}_USEFalse_USAFalse_{}.dat'.format(dt_item, sweeper), 'rb') - stats_FF = dill.load(f2) + stats = dill.load(f2) f2.close() f3 = open(cwd + 'data/battery_dt{}_USETrue_USATrue_{}.dat'.format(dt_item, sweeper), 'rb') - stats_TT = dill.load(f3) + stats_SE_adapt = dill.load(f3) f3.close() f4 = open(cwd + 'data/battery_dt{}_USEFalse_USATrue_{}.dat'.format(dt_item, sweeper), 'rb') - stats_FT = dill.load(f4) + stats_adapt = dill.load(f4) f4.close() - val_switch_TF = get_sorted(stats_TF, type='switch1', sortby='time') - t_switch_TF = [v[1] for v in val_switch_TF] - t_switch_TF = t_switch_TF[-1] # battery has only one single switch + switches_SE = get_recomputed(stats_SE, type='switch', sortby='time') + t_switch_SE = [v[1] for v in switches_SE] + t_switch_SE = t_switch_SE[-1] # battery has only one single switch - val_switch_TT = get_sorted(stats_TT, type='switch1', sortby='time') - t_switch_adapt = [v[1] for v in val_switch_TT] - t_switch_adapt = t_switch_adapt[-1] + switches_SE_adapt = get_recomputed(stats_SE_adapt, type='switch', sortby='time') + t_switch_SE_adapt = [v[1] for v in switches_SE_adapt] + t_switch_SE_adapt = t_switch_SE_adapt[-1] - dt_FT = np.array(get_sorted(stats_FT, type='dt', recomputed=False, sortby='time')) - dt_TT = np.array(get_sorted(stats_TT, type='dt', recomputed=False, sortby='time')) + dt_adapt = np.array(get_sorted(stats_adapt, type='dt', recomputed=False)) + dt_SE_adapt = np.array(get_sorted(stats_SE_adapt, type='dt', recomputed=False)) - restart_FT = np.array(get_sorted(stats_FT, type='restart', recomputed=None, sortby='time')) - restart_TT = np.array(get_sorted(stats_TT, type='restart', recomputed=None, sortby='time')) + restart_adapt = np.array(get_sorted(stats_adapt, type='restart', recomputed=None)) + restart_SE_adapt = np.array(get_sorted(stats_SE_adapt, type='restart', recomputed=None)) - vC_TF = get_sorted(stats_TF, type='voltage C', recomputed=False, sortby='time') - vC_FT = get_sorted(stats_FT, type='voltage C', recomputed=False, sortby='time') - vC_TT = get_sorted(stats_TT, type='voltage C', recomputed=False, sortby='time') - vC_FF = get_sorted(stats_FF, type='voltage C', sortby='time') + vC_SE = [me[1][1] for me in get_sorted(stats_SE, type='u', recomputed=False)] + vC_adapt = [me[1][1] for me in get_sorted(stats_adapt, type='u', recomputed=False)] + vC_SE_adapt = [me[1][1] for me in get_sorted(stats_SE_adapt, type='u', recomputed=False)] + vC = [me[1][1] for me in get_sorted(stats, type='u', recomputed=False)] - diff_TF, diff_FF = [v[1] - V_ref for v in vC_TF], [v[1] - V_ref for v in vC_FF] - times_TF, times_FF = [v[0] for v in vC_TF], [v[0] for v in vC_FF] + diff_SE, diff = vC_SE - V_ref[0], vC - V_ref[0] + times_SE = [me[0] for me in get_sorted(stats_SE, type='u', recomputed=False)] + times = [me[0] for me in get_sorted(stats, type='u', recomputed=False)] - diff_FT, diff_TT = [v[1] - V_ref for v in vC_FT], [v[1] - V_ref for v in vC_TT] - times_FT, times_TT = [v[0] for v in vC_FT], [v[0] for v in vC_TT] + diff_adapt, diff_SE_adapt = vC_adapt - V_ref[0], vC_SE_adapt - V_ref[0] + times_adapt = [me[0] for me in get_sorted(stats_adapt, type='u', recomputed=False)] + times_SE_adapt = [me[0] for me in get_sorted(stats_SE_adapt, type='u', recomputed=False)] if len(dt_list) > 1: - ax_diffs[0, count_ax].set_title(r'$\Delta t$={}'.format(dt_item)) - ax_diffs[0, count_ax].plot(times_TF, diff_TF, label='SE=True, A=False', color='#ff7f0e') - ax_diffs[0, count_ax].plot(times_FF, diff_FF, label='SE=False, A=False', color='#1f77b4') - ax_diffs[0, count_ax].plot(times_FT, diff_FT, label='SE=False, A=True', color='red', linestyle='--') - ax_diffs[0, count_ax].plot(times_TT, diff_TT, label='SE=True, A=True', color='limegreen', linestyle='-.') - ax_diffs[0, count_ax].axvline(x=t_switch_TF, linestyle='--', linewidth=0.5, color='k', label='Switch') + ax_diffs[0, count_ax].set_title(r'$\Delta t$=%s' % dt_item) + ax_diffs[0, count_ax].plot(times_SE, diff_SE, label='SE=True, A=False', color='#ff7f0e') + ax_diffs[0, count_ax].plot(times, diff, label='SE=False, A=False', color='#1f77b4') + ax_diffs[0, count_ax].plot(times_adapt, diff_adapt, label='SE=False, A=True', color='red', linestyle='--') + ax_diffs[0, count_ax].plot( + times_SE_adapt, diff_SE_adapt, label='SE=True, A=True', color='limegreen', linestyle='-.' + ) + ax_diffs[0, count_ax].axvline(x=t_switch_SE, linestyle='--', linewidth=0.5, color='k', label='Switch') ax_diffs[0, count_ax].legend(frameon=False, fontsize=6, loc='lower left') ax_diffs[0, count_ax].set_yscale('symlog', linthresh=1e-5) + ax_diffs[0, count_ax].tick_params(axis='both', which='major', labelsize=6) if count_ax == 0: ax_diffs[0, count_ax].set_ylabel('Difference $v_{C}-V_{ref}$', fontsize=6) @@ -488,110 +488,128 @@ def differences_over_time(dt_list, problem, sweeper, V_ref, cwd='./'): ax_diffs[0, count_ax].legend(frameon=False, fontsize=6, loc='upper right') ax_diffs[1, count_ax].plot( - dt_FT[:, 0], dt_FT[:, 1], label=r'$\Delta t$ - SE=F, A=T', color='red', linestyle='--' + dt_adapt[:, 0], dt_adapt[:, 1], label=r'$\Delta t$ - SE=F, A=T', color='red', linestyle='--' ) ax_diffs[1, count_ax].plot([None], [None], label='Restart - SE=F, A=T', color='grey', linestyle='-.') - for i in range(len(restart_FT)): - if restart_FT[i, 1] > 0: - ax_diffs[1, count_ax].axvline(restart_FT[i, 0], color='grey', linestyle='-.') + for i in range(len(restart_adapt)): + if restart_adapt[i, 1] > 0: + ax_diffs[1, count_ax].axvline(restart_adapt[i, 0], color='grey', linestyle='-.') ax_diffs[1, count_ax].plot( - dt_TT[:, 0], dt_TT[:, 1], label=r'$ \Delta t$ - SE=T, A=T', color='limegreen', linestyle='-.' + dt_SE_adapt[:, 0], + dt_SE_adapt[:, 1], + label=r'$ \Delta t$ - SE=T, A=T', + color='limegreen', + linestyle='-.', ) ax_diffs[1, count_ax].plot([None], [None], label='Restart - SE=T, A=T', color='black', linestyle='-.') - for i in range(len(restart_TT)): - if restart_TT[i, 1] > 0: - ax_diffs[1, count_ax].axvline(restart_TT[i, 0], color='black', linestyle='-.') + for i in range(len(restart_SE_adapt)): + if restart_SE_adapt[i, 1] > 0: + ax_diffs[1, count_ax].axvline(restart_SE_adapt[i, 0], color='black', linestyle='-.') ax_diffs[1, count_ax].set_xlabel('Time', fontsize=6) + ax_diffs[1, count_ax].tick_params(axis='both', which='major', labelsize=6) if count_ax == 0: - ax_diffs[1, count_ax].set_ylabel(r'$\Delta t_{adapted}$', fontsize=6) + ax_diffs[1, count_ax].set_ylabel(r'$\Delta t_\mathrm{adapted}$', fontsize=6) - ax_diffs[1, count_ax].legend(frameon=True, fontsize=6, loc='upper left') + ax_diffs[1, count_ax].set_yscale('log', base=10) + ax_diffs[1, count_ax].legend(frameon=True, fontsize=6, loc='lower left') else: - ax_diffs[0].set_title(r'$\Delta t$={}'.format(dt_item)) - ax_diffs[0].plot(times_TF, diff_TF, label='SE=True', color='#ff7f0e') - ax_diffs[0].plot(times_FF, diff_FF, label='SE=False', color='#1f77b4') - ax_diffs[0].plot(times_FT, diff_FT, label='SE=False, A=True', color='red', linestyle='--') - ax_diffs[0].plot(times_TT, diff_TT, label='SE=True, A=True', color='limegreen', linestyle='-.') - ax_diffs[0].axvline(x=t_switch_TF, linestyle='--', linewidth=0.5, color='k', label='Switch') - ax_diffs[0].legend(frameon=False, fontsize=6, loc='lower left') + ax_diffs[0].set_title(r'$\Delta t$=%s' % dt_item) + ax_diffs[0].plot(times_SE, diff_SE, label='SE=True', color='#ff7f0e') + ax_diffs[0].plot(times, diff, label='SE=False', color='#1f77b4') + ax_diffs[0].plot(times_adapt, diff_adapt, label='SE=False, A=True', color='red', linestyle='--') + ax_diffs[0].plot(times_SE_adapt, diff_SE_adapt, label='SE=True, A=True', color='limegreen', linestyle='-.') + ax_diffs[0].axvline(x=t_switch_SE, linestyle='--', linewidth=0.5, color='k', label='Switch') + ax_diffs[0].tick_params(axis='both', which='major', labelsize=6) ax_diffs[0].set_yscale('symlog', linthresh=1e-5) ax_diffs[0].set_ylabel('Difference $v_{C}-V_{ref}$', fontsize=6) ax_diffs[0].legend(frameon=False, fontsize=6, loc='center right') - ax_diffs[1].plot(dt_FT[:, 0], dt_FT[:, 1], label='SE=False, A=True', color='red', linestyle='--') - ax_diffs[1].plot(dt_TT[:, 0], dt_TT[:, 1], label='SE=True, A=True', color='limegreen', linestyle='-.') + ax_diffs[1].plot(dt_adapt[:, 0], dt_adapt[:, 1], label='SE=False, A=True', color='red', linestyle='--') + ax_diffs[1].plot( + dt_SE_adapt[:, 0], dt_SE_adapt[:, 1], label='SE=True, A=True', color='limegreen', linestyle='-.' + ) + ax_diffs[1].tick_params(axis='both', which='major', labelsize=6) ax_diffs[1].set_xlabel('Time', fontsize=6) - ax_diffs[1].set_ylabel(r'$\Delta t_{adapted}$', fontsize=6) + ax_diffs[1].set_ylabel(r'$\Delta t_\mathrm{adapted}$', fontsize=6) + ax_diffs[1].set_yscale('log', base=10) ax_diffs[1].legend(frameon=False, fontsize=6, loc='upper right') count_ax += 1 plt_helper.plt.tight_layout() - fig_diffs.savefig('data/difference_estimation_{}.png'.format(sweeper), dpi=300, bbox_inches='tight') + fig_diffs.savefig('data/diffs_over_time_{}.png'.format(sweeper), dpi=300, bbox_inches='tight') plt_helper.plt.close(fig_diffs) def iterations_over_time(dt_list, maxiter, problem, sweeper, cwd='./'): """ Routine to plot the number of iterations over time using switch estimator or not. Produces the iters_<sweeper_class>.png file + + Args: + dt_list (list): list of considered (initial) step sizes + maxiter (np.int): maximum number of iterations + problem (pySDC.core.Problem.ptype): Problem class used to consider (the class name) + sweeper (pySDC.core.Sweeper.sweeper): Sweeper used to solve (the class name) + cwd (str): current working directory """ - iters_time_TF = [] - iters_time_FF = [] - iters_time_TT = [] - iters_time_FT = [] - times_TF = [] - times_FF = [] - times_TT = [] - times_FT = [] - t_switches_TF = [] - t_switches_adapt = [] + iters_time_SE = [] + iters_time = [] + iters_time_SE_adapt = [] + iters_time_adapt = [] + times_SE = [] + times = [] + times_SE_adapt = [] + times_adapt = [] + t_switches_SE = [] + t_switches_SE_adapt = [] for dt_item in dt_list: f1 = open(cwd + 'data/battery_dt{}_USETrue_USAFalse_{}.dat'.format(dt_item, sweeper), 'rb') - stats_TF = dill.load(f1) + stats_SE = dill.load(f1) f1.close() f2 = open(cwd + 'data/battery_dt{}_USEFalse_USAFalse_{}.dat'.format(dt_item, sweeper), 'rb') - stats_FF = dill.load(f2) + stats = dill.load(f2) f2.close() f3 = open(cwd + 'data/battery_dt{}_USETrue_USATrue_{}.dat'.format(dt_item, sweeper), 'rb') - stats_TT = dill.load(f3) + stats_SE_adapt = dill.load(f3) f3.close() f4 = open(cwd + 'data/battery_dt{}_USEFalse_USATrue_{}.dat'.format(dt_item, sweeper), 'rb') - stats_FT = dill.load(f4) + stats_adapt = dill.load(f4) f4.close() - iter_counts_TF_val = get_sorted(stats_TF, type='niter', recomputed=False, sortby='time') - iter_counts_TT_val = get_sorted(stats_TT, type='niter', recomputed=False, sortby='time') - iter_counts_FT_val = get_sorted(stats_FT, type='niter', recomputed=False, sortby='time') - iter_counts_FF_val = get_sorted(stats_FF, type='niter', recomputed=False, sortby='time') + # consider iterations before restarts to see what happens + iter_counts_SE_val = get_sorted(stats_SE, type='niter') + iter_counts_SE_adapt_val = get_sorted(stats_SE_adapt, type='niter') + iter_counts_adapt_val = get_sorted(stats_adapt, type='niter') + iter_counts_val = get_sorted(stats, type='niter') - iters_time_TF.append([v[1] for v in iter_counts_TF_val]) - iters_time_TT.append([v[1] for v in iter_counts_TT_val]) - iters_time_FT.append([v[1] for v in iter_counts_FT_val]) - iters_time_FF.append([v[1] for v in iter_counts_FF_val]) + iters_time_SE.append([v[1] for v in iter_counts_SE_val]) + iters_time_SE_adapt.append([v[1] for v in iter_counts_SE_adapt_val]) + iters_time_adapt.append([v[1] for v in iter_counts_adapt_val]) + iters_time.append([v[1] for v in iter_counts_val]) - times_TF.append([v[0] for v in iter_counts_TF_val]) - times_TT.append([v[0] for v in iter_counts_TT_val]) - times_FT.append([v[0] for v in iter_counts_FT_val]) - times_FF.append([v[0] for v in iter_counts_FF_val]) + times_SE.append([v[0] for v in iter_counts_SE_val]) + times_SE_adapt.append([v[0] for v in iter_counts_SE_adapt_val]) + times_adapt.append([v[0] for v in iter_counts_adapt_val]) + times.append([v[0] for v in iter_counts_val]) - val_switch_TF = get_sorted(stats_TF, type='switch1', sortby='time') - t_switch_TF = [v[1] for v in val_switch_TF] - t_switches_TF.append(t_switch_TF[-1]) + switches_SE = get_recomputed(stats_SE, type='switch', sortby='time') + t_switch_SE = [v[1] for v in switches_SE] + t_switches_SE.append(t_switch_SE[-1]) - val_switch_TT = get_sorted(stats_TT, type='switch1', sortby='time') - t_switch_adapt = [v[1] for v in val_switch_TT] - t_switches_adapt.append(t_switch_adapt[-1]) + switches_SE_adapt = get_recomputed(stats_SE_adapt, type='switch', sortby='time') + t_switch_SE_adapt = [v[1] for v in switches_SE_adapt] + t_switches_SE_adapt.append(t_switch_SE_adapt[-1]) if len(dt_list) > 1: setup_mpl() @@ -599,18 +617,15 @@ def iterations_over_time(dt_list, maxiter, problem, sweeper, cwd='./'): nrows=1, ncols=len(dt_list), figsize=(2 * len(dt_list) - 1, 3), sharex='col', sharey='row' ) for col in range(len(dt_list)): - ax_iter_all[col].plot(times_FF[col], iters_time_FF[col], label='SE=F, A=F') - ax_iter_all[col].plot(times_TF[col], iters_time_TF[col], label='SE=T, A=F') - ax_iter_all[col].plot(times_TT[col], iters_time_TT[col], '--', label='SE=T, A=T') - ax_iter_all[col].plot(times_FT[col], iters_time_FT[col], '--', label='SE=F, A=T') - ax_iter_all[col].axvline(x=t_switches_TF[col], linestyle='--', linewidth=0.5, color='k', label='Switch') - if t_switches_adapt[col] != t_switches_TF[col]: - ax_iter_all[col].axvline( - x=t_switches_adapt[col], linestyle='--', linewidth=0.5, color='k', label='Switch' - ) - ax_iter_all[col].set_title('dt={}'.format(dt_list[col])) + ax_iter_all[col].plot(times[col], iters_time[col], label='SE=F, A=F') + ax_iter_all[col].plot(times_SE[col], iters_time_SE[col], label='SE=T, A=F') + ax_iter_all[col].plot(times_SE_adapt[col], iters_time_SE_adapt[col], '--', label='SE=T, A=T') + ax_iter_all[col].plot(times_adapt[col], iters_time_adapt[col], '--', label='SE=F, A=T') + ax_iter_all[col].axvline(x=t_switches_SE[col], linestyle='--', linewidth=0.5, color='k', label='Switch') + ax_iter_all[col].set_title(r'$\Delta t_\mathrm{initial}$=%s' % dt_list[col]) ax_iter_all[col].set_ylim(0, maxiter + 2) ax_iter_all[col].set_xlabel('Time', fontsize=6) + ax_iter_all[col].tick_params(axis='both', which='major', labelsize=6) if col == 0: ax_iter_all[col].set_ylabel('Number iterations', fontsize=6) @@ -620,16 +635,15 @@ def iterations_over_time(dt_list, maxiter, problem, sweeper, cwd='./'): setup_mpl() fig_iter_all, ax_iter_all = plt_helper.plt.subplots(nrows=1, ncols=1, figsize=(3, 3)) - ax_iter_all.plot(times_FF[0], iters_time_FF[0], label='SE=False') - ax_iter_all.plot(times_TF[0], iters_time_TF[0], label='SE=True') - ax_iter_all.plot(times_TT[0], iters_time_TT[0], '--', label='SE=T, A=T') - ax_iter_all.plot(times_FT[0], iters_time_FT[0], '--', label='SE=F, A=T') - ax_iter_all.axvline(x=t_switches_TF[0], linestyle='--', linewidth=0.5, color='k', label='Switch') - if t_switches_adapt[0] != t_switches_TF[0]: - ax_iter_all.axvline(x=t_switches_adapt[0], linestyle='--', linewidth=0.5, color='k', label='Switch') - ax_iter_all.set_title('dt={}'.format(dt_list[0])) + ax_iter_all.plot(times[0], iters_time[0], label='SE=False') + ax_iter_all.plot(times_SE[0], iters_time_SE[0], label='SE=True') + ax_iter_all.plot(times_SE_adapt[0], iters_time_SE_adapt[0], '--', label='SE=T, A=T') + ax_iter_all.plot(times_adapt[0], iters_time_adapt[0], '--', label='SE=F, A=T') + ax_iter_all.axvline(x=t_switches_SE[0], linestyle='--', linewidth=0.5, color='k', label='Switch') + ax_iter_all.set_title(r'$\Delta t_\mathrm{initial}$=%s' % dt_list[0]) ax_iter_all.set_ylim(0, maxiter + 2) ax_iter_all.set_xlabel('Time', fontsize=6) + ax_iter_all.tick_params(axis='both', which='major', labelsize=6) ax_iter_all.set_ylabel('Number iterations', fontsize=6) ax_iter_all.legend(frameon=False, fontsize=6, loc='upper right') @@ -640,4 +654,4 @@ def iterations_over_time(dt_list, maxiter, problem, sweeper, cwd='./'): if __name__ == "__main__": - check() + run() diff --git a/pySDC/projects/PinTSimE/estimation_check_2capacitors.py b/pySDC/projects/PinTSimE/estimation_check_2capacitors.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf747af665d1fb89b483e73c3b9df6c27cd4596 --- /dev/null +++ b/pySDC/projects/PinTSimE/estimation_check_2capacitors.py @@ -0,0 +1,236 @@ +import numpy as np +import dill +from pathlib import Path + +from pySDC.helpers.stats_helper import get_sorted +from pySDC.core.Collocation import CollBase as Collocation +from pySDC.implementations.problem_classes.Battery import battery_n_capacitors +from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order +from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI +from pySDC.projects.PinTSimE.battery_model import controller_run, generate_description, get_recomputed, log_data +from pySDC.projects.PinTSimE.piline_model import setup_mpl +from pySDC.projects.PinTSimE.battery_2capacitors_model import ( + check_solution, + proof_assertions_description, + proof_assertions_time, +) +import pySDC.helpers.plot_helper as plt_helper + +from pySDC.projects.PinTSimE.switch_estimator import SwitchEstimator + + +def run(cwd='./'): + """ + Routine to check the differences between using a switch estimator or not + + Args: + cwd (str): current working directory + """ + + dt_list = [4e-1, 4e-2, 4e-3] + t0 = 0.0 + Tend = 3.5 + + problem_classes = [battery_n_capacitors] + sweeper_classes = [imex_1st_order] + + ncapacitors = 2 + alpha = 5.0 + V_ref = np.array([1.0, 1.0]) + C = np.array([1.0, 1.0]) + + use_switch_estimator = [True, False] + restarts_all = [] + restarts_dict = dict() + for problem, sweeper in zip(problem_classes, sweeper_classes): + for dt_item in dt_list: + for use_SE in use_switch_estimator: + description, controller_params = generate_description( + dt_item, + problem, + sweeper, + log_data, + False, + use_SE, + ncapacitors, + alpha, + V_ref, + C, + ) + + # Assertions + proof_assertions_description(description, False, use_SE) + + proof_assertions_time(dt_item, Tend, V_ref, alpha) + + stats = controller_run(description, controller_params, False, use_SE, t0, Tend) + + if use_SE: + switches = get_recomputed(stats, type='switch', sortby='time') + assert len(switches) >= 2, f"Expected at least 2 switches for dt: {dt_item}, got {len(switches)}!" + + check_solution(stats, dt_item, use_SE) + + fname = 'data/{}_dt{}_USE{}.dat'.format(problem.__name__, dt_item, use_SE) + f = open(fname, 'wb') + dill.dump(stats, f) + f.close() + + if use_SE: + restarts_dict[dt_item] = np.array(get_sorted(stats, type='restart', recomputed=None)) + restarts = restarts_dict[dt_item][:, 1] + restarts_all.append(np.sum(restarts)) + print("Restarts for dt: ", dt_item, " -- ", np.sum(restarts)) + + V_ref = description['problem_params']['V_ref'] + + val_switch_all = [] + diff_true_all1 = [] + diff_false_all_before1 = [] + diff_false_all_after1 = [] + diff_true_all2 = [] + diff_false_all_before2 = [] + diff_false_all_after2 = [] + restarts_dt_switch1 = [] + restarts_dt_switch2 = [] + for dt_item in dt_list: + f1 = open(cwd + 'data/{}_dt{}_USETrue.dat'.format(problem.__name__, dt_item), 'rb') + stats_true = dill.load(f1) + f1.close() + + f2 = open(cwd + 'data/{}_dt{}_USEFalse.dat'.format(problem.__name__, dt_item), 'rb') + stats_false = dill.load(f2) + f2.close() + + switches = get_recomputed(stats_true, type='switch', sortby='time') + t_switch = [v[1] for v in switches] + + val_switch_all.append([t_switch[0], t_switch[1]]) + + vC1_true = [me[1][1] for me in get_sorted(stats_true, type='u', recomputed=False)] + vC2_true = [me[1][2] for me in get_sorted(stats_true, type='u', recomputed=False)] + vC1_false = [me[1][1] for me in get_sorted(stats_false, type='u', recomputed=False)] + vC2_false = [me[1][2] for me in get_sorted(stats_false, type='u', recomputed=False)] + + diff_true1 = vC1_true - V_ref[0] + diff_true2 = vC2_true - V_ref[1] + diff_false1 = vC1_false - V_ref[0] + diff_false2 = vC2_false - V_ref[1] + + t_true = [me[0] for me in get_sorted(stats_true, type='u', recomputed=False)] + t_false = [me[0] for me in get_sorted(stats_false, type='u', recomputed=False)] + + diff_true_all1.append( + [diff_true1[m] for m in range(len(t_true)) if np.isclose(t_true[m], t_switch[0], atol=1e-15)] + ) + diff_true_all2.append( + [diff_true2[m] for m in range(len(t_true)) if np.isclose(t_true[m], t_switch[1], atol=1e-15)] + ) + + diff_false_all_before1.append( + [diff_false1[m - 1] for m in range(1, len(t_false)) if t_false[m - 1] < t_switch[0] < t_false[m]] + ) + diff_false_all_after1.append( + [diff_false1[m] for m in range(1, len(t_false)) if t_false[m - 1] < t_switch[0] < t_false[m]] + ) + + diff_false_all_before2.append( + [diff_false2[m - 1] for m in range(1, len(t_false)) if t_false[m - 1] < t_switch[1] < t_false[m]] + ) + diff_false_all_after2.append( + [diff_false2[m] for m in range(1, len(t_false)) if t_false[m - 1] < t_switch[1] < t_false[m]] + ) + + restarts_dt = restarts_dict[dt_item] + restarts_dt_switch1.append( + [ + np.sum(restarts_dt[0 : i - 1, 1]) + for i in range(len(restarts_dt[:, 0])) + if np.isclose(restarts_dt[i, 0], t_switch[0], atol=1e-13) + ] + ) + restarts_dt_switch2.append( + [ + np.sum(restarts_dt[i - 2 :, 1]) + for i in range(len(restarts_dt[:, 0])) + if np.isclose(restarts_dt[i, 0], t_switch[1], atol=1e-13) + ] + ) + + setup_mpl() + fig1, ax1 = plt_helper.plt.subplots(1, 1, figsize=(4.5, 3)) + ax1.set_title('Time evolution of $v_{C_{1}}-V_{ref1}$') + ax1.plot(t_true, diff_true1, label='SE=True', color='#ff7f0e') + ax1.plot(t_false, diff_false1, label='SE=False', color='#1f77b4') + ax1.axvline(x=t_switch[0], linestyle='--', color='k', label='Switch1') + ax1.legend(frameon=False, fontsize=10, loc='lower left') + ax1.set_yscale('symlog', linthresh=1e-5) + ax1.set_xlabel('Time') + + fig1.savefig('data/difference_estimation_vC1_dt{}.png'.format(dt_item), dpi=300, bbox_inches='tight') + plt_helper.plt.close(fig1) + + setup_mpl() + fig2, ax2 = plt_helper.plt.subplots(1, 1, figsize=(4.5, 3)) + ax2.set_title('Time evolution of $v_{C_{2}}-V_{ref2}$') + ax2.plot(t_true, diff_true2, label='SE=True', color='#ff7f0e') + ax2.plot(t_false, diff_false2, label='SE=False', color='#1f77b4') + ax2.axvline(x=t_switch[1], linestyle='--', color='k', label='Switch2') + ax2.legend(frameon=False, fontsize=10, loc='lower left') + ax2.set_yscale('symlog', linthresh=1e-5) + ax2.set_xlabel('Time') + + fig2.savefig('data/difference_estimation_vC2_dt{}.png'.format(dt_item), dpi=300, bbox_inches='tight') + plt_helper.plt.close(fig2) + + setup_mpl() + fig1, ax1 = plt_helper.plt.subplots(1, 1, figsize=(3, 3)) + ax1.set_title("Difference $v_{C_{1}}-V_{ref1}$") + pos1 = ax1.plot(dt_list, diff_false_all_before1, 'rs-', label='SE=False - before switch1') + pos2 = ax1.plot(dt_list, diff_false_all_after1, 'bd-', label='SE=False - after switch1') + pos3 = ax1.plot(dt_list, diff_true_all1, 'kd-', label='SE=True') + ax1.set_xticks(dt_list) + ax1.set_xticklabels(dt_list) + ax1.set_xscale('log', base=10) + ax1.set_yscale('symlog', linthresh=1e-10) + ax1.set_ylim(-2, 2) + ax1.set_xlabel(r'$\Delta t$') + + restart_ax = ax1.twinx() + restarts = restart_ax.plot(dt_list, restarts_dt_switch1, 'cs--', label='Restarts') + restart_ax.set_ylabel('Restarts') + + lines = pos1 + pos2 + pos3 + restarts + labels = [l.get_label() for l in lines] + ax1.legend(lines, labels, frameon=False, fontsize=8, loc='center right') + + fig1.savefig('data/diffs_estimation_vC1.png', dpi=300, bbox_inches='tight') + plt_helper.plt.close(fig1) + + setup_mpl() + fig2, ax2 = plt_helper.plt.subplots(1, 1, figsize=(3, 3)) + ax2.set_title("Difference $v_{C_{2}}-V_{ref2}$") + pos1 = ax2.plot(dt_list, diff_false_all_before2, 'rs-', label='SE=False - before switch2') + pos2 = ax2.plot(dt_list, diff_false_all_after2, 'bd-', label='SE=False - after switch2') + pos3 = ax2.plot(dt_list, diff_true_all2, 'kd-', label='SE=True') + ax2.set_xticks(dt_list) + ax2.set_xticklabels(dt_list) + ax2.set_xscale('log', base=10) + ax2.set_yscale('symlog', linthresh=1e-10) + ax2.set_ylim(-2, 2) + ax2.set_xlabel(r'$\Delta t$') + + restart_ax = ax2.twinx() + restarts = restart_ax.plot(dt_list, restarts_dt_switch2, 'cs--', label='Restarts') + restart_ax.set_ylabel('Restarts') + + lines = pos1 + pos2 + pos3 + restarts + labels = [l.get_label() for l in lines] + ax2.legend(lines, labels, frameon=False, fontsize=8, loc='center right') + + fig2.savefig('data/diffs_estimation_vC2.png', dpi=300, bbox_inches='tight') + plt_helper.plt.close(fig2) + + +if __name__ == "__main__": + run() diff --git a/pySDC/projects/PinTSimE/estimation_check_extended.py b/pySDC/projects/PinTSimE/estimation_check_extended.py deleted file mode 100644 index f075fe2de47cd65765f4a6c4b2b505b5bd3c0e53..0000000000000000000000000000000000000000 --- a/pySDC/projects/PinTSimE/estimation_check_extended.py +++ /dev/null @@ -1,291 +0,0 @@ -import numpy as np -import dill -from pathlib import Path - -from pySDC.helpers.stats_helper import get_sorted -from pySDC.core.Collocation import CollBase as Collocation -from pySDC.implementations.problem_classes.Battery_2Condensators import battery_2condensators -from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order -from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI -from pySDC.implementations.transfer_classes.TransferMesh import mesh_to_mesh -from pySDC.projects.PinTSimE.piline_model import setup_mpl -from pySDC.projects.PinTSimE.battery_2condensators_model import log_data, proof_assertions_description -import pySDC.helpers.plot_helper as plt_helper - -from pySDC.projects.PinTSimE.switch_estimator import SwitchEstimator - - -def run(dt, use_switch_estimator=True): - - # initialize level parameters - level_params = dict() - level_params['restol'] = 1e-13 - level_params['dt'] = dt - - # initialize sweeper parameters - sweeper_params = dict() - sweeper_params['quad_type'] = 'LOBATTO' - sweeper_params['num_nodes'] = 5 - sweeper_params['QI'] = 'LU' # For the IMEX sweeper, the LU-trick can be activated for the implicit part - sweeper_params['initial_guess'] = 'zero' - - # initialize problem parameters - problem_params = dict() - problem_params['Vs'] = 5.0 - problem_params['Rs'] = 0.5 - problem_params['C1'] = 1.0 - problem_params['C2'] = 1.0 - problem_params['R'] = 1.0 - problem_params['L'] = 1.0 - problem_params['alpha'] = 5.0 - problem_params['V_ref'] = np.array([1.0, 1.0]) # [V_ref1, V_ref2] - problem_params['set_switch'] = np.array([False, False], dtype=bool) - problem_params['t_switch'] = np.zeros(np.shape(problem_params['V_ref'])[0]) - - # initialize step parameters - step_params = dict() - step_params['maxiter'] = 20 - - # initialize controller parameters - controller_params = dict() - controller_params['logger_level'] = 20 - controller_params['hook_class'] = log_data - - # convergence controllers - convergence_controllers = dict() - if use_switch_estimator: - switch_estimator_params = {} - convergence_controllers[SwitchEstimator] = switch_estimator_params - - # fill description dictionary for easy step instantiation - description = dict() - description['problem_class'] = battery_2condensators # pass problem class - description['problem_params'] = problem_params # pass problem parameters - description['sweeper_class'] = imex_1st_order # pass sweeper - description['sweeper_params'] = sweeper_params # pass sweeper parameters - description['level_params'] = level_params # pass level parameters - description['step_params'] = step_params - description['space_transfer_class'] = mesh_to_mesh # pass spatial transfer class - - if use_switch_estimator: - description['convergence_controllers'] = convergence_controllers - - proof_assertions_description(description, problem_params) - - # set time parameters - t0 = 0.0 - Tend = 3.5 - - # instantiate controller - controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description) - - # get initial values on finest level - P = controller.MS[0].levels[0].prob - uinit = P.u_exact(t0) - - # call main function to get things done... - uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) - - Path("data").mkdir(parents=True, exist_ok=True) - fname = 'data/battery_2condensators.dat' - f = open(fname, 'wb') - dill.dump(stats, f) - f.close() - - # filter statistics by number of iterations - iter_counts = get_sorted(stats, type='niter', sortby='time') - - # compute and print statistics - min_iter = 20 - max_iter = 0 - - f = open('data/battery_2condensators_out.txt', 'w') - niters = np.array([item[1] for item in iter_counts]) - out = ' Mean number of iterations: %4.2f' % np.mean(niters) - f.write(out + '\n') - print(out) - for item in iter_counts: - out = 'Number of iterations for time %4.2f: %1i' % item - f.write(out + '\n') - # print(out) - min_iter = min(min_iter, item[1]) - max_iter = max(max_iter, item[1]) - - assert np.mean(niters) <= 12, "Mean number of iterations is too high, got %s" % np.mean(niters) - f.close() - - return stats, description - - -def check(cwd='./'): - """ - Routine to check the differences between using a switch estimator or not - """ - - dt_list = [4e-1, 4e-2, 4e-3] - use_switch_estimator = [True, False] - restarts_all = [] - restarts_dict = dict() - for dt_item in dt_list: - for item in use_switch_estimator: - stats, description = run(dt=dt_item, use_switch_estimator=item) - - fname = 'data/battery_2condensators_dt{}_USE{}.dat'.format(dt_item, item) - f = open(fname, 'wb') - dill.dump(stats, f) - f.close() - - if item: - restarts_dict[dt_item] = np.array(get_sorted(stats, type='restart', recomputed=None, sortby='time')) - restarts = restarts_dict[dt_item][:, 1] - restarts_all.append(np.sum(restarts)) - print("Restarts for dt: ", dt_item, " -- ", np.sum(restarts)) - - V_ref = description['problem_params']['V_ref'] - - val_switch_all = [] - diff_true_all1 = [] - diff_false_all_before1 = [] - diff_false_all_after1 = [] - diff_true_all2 = [] - diff_false_all_before2 = [] - diff_false_all_after2 = [] - restarts_dt_switch1 = [] - restarts_dt_switch2 = [] - for dt_item in dt_list: - f1 = open(cwd + 'data/battery_2condensators_dt{}_USETrue.dat'.format(dt_item), 'rb') - stats_true = dill.load(f1) - f1.close() - - f2 = open(cwd + 'data/battery_2condensators_dt{}_USEFalse.dat'.format(dt_item), 'rb') - stats_false = dill.load(f2) - f2.close() - - val_switch1 = get_sorted(stats_true, type='switch1', sortby='time') - val_switch2 = get_sorted(stats_true, type='switch2', sortby='time') - t_switch1 = [v[1] for v in val_switch1] - t_switch2 = [v[1] for v in val_switch2] - - t_switch1 = t_switch1[-1] - t_switch2 = t_switch2[-1] - - val_switch_all.append([t_switch1, t_switch2]) - - vC1_true = get_sorted(stats_true, type='voltage C1', recomputed=False, sortby='time') - vC2_true = get_sorted(stats_true, type='voltage C2', recomputed=False, sortby='time') - vC1_false = get_sorted(stats_false, type='voltage C1', sortby='time') - vC2_false = get_sorted(stats_false, type='voltage C2', sortby='time') - - diff_true1 = [v[1] - V_ref[0] for v in vC1_true] - diff_true2 = [v[1] - V_ref[1] for v in vC2_true] - diff_false1 = [v[1] - V_ref[0] for v in vC1_false] - diff_false2 = [v[1] - V_ref[1] for v in vC2_false] - - times_true1 = [v[0] for v in vC1_true] - times_true2 = [v[0] for v in vC2_true] - times_false1 = [v[0] for v in vC1_false] - times_false2 = [v[0] for v in vC2_false] - - for m in range(len(times_true1)): - if np.round(times_true1[m], 15) == np.round(t_switch1, 15): - diff_true_all1.append(diff_true1[m]) - - for m in range(len(times_true2)): - if np.round(times_true2[m], 15) == np.round(t_switch2, 15): - diff_true_all2.append(diff_true2[m]) - - for m in range(1, len(times_false1)): - if times_false1[m - 1] < t_switch1 < times_false1[m]: - diff_false_all_before1.append(diff_false1[m - 1]) - diff_false_all_after1.append(diff_false1[m]) - - for m in range(1, len(times_false2)): - if times_false2[m - 1] < t_switch2 < times_false2[m]: - diff_false_all_before2.append(diff_false2[m - 1]) - diff_false_all_after2.append(diff_false2[m]) - - restarts_dt = restarts_dict[dt_item] - for i in range(len(restarts_dt[:, 0])): - if round(restarts_dt[i, 0], 13) == round(t_switch1, 13): - restarts_dt_switch1.append(np.sum(restarts_dt[0:i, 1])) - - if round(restarts_dt[i, 0], 13) == round(t_switch2, 13): - restarts_dt_switch2.append(np.sum(restarts_dt[i - 1 :, 1])) - - setup_mpl() - fig1, ax1 = plt_helper.plt.subplots(1, 1, figsize=(4.5, 3)) - ax1.set_title('Time evolution of $v_{C_{1}}-V_{ref1}$') - ax1.plot(times_true1, diff_true1, label='SE=True', color='#ff7f0e') - ax1.plot(times_false1, diff_false1, label='SE=False', color='#1f77b4') - ax1.axvline(x=t_switch1, linestyle='--', color='k', label='Switch1') - ax1.legend(frameon=False, fontsize=10, loc='lower left') - ax1.set_yscale('symlog', linthresh=1e-5) - ax1.set_xlabel('Time') - - fig1.savefig('data/difference_estimation_vC1_dt{}.png'.format(dt_item), dpi=300, bbox_inches='tight') - plt_helper.plt.close(fig1) - - setup_mpl() - fig2, ax2 = plt_helper.plt.subplots(1, 1, figsize=(4.5, 3)) - ax2.set_title('Time evolution of $v_{C_{2}}-V_{ref2}$') - ax2.plot(times_true2, diff_true2, label='SE=True', color='#ff7f0e') - ax2.plot(times_false2, diff_false2, label='SE=False', color='#1f77b4') - ax2.axvline(x=t_switch2, linestyle='--', color='k', label='Switch2') - ax2.legend(frameon=False, fontsize=10, loc='lower left') - ax2.set_yscale('symlog', linthresh=1e-5) - ax2.set_xlabel('Time') - - fig2.savefig('data/difference_estimation_vC2_dt{}.png'.format(dt_item), dpi=300, bbox_inches='tight') - plt_helper.plt.close(fig2) - - setup_mpl() - fig1, ax1 = plt_helper.plt.subplots(1, 1, figsize=(3, 3)) - ax1.set_title("Difference $v_{C_{1}}-V_{ref1}$") - pos1 = ax1.plot(dt_list, diff_false_all_before1, 'rs-', label='SE=False - before switch1') - pos2 = ax1.plot(dt_list, diff_false_all_after1, 'bd-', label='SE=False - after switch1') - pos3 = ax1.plot(dt_list, diff_true_all1, 'kd-', label='SE=True') - ax1.set_xticks(dt_list) - ax1.set_xticklabels(dt_list) - ax1.set_xscale('log', base=10) - ax1.set_yscale('symlog', linthresh=1e-10) - ax1.set_ylim(-2, 2) - ax1.set_xlabel(r'$\Delta t$') - - restart_ax = ax1.twinx() - restarts = restart_ax.plot(dt_list, restarts_dt_switch1, 'cs--', label='Restarts') - restart_ax.set_ylabel('Restarts') - - lines = pos1 + pos2 + pos3 + restarts - labels = [l.get_label() for l in lines] - ax1.legend(lines, labels, frameon=False, fontsize=8, loc='center right') - - fig1.savefig('data/diffs_estimation_vC1.png', dpi=300, bbox_inches='tight') - plt_helper.plt.close(fig1) - - setup_mpl() - fig2, ax2 = plt_helper.plt.subplots(1, 1, figsize=(3, 3)) - ax2.set_title("Difference $v_{C_{2}}-V_{ref2}$") - pos1 = ax2.plot(dt_list, diff_false_all_before2, 'rs-', label='SE=False - before switch2') - pos2 = ax2.plot(dt_list, diff_false_all_after2, 'bd-', label='SE=False - after switch2') - pos3 = ax2.plot(dt_list, diff_true_all2, 'kd-', label='SE=True') - ax2.set_xticks(dt_list) - ax2.set_xticklabels(dt_list) - ax2.set_xscale('log', base=10) - ax2.set_yscale('symlog', linthresh=1e-10) - ax2.set_ylim(-2, 2) - ax2.set_xlabel(r'$\Delta t$') - - restart_ax = ax2.twinx() - restarts = restart_ax.plot(dt_list, restarts_dt_switch2, 'cs--', label='Restarts') - restart_ax.set_ylabel('Restarts') - - lines = pos1 + pos2 + pos3 + restarts - labels = [l.get_label() for l in lines] - ax2.legend(lines, labels, frameon=False, fontsize=8, loc='center right') - - fig2.savefig('data/diffs_estimation_vC2.png', dpi=300, bbox_inches='tight') - plt_helper.plt.close(fig2) - - -if __name__ == "__main__": - check() diff --git a/pySDC/projects/PinTSimE/switch_estimator.py b/pySDC/projects/PinTSimE/switch_estimator.py index 67addb60c2ea5f532aa02dbcdb1b81a2fd02fd0a..f31a290ea35f6e9b44560fc4134a88abeb30cbcd 100644 --- a/pySDC/projects/PinTSimE/switch_estimator.py +++ b/pySDC/projects/PinTSimE/switch_estimator.py @@ -2,7 +2,7 @@ import numpy as np import scipy as sp from pySDC.core.Collocation import CollBase -from pySDC.core.ConvergenceController import ConvergenceController +from pySDC.core.ConvergenceController import ConvergenceController, Status class SwitchEstimator(ConvergenceController): @@ -31,13 +31,34 @@ class SwitchEstimator(ConvergenceController): num_nodes=description['sweeper_params']['num_nodes'], quad_type=description['sweeper_params']['quad_type'], ) - self.coll_nodes_local = coll.nodes - self.switch_detected = False - self.switch_detected_step = False - self.t_switch = None - self.count_switches = 0 - self.dt_initial = description['level_params']['dt'] - return {'control_order': 100, **params} + + defaults = { + 'control_order': 100, + 'tol': description['level_params']['dt'], + 'coll_nodes': coll.nodes, + 'dt_initial': description['level_params']['dt'], + } + return {**defaults, **params} + + def setup_status_variables(self, controller, **kwargs): + """ + Adds switching specific variables to status variables. + + Args: + controller (pySDC.Controller): The controller + """ + + self.status = Status(['t_switch', 'switch_detected', 'switch_detected_step']) + + def reset_status_variables(self, controller, **kwargs): + """ + Resets status variables. + + Args: + controller (pySDC.Controller): The controller + """ + + self.setup_status_variables(controller, **kwargs) def get_new_step_size(self, controller, S): """ @@ -51,76 +72,62 @@ class SwitchEstimator(ConvergenceController): None """ - self.switch_detected = False # reset between steps - L = S.levels[0] - if not type(L.prob.params.V_ref) == int and not type(L.prob.params.V_ref) == float: - # if V_ref is not a scalar, but an (np.)array - V_ref = np.zeros(np.shape(L.prob.params.V_ref)[0], dtype=float) - for m in range(np.shape(L.prob.params.V_ref)[0]): - V_ref[m] = L.prob.params.V_ref[m] - else: - V_ref = np.array([L.prob.params.V_ref], dtype=float) + if S.status.iter == S.params.maxiter: - if S.status.iter > 0 and self.count_switches < np.shape(V_ref)[0]: - for m in range(len(L.u)): - if L.u[m][self.count_switches + 1] - V_ref[self.count_switches] <= 0: - self.switch_detected = True - m_guess = m - 1 - break + self.status.switch_detected, m_guess, vC_switch = L.prob.get_switching_info(L.u, L.time) - if self.switch_detected: - t_interp = [L.time + L.dt * self.coll_nodes_local[m] for m in range(len(self.coll_nodes_local))] - - vC_switch = [] - for m in range(1, len(L.u)): - vC_switch.append(L.u[m][self.count_switches + 1] - V_ref[self.count_switches]) + if self.status.switch_detected: + t_interp = [L.time + L.dt * self.params.coll_nodes[m] for m in range(len(self.params.coll_nodes))] # only find root if vc_switch[0], vC_switch[-1] have opposite signs (intermediate value theorem) if vC_switch[0] * vC_switch[-1] < 0: - self.t_switch = self.get_switch(t_interp, vC_switch, m_guess) + self.status.t_switch = self.get_switch(t_interp, vC_switch, m_guess) - # if the switch is not find, we need to do ... ? - if L.time < self.t_switch < L.time + L.dt: - r = 1 - tol = self.dt_initial / r + if L.time < self.status.t_switch < L.time + L.dt: - if not np.isclose(self.t_switch - L.time, L.dt, atol=tol): - dt_search = self.t_switch - L.time + dt_switch = self.status.t_switch - L.time + if not np.isclose(self.status.t_switch - L.time, L.dt, atol=self.params.tol): + self.log( + f"Located Switch at time {self.status.t_switch:.6f} is outside the range of tol={self.params.tol:.4e}", + S, + ) else: - print('Switch located at time: {}'.format(self.t_switch)) - dt_search = self.t_switch - L.time - L.prob.params.set_switch[self.count_switches] = self.switch_detected - L.prob.params.t_switch[self.count_switches] = self.t_switch + self.log( + f"Switch located at time {self.status.t_switch:.6f} inside tol={self.params.tol:.4e}", S + ) + + L.prob.t_switch = self.status.t_switch controller.hooks[0].add_to_stats( process=S.status.slot, time=L.time, level=L.level_index, iter=0, sweep=L.status.sweep, - type='switch{}'.format(self.count_switches + 1), - value=self.t_switch, + type='switch', + value=self.status.t_switch, ) - self.switch_detected_step = True + L.prob.count_switches() + self.status.switch_detected_step = True dt_planned = L.status.dt_new if L.status.dt_new is not None else L.params.dt # when a switch is found, time step to match with switch should be preferred - if self.switch_detected: - L.status.dt_new = dt_search + if self.status.switch_detected: + L.status.dt_new = dt_switch else: - L.status.dt_new = min([dt_planned, dt_search]) + L.status.dt_new = min([dt_planned, dt_switch]) else: - self.switch_detected = False + self.status.switch_detected = False else: - self.switch_detected = False + self.status.switch_detected = False def determine_restart(self, controller, S): """ @@ -134,8 +141,7 @@ class SwitchEstimator(ConvergenceController): None """ - if self.switch_detected: - print("Restart") + if self.status.switch_detected: S.status.restart = True S.status.force_done = True @@ -156,13 +162,9 @@ class SwitchEstimator(ConvergenceController): L = S.levels[0] - if self.switch_detected_step: - if L.prob.params.set_switch[self.count_switches] and L.time + L.dt >= self.t_switch: - self.count_switches += 1 - self.t_switch = None - self.switch_detected_step = False - - L.status.dt_new = self.dt_initial + if self.status.switch_detected_step: + if L.time + L.dt >= self.params.t_switch: + L.status.dt_new = L.status.dt_new if L.status.dt_new is not None else L.params.dt super(SwitchEstimator, self).post_step_processing(controller, S) diff --git a/pySDC/tests/test_projects/test_pintsime/test_battery_2capacitors_model.py b/pySDC/tests/test_projects/test_pintsime/test_battery_2capacitors_model.py new file mode 100644 index 0000000000000000000000000000000000000000..cb7dccd0cf28d0e04893fd10242b85ba55030669 --- /dev/null +++ b/pySDC/tests/test_projects/test_pintsime/test_battery_2capacitors_model.py @@ -0,0 +1,8 @@ +import pytest + + +@pytest.mark.base +def test_main(): + from pySDC.projects.PinTSimE.battery_2capacitors_model import run + + run() diff --git a/pySDC/tests/test_projects/test_pintsime/test_battery_2condensators_model.py b/pySDC/tests/test_projects/test_pintsime/test_battery_2condensators_model.py deleted file mode 100644 index 4965514ebaaf32bae3b1049fa24f608d8b7b1b27..0000000000000000000000000000000000000000 --- a/pySDC/tests/test_projects/test_pintsime/test_battery_2condensators_model.py +++ /dev/null @@ -1,8 +0,0 @@ -import pytest - - -@pytest.mark.base -def test_main(): - from pySDC.projects.PinTSimE.battery_2condensators_model import main - - main() diff --git a/pySDC/tests/test_projects/test_pintsime/test_estimation_check.py b/pySDC/tests/test_projects/test_pintsime/test_estimation_check.py index 375c365982e0e327ba827559df6de46d8709af15..01ec3fed73b7dcb01dd284afd6d9468aad0a6939 100644 --- a/pySDC/tests/test_projects/test_pintsime/test_estimation_check.py +++ b/pySDC/tests/test_projects/test_pintsime/test_estimation_check.py @@ -3,6 +3,6 @@ import pytest @pytest.mark.base def test_main(): - from pySDC.projects.PinTSimE.estimation_check import check + from pySDC.projects.PinTSimE.estimation_check import run - check() + run() diff --git a/pySDC/tests/test_projects/test_pintsime/test_estimation_check_2capacitors.py b/pySDC/tests/test_projects/test_pintsime/test_estimation_check_2capacitors.py new file mode 100644 index 0000000000000000000000000000000000000000..09751a7897487e72518bcf75ebc16aceb6fef6f3 --- /dev/null +++ b/pySDC/tests/test_projects/test_pintsime/test_estimation_check_2capacitors.py @@ -0,0 +1,8 @@ +import pytest + + +@pytest.mark.base +def test_main(): + from pySDC.projects.PinTSimE.estimation_check_2capacitors import run + + run() diff --git a/pySDC/tests/test_projects/test_pintsime/test_estimation_check_extended.py b/pySDC/tests/test_projects/test_pintsime/test_estimation_check_extended.py deleted file mode 100644 index 03f44a23e31d108cbfa2de1de1e7017204dcfdd0..0000000000000000000000000000000000000000 --- a/pySDC/tests/test_projects/test_pintsime/test_estimation_check_extended.py +++ /dev/null @@ -1,8 +0,0 @@ -import pytest - - -@pytest.mark.base -def test_main(): - from pySDC.projects.PinTSimE.estimation_check_extended import check - - check()