From fb17c979da16bee07148893aef9a1d94fbb8b8da Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Mon, 16 May 2022 10:51:22 +0200 Subject: [PATCH] update tests --- mlair/run_modules/training.py | 13 ------------- test/test_run_modules/test_model_setup.py | 4 ++-- test/test_run_modules/test_training.py | 8 +------- 3 files changed, 3 insertions(+), 22 deletions(-) diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py index cb9527ff..5ce90612 100644 --- a/mlair/run_modules/training.py +++ b/mlair/run_modules/training.py @@ -187,19 +187,6 @@ class Training(RunEnvironment): self.model.save(model_name, save_format="tf") self.data_store.set("model", self.model) - def load_best_model(self, name: str) -> None: - """ - Load model weights for model with name. Skip if no weights are available. - - :param name: name of the model to load weights for - """ - logging.debug(f"load best model: {name}") - try: - self.model.load_model(name, compile=True) - logging.info(f"reload model...") - except OSError: - logging.info("no weights to reload...") - def save_callbacks_as_json(self, history: Callback, lr_sc: Callback, epo_timing: Callback) -> None: """ Save callbacks (history, learning rate) of training. diff --git a/test/test_run_modules/test_model_setup.py b/test/test_run_modules/test_model_setup.py index 60b37207..962287e0 100644 --- a/test/test_run_modules/test_model_setup.py +++ b/test/test_run_modules/test_model_setup.py @@ -80,7 +80,7 @@ class TestModelSetup: setup._set_callbacks() assert "general.model" in setup.data_store.search_name("callbacks") callbacks = setup.data_store.get("callbacks", "general.model") - assert len(callbacks.get_callbacks()) == 4 + assert len(callbacks.get_callbacks()) == 5 def test_set_callbacks_no_lr_decay(self, setup): setup.data_store.set("lr_decay", None, "general.model") @@ -88,7 +88,7 @@ class TestModelSetup: setup.checkpoint_name = "TestName" setup._set_callbacks() callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.model") - assert len(callbacks.get_callbacks()) == 3 + assert len(callbacks.get_callbacks()) == 4 with pytest.raises(IndexError): callbacks.get_callback_by_name("lr_decay") diff --git a/test/test_run_modules/test_training.py b/test/test_run_modules/test_training.py index 29717674..8f1fcd19 100644 --- a/test/test_run_modules/test_training.py +++ b/test/test_run_modules/test_training.py @@ -326,16 +326,10 @@ class TestTraining: model_name = "test_model.h5" assert model_name not in os.listdir(model_path) init_without_run.save_model() - message = PyTestRegex(f"save best model to {os.path.join(model_path, model_name)}") + message = PyTestRegex(f"save model to {os.path.join(model_path, model_name)}") assert caplog.record_tuples[1] == ("root", 10, message) assert model_name in os.listdir(model_path) - def test_load_best_model_no_weights(self, init_without_run, caplog): - caplog.set_level(logging.DEBUG) - init_without_run.load_best_model("notExisting.h5") - assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting.h5")) - assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload...")) - def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, epo_timing, model_path): init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing) assert "history.json" in os.listdir(model_path) -- GitLab