import os import numpy as np import shutil import pytest from mlair.data_handler import KerasIterator from mlair.data_handler import DataCollection from mlair.helpers.datastore import EmptyScope from mlair.model_modules.keras_extensions import CallbackHandler from mlair.model_modules.fully_connected_networks import FCN_64_32_16 from mlair.model_modules import AbstractModelClass from mlair.run_modules.model_setup import ModelSetup from mlair.run_modules.run_environment import RunEnvironment class TestModelSetup: @pytest.fixture def setup(self): obj = object.__new__(ModelSetup) super(ModelSetup, obj).__init__() obj.scope = "general.model" obj.model = None obj.callbacks_name = "placeholder_%s_str.pickle" obj.data_store.set("model_class", FCN_64_32_16) obj.data_store.set("lr_decay", "dummy_str", "general.model") obj.data_store.set("hist", "dummy_str", "general.model") obj.data_store.set("epochs", 2) obj.model_name = "%s.h5" yield obj RunEnvironment().__del__() @pytest.fixture def path(self): p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata") shutil.rmtree(p, ignore_errors=True) if os.path.exists(p) else None yield p shutil.rmtree(p, ignore_errors=True) @pytest.fixture def keras_iterator(self, path): coll = [] for i in range(3): coll.append(DummyData(50 + i)) data_coll = DataCollection(collection=coll) KerasIterator(data_coll, 25, path) return data_coll @pytest.fixture def setup_with_gen(self, setup, keras_iterator): setup.data_store.set("data_collection", keras_iterator, "train") input_shape = [keras_iterator[0].get_X()[0].shape[1:]] setup.data_store.set("input_shape", input_shape, "model") output_shape = [keras_iterator[0].get_Y()[0].shape[1:]] setup.data_store.set("output_shape", output_shape, "model") yield setup RunEnvironment().__del__() @pytest.fixture def setup_with_gen_tiny(self, setup, keras_iterator): setup.data_store.set("data_collection", keras_iterator, "train") yield setup RunEnvironment().__del__() @pytest.fixture def setup_with_model(self, setup): setup.model = AbstractModelClass(input_shape=(12, 1), output_shape=2) setup.model.test_param = "42" yield setup RunEnvironment().__del__() @staticmethod def current_scope_as_set(model_cls): return set(model_cls.data_store.search_scope(model_cls.scope, current_scope_only=True)) def test_set_callbacks(self, setup): assert "general.model" not in setup.data_store.search_name("callbacks") setup.checkpoint_name = "TestName" setup._set_callbacks() assert "general.model" in setup.data_store.search_name("callbacks") callbacks = setup.data_store.get("callbacks", "general.model") assert len(callbacks.get_callbacks()) == 3 def test_set_callbacks_no_lr_decay(self, setup): setup.data_store.set("lr_decay", None, "general.model") assert "general.model" not in setup.data_store.search_name("callbacks") setup.checkpoint_name = "TestName" setup._set_callbacks() callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.model") assert len(callbacks.get_callbacks()) == 2 with pytest.raises(IndexError): callbacks.get_callback_by_name("lr_decay") def test_get_model_settings(self, setup_with_model): setup_with_model.scope = "model_test" with pytest.raises(EmptyScope): self.current_scope_as_set(setup_with_model) # will fail because scope is not created setup_with_model.get_model_settings() # this saves now the parameter test_param into scope assert {"test_param", "model_name"} <= self.current_scope_as_set(setup_with_model) def test_build_model(self, setup_with_gen): assert setup_with_gen.model is None setup_with_gen.build_model() assert isinstance(setup_with_gen.model, AbstractModelClass) expected = {"lr_decay", "model_name", "optimizer", "activation", "input_shape", "output_shape"} assert expected <= self.current_scope_as_set(setup_with_gen) def test_set_shapes(self, setup_with_gen_tiny): assert len(setup_with_gen_tiny.data_store.search_name("input_shape")) == 0 assert len(setup_with_gen_tiny.data_store.search_name("output_shape")) == 0 setup_with_gen_tiny._set_shapes() assert setup_with_gen_tiny.data_store.get("input_shape", setup_with_gen_tiny.scope) == [(14, 1, 5), (10, 1, 2), (1, 1, 2)] assert setup_with_gen_tiny.data_store.get("output_shape", setup_with_gen_tiny.scope) == [(5,), (3,)] def test_load_weights(self): pass def test_compile_model(self): pass def test_run(self): pass def test_init(self): pass class DummyData: def __init__(self, number_of_samples=np.random.randint(100, 150)): self.number_of_samples = number_of_samples def get_X(self, upsampling=False, as_numpy=True): X1 = np.random.randint(0, 10, size=(self.number_of_samples, 14, 1, 5)) # samples, window, variables X2 = np.random.randint(21, 30, size=(self.number_of_samples, 10, 1, 2)) # samples, window, variables X3 = np.random.randint(-5, 0, size=(self.number_of_samples, 1, 1, 2)) # samples, window, variables return [X1, X2, X3] def get_Y(self, upsampling=False, as_numpy=True): Y1 = np.random.randint(0, 10, size=(self.number_of_samples, 5)) # samples, window Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 3)) # samples, window return [Y1, Y2]