diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py
index 1a5dd9da38520d6d732253015e8a67325e24c460..ebbd7a25cef9031436d932a6502c9726bfe3e318 100644
--- a/src/model_modules/model_class.py
+++ b/src/model_modules/model_class.py
@@ -8,6 +8,8 @@ from abc import ABC
 from typing import Any, Callable
 
 import keras
+from src.model_modules.inception_model import InceptionModelBase
+from src.model_modules.flatten import flatten_tail
 
 
 class AbstractModelClass(ABC):
@@ -240,3 +242,112 @@ class MyBranchedModel(AbstractModelClass):
 
         self.loss = [keras.losses.mean_absolute_error] + [keras.losses.mean_squared_error] + \
                     [keras.losses.mean_squared_error]
+
+
+class MyTowerModel(AbstractModelClass):
+
+    def __init__(self, window_history_size, window_lead_time, channels):
+
+        """
+        Sets model and loss depending on the given arguments.
+        :param activation: activation function
+        :param window_history_size: number of historical time steps included in the input data
+        :param channels: number of variables used in input data
+        :param regularizer: <not used here>
+        :param dropout_rate: dropout rate used in the model [0, 1)
+        :param window_lead_time: number of time steps to forecast in the output layer
+        """
+
+        super().__init__()
+
+        # settings
+        self.window_history_size = window_history_size
+        self.window_lead_time = window_lead_time
+        self.channels = channels
+        self.dropout_rate = 1e-2
+        self.regularizer = keras.regularizers.l2(0.1)
+        self.initial_lr = 1e-2
+        self.optimizer = keras.optimizers.adam(lr=self.initial_lr)
+        self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10)
+        self.epochs = 20
+        self.batch_size = int(256*4)
+        self.activation = keras.layers.PReLU
+
+        # apply to model
+        self.set_model()
+        self.set_loss()
+
+    def set_model(self):
+
+        """
+        Build the model.
+        :param activation: activation function
+        :param window_history_size: number of historical time steps included in the input data
+        :param channels: number of variables used in input data
+        :param dropout_rate: dropout rate used in the model [0, 1)
+        :param window_lead_time: number of time steps to forecast in the output layer
+        :return: built keras model
+        """
+        activation = self.activation
+        conv_settings_dict1 = {
+            'tower_1': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (3, 1), 'activation': activation},
+            'tower_2': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (5, 1), 'activation': activation},
+            'tower_3': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (1, 1), 'activation': activation},
+        }
+
+        pool_settings_dict1 = {'pool_kernel': (3, 1), 'tower_filter': 8 * 2, 'activation': activation}
+
+        conv_settings_dict2 = {
+            'tower_1': {'reduction_filter': 8 * 2, 'tower_filter': 16 * 2 * 2, 'tower_kernel': (3, 1),
+                        'activation': activation},
+            'tower_2': {'reduction_filter': 8 * 2, 'tower_filter': 16 * 2 * 2, 'tower_kernel': (5, 1),
+                        'activation': activation},
+            'tower_3': {'reduction_filter': 8 * 2, 'tower_filter': 16 * 2 * 2, 'tower_kernel': (1, 1),
+                        'activation': activation},
+            }
+        pool_settings_dict2 = {'pool_kernel': (3, 1), 'tower_filter': 16, 'activation': activation}
+
+        conv_settings_dict3 = {'tower_1': {'reduction_filter': 16 * 4, 'tower_filter': 32 * 2, 'tower_kernel': (3, 1),
+                                           'activation': activation},
+                               'tower_2': {'reduction_filter': 16 * 4, 'tower_filter': 32 * 2, 'tower_kernel': (5, 1),
+                                           'activation': activation},
+                               'tower_3': {'reduction_filter': 16 * 4, 'tower_filter': 32 * 2, 'tower_kernel': (1, 1),
+                                           'activation': activation},
+                               }
+
+        pool_settings_dict3 = {'pool_kernel': (3, 1), 'tower_filter': 32, 'activation': activation}
+
+        ##########################################
+        inception_model = InceptionModelBase()
+
+        X_input = keras.layers.Input(
+            shape=(self.window_history_size + 1, 1, self.channels))  # add 1 to window_size to include current time step t0
+
+        X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1,
+                                               regularizer=self.regularizer,
+                                               batch_normalisation=True)
+
+        X_in = keras.layers.Dropout(self.dropout_rate)(X_in)
+
+        X_in = inception_model.inception_block(X_in, conv_settings_dict2, pool_settings_dict2, regularizer=self.regularizer,
+                                               batch_normalisation=True)
+
+        X_in = keras.layers.Dropout(self.dropout_rate)(X_in)
+
+        X_in = inception_model.inception_block(X_in, conv_settings_dict3, pool_settings_dict3, regularizer=self.regularizer,
+                                               batch_normalisation=True)
+        #############################################
+
+        out_main = flatten_tail(X_in, 'Main', activation=activation, bound_weight=True, dropout_rate=self.dropout_rate,
+                                reduction_filter=64, first_dense=64, window_lead_time=self.window_lead_time)
+
+        self.model = keras.Model(inputs=X_input, outputs=[out_main])
+
+    def set_loss(self):
+
+        """
+        Set the loss
+        :return: loss function
+        """
+
+        self.loss = [keras.losses.mean_squared_error]
diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py
index c14298d7d21f63cdc4465c1ed8e8bb30868b3c1a..e3945a542d60b09dc9855bd28be87cdba729ed72 100644
--- a/src/run_modules/model_setup.py
+++ b/src/run_modules/model_setup.py
@@ -7,14 +7,11 @@ import os
 
 import keras
 import tensorflow as tf
