diff --git a/mlair/model_modules/abstract_model_class.py b/mlair/model_modules/abstract_model_class.py
index 8898a6b2d0591328f2bb7010ccbfe144a48ca40b..4dc9521abf3569eb57249286e92c1e6a259c667d 100644
--- a/mlair/model_modules/abstract_model_class.py
+++ b/mlair/model_modules/abstract_model_class.py
@@ -38,6 +38,11 @@ class AbstractModelClass(ABC):
         self._input_shape = input_shape
         self._output_shape = self.__extract_from_tuple(output_shape)
 
+    def load_model(self, name: str):
+        hist = self.model.history
+        self.model = keras.models.load_model(name)
+        self.model.history = hist
+
     def __getattr__(self, name: str) -> Any:
         """
         Is called if __getattribute__ is not able to find requested attribute.
diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py
index a64ce917393dda9605e81cbef1dfd60fc57beaa5..98263eb732d8067fba0950c7a4882fb3ef020995 100644
--- a/mlair/run_modules/model_setup.py
+++ b/mlair/run_modules/model_setup.py
@@ -134,7 +134,7 @@ class ModelSetup(RunEnvironment):
     def load_model(self):
         """Try to load model from disk or skip if not possible."""
         try:
-            self.model = keras.models.load_model(self.model_name)
+            self.model.load_model(self.model_name)
             logging.info(f"reload model {self.model_name} from disk ...")
         except OSError:
             logging.info('no local model to load...')
diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py
index 64b323fa74b3770698d40364ee7defae88a01b4c..0d875766926e870349337a0597e2b3612a93ee07 100644
--- a/mlair/run_modules/training.py
+++ b/mlair/run_modules/training.py
@@ -149,7 +149,7 @@ class Training(RunEnvironment):
             logging.info("Found locally stored model and checkpoints. Training is resumed from the last checkpoint.")
             self.callbacks.load_callbacks()
             self.callbacks.update_checkpoint()
-            self.model = keras.models.load_model(checkpoint.filepath)
+            self.model.load_model(checkpoint.filepath)
             hist: History = self.callbacks.get_callback_by_name("hist")
             initial_epoch = max(hist.epoch) + 1
             _ = self.model.fit(self.train_set,
@@ -179,6 +179,7 @@ class Training(RunEnvironment):
         model_name = self.data_store.get("model_name", "model")
         logging.debug(f"save best model to {model_name}")
         self.model.save(model_name, save_format='h5')
+        self.model.save(model_name)
         self.data_store.set("best_model", self.model)
 
     def load_best_model(self, name: str) -> None:
@@ -189,7 +190,7 @@ class Training(RunEnvironment):
         """
         logging.debug(f"load best model: {name}")
         try:
-            self.model = keras.models.load_model(name)
+            self.model.load_model(name)
             logging.info('reload weights...')
         except OSError:
             logging.info('no weights to reload...')
diff --git a/test/test_run_modules/test_training.py b/test/test_run_modules/test_training.py
index 46a5ba92b78cd4711e8f48c8df8e3721cf5d52cc..44e664e4f47dfd842ed956fcf7f7e56becb758ef 100644
--- a/test/test_run_modules/test_training.py
+++ b/test/test_run_modules/test_training.py
@@ -13,6 +13,7 @@ from tensorflow.keras.callbacks import History
 
 from mlair.data_handler import DataCollection, KerasIterator, DefaultDataHandler
 from mlair.helpers import PyTestRegex
+from mlair.model_modules.fully_connected_networks import FCN
 from mlair.model_modules.flatten import flatten_tail
 from mlair.model_modules.inception_model import InceptionModelBase
 from mlair.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced, CallbackHandler, EpoTimingCallback
@@ -160,7 +161,10 @@ class TestTraining:
     @pytest.fixture
     def model(self, window_history_size, window_lead_time, statistics_per_var):
         channels = len(list(statistics_per_var.keys()))
-        return my_test_model(keras.layers.PReLU, window_history_size, channels, window_lead_time, 0.1, False)
+
+        return FCN([(window_history_size + 1, 1, channels)], [window_lead_time])
+
+        # return my_test_model(keras.layers.PReLU, window_history_size, channels, window_lead_time, 0.1, False)
 
     @pytest.fixture
     def callbacks(self, path):