diff --git a/run.py b/run.py
index 71244fb9d15f594ac3ffbce60341d5c8dcb15f03..9f38fdca9c51cbed332725ce8e120e1493551b93 100644
--- a/run.py
+++ b/run.py
@@ -30,8 +30,8 @@ def main(parser_args):
 if __name__ == "__main__":
 
     formatter = '%(asctime)s - %(levelname)s: %(message)s  [%(filename)s:%(funcName)s:%(lineno)s]'
-    # logging.basicConfig(format=formatter, level=logging.INFO)
-    logging.basicConfig(format=formatter, level=logging.DEBUG)
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    # logging.basicConfig(format=formatter, level=logging.DEBUG)
 
     parser = argparse.ArgumentParser()
     parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None,
diff --git a/src/helpers.py b/src/helpers.py
index 172a8dd3cf04a15e9069347dac7f06c6d2d8ed60..a4ce625c8ae9bbf3c03425116a6bc10abf328bc9 100644
--- a/src/helpers.py
+++ b/src/helpers.py
@@ -11,12 +11,10 @@ import time
 import socket
 import datetime as dt
 
-import keras
 import keras.backend as K
-import numpy as np
 import xarray as xr
 
-from typing import Union, Dict, Callable
+from typing import Dict, Callable
 
 
 def to_list(arg):
@@ -45,55 +43,6 @@ def l_p_loss(power: int):
     return loss
 
 
