diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py
index 0b9e8ec56592901d9feba15eb50b6b21a0c53560..a64ce917393dda9605e81cbef1dfd60fc57beaa5 100644
--- a/mlair/run_modules/model_setup.py
+++ b/mlair/run_modules/model_setup.py
@@ -84,7 +84,7 @@ class ModelSetup(RunEnvironment):
 
         # load weights if no training shall be performed
         if not self._train_model and not self._create_new_model:
-            self.load_weights()
+            self.load_model()
 
         # create checkpoint
         self._set_callbacks()
@@ -131,13 +131,13 @@ class ModelSetup(RunEnvironment):
                                           save_best_only=True, mode='auto')
         self.data_store.set("callbacks", callbacks, self.scope)
 
-    def load_weights(self):
-        """Try to load weights from existing model or skip if not possible."""
+    def load_model(self):
+        """Try to load model from disk or skip if not possible."""
         try:
-            self.model.load_weights(self.model_name)
-            logging.info(f"reload weights from model {self.model_name} ...")
+            self.model = keras.models.load_model(self.model_name)
+            logging.info(f"reload model {self.model_name} from disk ...")
         except OSError:
-            logging.info('no weights to reload...')
+            logging.info('no local model to load...')
 
     def build_model(self):
         """Build model using input and output shapes from data store."""
diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py
index 0696c2e7b8daa75925cf16096e183de94c21fe85..64b323fa74b3770698d40364ee7defae88a01b4c 100644
--- a/mlair/run_modules/training.py
+++ b/mlair/run_modules/training.py
@@ -189,7 +189,7 @@ class Training(RunEnvironment):
         """
         logging.debug(f"load best model: {name}")
         try:
-            self.model.load_weights(name)
+            self.model = keras.models.load_model(name)
             logging.info('reload weights...')
         except OSError:
             logging.info('no weights to reload...')
diff --git a/test/test_run_modules/test_training.py b/test/test_run_modules/test_training.py
index f1b210e1c7429c96658238ac21d96b7843053da7..46a5ba92b78cd4711e8f48c8df8e3721cf5d52cc 100644
--- a/test/test_run_modules/test_training.py
+++ b/test/test_run_modules/test_training.py
@@ -308,9 +308,58 @@ class TestTraining:
         init_without_run.create_monitoring_plots(history, learning_rate)
         assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2
 
-    def test_resume_training(self, ready_to_run):
-        with copy.copy(ready_to_run) as pre_run:
-            assert pre_run._run() is None  # rune once to create model
-            ready_to_run.epochs = 4  # continue train up to epoch 4
-            assert ready_to_run._run() is None
+    def test_resume_training(self, ready_to_run, path: str, model: keras.Model, model_path,
+                             batch_path, data_collection):
+        with ready_to_run as run_obj:
+            assert run_obj._run() is None  # rune once to create model
+
+            # init new object
+            obj = object.__new__(Training)
+            super(Training, obj).__init__()
+            obj.model = model
+            obj.train_set = None
+            obj.val_set = None
+            obj.test_set = None
+            obj.batch_size = 256
+            obj.epochs = 4
+
+            clbk = CallbackHandler()
+            hist = HistoryAdvanced()
+            epo_timing = EpoTimingCallback()
+            clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist")
+            lr = LearningRateDecay()
+            clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr")
+            clbk.add_callback(epo_timing, os.path.join(path, "epo_timing.pickle"), "epo_timing")
+            clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss',
+                                         save_best_only=True)
+            obj.callbacks = clbk
+            obj.lr_sc = lr
+            obj.hist = hist
+            obj.experiment_name = "TestExperiment"
+            obj.data_store.set("data_collection", data_collection, "general.train")
+            obj.data_store.set("data_collection", data_collection, "general.val")
+            obj.data_store.set("data_collection", data_collection, "general.test")
+            obj.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
+            if not os.path.exists(path):
+                os.makedirs(path)
+            obj.data_store.set("experiment_path", path, "general")
+            os.makedirs(batch_path, exist_ok=True)
+            obj.data_store.set("batch_path", batch_path, "general")
+            os.makedirs(model_path, exist_ok=True)
+            obj.data_store.set("model_path", model_path, "general")
+            obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model")
+            obj.data_store.set("experiment_name", "TestExperiment", "general")
+
+            path_plot = os.path.join(path, "plots")
+            os.makedirs(path_plot, exist_ok=True)
+            obj.data_store.set("plot_path", path_plot, "general")
+            obj._train_model = True
+            obj._create_new_model = False
+
+
+            assert obj._run() is None
+            assert 1 == 1
+        assert 1 == 1
+
+