diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py
index 7781a0d2cb10f474240a30a6b7fd06a72b1754fa..4140f15b7afb496452357793c8e4d4c914fd9f86 100644
--- a/mlair/model_modules/keras_extensions.py
+++ b/mlair/model_modules/keras_extensions.py
@@ -9,12 +9,9 @@ import pickle
 from typing import Union, List
 from typing_extensions import TypedDict
 
-from mlair.helpers import TimeTracking
-
 import numpy as np
 from keras import backend as K
 from keras.callbacks import History, ModelCheckpoint, Callback
-import keras
 
 from mlair import helpers
 
@@ -150,7 +147,6 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
     def __init__(self, *args, **kwargs):
         """Initialise ModelCheckpointAdvanced and set callbacks attribute."""
         self.callbacks = kwargs.pop("callbacks")
-        self.custom_objects = kwargs.pop("custom_objects")
         super().__init__(*args, **kwargs)
 
     def update_best(self, hist):
@@ -184,23 +180,12 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
             if self.epochs_since_last_save == 0 and epoch != 0:
                 if self.save_best_only:
                     current = logs.get(self.monitor)
-                    model_save = None
-                    if hasattr(callback["callback"], "model"):
-                        # ToDo: store model in cache?
-                        callback["callback"].model.save(
-                            callback["path"].rsplit(".", maxsplit=1)[0] + "model_save_tmp.h5")
-                        callback["callback"].model = None
                     if current == self.best:
                         if self.verbose > 0:  # pragma: no branch
                             print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
                         with open(file_path, "wb") as f:
                             # ToDo: create "save" method
                             pickle.dump(callback["callback"], f)
-                    if callback["callback"].model is None:
-                        with TimeTracking("load_model"):
-                            callback["callback"].model = keras.models.load_model(
-                                callback["path"].rsplit(".", maxsplit=1)[0] + "model_save_tmp.h5",
-                                custom_objects=callback["custom_objects"])
                 else:
                     with open(file_path, "wb") as f:
                         if self.verbose > 0:  # pragma: no branch
@@ -269,16 +254,12 @@ class CallbackHandler:
     def __init__(self):
         """Initialise CallbackHandler."""
         self.__callbacks: List[clbk_type] = []
-        self.custom_objects = {}
         self._checkpoint = None
         self.editable = True
 
-    def add_custom_objects(self, custom_objects):
-        self.custom_objects = custom_objects
-
     @property
     def _callbacks(self):
-        return [{"callback": clbk[clbk["name"]], "path": clbk["path"], "custom_objects": self.custom_objects} for clbk
+        return [{"callback": clbk[clbk["name"]], "path": clbk["path"]} for clbk
                 in self.__callbacks]
 
     @_callbacks.setter
@@ -334,8 +315,7 @@ class CallbackHandler:
         """Return all callbacks and append checkpoint if available on last position."""
         clbks = self._callbacks
         if self._checkpoint is not None:
-            clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath,
-                       "custom_objects": self.custom_objects}]
+            clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath}]
         return clbks
 
     def get_checkpoint(self) -> ModelCheckpointAdvanced:
@@ -345,8 +325,7 @@ class CallbackHandler:
 
     def create_model_checkpoint(self, **kwargs):
         """Create a model checkpoint and enable edit."""
-        self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, custom_objects=self.custom_objects,
-                                                   **kwargs)
+        self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, **kwargs)
         self.editable = False
 
     def load_callbacks(self) -> None:
diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py
index 26c61ff37748d6110752fe5afaf5122e2df16303..0ce13790a10999d69de7a39676d2a1f92997bbf9 100644
--- a/mlair/run_modules/model_setup.py
+++ b/mlair/run_modules/model_setup.py
@@ -118,7 +118,6 @@ class ModelSetup(RunEnvironment):
         hist = HistoryAdvanced()
         self.data_store.set("hist", hist, scope="model")
         callbacks = CallbackHandler()
-        callbacks.add_custom_objects(self.model.custom_objects)
         if lr is not None:
             callbacks.add_callback(lr, self.callbacks_name % "lr", "lr")
         callbacks.add_callback(hist, self.callbacks_name % "hist", "hist")
@@ -141,6 +140,7 @@ class ModelSetup(RunEnvironment):
         model = self.data_store.get("model_class")
         self.model = model(**args)
         self.get_model_settings()
+        keras.utils.get_custom_objects().update(self.model.custom_objects)
 
     def get_model_settings(self):
         """Load all model settings and store in data store."""