From 6329b0b7554428929d706ff45b4519d397a5ebbf Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Tue, 17 Mar 2020 15:21:54 +0100
Subject: [PATCH] upsampling is included in experiment workflow and is
 exclusively restricted to train data

---
 src/run_modules/experiment_setup.py |  9 +++++++--
 src/run_modules/pre_processing.py   |  3 ++-
 src/run_modules/training.py         | 18 ++++++++++++------
 3 files changed, 21 insertions(+), 9 deletions(-)

diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py
index 2aafe6c6..39d973b2 100644
--- a/src/run_modules/experiment_setup.py
+++ b/src/run_modules/experiment_setup.py
@@ -35,7 +35,8 @@ class ExperimentSetup(RunEnvironment):
                  test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None,
                  experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily",
                  create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None,
-                 train_min_length=None, val_min_length=None, test_min_length=None):
+                 train_min_length=None, val_min_length=None, test_min_length=None, extreme_values=None,
+                 extremes_on_right_tail_only=None):
 
         # create run framework
         super().__init__()
@@ -50,7 +51,11 @@ class ExperimentSetup(RunEnvironment):
         self._set_param("bootstrap_path", bootstrap_path)
         self._set_param("trainable", trainable, default=True)
         self._set_param("fraction_of_training", fraction_of_train, default=0.8)
-        self._set_param("permute_data", permute_data_on_training, default=False, scope="general.train")
+        self._set_param("extreme_values", extreme_values, default=None, scope="general.train")
+        self._set_param("extremes_on_right_tail_only", extremes_on_right_tail_only, default=False, scope="general.train")
+        self._set_param("upsampling", extreme_values is not None, scope="general.train")
+        upsampling = self.data_store.get("upsampling", "general.train")
+        self._set_param("permute_data", max([permute_data_on_training, upsampling]), default=False, scope="general.train")
 
         # set experiment name
         exp_date = self._get_parser_args(parser_args).get("experiment_date")
diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py
index 20286bc4..439793f9 100644
--- a/src/run_modules/pre_processing.py
+++ b/src/run_modules/pre_processing.py
@@ -12,7 +12,8 @@ from src.run_modules.run_environment import RunEnvironment
 
 DEFAULT_ARGS_LIST = ["data_path", "network", "stations", "variables", "interpolate_dim", "target_dim", "target_var"]
 DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history_size", "window_lead_time", "statistics_per_var", "min_length",
-                       "station_type", "overwrite_local_data", "start", "end", "sampling", "transformation"]
+                       "station_type", "overwrite_local_data", "start", "end", "sampling", "transformation",
+                       "extreme_values", "extremes_on_right_tail_only"]
 
 
 class PreProcessing(RunEnvironment):
diff --git a/src/run_modules/training.py b/src/run_modules/training.py
index 55b5c296..0d6279b1 100644
--- a/src/run_modules/training.py
+++ b/src/run_modules/training.py
@@ -9,19 +9,21 @@ import pickle
 import keras
 
 from src.data_handling.data_distributor import Distributor
-from src.model_modules.keras_extensions import LearningRateDecay, ModelCheckpointAdvanced, CallbackHandler
+from src.model_modules.keras_extensions import LearningRateDecay, CallbackHandler
 from src.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate
 from src.run_modules.run_environment import RunEnvironment
 
+from typing import Union
+
 
 class Training(RunEnvironment):
 
     def __init__(self):
         super().__init__()
         self.model: keras.Model = self.data_store.get("model", "general.model")
-        self.train_set = None
-        self.val_set = None
-        self.test_set = None
+        self.train_set: Union[Distributor, None] = None
+        self.val_set: Union[Distributor, None] = None
+        self.test_set: Union[Distributor, None] = None
         self.batch_size = self.data_store.get("batch_size", "general.model")
         self.epochs = self.data_store.get("epochs", "general.model")
         self.callbacks: CallbackHandler = self.data_store.get("callbacks", "general.model")
@@ -65,8 +67,9 @@ class Training(RunEnvironment):
         :param mode: name of set, should be from ["train", "val", "test"]
         """
         gen = self.data_store.get("generator", f"general.{mode}")
-        permute_data = self.data_store.get_default("permute_data", f"general.{mode}", default=False)
-        setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size, permute_data=permute_data))
+        # permute_data = self.data_store.get_default("permute_data", f"general.{mode}", default=False)
+        kwargs = self.data_store.create_args_dict(["permute_data", "upsampling"], scope=f"general.{mode}")
+        setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size, **kwargs))
 
     def set_generators(self) -> None:
         """
@@ -86,6 +89,9 @@ class Training(RunEnvironment):
         locally stored information and the corresponding model and proceed with the already started training.
         """
         logging.info(f"Train with {len(self.train_set)} mini batches.")
+        logging.info(f"Train with option upsampling={self.train_set.upsampling}.")
+        logging.info(f"Train with option data_permutation={self.train_set.do_data_permutation}.")
+
         checkpoint = self.callbacks.get_checkpoint()
         if not os.path.exists(checkpoint.filepath) or self._create_new_model:
             history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
-- 
GitLab