import os import pytest from src.data_handling.data_generator import DataGenerator from src.helpers.datastore import EmptyScope from src.model_modules.keras_extensions import CallbackHandler from src.model_modules.model_class import AbstractModelClass from src.run_modules.model_setup import ModelSetup from src.run_modules.run_environment import RunEnvironment class TestModelSetup: @pytest.fixture def setup(self): obj = object.__new__(ModelSetup) super(ModelSetup, obj).__init__() obj.scope = "general.model" obj.model = None obj.callbacks_name = "placeholder_%s_str.pickle" obj.data_store.set("lr_decay", "dummy_str", "general.model") obj.data_store.set("hist", "dummy_str", "general.model") obj.model_name = "%s.h5" yield obj RunEnvironment().__del__() @pytest.fixture def gen(self): return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], 'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) @pytest.fixture def setup_with_gen(self, setup, gen): setup.data_store.set("generator", gen, "general.train") setup.data_store.set("window_history_size", gen.window_history_size, "general") setup.data_store.set("window_lead_time", gen.window_lead_time, "general") setup.data_store.set("channels", 2, "general") yield setup RunEnvironment().__del__() @pytest.fixture def setup_with_gen_tiny(self, setup, gen): setup.data_store.set("generator", gen, "general.train") yield setup RunEnvironment().__del__() @pytest.fixture def setup_with_model(self, setup): setup.model = AbstractModelClass() setup.model.epochs = 2 setup.model.batch_size = int(256) yield setup RunEnvironment().__del__() @staticmethod def current_scope_as_set(model_cls): return set(model_cls.data_store.search_scope(model_cls.scope, current_scope_only=True)) def test_set_callbacks(self, setup): assert "general.model" not in setup.data_store.search_name("callbacks") setup.checkpoint_name = "TestName" setup._set_callbacks() assert "general.model" in setup.data_store.search_name("callbacks") callbacks = setup.data_store.get("callbacks", "general.model") assert len(callbacks.get_callbacks()) == 3 def test_set_callbacks_no_lr_decay(self, setup): setup.data_store.set("lr_decay", None, "general.model") assert "general.model" not in setup.data_store.search_name("callbacks") setup.checkpoint_name = "TestName" setup._set_callbacks() callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.model") assert len(callbacks.get_callbacks()) == 2 with pytest.raises(IndexError): callbacks.get_callback_by_name("lr_decay") def test_get_model_settings(self, setup_with_model): setup_with_model.scope = "model_test" with pytest.raises(EmptyScope): self.current_scope_as_set(setup_with_model) # will fail because scope is not created setup_with_model.get_model_settings() # this saves now the parameters epochs and batch_size into scope assert {"epochs", "batch_size"} <= self.current_scope_as_set(setup_with_model) def test_build_model(self, setup_with_gen): assert setup_with_gen.model is None setup_with_gen.build_model() assert isinstance(setup_with_gen.model, AbstractModelClass) expected = {"window_history_size", "window_lead_time", "channels", "dropout_rate", "regularizer", "initial_lr", "optimizer", "epochs", "batch_size", "activation"} assert expected <= self.current_scope_as_set(setup_with_gen) def test_set_channels(self, setup_with_gen_tiny): assert len(setup_with_gen_tiny.data_store.search_name("channels")) == 0 setup_with_gen_tiny._set_channels() assert setup_with_gen_tiny.data_store.get("channels", setup_with_gen_tiny.scope) == 2 def test_load_weights(self): pass def test_compile_model(self): pass def test_run(self): pass def test_init(self): pass