diff --git a/.gitignore b/.gitignore
index ff59ade5d38dac9c3cf2fecee6a676ee728a2162..f7793d5f492cced655aeb62a8c29af48ac3e452e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -59,3 +59,7 @@ htmlcov/
 report.html
 /TestExperiment/
 /testrun_network*/
+
+# secret variables #
+####################
+/src/join_settings.py
\ No newline at end of file
diff --git a/README.md b/README.md
index e49362e95e9a69c159a5b8d857ccb336cf58d3c6..3467a31f23b7f770d32afb91cb62d5207ccf3d62 100644
--- a/README.md
+++ b/README.md
@@ -12,4 +12,68 @@ and [Network In Network (Lin et al., 2014)](https://arxiv.org/abs/1312.4400).
 # Installation
 
 * Install __proj__ on your machine using the console. E.g. for opensuse / leap `zypper install proj`
-* c++ compiler required for cartopy installation
\ No newline at end of file
+* c++ compiler required for cartopy installation
+
+# Security
+
+* To use hourly data from ToarDB via JOIN interface, a private token is required. Request your personal access token and
+add it to `src/join_settings.py` in the hourly data section. Replace the `TOAR_SERVICE_URL` and the `Authorization` 
+value. To make sure, that this **sensitive** data is not uploaded to the remote server, use the following command to
+prevent git from tracking this file: `git update-index --assume-unchanged src/join_settings.py
+`
+
+# Customise your experiment
+
+This section summarises which parameters can be customised for a training.
+
+## Transformation
+
+There are two different approaches (called scopes) to transform the data:
+1) `station`: transform data for each station independently (somehow like batch normalisation)
+1) `data`: transform all data of each station with shared metrics
+
+Transformation must be set by the `transformation` attribute. If `transformation = None` is given to `ExperimentSetup`, 
+data is not transformed at all. For all other setups, use the following dictionary structure to specify the 
+transformation.
+```
+transformation = {"scope": <...>, 
+                  "method": <...>,
+                  "mean": <...>,
+                  "std": <...>}
+ExperimentSetup(..., transformation=transformation, ...)
+```
+
+### scopes
+
+**station**: mean and std are not used
+
+**data**: either provide already calculated values for mean and std (if required by transformation method), or choose 
+from different calculation schemes, explained in the mean and std section.
+
+### supported transformation methods
+Currently supported methods are:
+* standardise (default, if method is not given)
+* centre
+
+### mean and std
+`"mean"="accurate"`: calculate the accurate values of mean and std (depending on method) by using all data. Although, 
+this method is accurate, it may take some time for the calculation. Furthermore, this could potentially lead to memory 
+issue (not explored yet, but could appear for a very big amount of data)
+
+`"mean"="estimate"`: estimate mean and std (depending on method). For each station, mean and std are calculated and
+afterwards aggregated using the mean value over all station-wise metrics. This method is less accurate, especially 
+regarding the std calculation but therefore much faster.
+
+We recommend to use the later method *estimate* because of following reasons:
+* much faster calculation
+* real accuracy of mean and std is less important, because it is "just" a transformation / scaling
+* accuracy of mean is almost as high as in the *accurate* case, because of 
+$\bar{x_{ij}} = \bar{\left(\bar{x_i}\right)_j}$. The only difference is, that in the *estimate* case, each mean is 
+equally weighted for each station independently of the actual data count of the station.
+* accuracy of std is lower for *estimate* because of $\var{x_{ij}} \ne \bar{\left(\var{x_i}\right)_j}$, but still the mean of all 
+station-wise std is a decent estimate of the true std.
+
+`"mean"=<value, e.g. xr.DataArray>`: If mean and std are already calculated or shall be set manually, just add the
+scaling values instead of the calculation method. For method *centre*, std can still be None, but is required for the
+*standardise* method. **Important**: Format of given values **must** match internal data format of DataPreparation 
+class: `xr.DataArray` with `dims=["variables"]` and one value for each variable.
\ No newline at end of file
diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py
index c6f38a6f0e70518956bcbbd51a6fdfc1a1e7849f..b1624410e746ab779b20a60d6a7d19b4ae3b1267 100644
--- a/src/data_handling/data_distributor.py
+++ b/src/data_handling/data_distributor.py
@@ -12,11 +12,11 @@ import numpy as np
 class Distributor(keras.utils.Sequence):
 
     def __init__(self, generator: keras.utils.Sequence, model: keras.models, batch_size: int = 256,
-                 fit_call: bool = True):
+                 permute_data: bool = False):
         self.generator = generator
         self.model = model
         self.batch_size = batch_size
-        self.fit_call = fit_call
+        self.do_data_permutation = permute_data
 
     def _get_model_rank(self):
         mod_out = self.model.output_shape
@@ -33,6 +33,16 @@ class Distributor(keras.utils.Sequence):
     def _get_number_of_mini_batches(self, values):
         return math.ceil(values[0].shape[0] / self.batch_size)
 
+    def _permute_data(self, x, y):
+        """
+        Permute inputs x and labels y
+        """
+        if self.do_data_permutation:
+            p = np.random.permutation(len(x))  # equiv to .shape[0]
+            x = x[p]
+            y = y[p]
+        return x, y
+
     def distribute_on_batches(self, fit_call=True):
         while True:
             for k, v in enumerate(self.generator):
@@ -42,6 +52,8 @@ class Distributor(keras.utils.Sequence):
                 num_mini_batches = self._get_number_of_mini_batches(v)
                 x_total = np.copy(v[0])
                 y_total = np.copy(v[1])
+                # permute order for mini-batches
+                x_total, y_total = self._permute_data(x_total, y_total)
                 for prev, curr in enumerate(range(1, num_mini_batches+1)):
                     x = x_total[prev*self.batch_size:curr*self.batch_size, ...]
                     y = [y_total[prev*self.batch_size:curr*self.batch_size, ...] for _ in range(mod_rank)]
diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py
index 19a94fbb9dbbc8f382a225c852f34971a98395b8..79e1e7e72c1779d18a11652ab132c253e1dff806 100644
--- a/src/data_handling/data_generator.py
+++ b/src/data_handling/data_generator.py
@@ -2,8 +2,9 @@ __author__ = 'Felix Kleinert, Lukas Leufen'
 __date__ = '2019-11-07'
 
 import os
-from typing import Union, List, Tuple, Any
+from typing import Union, List, Tuple, Any, Dict
 
+import dask.array as da
 import keras
 import xarray as xr
 import pickle
@@ -11,6 +12,7 @@ import logging
 
 from src import helpers
 from src.data_handling.data_preparation import DataPrep
+from src.join import EmptyQueryResult
 
 
 class DataGenerator(keras.utils.Sequence):
@@ -25,7 +27,7 @@ class DataGenerator(keras.utils.Sequence):
     def __init__(self, data_path: str, network: str, stations: Union[str, List[str]], variables: List[str],
                  interpolate_dim: str, target_dim: str, target_var: str, station_type: str = None,
                  interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7,
-                 window_lead_time: int = 4, transform_method: str = "standardise", **kwargs):
+                 window_lead_time: int = 4, transformation: Dict = None, **kwargs):
         self.data_path = os.path.abspath(data_path)
         self.data_path_tmp = os.path.join(os.path.abspath(data_path), "tmp")
         if not os.path.exists(self.data_path_tmp):
@@ -41,8 +43,8 @@ class DataGenerator(keras.utils.Sequence):
         self.limit_nan_fill = limit_nan_fill
         self.window_history_size = window_history_size
         self.window_lead_time = window_lead_time
-        self.transform_method = transform_method
         self.kwargs = kwargs
+        self.transformation = self.setup_transformation(transformation)
 
     def __repr__(self):
         """
@@ -94,18 +96,86 @@ class DataGenerator(keras.utils.Sequence):
         data = self.get_data_generator(key=item)
         return data.get_transposed_history(), data.label.squeeze("Stations").transpose("datetime", "window")
 
-    def get_data_generator(self, key: Union[str, int] = None, local_tmp_storage: bool = True) -> DataPrep:
+    def setup_transformation(self, transformation):
+        if transformation is None:
+            return
+        transformation = transformation.copy()
+        scope = transformation.get("scope", "station")
+        method = transformation.get("method", "standardise")
+        mean = transformation.get("mean", None)
+        std = transformation.get("std", None)
+        if scope == "data":
+            if isinstance(mean, str):
+                if mean == "accurate":
+                    mean, std = self.calculate_accurate_transformation(method)
+                elif mean == "estimate":
+                    mean, std = self.calculate_estimated_transformation(method)
+                else:
+                    raise ValueError(f"given mean attribute must either be equal to strings 'accurate' or 'estimate' or"
+                                     f"be an array with already calculated means. Given was: {mean}")
+        elif scope == "station":
+            mean, std = None, None
+        else:
+            raise ValueError(f"Scope argument can either be 'station' or 'data'. Given was: {scope}")
+        transformation["method"] = method
+        transformation["mean"] = mean
+        transformation["std"] = std
+        return transformation
+
+    def calculate_accurate_transformation(self, method):
+        tmp = []
+        mean = None
+        std = None
+        for station in self.stations:
+            try:
+                data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type,
+                                **self.kwargs)
+                chunks = (1, 100, data.data.shape[2])
+                tmp.append(da.from_array(data.data.data, chunks=chunks))
+            except EmptyQueryResult:
+                continue
+        tmp = da.concatenate(tmp, axis=1)
+        if method in ["standardise", "centre"]:
+            mean = da.nanmean(tmp, axis=1).compute()
+            mean = xr.DataArray(mean.flatten(), coords={"variables": sorted(self.variables)}, dims=["variables"])
+            if method == "standardise":
+                std = da.nanstd(tmp, axis=1).compute()
+                std = xr.DataArray(std.flatten(), coords={"variables": sorted(self.variables)}, dims=["variables"])
+        else:
+            raise NotImplementedError
+        return mean, std
+
+    def calculate_estimated_transformation(self, method):
+        data = [[]]*len(self.variables)
+        coords = {"variables": self.variables, "Stations": range(0)}
+        mean = xr.DataArray(data, coords=coords, dims=["variables", "Stations"])
+        std = xr.DataArray(data, coords=coords, dims=["variables", "Stations"])
+        for station in self.stations:
+            try:
+                data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type,
+                                **self.kwargs)
+                data.transform("datetime", method=method)
+                mean = mean.combine_first(data.mean)
+                std = std.combine_first(data.std)
+                data.transform("datetime", method=method, inverse=True)
+            except EmptyQueryResult:
+                continue
+        return mean.mean("Stations") if mean.shape[1] > 0 else None, std.mean("Stations") if std.shape[1] > 0 else None
+
+    def get_data_generator(self, key: Union[str, int] = None, load_local_tmp_storage: bool = True,
+                           save_local_tmp_storage: bool = True) -> DataPrep:
         """
         Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and
         remove nans.
         :param key: station key to choose the data generator.
-        :param local_tmp_storage: say if data should be processed from scratch or loaded as already processed data from
-            tmp pickle file to save computational time (but of course more disk space required).
+        :param load_local_tmp_storage: say if data should be processed from scratch or loaded as already processed data
+            from tmp pickle file to save computational time (but of course more disk space required).
+        :param save_local_tmp_storage: save processed data as temporal file locally (default True)
         :return: preprocessed data as a DataPrep instance
         """
         station = self.get_station_key(key)
         try:
-            if not local_tmp_storage:
+            if not load_local_tmp_storage:
                 raise FileNotFoundError
             data = self._load_pickle_data(station, self.variables)
         except FileNotFoundError:
@@ -113,11 +183,13 @@ class DataGenerator(keras.utils.Sequence):
             data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type,
                             **self.kwargs)
             data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill)
-            data.transform("datetime", method=self.transform_method)
+            if self.transformation is not None:
+                data.transform("datetime", **helpers.dict_pop(self.transformation, "scope"))
             data.make_history_window(self.interpolate_dim, self.window_history_size)
             data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time)
             data.history_label_nan_remove(self.interpolate_dim)
-            self._save_pickle_data(data)
+            if save_local_tmp_storage:
+                self._save_pickle_data(data)
         return data
 
     def _save_pickle_data(self, data: Any):
diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py
index 5bca71f52c9f136b5910d4e080491e0ff86484ae..98b47a6df3825581564fa9aaef7be8698408760e 100644
--- a/src/data_handling/data_preparation.py
+++ b/src/data_handling/data_preparation.py
@@ -216,7 +216,7 @@ class DataPrep(object):
         self.data, self.mean, self.std = f_inverse(self.data, self.mean, self.std, self._transform_method)
         self._transform_method = None
 
-    def transform(self, dim: Union[str, int] = 0, method: str = 'standardise', inverse: bool = False) -> None:
+    def transform(self, dim: Union[str, int] = 0, method: str = 'standardise', inverse: bool = False, mean = None, std=None) -> None:
         """
         This function transforms a xarray.dataarray (along dim) or pandas.DataFrame (along axis) either with mean=0
         and std=1 (`method=standardise`) or centers the data with mean=0 and no change in data scale
@@ -247,11 +247,19 @@ class DataPrep(object):
             else:
                 raise NotImplementedError
 
+        def f_apply(data):
+            if method == "standardise":
+                return mean, std, statistics.standardise_apply(data, mean, std)
+            elif method == "centre":
+                return mean, None, statistics.centre_apply(data, mean)
+            else:
+                raise NotImplementedError
+
         if not inverse:
             if self._transform_method is not None:
                 raise AssertionError(f"Transform method is already set. Therefore, data was already transformed with "
                                      f"{self._transform_method}. Please perform inverse transformation of data first.")
-            self.mean, self.std, self.data = f(self.data)
+            self.mean, self.std, self.data = locals()["f" if mean is None else "f_apply"](self.data)
             self._transform_method = method
         else:
             self.inverse_transform()
@@ -370,7 +378,7 @@ class DataPrep(object):
         :param coord: name of axis to slice
         :return:
         """
-        return data.loc[{coord: slice(start, end)}]
+        return data.loc[{coord: slice(str(start), str(end))}]
 
     def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray:
         """
