diff --git a/src/model_modules/keras_extensions.py b/src/model_modules/keras_extensions.py
index e2a4b93219be2cbebfb35749560efa65c07226bb..cfb8638816e85ec8427f3df5d45d480d21fd929f 100644
--- a/src/model_modules/keras_extensions.py
+++ b/src/model_modules/keras_extensions.py
@@ -10,6 +10,8 @@ import numpy as np
 from keras import backend as K
 from keras.callbacks import History, ModelCheckpoint
 
+from src import helpers
+
 
 class HistoryAdvanced(History):
     """
@@ -125,7 +127,7 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
         Update all stored callback objects. The argument callbacks needs to follow the same convention like described
         in the class description (list of dictionaries). Must be run before resuming a training process.
         """
-        self.callbacks = callbacks
+        self.callbacks = helpers.to_list(callbacks)
 
     def on_epoch_end(self, epoch, logs=None):
         """
@@ -139,12 +141,12 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
                 if self.save_best_only:
                     current = logs.get(self.monitor)
                     if current == self.best:
-                        if self.verbose > 0:
+                        if self.verbose > 0:  # pragma: no branch
                             print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
                         with open(file_path, "wb") as f:
                             pickle.dump(callback["callback"], f)
                 else:
                     with open(file_path, "wb") as f:
-                        if self.verbose > 0:
+                        if self.verbose > 0:  # pragma: no branch
                             print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
                         pickle.dump(callback["callback"], f)
diff --git a/test/test_model_modules/test_keras_extensions.py b/test/test_model_modules/test_keras_extensions.py
index 2f6565b4cabe295169047a6582d2b89cbf387062..9141434d448f8d076836ef357279ae5815686767 100644
--- a/test/test_model_modules/test_keras_extensions.py
+++ b/test/test_model_modules/test_keras_extensions.py
@@ -1,6 +1,8 @@
 import keras
 import numpy as np
 import pytest
+import mock
+import os
 
 from src.helpers import l_p_loss
 from src.model_modules.keras_extensions import *
@@ -60,3 +62,51 @@ class TestLearningRateDecay:
         model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(2))
         model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=5, callbacks=[lr_decay])
         assert lr_decay.lr['lr'] == [0.02, 0.02, 0.02 * 0.95, 0.02 * 0.95, 0.02 * 0.95 * 0.95]
+
+
+class TestModelCheckpointAdvanced:
+
+    @pytest.fixture()
+    def callbacks(self):
+        callbacks_name = os.path.join(os.path.dirname(__file__), "callback_%s")
+        return [{"callback": LearningRateDecay(), "path": callbacks_name % "lr"},
+                     {"callback": HistoryAdvanced(), "path": callbacks_name % "hist"}]
+
+    @pytest.fixture
+    def ckpt(self, callbacks):
+        ckpt_name = "ckpt.test"
+        return ModelCheckpointAdvanced(filepath=ckpt_name, monitor='val_loss', save_best_only=True, callbacks=callbacks, verbose=1)
+
+    def test_init(self, ckpt, callbacks):
+        assert ckpt.callbacks == callbacks
+        assert ckpt.monitor == "val_loss"
+        assert ckpt.save_best_only is True
+        assert ckpt.best == np.inf
+
+    def test_update_best(self, ckpt):
+        hist = HistoryAdvanced()
+        hist.history["val_loss"] = [10, 6]
+        ckpt.update_best(hist)
+        assert ckpt.best == 6
+
+    def test_update_callbacks(self, ckpt, callbacks):
+        ckpt.update_callbacks(callbacks[0])
+        assert ckpt.callbacks == [callbacks[0]]
+
+    def test_on_epoch_end(self, ckpt):
+        path = os.path.dirname(__file__)
+        ckpt.set_model(mock.MagicMock())
+        ckpt.best = 6
+        ckpt.on_epoch_end(0, {"val_loss": 6})
+        assert "callback_hist" not in os.listdir(path)
+        ckpt.on_epoch_end(9, {"val_loss": 10})
+        assert "callback_hist" not in os.listdir(path)
+        ckpt.on_epoch_end(10, {"val_loss": 4})
+        assert "callback_hist" in os.listdir(path)
+        os.remove(os.path.join(path, "callback_hist"))
+        os.remove(os.path.join(path, "callback_lr"))
+        ckpt.save_best_only = False
+        ckpt.on_epoch_end(10, {"val_loss": 3})
+        assert "callback_hist" in os.listdir(path)
+        os.remove(os.path.join(path, "callback_hist"))
+        os.remove(os.path.join(path, "callback_lr"))