From 175841d5855cef26b63fdbb7e12075e0a4e17b19 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Tue, 13 Apr 2021 11:52:12 +0200
Subject: [PATCH] moved availability and station map to preprocessing

---
 mlair/run_modules/post_processing.py | 34 -------------------
 mlair/run_modules/pre_processing.py  | 50 ++++++++++++++++++++++++++++
 2 files changed, 50 insertions(+), 34 deletions(-)

diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index ff74da37..a633dec1 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -21,7 +21,6 @@ from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel
 from mlair.model_modules import AbstractModelClass
 from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotClimatologicalSkillScore, \
     PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotConditionalQuantiles, PlotSeparationOfScales
-from mlair.plotting.preprocessing_plotting import PlotStationMap, PlotAvailability, PlotAvailabilityHistogram
 from mlair.run_modules.run_environment import RunEnvironment
 
 
@@ -325,23 +324,6 @@ class PostProcessing(RunEnvironment):
         except Exception as e:
             logging.error(f"Could not create plot PlotConditionalQuantiles due to the following error: {e}")
 
-        try:
-            if "PlotStationMap" in plot_list:
-                if self.data_store.get("hostname")[:2] in self.data_store.get("hpc_hosts") or self.data_store.get(
-                        "hostname")[:6] in self.data_store.get("hpc_hosts"):
-                    logging.warning(
-                        f"Skip 'PlotStationMap` because running on a hpc node: {self.data_store.get('hostname')}")
-                else:
-                    gens = [(self.train_data, {"marker": 5, "ms": 9}),
-                            (self.val_data, {"marker": 6, "ms": 9}),
-                            (self.test_data, {"marker": 4, "ms": 9})]
-                    PlotStationMap(generators=gens, plot_folder=self.plot_path)
-                    gens = [(self.train_val_data, {"marker": 8, "ms": 9}),
-                            (self.test_data, {"marker": 9, "ms": 9})]
-                    PlotStationMap(generators=gens, plot_folder=self.plot_path, plot_name="station_map_var")
-        except Exception as e:
-            logging.error(f"Could not create plot PlotStationMap due to the following error: {e}")
-
         try:
             if "PlotMonthlySummary" in plot_list:
                 PlotMonthlySummary(self.test_data.keys(), path, r"forecasts_%s_test.nc", self.target_var,
@@ -372,22 +354,6 @@ class PostProcessing(RunEnvironment):
         except Exception as e:
             logging.error(f"Could not create plot PlotTimeSeries due to the following error: {e}")
 
-        try:
-            if "PlotAvailability" in plot_list:
-                avail_data = {"train": self.train_data, "val": self.val_data, "test": self.test_data}
-                PlotAvailability(avail_data, plot_folder=self.plot_path, time_dimension=time_dim,
-                                 window_dimension=window_dim)
-        except Exception as e:
-            logging.error(f"Could not create plot PlotAvailability due to the following error: {e}")
-
-        try:
-            if "PlotAvailabilityHistogram" in plot_list:
-                avail_data = {"train": self.train_data, "val": self.val_data, "test": self.test_data}
-                PlotAvailabilityHistogram(avail_data, plot_folder=self.plot_path, station_dim=iter_dim,
-                                          history_dim=window_dim)
-        except Exception as e:
-            logging.error(f"Could not create plot PlotAvailabilityHistogram due to the following error: {e}")
-
     def calculate_test_score(self):
         """Evaluate test score of model and save locally."""
 
diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py
index f59a4e89..3c2670aa 100644
--- a/mlair/run_modules/pre_processing.py
+++ b/mlair/run_modules/pre_processing.py
@@ -18,6 +18,7 @@ from mlair.helpers import TimeTracking, to_list, tables
 from mlair.configuration import path_config
 from mlair.helpers.join import EmptyQueryResult
 from mlair.run_modules.run_environment import RunEnvironment
+from mlair.plotting.preprocessing_plotting import PlotStationMap, PlotAvailability, PlotAvailabilityHistogram
 
 
 class PreProcessing(RunEnvironment):
@@ -67,6 +68,7 @@ class PreProcessing(RunEnvironment):
         self.split_train_val_test()
         self.report_pre_processing()
         self.prepare_competitors()
+        self.plot()
 
     def report_pre_processing(self):
         """Log some metrics on data and create latex report."""
@@ -327,6 +329,54 @@ class PreProcessing(RunEnvironment):
         else:
             logging.info("No preparation required because no competitor was provided to the workflow.")
 
+    def plot(self):
+        logging.info("Run plotting routines...")
+
+        plot_list = self.data_store.get("plot_list", "postprocessing")
+        time_dim = self.data_store.get("time_dim")
+        window_dim = self.data_store.get("window_dim")
+        target_dim = self.data_store.get("target_dim")
+        iter_dim = self.data_store.get("iter_dim")
+
+        train_data = self.data_store.get("data_collection", "train")
+        val_data = self.data_store.get("data_collection", "val")
+        test_data = self.data_store.get("data_collection", "test")
+        train_val_data = self.data_store.get("data_collection", "train_val")
+        plot_path: str = self.data_store.get("plot_path")
+
+        try:
+            if "PlotStationMap" in plot_list:
+                if self.data_store.get("hostname")[:2] in self.data_store.get("hpc_hosts") or self.data_store.get(
+                        "hostname")[:6] in self.data_store.get("hpc_hosts"):
+                    logging.warning(
+                        f"Skip 'PlotStationMap` because running on a hpc node: {self.data_store.get('hostname')}")
+                else:
+                    gens = [(train_data, {"marker": 5, "ms": 9}),
+                            (val_data, {"marker": 6, "ms": 9}),
+                            (test_data, {"marker": 4, "ms": 9})]
+                    PlotStationMap(generators=gens, plot_folder=plot_path)
+                    gens = [(train_val_data, {"marker": 8, "ms": 9}),
+                            (test_data, {"marker": 9, "ms": 9})]
+                    PlotStationMap(generators=gens, plot_folder=plot_path, plot_name="station_map_var")
+        except Exception as e:
+            logging.error(f"Could not create plot PlotStationMap due to the following error: {e}")
+
+        try:
+            if "PlotAvailability" in plot_list:
+                avail_data = {"train": train_data, "val": val_data, "test": test_data}
+                PlotAvailability(avail_data, plot_folder=plot_path, time_dimension=time_dim,
+                                 window_dimension=window_dim)
+        except Exception as e:
+            logging.error(f"Could not create plot PlotAvailability due to the following error: {e}")
+
+        try:
+            if "PlotAvailabilityHistogram" in plot_list:
+                avail_data = {"train": train_data, "val": val_data, "test": test_data}
+                PlotAvailabilityHistogram(avail_data, plot_folder=plot_path, station_dim=iter_dim,
+                                          history_dim=window_dim)
+        except Exception as e:
+            logging.error(f"Could not create plot PlotAvailabilityHistogram due to the following error: {e}")
+
 
 def f_proc(data_handler, station, name_affix, store, **kwargs):
     """
-- 
GitLab