@@ -387,8 +395,7 @@ class DataPrep(object):
         return data
 
     def get_transposed_history(self):
-        if self.history is not None:
-            return self.history.transpose("datetime", "window", "Stations", "variables")
+        return self.history.transpose("datetime", "window", "Stations", "variables")
 
 
 if __name__ == "__main__":
diff --git a/src/helpers.py b/src/helpers.py
index c33684508b7b36cbebc3cbf3ac826d1779f9df50..e1496c3b232db29194878892647c59581b6a70a3 100644
--- a/src/helpers.py
+++ b/src/helpers.py
@@ -190,3 +190,8 @@ def float_round(number: float, decimals: int = 0, round_type: Callable = math.ce
     """
     multiplier = 10. ** decimals
     return round_type(number * multiplier) / multiplier
+
+
+def dict_pop(dict: Dict, pop_keys):
+    pop_keys = to_list(pop_keys)
+    return {k: v for k, v in dict.items() if k not in pop_keys}
diff --git a/src/model_modules/keras_extensions.py b/src/model_modules/keras_extensions.py
index e2a4b93219be2cbebfb35749560efa65c07226bb..180e324602da25e1df8fb218c1d3bba180004ac8 100644
--- a/src/model_modules/keras_extensions.py
+++ b/src/model_modules/keras_extensions.py
@@ -10,6 +10,8 @@ import numpy as np
 from keras import backend as K
 from keras.callbacks import History, ModelCheckpoint
 
+from src import helpers
+
 
 class HistoryAdvanced(History):
     """
@@ -125,7 +127,7 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
         Update all stored callback objects. The argument callbacks needs to follow the same convention like described
         in the class description (list of dictionaries). Must be run before resuming a training process.
         """
-        self.callbacks = callbacks
+        self.callbacks = helpers.to_list(callbacks)
 
     def on_epoch_end(self, epoch, logs=None):
         """
@@ -139,12 +141,73 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
                 if self.save_best_only:
                     current = logs.get(self.monitor)
                     if current == self.best:
-                        if self.verbose > 0:
+                        if self.verbose > 0:  # pragma: no branch
                             print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
                         with open(file_path, "wb") as f:
                             pickle.dump(callback["callback"], f)
                 else:
                     with open(file_path, "wb") as f:
-                        if self.verbose > 0:
+                        if self.verbose > 0:  # pragma: no branch
                             print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
                         pickle.dump(callback["callback"], f)
+
+
+class CallbackHandler:
+
+    def __init__(self):
+        self.__callbacks = []
+        self._checkpoint = None
+        self.editable = True
+
+    @property
+    def _callbacks(self):
+        return [{"callback": clbk[clbk["name"]], "path": clbk["path"]} for clbk in self.__callbacks]
+
+    @_callbacks.setter
+    def _callbacks(self, value):
+        name, callback, callback_path = value
+        self.__callbacks.append({"name": name, name: callback, "path": callback_path})
+
+    def _update_callback(self, pos, value):
+        name = self.__callbacks[pos]["name"]
+        self.__callbacks[pos][name] = value
+
+    def add_callback(self, callback, callback_path, name="callback"):
+        if self.editable:
+            self._callbacks = (name, callback, callback_path)
+        else:
+            raise PermissionError(f"{__class__.__name__} is protected and cannot be edited.")
+
+    def get_callbacks(self, as_dict=True):
+        if as_dict:
+            return self._get_callbacks()
+        else:
+            return [clb["callback"] for clb in self._get_callbacks()]
+
+    def get_callback_by_name(self, obj_name):
+        if obj_name != "callback":
+            return [clbk[clbk["name"]] for clbk in self.__callbacks if clbk["name"] == obj_name][0]
+
+    def _get_callbacks(self):
+        clbks = self._callbacks
+        if self._checkpoint is not None:
+            clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath}]
+        return clbks
+
+    def get_checkpoint(self):
+        if self._checkpoint is not None:
+            return self._checkpoint
+
+    def create_model_checkpoint(self, **kwargs):
+        self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, **kwargs)
+        self.editable = False
+
+    def load_callbacks(self):
+        for pos, callback in enumerate(self.__callbacks):
+            path = callback["path"]
+            clb = pickle.load(open(path, "rb"))
+            self._update_callback(pos, clb)
+
+    def update_checkpoint(self, history_name="hist"):
+        self._checkpoint.update_callbacks(self._callbacks)
+        self._checkpoint.update_best(self.get_callback_by_name(history_name))
diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py
index bb4801acfb5c3a643aecbcfad9cfdb758258d0ef..ebbd7a25cef9031436d932a6502c9726bfe3e318 100644
--- a/src/model_modules/model_class.py
+++ b/src/model_modules/model_class.py
@@ -8,6 +8,8 @@ from abc import ABC
 from typing import Any, Callable
 
 import keras
+from src.model_modules.inception_model import InceptionModelBase
+from src.model_modules.flatten import flatten_tail
 
 
 class AbstractModelClass(ABC):
@@ -27,6 +29,7 @@ class AbstractModelClass(ABC):
 
         self.__model = None
         self.__loss = None
+        self.model_name = self.__class__.__name__
 
     def __getattr__(self, name: str) -> Any:
 
@@ -239,3 +242,112 @@ class MyBranchedModel(AbstractModelClass):
 
         self.loss = [keras.losses.mean_absolute_error] + [keras.losses.mean_squared_error] + \
                     [keras.losses.mean_squared_error]
+
+
+class MyTowerModel(AbstractModelClass):
+
+    def __init__(self, window_history_size, window_lead_time, channels):
+
+        """
+        Sets model and loss depending on the given arguments.
+        :param activation: activation function
+        :param window_history_size: number of historical time steps included in the input data
+        :param channels: number of variables used in input data
+        :param regularizer: <not used here>
+        :param dropout_rate: dropout rate used in the model [0, 1)
+        :param window_lead_time: number of time steps to forecast in the output layer
+        """
+
+        super().__init__()
+
+        # settings
+        self.window_history_size = window_history_size
+        self.window_lead_time = window_lead_time
+        self.channels = channels
+        self.dropout_rate = 1e-2
+        self.regularizer = keras.regularizers.l2(0.1)
+        self.initial_lr = 1e-2
+        self.optimizer = keras.optimizers.adam(lr=self.initial_lr)
+        self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10)
+        self.epochs = 20
+        self.batch_size = int(256*4)
+        self.activation = keras.layers.PReLU
+
+        # apply to model
+        self.set_model()
+        self.set_loss()
+
+    def set_model(self):
+
+        """
+        Build the model.
+        :param activation: activation function
+        :param window_history_size: number of historical time steps included in the input data
+        :param channels: number of variables used in input data
+        :param dropout_rate: dropout rate used in the model [0, 1)
+        :param window_lead_time: number of time steps to forecast in the output layer
+        :return: built keras model
+        """
+        activation = self.activation
+        conv_settings_dict1 = {
+            'tower_1': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (3, 1), 'activation': activation},
+            'tower_2': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (5, 1), 'activation': activation},
+            'tower_3': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (1, 1), 'activation': activation},
+        }
+
+        pool_settings_dict1 = {'pool_kernel': (3, 1), 'tower_filter': 8 * 2, 'activation': activation}
+
+        conv_settings_dict2 = {
+            'tower_1': {'reduction_filter': 8 * 2, 'tower_filter': 16 * 2 * 2, 'tower_kernel': (3, 1),
+                        'activation': activation},
+            'tower_2': {'reduction_filter': 8 * 2, 'tower_filter': 16 * 2 * 2, 'tower_kernel': (5, 1),
+                        'activation': activation},
+            'tower_3': {'reduction_filter': 8 * 2, 'tower_filter': 16 * 2 * 2, 'tower_kernel': (1, 1),
+                        'activation': activation},
+            }
+        pool_settings_dict2 = {'pool_kernel': (3, 1), 'tower_filter': 16, 'activation': activation}
+
+        conv_settings_dict3 = {'tower_1': {'reduction_filter': 16 * 4, 'tower_filter': 32 * 2, 'tower_kernel': (3, 1),
+                                           'activation': activation},
+                               'tower_2': {'reduction_filter': 16 * 4, 'tower_filter': 32 * 2, 'tower_kernel': (5, 1),
+                                           'activation': activation},
+                               'tower_3': {'reduction_filter': 16 * 4, 'tower_filter': 32 * 2, 'tower_kernel': (1, 1),
+                                           'activation': activation},
+                               }
+
+        pool_settings_dict3 = {'pool_kernel': (3, 1), 'tower_filter': 32, 'activation': activation}
+
+        ##########################################
+        inception_model = InceptionModelBase()
+
+        X_input = keras.layers.Input(
+            shape=(self.window_history_size + 1, 1, self.channels))  # add 1 to window_size to include current time step t0
+
+        X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1,
+                                               regularizer=self.regularizer,
+                                               batch_normalisation=True)
+
+        X_in = keras.layers.Dropout(self.dropout_rate)(X_in)
+
+        X_in = inception_model.inception_block(X_in, conv_settings_dict2, pool_settings_dict2, regularizer=self.regularizer,
+                                               batch_normalisation=True)
+
+        X_in = keras.layers.Dropout(self.dropout_rate)(X_in)
+
+        X_in = inception_model.inception_block(X_in, conv_settings_dict3, pool_settings_dict3, regularizer=self.regularizer,
+                                               batch_normalisation=True)
+        #############################################
+
+        out_main = flatten_tail(X_in, 'Main', activation=activation, bound_weight=True, dropout_rate=self.dropout_rate,
+                                reduction_filter=64, first_dense=64, window_lead_time=self.window_lead_time)
+
+        self.model = keras.Model(inputs=X_input, outputs=[out_main])
+
+    def set_loss(self):
+
+        """
+        Set the loss
+        :return: loss function
+        """
+
+        self.loss = [keras.losses.mean_squared_error]
diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py
index 01a565e3db284e56bf0b8c94420b71268fd21a80..a41c636b5ab17d2039f7976fca625e9c8e11ce6e 100644
--- a/src/plotting/postprocessing_plotting.py
+++ b/src/plotting/postprocessing_plotting.py
@@ -50,8 +50,8 @@ class PlotMonthlySummary(RunEnvironment):
 
     def _prepare_data(self, stations: List) -> xr.DataArray:
         """
-        Pre-process data required to plot. For each station, load locally saved predictions, extract the CNN and orig
-        prediction and group them into monthly bins (no aggregation, only sorting them).
+        Pre-process data required to plot. For each station, load locally saved predictions, extract the CNN prediction
+        and the observation and group them into monthly bins (no aggregation, only sorting them).
         :param stations: all stations to plot
         :return: The entire data set, flagged with the corresponding month.
         """
@@ -65,10 +65,10 @@ class PlotMonthlySummary(RunEnvironment):
             if len(data_cnn.shape) > 1:
                 data_cnn.coords["ahead"].values = [f"{days}d" for days in data_cnn.coords["ahead"].values]
 
-            data_orig = data.sel(type="orig", ahead=1).squeeze()
-            data_orig.coords["ahead"] = "orig"
+            data_obs = data.sel(type="obs", ahead=1).squeeze()
+            data_obs.coords["ahead"] = "obs"
 
-            data_concat = xr.concat([data_orig, data_cnn], dim="ahead")
+            data_concat = xr.concat([data_obs, data_cnn], dim="ahead")
             data_concat = data_concat.drop("type")
 
             data_concat.index.values = data_concat.index.values.astype("datetime64[M]").astype(int) % 12 + 1
@@ -189,7 +189,7 @@ class PlotStationMap(RunEnvironment):
         plt.close('all')
 
 
