From 028b9df860f09902561bd4b5e5d0c40e45d0a87a Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Fri, 19 Jun 2020 16:32:16 +0200
Subject: [PATCH] removed epochs

---
 src/model_modules/model_class.py    | 4 ----
 src/run_modules/experiment_setup.py | 3 ++-
 src/run_modules/training.py         | 2 +-
 3 files changed, 3 insertions(+), 6 deletions(-)

diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py
index ced01e9a..5dd69608 100644
--- a/src/model_modules/model_class.py
+++ b/src/model_modules/model_class.py
@@ -355,7 +355,6 @@ class MyLittleModel(AbstractModelClass):
         self.channels = channels
         self.dropout_rate = 0.1
         self.regularizer = keras.regularizers.l2(0.1)
-        self.epochs = 20
         self.batch_size = int(256)
         self.activation = keras.layers.PReLU
 
@@ -429,7 +428,6 @@ class MyBranchedModel(AbstractModelClass):
         self.channels = channels
         self.dropout_rate = 0.1
         self.regularizer = keras.regularizers.l2(0.1)
-        self.epochs = 20
         self.batch_size = int(256)
         self.activation = keras.layers.PReLU
 
@@ -504,7 +502,6 @@ class MyTowerModel(AbstractModelClass):
         self.initial_lr = 1e-2
         self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94,
                                                                              epochs_drop=10)
-        self.epochs = 20
         self.batch_size = int(256 * 4)
         self.activation = keras.layers.PReLU
 
@@ -618,7 +615,6 @@ class MyPaperModel(AbstractModelClass):
         self.initial_lr = 1e-3
         self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94,
                                                                              epochs_drop=10)
-        self.epochs = 150
         self.batch_size = int(256 * 2)
         self.activation = keras.layers.ELU
         self.padding = "SymPad2D"
diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py
index 110e7791..1f4c0634 100644
--- a/src/run_modules/experiment_setup.py
+++ b/src/run_modules/experiment_setup.py
@@ -233,7 +233,7 @@ class ExperimentSetup(RunEnvironment):
                  create_new_model: bool = None, bootstrap_path=None, permute_data_on_training: bool = None, transformation=None,
                  train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None,
                  extremes_on_right_tail_only: bool = None, evaluate_bootstraps=True, plot_list=None, number_of_bootstraps=None,
-                 create_new_bootstraps=None, data_path: str = None, login_nodes=None, hpc_hosts=None, model=None):
+                 create_new_bootstraps=None, data_path: str = None, login_nodes=None, hpc_hosts=None, model=None, epochs=None):
 
         # create run framework
         super().__init__()
@@ -257,6 +257,7 @@ class ExperimentSetup(RunEnvironment):
         upsampling = self.data_store.get("upsampling", "train")
         permute_data = False if permute_data_on_training is None else permute_data_on_training
         self._set_param("permute_data", permute_data or upsampling, scope="train")
+        self._set_param("epochs", epochs, default=20)
 
         # set experiment name
         exp_date = self._get_parser_args(parser_args).get("experiment_date")
diff --git a/src/run_modules/training.py b/src/run_modules/training.py
index 8624b515..5df8a15c 100644
--- a/src/run_modules/training.py
+++ b/src/run_modules/training.py
@@ -68,7 +68,7 @@ class Training(RunEnvironment):
         self.val_set: Union[Distributor, None] = None
         self.test_set: Union[Distributor, None] = None
         self.batch_size = self.data_store.get("batch_size", "model")
-        self.epochs = self.data_store.get("epochs", "model")
+        self.epochs = self.data_store.get("epochs")
         self.callbacks: CallbackHandler = self.data_store.get("callbacks", "model")
         self.experiment_name = self.data_store.get("experiment_name")
         self._trainable = self.data_store.get("trainable")
-- 
GitLab