diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py index d890e7b0ff3beea812d8fc7766433a84d65a1ebe..a5fdcea7db6c1a2d6ca26f4b9d7c9d365aa45924 100644 --- a/mlair/model_modules/keras_extensions.py +++ b/mlair/model_modules/keras_extensions.py @@ -3,6 +3,7 @@ __author__ = 'Lukas Leufen, Felix Kleinert' __date__ = '2020-01-31' +import copy import logging import math import pickle @@ -199,12 +200,18 @@ class ModelCheckpointAdvanced(ModelCheckpoint): if self.verbose > 0: # pragma: no branch print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) with open(file_path, "wb") as f: - pickle.dump(callback["callback"], f) + c = copy.copy(callback["callback"]) + if hasattr(c, "model"): + c.model = None + pickle.dump(c, f) else: with open(file_path, "wb") as f: if self.verbose > 0: # pragma: no branch print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) - pickle.dump(callback["callback"], f) + c = copy.copy(callback["callback"]) + if hasattr(c, "model"): + c.model = None + pickle.dump(c, f) clbk_type = TypedDict("clbk_type", {"name": str, str: Callback, "path": str}) @@ -346,6 +353,8 @@ class CallbackHandler: for pos, callback in enumerate(self.__callbacks): path = callback["path"] clb = pickle.load(open(path, "rb")) + if clb.model is None: + clb.model = self._checkpoint.model self._update_callback(pos, clb) def update_checkpoint(self, history_name: str = "hist") -> None: diff --git a/test/test_run_modules/test_training.py b/test/test_run_modules/test_training.py index 9d633a348bd1e24cd3f3abcdb83124f6107db2e9..f1b210e1c7429c96658238ac21d96b7843053da7 100644 --- a/test/test_run_modules/test_training.py +++ b/test/test_run_modules/test_training.py @@ -1,8 +1,10 @@ +import copy import glob import json import logging import os import shutil +from typing import Callable import tensorflow.keras as keras import mock @@ -76,10 +78,24 @@ class TestTraining: obj.data_store.set("plot_path", path_plot, "general") obj._train_model = True obj._create_new_model = False - yield obj - if os.path.exists(path): - shutil.rmtree(path) - RunEnvironment().__del__() + try: + yield obj + finally: + if os.path.exists(path): + shutil.rmtree(path) + try: + RunEnvironment().__del__() + except AssertionError: + pass + # try: + # yield obj + # finally: + # if os.path.exists(path): + # shutil.rmtree(path) + # try: + # RunEnvironment().__del__() + # except AssertionError: + # pass @pytest.fixture def learning_rate(self): @@ -223,9 +239,10 @@ class TestTraining: assert ready_to_run._run() is None # just test, if nothing fails def test_make_predict_function(self, init_without_run): - assert hasattr(init_without_run.model, "predict_function") is False + assert hasattr(init_without_run.model, "predict_function") is True + assert init_without_run.model.predict_function is None init_without_run.make_predict_function() - assert hasattr(init_without_run.model, "predict_function") + assert isinstance(init_without_run.model.predict_function, Callable) def test_set_gen(self, init_without_run): assert init_without_run.train_set is None @@ -242,10 +259,10 @@ class TestTraining: [getattr(init_without_run, f"{obj}_set")._collection.return_value == f"mock_{obj}_gen" for obj in sets]) def test_train(self, ready_to_train, path): - assert not hasattr(ready_to_train.model, "history") + assert ready_to_train.model.history is None assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0 ready_to_train.train() - assert list(ready_to_train.model.history.history.keys()) == ["val_loss", "loss"] + assert sorted(list(ready_to_train.model.history.history.keys())) == ["loss", "val_loss"] assert ready_to_train.model.history.epoch == [0, 1] assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2 @@ -260,8 +277,8 @@ class TestTraining: def test_load_best_model_no_weights(self, init_without_run, caplog): caplog.set_level(logging.DEBUG) - init_without_run.load_best_model("notExisting") - assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting")) + init_without_run.load_best_model("notExisting.h5") + assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting.h5")) assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload...")) def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, epo_timing, model_path): @@ -290,3 +307,10 @@ class TestTraining: history.model.metrics_names = mock.MagicMock(return_value=["loss", "mean_squared_error"]) init_without_run.create_monitoring_plots(history, learning_rate) assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2 + + def test_resume_training(self, ready_to_run): + with copy.copy(ready_to_run) as pre_run: + assert pre_run._run() is None # rune once to create model + ready_to_run.epochs = 4 # continue train up to epoch 4 + assert ready_to_run._run() is None +