-def plot_conditional_quantiles(stations: list, plot_folder: str = ".", rolling_window: int = 3, ref_name: str = 'orig',
+def plot_conditional_quantiles(stations: list, plot_folder: str = ".", rolling_window: int = 3, ref_name: str = 'obs',
                                pred_name: str = 'CNN', season: str = "", forecast_path: str = None,
                                plot_name_affix: str = "", units: str = "ppb"):
     """
@@ -222,7 +222,7 @@ def plot_conditional_quantiles(stations: list, plot_folder: str = ".", rolling_w
         for station in stations:
             file = os.path.join(forecast_path, f"forecasts_{station}_test.nc")
             data_tmp = xr.open_dataarray(file)
-            data_collector.append(data_tmp.loc[:, :, ['CNN', 'orig', 'OLS']].assign_coords(station=station))
+            data_collector.append(data_tmp.loc[:, :, ['CNN', 'obs', 'OLS']].assign_coords(station=station))
         return xr.concat(data_collector, dim='station').transpose('index', 'type', 'ahead', 'station')
 
     def segment_data(data):
@@ -252,7 +252,7 @@ def plot_conditional_quantiles(stations: list, plot_folder: str = ".", rolling_w
 
     def labels(plot_type, data_unit="ppb"):
         names = (f"forecast concentration (in {data_unit})", f"observed concentration (in {data_unit})")
-        if plot_type == "orig":
+        if plot_type == "obs":
             return names
         else:
             return names[::-1]
@@ -515,7 +515,7 @@ class PlotTimeSeries(RunEnvironment):
         logging.debug(f"... preprocess station {station}")
         file_name = os.path.join(self._data_path, self._data_name % station)
         data = xr.open_dataarray(file_name)
-        return data.sel(type=["CNN", "orig"])
+        return data.sel(type=["CNN", "obs"])
 
     def _plot(self, plot_folder):
         pdf_pages = self._create_pdf_pages(plot_folder)
@@ -527,9 +527,9 @@ class PlotTimeSeries(RunEnvironment):
             for i_year in range(end - start + 1):
                 data_year = data.sel(index=f"{start + i_year}")
                 for i_half_of_year in range(factor):
-                    pos = 2 * i_year + i_half_of_year
+                    pos = factor * i_year + i_half_of_year
                     plot_data = self._create_plot_data(data_year, factor, i_half_of_year)
-                    self._plot_orig(axes[pos], plot_data)
+                    self._plot_obs(axes[pos], plot_data)
                     self._plot_ahead(axes[pos], plot_data)
                     if np.isnan(plot_data.values).all():
                         nan_list.append(pos)
@@ -574,10 +574,10 @@ class PlotTimeSeries(RunEnvironment):
             label = f"{ahead}{self._sampling}"
             ax.plot(index, plot_data.values, color=color[ahead-1], label=label)
 
-    def _plot_orig(self, ax, data):
-        orig_data = data.sel(type="orig", ahead=1)
+    def _plot_obs(self, ax, data):
+        obs_data = data.sel(type="obs", ahead=1)
         index = data.index + np.timedelta64(1, self._sampling)
-        ax.plot(index, orig_data.values, color=matplotlib.colors.cnames["green"], label="orig")
+        ax.plot(index, obs_data.values, color=matplotlib.colors.cnames["green"], label="obs")
 
     @staticmethod
     def _get_time_range(data):
diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py
index 44f1a3821dfacba146a0aabe8fb0254068d9e6d3..4c3b8872575ea9929f1d4ba3f5a42e222ac2fff4 100644
--- a/src/run_modules/experiment_setup.py
+++ b/src/run_modules/experiment_setup.py
@@ -19,6 +19,7 @@ DEFAULT_STATIONS = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBY
 DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values',
                         'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values',
                         'pblheight': 'maximum'}
+DEFAULT_TRANSFORMATION = {"scope": "data", "method": "standardise", "mean": "estimate"}
 
 
 class ExperimentSetup(RunEnvironment):
@@ -31,16 +32,21 @@ class ExperimentSetup(RunEnvironment):
                  statistics_per_var=None, start=None, end=None, window_history_size=None, target_var="o3", target_dim=None,
                  window_lead_time=None, dimensions=None, interpolate_dim=None, interpolate_method=None,
                  limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None,
-                 test_end=None, use_all_stations_on_all_data_sets=True, trainable=False, fraction_of_train=None,
-                 experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily"):
+                 test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None,
+                 experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily",
+                 create_new_model=None, permute_data_on_training=None, transformation=None):
 
         # create run framework
         super().__init__()
 
         # experiment setup
         self._set_param("data_path", helpers.prepare_host(sampling=sampling))
-        self._set_param("trainable", trainable, default=False)
+        self._set_param("create_new_model", create_new_model, default=True)
+        if self.data_store.get("create_new_model", "general"):
+            trainable = True
+        self._set_param("trainable", trainable, default=True)
         self._set_param("fraction_of_training", fraction_of_train, default=0.8)
+        self._set_param("permute_data", permute_data_on_training, default=False, scope="general.train")
 
         # set experiment name
         exp_date = self._get_parser_args(parser_args).get("experiment_date")
@@ -73,6 +79,8 @@ class ExperimentSetup(RunEnvironment):
         self._set_param("window_history_size", window_history_size, default=13)
         self._set_param("overwrite_local_data", overwrite_local_data, default=False, scope="general.preprocessing")
         self._set_param("sampling", sampling)
+        self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION)
+        self._set_param("transformation", None, scope="general.preprocessing")
 
         # target
         self._set_param("target_var", target_var, default="o3")
@@ -85,19 +93,19 @@ class ExperimentSetup(RunEnvironment):
         self._set_param("interpolate_method", interpolate_method, default='linear')
         self._set_param("limit_nan_fill", limit_nan_fill, default=1)
 
-        # train parameters
+        # train set parameters
         self._set_param("start", train_start, default="1997-01-01", scope="general.train")
         self._set_param("end", train_end, default="2007-12-31", scope="general.train")
 
-        # validation parameters
+        # validation set parameters
         self._set_param("start", val_start, default="2008-01-01", scope="general.val")
         self._set_param("end", val_end, default="2009-12-31", scope="general.val")
 
-        # test parameters
+        # test set parameters
         self._set_param("start", test_start, default="2010-01-01", scope="general.test")
         self._set_param("end", test_end, default="2017-12-31", scope="general.test")
 
-        # train_val parameters
+        # train_val set parameters
         self._set_param("start", self.data_store.get("start", "general.train"), scope="general.train_val")
         self._set_param("end", self.data_store.get("end", "general.val"), scope="general.train_val")
 
diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py
index 7049af591cdf73434459d4c3f5f6c11e80ab64c0..e3945a542d60b09dc9855bd28be87cdba729ed72 100644
--- a/src/run_modules/model_setup.py
+++ b/src/run_modules/model_setup.py
@@ -7,14 +7,11 @@ import os
 
 import keras
 import tensorflow as tf
-from keras import losses
 
-from src.helpers import l_p_loss
-from src.model_modules.flatten import flatten_tail
-from src.model_modules.inception_model import InceptionModelBase
-from src.model_modules.keras_extensions import HistoryAdvanced, ModelCheckpointAdvanced
+from src.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler
 # from src.model_modules.model_class import MyBranchedModel as MyModel
-from src.model_modules.model_class import MyLittleModel as MyModel
+# from src.model_modules.model_class import MyLittleModel as MyModel
+from src.model_modules.model_class import MyTowerModel as MyModel
 from src.run_modules.run_environment import RunEnvironment
 
 
@@ -28,8 +25,12 @@ class ModelSetup(RunEnvironment):
         path = self.data_store.get("experiment_path", "general")
         exp_name = self.data_store.get("experiment_name", "general")
         self.scope = "general.model"
-        self.checkpoint_name = os.path.join(path, f"{exp_name}_model-best.h5")
-        self.callbacks_name = os.path.join(path, f"{exp_name}_model-best-callbacks-%s.pickle")
+        self.path = os.path.join(path, f"{exp_name}_%s")
+        self.model_name = self.path % "%s.h5"
+        self.checkpoint_name = self.path % "model-best.h5"
+        self.callbacks_name = self.path % "model-best-callbacks-%s.pickle"
+        self._trainable = self.data_store.get("trainable", "general")
+        self._create_new_model = self.data_store.get("create_new_model", "general")
         self._run()
 
     def _run(self):
@@ -44,11 +45,11 @@ class ModelSetup(RunEnvironment):
         self.plot_model()
 
         # load weights if no training shall be performed
-        if self.data_store.get("trainable", self.scope) is False:
+        if not self._trainable and not self._create_new_model:
             self.load_weights()
 
         # create checkpoint
-        self._set_checkpoint()
+        self._set_callbacks()
 
         # compile model
         self.compile_model()
@@ -63,24 +64,25 @@ class ModelSetup(RunEnvironment):
         self.model.compile(optimizer=optimizer, loss=loss, metrics=["mse", "mae"])
         self.data_store.set("model", self.model, self.scope)
 
-    def _set_checkpoint(self):
+    def _set_callbacks(self):
         """
-        Must be run after all callback functions that shall be tracked during training have been created (currently this
-        affects the learning rate decay and the advanced history [actually created in this method]).
+        Set all callbacks for the training phase. Add all callbacks with the .add_callback statement. Finally, the
+        advanced model checkpoint is added.
         """
         lr = self.data_store.get("lr_decay", scope="general.model")
         hist = HistoryAdvanced()
         self.data_store.set("hist", hist, scope="general.model")
-        callbacks = [{"callback": lr, "path": self.callbacks_name % "lr"},
-                     {"callback": hist, "path": self.callbacks_name % "hist"}]
-        checkpoint = ModelCheckpointAdvanced(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
-                                             save_best_only=True, mode='auto', callbacks=callbacks)
-        self.data_store.set("checkpoint", checkpoint, self.scope)
+        callbacks = CallbackHandler()
+        callbacks.add_callback(lr, self.callbacks_name % "lr", "lr")
+        callbacks.add_callback(hist, self.callbacks_name % "hist", "hist")
+        callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
+                                          save_best_only=True, mode='auto')
+        self.data_store.set("callbacks", callbacks, self.scope)
 
     def load_weights(self):
         try:
-            self.model.load_weights(self.checkpoint_name)
-            logging.info('reload weights...')
+            self.model.load_weights(self.model_name)
+            logging.info(f"reload weights from model {self.model_name} ...")
         except OSError:
             logging.info('no weights to reload...')
 
@@ -93,97 +95,10 @@ class ModelSetup(RunEnvironment):
     def get_model_settings(self):
         model_settings = self.model.get_settings()
         self.data_store.set_args_from_dict(model_settings, self.scope)
+        self.model_name = self.model_name % self.data_store.get_default("model_name", self.scope, "my_model")
+        self.data_store.set("model_name", self.model_name, self.scope)
 
     def plot_model(self):  # pragma: no cover
         with tf.device("/cpu:0"):
-            path = self.data_store.get("experiment_path", "general")
-            name = self.data_store.get("experiment_name", "general") + "_model.pdf"
-            file_name = os.path.join(path, name)
+            file_name = f"{self.model_name.split(sep='.')[0]}.pdf"
             keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)
-
-
-def my_loss():
-    loss = l_p_loss(4)
-    keras_loss = losses.mean_squared_error
-    loss_all = [loss] + [keras_loss]
-    return loss_all
-
-
-def my_little_loss():
-    return losses.mean_squared_error
-
-
-def my_little_model(activation, window_history_size, channels, regularizer, dropout_rate, window_lead_time):
-
-    X_input = keras.layers.Input(
-        shape=(window_history_size + 1, 1, channels))  # add 1 to window_size to include current time step t0
-    X_in = keras.layers.Conv2D(32, (1, 1), padding='same', name='{}_Conv_1x1'.format("major"))(X_input)
-    X_in = activation(name='{}_conv_act'.format("major"))(X_in)
-    X_in = keras.layers.Flatten(name='{}'.format("major"))(X_in)
-    X_in = keras.layers.Dropout(dropout_rate, name='{}_Dropout_1'.format("major"))(X_in)
-    X_in = keras.layers.Dense(64, name='{}_Dense_64'.format("major"))(X_in)
-    X_in = activation()(X_in)
-    X_in = keras.layers.Dense(32, name='{}_Dense_32'.format("major"))(X_in)
-    X_in = activation()(X_in)
-    X_in = keras.layers.Dense(16, name='{}_Dense_16'.format("major"))(X_in)
-    X_in = activation()(X_in)
-    X_in = keras.layers.Dense(window_lead_time, name='{}_Dense'.format("major"))(X_in)
-    out_main = activation()(X_in)
-    return keras.Model(inputs=X_input, outputs=[out_main])
-
-
-def my_model(activation, window_history_size, channels, regularizer, dropout_rate, window_lead_time):
-
-    conv_settings_dict1 = {
-        'tower_1': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (3, 1), 'activation': activation},
-        'tower_2': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (5, 1), 'activation': activation},
-        'tower_3': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (1, 1), 'activation': activation},
-    }
-
-    pool_settings_dict1 = {'pool_kernel': (3, 1), 'tower_filter': 8 * 2, 'activation': activation}
-
-    conv_settings_dict2 = {'tower_1': {'reduction_filter': 8 * 2, 'tower_filter': 16 * 2 * 2, 'tower_kernel': (3, 1),
-                                       'activation': activation},
-                           'tower_2': {'reduction_filter': 8 * 2, 'tower_filter': 16 * 2 * 2, 'tower_kernel': (5, 1),
-                                       'activation': activation},
-                           'tower_3': {'reduction_filter': 8 * 2, 'tower_filter': 16 * 2 * 2, 'tower_kernel': (1, 1),
-                                       'activation': activation},
-                           }
-    pool_settings_dict2 = {'pool_kernel': (3, 1), 'tower_filter': 16, 'activation': activation}
-
-    conv_settings_dict3 = {'tower_1': {'reduction_filter': 16 * 4, 'tower_filter': 32 * 2, 'tower_kernel': (3, 1),
-                                       'activation': activation},
-                           'tower_2': {'reduction_filter': 16 * 4, 'tower_filter': 32 * 2, 'tower_kernel': (5, 1),
-                                       'activation': activation},
-                           'tower_3': {'reduction_filter': 16 * 4, 'tower_filter': 32 * 2, 'tower_kernel': (1, 1),
-                                       'activation': activation},
-                           }
-
-    pool_settings_dict3 = {'pool_kernel': (3, 1), 'tower_filter': 32, 'activation': activation}
-
-    ##########################################
-    inception_model = InceptionModelBase()
-
-    X_input = keras.layers.Input(shape=(window_history_size + 1, 1, channels))  # add 1 to window_size to include current time step t0
-
-    X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1, regularizer=regularizer,
-                                           batch_normalisation=True)
-
-    out_minor = flatten_tail(X_in, 'Minor_1', bound_weight=True, activation=activation, dropout_rate=dropout_rate,
-                             reduction_filter=4, first_dense=32, window_lead_time=window_lead_time)
-
-    X_in = keras.layers.Dropout(dropout_rate)(X_in)
-
-    X_in = inception_model.inception_block(X_in, conv_settings_dict2, pool_settings_dict2, regularizer=regularizer,
-                                           batch_normalisation=True)
-
-    X_in = keras.layers.Dropout(dropout_rate)(X_in)
-
-    X_in = inception_model.inception_block(X_in, conv_settings_dict3, pool_settings_dict3, regularizer=regularizer,
-                                           batch_normalisation=True)
-    #############################################
-
-    out_main = flatten_tail(X_in, 'Main', activation=activation, bound_weight=True, dropout_rate=dropout_rate,
-                            reduction_filter=64, first_dense=64, window_lead_time=window_lead_time)
-
-    return keras.Model(inputs=X_input, outputs=[out_minor, out_main])
diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py
index 03d2e36e8662a573b96c970747e9fe4445244e9b..06203c879872891f57c719040482fe052824c65e 100644
--- a/src/run_modules/post_processing.py
+++ b/src/run_modules/post_processing.py
@@ -43,7 +43,7 @@ class PostProcessing(RunEnvironment):
         with TimeTracking():
             self.train_ols_model()
             logging.info("take a look on the next reported time measure. If this increases a lot, one should think to "
-                         "skip make_prediction() whenever it is possible to save time.")
+                         "skip train_ols_model() whenever it is possible to save time.")
         with TimeTracking():
             self.make_prediction()
             logging.info("take a look on the next reported time measure. If this increases a lot, one should think to "
@@ -55,10 +55,8 @@ class PostProcessing(RunEnvironment):
         try:
             model = self.data_store.get("best_model", "general")
         except NameNotFoundInDataStore:
-            logging.info("no model saved in data store. trying to load model from experiment")
-            path = self.data_store.get("experiment_path", "general")
-            name = f"{self.data_store.get('experiment_name', 'general')}_my_model.h5"
-            model_name = os.path.join(path, name)
+            logging.info("no model saved in data store. trying to load model from experiment path")
+            model_name = self.data_store.get("model_name", "general.model")
             model = keras.models.load_model(model_name)
         return model
 
@@ -66,9 +64,9 @@ class PostProcessing(RunEnvironment):
         logging.debug("Run plotting routines...")
         path = self.data_store.get("forecast_path", "general")
 
-        plot_conditional_quantiles(self.test_data.stations, pred_name="CNN", ref_name="orig",
+        plot_conditional_quantiles(self.test_data.stations, pred_name="CNN", ref_name="obs",
                                    forecast_path=path, plot_name_affix="cali-ref", plot_folder=self.plot_path)
-        plot_conditional_quantiles(self.test_data.stations, pred_name="orig", ref_name="CNN",
+        plot_conditional_quantiles(self.test_data.stations, pred_name="obs", ref_name="CNN",
                                    forecast_path=path, plot_name_affix="like-bas", plot_folder=self.plot_path)
         PlotStationMap(generators={'b': self.test_data}, plot_folder=self.plot_path)
         PlotMonthlySummary(self.test_data.stations, path, r"forecasts_%s_test.nc", self.target_var,
@@ -115,15 +113,15 @@ class PostProcessing(RunEnvironment):
             # ols
             ols_prediction = self._create_ols_forecast(input_data, ols_prediction, mean, std, transformation_method)
 
-            # orig pred
-            orig_pred = self._create_orig_forecast(data, None, mean, std, transformation_method)
+            # observation
+            observation = self._create_observation(data, None, mean, std, transformation_method)
 
             # merge all predictions
             full_index = self.create_fullindex(data.data.indexes['datetime'], self._get_frequency())
             all_predictions = self.create_forecast_arrays(full_index, list(data.label.indexes['window']),
                                                           CNN=nn_prediction,
                                                           persi=persistence_prediction,
-                                                          orig=orig_pred,
+                                                          obs=observation,
                                                           OLS=ols_prediction)
 
             # save all forecasts locally
@@ -136,7 +134,7 @@ class PostProcessing(RunEnvironment):
         return getter.get(self._sampling, None)
 
     @staticmethod
-    def _create_orig_forecast(data, _, mean, std, transformation_method):
+    def _create_observation(data, _, mean, std, transformation_method):
         return statistics.apply_inverse_transformation(data.label.copy(), mean, std, transformation_method)
 
     def _create_ols_forecast(self, input_data, ols_prediction, mean, std, transformation_method):
@@ -229,7 +227,7 @@ class PostProcessing(RunEnvironment):
         try:
             data = self.train_val_data.get_data_generator(station)
             mean, std, transformation_method = data.get_transformation_information(variable=self.target_var)
-            external_data = self._create_orig_forecast(data, None, mean, std, transformation_method)
+            external_data = self._create_observation(data, None, mean, std, transformation_method)
             external_data = external_data.squeeze("Stations").sel(window=1).drop(["window", "Stations", "variables"])
             return external_data.rename({'datetime': 'index'})
         except KeyError:
diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py
index 4660a8116b6d0b860a7d0d50b92cee5e0deb77d8..3263f5c4562eeac321c7ce621df551fdf6373ba0 100644
--- a/src/run_modules/pre_processing.py
+++ b/src/run_modules/pre_processing.py
@@ -12,7 +12,7 @@ from src.run_modules.run_environment import RunEnvironment
 
 DEFAULT_ARGS_LIST = ["data_path", "network", "stations", "variables", "interpolate_dim", "target_dim", "target_var"]
 DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history_size", "window_lead_time", "statistics_per_var",
-                       "station_type", "overwrite_local_data", "start", "end", "sampling"]
+                       "station_type", "overwrite_local_data", "start", "end", "sampling", "transformation"]
 
 
 class PreProcessing(RunEnvironment):
@@ -35,7 +35,8 @@ class PreProcessing(RunEnvironment):
     def _run(self):
         args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope="general.preprocessing")
         kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope="general.preprocessing")
-        valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general"), load_tmp=False)
+        stations = self.data_store.get("stations", "general")
+        valid_stations = self.check_valid_stations(args, kwargs, stations, load_tmp=False, save_tmp=False)
         self.data_store.set("stations", valid_stations, "general")
         self.split_train_val_test()
         self.report_pre_processing()
@@ -53,11 +54,19 @@ class PreProcessing(RunEnvironment):
         logging.debug(f"TEST SHAPE OF GENERATOR CALL: {self.data_store.get('generator', 'general.test')[0][0].shape}"
                       f"{self.data_store.get('generator', 'general.test')[0][1].shape}")
 
-    def split_train_val_test(self):
+    def split_train_val_test(self) -> None:
+        """
+        Splits all subsets. Currently: train, val, test and train_val (actually this is only the merge of train and val,
+        but as an separate generator). IMPORTANT: Do not change to order of the execution of create_set_split. The train
+        subset needs always to be executed at first, to set a proper transformation.
+        """
         fraction_of_training = self.data_store.get("fraction_of_training", "general")
         stations = self.data_store.get("stations", "general")
         train_index, val_index, test_index, train_val_index = self.split_set_indices(len(stations), fraction_of_training)
         subset_names = ["train", "val", "test", "train_val"]
+        if subset_names[0] != "train":  # pragma: no cover
+            raise AssertionError(f"Make sure, that the train subset is always at first execution position! Given subset"
+                                 f"order was: {subset_names}.")
         for (ind, scope) in zip([train_index, val_index, test_index, train_val_index], subset_names):
             self.create_set_split(ind, scope)
 
@@ -79,7 +88,16 @@ class PreProcessing(RunEnvironment):
         train_val_index = slice(0, pos_test_split)
         return train_index, val_index, test_index, train_val_index
 
-    def create_set_split(self, index_list, set_name):
+    def create_set_split(self, index_list: slice, set_name) -> None:
+        """
+        Create the subset for given split index and stores the DataGenerator with given set name in data store as
+        `generator`. Checks for all valid stations using the default (kw)args for given scope and creates the
+        DataGenerator for all valid stations. Also sets all transformation information, if subset is training set. Make
+        sure, that the train set is executed first, and all other subsets afterwards.
+        :param index_list: list of all stations to use for the set. If attribute use_all_stations_on_all_data_sets=True,
+            this list is ignored.
+        :param set_name: name to load/save all information from/to data store without the leading general prefix.
+        """
         scope = f"general.{set_name}"
         args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope)
         kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope)
@@ -89,14 +107,16 @@ class PreProcessing(RunEnvironment):
         else:
             set_stations = stations[index_list]
         logging.debug(f"{set_name.capitalize()} stations (len={len(set_stations)}): {set_stations}")
-        set_stations = self.check_valid_stations(args, kwargs, set_stations)
+        set_stations = self.check_valid_stations(args, kwargs, set_stations, load_tmp=False)
         self.data_store.set("stations", set_stations, scope)
         set_args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope)
         data_set = DataGenerator(**set_args, **kwargs)
         self.data_store.set("generator", data_set, scope)
+        if set_name == "train":
+            self.data_store.set("transformation", data_set.transformation, "general")
 
     @staticmethod
-    def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True):
+    def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True, save_tmp=True):
         """
         Check if all given stations in `all_stations` are valid. Valid means, that there is data available for the given
         time range (is included in `kwargs`). The shape and the loading time are logged in debug mode.
@@ -118,7 +138,8 @@ class PreProcessing(RunEnvironment):
             t_inner.run()
             try:
                 # (history, label) = data_gen[station]
-                data = data_gen.get_data_generator(key=station, local_tmp_storage=load_tmp)
+                data = data_gen.get_data_generator(key=station, load_local_tmp_storage=load_tmp,
+                                                   save_local_tmp_storage=save_tmp)
                 valid_stations.append(station)
                 logging.debug(f'{station}: history_shape = {data.history.transpose("datetime", "window", "Stations", "variables").shape}')
                 logging.debug(f"{station}: loading time = {t_inner}")
diff --git a/src/run_modules/training.py b/src/run_modules/training.py
index e2a98f27c65e6050b0edae2bbc178abbf97ab646..df60c4f2f8dff4a9acb82920ad3c1d203813033d 100644
--- a/src/run_modules/training.py
+++ b/src/run_modules/training.py
@@ -9,7 +9,7 @@ import pickle
 import keras
 
 from src.data_handling.data_distributor import Distributor
-from src.model_modules.keras_extensions import LearningRateDecay, ModelCheckpointAdvanced
+from src.model_modules.keras_extensions import LearningRateDecay, ModelCheckpointAdvanced, CallbackHandler
 from src.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate
 from src.run_modules.run_environment import RunEnvironment
 
@@ -24,10 +24,10 @@ class Training(RunEnvironment):
         self.test_set = None
         self.batch_size = self.data_store.get("batch_size", "general.model")
         self.epochs = self.data_store.get("epochs", "general.model")
-        self.checkpoint: ModelCheckpointAdvanced = self.data_store.get("checkpoint", "general.model")
-        self.lr_sc = self.data_store.get("lr_decay", "general.model")
-        self.hist = self.data_store.get("hist", "general.model")
+        self.callbacks: CallbackHandler = self.data_store.get("callbacks", "general.model")
         self.experiment_name = self.data_store.get("experiment_name", "general")
+        self._trainable = self.data_store.get("trainable", "general")
+        self._create_new_model = self.data_store.get("create_new_model", "general")
         self._run()
 
     def _run(self) -> None:
@@ -44,8 +44,11 @@ class Training(RunEnvironment):
         """
         self.set_generators()
         self.make_predict_function()
-        self.train()
-        self.save_model()
+        if self._trainable:
+            self.train()
+            self.save_model()
+        else:
+            logging.info("No training has started, because trainable parameter was false.")
 
     def make_predict_function(self) -> None:
         """
@@ -62,7 +65,8 @@ class Training(RunEnvironment):
         :param mode: name of set, should be from ["train", "val", "test"]
         """
         gen = self.data_store.get("generator", f"general.{mode}")
-        setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size))
+        permute_data = self.data_store.get_default("permute_data", f"general.{mode}", default=False)
+        setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size, permute_data=permute_data))
 
     def set_generators(self) -> None:
         """
@@ -82,46 +86,41 @@ class Training(RunEnvironment):
         locally stored information and the corresponding model and proceed with the already started training.
         """
         logging.info(f"Train with {len(self.train_set)} mini batches.")
-        if not os.path.exists(self.checkpoint.filepath):
+        checkpoint = self.callbacks.get_checkpoint()
+        if not os.path.exists(checkpoint.filepath) or self._create_new_model:
             history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
                                                steps_per_epoch=len(self.train_set),
                                                epochs=self.epochs,
                                                verbose=2,
                                                validation_data=self.val_set.distribute_on_batches(),
                                                validation_steps=len(self.val_set),
-                                               callbacks=[self.lr_sc, self.hist, self.checkpoint])
+                                               callbacks=self.callbacks.get_callbacks(as_dict=False))
         else:
             logging.info("Found locally stored model and checkpoints. Training is resumed from the last checkpoint.")
-            lr_filepath = self.checkpoint.callbacks[0]["path"]
-            hist_filepath = self.checkpoint.callbacks[1]["path"]
-            self.lr_sc = pickle.load(open(lr_filepath, "rb"))
-            self.hist = pickle.load(open(hist_filepath, "rb"))
-            self.model = keras.models.load_model(self.checkpoint.filepath)
-            initial_epoch = max(self.hist.epoch) + 1
-            callbacks = [{"callback": self.lr_sc, "path": lr_filepath},
-                         {"callback": self.hist, "path": hist_filepath}]
-            self.checkpoint.update_callbacks(callbacks)
-            self.checkpoint.update_best(self.hist)
+            self.callbacks.load_callbacks()
+            self.callbacks.update_checkpoint()
+            self.model = keras.models.load_model(checkpoint.filepath)
+            hist = self.callbacks.get_callback_by_name("hist")
+            initial_epoch = max(hist.epoch) + 1
             _ = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
                                          steps_per_epoch=len(self.train_set),
                                          epochs=self.epochs,
                                          verbose=2,
                                          validation_data=self.val_set.distribute_on_batches(),
                                          validation_steps=len(self.val_set),
-                                         callbacks=[self.lr_sc, self.hist, self.checkpoint],
+                                         callbacks=self.callbacks.get_callbacks(as_dict=False),
                                          initial_epoch=initial_epoch)
-            history = self.hist
-        self.save_callbacks_as_json(history)
-        self.load_best_model(self.checkpoint.filepath)
-        self.create_monitoring_plots(history, self.lr_sc)
+            history = hist
+        lr = self.callbacks.get_callback_by_name("lr")
+        self.save_callbacks_as_json(history, lr)
+        self.load_best_model(checkpoint.filepath)
+        self.create_monitoring_plots(history, lr)
 
     def save_model(self) -> None:
         """
-        save model in local experiment directory. Model is named as <experiment_name>_my_model.h5 .
+        save model in local experiment directory. Model is named as <experiment_name>_<custom_model_name>.h5 .
         """
-        path = self.data_store.get("experiment_path", "general")
-        name = f"{self.data_store.get('experiment_name', 'general')}_my_model.h5"
-        model_name = os.path.join(path, name)
+        model_name = self.data_store.get("model_name", "general.model")
         logging.debug(f"save best model to {model_name}")
         self.model.save(model_name)
         self.data_store.set("best_model", self.model, "general")
@@ -138,7 +137,7 @@ class Training(RunEnvironment):
         except OSError:
             logging.info('no weights to reload...')
 
-    def save_callbacks_as_json(self, history: keras.callbacks.History) -> None:
+    def save_callbacks_as_json(self, history: keras.callbacks.History, lr_sc: keras.callbacks) -> None:
         """
         Save callbacks (history, learning rate) of training.
         * history.history -> history.json
@@ -150,7 +149,7 @@ class Training(RunEnvironment):
         with open(os.path.join(path, "history.json"), "w") as f:
             json.dump(history.history, f)
         with open(os.path.join(path, "history_lr.json"), "w") as f:
-            json.dump(self.lr_sc.lr, f)
+            json.dump(lr_sc.lr, f)
 
     def create_monitoring_plots(self, history: keras.callbacks.History, lr_sc: LearningRateDecay) -> None:
         """
diff --git a/src/statistics.py b/src/statistics.py
index df73784df830d5f7b96bf0fcd18a65d362516f12..26b2be8854c51584f20b753717ea94cc12967369 100644
--- a/src/statistics.py
+++ b/src/statistics.py
@@ -15,11 +15,11 @@ Data = Union[xr.DataArray, pd.DataFrame]
 
 
 def apply_inverse_transformation(data, mean, std=None, method="standardise"):
-    if method == 'standardise':
+    if method == 'standardise':  # pragma: no branch
         return standardise_inverse(data, mean, std)
-    elif method == 'centre':
+    elif method == 'centre':  # pragma: no branch
         return centre_inverse(data, mean)
-    elif method == 'normalise':
+    elif method == 'normalise':  # pragma: no cover
         # use min/max of data or given min/max
         raise NotImplementedError
     else:
@@ -52,6 +52,17 @@ def standardise_inverse(data: Data, mean: Data, std: Data) -> Data:
     return data * std + mean
 
 
+def standardise_apply(data: Data, mean: Data, std: Data) -> Data:
+    """
+    This applies `standardise` on data using given mean and std.
+    :param data:
+    :param mean:
+    :param std:
+    :return:
+    """
+    return (data - mean) / std
+
+
 def centre(data: Data, dim: Union[str, int]) -> Tuple[Data, None, Data]:
     """
     This function centres a xarray.dataarray (along dim) or pandas.DataFrame (along axis) to mean=0
@@ -77,6 +88,17 @@ def centre_inverse(data: Data, mean: Data) -> Data:
     return data + mean
 
 
+def centre_apply(data: Data, mean: Data) -> Data:
+    """
+    This applies `centre` on data using given mean and std.
+    :param data:
+    :param mean:
+    :param std:
+    :return:
+    """
+    return data - mean
+
+
 def mean_squared_error(a, b):
     return np.square(a - b).mean()
 
@@ -126,12 +148,12 @@ class SkillScores(RunEnvironment):
 
         return skill_score
 
-    def _climatological_skill_score(self, data, mu_type=1, observation_name="orig", forecast_name="CNN", external_data=None):
+    def _climatological_skill_score(self, data, mu_type=1, observation_name="obs", forecast_name="CNN", external_data=None):
         kwargs = {"external_data": external_data} if external_data is not None else {}
         return self.__getattribute__(f"skill_score_mu_case_{mu_type}")(data, observation_name, forecast_name, **kwargs)
 
     @staticmethod
-    def general_skill_score(data, observation_name="orig", forecast_name="CNN", reference_name="persi"):
+    def general_skill_score(data, observation_name="obs", forecast_name="CNN", reference_name="persi"):
         data = data.dropna("index")
         observation = data.sel(type=observation_name)
         forecast = data.sel(type=forecast_name)
@@ -159,12 +181,12 @@ class SkillScores(RunEnvironment):
         suffix = {"mean": mean, "sigma": sigma, "r": r, "p": p}
         return AI, BI, CI, data, suffix
 
-    def skill_score_mu_case_1(self, data, observation_name="orig", forecast_name="CNN"):
+    def skill_score_mu_case_1(self, data, observation_name="obs", forecast_name="CNN"):
         AI, BI, CI, data, _ = self.skill_score_pre_calculations(data, observation_name, forecast_name)
         skill_score = np.array(AI - BI - CI)
         return pd.DataFrame({"skill_score": [skill_score], "AI": [AI], "BI": [BI], "CI": [CI]}).to_xarray().to_array()
 
-    def skill_score_mu_case_2(self, data, observation_name="orig", forecast_name="CNN"):
+    def skill_score_mu_case_2(self, data, observation_name="obs", forecast_name="CNN"):
         AI, BI, CI, data, suffix = self.skill_score_pre_calculations(data, observation_name, forecast_name)
         monthly_mean = self.create_monthly_mean_from_daily_data(data)
         data = xr.concat([data, monthly_mean], dim="type")
@@ -177,14 +199,14 @@ class SkillScores(RunEnvironment):
         skill_score = np.array((AI - BI - CI - AII + BII) / (1 - AII + BII))
         return pd.DataFrame({"skill_score": [skill_score], "AII": [AII], "BII": [BII]}).to_xarray().to_array()
 
-    def skill_score_mu_case_3(self, data, observation_name="orig", forecast_name="CNN", external_data=None):
+    def skill_score_mu_case_3(self, data, observation_name="obs", forecast_name="CNN", external_data=None):
         AI, BI, CI, data, suffix = self.skill_score_pre_calculations(data, observation_name, forecast_name)
         mean, sigma = suffix["mean"], suffix["sigma"]
         AIII = (((external_data.mean().values - mean.loc[observation_name]) / sigma.loc[observation_name])**2).values
         skill_score = np.array((AI - BI - CI + AIII) / 1 + AIII)
         return pd.DataFrame({"skill_score": [skill_score], "AIII": [AIII]}).to_xarray().to_array()
 
-    def skill_score_mu_case_4(self, data, observation_name="orig", forecast_name="CNN", external_data=None):
+    def skill_score_mu_case_4(self, data, observation_name="obs", forecast_name="CNN", external_data=None):
         AI, BI, CI, data, suffix = self.skill_score_pre_calculations(data, observation_name, forecast_name)
         monthly_mean_external = self.create_monthly_mean_from_daily_data(external_data, columns=data.type.values, index=data.index)
         data = xr.concat([data, monthly_mean_external], dim="type")
diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py
index 4c6dbb1c38f2e4a49e53883fbe3cb33cb565118a..a26e76a0e7f3ef0f5cdbedc07d73a690116966c9 100644
--- a/test/test_data_handling/test_data_distributor.py
+++ b/test/test_data_handling/test_data_distributor.py
@@ -37,7 +37,7 @@ class TestDistributor:
 
     def test_init_defaults(self, distributor):
         assert distributor.batch_size == 256
-        assert distributor.fit_call is True
+        assert distributor.do_data_permutation is False
 
     def test_get_model_rank(self, distributor, model_with_minor_branch):
         assert distributor._get_model_rank() == 1
@@ -73,3 +73,28 @@ class TestDistributor:
         d = Distributor(gen, model)
         expected = math.ceil(len(gen[0][0]) / 256) + math.ceil(len(gen[1][0]) / 256)
         assert len(d) == expected
+
+    def test_permute_data_no_permutation(self, distributor):
+        x = np.array(range(20)).reshape(2, 10).T
+        y = np.array(range(10)).reshape(10, 1)
+        x_perm, y_perm = distributor._permute_data(x, y)
+        assert np.testing.assert_equal(x, x_perm) is None
+        assert np.testing.assert_equal(y, y_perm) is None
+
+    def test_permute_data(self, distributor):
+        x = np.array(range(20)).reshape(2, 10).T
+        y = np.array(range(10)).reshape(10, 1)
+        distributor.do_data_permutation = True
+        x_perm, y_perm = distributor._permute_data(x, y)
+        assert x_perm[0, 0] == y_perm[0]
+        assert x_perm[0, 1] == y_perm[0] + 10
+        assert x_perm[5, 0] == y_perm[5]
+        assert x_perm[5, 1] == y_perm[5] + 10
+        assert x_perm[-1, 0] == y_perm[-1]
+        assert x_perm[-1, 1] == y_perm[-1] + 10
+        # resort x_perm and compare if equal to x
+        x_perm.sort(axis=0)
+        y_perm.sort(axis=0)
+        assert np.testing.assert_equal(x, x_perm) is None
+        assert np.testing.assert_equal(y, y_perm) is None
+
diff --git a/test/test_data_handling/test_data_generator.py b/test/test_data_handling/test_data_generator.py
index 142acd166604951352ad6686548c2cb76f609ce0..9bf11154609afa9ada2b488455f7a341a41d21ae 100644
--- a/test/test_data_handling/test_data_generator.py
+++ b/test/test_data_handling/test_data_generator.py
@@ -1,12 +1,15 @@
 import os
 
+import operator as op
 import pytest
 
 import shutil
 import numpy as np
+import xarray as xr
 import pickle
 from src.data_handling.data_generator import DataGenerator
 from src.data_handling.data_preparation import DataPrep
+from src.join import EmptyQueryResult
 
 
 class TestDataGenerator:
@@ -22,6 +25,56 @@ class TestDataGenerator:
         return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'],
                              'datetime', 'variables', 'o3', start=2010, end=2014)
 