-from keras import losses
 
-from src.helpers import l_p_loss
-from src.model_modules.flatten import flatten_tail
-from src.model_modules.inception_model import InceptionModelBase
-from src.model_modules.keras_extensions import HistoryAdvanced, ModelCheckpointAdvanced
+from src.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler
 # from src.model_modules.model_class import MyBranchedModel as MyModel
-from src.model_modules.model_class import MyLittleModel as MyModel
+# from src.model_modules.model_class import MyLittleModel as MyModel
+from src.model_modules.model_class import MyTowerModel as MyModel
 from src.run_modules.run_environment import RunEnvironment
 
 
@@ -52,7 +49,7 @@ class ModelSetup(RunEnvironment):
             self.load_weights()
 
         # create checkpoint
-        self._set_checkpoint()
+        self._set_callbacks()
 
         # compile model
         self.compile_model()
@@ -67,19 +64,20 @@ class ModelSetup(RunEnvironment):
         self.model.compile(optimizer=optimizer, loss=loss, metrics=["mse", "mae"])
         self.data_store.set("model", self.model, self.scope)
 
-    def _set_checkpoint(self):
+    def _set_callbacks(self):
         """
-        Must be run after all callback functions that shall be tracked during training have been created (currently this
-        affects the learning rate decay and the advanced history [actually created in this method]).
+        Set all callbacks for the training phase. Add all callbacks with the .add_callback statement. Finally, the
+        advanced model checkpoint is added.
         """
         lr = self.data_store.get("lr_decay", scope="general.model")
         hist = HistoryAdvanced()
         self.data_store.set("hist", hist, scope="general.model")
-        callbacks = [{"callback": lr, "path": self.callbacks_name % "lr"},
-                     {"callback": hist, "path": self.callbacks_name % "hist"}]
-        checkpoint = ModelCheckpointAdvanced(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
-                                             save_best_only=True, mode='auto', callbacks=callbacks)
-        self.data_store.set("checkpoint", checkpoint, self.scope)
+        callbacks = CallbackHandler()
+        callbacks.add_callback(lr, self.callbacks_name % "lr", "lr")
+        callbacks.add_callback(hist, self.callbacks_name % "hist", "hist")
+        callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
+                                          save_best_only=True, mode='auto')
+        self.data_store.set("callbacks", callbacks, self.scope)
 
     def load_weights(self):
         try:
@@ -104,90 +102,3 @@ class ModelSetup(RunEnvironment):
         with tf.device("/cpu:0"):
             file_name = f"{self.model_name.split(sep='.')[0]}.pdf"
             keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)
