From 4cd8fe3c9d40bb33a61348009fc62e37ad78154c Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Mon, 6 Jan 2020 14:51:41 +0100
Subject: [PATCH] post processing uses plot_monthly_summary()

---
 run.py                             |  3 ++-
 src/run_modules/post_processing.py | 16 ++++++++++++++--
 2 files changed, 16 insertions(+), 3 deletions(-)

diff --git a/run.py b/run.py
index 03eda042..71244fb9 100644
--- a/run.py
+++ b/run.py
@@ -30,7 +30,8 @@ def main(parser_args):
 if __name__ == "__main__":
 
     formatter = '%(asctime)s - %(levelname)s: %(message)s  [%(filename)s:%(funcName)s:%(lineno)s]'
-    logging.basicConfig(format=formatter, level=logging.INFO)
+    # logging.basicConfig(format=formatter, level=logging.INFO)
+    logging.basicConfig(format=formatter, level=logging.DEBUG)
 
     parser = argparse.ArgumentParser()
     parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None,
diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py
index 35d93dcb..5d4c805a 100644
--- a/src/run_modules/post_processing.py
+++ b/src/run_modules/post_processing.py
@@ -12,10 +12,12 @@ import statsmodels.api as sm
 
 from src.run_modules.run_environment import RunEnvironment
 from src.data_handling.data_distributor import Distributor
+from src.data_handling.data_generator import DataGenerator
 from src.model_modules.linear_model import OrdinaryLeastSquaredModel
 from src import statistics
 from src import helpers
 from src.helpers import TimeTracking
+from src.plotting.postprocessing_plotting import plot_monthly_summary
 
 
 class PostProcessing(RunEnvironment):
@@ -25,14 +27,23 @@ class PostProcessing(RunEnvironment):
         self.model = self.data_store.get("best_model", "general")
         self.ols_model = None
         self.batch_size = self.data_store.get("batch_size", "general.model")
-        self.test_data = self.data_store.get("generator", "general.test")
+        self.test_data: DataGenerator = self.data_store.get("generator", "general.test")
         self.test_data_distributed = Distributor(self.test_data, self.model, self.batch_size)
-        self.train_data = self.data_store.get("generator", "general.train")
+        self.train_data: DataGenerator = self.data_store.get("generator", "general.train")
+        self.plot_path = self.data_store.get("plot_path", "general")
         self._run()
 
     def _run(self):
         self.train_ols_model()
         preds_for_all_stations = self.make_prediction()
+        self.plot()
+
+    def plot(self):
+        path = self.data_store.get("forecast_path", "general")
+        window_lead_time = self.data_store.get("window_lead_time", "general")
+        target_var = self.data_store.get("target_var", "general")
+        plot_monthly_summary(self.test_data.stations, path, r"forecasts_%s_test.nc", window_lead_time, target_var,
+                             plot_folder=self.plot_path)
 
     def calculate_test_score(self):
         test_score = self.model.evaluate(generator=self.test_data_distributed.distribute_on_batches(),
@@ -50,6 +61,7 @@ class PostProcessing(RunEnvironment):
         self.ols_model = OrdinaryLeastSquaredModel(self.train_data)
 
     def make_prediction(self, freq="1D"):
+        logging.debug("start make_prediction")
         nn_prediction_all_stations = []
         for i, v in enumerate(self.test_data):
             data = self.test_data.get_data_generator(i)
-- 
GitLab