From fb17c979da16bee07148893aef9a1d94fbb8b8da Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Mon, 16 May 2022 10:51:22 +0200
Subject: [PATCH] update tests

---
 mlair/run_modules/training.py             | 13 -------------
 test/test_run_modules/test_model_setup.py |  4 ++--
 test/test_run_modules/test_training.py    |  8 +-------
 3 files changed, 3 insertions(+), 22 deletions(-)

diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py
index cb9527ff..5ce90612 100644
--- a/mlair/run_modules/training.py
+++ b/mlair/run_modules/training.py
@@ -187,19 +187,6 @@ class Training(RunEnvironment):
         self.model.save(model_name, save_format="tf")
         self.data_store.set("model", self.model)
 
-    def load_best_model(self, name: str) -> None:
-        """
-        Load model weights for model with name. Skip if no weights are available.
-
-        :param name: name of the model to load weights for
-        """
-        logging.debug(f"load best model: {name}")
-        try:
-            self.model.load_model(name, compile=True)
-            logging.info(f"reload model...")
-        except OSError:
-            logging.info("no weights to reload...")
-
     def save_callbacks_as_json(self, history: Callback, lr_sc: Callback, epo_timing: Callback) -> None:
         """
         Save callbacks (history, learning rate) of training.
diff --git a/test/test_run_modules/test_model_setup.py b/test/test_run_modules/test_model_setup.py
index 60b37207..962287e0 100644
--- a/test/test_run_modules/test_model_setup.py
+++ b/test/test_run_modules/test_model_setup.py
@@ -80,7 +80,7 @@ class TestModelSetup:
         setup._set_callbacks()
         assert "general.model" in setup.data_store.search_name("callbacks")
         callbacks = setup.data_store.get("callbacks", "general.model")
-        assert len(callbacks.get_callbacks()) == 4
+        assert len(callbacks.get_callbacks()) == 5
 
     def test_set_callbacks_no_lr_decay(self, setup):
         setup.data_store.set("lr_decay", None, "general.model")
@@ -88,7 +88,7 @@ class TestModelSetup:
         setup.checkpoint_name = "TestName"
         setup._set_callbacks()
         callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.model")
-        assert len(callbacks.get_callbacks()) == 3
+        assert len(callbacks.get_callbacks()) == 4
         with pytest.raises(IndexError):
             callbacks.get_callback_by_name("lr_decay")
 
diff --git a/test/test_run_modules/test_training.py b/test/test_run_modules/test_training.py
index 29717674..8f1fcd19 100644
--- a/test/test_run_modules/test_training.py
+++ b/test/test_run_modules/test_training.py
@@ -326,16 +326,10 @@ class TestTraining:
         model_name = "test_model.h5"
         assert model_name not in os.listdir(model_path)
         init_without_run.save_model()
-        message = PyTestRegex(f"save best model to {os.path.join(model_path, model_name)}")
+        message = PyTestRegex(f"save model to {os.path.join(model_path, model_name)}")
         assert caplog.record_tuples[1] == ("root", 10, message)
         assert model_name in os.listdir(model_path)
 
-    def test_load_best_model_no_weights(self, init_without_run, caplog):
-        caplog.set_level(logging.DEBUG)
-        init_without_run.load_best_model("notExisting.h5")
-        assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting.h5"))
-        assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload..."))
-
     def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, epo_timing, model_path):
         init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
         assert "history.json" in os.listdir(model_path)
-- 
GitLab