-
-
-def my_loss():
-    loss = l_p_loss(4)
-    keras_loss = losses.mean_squared_error
-    loss_all = [loss] + [keras_loss]
-    return loss_all
-
-
-def my_little_loss():
-    return losses.mean_squared_error
-
-
-def my_little_model(activation, window_history_size, channels, regularizer, dropout_rate, window_lead_time):
-
-    X_input = keras.layers.Input(
-        shape=(window_history_size + 1, 1, channels))  # add 1 to window_size to include current time step t0
-    X_in = keras.layers.Conv2D(32, (1, 1), padding='same', name='{}_Conv_1x1'.format("major"))(X_input)
-    X_in = activation(name='{}_conv_act'.format("major"))(X_in)
-    X_in = keras.layers.Flatten(name='{}'.format("major"))(X_in)
-    X_in = keras.layers.Dropout(dropout_rate, name='{}_Dropout_1'.format("major"))(X_in)
-    X_in = keras.layers.Dense(64, name='{}_Dense_64'.format("major"))(X_in)
-    X_in = activation()(X_in)
-    X_in = keras.layers.Dense(32, name='{}_Dense_32'.format("major"))(X_in)
-    X_in = activation()(X_in)
-    X_in = keras.layers.Dense(16, name='{}_Dense_16'.format("major"))(X_in)
-    X_in = activation()(X_in)
-    X_in = keras.layers.Dense(window_lead_time, name='{}_Dense'.format("major"))(X_in)
-    out_main = activation()(X_in)
-    return keras.Model(inputs=X_input, outputs=[out_main])
-
-
-def my_model(activation, window_history_size, channels, regularizer, dropout_rate, window_lead_time):
-
-    conv_settings_dict1 = {
-        'tower_1': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (3, 1), 'activation': activation},
-        'tower_2': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (5, 1), 'activation': activation},
-        'tower_3': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (1, 1), 'activation': activation},
-    }
-
-    pool_settings_dict1 = {'pool_kernel': (3, 1), 'tower_filter': 8 * 2, 'activation': activation}
-
-    conv_settings_dict2 = {'tower_1': {'reduction_filter': 8 * 2, 'tower_filter': 16 * 2 * 2, 'tower_kernel': (3, 1),
-                                       'activation': activation},
-                           'tower_2': {'reduction_filter': 8 * 2, 'tower_filter': 16 * 2 * 2, 'tower_kernel': (5, 1),
-                                       'activation': activation},
-                           'tower_3': {'reduction_filter': 8 * 2, 'tower_filter': 16 * 2 * 2, 'tower_kernel': (1, 1),
-                                       'activation': activation},
-                           }
-    pool_settings_dict2 = {'pool_kernel': (3, 1), 'tower_filter': 16, 'activation': activation}
-
-    conv_settings_dict3 = {'tower_1': {'reduction_filter': 16 * 4, 'tower_filter': 32 * 2, 'tower_kernel': (3, 1),
-                                       'activation': activation},
-                           'tower_2': {'reduction_filter': 16 * 4, 'tower_filter': 32 * 2, 'tower_kernel': (5, 1),
-                                       'activation': activation},
-                           'tower_3': {'reduction_filter': 16 * 4, 'tower_filter': 32 * 2, 'tower_kernel': (1, 1),
-                                       'activation': activation},
-                           }
-
-    pool_settings_dict3 = {'pool_kernel': (3, 1), 'tower_filter': 32, 'activation': activation}
-
-    ##########################################
-    inception_model = InceptionModelBase()
-
-    X_input = keras.layers.Input(shape=(window_history_size + 1, 1, channels))  # add 1 to window_size to include current time step t0
-
-    X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1, regularizer=regularizer,
-                                           batch_normalisation=True)
-
-    out_minor = flatten_tail(X_in, 'Minor_1', bound_weight=True, activation=activation, dropout_rate=dropout_rate,
-                             reduction_filter=4, first_dense=32, window_lead_time=window_lead_time)
-
-    X_in = keras.layers.Dropout(dropout_rate)(X_in)
-
-    X_in = inception_model.inception_block(X_in, conv_settings_dict2, pool_settings_dict2, regularizer=regularizer,
-                                           batch_normalisation=True)
-
-    X_in = keras.layers.Dropout(dropout_rate)(X_in)
-
-    X_in = inception_model.inception_block(X_in, conv_settings_dict3, pool_settings_dict3, regularizer=regularizer,
-                                           batch_normalisation=True)
-    #############################################
-
-    out_main = flatten_tail(X_in, 'Main', activation=activation, bound_weight=True, dropout_rate=dropout_rate,
-                            reduction_filter=64, first_dense=64, window_lead_time=window_lead_time)
-
-    return keras.Model(inputs=X_input, outputs=[out_minor, out_main])
diff --git a/src/run_modules/training.py b/src/run_modules/training.py
index 0b11da8d8f9a23c51d787f00d00a74d7517ea3b1..7a522af0298bcabee62579f68bd29ed123cac7b0 100644
--- a/src/run_modules/training.py
+++ b/src/run_modules/training.py
@@ -9,7 +9,7 @@ import pickle
 import keras
 
 from src.data_handling.data_distributor import Distributor
-from src.model_modules.keras_extensions import LearningRateDecay, ModelCheckpointAdvanced
+from src.model_modules.keras_extensions import LearningRateDecay, ModelCheckpointAdvanced, CallbackHandler
 from src.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate
 from src.run_modules.run_environment import RunEnvironment
 
@@ -24,9 +24,7 @@ class Training(RunEnvironment):
         self.test_set = None
         self.batch_size = self.data_store.get("batch_size", "general.model")
         self.epochs = self.data_store.get("epochs", "general.model")
-        self.checkpoint: ModelCheckpointAdvanced = self.data_store.get("checkpoint", "general.model")
-        self.lr_sc = self.data_store.get("lr_decay", "general.model")
-        self.hist = self.data_store.get("hist", "general.model")
+        self.callbacks: CallbackHandler = self.data_store.get("callbacks", "general.model")
         self.experiment_name = self.data_store.get("experiment_name", "general")
         self._trainable = self.data_store.get("trainable", "general")
         self._create_new_model = self.data_store.get("create_new_model", "general")
@@ -87,38 +85,35 @@ class Training(RunEnvironment):
         locally stored information and the corresponding model and proceed with the already started training.
         """
         logging.info(f"Train with {len(self.train_set)} mini batches.")
-        if not os.path.exists(self.checkpoint.filepath) or self._create_new_model:
+        checkpoint = self.callbacks.get_checkpoint()
+        if not os.path.exists(checkpoint.filepath) or self._create_new_model:
             history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
                                                steps_per_epoch=len(self.train_set),
                                                epochs=self.epochs,
                                                verbose=2,
                                                validation_data=self.val_set.distribute_on_batches(),
                                                validation_steps=len(self.val_set),
-                                               callbacks=[self.lr_sc, self.hist, self.checkpoint])
+                                               callbacks=self.callbacks.get_callbacks(as_dict=False))
         else:
             logging.info("Found locally stored model and checkpoints. Training is resumed from the last checkpoint.")
-            lr_filepath = self.checkpoint.callbacks[0]["path"]
-            hist_filepath = self.checkpoint.callbacks[1]["path"]
-            self.lr_sc = pickle.load(open(lr_filepath, "rb"))
-            self.hist = pickle.load(open(hist_filepath, "rb"))
-            self.model = keras.models.load_model(self.checkpoint.filepath)
-            initial_epoch = max(self.hist.epoch) + 1
-            callbacks = [{"callback": self.lr_sc, "path": lr_filepath},
-                         {"callback": self.hist, "path": hist_filepath}]
-            self.checkpoint.update_callbacks(callbacks)
-            self.checkpoint.update_best(self.hist)
+            self.callbacks.load_callbacks()
+            self.callbacks.update_checkpoint()
+            self.model = keras.models.load_model(checkpoint.filepath)
+            hist = self.callbacks.get_callback_by_name("hist")
+            initial_epoch = max(hist.epoch) + 1
             _ = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
                                          steps_per_epoch=len(self.train_set),
                                          epochs=self.epochs,
                                          verbose=2,
                                          validation_data=self.val_set.distribute_on_batches(),
                                          validation_steps=len(self.val_set),
-                                         callbacks=[self.lr_sc, self.hist, self.checkpoint],
+                                         callbacks=self.callbacks.get_callbacks(as_dict=False),
                                          initial_epoch=initial_epoch)
-            history = self.hist
-        self.save_callbacks_as_json(history)
-        self.load_best_model(self.checkpoint.filepath)
-        self.create_monitoring_plots(history, self.lr_sc)
+            history = hist
+        lr = self.callbacks.get_callback_by_name("lr")
+        self.save_callbacks_as_json(history, lr)
+        self.load_best_model(checkpoint.filepath)
+        self.create_monitoring_plots(history, lr)
 
     def save_model(self) -> None:
         """
