diff --git a/src/model_modules/keras_extensions.py b/src/model_modules/keras_extensions.py
index cfb8638816e85ec8427f3df5d45d480d21fd929f..180e324602da25e1df8fb218c1d3bba180004ac8 100644
--- a/src/model_modules/keras_extensions.py
+++ b/src/model_modules/keras_extensions.py
@@ -150,3 +150,64 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
                         if self.verbose > 0:  # pragma: no branch
                             print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
                         pickle.dump(callback["callback"], f)
+
+
+class CallbackHandler:
+
+    def __init__(self):
+        self.__callbacks = []
+        self._checkpoint = None
+        self.editable = True
+
+    @property
+    def _callbacks(self):
+        return [{"callback": clbk[clbk["name"]], "path": clbk["path"]} for clbk in self.__callbacks]
+
+    @_callbacks.setter
+    def _callbacks(self, value):
+        name, callback, callback_path = value
+        self.__callbacks.append({"name": name, name: callback, "path": callback_path})
+
+    def _update_callback(self, pos, value):
+        name = self.__callbacks[pos]["name"]
+        self.__callbacks[pos][name] = value
+
+    def add_callback(self, callback, callback_path, name="callback"):
+        if self.editable:
+            self._callbacks = (name, callback, callback_path)
+        else:
+            raise PermissionError(f"{__class__.__name__} is protected and cannot be edited.")
+
+    def get_callbacks(self, as_dict=True):
+        if as_dict:
+            return self._get_callbacks()
+        else:
+            return [clb["callback"] for clb in self._get_callbacks()]
+
+    def get_callback_by_name(self, obj_name):
+        if obj_name != "callback":
+            return [clbk[clbk["name"]] for clbk in self.__callbacks if clbk["name"] == obj_name][0]
+
+    def _get_callbacks(self):
+        clbks = self._callbacks
+        if self._checkpoint is not None:
+            clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath}]
+        return clbks
+
+    def get_checkpoint(self):
+        if self._checkpoint is not None:
+            return self._checkpoint
+
+    def create_model_checkpoint(self, **kwargs):
+        self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, **kwargs)
+        self.editable = False
+
+    def load_callbacks(self):
+        for pos, callback in enumerate(self.__callbacks):
+            path = callback["path"]
+            clb = pickle.load(open(path, "rb"))
+            self._update_callback(pos, clb)
+
+    def update_checkpoint(self, history_name="hist"):
+        self._checkpoint.update_callbacks(self._callbacks)
+        self._checkpoint.update_best(self.get_callback_by_name(history_name))
diff --git a/test/test_model_modules/test_keras_extensions.py b/test/test_model_modules/test_keras_extensions.py
index 9141434d448f8d076836ef357279ae5815686767..17ab4f6d65c95a5a54c9d931818f889acadef532 100644
--- a/test/test_model_modules/test_keras_extensions.py
+++ b/test/test_model_modules/test_keras_extensions.py
@@ -110,3 +110,124 @@ class TestModelCheckpointAdvanced:
         assert "callback_hist" in os.listdir(path)
         os.remove(os.path.join(path, "callback_hist"))
         os.remove(os.path.join(path, "callback_lr"))
