From 6329b0b7554428929d706ff45b4519d397a5ebbf Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Tue, 17 Mar 2020 15:21:54 +0100 Subject: [PATCH] upsampling is included in experiment workflow and is exclusively restricted to train data --- src/run_modules/experiment_setup.py | 9 +++++++-- src/run_modules/pre_processing.py | 3 ++- src/run_modules/training.py | 18 ++++++++++++------ 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index 2aafe6c6..39d973b2 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -35,7 +35,8 @@ class ExperimentSetup(RunEnvironment): test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None, experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily", create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None, - train_min_length=None, val_min_length=None, test_min_length=None): + train_min_length=None, val_min_length=None, test_min_length=None, extreme_values=None, + extremes_on_right_tail_only=None): # create run framework super().__init__() @@ -50,7 +51,11 @@ class ExperimentSetup(RunEnvironment): self._set_param("bootstrap_path", bootstrap_path) self._set_param("trainable", trainable, default=True) self._set_param("fraction_of_training", fraction_of_train, default=0.8) - self._set_param("permute_data", permute_data_on_training, default=False, scope="general.train") + self._set_param("extreme_values", extreme_values, default=None, scope="general.train") + self._set_param("extremes_on_right_tail_only", extremes_on_right_tail_only, default=False, scope="general.train") + self._set_param("upsampling", extreme_values is not None, scope="general.train") + upsampling = self.data_store.get("upsampling", "general.train") + self._set_param("permute_data", max([permute_data_on_training, upsampling]), default=False, scope="general.train") # set experiment name exp_date = self._get_parser_args(parser_args).get("experiment_date") diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py index 20286bc4..439793f9 100644 --- a/src/run_modules/pre_processing.py +++ b/src/run_modules/pre_processing.py @@ -12,7 +12,8 @@ from src.run_modules.run_environment import RunEnvironment DEFAULT_ARGS_LIST = ["data_path", "network", "stations", "variables", "interpolate_dim", "target_dim", "target_var"] DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history_size", "window_lead_time", "statistics_per_var", "min_length", - "station_type", "overwrite_local_data", "start", "end", "sampling", "transformation"] + "station_type", "overwrite_local_data", "start", "end", "sampling", "transformation", + "extreme_values", "extremes_on_right_tail_only"] class PreProcessing(RunEnvironment): diff --git a/src/run_modules/training.py b/src/run_modules/training.py index 55b5c296..0d6279b1 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -9,19 +9,21 @@ import pickle import keras from src.data_handling.data_distributor import Distributor -from src.model_modules.keras_extensions import LearningRateDecay, ModelCheckpointAdvanced, CallbackHandler +from src.model_modules.keras_extensions import LearningRateDecay, CallbackHandler from src.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate from src.run_modules.run_environment import RunEnvironment +from typing import Union + class Training(RunEnvironment): def __init__(self): super().__init__() self.model: keras.Model = self.data_store.get("model", "general.model") - self.train_set = None - self.val_set = None - self.test_set = None + self.train_set: Union[Distributor, None] = None + self.val_set: Union[Distributor, None] = None + self.test_set: Union[Distributor, None] = None self.batch_size = self.data_store.get("batch_size", "general.model") self.epochs = self.data_store.get("epochs", "general.model") self.callbacks: CallbackHandler = self.data_store.get("callbacks", "general.model") @@ -65,8 +67,9 @@ class Training(RunEnvironment): :param mode: name of set, should be from ["train", "val", "test"] """ gen = self.data_store.get("generator", f"general.{mode}") - permute_data = self.data_store.get_default("permute_data", f"general.{mode}", default=False) - setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size, permute_data=permute_data)) + # permute_data = self.data_store.get_default("permute_data", f"general.{mode}", default=False) + kwargs = self.data_store.create_args_dict(["permute_data", "upsampling"], scope=f"general.{mode}") + setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size, **kwargs)) def set_generators(self) -> None: """ @@ -86,6 +89,9 @@ class Training(RunEnvironment): locally stored information and the corresponding model and proceed with the already started training. """ logging.info(f"Train with {len(self.train_set)} mini batches.") + logging.info(f"Train with option upsampling={self.train_set.upsampling}.") + logging.info(f"Train with option data_permutation={self.train_set.do_data_permutation}.") + checkpoint = self.callbacks.get_checkpoint() if not os.path.exists(checkpoint.filepath) or self._create_new_model: history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(), -- GitLab