diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py index cb9527ff9243c0d35c2fddb3d22368ef918ac2af..5ce906122ef184d6dcad5527e923e44f04028fe5 100644 --- a/mlair/run_modules/training.py +++ b/mlair/run_modules/training.py @@ -187,19 +187,6 @@ class Training(RunEnvironment): self.model.save(model_name, save_format="tf") self.data_store.set("model", self.model) - def load_best_model(self, name: str) -> None: - """ - Load model weights for model with name. Skip if no weights are available. - - :param name: name of the model to load weights for - """ - logging.debug(f"load best model: {name}") - try: - self.model.load_model(name, compile=True) - logging.info(f"reload model...") - except OSError: - logging.info("no weights to reload...") - def save_callbacks_as_json(self, history: Callback, lr_sc: Callback, epo_timing: Callback) -> None: """ Save callbacks (history, learning rate) of training. diff --git a/test/test_run_modules/test_model_setup.py b/test/test_run_modules/test_model_setup.py index 60b37207ceefc4088b33fa002dac9db7c6c35399..962287e09aacd3c44961a827c86b331d643ec401 100644 --- a/test/test_run_modules/test_model_setup.py +++ b/test/test_run_modules/test_model_setup.py @@ -80,7 +80,7 @@ class TestModelSetup: setup._set_callbacks() assert "general.model" in setup.data_store.search_name("callbacks") callbacks = setup.data_store.get("callbacks", "general.model") - assert len(callbacks.get_callbacks()) == 4 + assert len(callbacks.get_callbacks()) == 5 def test_set_callbacks_no_lr_decay(self, setup): setup.data_store.set("lr_decay", None, "general.model") @@ -88,7 +88,7 @@ class TestModelSetup: setup.checkpoint_name = "TestName" setup._set_callbacks() callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.model") - assert len(callbacks.get_callbacks()) == 3 + assert len(callbacks.get_callbacks()) == 4 with pytest.raises(IndexError): callbacks.get_callback_by_name("lr_decay") diff --git a/test/test_run_modules/test_training.py b/test/test_run_modules/test_training.py index 29717674241d912e14534300302790e56fec1df3..8f1fcd1943f9f203e738053017e00f8c269afef1 100644 --- a/test/test_run_modules/test_training.py +++ b/test/test_run_modules/test_training.py @@ -326,16 +326,10 @@ class TestTraining: model_name = "test_model.h5" assert model_name not in os.listdir(model_path) init_without_run.save_model() - message = PyTestRegex(f"save best model to {os.path.join(model_path, model_name)}") + message = PyTestRegex(f"save model to {os.path.join(model_path, model_name)}") assert caplog.record_tuples[1] == ("root", 10, message) assert model_name in os.listdir(model_path) - def test_load_best_model_no_weights(self, init_without_run, caplog): - caplog.set_level(logging.DEBUG) - init_without_run.load_best_model("notExisting.h5") - assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting.h5")) - assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload...")) - def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, epo_timing, model_path): init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing) assert "history.json" in os.listdir(model_path)