@@ -141,7 +136,7 @@ class Training(RunEnvironment):
         except OSError:
             logging.info('no weights to reload...')
 
-    def save_callbacks_as_json(self, history: keras.callbacks.History) -> None:
+    def save_callbacks_as_json(self, history: keras.callbacks.History, lr_sc: keras.callbacks) -> None:
         """
         Save callbacks (history, learning rate) of training.
         * history.history -> history.json
@@ -153,7 +148,7 @@ class Training(RunEnvironment):
         with open(os.path.join(path, "history.json"), "w") as f:
             json.dump(history.history, f)
         with open(os.path.join(path, "history_lr.json"), "w") as f:
-            json.dump(self.lr_sc.lr, f)
+            json.dump(lr_sc.lr, f)
 
     def create_monitoring_plots(self, history: keras.callbacks.History, lr_sc: LearningRateDecay) -> None:
         """
diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py
index 35c5f8ee7581856a9feee3abd0face73ee83952c..ade35a244601d138d22af6305e67b5aeae964680 100644
--- a/test/test_modules/test_model_setup.py
+++ b/test/test_modules/test_model_setup.py
@@ -56,11 +56,11 @@ class TestModelSetup:
     def current_scope_as_set(model_cls):
         return set(model_cls.data_store.search_scope(model_cls.scope, current_scope_only=True))
 
-    def test_set_checkpoint(self, setup):
-        assert "general.modeltest" not in setup.data_store.search_name("checkpoint")
+    def test_set_callbacks(self, setup):
+        assert "general.modeltest" not in setup.data_store.search_name("callbacks")
         setup.checkpoint_name = "TestName"
-        setup._set_checkpoint()
-        assert "general.modeltest" in setup.data_store.search_name("checkpoint")
+        setup._set_callbacks()
+        assert "general.modeltest" in setup.data_store.search_name("callbacks")
 
     def test_get_model_settings(self, setup_with_model):
         with pytest.raises(EmptyScope):
diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py
index ac040c3a286c25dc84853c26c8509278642a1495..31c673f05d055eb7c4ee76318711de030d97d480 100644
--- a/test/test_modules/test_training.py
+++ b/test/test_modules/test_training.py
@@ -14,7 +14,7 @@ from src.data_handling.data_generator import DataGenerator
 from src.helpers import PyTestRegex
 from src.model_modules.flatten import flatten_tail
 from src.model_modules.inception_model import InceptionModelBase
-from src.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced
+from src.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced, CallbackHandler
 from src.run_modules.run_environment import RunEnvironment
 from src.run_modules.training import Training
 
@@ -39,7 +39,7 @@ def my_test_model(activation, window_history_size, channels, dropout_rate, add_m
 class TestTraining:
 
     @pytest.fixture
-    def init_without_run(self, path: str, model: keras.Model, checkpoint: ModelCheckpoint):
+    def init_without_run(self, path: str, model: keras.Model, callbacks: CallbackHandler):
         obj = object.__new__(Training)
         super(Training, obj).__init__()
         obj.model = model
@@ -48,9 +48,10 @@ class TestTraining:
         obj.test_set = None
         obj.batch_size = 256
         obj.epochs = 2
-        obj.checkpoint = checkpoint
-        obj.lr_sc = LearningRateDecay()
-        obj.hist = HistoryAdvanced()
+        clbk, hist, lr = callbacks
+        obj.callbacks = clbk
+        obj.lr_sc = lr
+        obj.hist = hist
         obj.experiment_name = "TestExperiment"
         obj.data_store.set("generator", mock.MagicMock(return_value="mock_train_gen"), "general.train")
         obj.data_store.set("generator", mock.MagicMock(return_value="mock_val_gen"), "general.val")
@@ -70,12 +71,9 @@ class TestTraining:
 
     @pytest.fixture
     def learning_rate(self):
-        return {"lr": [0.01, 0.0094]}
-
-    @pytest.fixture
-    def init_with_lr(self, init_without_run, learning_rate):
-        init_without_run.lr_sc.lr = learning_rate
-        return init_without_run
+        lr = LearningRateDecay()
+        lr.lr = {"lr": [0.01, 0.0094]}
+        return lr
 
     @pytest.fixture
     def history(self):
@@ -105,8 +103,15 @@ class TestTraining:
         return my_test_model(keras.layers.PReLU, 7, 2, 0.1, False)
 
     @pytest.fixture
-    def checkpoint(self, path):
-        return ModelCheckpoint(os.path.join(path, "model_checkpoint"), monitor='val_loss', save_best_only=True)
+    def callbacks(self, path):
+        clbk = CallbackHandler()
+        hist = HistoryAdvanced()
+        clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist")
+        lr = LearningRateDecay()
+        clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr")
+        clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss',
+                                     save_best_only=True)
+        return clbk, hist, lr
 
     @pytest.fixture
     def ready_to_train(self, generator: DataGenerator, init_without_run: Training):
