Skip to content
Snippets Groups Projects
Commit d4849364 authored by lukas leufen's avatar lukas leufen
Browse files

updated tests

parent df8e3cba
Branches
Tags
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!73Resolve "lr_decay should be optional"
Pipeline #31877 passed
...@@ -4,6 +4,7 @@ import pytest ...@@ -4,6 +4,7 @@ import pytest
from src.data_handling.data_generator import DataGenerator from src.data_handling.data_generator import DataGenerator
from src.datastore import EmptyScope from src.datastore import EmptyScope
from src.model_modules.keras_extensions import CallbackHandler
from src.model_modules.model_class import AbstractModelClass from src.model_modules.model_class import AbstractModelClass
from src.run_modules.model_setup import ModelSetup from src.run_modules.model_setup import ModelSetup
from src.run_modules.run_environment import RunEnvironment from src.run_modules.run_environment import RunEnvironment
...@@ -61,6 +62,18 @@ class TestModelSetup: ...@@ -61,6 +62,18 @@ class TestModelSetup:
setup.checkpoint_name = "TestName" setup.checkpoint_name = "TestName"
setup._set_callbacks() setup._set_callbacks()
assert "general.modeltest" in setup.data_store.search_name("callbacks") assert "general.modeltest" in setup.data_store.search_name("callbacks")
callbacks = setup.data_store.get("callbacks", "general.modeltest")
assert len(callbacks.get_callbacks()) == 3
def test_set_callbacks_no_lr_decay(self, setup):
setup.data_store.set("lr_decay", None, "general.model")
assert "general.modeltest" not in setup.data_store.search_name("callbacks")
setup.checkpoint_name = "TestName"
setup._set_callbacks()
callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.modeltest")
assert len(callbacks.get_callbacks()) == 2
with pytest.raises(IndexError):
callbacks.get_callback_by_name("lr_decay")
def test_get_model_settings(self, setup_with_model): def test_get_model_settings(self, setup_with_model):
with pytest.raises(EmptyScope): with pytest.raises(EmptyScope):
...@@ -73,7 +86,7 @@ class TestModelSetup: ...@@ -73,7 +86,7 @@ class TestModelSetup:
setup_with_gen.build_model() setup_with_gen.build_model()
assert isinstance(setup_with_gen.model, AbstractModelClass) assert isinstance(setup_with_gen.model, AbstractModelClass)
expected = {"window_history_size", "window_lead_time", "channels", "dropout_rate", "regularizer", "initial_lr", expected = {"window_history_size", "window_lead_time", "channels", "dropout_rate", "regularizer", "initial_lr",
"optimizer", "lr_decay", "epochs", "batch_size", "activation"} "optimizer", "epochs", "batch_size", "activation"}
assert expected <= self.current_scope_as_set(setup_with_gen) assert expected <= self.current_scope_as_set(setup_with_gen)
def test_set_channels(self, setup_with_gen_tiny): def test_set_channels(self, setup_with_gen_tiny):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment