From ac46e586168b38edd03993dc5c2ab252cb38aa20 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Wed, 20 Jan 2021 15:55:21 +0100
Subject: [PATCH] competitors and their path can now be named in a workflow
 setup, comments for development have been removed now, /close #198

---
 mlair/run_modules/experiment_setup.py |  8 ++-
 mlair/run_modules/post_processing.py  | 86 +++++++++++++--------------
 2 files changed, 49 insertions(+), 45 deletions(-)

diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py
index 54d23077..e772bbaa 100644
--- a/mlair/run_modules/experiment_setup.py
+++ b/mlair/run_modules/experiment_setup.py
@@ -226,7 +226,7 @@ class ExperimentSetup(RunEnvironment):
                  number_of_bootstraps=None,
                  create_new_bootstraps=None, data_path: str = None, batch_path: str = None, login_nodes=None,
                  hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None,
-                 data_origin: Dict = None, **kwargs):
+                 data_origin: Dict = None, competitors: list = None, competitor_path: str = None, **kwargs):
 
         # create run framework
         super().__init__()
@@ -345,6 +345,12 @@ class ExperimentSetup(RunEnvironment):
         self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing")
         self._set_param("neighbors", ["DEBW030"])  # TODO: just for testing
 
+        # set competitors
+        self._set_param("competitors", helpers.to_list(competitors), default=[])
+        competitor_path_default = os.path.join(self.data_store.get("data_path"), "competitors",
+                                               "_".join(self.data_store.get("target_var")))
+        self._set_param("competitor_path", competitor_path, default=competitor_path_default)
+
         # check variables, statistics and target variable
         self._check_target_var()
         self._compare_variables_and_statistics()
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index d4b23f34..ebcf105f 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -64,6 +64,7 @@ class PostProcessing(RunEnvironment):
         """Initialise and run post-processing."""
         super().__init__()
         self.model: keras.Model = self._load_model()
+        self.model_name = self.data_store.get("model_name", "model").rsplit("/", 1)[1].split(".", 1)[0]
         self.ols_model = None
         self.batch_size: int = self.data_store.get_default("batch_size", "model", 64)
         self.test_data = self.data_store.get("data_collection", "test")
@@ -79,42 +80,28 @@ class PostProcessing(RunEnvironment):
         self.window_lead_time = extract_value(self.data_store.get("output_shape", "model"))
         self.skill_scores = None
         self.bootstrap_skill_scores = None
-        # ToDo: adjust this hard coded by a new entry in the data store setup in experiment setup phase
-        self.competitor_path = os.path.join(self.data_store.get("data_path"),
-                                            "competitors",
-                                            self.target_var
-                                            # ToDo: make sure this is a string, multiple vars are joined by underscore
-                                            )
-        self.competitors = ["test_model", "test_model2"]  # ToDo: this shouldn't be hardcoded
-        self.competitor_forecast_name = "cnn"  # ToDo: another refac, rename the cnn field to something like forecast to be more general
-        # ToDo: this must be applied to all predictions. Maybe there should be a model name parameter in the data store
-        # that can be requested. I would also argue that forecast_name and the string in competitors should be the same
-        # name.
+        self.competitor_path = self.data_store.get("competitor_path")
+        self.competitors = self.data_store.get_default("competitors", default=[])
+        self.forecast_indicator = "nn"
         self._run()
 
     def _run(self):
         # ols model
-        self.train_ols_model()  # ToDo: remove comment for final commit
+        self.train_ols_model()
 
         # forecasts
-        self.make_prediction()  # ToDo: remove comment for final commit
-
-        # competitors
-        self.load_competitors(
-            "DEBW107")  # ToDo: remove this line for final commit, it is not required to load the competitors here
+        self.make_prediction()
 
         # calculate error metrics on test data
-        # self.calculate_test_score()  # ToDo: remove comment for final commit
+        self.calculate_test_score()
 
         # bootstraps
-        # ToDo: remove comment for final commit
-        # if self.data_store.get("evaluate_bootstraps", "postprocessing"):
-        #     with TimeTracking(name="calculate bootstraps"):
-        #         create_new_bootstraps = self.data_store.get("create_new_bootstraps", "postprocessing")
-        #         self.bootstrap_postprocessing(create_new_bootstraps)
+        if self.data_store.get("evaluate_bootstraps", "postprocessing"):
+            with TimeTracking(name="calculate bootstraps"):
+                create_new_bootstraps = self.data_store.get("create_new_bootstraps", "postprocessing")
+                self.bootstrap_postprocessing(create_new_bootstraps)
 
         # skill scores
-        # ToDo: remove for final commit
         with TimeTracking(name="calculate skill scores"):
             self.skill_scores = self.calculate_skill_scores()
 
@@ -137,7 +124,8 @@ class PostProcessing(RunEnvironment):
             try:
                 prediction = self._create_competitor_forecast(station_name, competitor_name)
                 competing_predictions.append(prediction)
-            except FileNotFoundError:
+            except (FileNotFoundError, KeyError):
+                logging.debug(f"No competitor found for combination '{station_name}' and '{competitor_name}'.")
                 continue
         return xr.concat(competing_predictions, "type") if len(competing_predictions) > 0 else None
 
@@ -255,13 +243,18 @@ class PostProcessing(RunEnvironment):
                 score[str(station)] = xr.DataArray(skill, dims=["boot_var", "ahead"])
             return score
 
-    @staticmethod
-    def get_orig_prediction(path, file_name, number_of_bootstraps, prediction_name="cnn"):
+    def get_orig_prediction(self, path, file_name, number_of_bootstraps, prediction_name=None):
+        if prediction_name is None:
+            prediction_name = self.forecast_indicator
         file = os.path.join(path, file_name)
         prediction = xr.open_dataarray(file).sel(type=prediction_name).squeeze()
         vals = np.tile(prediction.data, (number_of_bootstraps, 1))
         return vals[~np.isnan(vals).any(axis=1), :]
 
+    def _get_model_name(self):
+        """Return model name without path information."""
+        return self.data_store.get("model_name", "model").rsplit("/", 1)[1].split(".", 1)[0]
+
     def _load_model(self) -> keras.models:
         """
         Load NN model either from data store or from local path.
@@ -311,7 +304,8 @@ class PostProcessing(RunEnvironment):
 
         try:
             if (self.bootstrap_skill_scores is not None) and ("PlotBootstrapSkillScore" in plot_list):
-                PlotBootstrapSkillScore(self.bootstrap_skill_scores, plot_folder=self.plot_path, model_setup="cnn")
+                PlotBootstrapSkillScore(self.bootstrap_skill_scores, plot_folder=self.plot_path,
+                                        model_setup=self.forecast_indicator)
         except Exception as e:
             logging.error(f"Could not create plot PlotBootstrapSkillScore due to the following error: {e}")
 
@@ -347,15 +341,17 @@ class PostProcessing(RunEnvironment):
 
         try:
             if "PlotClimatologicalSkillScore" in plot_list:
-                PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, model_setup="cnn")
+                PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path,
+                                             model_setup=self.forecast_indicator)
                 PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, score_only=False,
-                                             extra_name_tag="all_terms_", model_setup="cnn")
+                                             extra_name_tag="all_terms_", model_setup=self.forecast_indicator)
         except Exception as e:
             logging.error(f"Could not create plot PlotClimatologicalSkillScore due to the following error: {e}")
 
         try:
             if "PlotCompetitiveSkillScore" in plot_list:
-                PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path, model_setup="cnn")
+                PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path,
+                                          model_setup=self.forecast_indicator)
         except Exception as e:
             logging.error(f"Could not create plot PlotCompetitiveSkillScore due to the following error: {e}")
 
@@ -430,16 +426,17 @@ class PostProcessing(RunEnvironment):
                                                            normalised)
 
                 # observation
-                observation = self._create_observation(target_data, observation, mean, std, transformation_method, normalised)
+                observation = self._create_observation(target_data, observation, mean, std, transformation_method,
+                                                       normalised)
 
                 # merge all predictions
                 full_index = self.create_fullindex(observation_data.indexes[time_dimension], self._get_frequency())
+                prediction_dict = {self.forecast_indicator: nn_prediction,
+                                   "persi": persistence_prediction,
+                                   "obs": observation,
+                                   "ols": ols_prediction}
                 all_predictions = self.create_forecast_arrays(full_index, list(target_data.indexes['window']),
-                                                              time_dimension,
-                                                              cnn=nn_prediction,
-                                                              persi=persistence_prediction,
-                                                              obs=observation,
-                                                              ols=ols_prediction)
+                                                              time_dimension, **prediction_dict)
 
                 # save all forecasts locally
                 path = self.data_store.get("forecast_path")
@@ -456,7 +453,8 @@ class PostProcessing(RunEnvironment):
         """
         Load and format the competing forecast of a distinct model indicated by `competitor_name` for a distinct station
         indicated by `station_name`. The name of the competitor is set in the `type` axis as indicator. This method will
-        raise an `FileNotFoundError` if no competitor could be found for the given station.
+        raise either a `FileNotFoundError` or `KeyError` if no competitor could be found for the given station. Either
+        there is no file provided in the expected path or no forecast for given `competitor_name` in the forecast file.
 
         :param station_name: name of the station to load data for
         :param competitor_name: name of the model
@@ -465,11 +463,12 @@ class PostProcessing(RunEnvironment):
         path = os.path.join(self.competitor_path, competitor_name)
         file = os.path.join(path, f"forecasts_{station_name}_test.nc")
         data = xr.open_dataarray(file)
-        # data = data.expand_dims(Stations=[station_name])
-        forecast = data.sel(type=[self.competitor_forecast_name])
+        # data = data.expand_dims(Stations=[station_name])  # ToDo: remove line
+        forecast = data.sel(type=[self.forecast_indicator])
         forecast.coords["type"] = [competitor_name]
         return forecast
 
+
     @staticmethod
     def _create_observation(data, _, mean: xr.DataArray, std: xr.DataArray, transformation_method: str,
                             normalised: bool) -> xr.DataArray:
@@ -497,7 +496,7 @@ class PostProcessing(RunEnvironment):
 
         Inverse transformation is applied to the forecast to get the output in the original space.
 
-        :param data: transposed history from DataPrep
+        :param input_data: transposed history from DataPrep
         :param ols_prediction: empty array in right shape to fill with data
         :param mean: mean of target value transformation
         :param std: standard deviation of target value transformation
@@ -655,7 +654,6 @@ class PostProcessing(RunEnvironment):
         skill_score_competitive = {}
         skill_score_climatological = {}
         for station in self.test_data:
-            logging.info(str(station))
             file = os.path.join(path, f"forecasts_{str(station)}_test.nc")
             data = xr.open_dataarray(file)
             competitor = self.load_competitors(str(station))
@@ -663,6 +661,6 @@ class PostProcessing(RunEnvironment):
             skill_score = statistics.SkillScores(combined, models=remove_items(list(combined.type.values), "obs"))
             external_data = self._get_external_data(station)  # ToDo: check if external is still right?
             skill_score_competitive[station] = skill_score.skill_scores(self.window_lead_time)
-            skill_score_climatological[station] = skill_score.climatological_skill_scores(external_data,
-                                                                                          self.window_lead_time)
+            skill_score_climatological[station] = skill_score.climatological_skill_scores(
+                external_data, self.window_lead_time, forecast_name=self.forecast_indicator)
         return skill_score_competitive, skill_score_climatological
-- 
GitLab