-class LearningRateDecay(keras.callbacks.History):
-    """
-    Decay learning rate during model training. Start with a base learning rate and lower this rate after every
-    n(=epochs_drop) epochs by drop value (0, 1], drop value = 1 means no decay in learning rate.
-    """
-
-    def __init__(self, base_lr: float = 0.01, drop: float = 0.96, epochs_drop: int = 8):
-        super().__init__()
-        self.lr = {'lr': []}
-        self.base_lr = self.check_param(base_lr, 'base_lr')
-        self.drop = self.check_param(drop, 'drop')
-        self.epochs_drop = self.check_param(epochs_drop, 'epochs_drop', upper=None)
-
-    @staticmethod
-    def check_param(value: float, name: str, lower: Union[float, None] = 0, upper: Union[float, None] = 1):
-        """
-        Check if given value is in interval. The left (lower) endpoint is open, right (upper) endpoint is closed. To
-        only one side of the interval, set the other endpoint to None. If both ends are set to None, just return the
-        value without any check.
-        :param value: value to check
-        :param name: name of the variable to display in error message
-        :param lower: left (lower) endpoint of interval, opened
-        :param upper: right (upper) endpoint of interval, closed
-        :return: unchanged value or raise ValueError
-        """
-        if lower is None:
-            lower = -np.inf
-        if upper is None:
-            upper = np.inf
-        if lower < value <= upper:
-            return value
-        else:
-            raise ValueError(f"{name} is out of allowed range ({lower}, {upper}{')' if upper == np.inf else ']'}: "
-                             f"{name}={value}")
-
-    def on_epoch_begin(self, epoch: int, logs=None):
-        """
-        Lower learning rate every epochs_drop epochs by factor drop.
-        :param epoch: current epoch
-        :param logs: ?
-        :return: update keras learning rate
-        """
-        current_lr = self.base_lr * math.pow(self.drop, math.floor(epoch / self.epochs_drop))
-        K.set_value(self.model.optimizer.lr, current_lr)
-        self.lr['lr'].append(current_lr)
-        logging.info(f"Set learning rate to {current_lr}")
-        return K.get_value(self.model.optimizer.lr)
-
-
 class TimeTracking(object):
     """
     Track time to measure execution time. Time tracking automatically starts on initialisation and ends by calling stop
diff --git a/src/model_modules/keras_extensions.py b/src/model_modules/keras_extensions.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2a4b93219be2cbebfb35749560efa65c07226bb
--- /dev/null
+++ b/src/model_modules/keras_extensions.py
@@ -0,0 +1,150 @@
+__author__ = 'Lukas Leufen, Felix Kleinert'
+__date__ = '2020-01-31'
+
+import logging
+import math
+import pickle
+from typing import Union
+
+import numpy as np
+from keras import backend as K
+from keras.callbacks import History, ModelCheckpoint
+
+
+class HistoryAdvanced(History):
+    """
+    This is almost an identical clone of the original History class. The only difference is that attributes epoch and
+    history are instantiated during the init phase and not during on_train_begin. This is required to resume an already
+    started but disrupted training from an saved state. This HistoryAdvanced callback needs to be added separately as
+    additional callback. To get the full history use this object for further steps instead of the default return of
+    training methods like fit_generator().
+
+        hist = HistoryAdvanced()
+        history = model.fit_generator(generator=.... , callbacks=[hist])
+        history = hist
+
+    If training was started from beginning this class is identical to the returned history class object.
+    """
+
+    def __init__(self):
+        self.epoch = []
+        self.history = {}
+        super().__init__()
+
+    def on_train_begin(self, logs=None):
+        pass
+
+
+class LearningRateDecay(History):
+    """
+    Decay learning rate during model training. Start with a base learning rate and lower this rate after every
+    n(=epochs_drop) epochs by drop value (0, 1], drop value = 1 means no decay in learning rate.
+    """
+
+    def __init__(self, base_lr: float = 0.01, drop: float = 0.96, epochs_drop: int = 8):
+        super().__init__()
+        self.lr = {'lr': []}
+        self.base_lr = self.check_param(base_lr, 'base_lr')
+        self.drop = self.check_param(drop, 'drop')
+        self.epochs_drop = self.check_param(epochs_drop, 'epochs_drop', upper=None)
+        self.epoch = []
+        self.history = {}
+
+    @staticmethod
+    def check_param(value: float, name: str, lower: Union[float, None] = 0, upper: Union[float, None] = 1):
+        """
+        Check if given value is in interval. The left (lower) endpoint is open, right (upper) endpoint is closed. To
+        only one side of the interval, set the other endpoint to None. If both ends are set to None, just return the
+        value without any check.
+        :param value: value to check
+        :param name: name of the variable to display in error message
+        :param lower: left (lower) endpoint of interval, opened
+        :param upper: right (upper) endpoint of interval, closed
+        :return: unchanged value or raise ValueError
+        """
+        if lower is None:
+            lower = -np.inf
+        if upper is None:
+            upper = np.inf
+        if lower < value <= upper:
+            return value
+        else:
+            raise ValueError(f"{name} is out of allowed range ({lower}, {upper}{')' if upper == np.inf else ']'}: "
+                             f"{name}={value}")
+
+    def on_train_begin(self, logs=None):
+        pass
+
+    def on_epoch_begin(self, epoch: int, logs=None):
+        """
+        Lower learning rate every epochs_drop epochs by factor drop.
+        :param epoch: current epoch
+        :param logs: ?
+        :return: update keras learning rate
+        """
+        current_lr = self.base_lr * math.pow(self.drop, math.floor(epoch / self.epochs_drop))
+        K.set_value(self.model.optimizer.lr, current_lr)
+        self.lr['lr'].append(current_lr)
+        logging.info(f"Set learning rate to {current_lr}")
+        return K.get_value(self.model.optimizer.lr)
+
+
+class ModelCheckpointAdvanced(ModelCheckpoint):
+    """
+    Enhance the standard ModelCheckpoint class by additional saves of given callbacks. Specify this callbacks as follow:
+
+        lr = CustomLearningRate()
+        hist = CustomHistory()
+        callbacks_name = "your_custom_path_%s.pickle"
+        callbacks = [{"callback": lr, "path": callbacks_name % "lr"},
+                 {"callback": hist, "path": callbacks_name % "hist"}]
+        ckpt_callbacks = ModelCheckpointAdvanced(filepath=.... , callbacks=callbacks)
+
+    Add this ckpt_callbacks as all other additional callbacks to the callback list. IMPORTANT: Always add ckpt_callbacks
+    as last callback to properly update all tracked callbacks, e.g.
+
+        fit_generator(.... , callbacks=[lr, hist, ckpt_callbacks])
+
+    """
+    def __init__(self, *args, **kwargs):
+        self.callbacks = kwargs.pop("callbacks")
+        super().__init__(*args, **kwargs)
+
+    def update_best(self, hist):
+        """
+        Update internal best on resuming a training process. Otherwise best is set to +/- inf depending on the
+        performance metric and the first trained model (first of the resuming training process) will always saved as
+        best model because its performance will be better than infinity. To prevent this behaviour and compare the
+        performance with the best model performance, call this method before resuming the training process.
+        :param hist: The History object from the previous (interrupted) training.
+        """
+        self.best = hist.history.get(self.monitor)[-1]
+
+    def update_callbacks(self, callbacks):
+        """
+        Update all stored callback objects. The argument callbacks needs to follow the same convention like described
+        in the class description (list of dictionaries). Must be run before resuming a training process.
+        """
+        self.callbacks = callbacks
+
+    def on_epoch_end(self, epoch, logs=None):
+        """
+        Save model as usual (see ModelCheckpoint class), but also save additional callbacks.
+        """
+        super().on_epoch_end(epoch, logs)
+
+        for callback in self.callbacks:
+            file_path = callback["path"]
+            if self.epochs_since_last_save == 0 and epoch != 0:
+                if self.save_best_only:
+                    current = logs.get(self.monitor)
+                    if current == self.best:
+                        if self.verbose > 0:
+                            print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
+                        with open(file_path, "wb") as f:
+                            pickle.dump(callback["callback"], f)
+                else:
+                    with open(file_path, "wb") as f:
+                        if self.verbose > 0:
+                            print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
+                        pickle.dump(callback["callback"], f)
diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py
index 6b0fe236ff8ee726c34a721a6be0ed8be91f2bb8..02f43dc1b208cfd8a52a937298217216f26fbdb6 100644
--- a/src/model_modules/model_class.py
+++ b/src/model_modules/model_class.py
@@ -1,3 +1,5 @@
+import src.model_modules.keras_extensions
+
 __author__ = "Lukas Leufen"
 __date__ = '2019-12-12'
 
@@ -109,8 +111,8 @@ class MyLittleModel(AbstractModelClass):
         self.regularizer = keras.regularizers.l2(0.1)
         self.initial_lr = 1e-2
         self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
-        self.lr_decay = helpers.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10)
-        self.epochs = 2
+        self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10)
+        self.epochs = 20
         self.batch_size = int(256)
         self.activation = keras.layers.PReLU
 
@@ -189,8 +191,8 @@ class MyBranchedModel(AbstractModelClass):
         self.regularizer = keras.regularizers.l2(0.1)
         self.initial_lr = 1e-2
         self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
-        self.lr_decay = helpers.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10)
-        self.epochs = 2
+        self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10)
+        self.epochs = 20
         self.batch_size = int(256)
         self.activation = keras.layers.PReLU
 
diff --git a/src/plotting/training_monitoring.py b/src/plotting/training_monitoring.py
index 87e4071a0ac98c15c22131cd5d3418eb7c1b6976..339ba63711cde8122f2879d66b66d490429e6d23 100644
--- a/src/plotting/training_monitoring.py
+++ b/src/plotting/training_monitoring.py
@@ -9,8 +9,7 @@ import pandas as pd
 import matplotlib
 import matplotlib.pyplot as plt
 
-from src.helpers import LearningRateDecay
-
+from src.model_modules.keras_extensions import LearningRateDecay
 
 matplotlib.use('Agg')
 history_object = Union[Dict, keras.callbacks.History]
diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py
index a7722018c52275b390a10199cb30b7b936ed37a3..4a72189283ff1bedc3014b29cdb752fe244c84bf 100644
--- a/src/run_modules/model_setup.py
+++ b/src/run_modules/model_setup.py
@@ -4,19 +4,17 @@ __date__ = '2019-12-02'
 
 import keras
 from keras import losses
-from keras.callbacks import ModelCheckpoint
-from keras.regularizers import l2
-from keras.optimizers import SGD
 import tensorflow as tf
 import logging
 import os
 
 from src.run_modules.run_environment import RunEnvironment
-from src.helpers import l_p_loss, LearningRateDecay
+from src.helpers import l_p_loss
+from src.model_modules.keras_extensions import HistoryAdvanced, ModelCheckpointAdvanced
 from src.model_modules.inception_model import InceptionModelBase
 from src.model_modules.flatten import flatten_tail
-from src.model_modules.model_class import MyBranchedModel as MyModel
-# from src.model_modules.model_class import MyLittleModel as MyModel
+# from src.model_modules.model_class import MyBranchedModel as MyModel
+from src.model_modules.model_class import MyLittleModel as MyModel
 
 
 class ModelSetup(RunEnvironment):
@@ -30,13 +28,11 @@ class ModelSetup(RunEnvironment):
         exp_name = self.data_store.get("experiment_name", "general")
         self.scope = "general.model"
         self.checkpoint_name = os.path.join(path, f"{exp_name}_model-best.h5")
+        self.callbacks_name = os.path.join(path, f"{exp_name}_model-best-callbacks-%s.pickle")
         self._run()
 
     def _run(self):
 
-        # create checkpoint
-        self._set_checkpoint()
-
         # set channels depending on inputs
         self._set_channels()
 
@@ -50,6 +46,9 @@ class ModelSetup(RunEnvironment):
         if self.data_store.get("trainable", self.scope) is False:
             self.load_weights()
 
+        # create checkpoint
+        self._set_checkpoint()
+
         # compile model
         self.compile_model()
 
@@ -64,7 +63,17 @@ class ModelSetup(RunEnvironment):
         self.data_store.set("model", self.model, self.scope)
 
     def _set_checkpoint(self):
-        checkpoint = ModelCheckpoint(self.checkpoint_name, verbose=1, monitor='val_loss', save_best_only=True, mode='auto')
+        """
+        Must be run after all callback functions that shall be tracked during training have been created (currently this
+        affects the learning rate decay and the advanced history [actually created in this method]).
+        """
+        lr = self.data_store.get("lr_decay", scope="general.model")
+        hist = HistoryAdvanced()
+        self.data_store.set("hist", hist, scope="general.model")
+        callbacks = [{"callback": lr, "path": self.callbacks_name % "lr"},
+                     {"callback": hist, "path": self.callbacks_name % "hist"}]
+        checkpoint = ModelCheckpointAdvanced(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
+                                             save_best_only=True, mode='auto', callbacks=callbacks)
         self.data_store.set("checkpoint", checkpoint, self.scope)
 
     def load_weights(self):
diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py
index dd0def055afc00202f9139638586807b0d2b832b..2a4632d515a36a77b01a09a539da4f51ecd3e07a 100644
--- a/src/run_modules/pre_processing.py
+++ b/src/run_modules/pre_processing.py
@@ -42,17 +42,17 @@ class PreProcessing(RunEnvironment):
         self.report_pre_processing()
 
     def report_pre_processing(self):
-        logging.info(20 * '##')
+        logging.debug(20 * '##')
         n_train = len(self.data_store.get('generator', 'general.train'))
         n_val = len(self.data_store.get('generator', 'general.val'))
         n_test = len(self.data_store.get('generator', 'general.test'))
         n_total = n_train + n_val + n_test
-        logging.info(f"Number of all stations: {n_total}")
-        logging.info(f"Number of training stations: {n_train}")
-        logging.info(f"Number of val stations: {n_val}")
-        logging.info(f"Number of test stations: {n_test}")
-        logging.info(f"TEST SHAPE OF GENERATOR CALL: {self.data_store.get('generator', 'general.test')[0][0].shape}"
-                     f"{self.data_store.get('generator', 'general.test')[0][1].shape}")
+        logging.debug(f"Number of all stations: {n_total}")
+        logging.debug(f"Number of training stations: {n_train}")
+        logging.debug(f"Number of val stations: {n_val}")
+        logging.debug(f"Number of test stations: {n_test}")
+        logging.debug(f"TEST SHAPE OF GENERATOR CALL: {self.data_store.get('generator', 'general.test')[0][0].shape}"
+                      f"{self.data_store.get('generator', 'general.test')[0][1].shape}")
 
     def split_train_val_test(self):
         fraction_of_training = self.data_store.get("fraction_of_training", "general")
diff --git a/src/run_modules/training.py b/src/run_modules/training.py
index 96936ce124e05251af483e758401d833a44531f4..99afd8300ca28fec3a589e69fa5b4eff1a37914a 100644
--- a/src/run_modules/training.py
+++ b/src/run_modules/training.py
@@ -5,11 +5,12 @@ import logging
 import os
 import json
 import keras
+import pickle
 
 from src.run_modules.run_environment import RunEnvironment
 from src.data_handling.data_distributor import Distributor
 from src.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate
-from src.helpers import LearningRateDecay
+from src.model_modules.keras_extensions import LearningRateDecay, ModelCheckpointAdvanced
 
 
 class Training(RunEnvironment):
@@ -22,8 +23,9 @@ class Training(RunEnvironment):
         self.test_set = None
         self.batch_size = self.data_store.get("batch_size", "general.model")
         self.epochs = self.data_store.get("epochs", "general.model")
-        self.checkpoint = self.data_store.get("checkpoint", "general.model")
+        self.checkpoint: ModelCheckpointAdvanced = self.data_store.get("checkpoint", "general.model")
         self.lr_sc = self.data_store.get("lr_decay", "general.model")
+        self.hist = self.data_store.get("hist", "general.model")
         self.experiment_name = self.data_store.get("experiment_name", "general")
         self._run()
 
@@ -35,7 +37,7 @@ class Training(RunEnvironment):
         2) make_predict_function():
             create predict function before distribution on multiple nodes (detailed information in method description)
         3) train():
-            train model and save callbacks
+            start or resume training of model and save callbacks
         4) save_model():
             save best model from training as final model
         """
