diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py
index 09c16c68196b09fc7c1fbe5ef4b2639b684205a4..54ad03d05103323c3b68fa78218c011aaa9fe426 100644
--- a/src/data_handling/data_preparation.py
+++ b/src/data_handling/data_preparation.py
@@ -1,7 +1,7 @@
 """Data Preparation class to handle data processing for machine learning."""
 
-__author__ = 'Lukas Leufen'
-__date__ = '2020-06-29'
+__author__ = 'Lukas Leufen, Felix Kleinert'
+__date__ = '2020-07-20'
 
 import datetime as dt
 import logging
@@ -30,22 +30,6 @@ data_or_none = Union[xr.DataArray, None]
 class AbstractStationPrep():
     def __init__(self): #, path, station, statistics_per_var, transformation, **kwargs):
         pass
-        # passed parameters
-        # self.path = os.path.abspath(path)
-        # self.station = helpers.to_list(station)
-        # self.statistics_per_var = statistics_per_var
-        # # self.target_dim = 'variable'
-        # self.transformation = self.setup_transformation(transformation)
-        # self.kwargs = kwargs
-        #
-        # # internal
-        # self.data = None
-        # self.meta = None
-        # self.variables = kwargs.get('variables', list(statistics_per_var.keys()))
-        # self.history = None
-        # self.label = None
-        # self.observation = None
-
 
     def get_X(self):
         raise NotImplementedError
@@ -53,24 +37,18 @@ class AbstractStationPrep():
     def get_Y(self):
         raise NotImplementedError
 
-    # def load_data(self):
-    #     try:
-    #         self.read_data_from_disk()
-    #     except FileNotFoundError:
-    #         self.download_data()
-    #         self.load_data()
-    #
-    # def read_data_from_disk(self):
-    #     raise NotImplementedError
-    #
-    # def download_data(self):
-    #     raise NotImplementedError
 
 class StationPrep(AbstractStationPrep):
 
-    def __init__(self, station, data_path, statistics_per_var, transformation, station_type, network, sampling, target_dim, target_var,
-                 interpolate_dim, window_history_size, window_lead_time, overwrite_local_data: bool = False, **kwargs):
+    def __init__(self, station, data_path, statistics_per_var, station_type, network, sampling,
+                 target_dim, target_var, interpolate_dim, window_history_size, window_lead_time,
+                 overwrite_local_data: bool = False, transformation=None, **kwargs):
         super().__init__()  # path, station, statistics_per_var, transformation, **kwargs)
+        self.station = helpers.to_list(station)
+        self.path = os.path.abspath(data_path)
+        self.statistics_per_var = statistics_per_var
+        self.transformation = self.setup_transformation(transformation)
+
         self.station_type = station_type
         self.network = network
         self.sampling = sampling
@@ -79,11 +57,7 @@ class StationPrep(AbstractStationPrep):
         self.interpolate_dim = interpolate_dim
         self.window_history_size = window_history_size
         self.window_lead_time = window_lead_time
-
-        self.path = os.path.abspath(data_path)
-        self.station = helpers.to_list(station)
-        self.statistics_per_var = statistics_per_var
-        # self.target_dim = 'variable'
+        self.overwrite_local_data = overwrite_local_data
 
         # internal
         self.data = None
@@ -93,9 +67,15 @@ class StationPrep(AbstractStationPrep):
         self.label = None
         self.observation = None
 
-        self.transformation = None  # self.setup_transformation(transformation)
+        # internal for transformation
+        self.mean = None
+        self.std = None
+        self.max = None
+        self.min = None
+        self._transform_method = None
+
         self.kwargs = kwargs
-        self.kwargs["overwrite_local_data"] = overwrite_local_data
+        # self.kwargs["overwrite_local_data"] = overwrite_local_data
 
         self.make_samples()
 
@@ -129,8 +109,17 @@ class StationPrep(AbstractStationPrep):
     def get_Y(self):
         return self.get_transposed_label()
 
