diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py
index d890e7b0ff3beea812d8fc7766433a84d65a1ebe..a5fdcea7db6c1a2d6ca26f4b9d7c9d365aa45924 100644
--- a/mlair/model_modules/keras_extensions.py
+++ b/mlair/model_modules/keras_extensions.py
@@ -3,6 +3,7 @@
 __author__ = 'Lukas Leufen, Felix Kleinert'
 __date__ = '2020-01-31'
 
+import copy
 import logging
 import math
 import pickle
@@ -199,12 +200,18 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
                         if self.verbose > 0:  # pragma: no branch
                             print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
                         with open(file_path, "wb") as f:
-                            pickle.dump(callback["callback"], f)
+                            c = copy.copy(callback["callback"])
+                            if hasattr(c, "model"):
+                                c.model = None
+                            pickle.dump(c, f)
                 else:
                     with open(file_path, "wb") as f:
                         if self.verbose > 0:  # pragma: no branch
                             print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
-                        pickle.dump(callback["callback"], f)
+                        c = copy.copy(callback["callback"])
+                        if hasattr(c, "model"):
+                            c.model = None
+                        pickle.dump(c, f)
 
 
 clbk_type = TypedDict("clbk_type", {"name": str, str: Callback, "path": str})
@@ -346,6 +353,8 @@ class CallbackHandler:
         for pos, callback in enumerate(self.__callbacks):
             path = callback["path"]
             clb = pickle.load(open(path, "rb"))
+            if clb.model is None:
+                clb.model = self._checkpoint.model
             self._update_callback(pos, clb)
 
     def update_checkpoint(self, history_name: str = "hist") -> None:
diff --git a/test/test_run_modules/test_training.py b/test/test_run_modules/test_training.py
index 9d633a348bd1e24cd3f3abcdb83124f6107db2e9..f1b210e1c7429c96658238ac21d96b7843053da7 100644
--- a/test/test_run_modules/test_training.py
+++ b/test/test_run_modules/test_training.py
@@ -1,8 +1,10 @@
+import copy
 import glob
 import json
 import logging
 import os
 import shutil
+from typing import Callable
 
 import tensorflow.keras as keras
 import mock
@@ -76,10 +78,24 @@ class TestTraining:
         obj.data_store.set("plot_path", path_plot, "general")
         obj._train_model = True
         obj._create_new_model = False
-        yield obj
-        if os.path.exists(path):
-            shutil.rmtree(path)
-        RunEnvironment().__del__()
+        try:
+            yield obj
+        finally:
+            if os.path.exists(path):
+                shutil.rmtree(path)
+            try:
+                RunEnvironment().__del__()
+            except AssertionError:
+                pass
+        # try:
+        #     yield obj
+        # finally:
+        #     if os.path.exists(path):
+        #         shutil.rmtree(path)
+        #     try:
+        #         RunEnvironment().__del__()
+        #     except AssertionError:
+        #         pass
 
     @pytest.fixture
     def learning_rate(self):
@@ -223,9 +239,10 @@ class TestTraining:
         assert ready_to_run._run() is None  # just test, if nothing fails
 
     def test_make_predict_function(self, init_without_run):
-        assert hasattr(init_without_run.model, "predict_function") is False
+        assert hasattr(init_without_run.model, "predict_function") is True
+        assert init_without_run.model.predict_function is None
         init_without_run.make_predict_function()
-        assert hasattr(init_without_run.model, "predict_function")
+        assert isinstance(init_without_run.model.predict_function, Callable)
 
     def test_set_gen(self, init_without_run):
         assert init_without_run.train_set is None
@@ -242,10 +259,10 @@ class TestTraining:
             [getattr(init_without_run, f"{obj}_set")._collection.return_value == f"mock_{obj}_gen" for obj in sets])
 
     def test_train(self, ready_to_train, path):
-        assert not hasattr(ready_to_train.model, "history")
+        assert ready_to_train.model.history is None
         assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0
         ready_to_train.train()
-        assert list(ready_to_train.model.history.history.keys()) == ["val_loss", "loss"]
+        assert sorted(list(ready_to_train.model.history.history.keys())) == ["loss", "val_loss"]
         assert ready_to_train.model.history.epoch == [0, 1]
         assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2
 
@@ -260,8 +277,8 @@ class TestTraining:
 
     def test_load_best_model_no_weights(self, init_without_run, caplog):
         caplog.set_level(logging.DEBUG)
-        init_without_run.load_best_model("notExisting")
-        assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting"))
+        init_without_run.load_best_model("notExisting.h5")
+        assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting.h5"))
         assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload..."))
 
     def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, epo_timing, model_path):
@@ -290,3 +307,10 @@ class TestTraining:
         history.model.metrics_names = mock.MagicMock(return_value=["loss", "mean_squared_error"])
         init_without_run.create_monitoring_plots(history, learning_rate)
         assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2
+
+    def test_resume_training(self, ready_to_run):
+        with copy.copy(ready_to_run) as pre_run:
+            assert pre_run._run() is None  # rune once to create model
+            ready_to_run.epochs = 4  # continue train up to epoch 4
+            assert ready_to_run._run() is None
+