From a3374daecd07bae9afb9878b30d20084ea916a7a Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Wed, 8 Dec 2021 16:35:38 +0100 Subject: [PATCH] can now load transformation from file instead from parameters. --- mlair/data_handler/default_data_handler.py | 1 + mlair/helpers/testing.py | 29 +++++++++++++++ mlair/run_modules/experiment_setup.py | 9 ++++- mlair/run_modules/pre_processing.py | 41 +++++++++++++++++++--- 4 files changed, 74 insertions(+), 6 deletions(-) diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py index 68aff594..9b8efe81 100644 --- a/mlair/data_handler/default_data_handler.py +++ b/mlair/data_handler/default_data_handler.py @@ -293,6 +293,7 @@ class DefaultDataHandler(AbstractDataHandler): transformation_dict = ({}, {}) max_process = kwargs.get("max_number_multiprocessing", 16) + set_stations = to_list(set_stations) n_process = min([psutil.cpu_count(logical=False), len(set_stations), max_process]) # use only physical cpus if n_process > 1 and kwargs.get("use_multiprocessing", True) is True: # parallel solution logging.info("use parallel transformation approach") diff --git a/mlair/helpers/testing.py b/mlair/helpers/testing.py index abb50883..8c3b301d 100644 --- a/mlair/helpers/testing.py +++ b/mlair/helpers/testing.py @@ -86,3 +86,32 @@ def PyTestAllEqual(check_list: List): return self._check_all_equal() return PyTestAllEqualClass(check_list).is_true() + + +def test_nested_equality(obj1, obj2): + + try: + print(f"check type {type(obj1)} and {type(obj2)}") + assert type(obj1) == type(obj2) + + if isinstance(obj1, (tuple, list)): + print(f"check length {len(obj1)} and {len(obj2)}") + assert len(obj1) == len(obj1) + for pos in range(len(obj1)): + print(f"check pos {obj1[pos]} and {obj2[pos]}") + assert test_nested_equality(obj1[pos], obj2[pos]) is True + elif isinstance(obj1, dict): + print(f"check keys {obj1.keys()} and {obj2.keys()}") + assert sorted(obj1.keys()) == sorted(obj2.keys()) + for k in obj1.keys(): + print(f"check pos {obj1[k]} and {obj2[k]}") + assert test_nested_equality(obj1[k], obj2[k]) is True + elif isinstance(obj1, xr.DataArray): + print(f"check xr {obj1} and {obj2}") + assert xr.testing.assert_equal(obj1, obj2) is None + else: + print(f"check equal {obj1} and {obj2}") + assert obj1 == obj2 + except AssertionError: + return False + return True diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index 70b23c37..524d29b8 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -187,6 +187,9 @@ class ExperimentSetup(RunEnvironment): :param use_multiprocessing: Enable parallel preprocessing (postprocessing not implemented yet) by setting this parameter to `True` (default). If set to `False` the computation is performed in an serial approach. Multiprocessing is disabled when running in debug mode and cannot be switched on. + :param transformation_file: Use transformation options from this file for transformation + :param calculate_fresh_transformation: can either be True or False, indicates if new transformation options should + be calculated in any case (transformation_file is not used in this case!). """ @@ -224,7 +227,8 @@ class ExperimentSetup(RunEnvironment): max_number_multiprocessing: int = None, start_script: Union[Callable, str] = None, overwrite_lazy_data: bool = None, uncertainty_estimate_block_length: str = None, uncertainty_estimate_evaluate_competitors: bool = None, uncertainty_estimate_n_boots: int = None, - do_uncertainty_estimate: bool = None, model_display_name: str = None, **kwargs): + do_uncertainty_estimate: bool = None, model_display_name: str = None, transformation_file: str = None, + calculate_fresh_transformation: bool = None, **kwargs): # create run framework super().__init__() @@ -311,6 +315,9 @@ class ExperimentSetup(RunEnvironment): scope="preprocessing") self._set_param("transformation", transformation, default={}) self._set_param("transformation", None, scope="preprocessing") + self._set_param("transformation_file", transformation_file, default=None) + if calculate_fresh_transformation is not None: + self._set_param("calculate_fresh_transformation", calculate_fresh_transformation) self._set_param("data_handler", data_handler, default=DefaultDataHandler) # iter and window dimension diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index 116a37b3..92882a89 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -295,12 +295,43 @@ class PreProcessing(RunEnvironment): self.data_store.set(k, v) def transformation(self, data_handler: AbstractDataHandler, stations): + calculate_fresh_transformation = self.data_store.get_default("calculate_fresh_transformation", True) if hasattr(data_handler, "transformation"): - kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope="train") - tmp_path = self.data_store.get_default("tmp_path", default=None) - transformation_dict = data_handler.transformation(stations, tmp_path=tmp_path, **kwargs) - if transformation_dict is not None: - self.data_store.set("transformation", transformation_dict) + transformation_opts = None if calculate_fresh_transformation is True else self._load_transformation() + if transformation_opts is None: + logging.info(f"start to calculate transformation parameters.") + kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope="train") + tmp_path = self.data_store.get_default("tmp_path", default=None) + transformation_opts = data_handler.transformation(stations, tmp_path=tmp_path, **kwargs) + else: + logging.info("In case no valid train data could be found due to problems with transformation, please " + "check your provided transformation file for compability with your data.") + self.data_store.set("transformation", transformation_opts) + if transformation_opts is not None: + self._store_transformation(transformation_opts) + + def _load_transformation(self): + """Try to load transformation options from file if transformation_file is provided.""" + transformation_file = self.data_store.get_default("transformation_file", None) + if transformation_file is not None: + if os.path.exists(transformation_file): + logging.info(f"use transformation from given transformation file: {transformation_file}") + with open(transformation_file, "rb") as pickle_file: + return dill.load(pickle_file) + else: + logging.info(f"cannot load transformation file: {transformation_file}. Use fresh calculation of " + f"transformation from train data.") + + def _store_transformation(self, transformation_opts): + """Store transformation options locally inside experiment_path if not exists already.""" + experiment_path = self.data_store.get("experiment_path") + transformation_path = os.path.join(experiment_path, "data", "transformation") + transformation_file = os.path.join(transformation_path, "transformation.pickle") + if not os.path.exists(transformation_file): + path_config.check_path_and_create(transformation_path) + with open(transformation_file, "wb") as f: + dill.dump(transformation_opts, f, protocol=4) + logging.info(f"Store transformation options locally for later use at: {transformation_file}") def prepare_competitors(self): """ -- GitLab