From d083f8bd6239b55f9049f2a033fd9f2b75ed540d Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Mon, 21 Sep 2020 13:19:39 +0200
Subject: [PATCH] split advanced data handler into abstract and default data
 handler, updated import statements, Station Preparation inherits from
 AbstractDataHandler and states requirements explicit (not from kwargs
 anymore)

---
 mlair/data_handler/__init__.py                |   3 +-
 mlair/data_handler/abstract_data_handler.py   |  47 +++
 mlair/data_handler/advanced_data_handler.py   | 290 +-----------------
 mlair/data_handler/bootstraps.py              |   2 +-
 .../data_preparation_neighbors.py             |   2 +-
 mlair/data_handler/default_data_handler.py    | 238 ++++++++++++++
 mlair/data_handler/station_preparation.py     |  31 +-
 mlair/run_modules/experiment_setup.py         |   2 +-
 8 files changed, 301 insertions(+), 314 deletions(-)
 create mode 100644 mlair/data_handler/abstract_data_handler.py
 create mode 100644 mlair/data_handler/default_data_handler.py

diff --git a/mlair/data_handler/__init__.py b/mlair/data_handler/__init__.py
index 6510b336..01d66003 100644
--- a/mlair/data_handler/__init__.py
+++ b/mlair/data_handler/__init__.py
@@ -11,5 +11,6 @@ __date__ = '2020-04-17'
 
 from .bootstraps import BootStraps
 from .iterator import KerasIterator, DataCollection
-from .advanced_data_handler import DefaultDataHandler, AbstractDataHandler
+from .default_data_handler import DefaultDataHandler
+from .abstract_data_handler import AbstractDataHandler
 from .data_preparation_neighbors import DataHandlerNeighbors
diff --git a/mlair/data_handler/abstract_data_handler.py b/mlair/data_handler/abstract_data_handler.py
new file mode 100644
index 00000000..04b3d465
--- /dev/null
+++ b/mlair/data_handler/abstract_data_handler.py
@@ -0,0 +1,47 @@
+
+__author__ = 'Lukas Leufen'
+__date__ = '2020-09-21'
+
+import inspect
+from typing import Union, Dict
+
+from mlair.helpers import remove_items
+
+
+class AbstractDataHandler:
+
+    _requirements = []
+
+    def __init__(self, *args, **kwargs):
+        pass
+
+    @classmethod
+    def build(cls, *args, **kwargs):
+        """Return initialised class."""
+        return cls(*args, **kwargs)
+
+    @classmethod
+    def requirements(cls):
+        """Return requirements and own arguments without duplicates."""
+        return list(set(cls._requirements + cls.own_args()))
+
+    @classmethod
+    def own_args(cls, *args):
+        return remove_items(inspect.getfullargspec(cls).args, ["self"] + list(args))
+
+    @classmethod
+    def transformation(cls, *args, **kwargs):
+        return None
+
+    def get_X(self, upsampling=False, as_numpy=False):
+        raise NotImplementedError
+
+    def get_Y(self, upsampling=False, as_numpy=False):
+        raise NotImplementedError
+
+    def get_data(self, upsampling=False, as_numpy=False):
+        return self.get_X(upsampling, as_numpy), self.get_Y(upsampling, as_numpy)
+
+    def get_coordinates(self) -> Union[None, Dict]:
+        """Return coordinates as dictionary with keys `lon` and `lat`."""
+        return None
diff --git a/mlair/data_handler/advanced_data_handler.py b/mlair/data_handler/advanced_data_handler.py
index bf7defa5..c2d210bf 100644
--- a/mlair/data_handler/advanced_data_handler.py
+++ b/mlair/data_handler/advanced_data_handler.py
@@ -2,306 +2,20 @@
 __author__ = 'Lukas Leufen'
 __date__ = '2020-07-08'
 
-
-from mlair.helpers import to_list, remove_items
 import numpy as np
 import xarray as xr
-import pickle
 import os
 import pandas as pd
 import datetime as dt
-import shutil
-import inspect
-import copy
 
-from typing import Union, List, Tuple, Dict
-import logging
-from functools import reduce
-from mlair.data_handler.station_preparation import DataHandlerSingleStation
-from mlair.helpers.join import EmptyQueryResult
+from mlair.data_handler import AbstractDataHandler
 
+from typing import Union, List
 
 number = Union[float, int]
 num_or_list = Union[number, List[number]]
 
 
-class DummyDataSingleStation:  # pragma: no cover
-
-    def __init__(self, name, number_of_samples=None):
-        self.name = name
-        self.number_of_samples = number_of_samples if number_of_samples is not None else np.random.randint(100, 150)
-
-    def get_X(self):
-        X1 = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5))  # samples, window, variables
-        datelist = pd.date_range(dt.datetime.today().date(), periods=self.number_of_samples, freq="H").tolist()
-        return xr.DataArray(X1, dims=['datetime', 'window', 'variables'], coords={"datetime": datelist,
-                                                                                  "window": range(14),
-                                                                                  "variables": range(5)})
-
-    def get_Y(self):
-        Y1 = np.round(0.5 * np.random.randn(self.number_of_samples, 5, 1), 1)  # samples, window, variables
-        datelist = pd.date_range(dt.datetime.today().date(), periods=self.number_of_samples, freq="H").tolist()
-        return xr.DataArray(Y1, dims=['datetime', 'window', 'variables'], coords={"datetime": datelist,
-                                                                                  "window": range(5),
-                                                                                  "variables": range(1)})
-
-    def __str__(self):
-        return self.name
-
-
-class AbstractDataHandler:
-
-    _requirements = []
-
-    def __init__(self, *args, **kwargs):
-        pass
-
-    @classmethod
-    def build(cls, *args, **kwargs):
-        """Return initialised class."""
-        return cls(*args, **kwargs)
-
-    @classmethod
-    def requirements(cls):
-        """Return requirements and own arguments without duplicates."""
-        return list(set(cls._requirements + cls.own_args()))
-
-    @classmethod
-    def own_args(cls, *args):
-        return remove_items(inspect.getfullargspec(cls).args, ["self"] + list(args))
-
-    @classmethod
-    def transformation(cls, *args, **kwargs):
-        return None
-
-    def get_X(self, upsampling=False, as_numpy=False):
-        raise NotImplementedError
-
-    def get_Y(self, upsampling=False, as_numpy=False):
-        raise NotImplementedError
-
-    def get_data(self, upsampling=False, as_numpy=False):
-        return self.get_X(upsampling, as_numpy), self.get_Y(upsampling, as_numpy)
-
-    def get_coordinates(self) -> Union[None, Dict]:
-        """Return coordinates as dictionary with keys `lon` and `lat`."""
-        return None
-
-
-class DefaultDataHandler(AbstractDataHandler):
-
-    _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"])
-
-    def __init__(self, id_class: DataHandlerSingleStation, data_path: str, min_length: int = 0,
-                 extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False, name_affix=None):
-        super().__init__()
-        self.id_class = id_class
-        self.interpolation_dim = "datetime"
-        self.min_length = min_length
-        self._X = None
-        self._Y = None
-        self._X_extreme = None
-        self._Y_extreme = None
-        _name_affix = str(f"{str(self.id_class)}_{name_affix}" if name_affix is not None else id(self))
-        self._save_file = os.path.join(data_path, f"data_preparation_{_name_affix}.pickle")
-        self._collection = self._create_collection()
-        self.harmonise_X()
-        self.multiply_extremes(extreme_values, extremes_on_right_tail_only, dim=self.interpolation_dim)
-        self._store(fresh_store=True)
-
-    @classmethod
-    def build(cls, station: str, **kwargs):
-        sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs}
-        sp = DataHandlerSingleStation(station, **sp_keys)
-        dp_args = {k: copy.deepcopy(kwargs[k]) for k in cls.own_args("id_class") if k in kwargs}
-        return cls(sp, **dp_args)
-
-    def _create_collection(self):
-        return [self.id_class]
-
-    @classmethod
-    def requirements(cls):
-        return remove_items(super().requirements(), "id_class")
-
-    def _reset_data(self):
-        self._X, self._Y, self._X_extreme, self._Y_extreme = None, None, None, None
-
-    def _cleanup(self):
-        directory = os.path.dirname(self._save_file)
-        if os.path.exists(directory) is False:
-            os.makedirs(directory)
-        if os.path.exists(self._save_file):
-            shutil.rmtree(self._save_file, ignore_errors=True)
-
-    def _store(self, fresh_store=False):
-        self._cleanup() if fresh_store is True else None
-        data = {"X": self._X, "Y": self._Y, "X_extreme": self._X_extreme, "Y_extreme": self._Y_extreme}
-        with open(self._save_file, "wb") as f:
-            pickle.dump(data, f)
-        logging.debug(f"save pickle data to {self._save_file}")
-        self._reset_data()
-
-    def _load(self):
-        try:
-            with open(self._save_file, "rb") as f:
-                data = pickle.load(f)
-            logging.debug(f"load pickle data from {self._save_file}")
-            self._X, self._Y = data["X"], data["Y"]
-            self._X_extreme, self._Y_extreme = data["X_extreme"], data["Y_extreme"]
-        except FileNotFoundError:
-            pass
-
-    def get_data(self, upsampling=False, as_numpy=True):
-        self._load()
-        X = self.get_X(upsampling, as_numpy)
-        Y = self.get_Y(upsampling, as_numpy)
-        self._reset_data()
-        return X, Y
-
-    def __repr__(self):
-        return ";".join(list(map(lambda x: str(x), self._collection)))
-
-    def get_X_original(self):
-        X = []
-        for data in self._collection:
-            X.append(data.get_X())
-        return X
-
-    def get_Y_original(self):
-        Y = self._collection[0].get_Y()
-        return Y
-
-    @staticmethod
-    def _to_numpy(d):
-        return list(map(lambda x: np.copy(x), d))
-
-    def get_X(self, upsampling=False, as_numpy=True):
-        no_data = (self._X is None)
-        self._load() if no_data is True else None
-        X = self._X if upsampling is False else self._X_extreme
-        self._reset_data() if no_data is True else None
-        return self._to_numpy(X) if as_numpy is True else X
-
-    def get_Y(self, upsampling=False, as_numpy=True):
-        no_data = (self._Y is None)
-        self._load() if no_data is True else None
-        Y = self._Y if upsampling is False else self._Y_extreme
-        self._reset_data() if no_data is True else None
-        return self._to_numpy([Y]) if as_numpy is True else Y
-
-    def harmonise_X(self):
-        X_original, Y_original = self.get_X_original(), self.get_Y_original()
-        dim = self.interpolation_dim
-        intersect = reduce(np.intersect1d, map(lambda x: x.coords[dim].values, X_original))
-        if len(intersect) < max(self.min_length, 1):
-            X, Y = None, None
-        else:
-            X = list(map(lambda x: x.sel({dim: intersect}), X_original))
-            Y = Y_original.sel({dim: intersect})
-        self._X, self._Y = X, Y
-
-    def get_observation(self):
-        return self.id_class.observation.copy().squeeze()
-
-    def get_transformation_Y(self):
-        return self.id_class.get_transformation_information()
-
-    def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False,
-                          timedelta: Tuple[int, str] = (1, 'm'), dim="datetime"):
-        """
-        Multiply extremes.
-
-        This method extracts extreme values from self.labels which are defined in the argument extreme_values. One can
-        also decide only to extract extremes on the right tail of the distribution. When extreme_values is a list of
-        floats/ints all values larger (and smaller than negative extreme_values; extraction is performed in standardised
-        space) than are extracted iteratively. If for example extreme_values = [1.,2.] then a value of 1.5 would be
-        extracted once (for 0th entry in list), while a 2.5 would be extracted twice (once for each entry). Timedelta is
-        used to mark those extracted values by adding one min to each timestamp. As TOAR Data are hourly one can
-        identify those "artificial" data points later easily. Extreme inputs and labels are stored in
-        self.extremes_history and self.extreme_labels, respectively.
-
-        :param extreme_values: user definition of extreme
-        :param extremes_on_right_tail_only: if False also multiply values which are smaller then -extreme_values,
-            if True only extract values larger than extreme_values
-        :param timedelta: used as arguments for np.timedelta in order to mark extreme values on datetime
-        """
-        # check if X or Y is None
-        if (self._X is None) or (self._Y is None):
-            logging.debug(f"{str(self.id_class)} has no data for X or Y, skip multiply extremes")
-            return
-        if extreme_values is None:
-            logging.debug(f"No extreme values given, skip multiply extremes")
-            self._X_extreme, self._Y_extreme = self._X, self._Y
-            return
-
-        # check type if inputs
-        extreme_values = to_list(extreme_values)
-        for i in extreme_values:
-            if not isinstance(i, number.__args__):
-                raise TypeError(f"Elements of list extreme_values have to be {number.__args__}, but at least element "
-                                f"{i} is type {type(i)}")
-
-        for extr_val in sorted(extreme_values):
-            # check if some extreme values are already extracted
-            if (self._X_extreme is None) or (self._Y_extreme is None):
-                X = self._X
-                Y = self._Y
-            else:  # one extr value iteration is done already: self.extremes_label is NOT None...
-                X = self._X_extreme
-                Y = self._Y_extreme
-
-            # extract extremes based on occurrence in labels
-            other_dims = remove_items(list(Y.dims), dim)
-            if extremes_on_right_tail_only:
-                extreme_idx = (Y > extr_val).any(dim=other_dims)
-            else:
-                extreme_idx = xr.concat([(Y < -extr_val).any(dim=other_dims[0]),
-                                           (Y > extr_val).any(dim=other_dims[0])],
-                                          dim=other_dims[1]).any(dim=other_dims[1])
-
-            extremes_X = list(map(lambda x: x.sel(**{dim: extreme_idx}), X))
-            self._add_timedelta(extremes_X, dim, timedelta)
-            # extremes_X = list(map(lambda x: x.coords[dim].values + np.timedelta64(*timedelta), extremes_X))
-
-            extremes_Y = Y.sel(**{dim: extreme_idx})
-            extremes_Y.coords[dim].values += np.timedelta64(*timedelta)
-
-            self._Y_extreme = xr.concat([Y, extremes_Y], dim=dim)
-            self._X_extreme = list(map(lambda x1, x2: xr.concat([x1, x2], dim=dim), X, extremes_X))
-
-    @staticmethod
-    def _add_timedelta(data, dim, timedelta):
-        for d in data:
-            d.coords[dim].values += np.timedelta64(*timedelta)
-
-    @classmethod
-    def transformation(cls, set_stations, **kwargs):
-        sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs}
-        transformation_dict = sp_keys.pop("transformation")
-        if transformation_dict is None:
-            return
-        scope = transformation_dict.pop("scope")
-        method = transformation_dict.pop("method")
-        if transformation_dict.pop("mean", None) is not None:
-            return
-        mean, std = None, None
-        for station in set_stations:
-            try:
-                sp = DataHandlerSingleStation(station, transformation={"method": method}, **sp_keys)
-                mean = sp.mean.copy(deep=True) if mean is None else mean.combine_first(sp.mean)
-                std = sp.std.copy(deep=True) if std is None else std.combine_first(sp.std)
-            except (AttributeError, EmptyQueryResult):
-                continue
-        if mean is None:
-            return None
-        mean_estimated = mean.mean("Stations")
-        std_estimated = std.mean("Stations")
-        return {"scope": scope, "method": method, "mean": mean_estimated, "std": std_estimated}
-
-    def get_coordinates(self):
-        return self.id_class.get_coordinates()
-
-
 def run_data_prep():
 
     from .data_preparation_neighbors import DataHandlerNeighbors
diff --git a/mlair/data_handler/bootstraps.py b/mlair/data_handler/bootstraps.py
index f7f5c3c7..68a4bbc4 100644
--- a/mlair/data_handler/bootstraps.py
+++ b/mlair/data_handler/bootstraps.py
@@ -19,7 +19,7 @@ from itertools import chain
 import numpy as np
 import xarray as xr
 
-from mlair.data_handler.advanced_data_handler import AbstractDataHandler
+from mlair.data_handler.abstract_data_handler import AbstractDataHandler
 
 
 class BootstrapIterator(Iterator):
diff --git a/mlair/data_handler/data_preparation_neighbors.py b/mlair/data_handler/data_preparation_neighbors.py
index 37e19225..1482bb9f 100644
--- a/mlair/data_handler/data_preparation_neighbors.py
+++ b/mlair/data_handler/data_preparation_neighbors.py
@@ -5,7 +5,7 @@ __date__ = '2020-07-17'
 
 from mlair.helpers import to_list
 from mlair.data_handler.station_preparation import DataHandlerSingleStation
-from mlair.data_handler.advanced_data_handler import DefaultDataHandler
+from mlair.data_handler import DefaultDataHandler
 import os
 
 from typing import Union, List
diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py
new file mode 100644
index 00000000..47f63a3e
--- /dev/null
+++ b/mlair/data_handler/default_data_handler.py
@@ -0,0 +1,238 @@
+
+__author__ = 'Lukas Leufen'
+__date__ = '2020-09-21'
+
+import copy
+import inspect
+import logging
+import os
+import pickle
+import shutil
+from functools import reduce
+from typing import Tuple, Union, List
+
+import numpy as np
+import xarray as xr
+
+from mlair.data_handler.abstract_data_handler import AbstractDataHandler
+from mlair.data_handler.station_preparation import DataHandlerSingleStation
+from mlair.helpers import remove_items, to_list
+from mlair.helpers.join import EmptyQueryResult
+
+
+number = Union[float, int]
+num_or_list = Union[number, List[number]]
+
+
+class DefaultDataHandler(AbstractDataHandler):
+
+    _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"])
+
+    def __init__(self, id_class: DataHandlerSingleStation, data_path: str, min_length: int = 0,
+                 extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False, name_affix=None):
+        super().__init__()
+        self.id_class = id_class
+        self.interpolation_dim = "datetime"
+        self.min_length = min_length
+        self._X = None
+        self._Y = None
+        self._X_extreme = None
+        self._Y_extreme = None
+        _name_affix = str(f"{str(self.id_class)}_{name_affix}" if name_affix is not None else id(self))
+        self._save_file = os.path.join(data_path, f"data_preparation_{_name_affix}.pickle")
+        self._collection = self._create_collection()
+        self.harmonise_X()
+        self.multiply_extremes(extreme_values, extremes_on_right_tail_only, dim=self.interpolation_dim)
+        self._store(fresh_store=True)
+
+    @classmethod
+    def build(cls, station: str, **kwargs):
+        sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs}
+        sp = DataHandlerSingleStation(station, **sp_keys)
+        dp_args = {k: copy.deepcopy(kwargs[k]) for k in cls.own_args("id_class") if k in kwargs}
+        return cls(sp, **dp_args)
+
+    def _create_collection(self):
+        return [self.id_class]
+
+    @classmethod
+    def requirements(cls):
+        return remove_items(super().requirements(), "id_class")
+
+    def _reset_data(self):
+        self._X, self._Y, self._X_extreme, self._Y_extreme = None, None, None, None
+
+    def _cleanup(self):
+        directory = os.path.dirname(self._save_file)
+        if os.path.exists(directory) is False:
+            os.makedirs(directory)
+        if os.path.exists(self._save_file):
+            shutil.rmtree(self._save_file, ignore_errors=True)
+
+    def _store(self, fresh_store=False):
+        self._cleanup() if fresh_store is True else None
+        data = {"X": self._X, "Y": self._Y, "X_extreme": self._X_extreme, "Y_extreme": self._Y_extreme}
+        with open(self._save_file, "wb") as f:
+            pickle.dump(data, f)
+        logging.debug(f"save pickle data to {self._save_file}")
+        self._reset_data()
+
+    def _load(self):
+        try:
+            with open(self._save_file, "rb") as f:
+                data = pickle.load(f)
+            logging.debug(f"load pickle data from {self._save_file}")
+            self._X, self._Y = data["X"], data["Y"]
+            self._X_extreme, self._Y_extreme = data["X_extreme"], data["Y_extreme"]
+        except FileNotFoundError:
+            pass
+
+    def get_data(self, upsampling=False, as_numpy=True):
+        self._load()
+        X = self.get_X(upsampling, as_numpy)
+        Y = self.get_Y(upsampling, as_numpy)
+        self._reset_data()
+        return X, Y
+
+    def __repr__(self):
+        return ";".join(list(map(lambda x: str(x), self._collection)))
+
+    def get_X_original(self):
+        X = []
+        for data in self._collection:
+            X.append(data.get_X())
+        return X
+
+    def get_Y_original(self):
+        Y = self._collection[0].get_Y()
+        return Y
+
+    @staticmethod
+    def _to_numpy(d):
+        return list(map(lambda x: np.copy(x), d))
+
+    def get_X(self, upsampling=False, as_numpy=True):
+        no_data = (self._X is None)
+        self._load() if no_data is True else None
+        X = self._X if upsampling is False else self._X_extreme
+        self._reset_data() if no_data is True else None
+        return self._to_numpy(X) if as_numpy is True else X
+
+    def get_Y(self, upsampling=False, as_numpy=True):
+        no_data = (self._Y is None)
+        self._load() if no_data is True else None
+        Y = self._Y if upsampling is False else self._Y_extreme
+        self._reset_data() if no_data is True else None
+        return self._to_numpy([Y]) if as_numpy is True else Y
+
+    def harmonise_X(self):
+        X_original, Y_original = self.get_X_original(), self.get_Y_original()
+        dim = self.interpolation_dim
+        intersect = reduce(np.intersect1d, map(lambda x: x.coords[dim].values, X_original))
+        if len(intersect) < max(self.min_length, 1):
+            X, Y = None, None
+        else:
+            X = list(map(lambda x: x.sel({dim: intersect}), X_original))
+            Y = Y_original.sel({dim: intersect})
+        self._X, self._Y = X, Y
+
+    def get_observation(self):
+        return self.id_class.observation.copy().squeeze()
+
+    def get_transformation_Y(self):
+        return self.id_class.get_transformation_information()
+
+    def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False,
+                          timedelta: Tuple[int, str] = (1, 'm'), dim="datetime"):
+        """
+        Multiply extremes.
+
+        This method extracts extreme values from self.labels which are defined in the argument extreme_values. One can
+        also decide only to extract extremes on the right tail of the distribution. When extreme_values is a list of
+        floats/ints all values larger (and smaller than negative extreme_values; extraction is performed in standardised
+        space) than are extracted iteratively. If for example extreme_values = [1.,2.] then a value of 1.5 would be
+        extracted once (for 0th entry in list), while a 2.5 would be extracted twice (once for each entry). Timedelta is
+        used to mark those extracted values by adding one min to each timestamp. As TOAR Data are hourly one can
+        identify those "artificial" data points later easily. Extreme inputs and labels are stored in
+        self.extremes_history and self.extreme_labels, respectively.
+
+        :param extreme_values: user definition of extreme
+        :param extremes_on_right_tail_only: if False also multiply values which are smaller then -extreme_values,
+            if True only extract values larger than extreme_values
+        :param timedelta: used as arguments for np.timedelta in order to mark extreme values on datetime
+        """
+        # check if X or Y is None
+        if (self._X is None) or (self._Y is None):
+            logging.debug(f"{str(self.id_class)} has no data for X or Y, skip multiply extremes")
+            return
+        if extreme_values is None:
+            logging.debug(f"No extreme values given, skip multiply extremes")
+            self._X_extreme, self._Y_extreme = self._X, self._Y
+            return
+
+        # check type if inputs
+        extreme_values = to_list(extreme_values)
+        for i in extreme_values:
+            if not isinstance(i, number.__args__):
+                raise TypeError(f"Elements of list extreme_values have to be {number.__args__}, but at least element "
+                                f"{i} is type {type(i)}")
+
+        for extr_val in sorted(extreme_values):
+            # check if some extreme values are already extracted
+            if (self._X_extreme is None) or (self._Y_extreme is None):
+                X = self._X
+                Y = self._Y
+            else:  # one extr value iteration is done already: self.extremes_label is NOT None...
+                X = self._X_extreme
+                Y = self._Y_extreme
+
+            # extract extremes based on occurrence in labels
+            other_dims = remove_items(list(Y.dims), dim)
+            if extremes_on_right_tail_only:
+                extreme_idx = (Y > extr_val).any(dim=other_dims)
+            else:
+                extreme_idx = xr.concat([(Y < -extr_val).any(dim=other_dims[0]),
+                                           (Y > extr_val).any(dim=other_dims[0])],
+                                          dim=other_dims[1]).any(dim=other_dims[1])
+
+            extremes_X = list(map(lambda x: x.sel(**{dim: extreme_idx}), X))
+            self._add_timedelta(extremes_X, dim, timedelta)
+            # extremes_X = list(map(lambda x: x.coords[dim].values + np.timedelta64(*timedelta), extremes_X))
+
+            extremes_Y = Y.sel(**{dim: extreme_idx})
+            extremes_Y.coords[dim].values += np.timedelta64(*timedelta)
+
+            self._Y_extreme = xr.concat([Y, extremes_Y], dim=dim)
+            self._X_extreme = list(map(lambda x1, x2: xr.concat([x1, x2], dim=dim), X, extremes_X))
+
+    @staticmethod
+    def _add_timedelta(data, dim, timedelta):
+        for d in data:
+            d.coords[dim].values += np.timedelta64(*timedelta)
+
+    @classmethod
+    def transformation(cls, set_stations, **kwargs):
+        sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs}
+        transformation_dict = sp_keys.pop("transformation")
+        if transformation_dict is None:
+            return
+        scope = transformation_dict.pop("scope")
+        method = transformation_dict.pop("method")
+        if transformation_dict.pop("mean", None) is not None:
+            return
+        mean, std = None, None
+        for station in set_stations:
+            try:
+                sp = DataHandlerSingleStation(station, transformation={"method": method}, **sp_keys)
+                mean = sp.mean.copy(deep=True) if mean is None else mean.combine_first(sp.mean)
+                std = sp.std.copy(deep=True) if std is None else std.combine_first(sp.std)
+            except (AttributeError, EmptyQueryResult):
+                continue
+        if mean is None:
+            return None
+        mean_estimated = mean.mean("Stations")
+        std_estimated = std.mean("Stations")
+        return {"scope": scope, "method": method, "mean": mean_estimated, "std": std_estimated}
+
+    def get_coordinates(self):
+        return self.id_class.get_coordinates()
\ No newline at end of file
diff --git a/mlair/data_handler/station_preparation.py b/mlair/data_handler/station_preparation.py
index a278d0df..6112e7c5 100644
--- a/mlair/data_handler/station_preparation.py
+++ b/mlair/data_handler/station_preparation.py
@@ -16,6 +16,7 @@ import xarray as xr
 from mlair.configuration import check_path_and_create
 from mlair import helpers
 from mlair.helpers import join, statistics
