diff --git a/conftest.py b/conftest.py index c59ec4dc8a0ff0883a422c19107f803bdd860432..0726ea7cf9dbd259913c22cb87f83cb47ad5f40c 100644 --- a/conftest.py +++ b/conftest.py @@ -1,4 +1,5 @@ import os +import re import shutil @@ -20,5 +21,25 @@ def pytest_runtest_teardown(item, nextitem): shutil.rmtree(os.path.join(path, "data"), ignore_errors=True) if "TestExperiment" in list_dir: shutil.rmtree(os.path.join(path, "TestExperiment"), ignore_errors=True) + # remove all tracking json + remove_files_from_regex(list_dir, path, re.compile(r"tracking_\d*\.json")) + # remove all tracking pdf + remove_files_from_regex(list_dir, path, re.compile(r"tracking\.pdf")) + # remove all tracking json + remove_files_from_regex(list_dir, path, re.compile(r"logging_\d*\.log")) else: pass # nothing to do if next test is from same test class + + +def remove_files_from_regex(list_dir, path, regex): + r = list(filter(regex.search, list_dir)) + if len(r) > 0: + for e in r: + del_path = os.path.join(path, e) + try: + if os.path.isfile(del_path): + os.remove(del_path) + else: + shutil.rmtree(os.path.join(path, e), ignore_errors=True) + except: + pass diff --git a/src/helpers/__init__.py b/src/helpers/__init__.py index 4a428fd2fbbde81111305e9feca4bb4cbc1fb324..546713b3f18f2cb64c1527b57d1e9e2138e927aa 100644 --- a/src/helpers/__init__.py +++ b/src/helpers/__init__.py @@ -1,6 +1,6 @@ """Collection of different supporting functions and classes.""" -from .testing import PyTestRegex, PyTestAllEqual, xr_all_equal +from .testing import PyTestRegex, PyTestAllEqual from .time_tracking import TimeTracking, TimeTrackingWrapper from .logger import Logger from .helpers import remove_items, float_round, dict_to_xarray, to_list diff --git a/src/helpers/datastore.py b/src/helpers/datastore.py index a540d6f864775a2b333ecd544d507f28244b137c..b4615216000d887f16e6ed30d97215a261e12c6d 100644 --- a/src/helpers/datastore.py +++ b/src/helpers/datastore.py @@ -43,6 +43,9 @@ class CorrectScope: def __init__(self, func): """Construct decorator.""" + setattr(self, "wrapper", func) + if hasattr(func, "__wrapped__"): + func = func.__wrapped__ wraps(func)(self) def __call__(self, *args, **kwargs): @@ -59,7 +62,7 @@ class CorrectScope: args = self.update_tuple(args, new_arg, pos_scope) else: args = self.update_tuple(args, args[pos_scope], pos_scope, update=True) - return self.__wrapped__(*args, **kwargs) + return self.wrapper(*args, **kwargs) def __get__(self, instance, cls): """Create bound method object and supply self argument to the decorated method.""" @@ -97,6 +100,41 @@ class CorrectScope: return t_new +class TrackParameter: + + def __init__(self, func): + """Construct decorator.""" + wraps(func)(self) + + def __call__(self, *args, **kwargs): + """ + Call method of decorator. + """ + self.track(*args) + return self.__wrapped__(*args, **kwargs) + + def __get__(self, instance, cls): + """Create bound method object and supply self argument to the decorated method.""" + return types.MethodType(self, instance) + + def track(self, tracker_obj, *args): + name, obj, scope = self._decrypt_args(*args) + logging.debug(f"{self.__wrapped__.__name__}: {name}({scope})={obj}") + tracker = tracker_obj.tracker[-1] + new_entry = {"method": self.__wrapped__.__name__, "scope": scope} + if name in tracker: + tracker[name].append(new_entry) + else: + tracker[name] = [new_entry] + + @staticmethod + def _decrypt_args(*args): + if len(args) == 2: + return args[0], None, args[1] + else: + return args + + class AbstractDataStore(ABC): """ Abstract data store for all settings for the experiment workflow. @@ -106,6 +144,8 @@ class AbstractDataStore(ABC): adjustments. """ + tracker = [{}] + def __init__(self): """Initialise by creating empty data store.""" self._store: Dict = {} @@ -131,6 +171,27 @@ class AbstractDataStore(ABC): """ pass + @CorrectScope + def get_default(self, name: str, scope: str, default: Any) -> Any: + """ + Retrieve an object with `name` from `scope` and return given default if object wasn't found. + + Same functionality like the standard get method. But this method adds a default argument that is returned if no + data was stored in the data store. Use this function with care, because it will not report any errors and just + return the given default value. Currently, there is no statement that reports, if the returned value comes from + the data store or the default value. + + :param name: Name to look for + :param scope: scope to search the name for + :param default: default value that is return, if no data was found for given name and scope + + :return: the stored object or the default value + """ + try: + return self.get(name, scope) + except (NameNotFoundInDataStore, NameNotFoundInScope): + return default + def search_name(self, name: str) -> None: """ Abstract method to search for all occurrences of given `name` in the entire data store. @@ -235,6 +296,7 @@ class DataStoreByVariable(AbstractDataStore): """ @CorrectScope + @TrackParameter def set(self, name: str, obj: Any, scope: str, log: bool = False) -> None: """ Store an object `obj` with given `name` under `scope`. @@ -254,6 +316,7 @@ class DataStoreByVariable(AbstractDataStore): logging.debug(f"set: {name}({scope})={obj}") @CorrectScope + @TrackParameter def get(self, name: str, scope: str) -> Any: """ Retrieve an object with `name` from `scope`. @@ -270,28 +333,6 @@ class DataStoreByVariable(AbstractDataStore): """ return self._stride_through_scopes(name, scope)[2] - @CorrectScope - def get_default(self, name: str, scope: str, default: Any) -> Any: - """ - - Retrieve an object with `name` from `scope` and return given default if object wasn't found. - - Same functionality like the standard get method. But this method adds a default argument that is returned if no - data was stored in the data store. Use this function with care, because it will not report any errors and just - return the given default value. Currently, there is no statement that reports, if the returned value comes from - the data store or the default value. - - :param name: Name to look for - :param scope: scope to search the name for - :param default: default value that is return, if no data was found for given name and scope - - :return: the stored object or the default value - """ - try: - return self._stride_through_scopes(name, scope)[2] - except (NameNotFoundInDataStore, NameNotFoundInScope): - return default - @CorrectScope def _stride_through_scopes(self, name, scope, depth=0): if depth <= scope.count("."): @@ -407,6 +448,7 @@ class DataStoreByScope(AbstractDataStore): """ @CorrectScope + @TrackParameter def set(self, name: str, obj: Any, scope: str, log: bool = False) -> None: """ Store an object `obj` with given `name` under `scope`. @@ -425,6 +467,7 @@ class DataStoreByScope(AbstractDataStore): logging.debug(f"set: {name}({scope})={obj}") @CorrectScope + @TrackParameter def get(self, name: str, scope: str) -> Any: """ Retrieve an object with `name` from `scope`. @@ -441,27 +484,6 @@ class DataStoreByScope(AbstractDataStore): """ return self._stride_through_scopes(name, scope)[2] - @CorrectScope - def get_default(self, name: str, scope: str, default: Any) -> Any: - """ - Retrieve an object with `name` from `scope` and return given default if object wasn't found. - - Same functionality like the standard get method. But this method adds a default argument that is returned if no - data was stored in the data store. Use this function with care, because it will not report any errors and just - return the given default value. Currently, there is no statement that reports, if the returned value comes from - the data store or the default value. - - :param name: Name to look for - :param scope: scope to search the name for - :param default: default value that is return, if no data was found for given name and scope - - :return: the stored object or the default value - """ - try: - return self._stride_through_scopes(name, scope)[2] - except (NameNotFoundInDataStore, NameNotFoundInScope): - return default - @CorrectScope def _stride_through_scopes(self, name, scope, depth=0): if depth <= scope.count("."): diff --git a/src/helpers/testing.py b/src/helpers/testing.py index 3eea56bd9e1b748d1adc4c7d87af6f968f437734..244eb69fdc46dcadaeb3ada5779f09d44aa83e2a 100644 --- a/src/helpers/testing.py +++ b/src/helpers/testing.py @@ -2,6 +2,7 @@ import re from typing import Union, Pattern, List +import numpy as np import xarray as xr @@ -44,6 +45,13 @@ class PyTestAllEqual: def __init__(self, check_list: List): """Construct class.""" self._list = check_list + self._test_function = None + + def _set_test_function(self): + if isinstance(self._list[0], np.ndarray): + self._test_function = np.testing.assert_array_equal + else: + self._test_function = xr.testing.assert_equal def _check_all_equal(self) -> bool: """ @@ -52,8 +60,9 @@ class PyTestAllEqual: :return boolean if elements are equal """ equal = True + self._set_test_function() for b in self._list: - equal *= xr.testing.assert_equal(self._list[0], b) is None + equal *= self._test_function(self._list[0], b) is None return bool(equal == 1) def is_true(self) -> bool: diff --git a/src/plotting/tracker_plot.py b/src/plotting/tracker_plot.py new file mode 100644 index 0000000000000000000000000000000000000000..2d7b06cb6d7430be80eeae8ecedf811a3f2dc37c --- /dev/null +++ b/src/plotting/tracker_plot.py @@ -0,0 +1,378 @@ +from collections import OrderedDict + +import numpy as np +import os +from typing import Union, List, Optional, Dict + +from src.helpers import to_list + +from matplotlib import pyplot as plt, lines as mlines, ticker as ticker +from matplotlib.patches import Rectangle + + +class TrackObject: + + """ + A TrackObject can be used to create simple chains of objects. + + :param name: string or list of strings with a name describing the track object + :param stage: additional meta information (can be used to highlight different blocks inside a chain) + """ + + def __init__(self, name: Union[List[str], str], stage: str): + self.name = to_list(name) + self.stage = stage + self.precursor: Optional[List[TrackObject]] = None + self.successor: Optional[List[TrackObject]] = None + self.x: Optional[float] = None + self.y: Optional[float] = None + + def __repr__(self): + return str("/".join(self.name)) + + @property + def x(self): + """Get x value.""" + return self._x + + @x.setter + def x(self, value: float): + """Set x value.""" + self._x = value + + @property + def y(self): + """Get y value.""" + return self._y + + @y.setter + def y(self, value: float): + """Set y value.""" + self._y = value + + def add_precursor(self, precursor: "TrackObject"): + """Add a precursory track object.""" + if self.precursor is None: + self.precursor = [precursor] + else: + if precursor not in self.precursor: + self.precursor.append(precursor) + else: + return + precursor.add_successor(self) + + def add_successor(self, successor: "TrackObject"): + """Add a successive track object.""" + if self.successor is None: + self.successor = [successor] + else: + if successor not in self.successor: + self.successor.append(successor) + else: + return + successor.add_precursor(self) + + +class TrackChain: + + def __init__(self, track_list): + self.track_list = track_list + self.scopes = self.get_all_scopes(self.track_list) + self.dims = self.get_all_dims(self.scopes) + + def get_all_scopes(self, track_list) -> Dict: + """Return dictionary with all distinct variables as keys and its unique scopes as values.""" + dims = {} + for track_dict in track_list: # all stages + for track in track_dict.values(): # single stage, all variables + for k, v in track.items(): # single variable + scopes = self.get_unique_scopes(v) + if dims.get(k) is None: + dims[k] = scopes + else: + dims[k] = np.unique(scopes + dims[k]).tolist() + return OrderedDict(sorted(dims.items())) + + @staticmethod + def get_all_dims(scopes): + dims = {} + for k, v in scopes.items(): + dims[k] = len(v) + return dims + + def create_track_chain(self): + control = self.control_dict(self.scopes) + track_chain_dict = OrderedDict() + for track_dict in self.track_list: + stage, stage_track = list(track_dict.items())[0] + track_chain, control = self._create_track_chain(control, OrderedDict(sorted(stage_track.items())), stage) + control = self.clean_control(control) + track_chain_dict[stage] = track_chain + return track_chain_dict + + def _create_track_chain(self, control, sorted_track_dict, stage): + track_objects = [] + for variable, all_variable_tracks in sorted_track_dict.items(): + for track_details in all_variable_tracks: + method, scope = track_details["method"], track_details["scope"] + tr = TrackObject([variable, method, scope], stage) + control_obj = control[variable][scope] + if method == "set": + track_objects = self._add_set_object(track_objects, tr, control_obj) + elif method == "get": + track_objects, skip_control_update = self._add_get_object(track_objects, tr, control_obj, + control, scope, variable) + if skip_control_update is True: + continue + self._update_control(control, variable, scope, tr) + return track_objects, control + + @staticmethod + def _update_control(control, variable, scope, tr_obj): + control[variable][scope] = tr_obj + + @staticmethod + def _add_track_object(track_objects, tr_obj, prev_obj): + if tr_obj.stage != prev_obj.stage: + track_objects.append(prev_obj) + return track_objects + + def _add_precursor(self, track_objects, tr_obj, prev_obj): + tr_obj.add_precursor(prev_obj) + return self._add_track_object(track_objects, tr_obj, prev_obj) + + def _add_set_object(self, track_objects, tr_obj, control_obj): + if control_obj is not None: + track_objects = self._add_precursor(track_objects, tr_obj, control_obj) + else: + track_objects.append(tr_obj) + return track_objects + + def _recursive_decent(self, scope, control_obj_var): + scope = scope.rsplit(".", 1) + if len(scope) > 1: + scope = scope[0] + control_obj = control_obj_var[scope] + if control_obj is not None: + pre, candidate = control_obj, control_obj + while pre.precursor is not None and pre.name[1] != "set": + # change candidate on stage border + if pre.name[2] != pre.precursor[0].name[2]: + candidate = pre + pre = pre.precursor[0] + # correct pre if candidate is from same scope + if candidate.name[2] == pre.name[2]: + pre = candidate + return pre + else: + return self._recursive_decent(scope, control_obj_var) + + def _add_get_object(self, track_objects, tr_obj, control_obj, control, scope, variable): + skip_control_update = False + if control_obj is not None: + track_objects = self._add_precursor(track_objects, tr_obj, control_obj) + else: + pre = self._recursive_decent(scope, control[variable]) + if pre is not None: + track_objects = self._add_precursor(track_objects, tr_obj, pre) + else: + skip_control_update = True + return track_objects, skip_control_update + + @staticmethod + def control_dict(scopes): + """Create empty control dictionary with variables and scopes as keys and None as default for all values.""" + control = {} + for variable, scope_names in scopes.items(): + control[variable] = {} + for s in scope_names: + update = {s: None} + if len(control[variable].keys()) == 0: + control[variable] = update + else: + control[variable].update(update) + return control + + @staticmethod + def clean_control(control): + for k, v in control.items(): # var. scopes + for kv, vv in v.items(): # scope tr_obj + try: + if vv.precursor[0].name[2] != vv.name[2]: + control[k][kv] = None + except (TypeError, AttributeError): + pass + return control + + @staticmethod + def get_unique_scopes(track_list: List[Dict]) -> List[str]: + """Get list with all unique elements from input including general scope if missing.""" + scopes = [e["scope"] for e in track_list] + ["general"] + return np.unique(scopes).tolist() + + +class TrackPlot: + + def __init__(self, tracker_list, sparse_conn_mode=True, plot_folder: str = ".", skip_run_env=True): + + self.width = 0.6 + self.height = 0.5 + self.space_intern_y = 0.2 + self.space_extern_y = 1 + self.space_intern_x = 0.4 + self.space_extern_x = 0.6 + self.y_pos = None + self.anchor = None + self.x_max = None + + track_chain_obj = TrackChain(tracker_list) + track_chain_dict = track_chain_obj.create_track_chain() + self.set_ypos_anchor(track_chain_obj.scopes, track_chain_obj.dims) + self.fig, self.ax = plt.subplots(figsize=(len(tracker_list) * 2, (self.anchor.max() - self.anchor.min()) / 3)) + self._plot(track_chain_dict, sparse_conn_mode, skip_run_env, plot_folder) + + def _plot(self, track_chain_dict, sparse_conn_mode, skip_run_env, plot_folder): + stages, v_lines = self.create_track_chain_plot(track_chain_dict, sparse_conn_mode=sparse_conn_mode, + skip_run_env=skip_run_env) + self.set_lims() + self.add_variable_names() + self.add_stages(v_lines, stages) + plt.tight_layout() + plot_name = os.path.join(os.path.abspath(plot_folder), "tracking.pdf") + plt.savefig(plot_name, dpi=600) + + def line(self, start_x, end_x, y, color="darkgrey"): + """Draw grey horizontal connection line from start_x to end_x on y-pos.""" + # draw white border line + l = mlines.Line2D([start_x + self.width, end_x], [y + self.height / 2, y + self.height / 2], color="white", + linewidth=2.5) + self.ax.add_line(l) + # draw grey line + l = mlines.Line2D([start_x + self.width, end_x], [y + self.height / 2, y + self.height / 2], color=color, + linewidth=1.4) + self.ax.add_line(l) + + def step(self, start_x, end_x, start_y, end_y, color="black"): + """Draw black connection step line from start_xy to end_xy. Step is taken shortly before end position.""" + # adjust start and end by width height + start_x += self.width + start_y += self.height / 2 + end_y += self.height / 2 + step_x = end_x - (self.space_intern_x) / 2 # step is taken shortly before end + pos_x = [start_x, step_x, step_x, end_x] + pos_y = [start_y, start_y, end_y, end_y] + # draw white border line + l = mlines.Line2D(pos_x, pos_y, color="white", linewidth=2.5) + self.ax.add_line(l) + # draw black line + l = mlines.Line2D(pos_x, pos_y, color=color, linewidth=1.4) + self.ax.add_line(l) + + def rect(self, x, y, method="get"): + """Draw rectangle with lower left at (x,y), size equal to width/height and label/color according to method.""" + # draw rectangle + color = {"get": "orange"}.get(method, "lightblue") + r = Rectangle((x, y), self.width, self.height, color=color) + self.ax.add_artist(r) + # add label + rx, ry = r.get_xy() + cx = rx + r.get_width() / 2.0 + cy = ry + r.get_height() / 2.0 + self.ax.annotate(method, (cx, cy), color='w', weight='bold', fontsize=6, ha='center', va='center') + + def set_ypos_anchor(self, scopes, dims): + anchor = sum(dims.values()) + pos_dict = {} + d_y = 0 + for k, v in scopes.items(): + pos_dict[k] = {} + for e in v: + update = {e: anchor + d_y} + if len(pos_dict[k].keys()) == 0: + pos_dict[k] = update + else: + pos_dict[k].update(update) + d_y -= (self.space_intern_y + self.height) + d_y -= (self.space_extern_y - self.space_intern_y) + self.y_pos = pos_dict + self.anchor = np.array((d_y, self.height + self.space_extern_y)) + anchor + + def plot_track_chain(self, chain, y_pos, x_pos=0, prev=None, stage=None, sparse_conn_mode=False): + if (chain.successor is None) or (chain.stage == stage): + var, method, scope = chain.name + x, y = x_pos, y_pos[var][scope] + self.rect(x, y, method=method) + chain.x, chain.y = x, y + if prev is not None and prev[0] is not None: + if (sparse_conn_mode is True) and (method == "set"): + pass + else: + if y == prev[1]: + self.line(prev[0], x, prev[1]) + else: + self.step(prev[0], x, prev[1], y) + else: + x, y = chain.x, chain.y + + x_max = None + if chain.successor is not None: + for e in chain.successor: + if e.stage == stage: + shift = self.width + self.space_intern_x if chain.stage == e.stage else 0 + x_tmp = self.plot_track_chain(e, y_pos, x_pos + shift, prev=(x, y), + stage=stage, sparse_conn_mode=sparse_conn_mode) + x_max = np.nanmax(np.array([x_tmp, x_max], dtype=np.float64)) + else: + x_max = np.nanmax(np.array([x, x_max, x_pos], dtype=np.float64)) + else: + x_max = x + + return x_max + + def add_variable_names(self): + labels = [] + pos = [] + labels_major = [] + pos_major = [] + for k, v in self.y_pos.items(): + for kv, vv in v.items(): + if kv == "general": + labels_major.append(k) + pos_major.append(vv + self.height / 2) + else: + labels.append(kv.split(".", 1)[1]) + pos.append(vv + self.height / 2) + self.ax.tick_params(axis="y", which="major", labelsize="large") + self.ax.yaxis.set_major_locator(ticker.FixedLocator(pos_major)) + self.ax.yaxis.set_major_formatter(ticker.FixedFormatter(labels_major)) + self.ax.yaxis.set_minor_locator(ticker.FixedLocator(pos)) + self.ax.yaxis.set_minor_formatter(ticker.FixedFormatter(labels)) + + def add_stages(self, vlines, stages): + x_max = self.x_max + self.space_intern_x + self.width + for l in vlines: + self.ax.vlines(l, *self.anchor, "black", "dashed") + vlines = [0] + vlines + [x_max] + pos = [(vlines[i] + vlines[i+1]) / 2 for i in range(len(vlines)-1)] + self.ax.xaxis.set_major_locator(ticker.FixedLocator(pos)) + self.ax.xaxis.set_major_formatter(ticker.FixedFormatter(stages)) + + def create_track_chain_plot(self, track_chain_dict, sparse_conn_mode=True, skip_run_env=True): + x, x_max = 0, 0 + v_lines, stages = [], [] + for stage, track_chain in track_chain_dict.items(): + if stage == "RunEnvironment" and skip_run_env is True: + continue + if x > 0: + v_lines.append(x - self.space_extern_x / 2) + for e in track_chain: + x_max = max(x_max, self.plot_track_chain(e, self.y_pos, x_pos=x, stage=stage, sparse_conn_mode=sparse_conn_mode)) + x = x_max + self.space_extern_x + self.width + stages.append(stage) + self.x_max = x_max + return stages, v_lines + + def set_lims(self): + x_max = self.x_max + self.space_intern_x + self.width + self.ax.set_xlim((0, x_max)) + self.ax.set_ylim(self.anchor) diff --git a/src/run_modules/__init__.py b/src/run_modules/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..f06d627f6ff482e11c6d1c520fa59197feb831cd 100644 --- a/src/run_modules/__init__.py +++ b/src/run_modules/__init__.py @@ -0,0 +1,6 @@ +from src.run_modules.experiment_setup import ExperimentSetup +from src.run_modules.model_setup import ModelSetup +from src.run_modules.post_processing import PostProcessing +from src.run_modules.pre_processing import PreProcessing +from src.run_modules.run_environment import RunEnvironment +from src.run_modules.training import Training diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index 32b79a340adc221e89079884e8c54278d20f6217..97b0ea304e4e7236058609aac93cb3cc16f255df 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -4,7 +4,7 @@ __date__ = '2019-11-15' import argparse import logging import os -from typing import Union, Dict, Any +from typing import Union, Dict, Any, List from src.configuration import path_config from src import helpers @@ -27,14 +27,89 @@ DEFAULT_PLOT_LIST = ["PlotMonthlySummary", "PlotStationMap", "PlotClimatological class ExperimentSetup(RunEnvironment): """ - params: - trainable: Train new model if true, otherwise try to load existing model + Set up the model. + + Schedule of experiment setup: + #. set channels (from variables dimension) + #. build imported model + #. plot model architecture + #. load weights if enabled (e.g. to resume a training) + #. set callbacks and checkpoint + #. compile model + + Sets + * `channels` [model] + * `model` [model] + * `hist` [model] + * `callbacks` [model] + * `model_name` [model] + * all settings from model class like `dropout_rate`, `initial_lr`, `batch_size`, and `optimizer` [model] + + Creates + * plot of model architecture in `<model_name>.pdf` + + :param parser_args: argument parser, currently only accepting experiment_data argument + :param stations: list of stations or single station to use for experiment + :param network: name of network to restrict to use only stations from this measurement network + :param station_type: restrict network type to one of TOAR's categories (background, traffic, industrial) + :param variables: list of all variables to use + :param statistics_per_var: dictionary with statistics to use for variables (if data is daily and loaded from JOIN) + :param start: start date of overall data + :param end: end date of overall data + :param window_history_size: number of time steps to use for input data + :param target_var: target variable to predict by model + :param target_dim: dimension of this variable + :param window_lead_time: number of time steps to predict by model + :param dimensions: + :param interpolate_dim: + :param interpolate_method: + :param limit_nan_fill: + :param train_start: + :param train_end: + :param val_start: + :param val_end: + :param test_start: + :param test_end: + :param use_all_stations_on_all_data_sets: + :param trainable: + :param fraction_of_train: + :param experiment_path: + :param plot_path: + :param forecast_path: + :param overwrite_local_data: + :param sampling: + :param create_new_model: + :param bootstrap_path: + :param permute_data_on_training: + :param transformation: + :param train_min_length: + :param val_min_length: + :param test_min_length: + :param extreme_values: + :param extremes_on_right_tail_only: + :param evaluate_bootstraps: + :param plot_list: + :param number_of_bootstraps: + :param create_new_bootstraps: + """ - def __init__(self, parser_args=None, stations=None, network=None, station_type=None, variables=None, - statistics_per_var=None, start=None, end=None, window_history_size=None, target_var="o3", + def __init__(self, + parser_args=None, + stations: Union[str, List[str]] = None, + network: str = None, + station_type: str = None, + variables: Union[str, List[str]] = None, + statistics_per_var: Dict = None, + start: str = None, + end: str = None, + window_history_size: int = None, + target_var="o3", target_dim=None, - window_lead_time=None, dimensions=None, interpolate_dim=None, interpolate_method=None, + window_lead_time: int = None, + dimensions=None, + interpolate_dim=None, + interpolate_method=None, limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None, experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily", @@ -42,8 +117,7 @@ class ExperimentSetup(RunEnvironment): train_min_length=None, val_min_length=None, test_min_length=None, extreme_values=None, extremes_on_right_tail_only=None, evaluate_bootstraps=True, plot_list=None, number_of_bootstraps=None, create_new_bootstraps=None): - - # create run framework + """Set uo experiment.""" super().__init__() # experiment setup @@ -86,7 +160,6 @@ class ExperimentSetup(RunEnvironment): self._set_param("station_type", station_type, default=None) self._set_param("statistics_per_var", statistics_per_var, default=DEFAULT_VAR_ALL_DICT) self._set_param("variables", variables, default=list(self.data_store.get("statistics_per_var").keys())) - self._compare_variables_and_statistics() self._set_param("start", start, default="1997-01-01") self._set_param("end", end, default="2017-12-31") self._set_param("window_history_size", window_history_size, default=13) @@ -97,7 +170,6 @@ class ExperimentSetup(RunEnvironment): # target self._set_param("target_var", target_var, default="o3") - self._check_target_var() self._set_param("target_dim", target_dim, default='variables') self._set_param("window_lead_time", window_lead_time, default=3) @@ -138,7 +210,12 @@ class ExperimentSetup(RunEnvironment): self._set_param("number_of_bootstraps", number_of_bootstraps, default=20, scope="general.postprocessing") self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing") + # check variables, statistics and target variable + self._check_target_var() + self._compare_variables_and_statistics() + def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general") -> None: + """Set given parameter and log in debug.""" if value is None and default is not None: value = default self.data_store.set(param, value, scope) @@ -147,8 +224,10 @@ class ExperimentSetup(RunEnvironment): @staticmethod def _get_parser_args(args: Union[Dict, argparse.Namespace]) -> Dict: """ - Transform args to dict if given as argparse.Namespace + Transform args to dict if given as argparse.Namespace. + :param args: either a dictionary or an argument parser instance + :return: dictionary with all arguments """ if isinstance(args, argparse.Namespace): @@ -159,21 +238,23 @@ class ExperimentSetup(RunEnvironment): return {} def _compare_variables_and_statistics(self): + """ + Compare variables and statistics. + + * raise error, if a variable is missing. + * remove unused variables from statistics. + """ logging.debug("check if all variables are included in statistics_per_var") stat = self.data_store.get("statistics_per_var") var = self.data_store.get("variables") + # too less entries, raise error if not set(var).issubset(stat.keys()): missing = set(var).difference(stat.keys()) raise ValueError(f"Comparison of given variables and statistics_per_var show that not all requested " f"variables are part of statistics_per_var. Please add also information on the missing " f"statistics for the variables: {missing}") - - def _check_target_var(self): + # too much entries, remove unused target_var = helpers.to_list(self.data_store.get("target_var")) - stat = self.data_store.get("statistics_per_var") - var = self.data_store.get("variables") - if not set(target_var).issubset(stat.keys()): - raise ValueError(f"Could not find target variable {target_var} in statistics_per_var.") unused_vars = set(stat.keys()).difference(set(var).union(target_var)) if len(unused_vars) > 0: logging.info(f"There are unused keys in statistics_per_var. Therefore remove keys: {unused_vars}") @@ -181,6 +262,14 @@ class ExperimentSetup(RunEnvironment): self._set_param("statistics_per_var", stat_new) + def _check_target_var(self): + """Check if target variable is in statistics_per_var dictionary.""" + target_var = helpers.to_list(self.data_store.get("target_var")) + stat = self.data_store.get("statistics_per_var") + var = self.data_store.get("variables") + if not set(target_var).issubset(stat.keys()): + raise ValueError(f"Could not find target variable {target_var} in statistics_per_var.") + if __name__ == "__main__": formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]' logging.basicConfig(format=formatter, level=logging.DEBUG) diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index f5fc1f0fd627120f266b419b150eeb85b62c7389..823995181807a25b3e1759bdb77c639e16f8fd64 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -50,7 +50,7 @@ class ModelSetup(RunEnvironment): * all settings from model class like `dropout_rate`, `initial_lr`, `batch_size`, and `optimizer` [model] Creates - * plot of model architecture in `<model_name>.pdf` + * plot of model architecture `<model_name>.pdf` """ diff --git a/src/run_modules/run_environment.py b/src/run_modules/run_environment.py index b47c76d0a319eaa59546e031f8650c60c17be03f..16d21ae9b294a481e87e66a28b67f7ab759bbe78 100644 --- a/src/run_modules/run_environment.py +++ b/src/run_modules/run_environment.py @@ -3,6 +3,7 @@ __author__ = "Lukas Leufen" __date__ = '2019-11-25' +import json import logging import os import shutil @@ -12,6 +13,7 @@ from src.helpers.datastore import DataStoreByScope as DataStoreObject from src.helpers.datastore import NameNotFoundInDataStore from src.helpers import Logger from src.helpers import TimeTracking +from src.plotting.tracker_plot import TrackPlot class RunEnvironment(object): @@ -88,11 +90,15 @@ class RunEnvironment(object): del_by_exit = False data_store = DataStoreObject() logger = Logger() + tracker_list = [] def __init__(self): """Start time tracking automatically and logs as info.""" self.time = TimeTracking() logging.info(f"{self.__class__.__name__} started") + # atexit.register(self.__del__) + self.data_store.tracker.append({}) + self.tracker_list.extend([{self.__class__.__name__: self.data_store.tracker[-1]}]) def __del__(self): """ @@ -106,10 +112,15 @@ class RunEnvironment(object): self.time.stop() logging.info(f"{self.__class__.__name__} finished after {self.time}") self.del_by_exit = True - # copy log file and clear data store only if called as base class and not as super class - if self.__class__.__name__ == "RunEnvironment": - self.__copy_log_file() - self.data_store.clear_data_store() + # copy log file and clear data store only if called as base class and not as super class + if self.__class__.__name__ == "RunEnvironment": + try: + TrackPlot(self.tracker_list, True, plot_folder=self.data_store.get_default("experiment_path", ".")) + self.__save_tracking() + self.__copy_log_file() + except FileNotFoundError: + pass + self.data_store.clear_data_store() def __enter__(self): """Enter run environment.""" @@ -123,17 +134,28 @@ class RunEnvironment(object): def __copy_log_file(self): try: - counter = 0 - filename_pattern = os.path.join(self.data_store.get("experiment_path"), "logging_%03i.log") - new_file = filename_pattern % counter - while os.path.exists(new_file): - counter += 1 - new_file = filename_pattern % counter + new_file = self.__find_file_pattern("logging_%03i.log") logging.info(f"Copy log file to {new_file}") shutil.copyfile(self.logger.log_file, new_file) except (NameNotFoundInDataStore, FileNotFoundError): pass + def __save_tracking(self): + tracker = self.data_store.tracker + new_file = self.__find_file_pattern("tracking_%03i.json") + logging.info(f"Copy tracker file to {new_file}") + with open(new_file, "w") as f: + json.dump(tracker, f) + + def __find_file_pattern(self, name): + counter = 0 + filename_pattern = os.path.join(self.data_store.get_default("experiment_path", os.path.realpath(".")), name) + new_file = filename_pattern % counter + while os.path.exists(new_file): + counter += 1 + new_file = filename_pattern % counter + return new_file + @staticmethod def do_stuff(length=2): """Just a placeholder method for testing without any sense.""" diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py index 338bb7fd5cd1cb9a743cbc54ec9c4ed388a6bdaf..6abc722273613a1f4d6727396b114939b4d6a552 100644 --- a/test/test_modules/test_pre_processing.py +++ b/test/test_modules/test_pre_processing.py @@ -65,7 +65,7 @@ class TestPreProcessing: caplog.set_level(logging.DEBUG) obj_with_exp_setup.data_store.set("use_all_stations_on_all_data_sets", False, "general") obj_with_exp_setup.create_set_split(slice(0, 2), "awesome") - assert caplog.record_tuples[0] == ('root', 10, "Awesome stations (len=2): ['DEBW107', 'DEBY081']") + assert ('root', 10, "Awesome stations (len=2): ['DEBW107', 'DEBY081']") in caplog.record_tuples data_store = obj_with_exp_setup.data_store assert isinstance(data_store.get("generator", "general.awesome"), DataGenerator) with pytest.raises(NameNotFoundInScope): @@ -75,8 +75,8 @@ class TestPreProcessing: def test_create_set_split_all_stations(self, caplog, obj_with_exp_setup): caplog.set_level(logging.DEBUG) obj_with_exp_setup.create_set_split(slice(0, 2), "awesome") - assert caplog.record_tuples[0] == ('root', 10, "Awesome stations (len=6): ['DEBW107', 'DEBY081', 'DEBW013', " - "'DEBW076', 'DEBW087', 'DEBW001']") + message = "Awesome stations (len=6): ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001']" + assert ('root', 10, message) in caplog.record_tuples data_store = obj_with_exp_setup.data_store assert isinstance(data_store.get("generator", "general.awesome"), DataGenerator) with pytest.raises(NameNotFoundInScope): diff --git a/test/test_modules/test_run_environment.py b/test/test_modules/test_run_environment.py index d82675b57ea6feb4f83c99dab6f648c2846e4137..59bb8535c4dab44e646bd6bc4aa83a8553be4d26 100644 --- a/test/test_modules/test_run_environment.py +++ b/test/test_modules/test_run_environment.py @@ -17,7 +17,7 @@ class TestRunEnvironment: with RunEnvironment() as r: r.do_stuff(0.1) expression = PyTestRegex(r"RunEnvironment finished after \d+:\d+:\d+ \(hh:mm:ss\)") - assert caplog.record_tuples[-1] == ('root', 20, expression) + assert ('root', 20, expression) in caplog.record_tuples[-3:] def test_init(self, caplog): caplog.set_level(logging.INFO) @@ -30,4 +30,4 @@ class TestRunEnvironment: r.do_stuff(0.2) del r expression = PyTestRegex(r"RunEnvironment finished after \d+:\d+:\d+ \(hh:mm:ss\)") - assert caplog.record_tuples[-1] == ('root', 20, expression) + assert ('root', 20, expression) in caplog.record_tuples[-3:] diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py index 83b109940ca74475e7d865eaf690e1a757075815..33f9ddf62bd91c870643727de4d146ce332fbe07 100644 --- a/test/test_modules/test_training.py +++ b/test/test_modules/test_training.py @@ -202,8 +202,8 @@ class TestTraining: model_name = "test_model.h5" assert model_name not in os.listdir(path) init_without_run.save_model() - assert caplog.record_tuples[0] == ( - "root", 10, PyTestRegex(f"save best model to {os.path.join(path, model_name)}")) + message = PyTestRegex(f"save best model to {os.path.join(path, model_name)}") + assert caplog.record_tuples[1] == ("root", 10, message) assert model_name in os.listdir(path) def test_load_best_model_no_weights(self, init_without_run, caplog): diff --git a/test/test_plotting/test_tracker_plot.py b/test/test_plotting/test_tracker_plot.py new file mode 100644 index 0000000000000000000000000000000000000000..9a92360a819c130c213d06b89a48a896e082adad --- /dev/null +++ b/test/test_plotting/test_tracker_plot.py @@ -0,0 +1,447 @@ +import pytest + +from collections import OrderedDict +import os +import shutil + +from matplotlib import pyplot as plt +import numpy as np + +from src.plotting.tracker_plot import TrackObject, TrackChain, TrackPlot +from src.helpers import PyTestAllEqual + + +class TestTrackObject: + + @pytest.fixture + def track_obj(self): + return TrackObject("custom_name", "your_stage") + + def test_init(self, track_obj): + assert track_obj.name == ["custom_name"] + assert track_obj.stage == "your_stage" + assert all(track_obj.__getattribute__(obj) is None for obj in ["precursor", "successor", "x", "y"]) + + def test_repr(self, track_obj): + track_obj.name = ["custom", "name"] + assert repr(track_obj) == "custom/name" + + def test_x_property(self, track_obj): + assert track_obj.x is None + track_obj.x = 23 + assert track_obj.x == 23 + + def test_y_property(self, track_obj): + assert track_obj.y is None + track_obj.y = 21 + assert track_obj.y == 21 + + def test_add_precursor(self, track_obj): + assert track_obj.precursor is None + another_track_obj = TrackObject(["another", "track"], "your_stage") + track_obj.add_precursor(another_track_obj) + assert isinstance(track_obj.precursor, list) + assert track_obj.precursor[-1] == another_track_obj + assert len(track_obj.precursor) == 1 + assert another_track_obj.successor is not None + track_obj.add_precursor(another_track_obj) + assert len(track_obj.precursor) == 1 + track_obj.add_precursor(TrackObject(["third", "track"], "your_stage")) + assert len(track_obj.precursor) == 2 + + def test_add_successor(self, track_obj): + assert track_obj.successor is None + another_track_obj = TrackObject(["another", "track"], "your_stage") + track_obj.add_successor(another_track_obj) + assert isinstance(track_obj.successor, list) + assert track_obj.successor[-1] == another_track_obj + assert len(track_obj.successor) == 1 + assert another_track_obj.precursor is not None + track_obj.add_successor(another_track_obj) + assert len(track_obj.successor) == 1 + track_obj.add_successor(TrackObject(["third", "track"], "your_stage")) + assert len(track_obj.successor) == 2 + + +class TestTrackChain: + + @pytest.fixture + def track_list(self): + return [{'Stage1': {'test': [{'method': 'set', 'scope': 'general.daytime'}, + {'method': 'set', 'scope': 'general'}, + {'method': 'get', 'scope': 'general'}, + {'method': 'get', 'scope': 'general'},], + 'another': [{'method': 'set', 'scope': 'general'}]}}, + {'Stage2': {'sunlight': [{'method': 'set', 'scope': 'general'}], + 'another': [{'method': 'get', 'scope': 'general.daytime'}, + {'method': 'set', 'scope': 'general'}, + {'method': 'set', 'scope': 'general.daytime'}, + {'method': 'get', 'scope': 'general.daytime.noon'}, + {'method': 'get', 'scope': 'general.nighttime'}, + {'method': 'get', 'scope': 'general.daytime.noon'}]}}, + {'Stage3': {'another': [{'method': 'get', 'scope': 'general.daytime'}], + 'test': [{'method': 'get', 'scope': 'general'}], + 'moonlight': [{'method': 'set', 'scope': 'general.daytime'}]}}] + + @pytest.fixture + def track_chain(self, track_list): + return TrackChain(track_list) + + @pytest.fixture + def track_chain_object(self): + return object.__new__(TrackChain) + + def test_init(self, track_list): + chain = TrackChain(track_list) + assert chain.track_list == track_list + + def test_get_all_scopes(self, track_chain, track_list): + scopes = track_chain.get_all_scopes(track_list) + expected_scopes = {"another": ["general", "general.daytime", "general.daytime.noon", "general.nighttime"], + "moonlight": ["general", "general.daytime"], + "sunlight": ["general"], + "test": ["general", "general.daytime"]} + assert scopes == expected_scopes + + def test_get_unique_scopes(self, track_chain_object): + variable_calls = [{'method': 'get', 'scope': 'general.daytime'}, + {'method': 'set', 'scope': 'general'}, + {'method': 'set', 'scope': 'general.daytime'}, + {'method': 'get', 'scope': 'general.daytime.noon'}, + {'method': 'get', 'scope': 'general.nighttime'}, ] + + unique_scopes = track_chain_object.get_unique_scopes(variable_calls) + assert sorted(unique_scopes) == sorted(["general", "general.daytime", "general.daytime.noon", + "general.nighttime"]) + + def test_get_unique_scopes_no_general(self, track_chain_object): + variable_calls = [{'method': 'get', 'scope': 'general.daytime'}, + {'method': 'get', 'scope': 'general.nighttime'}, ] + unique_scopes = track_chain_object.get_unique_scopes(variable_calls) + assert sorted(unique_scopes) == sorted(["general", "general.daytime", "general.nighttime"]) + + def test_get_all_dims(self, track_chain_object): + scopes = {"another": ["general", "general.daytime", "general.daytime.noon", "general.nighttime"], + "moonlight": ["general", "general.daytime"], + "sunlight": ["general"], + "test": ["general", "general.daytime"]} + dims = track_chain_object.get_all_dims(scopes) + expected_dims = {"another": 4, "moonlight": 2, "sunlight": 1, "test": 2} + assert dims == expected_dims + + def test_create_track_chain(self, track_chain): + train_chain_dict = track_chain.create_track_chain() + assert list(train_chain_dict.keys()) == ["Stage1", "Stage2", "Stage3"] + assert len(train_chain_dict["Stage1"]) == 3 + assert len(train_chain_dict["Stage2"]) == 3 + assert len(train_chain_dict["Stage3"]) == 3 + + def test_control_dict(self, track_chain_object): + scopes = {"another": ["general", "general.daytime", "general.daytime.noon", "general.nighttime"], + "moonlight": ["general", "general.daytime"], + "sunlight": ["general"], + "test": ["general", "general.daytime"]} + control = track_chain_object.control_dict(scopes) + expected_control = {"another": {"general": None, "general.daytime": None, "general.daytime.noon": None, + "general.nighttime": None}, + "moonlight": {"general": None, "general.daytime": None}, + "sunlight": {"general": None}, + "test": {"general": None, "general.daytime": None}} + assert control == expected_control + + def test__create_track_chain(self, track_chain_object): + control = {'another': {'general': None, 'general.sub': None}, + 'first': {'general': None, 'general.sub': None}, + 'skip': {'general': None, 'general.sub': None}} + sorted_track_dict = OrderedDict([("another", [{"method": "set", "scope": "general"}, + {"method": "get", "scope": "general"}, + {"method": "get", "scope": "general.sub"}]), + ("first", [{"method": "set", "scope": "general.sub"}, + {"method": "get", "scope": "general.sub"}]), + ("skip", [{"method": "get", "scope": "general.sub"}]),]) + stage = "Stage1" + track_objects, control = track_chain_object._create_track_chain(control, sorted_track_dict, stage) + assert len(track_objects) == 2 + assert control["another"]["general"] is not None + assert control["first"]["general"] is None + assert control["skip"]["general.sub"] is None + + def test_add_precursor(self, track_chain_object): + track_objects = [] + tr_obj = TrackObject(["first", "get", "general"], "Stage1") + prev_obj = TrackObject(["first", "set", "general"], "Stage1") + assert len(track_chain_object._add_precursor(track_objects, tr_obj, prev_obj)) == 0 + assert tr_obj.precursor[0] == prev_obj + + def test_add_track_object_same_stage(self, track_chain_object): + track_objects = [] + tr_obj = TrackObject(["first", "get", "general"], "Stage1") + prev_obj = TrackObject(["first", "set", "general"], "Stage1") + assert len(track_chain_object._add_track_object(track_objects, tr_obj, prev_obj)) == 0 + + def test_add_track_object_different_stage(self, track_chain_object): + track_objects = [] + tr_obj = TrackObject(["first", "get", "general"], "Stage2") + prev_obj = TrackObject(["first", "set", "general"], "Stage1") + assert len(track_chain_object._add_track_object(track_objects, tr_obj, prev_obj)) == 1 + tr_obj = TrackObject(["first", "get", "general.sub"], "Stage2") + assert len(track_chain_object._add_track_object(track_objects, tr_obj, prev_obj)) == 2 + + def test_update_control(self, track_chain_object): + control = {'another': {'general': None, 'general.sub': None}, + 'first': {'general': None, 'general.sub': None}, } + variable, scope, tr_obj = "first", "general", 23 + track_chain_object._update_control(control, variable, scope, tr_obj) + assert control[variable][scope] == tr_obj + + def test_add_set_object(self, track_chain_object): + track_objects = [] + tr_obj = TrackObject(["first", "set", "general"], "Stage1") + control_obj = TrackObject(["first", "set", "general"], "Stage1") + assert len(track_chain_object._add_set_object(track_objects, tr_obj, control_obj)) == 0 + assert len(tr_obj.precursor) == 1 + control_obj = TrackObject(["first", "set", "general"], "Stage0") + assert len(track_chain_object._add_set_object(track_objects, tr_obj, control_obj)) == 1 + assert len(tr_obj.precursor) == 2 + + def test_add_set_object_no_control_obj(self, track_chain_object): + track_objects = [] + tr_obj = TrackObject(["first", "set", "general"], "Stage1") + assert len(track_chain_object._add_set_object(track_objects, tr_obj, None)) == 1 + assert tr_obj.precursor is None + + def test_add_get_object_no_new_track_obj(self, track_chain_object): + track_objects = [] + tr_obj = TrackObject(["first", "get", "general"], "Stage1") + pre = TrackObject(["first", "set", "general"], "Stage1") + control = {"testVar": {"general": pre, "general.sub": None}} + scope, variable = "general", "testVar" + res = track_chain_object._add_get_object(track_objects, tr_obj, pre, control, scope, variable) + assert res == ([], False) + assert pre.successor[0] == tr_obj + + def test_add_get_object_no_control_obj(self, track_chain_object): + track_objects = [] + tr_obj = TrackObject(["first", "get", "general"], "Stage1") + pre = TrackObject(["first", "set", "general"], "Stage1") + control = {"testVar": {"general": pre, "general.sub": None}} + scope, variable = "general.sub", "testVar" + res = track_chain_object._add_get_object(track_objects, tr_obj, None, control, scope, variable) + assert res == ([], False) + assert pre.successor[0] == tr_obj + + def test_add_get_object_skip_update(self, track_chain_object): + track_objects = [] + tr_obj = TrackObject(["first", "get", "general"], "Stage1") + control = {"testVar": {"general": None, "general.sub": None}} + scope, variable = "general.sub", "testVar" + res = track_chain_object._add_get_object(track_objects, tr_obj, None, control, scope, variable) + assert res == ([], True) + + def test_recursive_decent_avail_in_1_up(self, track_chain_object): + scope = "general.sub" + expected_pre = TrackObject(["first", "set", "general"], "Stage1") + control_obj_var = {"general": expected_pre} + pre = track_chain_object._recursive_decent(scope, control_obj_var) + assert pre == expected_pre + + def test_recursive_decent_avail_in_2_up(self, track_chain_object): + scope = "general.sub.sub" + expected_pre = TrackObject(["first", "set", "general"], "Stage1") + control_obj_var = {"general": expected_pre, "general.sub": None} + pre = track_chain_object._recursive_decent(scope, control_obj_var) + assert pre == expected_pre + + def test_recursive_decent_avail_from_chain(self, track_chain_object): + scope = "general.sub.sub" + expected_pre = TrackObject(["first", "set", "general"], "Stage1") + expected_pre.add_successor(TrackObject(["first", "get", "general.sub"], "Stage1")) + control_obj_var = {"general": expected_pre, "general.sub": expected_pre.successor[0]} + pre = track_chain_object._recursive_decent(scope, control_obj_var) + assert pre == expected_pre + + def test_recursive_decent_avail_from_chain_get(self, track_chain_object): + scope = "general.sub.sub" + expected_pre = TrackObject(["first", "get", "general"], "Stage1") + expected_pre.add_precursor(TrackObject(["first", "set", "general"], "Stage1")) + control_obj_var = {"general": expected_pre, "general.sub": None} + pre = track_chain_object._recursive_decent(scope, control_obj_var) + assert pre == expected_pre + + def test_recursive_decent_avail_from_chain_multiple_get(self, track_chain_object): + scope = "general.sub.sub" + expected_pre = TrackObject(["first", "get", "general"], "Stage1") + start_obj = TrackObject(["first", "set", "general"], "Stage1") + mid_obj = TrackObject(["first", "get", "general"], "Stage1") + expected_pre.add_precursor(mid_obj) + mid_obj.add_precursor(start_obj) + control_obj_var = {"general": expected_pre, "general.sub": None} + pre = track_chain_object._recursive_decent(scope, control_obj_var) + assert pre == expected_pre + + def test_clean_control(self, track_chain_object): + tr1 = TrackObject(["first", "get", "general"], "Stage1") + tr2 = TrackObject(["first", "set", "general"], "Stage1") + tr2.add_precursor(tr1) + tr3 = TrackObject(["first", "get", "general/sub"], "Stage1") + tr3.add_precursor(tr1) + control = {'another': {'general': None, 'general.sub': None}, + 'first': {'general': tr2, 'general.sub': tr3}, } + control = track_chain_object.clean_control(control) + expected_control = {'another': {'general': None, 'general.sub': None}, + 'first': {'general': tr2, 'general.sub': None}, } + assert control == expected_control + + +class TestTrackPlot: + + @pytest.fixture + def track_plot_obj(self): + return object.__new__(TrackPlot) + + @pytest.fixture + def track_list(self): + return [{'Stage1': {'test': [{'method': 'set', 'scope': 'general.daytime'}, + {'method': 'set', 'scope': 'general'}, + {'method': 'get', 'scope': 'general'}, + {'method': 'get', 'scope': 'general'},], + 'another': [{'method': 'set', 'scope': 'general'}]}}, + {'Stage2': {'sunlight': [{'method': 'set', 'scope': 'general'}], + 'another': [{'method': 'get', 'scope': 'general.daytime'}, + {'method': 'set', 'scope': 'general'}, + {'method': 'set', 'scope': 'general.daytime'}, + {'method': 'get', 'scope': 'general.daytime.noon'}, + {'method': 'get', 'scope': 'general.nighttime'}, + {'method': 'get', 'scope': 'general.daytime.noon'}]}}, + {'RunEnvironment': {'another': [{'method': 'get', 'scope': 'general.daytime'}], + 'test': [{'method': 'get', 'scope': 'general'}], + 'moonlight': [{'method': 'set', 'scope': 'general.daytime'}]}}] + + @pytest.fixture + def scopes(self): + return {"another": ["general", "general.daytime", "general.daytime.noon", "general.nighttime"], + "moonlight": ["general", "general.daytime"], + "sunlight": ["general"], + "test": ["general", "general.daytime"]} + + @pytest.fixture + def dims(self): + return {"another": 4, "moonlight": 2, "sunlight": 1, "test": 2} + + @pytest.fixture + def track_chain_dict(self, track_list): + return TrackChain(track_list).create_track_chain() + + @pytest.fixture + def path(self): + p = os.path.join(os.path.dirname(__file__), "TestExperiment") + if not os.path.exists(p): + os.makedirs(p) + yield p + shutil.rmtree(p, ignore_errors=True) + + def test_init(self, path, track_list): + assert "tracking.pdf" not in os.listdir(path) + TrackPlot(track_list, plot_folder=path) + assert "tracking.pdf" in os.listdir(path) + + def test_plot(self): + pass + + def test_line(self, track_plot_obj): + h, w = 0.6, 0.65 + track_plot_obj.height = h + track_plot_obj.width = w + track_plot_obj.fig, track_plot_obj.ax = plt.subplots() + assert len(track_plot_obj.ax.lines) == 0 + track_plot_obj.line(start_x=5, end_x=6, y=2) + assert len(track_plot_obj.ax.lines) == 2 + pos_x, pos_y = np.array([5 + w, 6]), np.ones((2, )) * (2 + h / 2) + assert track_plot_obj.ax.lines[0]._color == "white" + assert track_plot_obj.ax.lines[0]._linewidth == 2.5 + assert track_plot_obj.ax.lines[1]._color == "darkgrey" + assert track_plot_obj.ax.lines[1]._linewidth == 1.4 + assert PyTestAllEqual([track_plot_obj.ax.lines[0]._x, track_plot_obj.ax.lines[1]._x, pos_x]).is_true() + assert PyTestAllEqual([track_plot_obj.ax.lines[0]._y, track_plot_obj.ax.lines[1]._y, pos_y]).is_true() + + def test_step(self, track_plot_obj): + x_int, h, w = 0.5, 0.6, 0.65 + track_plot_obj.space_intern_x = x_int + track_plot_obj.height = h + track_plot_obj.width = w + track_plot_obj.fig, track_plot_obj.ax = plt.subplots() + assert len(track_plot_obj.ax.lines) == 0 + track_plot_obj.step(start_x=5, end_x=6, start_y=2, end_y=3) + assert len(track_plot_obj.ax.lines) == 2 + pos_x = np.array([5 + w, 6 - x_int / 2, 6 - x_int / 2, 6]) + pos_y = np.array([2 + h / 2, 2 + h / 2, 3 + h / 2, 3 + h / 2]) + assert track_plot_obj.ax.lines[0]._color == "white" + assert track_plot_obj.ax.lines[0]._linewidth == 2.5 + assert track_plot_obj.ax.lines[1]._color == "black" + assert track_plot_obj.ax.lines[1]._linewidth == 1.4 + assert PyTestAllEqual([track_plot_obj.ax.lines[0]._x, track_plot_obj.ax.lines[1]._x, pos_x]).is_true() + assert PyTestAllEqual([track_plot_obj.ax.lines[0]._y, track_plot_obj.ax.lines[1]._y, pos_y]).is_true() + + def test_rect(self, track_plot_obj): + h, w = 0.5, 0.6 + track_plot_obj.height = h + track_plot_obj.width = w + track_plot_obj.fig, track_plot_obj.ax = plt.subplots() + assert len(track_plot_obj.ax.artists) == 0 + assert len(track_plot_obj.ax.texts) == 0 + track_plot_obj.rect(x=4, y=2) + assert len(track_plot_obj.ax.artists) == 1 + assert len(track_plot_obj.ax.texts) == 1 + track_plot_obj.ax.artists[0].xy == (4, 2) + track_plot_obj.ax.artists[0]._height == h + track_plot_obj.ax.artists[0]._width == w + track_plot_obj.ax.artists[0]._original_facecolor == "orange" + track_plot_obj.ax.texts[0].xy == (4 + w / 2, 2 + h / 2) + track_plot_obj.ax.texts[0]._color == "w" + track_plot_obj.ax.texts[0]._text == "get" + track_plot_obj.rect(x=4, y=2, method="set") + assert len(track_plot_obj.ax.artists) == 2 + assert len(track_plot_obj.ax.texts) == 2 + track_plot_obj.ax.artists[0]._original_facecolor == "lightblue" + track_plot_obj.ax.texts[0]._text == "set" + + + + def test_set_ypos_anchor(self, track_plot_obj, scopes, dims): + assert not hasattr(track_plot_obj, "y_pos") + assert not hasattr(track_plot_obj, "anchor") + y_int, y_ext, h = 0.5, 0.7, 0.6 + track_plot_obj.space_intern_y = y_int + track_plot_obj.height = h + track_plot_obj.space_extern_y = y_ext + track_plot_obj.set_ypos_anchor(scopes, dims) + d_y = 0 - sum([factor * (y_int + h) + y_ext - y_int for factor in dims.values()]) + expected_anchor = (d_y + sum(dims.values()), h + y_ext + sum(dims.values())) + assert np.testing.assert_array_almost_equal(track_plot_obj.anchor, expected_anchor) is None + assert track_plot_obj.y_pos["another"]["general"] == sum(dims.values()) + assert track_plot_obj.y_pos["another"]["general.daytime"] == sum(dims.values()) - (h + y_int) + assert track_plot_obj.y_pos["another"]["general.daytime.noon"] == sum(dims.values()) - 2 * (h + y_int) + + def test_plot_track_chain(self): + pass + + def test_add_variable_names(self): + pass + + def test_add_stages(self): + pass + + def test_create_track_chain_plot_run_env(self): + pass + + def test_set_lims(self, track_plot_obj): + track_plot_obj.x_max = 10 + track_plot_obj.space_intern_x = 0.5 + track_plot_obj.width = 0.4 + track_plot_obj.anchor = np.array((0.1, 12.5)) + track_plot_obj.fig, track_plot_obj.ax = plt.subplots() + assert track_plot_obj.ax.get_ylim() == (0, 1) # matplotlib default + assert track_plot_obj.ax.get_xlim() == (0, 1) # matplotlib default + track_plot_obj.set_lims() + assert track_plot_obj.ax.get_ylim() == (0.1, 12.5) + assert track_plot_obj.ax.get_xlim() == (0, 10+0.5+0.4) \ No newline at end of file