diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py index ecaef6325c2bfef2246134c23d04df234b80bad2..1a8f7c4c400eaf75bdd1dc6af2e0993f662eac49 100644 --- a/src/model_modules/model_class.py +++ b/src/model_modules/model_class.py @@ -87,7 +87,7 @@ class MyLittleModel(AbstractModelClass): Dense layer. """ - def __init__(self, input_x, window_history_size, window_lead_time): + def __init__(self, window_history_size, window_lead_time, channels): """ Sets model and loss depending on the given arguments. @@ -104,7 +104,7 @@ class MyLittleModel(AbstractModelClass): # settings self.window_history_size = window_history_size self.window_lead_time = window_lead_time - self.channels = input_x.shape[-1] # input variables + self.channels = channels self.dropout_rate = 0.1 self.regularizer = keras.regularizers.l2(0.1) self.initial_lr = 1e-2 diff --git a/src/modules/model_setup.py b/src/modules/model_setup.py index f7f6b4f4b4594b8b7ad7b0016b4289258f787978..a62b53b86651109c4c1dd10d4a7dfccbaf3cf9c2 100644 --- a/src/modules/model_setup.py +++ b/src/modules/model_setup.py @@ -36,6 +36,9 @@ class ModelSetup(RunEnvironment): # create checkpoint self._set_checkpoint() + # set channels depending on inputs + self._set_channels() + # build model graph using settings from my_model_settings() self.build_model() @@ -49,6 +52,10 @@ class ModelSetup(RunEnvironment): # compile model self.compile_model() + def _set_channels(self): + channels = self.data_store.get("generator", "general.train")[0][0].shape[-1] + self.data_store.set("channels", channels, self.scope) + def compile_model(self): optimizer = self.data_store.get("optimizer", self.scope) loss = self.model.loss @@ -67,10 +74,12 @@ class ModelSetup(RunEnvironment): logging.info('no weights to reload...') def build_model(self): - args_list = ["window_history_size", "window_lead_time"] + args_list = ["window_history_size", "window_lead_time", "channels"] args = self.data_store.create_args_dict(args_list, self.scope) - input_x = self.data_store.get("generator", "general.train")[0][0] - self.model = MyLittleModel(input_x, **args) + self.model = MyLittleModel(**args) + self.get_model_settings() + + def get_model_settings(self): model_settings = self.model.get_settings() self.data_store.set_args_from_dict(model_settings, self.scope) diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py index 95a242b093e298ef0cd543e0437b5dc4720c7abc..5a5a7bbd5e95f0db9f88fe3fcb944718b8ddab94 100644 --- a/test/test_modules/test_model_setup.py +++ b/test/test_modules/test_model_setup.py @@ -1,10 +1,13 @@ import pytest import os import keras +import mock from src.modules.model_setup import ModelSetup from src.modules.run_environment import RunEnvironment from src.data_handling.data_generator import DataGenerator +from src.model_modules.model_class import AbstractModelClass +from src.datastore import EmptyScope class TestModelSetup: @@ -31,23 +34,43 @@ class TestModelSetup: yield setup RunEnvironment().__del__() + @pytest.fixture + def setup_with_model(self, setup_with_gen): + setup_with_gen.data_store.set("channels", 2, "general") + setup_with_gen.model = AbstractModelClass() + setup_with_gen.model.epochs = 2 + setup_with_gen.model.batch_size = int(256) + yield setup_with_gen + RunEnvironment().__del__() + + @staticmethod + def current_scope_as_set(model_cls): + return set(model_cls.data_store.search_scope(model_cls.scope, current_scope_only=True)) + def test_set_checkpoint(self, setup): assert "general.modeltest" not in setup.data_store.search_name("checkpoint") setup.checkpoint_name = "TestName" setup._set_checkpoint() assert "general.modeltest" in setup.data_store.search_name("checkpoint") - def test_my_model_settings(self, setup_with_gen): - setup_with_gen.my_model_settings() - expected = {"channels", "dropout_rate", "regularizer", "initial_lr", "optimizer", "lr_decay", "epochs", - "batch_size", "activation", "loss"} - assert expected <= set(setup_with_gen.data_store.search_scope(setup_with_gen.scope, current_scope_only=True)) + def test_get_model_settings(self, setup_with_model): + with pytest.raises(EmptyScope): + self.current_scope_as_set(setup_with_model) # will fail because scope is not created + setup_with_model.get_model_settings() # this saves now the parameters epochs and batch_size into scope + assert {"epochs", "batch_size"} <= self.current_scope_as_set(setup_with_model) def test_build_model(self, setup_with_gen): - setup_with_gen.my_model_settings() assert setup_with_gen.model is None setup_with_gen.build_model() - assert isinstance(setup_with_gen.model, keras.Model) + assert isinstance(setup_with_gen.model, AbstractModelClass) + expected = {"window_history_size", "window_lead_time", "channels", "dropout_rate", "regularizer", "initial_lr", + "optimizer", "lr_decay", "epochs", "batch_size", "activation"} + assert expected <= self.current_scope_as_set(setup_with_gen) + + def test_set_channels(self, setup_with_gen): + assert len(setup_with_gen.data_store.search_name("channels")) == 0 + setup_with_gen._set_channels() + assert setup_with_gen.data_store.get("channels", setup_with_gen.scope) == 2 def test_load_weights(self): pass @@ -55,3 +78,9 @@ class TestModelSetup: def test_compile_model(self): pass + def test_run(self): + pass + + def test_init(self): + pass +