From f9b7105f9fbd534cc659bef7039b270e1d37763f Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Wed, 10 Feb 2021 16:41:10 +0100
Subject: [PATCH] parallel transformation is back online, tests will still fail

---
 mlair/data_handler/default_data_handler.py | 102 +++++++++++++++++++++
 1 file changed, 102 insertions(+)

diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py
index 070da625..a7ce2dc0 100644
--- a/mlair/data_handler/default_data_handler.py
+++ b/mlair/data_handler/default_data_handler.py
@@ -248,6 +248,108 @@ class DefaultDataHandler(AbstractDataHandler):
             return
         if isinstance(transformation_dict, dict):  # tuple for (input, target) transformation
             transformation_dict = copy.deepcopy(transformation_dict), copy.deepcopy(transformation_dict)
+
+        def _inner():
+            """Inner method that is performed in both serial and parallel approach."""
+            if dh is not None:
+                for i, transformation in enumerate(dh._transformation):
+                    for var in transformation.keys():
+                        if var not in transformation_dict[i].keys():
+                            transformation_dict[i][var] = {}
+                        opts = transformation[var]
+                        assert transformation_dict[i][var].get("method", opts["method"]) == opts["method"]
+                        transformation_dict[i][var]["method"] = opts["method"]
+                        for k in ["mean", "std"]:
+                            old = transformation_dict[i][var].get(k, None)
+                            new = opts.get(k)
+                            transformation_dict[i][var][k] = new if old is None else old.combine_first(new)
+
+        if multiprocessing.cpu_count() > 1:  # parallel solution
+            logging.info("use parallel transformation approach")
+            pool = multiprocessing.Pool()
+            logging.info(f"running {getattr(pool, '_processes')} processes in parallel")
+            output = [
+                pool.apply_async(f_proc, args=(cls.data_handler_transformation, station), kwds=sp_keys)
+                for station in set_stations]
+            for p in output:
+                dh, s = p.get()
+                _inner()
+        else:  # serial solution
+            logging.info("use serial transformation approach")
+            for station in set_stations:
+                dh, s = f_proc(cls.data_handler_transformation, station, **sp_keys)
+                _inner()
+
+        # aggregate all information
+        pop_list = []
+        for i, transformation in enumerate(transformation_dict):
+            for k in transformation.keys():
+                try:
+                    if transformation[k]["mean"] is not None:
+                        transformation_dict[i][k]["mean"] = transformation[k]["mean"].mean("Stations")
+                    if transformation[k]["std"] is not None:
+                        transformation_dict[i][k]["std"] = transformation[k]["std"].mean("Stations")
+                except KeyError:
+                    pop_list.append((i, k))
+        for (i, k) in pop_list:
+            transformation_dict[i].pop(k)
+        return transformation_dict
+
+        # if multiprocessing.cpu_count() > 1:  # parallel solution
+        #     logging.info("use parallel transformation approach")
+        #     pool = multiprocessing.Pool()
+        #     logging.info(f"running {getattr(pool, '_processes')} processes in parallel")
+        #     output = [
+        #         pool.apply_async(f_proc, args=(cls.data_handler_transformation, station), kwds=sp_keys)
+        #         for station in set_stations]
+        #     for p in output:
+        #         dh, s = p.get()
+        #         if dh is not None:
+        #             for i, data in enumerate([dh.input_data, dh.target_data]):
+        #                 means[i] = data.mean.copy(deep=True) if means[i] is None else means[i].combine_first(data.mean)
+        #                 stds[i] = data.std.copy(deep=True) if stds[i] is None else stds[i].combine_first(data.std)
+        # else:  # serial solution
+        #     logging.info("use serial transformation approach")
+        #     for station in set_stations:
+        #         dh, s = f_proc(cls.data_handler_transformation, station, **sp_keys)
+        #         if dh is not None:
+        #             for i, data in enumerate([dh.input_data, dh.target_data]):
+        #                 means[i] = data.mean.copy(deep=True) if means[i] is None else means[i].combine_first(data.mean)
+        #                 stds[i] = data.std.copy(deep=True) if stds[i] is None else stds[i].combine_first(data.std)
+
+    @classmethod
+    def transformation_old(cls, set_stations, **kwargs):
+        """
+        ### supported transformation methods
+
+        Currently supported methods are:
+
+        * standardise (default, if method is not given)
+        * centre
+
+        ### mean and std estimation
+
+        Mean and std (depending on method) are estimated. For each station, mean and std are calculated and afterwards
+        aggregated using the mean value over all station-wise metrics. This method is not exactly accurate, especially
+        regarding the std calculation but therefore much faster. Furthermore, it is a weighted mean weighted by the
+        time series length / number of data itself - a longer time series has more influence on the transformation
+        settings than a short time series. The estimation of the std in less accurate, because the unweighted mean of
+        all stds in not equal to the true std, but still the mean of all station-wise std is a decent estimate. Finally,
+        the real accuracy of mean and std is less important, because it is "just" a transformation / scaling.
+
+        ### mean and std given
+
+        If mean and std are not None, the default data handler expects this parameters to match the data and applies
+        this values to the data. Make sure that all dimensions and/or coordinates are in agreement.
+        """
+
+        sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs}
+        transformation_dict = sp_keys.get("transformation", None)
+        if transformation_dict is None:
+            return
+        if isinstance(transformation_dict, dict):  # tuple for (input, target) transformation
+            transformation_dict = copy.deepcopy(transformation_dict), copy.deepcopy(transformation_dict)
+
         for station in set_stations:
             dh, s = f_proc(cls.data_handler_transformation, station, **sp_keys)
             if dh is not None:
-- 
GitLab