+    @pytest.fixture
+    def gen_with_transformation(self):
+        return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'],
+                             'datetime', 'variables', 'o3', start=2010, end=2014,
+                             transformation={"scope": "data", "mean": "estimate"},
+                             statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
+
+    @pytest.fixture
+    def gen_no_init(self):
+        generator = object.__new__(DataGenerator)
+        path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data'))
+        generator.data_path = path
+        if not os.path.exists(path):
+            os.makedirs(path)
+        generator.stations = ["DEBW107", "DEBW013", "DEBW001"]
+        generator.network = "AIRBASE"
+        generator.variables = ["temp", "o3"]
+        generator.station_type = "background"
+        generator.kwargs = {"start": 2010, "end": 2014, "statistics_per_var": {'o3': 'dma8eu', 'temp': 'maximum'}}
+        return generator
+
+    @pytest.fixture
+    def accurate_transformation(self, gen_no_init):
+        tmp = np.nan
+        for station in gen_no_init.stations:
+            try:
+                data_prep = DataPrep(gen_no_init.data_path, gen_no_init.network, station, gen_no_init.variables,
+                                     station_type=gen_no_init.station_type, **gen_no_init.kwargs)
+                tmp = data_prep.data.combine_first(tmp)
+            except EmptyQueryResult:
+                continue
+        mean_expected = tmp.mean(dim=["Stations", "datetime"])
+        std_expected = tmp.std(dim=["Stations", "datetime"])
+        return mean_expected, std_expected
+
+    @pytest.fixture
+    def estimated_transformation(self, gen_no_init):
+        mean, std = None, None
+        for station in gen_no_init.stations:
+            try:
+                data_prep = DataPrep(gen_no_init.data_path, gen_no_init.network, station, gen_no_init.variables,
+                                     station_type=gen_no_init.station_type, **gen_no_init.kwargs)
+                mean = data_prep.data.mean(axis=1).combine_first(mean)
+                std = data_prep.data.std(axis=1).combine_first(std)
+            except EmptyQueryResult:
+                continue
+        mean_expected = mean.mean(axis=0)
+        std_expected = std.mean(axis=0)
+        return mean_expected, std_expected
+
     class DummyDataPrep:
         def __init__(self, data):
             self.station = "DEBW107"