+    def call_transform(self, inverse=False):
+        self.transform(dim=self.interpolate_dim, method=self.transformation["method"],
+                       mean=self.transformation['mean'], std=self.transformation["std"],
+                       min_val=self.transformation["min"], max_val=self.transformation["max"],
+                       inverse=inverse
+                       )
+
     def make_samples(self):
         self.load_data()
+        if self.transformation is not None:
+            self.call_transform()
         self.make_history_window(self.target_dim, self.window_history_size, self.interpolate_dim)
         self.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time)
         self.make_observation(self.target_dim, self.target_var, self.interpolate_dim)
@@ -467,100 +456,153 @@ class StationPrep(AbstractStationPrep):
         """
         if transformation is None:
             return
+        elif not isinstance(transformation, dict):
+            raise TypeError(f"`transformation' must be either `None' or dict like e.g. `{{'method': 'standardise'}},"
+                            f" but transformation is of type {type(transformation)}.")
         transformation = transformation.copy()
         scope = transformation.get("scope", "station")
-        method = transformation.get("method", "standardise")
+        # method = transformation.get("method", "standardise")
+        method = transformation.get("method", None)
         mean = transformation.get("mean", None)
         std = transformation.get("std", None)
-        if scope == "data":
-            if isinstance(mean, str):
-                if mean == "accurate":
-                    mean, std = self.calculate_accurate_transformation(method)
-                elif mean == "estimate":
-                    mean, std = self.calculate_estimated_transformation(method)
-                else:
-                    raise ValueError(f"given mean attribute must either be equal to strings 'accurate' or 'estimate' or"
-                                     f"be an array with already calculated means. Given was: {mean}")
-        elif scope == "station":
-            mean, std = None, None
-        else:
-            raise ValueError(f"Scope argument can either be 'station' or 'data'. Given was: {scope}")
+        max_val = transformation.get("max", None)
+        min_val = transformation.get("min", None)
+        # if scope == "data":
+        #     if isinstance(mean, str):
+        #         if mean == "accurate":
+        #             mean, std = self.calculate_accurate_transformation(method)
+        #         elif mean == "estimate":
+        #             mean, std = self.calculate_estimated_transformation(method)
+        #         else:
+        #             raise ValueError(f"given mean attribute must either be equal to strings 'accurate' or 'estimate' or"
+        #                              f"be an array with already calculated means. Given was: {mean}")
+        # if scope == "station":
+        #     mean, std = None, None
+        # else:
+        #     raise ValueError(f"Scope argument can either be 'station' or 'data'. Given was: {scope}")
         transformation["method"] = method
         transformation["mean"] = mean
         transformation["std"] = std
+        transformation["max"] = max_val
+        transformation["min"] = min_val
         return transformation
 
-    def calculate_accurate_transformation(self, method: str) -> Tuple[data_or_none, data_or_none]:
+    def load_data(self):
+        try:
+            self.read_data_from_disk()
+        except FileNotFoundError:
+            self.download_data()
+            self.load_data()
+
+    def transform(self, dim: Union[str, int] = 0, method: str = 'standardise', inverse: bool = False, mean=None,
+                  std=None, min_val=None, max_val=None) -> None:
         """
-        Calculate accurate transformation statistics.
+        Transform data according to given transformation settings.
 
-        Use all stations of this generator and calculate mean and standard deviation on entire data set using dask.
-        Because there can be much data, this can take a while.
+        This function transforms a xarray.dataarray (along dim) or pandas.DataFrame (along axis) either with mean=0
+        and std=1 (`method=standardise`) or centers the data with mean=0 and no change in data scale
+        (`method=centre`). Furthermore, this sets an internal instance attribute for later inverse transformation. This
+        method will raise an AssertionError if an internal transform method was already set ('inverse=False') or if the
+        internal transform method, internal mean and internal standard deviation weren't set ('inverse=True').
 
