From d79abc09a4cfe05bcbb6fb43d9387f60f761edb1 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Tue, 31 May 2022 10:52:20 +0200
Subject: [PATCH] running version, but must be tested on other systems

---
 HPC_setup/requirements_HDFML_additionals.txt  |   2 +
 HPC_setup/requirements_JUWELS_additionals.txt |   2 +
 mlair/run_modules/post_processing.py          | 333 +++++++++++++++---
 requirements.txt                              |   4 +-
 4 files changed, 298 insertions(+), 43 deletions(-)

diff --git a/HPC_setup/requirements_HDFML_additionals.txt b/HPC_setup/requirements_HDFML_additionals.txt
index ebfac3cd..bfe94684 100644
--- a/HPC_setup/requirements_HDFML_additionals.txt
+++ b/HPC_setup/requirements_HDFML_additionals.txt
@@ -2,7 +2,9 @@ astropy==4.1
 bottleneck==1.3.2
 cached-property==1.5.2
 iniconfig==1.1.1
+multiprocess==0.70.12.2
 ordered-set==4.0.2
+pathos==0.2.8
 pyshp==2.1.3
 pytest-html==3.1.1
 pytest-lazy-fixture==0.6.3
diff --git a/HPC_setup/requirements_JUWELS_additionals.txt b/HPC_setup/requirements_JUWELS_additionals.txt
index ebfac3cd..bfe94684 100644
--- a/HPC_setup/requirements_JUWELS_additionals.txt
+++ b/HPC_setup/requirements_JUWELS_additionals.txt
@@ -2,7 +2,9 @@ astropy==4.1
 bottleneck==1.3.2
 cached-property==1.5.2
 iniconfig==1.1.1
+multiprocess==0.70.12.2
 ordered-set==4.0.2
+pathos==0.2.8
 pyshp==2.1.3
 pytest-html==3.1.1
 pytest-lazy-fixture==0.6.3
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index 07ef1ce4..5bf6e90f 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -8,6 +8,10 @@ import logging
 import os
 import sys
 import traceback
+import pathos
+import multiprocess.context as ctx
+ctx._force_start_method('spawn')
+import psutil
 from typing import Dict, Tuple, Union, List, Callable
 
 import tensorflow.keras as keras
@@ -695,6 +699,7 @@ class PostProcessing(RunEnvironment):
         logging.info(f"start train_ols_model on train data")
         self.ols_model = OrdinaryLeastSquaredModel(self.train_data)
 
