import logging import pytest import os import keras from src.modules.model_setup import ModelSetup from src.modules.run_environment import RunEnvironment from src.data_generator import DataGenerator class TestModelSetup: @pytest.fixture def setup(self): obj = object.__new__(ModelSetup) super(ModelSetup, obj).__init__() obj.scope = "general.modeltest" obj.model = None yield obj RunEnvironment().__del__() @pytest.fixture def gen(self): return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], 'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) @pytest.fixture def setup_with_gen(self, setup, gen): setup.data_store.put("generator", gen, "general.train") setup.data_store.put("window_history_size", gen.window_history_size, "general") setup.data_store.put("window_lead_time", gen.window_lead_time, "general") yield setup RunEnvironment().__del__() def test_set_checkpoint(self, setup): assert "general.modeltest" not in setup.data_store.search_name("checkpoint") setup.model_name = "TestName" setup._set_checkpoint() assert "general.modeltest" in setup.data_store.search_name("checkpoint") def test_my_model_settings(self, setup_with_gen): setup_with_gen.my_model_settings() expected = {"channels", "dropout_rate", "regularizer", "initial_lr", "optimizer", "lr_decay", "epochs", "batch_size", "activation", "loss"} assert expected <= set(setup_with_gen.data_store.search_scope(setup_with_gen.scope, current_scope_only=True)) def test_build_model(self, setup_with_gen): setup_with_gen.my_model_settings() assert setup_with_gen.model is None setup_with_gen.build_model() assert isinstance(setup_with_gen.model, keras.Model) def test_load_weights(self): pass def test_compile_model(self): pass