diff --git a/test/test_run_modules/test_training.py b/test/test_run_modules/test_training.py
index 44e664e4f47dfd842ed956fcf7f7e56becb758ef..51ea1cd344c1ff1899af818c6b38a2cbb93b733a 100644
--- a/test/test_run_modules/test_training.py
+++ b/test/test_run_modules/test_training.py
@@ -1,6 +1,8 @@
 import copy
 import glob
 import json
+import time
+
 import logging
 import os
 import shutil
@@ -161,11 +163,8 @@ class TestTraining:
     @pytest.fixture
     def model(self, window_history_size, window_lead_time, statistics_per_var):
         channels = len(list(statistics_per_var.keys()))
-
         return FCN([(window_history_size + 1, 1, channels)], [window_lead_time])
 
-        # return my_test_model(keras.layers.PReLU, window_history_size, channels, window_lead_time, 0.1, False)
-
     @pytest.fixture
     def callbacks(self, path):
         clbk = CallbackHandler()
@@ -194,7 +193,7 @@ class TestTraining:
         obj.data_store.set("data_collection", data_collection, "general.train")
         obj.data_store.set("data_collection", data_collection, "general.val")
         obj.data_store.set("data_collection", data_collection, "general.test")
-        obj.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
+        obj.model.compile(**obj.model.compile_options)
         return obj
 
     @pytest.fixture
@@ -229,6 +228,57 @@ class TestTraining:
         if os.path.exists(path):
             shutil.rmtree(path)
 
+    @staticmethod
+    def create_training_obj(epochs, path, data_collection, batch_path, model_path,
+                            statistics_per_var, window_history_size, window_lead_time) -> Training:
+
+        channels = len(list(statistics_per_var.keys()))
+        model =  FCN([(window_history_size + 1, 1, channels)], [window_lead_time])
+
+        obj = object.__new__(Training)
+        super(Training, obj).__init__()
+        obj.model = model
+        obj.train_set = None
+        obj.val_set = None
+        obj.test_set = None
+        obj.batch_size = 256
+        obj.epochs = epochs
+
+        clbk = CallbackHandler()
+        hist = HistoryAdvanced()
+        epo_timing = EpoTimingCallback()
+        clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist")
+        lr = LearningRateDecay()
+        clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr")
+        clbk.add_callback(epo_timing, os.path.join(path, "epo_timing.pickle"), "epo_timing")
+        clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss',
+                                     save_best_only=True)
+        obj.callbacks = clbk
+        obj.lr_sc = lr
+        obj.hist = hist
+        obj.experiment_name = "TestExperiment"
+        obj.data_store.set("data_collection", data_collection, "general.train")
+        obj.data_store.set("data_collection", data_collection, "general.val")
+        obj.data_store.set("data_collection", data_collection, "general.test")
+        if not os.path.exists(path):
+            os.makedirs(path)
+        obj.data_store.set("experiment_path", path, "general")
+        os.makedirs(batch_path, exist_ok=True)
+        obj.data_store.set("batch_path", batch_path, "general")
+        os.makedirs(model_path, exist_ok=True)
+        obj.data_store.set("model_path", model_path, "general")
+        obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model")
+        obj.data_store.set("experiment_name", "TestExperiment", "general")
+
+        path_plot = os.path.join(path, "plots")
+        os.makedirs(path_plot, exist_ok=True)
+        obj.data_store.set("plot_path", path_plot, "general")
+        obj._train_model = True
+        obj._create_new_model = False
+
+        obj.model.compile(**obj.model.compile_options)
+        return obj
+
     def test_init(self, ready_to_init):
         assert isinstance(Training(), Training)  # just test, if nothing fails
 
@@ -312,58 +362,13 @@ class TestTraining:
         init_without_run.create_monitoring_plots(history, learning_rate)
         assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2
 
-    def test_resume_training(self, ready_to_run, path: str, model: keras.Model, model_path,
-                             batch_path, data_collection):
-        with ready_to_run as run_obj:
-            assert run_obj._run() is None  # rune once to create model
-
-            # init new object
-            obj = object.__new__(Training)
-            super(Training, obj).__init__()
-            obj.model = model
-            obj.train_set = None
-            obj.val_set = None
-            obj.test_set = None
-            obj.batch_size = 256
-            obj.epochs = 4
-
-            clbk = CallbackHandler()
-            hist = HistoryAdvanced()
-            epo_timing = EpoTimingCallback()
-            clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist")
-            lr = LearningRateDecay()
-            clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr")
-            clbk.add_callback(epo_timing, os.path.join(path, "epo_timing.pickle"), "epo_timing")
-            clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss',
-                                         save_best_only=True)
-            obj.callbacks = clbk
-            obj.lr_sc = lr
-            obj.hist = hist
-            obj.experiment_name = "TestExperiment"
-            obj.data_store.set("data_collection", data_collection, "general.train")
-            obj.data_store.set("data_collection", data_collection, "general.val")
-            obj.data_store.set("data_collection", data_collection, "general.test")
-            obj.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
-            if not os.path.exists(path):
-                os.makedirs(path)
-            obj.data_store.set("experiment_path", path, "general")
-            os.makedirs(batch_path, exist_ok=True)
-            obj.data_store.set("batch_path", batch_path, "general")
-            os.makedirs(model_path, exist_ok=True)
-            obj.data_store.set("model_path", model_path, "general")
-            obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model")
-            obj.data_store.set("experiment_name", "TestExperiment", "general")
-
-            path_plot = os.path.join(path, "plots")
-            os.makedirs(path_plot, exist_ok=True)
-            obj.data_store.set("plot_path", path_plot, "general")
-            obj._train_model = True
-            obj._create_new_model = False
-
-
-            assert obj._run() is None
-            assert 1 == 1
-        assert 1 == 1
-
-
+    def test_resume_training1(self, path: str, model_path, batch_path, data_collection, statistics_per_var,
+                              window_history_size, window_lead_time):
 
+        obj_1st = self.create_training_obj(2, path, data_collection, batch_path, model_path, statistics_per_var,
+                                           window_history_size, window_lead_time)
+        keras.utils.get_custom_objects().update(obj_1st.model.custom_objects)
+        assert obj_1st._run() is None
+        obj_2nd = self.create_training_obj(4, path, data_collection, batch_path, model_path, statistics_per_var,
+                                           window_history_size, window_lead_time)
+        assert obj_2nd._run() is None