@@ -73,17 +75,42 @@ class Training(RunEnvironment):
     def train(self) -> None:
         """
         Perform training using keras fit_generator(). Callbacks are stored locally in the experiment directory. Best
-        model from training is saved for class variable model.
+        model from training is saved for class variable model. If the file path of checkpoint is not empty, this method
+        assumes, that this is not a new training starting from the very beginning, but a resumption from a previous
+        started but interrupted training (or a stopped and now continued training). Train will automatically load the
+        locally stored information and the corresponding model and proceed with the already started training.
         """
         logging.info(f"Train with {len(self.train_set)} mini batches.")
-        history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
-                                           steps_per_epoch=len(self.train_set),
-                                           epochs=self.epochs,
-                                           verbose=2,
-                                           validation_data=self.val_set.distribute_on_batches(),
-                                           validation_steps=len(self.val_set),
-                                           callbacks=[self.checkpoint, self.lr_sc])
-        self.save_callbacks(history)
+        if not os.path.exists(self.checkpoint.filepath):
+            history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
+                                               steps_per_epoch=len(self.train_set),
+                                               epochs=self.epochs,
+                                               verbose=2,
+                                               validation_data=self.val_set.distribute_on_batches(),
+                                               validation_steps=len(self.val_set),
+                                               callbacks=[self.lr_sc, self.hist, self.checkpoint])
+        else:
+            logging.info("Found locally stored model and checkpoints. Training is resumed from the last checkpoint.")
+            lr_filepath = self.checkpoint.callbacks[0]["path"]
+            hist_filepath = self.checkpoint.callbacks[1]["path"]
+            self.lr_sc = pickle.load(open(lr_filepath, "rb"))
+            self.hist = pickle.load(open(hist_filepath, "rb"))
+            self.model = keras.models.load_model(self.checkpoint.filepath)
+            initial_epoch = max(self.hist.epoch) + 1
+            callbacks = [{"callback": self.lr_sc, "path": lr_filepath},
+                         {"callback": self.hist, "path": hist_filepath}]
+            self.checkpoint.update_callbacks(callbacks)
+            self.checkpoint.update_best(self.hist)
+            _ = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
+                                         steps_per_epoch=len(self.train_set),
+                                         epochs=self.epochs,
+                                         verbose=2,
+                                         validation_data=self.val_set.distribute_on_batches(),
+                                         validation_steps=len(self.val_set),
+                                         callbacks=[self.lr_sc, self.hist, self.checkpoint],
+                                         initial_epoch=initial_epoch)
+            history = self.hist
+        self.save_callbacks_as_json(history)
         self.load_best_model(self.checkpoint.filepath)
         self.create_monitoring_plots(history, self.lr_sc)
 