+
+
+class TestCallbackHandler:
+
+    @pytest.fixture
+    def clbk_handler(self):
+        return CallbackHandler()
+
+    @pytest.fixture
+    def clbk_handler_with_dummies(self, clbk_handler):
+        clbk_handler.add_callback("callback_new_instance", "this_path")
+        clbk_handler.add_callback("callback_other", "otherpath", "other_clbk")
+        return clbk_handler
+
+    @pytest.fixture
+    def callback_handler(self, clbk_handler):
+        clbk_handler.add_callback(HistoryAdvanced(), "callbacks_hist.pickle", "hist")
+        clbk_handler.add_callback(LearningRateDecay(), "callbacks_lr.pickle", "lr")
+        return clbk_handler
+
+    @pytest.fixture
+    def prepare_pickle_files(self):
+        hist = HistoryAdvanced()
+        hist.epoch = [1, 2, 3]
+        hist.history = {"val_loss": [10, 5, 4]}
+        lr = LearningRateDecay()
+        lr.epoch = [1, 2, 3]
+        pickle.dump(hist, open("callbacks_hist.pickle", "wb"))
+        pickle.dump(lr, open("callbacks_lr.pickle", "wb"))
+        yield
+        os.remove("callbacks_hist.pickle")
+        os.remove("callbacks_lr.pickle")
+
+    def test_init(self, clbk_handler):
+        assert len(clbk_handler._CallbackHandler__callbacks) == 0
+        assert clbk_handler._checkpoint is None
+        assert clbk_handler.editable is True
+
+    def test_callbacks_set(self, clbk_handler):
+        clbk_handler._callbacks = ("default", "callback_instance", "callback_path")
+        assert clbk_handler._CallbackHandler__callbacks == [{"name": "default", "default": "callback_instance",
+                                                             "path": "callback_path"}]
+        clbk_handler._callbacks = ("another", "callback_instance2", "callback_path")
+        assert clbk_handler._CallbackHandler__callbacks == [{"name": "default", "default": "callback_instance",
+                                                             "path": "callback_path"},
+                                                            {"name": "another", "another": "callback_instance2",
+                                                             "path": "callback_path"}]
+
+    def test_callbacks_get(self, clbk_handler):
+        clbk_handler._callbacks = ("default", "callback_instance", "callback_path")
+        clbk_handler._callbacks = ("another", "callback_instance2", "callback_path2")
+        assert clbk_handler._callbacks == [{"callback": "callback_instance", "path": "callback_path"},
+                                           {"callback": "callback_instance2", "path": "callback_path2"}]
+
+    def test_update_callback(self, clbk_handler_with_dummies):
+        clbk_handler_with_dummies._update_callback(0, "old_instance")
+        assert clbk_handler_with_dummies.get_callbacks() == [{"callback": "old_instance", "path": "this_path"},
+                                                             {"callback": "callback_other", "path": "otherpath"}]
+
+    def test_add_callback(self, clbk_handler):
+        clbk_handler.add_callback("callback_new_instance", "this_path")
+        assert clbk_handler._CallbackHandler__callbacks == [{"name": "callback", "callback": "callback_new_instance",
+                                                             "path": "this_path"}]
+        clbk_handler.add_callback("callback_other", "otherpath", "other_clbk")
+        assert clbk_handler._CallbackHandler__callbacks == [{"name": "callback", "callback": "callback_new_instance",
+                                                             "path": "this_path"},
+                                                            {"name": "other_clbk", "other_clbk": "callback_other",
+                                                             "path": "otherpath"}]
+
+    def test_get_callbacks_as_dict(self, clbk_handler_with_dummies):
+        clbk = clbk_handler_with_dummies
+        assert clbk.get_callbacks() == [{"callback": "callback_new_instance", "path": "this_path"},
+                                        {"callback": "callback_other", "path": "otherpath"}]
+        assert clbk.get_callbacks() == clbk.get_callbacks(as_dict=True)
+
+    def test_get_callbacks_no_dict(self, clbk_handler_with_dummies):
+        assert clbk_handler_with_dummies.get_callbacks(as_dict=False) == ["callback_new_instance", "callback_other"]
+
+    def test_get_callback_by_name(self, clbk_handler_with_dummies):
+        assert clbk_handler_with_dummies.get_callback_by_name("other_clbk") == "callback_other"
+        assert clbk_handler_with_dummies.get_callback_by_name("callback") is None
+
+    def test__get_callbacks(self, clbk_handler_with_dummies):
+        clbk = clbk_handler_with_dummies
+        assert clbk._get_callbacks() == [{"callback": "callback_new_instance", "path": "this_path"},
+                                         {"callback": "callback_other", "path": "otherpath"}]
+        ckpt = keras.callbacks.ModelCheckpoint("testFilePath")
+        clbk._checkpoint = ckpt
+        assert clbk._get_callbacks() == [{"callback": "callback_new_instance", "path": "this_path"},
+                                         {"callback": "callback_other", "path": "otherpath"},
+                                         {"callback": ckpt, "path": "testFilePath"}]
+
+    def test_get_checkpoint(self, clbk_handler):
+        assert clbk_handler.get_checkpoint() is None
+        clbk_handler._checkpoint = "testCKPT"
+        assert clbk_handler.get_checkpoint() == "testCKPT"
+
+    def test_create_model_checkpoint(self, callback_handler):
+        callback_handler.create_model_checkpoint(filepath="tester_path", verbose=1)
+        assert callback_handler.editable is False
+        assert isinstance(callback_handler._checkpoint, ModelCheckpointAdvanced)
+        assert callback_handler._checkpoint.filepath == "tester_path"
+        assert callback_handler._checkpoint.verbose == 1
+        assert callback_handler._checkpoint.monitor == "val_loss"
+
+    def test_load_callbacks(self, callback_handler, prepare_pickle_files):
+        assert len(callback_handler.get_callback_by_name("hist").epoch) == 0
+        assert len(callback_handler.get_callback_by_name("lr").epoch) == 0
+        callback_handler.load_callbacks()
+        assert len(callback_handler.get_callback_by_name("hist").epoch) == 3
+        assert len(callback_handler.get_callback_by_name("lr").epoch) == 3
+
+    def test_update_checkpoint(self, callback_handler, prepare_pickle_files):
+        assert len(callback_handler.get_callback_by_name("hist").epoch) == 0
+        assert len(callback_handler.get_callback_by_name("lr").epoch) == 0
+        callback_handler.create_model_checkpoint(filepath="tester_path", verbose=1)
+        callback_handler.load_callbacks()
+        callback_handler.update_checkpoint()
+        assert len(callback_handler.get_callback_by_name("hist").epoch) == 3
+        assert len(callback_handler.get_callback_by_name("lr").epoch) == 3
+        assert callback_handler._checkpoint.best == 4