+from mlair.data_handler.abstract_data_handler import AbstractDataHandler
 
 # define a more general date type for type hinting
 date = Union[dt.date, dt.datetime]
@@ -39,18 +40,7 @@ DEFAULT_SAMPLING = "daily"
 DEFAULT_INTERPOLATION_METHOD = "linear"
 
 
-class AbstractDataHandlerSingleStation(object):
-    def __init__(self): #, path, station, statistics_per_var, transformation, **kwargs):
-        pass
-
-    def get_X(self):
-        raise NotImplementedError
-
-    def get_Y(self):
-        raise NotImplementedError
-
-
-class DataHandlerSingleStation(AbstractDataHandlerSingleStation):
+class DataHandlerSingleStation(AbstractDataHandler):
 
     def __init__(self, station, data_path, statistics_per_var, station_type=DEFAULT_STATION_TYPE,
                  network=DEFAULT_NETWORK, sampling=DEFAULT_SAMPLING, target_dim=DEFAULT_TARGET_DIM,
@@ -58,7 +48,7 @@ class DataHandlerSingleStation(AbstractDataHandlerSingleStation):
                  window_history_size=DEFAULT_WINDOW_HISTORY_SIZE, window_lead_time=DEFAULT_WINDOW_LEAD_TIME,
                  interpolation_limit: int = 0, interpolation_method: str = DEFAULT_INTERPOLATION_METHOD,
                  overwrite_local_data: bool = False, transformation=None, store_data_locally: bool = True,
-                 min_length: int = 0, start=None, end=None, **kwargs):
+                 min_length: int = 0, start=None, end=None, variables=None, **kwargs):
         super().__init__()  # path, station, statistics_per_var, transformation, **kwargs)
         self.station = helpers.to_list(station)
         self.path = os.path.abspath(data_path)