@@ -110,7 +137,7 @@ class Training(RunEnvironment):
         except OSError:
             logging.info('no weights to reload...')
 
-    def save_callbacks(self, history: keras.callbacks.History) -> None:
+    def save_callbacks_as_json(self, history: keras.callbacks.History) -> None:
         """
         Save callbacks (history, learning rate) of training.
         * history.history -> history.json
diff --git a/test/test_helpers.py b/test/test_helpers.py
index e98a46fad6365a3a05ab28c9d118a119e35ff86a..c909960b4e5e053b9291c12e64e3649e957886bc 100644
--- a/test/test_helpers.py
+++ b/test/test_helpers.py
@@ -7,6 +7,8 @@ import numpy as np
 import mock
 import platform
 
+from src.model_modules.keras_extensions import LearningRateDecay
+
 
 class TestToList:
 
@@ -44,44 +46,6 @@ class TestLoss:
         assert hist.history['loss'][0] == 2.25
 
 
-class TestLearningRateDecay:
-
-    def test_init(self):
-        lr_decay = LearningRateDecay()
-        assert lr_decay.lr == {'lr': []}
-        assert lr_decay.base_lr == 0.01
-        assert lr_decay.drop == 0.96
-        assert lr_decay.epochs_drop == 8
-
-    def test_check_param(self):
-        lr_decay = object.__new__(LearningRateDecay)
-        assert lr_decay.check_param(1, "tester") == 1
-        assert lr_decay.check_param(0.5, "tester") == 0.5
-        with pytest.raises(ValueError) as e:
-            lr_decay.check_param(0, "tester")
-        assert "tester is out of allowed range (0, 1]: tester=0" in e.value.args[0]
-        with pytest.raises(ValueError) as e:
-            lr_decay.check_param(1.5, "tester")
-        assert "tester is out of allowed range (0, 1]: tester=1.5" in e.value.args[0]
-        assert lr_decay.check_param(1.5, "tester", upper=None) == 1.5
-        with pytest.raises(ValueError) as e:
-            lr_decay.check_param(0, "tester", upper=None)
-        assert "tester is out of allowed range (0, inf): tester=0" in e.value.args[0]
-        assert lr_decay.check_param(0.5, "tester", lower=None) == 0.5
-        with pytest.raises(ValueError) as e:
-            lr_decay.check_param(0.5, "tester", lower=None, upper=0.2)
-        assert "tester is out of allowed range (-inf, 0.2]: tester=0.5" in e.value.args[0]
-        assert lr_decay.check_param(10, "tester", upper=None, lower=None)
-
-    def test_on_epoch_begin(self):
-        lr_decay = LearningRateDecay(base_lr=0.02, drop=0.95, epochs_drop=2)
-        model = keras.Sequential()
-        model.add(keras.layers.Dense(1, input_dim=1))
-        model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(2))
-        model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=5, callbacks=[lr_decay])
-        assert lr_decay.lr['lr'] == [0.02, 0.02, 0.02 * 0.95, 0.02 * 0.95, 0.02 * 0.95 * 0.95]
-
-
 class TestTimeTracking:
 
     def test_init(self):
diff --git a/test/test_model_modules/test_keras_extensions.py b/test/test_model_modules/test_keras_extensions.py
new file mode 100644
index 0000000000000000000000000000000000000000..c50e5e425779575eb2b492213a0b39b2b7c3376e
--- /dev/null
+++ b/test/test_model_modules/test_keras_extensions.py
@@ -0,0 +1,43 @@
+import pytest
+from src.model_modules.keras_extensions import *
+from src.helpers import l_p_loss
+import keras
+import numpy as np
+
+
+class TestLearningRateDecay:
+
+    def test_init(self):
+        lr_decay = LearningRateDecay()
+        assert lr_decay.lr == {'lr': []}
+        assert lr_decay.base_lr == 0.01
+        assert lr_decay.drop == 0.96
+        assert lr_decay.epochs_drop == 8
+
+    def test_check_param(self):
+        lr_decay = object.__new__(LearningRateDecay)
+        assert lr_decay.check_param(1, "tester") == 1
+        assert lr_decay.check_param(0.5, "tester") == 0.5
+        with pytest.raises(ValueError) as e:
+            lr_decay.check_param(0, "tester")
+        assert "tester is out of allowed range (0, 1]: tester=0" in e.value.args[0]
+        with pytest.raises(ValueError) as e:
+            lr_decay.check_param(1.5, "tester")
+        assert "tester is out of allowed range (0, 1]: tester=1.5" in e.value.args[0]
+        assert lr_decay.check_param(1.5, "tester", upper=None) == 1.5
+        with pytest.raises(ValueError) as e:
+            lr_decay.check_param(0, "tester", upper=None)
+        assert "tester is out of allowed range (0, inf): tester=0" in e.value.args[0]
+        assert lr_decay.check_param(0.5, "tester", lower=None) == 0.5
+        with pytest.raises(ValueError) as e:
+            lr_decay.check_param(0.5, "tester", lower=None, upper=0.2)
+        assert "tester is out of allowed range (-inf, 0.2]: tester=0.5" in e.value.args[0]
+        assert lr_decay.check_param(10, "tester", upper=None, lower=None)
+
+    def test_on_epoch_begin(self):
+        lr_decay = LearningRateDecay(base_lr=0.02, drop=0.95, epochs_drop=2)
+        model = keras.Sequential()
+        model.add(keras.layers.Dense(1, input_dim=1))
+        model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(2))
+        model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=5, callbacks=[lr_decay])
+        assert lr_decay.lr['lr'] == [0.02, 0.02, 0.02 * 0.95, 0.02 * 0.95, 0.02 * 0.95 * 0.95]
diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py
index d604d7474af84740a4b7a1cc51e5e94f1c94533b..65c683003d173031115164b4800d7759ff9cec2f 100644
--- a/test/test_modules/test_model_setup.py
+++ b/test/test_modules/test_model_setup.py
@@ -18,6 +18,9 @@ class TestModelSetup:
         super(ModelSetup, obj).__init__()
         obj.scope = "general.modeltest"
         obj.model = None
+        obj.callbacks_name = "placeholder_%s_str.pickle"
+        obj.data_store.set("lr_decay", "dummy_str", "general.model")
+        obj.data_store.set("hist", "dummy_str", "general.model")
         yield obj
         RunEnvironment().__del__()
 
diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py
index a562e7b05a79f0068b10e9e36771669fe47d4ce8..c6f70169adee6eec74c834d952721f41d3c3fa03 100644
--- a/test/test_modules/test_pre_processing.py
+++ b/test/test_modules/test_pre_processing.py
@@ -47,14 +47,15 @@ class TestPreProcessing:
         assert obj_with_exp_setup.data_store.search_name("generator") == []
         assert obj_with_exp_setup._run() is None
         assert obj_with_exp_setup.data_store.search_name("generator") == sorted(["general.train", "general.val",
-                                                                                 "general.test"])
+                                                                                 "general.train_val", "general.test"])
 
     def test_split_train_val_test(self, obj_with_exp_setup):
         assert obj_with_exp_setup.data_store.search_name("generator") == []
         obj_with_exp_setup.split_train_val_test()
         data_store = obj_with_exp_setup.data_store
         assert data_store.search_scope("general.train") == sorted(["generator", "start", "end", "stations"])
-        assert data_store.search_name("generator") == sorted(["general.train", "general.val", "general.test"])
+        assert data_store.search_name("generator") == sorted(["general.train", "general.val", "general.test",
+                                                              "general.train_val"])
 
     def test_create_set_split_not_all_stations(self, caplog, obj_with_exp_setup):
         caplog.set_level(logging.DEBUG)
@@ -93,10 +94,11 @@ class TestPreProcessing:
 
     def test_split_set_indices(self, obj_super_init):
         dummy_list = list(range(0, 15))
-        train, val, test = obj_super_init.split_set_indices(len(dummy_list), 0.9)
+        train, val, test, train_val = obj_super_init.split_set_indices(len(dummy_list), 0.9)
         assert dummy_list[train] == list(range(0, 10))
         assert dummy_list[val] == list(range(10, 13))
         assert dummy_list[test] == list(range(13, 15))
+        assert dummy_list[train_val] == list(range(0, 13))
 
     def test_create_args_dict_default_scope(self, obj_super_init):
         assert obj_super_init.data_store.create_args_dict(["NAME1", "NAME2"]) == {"NAME1": 1, "NAME2": 2}
diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py
index accb32e5e3ec0fc425065ae6199c0418c524b174..08b9eaf19e831ed1662efbe21f4ad29d18dff9b4 100644
--- a/test/test_modules/test_training.py
+++ b/test/test_modules/test_training.py
@@ -14,7 +14,8 @@ from src.run_modules.training import Training
 from src.run_modules.run_environment import RunEnvironment
 from src.data_handling.data_distributor import Distributor
 from src.data_handling.data_generator import DataGenerator
-from src.helpers import LearningRateDecay, PyTestRegex
+from src.helpers import PyTestRegex
+from src.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced
 
 
 def my_test_model(activation, window_history_size, channels, dropout_rate, add_minor_branch=False):
@@ -48,6 +49,7 @@ class TestTraining:
         obj.epochs = 2
         obj.checkpoint = checkpoint
         obj.lr_sc = LearningRateDecay()
+        obj.hist = HistoryAdvanced()
         obj.experiment_name = "TestExperiment"
         obj.data_store.set("generator", mock.MagicMock(return_value="mock_train_gen"), "general.train")
         obj.data_store.set("generator", mock.MagicMock(return_value="mock_val_gen"), "general.val")
@@ -132,6 +134,7 @@ class TestTraining:
         obj.data_store.set("epochs", 2, "general.model")
         obj.data_store.set("checkpoint", checkpoint, "general.model")
         obj.data_store.set("lr_decay", LearningRateDecay(), "general.model")
+        obj.data_store.set("hist", HistoryAdvanced(), "general.model")
         obj.data_store.set("experiment_name", "TestExperiment", "general")
         obj.data_store.set("experiment_path", path, "general")
         path_plot = os.path.join(path, "plots")
@@ -188,21 +191,21 @@ class TestTraining:
         assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload..."))
 
     def test_save_callbacks_history_created(self, init_without_run, history, path):
-        init_without_run.save_callbacks(history)
+        init_without_run.save_callbacks_as_json(history)
         assert "history.json" in os.listdir(path)
 
     def test_save_callbacks_lr_created(self, init_with_lr, history, path):
-        init_with_lr.save_callbacks(history)
+        init_with_lr.save_callbacks_as_json(history)
         assert "history_lr.json" in os.listdir(path)
 
     def test_save_callbacks_inspect_history(self, init_without_run, history, path):
-        init_without_run.save_callbacks(history)
+        init_without_run.save_callbacks_as_json(history)
         with open(os.path.join(path, "history.json")) as jfile:
             hist = json.load(jfile)
             assert hist == history.history
 
     def test_save_callbacks_inspect_lr(self, init_with_lr, history, path):
-        init_with_lr.save_callbacks(history)
+        init_with_lr.save_callbacks_as_json(history)
         with open(os.path.join(path, "history_lr.json")) as jfile:
             lr = json.load(jfile)
             assert lr == init_with_lr.lr_sc.lr
diff --git a/test/test_plotting/test_training_monitoring.py b/test/test_plotting/test_training_monitoring.py
index e2d17057facee8ae07284109392f3e42e8e2fb6e..358a19adf5c81c90f7e77b787ca7b50923990f00 100644
--- a/test/test_plotting/test_training_monitoring.py
+++ b/test/test_plotting/test_training_monitoring.py
@@ -3,7 +3,7 @@ import pytest
 import os
 
 from src.plotting.training_monitoring import PlotModelLearningRate, PlotModelHistory
-from src.helpers import LearningRateDecay
+from src.model_modules.keras_extensions import LearningRateDecay
 
 
 @pytest.fixture