diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py
index a5fdcea7db6c1a2d6ca26f4b9d7c9d365aa45924..8b99acd0f5723d3b00ec1bd0098712753da21b52 100644
--- a/mlair/model_modules/keras_extensions.py
+++ b/mlair/model_modules/keras_extensions.py
@@ -353,7 +353,7 @@ class CallbackHandler:
         for pos, callback in enumerate(self.__callbacks):
             path = callback["path"]
             clb = pickle.load(open(path, "rb"))
-            if clb.model is None:
+            if clb.model is None and hasattr(self._checkpoint, "model"):
                 clb.model = self._checkpoint.model
             self._update_callback(pos, clb)