diff --git a/src/plotting/tracker_plot.py b/src/plotting/tracker_plot.py index 3f979b65708c6c10377364c18d3e1013e3fdfd9d..e6502ea87dafe3dca1d79ea66cfe2879e2198aa4 100644 --- a/src/plotting/tracker_plot.py +++ b/src/plotting/tracker_plot.py @@ -112,48 +112,19 @@ class TrackChain: def _create_track_chain(self, control, sorted_track_dict, stage): track_objects = [] - for k, v in sorted_track_dict.items(): - for e in v: - tr = TrackObject([k, e["method"], e["scope"]], stage) - if e["method"] == "set": - if control[k][e["scope"]] is not None: - track_objects = self._add_precursor(track_objects, tr, control[k][e["scope"]]) - # tr.add_precursor(control[k][e["scope"]]) - # # if tr.stage != control[k][e["scope"]].stage: - # # track_objects.append(control[k][e["scope"]]) - # track_objects = self._add_track_object(track_objects, tr, control[k][e["scope"]]) - else: - track_objects.append(tr) - self._update_control(control, k, e["scope"], tr) - # control[k][e["scope"]] = tr - elif e["method"] == "get": - if control[k][e["scope"]] is not None: - track_objects = self._add_precursor(track_objects, tr, control[k][e["scope"]]) - # tr.add_precursor(control[k][e["scope"]]) - # # if tr.stage != control[k][e["scope"]].stage: - # # track_objects.append(control[k][e["scope"]]) - # track_objects = self._add_track_object(track_objects, tr, control[k][e["scope"]]) - # control[k][e["scope"]] = tr - self._update_control(control, k, e["scope"], tr) - else: - scope = e["scope"].rsplit(".", 1) - while len(scope) > 1: - scope = scope[0] - if control[k][scope] is not None: - pre = control[k][scope] - while pre.precursor is not None and pre.stage == stage and pre.name[1] != "set": - pre = pre.precursor[0] - # tr.add_precursor(pre) - # # if tr.stage != pre.stage: - # # track_objects.append(pre) - # track_objects = self._add_track_object(track_objects, tr, pre) - track_objects = self._add_precursor(track_objects, tr, pre) - break - scope = scope.rsplit(".", 1) - else: - continue - # control[k][e["scope"]] = tr - self._update_control(control, k, e["scope"], tr) + for variable, all_variable_tracks in sorted_track_dict.items(): + for track_details in all_variable_tracks: + method, scope = track_details["method"], track_details["scope"] + tr = TrackObject([variable, method, scope], stage) + control_obj = control[variable][scope] + if method == "set": + track_objects = self._add_set_object(track_objects, tr, control_obj) + elif method == "get": + track_objects, skip_control_update = self._add_get_object(track_objects, tr, control_obj, + control, scope, variable) + if skip_control_update is True: + continue + self._update_control(control, variable, scope, tr) return track_objects, control @staticmethod @@ -170,6 +141,44 @@ class TrackChain: tr_obj.add_precursor(prev_obj) return self._add_track_object(track_objects, tr_obj, prev_obj) + def _add_set_object(self, track_objects, tr_obj, control_obj): + if control_obj is not None: + track_objects = self._add_precursor(track_objects, tr_obj, control_obj) + else: + track_objects.append(tr_obj) + return track_objects + + def _recursive_decent(self, scope, control_obj_var): + scope = scope.rsplit(".", 1) + if len(scope) > 1: + scope = scope[0] + control_obj = control_obj_var[scope] + if control_obj is not None: + pre, candidate = control_obj, control_obj + while pre.precursor is not None and pre.name[1] != "set": + # change candidate on stage border + if pre.name[2] != pre.precursor[0].name[2]: + candidate = pre + pre = pre.precursor[0] + # correct pre if candidate is from same scope + if candidate.name[2] == pre.name[2]: + pre = candidate + return pre + else: + return self._recursive_decent(scope, control_obj_var) + + def _add_get_object(self, track_objects, tr_obj, control_obj, control, scope, variable): + skip_control_update = False + if control_obj is not None: + track_objects = self._add_precursor(track_objects, tr_obj, control_obj) + else: + pre = self._recursive_decent(scope, control[variable]) + if pre is not None: + track_objects = self._add_precursor(track_objects, tr_obj, pre) + else: + skip_control_update = True + return track_objects, skip_control_update + @staticmethod def control_dict(scopes): """Create empty control dictionary with variables and scopes as keys and None as default for all values.""" @@ -186,8 +195,8 @@ class TrackChain: @staticmethod def clean_control(control): - for k, v in control.items(): - for kv, vv in v.items(): + for k, v in control.items(): # var. scopes + for kv, vv in v.items(): # scope tr_obj try: if vv.precursor[0].name[2] != vv.name[2]: control[k][kv] = None @@ -202,7 +211,7 @@ class TrackChain: return np.unique(scopes).tolist() -class TrackerPlot: +class TrackPlot: def __init__(self, tracker_list, sparse_conn_mode=True, plot_folder: str = ".", skip_run_env=True): @@ -220,10 +229,7 @@ class TrackerPlot: scopes = track_chain_obj.scopes dims = track_chain_obj.dims track_chain_dict = track_chain_obj.create_track_chain() - # scopes = self.get_scopes(tracker_list) - # dims = self.get_dims(scopes) self.set_ypos_anchor(scopes, dims) - # track_chain_dict = self.create_track_chain(tracker_list, scopes) self.fig, self.ax = plt.subplots(figsize=(len(tracker_list) * 2, (self.anchor.max() - self.anchor.min()) / 3)) stages, v_lines = self.create_track_chain_plot(track_chain_dict, sparse_conn_mode=sparse_conn_mode, skip_run_env=skip_run_env) self.set_lims() @@ -255,8 +261,6 @@ class TrackerPlot: self.ax.add_line(l) def rect(self, x, y, method="get"): - # r = Rectangle((x, y), self.width, self.height, color=color, label=color) - # self.ax.add_patch(r) if method == "get": color = "orange" @@ -348,16 +352,6 @@ class TrackerPlot: self.ax.xaxis.set_major_locator(ticker.FixedLocator(pos)) self.ax.xaxis.set_major_formatter(ticker.FixedFormatter(stages)) - def create_track_chain(self, tracker_list, scopes): - control = self.control_dict(scopes) - track_chain_dict = OrderedDict() - for track_dict in tracker_list: - stage, stage_track = list(track_dict.items())[0] - track_chain, control = create_track_chain(control, OrderedDict(sorted(stage_track.items())), stage) - control = self.clean_control(control) - track_chain_dict[stage] = track_chain - return track_chain_dict - def create_track_chain_plot(self, track_chain_dict, sparse_conn_mode=True, skip_run_env=True): x, x_max = 0, 0 v_lines, stages = [], [] @@ -377,52 +371,3 @@ class TrackerPlot: x_max = self.x_max + self.space_intern_x + self.width self.ax.set_xlim((0, x_max)) self.ax.set_ylim(self.anchor) - - @staticmethod - def control_dict(scopes): - control = {} - for k, v in scopes.items(): - control[k] = {} - for e in v: - update = {e: None} - if len(control[k].keys()) == 0: - control[k] = update - else: - control[k].update(update) - return control - - @staticmethod - def clean_control(control): - for k, v in control.items(): - for kv, vv in v.items(): - try: - if vv.precursor[0].name[2] != vv.name[2]: - control[k][kv] = None - except (TypeError, AttributeError): - pass - return control - - @staticmethod - def get_scopes(track_list): - dims = {} - for track_dict in track_list: - for track in track_dict.values(): - for k, v in track.items(): - scopes = get_dim_scope(v) - if dims.get(k) is None: - dims[k] = scopes - else: - dims[k] = np.unique(scopes + dims[k]).tolist() - return OrderedDict(sorted(dims.items())) - - @staticmethod - def get_dims(scopes): - dims = {} - for k, v in scopes.items(): - dims[k] = len(v) - return dims - - -def get_dim_scope(track_list): - scopes = [e["scope"] for e in track_list] + ["general"] - return np.unique(scopes).tolist() diff --git a/src/run_modules/__init__.py b/src/run_modules/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..f06d627f6ff482e11c6d1c520fa59197feb831cd 100644 --- a/src/run_modules/__init__.py +++ b/src/run_modules/__init__.py @@ -0,0 +1,6 @@ +from src.run_modules.experiment_setup import ExperimentSetup +from src.run_modules.model_setup import ModelSetup +from src.run_modules.post_processing import PostProcessing +from src.run_modules.pre_processing import PreProcessing +from src.run_modules.run_environment import RunEnvironment +from src.run_modules.training import Training diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index 32b79a340adc221e89079884e8c54278d20f6217..97b0ea304e4e7236058609aac93cb3cc16f255df 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -4,7 +4,7 @@ __date__ = '2019-11-15' import argparse import logging import os -from typing import Union, Dict, Any +from typing import Union, Dict, Any, List from src.configuration import path_config from src import helpers @@ -27,14 +27,89 @@ DEFAULT_PLOT_LIST = ["PlotMonthlySummary", "PlotStationMap", "PlotClimatological class ExperimentSetup(RunEnvironment): """ - params: - trainable: Train new model if true, otherwise try to load existing model + Set up the model. + + Schedule of experiment setup: + #. set channels (from variables dimension) + #. build imported model + #. plot model architecture + #. load weights if enabled (e.g. to resume a training) + #. set callbacks and checkpoint + #. compile model + + Sets + * `channels` [model] + * `model` [model] + * `hist` [model] + * `callbacks` [model] + * `model_name` [model] + * all settings from model class like `dropout_rate`, `initial_lr`, `batch_size`, and `optimizer` [model] + + Creates + * plot of model architecture in `<model_name>.pdf` + + :param parser_args: argument parser, currently only accepting experiment_data argument + :param stations: list of stations or single station to use for experiment + :param network: name of network to restrict to use only stations from this measurement network + :param station_type: restrict network type to one of TOAR's categories (background, traffic, industrial) + :param variables: list of all variables to use + :param statistics_per_var: dictionary with statistics to use for variables (if data is daily and loaded from JOIN) + :param start: start date of overall data + :param end: end date of overall data + :param window_history_size: number of time steps to use for input data + :param target_var: target variable to predict by model + :param target_dim: dimension of this variable + :param window_lead_time: number of time steps to predict by model + :param dimensions: + :param interpolate_dim: + :param interpolate_method: + :param limit_nan_fill: + :param train_start: + :param train_end: + :param val_start: + :param val_end: + :param test_start: + :param test_end: + :param use_all_stations_on_all_data_sets: + :param trainable: + :param fraction_of_train: + :param experiment_path: + :param plot_path: + :param forecast_path: + :param overwrite_local_data: + :param sampling: + :param create_new_model: + :param bootstrap_path: + :param permute_data_on_training: + :param transformation: + :param train_min_length: + :param val_min_length: + :param test_min_length: + :param extreme_values: + :param extremes_on_right_tail_only: + :param evaluate_bootstraps: + :param plot_list: + :param number_of_bootstraps: + :param create_new_bootstraps: + """ - def __init__(self, parser_args=None, stations=None, network=None, station_type=None, variables=None, - statistics_per_var=None, start=None, end=None, window_history_size=None, target_var="o3", + def __init__(self, + parser_args=None, + stations: Union[str, List[str]] = None, + network: str = None, + station_type: str = None, + variables: Union[str, List[str]] = None, + statistics_per_var: Dict = None, + start: str = None, + end: str = None, + window_history_size: int = None, + target_var="o3", target_dim=None, - window_lead_time=None, dimensions=None, interpolate_dim=None, interpolate_method=None, + window_lead_time: int = None, + dimensions=None, + interpolate_dim=None, + interpolate_method=None, limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None, experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily", @@ -42,8 +117,7 @@ class ExperimentSetup(RunEnvironment): train_min_length=None, val_min_length=None, test_min_length=None, extreme_values=None, extremes_on_right_tail_only=None, evaluate_bootstraps=True, plot_list=None, number_of_bootstraps=None, create_new_bootstraps=None): - - # create run framework + """Set uo experiment.""" super().__init__() # experiment setup @@ -86,7 +160,6 @@ class ExperimentSetup(RunEnvironment): self._set_param("station_type", station_type, default=None) self._set_param("statistics_per_var", statistics_per_var, default=DEFAULT_VAR_ALL_DICT) self._set_param("variables", variables, default=list(self.data_store.get("statistics_per_var").keys())) - self._compare_variables_and_statistics() self._set_param("start", start, default="1997-01-01") self._set_param("end", end, default="2017-12-31") self._set_param("window_history_size", window_history_size, default=13) @@ -97,7 +170,6 @@ class ExperimentSetup(RunEnvironment): # target self._set_param("target_var", target_var, default="o3") - self._check_target_var() self._set_param("target_dim", target_dim, default='variables') self._set_param("window_lead_time", window_lead_time, default=3) @@ -138,7 +210,12 @@ class ExperimentSetup(RunEnvironment): self._set_param("number_of_bootstraps", number_of_bootstraps, default=20, scope="general.postprocessing") self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing") + # check variables, statistics and target variable + self._check_target_var() + self._compare_variables_and_statistics() + def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general") -> None: + """Set given parameter and log in debug.""" if value is None and default is not None: value = default self.data_store.set(param, value, scope) @@ -147,8 +224,10 @@ class ExperimentSetup(RunEnvironment): @staticmethod def _get_parser_args(args: Union[Dict, argparse.Namespace]) -> Dict: """ - Transform args to dict if given as argparse.Namespace + Transform args to dict if given as argparse.Namespace. + :param args: either a dictionary or an argument parser instance + :return: dictionary with all arguments """ if isinstance(args, argparse.Namespace): @@ -159,21 +238,23 @@ class ExperimentSetup(RunEnvironment): return {} def _compare_variables_and_statistics(self): + """ + Compare variables and statistics. + + * raise error, if a variable is missing. + * remove unused variables from statistics. + """ logging.debug("check if all variables are included in statistics_per_var") stat = self.data_store.get("statistics_per_var") var = self.data_store.get("variables") + # too less entries, raise error if not set(var).issubset(stat.keys()): missing = set(var).difference(stat.keys()) raise ValueError(f"Comparison of given variables and statistics_per_var show that not all requested " f"variables are part of statistics_per_var. Please add also information on the missing " f"statistics for the variables: {missing}") - - def _check_target_var(self): + # too much entries, remove unused target_var = helpers.to_list(self.data_store.get("target_var")) - stat = self.data_store.get("statistics_per_var") - var = self.data_store.get("variables") - if not set(target_var).issubset(stat.keys()): - raise ValueError(f"Could not find target variable {target_var} in statistics_per_var.") unused_vars = set(stat.keys()).difference(set(var).union(target_var)) if len(unused_vars) > 0: logging.info(f"There are unused keys in statistics_per_var. Therefore remove keys: {unused_vars}") @@ -181,6 +262,14 @@ class ExperimentSetup(RunEnvironment): self._set_param("statistics_per_var", stat_new) + def _check_target_var(self): + """Check if target variable is in statistics_per_var dictionary.""" + target_var = helpers.to_list(self.data_store.get("target_var")) + stat = self.data_store.get("statistics_per_var") + var = self.data_store.get("variables") + if not set(target_var).issubset(stat.keys()): + raise ValueError(f"Could not find target variable {target_var} in statistics_per_var.") + if __name__ == "__main__": formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]' logging.basicConfig(format=formatter, level=logging.DEBUG) diff --git a/src/run_modules/run_environment.py b/src/run_modules/run_environment.py index b47c76d0a319eaa59546e031f8650c60c17be03f..b0027119690ff736bb9491a7681fa602345cd50a 100644 --- a/src/run_modules/run_environment.py +++ b/src/run_modules/run_environment.py @@ -3,6 +3,7 @@ __author__ = "Lukas Leufen" __date__ = '2019-11-25' +import json import logging import os import shutil @@ -12,6 +13,7 @@ from src.helpers.datastore import DataStoreByScope as DataStoreObject from src.helpers.datastore import NameNotFoundInDataStore from src.helpers import Logger from src.helpers import TimeTracking +from src.plotting.tracker_plot import TrackPlot class RunEnvironment(object): @@ -88,11 +90,15 @@ class RunEnvironment(object): del_by_exit = False data_store = DataStoreObject() logger = Logger() + tracker_list = [] def __init__(self): """Start time tracking automatically and logs as info.""" self.time = TimeTracking() logging.info(f"{self.__class__.__name__} started") + # atexit.register(self.__del__) + self.data_store.tracker.append({}) + self.tracker_list.extend([{self.__class__.__name__: self.data_store.tracker[-1]}]) def __del__(self): """ @@ -106,10 +112,12 @@ class RunEnvironment(object): self.time.stop() logging.info(f"{self.__class__.__name__} finished after {self.time}") self.del_by_exit = True - # copy log file and clear data store only if called as base class and not as super class - if self.__class__.__name__ == "RunEnvironment": - self.__copy_log_file() - self.data_store.clear_data_store() + # copy log file and clear data store only if called as base class and not as super class + if self.__class__.__name__ == "RunEnvironment": + TrackPlot(self.tracker_list, True, plot_folder=self.data_store.get_default("experiment_path", ".")) + self.__save_tracking() + self.__copy_log_file() + self.data_store.clear_data_store() def __enter__(self): """Enter run environment.""" @@ -123,17 +131,28 @@ class RunEnvironment(object): def __copy_log_file(self): try: - counter = 0 - filename_pattern = os.path.join(self.data_store.get("experiment_path"), "logging_%03i.log") - new_file = filename_pattern % counter - while os.path.exists(new_file): - counter += 1 - new_file = filename_pattern % counter + new_file = self.__find_file_pattern("logging_%03i.log") logging.info(f"Copy log file to {new_file}") shutil.copyfile(self.logger.log_file, new_file) except (NameNotFoundInDataStore, FileNotFoundError): pass + def __save_tracking(self): + tracker = self.data_store.tracker + new_file = self.__find_file_pattern("tracking_%03i.json") + logging.info(f"Copy tracker file to {new_file}") + with open(new_file, "w") as f: + json.dump(tracker, f) + + def __find_file_pattern(self, name): + counter = 0 + filename_pattern = os.path.join(self.data_store.get_default("experiment_path", os.path.realpath(".")), name) + new_file = filename_pattern % counter + while os.path.exists(new_file): + counter += 1 + new_file = filename_pattern % counter + return new_file + @staticmethod def do_stuff(length=2): """Just a placeholder method for testing without any sense.""" diff --git a/test/test_plotting/test_tracker_plot.py b/test/test_plotting/test_tracker_plot.py new file mode 100644 index 0000000000000000000000000000000000000000..37a53c1c2dc21e311ad9f59a885f066ca20a802f --- /dev/null +++ b/test/test_plotting/test_tracker_plot.py @@ -0,0 +1,286 @@ +import pytest + +from collections import OrderedDict + +from src.plotting.tracker_plot import TrackObject, TrackChain + +class TestTrackObject: + + @pytest.fixture + def track_obj(self): + return TrackObject("custom_name", "your_stage") + + def test_init(self, track_obj): + assert track_obj.name == ["custom_name"] + assert track_obj.stage == "your_stage" + assert all(track_obj.__getattribute__(obj) is None for obj in ["precursor", "successor", "x", "y"]) + + def test_repr(self, track_obj): + track_obj.name = ["custom", "name"] + assert repr(track_obj) == "custom/name" + + def test_x_property(self, track_obj): + assert track_obj.x is None + track_obj.x = 23 + assert track_obj.x == 23 + + def test_y_property(self, track_obj): + assert track_obj.y is None + track_obj.y = 21 + assert track_obj.y == 21 + + def test_add_precursor(self, track_obj): + assert track_obj.precursor is None + another_track_obj = TrackObject(["another", "track"], "your_stage") + track_obj.add_precursor(another_track_obj) + assert isinstance(track_obj.precursor, list) + assert track_obj.precursor[-1] == another_track_obj + assert len(track_obj.precursor) == 1 + assert another_track_obj.successor is not None + track_obj.add_precursor(another_track_obj) + assert len(track_obj.precursor) == 1 + track_obj.add_precursor(TrackObject(["third", "track"], "your_stage")) + assert len(track_obj.precursor) == 2 + + def test_add_successor(self, track_obj): + assert track_obj.successor is None + another_track_obj = TrackObject(["another", "track"], "your_stage") + track_obj.add_successor(another_track_obj) + assert isinstance(track_obj.successor, list) + assert track_obj.successor[-1] == another_track_obj + assert len(track_obj.successor) == 1 + assert another_track_obj.precursor is not None + track_obj.add_successor(another_track_obj) + assert len(track_obj.successor) == 1 + track_obj.add_successor(TrackObject(["third", "track"], "your_stage")) + assert len(track_obj.successor) == 2 + + +class TestTrackChain: + + @pytest.fixture + def track_list(self): + return [{'Stage1': {'test': [{'method': 'set', 'scope': 'general.daytime'}, + {'method': 'set', 'scope': 'general'}, + {'method': 'get', 'scope': 'general'}, + {'method': 'get', 'scope': 'general'},], + 'another': [{'method': 'set', 'scope': 'general'}]}}, + {'Stage2': {'sunlight': [{'method': 'set', 'scope': 'general'}], + 'another': [{'method': 'get', 'scope': 'general.daytime'}, + {'method': 'set', 'scope': 'general'}, + {'method': 'set', 'scope': 'general.daytime'}, + {'method': 'get', 'scope': 'general.daytime.noon'}, + {'method': 'get', 'scope': 'general.nighttime'}, + {'method': 'get', 'scope': 'general.daytime.noon'}]}}, + {'Stage3': {'another': [{'method': 'get', 'scope': 'general.daytime'}], + 'test': [{'method': 'get', 'scope': 'general'}], + 'moonlight': [{'method': 'set', 'scope': 'general.daytime'}]}}] + + @pytest.fixture + def track_chain(self, track_list): + return TrackChain(track_list) + + @pytest.fixture + def track_chain_object(self): + return object.__new__(TrackChain) + + def test_init(self, track_list): + chain = TrackChain(track_list) + assert chain.track_list == track_list + + def test_get_all_scopes(self, track_chain, track_list): + scopes = track_chain.get_all_scopes(track_list) + expected_scopes = {"another": ["general", "general.daytime", "general.daytime.noon", "general.nighttime"], + "moonlight": ["general", "general.daytime"], + "sunlight": ["general"], + "test": ["general", "general.daytime"]} + assert scopes == expected_scopes + + def test_get_unique_scopes(self, track_chain_object): + variable_calls = [{'method': 'get', 'scope': 'general.daytime'}, + {'method': 'set', 'scope': 'general'}, + {'method': 'set', 'scope': 'general.daytime'}, + {'method': 'get', 'scope': 'general.daytime.noon'}, + {'method': 'get', 'scope': 'general.nighttime'}, ] + + unique_scopes = track_chain_object.get_unique_scopes(variable_calls) + assert sorted(unique_scopes) == sorted(["general", "general.daytime", "general.daytime.noon", + "general.nighttime"]) + + def test_get_unique_scopes_no_general(self, track_chain_object): + variable_calls = [{'method': 'get', 'scope': 'general.daytime'}, + {'method': 'get', 'scope': 'general.nighttime'}, ] + unique_scopes = track_chain_object.get_unique_scopes(variable_calls) + assert sorted(unique_scopes) == sorted(["general", "general.daytime", "general.nighttime"]) + + def test_get_all_dims(self, track_chain_object): + scopes = {"another": ["general", "general.daytime", "general.daytime.noon", "general.nighttime"], + "moonlight": ["general", "general.daytime"], + "sunlight": ["general"], + "test": ["general", "general.daytime"]} + dims = track_chain_object.get_all_dims(scopes) + expected_dims = {"another": 4, "moonlight": 2, "sunlight": 1, "test": 2} + assert dims == expected_dims + + def test_create_track_chain(self, track_chain): + train_chain_dict = track_chain.create_track_chain() + assert list(train_chain_dict.keys()) == ["Stage1", "Stage2", "Stage3"] + assert len(train_chain_dict["Stage1"]) == 3 + assert len(train_chain_dict["Stage2"]) == 3 + assert len(train_chain_dict["Stage3"]) == 3 + + def test_control_dict(self, track_chain_object): + scopes = {"another": ["general", "general.daytime", "general.daytime.noon", "general.nighttime"], + "moonlight": ["general", "general.daytime"], + "sunlight": ["general"], + "test": ["general", "general.daytime"]} + control = track_chain_object.control_dict(scopes) + expected_control = {"another": {"general": None, "general.daytime": None, "general.daytime.noon": None, + "general.nighttime": None}, + "moonlight": {"general": None, "general.daytime": None}, + "sunlight": {"general": None}, + "test": {"general": None, "general.daytime": None}} + assert control == expected_control + + def test__create_track_chain(self, track_chain_object): + control = {'another': {'general': None, 'general.sub': None}, + 'first': {'general': None, 'general.sub': None}, + 'skip': {'general': None, 'general.sub': None}} + sorted_track_dict = OrderedDict([("another", [{"method": "set", "scope": "general"}, + {"method": "get", "scope": "general"}, + {"method": "get", "scope": "general.sub"}]), + ("first", [{"method": "set", "scope": "general.sub"}, + {"method": "get", "scope": "general.sub"}]), + ("skip", [{"method": "get", "scope": "general.sub"}]),]) + stage = "Stage1" + track_objects, control = track_chain_object._create_track_chain(control, sorted_track_dict, stage) + assert len(track_objects) == 2 + assert control["another"]["general"] is not None + assert control["first"]["general"] is None + assert control["skip"]["general.sub"] is None + + def test_add_precursor(self, track_chain_object): + track_objects = [] + tr_obj = TrackObject(["first", "get", "general"], "Stage1") + prev_obj = TrackObject(["first", "set", "general"], "Stage1") + assert len(track_chain_object._add_precursor(track_objects, tr_obj, prev_obj)) == 0 + assert tr_obj.precursor[0] == prev_obj + + def test_add_track_object_same_stage(self, track_chain_object): + track_objects = [] + tr_obj = TrackObject(["first", "get", "general"], "Stage1") + prev_obj = TrackObject(["first", "set", "general"], "Stage1") + assert len(track_chain_object._add_track_object(track_objects, tr_obj, prev_obj)) == 0 + + def test_add_track_object_different_stage(self, track_chain_object): + track_objects = [] + tr_obj = TrackObject(["first", "get", "general"], "Stage2") + prev_obj = TrackObject(["first", "set", "general"], "Stage1") + assert len(track_chain_object._add_track_object(track_objects, tr_obj, prev_obj)) == 1 + tr_obj = TrackObject(["first", "get", "general.sub"], "Stage2") + assert len(track_chain_object._add_track_object(track_objects, tr_obj, prev_obj)) == 2 + + def test_update_control(self, track_chain_object): + control = {'another': {'general': None, 'general.sub': None}, + 'first': {'general': None, 'general.sub': None}, } + variable, scope, tr_obj = "first", "general", 23 + track_chain_object._update_control(control, variable, scope, tr_obj) + assert control[variable][scope] == tr_obj + + def test_add_set_object(self, track_chain_object): + track_objects = [] + tr_obj = TrackObject(["first", "set", "general"], "Stage1") + control_obj = TrackObject(["first", "set", "general"], "Stage1") + assert len(track_chain_object._add_set_object(track_objects, tr_obj, control_obj)) == 0 + assert len(tr_obj.precursor) == 1 + control_obj = TrackObject(["first", "set", "general"], "Stage0") + assert len(track_chain_object._add_set_object(track_objects, tr_obj, control_obj)) == 1 + assert len(tr_obj.precursor) == 2 + + def test_add_set_object_no_control_obj(self, track_chain_object): + track_objects = [] + tr_obj = TrackObject(["first", "set", "general"], "Stage1") + assert len(track_chain_object._add_set_object(track_objects, tr_obj, None)) == 1 + assert tr_obj.precursor is None + + def test_add_get_object_no_new_track_obj(self, track_chain_object): + track_objects = [] + tr_obj = TrackObject(["first", "get", "general"], "Stage1") + pre = TrackObject(["first", "set", "general"], "Stage1") + control = {"testVar": {"general": pre, "general.sub": None}} + scope, variable = "general", "testVar" + res = track_chain_object._add_get_object(track_objects, tr_obj, pre, control, scope, variable) + assert res == ([], False) + assert pre.successor[0] == tr_obj + + def test_add_get_object_no_control_obj(self, track_chain_object): + track_objects = [] + tr_obj = TrackObject(["first", "get", "general"], "Stage1") + pre = TrackObject(["first", "set", "general"], "Stage1") + control = {"testVar": {"general": pre, "general.sub": None}} + scope, variable = "general.sub", "testVar" + res = track_chain_object._add_get_object(track_objects, tr_obj, None, control, scope, variable) + assert res == ([], False) + assert pre.successor[0] == tr_obj + + def test_add_get_object_skip_update(self, track_chain_object): + track_objects = [] + tr_obj = TrackObject(["first", "get", "general"], "Stage1") + control = {"testVar": {"general": None, "general.sub": None}} + scope, variable = "general.sub", "testVar" + res = track_chain_object._add_get_object(track_objects, tr_obj, None, control, scope, variable) + assert res == ([], True) + + def test_recursive_decent_avail_in_1_up(self, track_chain_object): + scope = "general.sub" + expected_pre = TrackObject(["first", "set", "general"], "Stage1") + control_obj_var = {"general": expected_pre} + pre = track_chain_object._recursive_decent(scope, control_obj_var) + assert pre == expected_pre + + def test_recursive_decent_avail_in_2_up(self, track_chain_object): + scope = "general.sub.sub" + expected_pre = TrackObject(["first", "set", "general"], "Stage1") + control_obj_var = {"general": expected_pre, "general.sub": None} + pre = track_chain_object._recursive_decent(scope, control_obj_var) + assert pre == expected_pre + + def test_recursive_decent_avail_from_chain(self, track_chain_object): + scope = "general.sub.sub" + expected_pre = TrackObject(["first", "set", "general"], "Stage1") + expected_pre.add_successor(TrackObject(["first", "get", "general.sub"], "Stage1")) + control_obj_var = {"general": expected_pre, "general.sub": expected_pre.successor[0]} + pre = track_chain_object._recursive_decent(scope, control_obj_var) + assert pre == expected_pre + + def test_recursive_decent_avail_from_chain_get(self, track_chain_object): + scope = "general.sub.sub" + expected_pre = TrackObject(["first", "get", "general"], "Stage1") + expected_pre.add_precursor(TrackObject(["first", "set", "general"], "Stage1")) + control_obj_var = {"general": expected_pre, "general.sub": None} + pre = track_chain_object._recursive_decent(scope, control_obj_var) + assert pre == expected_pre + + def test_recursive_decent_avail_from_chain_multiple_get(self, track_chain_object): + scope = "general.sub.sub" + expected_pre = TrackObject(["first", "get", "general"], "Stage1") + start_obj = TrackObject(["first", "set", "general"], "Stage1") + mid_obj = TrackObject(["first", "get", "general"], "Stage1") + expected_pre.add_precursor(mid_obj) + mid_obj.add_precursor(start_obj) + control_obj_var = {"general": expected_pre, "general.sub": None} + pre = track_chain_object._recursive_decent(scope, control_obj_var) + assert pre == expected_pre + + def test_clean_control(self, track_chain_object): + tr1 = TrackObject(["first", "get", "general"], "Stage1") + tr2 = TrackObject(["first", "set", "general"], "Stage1") + tr2.add_precursor(tr1) + tr3 = TrackObject(["first", "get", "general/sub"], "Stage1") + tr3.add_precursor(tr1) + control = {'another': {'general': None, 'general.sub': None}, + 'first': {'general': tr2, 'general.sub': tr3}, } + control = track_chain_object.clean_control(control) + expected_control = {'another': {'general': None, 'general.sub': None}, + 'first': {'general': tr2, 'general.sub': None}, } + assert control == expected_control