Select Git revision
test_model_setup.py 3.19 KiB
import pytest
import os
import keras
import mock
from src.modules.model_setup import ModelSetup
from src.modules.run_environment import RunEnvironment
from src.data_handling.data_generator import DataGenerator
from src.model_modules.model_class import AbstractModelClass
from src.datastore import EmptyScope
class TestModelSetup:
@pytest.fixture
def setup(self):
obj = object.__new__(ModelSetup)
super(ModelSetup, obj).__init__()
obj.scope = "general.modeltest"
obj.model = None
yield obj
RunEnvironment().__del__()
@pytest.fixture
def gen(self):
return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'],
'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
@pytest.fixture
def setup_with_gen(self, setup, gen):
setup.data_store.set("generator", gen, "general.train")
setup.data_store.set("window_history_size", gen.window_history_size, "general")
setup.data_store.set("window_lead_time", gen.window_lead_time, "general")
yield setup
RunEnvironment().__del__()
@pytest.fixture
def setup_with_model(self, setup_with_gen):
setup_with_gen.data_store.set("channels", 2, "general")
setup_with_gen.model = AbstractModelClass()
setup_with_gen.model.epochs = 2
setup_with_gen.model.batch_size = int(256)
yield setup_with_gen
RunEnvironment().__del__()
@staticmethod
def current_scope_as_set(model_cls):
return set(model_cls.data_store.search_scope(model_cls.scope, current_scope_only=True))
def test_set_checkpoint(self, setup):
assert "general.modeltest" not in setup.data_store.search_name("checkpoint")
setup.checkpoint_name = "TestName"
setup._set_checkpoint()
assert "general.modeltest" in setup.data_store.search_name("checkpoint")
def test_get_model_settings(self, setup_with_model):
with pytest.raises(EmptyScope):
self.current_scope_as_set(setup_with_model) # will fail because scope is not created
setup_with_model.get_model_settings() # this saves now the parameters epochs and batch_size into scope
assert {"epochs", "batch_size"} <= self.current_scope_as_set(setup_with_model)
def test_build_model(self, setup_with_gen):
assert setup_with_gen.model is None
setup_with_gen.build_model()
assert isinstance(setup_with_gen.model, AbstractModelClass)
expected = {"window_history_size", "window_lead_time", "channels", "dropout_rate", "regularizer", "initial_lr",
"optimizer", "lr_decay", "epochs", "batch_size", "activation"}
assert expected <= self.current_scope_as_set(setup_with_gen)
def test_set_channels(self, setup_with_gen):
assert len(setup_with_gen.data_store.search_name("channels")) == 0
setup_with_gen._set_channels()
assert setup_with_gen.data_store.get("channels", setup_with_gen.scope) == 2
def test_load_weights(self):
pass
def test_compile_model(self):
pass
def test_run(self):
pass
def test_init(self):
pass