@@ -41,7 +94,7 @@ class TestDataGenerator:
         assert gen.limit_nan_fill == 1
         assert gen.window_history_size == 7
         assert gen.window_lead_time == 4
-        assert gen.transform_method == "standardise"
+        assert gen.transformation is None
         assert gen.kwargs == {"start": 2010, "end": 2014}
 
     def test_repr(self, gen):
@@ -76,6 +129,72 @@ class TestDataGenerator:
         assert station[1].data.shape[-1] == gen.window_lead_time
         assert station[0].data.shape[1] == gen.window_history_size + 1
 
+    def test_setup_transformation_no_transformation(self, gen_no_init):
+        assert gen_no_init.setup_transformation(None) is None
+        assert gen_no_init.setup_transformation({}) == {"method": "standardise", "mean": None, "std": None}
+        assert gen_no_init.setup_transformation({"scope": "station", "mean": "accurate"}) == \
+               {"scope": "station", "method": "standardise", "mean": None, "std": None}
+
+    def test_setup_transformation_calculate_statistics(self, gen_no_init):
+        transformation = {"scope": "data", "mean": "accurate"}
+        res_acc = gen_no_init.setup_transformation(transformation)
+        assert sorted(res_acc.keys()) == sorted(["scope", "mean", "std", "method"])
+        assert isinstance(res_acc["mean"], xr.DataArray)
+        assert isinstance(res_acc["std"], xr.DataArray)
+        transformation["mean"] = "estimate"
+        res_est = gen_no_init.setup_transformation(transformation)
+        assert sorted(res_est.keys()) == sorted(["scope", "mean", "std", "method"])
+        assert isinstance(res_est["mean"], xr.DataArray)
+        assert isinstance(res_est["std"], xr.DataArray)
+        assert np.testing.assert_array_compare(op.__ne__, res_est["std"].values, res_acc["std"].values) is None
+
+    def test_setup_transformation_use_given_statistics(self, gen_no_init):
+        mean = xr.DataArray([30, 15], coords={"variables": ["o3", "temp"]}, dims=["variables"])
+        transformation = {"scope": "data", "method": "centre", "mean": mean}
+        res = gen_no_init.setup_transformation(transformation)
+        assert np.testing.assert_equal(res["mean"].values, mean.values) is None
+        assert res["std"] is None
+
+    def test_setup_transformation_errors(self, gen_no_init):
+        transformation = {"scope": "random", "mean": "accurate"}
+        with pytest.raises(ValueError):
+            gen_no_init.setup_transformation(transformation)
+        transformation = {"scope": "data", "mean": "fit"}
+        with pytest.raises(ValueError):
+            gen_no_init.setup_transformation(transformation)
+
+    def test_calculate_accurate_transformation_standardise(self, gen_no_init, accurate_transformation):
+        mean_expected, std_expected = accurate_transformation
+        mean, std = gen_no_init.calculate_accurate_transformation("standardise")
+        assert np.testing.assert_almost_equal(mean_expected.values, mean.values) is None
+        assert np.testing.assert_almost_equal(std_expected.values, std.values) is None
+
+    def test_calculate_accurate_transformation_centre(self, gen_no_init, accurate_transformation):
+        mean_expected, _ = accurate_transformation
+        mean, std = gen_no_init.calculate_accurate_transformation("centre")
+        assert np.testing.assert_almost_equal(mean_expected.values, mean.values) is None
+        assert std is None
+
+    def test_calculate_accurate_transformation_all_others(self, gen_no_init):
+        with pytest.raises(NotImplementedError):
+            gen_no_init.calculate_accurate_transformation("normalise")
+
+    def test_calculate_estimated_transformation_standardise(self, gen_no_init, estimated_transformation):
+        mean_expected, std_expected = estimated_transformation
+        mean, std = gen_no_init.calculate_estimated_transformation("standardise")
+        assert np.testing.assert_almost_equal(mean_expected.values, mean.values) is None
+        assert np.testing.assert_almost_equal(std_expected.values, std.values) is None
+
+    def test_calculate_estimated_transformation_centre(self, gen_no_init, estimated_transformation):
+        mean_expected, _ = estimated_transformation
+        mean, std = gen_no_init.calculate_estimated_transformation("centre")
+        assert np.testing.assert_almost_equal(mean_expected.values, mean.values) is None
+        assert std is None
+
+    def test_calculate_estimated_transformation_all_others(self, gen_no_init):
+        with pytest.raises(NotImplementedError):
+            gen_no_init.calculate_estimated_transformation("normalise")
+
     def test_get_station_key(self, gen):
         gen.stations.append("DEBW108")
         f = gen.get_station_key
@@ -104,7 +223,7 @@ class TestDataGenerator:
         if os.path.exists(file):
             os.remove(file)
         assert not os.path.exists(file)
-        assert isinstance(gen.get_data_generator("DEBW107", local_tmp_storage=False), DataPrep)
+        assert isinstance(gen.get_data_generator("DEBW107", load_local_tmp_storage=False), DataPrep)
         t = os.stat(file).st_ctime
         assert os.path.exists(file)
         assert isinstance(gen.get_data_generator("DEBW107"), DataPrep)
@@ -113,6 +232,12 @@ class TestDataGenerator:
         assert isinstance(gen.get_data_generator("DEBW107"), DataPrep)
         assert os.stat(file).st_ctime > t
 
+    def test_get_data_generator_transform(self, gen_with_transformation):
+        gen = gen_with_transformation
+        data = gen.get_data_generator("DEBW107", load_local_tmp_storage=False, save_local_tmp_storage=False)
+        assert data._transform_method == "standardise"
+        assert data.mean is not None
+
     def test_save_pickle_data(self, gen):
         file = os.path.join(gen.data_path_tmp, f"DEBW107_{'_'.join(sorted(gen.variables))}_2010_2014_.pickle")
         if os.path.exists(file):
diff --git a/test/test_data_handling/test_data_preparation.py b/test/test_data_handling/test_data_preparation.py
index 72bacaf9cc1e5a9dc9736b8e8eb7161f35d8ea69..ac449c4dc6d4c83a457eccc93a766ec4f17f58c9 100644
--- a/test/test_data_handling/test_data_preparation.py
+++ b/test/test_data_handling/test_data_preparation.py
@@ -152,6 +152,26 @@ class TestDataPrep:
         assert isinstance(data.mean, xr.DataArray)
         assert isinstance(data.std, xr.DataArray)
 
+    def test_transform_standardise_apply(self, data):
+        assert data._transform_method is None
+        assert data.mean is None
+        assert data.std is None
+        data_mean_orig = data.data.mean('datetime').variable.values
+        data_std_orig = data.data.std('datetime').variable.values
+        mean_external = np.array([20, 12])
+        std_external = np.array([15, 5])
+        mean = xr.DataArray(mean_external, coords={"variables": ['o3', 'temp']}, dims=["variables"])
+        std = xr.DataArray(std_external, coords={"variables": ['o3', 'temp']}, dims=["variables"])
+        data.transform('datetime', mean=mean, std=std)
+        assert all(data.mean.values == mean_external)
+        assert all(data.std.values == std_external)
+        data_mean_transformed = data.data.mean('datetime').variable.values
+        data_std_transformed = data.data.std('datetime').variable.values
+        data_mean_expected = (data_mean_orig - mean_external) / std_external  # mean scales as any other data
+        data_std_expected = data_std_orig / std_external  # std scales by given std
+        assert np.testing.assert_almost_equal(data_mean_transformed, data_mean_expected) is None
+        assert np.testing.assert_almost_equal(data_std_transformed, data_std_expected) is None
+
     @pytest.mark.parametrize('mean, std, method, msg', [(10, 3, 'standardise', ''), (6, None, 'standardise', 'std, '),
                                                         (None, 3, 'standardise', 'mean, '), (19, None, 'centre', ''),
                                                         (None, 2, 'centre', 'mean, '), (8, 2, 'centre', ''),
@@ -168,12 +188,29 @@ class TestDataPrep:
         assert data._transform_method is None
         assert data.mean is None
         assert data.std is None
-        data_std_org = data.data.std('datetime'). variable.values
+        data_std_orig = data.data.std('datetime'). variable.values
         data.transform('datetime', 'centre')
         assert data._transform_method == 'centre'
         assert np.testing.assert_almost_equal(data.data.mean('datetime').variable.values, np.array([[0, 0]])) is None
-        assert np.testing.assert_almost_equal(data.data.std('datetime').variable.values, data_std_org) is None
+        assert np.testing.assert_almost_equal(data.data.std('datetime').variable.values, data_std_orig) is None
+        assert data.std is None
+
+    def test_transform_centre_apply(self, data):
+        assert data._transform_method is None
+        assert data.mean is None
+        assert data.std is None
+        data_mean_orig = data.data.mean('datetime').variable.values
+        data_std_orig = data.data.std('datetime').variable.values
+        mean_external = np.array([20, 12])
+        mean = xr.DataArray(mean_external, coords={"variables": ['o3', 'temp']}, dims=["variables"])
+        data.transform('datetime', 'centre', mean=mean)
+        assert all(data.mean.values == mean_external)
         assert data.std is None
+        data_mean_transformed = data.data.mean('datetime').variable.values
+        data_std_transformed = data.data.std('datetime').variable.values
+        data_mean_expected = (data_mean_orig - mean_external)  # mean scales as any other data
+        assert np.testing.assert_almost_equal(data_mean_transformed, data_mean_expected) is None
+        assert np.testing.assert_almost_equal(data_std_transformed, data_std_orig) is None
 
     @pytest.mark.parametrize('method', ['standardise', 'centre'])
     def test_transform_inverse(self, data, method):
diff --git a/test/test_datastore.py b/test/test_datastore.py
index 95a58deafc915dd6193960e77bb99cc8ab8d85cb..9fcb319f51954b365c59274a4a9744f093e155f1 100644
--- a/test/test_datastore.py
+++ b/test/test_datastore.py
@@ -30,6 +30,14 @@ class TestDataStoreByVariable:
     def ds(self):
         return DataStoreByVariable()
 
+    @pytest.fixture
+    def ds_with_content(self, ds):
+        ds.set("tester1", 1, "general")
+        ds.set("tester2", 11, "general")
+        ds.set("tester2", 10, "general.sub")
+        ds.set("tester3", 21, "general")
+        return ds
+
     def test_put(self, ds):
         ds.set("number", 3, "general.subscope")
         assert ds._store["number"]["general.subscope"] == 3
@@ -131,15 +139,18 @@ class TestDataStoreByVariable:
         assert ds.search_scope("general.sub.sub", current_scope_only=False, return_all=True) == \
             [("number", "general.sub.sub", "ABC"), ("number1", "general.sub", 22), ("number2", "general.sub.sub", 3)]
 
-    def test_create_args_dict(self, ds):
-        ds.set("tester1", 1, "general")
-        ds.set("tester2", 11, "general")
-        ds.set("tester2", 10, "general.sub")
-        ds.set("tester3", 21, "general")
+    def test_create_args_dict_default_scope(self, ds_with_content):
         args = ["tester1", "tester2", "tester3", "tester4"]
-        assert ds.create_args_dict(args) == {"tester1": 1, "tester2": 11, "tester3": 21}
-        assert ds.create_args_dict(args, "general.sub") == {"tester1": 1, "tester2": 10, "tester3": 21}
-        assert ds.create_args_dict(["notAvail", "alsonot"]) == {}
+        assert ds_with_content.create_args_dict(args) == {"tester1": 1, "tester2": 11, "tester3": 21}
+
+    def test_create_args_dict_given_scope(self, ds_with_content):
+        args = ["tester1", "tester2", "tester3", "tester4"]
+        assert ds_with_content.create_args_dict(args, "general.sub") == {"tester1": 1, "tester2": 10, "tester3": 21}
+
+    def test_create_args_dict_missing_entry(self, ds_with_content):
+        args = ["tester1", "notAvail", "tester4"]
+        assert ds_with_content.create_args_dict(["notAvail", "alsonot"]) == {}
+        assert ds_with_content.create_args_dict(args) == {"tester1": 1}
 
     def test_set_args_from_dict(self, ds):
         ds.set_args_from_dict({"tester1": 1, "tester2": 10, "tester3": 21})
@@ -157,6 +168,14 @@ class TestDataStoreByScope:
     def ds(self):
         return DataStoreByScope()
 
+    @pytest.fixture
+    def ds_with_content(self, ds):
+        ds.set("tester1", 1, "general")
+        ds.set("tester2", 11, "general")
+        ds.set("tester2", 10, "general.sub")
+        ds.set("tester3", 21, "general")
+        return ds
+
     def test_put_with_scope(self, ds):
         ds.set("number", 3, "general.subscope")
         assert ds._store["general.subscope"]["number"] == 3
@@ -258,15 +277,18 @@ class TestDataStoreByScope:
         assert ds.search_scope("general.sub.sub", current_scope_only=False, return_all=True) == \
             [("number", "general.sub.sub", "ABC"), ("number1", "general.sub", 22), ("number2", "general.sub.sub", 3)]
 
-    def test_create_args_dict(self, ds):
-        ds.set("tester1", 1, "general")
-        ds.set("tester2", 11, "general")
-        ds.set("tester2", 10, "general.sub")
-        ds.set("tester3", 21, "general")
+    def test_create_args_dict_default_scope(self, ds_with_content):
         args = ["tester1", "tester2", "tester3", "tester4"]
-        assert ds.create_args_dict(args) == {"tester1": 1, "tester2": 11, "tester3": 21}
-        assert ds.create_args_dict(args, "general.sub") == {"tester1": 1, "tester2": 10, "tester3": 21}
-        assert ds.create_args_dict(["notAvail", "alsonot"]) == {}
+        assert ds_with_content.create_args_dict(args) == {"tester1": 1, "tester2": 11, "tester3": 21}
+
+    def test_create_args_dict_given_scope(self, ds_with_content):
+        args = ["tester1", "tester2", "tester3", "tester4"]
+        assert ds_with_content.create_args_dict(args, "general.sub") == {"tester1": 1, "tester2": 10, "tester3": 21}
+
+    def test_create_args_dict_missing_entry(self, ds_with_content):
+        args = ["tester1", "notAvail", "tester4"]
+        assert ds_with_content.create_args_dict(["notAvail", "alsonot"]) == {}
+        assert ds_with_content.create_args_dict(args) == {"tester1": 1}
 
     def test_set_args_from_dict(self, ds):
         ds.set_args_from_dict({"tester1": 1, "tester2": 10, "tester3": 21})
diff --git a/test/test_model_modules/test_keras_extensions.py b/test/test_model_modules/test_keras_extensions.py
index 2f6565b4cabe295169047a6582d2b89cbf387062..17ab4f6d65c95a5a54c9d931818f889acadef532 100644
--- a/test/test_model_modules/test_keras_extensions.py
+++ b/test/test_model_modules/test_keras_extensions.py
@@ -1,6 +1,8 @@
 import keras
 import numpy as np
 import pytest
+import mock
+import os
 
 from src.helpers import l_p_loss
 from src.model_modules.keras_extensions import *
@@ -60,3 +62,172 @@ class TestLearningRateDecay:
         model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(2))
         model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=5, callbacks=[lr_decay])
         assert lr_decay.lr['lr'] == [0.02, 0.02, 0.02 * 0.95, 0.02 * 0.95, 0.02 * 0.95 * 0.95]
