From a3374daecd07bae9afb9878b30d20084ea916a7a Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Wed, 8 Dec 2021 16:35:38 +0100
Subject: [PATCH] can now load transformation from file instead from
 parameters.

---
 mlair/data_handler/default_data_handler.py |  1 +
 mlair/helpers/testing.py                   | 29 +++++++++++++++
 mlair/run_modules/experiment_setup.py      |  9 ++++-
 mlair/run_modules/pre_processing.py        | 41 +++++++++++++++++++---
 4 files changed, 74 insertions(+), 6 deletions(-)

diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py
index 68aff594..9b8efe81 100644
--- a/mlair/data_handler/default_data_handler.py
+++ b/mlair/data_handler/default_data_handler.py
@@ -293,6 +293,7 @@ class DefaultDataHandler(AbstractDataHandler):
         transformation_dict = ({}, {})
 
         max_process = kwargs.get("max_number_multiprocessing", 16)
+        set_stations = to_list(set_stations)
         n_process = min([psutil.cpu_count(logical=False), len(set_stations), max_process])  # use only physical cpus
         if n_process > 1 and kwargs.get("use_multiprocessing", True) is True:  # parallel solution
             logging.info("use parallel transformation approach")
diff --git a/mlair/helpers/testing.py b/mlair/helpers/testing.py
index abb50883..8c3b301d 100644
--- a/mlair/helpers/testing.py
+++ b/mlair/helpers/testing.py
@@ -86,3 +86,32 @@ def PyTestAllEqual(check_list: List):
             return self._check_all_equal()
 
     return PyTestAllEqualClass(check_list).is_true()
+
+
+def test_nested_equality(obj1, obj2):
+
+    try:
+        print(f"check type {type(obj1)} and {type(obj2)}")
+        assert type(obj1) == type(obj2)
+
+        if isinstance(obj1, (tuple, list)):
+            print(f"check length {len(obj1)} and {len(obj2)}")
+            assert len(obj1) == len(obj1)
+            for pos in range(len(obj1)):
+                print(f"check pos {obj1[pos]} and {obj2[pos]}")
+                assert test_nested_equality(obj1[pos], obj2[pos]) is True
+        elif isinstance(obj1, dict):
+            print(f"check keys {obj1.keys()} and {obj2.keys()}")
+            assert sorted(obj1.keys()) == sorted(obj2.keys())
+            for k in obj1.keys():
+                print(f"check pos {obj1[k]} and {obj2[k]}")
+                assert test_nested_equality(obj1[k], obj2[k]) is True
+        elif isinstance(obj1, xr.DataArray):
+            print(f"check xr {obj1} and {obj2}")
+            assert xr.testing.assert_equal(obj1, obj2) is None
+        else:
+            print(f"check equal {obj1} and {obj2}")
+            assert obj1 == obj2
+    except AssertionError:
+        return False
+    return True
diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py
index 70b23c37..524d29b8 100644
--- a/mlair/run_modules/experiment_setup.py
+++ b/mlair/run_modules/experiment_setup.py
@@ -187,6 +187,9 @@ class ExperimentSetup(RunEnvironment):
     :param use_multiprocessing: Enable parallel preprocessing (postprocessing not implemented yet) by setting this
         parameter to `True` (default). If set to `False` the computation is performed in an serial approach.
         Multiprocessing is disabled when running in debug mode and cannot be switched on.
+    :param transformation_file: Use transformation options from this file for transformation
+    :param calculate_fresh_transformation: can either be True or False, indicates if new transformation options should
+        be calculated in any case (transformation_file is not used in this case!).
 
     """
 
@@ -224,7 +227,8 @@ class ExperimentSetup(RunEnvironment):
                  max_number_multiprocessing: int = None, start_script: Union[Callable, str] = None,
                  overwrite_lazy_data: bool = None, uncertainty_estimate_block_length: str = None,
                  uncertainty_estimate_evaluate_competitors: bool = None, uncertainty_estimate_n_boots: int = None,
-                 do_uncertainty_estimate: bool = None, model_display_name: str = None, **kwargs):
+                 do_uncertainty_estimate: bool = None, model_display_name: str = None, transformation_file: str = None,
+                 calculate_fresh_transformation: bool = None, **kwargs):
 
         # create run framework
         super().__init__()
@@ -311,6 +315,9 @@ class ExperimentSetup(RunEnvironment):
                         scope="preprocessing")
         self._set_param("transformation", transformation, default={})
         self._set_param("transformation", None, scope="preprocessing")
+        self._set_param("transformation_file", transformation_file, default=None)
+        if calculate_fresh_transformation is not None:
+            self._set_param("calculate_fresh_transformation", calculate_fresh_transformation)
         self._set_param("data_handler", data_handler, default=DefaultDataHandler)
 
         # iter and window dimension
diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py
index 116a37b3..92882a89 100644
--- a/mlair/run_modules/pre_processing.py
+++ b/mlair/run_modules/pre_processing.py
@@ -295,12 +295,43 @@ class PreProcessing(RunEnvironment):
                 self.data_store.set(k, v)
 
     def transformation(self, data_handler: AbstractDataHandler, stations):
+        calculate_fresh_transformation = self.data_store.get_default("calculate_fresh_transformation", True)
         if hasattr(data_handler, "transformation"):
-            kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope="train")
-            tmp_path = self.data_store.get_default("tmp_path", default=None)
-            transformation_dict = data_handler.transformation(stations, tmp_path=tmp_path, **kwargs)
-            if transformation_dict is not None:
-                self.data_store.set("transformation", transformation_dict)
+            transformation_opts = None if calculate_fresh_transformation is True else self._load_transformation()
+            if transformation_opts is None:
+                logging.info(f"start to calculate transformation parameters.")
+                kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope="train")
+                tmp_path = self.data_store.get_default("tmp_path", default=None)
+                transformation_opts = data_handler.transformation(stations, tmp_path=tmp_path, **kwargs)
+            else:
+                logging.info("In case no valid train data could be found due to problems with transformation, please "
+                             "check your provided transformation file for compability with your data.")
+            self.data_store.set("transformation", transformation_opts)
+            if transformation_opts is not None:
+                self._store_transformation(transformation_opts)
+
+    def _load_transformation(self):
+        """Try to load transformation options from file if transformation_file is provided."""
+        transformation_file = self.data_store.get_default("transformation_file", None)
+        if transformation_file is not None:
+            if os.path.exists(transformation_file):
+                logging.info(f"use transformation from given transformation file: {transformation_file}")
+                with open(transformation_file, "rb") as pickle_file:
+                    return dill.load(pickle_file)
+            else:
+                logging.info(f"cannot load transformation file: {transformation_file}. Use fresh calculation of "
+                             f"transformation from train data.")
+
+    def _store_transformation(self, transformation_opts):
+        """Store transformation options locally inside experiment_path if not exists already."""
+        experiment_path = self.data_store.get("experiment_path")
+        transformation_path = os.path.join(experiment_path, "data", "transformation")
+        transformation_file = os.path.join(transformation_path, "transformation.pickle")
+        if not os.path.exists(transformation_file):
+            path_config.check_path_and_create(transformation_path)
+            with open(transformation_file, "wb") as f:
+                dill.dump(transformation_opts, f, protocol=4)
+            logging.info(f"Store transformation options locally for later use at: {transformation_file}")
 
     def prepare_competitors(self):
         """
-- 
GitLab