Skip to content
Snippets Groups Projects
Commit 9e6497c2 authored by lukas leufen's avatar lukas leufen
Browse files

another test update

parent a5d01b61
No related branches found
No related tags found
2 merge requests!24include recent development,!22model class
Pipeline #27486 passed
...@@ -11,19 +11,19 @@ class TestAbstractModelClass: ...@@ -11,19 +11,19 @@ class TestAbstractModelClass:
return AbstractModelClass() return AbstractModelClass()
def test_init(self, amc): def test_init(self, amc):
assert amc.__model is None assert amc.model is None
assert amc.__loss is None assert amc.loss is None
def test_model_property(self, amc): def test_model_property(self, amc):
amc.__model = keras.Model() amc.model = keras.Model()
assert isinstance(amc.model, keras.Model) is True assert isinstance(amc.model, keras.Model) is True
def test_loss_property(self, amc): def test_loss_property(self, amc):
amc.__loss = keras.losses.mean_absolute_error amc.loss = keras.losses.mean_absolute_error
assert amc.loss == keras.losses.mean_absolute_error assert amc.loss == keras.losses.mean_absolute_error
def test_getattr(self, amc): def test_getattr(self, amc):
amc.__model = keras.Model() amc.model = keras.Model()
assert hasattr(amc, "compile") is True assert hasattr(amc, "compile") is True
assert hasattr(amc.model, "compile") is True assert hasattr(amc.model, "compile") is True
assert amc.compile == amc.model.compile assert amc.compile == amc.model.compile
...@@ -31,16 +31,22 @@ class TestModelSetup: ...@@ -31,16 +31,22 @@ class TestModelSetup:
setup.data_store.set("generator", gen, "general.train") setup.data_store.set("generator", gen, "general.train")
setup.data_store.set("window_history_size", gen.window_history_size, "general") setup.data_store.set("window_history_size", gen.window_history_size, "general")
setup.data_store.set("window_lead_time", gen.window_lead_time, "general") setup.data_store.set("window_lead_time", gen.window_lead_time, "general")
setup.data_store.set("channels", 2, "general")
yield setup yield setup
RunEnvironment().__del__() RunEnvironment().__del__()
@pytest.fixture @pytest.fixture
def setup_with_model(self, setup_with_gen): def setup_with_gen_tiny(self, setup, gen):
setup_with_gen.data_store.set("channels", 2, "general") setup.data_store.set("generator", gen, "general.train")
setup_with_gen.model = AbstractModelClass() yield setup
setup_with_gen.model.epochs = 2 RunEnvironment().__del__()
setup_with_gen.model.batch_size = int(256)
yield setup_with_gen @pytest.fixture
def setup_with_model(self, setup):
setup.model = AbstractModelClass()
setup.model.epochs = 2
setup.model.batch_size = int(256)
yield setup
RunEnvironment().__del__() RunEnvironment().__del__()
@staticmethod @staticmethod
...@@ -67,10 +73,10 @@ class TestModelSetup: ...@@ -67,10 +73,10 @@ class TestModelSetup:
"optimizer", "lr_decay", "epochs", "batch_size", "activation"} "optimizer", "lr_decay", "epochs", "batch_size", "activation"}
assert expected <= self.current_scope_as_set(setup_with_gen) assert expected <= self.current_scope_as_set(setup_with_gen)
def test_set_channels(self, setup_with_gen): def test_set_channels(self, setup_with_gen_tiny):
assert len(setup_with_gen.data_store.search_name("channels")) == 0 assert len(setup_with_gen_tiny.data_store.search_name("channels")) == 0
setup_with_gen._set_channels() setup_with_gen_tiny._set_channels()
assert setup_with_gen.data_store.get("channels", setup_with_gen.scope) == 2 assert setup_with_gen_tiny.data_store.get("channels", setup_with_gen_tiny.scope) == 2
def test_load_weights(self): def test_load_weights(self):
pass pass
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment