diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py
index 66697e07d68a01ffa203798c11417b8440d54214..e3945a542d60b09dc9855bd28be87cdba729ed72 100644
--- a/src/run_modules/model_setup.py
+++ b/src/run_modules/model_setup.py
@@ -7,14 +7,11 @@ import os
 
 import keras
 import tensorflow as tf
-from keras import losses
 
-from src.helpers import l_p_loss
-from src.model_modules.flatten import flatten_tail
-from src.model_modules.inception_model import InceptionModelBase
-from src.model_modules.keras_extensions import HistoryAdvanced, ModelCheckpointAdvanced
+from src.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler
 # from src.model_modules.model_class import MyBranchedModel as MyModel
-from src.model_modules.model_class import MyLittleModel as MyModel
+# from src.model_modules.model_class import MyLittleModel as MyModel
+from src.model_modules.model_class import MyTowerModel as MyModel
 from src.run_modules.run_environment import RunEnvironment
 
 
@@ -52,7 +49,7 @@ class ModelSetup(RunEnvironment):
             self.load_weights()
 
         # create checkpoint
-        self._set_checkpoint()
+        self._set_callbacks()
 
         # compile model
         self.compile_model()
@@ -67,19 +64,20 @@ class ModelSetup(RunEnvironment):
         self.model.compile(optimizer=optimizer, loss=loss, metrics=["mse", "mae"])
         self.data_store.set("model", self.model, self.scope)
 
-    def _set_checkpoint(self):
+    def _set_callbacks(self):
         """
-        Must be run after all callback functions that shall be tracked during training have been created (currently this
-        affects the learning rate decay and the advanced history [actually created in this method]).
+        Set all callbacks for the training phase. Add all callbacks with the .add_callback statement. Finally, the
+        advanced model checkpoint is added.
         """
         lr = self.data_store.get("lr_decay", scope="general.model")
         hist = HistoryAdvanced()
         self.data_store.set("hist", hist, scope="general.model")
-        callbacks = [{"callback": lr, "path": self.callbacks_name % "lr"},
-                     {"callback": hist, "path": self.callbacks_name % "hist"}]
-        checkpoint = ModelCheckpointAdvanced(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
-                                             save_best_only=True, mode='auto', callbacks=callbacks)
-        self.data_store.set("checkpoint", checkpoint, self.scope)
+        callbacks = CallbackHandler()
+        callbacks.add_callback(lr, self.callbacks_name % "lr", "lr")
+        callbacks.add_callback(hist, self.callbacks_name % "hist", "hist")
+        callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
+                                          save_best_only=True, mode='auto')
+        self.data_store.set("callbacks", callbacks, self.scope)
 
     def load_weights(self):
         try:
diff --git a/src/run_modules/training.py b/src/run_modules/training.py
index 0b11da8d8f9a23c51d787f00d00a74d7517ea3b1..7a522af0298bcabee62579f68bd29ed123cac7b0 100644
--- a/src/run_modules/training.py
+++ b/src/run_modules/training.py
@@ -9,7 +9,7 @@ import pickle
 import keras
 
 from src.data_handling.data_distributor import Distributor
-from src.model_modules.keras_extensions import LearningRateDecay, ModelCheckpointAdvanced
+from src.model_modules.keras_extensions import LearningRateDecay, ModelCheckpointAdvanced, CallbackHandler
 from src.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate
 from src.run_modules.run_environment import RunEnvironment
 
@@ -24,9 +24,7 @@ class Training(RunEnvironment):
         self.test_set = None
         self.batch_size = self.data_store.get("batch_size", "general.model")
         self.epochs = self.data_store.get("epochs", "general.model")
-        self.checkpoint: ModelCheckpointAdvanced = self.data_store.get("checkpoint", "general.model")
-        self.lr_sc = self.data_store.get("lr_decay", "general.model")
-        self.hist = self.data_store.get("hist", "general.model")
+        self.callbacks: CallbackHandler = self.data_store.get("callbacks", "general.model")
         self.experiment_name = self.data_store.get("experiment_name", "general")
         self._trainable = self.data_store.get("trainable", "general")
         self._create_new_model = self.data_store.get("create_new_model", "general")
@@ -87,38 +85,35 @@ class Training(RunEnvironment):
         locally stored information and the corresponding model and proceed with the already started training.
         """
         logging.info(f"Train with {len(self.train_set)} mini batches.")
-        if not os.path.exists(self.checkpoint.filepath) or self._create_new_model:
+        checkpoint = self.callbacks.get_checkpoint()
+        if not os.path.exists(checkpoint.filepath) or self._create_new_model:
             history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
                                                steps_per_epoch=len(self.train_set),
                                                epochs=self.epochs,
                                                verbose=2,
                                                validation_data=self.val_set.distribute_on_batches(),
                                                validation_steps=len(self.val_set),
-                                               callbacks=[self.lr_sc, self.hist, self.checkpoint])
+                                               callbacks=self.callbacks.get_callbacks(as_dict=False))
         else:
             logging.info("Found locally stored model and checkpoints. Training is resumed from the last checkpoint.")
-            lr_filepath = self.checkpoint.callbacks[0]["path"]
-            hist_filepath = self.checkpoint.callbacks[1]["path"]
-            self.lr_sc = pickle.load(open(lr_filepath, "rb"))
-            self.hist = pickle.load(open(hist_filepath, "rb"))
-            self.model = keras.models.load_model(self.checkpoint.filepath)
-            initial_epoch = max(self.hist.epoch) + 1
-            callbacks = [{"callback": self.lr_sc, "path": lr_filepath},
-                         {"callback": self.hist, "path": hist_filepath}]
-            self.checkpoint.update_callbacks(callbacks)
-            self.checkpoint.update_best(self.hist)
+            self.callbacks.load_callbacks()
+            self.callbacks.update_checkpoint()
+            self.model = keras.models.load_model(checkpoint.filepath)
+            hist = self.callbacks.get_callback_by_name("hist")
+            initial_epoch = max(hist.epoch) + 1
             _ = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
                                          steps_per_epoch=len(self.train_set),
                                          epochs=self.epochs,
                                          verbose=2,
                                          validation_data=self.val_set.distribute_on_batches(),
                                          validation_steps=len(self.val_set),
-                                         callbacks=[self.lr_sc, self.hist, self.checkpoint],
+                                         callbacks=self.callbacks.get_callbacks(as_dict=False),
                                          initial_epoch=initial_epoch)
-            history = self.hist
-        self.save_callbacks_as_json(history)
-        self.load_best_model(self.checkpoint.filepath)
-        self.create_monitoring_plots(history, self.lr_sc)
+            history = hist
+        lr = self.callbacks.get_callback_by_name("lr")
+        self.save_callbacks_as_json(history, lr)
+        self.load_best_model(checkpoint.filepath)
+        self.create_monitoring_plots(history, lr)
 
     def save_model(self) -> None:
         """
@@ -141,7 +136,7 @@ class Training(RunEnvironment):
         except OSError:
             logging.info('no weights to reload...')
 
-    def save_callbacks_as_json(self, history: keras.callbacks.History) -> None:
+    def save_callbacks_as_json(self, history: keras.callbacks.History, lr_sc: keras.callbacks) -> None:
         """
         Save callbacks (history, learning rate) of training.
         * history.history -> history.json
@@ -153,7 +148,7 @@ class Training(RunEnvironment):
         with open(os.path.join(path, "history.json"), "w") as f:
             json.dump(history.history, f)
         with open(os.path.join(path, "history_lr.json"), "w") as f:
-            json.dump(self.lr_sc.lr, f)
+            json.dump(lr_sc.lr, f)
 
     def create_monitoring_plots(self, history: keras.callbacks.History, lr_sc: LearningRateDecay) -> None:
         """
diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py
index 35c5f8ee7581856a9feee3abd0face73ee83952c..ade35a244601d138d22af6305e67b5aeae964680 100644
--- a/test/test_modules/test_model_setup.py
+++ b/test/test_modules/test_model_setup.py
@@ -56,11 +56,11 @@ class TestModelSetup:
     def current_scope_as_set(model_cls):
         return set(model_cls.data_store.search_scope(model_cls.scope, current_scope_only=True))
 
-    def test_set_checkpoint(self, setup):
-        assert "general.modeltest" not in setup.data_store.search_name("checkpoint")
+    def test_set_callbacks(self, setup):
+        assert "general.modeltest" not in setup.data_store.search_name("callbacks")
         setup.checkpoint_name = "TestName"
-        setup._set_checkpoint()
-        assert "general.modeltest" in setup.data_store.search_name("checkpoint")
+        setup._set_callbacks()
+        assert "general.modeltest" in setup.data_store.search_name("callbacks")
 
     def test_get_model_settings(self, setup_with_model):
         with pytest.raises(EmptyScope):
diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py
index ac040c3a286c25dc84853c26c8509278642a1495..31c673f05d055eb7c4ee76318711de030d97d480 100644
--- a/test/test_modules/test_training.py
+++ b/test/test_modules/test_training.py
@@ -14,7 +14,7 @@ from src.data_handling.data_generator import DataGenerator
 from src.helpers import PyTestRegex
 from src.model_modules.flatten import flatten_tail
 from src.model_modules.inception_model import InceptionModelBase
-from src.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced
+from src.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced, CallbackHandler
 from src.run_modules.run_environment import RunEnvironment
 from src.run_modules.training import Training
 
@@ -39,7 +39,7 @@ def my_test_model(activation, window_history_size, channels, dropout_rate, add_m
 class TestTraining:
 
     @pytest.fixture
-    def init_without_run(self, path: str, model: keras.Model, checkpoint: ModelCheckpoint):
+    def init_without_run(self, path: str, model: keras.Model, callbacks: CallbackHandler):
         obj = object.__new__(Training)
         super(Training, obj).__init__()
         obj.model = model
@@ -48,9 +48,10 @@ class TestTraining:
         obj.test_set = None
         obj.batch_size = 256
         obj.epochs = 2
-        obj.checkpoint = checkpoint
-        obj.lr_sc = LearningRateDecay()
-        obj.hist = HistoryAdvanced()
+        clbk, hist, lr = callbacks
+        obj.callbacks = clbk
+        obj.lr_sc = lr
+        obj.hist = hist
         obj.experiment_name = "TestExperiment"
         obj.data_store.set("generator", mock.MagicMock(return_value="mock_train_gen"), "general.train")
         obj.data_store.set("generator", mock.MagicMock(return_value="mock_val_gen"), "general.val")
@@ -70,12 +71,9 @@ class TestTraining:
 
     @pytest.fixture
     def learning_rate(self):
-        return {"lr": [0.01, 0.0094]}
-
-    @pytest.fixture
-    def init_with_lr(self, init_without_run, learning_rate):
-        init_without_run.lr_sc.lr = learning_rate
-        return init_without_run
+        lr = LearningRateDecay()
+        lr.lr = {"lr": [0.01, 0.0094]}
+        return lr
 
     @pytest.fixture
     def history(self):
@@ -105,8 +103,15 @@ class TestTraining:
         return my_test_model(keras.layers.PReLU, 7, 2, 0.1, False)
 
     @pytest.fixture
-    def checkpoint(self, path):
-        return ModelCheckpoint(os.path.join(path, "model_checkpoint"), monitor='val_loss', save_best_only=True)
+    def callbacks(self, path):
+        clbk = CallbackHandler()
+        hist = HistoryAdvanced()
+        clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist")
+        lr = LearningRateDecay()
+        clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr")
+        clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss',
+                                     save_best_only=True)
+        return clbk, hist, lr
 
     @pytest.fixture
     def ready_to_train(self, generator: DataGenerator, init_without_run: Training):
@@ -125,7 +130,7 @@ class TestTraining:
         return obj
 
     @pytest.fixture
-    def ready_to_init(self, generator, model, checkpoint, path):
+    def ready_to_init(self, generator, model, callbacks, path):
         os.makedirs(path)
         obj = RunEnvironment()
         obj.data_store.set("generator", generator, "general.train")
@@ -136,14 +141,14 @@ class TestTraining:
         obj.data_store.set("model_name", os.path.join(path, "test_model.h5"), "general.model")
         obj.data_store.set("batch_size", 256, "general.model")
         obj.data_store.set("epochs", 2, "general.model")
-        obj.data_store.set("checkpoint", checkpoint, "general.model")
-        obj.data_store.set("lr_decay", LearningRateDecay(), "general.model")
-        obj.data_store.set("hist", HistoryAdvanced(), "general.model")
+        clbk, hist, lr = callbacks
+        obj.data_store.set("callbacks", clbk, "general.model")
+        obj.data_store.set("lr_decay", lr, "general.model")
+        obj.data_store.set("hist", hist, "general.model")
         obj.data_store.set("experiment_name", "TestExperiment", "general")
         obj.data_store.set("experiment_path", path, "general")
         obj.data_store.set("trainable", True, "general")
-        obj.data_store.set("create_new_model"
-                           "", True, "general")
+        obj.data_store.set("create_new_model", True, "general")
         path_plot = os.path.join(path, "plots")
         os.makedirs(path_plot)
         obj.data_store.set("plot_path", path_plot, "general")
@@ -197,25 +202,25 @@ class TestTraining:
         assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting"))
         assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload..."))
 
-    def test_save_callbacks_history_created(self, init_without_run, history, path):
-        init_without_run.save_callbacks_as_json(history)
+    def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, path):
+        init_without_run.save_callbacks_as_json(history, learning_rate)
         assert "history.json" in os.listdir(path)
 
-    def test_save_callbacks_lr_created(self, init_with_lr, history, path):
-        init_with_lr.save_callbacks_as_json(history)
+    def test_save_callbacks_lr_created(self, init_without_run, history, learning_rate, path):
+        init_without_run.save_callbacks_as_json(history, learning_rate)
         assert "history_lr.json" in os.listdir(path)
 
-    def test_save_callbacks_inspect_history(self, init_without_run, history, path):
-        init_without_run.save_callbacks_as_json(history)
+    def test_save_callbacks_inspect_history(self, init_without_run, history, learning_rate, path):
+        init_without_run.save_callbacks_as_json(history, learning_rate)
         with open(os.path.join(path, "history.json")) as jfile:
             hist = json.load(jfile)
             assert hist == history.history
 
-    def test_save_callbacks_inspect_lr(self, init_with_lr, history, path):
-        init_with_lr.save_callbacks_as_json(history)
+    def test_save_callbacks_inspect_lr(self, init_without_run, history, learning_rate, path):
+        init_without_run.save_callbacks_as_json(history, learning_rate)
         with open(os.path.join(path, "history_lr.json")) as jfile:
             lr = json.load(jfile)
-            assert lr == init_with_lr.lr_sc.lr
+            assert lr == learning_rate.lr
 
     def test_create_monitoring_plots(self, init_without_run, learning_rate, history, path):
         assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0