diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index 139969c76fe21498a2ac738426e33d1680b60a3f..e79a1eef8c6ea3f4082d1b7c146b42e83a3b0eee 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -6,10 +6,10 @@ import logging import os from typing import Union, Dict, Any, List -import src.configuration.path_config from src.configuration import path_config from src import helpers from src.run_modules.run_environment import RunEnvironment +from src.model_modules.model_class import MyLittleModel as VanillaModel DEFAULT_STATIONS = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBY052', 'DEBY032', 'DEBW022', 'DEBY004', 'DEBY020', 'DEBW030', 'DEBW037', 'DEBW031', 'DEBW015', 'DEBW073', 'DEBY039', 'DEBW038', 'DEBW081', @@ -111,13 +111,6 @@ class ExperimentSetup(RunEnvironment): self._compare_variables_and_statistics() - - - - - - - Creates * plot of model architecture in `<model_name>.pdf` @@ -240,7 +233,7 @@ class ExperimentSetup(RunEnvironment): create_new_model: bool = None, bootstrap_path=None, permute_data_on_training: bool = None, transformation=None, train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None, extremes_on_right_tail_only: bool = None, evaluate_bootstraps=True, plot_list=None, number_of_bootstraps=None, - create_new_bootstraps=None, data_path: str = None, login_nodes=None, hpc_hosts=None): + create_new_bootstraps=None, data_path: str = None, login_nodes=None, hpc_hosts=None, model=None): # create run framework super().__init__() @@ -351,6 +344,9 @@ class ExperimentSetup(RunEnvironment): self._check_target_var() self._compare_variables_and_statistics() + # set model architecture class + self._set_param("model_class", model, VanillaModel) + def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general") -> None: """Set given parameter and log in debug.""" if value is None and default is not None: diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index 618d7cd8bfff3a253b5f3084512e9ba72c603c8c..13a13bb72fd634b6ebbec20f46a3a08a9f0afa8e 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -10,10 +10,6 @@ import keras import tensorflow as tf from src.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler -# from src.model_modules.model_class import MyBranchedModel as MyModel -from src.model_modules.model_class import MyLittleModel as MyModel -# from src.model_modules.model_class import MyTowerModel as MyModel -# from src.model_modules.model_class import MyPaperModel as MyModel from src.run_modules.run_environment import RunEnvironment @@ -134,7 +130,8 @@ class ModelSetup(RunEnvironment): """Build model using window_history_size, window_lead_time and channels from data store.""" args_list = ["window_history_size", "window_lead_time", "channels"] args = self.data_store.create_args_dict(args_list, self.scope) - self.model = MyModel(**args) + model = self.data_store.get("model_class") + self.model = model(**args) self.get_model_settings() def get_model_settings(self): diff --git a/test/test_configuration/test_path_config.py b/test/test_configuration/test_path_config.py index b27eae67a82432de08b9a1de77233121f5c7e2e4..557e59d7c9cd1ed45e2fc80c93ede020976e2976 100644 --- a/test/test_configuration/test_path_config.py +++ b/test/test_configuration/test_path_config.py @@ -58,9 +58,10 @@ class TestSetExperimentName: exp_name, exp_path = set_experiment_name() assert exp_name == "TestExperiment" assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "TestExperiment")) - exp_name, exp_path = set_experiment_name(experiment_name="2019-11-14", experiment_path="./test2") + exp_name, exp_path = set_experiment_name(experiment_name="2019-11-14", experiment_path=os.path.join( + os.path.dirname(__file__), "test2")) assert exp_name == "2019-11-14_network" - assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "test2", exp_name)) + assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "test2", exp_name)) def test_set_experiment_from_sys(self): exp_name, _ = set_experiment_name(experiment_name="2019-11-14") diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py index 24e89772cd2e9ddc4418617e2a622fe325d94b2f..b8652c33852b56b5eaee369a4662b68b95316291 100644 --- a/test/test_modules/test_model_setup.py +++ b/test/test_modules/test_model_setup.py @@ -5,7 +5,7 @@ import pytest from src.data_handling.data_generator import DataGenerator from src.helpers.datastore import EmptyScope from src.model_modules.keras_extensions import CallbackHandler -from src.model_modules.model_class import AbstractModelClass +from src.model_modules.model_class import AbstractModelClass, MyLittleModel from src.run_modules.model_setup import ModelSetup from src.run_modules.run_environment import RunEnvironment @@ -19,6 +19,7 @@ class TestModelSetup: obj.scope = "general.model" obj.model = None obj.callbacks_name = "placeholder_%s_str.pickle" + obj.data_store.set("model_class", MyLittleModel) obj.data_store.set("lr_decay", "dummy_str", "general.model") obj.data_store.set("hist", "dummy_str", "general.model") obj.model_name = "%s.h5"