diff --git a/src/model_modules/keras_extensions.py b/src/model_modules/keras_extensions.py index e2a4b93219be2cbebfb35749560efa65c07226bb..cfb8638816e85ec8427f3df5d45d480d21fd929f 100644 --- a/src/model_modules/keras_extensions.py +++ b/src/model_modules/keras_extensions.py @@ -10,6 +10,8 @@ import numpy as np from keras import backend as K from keras.callbacks import History, ModelCheckpoint +from src import helpers + class HistoryAdvanced(History): """ @@ -125,7 +127,7 @@ class ModelCheckpointAdvanced(ModelCheckpoint): Update all stored callback objects. The argument callbacks needs to follow the same convention like described in the class description (list of dictionaries). Must be run before resuming a training process. """ - self.callbacks = callbacks + self.callbacks = helpers.to_list(callbacks) def on_epoch_end(self, epoch, logs=None): """ @@ -139,12 +141,12 @@ class ModelCheckpointAdvanced(ModelCheckpoint): if self.save_best_only: current = logs.get(self.monitor) if current == self.best: - if self.verbose > 0: + if self.verbose > 0: # pragma: no branch print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) with open(file_path, "wb") as f: pickle.dump(callback["callback"], f) else: with open(file_path, "wb") as f: - if self.verbose > 0: + if self.verbose > 0: # pragma: no branch print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) pickle.dump(callback["callback"], f) diff --git a/test/test_model_modules/test_keras_extensions.py b/test/test_model_modules/test_keras_extensions.py index 2f6565b4cabe295169047a6582d2b89cbf387062..9141434d448f8d076836ef357279ae5815686767 100644 --- a/test/test_model_modules/test_keras_extensions.py +++ b/test/test_model_modules/test_keras_extensions.py @@ -1,6 +1,8 @@ import keras import numpy as np import pytest +import mock +import os from src.helpers import l_p_loss from src.model_modules.keras_extensions import * @@ -60,3 +62,51 @@ class TestLearningRateDecay: model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(2)) model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=5, callbacks=[lr_decay]) assert lr_decay.lr['lr'] == [0.02, 0.02, 0.02 * 0.95, 0.02 * 0.95, 0.02 * 0.95 * 0.95] + + +class TestModelCheckpointAdvanced: + + @pytest.fixture() + def callbacks(self): + callbacks_name = os.path.join(os.path.dirname(__file__), "callback_%s") + return [{"callback": LearningRateDecay(), "path": callbacks_name % "lr"}, + {"callback": HistoryAdvanced(), "path": callbacks_name % "hist"}] + + @pytest.fixture + def ckpt(self, callbacks): + ckpt_name = "ckpt.test" + return ModelCheckpointAdvanced(filepath=ckpt_name, monitor='val_loss', save_best_only=True, callbacks=callbacks, verbose=1) + + def test_init(self, ckpt, callbacks): + assert ckpt.callbacks == callbacks + assert ckpt.monitor == "val_loss" + assert ckpt.save_best_only is True + assert ckpt.best == np.inf + + def test_update_best(self, ckpt): + hist = HistoryAdvanced() + hist.history["val_loss"] = [10, 6] + ckpt.update_best(hist) + assert ckpt.best == 6 + + def test_update_callbacks(self, ckpt, callbacks): + ckpt.update_callbacks(callbacks[0]) + assert ckpt.callbacks == [callbacks[0]] + + def test_on_epoch_end(self, ckpt): + path = os.path.dirname(__file__) + ckpt.set_model(mock.MagicMock()) + ckpt.best = 6 + ckpt.on_epoch_end(0, {"val_loss": 6}) + assert "callback_hist" not in os.listdir(path) + ckpt.on_epoch_end(9, {"val_loss": 10}) + assert "callback_hist" not in os.listdir(path) + ckpt.on_epoch_end(10, {"val_loss": 4}) + assert "callback_hist" in os.listdir(path) + os.remove(os.path.join(path, "callback_hist")) + os.remove(os.path.join(path, "callback_lr")) + ckpt.save_best_only = False + ckpt.on_epoch_end(10, {"val_loss": 3}) + assert "callback_hist" in os.listdir(path) + os.remove(os.path.join(path, "callback_hist")) + os.remove(os.path.join(path, "callback_lr"))