Skip to content
Snippets Groups Projects
Commit 496e332b authored by lukas leufen's avatar lukas leufen
Browse files

Merge branch 'lukas_issue121_refac_model-in-exp-setup' into 'develop'

Resolve "model link in experiment setup"

See merge request toar/machinelearningtools!102
parents 980d9ded 109e51d9
Branches
Tags
3 merge requests!125Release v0.10.0,!124Update Master to new version v0.10.0,!102Resolve "model link in experiment setup"
Pipeline #39333 passed
......@@ -6,10 +6,10 @@ import logging
import os
from typing import Union, Dict, Any, List
import src.configuration.path_config
from src.configuration import path_config
from src import helpers
from src.run_modules.run_environment import RunEnvironment
from src.model_modules.model_class import MyLittleModel as VanillaModel
DEFAULT_STATIONS = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBY052', 'DEBY032', 'DEBW022', 'DEBY004',
'DEBY020', 'DEBW030', 'DEBW037', 'DEBW031', 'DEBW015', 'DEBW073', 'DEBY039', 'DEBW038', 'DEBW081',
......@@ -111,13 +111,6 @@ class ExperimentSetup(RunEnvironment):
self._compare_variables_and_statistics()
Creates
* plot of model architecture in `<model_name>.pdf`
......@@ -240,7 +233,7 @@ class ExperimentSetup(RunEnvironment):
create_new_model: bool = None, bootstrap_path=None, permute_data_on_training: bool = None, transformation=None,
train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None,
extremes_on_right_tail_only: bool = None, evaluate_bootstraps=True, plot_list=None, number_of_bootstraps=None,
create_new_bootstraps=None, data_path: str = None, login_nodes=None, hpc_hosts=None):
create_new_bootstraps=None, data_path: str = None, login_nodes=None, hpc_hosts=None, model=None):
# create run framework
super().__init__()
......@@ -351,6 +344,9 @@ class ExperimentSetup(RunEnvironment):
self._check_target_var()
self._compare_variables_and_statistics()
# set model architecture class
self._set_param("model_class", model, VanillaModel)
def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general") -> None:
"""Set given parameter and log in debug."""
if value is None and default is not None:
......
......@@ -10,10 +10,6 @@ import keras
import tensorflow as tf
from src.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler
# from src.model_modules.model_class import MyBranchedModel as MyModel
from src.model_modules.model_class import MyLittleModel as MyModel
# from src.model_modules.model_class import MyTowerModel as MyModel
# from src.model_modules.model_class import MyPaperModel as MyModel
from src.run_modules.run_environment import RunEnvironment
......@@ -134,7 +130,8 @@ class ModelSetup(RunEnvironment):
"""Build model using window_history_size, window_lead_time and channels from data store."""
args_list = ["window_history_size", "window_lead_time", "channels"]
args = self.data_store.create_args_dict(args_list, self.scope)
self.model = MyModel(**args)
model = self.data_store.get("model_class")
self.model = model(**args)
self.get_model_settings()
def get_model_settings(self):
......
......@@ -58,9 +58,10 @@ class TestSetExperimentName:
exp_name, exp_path = set_experiment_name()
assert exp_name == "TestExperiment"
assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "TestExperiment"))
exp_name, exp_path = set_experiment_name(experiment_name="2019-11-14", experiment_path="./test2")
exp_name, exp_path = set_experiment_name(experiment_name="2019-11-14", experiment_path=os.path.join(
os.path.dirname(__file__), "test2"))
assert exp_name == "2019-11-14_network"
assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "test2", exp_name))
assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "test2", exp_name))
def test_set_experiment_from_sys(self):
exp_name, _ = set_experiment_name(experiment_name="2019-11-14")
......
......@@ -5,7 +5,7 @@ import pytest
from src.data_handling.data_generator import DataGenerator
from src.helpers.datastore import EmptyScope
from src.model_modules.keras_extensions import CallbackHandler
from src.model_modules.model_class import AbstractModelClass
from src.model_modules.model_class import AbstractModelClass, MyLittleModel
from src.run_modules.model_setup import ModelSetup
from src.run_modules.run_environment import RunEnvironment
......@@ -19,6 +19,7 @@ class TestModelSetup:
obj.scope = "general.model"
obj.model = None
obj.callbacks_name = "placeholder_%s_str.pickle"
obj.data_store.set("model_class", MyLittleModel)
obj.data_store.set("lr_decay", "dummy_str", "general.model")
obj.data_store.set("hist", "dummy_str", "general.model")
obj.model_name = "%s.h5"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment