diff --git a/docs/requirements_docs.txt b/docs/requirements_docs.txt index 8ccf3ba6515d31ebd2b35901d3c9e58734d653d8..ee455d83f0debc10faa09ffd82cad9a77930d936 100644 --- a/docs/requirements_docs.txt +++ b/docs/requirements_docs.txt @@ -1,7 +1,9 @@ sphinx==3.0.3 -sphinx-autoapi==1.3.0 -sphinx-autodoc-typehints==1.10.3 +sphinx-autoapi==1.8.4 +sphinx-autodoc-typehints==1.12.0 sphinx-rtd-theme==0.4.3 #recommonmark==0.6.0 -m2r2==0.2.5 -docutils<0.18 \ No newline at end of file +m2r2==0.3.1 +docutils<0.18 +mistune==0.8.4 +setuptools>=59.5.0 \ No newline at end of file diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py index 68aff5947743dfdc66f95d93d5b8b284a87789d8..9b8efe811d3ca987a9a67765cdde8ac1e73a9cca 100644 --- a/mlair/data_handler/default_data_handler.py +++ b/mlair/data_handler/default_data_handler.py @@ -293,6 +293,7 @@ class DefaultDataHandler(AbstractDataHandler): transformation_dict = ({}, {}) max_process = kwargs.get("max_number_multiprocessing", 16) + set_stations = to_list(set_stations) n_process = min([psutil.cpu_count(logical=False), len(set_stations), max_process]) # use only physical cpus if n_process > 1 and kwargs.get("use_multiprocessing", True) is True: # parallel solution logging.info("use parallel transformation approach") diff --git a/mlair/helpers/testing.py b/mlair/helpers/testing.py index abb50883c7af49a0c1571d99f737e310abff9b13..1fb8012f50dab520df1d154303f727c36bfca418 100644 --- a/mlair/helpers/testing.py +++ b/mlair/helpers/testing.py @@ -86,3 +86,35 @@ def PyTestAllEqual(check_list: List): return self._check_all_equal() return PyTestAllEqualClass(check_list).is_true() + + +def test_nested_equality(obj1, obj2): + + try: + print(f"check type {type(obj1)} and {type(obj2)}") + assert type(obj1) == type(obj2) + + if isinstance(obj1, (tuple, list)): + print(f"check length {len(obj1)} and {len(obj2)}") + assert len(obj1) == len(obj2) + for pos in range(len(obj1)): + print(f"check pos {obj1[pos]} and {obj2[pos]}") + assert test_nested_equality(obj1[pos], obj2[pos]) is True + elif isinstance(obj1, dict): + print(f"check keys {obj1.keys()} and {obj2.keys()}") + assert sorted(obj1.keys()) == sorted(obj2.keys()) + for k in obj1.keys(): + print(f"check pos {obj1[k]} and {obj2[k]}") + assert test_nested_equality(obj1[k], obj2[k]) is True + elif isinstance(obj1, xr.DataArray): + print(f"check xr {obj1} and {obj2}") + assert xr.testing.assert_equal(obj1, obj2) is None + elif isinstance(obj1, np.ndarray): + print(f"check np {obj1} and {obj2}") + assert np.testing.assert_array_equal(obj1, obj2) is None + else: + print(f"check equal {obj1} and {obj2}") + assert obj1 == obj2 + except AssertionError: + return False + return True diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index 70b23c3730d9091d3780746cbb3913eefe4dcf95..524d29b8cc5adda337cf80866071c1697253300f 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -187,6 +187,9 @@ class ExperimentSetup(RunEnvironment): :param use_multiprocessing: Enable parallel preprocessing (postprocessing not implemented yet) by setting this parameter to `True` (default). If set to `False` the computation is performed in an serial approach. Multiprocessing is disabled when running in debug mode and cannot be switched on. + :param transformation_file: Use transformation options from this file for transformation + :param calculate_fresh_transformation: can either be True or False, indicates if new transformation options should + be calculated in any case (transformation_file is not used in this case!). """ @@ -224,7 +227,8 @@ class ExperimentSetup(RunEnvironment): max_number_multiprocessing: int = None, start_script: Union[Callable, str] = None, overwrite_lazy_data: bool = None, uncertainty_estimate_block_length: str = None, uncertainty_estimate_evaluate_competitors: bool = None, uncertainty_estimate_n_boots: int = None, - do_uncertainty_estimate: bool = None, model_display_name: str = None, **kwargs): + do_uncertainty_estimate: bool = None, model_display_name: str = None, transformation_file: str = None, + calculate_fresh_transformation: bool = None, **kwargs): # create run framework super().__init__() @@ -311,6 +315,9 @@ class ExperimentSetup(RunEnvironment): scope="preprocessing") self._set_param("transformation", transformation, default={}) self._set_param("transformation", None, scope="preprocessing") + self._set_param("transformation_file", transformation_file, default=None) + if calculate_fresh_transformation is not None: + self._set_param("calculate_fresh_transformation", calculate_fresh_transformation) self._set_param("data_handler", data_handler, default=DefaultDataHandler) # iter and window dimension diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index 116a37b305fbe0c2e81dd89bd8ba43257d29a61c..92882a897d012a90ea052d9491973b0be83ad3ef 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -295,12 +295,43 @@ class PreProcessing(RunEnvironment): self.data_store.set(k, v) def transformation(self, data_handler: AbstractDataHandler, stations): + calculate_fresh_transformation = self.data_store.get_default("calculate_fresh_transformation", True) if hasattr(data_handler, "transformation"): - kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope="train") - tmp_path = self.data_store.get_default("tmp_path", default=None) - transformation_dict = data_handler.transformation(stations, tmp_path=tmp_path, **kwargs) - if transformation_dict is not None: - self.data_store.set("transformation", transformation_dict) + transformation_opts = None if calculate_fresh_transformation is True else self._load_transformation() + if transformation_opts is None: + logging.info(f"start to calculate transformation parameters.") + kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope="train") + tmp_path = self.data_store.get_default("tmp_path", default=None) + transformation_opts = data_handler.transformation(stations, tmp_path=tmp_path, **kwargs) + else: + logging.info("In case no valid train data could be found due to problems with transformation, please " + "check your provided transformation file for compability with your data.") + self.data_store.set("transformation", transformation_opts) + if transformation_opts is not None: + self._store_transformation(transformation_opts) + + def _load_transformation(self): + """Try to load transformation options from file if transformation_file is provided.""" + transformation_file = self.data_store.get_default("transformation_file", None) + if transformation_file is not None: + if os.path.exists(transformation_file): + logging.info(f"use transformation from given transformation file: {transformation_file}") + with open(transformation_file, "rb") as pickle_file: + return dill.load(pickle_file) + else: + logging.info(f"cannot load transformation file: {transformation_file}. Use fresh calculation of " + f"transformation from train data.") + + def _store_transformation(self, transformation_opts): + """Store transformation options locally inside experiment_path if not exists already.""" + experiment_path = self.data_store.get("experiment_path") + transformation_path = os.path.join(experiment_path, "data", "transformation") + transformation_file = os.path.join(transformation_path, "transformation.pickle") + if not os.path.exists(transformation_file): + path_config.check_path_and_create(transformation_path) + with open(transformation_file, "wb") as f: + dill.dump(transformation_opts, f, protocol=4) + logging.info(f"Store transformation options locally for later use at: {transformation_file}") def prepare_competitors(self): """ diff --git a/test/test_helpers/test_testing_helpers.py b/test/test_helpers/test_testing_helpers.py index 385161c740f386847ef2f2dc4df17c1c84fa7fa5..83ba0101cd452869af8c56f44432e697d290fa97 100644 --- a/test/test_helpers/test_testing_helpers.py +++ b/test/test_helpers/test_testing_helpers.py @@ -1,4 +1,4 @@ -from mlair.helpers.testing import PyTestRegex, PyTestAllEqual +from mlair.helpers.testing import PyTestRegex, PyTestAllEqual, test_nested_equality import re import xarray as xr @@ -46,3 +46,35 @@ class TestPyTestAllEqual: [xr.DataArray([1, 2, 3]), xr.DataArray([12, 22, 32])]]) assert PyTestAllEqual([["test", "test2"], ["test", "test2"]]) + + +class TestNestedEquality: + + def test_nested_equality_single_entries(self): + assert test_nested_equality(3, 3) is True + assert test_nested_equality(3.9, 3.9) is True + assert test_nested_equality(3.91, 3.9) is False + assert test_nested_equality("3", 3) is False + assert test_nested_equality("3", "3") is True + assert test_nested_equality(None, None) is True + + def test_nested_equality_xarray(self): + obj1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]}) + obj2 = xr.ones_like(obj1) * obj1 + assert test_nested_equality(obj1, obj2) is True + + def test_nested_equality_numpy(self): + obj1 = np.random.randn(2, 3) + obj2 = obj1 * 1 + assert test_nested_equality(obj1, obj2) is True + + def test_nested_equality_list_tuple(self): + assert test_nested_equality([3, 3], [3, 3]) is True + assert test_nested_equality((2, 6), (2, 6)) is True + assert test_nested_equality([3, 3.5], [3.5, 3]) is False + assert test_nested_equality([3, 3.5, 10], [3, 3.5]) is False + + def test_nested_equality_dict(self): + assert test_nested_equality({"a": 3, "b": 10}, {"b": 10, "a": 3}) is True + assert test_nested_equality({"a": 3, "b": [10, 100]}, {"b": [10, 100], "a": 3}) is True + assert test_nested_equality({"a": 3, "b": 10, "c": "c"}, {"b": 10, "a": 3}) is False