From 483d6ca9c856cd2e7949bf5d56930be2a470b01b Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Tue, 30 Jun 2020 11:58:44 +0200
Subject: [PATCH] intermediate save point

---
 src/data_handling/data_preparation.py      |  81 +--
 src/data_handling/data_preparation_join.py | 591 +++++++++++++++++++++
 2 files changed, 608 insertions(+), 64 deletions(-)
 create mode 100644 src/data_handling/data_preparation_join.py

diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py
index fa7388e7..e85d8a3a 100644
--- a/src/data_handling/data_preparation.py
+++ b/src/data_handling/data_preparation.py
@@ -1,7 +1,7 @@
 """Data Preparation class to handle data processing for machine learning."""
 
-__author__ = 'Felix Kleinert, Lukas Leufen'
-__date__ = '2019-10-16'
+__author__ = 'Lukas Leufen'
+__date__ = '2020-06-29'
 
 import datetime as dt
 import logging
@@ -25,7 +25,7 @@ num_or_list = Union[number, List[number]]
 data_or_none = Union[xr.DataArray, None]
 
 
-class DataPrep(object):
+class AbstractDataPrep(object):
     """
     This class prepares data to be used in neural networks.
 
@@ -55,14 +55,11 @@ class DataPrep(object):
 
     """
 
-    def __init__(self, path: str, network: str, station: Union[str, List[str]], variables: List[str],
-                 station_type: str = None, **kwargs):
+    def __init__(self, path: str, station: Union[str, List[str]], variables: List[str], **kwargs):
         """Construct instance."""
         self.path = os.path.abspath(path)
-        self.network = network
         self.station = helpers.to_list(station)
         self.variables = variables
-        self.station_type = station_type
         self.mean: data_or_none = None
         self.std: data_or_none = None
         self.history: data_or_none = None
@@ -83,7 +80,7 @@ class DataPrep(object):
 
     def load_data(self):
         """
-        Load data and meta data either from local disk (preferred) or download new data from TOAR database.
+        Load data and meta data either from local disk (preferred) or download new data by using a custom download method.
 
         Data is either downloaded, if no local data is available or parameter overwrite_local_data is true. In both
         cases, downloaded data is only stored locally if store_data_locally is not disabled. If this parameter is not
@@ -103,70 +100,26 @@ class DataPrep(object):
         else:
             try:
                 logging.debug(f"try to load local data from: {file_name}")
-                data = self._slice_prep(xr.open_dataarray(file_name))
-                self.data = self.check_for_negative_concentrations(data)
+                data = xr.open_dataarray(file_name)
                 self.meta = pd.read_csv(meta_file, index_col=0)
-                if self.station_type is not None:
-                    self.check_station_meta()
                 logging.debug("loading finished")
             except FileNotFoundError as e:
                 logging.debug(e)
-                self.download_data(file_name, meta_file)
-                logging.debug("loaded new data from JOIN")
+                logging.debug("load new data from JOIN")
+                data, self.meta = self.download_data(file_name, meta_file)
+                logging.debug("loading finished")
+            # create slices and check for negative concentration.
+            data = self._slice_prep(data)
+            self.data = self.check_for_negative_concentrations(data)
 
-    def download_data(self, file_name, meta_file):
+    def download_data(self, file_name, meta_file) -> [xr.DataArray, pd.DataFrame]:
         """
-        Download data from join, create slices and check for negative concentration.
-
-        Handle sequence of required operation on new data downloads. First, download data using class method
-        download_data_from_join. Second, slice data using _slice_prep and lastly check for negative concentrations in
-        data with check_for_negative_concentrations. Finally, data is stored in instance attribute data.
+        Download data and meta.
 
         :param file_name: name of file to save data to (containing full path)
         :param meta_file: name of the meta data file (also containing full path)
         """
-        data, self.meta = self.download_data_from_join(file_name, meta_file)
-        data = self._slice_prep(data)
-        self.data = self.check_for_negative_concentrations(data)
-
-    def check_station_meta(self):
-        """
-        Search for the entries in meta data and compare the value with the requested values.
-
-        Will raise a FileNotFoundError if the values mismatch.
-        """
-        check_dict = {"station_type": self.station_type, "network_name": self.network}
-        for (k, v) in check_dict.items():
-            if self.meta.at[k, self.station[0]] != v:
-                logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != "
-                              f"{self.meta.at[k, self.station[0]]} (local). Raise FileNotFoundError to trigger new "
-                              f"grapping from web.")
-                raise FileNotFoundError
-
-    def download_data_from_join(self, file_name: str, meta_file: str) -> [xr.DataArray, pd.DataFrame]:
-        """
-        Download data from TOAR database using the JOIN interface.
-
-        Data is transformed to a xarray dataset. If class attribute store_data_locally is true, data is additionally
-        stored locally using given names for file and meta file.
-
-        :param file_name: name of file to save data to (containing full path)
-        :param meta_file: name of the meta data file (also containing full path)
-
-        :return: downloaded data and its meta data
-        """
-        df_all = {}
-        df, meta = join.download_join(station_name=self.station, stat_var=self.statistics_per_var,
-                                      station_type=self.station_type, network_name=self.network, sampling=self.sampling)
-        df_all[self.station[0]] = df
-        # convert df_all to xarray
-        xarr = {k: xr.DataArray(v, dims=['datetime', 'variables']) for k, v in df_all.items()}
-        xarr = xr.Dataset(xarr).to_array(dim='Stations')
-        if self.kwargs.get('store_data_locally', True):
-            # save locally as nc/csv file
-            xarr.to_netcdf(path=file_name)
-            meta.to_csv(meta_file)
-        return xarr, meta
+        raise NotImplementedError
 
     def _set_file_name(self):
         all_vars = sorted(self.statistics_per_var.keys())
@@ -178,8 +131,8 @@ class DataPrep(object):
 
     def __repr__(self):
         """Represent class attributes."""
-        return f"Dataprep(path='{self.path}', network='{self.network}', station={self.station}, " \
-               f"variables={self.variables}, station_type={self.station_type}, **{self.kwargs})"
+        return f"AbstractDataPrep(path='{self.path}', station={self.station}, variables={self.variables}, " \
+               f"**{self.kwargs})"
 
     def interpolate(self, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True,
                     **kwargs):
diff --git a/src/data_handling/data_preparation_join.py b/src/data_handling/data_preparation_join.py
new file mode 100644
index 00000000..43138911
--- /dev/null
+++ b/src/data_handling/data_preparation_join.py
@@ -0,0 +1,591 @@
+"""Data Preparation class to handle data processing for machine learning."""
+
+__author__ = 'Felix Kleinert, Lukas Leufen'
+__date__ = '2019-10-16'
+
+import datetime as dt
+import logging
+import os
+from functools import reduce
+from typing import Union, List, Iterable, Tuple
+
+import numpy as np
+import pandas as pd
+import xarray as xr
+
+from src.configuration import check_path_and_create
+from src import helpers
+from src.helpers import join, statistics
+
+# define a more general date type for type hinting
+date = Union[dt.date, dt.datetime]
+str_or_list = Union[str, List[str]]
+number = Union[float, int]
+num_or_list = Union[number, List[number]]
+data_or_none = Union[xr.DataArray, None]
+
+
+class DataPrep(object):
+    """
+    This class prepares data to be used in neural networks.
+
+    The instance searches for local stored data, that meet the given demands. If no local data is found, the DataPrep
+    instance will load data from TOAR database and store this data locally to use the next time. For the moment, there
+    is only support for daily aggregated time series. The aggregation can be set manually and differ for each variable.
+
+    After data loading, different data pre-processing steps can be executed to prepare the data for further
+    applications. Especially the following methods can be used for the pre-processing step:
+
+    - interpolate: interpolate between data points by using xarray's interpolation method
+    - standardise: standardise data to mean=1 and std=1, centralise to mean=0, additional methods like normalise on \
+        interval [0, 1] are not implemented yet.
+    - make window history: represent the history (time steps before) for training/ testing; X
+    - make labels: create target vector with given leading time steps for training/ testing; y
+    - remove Nans jointly from desired input and output, only keeps time steps where no NaNs are present in X AND y. \
+        Use this method after the creation of the window history and labels to clean up the data cube.
+
+    To create a DataPrep instance, it is needed to specify the stations by id (e.g. "DEBW107"), its network (e.g. UBA,
+    "Umweltbundesamt") and the variables to use. Further options can be set in the instance.
+
+    * `statistics_per_var`: define a specific statistic to extract from the TOAR database for each variable.
+    * `start`: define a start date for the data cube creation. Default: Use the first entry in time series
+    * `end`: set the end date for the data cube. Default: Use last date in time series.
+    * `store_data_locally`: store recently downloaded data on local disk. Default: True
+    * set further parameters for xarray's interpolation methods to modify the interpolation scheme
+
+    """
+
+    def __init__(self, path: str, network: str, station: Union[str, List[str]], variables: List[str],
+                 station_type: str = None, **kwargs):
+        """Construct instance."""
+        self.path = os.path.abspath(path)
+        self.network = network
+        self.station = helpers.to_list(station)
+        self.variables = variables
+        self.station_type = station_type
+        self.mean: data_or_none = None
+        self.std: data_or_none = None
+        self.history: data_or_none = None
+        self.label: data_or_none = None
+        self.observation: data_or_none = None
+        self.extremes_history: data_or_none = None
+        self.extremes_label: data_or_none = None
+        self.kwargs = kwargs
+        self.data = None
+        self.meta = None
+        self._transform_method = None
+        self.statistics_per_var = kwargs.get("statistics_per_var", None)
+        self.sampling = kwargs.get("sampling", "daily")
+        if self.statistics_per_var is not None or self.sampling == "hourly":
+            self.load_data()
+        else:
+            raise NotImplementedError("Either select hourly data or provide statistics_per_var.")
+
+    def load_data(self):
+        """
+        Load data and meta data either from local disk (preferred) or download new data from TOAR database.
+
+        Data is either downloaded, if no local data is available or parameter overwrite_local_data is true. In both
+        cases, downloaded data is only stored locally if store_data_locally is not disabled. If this parameter is not
+        set, it is assumed, that data should be saved locally.
+        """
+        check_path_and_create(self.path)
+        file_name = self._set_file_name()
+        meta_file = self._set_meta_file_name()
+        if self.kwargs.get('overwrite_local_data', False):
+            logging.debug(f"overwrite_local_data is true, therefore reload {file_name} from JOIN")
+            if os.path.exists(file_name):
+                os.remove(file_name)
+            if os.path.exists(meta_file):
+                os.remove(meta_file)
+            self.download_data(file_name, meta_file)
+            logging.debug("loaded new data from JOIN")
+        else:
+            try:
+                logging.debug(f"try to load local data from: {file_name}")
+                data = xr.open_dataarray(file_name)
+                self.meta = pd.read_csv(meta_file, index_col=0)
+                if self.station_type is not None:
+                    self.check_station_meta()
+                logging.debug("loading finished")
+            except FileNotFoundError as e:
+                logging.debug(e)
+                logging.debug("load new data from JOIN")
+                data, self.meta = self.download_data(file_name, meta_file)
+                logging.debug("loading finished")
+            # create slices and check for negative concentration.
+            data = self._slice_prep(data)
+            self.data = self.check_for_negative_concentrations(data)
+
+    def download_data(self, file_name, meta_file):
+        """
+        Download data and meta from join.
+
+        :param file_name: name of file to save data to (containing full path)
+        :param meta_file: name of the meta data file (also containing full path)
+        """
+        data, meta = self.download_data_from_join(file_name, meta_file)
+        return data, meta
+
+    def check_station_meta(self):
+        """
+        Search for the entries in meta data and compare the value with the requested values.
+
+        Will raise a FileNotFoundError if the values mismatch.
+        """
+        check_dict = {"station_type": self.station_type, "network_name": self.network}
+        for (k, v) in check_dict.items():
+            if self.meta.at[k, self.station[0]] != v:
+                logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != "
+                              f"{self.meta.at[k, self.station[0]]} (local). Raise FileNotFoundError to trigger new "
+                              f"grapping from web.")
+                raise FileNotFoundError
+
+    def download_data_from_join(self, file_name: str, meta_file: str) -> [xr.DataArray, pd.DataFrame]:
+        """
+        Download data from TOAR database using the JOIN interface.
+
+        Data is transformed to a xarray dataset. If class attribute store_data_locally is true, data is additionally
+        stored locally using given names for file and meta file.
+
+        :param file_name: name of file to save data to (containing full path)
+        :param meta_file: name of the meta data file (also containing full path)
+
+        :return: downloaded data and its meta data
+        """
+        df_all = {}
+        df, meta = join.download_join(station_name=self.station, stat_var=self.statistics_per_var,
+                                      station_type=self.station_type, network_name=self.network, sampling=self.sampling)
+        df_all[self.station[0]] = df
+        # convert df_all to xarray
+        xarr = {k: xr.DataArray(v, dims=['datetime', 'variables']) for k, v in df_all.items()}
+        xarr = xr.Dataset(xarr).to_array(dim='Stations')
+        if self.kwargs.get('store_data_locally', True):
+            # save locally as nc/csv file
+            xarr.to_netcdf(path=file_name)
+            meta.to_csv(meta_file)
+        return xarr, meta
+
+    def _set_file_name(self):
+        all_vars = sorted(self.statistics_per_var.keys())
+        return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(all_vars)}.nc")
+
+    def _set_meta_file_name(self):
+        all_vars = sorted(self.statistics_per_var.keys())
+        return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(all_vars)}_meta.csv")
+
+    def __repr__(self):
+        """Represent class attributes."""
+        return f"Dataprep(path='{self.path}', network='{self.network}', station={self.station}, " \
+               f"variables={self.variables}, station_type={self.station_type}, **{self.kwargs})"
+
+    def interpolate(self, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True,
+                    **kwargs):
+        """
+        Interpolate values according to different methods.
+
+        (Copy paste from dataarray.interpolate_na)
+
+        :param dim:
+                Specifies the dimension along which to interpolate.
+        :param method:
+                {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic',
+                          'polynomial', 'barycentric', 'krog', 'pchip',
+                          'spline', 'akima'}, optional
+                    String indicating which method to use for interpolation:
+
+                    - 'linear': linear interpolation (Default). Additional keyword
+                      arguments are passed to ``numpy.interp``
+                    - 'nearest', 'zero', 'slinear', 'quadratic', 'cubic',
+                      'polynomial': are passed to ``scipy.interpolate.interp1d``. If
+                      method=='polynomial', the ``order`` keyword argument must also be
+                      provided.
+                    - 'barycentric', 'krog', 'pchip', 'spline', and `akima`: use their
+                      respective``scipy.interpolate`` classes.
+        :param limit:
+                    default None
+                    Maximum number of consecutive NaNs to fill. Must be greater than 0
+                    or None for no limit.
+        :param use_coordinate:
+                default True
+                    Specifies which index to use as the x values in the interpolation
+                    formulated as `y = f(x)`. If False, values are treated as if
+                    eqaully-spaced along `dim`. If True, the IndexVariable `dim` is
+                    used. If use_coordinate is a string, it specifies the name of a
+                    coordinate variariable to use as the index.
+        :param kwargs:
+
+        :return: xarray.DataArray
+        """
+        self.data = self.data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate,
+                                             **kwargs)
+
+    @staticmethod
+    def check_inverse_transform_params(mean: data_or_none, std: data_or_none, method: str) -> None:
+        """
+        Support inverse_transformation method.
+
+        Validate if all required statistics are available for given method. E.g. centering requires mean only, whereas
+        normalisation requires mean and standard deviation. Will raise an AttributeError on missing requirements.
+
+        :param mean: data with all mean values
+        :param std: data with all standard deviation values
+        :param method: name of transformation method
+        """
+        msg = ""
+        if method in ['standardise', 'centre'] and mean is None:
+            msg += "mean, "
+        if method == 'standardise' and std is None:
+            msg += "std, "
+        if len(msg) > 0:
+            raise AttributeError(f"Inverse transform {method} can not be executed because following is None: {msg}")
+
+    def inverse_transform(self) -> None:
+        """
+        Perform inverse transformation.
+
+        Will raise an AssertionError, if no transformation was performed before. Checks first, if all required
+        statistics are available for inverse transformation. Class attributes data, mean and std are overwritten by
+        new data afterwards. Thereby, mean, std, and the private transform method are set to None to indicate, that the
+        current data is not transformed.
+        """
+
+        def f_inverse(data, mean, std, method_inverse):
+            if method_inverse == 'standardise':
+                return statistics.standardise_inverse(data, mean, std), None, None
+            elif method_inverse == 'centre':
+                return statistics.centre_inverse(data, mean), None, None
+            elif method_inverse == 'normalise':
+                raise NotImplementedError
+            else:
+                raise NotImplementedError
+
+        if self._transform_method is None:
+            raise AssertionError("Inverse transformation method is not set. Data cannot be inverse transformed.")
+        self.check_inverse_transform_params(self.mean, self.std, self._transform_method)
+        self.data, self.mean, self.std = f_inverse(self.data, self.mean, self.std, self._transform_method)
+        self._transform_method = None
+
+    def transform(self, dim: Union[str, int] = 0, method: str = 'standardise', inverse: bool = False, mean=None,
+                  std=None) -> None:
+        """
+        Transform data according to given transformation settings.
+
+        This function transforms a xarray.dataarray (along dim) or pandas.DataFrame (along axis) either with mean=0
+        and std=1 (`method=standardise`) or centers the data with mean=0 and no change in data scale
+        (`method=centre`). Furthermore, this sets an internal instance attribute for later inverse transformation. This
+        method will raise an AssertionError if an internal transform method was already set ('inverse=False') or if the
+        internal transform method, internal mean and internal standard deviation weren't set ('inverse=True').
+
+        :param string/int dim: This param is not used for inverse transformation.
+                | for xarray.DataArray as string: name of dimension which should be standardised
+                | for pandas.DataFrame as int: axis of dimension which should be standardised
+        :param method: Choose the transformation method from 'standardise' and 'centre'. 'normalise' is not implemented
+                    yet. This param is not used for inverse transformation.
+        :param inverse: Switch between transformation and inverse transformation.
+
+        :return: xarray.DataArrays or pandas.DataFrames:
+                #. mean: Mean of data
+                #. std: Standard deviation of data
+                #. data: Standardised data
+        """
+
+        def f(data):
+            if method == 'standardise':
+                return statistics.standardise(data, dim)
+            elif method == 'centre':
+                return statistics.centre(data, dim)
+            elif method == 'normalise':
+                # use min/max of data or given min/max
+                raise NotImplementedError
+            else:
+                raise NotImplementedError
+
+        def f_apply(data):
+            if method == "standardise":
+                return mean, std, statistics.standardise_apply(data, mean, std)
+            elif method == "centre":
+                return mean, None, statistics.centre_apply(data, mean)
+            else:
+                raise NotImplementedError
+
+        if not inverse:
+            if self._transform_method is not None:
+                raise AssertionError(f"Transform method is already set. Therefore, data was already transformed with "
+                                     f"{self._transform_method}. Please perform inverse transformation of data first.")
+            self.mean, self.std, self.data = locals()["f" if mean is None else "f_apply"](self.data)
+            self._transform_method = method
+        else:
+            self.inverse_transform()
+
+    def get_transformation_information(self, variable: str) -> Tuple[data_or_none, data_or_none, str]:
+        """
+        Extract transformation statistics and method.
+
+        Get mean and standard deviation for given variable and the transformation method if set. If a transformation
+        depends only on particular statistics (e.g. only mean is required for centering), the remaining statistics are
+        returned with None as fill value.
+
+        :param variable: Variable for which the information on transformation is requested.
+
+        :return: mean, standard deviation and transformation method
+        """
+        try:
+            mean = self.mean.sel({'variables': variable}).values
+        except AttributeError:
+            mean = None
+        try:
+            std = self.std.sel({'variables': variable}).values
+        except AttributeError:
+            std = None
+        return mean, std, self._transform_method
+
+    def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None:
+        """
+        Create a xr.DataArray containing history data.
+
+        Shift the data window+1 times and return a xarray which has a new dimension 'window' containing the shifted
+        data. This is used to represent history in the data. Results are stored in history attribute.
+
+        :param dim_name_of_inputs: Name of dimension which contains the input variables
+        :param window: number of time steps to look back in history
+                Note: window will be treated as negative value. This should be in agreement with looking back on
+                a time line. Nonetheless positive values are allowed but they are converted to its negative
+                expression
+        :param dim_name_of_shift: Dimension along shift will be applied
+        """
+        window = -abs(window)
+        self.history = self.shift(dim_name_of_shift, window).sel({dim_name_of_inputs: self.variables})
+
+    def shift(self, dim: str, window: int) -> xr.DataArray:
+        """
+        Shift data multiple times to represent history (if window <= 0) or lead time (if window > 0).
+
+        :param dim: dimension along shift is applied
+        :param window: number of steps to shift (corresponds to the window length)
+
+        :return: shifted data
+        """
+        start = 1
+        end = 1
+        if window <= 0:
+            start = window
+        else:
+            end = window + 1
+        res = []
+        for w in range(start, end):
+            res.append(self.data.shift({dim: -w}))
+        window_array = self.create_index_array('window', range(start, end))
+        res = xr.concat(res, dim=window_array)
+        return res
+
+    def make_labels(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str,
+                    window: int) -> None:
+        """
+        Create a xr.DataArray containing labels.
+
+        Labels are defined as the consecutive target values (t+1, ...t+n) following the current time step t. Set label
+        attribute.
+
+        :param dim_name_of_target: Name of dimension which contains the target variable
+        :param target_var: Name of target variable in 'dimension'
+        :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied
+        :param window: lead time of label
+        """
+        window = abs(window)
+        self.label = self.shift(dim_name_of_shift, window).sel({dim_name_of_target: target_var})
+
+    def make_observation(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str) -> None:
+        """
+        Create a xr.DataArray containing observations.
+
+        Observations are defined as value of the current time step t. Set observation attribute.
+
+        :param dim_name_of_target: Name of dimension which contains the observation variable
+        :param target_var: Name of observation variable(s) in 'dimension'
+        :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied
+        """
+        self.observation = self.shift(dim_name_of_shift, 0).sel({dim_name_of_target: target_var})
+
+    def remove_nan(self, dim: str) -> None:
+        """
+        Remove all NAs slices along dim which contain nans in history, label and observation.
+
+        This is done to present only a full matrix to keras.fit. Update history, label, and observation attribute.
+
+        :param dim: dimension along the remove is performed.
+        """
+        intersect = []
+        if (self.history is not None) and (self.label is not None):
+            non_nan_history = self.history.dropna(dim=dim)
+            non_nan_label = self.label.dropna(dim=dim)
+            non_nan_observation = self.observation.dropna(dim=dim)
+            intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values,
+                                                non_nan_observation.coords[dim].values))
+
+        min_length = self.kwargs.get("min_length", 0)
+        if len(intersect) < max(min_length, 1):
+            self.history = None
+            self.label = None
+            self.observation = None
+        else:
+            self.history = self.history.sel({dim: intersect})
+            self.label = self.label.sel({dim: intersect})
+            self.observation = self.observation.sel({dim: intersect})
+
+    @staticmethod
+    def create_index_array(index_name: str, index_value: Iterable[int]) -> xr.DataArray:
+        """
+        Create an 1D xr.DataArray with given index name and value.
+
+        :param index_name: name of dimension
+        :param index_value: values of this dimension
+
+        :return: this array
+        """
+        ind = pd.DataFrame({'val': index_value}, index=index_value)
+        res = xr.Dataset.from_dataframe(ind).to_array().rename({'index': index_name}).squeeze(dim='variable', drop=True)
+        res.name = index_name
+        return res
+
+    def _slice_prep(self, data: xr.DataArray, coord: str = 'datetime') -> xr.DataArray:
+        """
+        Set start and end date for slicing and execute self._slice().
+
+        :param data: data to slice
+        :param coord: name of axis to slice
+
+        :return: sliced data
+        """
+        start = self.kwargs.get('start', data.coords[coord][0].values)
+        end = self.kwargs.get('end', data.coords[coord][-1].values)
+        return self._slice(data, start, end, coord)
+
+    @staticmethod
+    def _slice(data: xr.DataArray, start: Union[date, str], end: Union[date, str], coord: str) -> xr.DataArray:
+        """
+        Slice through a given data_item (for example select only values of 2011).
+
+        :param data: data to slice
+        :param start: start date of slice
+        :param end: end date of slice
+        :param coord: name of axis to slice
+
+        :return: sliced data
+        """
+        return data.loc[{coord: slice(str(start), str(end))}]
+
+    def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray:
+        """
+        Set all negative concentrations to zero.
+
+        Names of all concentrations are extracted from https://join.fz-juelich.de/services/rest/surfacedata/
+        #2.1 Parameters. Currently, this check is applied on "benzene", "ch4", "co", "ethane", "no", "no2", "nox",
+        "o3", "ox", "pm1", "pm10", "pm2p5", "propane", "so2", and "toluene".
+
+        :param data: data array containing variables to check
+        :param minimum: minimum value, by default this should be 0
+
+        :return: corrected data
+        """
+        chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5",
+                     "propane", "so2", "toluene"]
+        used_chem_vars = list(set(chem_vars) & set(self.variables))
+        data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum)
+        return data
+
+    def get_transposed_history(self) -> xr.DataArray:
+        """Return history.
+
+        :return: history with dimensions datetime, window, Stations, variables.
+        """
+        return self.history.transpose("datetime", "window", "Stations", "variables").copy()
+
+    def get_transposed_label(self) -> xr.DataArray:
+        """Return label.
+
+        :return: label with dimensions datetime, window, Stations, variables.
+        """
+        return self.label.squeeze("Stations").transpose("datetime", "window").copy()
+
+    def get_extremes_history(self) -> xr.DataArray:
+        """Return extremes history.
+
+        :return: extremes history with dimensions datetime, window, Stations, variables.
+        """
+        return self.extremes_history.transpose("datetime", "window", "Stations", "variables").copy()
+
+    def get_extremes_label(self) -> xr.DataArray:
+        """Return extremes label.
+
+        :return: extremes label with dimensions datetime, window, Stations, variables.
+        """
+        return self.extremes_label.squeeze("Stations").transpose("datetime", "window").copy()
+
+    def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False,
+                          timedelta: Tuple[int, str] = (1, 'm')):
+        """
+        Multiply extremes.
+
+        This method extracts extreme values from self.labels which are defined in the argument extreme_values. One can
+        also decide only to extract extremes on the right tail of the distribution. When extreme_values is a list of
+        floats/ints all values larger (and smaller than negative extreme_values; extraction is performed in standardised
+        space) than are extracted iteratively. If for example extreme_values = [1.,2.] then a value of 1.5 would be
+        extracted once (for 0th entry in list), while a 2.5 would be extracted twice (once for each entry). Timedelta is
+        used to mark those extracted values by adding one min to each timestamp. As TOAR Data are hourly one can
+        identify those "artificial" data points later easily. Extreme inputs and labels are stored in
+        self.extremes_history and self.extreme_labels, respectively.
+
+        :param extreme_values: user definition of extreme
+        :param extremes_on_right_tail_only: if False also multiply values which are smaller then -extreme_values,
+            if True only extract values larger than extreme_values
+        :param timedelta: used as arguments for np.timedelta in order to mark extreme values on datetime
+        """
+        # check if labels or history is None
+        if (self.label is None) or (self.history is None):
+            logging.debug(f"{self.station} has `None' labels, skip multiply extremes")
+            return
+
+        # check type if inputs
+        extreme_values = helpers.to_list(extreme_values)
+        for i in extreme_values:
+            if not isinstance(i, number.__args__):
+                raise TypeError(f"Elements of list extreme_values have to be {number.__args__}, but at least element "
+                                f"{i} is type {type(i)}")
+
+        for extr_val in sorted(extreme_values):
+            # check if some extreme values are already extracted
+            if (self.extremes_label is None) or (self.extremes_history is None):
+                # extract extremes based on occurance in labels
+                if extremes_on_right_tail_only:
+                    extreme_label_idx = (self.label > extr_val).any(axis=0).values.reshape(-1, )
+                else:
+                    extreme_label_idx = np.concatenate(((self.label < -extr_val).any(axis=0).values.reshape(-1, 1),
+                                                        (self.label > extr_val).any(axis=0).values.reshape(-1, 1)),
+                                                       axis=1).any(axis=1)
+                extremes_label = self.label[..., extreme_label_idx]
+                extremes_history = self.history[..., extreme_label_idx, :]
+                extremes_label.datetime.values += np.timedelta64(*timedelta)
+                extremes_history.datetime.values += np.timedelta64(*timedelta)
+                self.extremes_label = extremes_label  # .squeeze('Stations').transpose('datetime', 'window')
+                self.extremes_history = extremes_history  # .transpose('datetime', 'window', 'Stations', 'variables')
+            else:  # one extr value iteration is done already: self.extremes_label is NOT None...
+                if extremes_on_right_tail_only:
+                    extreme_label_idx = (self.extremes_label > extr_val).any(axis=0).values.reshape(-1, )
+                else:
+                    extreme_label_idx = np.concatenate(
+                        ((self.extremes_label < -extr_val).any(axis=0).values.reshape(-1, 1),
+                         (self.extremes_label > extr_val).any(axis=0).values.reshape(-1, 1)
+                         ), axis=1).any(axis=1)
+                # check on existing extracted extremes to minimise computational costs for comparison
+                extremes_label = self.extremes_label[..., extreme_label_idx]
+                extremes_history = self.extremes_history[..., extreme_label_idx, :]
+                extremes_label.datetime.values += np.timedelta64(*timedelta)
+                extremes_history.datetime.values += np.timedelta64(*timedelta)
+                self.extremes_label = xr.concat([self.extremes_label, extremes_label], dim='datetime')
+                self.extremes_history = xr.concat([self.extremes_history, extremes_history], dim='datetime')
+
+
+if __name__ == "__main__":
+    dp = DataPrep('data/', 'dummy', 'DEBW107', ['o3', 'temp'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
+    print(dp)
-- 
GitLab