+
+
+class TestModelCheckpointAdvanced:
+
+    @pytest.fixture()
+    def callbacks(self):
+        callbacks_name = os.path.join(os.path.dirname(__file__), "callback_%s")
+        return [{"callback": LearningRateDecay(), "path": callbacks_name % "lr"},
+                     {"callback": HistoryAdvanced(), "path": callbacks_name % "hist"}]
+
+    @pytest.fixture
+    def ckpt(self, callbacks):
+        ckpt_name = "ckpt.test"
+        return ModelCheckpointAdvanced(filepath=ckpt_name, monitor='val_loss', save_best_only=True, callbacks=callbacks, verbose=1)
+
+    def test_init(self, ckpt, callbacks):
+        assert ckpt.callbacks == callbacks
+        assert ckpt.monitor == "val_loss"
+        assert ckpt.save_best_only is True
+        assert ckpt.best == np.inf
+
+    def test_update_best(self, ckpt):
+        hist = HistoryAdvanced()
+        hist.history["val_loss"] = [10, 6]
+        ckpt.update_best(hist)
+        assert ckpt.best == 6
+
+    def test_update_callbacks(self, ckpt, callbacks):
+        ckpt.update_callbacks(callbacks[0])
+        assert ckpt.callbacks == [callbacks[0]]
+
+    def test_on_epoch_end(self, ckpt):
+        path = os.path.dirname(__file__)
+        ckpt.set_model(mock.MagicMock())
+        ckpt.best = 6
+        ckpt.on_epoch_end(0, {"val_loss": 6})
+        assert "callback_hist" not in os.listdir(path)
+        ckpt.on_epoch_end(9, {"val_loss": 10})
+        assert "callback_hist" not in os.listdir(path)
+        ckpt.on_epoch_end(10, {"val_loss": 4})
+        assert "callback_hist" in os.listdir(path)
+        os.remove(os.path.join(path, "callback_hist"))
+        os.remove(os.path.join(path, "callback_lr"))
+        ckpt.save_best_only = False
+        ckpt.on_epoch_end(10, {"val_loss": 3})
+        assert "callback_hist" in os.listdir(path)
+        os.remove(os.path.join(path, "callback_hist"))
+        os.remove(os.path.join(path, "callback_lr"))
+
+
+class TestCallbackHandler:
+
+    @pytest.fixture
+    def clbk_handler(self):
+        return CallbackHandler()
+
+    @pytest.fixture
+    def clbk_handler_with_dummies(self, clbk_handler):
+        clbk_handler.add_callback("callback_new_instance", "this_path")
+        clbk_handler.add_callback("callback_other", "otherpath", "other_clbk")
+        return clbk_handler
+
+    @pytest.fixture
+    def callback_handler(self, clbk_handler):
+        clbk_handler.add_callback(HistoryAdvanced(), "callbacks_hist.pickle", "hist")
+        clbk_handler.add_callback(LearningRateDecay(), "callbacks_lr.pickle", "lr")
+        return clbk_handler
+
+    @pytest.fixture
+    def prepare_pickle_files(self):
+        hist = HistoryAdvanced()
+        hist.epoch = [1, 2, 3]
+        hist.history = {"val_loss": [10, 5, 4]}
+        lr = LearningRateDecay()
+        lr.epoch = [1, 2, 3]
+        pickle.dump(hist, open("callbacks_hist.pickle", "wb"))
+        pickle.dump(lr, open("callbacks_lr.pickle", "wb"))
+        yield
+        os.remove("callbacks_hist.pickle")
+        os.remove("callbacks_lr.pickle")
+
+    def test_init(self, clbk_handler):
+        assert len(clbk_handler._CallbackHandler__callbacks) == 0
+        assert clbk_handler._checkpoint is None
+        assert clbk_handler.editable is True
+
+    def test_callbacks_set(self, clbk_handler):
+        clbk_handler._callbacks = ("default", "callback_instance", "callback_path")
+        assert clbk_handler._CallbackHandler__callbacks == [{"name": "default", "default": "callback_instance",
+                                                             "path": "callback_path"}]
+        clbk_handler._callbacks = ("another", "callback_instance2", "callback_path")
+        assert clbk_handler._CallbackHandler__callbacks == [{"name": "default", "default": "callback_instance",
+                                                             "path": "callback_path"},
+                                                            {"name": "another", "another": "callback_instance2",
+                                                             "path": "callback_path"}]
+
+    def test_callbacks_get(self, clbk_handler):
+        clbk_handler._callbacks = ("default", "callback_instance", "callback_path")
+        clbk_handler._callbacks = ("another", "callback_instance2", "callback_path2")
+        assert clbk_handler._callbacks == [{"callback": "callback_instance", "path": "callback_path"},
+                                           {"callback": "callback_instance2", "path": "callback_path2"}]
+
+    def test_update_callback(self, clbk_handler_with_dummies):
+        clbk_handler_with_dummies._update_callback(0, "old_instance")
+        assert clbk_handler_with_dummies.get_callbacks() == [{"callback": "old_instance", "path": "this_path"},
+                                                             {"callback": "callback_other", "path": "otherpath"}]
+
+    def test_add_callback(self, clbk_handler):
+        clbk_handler.add_callback("callback_new_instance", "this_path")
+        assert clbk_handler._CallbackHandler__callbacks == [{"name": "callback", "callback": "callback_new_instance",
+                                                             "path": "this_path"}]
+        clbk_handler.add_callback("callback_other", "otherpath", "other_clbk")
+        assert clbk_handler._CallbackHandler__callbacks == [{"name": "callback", "callback": "callback_new_instance",
+                                                             "path": "this_path"},
+                                                            {"name": "other_clbk", "other_clbk": "callback_other",
+                                                             "path": "otherpath"}]
+
+    def test_get_callbacks_as_dict(self, clbk_handler_with_dummies):
+        clbk = clbk_handler_with_dummies
+        assert clbk.get_callbacks() == [{"callback": "callback_new_instance", "path": "this_path"},
+                                        {"callback": "callback_other", "path": "otherpath"}]
+        assert clbk.get_callbacks() == clbk.get_callbacks(as_dict=True)
+
+    def test_get_callbacks_no_dict(self, clbk_handler_with_dummies):
+        assert clbk_handler_with_dummies.get_callbacks(as_dict=False) == ["callback_new_instance", "callback_other"]
+
+    def test_get_callback_by_name(self, clbk_handler_with_dummies):
+        assert clbk_handler_with_dummies.get_callback_by_name("other_clbk") == "callback_other"
+        assert clbk_handler_with_dummies.get_callback_by_name("callback") is None
+
+    def test__get_callbacks(self, clbk_handler_with_dummies):
+        clbk = clbk_handler_with_dummies
+        assert clbk._get_callbacks() == [{"callback": "callback_new_instance", "path": "this_path"},
+                                         {"callback": "callback_other", "path": "otherpath"}]
+        ckpt = keras.callbacks.ModelCheckpoint("testFilePath")
+        clbk._checkpoint = ckpt
+        assert clbk._get_callbacks() == [{"callback": "callback_new_instance", "path": "this_path"},
+                                         {"callback": "callback_other", "path": "otherpath"},
+                                         {"callback": ckpt, "path": "testFilePath"}]
+
+    def test_get_checkpoint(self, clbk_handler):
+        assert clbk_handler.get_checkpoint() is None
+        clbk_handler._checkpoint = "testCKPT"
+        assert clbk_handler.get_checkpoint() == "testCKPT"
+
+    def test_create_model_checkpoint(self, callback_handler):
+        callback_handler.create_model_checkpoint(filepath="tester_path", verbose=1)
+        assert callback_handler.editable is False
+        assert isinstance(callback_handler._checkpoint, ModelCheckpointAdvanced)
+        assert callback_handler._checkpoint.filepath == "tester_path"
+        assert callback_handler._checkpoint.verbose == 1
+        assert callback_handler._checkpoint.monitor == "val_loss"
+
+    def test_load_callbacks(self, callback_handler, prepare_pickle_files):
+        assert len(callback_handler.get_callback_by_name("hist").epoch) == 0
+        assert len(callback_handler.get_callback_by_name("lr").epoch) == 0
+        callback_handler.load_callbacks()
+        assert len(callback_handler.get_callback_by_name("hist").epoch) == 3
+        assert len(callback_handler.get_callback_by_name("lr").epoch) == 3
+
+    def test_update_checkpoint(self, callback_handler, prepare_pickle_files):
+        assert len(callback_handler.get_callback_by_name("hist").epoch) == 0
+        assert len(callback_handler.get_callback_by_name("lr").epoch) == 0
+        callback_handler.create_model_checkpoint(filepath="tester_path", verbose=1)
+        callback_handler.load_callbacks()
+        callback_handler.update_checkpoint()
+        assert len(callback_handler.get_callback_by_name("hist").epoch) == 3
+        assert len(callback_handler.get_callback_by_name("lr").epoch) == 3
+        assert callback_handler._checkpoint.best == 4
diff --git a/test/test_modules/test_experiment_setup.py b/test/test_modules/test_experiment_setup.py
index 7a4f16fd055e0af0aea95181d475d061c185ca92..9e6d17627d1697a2150ea7f74a373a720d2f02ac 100644
--- a/test/test_modules/test_experiment_setup.py
+++ b/test/test_modules/test_experiment_setup.py
@@ -47,7 +47,8 @@ class TestExperimentSetup:
         data_store = exp_setup.data_store
         # experiment setup
         assert data_store.get("data_path", "general") == prepare_host()
-        assert data_store.get("trainable", "general") is False
+        assert data_store.get("trainable", "general") is True
+        assert data_store.get("create_new_model", "general") is True
         assert data_store.get("fraction_of_training", "general") == 0.8
         # set experiment name
         assert data_store.get("experiment_name", "general") == "TestExperiment"
@@ -104,13 +105,14 @@ class TestExperimentSetup:
                       target_var="temp", target_dim="target", window_lead_time=10, dimensions="dim1",
                       interpolate_dim="int_dim", interpolate_method="cubic", limit_nan_fill=5, train_start="2000-01-01",
                       train_end="2000-01-02", val_start="2000-01-03", val_end="2000-01-04", test_start="2000-01-05",
-                      test_end="2000-01-06", use_all_stations_on_all_data_sets=False, trainable=True, 
-                      fraction_of_train=0.5, experiment_path=experiment_path)
+                      test_end="2000-01-06", use_all_stations_on_all_data_sets=False, trainable=False,
+                      fraction_of_train=0.5, experiment_path=experiment_path, create_new_model=True)
         exp_setup = ExperimentSetup(**kwargs)
         data_store = exp_setup.data_store
         # experiment setup
         assert data_store.get("data_path", "general") == prepare_host()
         assert data_store.get("trainable", "general") is True
+        assert data_store.get("create_new_model", "general") is True
         assert data_store.get("fraction_of_training", "general") == 0.5
         # set experiment name
         assert data_store.get("experiment_name", "general") == "TODAY_network"
@@ -150,10 +152,30 @@ class TestExperimentSetup:
         # use all stations on all data sets (train, val, test)
         assert data_store.get("use_all_stations_on_all_data_sets", "general.test") is False
 
+    def test_init_trainable_behaviour(self):
+        exp_setup = ExperimentSetup(trainable=False, create_new_model=True)
+        data_store = exp_setup.data_store
+        assert data_store.get("trainable", "general") is True
+        assert data_store.get("create_new_model", "general") is True
+        exp_setup = ExperimentSetup(trainable=False, create_new_model=False)
+        data_store = exp_setup.data_store
+        assert data_store.get("trainable", "general") is False
+        assert data_store.get("create_new_model", "general") is False
+        exp_setup = ExperimentSetup(trainable=True, create_new_model=True)
+        data_store = exp_setup.data_store
+        assert data_store.get("trainable", "general") is True
+        assert data_store.get("create_new_model", "general") is True
+        exp_setup = ExperimentSetup(trainable=True, create_new_model=False)
+        data_store = exp_setup.data_store
+        assert data_store.get("trainable", "general") is True
+        assert data_store.get("create_new_model", "general") is False
+
     def test_compare_variables_and_statistics(self):
+        experiment_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder"))
         kwargs = dict(parser_args={"experiment_date": "TODAY"},
                       var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'},
-                      stations=['DEBY053', 'DEBW059', 'DEBW027'], variables=["o3", "relhum"], statistics_per_var=None)
+                      stations=['DEBY053', 'DEBW059', 'DEBW027'], variables=["o3", "relhum"], statistics_per_var=None,
+                      experiment_path=experiment_path)
         with pytest.raises(ValueError) as e:
             ExperimentSetup(**kwargs)
         assert "for the variables: {'relhum'}" in e.value.args[0]
diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py
index 2864ae45bcd7d3c6109d6d84fe5ea152a7d86384..ade35a244601d138d22af6305e67b5aeae964680 100644
--- a/test/test_modules/test_model_setup.py
+++ b/test/test_modules/test_model_setup.py
@@ -20,6 +20,7 @@ class TestModelSetup:
         obj.callbacks_name = "placeholder_%s_str.pickle"
         obj.data_store.set("lr_decay", "dummy_str", "general.model")
         obj.data_store.set("hist", "dummy_str", "general.model")
+        obj.model_name = "%s.h5"
         yield obj
         RunEnvironment().__del__()
 
@@ -55,11 +56,11 @@ class TestModelSetup:
     def current_scope_as_set(model_cls):
         return set(model_cls.data_store.search_scope(model_cls.scope, current_scope_only=True))
 
-    def test_set_checkpoint(self, setup):
-        assert "general.modeltest" not in setup.data_store.search_name("checkpoint")
+    def test_set_callbacks(self, setup):
+        assert "general.modeltest" not in setup.data_store.search_name("callbacks")
         setup.checkpoint_name = "TestName"
-        setup._set_checkpoint()
-        assert "general.modeltest" in setup.data_store.search_name("checkpoint")
+        setup._set_callbacks()
+        assert "general.modeltest" in setup.data_store.search_name("callbacks")
 
     def test_get_model_settings(self, setup_with_model):
         with pytest.raises(EmptyScope):
diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py
index 9d1feb03ac71b980f6a4cd1b0e6cac2a52d9625b..29172a1b8500b605859e925574535c6158c7d805 100644
--- a/test/test_modules/test_pre_processing.py
+++ b/test/test_modules/test_pre_processing.py
@@ -54,7 +54,8 @@ class TestPreProcessing:
         assert obj_with_exp_setup.data_store.search_name("generator") == []
         obj_with_exp_setup.split_train_val_test()
         data_store = obj_with_exp_setup.data_store
-        assert data_store.search_scope("general.train") == sorted(["generator", "start", "end", "stations"])
+        expected_params = ["generator", "start", "end", "stations", "permute_data"]
+        assert data_store.search_scope("general.train") == sorted(expected_params)
         assert data_store.search_name("generator") == sorted(["general.train", "general.val", "general.test",
                                                               "general.train_val"])
 
@@ -100,13 +101,3 @@ class TestPreProcessing:
         assert dummy_list[val] == list(range(10, 13))
         assert dummy_list[test] == list(range(13, 15))
         assert dummy_list[train_val] == list(range(0, 13))
-
-    def test_create_args_dict_default_scope(self, obj_super_init):
-        assert obj_super_init.data_store.create_args_dict(["NAME1", "NAME2"]) == {"NAME1": 1, "NAME2": 2}
-
-    def test_create_args_dict_given_scope(self, obj_super_init):
-        assert obj_super_init.data_store.create_args_dict(["NAME1", "NAME2"], scope="general.sub") == {"NAME1": 10, "NAME2": 2}
-
-    def test_create_args_dict_missing_entry(self, obj_super_init):
-        assert obj_super_init.data_store.create_args_dict(["NAME5", "NAME2"]) == {"NAME2": 2}
-        assert obj_super_init.data_store.create_args_dict(["NAME4", "NAME2"]) == {"NAME2": 2}
diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py
index 485348ceca740d8263394fca36efbfbde6dd2d0d..31c673f05d055eb7c4ee76318711de030d97d480 100644
--- a/test/test_modules/test_training.py
+++ b/test/test_modules/test_training.py
@@ -14,7 +14,7 @@ from src.data_handling.data_generator import DataGenerator
 from src.helpers import PyTestRegex
 from src.model_modules.flatten import flatten_tail
 from src.model_modules.inception_model import InceptionModelBase