@@ -125,7 +130,7 @@ class TestTraining:
         return obj
 
     @pytest.fixture
-    def ready_to_init(self, generator, model, checkpoint, path):
+    def ready_to_init(self, generator, model, callbacks, path):
         os.makedirs(path)
         obj = RunEnvironment()
         obj.data_store.set("generator", generator, "general.train")
@@ -136,14 +141,14 @@ class TestTraining:
         obj.data_store.set("model_name", os.path.join(path, "test_model.h5"), "general.model")
         obj.data_store.set("batch_size", 256, "general.model")
         obj.data_store.set("epochs", 2, "general.model")
-        obj.data_store.set("checkpoint", checkpoint, "general.model")
-        obj.data_store.set("lr_decay", LearningRateDecay(), "general.model")
-        obj.data_store.set("hist", HistoryAdvanced(), "general.model")
+        clbk, hist, lr = callbacks
+        obj.data_store.set("callbacks", clbk, "general.model")
+        obj.data_store.set("lr_decay", lr, "general.model")
+        obj.data_store.set("hist", hist, "general.model")
         obj.data_store.set("experiment_name", "TestExperiment", "general")
         obj.data_store.set("experiment_path", path, "general")
         obj.data_store.set("trainable", True, "general")
-        obj.data_store.set("create_new_model"
-                           "", True, "general")
+        obj.data_store.set("create_new_model", True, "general")
         path_plot = os.path.join(path, "plots")
         os.makedirs(path_plot)
         obj.data_store.set("plot_path", path_plot, "general")
@@ -197,25 +202,25 @@ class TestTraining:
         assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting"))
         assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload..."))
 
-    def test_save_callbacks_history_created(self, init_without_run, history, path):
-        init_without_run.save_callbacks_as_json(history)
+    def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, path):
+        init_without_run.save_callbacks_as_json(history, learning_rate)
         assert "history.json" in os.listdir(path)
 
-    def test_save_callbacks_lr_created(self, init_with_lr, history, path):
-        init_with_lr.save_callbacks_as_json(history)
+    def test_save_callbacks_lr_created(self, init_without_run, history, learning_rate, path):
+        init_without_run.save_callbacks_as_json(history, learning_rate)
         assert "history_lr.json" in os.listdir(path)
 
-    def test_save_callbacks_inspect_history(self, init_without_run, history, path):
-        init_without_run.save_callbacks_as_json(history)
+    def test_save_callbacks_inspect_history(self, init_without_run, history, learning_rate, path):
+        init_without_run.save_callbacks_as_json(history, learning_rate)
         with open(os.path.join(path, "history.json")) as jfile:
             hist = json.load(jfile)
             assert hist == history.history
 
-    def test_save_callbacks_inspect_lr(self, init_with_lr, history, path):
-        init_with_lr.save_callbacks_as_json(history)
+    def test_save_callbacks_inspect_lr(self, init_without_run, history, learning_rate, path):
+        init_without_run.save_callbacks_as_json(history, learning_rate)
         with open(os.path.join(path, "history_lr.json")) as jfile:
             lr = json.load(jfile)
-            assert lr == init_with_lr.lr_sc.lr
+            assert lr == learning_rate.lr
 
     def test_create_monitoring_plots(self, init_without_run, learning_rate, history, path):
         assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0