@@ -86,7 +76,7 @@ class DataHandlerSingleStation(AbstractDataHandlerSingleStation):
         # internal
         self.data = None
         self.meta = None
-        self.variables = kwargs.get('variables', list(statistics_per_var.keys()))
+        self.variables = statistics_per_var.keys() if variables is None else variables
         self.history = None
         self.label = None
         self.observation = None
@@ -98,10 +88,7 @@ class DataHandlerSingleStation(AbstractDataHandlerSingleStation):
         self.min = None
         self._transform_method = None
 
-        self.kwargs = kwargs
-        # self.kwargs["overwrite_local_data"] = overwrite_local_data
-
-        # self.make_samples()
+        # create samples
         self.setup_samples()
 
     def __str__(self):
@@ -123,7 +110,7 @@ class DataHandlerSingleStation(AbstractDataHandlerSingleStation):
                f"time_dim='{self.time_dim}', window_history_size={self.window_history_size}, " \
                f"window_lead_time={self.window_lead_time}, interpolation_limit={self.interpolation_limit}, " \
                f"interpolation_method='{self.interpolation_method}', overwrite_local_data={self.overwrite_local_data}, " \
-               f"transformation={self._print_transformation_as_string}, **{self.kwargs})"
+               f"transformation={self._print_transformation_as_string})"
 
     @property
     def _print_transformation_as_string(self):