-from src.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced
+from src.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced, CallbackHandler
 from src.run_modules.run_environment import RunEnvironment
 from src.run_modules.training import Training
 
@@ -39,7 +39,7 @@ def my_test_model(activation, window_history_size, channels, dropout_rate, add_m
 class TestTraining:
 
     @pytest.fixture
-    def init_without_run(self, path: str, model: keras.Model, checkpoint: ModelCheckpoint):
+    def init_without_run(self, path: str, model: keras.Model, callbacks: CallbackHandler):
         obj = object.__new__(Training)
         super(Training, obj).__init__()
         obj.model = model
@@ -48,19 +48,22 @@ class TestTraining:
         obj.test_set = None
         obj.batch_size = 256
         obj.epochs = 2
-        obj.checkpoint = checkpoint
-        obj.lr_sc = LearningRateDecay()
-        obj.hist = HistoryAdvanced()
+        clbk, hist, lr = callbacks
+        obj.callbacks = clbk
+        obj.lr_sc = lr
+        obj.hist = hist
         obj.experiment_name = "TestExperiment"
         obj.data_store.set("generator", mock.MagicMock(return_value="mock_train_gen"), "general.train")
         obj.data_store.set("generator", mock.MagicMock(return_value="mock_val_gen"), "general.val")
         obj.data_store.set("generator", mock.MagicMock(return_value="mock_test_gen"), "general.test")
         os.makedirs(path)
         obj.data_store.set("experiment_path", path, "general")
+        obj.data_store.set("model_name", os.path.join(path, "test_model.h5"), "general.model")
         obj.data_store.set("experiment_name", "TestExperiment", "general")
         path_plot = os.path.join(path, "plots")
         os.makedirs(path_plot)
         obj.data_store.set("plot_path", path_plot, "general")
+        obj._trainable = True
         yield obj
         if os.path.exists(path):
             shutil.rmtree(path)
@@ -68,12 +71,9 @@ class TestTraining:
 
     @pytest.fixture
     def learning_rate(self):
-        return {"lr": [0.01, 0.0094]}
-
-    @pytest.fixture
-    def init_with_lr(self, init_without_run, learning_rate):
-        init_without_run.lr_sc.lr = learning_rate
-        return init_without_run
+        lr = LearningRateDecay()
+        lr.lr = {"lr": [0.01, 0.0094]}
+        return lr
 
     @pytest.fixture
     def history(self):
@@ -103,8 +103,15 @@ class TestTraining:
         return my_test_model(keras.layers.PReLU, 7, 2, 0.1, False)
 
     @pytest.fixture
-    def checkpoint(self, path):
-        return ModelCheckpoint(os.path.join(path, "model_checkpoint"), monitor='val_loss', save_best_only=True)
+    def callbacks(self, path):
+        clbk = CallbackHandler()
+        hist = HistoryAdvanced()
+        clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist")
+        lr = LearningRateDecay()
+        clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr")
+        clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss',
+                                     save_best_only=True)
+        return clbk, hist, lr
 
     @pytest.fixture
     def ready_to_train(self, generator: DataGenerator, init_without_run: Training):
@@ -123,7 +130,7 @@ class TestTraining:
         return obj
 
     @pytest.fixture
-    def ready_to_init(self, generator, model, checkpoint, path):
+    def ready_to_init(self, generator, model, callbacks, path):
         os.makedirs(path)
         obj = RunEnvironment()
         obj.data_store.set("generator", generator, "general.train")
@@ -131,13 +138,17 @@ class TestTraining:
         obj.data_store.set("generator", generator, "general.test")
         model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
         obj.data_store.set("model", model, "general.model")
+        obj.data_store.set("model_name", os.path.join(path, "test_model.h5"), "general.model")
         obj.data_store.set("batch_size", 256, "general.model")
         obj.data_store.set("epochs", 2, "general.model")
-        obj.data_store.set("checkpoint", checkpoint, "general.model")
-        obj.data_store.set("lr_decay", LearningRateDecay(), "general.model")
-        obj.data_store.set("hist", HistoryAdvanced(), "general.model")
+        clbk, hist, lr = callbacks
+        obj.data_store.set("callbacks", clbk, "general.model")
+        obj.data_store.set("lr_decay", lr, "general.model")
+        obj.data_store.set("hist", hist, "general.model")
         obj.data_store.set("experiment_name", "TestExperiment", "general")
         obj.data_store.set("experiment_path", path, "general")
+        obj.data_store.set("trainable", True, "general")
+        obj.data_store.set("create_new_model", True, "general")
         path_plot = os.path.join(path, "plots")
         os.makedirs(path_plot)
         obj.data_store.set("plot_path", path_plot, "general")
@@ -179,7 +190,7 @@ class TestTraining:
 
     def test_save_model(self, init_without_run, path, caplog):
         caplog.set_level(logging.DEBUG)
-        model_name = "TestExperiment_my_model.h5"
+        model_name = "test_model.h5"
         assert model_name not in os.listdir(path)
         init_without_run.save_model()
         assert caplog.record_tuples[0] == ("root", 10, PyTestRegex(f"save best model to {os.path.join(path, model_name)}"))
@@ -191,25 +202,25 @@ class TestTraining:
         assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting"))
         assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload..."))
 
-    def test_save_callbacks_history_created(self, init_without_run, history, path):
-        init_without_run.save_callbacks_as_json(history)
+    def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, path):
+        init_without_run.save_callbacks_as_json(history, learning_rate)
         assert "history.json" in os.listdir(path)
 
-    def test_save_callbacks_lr_created(self, init_with_lr, history, path):
-        init_with_lr.save_callbacks_as_json(history)
+    def test_save_callbacks_lr_created(self, init_without_run, history, learning_rate, path):
+        init_without_run.save_callbacks_as_json(history, learning_rate)
         assert "history_lr.json" in os.listdir(path)
 
-    def test_save_callbacks_inspect_history(self, init_without_run, history, path):
-        init_without_run.save_callbacks_as_json(history)
+    def test_save_callbacks_inspect_history(self, init_without_run, history, learning_rate, path):
+        init_without_run.save_callbacks_as_json(history, learning_rate)
         with open(os.path.join(path, "history.json")) as jfile:
             hist = json.load(jfile)
             assert hist == history.history
 
-    def test_save_callbacks_inspect_lr(self, init_with_lr, history, path):
-        init_with_lr.save_callbacks_as_json(history)
+    def test_save_callbacks_inspect_lr(self, init_without_run, history, learning_rate, path):
+        init_without_run.save_callbacks_as_json(history, learning_rate)
         with open(os.path.join(path, "history_lr.json")) as jfile:
             lr = json.load(jfile)
-            assert lr == init_with_lr.lr_sc.lr
+            assert lr == learning_rate.lr
 
     def test_create_monitoring_plots(self, init_without_run, learning_rate, history, path):
         assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0
diff --git a/test/test_statistics.py b/test/test_statistics.py
index 308ac655787e69f90b45e65e7e7df8f35875f652..cad915564aac675cadda0f625dca1a073b2c8959 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -3,7 +3,10 @@ import pandas as pd
 import pytest
 import xarray as xr
 
-from src.statistics import standardise, standardise_inverse, centre, centre_inverse
+from src.statistics import standardise, standardise_inverse, standardise_apply, centre, centre_inverse, centre_apply,\
+    apply_inverse_transformation
+
+lazy = pytest.lazy_fixture
 
 
 @pytest.fixture(scope='module')
@@ -18,44 +21,95 @@ def pandas(input_data):
     return pd.DataFrame(input_data)
 
 
+@pytest.fixture(scope='module')
+def pd_mean():
+    return [2, 10, 3]
+
+
+@pytest.fixture(scope='module')
+def pd_std():
+    return [3, 2, 3]
+
+
 @pytest.fixture(scope='module')
 def xarray(input_data):
-    return xr.DataArray(input_data, dims=['index', 'value'])
+    shape = input_data.shape
+    coords = {'index': range(shape[0]), 'value': range(shape[1])}
+    return xr.DataArray(input_data, coords=coords, dims=coords.keys())
+
+
+@pytest.fixture(scope='module')
+def xr_mean(input_data):
+    return xr.DataArray([2, 10, 3], coords={'value': range(3)}, dims=['value'])
+
+
+@pytest.fixture(scope='module')
+def xr_std(input_data):
+    return xr.DataArray([3, 2, 3], coords={'value': range(3)}, dims=['value'])
 
 
 class TestStandardise:
 
-    @pytest.mark.parametrize('data_org, dim', [(pytest.lazy_fixture('pandas'), 0),
-                                               (pytest.lazy_fixture('xarray'), 'index')])
-    def test_standardise(self, data_org, dim):
-        mean, std, data = standardise(data_org, dim)
+    @pytest.mark.parametrize('data_orig, dim', [(lazy('pandas'), 0),
+                                                (lazy('xarray'), 'index')])
+    def test_standardise(self, data_orig, dim):
+        mean, std, data = standardise(data_orig, dim)
         assert np.testing.assert_almost_equal(mean, [2, -5, 10], decimal=1) is None
         assert np.testing.assert_almost_equal(std, [2, 3, 1], decimal=1) is None
         assert np.testing.assert_almost_equal(data.mean(dim), [0, 0, 0]) is None
         assert np.testing.assert_almost_equal(data.std(dim), [1, 1, 1]) is None
 
-    @pytest.mark.parametrize('data_org, dim', [(pytest.lazy_fixture('pandas'), 0),
-                                               (pytest.lazy_fixture('xarray'), 'index')])
-    def test_standardise_inverse(self, data_org, dim):
-        mean, std, data = standardise(data_org, dim)
+    @pytest.mark.parametrize('data_orig, dim', [(lazy('pandas'), 0),
+                                                (lazy('xarray'), 'index')])
+    def test_standardise_inverse(self, data_orig, dim):
+        mean, std, data = standardise(data_orig, dim)
         data_recovered = standardise_inverse(data, mean, std)
-        assert np.testing.assert_array_almost_equal(data_org, data_recovered) is None
+        assert np.testing.assert_array_almost_equal(data_orig, data_recovered) is None
+
+    @pytest.mark.parametrize('data_orig, dim', [(lazy('pandas'), 0),
+                                                (lazy('xarray'), 'index')])
+    def test_apply_standardise_inverse(self, data_orig, dim):
+        mean, std, data = standardise(data_orig, dim)
+        data_recovered = apply_inverse_transformation(data, mean, std)
+        assert np.testing.assert_array_almost_equal(data_orig, data_recovered) is None
+
+    @pytest.mark.parametrize('data_orig, mean, std, dim', [(lazy('pandas'), lazy('pd_mean'), lazy('pd_std'), 0),
+                                                           (lazy('xarray'), lazy('xr_mean'), lazy('xr_std'), 'index')])
+    def test_standardise_apply(self, data_orig, mean, std, dim):
+        data = standardise_apply(data_orig, mean, std)
+        mean_expected = (np.array([2, -5, 10]) - np.array([2, 10, 3])) / np.array([3, 2, 3])
+        std_expected = np.array([2, 3, 1]) / np.array([3, 2, 3])
+        assert np.testing.assert_almost_equal(data.mean(dim), mean_expected, decimal=1) is None
+        assert np.testing.assert_almost_equal(data.std(dim), std_expected, decimal=1) is None
 
 
 class TestCentre:
 
-    @pytest.mark.parametrize('data_org, dim', [(pytest.lazy_fixture('pandas'), 0),
-                                               (pytest.lazy_fixture('xarray'), 'index')])
-    def test_centre(self, data_org, dim):
-        mean, std, data = centre(data_org, dim)
+    @pytest.mark.parametrize('data_orig, dim', [(lazy('pandas'), 0),
+                                                (lazy('xarray'), 'index')])
+    def test_centre(self, data_orig, dim):
+        mean, std, data = centre(data_orig, dim)
         assert np.testing.assert_almost_equal(mean, [2, -5, 10], decimal=1) is None
         assert std is None
         assert np.testing.assert_almost_equal(data.mean(dim), [0, 0, 0]) is None
 
-    @pytest.mark.parametrize('data_org, dim', [(pytest.lazy_fixture('pandas'), 0),
-                                               (pytest.lazy_fixture('xarray'), 'index')])
-    def test_centre_inverse(self, data_org, dim):
-        mean, _, data = centre(data_org, dim)
+    @pytest.mark.parametrize('data_orig, dim', [(lazy('pandas'), 0),
+                                                (lazy('xarray'), 'index')])
+    def test_centre_inverse(self, data_orig, dim):
+        mean, _, data = centre(data_orig, dim)
         data_recovered = centre_inverse(data, mean)
-        assert np.testing.assert_array_almost_equal(data_org, data_recovered) is None
-
+        assert np.testing.assert_array_almost_equal(data_orig, data_recovered) is None
+
+    @pytest.mark.parametrize('data_orig, dim', [(lazy('pandas'), 0),
+                                                (lazy('xarray'), 'index')])
+    def test_apply_centre_inverse(self, data_orig, dim):
+        mean, _, data = centre(data_orig, dim)
+        data_recovered = apply_inverse_transformation(data, mean, method="centre")
+        assert np.testing.assert_array_almost_equal(data_orig, data_recovered) is None
+
+    @pytest.mark.parametrize('data_orig, mean, dim', [(lazy('pandas'), lazy('pd_mean'), 0),
+                                                      (lazy('xarray'), lazy('xr_mean'), 'index')])
+    def test_centre_apply(self, data_orig, mean, dim):
+        data = centre_apply(data_orig, mean)
+        mean_expected = np.array([2, -5, 10]) - np.array([2, 10, 3])
+        assert np.testing.assert_almost_equal(data.mean(dim), mean_expected, decimal=1) is None