-        :param method: name of transformation method
+        :param string/int dim: This param is not used for inverse transformation.
+                | for xarray.DataArray as string: name of dimension which should be standardised
+                | for pandas.DataFrame as int: axis of dimension which should be standardised
+        :param method: Choose the transformation method from 'standardise' and 'centre'. 'normalise' is not implemented
+                    yet. This param is not used for inverse transformation.
+        :param inverse: Switch between transformation and inverse transformation.
+        :param mean: Used for transformation (if required by 'method') based on external data. If 'None' the mean is
+                    calculated over the data in this class instance.
+        :param std: Used for transformation (if required by 'method') based on external data. If 'None' the std is
+                    calculated over the data in this class instance.
+        :param min_val: Used for transformation (if required by 'method') based on external data. If 'None' min_val is
+                    extracted from the data in this class instance.
+        :param max_val: Used for transformation (if required by 'method') based on external data. If 'None' max_val is
+                    extracted from the data in this class instance.
 
-        :return: accurate calculated mean and std (depending on transformation)
+        :return: xarray.DataArrays or pandas.DataFrames:
+                #. mean: Mean of data
+                #. std: Standard deviation of data
+                #. data: Standardised data
         """
-        tmp = []
-        mean = None
-        std = None
-        for station in self.stations:
-            try:
-                data = self.DataPrep(self.data_path, station, self.variables, station_type=self.station_type,
-                                     **self.kwargs)
-                chunks = (1, 100, data.data.shape[2])
-                tmp.append(da.from_array(data.data.data, chunks=chunks))
-            except EmptyQueryResult:
-                continue
-        tmp = da.concatenate(tmp, axis=1)
-        if method in ["standardise", "centre"]:
-            mean = da.nanmean(tmp, axis=1).compute()
-            mean = xr.DataArray(mean.flatten(), coords={"variables": sorted(self.variables)}, dims=["variables"])
+
+        def f(data):
+            if method == 'standardise':
+                return statistics.standardise(data, dim)
+            elif method == 'centre':
+                return statistics.centre(data, dim)
+            elif method == 'normalise':
+                # use min/max of data or given min/max
+                raise NotImplementedError
+            else:
+                raise NotImplementedError
+
+        def f_apply(data):
             if method == "standardise":
-                std = da.nanstd(tmp, axis=1).compute()
-                std = xr.DataArray(std.flatten(), coords={"variables": sorted(self.variables)}, dims=["variables"])
+                return mean, std, statistics.standardise_apply(data, mean, std)
+            elif method == "centre":
+                return mean, None, statistics.centre_apply(data, mean)
+            else:
+                raise NotImplementedError
+
+        if not inverse:
+            if self._transform_method is not None:
+                raise AssertionError(f"Transform method is already set. Therefore, data was already transformed with "
+                                     f"{self._transform_method}. Please perform inverse transformation of data first.")
+            # apply transformation on local data instance (f) if mean is None, else apply by using mean (and std) from
+            # external data.
+            self.mean, self.std, self.data = locals()["f" if mean is None else "f_apply"](self.data)
+
+            # set transform method to find correct method for inverse transformation.
+            self._transform_method = method
         else:
-            raise NotImplementedError
-        return mean, std
+            self.inverse_transform()
 
-    def calculate_estimated_transformation(self, method):
+    @staticmethod
+    def check_inverse_transform_params(mean: data_or_none, std: data_or_none, method: str) -> None:
         """
-        Calculate estimated transformation statistics.
+        Support inverse_transformation method.
 
-        Use all stations of this generator and calculate mean and standard deviation first for each station separately.
-        Afterwards, calculate the average mean and standard devation as estimated statistics. Because this method does
-        not consider the length of each data set, the estimated mean distinguishes from the real data mean. Furthermore,
-        the estimated standard deviation is assumed to be the mean (also not weighted) of all deviations. But this is
-        mathematically not true, but still a rough and faster estimation of the true standard deviation. Do not use this
-        method for further statistical calculation. However, in the scope of data preparation for machine learning, this
-        approach is decent ("it is just scaling").
+        Validate if all required statistics are available for given method. E.g. centering requires mean only, whereas
+        normalisation requires mean and standard deviation. Will raise an AttributeError on missing requirements.
 