@@ -155,10 +142,10 @@ class DataHandlerSingleStation(AbstractDataHandlerSingleStation):
         """
         return self.label.squeeze("Stations").transpose("datetime", "window").copy()
 
-    def get_X(self):
+    def get_X(self, **kwargs):
         return self.get_transposed_history()
 
-    def get_Y(self):
+    def get_Y(self, **kwargs):
         return self.get_transposed_label()
 
     def get_coordinates(self):
@@ -498,7 +485,7 @@ class DataHandlerSingleStation(AbstractDataHandlerSingleStation):
         """
         chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5",
                      "propane", "so2", "toluene"]
-        used_chem_vars = list(set(chem_vars) & set(self.variables))
+        used_chem_vars = list(set(chem_vars) & set(self.statistics_per_var.keys()))
         data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum)
         return data
 
diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py
index d66954b0..f5d7d80f 100644
--- a/mlair/run_modules/experiment_setup.py
+++ b/mlair/run_modules/experiment_setup.py
@@ -18,7 +18,7 @@ from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT,
     DEFAULT_VAL_MIN_LENGTH, DEFAULT_TEST_START, DEFAULT_TEST_END, DEFAULT_TEST_MIN_LENGTH, DEFAULT_TRAIN_VAL_MIN_LENGTH, \
     DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \
     DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST
-from mlair.data_handler.advanced_data_handler import DefaultDataHandler
+from mlair.data_handler import DefaultDataHandler
 from mlair.run_modules.run_environment import RunEnvironment
 from mlair.model_modules.model_class import MyLittleModel as VanillaModel
 
-- 
GitLab