diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py index 8a065fa5a1f1ed3159b6a90faba03dd00c390452..dab2e168c5a9f87d4aee42fc94489fd0fa67772a 100644 --- a/src/model_modules/model_class.py +++ b/src/model_modules/model_class.py @@ -354,7 +354,6 @@ class MyLittleModel(AbstractModelClass): self.channels = channels self.dropout_rate = 0.1 self.regularizer = keras.regularizers.l2(0.1) - self.epochs = 20 self.activation = keras.layers.PReLU # apply to model @@ -427,7 +426,6 @@ class MyBranchedModel(AbstractModelClass): self.channels = channels self.dropout_rate = 0.1 self.regularizer = keras.regularizers.l2(0.1) - self.epochs = 20 self.activation = keras.layers.PReLU # apply to model @@ -501,7 +499,6 @@ class MyTowerModel(AbstractModelClass): self.initial_lr = 1e-2 self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10) - self.epochs = 20 self.activation = keras.layers.PReLU # apply to model @@ -614,7 +611,6 @@ class MyPaperModel(AbstractModelClass): self.initial_lr = 1e-3 self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10) - self.epochs = 150 self.activation = keras.layers.ELU self.padding = "SymPad2D" diff --git a/src/run.py b/src/run.py index eda0373c1e609e0818e98358d00a00beddb63cdf..11029817a978b872d0f99954a50ab5f5b93aa012 100644 --- a/src/run.py +++ b/src/run.py @@ -26,7 +26,8 @@ def run(stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW00 evaluate_bootstraps=True, number_of_bootstraps=None, create_new_bootstraps=False, plot_list=None, model=None, - batch_size=None): + batch_size=None, + epochs=None): params = inspect.getfullargspec(ExperimentSetup).args kwargs = {k: v for k, v in locals().items() if k in params} diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index 5443e265762d7286ffd507db38a86a479e6cfc3f..ff6fec842714d599696b8726e9d25aa22e55583f 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -233,7 +233,8 @@ class ExperimentSetup(RunEnvironment): create_new_model: bool = None, bootstrap_path=None, permute_data_on_training: bool = None, transformation=None, train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None, extremes_on_right_tail_only: bool = None, evaluate_bootstraps=True, plot_list=None, number_of_bootstraps=None, - create_new_bootstraps=None, data_path: str = None, login_nodes=None, hpc_hosts=None, model=None, batch_size=None): + create_new_bootstraps=None, data_path: str = None, login_nodes=None, hpc_hosts=None, model=None, + batch_size=None, epochs=None): # create run framework super().__init__() @@ -258,6 +259,7 @@ class ExperimentSetup(RunEnvironment): permute_data = False if permute_data_on_training is None else permute_data_on_training self._set_param("permute_data", permute_data or upsampling, scope="train") self._set_param("batch_size", batch_size, default=int(256 * 2)) + self._set_param("epochs", epochs, default=20) # set experiment name exp_date = self._get_parser_args(parser_args).get("experiment_date") diff --git a/src/run_modules/training.py b/src/run_modules/training.py index 74e9b799bfc1ea942063cc89b9d3aa984e7e4882..1a0d7beb1ec37bb5e59a4129da58572d79a73636 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -34,7 +34,7 @@ class Training(RunEnvironment): Required objects [scope] from data store: * `model` [model] * `batch_size` [.] - * `epochs` [model] + * `epochs` [.] * `callbacks` [model] * `model_name` [model] * `experiment_name` [.] @@ -68,7 +68,7 @@ class Training(RunEnvironment): self.val_set: Union[Distributor, None] = None self.test_set: Union[Distributor, None] = None self.batch_size = self.data_store.get("batch_size") - self.epochs = self.data_store.get("epochs", "model") + self.epochs = self.data_store.get("epochs") self.callbacks: CallbackHandler = self.data_store.get("callbacks", "model") self.experiment_name = self.data_store.get("experiment_name") self._trainable = self.data_store.get("trainable") diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py index b715445e61b050b5f52594c134ee261935a87cf2..60d140f8845b25432184de1f3890b3ee4d0b034e 100644 --- a/test/test_modules/test_model_setup.py +++ b/test/test_modules/test_model_setup.py @@ -22,6 +22,7 @@ class TestModelSetup: obj.data_store.set("model_class", MyLittleModel) obj.data_store.set("lr_decay", "dummy_str", "general.model") obj.data_store.set("hist", "dummy_str", "general.model") + obj.data_store.set("epochs", 2) obj.model_name = "%s.h5" yield obj RunEnvironment().__del__() @@ -49,8 +50,7 @@ class TestModelSetup: @pytest.fixture def setup_with_model(self, setup): setup.model = AbstractModelClass() - setup.model.epochs = 2 - setup.model.batch_size = int(256) + setup.model.test_param = "42" yield setup RunEnvironment().__del__() @@ -80,15 +80,15 @@ class TestModelSetup: setup_with_model.scope = "model_test" with pytest.raises(EmptyScope): self.current_scope_as_set(setup_with_model) # will fail because scope is not created - setup_with_model.get_model_settings() # this saves now the parameters epochs and batch_size into scope - assert {"epochs", "batch_size"} <= self.current_scope_as_set(setup_with_model) + setup_with_model.get_model_settings() # this saves now the parameter test_param into scope + assert {"test_param", "model_name"} <= self.current_scope_as_set(setup_with_model) def test_build_model(self, setup_with_gen): assert setup_with_gen.model is None setup_with_gen.build_model() assert isinstance(setup_with_gen.model, AbstractModelClass) expected = {"window_history_size", "window_lead_time", "channels", "dropout_rate", "regularizer", "initial_lr", - "optimizer", "epochs", "activation"} + "optimizer", "activation"} assert expected <= self.current_scope_as_set(setup_with_gen) def test_set_channels(self, setup_with_gen_tiny): diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py index 66ba0709c21b105bd798cd35f20715e6c0a83177..eb5dfe5adb170981d5d67c94ca1fbcb55e326550 100644 --- a/test/test_modules/test_training.py +++ b/test/test_modules/test_training.py @@ -156,7 +156,7 @@ class TestTraining: obj.data_store.set("model_path", model_path, "general") obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model") obj.data_store.set("batch_size", 256, "general") - obj.data_store.set("epochs", 2, "general.model") + obj.data_store.set("epochs", 2, "general") clbk, hist, lr = callbacks obj.data_store.set("callbacks", clbk, "general.model") obj.data_store.set("lr_decay", lr, "general.model")