+    @TimeTrackingWrapper
     def make_prediction(self, subset):
         """
         Create predictions for NN, OLS, and persistence and add true observation as reference.
@@ -707,48 +712,82 @@ class PostProcessing(RunEnvironment):
         logging.info(f"start make_prediction for {subset_type}")
         time_dimension = self.data_store.get("time_dim")
         window_dim = self.data_store.get("window_dim")
-        subset_type = subset.name
-        for i, data in enumerate(subset):
-            input_data = data.get_X()
-            target_data = data.get_Y(as_numpy=False)
-            observation_data = data.get_observation()
-
-            # get scaling parameters
-            transformation_func = data.apply_transformation
-
-            for normalised in [True, False]:
-                # create empty arrays
-                nn_prediction, persistence_prediction, ols_prediction, observation = self._create_empty_prediction_arrays(
-                    target_data, count=4)
-
-                # nn forecast
-                nn_prediction = self._create_nn_forecast(input_data, nn_prediction, transformation_func, normalised)
-
-                # persistence
-                persistence_prediction = self._create_persistence_forecast(observation_data, persistence_prediction,
-                                                                           transformation_func, normalised)
-
-                # ols
-                ols_prediction = self._create_ols_forecast(input_data, ols_prediction, transformation_func, normalised)
-
-                # observation
-                observation = self._create_observation(target_data, observation, transformation_func, normalised)
-
-                # merge all predictions
-                full_index = self.create_fullindex(observation_data.indexes[time_dimension], self._get_frequency())
-                prediction_dict = {self.forecast_indicator: nn_prediction,
-                                   "persi": persistence_prediction,
-                                   self.observation_indicator: observation,
-                                   "ols": ols_prediction}
-                all_predictions = self.create_forecast_arrays(full_index, list(target_data.indexes[window_dim]),
-                                                              time_dimension, ahead_dim=self.ahead_dim,
-                                                              index_dim=self.index_dim, type_dim=self.model_type_dim,
-                                                              **prediction_dict)
-
-                # save all forecasts locally
-                prefix = "forecasts_norm" if normalised is True else "forecasts"
-                file = os.path.join(self.forecast_path, f"{prefix}_{str(data)}_{subset_type}.nc")
-                all_predictions.to_netcdf(file)
+        frequency = self._get_frequency()
+
+        use_multiprocessing = self.data_store.get("use_multiprocessing")
+        max_process = self.data_store.get("max_number_multiprocessing")
+        n_process = min([psutil.cpu_count(logical=False), len(subset), max_process])  # use only physical cpus
+        if n_process > 1 and use_multiprocessing is True:  # parallel solution
+            logging.info("use parallel make prediction approach")
+            pool = pathos.multiprocessing.ProcessingPool(n_process)
+            logging.info(f"running {getattr(pool, 'ncpus')} processes in parallel")
+            output = [
+                pool.apipe(f_proc_make_prediction, data, frequency, time_dimension, self.forecast_indicator,
+                           self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim,
+                           self.forecast_path, subset_type, self.model.model, self.window_lead_time, self.ols_model)
+                for data in subset]
+            for i, p in enumerate(output):
+                p.get()
+                logging.info(f"...finished: {subset[i]} ({int((i + 1.) / len(output) * 100)}%)")
+            pool.close()
+            pool.join()
+            pool.clear()
+        else:  # serial solution
+            logging.info("use serial make prediction approach")
+            for i, data in enumerate(subset):
+                f_proc_make_prediction(data, frequency, time_dimension, self.forecast_indicator,
+                    self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim,
+                    self.forecast_path, subset_type, self.model, self.window_lead_time, self.ols_model)
+                logging.info(f"...finished: {data} ({int((i + 1.) / len(subset) * 100)}%)")
+
+        # for i, data in enumerate(subset):
+        #     f_proc_make_prediction(data, frequency, time_dimension, self.forecast_indicator, self.observation_indicator, window_dim,
+        #                            self.ahead_dim, self.index_dim, self.model_type_dim, self.forecast_path, subset_type, model,
+        #                            self.window_lead_time, self.ols_model)
+
+        #
+        #
+        # for i, data in enumerate(subset):
+        #     input_data = data.get_X()
+        #     target_data = data.get_Y(as_numpy=False)
+        #     observation_data = data.get_observation()
+        #
+        #     # get scaling parameters
+        #     transformation_func = data.apply_transformation
+        #
+        #     for normalised in [True, False]:
+        #         # create empty arrays
+        #         nn_prediction, persistence_prediction, ols_prediction, observation = self._create_empty_prediction_arrays(
+        #             target_data, count=4)
+        #
+        #         # nn forecast
+        #         nn_prediction = self._create_nn_forecast(input_data, nn_prediction, transformation_func, normalised)
+        #
+        #         # persistence
+        #         persistence_prediction = self._create_persistence_forecast(observation_data, persistence_prediction,
+        #                                                                    transformation_func, normalised)
+        #
+        #         # ols
+        #         ols_prediction = self._create_ols_forecast(input_data, ols_prediction, transformation_func, normalised)
+        #
+        #         # observation
+        #         observation = self._create_observation(target_data, observation, transformation_func, normalised)
+        #
+        #         # merge all predictions
+        #         full_index = self.create_fullindex(observation_data.indexes[time_dimension], self._get_frequency())
+        #         prediction_dict = {self.forecast_indicator: nn_prediction,
+        #                            "persi": persistence_prediction,
+        #                            self.observation_indicator: observation,
+        #                            "ols": ols_prediction}
+        #         all_predictions = self.create_forecast_arrays(full_index, list(target_data.indexes[window_dim]),
+        #                                                       time_dimension, ahead_dim=self.ahead_dim,
+        #                                                       index_dim=self.index_dim, type_dim=self.model_type_dim,
+        #                                                       **prediction_dict)
+        #
+        #         # save all forecasts locally
+        #         prefix = "forecasts_norm" if normalised is True else "forecasts"
+        #         file = os.path.join(self.forecast_path, f"{prefix}_{str(data)}_{subset_type}.nc")
+        #         all_predictions.to_netcdf(file)
 
     def _get_frequency(self) -> str:
         """Get frequency abbreviation."""
@@ -1110,3 +1149,213 @@ class PostProcessing(RunEnvironment):
                     file_name = f"error_report_{metric}_{model_type}.%s".replace(' ', '_').replace('/', '_')
                 tables.save_to_tex(report_path, file_name % "tex", column_format=column_format, df=df)
                 tables.save_to_md(report_path, file_name % "md", df=df)
+
+
+class MakePrediction:
+
+    def __init__(self, model, window_lead_time, ols_model):
+        self.model = model
+        self.window_lead_time = window_lead_time
+        self.ols_model = ols_model  # must be copied maybe
+
+    @staticmethod
+    def _create_empty_prediction_arrays(target_data, count=1):
+        """
+        Create array to collect all predictions. Expand target data by a station dimension. """
+        return [target_data.copy() for _ in range(count)]
+
+    def _create_nn_forecast(self, input_data: xr.DataArray, nn_prediction: xr.DataArray, transformation_func: Callable,
+                            normalised: bool) -> xr.DataArray:
+        """
+        Create NN forecast for given input data.
+
+        Inverse transformation is applied to the forecast to get the output in the original space. Furthermore, only the
+        output of the main branch is returned (not all minor branches, if the network has multiple output branches). The
+        main branch is defined to be the last entry of all outputs.
+
+        :param input_data: transposed history from DataPrep
+        :param nn_prediction: empty array in right shape to fill with data
+        :param transformation_func: a callable function to apply inverse transformation
+        :param normalised: transform prediction in original space if false, or use normalised predictions if true
+
+        :return: filled data array with nn predictions
+        """
+        tmp_nn = self.model.predict(input_data)
+        if isinstance(tmp_nn, list):
+            nn_prediction.values = tmp_nn[-1]
+        elif tmp_nn.ndim == 3:
+            nn_prediction.values = tmp_nn[-1, ...]
+        elif tmp_nn.ndim == 2:
+            nn_prediction.values = tmp_nn
+        else:
+            raise NotImplementedError(f"Number of dimension of model output must be 2 or 3, but not {tmp_nn.dims}.")
+        if not normalised:
+            nn_prediction = transformation_func(nn_prediction, base="target", inverse=True)
+        return nn_prediction
+
+    def _create_persistence_forecast(self, data, persistence_prediction: xr.DataArray, transformation_func: Callable,
+                                     normalised: bool) -> xr.DataArray:
+        """
+        Create persistence forecast with given data.
+
+        Persistence is deviated from the value at t=0 and applied to all following time steps (t+1, ..., t+window).
+        Inverse transformation is applied to the forecast to get the output in the original space.
+
+        :param data: observation
+        :param persistence_prediction: empty array in right shape to fill with data
+        :param transformation_func: a callable function to apply inverse transformation
+        :param normalised: transform prediction in original space if false, or use normalised predictions if true
+
+        :return: filled data array with persistence predictions
+        """
+        tmp_persi = data.copy()
+        persistence_prediction.values = np.tile(tmp_persi, (self.window_lead_time, 1)).T
+        if not normalised:
+            persistence_prediction = transformation_func(persistence_prediction, "target", inverse=True)
+        return persistence_prediction
+
+    def _create_ols_forecast(self, input_data: xr.DataArray, ols_prediction: xr.DataArray,
+                             transformation_func: Callable, normalised: bool) -> xr.DataArray:
+        """
+        Create ordinary least square model forecast with given input data.
+
+        Inverse transformation is applied to the forecast to get the output in the original space.
+
+        :param input_data: transposed history from DataPrep
+        :param ols_prediction: empty array in right shape to fill with data
+        :param transformation_func: a callable function to apply inverse transformation
+        :param normalised: transform prediction in original space if false, or use normalised predictions if true
+
+        :return: filled data array with ols predictions
+        """
+        tmp_ols = self.ols_model.predict(input_data)
+        target_shape = ols_prediction.values.shape
+        if target_shape != tmp_ols.shape:
+            if len(target_shape) == 2:
+                new_values = np.swapaxes(tmp_ols, 1, 0)
+            else:
+                new_values = np.swapaxes(tmp_ols, 2, 0)
+        else:
+            new_values = tmp_ols
+        ols_prediction.values = new_values
+        if not normalised:
+            ols_prediction = transformation_func(ols_prediction, "target", inverse=True)
+        return ols_prediction
+
+    def _create_observation(self, data, _, transformation_func: Callable, normalised: bool) -> xr.DataArray:
+        """
+        Create observation as ground truth from given data.
+
+        Inverse transformation is applied to the ground truth to get the output in the original space.
+
+        :param data: observation
+        :param transformation_func: a callable function to apply inverse transformation
+        :param normalised: transform ground truth in original space if false, or use normalised predictions if true
+
+        :return: filled data array with observation
+        """
+        if not normalised:
+            data = transformation_func(data, "target", inverse=True)
+        return data
+
+    @staticmethod
+    def create_fullindex(df: Union[xr.DataArray, pd.DataFrame, pd.DatetimeIndex], freq: str) -> pd.DataFrame:
+        """
+        Create full index from first and last date inside df and resample with given frequency.
+
+        :param df: use time range of this data set
+        :param freq: frequency of full index
+
+        :return: empty data frame with full index.
+        """
+        if isinstance(df, pd.DataFrame):
+            earliest = df.index[0]
+            latest = df.index[-1]
+        elif isinstance(df, xr.DataArray):
+            earliest = df.index[0].values
+            latest = df.index[-1].values
+        elif isinstance(df, pd.DatetimeIndex):
+            earliest = df[0]
+            latest = df[-1]
+        else:
+            raise AttributeError(f"unknown array type. Only pandas dataframes, xarray dataarrays and pandas datetimes "
+                                 f"are supported. Given type is {type(df)}.")
+        index = pd.DataFrame(index=pd.date_range(earliest, latest, freq=freq))
+        return index
+
+    @staticmethod
+    def create_forecast_arrays(index: pd.DataFrame, ahead_names: List[Union[str, int]], time_dimension,
+                               ahead_dim="ahead", index_dim="index", type_dim="type", **kwargs):
+        """
+        Combine different forecast types into single xarray.
+
+        :param index: index for forecasts (e.g. time)
+        :param ahead_names: names of ahead values (e.g. hours or days)
+        :param kwargs: as xarrays; data of forecasts
+
+        :return: xarray of dimension 3: index, ahead_names, # predictions
+
+        """
+        keys = list(kwargs.keys())
+        res = xr.DataArray(np.full((len(index.index), len(ahead_names), len(keys)), np.nan),
+                           coords=[index.index, ahead_names, keys], dims=[index_dim, ahead_dim, type_dim])
+        for k, v in kwargs.items():
+            intersection = set(res.index.values) & set(v.indexes[time_dimension].values)
+            match_index = np.array(list(intersection))
+            res.loc[match_index, :, k] = v.loc[match_index]
+        return res
+
+
+def f_proc_make_prediction(data, frequency, time_dimension, forecast_indicator, observation_indicator, window_dim,
+                           ahead_dim, index_dim, model_type_dim, forecast_path, subset_type, model, window_lead_time,
+                           ols_model, custom_objects=None):
+    # import tensorflow.keras as keras
+    # if not (hasattr(model, "model") or isinstance(model, keras.Model)):
+    #     print(f"{data} load model")
+    #     model = keras.models.load_model(model, custom_objects=custom_objects)
+    #     model.make_predict_function()
+    #     print(f"{data} done")
+
+    prediction_maker = MakePrediction(model, window_lead_time, ols_model)
+
+    input_data = data.get_X()
+    target_data = data.get_Y(as_numpy=False)
+    observation_data = data.get_observation()
+
+    # get scaling parameters
+    transformation_func = data.apply_transformation
+
+    for normalised in [True, False]:
+
+        # create empty arrays
+        nn_pred, persi_pred, ols_pred, observation = prediction_maker._create_empty_prediction_arrays(target_data,
+                                                                                                      count=4)
+
+        # nn forecast
+        nn_pred = prediction_maker._create_nn_forecast(input_data, nn_pred, transformation_func, normalised)
+
+        # persistence
+        persi_pred = prediction_maker._create_persistence_forecast(observation_data, persi_pred,
+                                                                   transformation_func, normalised)
+
+        # ols
+        ols_pred = prediction_maker._create_ols_forecast(input_data, ols_pred, transformation_func, normalised)
+
+        # observation
+        observation = prediction_maker._create_observation(target_data, observation, transformation_func, normalised)
+
+        # merge all predictions
+        full_index = prediction_maker.create_fullindex(observation_data.indexes[time_dimension], frequency)
+        prediction_dict = {forecast_indicator: nn_pred,
+                           "persi": persi_pred,
+                           observation_indicator: observation,
+                           "ols": ols_pred}
+        all_predictions = prediction_maker.create_forecast_arrays(full_index, list(target_data.indexes[window_dim]),
+                                                                  time_dimension, ahead_dim=ahead_dim,
+                                                                  index_dim=index_dim, type_dim=model_type_dim,
+                                                                  **prediction_dict)
+
+        # save all forecasts locally
+        prefix = "forecasts_norm" if normalised is True else "forecasts"
+        file = os.path.join(forecast_path, f"{prefix}_{str(data)}_{subset_type}.nc")
+        all_predictions.to_netcdf(file)
diff --git a/requirements.txt b/requirements.txt
index 3afc17b6..7b9e676d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,17 +2,19 @@ astropy==4.1
 auto_mix_prep==0.2.0
 Cartopy==0.18.0
 dask==2021.3.0
-dill==0.3.3
+dill==0.3.4
 fsspec==2021.11.0
 keras==2.6.0
 keras_nightly==2.5.0.dev2021032900
 locket==0.2.1
 matplotlib==3.3.4
 mock==4.0.3
+multiprocess==0.70.12.2
 netcdf4==1.5.8
 numpy==1.19.5
 pandas==1.1.5
 partd==1.2.0
+pathos==0.2.8
 psutil==5.8.0
 pydot==1.4.2
 pytest==6.2.2
-- 
GitLab