Select Git revision
test_model_setup.py 4.30 KiB
import os
import pytest
from src.data_handling.data_generator import DataGenerator
from src.helpers.datastore import EmptyScope
from src.model_modules.keras_extensions import CallbackHandler
from src.model_modules.model_class import AbstractModelClass, MyLittleModel
from src.run_modules.model_setup import ModelSetup
from src.run_modules.run_environment import RunEnvironment
class TestModelSetup:
@pytest.fixture
def setup(self):
obj = object.__new__(ModelSetup)
super(ModelSetup, obj).__init__()
obj.scope = "general.model"
obj.model = None
obj.callbacks_name = "placeholder_%s_str.pickle"
obj.data_store.set("model_class", MyLittleModel)
obj.data_store.set("lr_decay", "dummy_str", "general.model")
obj.data_store.set("hist", "dummy_str", "general.model")
obj.data_store.set("epochs", 2)
obj.model_name = "%s.h5"
yield obj
RunEnvironment().__del__()
@pytest.fixture
def gen(self):
return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'],
'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
@pytest.fixture
def setup_with_gen(self, setup, gen):
setup.data_store.set("generator", gen, "general.train")
setup.data_store.set("window_history_size", gen.window_history_size, "general")
setup.data_store.set("window_lead_time", gen.window_lead_time, "general")
setup.data_store.set("channels", 2, "general")
yield setup
RunEnvironment().__del__()
@pytest.fixture
def setup_with_gen_tiny(self, setup, gen):
setup.data_store.set("generator", gen, "general.train")
yield setup
RunEnvironment().__del__()
@pytest.fixture
def setup_with_model(self, setup):
setup.model = AbstractModelClass()
setup.model.test_param = "42"
yield setup
RunEnvironment().__del__()
@staticmethod
def current_scope_as_set(model_cls):
return set(model_cls.data_store.search_scope(model_cls.scope, current_scope_only=True))
def test_set_callbacks(self, setup):
assert "general.model" not in setup.data_store.search_name("callbacks")
setup.checkpoint_name = "TestName"
setup._set_callbacks()
assert "general.model" in setup.data_store.search_name("callbacks")
callbacks = setup.data_store.get("callbacks", "general.model")
assert len(callbacks.get_callbacks()) == 3
def test_set_callbacks_no_lr_decay(self, setup):
setup.data_store.set("lr_decay", None, "general.model")
assert "general.model" not in setup.data_store.search_name("callbacks")
setup.checkpoint_name = "TestName"
setup._set_callbacks()
callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.model")
assert len(callbacks.get_callbacks()) == 2
with pytest.raises(IndexError):
callbacks.get_callback_by_name("lr_decay")
def test_get_model_settings(self, setup_with_model):
setup_with_model.scope = "model_test"
with pytest.raises(EmptyScope):
self.current_scope_as_set(setup_with_model) # will fail because scope is not created
setup_with_model.get_model_settings() # this saves now the parameter test_param into scope
assert {"test_param", "model_name"} <= self.current_scope_as_set(setup_with_model)
def test_build_model(self, setup_with_gen):
assert setup_with_gen.model is None
setup_with_gen.build_model()
assert isinstance(setup_with_gen.model, AbstractModelClass)
expected = {"window_history_size", "window_lead_time", "channels", "dropout_rate", "regularizer", "initial_lr",
"optimizer", "batch_size", "activation"}
assert expected <= self.current_scope_as_set(setup_with_gen)
def test_set_channels(self, setup_with_gen_tiny):
assert len(setup_with_gen_tiny.data_store.search_name("channels")) == 0
setup_with_gen_tiny._set_channels()
assert setup_with_gen_tiny.data_store.get("channels", setup_with_gen_tiny.scope) == 2
def test_load_weights(self):
pass
def test_compile_model(self):
pass
def test_run(self):
pass
def test_init(self):
pass