From d48493644636f208b848c23e2f9b5ee57f3988bc Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Fri, 13 Mar 2020 14:55:54 +0100 Subject: [PATCH] updated tests --- test/test_modules/test_model_setup.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py index ade35a24..9ff7494f 100644 --- a/test/test_modules/test_model_setup.py +++ b/test/test_modules/test_model_setup.py @@ -4,6 +4,7 @@ import pytest from src.data_handling.data_generator import DataGenerator from src.datastore import EmptyScope +from src.model_modules.keras_extensions import CallbackHandler from src.model_modules.model_class import AbstractModelClass from src.run_modules.model_setup import ModelSetup from src.run_modules.run_environment import RunEnvironment @@ -61,6 +62,18 @@ class TestModelSetup: setup.checkpoint_name = "TestName" setup._set_callbacks() assert "general.modeltest" in setup.data_store.search_name("callbacks") + callbacks = setup.data_store.get("callbacks", "general.modeltest") + assert len(callbacks.get_callbacks()) == 3 + + def test_set_callbacks_no_lr_decay(self, setup): + setup.data_store.set("lr_decay", None, "general.model") + assert "general.modeltest" not in setup.data_store.search_name("callbacks") + setup.checkpoint_name = "TestName" + setup._set_callbacks() + callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.modeltest") + assert len(callbacks.get_callbacks()) == 2 + with pytest.raises(IndexError): + callbacks.get_callback_by_name("lr_decay") def test_get_model_settings(self, setup_with_model): with pytest.raises(EmptyScope): @@ -73,7 +86,7 @@ class TestModelSetup: setup_with_gen.build_model() assert isinstance(setup_with_gen.model, AbstractModelClass) expected = {"window_history_size", "window_lead_time", "channels", "dropout_rate", "regularizer", "initial_lr", - "optimizer", "lr_decay", "epochs", "batch_size", "activation"} + "optimizer", "epochs", "batch_size", "activation"} assert expected <= self.current_scope_as_set(setup_with_gen) def test_set_channels(self, setup_with_gen_tiny): -- GitLab