diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py
index 33358e566ef80f28ee7740531b71d1a83abde115..7781a0d2cb10f474240a30a6b7fd06a72b1754fa 100644
--- a/mlair/model_modules/keras_extensions.py
+++ b/mlair/model_modules/keras_extensions.py
@@ -9,9 +9,12 @@ import pickle
 from typing import Union, List
 from typing_extensions import TypedDict
 
+from mlair.helpers import TimeTracking
+
 import numpy as np
 from keras import backend as K
 from keras.callbacks import History, ModelCheckpoint, Callback
+import keras
 
 from mlair import helpers
 
@@ -147,6 +150,7 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
     def __init__(self, *args, **kwargs):
         """Initialise ModelCheckpointAdvanced and set callbacks attribute."""
         self.callbacks = kwargs.pop("callbacks")
+        self.custom_objects = kwargs.pop("custom_objects")
         super().__init__(*args, **kwargs)
 
     def update_best(self, hist):
@@ -174,17 +178,29 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
     def on_epoch_end(self, epoch, logs=None):
         """Save model as usual (see ModelCheckpoint class), but also save additional callbacks."""
         super().on_epoch_end(epoch, logs)
-
         for callback in self.callbacks:
+            print(callback.keys())
             file_path = callback["path"]
             if self.epochs_since_last_save == 0 and epoch != 0:
                 if self.save_best_only:
                     current = logs.get(self.monitor)
+                    model_save = None
+                    if hasattr(callback["callback"], "model"):
+                        # ToDo: store model in cache?
+                        callback["callback"].model.save(
+                            callback["path"].rsplit(".", maxsplit=1)[0] + "model_save_tmp.h5")
+                        callback["callback"].model = None
                     if current == self.best:
                         if self.verbose > 0:  # pragma: no branch
                             print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
                         with open(file_path, "wb") as f:
+                            # ToDo: create "save" method
                             pickle.dump(callback["callback"], f)
+                    if callback["callback"].model is None:
+                        with TimeTracking("load_model"):
+                            callback["callback"].model = keras.models.load_model(
+                                callback["path"].rsplit(".", maxsplit=1)[0] + "model_save_tmp.h5",
+                                custom_objects=callback["custom_objects"])
                 else:
                     with open(file_path, "wb") as f:
                         if self.verbose > 0:  # pragma: no branch
@@ -253,12 +269,17 @@ class CallbackHandler:
     def __init__(self):
         """Initialise CallbackHandler."""
         self.__callbacks: List[clbk_type] = []
+        self.custom_objects = {}
         self._checkpoint = None
         self.editable = True
 
+    def add_custom_objects(self, custom_objects):
+        self.custom_objects = custom_objects
+
     @property
     def _callbacks(self):
-        return [{"callback": clbk[clbk["name"]], "path": clbk["path"]} for clbk in self.__callbacks]
+        return [{"callback": clbk[clbk["name"]], "path": clbk["path"], "custom_objects": self.custom_objects} for clbk
+                in self.__callbacks]
 
     @_callbacks.setter
     def _callbacks(self, value):
@@ -313,7 +334,8 @@ class CallbackHandler:
         """Return all callbacks and append checkpoint if available on last position."""
         clbks = self._callbacks
         if self._checkpoint is not None:
-            clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath}]
+            clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath,
+                       "custom_objects": self.custom_objects}]
         return clbks
 
     def get_checkpoint(self) -> ModelCheckpointAdvanced:
@@ -323,7 +345,8 @@ class CallbackHandler:
 
     def create_model_checkpoint(self, **kwargs):
         """Create a model checkpoint and enable edit."""
-        self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, **kwargs)
+        self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, custom_objects=self.custom_objects,
+                                                   **kwargs)
         self.editable = False
 
     def load_callbacks(self) -> None:
diff --git a/mlair/model_modules/model_class.py b/mlair/model_modules/model_class.py
index 07eaa1ec19fbd394ad0bffd262d2698ccfd81055..fb2902dce93d37b52e91804ab92ed895d85a876a 100644
--- a/mlair/model_modules/model_class.py
+++ b/mlair/model_modules/model_class.py
@@ -669,10 +669,10 @@ class MyPaperModel(AbstractModelClass):
         conv_settings_dict1 = {
             'tower_1': {'reduction_filter': 8, 'tower_filter': 16 * 2, 'tower_kernel': (3, 1),
                         'activation': activation},
-            'tower_2': {'reduction_filter': 8, 'tower_filter': 16 * 2, 'tower_kernel': (5, 1),
-                        'activation': activation},
-            'tower_3': {'reduction_filter': 8, 'tower_filter': 16 * 2, 'tower_kernel': (1, 1),
-                        'activation': activation},
+            # 'tower_2': {'reduction_filter': 8, 'tower_filter': 16 * 2, 'tower_kernel': (5, 1),
+            #             'activation': activation},
+            # 'tower_3': {'reduction_filter': 8, 'tower_filter': 16 * 2, 'tower_kernel': (1, 1),
+            #             'activation': activation},
             # 'tower_4':{'reduction_filter':8, 'tower_filter':8*2, 'tower_kernel':(7,1), 'activation':activation},
         }
         pool_settings_dict1 = {'pool_kernel': (3, 1), 'tower_filter': 16, 'activation': activation}
@@ -680,10 +680,10 @@ class MyPaperModel(AbstractModelClass):
         conv_settings_dict2 = {
             'tower_1': {'reduction_filter': 64, 'tower_filter': 32 * 2, 'tower_kernel': (3, 1),
                         'activation': activation},
-            'tower_2': {'reduction_filter': 64, 'tower_filter': 32 * 2, 'tower_kernel': (5, 1),
-                        'activation': activation},
-            'tower_3': {'reduction_filter': 64, 'tower_filter': 32 * 2, 'tower_kernel': (1, 1),
-                        'activation': activation},
+            # 'tower_2': {'reduction_filter': 64, 'tower_filter': 32 * 2, 'tower_kernel': (5, 1),
+            #             'activation': activation},
+            # 'tower_3': {'reduction_filter': 64, 'tower_filter': 32 * 2, 'tower_kernel': (1, 1),
+            #             'activation': activation},
             # 'tower_4':{'reduction_filter':8*2, 'tower_filter':16*2, 'tower_kernel':(7,1), 'activation':activation},
         }
         pool_settings_dict2 = {'pool_kernel': (3, 1), 'tower_filter': 32, 'activation': activation}
diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py
index c6af13b02e818431578c7423d837f95e64ca3d15..26c61ff37748d6110752fe5afaf5122e2df16303 100644
--- a/mlair/run_modules/model_setup.py
+++ b/mlair/run_modules/model_setup.py
@@ -118,6 +118,7 @@ class ModelSetup(RunEnvironment):
         hist = HistoryAdvanced()
         self.data_store.set("hist", hist, scope="model")
         callbacks = CallbackHandler()
+        callbacks.add_custom_objects(self.model.custom_objects)
         if lr is not None:
             callbacks.add_callback(lr, self.callbacks_name % "lr", "lr")
         callbacks.add_callback(hist, self.callbacks_name % "hist", "hist")