+        :param mean: data with all mean values
+        :param std: data with all standard deviation values
         :param method: name of transformation method
+        """
+        msg = ""
+        if method in ['standardise', 'centre'] and mean is None:
+            msg += "mean, "
+        if method == 'standardise' and std is None:
+            msg += "std, "
+        if len(msg) > 0:
+            raise AttributeError(f"Inverse transform {method} can not be executed because following is None: {msg}")
 
-        :return: accurate calculated mean and std (depending on transformation)
+    def inverse_transform(self) -> None:
         """
-        data = [[]] * len(self.variables)
-        coords = {"variables": self.variables, "Stations": range(0)}
-        mean = xr.DataArray(data, coords=coords, dims=["variables", "Stations"])
-        std = xr.DataArray(data, coords=coords, dims=["variables", "Stations"])
-        for station in self.stations:
-            try:
-                data = self.DataPrep(self.data_path, station, self.variables, station_type=self.station_type,
-                                     **self.kwargs)
-                data.transform("datetime", method=method)
-                mean = mean.combine_first(data.mean)
-                std = std.combine_first(data.std)
-                data.transform("datetime", method=method, inverse=True)
-            except EmptyQueryResult:
-                continue
-        return mean.mean("Stations") if mean.shape[1] > 0 else None, std.mean("Stations") if std.shape[1] > 0 else None
+        Perform inverse transformation.
 
-    def load_data(self):
-        try:
-            self.read_data_from_disk()
-        except FileNotFoundError:
-            self.download_data()
-            self.load_data()
+        Will raise an AssertionError, if no transformation was performed before. Checks first, if all required
+        statistics are available for inverse transformation. Class attributes data, mean and std are overwritten by
+        new data afterwards. Thereby, mean, std, and the private transform method are set to None to indicate, that the
+        current data is not transformed.
+        """
+
+        def f_inverse(data, mean, std, method_inverse):
+            if method_inverse == 'standardise':
+                return statistics.standardise_inverse(data, mean, std), None, None
+            elif method_inverse == 'centre':
+                return statistics.centre_inverse(data, mean), None, None
+            elif method_inverse == 'normalise':
+                raise NotImplementedError
+            else:
+                raise NotImplementedError
+
+        if self._transform_method is None:
+            raise AssertionError("Inverse transformation method is not set. Data cannot be inverse transformed.")
+        self.check_inverse_transform_params(self.mean, self.std, self._transform_method)
+        self.data, self.mean, self.std = f_inverse(self.data, self.mean, self.std, self._transform_method)
+        self._transform_method = None
 
 
 class AbstractDataPrep(object):
@@ -1095,10 +1137,11 @@ if __name__ == "__main__":
     # dp = AbstractDataPrep('data/', 'dummy', 'DEBW107', ['o3', 'temp'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
     # print(dp)
     statistics_per_var = {'o3': 'dma8eu', 'temp-rea-miub': 'maximum'}
-    sp = StationPrep(path='/home/felix/PycharmProjects/mlt_new/data/', station='DEBY122',
-                     statistics_per_var=statistics_per_var, transformation={}, station_type='background',
+    sp = StationPrep(data_path='/home/felix/PycharmProjects/mlt_new/data/', station='DEBY122',
+                     statistics_per_var=statistics_per_var, station_type='background',
                      network='UBA', sampling='daily', target_dim='variables', target_var='o3',
-                     interpolate_dim='datetime', window_history_size=7, window_lead_time=3)
+                     interpolate_dim='datetime', window_history_size=7, window_lead_time=3,
+                     transformation={'method': 'standardise'})
     sp.get_X()
     sp.get_Y()
     print(sp)