diff --git a/src/datastore.py b/src/datastore.py index d9f844ff97acb3f5c6600205f91100219d9c53e6..fb1650808a72f2a4d8b6afc10940cd9d14f894ba 100644 --- a/src/datastore.py +++ b/src/datastore.py @@ -3,6 +3,9 @@ __date__ = '2019-11-22' from abc import ABC +from functools import wraps +import inspect +import types from typing import Any, List, Tuple, Dict @@ -27,6 +30,57 @@ class EmptyScope(Exception): pass +class CorrectScope: + """ + This class is used as decorator for all class methods, that have scope in parameters. After decoration, the scope + argument is not required on method call anymore. If no scope parameter is given, this decorator automatically adds + the default scope=`general` to the arguments. Furthermore, calls like `scope=general.sub` are obsolete, because this + decorator adds the prefix `general.` if not provided. Therefore, a call like `scope=sub` will actually become + `scope=general.sub` after passing this decorator. + """ + + def __init__(self, func): + wraps(func)(self) + + def __call__(self, *args, **kwargs): + f_arg = inspect.getfullargspec(self.__wrapped__) + pos_scope = f_arg.args.index("scope") + if len(args) < (len(f_arg.args) - len(f_arg.defaults or "")): + new_arg = kwargs.pop("scope", "general") or "general" + args = self.update_tuple(args, new_arg, pos_scope) + else: + args = self.update_tuple(args, args[pos_scope], pos_scope, update=True) + return self.__wrapped__(*args, **kwargs) + + def __get__(self, instance, cls): + return types.MethodType(self, instance) + + @staticmethod + def correct(arg: str): + """ + adds leading general prefix + :param arg: string argument of scope to add prefix general if necessary + :return: corrected string + """ + if not arg.startswith("general"): + arg = "general." + arg + return arg + + def update_tuple(self, t: Tuple, new: Any, ind: int, update: bool = False): + """ + Either updates a entry in given tuple t (<old1>, <old2>, <old3>) --(ind=1)--> (<old1>, <new>, <old3>) or slots + entry into given position (<old1>, <old2>, <old3>) --(ind=1,update=True)--> (<old1>, <new>, <old2>, <old3>). In + the latter case, length of returned tuple is increased by 1 in comparison to given tuple. + :param t: tuple to update + :param new: new element to add to tuple + :param ind: position to add or slot in + :param update: updates entry if true, otherwise slot in (default: False) + :return: updated tuple + """ + t_new = (*t[:ind], self.correct(new), *t[ind + update:]) + return t_new + + class AbstractDataStore(ABC): """ @@ -119,6 +173,7 @@ class DataStoreByVariable(AbstractDataStore): <scope3>: value """ + @CorrectScope def set(self, name: str, obj: Any, scope: str) -> None: """ Store an object `obj` with given `name` under `scope`. In the current implementation, existing entries are @@ -132,6 +187,7 @@ class DataStoreByVariable(AbstractDataStore): self._store[name] = {} self._store[name][scope] = obj + @CorrectScope def get(self, name: str, scope: str) -> Any: """ Retrieve an object with `name` from `scope`. If no object can be found in the exact scope, take an iterative @@ -144,6 +200,7 @@ class DataStoreByVariable(AbstractDataStore): """ return self._stride_through_scopes(name, scope)[2] + @CorrectScope def get_default(self, name: str, scope: str, default: Any) -> Any: """ Same functionality like the standard get method. But this method adds a default argument that is returned if no @@ -160,6 +217,7 @@ class DataStoreByVariable(AbstractDataStore): except (NameNotFoundInDataStore, NameNotFoundInScope): return default + @CorrectScope def _stride_through_scopes(self, name, scope, depth=0): if depth <= scope.count("."): local_scope = scope.rsplit(".", maxsplit=depth)[0] @@ -183,6 +241,7 @@ class DataStoreByVariable(AbstractDataStore): """ return sorted(self._store[name] if name in self._store.keys() else []) + @CorrectScope def search_scope(self, scope: str, current_scope_only=True, return_all=False) -> List[str or Tuple]: """ Search for given `scope` and list all object names stored under this scope. To look also for all superior scopes @@ -259,6 +318,7 @@ class DataStoreByScope(AbstractDataStore): <variable3>: value """ + @CorrectScope def set(self, name: str, obj: Any, scope: str) -> None: """ Store an object `obj` with given `name` under `scope`. In the current implementation, existing entries are @@ -271,6 +331,7 @@ class DataStoreByScope(AbstractDataStore): self._store[scope] = {} self._store[scope][name] = obj + @CorrectScope def get(self, name: str, scope: str) -> Any: """ Retrieve an object with `name` from `scope`. If no object can be found in the exact scope, take an iterative @@ -283,6 +344,7 @@ class DataStoreByScope(AbstractDataStore): """ return self._stride_through_scopes(name, scope)[2] + @CorrectScope def get_default(self, name: str, scope: str, default: Any) -> Any: """ Same functionality like the standard get method. But this method adds a default argument that is returned if no @@ -299,6 +361,7 @@ class DataStoreByScope(AbstractDataStore): except (NameNotFoundInDataStore, NameNotFoundInScope): return default + @CorrectScope def _stride_through_scopes(self, name, scope, depth=0): if depth <= scope.count("."): local_scope = scope.rsplit(".", maxsplit=depth)[0] @@ -326,6 +389,7 @@ class DataStoreByScope(AbstractDataStore): keys.append(key) return sorted(keys) + @CorrectScope def search_scope(self, scope: str, current_scope_only: bool = True, return_all: bool = False) -> List[str or Tuple]: """ Search for given `scope` and list all object names stored under this scope. To look also for all superior scopes diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index 295d4342c4527aa54f6e302bbf77c92bcf760c56..6414b3818cc270799ed6f247857b7b2904665af5 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -44,18 +44,18 @@ class ExperimentSetup(RunEnvironment): # experiment setup self._set_param("data_path", helpers.prepare_host(sampling=sampling)) self._set_param("create_new_model", create_new_model, default=True) - if self.data_store.get("create_new_model", "general"): + if self.data_store.get("create_new_model"): trainable = True - data_path = self.data_store.get("data_path", "general") + data_path = self.data_store.get("data_path") bootstrap_path = helpers.set_bootstrap_path(bootstrap_path, data_path, sampling) self._set_param("bootstrap_path", bootstrap_path) self._set_param("trainable", trainable, default=True) self._set_param("fraction_of_training", fraction_of_train, default=0.8) - self._set_param("extreme_values", extreme_values, default=None, scope="general.train") - self._set_param("extremes_on_right_tail_only", extremes_on_right_tail_only, default=False, scope="general.train") - self._set_param("upsampling", extreme_values is not None, scope="general.train") - upsampling = self.data_store.get("upsampling", "general.train") - self._set_param("permute_data", max([permute_data_on_training, upsampling]), scope="general.train") + self._set_param("extreme_values", extreme_values, default=None, scope="train") + self._set_param("extremes_on_right_tail_only", extremes_on_right_tail_only, default=False, scope="train") + self._set_param("upsampling", extreme_values is not None, scope="train") + upsampling = self.data_store.get("upsampling", "train") + self._set_param("permute_data", max([permute_data_on_training, upsampling]), scope="train") # set experiment name exp_date = self._get_parser_args(parser_args).get("experiment_date") @@ -63,32 +63,32 @@ class ExperimentSetup(RunEnvironment): sampling=sampling) self._set_param("experiment_name", exp_name) self._set_param("experiment_path", exp_path) - helpers.check_path_and_create(self.data_store.get("experiment_path", "general")) + helpers.check_path_and_create(self.data_store.get("experiment_path")) # set plot path default_plot_path = os.path.join(exp_path, "plots") self._set_param("plot_path", plot_path, default=default_plot_path) - helpers.check_path_and_create(self.data_store.get("plot_path", "general")) + helpers.check_path_and_create(self.data_store.get("plot_path")) # set results path default_forecast_path = os.path.join(exp_path, "forecasts") self._set_param("forecast_path", forecast_path, default_forecast_path) - helpers.check_path_and_create(self.data_store.get("forecast_path", "general")) + helpers.check_path_and_create(self.data_store.get("forecast_path")) # setup for data self._set_param("stations", stations, default=DEFAULT_STATIONS) self._set_param("network", network, default="AIRBASE") self._set_param("station_type", station_type, default=None) self._set_param("statistics_per_var", statistics_per_var, default=DEFAULT_VAR_ALL_DICT) - self._set_param("variables", variables, default=list(self.data_store.get("statistics_per_var", "general").keys())) + self._set_param("variables", variables, default=list(self.data_store.get("statistics_per_var").keys())) self._compare_variables_and_statistics() - self._set_param("start", start, default="1997-01-01", scope="general") - self._set_param("end", end, default="2017-12-31", scope="general") + self._set_param("start", start, default="1997-01-01") + self._set_param("end", end, default="2017-12-31") self._set_param("window_history_size", window_history_size, default=13) - self._set_param("overwrite_local_data", overwrite_local_data, default=False, scope="general.preprocessing") + self._set_param("overwrite_local_data", overwrite_local_data, default=False, scope="preprocessing") self._set_param("sampling", sampling) self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION) - self._set_param("transformation", None, scope="general.preprocessing") + self._set_param("transformation", None, scope="preprocessing") # target self._set_param("target_var", target_var, default="o3") @@ -103,25 +103,25 @@ class ExperimentSetup(RunEnvironment): self._set_param("limit_nan_fill", limit_nan_fill, default=1) # train set parameters - self._set_param("start", train_start, default="1997-01-01", scope="general.train") - self._set_param("end", train_end, default="2007-12-31", scope="general.train") - self._set_param("min_length", train_min_length, default=90, scope="general.train") + self._set_param("start", train_start, default="1997-01-01", scope="train") + self._set_param("end", train_end, default="2007-12-31", scope="train") + self._set_param("min_length", train_min_length, default=90, scope="train") # validation set parameters - self._set_param("start", val_start, default="2008-01-01", scope="general.val") - self._set_param("end", val_end, default="2009-12-31", scope="general.val") - self._set_param("min_length", val_min_length, default=90, scope="general.val") + self._set_param("start", val_start, default="2008-01-01", scope="val") + self._set_param("end", val_end, default="2009-12-31", scope="val") + self._set_param("min_length", val_min_length, default=90, scope="val") # test set parameters - self._set_param("start", test_start, default="2010-01-01", scope="general.test") - self._set_param("end", test_end, default="2017-12-31", scope="general.test") - self._set_param("min_length", test_min_length, default=90, scope="general.test") + self._set_param("start", test_start, default="2010-01-01", scope="test") + self._set_param("end", test_end, default="2017-12-31", scope="test") + self._set_param("min_length", test_min_length, default=90, scope="test") # train_val set parameters - self._set_param("start", self.data_store.get("start", "general.train"), scope="general.train_val") - self._set_param("end", self.data_store.get("end", "general.val"), scope="general.train_val") - train_val_min_length = sum([self.data_store.get("min_length", f"general.{s}") for s in ["train", "val"]]) - self._set_param("min_length", train_val_min_length, default=180, scope="general.train_val") + self._set_param("start", self.data_store.get("start", "train"), scope="train_val") + self._set_param("end", self.data_store.get("end", "val"), scope="train_val") + train_val_min_length = sum([self.data_store.get("min_length", s) for s in ["train", "val"]]) + self._set_param("min_length", train_val_min_length, default=180, scope="train_val") # use all stations on all data sets (train, val, test) self._set_param("use_all_stations_on_all_data_sets", use_all_stations_on_all_data_sets, default=True) @@ -148,8 +148,8 @@ class ExperimentSetup(RunEnvironment): def _compare_variables_and_statistics(self): logging.debug("check if all variables are included in statistics_per_var") - stat = self.data_store.get("statistics_per_var", "general") - var = self.data_store.get("variables", "general") + stat = self.data_store.get("statistics_per_var") + var = self.data_store.get("variables") if not set(var).issubset(stat.keys()): missing = set(var).difference(stat.keys()) raise ValueError(f"Comparison of given variables and statistics_per_var show that not all requested " @@ -157,9 +157,9 @@ class ExperimentSetup(RunEnvironment): f"statistics for the variables: {missing}") def _check_target_var(self): - target_var = helpers.to_list(self.data_store.get("target_var", "general")) - stat = self.data_store.get("statistics_per_var", "general") - var = self.data_store.get("variables", "general") + target_var = helpers.to_list(self.data_store.get("target_var")) + stat = self.data_store.get("statistics_per_var") + var = self.data_store.get("variables") if not set(target_var).issubset(stat.keys()): raise ValueError(f"Could not find target variable {target_var} in statistics_per_var.") unused_vars = set(stat.keys()).difference(set(var).union(target_var)) diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index 54d150e0bb44aa1ade473f5a184652ad2c3444d8..d04be92c4209ccad58ff20e97376fa4e96a07636 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -23,15 +23,15 @@ class ModelSetup(RunEnvironment): # create run framework super().__init__() self.model = None - path = self.data_store.get("experiment_path", "general") - exp_name = self.data_store.get("experiment_name", "general") - self.scope = "general.model" + path = self.data_store.get("experiment_path") + exp_name = self.data_store.get("experiment_name") + self.scope = "model" self.path = os.path.join(path, f"{exp_name}_%s") self.model_name = self.path % "%s.h5" self.checkpoint_name = self.path % "model-best.h5" self.callbacks_name = self.path % "model-best-callbacks-%s.pickle" - self._trainable = self.data_store.get("trainable", "general") - self._create_new_model = self.data_store.get("create_new_model", "general") + self._trainable = self.data_store.get("trainable") + self._create_new_model = self.data_store.get("create_new_model") self._run() def _run(self): @@ -56,7 +56,7 @@ class ModelSetup(RunEnvironment): self.compile_model() def _set_channels(self): - channels = self.data_store.get("generator", "general.train")[0][0].shape[-1] + channels = self.data_store.get("generator", "train")[0][0].shape[-1] self.data_store.set("channels", channels, self.scope) def compile_model(self): @@ -70,9 +70,9 @@ class ModelSetup(RunEnvironment): Set all callbacks for the training phase. Add all callbacks with the .add_callback statement. Finally, the advanced model checkpoint is added. """ - lr = self.data_store.get_default("lr_decay", scope="general.model", default=None) + lr = self.data_store.get_default("lr_decay", scope="model", default=None) hist = HistoryAdvanced() - self.data_store.set("hist", hist, scope="general.model") + self.data_store.set("hist", hist, scope="model") callbacks = CallbackHandler() if lr: callbacks.add_callback(lr, self.callbacks_name % "lr", "lr") diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index 07c257d7216fc340ec3493b7c2ff7bae7895e356..9922ba4e655d551ab23fce33bfe40f6f262274f6 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -30,14 +30,14 @@ class PostProcessing(RunEnvironment): super().__init__() self.model: keras.Model = self._load_model() self.ols_model = None - self.batch_size: int = self.data_store.get_default("batch_size", "general.model", 64) - self.test_data: DataGenerator = self.data_store.get("generator", "general.test") + self.batch_size: int = self.data_store.get_default("batch_size", "model", 64) + self.test_data: DataGenerator = self.data_store.get("generator", "test") self.test_data_distributed = Distributor(self.test_data, self.model, self.batch_size) - self.train_data: DataGenerator = self.data_store.get("generator", "general.train") - self.train_val_data: DataGenerator = self.data_store.get("generator", "general.train_val") - self.plot_path: str = self.data_store.get("plot_path", "general") - self.target_var = self.data_store.get("target_var", "general") - self._sampling = self.data_store.get("sampling", "general") + self.train_data: DataGenerator = self.data_store.get("generator", "train") + self.train_val_data: DataGenerator = self.data_store.get("generator", "train_val") + self.plot_path: str = self.data_store.get("plot_path") + self.target_var = self.data_store.get("target_var") + self._sampling = self.data_store.get("sampling") self.skill_scores = None self.bootstrap_skill_scores = None self._run() @@ -59,9 +59,9 @@ class PostProcessing(RunEnvironment): # forecast - bootstrap_path = self.data_store.get("bootstrap_path", "general") - forecast_path = self.data_store.get("forecast_path", "general") - window_lead_time = self.data_store.get("window_lead_time", "general") + bootstrap_path = self.data_store.get("bootstrap_path") + forecast_path = self.data_store.get("forecast_path") + window_lead_time = self.data_store.get("window_lead_time") bootstraps = BootStraps(self.test_data, bootstrap_path, 20) with TimeTracking(name="boot predictions"): bootstrap_predictions = self.model.predict_generator(generator=bootstraps.boot_strap_generator(), @@ -114,17 +114,17 @@ class PostProcessing(RunEnvironment): def _load_model(self): try: - model = self.data_store.get("best_model", "general") + model = self.data_store.get("best_model") except NameNotFoundInDataStore: logging.info("no model saved in data store. trying to load model from experiment path") - model_name = self.data_store.get("model_name", "general.model") - model_class: AbstractModelClass = self.data_store.get("model", "general.model") + model_name = self.data_store.get("model_name", "model") + model_class: AbstractModelClass = self.data_store.get("model", "model") model = keras.models.load_model(model_name, custom_objects=model_class.custom_objects) return model def plot(self): logging.debug("Run plotting routines...") - path = self.data_store.get("forecast_path", "general") + path = self.data_store.get("forecast_path") plot_conditional_quantiles(self.test_data.stations, pred_name="CNN", ref_name="obs", forecast_path=path, plot_name_affix="cali-ref", plot_folder=self.plot_path) @@ -147,7 +147,7 @@ class PostProcessing(RunEnvironment): self._save_test_score(test_score) def _save_test_score(self, score): - path = self.data_store.get("experiment_path", "general") + path = self.data_store.get("experiment_path") with open(os.path.join(path, "test_scores.txt")) as f: for index, item in enumerate(score): f.write(f"{self.model.metrics[index]}, {item}\n") @@ -190,7 +190,7 @@ class PostProcessing(RunEnvironment): OLS=ols_prediction) # save all forecasts locally - path = self.data_store.get("forecast_path", "general") + path = self.data_store.get("forecast_path") prefix = "forecasts_norm" if normalised else "forecasts" file = os.path.join(path, f"{prefix}_{data.station[0]}_test.nc") all_predictions.to_netcdf(file) @@ -218,7 +218,7 @@ class PostProcessing(RunEnvironment): tmp_persi = data.observation.copy().sel({'window': 0}) if not normalised: tmp_persi = statistics.apply_inverse_transformation(tmp_persi, mean, std, transformation_method) - window_lead_time = self.data_store.get("window_lead_time", "general") + window_lead_time = self.data_store.get("window_lead_time") persistence_prediction.values = np.expand_dims(np.tile(tmp_persi.squeeze('Stations'), (window_lead_time, 1)), axis=1) return persistence_prediction @@ -307,8 +307,8 @@ class PostProcessing(RunEnvironment): return None def calculate_skill_scores(self): - path = self.data_store.get("forecast_path", "general") - window_lead_time = self.data_store.get("window_lead_time", "general") + path = self.data_store.get("forecast_path") + window_lead_time = self.data_store.get("window_lead_time") skill_score_competitive = {} skill_score_climatological = {} for station in self.test_data.stations: diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py index 439793f941e6aaf4085241f200a63614563b550a..f83b659f3d8ae449767e8f125893dc6df0de17d8 100644 --- a/src/run_modules/pre_processing.py +++ b/src/run_modules/pre_processing.py @@ -34,26 +34,26 @@ class PreProcessing(RunEnvironment): self._run() def _run(self): - args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope="general.preprocessing") - kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope="general.preprocessing") - stations = self.data_store.get("stations", "general") + args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope="preprocessing") + kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope="preprocessing") + stations = self.data_store.get("stations") valid_stations = self.check_valid_stations(args, kwargs, stations, load_tmp=False, save_tmp=False) - self.data_store.set("stations", valid_stations, "general") + self.data_store.set("stations", valid_stations) self.split_train_val_test() self.report_pre_processing() def report_pre_processing(self): logging.debug(20 * '##') - n_train = len(self.data_store.get('generator', 'general.train')) - n_val = len(self.data_store.get('generator', 'general.val')) - n_test = len(self.data_store.get('generator', 'general.test')) + n_train = len(self.data_store.get('generator', 'train')) + n_val = len(self.data_store.get('generator', 'val')) + n_test = len(self.data_store.get('generator', 'test')) n_total = n_train + n_val + n_test logging.debug(f"Number of all stations: {n_total}") logging.debug(f"Number of training stations: {n_train}") logging.debug(f"Number of val stations: {n_val}") logging.debug(f"Number of test stations: {n_test}") - logging.debug(f"TEST SHAPE OF GENERATOR CALL: {self.data_store.get('generator', 'general.test')[0][0].shape}" - f"{self.data_store.get('generator', 'general.test')[0][1].shape}") + logging.debug(f"TEST SHAPE OF GENERATOR CALL: {self.data_store.get('generator', 'test')[0][0].shape}" + f"{self.data_store.get('generator', 'test')[0][1].shape}") def split_train_val_test(self) -> None: """ @@ -61,8 +61,8 @@ class PreProcessing(RunEnvironment): but as an separate generator). IMPORTANT: Do not change to order of the execution of create_set_split. The train subset needs always to be executed at first, to set a proper transformation. """ - fraction_of_training = self.data_store.get("fraction_of_training", "general") - stations = self.data_store.get("stations", "general") + fraction_of_training = self.data_store.get("fraction_of_training") + stations = self.data_store.get("stations") train_index, val_index, test_index, train_val_index = self.split_set_indices(len(stations), fraction_of_training) subset_names = ["train", "val", "test", "train_val"] if subset_names[0] != "train": # pragma: no cover @@ -97,9 +97,9 @@ class PreProcessing(RunEnvironment): sure, that the train set is executed first, and all other subsets afterwards. :param index_list: list of all stations to use for the set. If attribute use_all_stations_on_all_data_sets=True, this list is ignored. - :param set_name: name to load/save all information from/to data store without the leading general prefix. + :param set_name: name to load/save all information from/to data store. """ - scope = f"general.{set_name}" + scope = set_name args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope) kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope) stations = args["stations"] @@ -114,7 +114,7 @@ class PreProcessing(RunEnvironment): data_set = DataGenerator(**set_args, **kwargs) self.data_store.set("generator", data_set, scope) if set_name == "train": - self.data_store.set("transformation", data_set.transformation, "general") + self.data_store.set("transformation", data_set.transformation) @staticmethod def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True, save_tmp=True): diff --git a/src/run_modules/run_environment.py b/src/run_modules/run_environment.py index 1c44786dfd4830c8053ae1673eac1473fbd19338..c5bc2cb4af46a6920fdfbb69ed3d0261efa17411 100644 --- a/src/run_modules/run_environment.py +++ b/src/run_modules/run_environment.py @@ -40,7 +40,7 @@ class RunEnvironment(object): self.del_by_exit = True if self.__class__.__name__ == "RunEnvironment": try: - new_file = os.path.join(self.data_store.get("experiment_path", "general"), "logging.log") + new_file = os.path.join(self.data_store.get("experiment_path"), "logging.log") shutil.copyfile(self.logger.log_file, new_file) except (NameNotFoundInDataStore, FileNotFoundError): pass diff --git a/src/run_modules/training.py b/src/run_modules/training.py index 0d6279b132b64f287f541088c2675012a2d1e933..2d949af8c68f244c0a0da2bad6580c616695da8d 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -20,16 +20,16 @@ class Training(RunEnvironment): def __init__(self): super().__init__() - self.model: keras.Model = self.data_store.get("model", "general.model") + self.model: keras.Model = self.data_store.get("model", "model") self.train_set: Union[Distributor, None] = None self.val_set: Union[Distributor, None] = None self.test_set: Union[Distributor, None] = None - self.batch_size = self.data_store.get("batch_size", "general.model") - self.epochs = self.data_store.get("epochs", "general.model") - self.callbacks: CallbackHandler = self.data_store.get("callbacks", "general.model") - self.experiment_name = self.data_store.get("experiment_name", "general") - self._trainable = self.data_store.get("trainable", "general") - self._create_new_model = self.data_store.get("create_new_model", "general") + self.batch_size = self.data_store.get("batch_size", "model") + self.epochs = self.data_store.get("epochs", "model") + self.callbacks: CallbackHandler = self.data_store.get("callbacks", "model") + self.experiment_name = self.data_store.get("experiment_name") + self._trainable = self.data_store.get("trainable") + self._create_new_model = self.data_store.get("create_new_model") self._run() def _run(self) -> None: @@ -66,9 +66,9 @@ class Training(RunEnvironment): Set and distribute the generators for given mode regarding batch size :param mode: name of set, should be from ["train", "val", "test"] """ - gen = self.data_store.get("generator", f"general.{mode}") - # permute_data = self.data_store.get_default("permute_data", f"general.{mode}", default=False) - kwargs = self.data_store.create_args_dict(["permute_data", "upsampling"], scope=f"general.{mode}") + gen = self.data_store.get("generator", mode) + # permute_data = self.data_store.get_default("permute_data", mode, default=False) + kwargs = self.data_store.create_args_dict(["permute_data", "upsampling"], scope=mode) setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size, **kwargs)) def set_generators(self) -> None: @@ -129,10 +129,10 @@ class Training(RunEnvironment): """ save model in local experiment directory. Model is named as <experiment_name>_<custom_model_name>.h5 . """ - model_name = self.data_store.get("model_name", "general.model") + model_name = self.data_store.get("model_name", "model") logging.debug(f"save best model to {model_name}") self.model.save(model_name) - self.data_store.set("best_model", self.model, "general") + self.data_store.set("best_model", self.model) def load_best_model(self, name: str) -> None: """ @@ -154,7 +154,7 @@ class Training(RunEnvironment): :param history: history object of training """ logging.debug("saving callbacks") - path = self.data_store.get("experiment_path", "general") + path = self.data_store.get("experiment_path") with open(os.path.join(path, "history.json"), "w") as f: json.dump(history.history, f) if lr_sc: @@ -169,8 +169,8 @@ class Training(RunEnvironment): :param history: keras history object with losses to plot (must include 'loss' and 'val_loss') :param lr_sc: learning rate decay object with 'lr' attribute """ - path = self.data_store.get("plot_path", "general") - name = self.data_store.get("experiment_name", "general") + path = self.data_store.get("plot_path") + name = self.data_store.get("experiment_name") # plot history of loss and mse (if available) filename = os.path.join(path, f"{name}_history_loss.pdf") diff --git a/test/test_datastore.py b/test/test_datastore.py index 9fcb319f51954b365c59274a4a9744f093e155f1..5b6cd17a00271a17b8fe5c30ca26665b42e56141 100644 --- a/test/test_datastore.py +++ b/test/test_datastore.py @@ -4,7 +4,7 @@ __date__ = '2019-11-22' import pytest -from src.datastore import AbstractDataStore, DataStoreByVariable, DataStoreByScope +from src.datastore import AbstractDataStore, DataStoreByVariable, DataStoreByScope, CorrectScope from src.datastore import NameNotFoundInDataStore, NameNotFoundInScope, EmptyScope @@ -68,7 +68,7 @@ class TestDataStoreByVariable: ds.set("number", 3, "general") assert ds.get_default("number", "general", 45) == 3 assert ds.get_default("number", "general.sub", 45) == 3 - assert ds.get_default("number", "other", 45) == 45 + assert ds.get_default("other", 45) == 45 def test_search(self, ds): ds.set("number", 22, "general") @@ -161,6 +161,19 @@ class TestDataStoreByVariable: assert ds.get("tester1", "general.sub") == 111 assert ds.get("tester3", "general.sub") == 21 + def test_no_scope_given(self, ds): + ds.set("tester", 34) + assert ds._store["tester"]["general"] == 34 + assert ds.get("tester") == 34 + assert ds.get("tester", "sub") == 34 + ds.set("tester", 99, "sub") + assert ds.list_all_scopes() == ["general", "general.sub"] + assert ds.get_default("test2", 4) == 4 + assert ds.get_default("tester", "sub", 4) == 99 + ds.set("test2", 4) + assert sorted(ds.search_scope(current_scope_only=False)) == sorted(["tester", "test2"]) + assert ds.search_scope("sub", current_scope_only=True) == ["tester"] + class TestDataStoreByScope: @@ -206,7 +219,7 @@ class TestDataStoreByScope: ds.set("number", 3, "general") assert ds.get_default("number", "general", 45) == 3 assert ds.get_default("number", "general.sub", 45) == 3 - assert ds.get_default("number", "other", 45) == 45 + assert ds.get_default("other", "other", 45) == 45 def test_search(self, ds): ds.set("number", 22, "general") @@ -297,4 +310,31 @@ class TestDataStoreByScope: assert ds.get("tester3", "general") == 21 ds.set_args_from_dict({"tester1": 111}, "general.sub") assert ds.get("tester1", "general.sub") == 111 - assert ds.get("tester3", "general.sub") == 21 \ No newline at end of file + assert ds.get("tester3", "general.sub") == 21 + + def test_no_scope_given(self, ds): + ds.set("tester", 34) + assert ds._store["general"]["tester"] == 34 + assert ds.get("tester") == 34 + assert ds.get("tester", "sub") == 34 + ds.set("tester", 99, "sub") + assert ds.list_all_scopes() == ["general", "general.sub"] + assert ds.get_default("test2", 4) == 4 + assert ds.get_default("tester", "sub", 4) == 99 + ds.set("test2", 4) + assert sorted(ds.search_scope(current_scope_only=False)) == sorted(["tester", "test2"]) + assert ds.search_scope("sub", current_scope_only=True) == ["tester"] + + +class TestCorrectScope: + + @staticmethod + @CorrectScope + def function1(a, scope, b=44): + return a, scope, b + + def test_init(self): + assert self.function1(22, "general") == (22, "general", 44) + assert self.function1(21) == (21, "general", 44) + assert self.function1(55, "sub", 34) == (55, "general.sub", 34) + assert self.function1("string", b=99, scope="tester") == ("string", "general.tester", 99)