From cc88fd2e5d9b189f8a0293f3c931e96f991ba900 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Fri, 19 Jun 2020 14:02:55 +0200 Subject: [PATCH] moved model import from model_setup to exp_setup --- src/run_modules/experiment_setup.py | 14 +++++--------- src/run_modules/model_setup.py | 7 ++----- test/test_configuration/test_path_config.py | 2 +- test/test_modules/test_model_setup.py | 3 ++- 4 files changed, 10 insertions(+), 16 deletions(-) diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index 139969c7..ac353716 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -6,10 +6,10 @@ import logging import os from typing import Union, Dict, Any, List -import src.configuration.path_config from src.configuration import path_config from src import helpers from src.run_modules.run_environment import RunEnvironment +from src.model_modules.model_class import MyLittleModel as DefaultModel DEFAULT_STATIONS = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBY052', 'DEBY032', 'DEBW022', 'DEBY004', 'DEBY020', 'DEBW030', 'DEBW037', 'DEBW031', 'DEBW015', 'DEBW073', 'DEBY039', 'DEBW038', 'DEBW081', @@ -111,13 +111,6 @@ class ExperimentSetup(RunEnvironment): self._compare_variables_and_statistics() - - - - - - - Creates * plot of model architecture in `<model_name>.pdf` @@ -240,7 +233,7 @@ class ExperimentSetup(RunEnvironment): create_new_model: bool = None, bootstrap_path=None, permute_data_on_training: bool = None, transformation=None, train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None, extremes_on_right_tail_only: bool = None, evaluate_bootstraps=True, plot_list=None, number_of_bootstraps=None, - create_new_bootstraps=None, data_path: str = None, login_nodes=None, hpc_hosts=None): + create_new_bootstraps=None, data_path: str = None, login_nodes=None, hpc_hosts=None, model=None): # create run framework super().__init__() @@ -351,6 +344,9 @@ class ExperimentSetup(RunEnvironment): self._check_target_var() self._compare_variables_and_statistics() + # set model architecture class + self._set_param("model_class", model, DefaultModel) + def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general") -> None: """Set given parameter and log in debug.""" if value is None and default is not None: diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index 618d7cd8..13a13bb7 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -10,10 +10,6 @@ import keras import tensorflow as tf from src.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler -# from src.model_modules.model_class import MyBranchedModel as MyModel -from src.model_modules.model_class import MyLittleModel as MyModel -# from src.model_modules.model_class import MyTowerModel as MyModel -# from src.model_modules.model_class import MyPaperModel as MyModel from src.run_modules.run_environment import RunEnvironment @@ -134,7 +130,8 @@ class ModelSetup(RunEnvironment): """Build model using window_history_size, window_lead_time and channels from data store.""" args_list = ["window_history_size", "window_lead_time", "channels"] args = self.data_store.create_args_dict(args_list, self.scope) - self.model = MyModel(**args) + model = self.data_store.get("model_class") + self.model = model(**args) self.get_model_settings() def get_model_settings(self): diff --git a/test/test_configuration/test_path_config.py b/test/test_configuration/test_path_config.py index b27eae67..164326cc 100644 --- a/test/test_configuration/test_path_config.py +++ b/test/test_configuration/test_path_config.py @@ -60,7 +60,7 @@ class TestSetExperimentName: assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "TestExperiment")) exp_name, exp_path = set_experiment_name(experiment_name="2019-11-14", experiment_path="./test2") assert exp_name == "2019-11-14_network" - assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "test2", exp_name)) + assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "test2", exp_name)) def test_set_experiment_from_sys(self): exp_name, _ = set_experiment_name(experiment_name="2019-11-14") diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py index 24e89772..b8652c33 100644 --- a/test/test_modules/test_model_setup.py +++ b/test/test_modules/test_model_setup.py @@ -5,7 +5,7 @@ import pytest from src.data_handling.data_generator import DataGenerator from src.helpers.datastore import EmptyScope from src.model_modules.keras_extensions import CallbackHandler -from src.model_modules.model_class import AbstractModelClass +from src.model_modules.model_class import AbstractModelClass, MyLittleModel from src.run_modules.model_setup import ModelSetup from src.run_modules.run_environment import RunEnvironment @@ -19,6 +19,7 @@ class TestModelSetup: obj.scope = "general.model" obj.model = None obj.callbacks_name = "placeholder_%s_str.pickle" + obj.data_store.set("model_class", MyLittleModel) obj.data_store.set("lr_decay", "dummy_str", "general.model") obj.data_store.set("hist", "dummy_str", "general.model") obj.model_name = "%s.h5" -- GitLab