diff --git a/conftest.py b/conftest.py
index abb0c0f52757e4b2228d7d48e3dc07e08b302841..b63d3efb33f5b2c02185f16e8753231d1853e66c 100644
--- a/conftest.py
+++ b/conftest.py
@@ -66,5 +66,5 @@ def default_session_fixture(request):
 
     # request.addfinalizer(unpatch)
 
-    with mock.patch("multiprocessing.cpu_count", return_value=1):
+    with mock.patch("psutil.cpu_count", return_value=1):
         yield
diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py
index 31b58a56375ea26a857ee132c2170680bab4e55a..00815419b43ffc2466f41bfede3a96311a752fdf 100644
--- a/mlair/configuration/defaults.py
+++ b/mlair/configuration/defaults.py
@@ -46,21 +46,25 @@ DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS = True
 DEFAULT_EVALUATE_BOOTSTRAPS = True
 DEFAULT_CREATE_NEW_BOOTSTRAPS = False
 DEFAULT_NUMBER_OF_BOOTSTRAPS = 20
+DEFAULT_BOOTSTRAP_TYPE = "singleinput"
+DEFAULT_BOOTSTRAP_METHOD = "shuffle"
 DEFAULT_PLOT_LIST = ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", "PlotTimeSeries",
                      "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", "PlotConditionalQuantiles",
                      "PlotAvailability", "PlotAvailabilityHistogram", "PlotDataHistogram", "PlotOversampling",
-                     "PlotOversamplingContingency"]
+                     "PlotOversamplingContingency", "PlotPeriodogram"]
 DEFAULT_SAMPLING = "daily"
 DEFAULT_DATA_ORIGIN = {"cloudcover": "REA", "humidity": "REA", "pblheight": "REA", "press": "REA", "relhum": "REA",
                        "temp": "REA", "totprecip": "REA", "u": "REA", "v": "REA", "no": "", "no2": "", "o3": "",
                        "pm10": "", "so2": ""}
 DEFAULT_USE_MULTIPROCESSING = True
 DEFAULT_USE_MULTIPROCESSING_ON_DEBUG = False
+DEFAULT_MAX_NUMBER_MULTIPROCESSING = 16
 DEFAULT_OVERSAMPLING_BINS = 10
 DEFAULT_OVERSAMPLING_RATES_CAP = 100
 DEFAULT_OVERSAMPLING_METHOD = None
 
 
+
 def get_defaults():
     """Return all default parameters set in defaults.py"""
     return {key: value for key, value in globals().items() if key.startswith('DEFAULT')}
diff --git a/mlair/data_handler/abstract_data_handler.py b/mlair/data_handler/abstract_data_handler.py
index 419db059a58beeb4ed7e3e198e41b565f8dc7d25..36d6e9ae5394705af4b9fbcfd1d8ff77572642b5 100644
--- a/mlair/data_handler/abstract_data_handler.py
+++ b/mlair/data_handler/abstract_data_handler.py
@@ -11,6 +11,7 @@ from mlair.helpers import remove_items
 class AbstractDataHandler:
 
     _requirements = []
+    _store_attributes = []
 
     def __init__(self, *args, **kwargs):
         pass
@@ -32,6 +33,31 @@ class AbstractDataHandler:
         list_of_args = arg_spec.args + arg_spec.kwonlyargs
         return remove_items(list_of_args, ["self"] + list(args))
 
+    @classmethod
+    def store_attributes(cls) -> list:
+        """
+        Let MLAir know that some data should be stored in the data store. This is used for calculations on the train
+        subset that should be applied to validation and test subset.
+
+        To work properly, add a class variable cls._store_attributes to your data handler. If your custom data handler
+        is constructed on different data handlers (e.g. like the DefaultDataHandler), it is required to overwrite the
+        get_store_attributs method in addition to return attributes from the corresponding subclasses. This is not
+        required, if only attributes from the main class are to be returned.
+
+        Note, that MLAir will store these attributes with the data handler's identification. This depends on the custom
+        data handler setting. When loading an attribute from the data handler, it is therefore required to extract the
+        right information by using the class identification. In case of the DefaultDataHandler this can be achieved to
+        convert all keys of the attribute to string and compare these with the station parameter.
+        """
+        return list(set(cls._store_attributes))
+
+    def get_store_attributes(self):
+        """Returns all attribute names and values that are indicated by the store_attributes method."""
+        attr_dict = {}
+        for attr in self.store_attributes():
+            attr_dict[attr] = self.__getattribute__(attr)
+        return attr_dict
+
     @classmethod
     def transformation(cls, *args, **kwargs):
         return None
diff --git a/mlair/data_handler/bootstraps.py b/mlair/data_handler/bootstraps.py
index 68a4bbc4bc9620bfb54ba23fef1ce882e76c8626..e03881484bfc9b8275ede8a4432072c74643994a 100644
--- a/mlair/data_handler/bootstraps.py
+++ b/mlair/data_handler/bootstraps.py
@@ -15,69 +15,175 @@ __date__ = '2020-02-07'
 import os
 from collections import Iterator, Iterable
 from itertools import chain
+from typing import Union, List
 
 import numpy as np
 import xarray as xr
 
 from mlair.data_handler.abstract_data_handler import AbstractDataHandler
+from mlair.helpers.helpers import to_list
 
 
 class BootstrapIterator(Iterator):
 
     _position: int = None
 
-    def __init__(self, data: "BootStraps"):
+    def __init__(self, data: "BootStraps", method):
         assert isinstance(data, BootStraps)
         self._data = data
         self._dimension = data.bootstrap_dimension
-        self._collection = self._data.bootstraps()
+        self.boot_dim = "boots"
+        self._method = method
+        self._collection = self.create_collection(self._data.data, self._dimension)
         self._position = 0
 
+    def __next__(self):
+        """Return next element or stop iteration."""
+        raise NotImplementedError
+
+    @classmethod
+    def create_collection(cls, data, dim):
+        raise NotImplementedError
+
+    def _reshape(self, d):
+        if isinstance(d, list):
+            return list(map(lambda x: self._reshape(x), d))
+            # return list(map(lambda x: np.rollaxis(x, -1, 0).reshape(x.shape[0] * x.shape[-1], *x.shape[1:-1]), d))
+        else:
+            shape = d.shape
+            return np.rollaxis(d, -1, 0).reshape(shape[0] * shape[-1], *shape[1:-1])
+
+    def _to_numpy(self, d):
+        if isinstance(d, list):
+            return list(map(lambda x: self._to_numpy(x), d))
+        else:
+            return d.values
+
+    def apply_bootstrap_method(self, data: np.ndarray) -> Union[np.ndarray, List[np.ndarray]]:
+        """
+        Apply predefined bootstrap method from given data.
+
+        :param data: data to apply bootstrap method on
+        :return: processed data as numpy array
+        """
+        if isinstance(data, list):
+            return list(map(lambda x: self.apply_bootstrap_method(x.values), data))
+        else:
+            return self._method.apply(data)
+
+
+class BootstrapIteratorSingleInput(BootstrapIterator):
+    _position: int = None
+
+    def __init__(self, *args):
+        super().__init__(*args)
+
     def __next__(self):
         """Return next element or stop iteration."""
         try:
             index, dimension = self._collection[self._position]
             nboot = self._data.number_of_bootstraps
             _X, _Y = self._data.data.get_data(as_numpy=False)
-            _X = list(map(lambda x: x.expand_dims({'boots': range(nboot)}, axis=-1), _X))
-            _Y = _Y.expand_dims({"boots": range(nboot)}, axis=-1)
+            _X = list(map(lambda x: x.expand_dims({self.boot_dim: range(nboot)}, axis=-1), _X))
+            _Y = _Y.expand_dims({self.boot_dim: range(nboot)}, axis=-1)
             single_variable = _X[index].sel({self._dimension: [dimension]})
-            shuffled_variable = self.shuffle(single_variable.values)
-            shuffled_data = xr.DataArray(shuffled_variable, coords=single_variable.coords, dims=single_variable.dims)
-            _X[index] = shuffled_data.combine_first(_X[index]).reindex_like(_X[index])
+            bootstrapped_variable = self.apply_bootstrap_method(single_variable.values)
+            bootstrapped_data = xr.DataArray(bootstrapped_variable, coords=single_variable.coords,
+                                             dims=single_variable.dims)
+            _X[index] = bootstrapped_data.combine_first(_X[index]).reindex_like(_X[index])
             self._position += 1
         except IndexError:
             raise StopIteration()
         _X, _Y = self._to_numpy(_X), self._to_numpy(_Y)
         return self._reshape(_X), self._reshape(_Y), (index, dimension)
 
-    @staticmethod
-    def _reshape(d):
-        if isinstance(d, list):
-            return list(map(lambda x: np.rollaxis(x, -1, 0).reshape(x.shape[0] * x.shape[-1], *x.shape[1:-1]), d))
-        else:
-            shape = d.shape
-            return np.rollaxis(d, -1, 0).reshape(shape[0] * shape[-1], *shape[1:-1])
+    @classmethod
+    def create_collection(cls, data, dim):
+        l = []
+        for i, x in enumerate(data.get_X(as_numpy=False)):
+            l.append(list(map(lambda y: (i, y), x.indexes[dim])))
+        return list(chain(*l))
 
-    @staticmethod
-    def _to_numpy(d):
-        if isinstance(d, list):
-            return list(map(lambda x: x.values, d))
-        else:
-            return d.values
 
-    @staticmethod
-    def shuffle(data: np.ndarray) -> np.ndarray:
-        """
-        Shuffle randomly from given data (draw elements with replacement).
+class BootstrapIteratorVariable(BootstrapIterator):
 
-        :param data: data to shuffle
-        :return: shuffled data as numpy array
-        """
+    def __init__(self, *args):
+        super().__init__(*args)
+
+    def __next__(self):
+        """Return next element or stop iteration."""
+        try:
+            dimension = self._collection[self._position]
+            nboot = self._data.number_of_bootstraps
+            _X, _Y = self._data.data.get_data(as_numpy=False)
+            _X = list(map(lambda x: x.expand_dims({self.boot_dim: range(nboot)}, axis=-1), _X))
+            _Y = _Y.expand_dims({self.boot_dim: range(nboot)}, axis=-1)
+            for index in range(len(_X)):
+                single_variable = _X[index].sel({self._dimension: [dimension]})
+                bootstrapped_variable = self.apply_bootstrap_method(single_variable.values)
+                bootstrapped_data = xr.DataArray(bootstrapped_variable, coords=single_variable.coords,
+                                                 dims=single_variable.dims)
+                _X[index] = bootstrapped_data.combine_first(_X[index]).transpose(*_X[index].dims)
+            self._position += 1
+        except IndexError:
+            raise StopIteration()
+        _X, _Y = self._to_numpy(_X), self._to_numpy(_Y)
+        return self._reshape(_X), self._reshape(_Y), (None, dimension)
+
+    @classmethod
+    def create_collection(cls, data, dim):
+        l = set()
+        for i, x in enumerate(data.get_X(as_numpy=False)):
+            l.update(x.indexes[dim].to_list())
+        return to_list(l)
+
+
+class BootstrapIteratorBranch(BootstrapIterator):
+
+    def __init__(self, *args):
+        super().__init__(*args)
+
+    def __next__(self):
+        try:
+            index = self._collection[self._position]
+            nboot = self._data.number_of_bootstraps
+            _X, _Y = self._data.data.get_data(as_numpy=False)
+            _X = list(map(lambda x: x.expand_dims({self.boot_dim: range(nboot)}, axis=-1), _X))
+            _Y = _Y.expand_dims({self.boot_dim: range(nboot)}, axis=-1)
+            for dimension in _X[index].coords[self._dimension].values:
+                single_variable = _X[index].sel({self._dimension: [dimension]})
+                bootstrapped_variable = self.apply_bootstrap_method(single_variable.values)
+                bootstrapped_data = xr.DataArray(bootstrapped_variable, coords=single_variable.coords,
+                                                 dims=single_variable.dims)
+                _X[index] = bootstrapped_data.combine_first(_X[index]).transpose(*_X[index].dims)
+            self._position += 1
+        except IndexError:
+            raise StopIteration()
+        _X, _Y = self._to_numpy(_X), self._to_numpy(_Y)
+        return self._reshape(_X), self._reshape(_Y), (None, index)
+
+    @classmethod
+    def create_collection(cls, data, dim):
+        return list(range(len(data.get_X(as_numpy=False))))
+
+
+class ShuffleBootstraps:
+
+    @staticmethod
+    def apply(data):
         size = data.shape
         return np.random.choice(data.reshape(-1, ), size=size)
 
 
+class MeanBootstraps:
+
+    def __init__(self, mean):
+        self._mean = mean
+
+    def apply(self, data):
+        return np.ones_like(data) * self._mean
+
+
 class BootStraps(Iterable):
     """
     Main class to perform bootstrap operations.
@@ -89,10 +195,19 @@ class BootStraps(Iterable):
     this variable). The tuple is interesting if X consists on mutliple input streams X_i (e.g. two or more stations)
     because it shows which variable of which input X_i has been bootstrapped. All bootstrap combinations can be
     retrieved by calling the .bootstraps() method. Further more, by calling the .get_orig_prediction() this class
-    imitates according to the set number of bootstraps the original prediction
+    imitates according to the set number of bootstraps the original prediction.
+
+    As bootstrap method, this class can currently make use of the ShuffleBoostraps class that uses drawing with
+    replacement to destroy the variables information by keeping its statistical properties. Use `bootstrap="shuffle"` to
+    call this method. Another method is the zero mean bootstrapping triggered by `bootstrap="zero_mean"` and performed
+    by the MeanBootstraps class. This method destroy the variable's information by a mode collapse to constant value of
+    zero. In case, the variable is normalized with a zero mean, this is equivalent to a mode collapse to the variable's
+    mean value. Statistics in general are not conserved in this case, but the mean value of course. A custom mean value
+    for bootstrapping is currently not supported.
     """
+
     def __init__(self, data: AbstractDataHandler, number_of_bootstraps: int = 10,
-                 bootstrap_dimension: str = "variables"):
+                 bootstrap_dimension: str = "variables", bootstrap_type="singleinput", bootstrap_method="shuffle"):
         """
         Create iterable class to be ready to iter.
 
@@ -100,20 +215,24 @@ class BootStraps(Iterable):
         :param number_of_bootstraps: the number of bootstrap realisations
         """
         self.data = data
-        self.number_of_bootstraps = number_of_bootstraps
+        self.number_of_bootstraps = number_of_bootstraps if bootstrap_method == "shuffle" else 1
         self.bootstrap_dimension = bootstrap_dimension
+        self.bootstrap_method = {"shuffle": ShuffleBootstraps(),
+                                 "zero_mean": MeanBootstraps(mean=0)}.get(
+            bootstrap_method)  # todo adjust number of bootstraps if mean bootstrapping
+        self.BootstrapIterator = {"singleinput": BootstrapIteratorSingleInput,
+                                  "branch": BootstrapIteratorBranch,
+                                  "variable": BootstrapIteratorVariable}.get(bootstrap_type,
+                                                                             BootstrapIteratorSingleInput)
 
     def __iter__(self):
-        return BootstrapIterator(self)
+        return self.BootstrapIterator(self, self.bootstrap_method)
 
     def __len__(self):
-        return len(self.bootstraps())
+        return len(self.BootstrapIterator.create_collection(self.data, self.bootstrap_dimension))
 
     def bootstraps(self):
-        l = []
-        for i, x in enumerate(self.data.get_X(as_numpy=False)):
-            l.append(list(map(lambda y: (i, y), x.indexes['variables'])))
-        return list(chain(*l))
+        return self.BootstrapIterator.create_collection(self.data, self.bootstrap_dimension)
 
     def get_orig_prediction(self, path: str, file_name: str, prediction_name: str = "CNN") -> np.ndarray:
         """
diff --git a/mlair/data_handler/data_handler_kz_filter.py b/mlair/data_handler/data_handler_kz_filter.py
deleted file mode 100644
index 539712b39e51c32203e1c55e28ce2eff24069479..0000000000000000000000000000000000000000
--- a/mlair/data_handler/data_handler_kz_filter.py
+++ /dev/null
@@ -1,114 +0,0 @@
-"""Data Handler using kz-filtered data."""
-
-__author__ = 'Lukas Leufen'
-__date__ = '2020-08-26'
-
-import inspect
-import numpy as np
-import pandas as pd
-import xarray as xr
-from typing import List, Union, Tuple, Optional
-from functools import partial
-
-from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation
-from mlair.data_handler import DefaultDataHandler
-from mlair.helpers import remove_items, to_list, TimeTrackingWrapper
-from mlair.helpers.statistics import KolmogorovZurbenkoFilterMovingWindow as KZFilter
-
-# define a more general date type for type hinting
-str_or_list = Union[str, List[str]]
-
-
-class DataHandlerKzFilterSingleStation(DataHandlerSingleStation):
-    """Data handler for a single station to be used by a superior data handler. Inputs are kz filtered."""
-
-    _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"])
-    _hash = DataHandlerSingleStation._hash + ["kz_filter_length", "kz_filter_iter", "filter_dim"]
-
-    DEFAULT_FILTER_DIM = "filter"
-
-    def __init__(self, *args, kz_filter_length, kz_filter_iter, filter_dim=DEFAULT_FILTER_DIM, **kwargs):
-        self._check_sampling(**kwargs)
-        # self.original_data = None  # ToDo: implement here something to store unfiltered data
-        self.kz_filter_length = to_list(kz_filter_length)
-        self.kz_filter_iter = to_list(kz_filter_iter)
-        self.filter_dim = filter_dim
-        self.cutoff_period = None
-        self.cutoff_period_days = None
-        super().__init__(*args, **kwargs)
-
-    def setup_transformation(self, transformation: Union[None, dict, Tuple]) -> Tuple[Optional[dict], Optional[dict]]:
-        """
-        Adjust setup of transformation because kfz filtered data will have negative values which is not compatible with
-        the log transformation. Therefore, replace all log transformation methods by a default standardization. This is
-        only applied on input side.
-        """
-        transformation = super(__class__, self).setup_transformation(transformation)
-        if transformation[0] is not None:
-            for k, v in transformation[0].items():
-                if v["method"] == "log":
-                    transformation[0][k]["method"] = "standardise"
-        return transformation
-
-    def _check_sampling(self, **kwargs):
-        assert kwargs.get("sampling") == "hourly"  # This data handler requires hourly data resolution
-
-    def make_input_target(self):
-        data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling,
-                                         self.station_type, self.network, self.store_data_locally, self.data_origin)
-        self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
-                                      limit=self.interpolation_limit)
-        self.set_inputs_and_targets()
-        self.apply_kz_filter()
-        # this is just a code snippet to check the results of the kz filter
-        # import matplotlib
-        # matplotlib.use("TkAgg")
-        # import matplotlib.pyplot as plt
-        # self.input_data.sel(filter="74d", variables="temp", Stations="DEBW107").plot()
-        # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter")
-
-    @TimeTrackingWrapper
-    def apply_kz_filter(self):
-        """Apply kolmogorov zurbenko filter only on inputs."""
-        kz = KZFilter(self.input_data, wl=self.kz_filter_length, itr=self.kz_filter_iter, filter_dim=self.time_dim)
-        filtered_data: List[xr.DataArray] = kz.run()
-        self.cutoff_period = kz.period_null()
-        self.cutoff_period_days = kz.period_null_days()
-        self.input_data = xr.concat(filtered_data, pd.Index(self.create_filter_index(), name=self.filter_dim))
-
-    def create_filter_index(self) -> pd.Index:
-        """
-        Round cut off periods in days and append 'res' for residuum index.
-
-        Round small numbers (<10) to single decimal, and higher numbers to int. Transform as list of str and append
-        'res' for residuum index.
-        """
-        index = np.round(self.cutoff_period_days, 1)
-        f = lambda x: int(np.round(x)) if x >= 10 else np.round(x, 1)
-        index = list(map(f, index.tolist()))
-        index = list(map(lambda x: str(x) + "d", index)) + ["res"]
-        return pd.Index(index, name=self.filter_dim)
-
-    def get_transposed_history(self) -> xr.DataArray:
-        """Return history.
-
-        :return: history with dimensions datetime, window, Stations, variables, filter.
-        """
-        return self.history.transpose(self.time_dim, self.window_dim, self.iter_dim, self.target_dim,
-                                      self.filter_dim).copy()
-
-    def _create_lazy_data(self):
-        return [self._data, self.meta, self.input_data, self.target_data, self.cutoff_period, self.cutoff_period_days]
-
-    def _extract_lazy(self, lazy_data):
-        _data, self.meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days = lazy_data
-        f_prep = partial(self._slice_prep, start=self.start, end=self.end)
-        self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data]))
-
-
-class DataHandlerKzFilter(DefaultDataHandler):
-    """Data handler using kz filtered data."""
-
-    data_handler = DataHandlerKzFilterSingleStation
-    data_handler_transformation = DataHandlerKzFilterSingleStation
-    _requirements = data_handler.requirements()
diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py
index a10364333f3671448c560b40283fb2645d251428..8205ae6c28f3683b1052c292e5d063d8bca555dc 100644
--- a/mlair/data_handler/data_handler_mixed_sampling.py
+++ b/mlair/data_handler/data_handler_mixed_sampling.py
@@ -2,11 +2,15 @@ __author__ = 'Lukas Leufen'
 __date__ = '2020-11-05'
 
 from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation
-from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilterSingleStation
+from mlair.data_handler.data_handler_with_filter import DataHandlerKzFilterSingleStation, \
+    DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerClimateFirFilterSingleStation
+from mlair.data_handler.data_handler_with_filter import DataHandlerClimateFirFilter, DataHandlerFirFilter, \
+    DataHandlerKzFilter
 from mlair.data_handler import DefaultDataHandler
 from mlair import helpers
 from mlair.helpers import remove_items
 from mlair.configuration.defaults import DEFAULT_SAMPLING, DEFAULT_INTERPOLATION_LIMIT, DEFAULT_INTERPOLATION_METHOD
+from mlair.helpers.filter import filter_width_kzf
 
 import inspect
 from typing import Callable
@@ -66,7 +70,8 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation):
                                          self.station_type, self.network, self.store_data_locally, self.data_origin,
                                          self.start, self.end)
         data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method[ind],
-                                limit=self.interpolation_limit[ind])
+                                limit=self.interpolation_limit[ind], sampling=self.sampling[ind])
+
         return data
 
     def set_inputs_and_targets(self):
@@ -94,8 +99,8 @@ class DataHandlerMixedSampling(DefaultDataHandler):
 
 
 class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSingleStation,
-                                                      DataHandlerKzFilterSingleStation):
-    _requirements1 = DataHandlerKzFilterSingleStation.requirements()
+                                                      DataHandlerFilterSingleStation):
+    _requirements1 = DataHandlerFilterSingleStation.requirements()
     _requirements2 = DataHandlerMixedSamplingSingleStation.requirements()
     _requirements = list(set(_requirements1 + _requirements2))
 
@@ -107,19 +112,16 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi
 
     def make_input_target(self):
         """
-        A KZ filter is applied on the input data that has hourly resolution. Lables Y are provided as aggregated values
+        A FIR filter is applied on the input data that has hourly resolution. Lables Y are provided as aggregated values
         with daily resolution.
         """
         self._data = tuple(map(self.load_and_interpolate, [0, 1]))  # load input (0) and target (1) data
         self.set_inputs_and_targets()
-        self.apply_kz_filter()
+        self.apply_filter()
 
     def estimate_filter_width(self):
-        """
-        f = 0.5 / (len * sqrt(itr)) -> T = 1 / f
-        :return:
-        """
-        return int(self.kz_filter_length[0] * np.sqrt(self.kz_filter_iter[0]) * 2)
+        """Return maximum filter width."""
+        raise NotImplementedError
 
     @staticmethod
     def _add_time_delta(date, delta):
@@ -152,26 +154,120 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi
                                          self.station_type, self.network, self.store_data_locally, self.data_origin,
                                          start, end)
         data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method[ind],
-                                limit=self.interpolation_limit[ind])
+                                limit=self.interpolation_limit[ind], sampling=self.sampling[ind])
         return data
 
     def _extract_lazy(self, lazy_data):
-        _data, self.meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days = lazy_data
+        _data, self.meta, _input_data, _target_data = lazy_data
         start_inp, end_inp = self.update_start_end(0)
         self._data = tuple(map(lambda x: self._slice_prep(_data[x], *self.update_start_end(x)), [0, 1]))
         self.input_data = self._slice_prep(_input_data, start_inp, end_inp)
         self.target_data = self._slice_prep(_target_data, self.start, self.end)
 
 
-class DataHandlerMixedSamplingWithFilter(DefaultDataHandler):
+class DataHandlerMixedSamplingWithKzFilterSingleStation(DataHandlerMixedSamplingWithFilterSingleStation,
+                                                        DataHandlerKzFilterSingleStation):
+    _requirements1 = DataHandlerKzFilterSingleStation.requirements()
+    _requirements2 = DataHandlerMixedSamplingWithFilterSingleStation.requirements()
+    _requirements = list(set(_requirements1 + _requirements2))
+
+    def estimate_filter_width(self):
+        """
+        f = 0.5 / (len * sqrt(itr)) -> T = 1 / f
+        :return:
+        """
+        return int(self.kz_filter_length[0] * np.sqrt(self.kz_filter_iter[0]) * 2)
+
+    def _extract_lazy(self, lazy_data):
+        _data, _meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days, \
+        self.filter_dim_order = lazy_data
+        super(__class__, self)._extract_lazy((_data, _meta, _input_data, _target_data))
+
+
+class DataHandlerMixedSamplingWithKzFilter(DataHandlerKzFilter):
+    """Data handler using mixed sampling for input and target. Inputs are temporal filtered."""
+
+    data_handler = DataHandlerMixedSamplingWithKzFilterSingleStation
+    data_handler_transformation = DataHandlerMixedSamplingWithKzFilterSingleStation
+    _requirements = data_handler.requirements()
+
+
+class DataHandlerMixedSamplingWithFirFilterSingleStation(DataHandlerMixedSamplingWithFilterSingleStation,
+                                                         DataHandlerFirFilterSingleStation):
+    _requirements1 = DataHandlerFirFilterSingleStation.requirements()
+    _requirements2 = DataHandlerMixedSamplingWithFilterSingleStation.requirements()
+    _requirements = list(set(_requirements1 + _requirements2))
+
+    def estimate_filter_width(self):
+        """Filter width is determined by the filter with the highest order."""
+        return max(self.filter_order)
+
+    def _extract_lazy(self, lazy_data):
+        _data, _meta, _input_data, _target_data, self.fir_coeff, self.filter_dim_order = lazy_data
+        super(__class__, self)._extract_lazy((_data, _meta, _input_data, _target_data))
+
+    @staticmethod
+    def _get_fs(**kwargs):
+        """Return frequency in 1/day (not Hz)"""
+        sampling = kwargs.get("sampling")[0]
+        if sampling == "daily":
+            return 1
+        elif sampling == "hourly":
+            return 24
+        else:
+            raise ValueError(f"Unknown sampling rate {sampling}. Only daily and hourly resolution is supported.")
+
+
+class DataHandlerMixedSamplingWithFirFilter(DataHandlerFirFilter):
+    """Data handler using mixed sampling for input and target. Inputs are temporal filtered."""
+
+    data_handler = DataHandlerMixedSamplingWithFirFilterSingleStation
+    data_handler_transformation = DataHandlerMixedSamplingWithFirFilterSingleStation
+    _requirements = data_handler.requirements()
+
+
+class DataHandlerMixedSamplingWithClimateFirFilterSingleStation(DataHandlerMixedSamplingWithFilterSingleStation,
+                                                                DataHandlerClimateFirFilterSingleStation):
+    _requirements1 = DataHandlerClimateFirFilterSingleStation.requirements()
+    _requirements2 = DataHandlerMixedSamplingWithFilterSingleStation.requirements()
+    _requirements = list(set(_requirements1 + _requirements2))
+
+    def estimate_filter_width(self):
+        """Filter width is determined by the filter with the highest order."""
+        if isinstance(self.filter_order[0], tuple):
+            return max([filter_width_kzf(*e) for e in self.filter_order])
+        else:
+            return max(self.filter_order)
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def _extract_lazy(self, lazy_data):
+        _data, _meta, _input_data, _target_data, self.climate_filter_coeff, self.apriori, self.all_apriori, \
+        self.filter_dim_order = lazy_data
+        DataHandlerMixedSamplingWithFilterSingleStation._extract_lazy(self, (_data, _meta, _input_data, _target_data))
+
+    @staticmethod
+    def _get_fs(**kwargs):
+        """Return frequency in 1/day (not Hz)"""
+        sampling = kwargs.get("sampling")[0]
+        if sampling == "daily":
+            return 1
+        elif sampling == "hourly":
+            return 24
+        else:
+            raise ValueError(f"Unknown sampling rate {sampling}. Only daily and hourly resolution is supported.")
+
+
+class DataHandlerMixedSamplingWithClimateFirFilter(DataHandlerClimateFirFilter):
     """Data handler using mixed sampling for input and target. Inputs are temporal filtered."""
 
-    data_handler = DataHandlerMixedSamplingWithFilterSingleStation
-    data_handler_transformation = DataHandlerMixedSamplingWithFilterSingleStation
+    data_handler = DataHandlerMixedSamplingWithClimateFirFilterSingleStation
+    data_handler_transformation = DataHandlerMixedSamplingWithClimateFirFilterSingleStation
     _requirements = data_handler.requirements()
 
 
-class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithFilterSingleStation):
+class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithKzFilterSingleStation):
     """
     Data handler using mixed sampling for input and target. Inputs are temporal filtered and depending on the
     separation frequency of a filtered time series the time step delta for input data is adjusted (see image below).
@@ -181,8 +277,8 @@ class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithFil
 
     """
 
-    _requirements = DataHandlerMixedSamplingWithFilterSingleStation.requirements()
-    _hash = DataHandlerMixedSamplingWithFilterSingleStation._hash + ["time_delta"]
+    _requirements = DataHandlerMixedSamplingWithKzFilterSingleStation.requirements()
+    _hash = DataHandlerMixedSamplingWithKzFilterSingleStation._hash + ["time_delta"]
 
     def __init__(self, *args, time_delta=np.sqrt, **kwargs):
         assert isinstance(time_delta, Callable)
diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py
index 89aafa2c7030427e105b663c97998c3ecf09eaaf..4330efd9ee5d3ae8a64c6eb9b95a0c58e18b3c36 100644
--- a/mlair/data_handler/data_handler_single_station.py
+++ b/mlair/data_handler/data_handler_single_station.py
@@ -280,7 +280,7 @@ class DataHandlerSingleStation(AbstractDataHandler):
                                          self.station_type, self.network, self.store_data_locally, self.data_origin,
                                          self.start, self.end)
         self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
-                                      limit=self.interpolation_limit)
+                                      limit=self.interpolation_limit, sampling=self.sampling)
         self.set_inputs_and_targets()
 
     def set_inputs_and_targets(self):
@@ -406,7 +406,8 @@ class DataHandlerSingleStation(AbstractDataHandler):
                      "propane", "so2", "toluene"]
         # used_chem_vars = list(set(chem_vars) & set(self.statistics_per_var.keys()))
         used_chem_vars = list(set(chem_vars) & set(data.variables.values))
-        data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum)
+        if len(used_chem_vars) > 0:
+            data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum)
         return data
 
     def setup_data_path(self, data_path: str, sampling: str):
@@ -468,9 +469,8 @@ class DataHandlerSingleStation(AbstractDataHandler):
         all_vars = sorted(statistics_per_var.keys())
         return os.path.join(path, f"{''.join(station)}_{'_'.join(all_vars)}_meta.csv")
 
-    @staticmethod
-    def interpolate(data, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True,
-                    **kwargs):
+    def interpolate(self, data, dim: str, method: str = 'linear', limit: int = None,
+                    use_coordinate: Union[bool, str] = True, sampling="daily", **kwargs):
         """
         Interpolate values according to different methods.
 
@@ -507,8 +507,22 @@ class DataHandlerSingleStation(AbstractDataHandler):
 
         :return: xarray.DataArray
         """
+        data = self.create_full_time_dim(data, dim, sampling)
         return data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate, **kwargs)
 
+    @staticmethod
+    def create_full_time_dim(data, dim, sampling):
+        """Ensure time dimension to be equidistant. Sometimes dates if missing values have been dropped."""
+        start = data.coords[dim].values[0]
+        end = data.coords[dim].values[-1]
+        freq = {"daily": "1D", "hourly": "1H"}.get(sampling)
+        datetime_index = pd.DataFrame(index=pd.date_range(start, end, freq=freq))
+        t = data.sel({dim: start}, drop=True)
+        res = xr.DataArray(coords=[datetime_index.index, *[t.coords[c] for c in t.coords]], dims=[dim, *t.coords])
+        res = res.transpose(*data.dims)
+        res.loc[data.coords] = data
+        return res
+
     def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None:
         """
         Create a xr.DataArray containing history data.
diff --git a/mlair/data_handler/data_handler_with_filter.py b/mlair/data_handler/data_handler_with_filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..e76f396aea80b2db76e01ea5baacf71d024b0d23
--- /dev/null
+++ b/mlair/data_handler/data_handler_with_filter.py
@@ -0,0 +1,501 @@
+"""Data Handler using kz-filtered data."""
+
+__author__ = 'Lukas Leufen'
+__date__ = '2020-08-26'
+
+import inspect
+import numpy as np
+import pandas as pd
+import xarray as xr
+from typing import List, Union, Tuple, Optional
+from functools import partial
+import logging
+from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation
+from mlair.data_handler import DefaultDataHandler
+from mlair.helpers import remove_items, to_list, TimeTrackingWrapper
+from mlair.helpers.filter import KolmogorovZurbenkoFilterMovingWindow as KZFilter
+from mlair.helpers.filter import FIRFilter, ClimateFIRFilter, omega_null_kzf
+
+# define a more general date type for type hinting
+str_or_list = Union[str, List[str]]
+
+
+# cutoff_p = [(None, 14), (8, 6), (2, 0.8), (0.8, None)]
+# cutoff = list(map(lambda x: (1. / x[0] if x[0] is not None else None, 1. / x[1] if x[1] is not None else None), cutoff_p))
+# fs = 24.
+# # order = int(60 * fs) + 1
+# order = np.array([int(14 * fs) + 1, int(14 * fs) + 1, int(4 * fs) + 1, int(2 * fs) + 1])
+# print("cutoff period", cutoff_p)
+# print("cutoff", cutoff)
+# print("fs", fs)
+# print("order", order)
+# print("delay", 0.5 * (order-1) / fs)
+# window = ("kaiser", 5)
+# # low pass
+# y, h = fir_filter(station_data.values.flatten(), fs, order[0], cutoff_low = cutoff[0][0], cutoff_high = cutoff[0][1], window=window)
+# filtered = xr.ones_like(station_data) * y.reshape(station_data.values.shape)
+
+
+class DataHandlerFilterSingleStation(DataHandlerSingleStation):
+    """General data handler for a single station to be used by a superior data handler."""
+
+    _requirements = remove_items(DataHandlerSingleStation.requirements(), "station")
+    _hash = DataHandlerSingleStation._hash + ["filter_dim"]
+
+    DEFAULT_FILTER_DIM = "filter"
+
+    def __init__(self, *args, filter_dim=DEFAULT_FILTER_DIM, **kwargs):
+        # self.original_data = None  # ToDo: implement here something to store unfiltered data
+        self.filter_dim = filter_dim
+        self.filter_dim_order = None
+        super().__init__(*args, **kwargs)
+
+    def setup_transformation(self, transformation: Union[None, dict, Tuple]) -> Tuple[Optional[dict], Optional[dict]]:
+        """
+        Adjust setup of transformation because filtered data will have negative values which is not compatible with
+        the log transformation. Therefore, replace all log transformation methods by a default standardization. This is
+        only applied on input side.
+        """
+        transformation = super(__class__, self).setup_transformation(transformation)
+        if transformation[0] is not None:
+            for k, v in transformation[0].items():
+                if v["method"] == "log":
+                    transformation[0][k]["method"] = "standardise"
+                elif v["method"] == "min_max":
+                    transformation[0][k]["method"] = "standardise"
+        return transformation
+
+    def _check_sampling(self, **kwargs):
+        assert kwargs.get("sampling") == "hourly"  # This data handler requires hourly data resolution, does it?
+
+    def make_input_target(self):
+        data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling,
+                                         self.station_type, self.network, self.store_data_locally, self.data_origin,
+                                         self.start, self.end)
+        self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
+                                      limit=self.interpolation_limit)
+        self.set_inputs_and_targets()
+        self.apply_filter()
+        # this is just a code snippet to check the results of the kz filter
+        # import matplotlib
+        # matplotlib.use("TkAgg")
+        # import matplotlib.pyplot as plt
+        # self.input_data.sel(filter="low", variables="temp", Stations="DEBW107").plot()
+        # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter")
+
+    def apply_filter(self):
+        raise NotImplementedError
+
+    def create_filter_index(self) -> pd.Index:
+        """Create name for filter dimension."""
+        raise NotImplementedError
+
+    def get_transposed_history(self) -> xr.DataArray:
+        """Return history.
+
+        :return: history with dimensions datetime, window, Stations, variables, filter.
+        """
+        return self.history.transpose(self.time_dim, self.window_dim, self.iter_dim, self.target_dim,
+                                      self.filter_dim).copy()
+
+    def _create_lazy_data(self):
+        raise NotImplementedError
+
+    def _extract_lazy(self, lazy_data):
+        _data, self.meta, _input_data, _target_data = lazy_data
+        f_prep = partial(self._slice_prep, start=self.start, end=self.end)
+        self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data]))
+
+
+class DataHandlerFilter(DefaultDataHandler):
+    """Data handler using FIR filtered data."""
+
+    data_handler = DataHandlerFilterSingleStation
+    data_handler_transformation = DataHandlerFilterSingleStation
+    _requirements = data_handler.requirements()
+
+    def __init__(self, *args, use_filter_branches=False, **kwargs):
+        self.use_filter_branches = use_filter_branches
+        super().__init__(*args, **kwargs)
+
+    @classmethod
+    def own_args(cls, *args):
+        """Return all arguments (including kwonlyargs)."""
+        super_own_args = DefaultDataHandler.own_args(*args)
+        arg_spec = inspect.getfullargspec(cls)
+        list_of_args = arg_spec.args + arg_spec.kwonlyargs + super_own_args
+        return remove_items(list_of_args, ["self"] + list(args))
+
+    def get_X_original(self):
+        if self.use_filter_branches is True:
+            X = []
+            for data in self._collection:
+                X_total = data.get_X()
+                filter_dim = data.filter_dim
+                for filter_name in data.filter_dim_order:
+                    X.append(X_total.sel({filter_dim: filter_name}, drop=True))
+            return X
+        else:
+            return super().get_X_original()
+
+
+class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation):
+    """Data handler for a single station to be used by a superior data handler. Inputs are FIR filtered."""
+
+    _requirements = remove_items(DataHandlerFilterSingleStation.requirements(), "station")
+    _hash = DataHandlerFilterSingleStation._hash + ["filter_cutoff_period", "filter_order", "filter_window_type",
+                                                    "_add_unfiltered"]
+
+    DEFAULT_WINDOW_TYPE = ("kaiser", 5)
+    DEFAULT_ADD_UNFILTERED = False
+
+    def __init__(self, *args, filter_cutoff_period, filter_order, filter_window_type=DEFAULT_WINDOW_TYPE,
+                 filter_add_unfiltered=DEFAULT_ADD_UNFILTERED, **kwargs):
+        # self._check_sampling(**kwargs)
+        # self.original_data = None  # ToDo: implement here something to store unfiltered data
+        self.fs = self._get_fs(**kwargs)
+        if filter_window_type == "kzf":
+            filter_cutoff_period = self._get_kzf_cutoff_period(filter_order, self.fs)
+        self.filter_cutoff_period, removed_index = self._prepare_filter_cutoff_period(filter_cutoff_period, self.fs)
+        self.filter_cutoff_freq = self._period_to_freq(self.filter_cutoff_period)
+        assert len(self.filter_cutoff_period) == (len(filter_order) - len(removed_index))
+        self.filter_order = self._prepare_filter_order(filter_order, removed_index, self.fs)
+        self.filter_window_type = filter_window_type
+        self._add_unfiltered = filter_add_unfiltered
+        super().__init__(*args, **kwargs)
+
+    @staticmethod
+    def _prepare_filter_order(filter_order, removed_index, fs):
+        order = []
+        for i, o in enumerate(filter_order):
+            if i not in removed_index:
+                if isinstance(o, tuple):
+                    fo = (o[0] * fs, o[1])
+                else:
+                    fo = int(o * fs)
+                    fo = fo + 1 if fo % 2 == 0 else fo
+                order.append(fo)
+        return order
+
+    @staticmethod
+    def _prepare_filter_cutoff_period(filter_cutoff_period, fs):
+        """Frequency must be smaller than the sampling frequency fs. Otherwise remove given cutoff period pair."""
+        cutoff_tmp = (lambda x: [x] if isinstance(x, tuple) else to_list(x))(filter_cutoff_period)
+        cutoff = []
+        removed = []
+        for i, (low, high) in enumerate(cutoff_tmp):
+            low = low if (low is None or low > 2. / fs) else None
+            high = high if (high is None or high > 2. / fs) else None
+            if any([low, high]):
+                cutoff.append((low, high))
+            else:
+                removed.append(i)
+        return cutoff, removed
+
+    @staticmethod
+    def _get_kzf_cutoff_period(kzf_settings, fs):
+        cutoff = []
+        for (m, k) in kzf_settings:
+            w0 = omega_null_kzf(m * fs, k) * fs
+            cutoff.append(1. / w0)
+        return cutoff
+
+    @staticmethod
+    def _period_to_freq(cutoff_p):
+        return list(map(lambda x: (1. / x[0] if x[0] is not None else None, 1. / x[1] if x[1] is not None else None),
+                        cutoff_p))
+
+    @staticmethod
+    def _get_fs(**kwargs):
+        """Return frequency in 1/day (not Hz)"""
+        sampling = kwargs.get("sampling")
+        if sampling == "daily":
+            return 1
+        elif sampling == "hourly":
+            return 24
+        else:
+            raise ValueError(f"Unknown sampling rate {sampling}. Only daily and hourly resolution is supported.")
+
+    @TimeTrackingWrapper
+    def apply_filter(self):
+        """Apply FIR filter only on inputs."""
+        fir = FIRFilter(self.input_data.astype("float32"), self.fs, self.filter_order, self.filter_cutoff_freq,
+                        self.filter_window_type, self.target_dim)
+        self.fir_coeff = fir.filter_coefficients()
+        fir_data = fir.filtered_data()
+        if self._add_unfiltered is True:
+            fir_data.append(self.input_data)
+        self.input_data = xr.concat(fir_data, pd.Index(self.create_filter_index(), name=self.filter_dim))
+        # this is just a code snippet to check the results of the kz filter
+        # import matplotlib
+        # matplotlib.use("TkAgg")
+        # import matplotlib.pyplot as plt
+        # self.input_data.sel(filter="low", variables="temp", Stations="DEBW107").plot()
+        # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter")
+
+    def create_filter_index(self) -> pd.Index:
+        """
+        Create name for filter dimension. Use 'high' or 'low' for high/low pass data and 'bandi' for band pass data with
+        increasing numerator i (starting from 1). If 1 low, 2 band, and 1 high pass filter is used the filter index will
+        become to ['low', 'band1', 'band2', 'high'].
+        """
+        index = []
+        band_num = 1
+        for (low, high) in self.filter_cutoff_period:
+            if low is None:
+                index.append("low")
+            elif high is None:
+                index.append("high")
+            else:
+                index.append(f"band{band_num}")
+                band_num += 1
+        if self._add_unfiltered:
+            index.append("unfiltered")
+        self.filter_dim_order = index
+        return pd.Index(index, name=self.filter_dim)
+
+    def _create_lazy_data(self):
+        return [self._data, self.meta, self.input_data, self.target_data, self.fir_coeff, self.filter_dim_order]
+
+    def _extract_lazy(self, lazy_data):
+        _data, _meta, _input_data, _target_data, self.fir_coeff, self.filter_dim_order = lazy_data
+        super(__class__, self)._extract_lazy((_data, _meta, _input_data, _target_data))
+
+
+class DataHandlerFirFilter(DataHandlerFilter):
+    """Data handler using FIR filtered data."""
+
+    data_handler = DataHandlerFirFilterSingleStation
+    data_handler_transformation = DataHandlerFirFilterSingleStation
+    _requirements = data_handler.requirements()
+
+
+class DataHandlerKzFilterSingleStation(DataHandlerFilterSingleStation):
+    """Data handler for a single station to be used by a superior data handler. Inputs are kz filtered."""
+
+    _requirements = remove_items(inspect.getfullargspec(DataHandlerFilterSingleStation).args, ["self", "station"])
+    _hash = DataHandlerFilterSingleStation._hash + ["kz_filter_length", "kz_filter_iter"]
+
+    def __init__(self, *args, kz_filter_length, kz_filter_iter, **kwargs):
+        self._check_sampling(**kwargs)
+        # self.original_data = None  # ToDo: implement here something to store unfiltered data
+        self.kz_filter_length = to_list(kz_filter_length)
+        self.kz_filter_iter = to_list(kz_filter_iter)
+        self.cutoff_period = None
+        self.cutoff_period_days = None
+        super().__init__(*args, **kwargs)
+
+    @TimeTrackingWrapper
+    def apply_filter(self):
+        """Apply kolmogorov zurbenko filter only on inputs."""
+        kz = KZFilter(self.input_data, wl=self.kz_filter_length, itr=self.kz_filter_iter, filter_dim=self.time_dim)
+        filtered_data: List[xr.DataArray] = kz.run()
+        self.cutoff_period = kz.period_null()
+        self.cutoff_period_days = kz.period_null_days()
+        self.input_data = xr.concat(filtered_data, pd.Index(self.create_filter_index(), name=self.filter_dim))
+        # this is just a code snippet to check the results of the kz filter
+        # import matplotlib
+        # matplotlib.use("TkAgg")
+        # import matplotlib.pyplot as plt
+        # self.input_data.sel(filter="74d", variables="temp", Stations="DEBW107").plot()
+        # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter")
+
+    def create_filter_index(self) -> pd.Index:
+        """
+        Round cut off periods in days and append 'res' for residuum index.
+
+        Round small numbers (<10) to single decimal, and higher numbers to int. Transform as list of str and append
+        'res' for residuum index.
+        """
+        index = np.round(self.cutoff_period_days, 1)
+        f = lambda x: int(np.round(x)) if x >= 10 else np.round(x, 1)
+        index = list(map(f, index.tolist()))
+        index = list(map(lambda x: str(x) + "d", index)) + ["res"]
+        self.filter_dim_order = index
+        return pd.Index(index, name=self.filter_dim)
+
+    def _create_lazy_data(self):
+        return [self._data, self.meta, self.input_data, self.target_data, self.cutoff_period, self.cutoff_period_days,
+                self.filter_dim_order]
+
+    def _extract_lazy(self, lazy_data):
+        _data, _meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days, \
+        self.filter_dim_order = lazy_data
+        super(__class__, self)._extract_lazy((_data, _meta, _input_data, _target_data))
+
+
+class DataHandlerKzFilter(DataHandlerFilter):
+    """Data handler using kz filtered data."""
+
+    data_handler = DataHandlerKzFilterSingleStation
+    data_handler_transformation = DataHandlerKzFilterSingleStation
+    _requirements = data_handler.requirements()
+
+
+class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation):
+    """
+    Data handler for a single station to be used by a superior data handler. Inputs are FIR filtered. In contrast to
+    the simple DataHandlerFirFilterSingleStation, this data handler is centered around t0 to have no time delay. For
+    values in the future (t > t0), this data handler assumes a climatological value for the low pass data and values of
+    0 for all residuum components.
+
+    :param apriori: Data to use as apriori information. This should be either a xarray dataarray containing monthly or
+        any other heuristic to support the clim filter, or a list of such arrays containing heuristics for all residua
+        in addition. The 2nd can be used together with apriori_type `residuum_stats` which estimates the error of the
+        residuum when the clim filter should be applied with exogenous parameters. If apriori_type is None/`zeros` data
+        can be provided, but this is not required in this case.
+    :param apriori_type: set type of information that is provided to the clim filter. For the first low pass always a
+        calculated or given statistic is used. For residuum prediction a constant value of zero is assumed if
+        apriori_type is None or `zeros`, and a climatology of the residuum is used for `residuum_stats`.
+    :param apriori_diurnal: use diurnal anomalies of each hour as addition to the apriori information type chosen by
+        parameter apriori_type. This is only applicable for hourly resolution data.
+    """
+
+    _requirements = remove_items(DataHandlerFirFilterSingleStation.requirements(), "station")
+    _hash = DataHandlerFirFilterSingleStation._hash + ["apriori_type", "apriori_sel_opts", "apriori_diurnal"]
+    _store_attributes = DataHandlerFirFilterSingleStation.store_attributes() + ["apriori"]
+
+    def __init__(self, *args, apriori=None, apriori_type=None, apriori_diurnal=False, apriori_sel_opts=None,
+                 plot_path=None, name_affix=None, **kwargs):
+        self.apriori_type = apriori_type
+        self.climate_filter_coeff = None  # coefficents of the used FIR filter
+        self.apriori = apriori  # exogenous apriori information or None to calculate from data (endogenous)
+        self.apriori_diurnal = apriori_diurnal
+        self.all_apriori = None  # collection of all apriori information
+        self.apriori_sel_opts = apriori_sel_opts  # ensure to separate exogenous and endogenous information
+        self.plot_path = plot_path  # use this path to create insight plots
+        self.plot_name_affix = name_affix
+        super().__init__(*args, **kwargs)
+
+    @TimeTrackingWrapper
+    def apply_filter(self):
+        """Apply FIR filter only on inputs."""
+        self.apriori = self.apriori.get(str(self)) if isinstance(self.apriori, dict) else self.apriori
+        logging.info(f"{self.station}: call ClimateFIRFilter")
+        plot_name = str(self)  # if self.plot_name_affix is None else f"{str(self)}_{self.plot_name_affix}"
+        climate_filter = ClimateFIRFilter(self.input_data.astype("float32"), self.fs, self.filter_order,
+                                          self.filter_cutoff_freq,
+                                          self.filter_window_type, time_dim=self.time_dim, var_dim=self.target_dim,
+                                          apriori_type=self.apriori_type, apriori=self.apriori,
+                                          apriori_diurnal=self.apriori_diurnal, sel_opts=self.apriori_sel_opts,
+                                          plot_path=self.plot_path, plot_name=plot_name,
+                                          minimum_length=self.window_history_size, new_dim=self.window_dim)
+        self.climate_filter_coeff = climate_filter.filter_coefficients
+
+        # store apriori information: store all if residuum_stat method was used, otherwise just store initial apriori
+        if self.apriori_type == "residuum_stats":
+            self.apriori = climate_filter.apriori_data
+        else:
+            self.apriori = climate_filter.initial_apriori_data
+        self.all_apriori = climate_filter.apriori_data
+
+        climate_filter_data = [c.sel({self.window_dim: slice(-self.window_history_size, 0)}) for c in
+                               climate_filter.filtered_data]
+
+        # create input data with filter index
+        input_data = xr.concat(climate_filter_data, pd.Index(self.create_filter_index(), name=self.filter_dim))
+
+        # add unfiltered raw data
+        if self._add_unfiltered is True:
+            data_raw = self.shift(self.input_data, self.time_dim, -self.window_history_size)
+            data_raw = data_raw.expand_dims({self.filter_dim: ["unfiltered"]}, -1)
+            input_data = xr.concat([input_data, data_raw], self.filter_dim)
+
+        self.input_data = input_data
+
+        # this is just a code snippet to check the results of the filter
+        # import matplotlib
+        # matplotlib.use("TkAgg")
+        # import matplotlib.pyplot as plt
+        # self.input_data.sel(filter="low", variables="temp", Stations="DEBW107").plot()
+        # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter")
+
+    def create_filter_index(self) -> pd.Index:
+        """
+        Round cut off periods in days and append 'res' for residuum index.
+
+        Round small numbers (<10) to single decimal, and higher numbers to int. Transform as list of str and append
+        'res' for residuum index. Add index unfiltered if the raw / unfiltered data is appended to data in addition.
+        """
+        index = np.round(self.filter_cutoff_period, 1)
+        f = lambda x: int(np.round(x)) if x >= 10 else np.round(x, 1)
+        index = list(map(f, index.tolist()))
+        index = list(map(lambda x: str(x) + "d", index)) + ["res"]
+        if self._add_unfiltered:
+            index.append("unfiltered")
+        self.filter_dim_order = index
+        return pd.Index(index, name=self.filter_dim)
+
+    def _create_lazy_data(self):
+        return [self._data, self.meta, self.input_data, self.target_data, self.climate_filter_coeff,
+                self.apriori, self.all_apriori, self.filter_dim_order]
+
+    def _extract_lazy(self, lazy_data):
+        _data, _meta, _input_data, _target_data, self.climate_filter_coeff, self.apriori, self.all_apriori, \
+        self.filter_dim_order = lazy_data
+        DataHandlerSingleStation._extract_lazy(self, (_data, _meta, _input_data, _target_data))
+
+    @staticmethod
+    def _prepare_filter_cutoff_period(filter_cutoff_period, fs):
+        """Frequency must be smaller than the sampling frequency fs. Otherwise remove given cutoff period pair."""
+        cutoff = []
+        removed = []
+        for i, period in enumerate(to_list(filter_cutoff_period)):
+            if period > 2. / fs:
+                cutoff.append(period)
+            else:
+                removed.append(i)
+        return cutoff, removed
+
+    @staticmethod
+    def _period_to_freq(cutoff_p):
+        return [1. / x for x in cutoff_p]
+
+    def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None:
+        """
+        Create a xr.DataArray containing history data.
+
+        Shift the data window+1 times and return a xarray which has a new dimension 'window' containing the shifted
+        data. This is used to represent history in the data. Results are stored in history attribute.
+
+        :param dim_name_of_inputs: Name of dimension which contains the input variables
+        :param window: number of time steps to look back in history
+                Note: window will be treated as negative value. This should be in agreement with looking back on
+                a time line. Nonetheless positive values are allowed but they are converted to its negative
+                expression
+        :param dim_name_of_shift: Dimension along shift will be applied
+        """
+        data = self.input_data
+        sampling = {"daily": "D", "hourly": "h"}.get(to_list(self.sampling)[0])
+        data.coords[dim_name_of_shift] = data.coords[dim_name_of_shift] - np.timedelta64(self.window_history_offset,
+                                                                                         sampling)
+        data.coords[self.window_dim] = data.coords[self.window_dim] + self.window_history_offset
+        self.history = data
+
+    def call_transform(self, inverse=False):
+        opts_input = self._transformation[0]
+        self.input_data, opts_input = self.transform(self.input_data, dim=[self.time_dim, self.window_dim],
+                                                     inverse=inverse, opts=opts_input,
+                                                     transformation_dim=self.target_dim)
+        opts_target = self._transformation[1]
+        self.target_data, opts_target = self.transform(self.target_data, dim=self.time_dim, inverse=inverse,
+                                                       opts=opts_target, transformation_dim=self.target_dim)
+        self._transformation = (opts_input, opts_target)
+
+
+class DataHandlerClimateFirFilter(DataHandlerFilter):
+    """Data handler using climatic adjusted FIR filtered data."""
+
+    data_handler = DataHandlerClimateFirFilterSingleStation
+    data_handler_transformation = DataHandlerClimateFirFilterSingleStation
+    _requirements = data_handler.requirements()
+    _store_attributes = data_handler.store_attributes()
+
+    # def get_X_original(self):
+    #     X = []
+    #     for data in self._collection:
+    #         X_total = data.get_X()
+    #         filter_dim = data.filter_dim
+    #         for filter in data.filter_dim_order:
+    #             X.append(X_total.sel({filter_dim: filter}, drop=True))
+    #     return X
diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py
index 8d977e115cf7ea85d4d83bfac4c59977412ab8a7..c97d57ef7edf26c258040047343a701974a9a8f1 100644
--- a/mlair/data_handler/default_data_handler.py
+++ b/mlair/data_handler/default_data_handler.py
@@ -33,14 +33,16 @@ class DefaultDataHandler(AbstractDataHandler):
     from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation as data_handler_transformation
 
     _requirements = remove_items(inspect.getfullargspec(data_handler).args, ["self", "station"])
+    _store_attributes = data_handler.store_attributes()
 
     DEFAULT_ITER_DIM = "Stations"
     DEFAULT_TIME_DIM = "datetime"
+    MAX_NUMBER_MULTIPROCESSING = 16
 
     def __init__(self, id_class: data_handler, experiment_path: str, min_length: int = 0,
                  extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False, name_affix=None,
                  store_processed_data=True, iter_dim=DEFAULT_ITER_DIM, time_dim=DEFAULT_TIME_DIM,
-                 use_multiprocessing=True):
+                 use_multiprocessing=True, max_number_multiprocessing=MAX_NUMBER_MULTIPROCESSING):
         super().__init__()
         self.id_class = id_class
         self.time_dim = time_dim
@@ -51,6 +53,7 @@ class DefaultDataHandler(AbstractDataHandler):
         self._X_extreme = None
         self._Y_extreme = None
         self._use_multiprocessing = use_multiprocessing
+        self._max_number_multiprocessing = max_number_multiprocessing
         _name_affix = str(f"{str(self.id_class)}_{name_affix}" if name_affix is not None else id(self))
         self._save_file = os.path.join(experiment_path, "data", f"{_name_affix}.pickle")
         self._collection = self._create_collection()
@@ -79,7 +82,7 @@ class DefaultDataHandler(AbstractDataHandler):
     def _cleanup(self):
         directory = os.path.dirname(self._save_file)
         if os.path.exists(directory) is False:
-            os.makedirs(directory)
+            os.makedirs(directory, exist_ok=True)
         if os.path.exists(self._save_file):
             shutil.rmtree(self._save_file, ignore_errors=True)
 
@@ -93,6 +96,16 @@ class DefaultDataHandler(AbstractDataHandler):
             logging.debug(f"save pickle data to {self._save_file}")
             self._reset_data()
 
+    def get_store_attributes(self):
+        attr_dict = {}
+        for attr in self.store_attributes():
+            try:
+                val = self.__getattribute__(attr)
+            except AttributeError:
+                val = self.id_class.__getattribute__(attr)
+            attr_dict[attr] = val
+        return attr_dict
+
     @staticmethod
     def _force_dask_computation(data):
         try:
@@ -333,7 +346,9 @@ class DefaultDataHandler(AbstractDataHandler):
                         if "feature_range" in opts.keys():
                             transformation_dict[i][var]["feature_range"] = opts.get("feature_range", None)
 
-        if multiprocessing.cpu_count() > 1 and kwargs.get("use_multiprocessing", True) is True:  # parallel solution
+        max_process = kwargs.get("max_number_multiprocessing", 16)
+        n_process = min([psutil.cpu_count(logical=False), len(set_stations), max_process])  # use only physical cpus
+        if n_process > 1 and kwargs.get("use_multiprocessing", True) is True:  # parallel solution
             logging.info("use parallel transformation approach")
             pool = multiprocessing.Pool(
                 min([psutil.cpu_count(logical=False), len(set_stations), 16]))  # use only physical cpus
diff --git a/mlair/helpers/filter.py b/mlair/helpers/filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..a63cef975888162f335e4528c2f99bdfc7a892d5
--- /dev/null
+++ b/mlair/helpers/filter.py
@@ -0,0 +1,918 @@
+import gc
+import warnings
+from typing import Union, Callable, Tuple
+import logging
+import os
+import time
+
+import datetime
+import numpy as np
+import pandas as pd
+from matplotlib import pyplot as plt
+from scipy import signal
+import xarray as xr
+import dask.array as da
+
+from mlair.helpers import to_list, TimeTrackingWrapper, TimeTracking
+
+
+class FIRFilter:
+
+    def __init__(self, data, fs, order, cutoff, window, dim):
+
+        filtered = []
+        h = []
+        for i in range(len(order)):
+            fi, hi = fir_filter(data, fs, order=order[i], cutoff_low=cutoff[i][0], cutoff_high=cutoff[i][1],
+                                window=window, dim=dim, h=None, causal=True, padlen=None)
+            filtered.append(fi)
+            h.append(hi)
+
+        self._filtered = filtered
+        self._h = h
+
+    def filter_coefficients(self):
+        return self._h
+
+    def filtered_data(self):
+        return self._filtered
+        #
+        # y, h = fir_filter(station_data.values.flatten(), fs, order[0], cutoff_low=cutoff[0][0], cutoff_high=cutoff[0][1],
+        #                   window=window)
+        # filtered = xr.ones_like(station_data) * y.reshape(station_data.values.shape)
+        # # band pass
+        # y_band, h_band = fir_filter(station_data.values.flatten(), fs, order[1], cutoff_low=cutoff[1][0],
+        #                             cutoff_high=cutoff[1][1], window=window)
+        # filtered_band = xr.ones_like(station_data) * y_band.reshape(station_data.values.shape)
+        # # band pass 2
+        # y_band_2, h_band_2 = fir_filter(station_data.values.flatten(), fs, order[2], cutoff_low=cutoff[2][0],
+        #                                 cutoff_high=cutoff[2][1], window=window)
+        # filtered_band_2 = xr.ones_like(station_data) * y_band_2.reshape(station_data.values.shape)
+        # # high pass
+        # y_high, h_high = fir_filter(station_data.values.flatten(), fs, order[3], cutoff_low=cutoff[3][0],
+        #                             cutoff_high=cutoff[3][1], window=window)
+        # filtered_high = xr.ones_like(station_data) * y_high.reshape(station_data.values.shape)
+
+
+class ClimateFIRFilter:
+    from mlair.plotting.data_insight_plotting import PlotClimateFirFilter
+
+    def __init__(self, data, fs, order, cutoff, window, time_dim, var_dim, apriori=None, apriori_type=None,
+                 apriori_diurnal=False, sel_opts=None, plot_path=None, plot_name=None,
+                 minimum_length=None, new_dim=None):
+        """
+        :param data: data to filter
+        :param fs: sampling frequency in 1/days -> 1d: fs=1 -> 1H: fs=24
+        :param order: a tuple with the order of the filter in same ordering like cutoff
+        :param cutoff: a tuple with the cutoff frequencies (all are applied as low pass)
+        :param window: window type of the filter (e.g. hamming)
+        :param time_dim: name of time dimension to apply filter along
+        :param var_dim: name of variables dimension
+        :param apriori: apriori information to use for the first low pass. If None, climatology is calculated on the
+            provided data.
+        :param apriori_type: type of apriori information to use. Climatology will be used always for first low pass. For
+            the residuum either the value zero is used (apriori_type is None or "zeros") or a climatology on the
+            residua is used ("residuum_stats").
+        :param apriori_diurnal: Use diurnal cycle as additional apriori information (only applicable for hourly
+            resoluted data). The mean anomaly of each hour is added to the apriori_type information.
+        """
+        logging.info(f"{plot_name}: start init ClimateFIRFilter")
+        self.plot_path = plot_path
+        self.plot_name = plot_name
+        self.plot_data = []
+        filtered = []
+        h = []
+        if sel_opts is not None:
+            sel_opts = sel_opts if isinstance(sel_opts, dict) else {time_dim: sel_opts}
+        sampling = {1: "1d", 24: "1H"}.get(int(fs))
+        logging.debug(f"{plot_name}: create diurnal_anomalies")
+        if apriori_diurnal is True and sampling == "1H":
+            # diurnal_anomalies = self.create_hourly_mean(data, sel_opts=sel_opts, sampling=sampling, time_dim=time_dim,
+            #                                             as_anomaly=True)
+            diurnal_anomalies = self.create_seasonal_hourly_mean(data, sel_opts=sel_opts, sampling=sampling,
+                                                                 time_dim=time_dim,
+                                                                 as_anomaly=True)
+        else:
+            diurnal_anomalies = 0
+        logging.debug(f"{plot_name}: create monthly apriori")
+        if apriori is None:
+            apriori = self.create_monthly_mean(data, sel_opts=sel_opts, sampling=sampling,
+                                               time_dim=time_dim) + diurnal_anomalies
+            logging.debug(f"{plot_name}: apriori shape = {apriori.shape}")
+        apriori_list = to_list(apriori)
+        input_data = data.__deepcopy__()
+
+        # for viz
+        plot_dates = None
+
+        # create tmp dimension to apply filter, search for unused name
+        new_dim = self._create_tmp_dimension(input_data) if new_dim is None else new_dim
+
+        for i in range(len(order)):
+            logging.info(f"{plot_name}: start filter for order {order[i]}")
+            # calculate climatological filter
+            # ToDo: remove all methods except the vectorized version
+            _minimum_length = self._minimum_length(order, minimum_length, i, window)
+            fi, hi, apriori, plot_data = self.clim_filter(input_data, fs, cutoff[i], order[i],
+                                                          apriori=apriori_list[i],
+                                                          sel_opts=sel_opts, sampling=sampling, time_dim=time_dim,
+                                                          window=window, var_dim=var_dim,
+                                                          minimum_length=_minimum_length, new_dim=new_dim,
+                                                          plot_dates=plot_dates)
+
+            logging.info(f"{plot_name}: finished clim_filter calculation")
+            if minimum_length is None:
+                filtered.append(fi)
+            else:
+                filtered.append(fi.sel({new_dim: slice(-minimum_length, 0)}))
+            h.append(hi)
+            gc.collect()
+            self.plot_data.append(plot_data)
+            plot_dates = {e["t0"] for e in plot_data}
+
+            # calculate residuum
+            logging.info(f"{plot_name}: calculate residuum")
+            coord_range = range(fi.coords[new_dim].values.min(), fi.coords[new_dim].values.max() + 1)
+            if new_dim in input_data.coords:
+                input_data = input_data.sel({new_dim: coord_range}) - fi
+            else:
+                input_data = self._shift_data(input_data, coord_range, time_dim, var_dim, new_dim) - fi
+
+            # create new apriori information for next iteration if no further apriori is provided
+            if len(apriori_list) <= i + 1:
+                logging.info(f"{plot_name}: create diurnal_anomalies")
+                if apriori_diurnal is True and sampling == "1H":
+                    # diurnal_anomalies = self.create_hourly_mean(input_data.sel({new_dim: 0}, drop=True),
+                    #                                             sel_opts=sel_opts, sampling=sampling,
+                    #                                             time_dim=time_dim, as_anomaly=True)
+                    diurnal_anomalies = self.create_seasonal_hourly_mean(input_data.sel({new_dim: 0}, drop=True),
+                                                                         sel_opts=sel_opts, sampling=sampling,
+                                                                         time_dim=time_dim, as_anomaly=True)
+                else:
+                    diurnal_anomalies = 0
+                logging.info(f"{plot_name}: create monthly apriori")
+                if apriori_type is None or apriori_type == "zeros":  # zero version
+                    apriori_list.append(xr.zeros_like(apriori_list[i]) + diurnal_anomalies)
+                elif apriori_type == "residuum_stats":  # calculate monthly statistic on residuum
+                    apriori_list.append(
+                        -self.create_monthly_mean(input_data.sel({new_dim: 0}, drop=True), sel_opts=sel_opts,
+                                                  sampling=sampling,
+                                                  time_dim=time_dim) + diurnal_anomalies)
+                else:
+                    raise ValueError(f"Cannot handle unkown apriori type: {apriori_type}. Please choose from None, "
+                                     f"`zeros` or `residuum_stats`.")
+        # add last residuum to filtered
+        if minimum_length is None:
+            filtered.append(input_data)
+        else:
+            filtered.append(input_data.sel({new_dim: slice(-minimum_length, 0)}))
+        # filtered.append(input_data)
+        self._filtered = filtered
+        self._h = h
+        self._apriori = apriori_list
+
+        # visualize
+        if self.plot_path is not None:
+            self.PlotClimateFirFilter(self.plot_path, self.plot_data, sampling, plot_name)
+            # self._plot(sampling, new_dim=new_dim)
+
+    @staticmethod
+    def _minimum_length(order, minimum_length, pos, window):
+        next_order = 0
+        if pos + 1 < len(order):
+            next_order = order[pos + 1]
+            if window == "kzf" and isinstance(next_order, tuple):
+                next_order = filter_width_kzf(*next_order)
+        if minimum_length is not None:
+            next_order = next_order + minimum_length
+        return next_order if next_order > 0 else None
+
+    @staticmethod
+    def create_unity_array(data, time_dim, extend_range=366):
+        """Create a xr data array filled with ones. time_dim is extended by extend_range days in future and past."""
+        coords = data.coords
+
+        # extend time_dim by given extend_range days
+        start = coords[time_dim][0].values.astype("datetime64[D]") - np.timedelta64(extend_range, "D")
+        end = coords[time_dim][-1].values.astype("datetime64[D]") + np.timedelta64(extend_range, "D")
+        new_time_axis = np.arange(start, end).astype("datetime64[ns]")
+
+        # construct data array with updated coords
+        new_coords = {k: data.coords[k].values if k != time_dim else new_time_axis for k in coords}
+        new_array = xr.DataArray(1, coords=new_coords, dims=new_coords.keys()).transpose(*data.dims)
+
+        # loffset is required because resampling uses last day in month as resampling timestamp
+        return new_array.resample({time_dim: "1m"}, loffset=datetime.timedelta(days=-15)).max()
+
+    def create_monthly_mean(self, data, sel_opts=None, sampling="1d", time_dim="datetime"):
+        """Calculate monthly statistics."""
+
+        # create unity xarray in monthly resolution with sampling point in mid of each month
+        monthly = self.create_unity_array(data, time_dim)
+
+        # apply selection if given (only use subset for monthly means)
+        if sel_opts is not None:
+            data = data.sel(**sel_opts)
+
+        # create monthly mean and replace entries in unity array
+        monthly_mean = data.groupby(f"{time_dim}.month").mean()
+        for month in monthly_mean.month.values:
+            monthly = xr.where((monthly[f"{time_dim}.month"] == month),
+                               monthly_mean.sel(month=month, drop=True),
+                               monthly)
+        # transform monthly information into original sampling rate
+        return monthly.resample({time_dim: sampling}).interpolate()
+
+        # for month in monthly_mean.month.values:
+        #     loc = (monthly[f"{time_dim}.month"] == month)
+        #     monthly.loc[{time_dim: loc}] = monthly_mean.sel(month=month, drop=True)
+        # aggregate monthly information (shift by half month, because resample base is last day)
+        # return monthly.resample({time_dim: "1m"}).max().resample({time_dim: sampling}).interpolate()
+
+    @staticmethod
+    def create_hourly_mean(data, sel_opts=None, sampling="1H", time_dim="datetime", as_anomaly=True):
+        """Calculate hourly statistics. Either the absolute value or the anomaly (as_anomaly=True)."""
+        # can only be used for hourly sampling rate
+        assert sampling == "1H"
+
+        # create unity xarray in hourly resolution
+        hourly = xr.ones_like(data)
+
+        # apply selection if given (only use subset for hourly means)
+        if sel_opts is not None:
+            data = data.sel(**sel_opts)
+
+        # create mean for each hour and replace entries in unity array, calculate anomaly if enabled
+        hourly_mean = data.groupby(f"{time_dim}.hour").mean()
+        if as_anomaly is True:
+            hourly_mean = hourly_mean - hourly_mean.mean("hour")
+        for hour in hourly_mean.hour.values:
+            loc = (hourly[f"{time_dim}.hour"] == hour)
+            hourly.loc[{f"{time_dim}": loc}] = hourly_mean.sel(hour=hour)
+        return hourly
+
+    def create_seasonal_hourly_mean(self, data, sel_opts=None, sampling="1H", time_dim="datetime", as_anomaly=True):
+        """Calculate hourly statistics. Either the absolute value or the anomaly (as_anomaly=True)."""
+        # can only be used for hourly sampling rate
+        assert sampling == "1H"
+
+        # apply selection if given (only use subset for seasonal hourly means)
+        if sel_opts is not None:
+            data = data.sel(**sel_opts)
+
+        # create unity xarray in monthly resolution with sampling point in mid of each month
+        monthly = self.create_unity_array(data, time_dim) * np.nan
+
+        seasonal_hourly_means = {}
+
+        for month in data.groupby(f"{time_dim}.month").groups.keys():
+            # select each month
+            single_month_data = data.sel({time_dim: (data[f"{time_dim}.month"] == month)})
+            hourly_mean = single_month_data.groupby(f"{time_dim}.hour").mean()
+            if as_anomaly is True:
+                hourly_mean = hourly_mean - hourly_mean.mean("hour")
+            seasonal_hourly_means[month] = hourly_mean
+
+        seasonal_coll = []
+        for hour in data.groupby(f"{time_dim}.hour").groups.keys():
+            h_coll = monthly.__deepcopy__()
+            for month in seasonal_hourly_means.keys():
+                hourly_mean_single_month = seasonal_hourly_means[month].sel(hour=hour, drop=True)
+                h_coll = xr.where((h_coll[f"{time_dim}.month"] == month),
+                                  hourly_mean_single_month,
+                                  h_coll)
+            h_coll = h_coll.resample({time_dim: sampling}).interpolate()
+            h_coll = h_coll.sel({time_dim: (h_coll[f"{time_dim}.hour"] == hour)})
+            seasonal_coll.append(h_coll)
+        hourly = xr.concat(seasonal_coll, time_dim).sortby(time_dim).resample({time_dim: sampling}).interpolate()
+
+        return hourly
+
+    @staticmethod
+    def extend_apriori(data, apriori, time_dim, sampling="1d"):
+        """
+        Extend time range of apriori information.
+
+        This method may not working properly if length of apriori is less then one year.
+        """
+        dates = data.coords[time_dim].values
+        td_type = {"1d": "D", "1H": "h"}.get(sampling)
+
+        # apriori starts after data
+        if dates[0] < apriori.coords[time_dim].values[0]:
+            logging.debug(f"{data.coords['Stations'].values[0]}: apriori starts after data")
+
+            # add difference in full years
+            date_diff = abs(dates[0] - apriori.coords[time_dim].values[0]).astype("timedelta64[D]")
+            extend_range = np.ceil(date_diff / (np.timedelta64(1, "D") * 365)).astype(int) * 365
+            factor = 1 if td_type == "D" else 24
+
+            # get fill data range
+            start = apriori.coords[time_dim][0].values.astype("datetime64[%s]" % td_type)
+            end = apriori.coords[time_dim][0].values.astype("datetime64[%s]" % td_type) + np.timedelta64(
+                366 * factor + 1, td_type)
+
+            # fill year by year
+            for i in range(365, extend_range + 365, 365):
+                apriori_tmp = apriori.sel({time_dim: slice(start, end)})  # hint: slice includes end date
+                new_time_axis = apriori_tmp.coords[time_dim] - np.timedelta64(i * factor, td_type)
+                apriori_tmp.coords[time_dim] = new_time_axis
+                apriori = apriori.combine_first(apriori_tmp)
+
+        # apriori ends before data
+        if dates[-1] + np.timedelta64(365, "D") > apriori.coords[time_dim].values[-1]:
+            logging.debug(f"{data.coords['Stations'].values[0]}: apriori ends before data")
+
+            # add difference in full years + 1 year (because apriori is used as future estimate)
+            date_diff = abs(dates[-1] - apriori.coords[time_dim].values[-1]).astype("timedelta64[D]")
+            extend_range = np.ceil(date_diff / (np.timedelta64(1, "D") * 365)).astype(int) * 365 + 365
+            factor = 1 if td_type == "D" else 24
+
+            # get fill data range
+            start = apriori.coords[time_dim][-1].values.astype("datetime64[%s]" % td_type) - np.timedelta64(
+                366 * factor + 1, td_type)
+            end = apriori.coords[time_dim][-1].values.astype("datetime64[%s]" % td_type)
+
+            # fill year by year
+            for i in range(365, extend_range + 365, 365):
+                apriori_tmp = apriori.sel({time_dim: slice(start, end)})  # hint: slice includes end date
+                new_time_axis = apriori_tmp.coords[time_dim] + np.timedelta64(i * factor, td_type)
+                apriori_tmp.coords[time_dim] = new_time_axis
+                apriori = apriori.combine_first(apriori_tmp)
+
+        return apriori
+
+    @TimeTrackingWrapper
+    def clim_filter(self, data, fs, cutoff_high, order, apriori=None, sel_opts=None,
+                    sampling="1d", time_dim="datetime", var_dim="variables", window: Union[str, Tuple] = "hamming",
+                    minimum_length=None, new_dim="window", plot_dates=None):
+
+        logging.debug(f"{data.coords['Stations'].values[0]}: extend apriori")
+
+        # calculate apriori information from data if not given and extend its range if not sufficient long enough
+        if apriori is None:
+            apriori = self.create_monthly_mean(data, sel_opts=sel_opts, sampling=sampling, time_dim=time_dim)
+        apriori = apriori.astype(data.dtype)
+        apriori = self.extend_apriori(data, apriori, time_dim, sampling)
+
+        # calculate FIR filter coefficients
+        if window == "kzf":
+            h = firwin_kzf(*order)
+        else:
+            h = signal.firwin(order, cutoff_high, pass_zero="lowpass", fs=fs, window=window)
+        length = len(h)
+
+        # use filter length if no minimum is given, otherwise use minimum + half filter length for extension
+        extend_length_history = length if minimum_length is None else minimum_length + int((length + 1) / 2)
+        extend_length_future = int((length + 1) / 2) + 1
+
+        # collect some data for visualization
+        plot_pos = np.array([0.25, 1.5, 2.75, 4]) * 365 * fs
+        if plot_dates is None:
+            plot_dates = [data.isel({time_dim: int(pos)}).coords[time_dim].values for pos in plot_pos if
+                          pos < len(data.coords[time_dim])]
+        plot_data = []
+
+        coll = []
+
+        for var in reversed(data.coords[var_dim].values):
+            logging.info(f"{data.coords['Stations'].values[0]} ({var}): sel data")
+
+            _start = pd.to_datetime(data.coords[time_dim].min().values).year
+            _end = pd.to_datetime(data.coords[time_dim].max().values).year
+            filt_coll = []
+            for _year in range(_start, _end + 1):
+                logging.info(f"{data.coords['Stations'].values[0]} ({var}): year={_year}")
+
+                time_slice = self._create_time_range_extend(_year, sampling, extend_length_history)
+                d = data.sel({var_dim: [var], time_dim: time_slice})
+                a = apriori.sel({var_dim: [var], time_dim: time_slice})
+                if len(d.coords[time_dim]) == 0:  # no data at all for this year
+                    continue
+
+                # combine historical data / observation [t0-length,t0] and climatological statistics [t0+1,t0+length]
+                if new_dim not in d.coords:
+                    history = self._shift_data(d, range(int(-extend_length_history), 1), time_dim, var_dim, new_dim)
+                else:
+                    history = d.sel({new_dim: slice(int(-extend_length_history), 0)})
+                if new_dim not in a.coords:
+                    future = self._shift_data(a, range(1, extend_length_future), time_dim, var_dim, new_dim)
+                else:
+                    future = a.sel({new_dim: slice(1, extend_length_future)})
+                filter_input_data = xr.concat([history.dropna(time_dim), future], dim=new_dim, join="left")
+                try:
+                    filter_input_data = filter_input_data.sel({time_dim: str(_year)})
+                except KeyError:  # no valid data for this year
+                    continue
+                if len(filter_input_data.coords[time_dim]) == 0:  # no valid data for this year
+                    continue
+
+                logging.debug(f"{data.coords['Stations'].values[0]} ({var}): start filter convolve")
+                with TimeTracking(name=f"{data.coords['Stations'].values[0]} ({var}): filter convolve",
+                                  logging_level=logging.DEBUG):
+                    filt = xr.apply_ufunc(fir_filter_convolve, filter_input_data,
+                                          input_core_dims=[[new_dim]],
+                                          output_core_dims=[[new_dim]],
+                                          vectorize=True,
+                                          kwargs={"h": h},
+                                          output_dtypes=[d.dtype])
+
+                if minimum_length is None:
+                    filt_coll.append(filt.sel({new_dim: slice(-extend_length_history, 0)}, drop=True))
+                else:
+                    filt_coll.append(filt.sel({new_dim: slice(-minimum_length, 0)}, drop=True))
+
+                # visualization
+                for viz_date in set(plot_dates).intersection(filt.coords[time_dim].values):
+                    try:
+                        td_type = {"1d": "D", "1H": "h"}.get(sampling)
+                        t_minus = viz_date + np.timedelta64(int(-extend_length_history), td_type)
+                        t_plus = viz_date + np.timedelta64(int(extend_length_future), td_type)
+                        if new_dim not in d.coords:
+                            tmp_filter_data = self._shift_data(d.sel({time_dim: slice(t_minus, t_plus)}),
+                                                               range(int(-extend_length_history),
+                                                                     int(extend_length_future)),
+                                                               time_dim, var_dim, new_dim).sel({time_dim: viz_date})
+                        else:
+                            # tmp_filter_data = d.sel({time_dim: viz_date,
+                            #                          new_dim: slice(int(-extend_length_history), int(extend_length_future))})
+                            tmp_filter_data = None
+                        valid_range = range(int((length + 1) / 2) if minimum_length is None else minimum_length, 1)
+                        plot_data.append({"t0": viz_date,
+                                          "var": var,
+                                          "filter_input": filter_input_data.sel({time_dim: viz_date}),
+                                          "filter_input_nc": tmp_filter_data,
+                                          "valid_range": valid_range,
+                                          "time_range": d.sel(
+                                              {time_dim: slice(t_minus, t_plus - np.timedelta64(1, td_type))}).coords[
+                                              time_dim].values,
+                                          "h": h,
+                                          "new_dim": new_dim})
+                    except:
+                        pass
+
+            # collect all filter results
+            coll.append(xr.concat(filt_coll, time_dim))
+            gc.collect()
+
+        logging.debug(f"{data.coords['Stations'].values[0]}: concat all variables")
+        res = xr.concat(coll, var_dim)
+        # create result array with same shape like input data, gabs are filled by nans
+        logging.debug(f"{data.coords['Stations'].values[0]}: create res_full")
+
+        new_coords = {**{k: data.coords[k].values for k in data.coords if k != new_dim}, new_dim: res.coords[new_dim]}
+        dims = [*data.dims, new_dim] if new_dim not in data.dims else data.dims
+        res = res.transpose(*dims)
+        # res_full = xr.DataArray(dims=dims, coords=new_coords)
+        # res_full.loc[res.coords] = res
+        # res_full.compute()
+        res_full = res.broadcast_like(xr.DataArray(dims=dims, coords=new_coords))
+        return res_full, h, apriori, plot_data
+
+    @staticmethod
+    def _create_time_range_extend(year, sampling, extend_length):
+        td_type = {"1d": "D", "1H": "h"}.get(sampling)
+        delta = np.timedelta64(extend_length + 1, td_type)
+        start = np.datetime64(f"{year}-01-01") - delta
+        end = np.datetime64(f"{year}-12-31") + delta
+        return slice(start, end)
+
+    @staticmethod
+    def _create_tmp_dimension(data):
+        new_dim = "window"
+        count = 0
+        while new_dim in data.dims:
+            new_dim += new_dim
+            count += 1
+            if count > 10:
+                raise ValueError("Could not create new dimension.")
+        return new_dim
+
+    def _shift_data(self, data, index_value, time_dim, squeeze_dim, new_dim):
+        coll = []
+        for i in index_value:
+            coll.append(data.shift({time_dim: -i}))
+        new_ind = self.create_index_array(new_dim, index_value, squeeze_dim)
+        return xr.concat(coll, dim=new_ind)
+
+    @staticmethod
+    def create_index_array(index_name: str, index_value, squeeze_dim: str):
+        ind = pd.DataFrame({'val': index_value}, index=index_value)
+        res = xr.Dataset.from_dataframe(ind).to_array(squeeze_dim).rename({'index': index_name}).squeeze(
+            dim=squeeze_dim,
+            drop=True)
+        res.name = index_name
+        return res
+
+    def _plot(self, sampling, new_dim="window"):
+        h = None
+        td_type = {"1d": "D", "1H": "h"}.get(sampling)
+        if self.plot_path is None:
+            return
+        plot_folder = os.path.join(os.path.abspath(self.plot_path), "climFIR")
+        if not os.path.exists(plot_folder):
+            os.makedirs(plot_folder)
+
+        # set plot parameter
+        rc_params = {'axes.labelsize': 'large',
+                     'xtick.labelsize': 'large',
+                     'ytick.labelsize': 'large',
+                     'legend.fontsize': 'medium',
+                     'axes.titlesize': 'large',
+                     }
+        plt.rcParams.update(rc_params)
+
+        plot_dict = {}
+        for i, o in enumerate(range(len(self.plot_data))):
+            plot_data = self.plot_data[i]
+            for p_d in plot_data:
+                var = p_d.get("var")
+                t0 = p_d.get("t0")
+                filter_input = p_d.get("filter_input")
+                filter_input_nc = p_d.get("filter_input_nc")
+                valid_range = p_d.get("valid_range")
+                time_range = p_d.get("time_range")
+                new_dim = p_d.get("new_dim")
+                h = p_d.get("h")
+                plot_dict_var = plot_dict.get(var, {})
+                plot_dict_t0 = plot_dict_var.get(t0, {})
+                plot_dict_order = {"filter_input": filter_input,
+                                   "filter_input_nc": filter_input_nc,
+                                   "valid_range": valid_range,
+                                   "time_range": time_range,
+                                   "order": o, "h": h}
+                plot_dict_t0[i] = plot_dict_order
+                plot_dict_var[t0] = plot_dict_t0
+                plot_dict[var] = plot_dict_var
+
+        for var, viz_date_dict in plot_dict.items():
+            for it0, t0 in enumerate(viz_date_dict.keys()):
+                viz_data = viz_date_dict[t0]
+                residuum_true = None
+                for ifilter in sorted(viz_data.keys()):
+                    data = viz_data[ifilter]
+                    filter_input = data["filter_input"]
+                    filter_input_nc = data["filter_input_nc"] if residuum_true is None else residuum_true.sel(
+                        {new_dim: filter_input.coords[new_dim]})
+                    valid_range = data["valid_range"]
+                    time_axis = data["time_range"]
+                    # time_axis = pd.date_range(t_minus, t_plus, freq=sampling)
+                    filter_order = data["order"]
+                    h = data["h"]
+                    t_minus = t0 + np.timedelta64(-int(1.5 * valid_range.start), td_type)
+                    t_plus = t0 + np.timedelta64(int(0.5 * valid_range.start), td_type)
+                    fig, ax = plt.subplots()
+                    ax.axvspan(t0 + np.timedelta64(-valid_range.start, td_type),
+                               t0 + np.timedelta64(valid_range.stop - 1, td_type), color="whitesmoke",
+                               label="valid area")
+                    ax.axvline(t0, color="lightgrey", lw=6, label="time of interest ($t_0$)")
+
+                    # original data
+                    ax.plot(time_axis, filter_input_nc.values.flatten(), color="darkgrey", linestyle="dashed",
+                            label="original")
+
+                    # clim apriori
+                    if ifilter == 0:
+                        d_tmp = filter_input.sel(
+                            {new_dim: slice(0, filter_input.coords[new_dim].values.max())}).values.flatten()
+                    else:
+                        d_tmp = filter_input.values.flatten()
+                    ax.plot(time_axis[len(time_axis) - len(d_tmp):], d_tmp, color="darkgrey", linestyle="solid",
+                            label="estimated future")
+
+                    # clim filter response
+                    filt = xr.apply_ufunc(fir_filter_convolve, filter_input,
+                                          input_core_dims=[[new_dim]],
+                                          output_core_dims=[[new_dim]],
+                                          vectorize=True,
+                                          kwargs={"h": h},
+                                          output_dtypes=[filter_input.dtype])
+                    ax.plot(time_axis, filt.values.flatten(), color="black", linestyle="solid",
+                            label="clim filter response", linewidth=2)
+                    residuum_estimated = filter_input - filt
+
+                    # ideal filter response
+                    filt = xr.apply_ufunc(fir_filter_convolve, filter_input_nc,
+                                          input_core_dims=[[new_dim]],
+                                          output_core_dims=[[new_dim]],
+                                          vectorize=True,
+                                          kwargs={"h": h},
+                                          output_dtypes=[filter_input.dtype])
+                    ax.plot(time_axis, filt.values.flatten(), color="black", linestyle="dashed",
+                            label="ideal filter response", linewidth=2)
+                    residuum_true = filter_input_nc - filt
+
+                    # set title, legend, and save plot
+                    ax_start = max(t_minus, time_axis[0])
+                    ax_end = min(t_plus, time_axis[-1])
+                    ax.set_xlim((ax_start, ax_end))
+                    plt.title(f"Input of ClimFilter ({str(var)})")
+                    plt.legend()
+                    fig.autofmt_xdate()
+                    plt.tight_layout()
+                    plot_name = os.path.join(plot_folder,
+                                             f"climFIR_{self.plot_name}_{str(var)}_{it0}_{ifilter}.pdf")
+                    plt.savefig(plot_name, dpi=300)
+                    plt.close('all')
+
+                    # plot residuum
+                    fig, ax = plt.subplots()
+                    ax.axvspan(t0 + np.timedelta64(-valid_range.start, td_type),
+                               t0 + np.timedelta64(valid_range.stop - 1, td_type), color="whitesmoke",
+                               label="valid area")
+                    ax.axvline(t0, color="lightgrey", lw=6, label="time of interest ($t_0$)")
+                    ax.plot(time_axis, residuum_true.values.flatten(), color="black", linestyle="dashed",
+                            label="ideal filter residuum", linewidth=2)
+                    ax.plot(time_axis, residuum_estimated.values.flatten(), color="black", linestyle="solid",
+                            label="clim filter residuum", linewidth=2)
+                    ax.set_xlim((ax_start, ax_end))
+                    plt.title(f"Residuum of ClimFilter ({str(var)})")
+                    plt.legend(loc="upper left")
+                    fig.autofmt_xdate()
+                    plt.tight_layout()
+                    plot_name = os.path.join(plot_folder,
+                                             f"climFIR_{self.plot_name}_{str(var)}_{it0}_{ifilter}_residuum.pdf")
+                    plt.savefig(plot_name, dpi=300)
+                    plt.close('all')
+
+    @property
+    def filter_coefficients(self):
+        return self._h
+
+    @property
+    def filtered_data(self):
+        return self._filtered
+
+    @property
+    def apriori_data(self):
+        return self._apriori
+
+    @property
+    def initial_apriori_data(self):
+        return self.apriori_data[0]
+
+
+def fir_filter(data, fs, order=5, cutoff_low=None, cutoff_high=None, window="hamming", dim="variables", h=None,
+               causal=True, padlen=None):
+    """Expects xarray."""
+    if h is None:
+        cutoff = []
+        if cutoff_low is not None:
+            cutoff += [cutoff_low]
+        if cutoff_high is not None:
+            cutoff += [cutoff_high]
+        if len(cutoff) == 2:
+            filter_type = "bandpass"
+        elif len(cutoff) == 1 and cutoff_low is not None:
+            filter_type = "highpass"
+        elif len(cutoff) == 1 and cutoff_high is not None:
+            filter_type = "lowpass"
+        else:
+            raise ValueError("Please provide either cutoff_low or cutoff_high.")
+        h = signal.firwin(order, cutoff, pass_zero=filter_type, fs=fs, window=window)
+    filtered = xr.ones_like(data)
+    for var in data.coords[dim]:
+        d = data.sel({dim: var}).values.flatten()
+        if causal:
+            y = signal.lfilter(h, 1., d)
+        else:
+            padlen = padlen if padlen is not None else 3 * len(h)
+            y = signal.filtfilt(h, 1., d, padlen=padlen)
+        filtered.loc[{dim: var}] = y
+    return filtered, h
+
+
+def fir_filter_convolve(data, h):
+    return signal.convolve(data, h, mode='same', method="direct") / sum(h)
+
+
+class KolmogorovZurbenkoBaseClass:
+
+    def __init__(self, df, wl, itr, is_child=False, filter_dim="window"):
+        """
+        It create the variables associate with the Kolmogorov-Zurbenko-filter.
+
+        Args:
+            df(pd.DataFrame, None): time series of a variable
+            wl(list of int): window length
+            itr(list of int): number of iteration
+        """
+        self.df = df
+        self.filter_dim = filter_dim
+        self.wl = to_list(wl)
+        self.itr = to_list(itr)
+        if abs(len(self.wl) - len(self.itr)) > 0:
+            raise ValueError("Length of lists for wl and itr must agree!")
+        self._isChild = is_child
+        self.child = self.set_child()
+        self.type = type(self).__name__
+
+    def set_child(self):
+        if len(self.wl) > 1:
+            return KolmogorovZurbenkoBaseClass(None, self.wl[1:], self.itr[1:], True, self.filter_dim)
+        else:
+            return None
+
+    def kz_filter(self, df, m, k):
+        pass
+
+    def spectral_calc(self):
+        df_start = self.df
+        kz = self.kz_filter(df_start, self.wl[0], self.itr[0])
+        filtered = self.subtract(df_start, kz)
+        # case I: no child avail -> return kz and remaining
+        if self.child is None:
+            return [kz, filtered]
+        # case II: has child -> return current kz and all child results
+        else:
+            self.child.df = filtered
+            kz_next = self.child.spectral_calc()
+            return [kz] + kz_next
+
+    @staticmethod
+    def subtract(minuend, subtrahend):
+        try:  # pandas implementation
+            return minuend.sub(subtrahend, axis=0)
+        except AttributeError:  # general implementation
+            return minuend - subtrahend
+
+    def run(self):
+        return self.spectral_calc()
+
+    def transfer_function(self):
+        m = self.wl[0]
+        k = self.itr[0]
+        omega = np.linspace(0.00001, 0.15, 5000)
+        return omega, (np.sin(m * np.pi * omega) / (m * np.sin(np.pi * omega))) ** (2 * k)
+
+    def omega_null(self, alpha=0.5):
+        a = np.sqrt(6) / np.pi
+        b = 1 / (2 * np.array(self.itr))
+        c = 1 - alpha ** b
+        d = np.array(self.wl) ** 2 - alpha ** b
+        return a * np.sqrt(c / d)
+
+    def period_null(self, alpha=0.5):
+        return 1. / self.omega_null(alpha)
+
+    def period_null_days(self, alpha=0.5):
+        return self.period_null(alpha) / 24.
+
+    def plot_transfer_function(self, fig=None, name=None):
+        if fig is None:
+            fig = plt.figure()
+        omega, transfer_function = self.transfer_function()
+        if self.child is not None:
+            transfer_function_child = self.child.plot_transfer_function(fig)
+        else:
+            transfer_function_child = transfer_function * 0
+        plt.semilogx(omega, transfer_function - transfer_function_child,
+                     label="m={:3.0f}, k={:3.0f}, T={:6.2f}d".format(self.wl[0],
+                                                                     self.itr[0],
+                                                                     self.period_null_days()))
+        plt.axvline(x=self.omega_null())
+        if not self._isChild:
+            locs, labels = plt.xticks()
+            plt.xticks(locs, np.round(1. / (locs * 24), 1))
+            plt.xlim([0.00001, 0.15])
+            plt.legend()
+            if name is None:
+                plt.show()
+            else:
+                plt.savefig(name)
+        else:
+            return transfer_function
+
+
+class KolmogorovZurbenkoFilterMovingWindow(KolmogorovZurbenkoBaseClass):
+
+    def __init__(self, df, wl: Union[list, int], itr: Union[list, int], is_child=False, filter_dim="window",
+                 method="mean", percentile=0.5):
+        """
+        It create the variables associate with the KolmogorovZurbenkoFilterMovingWindow class.
+
+        Args:
+            df(pd.DataFrame, xr.DataArray): time series of a variable
+            wl: window length
+            itr: number of iteration
+        """
+        self.valid_methods = ["mean", "percentile", "median", "max", "min"]
+        if method not in self.valid_methods:
+            raise ValueError("Method '{}' is not supported. Please select from [{}].".format(
+                method, ", ".join(self.valid_methods)))
+        else:
+            self.method = method
+            if percentile > 1 or percentile < 0:
+                raise ValueError("Percentile must be in range [0, 1]. Given was {}!".format(percentile))
+            else:
+                self.percentile = percentile
+        super().__init__(df, wl, itr, is_child, filter_dim)
+
+    def set_child(self):
+        if len(self.wl) > 1:
+            return KolmogorovZurbenkoFilterMovingWindow(self.df, self.wl[1:], self.itr[1:], is_child=True,
+                                                        filter_dim=self.filter_dim, method=self.method,
+                                                        percentile=self.percentile)
+        else:
+            return None
+
+    @TimeTrackingWrapper
+    def kz_filter_new(self, df, wl, itr):
+        """
+        It passes the low frequency time series.
+
+        If filter method is from mean, max, min this method will call construct and rechunk before the actual
+        calculation to improve performance. If filter method is either median or percentile this approach is not
+        applicable and depending on the data and window size, this method can become slow.
+
+        Args:
+             wl(int): a window length
+             itr(int): a number of iteration
+        """
+        warnings.filterwarnings("ignore")
+        df_itr = df.__deepcopy__()
+        try:
+            kwargs = {"min_periods": int(0.7 * wl),
+                      "center": True,
+                      self.filter_dim: wl}
+            for i in np.arange(0, itr):
+                print(i)
+                rolling = df_itr.chunk().rolling(**kwargs)
+                if self.method not in ["percentile", "median"]:
+                    rolling = rolling.construct("construct").chunk("auto")
+                if self.method == "median":
+                    df_mv_avg_tmp = rolling.median()
+                elif self.method == "percentile":
+                    df_mv_avg_tmp = rolling.quantile(self.percentile)
+                elif self.method == "max":
+                    df_mv_avg_tmp = rolling.max("construct")
+                elif self.method == "min":
+                    df_mv_avg_tmp = rolling.min("construct")
+                else:
+                    df_mv_avg_tmp = rolling.mean("construct")
+                df_itr = df_mv_avg_tmp.compute()
+                del df_mv_avg_tmp, rolling
+                gc.collect()
+            return df_itr
+        except ValueError:
+            raise ValueError
+
+    @TimeTrackingWrapper
+    def kz_filter(self, df, wl, itr):
+        """
+        It passes the low frequency time series.
+
+        Args:
+             wl(int): a window length
+             itr(int): a number of iteration
+        """
+        import warnings
+        warnings.filterwarnings("ignore")
+        df_itr = df.__deepcopy__()
+        try:
+            kwargs = {"min_periods": int(0.7 * wl),
+                      "center": True,
+                      self.filter_dim: wl}
+            iter_vars = df_itr.coords["variables"].values
+            for var in iter_vars:
+                df_itr_var = df_itr.sel(variables=[var])
+                for _ in np.arange(0, itr):
+                    df_itr_var = df_itr_var.chunk()
+                    rolling = df_itr_var.rolling(**kwargs)
+                    if self.method == "median":
+                        df_mv_avg_tmp = rolling.median()
+                    elif self.method == "percentile":
+                        df_mv_avg_tmp = rolling.quantile(self.percentile)
+                    elif self.method == "max":
+                        df_mv_avg_tmp = rolling.max()
+                    elif self.method == "min":
+                        df_mv_avg_tmp = rolling.min()
+                    else:
+                        df_mv_avg_tmp = rolling.mean()
+                    df_itr_var = df_mv_avg_tmp.compute()
+                df_itr.loc[{"variables": [var]}] = df_itr_var
+            return df_itr
+        except ValueError:
+            raise ValueError
+
+
+def firwin_kzf(m, k):
+    coef = np.ones(m)
+    for i in range(1, k):
+        t = np.zeros((m, m + i * (m - 1)))
+        for km in range(m):
+            t[km, km:km + coef.size] = coef
+        coef = np.sum(t, axis=0)
+    return coef / m ** k
+
+
+def omega_null_kzf(m, k, alpha=0.5):
+    a = np.sqrt(6) / np.pi
+    b = 1 / (2 * np.array(k))
+    c = 1 - alpha ** b
+    d = np.array(m) ** 2 - alpha ** b
+    return a * np.sqrt(c / d)
+
+
+def filter_width_kzf(m, k):
+    return k * (m - 1) + 1
diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py
index b57b733b08c4635a16d7fd18e99538a991521fd8..5ddaa3ee3fe505eeb7c8082274d9cd888cec720f 100644
--- a/mlair/helpers/helpers.py
+++ b/mlair/helpers/helpers.py
@@ -9,7 +9,7 @@ import numpy as np
 import xarray as xr
 import dask.array as da
 
-from typing import Dict, Callable, Union, List, Any
+from typing import Dict, Callable, Union, List, Any, Tuple
 
 
 def to_list(obj: Any) -> List:
@@ -68,9 +68,9 @@ def float_round(number: float, decimals: int = 0, round_type: Callable = math.ce
     return round_type(number * multiplier) / multiplier
 
 
-def remove_items(obj: Union[List, Dict], items: Any):
+def remove_items(obj: Union[List, Dict, Tuple], items: Any):
     """
-    Remove item(s) from either list or dictionary.
+    Remove item(s) from either list, tuple or dictionary.
 
     :param obj: object to remove items from (either dictionary or list)
     :param items: elements to remove from obj. Can either be a list or single entry / key
@@ -99,6 +99,8 @@ def remove_items(obj: Union[List, Dict], items: Any):
         return remove_from_list(obj, items)
     elif isinstance(obj, dict):
         return remove_from_dict(obj, items)
+    elif isinstance(obj, tuple):
+        return tuple(remove_from_list(to_list(obj), items))
     else:
         raise TypeError(f"{inspect.stack()[0][3]} does not support type {type(obj)}.")
 
@@ -177,5 +179,3 @@ def convert2xrda(arr: Union[xr.DataArray, xr.Dataset, np.ndarray, int, float],
             kwargs.update({'dims': dims, 'coords': coords})
 
         return xr.DataArray(arr, **kwargs)
-
-
diff --git a/mlair/helpers/statistics.py b/mlair/helpers/statistics.py
index 30391998c65950f12fc6824626638788e1bd721b..a1e713a8c135800d02ff7c27894485a5da7fae37 100644
--- a/mlair/helpers/statistics.py
+++ b/mlair/helpers/statistics.py
@@ -9,12 +9,7 @@ import numpy as np
 import xarray as xr
 import pandas as pd
 from typing import Union, Tuple, Dict, List
-from matplotlib import pyplot as plt
 import itertools
-import gc
-import warnings
-
-from mlair.helpers import to_list, TimeTracking, TimeTrackingWrapper
 
 Data = Union[xr.DataArray, pd.DataFrame]
 
@@ -262,11 +257,12 @@ class SkillScores:
     """
     models_default = ["cnn", "persi", "ols"]
 
-    def __init__(self, external_data: Data, models=None, observation_name="obs"):
+    def __init__(self, external_data: Union[Data, None], models=None, observation_name="obs", ahead_dim="ahead"):
         """Set internal data."""
         self.external_data = external_data
         self.models = self.set_model_names(models)
         self.observation_name = observation_name
+        self.ahead_dim = ahead_dim
 
     def set_model_names(self, models: List[str]) -> List[str]:
         """Either use given models or use defaults."""
@@ -288,19 +284,17 @@ class SkillScores:
         combination_strings = [f"{first}-{second}" for (first, second) in combinations]
         return combinations, combination_strings
 
-    def skill_scores(self, window_lead_time: int) -> pd.DataFrame:
+    def skill_scores(self) -> pd.DataFrame:
         """
         Calculate skill scores for all combinations of model names.
 
-        :param window_lead_time: length of forecast steps
-
         :return: skill score for each comparison and forecast step
         """
-        ahead_names = list(range(1, window_lead_time + 1))
+        ahead_names = list(self.external_data[self.ahead_dim].data)
         combinations, combination_strings = self.get_model_name_combinations()
         skill_score = pd.DataFrame(index=combination_strings)
         for iahead in ahead_names:
-            data = self.external_data.sel(ahead=iahead)
+            data = self.external_data.sel({self.ahead_dim: iahead})
             skill_score[iahead] = [self.general_skill_score(data,
                                                             forecast_name=first,
                                                             reference_name=second,
@@ -308,8 +302,7 @@ class SkillScores:
                                    for (first, second) in combinations]
         return skill_score
 
-    def climatological_skill_scores(self, internal_data: Data, window_lead_time: int,
-                                    forecast_name: str) -> xr.DataArray:
+    def climatological_skill_scores(self, internal_data: Data, forecast_name: str) -> xr.DataArray:
         """
         Calculate climatological skill scores according to Murphy (1988).
 
@@ -317,20 +310,19 @@ class SkillScores:
         is part of parameters.
 
         :param internal_data: internal data
-        :param window_lead_time: interested time step of forecast horizon to select data
         :param forecast_name: name of the forecast to use for this calculation (must be available in `data`)
 
         :return: all CASES as well as all terms
         """
-        ahead_names = list(range(1, window_lead_time + 1))
+        ahead_names = list(self.external_data[self.ahead_dim].data)
 
         all_terms = ['AI', 'AII', 'AIII', 'AIV', 'BI', 'BII', 'BIV', 'CI', 'CIV', 'CASE I', 'CASE II', 'CASE III',
                      'CASE IV']
         skill_score = xr.DataArray(np.full((len(all_terms), len(ahead_names)), np.nan), coords=[all_terms, ahead_names],
-                                   dims=['terms', 'ahead'])
+                                   dims=['terms', self.ahead_dim])
 
         for iahead in ahead_names:
-            data = internal_data.sel(ahead=iahead)
+            data = internal_data.sel({self.ahead_dim: iahead})
 
             skill_score.loc[["CASE I", "AI", "BI", "CI"], iahead] = np.stack(self._climatological_skill_score(
                 data, mu_type=1, forecast_name=forecast_name, observation_name=self.observation_name).values.flatten())
@@ -338,8 +330,8 @@ class SkillScores:
             skill_score.loc[["CASE II", "AII", "BII"], iahead] = np.stack(self._climatological_skill_score(
                 data, mu_type=2, forecast_name=forecast_name, observation_name=self.observation_name).values.flatten())
 
-            if self.external_data is not None:
-                external_data = self.external_data.sel(ahead=iahead, type=[self.observation_name])
+            if self.external_data is not None and self.observation_name in self.external_data.coords["type"]:
+                external_data = self.external_data.sel({self.ahead_dim: iahead, "type": [self.observation_name]})
                 skill_score.loc[["CASE III", "AIII"], iahead] = np.stack(self._climatological_skill_score(
                     data, mu_type=3, forecast_name=forecast_name, observation_name=self.observation_name,
                     external_data=external_data).values.flatten())
@@ -378,12 +370,12 @@ class SkillScores:
         skill_score = 1 - mse(observation, forecast) / mse(observation, reference)
         return skill_score.values
 
-    @staticmethod
-    def skill_score_pre_calculations(data: Data, observation_name: str, forecast_name: str) -> Tuple[np.ndarray,
-                                                                                                     np.ndarray,
-                                                                                                     np.ndarray,
-                                                                                                     Data,
-                                                                                                     Dict[str, Data]]:
+    def skill_score_pre_calculations(self, data: Data, observation_name: str, forecast_name: str) -> Tuple[np.ndarray,
+                                                                                                           np.ndarray,
+                                                                                                           np.ndarray,
+                                                                                                           Data,
+                                                                                                           Dict[
+                                                                                                               str, Data]]:
         """
         Calculate terms AI, BI, and CI, mean, variance and pearson's correlation and clean up data.
 
@@ -396,7 +388,7 @@ class SkillScores:
 
         :returns: Terms AI, BI, and CI, internal data without nans and mean, variance, correlation and its p-value
         """
-        data = data.sel(type=[observation_name, forecast_name]).drop("ahead")
+        data = data.sel(type=[observation_name, forecast_name]).drop(self.ahead_dim)
         data = data.dropna("index")
 
         mean = data.mean("index")
@@ -483,212 +475,3 @@ class SkillScores:
 
         return monthly_mean
 
-
-class KolmogorovZurbenkoBaseClass:
-
-    def __init__(self, df, wl, itr, is_child=False, filter_dim="window"):
-        """
-        It create the variables associate with the Kolmogorov-Zurbenko-filter.
-
-        Args:
-            df(pd.DataFrame, None): time series of a variable
-            wl(list of int): window length
-            itr(list of int): number of iteration
-        """
-        self.df = df
-        self.filter_dim = filter_dim
-        self.wl = to_list(wl)
-        self.itr = to_list(itr)
-        if abs(len(self.wl) - len(self.itr)) > 0:
-            raise ValueError("Length of lists for wl and itr must agree!")
-        self._isChild = is_child
-        self.child = self.set_child()
-        self.type = type(self).__name__
-
-    def set_child(self):
-        if len(self.wl) > 1:
-            return KolmogorovZurbenkoBaseClass(None, self.wl[1:], self.itr[1:], True, self.filter_dim)
-        else:
-            return None
-
-    def kz_filter(self, df, m, k):
-        pass
-
-    def spectral_calc(self):
-        df_start = self.df
-        kz = self.kz_filter(df_start, self.wl[0], self.itr[0])
-        filtered = self.subtract(df_start, kz)
-        # case I: no child avail -> return kz and remaining
-        if self.child is None:
-            return [kz, filtered]
-        # case II: has child -> return current kz and all child results
-        else:
-            self.child.df = filtered
-            kz_next = self.child.spectral_calc()
-            return [kz] + kz_next
-
-    @staticmethod
-    def subtract(minuend, subtrahend):
-        try:  # pandas implementation
-            return minuend.sub(subtrahend, axis=0)
-        except AttributeError:  # general implementation
-            return minuend - subtrahend
-
-    def run(self):
-        return self.spectral_calc()
-
-    def transfer_function(self):
-        m = self.wl[0]
-        k = self.itr[0]
-        omega = np.linspace(0.00001, 0.15, 5000)
-        return omega, (np.sin(m * np.pi * omega) / (m * np.sin(np.pi * omega))) ** (2 * k)
-
-    def omega_null(self, alpha=0.5):
-        a = np.sqrt(6) / np.pi
-        b = 1 / (2 * np.array(self.itr))
-        c = 1 - alpha ** b
-        d = np.array(self.wl) ** 2 - alpha ** b
-        return a * np.sqrt(c / d)
-
-    def period_null(self, alpha=0.5):
-        return 1. / self.omega_null(alpha)
-
-    def period_null_days(self, alpha=0.5):
-        return self.period_null(alpha) / 24.
-
-    def plot_transfer_function(self, fig=None, name=None):
-        if fig is None:
-            fig = plt.figure()
-        omega, transfer_function = self.transfer_function()
-        if self.child is not None:
-            transfer_function_child = self.child.plot_transfer_function(fig)
-        else:
-            transfer_function_child = transfer_function * 0
-        plt.semilogx(omega, transfer_function - transfer_function_child,
-                     label="m={:3.0f}, k={:3.0f}, T={:6.2f}d".format(self.wl[0],
-                                                                     self.itr[0],
-                                                                     self.period_null_days()))
-        plt.axvline(x=self.omega_null())
-        if not self._isChild:
-            locs, labels = plt.xticks()
-            plt.xticks(locs, np.round(1. / (locs * 24), 1))
-            plt.xlim([0.00001, 0.15])
-            plt.legend()
-            if name is None:
-                plt.show()
-            else:
-                plt.savefig(name)
-        else:
-            return transfer_function
-
-
-class KolmogorovZurbenkoFilterMovingWindow(KolmogorovZurbenkoBaseClass):
-
-    def __init__(self, df, wl: Union[list, int], itr: Union[list, int], is_child=False, filter_dim="window",
-                 method="mean", percentile=0.5):
-        """
-        It create the variables associate with the KolmogorovZurbenkoFilterMovingWindow class.
-
-        Args:
-            df(pd.DataFrame, xr.DataArray): time series of a variable
-            wl: window length
-            itr: number of iteration
-        """
-        self.valid_methods = ["mean", "percentile", "median", "max", "min"]
-        if method not in self.valid_methods:
-            raise ValueError("Method '{}' is not supported. Please select from [{}].".format(
-                method, ", ".join(self.valid_methods)))
-        else:
-            self.method = method
-            if percentile > 1 or percentile < 0:
-                raise ValueError("Percentile must be in range [0, 1]. Given was {}!".format(percentile))
-            else:
-                self.percentile = percentile
-        super().__init__(df, wl, itr, is_child, filter_dim)
-
-    def set_child(self):
-        if len(self.wl) > 1:
-            return KolmogorovZurbenkoFilterMovingWindow(self.df, self.wl[1:], self.itr[1:], is_child=True,
-                                                        filter_dim=self.filter_dim, method=self.method,
-                                                        percentile=self.percentile)
-        else:
-            return None
-
-    @TimeTrackingWrapper
-    def kz_filter_new(self, df, wl, itr):
-        """
-        It passes the low frequency time series.
-
-        If filter method is from mean, max, min this method will call construct and rechunk before the actual
-        calculation to improve performance. If filter method is either median or percentile this approach is not
-        applicable and depending on the data and window size, this method can become slow.
-
-        Args:
-             wl(int): a window length
-             itr(int): a number of iteration
-        """
-        warnings.filterwarnings("ignore")
-        df_itr = df.__deepcopy__()
-        try:
-            kwargs = {"min_periods": int(0.7 * wl),
-                      "center": True,
-                      self.filter_dim: wl}
-            for i in np.arange(0, itr):
-                print(i)
-                rolling = df_itr.chunk().rolling(**kwargs)
-                if self.method not in ["percentile", "median"]:
-                    rolling = rolling.construct("construct").chunk("auto")
-                if self.method == "median":
-                    df_mv_avg_tmp = rolling.median()
-                elif self.method == "percentile":
-                    df_mv_avg_tmp = rolling.quantile(self.percentile)
-                elif self.method == "max":
-                    df_mv_avg_tmp = rolling.max("construct")
-                elif self.method == "min":
-                    df_mv_avg_tmp = rolling.min("construct")
-                else:
-                    df_mv_avg_tmp = rolling.mean("construct")
-                df_itr = df_mv_avg_tmp.compute()
-                del df_mv_avg_tmp, rolling
-                gc.collect()
-            return df_itr
-        except ValueError:
-            raise ValueError
-
-    @TimeTrackingWrapper
-    def kz_filter(self, df, wl, itr):
-        """
-        It passes the low frequency time series.
-
-        Args:
-             wl(int): a window length
-             itr(int): a number of iteration
-        """
-        import warnings
-        warnings.filterwarnings("ignore")
-        df_itr = df.__deepcopy__()
-        try:
-            kwargs = {"min_periods": int(0.7 * wl),
-                      "center": True,
-                      self.filter_dim: wl}
-            iter_vars = df_itr.coords["variables"].values
-            for var in iter_vars:
-                df_itr_var = df_itr.sel(variables=[var])
-                for _ in np.arange(0, itr):
-                    df_itr_var = df_itr_var.chunk()
-                    rolling = df_itr_var.rolling(**kwargs)
-                    if self.method == "median":
-                        df_mv_avg_tmp = rolling.median()
-                    elif self.method == "percentile":
-                        df_mv_avg_tmp = rolling.quantile(self.percentile)
-                    elif self.method == "max":
-                        df_mv_avg_tmp = rolling.max()
-                    elif self.method == "min":
-                        df_mv_avg_tmp = rolling.min()
-                    else:
-                        df_mv_avg_tmp = rolling.mean()
-                    df_itr_var = df_mv_avg_tmp.compute()
-                df_itr.loc[{"variables": [var]}] = df_itr_var
-            return df_itr
-        except ValueError:
-            raise ValueError
diff --git a/mlair/helpers/time_tracking.py b/mlair/helpers/time_tracking.py
index c85a6a047943a589a9d076584ae40186634db767..3105ebcd04406b7d449ba312bd3af46f83e3a716 100644
--- a/mlair/helpers/time_tracking.py
+++ b/mlair/helpers/time_tracking.py
@@ -68,11 +68,12 @@ class TimeTracking(object):
     The only disadvantage of the latter implementation is, that the duration is logged but not returned.
     """
 
-    def __init__(self, start=True, name="undefined job"):
+    def __init__(self, start=True, name="undefined job", logging_level=logging.INFO):
         """Construct time tracking and start if enabled."""
         self.start = None
         self.end = None
         self._name = name
+        self._logging = {logging.INFO: logging.info, logging.DEBUG: logging.debug}.get(logging_level, logging.info)
         if start:
             self._start()
 
@@ -128,4 +129,4 @@ class TimeTracking(object):
     def __exit__(self, exc_type, exc_val, exc_tb) -> None:
         """Stop time tracking on exit and log info about passed time."""
         self.stop()
-        logging.info(f"{self._name} finished after {self}")
\ No newline at end of file
+        self._logging(f"{self._name} finished after {self}")
diff --git a/mlair/model_modules/fully_connected_networks.py b/mlair/model_modules/fully_connected_networks.py
index 9fb08cdf6efacab12c2828ed221966586bce1d08..0338033315d294c2e54de8b038bba2123d2fee77 100644
--- a/mlair/model_modules/fully_connected_networks.py
+++ b/mlair/model_modules/fully_connected_networks.py
@@ -1,11 +1,11 @@
 __author__ = "Lukas Leufen"
-__date__ = '2021-02-'
+__date__ = '2021-02-18'
 
 from functools import reduce, partial
 
 from mlair.model_modules import AbstractModelClass
 from mlair.helpers import select_from_dict
-from mlair.model_modules.loss import var_loss, custom_loss
+from mlair.model_modules.loss import var_loss, custom_loss, l_p_loss
 
 import keras
 
@@ -20,7 +20,8 @@ class FCN(AbstractModelClass):
                    "sigmoid": partial(keras.layers.Activation, "sigmoid"),
                    "linear": partial(keras.layers.Activation, "linear"),
                    "selu": partial(keras.layers.Activation, "selu"),
-                   "prelu": partial(keras.layers.PReLU, alpha_initializer=keras.initializers.constant(value=0.25))}
+                   "prelu": partial(keras.layers.PReLU, alpha_initializer=keras.initializers.constant(value=0.25)),
+                   "leakyrelu": partial(keras.layers.LeakyReLU)}
     _initializer = {"tanh": "glorot_uniform", "sigmoid": "glorot_uniform", "linear": "glorot_uniform",
                     "relu": keras.initializers.he_normal(), "selu": keras.initializers.lecun_normal(),
                     "prelu": keras.initializers.he_normal()}
@@ -31,12 +32,31 @@ class FCN(AbstractModelClass):
 
     def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear",
                  optimizer="adam", n_layer=1, n_hidden=10, regularizer=None, dropout=None, layer_configuration=None,
-                 **kwargs):
+                 batch_normalization=False, **kwargs):
         """
         Sets model and loss depending on the given arguments.
 
         :param input_shape: list of input shapes (expect len=1 with shape=(window_hist, station, variables))
         :param output_shape: list of output shapes (expect len=1 with shape=(window_forecast))
+
+        Customize this FCN model via the following parameters:
+
+        :param activation: set your desired activation function. Chose from relu, tanh, sigmoid, linear, selu, prelu,
+            leakyrelu. (Default relu)
+        :param activation_output: same as activation parameter but exclusively applied on output layer only. (Default
+            linear)
+        :param optimizer: set optimizer method. Can be either adam or sgd. (Default adam)
+        :param n_layer: define number of hidden layers in the network. Given number of hidden neurons are used in each
+            layer. (Default 1)
+        :param n_hidden: define number of hidden units per layer. This number is used in each hidden layer. (Default 10)
+        :param layer_configuration: alternative formulation of the network's architecture. This will overwrite the
+            settings from n_layer and n_hidden. Provide a list where each element represent the number of units in the
+            hidden layer. The number of hidden layers is equal to the total length of this list.
+        :param dropout: use dropout with given rate. If no value is provided, dropout layers are not added to the
+            network at all. (Default None)
+        :param batch_normalization: use batch normalization layer in the network if enabled. These layers are inserted
+            between the linear part of a layer (the nn part) and the non-linear part (activation function). No BN layer
+            is added if set to false. (Default false)
         """
 
         assert len(input_shape) == 1
@@ -49,6 +69,7 @@ class FCN(AbstractModelClass):
         self.activation_output = self._set_activation(activation_output)
         self.activation_output_name = activation_output
         self.optimizer = self._set_optimizer(optimizer, **kwargs)
+        self.bn = batch_normalization
         self.layer_configuration = (n_layer, n_hidden) if layer_configuration is None else layer_configuration
         self._update_model_name()
         self.kernel_initializer = self._initializer.get(activation, "glorot_uniform")
@@ -58,7 +79,7 @@ class FCN(AbstractModelClass):
         # apply to model
         self.set_model()
         self.set_compile_options()
-        self.set_custom_objects(loss=self.compile_options["loss"][0], var_loss=var_loss)
+        self.set_custom_objects(loss=self.compile_options["loss"][0], var_loss=var_loss, l_p_loss=l_p_loss(.5))
 
     def _set_activation(self, activation):
         try:
@@ -115,27 +136,29 @@ class FCN(AbstractModelClass):
         """
         Build the model.
         """
-        x_input = keras.layers.Input(shape=self._input_shape)
-        x_in = keras.layers.Flatten()(x_input)
         if isinstance(self.layer_configuration, tuple) is True:
             n_layer, n_hidden = self.layer_configuration
-            for layer in range(n_layer):
-                x_in = keras.layers.Dense(n_hidden, kernel_initializer=self.kernel_initializer,
-                                          kernel_regularizer=self.kernel_regularizer)(x_in)
-                x_in = self.activation(name=f"{self.activation_name}_{layer + 1}")(x_in)
-                if self.dropout is not None:
-                    x_in = self.dropout(self.dropout_rate)(x_in)
+            conf = [n_hidden for _ in range(n_layer)]
         else:
             assert isinstance(self.layer_configuration, list) is True
-            for layer, n_hidden in enumerate(self.layer_configuration):
-                x_in = keras.layers.Dense(n_hidden, kernel_initializer=self.kernel_initializer,
-                                          kernel_regularizer=self.kernel_regularizer)(x_in)
-                x_in = self.activation(name=f"{self.activation_name}_{layer + 1}")(x_in)
-                if self.dropout is not None:
-                    x_in = self.dropout(self.dropout_rate)(x_in)
+            conf = self.layer_configuration
+
+        x_input = keras.layers.Input(shape=self._input_shape)
+        x_in = keras.layers.Flatten()(x_input)
+
+        for layer, n_hidden in enumerate(conf):
+            x_in = keras.layers.Dense(n_hidden, kernel_initializer=self.kernel_initializer,
+                                      kernel_regularizer=self.kernel_regularizer)(x_in)
+            if self.bn is True:
+                x_in = keras.layers.BatchNormalization()(x_in)
+            x_in = self.activation(name=f"{self.activation_name}_{layer + 1}")(x_in)
+            if self.dropout is not None:
+                x_in = self.dropout(self.dropout_rate)(x_in)
+
         x_in = keras.layers.Dense(self._output_shape)(x_in)
         out = self.activation_output(name=f"{self.activation_output_name}_output")(x_in)
         self.model = keras.Model(inputs=x_input, outputs=[out])
+        print(self.model.summary())
 
     def set_compile_options(self):
         self.compile_options = {"loss": [custom_loss([keras.losses.mean_squared_error, var_loss])],
@@ -167,3 +190,191 @@ class FCN_64_32_16(FCN):
     def _update_model_name(self):
         self.model_name = "FCN"
         super()._update_model_name()
+
+
+class BranchedInputFCN(AbstractModelClass):
+    """
+    A customisable fully connected network (64, 32, 16, window_lead_time), where the last layer is the output layer depending
+    on the window_lead_time parameter.
+    """
+
+    _activation = {"relu": keras.layers.ReLU, "tanh": partial(keras.layers.Activation, "tanh"),
+                   "sigmoid": partial(keras.layers.Activation, "sigmoid"),
+                   "linear": partial(keras.layers.Activation, "linear"),
+                   "selu": partial(keras.layers.Activation, "selu"),
+                   "prelu": partial(keras.layers.PReLU, alpha_initializer=keras.initializers.constant(value=0.25)),
+                   "leakyrelu": partial(keras.layers.LeakyReLU)}
+    _initializer = {"tanh": "glorot_uniform", "sigmoid": "glorot_uniform", "linear": "glorot_uniform",
+                    "relu": keras.initializers.he_normal(), "selu": keras.initializers.lecun_normal(),
+                    "prelu": keras.initializers.he_normal()}
+    _optimizer = {"adam": keras.optimizers.adam, "sgd": keras.optimizers.SGD}
+    _regularizer = {"l1": keras.regularizers.l1, "l2": keras.regularizers.l2, "l1_l2": keras.regularizers.l1_l2}
+    _requirements = ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad", "momentum", "nesterov", "l1", "l2"]
+    _dropout = {"selu": keras.layers.AlphaDropout}
+
+    def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear",
+                 optimizer="adam", n_layer=1, n_hidden=10, regularizer=None, dropout=None, layer_configuration=None,
+                 batch_normalization=False, **kwargs):
+        """
+        Sets model and loss depending on the given arguments.
+
+        :param input_shape: list of input shapes (expect len=1 with shape=(window_hist, station, variables))
+        :param output_shape: list of output shapes (expect len=1 with shape=(window_forecast))
+
+        Customize this FCN model via the following parameters:
+
+        :param activation: set your desired activation function. Chose from relu, tanh, sigmoid, linear, selu, prelu,
+            leakyrelu. (Default relu)
+        :param activation_output: same as activation parameter but exclusively applied on output layer only. (Default
+            linear)
+        :param optimizer: set optimizer method. Can be either adam or sgd. (Default adam)
+        :param n_layer: define number of hidden layers in the network. Given number of hidden neurons are used in each
+            layer. (Default 1)
+        :param n_hidden: define number of hidden units per layer. This number is used in each hidden layer. (Default 10)
+        :param layer_configuration: alternative formulation of the network's architecture. This will overwrite the
+            settings from n_layer and n_hidden. Provide a list where each element represent the number of units in the
+            hidden layer. The number of hidden layers is equal to the total length of this list.
+        :param dropout: use dropout with given rate. If no value is provided, dropout layers are not added to the
+            network at all. (Default None)
+        :param batch_normalization: use batch normalization layer in the network if enabled. These layers are inserted
+            between the linear part of a layer (the nn part) and the non-linear part (activation function). No BN layer
+            is added if set to false. (Default false)
+        """
+
+        super().__init__(input_shape, output_shape[0])
+
+        # settings
+        self.activation = self._set_activation(activation)
+        self.activation_name = activation
+        self.activation_output = self._set_activation(activation_output)
+        self.activation_output_name = activation_output
+        self.optimizer = self._set_optimizer(optimizer, **kwargs)
+        self.bn = batch_normalization
+        self.layer_configuration = (n_layer, n_hidden) if layer_configuration is None else layer_configuration
+        self._update_model_name()
+        self.kernel_initializer = self._initializer.get(activation, "glorot_uniform")
+        self.kernel_regularizer = self._set_regularizer(regularizer, **kwargs)
+        self.dropout, self.dropout_rate = self._set_dropout(activation, dropout)
+
+        # apply to model
+        self.set_model()
+        self.set_compile_options()
+        self.set_custom_objects(loss=self.compile_options["loss"][0], var_loss=var_loss)
+
+    def _set_activation(self, activation):
+        try:
+            return self._activation.get(activation.lower())
+        except KeyError:
+            raise AttributeError(f"Given activation {activation} is not supported in this model class.")
+
+    def _set_optimizer(self, optimizer, **kwargs):
+        try:
+            opt_name = optimizer.lower()
+            opt = self._optimizer.get(opt_name)
+            opt_kwargs = {}
+            if opt_name == "adam":
+                opt_kwargs = select_from_dict(kwargs, ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad"])
+            elif opt_name == "sgd":
+                opt_kwargs = select_from_dict(kwargs, ["lr", "momentum", "decay", "nesterov"])
+            return opt(**opt_kwargs)
+        except KeyError:
+            raise AttributeError(f"Given optimizer {optimizer} is not supported in this model class.")
+
+    def _set_regularizer(self, regularizer, **kwargs):
+        if regularizer is None or (isinstance(regularizer, str) and regularizer.lower() == "none"):
+            return None
+        try:
+            reg_name = regularizer.lower()
+            reg = self._regularizer.get(reg_name)
+            reg_kwargs = {}
+            if reg_name in ["l1", "l2"]:
+                reg_kwargs = select_from_dict(kwargs, reg_name, remove_none=True)
+                if reg_name in reg_kwargs:
+                    reg_kwargs["l"] = reg_kwargs.pop(reg_name)
+            elif reg_name == "l1_l2":
+                reg_kwargs = select_from_dict(kwargs, ["l1", "l2"], remove_none=True)
+            return reg(**reg_kwargs)
+        except KeyError:
+            raise AttributeError(f"Given regularizer {regularizer} is not supported in this model class.")
+
+    def _set_dropout(self, activation, dropout_rate):
+        if dropout_rate is None:
+            return None, None
+        assert 0 <= dropout_rate < 1
+        return self._dropout.get(activation, keras.layers.Dropout), dropout_rate
+
+    def _update_model_name(self):
+        n_input = f"{len(self._input_shape)}x{str(reduce(lambda x, y: x * y, self._input_shape[0]))}"
+        n_output = str(self._output_shape)
+
+        if isinstance(self.layer_configuration, tuple) and len(self.layer_configuration) == 2:
+            n_layer, n_hidden = self.layer_configuration
+            branch = [f"{n_hidden}" for _ in range(n_layer)]
+        else:
+            branch = [f"{n}" for n in self.layer_configuration]
+
+        concat = []
+        n_neurons_concat = int(branch[-1]) * len(self._input_shape)
+        for exp in reversed(range(2, len(self._input_shape) + 1)):
+            n_neurons = self._output_shape ** exp
+            if n_neurons < n_neurons_concat:
+                if len(concat) == 0:
+                    concat.append(f"1x{n_neurons}")
+                else:
+                    concat.append(str(n_neurons))
+        self.model_name += "_".join(["", n_input, *branch, *concat, n_output])
+
+    def set_model(self):
+        """
+        Build the model.
+        """
+
+        if isinstance(self.layer_configuration, tuple) is True:
+            n_layer, n_hidden = self.layer_configuration
+            conf = [n_hidden for _ in range(n_layer)]
+        else:
+            assert isinstance(self.layer_configuration, list) is True
+            conf = self.layer_configuration
+
+        x_input = []
+        x_in = []
+
+        for branch in range(len(self._input_shape)):
+            x_input_b = keras.layers.Input(shape=self._input_shape[branch])
+            x_input.append(x_input_b)
+            x_in_b = keras.layers.Flatten()(x_input_b)
+
+            for layer, n_hidden in enumerate(conf):
+                x_in_b = keras.layers.Dense(n_hidden, kernel_initializer=self.kernel_initializer,
+                                            kernel_regularizer=self.kernel_regularizer,
+                                            name=f"Dense_branch{branch + 1}_{layer + 1}")(x_in_b)
+                if self.bn is True:
+                    x_in_b = keras.layers.BatchNormalization()(x_in_b)
+                x_in_b = self.activation(name=f"{self.activation_name}_branch{branch + 1}_{layer + 1}")(x_in_b)
+                if self.dropout is not None:
+                    x_in_b = self.dropout(self.dropout_rate)(x_in_b)
+            x_in.append(x_in_b)
+        x_concat = keras.layers.Concatenate()(x_in)
+
+        n_neurons_concat = int(conf[-1]) * len(self._input_shape)
+        layer_concat = 0
+        for exp in reversed(range(2, len(self._input_shape) + 1)):
+            n_neurons = self._output_shape ** exp
+            if n_neurons < n_neurons_concat:
+                layer_concat += 1
+                x_concat = keras.layers.Dense(n_neurons, name=f"Dense_{layer_concat}")(x_concat)
+                if self.bn is True:
+                    x_concat = keras.layers.BatchNormalization()(x_concat)
+                x_concat = self.activation(name=f"{self.activation_name}_{layer_concat}")(x_concat)
+                if self.dropout is not None:
+                    x_concat = self.dropout(self.dropout_rate)(x_concat)
+        x_concat = keras.layers.Dense(self._output_shape)(x_concat)
+        out = self.activation_output(name=f"{self.activation_output_name}_output")(x_concat)
+        self.model = keras.Model(inputs=x_input, outputs=[out])
+        print(self.model.summary())
+
+    def set_compile_options(self):
+        self.compile_options = {"loss": [keras.losses.mean_squared_error],
+                                "metrics": ["mse", "mae", var_loss]}
+        # self.compile_options = {"loss": [custom_loss([keras.losses.mean_squared_error, var_loss], loss_weights=[2, 1])],
+        #                         "metrics": ["mse", "mae", var_loss]}
diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py
index 33358e566ef80f28ee7740531b71d1a83abde115..e0f54282010e765fb3d8b0aca191a75c0b22fdf9 100644
--- a/mlair/model_modules/keras_extensions.py
+++ b/mlair/model_modules/keras_extensions.py
@@ -8,6 +8,7 @@ import math
 import pickle
 from typing import Union, List
 from typing_extensions import TypedDict
+from time import time
 
 import numpy as np
 from keras import backend as K
@@ -111,6 +112,20 @@ class LearningRateDecay(History):
         return K.get_value(self.model.optimizer.lr)
 
 
+class EpoTimingCallback(Callback):
+    def __init__(self):
+        self.epo_timing = {'epo_timing': []}
+        self.logs = []
+        self.starttime = None
+        super().__init__()
+
+    def on_epoch_begin(self, epoch: int, logs=None):
+        self.starttime = time()
+
+    def on_epoch_end(self, epoch: int, logs=None):
+        self.epo_timing["epo_timing"].append(time()-self.starttime)
+
+
 class ModelCheckpointAdvanced(ModelCheckpoint):
     """
     Enhance the standard ModelCheckpoint class by additional saves of given callbacks.
diff --git a/mlair/model_modules/loss.py b/mlair/model_modules/loss.py
index ba871e983ecfa1e91676d53b834ebd622c00fe49..2034c5a7795fad302d2a289e6fadbd5e295117cc 100644
--- a/mlair/model_modules/loss.py
+++ b/mlair/model_modules/loss.py
@@ -16,10 +16,10 @@ def l_p_loss(power: int) -> Callable:
     :return: loss for given power
     """
 
-    def loss(y_true, y_pred):
+    def l_p_loss(y_true, y_pred):
         return K.mean(K.pow(K.abs(y_pred - y_true), power), axis=-1)
 
-    return loss
+    return l_p_loss
 
 
 def var_loss(y_true, y_pred) -> Callable:
diff --git a/mlair/model_modules/recurrent_networks.py b/mlair/model_modules/recurrent_networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..95c48bc8659354c7c669bb03a7591dafbbe9f262
--- /dev/null
+++ b/mlair/model_modules/recurrent_networks.py
@@ -0,0 +1,194 @@
+__author__ = "Lukas Leufen"
+__date__ = '2021-05-25'
+
+from functools import reduce, partial
+
+from mlair.model_modules import AbstractModelClass
+from mlair.helpers import select_from_dict
+from mlair.model_modules.loss import var_loss, custom_loss
+
+import keras
+
+
+class RNN(AbstractModelClass):
+    """
+
+    """
+
+    _activation = {"relu": keras.layers.ReLU, "tanh": partial(keras.layers.Activation, "tanh"),
+                   "sigmoid": partial(keras.layers.Activation, "sigmoid"),
+                   "linear": partial(keras.layers.Activation, "linear"),
+                   "selu": partial(keras.layers.Activation, "selu"),
+                   "prelu": partial(keras.layers.PReLU, alpha_initializer=keras.initializers.constant(value=0.25)),
+                   "leakyrelu": partial(keras.layers.LeakyReLU)}
+    _initializer = {"tanh": "glorot_uniform", "sigmoid": "glorot_uniform", "linear": "glorot_uniform",
+                    "relu": keras.initializers.he_normal(), "selu": keras.initializers.lecun_normal(),
+                    "prelu": keras.initializers.he_normal()}
+    _optimizer = {"adam": keras.optimizers.adam, "sgd": keras.optimizers.SGD}
+    _regularizer = {"l1": keras.regularizers.l1, "l2": keras.regularizers.l2, "l1_l2": keras.regularizers.l1_l2}
+    _requirements = ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad", "momentum", "nesterov", "l1", "l2"]
+    _dropout = {"selu": keras.layers.AlphaDropout}
+    _rnn = {"lstm": keras.layers.LSTM, "gru": keras.layers.GRU}
+
+    def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear",
+                 activation_rnn="tanh", dropout_rnn=0,
+                 optimizer="adam", n_layer=1, n_hidden=10, regularizer=None, dropout=None, layer_configuration=None,
+                 batch_normalization=False, rnn_type="lstm", add_dense_layer=False, **kwargs):
+        """
+        Sets model and loss depending on the given arguments.
+
+        :param input_shape: list of input shapes (expect len=1 with shape=(window_hist, station, variables))
+        :param output_shape: list of output shapes (expect len=1 with shape=(window_forecast))
+
+        Customize this RNN model via the following parameters:
+
+        :param activation: set your desired activation function for appended dense layers (add_dense_layer=True=. Choose
+            from relu, tanh, sigmoid, linear, selu, prelu, leakyrelu. (Default relu)
+        :param activation_rnn: set your desired activation function of the rnn output. Choose from relu, tanh, sigmoid,
+            linear, selu, prelu, leakyrelu. (Default tanh)
+        :param activation_output: same as activation parameter but exclusively applied on output layer only. (Default
+            linear)
+        :param optimizer: set optimizer method. Can be either adam or sgd. (Default adam)
+        :param n_layer: define number of hidden layers in the network. Given number of hidden neurons are used in each
+            layer. (Default 1)
+        :param n_hidden: define number of hidden units per layer. This number is used in each hidden layer. (Default 10)
+        :param layer_configuration: alternative formulation of the network's architecture. This will overwrite the
+            settings from n_layer and n_hidden. Provide a list where each element represent the number of units in the
+            hidden layer. The number of hidden layers is equal to the total length of this list.
+        :param dropout: use dropout with given rate. If no value is provided, dropout layers are not added to the
+            network at all. (Default None)
+        :param dropout_rnn: use recurrent dropout with given rate. This is applied along the recursion and not after
+            a rnn layer. (Default 0)
+        :param batch_normalization: use batch normalization layer in the network if enabled. These layers are inserted
+            between the linear part of a layer (the nn part) and the non-linear part (activation function). No BN layer
+            is added if set to false. (Default false)
+        :param rnn_type: define which kind of recurrent network should be applied. Chose from either lstm or gru. All
+            units will be of this kind. (Default lstm)
+        """
+
+        assert len(input_shape) == 1
+        assert len(output_shape) == 1
+        super().__init__(input_shape[0], output_shape[0])
+
+        # settings
+        self.activation = self._set_activation(activation.lower())
+        self.activation_name = activation
+        self.activation_rnn = self._set_activation(activation_rnn.lower())
+        self.activation_rnn_name = activation
+        self.activation_output = self._set_activation(activation_output.lower())
+        self.activation_output_name = activation_output
+        self.optimizer = self._set_optimizer(optimizer.lower(), **kwargs)
+        self.bn = batch_normalization
+        self.add_dense_layer = add_dense_layer
+        self.layer_configuration = (n_layer, n_hidden) if layer_configuration is None else layer_configuration
+        self.RNN = self._rnn.get(rnn_type.lower())
+        self._update_model_name(rnn_type)
+        self.kernel_initializer = self._initializer.get(activation, "glorot_uniform")
+        # self.kernel_regularizer = self._set_regularizer(regularizer, **kwargs)
+        self.dropout, self.dropout_rate = self._set_dropout(activation, dropout)
+        assert 0 <= dropout_rnn <= 1
+        self.dropout_rnn = dropout_rnn
+
+        # apply to model
+        self.set_model()
+        self.set_compile_options()
+        self.set_custom_objects(loss=self.compile_options["loss"][0], var_loss=var_loss)
+
+    def set_model(self):
+        """
+        Build the model.
+        """
+        if isinstance(self.layer_configuration, tuple) is True:
+            n_layer, n_hidden = self.layer_configuration
+            conf = [n_hidden for _ in range(n_layer)]
+        else:
+            assert isinstance(self.layer_configuration, list) is True
+            conf = self.layer_configuration
+
+        x_input = keras.layers.Input(shape=self._input_shape)
+        x_in = keras.layers.Reshape((self._input_shape[0], reduce((lambda x, y: x * y), self._input_shape[1:])))(
+            x_input)
+
+        for layer, n_hidden in enumerate(conf):
+            return_sequences = (layer < len(conf) - 1)
+            x_in = self.RNN(n_hidden, return_sequences=return_sequences, recurrent_dropout=self.dropout_rnn)(x_in)
+            if self.bn is True:
+                x_in = keras.layers.BatchNormalization()(x_in)
+            x_in = self.activation_rnn(name=f"{self.activation_rnn_name}_{layer + 1}")(x_in)
+            if self.dropout is not None:
+                x_in = self.dropout(self.dropout_rate)(x_in)
+
+        if self.add_dense_layer is True:
+            x_in = keras.layers.Dense(min(self._output_shape ** 2, conf[-1]), name=f"Dense_{len(conf) + 1}",
+                                      kernel_initializer=self.kernel_initializer, )(x_in)
+            x_in = self.activation(name=f"{self.activation_name}_{len(conf) + 1}")(x_in)
+        x_in = keras.layers.Dense(self._output_shape)(x_in)
+        out = self.activation_output(name=f"{self.activation_output_name}_output")(x_in)
+        self.model = keras.Model(inputs=x_input, outputs=[out])
+        print(self.model.summary())
+
+        # x_in = keras.layers.LSTM(32)(x_in)
+        # if self.dropout is not None:
+        #     x_in = self.dropout(self.dropout_rate)(x_in)
+        # x_in = keras.layers.RepeatVector(self._output_shape)(x_in)
+        # x_in = keras.layers.LSTM(32, return_sequences=True)(x_in)
+        # if self.dropout is not None:
+        #     x_in = self.dropout(self.dropout_rate)(x_in)
+        # out = keras.layers.TimeDistributed(keras.layers.Dense(1))(x_in)
+        # out = keras.layers.Flatten()(out)
+
+    def _set_dropout(self, activation, dropout_rate):
+        if dropout_rate is None:
+            return None, None
+        assert 0 <= dropout_rate < 1
+        return self._dropout.get(activation, keras.layers.Dropout), dropout_rate
+
+    def _set_activation(self, activation):
+        try:
+            return self._activation.get(activation.lower())
+        except KeyError:
+            raise AttributeError(f"Given activation {activation} is not supported in this model class.")
+
+    def set_compile_options(self):
+        self.compile_options = {"loss": [keras.losses.mean_squared_error],
+                                "metrics": ["mse", "mae", var_loss]}
+
+    def _set_optimizer(self, optimizer, **kwargs):
+        try:
+            opt_name = optimizer.lower()
+            opt = self._optimizer.get(opt_name)
+            opt_kwargs = {}
+            if opt_name == "adam":
+                opt_kwargs = select_from_dict(kwargs, ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad"])
+            elif opt_name == "sgd":
+                opt_kwargs = select_from_dict(kwargs, ["lr", "momentum", "decay", "nesterov"])
+            return opt(**opt_kwargs)
+        except KeyError:
+            raise AttributeError(f"Given optimizer {optimizer} is not supported in this model class.")
+    #
+    # def _set_regularizer(self, regularizer, **kwargs):
+    #     if regularizer is None or (isinstance(regularizer, str) and regularizer.lower() == "none"):
+    #         return None
+    #     try:
+    #         reg_name = regularizer.lower()
+    #         reg = self._regularizer.get(reg_name)
+    #         reg_kwargs = {}
+    #         if reg_name in ["l1", "l2"]:
+    #             reg_kwargs = select_from_dict(kwargs, reg_name, remove_none=True)
+    #             if reg_name in reg_kwargs:
+    #                 reg_kwargs["l"] = reg_kwargs.pop(reg_name)
+    #         elif reg_name == "l1_l2":
+    #             reg_kwargs = select_from_dict(kwargs, ["l1", "l2"], remove_none=True)
+    #         return reg(**reg_kwargs)
+    #     except KeyError:
+    #         raise AttributeError(f"Given regularizer {regularizer} is not supported in this model class.")
+
+    def _update_model_name(self, rnn_type):
+        n_input = str(reduce(lambda x, y: x * y, self._input_shape))
+        n_output = str(self._output_shape)
+        self.model_name = rnn_type.upper()
+        if isinstance(self.layer_configuration, tuple) and len(self.layer_configuration) == 2:
+            n_layer, n_hidden = self.layer_configuration
+            self.model_name += "_".join(["", n_input, *[f"{n_hidden}" for _ in range(n_layer)], n_output])
+        else:
+            self.model_name += "_".join(["", n_input, *[f"{n}" for n in self.layer_configuration], n_output])
diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py
index 26376637b947f6cd97b66d584583a70c09ae868b..c4c1f4af8c6077a0f2a07b08ebc1d97d68eaf549 100644
--- a/mlair/plotting/data_insight_plotting.py
+++ b/mlair/plotting/data_insight_plotting.py
@@ -3,6 +3,7 @@ __author__ = "Lukas Leufen, Felix Kleinert"
 __date__ = '2021-04-13'
 
 from typing import List, Dict
+import dill
 import os
 import logging
 import multiprocessing
@@ -16,7 +17,7 @@ from matplotlib import lines as mlines, pyplot as plt, patches as mpatches, date
 from astropy.timeseries import LombScargle
 
 from mlair.data_handler import DataCollection
-from mlair.helpers import TimeTrackingWrapper, to_list
+from mlair.helpers import TimeTrackingWrapper, to_list, remove_items
 from mlair.plotting.abstract_plot_class import AbstractPlotClass
 
 @TimeTrackingWrapper
@@ -526,16 +527,18 @@ class PlotDataHistogram(AbstractPlotClass):  # pragma: no cover
         self.variables_dim = variables_dim
         self.time_dim = time_dim
         self.window_dim = window_dim
-        self.inputs, self.targets = self._get_inputs_targets(generators, self.variables_dim)
+        self.inputs, self.targets, number_of_branches = self._get_inputs_targets(generators, self.variables_dim)
         self.bins = {}
         self.interval_width = {}
         self.bin_edges = {}
 
         # input plots
-        self._calculate_hist(generators, self.inputs, input_data=True)
-        for subset in generators.keys():
-            self._plot(add_name="input", subset=subset)
-        self._plot_combined(add_name="input")
+        for branch_pos in range(number_of_branches):
+            self._calculate_hist(generators, self.inputs, input_data=True, branch_pos=branch_pos)
+            add_name = "input" if number_of_branches == 1 else f"input_branch_{branch_pos}"
+            for subset in generators.keys():
+                self._plot(add_name=add_name, subset=subset)
+            self._plot_combined(add_name=add_name)
 
         # target plots
         self._calculate_hist(generators, self.targets, input_data=False)
@@ -549,16 +552,17 @@ class PlotDataHistogram(AbstractPlotClass):  # pragma: no cover
         gen = gens[k][0]
         inputs = to_list(gen.get_X(as_numpy=False)[0].coords[dim].values.tolist())
         targets = to_list(gen.get_Y(as_numpy=False).coords[dim].values.tolist())
-        return inputs, targets
+        n_branches = len(gen.get_X(as_numpy=False))
+        return inputs, targets, n_branches
 
-    def _calculate_hist(self, generators, variables, input_data=True):
+    def _calculate_hist(self, generators, variables, input_data=True, branch_pos=0):
         n_bins = 100
         for set_type, generator in generators.items():
             tmp_bins = {}
             tmp_edges = {}
             end = {}
             start = {}
-            f = lambda x: x.get_X(as_numpy=False)[0] if input_data is True else x.get_Y(as_numpy=False)
+            f = lambda x: x.get_X(as_numpy=False)[branch_pos] if input_data is True else x.get_Y(as_numpy=False)
             for gen in generator:
                 w = min(abs(f(gen).coords[self.window_dim].values))
                 data = f(gen).sel({self.window_dim: w})
@@ -866,13 +870,15 @@ class PlotPeriodogram(AbstractPlotClass):  # pragma: no cover
         plot_path = os.path.join(os.path.abspath(self.plot_folder), plot_name)
         logging.info(f"... plotting {plot_name}")
         pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
-        colors = ["blue", "red", "green", "orange", "purple", "black", "grey"]
+        colors = ["grey", "blue", "red", "green", "orange", "purple", "black"]
         label_names = ["orig"] + label_names
         max_iter = len(self.plot_data)
         var_keys = self.plot_data[0].keys()
         for var in var_keys:
             fig, ax = plt.subplots()
             for i in reversed(range(max_iter)):
+                if label_names[i] == "unfiltered":
+                    continue  # do not include the filter 'unfiltered' because this is equal to the 'orig' data
                 plot_data = self.plot_data[i]
                 c = colors[i]
                 ma = pd.DataFrame(np.vstack(plot_data[var]).T).rolling(5, center=True, axis=0)
@@ -889,9 +895,13 @@ class PlotPeriodogram(AbstractPlotClass):  # pragma: no cover
         plt.close('all')
 
 
-def f_proc(var, d_var, f_index):  # pragma: no cover
+def f_proc(var, d_var, f_index, time_dim="datetime"):  # pragma: no cover
     var_str = str(var)
-    t = (d_var.datetime - d_var.datetime[0]).astype("timedelta64[h]").values / np.timedelta64(1, "D")
+    t = (d_var[time_dim] - d_var[time_dim][0]).astype("timedelta64[h]").values / np.timedelta64(1, "D")
+    if len(d_var.shape) > 1:  # use only max value if dimensions are remaining (e.g. max(window) -> latest value)
+        to_remove = remove_items(d_var.coords.dims, time_dim)
+        for e in to_list(to_remove):
+            d_var = d_var.sel({e: d_var[e].max()})
     pgram = LombScargle(t, d_var.values.flatten(), nterms=1, normalization="psd").power(f_index)
     # f, pgram = LombScargle(t, d_var.values.flatten(), nterms=1, normalization="psd").autopower()
     return var_str, f_index, pgram
@@ -923,3 +933,218 @@ def f_proc_hist(data, variables, n_bins, variables_dim):  # pragma: no cover
         res[var], bin_edges[var] = np.histogram(d.values, n_bins)
         interval_width[var] = bin_edges[var][1] - bin_edges[var][0]
     return res, interval_width, bin_edges
+
+
+class PlotClimateFirFilter(AbstractPlotClass):
+    """
+    Plot climate FIR filter components.
+
+    * Creates a separate folder climFIR inside the given plot directory.
+    * For each station up to 4 examples are shown (1 for each season).
+    * Each filtered component and its residuum is drawn in a separate plot.
+    * A filter component plot includes the climate FIR input, the filter response, the true non-causal (ideal) filter
+      input, and the corresponding ideal response (containing information about future)
+    * A filter residuum plot include the climate FIR residuum and the ideal filter residuum.
+    """
+
+    def __init__(self, plot_folder, plot_data, sampling, name):
+
+        from mlair.helpers.filter import fir_filter_convolve
+
+        # adjust default plot parameters
+        rc_params = {
+            'axes.labelsize': 'large',
+            'xtick.labelsize': 'large',
+            'ytick.labelsize': 'large',
+            'legend.fontsize': 'medium',
+            'axes.titlesize': 'large'}
+        if plot_folder is None:
+            return
+
+        self.style_dict = {
+            "original": {"color": "darkgrey", "linestyle": "dashed", "label": "original"},
+            "apriori": {"color": "darkgrey", "linestyle": "solid", "label": "estimated future"},
+            "clim": {"color": "black", "linestyle": "solid", "label": "clim filter", "linewidth": 2},
+            "ideal": {"color": "black", "linestyle": "dashed", "label": "ideal filter", "linewidth": 2},
+            "valid_area": {"color": "whitesmoke", "label": "valid area"},
+            "t0": {"color": "lightgrey", "lw": 6, "label": "$t_0$"}
+        }
+
+        plot_folder = os.path.join(os.path.abspath(plot_folder), "climFIR")
+        self.fir_filter_convolve = fir_filter_convolve
+        super().__init__(plot_folder, plot_name=None, rc_params=rc_params)
+        plot_dict, new_dim = self._prepare_data(plot_data)
+        self._name = name
+        self._plot(plot_dict, sampling, new_dim)
+        self._store_plot_data(plot_data)
+
+    def _prepare_data(self, data):
+        """Restructure plot data."""
+        plot_dict = {}
+        new_dim = None
+        for i, o in enumerate(range(len(data))):
+            plot_data = data[i]
+            for p_d in plot_data:
+                var = p_d.get("var")
+                t0 = p_d.get("t0")
+                filter_input = p_d.get("filter_input")
+                filter_input_nc = p_d.get("filter_input_nc")
+                valid_range = p_d.get("valid_range")
+                time_range = p_d.get("time_range")
+                if new_dim is None:
+                    new_dim = p_d.get("new_dim")
+                else:
+                    assert new_dim == p_d.get("new_dim")
+                h = p_d.get("h")
+                plot_dict_var = plot_dict.get(var, {})
+                plot_dict_t0 = plot_dict_var.get(t0, {})
+                plot_dict_order = {"filter_input": filter_input,
+                                   "filter_input_nc": filter_input_nc,
+                                   "valid_range": valid_range,
+                                   "time_range": time_range,
+                                   "order": len(h), "h": h}
+                plot_dict_t0[i] = plot_dict_order
+                plot_dict_var[t0] = plot_dict_t0
+                plot_dict[var] = plot_dict_var
+        return plot_dict, new_dim
+
+    def _plot(self, plot_dict, sampling, new_dim="window"):
+        td_type = {"1d": "D", "1H": "h"}.get(sampling)
+        for var, viz_date_dict in plot_dict.items():
+            for it0, t0 in enumerate(viz_date_dict.keys()):
+                viz_data = viz_date_dict[t0]
+                residuum_true = None
+                for ifilter in sorted(viz_data.keys()):
+                    data = viz_data[ifilter]
+                    filter_input = data["filter_input"]
+                    filter_input_nc = data["filter_input_nc"] if residuum_true is None else residuum_true.sel(
+                        {new_dim: filter_input.coords[new_dim]})
+                    valid_range = data["valid_range"]
+                    time_axis = data["time_range"]
+                    filter_order = data["order"]
+                    h = data["h"]
+                    fig, ax = plt.subplots()
+
+                    # plot backgrounds
+                    self._plot_valid_area(ax, t0, valid_range, td_type)
+                    self._plot_t0(ax, t0)
+
+                    # original data
+                    self._plot_original_data(ax, time_axis, filter_input_nc)
+
+                    # clim apriori
+                    self._plot_apriori(ax, time_axis, filter_input, new_dim, ifilter)
+
+                    # clim filter response
+                    residuum_estimated = self._plot_clim_filter(ax, time_axis, filter_input, new_dim, h,
+                                                                output_dtypes=filter_input.dtype)
+
+                    # ideal filter response
+                    residuum_true = self._plot_ideal_filter(ax, time_axis, filter_input_nc, new_dim, h,
+                                                            output_dtypes=filter_input.dtype)
+
+                    # set title, legend, and save plot
+                    xlims = self._set_xlim(ax, t0, filter_order, valid_range, td_type, time_axis)
+
+                    plt.title(f"Input of ClimFilter ({str(var)})")
+                    plt.legend()
+                    fig.autofmt_xdate()
+                    plt.tight_layout()
+                    self.plot_name = f"climFIR_{self._name}_{str(var)}_{it0}_{ifilter}"
+                    self._save()
+
+                    # plot residuum
+                    fig, ax = plt.subplots()
+                    self._plot_valid_area(ax, t0, valid_range, td_type)
+                    self._plot_t0(ax, t0)
+                    self._plot_series(ax, time_axis, residuum_true.values.flatten(), style="ideal")
+                    self._plot_series(ax, time_axis, residuum_estimated.values.flatten(), style="clim")
+                    ax.set_xlim(xlims)
+                    plt.title(f"Residuum of ClimFilter ({str(var)})")
+                    plt.legend(loc="upper left")
+                    fig.autofmt_xdate()
+                    plt.tight_layout()
+
+                    self.plot_name = f"climFIR_{self._name}_{str(var)}_{it0}_{ifilter}_residuum"
+                    self._save()
+
+    def _set_xlim(self, ax, t0, order, valid_range, td_type, time_axis):
+        """
+        Set xlims
+
+        Use order and valid_range to find a good zoom in that hides edges of filter values that are effected by reduced
+        filter order. Limits are returned to be usable for other plots.
+        """
+        t_minus_delta = max(1.5 * valid_range.start, 0.3 * order)
+        t_plus_delta = max(0.5 * valid_range.start, 0.3 * order)
+        t_minus = t0 + np.timedelta64(-int(t_minus_delta), td_type)
+        t_plus = t0 + np.timedelta64(int(t_plus_delta), td_type)
+        ax_start = max(t_minus, time_axis[0])
+        ax_end = min(t_plus, time_axis[-1])
+        ax.set_xlim((ax_start, ax_end))
+        return ax_start, ax_end
+
+    def _plot_valid_area(self, ax, t0, valid_range, td_type):
+        ax.axvspan(t0 + np.timedelta64(-valid_range.start, td_type),
+                   t0 + np.timedelta64(valid_range.stop - 1, td_type), **self.style_dict["valid_area"])
+
+    def _plot_t0(self, ax, t0):
+        ax.axvline(t0, **self.style_dict["t0"])
+
+    def _plot_series(self, ax, time_axis, data, style):
+        ax.plot(time_axis, data, **self.style_dict[style])
+
+    def _plot_original_data(self, ax, time_axis, data):
+        # original data
+        filter_input_nc = data
+        self._plot_series(ax, time_axis, filter_input_nc.values.flatten(), style="original")
+        # self._plot_series(ax, time_axis, filter_input_nc.values.flatten(), color="darkgrey", linestyle="dashed",
+        #                   label="original")
+
+    def _plot_apriori(self, ax, time_axis, data, new_dim, ifilter):
+        # clim apriori
+        filter_input = data
+        if ifilter == 0:
+            d_tmp = filter_input.sel(
+                {new_dim: slice(0, filter_input.coords[new_dim].values.max())}).values.flatten()
+        else:
+            d_tmp = filter_input.values.flatten()
+        self._plot_series(ax, time_axis[len(time_axis) - len(d_tmp):], d_tmp, style="apriori")
+        # self._plot_series(ax, time_axis[len(time_axis) - len(d_tmp):], d_tmp, color="darkgrey", linestyle="solid",
+        #                   label="estimated future")
+
+    def _plot_clim_filter(self, ax, time_axis, data, new_dim, h, output_dtypes):
+        filter_input = data
+        # clim filter response
+        filt = xr.apply_ufunc(self.fir_filter_convolve, filter_input,
+                              input_core_dims=[[new_dim]],
+                              output_core_dims=[[new_dim]],
+                              vectorize=True,
+                              kwargs={"h": h},
+                              output_dtypes=[output_dtypes])
+        self._plot_series(ax, time_axis, filt.values.flatten(), style="clim")
+        # self._plot_series(ax, time_axis, filt.values.flatten(), color="black", linestyle="solid",
+        #                   label="clim filter response", linewidth=2)
+        residuum_estimated = filter_input - filt
+        return residuum_estimated
+
+    def _plot_ideal_filter(self, ax, time_axis, data, new_dim, h, output_dtypes):
+        filter_input_nc = data
+        # ideal filter response
+        filt = xr.apply_ufunc(self.fir_filter_convolve, filter_input_nc,
+                              input_core_dims=[[new_dim]],
+                              output_core_dims=[[new_dim]],
+                              vectorize=True,
+                              kwargs={"h": h},
+                              output_dtypes=[output_dtypes])
+        self._plot_series(ax, time_axis, filt.values.flatten(), style="ideal")
+        # self._plot_series(ax, time_axis, filt.values.flatten(), color="black", linestyle="dashed",
+        #                   label="ideal filter response", linewidth=2)
+        residuum_true = filter_input_nc - filt
+        return residuum_true
+
+    def _store_plot_data(self, data):
+        """Store plot data. Could be loaded in a notebook to redraw."""
+        file = os.path.join(self.plot_folder, "plot_data.pickle")
+        with open(file, "wb") as f:
+            dill.dump(data, f)
diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py
index b5e76e5540a06aa5ae33ec85b0e4dfe73931dc9b..29ed4054206f77ca919a416dd1792193dec4aef6 100644
--- a/mlair/plotting/postprocessing_plotting.py
+++ b/mlair/plotting/postprocessing_plotting.py
@@ -217,6 +217,8 @@ class PlotMonthlySummary(AbstractPlotClass):
             data_nn = data.sel(type=self._model_name).squeeze()
             if len(data_nn.shape) > 1:
                 data_nn = data_nn.assign_coords(ahead=[f"{days}d" for days in data_nn.coords["ahead"].values])
+            else:
+                data_nn.coords["ahead"].values = str(data_nn.coords["ahead"].values) + "d"
 
             data_obs = data.sel(type="obs", ahead=1).squeeze()
             data_obs.coords["ahead"] = "obs"
@@ -744,7 +746,9 @@ class PlotBootstrapSkillScore(AbstractPlotClass):
 
     """
 
-    def __init__(self, data: Dict, plot_folder: str = ".", model_setup: str = "", separate_vars: List = None):
+    def __init__(self, data: Dict, plot_folder: str = ".", model_setup: str = "", separate_vars: List = None,
+                 sampling: str = "daily", ahead_dim: str = "ahead", bootstrap_type: str = None,
+                 bootstrap_method: str = None):
         """
         Set attributes and create plot.
 
@@ -752,20 +756,46 @@ class PlotBootstrapSkillScore(AbstractPlotClass):
         :param plot_folder: path to save the plot (default: current directory)
         :param model_setup: architecture type to specify plot name (default "CNN")
         :param separate_vars: variables to plot separated (default: ['o3'])
+        :param sampling: type of sampling rate, should be either hourly or daily (default: "daily")
+        :param ahead_dim: name of the ahead dimensions (default: "ahead")
+        :param bootstrap_annotation: additional information to use in the file name (default: None)
         """
-        super().__init__(plot_folder, f"skill_score_bootstrap_{model_setup}")
+        annotation = ["_".join([s for s in ["", bootstrap_type, bootstrap_method] if s is not None])][0]
+        super().__init__(plot_folder, f"skill_score_bootstrap_{model_setup}{annotation}")
         if separate_vars is None:
             separate_vars = ['o3']
         self._labels = None
         self._x_name = "boot_var"
-        self._data = self._prepare_data(data)
-        self._plot()
-        self._save()
-        self.plot_name += '_separated'
-        self._plot(separate_vars=separate_vars)
-        self._save(bbox_inches='tight')
+        self._ahead_dim = ahead_dim
+        self._boot_type = self._set_bootstrap_type(bootstrap_type)
+        self._boot_method = self._set_bootstrap_method(bootstrap_method)
+
+        self._title = f"Bootstrap analysis ({self._boot_method}, {self._boot_type})"
+        self._data = self._prepare_data(data, sampling)
+        if "branch" in self._data.columns:
+            plot_name = self.plot_name
+            for branch in self._data["branch"].unique():
+                self._title = f"Bootstrap analysis ({self._boot_method}, {self._boot_type}, {branch})"
+                self._plot(branch=branch)
+                self.plot_name = f"{plot_name}_{branch}"
+                self._save()
+        else:
+            self._plot()
+            self._save()
+            if len(set(separate_vars).intersection(self._data[self._x_name].unique())) > 0:
+                self.plot_name += '_separated'
+                self._plot(separate_vars=separate_vars)
+                self._save(bbox_inches='tight')
+
+    @staticmethod
+    def _set_bootstrap_type(boot_type):
+        return {"singleinput": "single input"}.get(boot_type, boot_type)
+
+    @staticmethod
+    def _set_bootstrap_method(boot_method):
+        return {"zero_mean": "zero mean", "shuffle": "shuffled"}.get(boot_method, boot_method)
 
-    def _prepare_data(self, data: Dict) -> pd.DataFrame:
+    def _prepare_data(self, data: Dict, sampling: str) -> pd.DataFrame:
         """
         Shrink given data, if only scores are relevant.
 
@@ -775,23 +805,53 @@ class PlotBootstrapSkillScore(AbstractPlotClass):
         :param data: dictionary with station names as keys and 2D xarrays as values
         :return: pre-processed data set
         """
-        data = helpers.dict_to_xarray(data, "station").sortby(self._x_name)
-        new_boot_coords = self._return_vars_without_number_tag(data.coords['boot_var'].values, split_by='_', keep=1)
-        data = data.assign_coords({'boot_var': new_boot_coords})
-        self._labels = [str(i) + "d" for i in data.coords["ahead"].values]
-        if "station" not in data.dims:
-            data = data.expand_dims("station")
-        return data.to_dataframe("data").reset_index(level=[0, 1, 2])
+        station_dim = "station"
+        data = helpers.dict_to_xarray(data, station_dim).sortby(self._x_name)
+        if self._boot_type == "single input":
+            number_tags = self._get_number_tag(data.coords[self._x_name].values, split_by='_')
+            new_boot_coords = self._return_vars_without_number_tag(data.coords[self._x_name].values, split_by='_',
+                                                                   keep=1, as_unique=True)
+            values = data.values.reshape((data.shape[0], len(new_boot_coords), len(number_tags), data.shape[-1]))
+            data = xr.DataArray(values, coords={station_dim: data.coords["station"], self._x_name: new_boot_coords,
+                                                "branch": number_tags, self._ahead_dim: data.coords[self._ahead_dim]},
+                                dims=[station_dim, self._x_name, "branch", self._ahead_dim])
+        else:
+            try:
+                new_boot_coords = self._return_vars_without_number_tag(data.coords[self._x_name].values, split_by='_',
+                                                                       keep=1)
+                data = data.assign_coords({self._x_name: new_boot_coords})
+            except NotImplementedError:
+                pass
+        _, sampling_letter = self._get_target_sampling(sampling, 1)
+        self._labels = [str(i) + sampling_letter for i in data.coords[self._ahead_dim].values]
+        if station_dim not in data.dims:
+            data = data.expand_dims(station_dim)
+        return data.to_dataframe("data").reset_index(level=np.arange(len(data.dims)).tolist())
+
+    @staticmethod
+    def _get_target_sampling(sampling, pos):
+        sampling = (sampling, sampling) if isinstance(sampling, str) else sampling
+        sampling_letter = {"hourly": "H", "daily": "d"}.get(sampling[pos], "")
+        return sampling, sampling_letter
 
-    def _return_vars_without_number_tag(self, values, split_by, keep):
+    def _return_vars_without_number_tag(self, values, split_by, keep, as_unique=False):
         arr = np.array([v.split(split_by) for v in values])
         num = arr[:, 0]
+        if arr.shape[keep] == 1:  # keep dim has only length 1, no number tags required
+            return num
         new_val = arr[:, keep]
         if self._all_values_are_equal(num, axis=0):
             return new_val
+        elif as_unique is True:
+            return np.unique(new_val)
         else:
             raise NotImplementedError
 
+    @staticmethod
+    def _get_number_tag(values, split_by):
+        arr = np.array([v.split(split_by) for v in values])
+        num = arr[:, 0]
+        return np.unique(num).tolist()
 
     @staticmethod
     def _all_values_are_equal(arr, axis=0):
@@ -809,45 +869,36 @@ class PlotBootstrapSkillScore(AbstractPlotClass):
         """
         return "" if score_only else "terms and "
 
-    def _plot(self, separate_vars=None):
+    def _plot(self, branch=None, separate_vars=None):
         """Plot climatological skill score."""
         if separate_vars is None:
-            self._plot_all_variables()
+            self._plot_all_variables(branch)
         else:
             self._plot_selected_variables(separate_vars)
 
     def _plot_selected_variables(self, separate_vars: List):
-        # if separate_vars is None:
-        #     separate_vars = ['o3']
         data = self._data
-        self.raise_error_if_separate_vars_do_not_exist(data, separate_vars)
-        all_variables = self._get_unique_values_from_column_of_df(data, 'boot_var')
-        # remaining_vars = helpers.list_pop(all_variables, separate_vars) #remove_items
+        self.raise_error_if_separate_vars_do_not_exist(data, separate_vars, self._x_name)
+        all_variables = self._get_unique_values_from_column_of_df(data, self._x_name)
         remaining_vars = helpers.remove_items(all_variables, separate_vars)
-        data_first = self._select_data(df=data, variables=separate_vars, column_name='boot_var')
-        data_second = self._select_data(df=data, variables=remaining_vars, column_name='boot_var')
-
-        fig, ax = plt.subplots(nrows=1, ncols=2,
-                               gridspec_kw={'width_ratios': [len(separate_vars),
-                                                             len(remaining_vars)
-                                                             ]
-                                            }
-                               )
+        data_first = self._select_data(df=data, variables=separate_vars, column_name=self._x_name)
+        data_second = self._select_data(df=data, variables=remaining_vars, column_name=self._x_name)
+
+        fig, ax = plt.subplots(nrows=1, ncols=2, gridspec_kw={'width_ratios': [len(separate_vars),
+                                                                               len(remaining_vars)]})
         if len(separate_vars) > 1:
             first_box_width = .8
         else:
             first_box_width = 2.
 
-        sns.boxplot(x=self._x_name, y="data", hue="ahead", data=data_first, ax=ax[0], whis=1., palette="Blues_d",
-                    showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"},
-                    flierprops={"marker": "."}, width=first_box_width
-                    )
+        sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=data_first, ax=ax[0], whis=1.,
+                    palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"},
+                    flierprops={"marker": "."}, width=first_box_width)
         ax[0].set(ylabel=f"skill score", xlabel="")
 
-        sns.boxplot(x=self._x_name, y="data", hue="ahead", data=data_second, ax=ax[1], whis=1., palette="Blues_d",
-                    showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"},
-                    flierprops={"marker": "."},
-                    )
+        sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=data_second, ax=ax[1], whis=1.,
+                    palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"},
+                    flierprops={"marker": "."})
         ax[1].set(ylabel="", xlabel="")
         ax[1].yaxis.tick_right()
         handles, _ = ax[1].get_legend_handles_labels()
@@ -882,9 +933,11 @@ class PlotBootstrapSkillScore(AbstractPlotClass):
 
         align_yaxis(ax[0], ax[1])
         align_yaxis(ax[0], ax[1])
+        plt.title(self._title)
 
     @staticmethod
     def _select_data(df: pd.DataFrame, variables: List[str], column_name: str) -> pd.DataFrame:
+        selected_data = None
         for i, variable in enumerate(variables):
             if i == 0:
                 selected_data = df.loc[df[column_name] == variable]
@@ -893,28 +946,29 @@ class PlotBootstrapSkillScore(AbstractPlotClass):
                 selected_data = pd.concat([selected_data, tmp_var], axis=0)
         return selected_data
 
-    def raise_error_if_separate_vars_do_not_exist(self, data, separate_vars):
-        if not self._variables_exist_in_df(df=data, variables=separate_vars):
+    def raise_error_if_separate_vars_do_not_exist(self, data, separate_vars, column_name):
+        if not self._variables_exist_in_df(df=data, variables=separate_vars, column_name=column_name):
             raise ValueError(f"At least one entry of `separate_vars' does not exist in `self.data' ")
 
     @staticmethod
     def _get_unique_values_from_column_of_df(df: pd.DataFrame, column_name: str) -> List:
         return list(df[column_name].unique())
 
-    def _variables_exist_in_df(self, df: pd.DataFrame, variables: List[str], column_name: str = 'boot_var'):
+    def _variables_exist_in_df(self, df: pd.DataFrame, variables: List[str], column_name: str):
         vars_in_df = set(self._get_unique_values_from_column_of_df(df, column_name))
         return set(variables).issubset(vars_in_df)
 
-    def _plot_all_variables(self):
+    def _plot_all_variables(self, branch=None):
         """
 
         """
         fig, ax = plt.subplots()
-        sns.boxplot(x=self._x_name, y="data", hue="ahead", data=self._data, ax=ax, whis=1., palette="Blues_d",
+        plot_data = self._data if branch is None else self._data[self._data["branch"] == str(branch)]
+        sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=plot_data, ax=ax, whis=1., palette="Blues_d",
                     showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, flierprops={"marker": "."})
         ax.axhline(y=0, color="grey", linewidth=.5)
         plt.xticks(rotation=45)
-        ax.set(ylabel=f"skill score", xlabel="", title="summary of all stations")
+        ax.set(ylabel=f"skill score", xlabel="", title=self._title)
         handles, _ = ax.get_legend_handles_labels()
         ax.legend(handles, self._labels)
         plt.tight_layout()
@@ -1029,8 +1083,6 @@ class PlotTimeSeries:
     def _plot_obs(self, ax, data):
         ahead = 1
         obs_data = data.sel(type="obs", ahead=ahead).shift(index=ahead)
-        # index = data.index + np.timedelta64(1, self._sampling)
-        # ax.plot(index, obs_data.values, color=matplotlib.colors.cnames["green"], label="obs")
         ax.plot(obs_data, color=matplotlib.colors.cnames["green"], label="obs")
 
     @staticmethod
diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py
index c5687e372298f9625794243324c77f2ed6abedb9..4755fff5b1709c688e420ed585e22b1ad9eab124 100644
--- a/mlair/run_modules/experiment_setup.py
+++ b/mlair/run_modules/experiment_setup.py
@@ -6,6 +6,7 @@ import logging
 import os
 import sys
 from typing import Union, Dict, Any, List, Callable
+from dill.source import getsource
 
 from mlair.configuration import path_config
 from mlair import helpers
@@ -20,7 +21,9 @@ from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT,
     DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \
     DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST, DEFAULT_SAMPLING, DEFAULT_DATA_ORIGIN, DEFAULT_ITER_DIM, \
     DEFAULT_USE_MULTIPROCESSING, DEFAULT_USE_MULTIPROCESSING_ON_DEBUG, DEFAULT_OVERSAMPLING_BINS, \
-    DEFAULT_OVERSAMPLING_RATES_CAP, DEFAULT_OVERSAMPLING_METHOD
+    DEFAULT_OVERSAMPLING_RATES_CAP, DEFAULT_OVERSAMPLING_METHOD, \
+    DEFAULT_MAX_NUMBER_MULTIPROCESSING, \
+    DEFAULT_BOOTSTRAP_TYPE, DEFAULT_BOOTSTRAP_METHOD
 from mlair.data_handler import DefaultDataHandler
 from mlair.run_modules.run_environment import RunEnvironment
 from mlair.model_modules.fully_connected_networks import FCN_64_32_16 as VanillaModel
@@ -215,11 +218,12 @@ class ExperimentSetup(RunEnvironment):
                  create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None,
                  train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None,
                  extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None,
-                 number_of_bootstraps=None,
-                 create_new_bootstraps=None, data_path: str = None, batch_path: str = None, login_nodes=None,
+                 number_of_bootstraps=None, create_new_bootstraps=None, bootstrap_method=None, bootstrap_type=None,
+                 data_path: str = None, batch_path: str = None, login_nodes=None,
                  hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None,
                  data_origin: Dict = None, competitors: list = None, competitor_path: str = None,
                  use_multiprocessing: bool = None, use_multiprocessing_on_debug: bool = None,
+                 max_number_multiprocessing: int = None, start_script: Union[Callable, str] = None,
                  oversampling_bins=None, oversampling_rates_cap=None, oversampling_method = None, **kwargs):
 
         # create run framework
@@ -273,6 +277,8 @@ class ExperimentSetup(RunEnvironment):
                             default=DEFAULT_USE_MULTIPROCESSING_ON_DEBUG)
         else:
             self._set_param("use_multiprocessing", use_multiprocessing, default=DEFAULT_USE_MULTIPROCESSING)
+        self._set_param("max_number_multiprocessing", max_number_multiprocessing,
+                        default=DEFAULT_MAX_NUMBER_MULTIPROCESSING)
 
         # batch path (temporary)
         self._set_param("batch_path", batch_path, default=os.path.join(experiment_path, "batch_data"))
@@ -357,6 +363,8 @@ class ExperimentSetup(RunEnvironment):
         self._set_param("create_new_bootstraps", create_new_bootstraps, scope="general.postprocessing")
         self._set_param("number_of_bootstraps", number_of_bootstraps, default=DEFAULT_NUMBER_OF_BOOTSTRAPS,
                         scope="general.postprocessing")
+        self._set_param("bootstrap_method", bootstrap_method, default=DEFAULT_BOOTSTRAP_METHOD)
+        self._set_param("bootstrap_type", bootstrap_type, default=DEFAULT_BOOTSTRAP_TYPE)
         self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing")
         self._set_param("neighbors", ["DEBW030"])  # TODO: just for testing
 
@@ -373,6 +381,9 @@ class ExperimentSetup(RunEnvironment):
         # set model architecture class
         self._set_param("model_class", model, VanillaModel)
 
+        # store starting script if provided
+        if start_script is not None:
+            self._store_start_script(start_script, experiment_path)
 
         # set remaining kwargs
         if len(kwargs) > 0:
@@ -395,6 +406,18 @@ class ExperimentSetup(RunEnvironment):
         logging.debug(f"set experiment attribute: {param}({scope})={value}")
         return value
 
+    @staticmethod
+    def _store_start_script(start_script, store_path):
+        out_file = os.path.join(store_path, "start_script.txt")
+        if isinstance(start_script, Callable):
+            with open(out_file, "w") as fh:
+                fh.write(getsource(start_script))
+        if isinstance(start_script, str):
+            with open(start_script, 'r') as f:
+                with open(out_file, "w") as out:
+                    for line in (f.readlines()):
+                        print(line, end='', file=out)
+
     def _compare_variables_and_statistics(self):
         """
         Compare variables and statistics.
diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py
index 8fae430fb48a28bdd8b21f8bfcfc7c569eb24f6c..83f4a2bd96314d6f8c53f8cc9407cbc12e7b9a16 100644
--- a/mlair/run_modules/model_setup.py
+++ b/mlair/run_modules/model_setup.py
@@ -12,7 +12,7 @@ import keras
 import pandas as pd
 import tensorflow as tf
 
-from mlair.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler
+from mlair.model_modules.keras_extensions import HistoryAdvanced, EpoTimingCallback, CallbackHandler
 from mlair.run_modules.run_environment import RunEnvironment
 from mlair.configuration import path_config
 
@@ -119,11 +119,14 @@ class ModelSetup(RunEnvironment):
         """
         lr = self.data_store.get_default("lr_decay", scope=self.scope, default=None)
         hist = HistoryAdvanced()
+        epo_timing = EpoTimingCallback()
         self.data_store.set("hist", hist, scope="model")
+        self.data_store.set("epo_timing", epo_timing, scope="model")
         callbacks = CallbackHandler()
         if lr is not None:
             callbacks.add_callback(lr, self.callbacks_name % "lr", "lr")
         callbacks.add_callback(hist, self.callbacks_name % "hist", "hist")
+        callbacks.add_callback(epo_timing, self.callbacks_name % "epo_timing", "epo_timing")
         callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
                                           save_best_only=True, mode='auto')
         self.data_store.set("callbacks", callbacks, self.scope)
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index 8a594808536ca5552003e88c4dbfd181237bb526..cef2c6510ae283b5ce5ca826b0d721edf6a57e76 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -86,6 +86,7 @@ class PostProcessing(RunEnvironment):
         self.competitor_path = self.data_store.get("competitor_path")
         self.competitors = to_list(self.data_store.get_default("competitors", default=[]))
         self.forecast_indicator = "nn"
+        self.ahead_dim = "ahead"
         self._run()
 
     def _run(self):
@@ -103,7 +104,10 @@ class PostProcessing(RunEnvironment):
         if self.data_store.get("evaluate_bootstraps", "postprocessing"):
             with TimeTracking(name="calculate bootstraps"):
                 create_new_bootstraps = self.data_store.get("create_new_bootstraps", "postprocessing")
-                self.bootstrap_postprocessing(create_new_bootstraps)
+                bootstrap_method = self.data_store.get("bootstrap_method", "postprocessing")
+                bootstrap_type = self.data_store.get("bootstrap_type", "postprocessing")
+                self.bootstrap_postprocessing(create_new_bootstraps, bootstrap_type=bootstrap_type,
+                                              bootstrap_method=bootstrap_method)
 
         # skill scores and error metrics
         with TimeTracking(name="calculate skill scores"):
@@ -136,7 +140,8 @@ class PostProcessing(RunEnvironment):
                 continue
         return xr.concat(competing_predictions, "type") if len(competing_predictions) > 0 else None
 
-    def bootstrap_postprocessing(self, create_new_bootstraps: bool, _iter: int = 0) -> None:
+    def bootstrap_postprocessing(self, create_new_bootstraps: bool, _iter: int = 0, bootstrap_type="singleinput",
+                                 bootstrap_method="shuffle") -> None:
         """
         Calculate skill scores of bootstrapped data.
 
@@ -149,18 +154,26 @@ class PostProcessing(RunEnvironment):
         :param _iter: internal counter to reduce unnecessary recursive calls (maximum number is 2, otherwise something
             went wrong).
         """
-        try:
-            if create_new_bootstraps:
-                self.create_bootstrap_forecast()
-            self.bootstrap_skill_scores = self.calculate_bootstrap_skill_scores()
-        except FileNotFoundError:
-            if _iter != 0:
-                raise RuntimeError("bootstrap_postprocessing is called for the 2nd time. This means, that calling"
-                                   "manually the reason for the failure.")
-            logging.info("Couldn't load all files, restart bootstrap postprocessing with create_new_bootstraps=True.")
-            self.bootstrap_postprocessing(True, _iter=1)
-
-    def create_bootstrap_forecast(self) -> None:
+        self.bootstrap_skill_scores = {}
+        for boot_type in to_list(bootstrap_type):
+            self.bootstrap_skill_scores[boot_type] = {}
+            for boot_method in to_list(bootstrap_method):
+                try:
+                    if create_new_bootstraps:
+                        self.create_bootstrap_forecast(bootstrap_type=boot_type, bootstrap_method=boot_method)
+                    boot_skill_score = self.calculate_bootstrap_skill_scores(bootstrap_type=boot_type,
+                                                                             bootstrap_method=boot_method)
+                    self.bootstrap_skill_scores[boot_type][boot_method] = boot_skill_score
+                except FileNotFoundError:
+                    if _iter != 0:
+                        raise RuntimeError(f"bootstrap_postprocessing ({boot_type}, {boot_type}) was called for the 2nd"
+                                           f" time. This means, that something internally goes wrong. Please check for "
+                                           f"possible errors")
+                    logging.info(f"Could not load all files for bootstrapping ({boot_type}, {boot_type}), restart "
+                                 f"bootstrap postprocessing with create_new_bootstraps=True.")
+                    self.bootstrap_postprocessing(True, _iter=1, bootstrap_type=boot_type, bootstrap_method=boot_method)
+
+    def create_bootstrap_forecast(self, bootstrap_type, bootstrap_method) -> None:
         """
         Create bootstrapped predictions for all stations and variables.
 
@@ -168,16 +181,15 @@ class PostProcessing(RunEnvironment):
         `bootstraps_labels_{station}.nc`.
         """
         # forecast
-        with TimeTracking(name=inspect.stack()[0].function):
+        with TimeTracking(name=f"{inspect.stack()[0].function} ({bootstrap_type}, {bootstrap_method})"):
             # extract all requirements from data store
-            bootstrap_path = self.data_store.get("bootstrap_path")
             forecast_path = self.data_store.get("forecast_path")
             number_of_bootstraps = self.data_store.get("number_of_bootstraps", "postprocessing")
-            dims = ["index", "ahead", "type"]
+            dims = ["index", self.ahead_dim, "type"]
             for station in self.test_data:
-                logging.info(str(station))
                 X, Y = None, None
-                bootstraps = BootStraps(station, number_of_bootstraps)
+                bootstraps = BootStraps(station, number_of_bootstraps, bootstrap_type=bootstrap_type,
+                                        bootstrap_method=bootstrap_method)
                 for boot in bootstraps:
                     X, Y, (index, dimension) = boot
                     # make bootstrap predictions
@@ -188,18 +200,19 @@ class PostProcessing(RunEnvironment):
                     bootstrap_predictions = np.expand_dims(bootstrap_predictions, axis=-1)
                     shape = bootstrap_predictions.shape
                     coords = (range(shape[0]), range(1, shape[1] + 1))
-                    var = f"{index}_{dimension}"
+                    var = f"{index}_{dimension}" if index is not None else str(dimension)
                     tmp = xr.DataArray(bootstrap_predictions, coords=(*coords, [var]), dims=dims)
-                    file_name = os.path.join(forecast_path, f"bootstraps_{station}_{var}.nc")
+                    file_name = os.path.join(forecast_path,
+                                             f"bootstraps_{station}_{var}_{bootstrap_type}_{bootstrap_method}.nc")
                     tmp.to_netcdf(file_name)
                 else:
                     # store also true labels for each station
                     labels = np.expand_dims(Y, axis=-1)
-                    file_name = os.path.join(forecast_path, f"bootstraps_{station}_labels.nc")
+                    file_name = os.path.join(forecast_path, f"bootstraps_{station}_{bootstrap_method}_labels.nc")
                     labels = xr.DataArray(labels, coords=(*coords, ["obs"]), dims=dims)
                     labels.to_netcdf(file_name)
 
-    def calculate_bootstrap_skill_scores(self) -> Dict[str, xr.DataArray]:
+    def calculate_bootstrap_skill_scores(self, bootstrap_type, bootstrap_method) -> Dict[str, xr.DataArray]:
         """
         Calculate skill score of bootstrapped variables.
 
@@ -209,53 +222,64 @@ class PostProcessing(RunEnvironment):
 
         :return: The result dictionary with station-wise skill scores
         """
-        with TimeTracking(name=inspect.stack()[0].function):
+        with TimeTracking(name=f"{inspect.stack()[0].function} ({bootstrap_type}, {bootstrap_method})"):
             # extract all requirements from data store
-            bootstrap_path = self.data_store.get("bootstrap_path")
             forecast_path = self.data_store.get("forecast_path")
             number_of_bootstraps = self.data_store.get("number_of_bootstraps", "postprocessing")
             forecast_file = f"forecasts_norm_%s_test.nc"
-            bootstraps = BootStraps(self.test_data[0], number_of_bootstraps).bootstraps()
-            skill_scores = statistics.SkillScores(None)
+
+            bootstraps = BootStraps(self.test_data[0], number_of_bootstraps, bootstrap_type=bootstrap_type,
+                                    bootstrap_method=bootstrap_method)
+            number_of_bootstraps = bootstraps.number_of_bootstraps
+            bootstrap_iter = bootstraps.bootstraps()
+            skill_scores = statistics.SkillScores(None, ahead_dim=self.ahead_dim)
             score = {}
             for station in self.test_data:
-                logging.info(station)
-
                 # get station labels
-                file_name = os.path.join(forecast_path, f"bootstraps_{str(station)}_labels.nc")
-                labels = xr.open_dataarray(file_name)
+                file_name = os.path.join(forecast_path, f"bootstraps_{str(station)}_{bootstrap_method}_labels.nc")
+                with xr.open_dataarray(file_name) as da:
+                    labels = da.load()
                 shape = labels.shape
 
                 # get original forecasts
                 orig = self.get_orig_prediction(forecast_path, forecast_file % str(station), number_of_bootstraps)
                 orig = orig.reshape(shape)
                 coords = (range(shape[0]), range(1, shape[1] + 1), ["orig"])
-                orig = xr.DataArray(orig, coords=coords, dims=["index", "ahead", "type"])
+                orig = xr.DataArray(orig, coords=coords, dims=["index", self.ahead_dim, "type"])
 
                 # calculate skill scores for each variable
                 skill = pd.DataFrame(columns=range(1, self.window_lead_time + 1))
-                for boot_set in bootstraps:
-                    boot_var = f"{boot_set[0]}_{boot_set[1]}"
-                    file_name = os.path.join(forecast_path, f"bootstraps_{station}_{boot_var}.nc")
-                    boot_data = xr.open_dataarray(file_name)
+                for boot_set in bootstrap_iter:
+                    boot_var = f"{boot_set[0]}_{boot_set[1]}" if isinstance(boot_set, tuple) else str(boot_set)
+                    file_name = os.path.join(forecast_path,
+                                             f"bootstraps_{station}_{boot_var}_{bootstrap_type}_{bootstrap_method}.nc")
+                    with xr.open_dataarray(file_name) as da:
+                        boot_data = da.load()
                     boot_data = boot_data.combine_first(labels).combine_first(orig)
                     boot_scores = []
                     for ahead in range(1, self.window_lead_time + 1):
-                        data = boot_data.sel(ahead=ahead)
+                        data = boot_data.sel({self.ahead_dim: ahead})
                         boot_scores.append(
                             skill_scores.general_skill_score(data, forecast_name=boot_var, reference_name="orig"))
                     skill.loc[boot_var] = np.array(boot_scores)
 
                 # collect all results in single dictionary
-                score[str(station)] = xr.DataArray(skill, dims=["boot_var", "ahead"])
+                score[str(station)] = xr.DataArray(skill, dims=["boot_var", self.ahead_dim])
             return score
 
     def get_orig_prediction(self, path, file_name, number_of_bootstraps, prediction_name=None):
         if prediction_name is None:
             prediction_name = self.forecast_indicator
         file = os.path.join(path, file_name)
-        prediction = xr.open_dataarray(file).sel(type=prediction_name).squeeze()
-        vals = np.tile(prediction.data, (number_of_bootstraps, 1))
+        with xr.open_dataarray(file) as da:
+            prediction = da.load().sel(type=prediction_name).squeeze()
+        return self.repeat_data(prediction, number_of_bootstraps)
+
+    @staticmethod
+    def repeat_data(data, number_of_repetition):
+        if isinstance(data, xr.DataArray):
+            data = data.data
+        vals = np.tile(data, (number_of_repetition, 1))
         return vals[~np.isnan(vals).any(axis=1), :]
 
     def _get_model_name(self):
@@ -335,8 +359,16 @@ class PostProcessing(RunEnvironment):
 
         try:
             if (self.bootstrap_skill_scores is not None) and ("PlotBootstrapSkillScore" in plot_list):
-                PlotBootstrapSkillScore(self.bootstrap_skill_scores, plot_folder=self.plot_path,
-                                        model_setup=self.forecast_indicator)
+                for boot_type, boot_data in self.bootstrap_skill_scores.items():
+                    for boot_method, boot_skill_score in boot_data.items():
+                        try:
+                            PlotBootstrapSkillScore(boot_skill_score, plot_folder=self.plot_path,
+                                                    model_setup=self.forecast_indicator, sampling=self._sampling,
+                                                    ahead_dim=self.ahead_dim, separate_vars=to_list(self.target_var),
+                                                    bootstrap_type=boot_type, bootstrap_method=boot_method)
+                        except Exception as e:
+                            logging.error(f"Could not create plot PlotBootstrapSkillScore ({boot_type}, {boot_method}) "
+                                          f"due to the following error: {e}")
         except Exception as e:
             logging.error(f"Could not create plot PlotBootstrapSkillScore due to the following error: {e}")
 
@@ -486,7 +518,8 @@ class PostProcessing(RunEnvironment):
                                    "obs": observation,
                                    "ols": ols_prediction}
                 all_predictions = self.create_forecast_arrays(full_index, list(target_data.indexes[window_dim]),
-                                                              time_dimension, **prediction_dict)
+                                                              time_dimension, ahead_dim=self.ahead_dim,
+                                                              **prediction_dict)
 
                 # save all forecasts locally
                 path = self.data_store.get("forecast_path")
@@ -512,8 +545,8 @@ class PostProcessing(RunEnvironment):
         """
         path = os.path.join(self.competitor_path, competitor_name)
         file = os.path.join(path, f"forecasts_{station_name}_test.nc")
-        data = xr.open_dataarray(file)
-        # data = data.expand_dims(Stations=[station_name])  # ToDo: remove line
+        with xr.open_dataarray(file) as da:
+            data = da.load()
         forecast = data.sel(type=[self.forecast_indicator])
         forecast.coords["type"] = [competitor_name]
         return forecast
@@ -550,7 +583,14 @@ class PostProcessing(RunEnvironment):
         """
         tmp_ols = self.ols_model.predict(input_data)
         target_shape = ols_prediction.values.shape
-        ols_prediction.values = np.swapaxes(tmp_ols, 2, 0) if target_shape != tmp_ols.shape else tmp_ols
+        if target_shape != tmp_ols.shape:
+            if len(target_shape)==2:
+                new_values = np.swapaxes(tmp_ols,1,0)
+            else:
+                new_values = np.swapaxes(tmp_ols, 2, 0)
+        else:
+            new_values = tmp_ols
+        ols_prediction.values = new_values
         if not normalised:
             ols_prediction = transformation_func(ols_prediction, "target", inverse=True)
         return ols_prediction
@@ -637,7 +677,8 @@ class PostProcessing(RunEnvironment):
         return index
 
     @staticmethod
-    def create_forecast_arrays(index: pd.DataFrame, ahead_names: List[Union[str, int]], time_dimension, **kwargs):
+    def create_forecast_arrays(index: pd.DataFrame, ahead_names: List[Union[str, int]], time_dimension,
+                               ahead_dim="ahead", **kwargs):
         """
         Combine different forecast types into single xarray.
 
@@ -650,7 +691,7 @@ class PostProcessing(RunEnvironment):
         """
         keys = list(kwargs.keys())
         res = xr.DataArray(np.full((len(index.index), len(ahead_names), len(keys)), np.nan),
-                           coords=[index.index, ahead_names, keys], dims=['index', 'ahead', 'type'])
+                           coords=[index.index, ahead_names, keys], dims=['index', ahead_dim, 'type'])
         for k, v in kwargs.items():
             intersection = set(res.index.values) & set(v.indexes[time_dimension].values)
             match_index = np.array(list(intersection))
@@ -668,7 +709,8 @@ class PostProcessing(RunEnvironment):
         """
         try:
             file = os.path.join(path, f"forecasts_{str(station)}_train_val.nc")
-            return xr.open_dataarray(file)
+            with xr.open_dataarray(file) as da:
+                return da.load()
         except (IndexError, KeyError, FileNotFoundError):
             return None
 
@@ -683,7 +725,8 @@ class PostProcessing(RunEnvironment):
         """
         try:
             file = os.path.join(path, f"forecasts_{str(station)}_test.nc")
-            return xr.open_dataarray(file)
+            with xr.open_dataarray(file) as da:
+                return da.load()
         except (IndexError, KeyError, FileNotFoundError):
             return None
 
@@ -725,14 +768,14 @@ class PostProcessing(RunEnvironment):
             competitor = self.load_competitors(station)
             combined = self._combine_forecasts(external_data, competitor, dim="type")
             model_list = remove_items(list(combined.type.values), "obs") if combined is not None else None
-            skill_score = statistics.SkillScores(combined, models=model_list)
+            skill_score = statistics.SkillScores(combined, models=model_list, ahead_dim=self.ahead_dim)
             if external_data is not None:
-                skill_score_competitive[station] = skill_score.skill_scores(self.window_lead_time)
+                skill_score_competitive[station] = skill_score.skill_scores()
 
             internal_data = self._get_internal_data(station, path)
             if internal_data is not None:
                 skill_score_climatological[station] = skill_score.climatological_skill_scores(
-                    internal_data, self.window_lead_time, forecast_name=self.forecast_indicator)
+                    internal_data, forecast_name=self.forecast_indicator)
 
         errors.update({"total": self.calculate_average_errors(errors)})
         return skill_score_competitive, skill_score_climatological, errors
diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py
index c00655239a6c0da727cee3462595ea959356a73a..3354e78c0c9ee85dad71f15a7a0171248913c0b7 100644
--- a/mlair/run_modules/pre_processing.py
+++ b/mlair/run_modules/pre_processing.py
@@ -284,10 +284,11 @@ class PreProcessing(RunEnvironment):
         kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope=set_name)
         use_multiprocessing = self.data_store.get("use_multiprocessing")
 
-        if multiprocessing.cpu_count() > 1 and use_multiprocessing:  # parallel solution
+        max_process = self.data_store.get("max_number_multiprocessing")
+        n_process = min([psutil.cpu_count(logical=False), len(set_stations), max_process])  # use only physical cpus
+        if n_process > 1 and use_multiprocessing is True:  # parallel solution
             logging.info("use parallel validate station approach")
-            pool = multiprocessing.Pool(
-                min([psutil.cpu_count(logical=False), len(set_stations), 16]))  # use only physical cpus
+            pool = multiprocessing.Pool(n_process)
             logging.info(f"running {getattr(pool, '_processes')} processes in parallel")
             output = [
                 pool.apply_async(f_proc, args=(data_handler, station, set_name, store_processed_data), kwds=kwargs)
@@ -309,40 +310,22 @@ class PreProcessing(RunEnvironment):
 
         logging.info(f"run for {t_outer} to check {len(set_stations)} station(s). Found {len(collection)}/"
                      f"{len(set_stations)} valid stations.")
-        return collection, valid_stations
-
-    def validate_station_old(self, data_handler: AbstractDataHandler, set_stations, set_name=None,
-                             store_processed_data=True):
-        """
-        Check if all given stations in `all_stations` are valid.
-
-        Valid means, that there is data available for the given time range (is included in `kwargs`). The shape and the
-        loading time are logged in debug mode.
-
-        :return: Corrected list containing only valid station IDs.
-        """
-        t_outer = TimeTracking()
-        logging.info(f"check valid stations started{' (%s)' % (set_name if set_name is not None else 'all')}")
-        # calculate transformation using train data
         if set_name == "train":
-            logging.info("setup transformation using train data exclusively")
-            self.transformation(data_handler, set_stations)
-        # start station check
-        collection = DataCollection()
-        valid_stations = []
-        kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope=set_name)
-        for station in set_stations:
-            try:
-                dp = data_handler.build(station, name_affix=set_name, store_processed_data=store_processed_data,
-                                        **kwargs)
-                collection.add(dp)
-                valid_stations.append(station)
-            except (AttributeError, EmptyQueryResult):
-                continue
-        logging.info(f"run for {t_outer} to check {len(set_stations)} station(s). Found {len(collection)}/"
-                     f"{len(set_stations)} valid stations.")
+            self.store_data_handler_attributes(data_handler, collection)
         return collection, valid_stations
 
+    def store_data_handler_attributes(self, data_handler, collection):
+        store_attributes = data_handler.store_attributes()
+        if len(store_attributes) > 0:
+            logging.info("store data requested by the data handler")
+            attrs = {}
+            for dh in collection:
+                station = str(dh)
+                for k, v in dh.get_store_attributes().items():
+                    attrs[k] = dict(attrs.get(k, {}), **{station: v})
+            for k, v in attrs.items():
+                self.data_store.set(k, v)
+
     def transformation(self, data_handler: AbstractDataHandler, stations):
         if hasattr(data_handler, "transformation"):
             kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope="train")
@@ -378,10 +361,11 @@ def f_proc(data_handler, station, name_affix, store, **kwargs):
     """
     try:
         res = data_handler.build(station, name_affix=name_affix, store_processed_data=store, **kwargs)
-    except (AttributeError, EmptyQueryResult, KeyError, requests.ConnectionError, ValueError) as e:
+    except (AttributeError, EmptyQueryResult, KeyError, requests.ConnectionError, ValueError, IndexError) as e:
         formatted_lines = traceback.format_exc().splitlines()
         logging.info(
             f"remove station {station} because it raised an error: {e} -> {' | '.join(f_inspect_error(formatted_lines))}")
+        logging.debug(f"detailed information for removal of station {station}: {traceback.format_exc()}")
         res = None
     return res, station
 
diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py
index 5f895b77d53d45bedc255bc7ff051f9d6a8d20a3..00e8eae1581453666d3ca11f48fcdaedf6a24ad0 100644
--- a/mlair/run_modules/training.py
+++ b/mlair/run_modules/training.py
@@ -166,7 +166,11 @@ class Training(RunEnvironment):
             lr = self.callbacks.get_callback_by_name("lr")
         except IndexError:
             lr = None
-        self.save_callbacks_as_json(history, lr)
+        try:
+            epo_timing = self.callbacks.get_callback_by_name("epo_timing")
+        except IndexError:
+            epo_timing = None
+        self.save_callbacks_as_json(history, lr, epo_timing)
         self.load_best_model(checkpoint.filepath)
         self.create_monitoring_plots(history, lr)
 
@@ -190,7 +194,7 @@ class Training(RunEnvironment):
         except OSError:
             logging.info('no weights to reload...')
 
-    def save_callbacks_as_json(self, history: Callback, lr_sc: Callback) -> None:
+    def save_callbacks_as_json(self, history: Callback, lr_sc: Callback, epo_timing: Callback) -> None:
         """
         Save callbacks (history, learning rate) of training.
 
@@ -207,6 +211,9 @@ class Training(RunEnvironment):
         if lr_sc:
             with open(os.path.join(path, "history_lr.json"), "w") as f:
                 json.dump(lr_sc.lr, f)
+        if epo_timing is not None:
+            with open(os.path.join(path, "epo_timing.json"), "w") as f:
+                json.dump(epo_timing.epo_timing, f)
 
     def create_monitoring_plots(self, history: Callback, lr_sc: Callback) -> None:
         """
diff --git a/run.py b/run.py
index 05b43ade453a4eb36952e18ad1c7ebab788dc37d..bd93db698c55bc8bae49c5d39a85f9d26cc49780 100644
--- a/run.py
+++ b/run.py
@@ -29,7 +29,7 @@ def main(parser_args):
         evaluate_bootstraps=False,  # plot_list=["PlotCompetitiveSkillScore"],
         competitors=["test_model", "test_model2"],
         competitor_path=os.path.join(os.getcwd(), "data", "comp_test"),
-        **parser_args.__dict__)
+        **parser_args.__dict__, start_script=__file__)
     workflow.run()
 
 
diff --git a/run_HPC.py b/run_HPC.py
index d6dbb4dc61e88a1e139b3cbe549bc6a3f2f0ab8a..dfa5045bbccf993d2381ff32c5aead90ea6957f3 100644
--- a/run_HPC.py
+++ b/run_HPC.py
@@ -7,7 +7,7 @@ from mlair.workflows import DefaultWorkflowHPC
 
 def main(parser_args):
 
-    workflow = DefaultWorkflowHPC(**parser_args.__dict__)
+    workflow = DefaultWorkflowHPC(**parser_args.__dict__, start_script=__file__)
     workflow.run()
 
 
diff --git a/run_hourly.py b/run_hourly.py
index 48c7205883eda7e08ee1c14fe3c0a8a9f429e3da..869f8ea16cd4093e04e40f1b05f863ca45ce3c99 100644
--- a/run_hourly.py
+++ b/run_hourly.py
@@ -22,7 +22,7 @@ def main(parser_args):
                                train_model=False,
                                create_new_model=False,
                                network="UBA",
-                               plot_list=["PlotStationMap"], **parser_args.__dict__)
+                               plot_list=["PlotStationMap"], **parser_args.__dict__, start_script=__file__)
     workflow.run()
 
 
diff --git a/run_hourly_kz.py b/run_hourly_kz.py
index 5536b56e732d81b84dfee7f34bd68d0d2ba49020..ba2939162c3fd22fc6a611bc7bc21b9334fbfd3b 100644
--- a/run_hourly_kz.py
+++ b/run_hourly_kz.py
@@ -19,7 +19,7 @@ def main(parser_args):
                 test_end="2011-12-31",
                 stations=["DEBW107", "DEBW013"]
                 )
-    workflow = DefaultWorkflow(**args)
+    workflow = DefaultWorkflow(**args, start_script=__file__)
     workflow.run()
 
 
diff --git a/run_mixed_sampling.py b/run_mixed_sampling.py
index 6ffb659953157060c39afb5960821e729df555dd..819ef51129854b4539632ef91a55e33a2607eb55 100644
--- a/run_mixed_sampling.py
+++ b/run_mixed_sampling.py
@@ -36,7 +36,7 @@ def main(parser_args):
                 test_end="2011-12-31",
                 **parser_args.__dict__,
                 )
-    workflow = DefaultWorkflow(**args)
+    workflow = DefaultWorkflow(**args, start_script=__file__)
     workflow.run()
 
 
diff --git a/run_zam347.py b/run_zam347.py
index 352f04177167441d3636359a9f6ade5f039c12c1..49fce3e7a0c0f2b24691c5b02590ff435300f552 100644
--- a/run_zam347.py
+++ b/run_zam347.py
@@ -31,7 +31,7 @@ def load_stations():
 
 def main(parser_args):
 
-    workflow = DefaultWorkflowHPC(stations=load_stations(), **parser_args.__dict__)
+    workflow = DefaultWorkflowHPC(stations=load_stations(), **parser_args.__dict__, start_script=__file__)
     workflow.run()
 
 
diff --git a/test/test_configuration/test_defaults.py b/test/test_configuration/test_defaults.py
index 922de3599dc7dc40717e0aeb8c7b8158ad21da38..27f38ce67b65c93a465051ab24fac1e8479fea59 100644
--- a/test/test_configuration/test_defaults.py
+++ b/test/test_configuration/test_defaults.py
@@ -68,4 +68,5 @@ class TestAllDefaults:
         assert DEFAULT_PLOT_LIST == ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore",
                                      "PlotTimeSeries", "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore",
                                      "PlotConditionalQuantiles", "PlotAvailability", "PlotAvailabilityHistogram",
-                                     "PlotDataHistogram","PlotOversampling","PlotOversamplingContingency"]
+                                     "PlotDataHistogram", "PlotPeriodogram","PlotOversampling",
+                                     "PlotOversamplingContingency"]
diff --git a/test/test_data_handler/old_t_bootstraps.py b/test/test_data_handler/old_t_bootstraps.py
index 9616ed3f457d74e44e8a9eae5a3ed862fa804011..21c18c6c2d6f6a6a38a41250f00d3d14a29ed457 100644
--- a/test/test_data_handler/old_t_bootstraps.py
+++ b/test/test_data_handler/old_t_bootstraps.py
@@ -160,7 +160,7 @@ class TestCreateShuffledData:
 
     def test_shuffle(self, shuffled_data_no_creation):
         dummy = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]])
-        res = shuffled_data_no_creation.shuffle(dummy, chunks=(2, 3)).compute()
+        res = shuffled_data_no_creation.apply_bootstrap_method(dummy, chunks=(2, 3)).compute()
         assert res.shape == dummy.shape
         assert dummy.max() >= res.max()
         assert dummy.min() <= res.min()
diff --git a/test/test_data_handler/test_data_handler_mixed_sampling.py b/test/test_data_handler/test_data_handler_mixed_sampling.py
index 2a6553b7f495bb4eb8aeddf7c39f2f2517edc967..7418a435008f06a9016f903fe140b51d0a7c8106 100644
--- a/test/test_data_handler/test_data_handler_mixed_sampling.py
+++ b/test/test_data_handler/test_data_handler_mixed_sampling.py
@@ -2,10 +2,10 @@ __author__ = 'Lukas Leufen'
 __date__ = '2020-12-10'
 
 from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling, \
-    DataHandlerMixedSamplingSingleStation, DataHandlerMixedSamplingWithFilter, \
-    DataHandlerMixedSamplingWithFilterSingleStation, DataHandlerSeparationOfScales, \
-    DataHandlerSeparationOfScalesSingleStation
-from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilterSingleStation
+    DataHandlerMixedSamplingSingleStation, DataHandlerMixedSamplingWithKzFilter, \
+    DataHandlerMixedSamplingWithKzFilterSingleStation, DataHandlerSeparationOfScales, \
+    DataHandlerSeparationOfScalesSingleStation, DataHandlerMixedSamplingWithFilterSingleStation
+from mlair.data_handler.data_handler_with_filter import DataHandlerKzFilterSingleStation
 from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation
 from mlair.helpers import remove_items
 from mlair.configuration.defaults import DEFAULT_INTERPOLATION_METHOD
@@ -86,19 +86,19 @@ class TestDataHandlerMixedSamplingSingleStation:
         pass
 
 
-class TestDataHandlerMixedSamplingWithFilter:
+class TestDataHandlerMixedSamplingWithKzFilter:
 
     def test_data_handler(self):
-        obj = object.__new__(DataHandlerMixedSamplingWithFilter)
-        assert obj.data_handler.__qualname__ == DataHandlerMixedSamplingWithFilterSingleStation.__qualname__
+        obj = object.__new__(DataHandlerMixedSamplingWithKzFilter)
+        assert obj.data_handler.__qualname__ == DataHandlerMixedSamplingWithKzFilterSingleStation.__qualname__
 
     def test_data_handler_transformation(self):
-        obj = object.__new__(DataHandlerMixedSamplingWithFilter)
-        assert obj.data_handler_transformation.__qualname__ == DataHandlerMixedSamplingWithFilterSingleStation.__qualname__
+        obj = object.__new__(DataHandlerMixedSamplingWithKzFilter)
+        assert obj.data_handler_transformation.__qualname__ == DataHandlerMixedSamplingWithKzFilterSingleStation.__qualname__
 
     def test_requirements(self):
-        obj = object.__new__(DataHandlerMixedSamplingWithFilter)
-        req1 = object.__new__(DataHandlerMixedSamplingSingleStation)
+        obj = object.__new__(DataHandlerMixedSamplingWithKzFilter)
+        req1 = object.__new__(DataHandlerMixedSamplingWithFilterSingleStation)
         req2 = object.__new__(DataHandlerKzFilterSingleStation)
         req = list(set(req1.requirements() + req2.requirements()))
         assert sorted(obj._requirements) == sorted(remove_items(req, "station"))
@@ -119,8 +119,8 @@ class TestDataHandlerSeparationOfScales:
         assert obj.data_handler_transformation.__qualname__ == DataHandlerSeparationOfScalesSingleStation.__qualname__
 
     def test_requirements(self):
-        obj = object.__new__(DataHandlerMixedSamplingWithFilter)
-        req1 = object.__new__(DataHandlerMixedSamplingSingleStation)
+        obj = object.__new__(DataHandlerMixedSamplingWithKzFilter)
+        req1 = object.__new__(DataHandlerMixedSamplingWithFilterSingleStation)
         req2 = object.__new__(DataHandlerKzFilterSingleStation)
         req = list(set(req1.requirements() + req2.requirements()))
         assert sorted(obj._requirements) == sorted(remove_items(req, "station"))
diff --git a/test/test_run_modules/test_model_setup.py b/test/test_run_modules/test_model_setup.py
index 8a7572148869537b505b2bd8e7f16cfdf7af1cdd..7cefd0e58f5b9b0787bafddffe1ad07e4851a068 100644
--- a/test/test_run_modules/test_model_setup.py
+++ b/test/test_run_modules/test_model_setup.py
@@ -80,7 +80,7 @@ class TestModelSetup:
         setup._set_callbacks()
         assert "general.model" in setup.data_store.search_name("callbacks")
         callbacks = setup.data_store.get("callbacks", "general.model")
-        assert len(callbacks.get_callbacks()) == 3
+        assert len(callbacks.get_callbacks()) == 4
 
     def test_set_callbacks_no_lr_decay(self, setup):
         setup.data_store.set("lr_decay", None, "general.model")
@@ -88,7 +88,7 @@ class TestModelSetup:
         setup.checkpoint_name = "TestName"
         setup._set_callbacks()
         callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.model")
-        assert len(callbacks.get_callbacks()) == 2
+        assert len(callbacks.get_callbacks()) == 3
         with pytest.raises(IndexError):
             callbacks.get_callback_by_name("lr_decay")
 
diff --git a/test/test_run_modules/test_pre_processing.py b/test/test_run_modules/test_pre_processing.py
index 5ae64bf3d535e72d9361394741ed8b8094091b1d..0f2ee7a10fd2e3190c0b66da558626747d4c03c9 100644
--- a/test/test_run_modules/test_pre_processing.py
+++ b/test/test_run_modules/test_pre_processing.py
@@ -109,7 +109,7 @@ class TestPreProcessing:
         assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 6 '
                                                                     r'station\(s\). Found 5/6 valid stations.'))
 
-    @mock.patch("multiprocessing.cpu_count", return_value=3)
+    @mock.patch("psutil.cpu_count", return_value=3)
     @mock.patch("multiprocessing.Pool", return_value=multiprocessing.Pool(3))
     def test_validate_station_parallel(self, mock_pool, mock_cpu, caplog, obj_with_exp_setup):
         pre = obj_with_exp_setup
diff --git a/test/test_run_modules/test_training.py b/test/test_run_modules/test_training.py
index c2b58cbd2160bd958c76ba67649ef8caba09fcb4..ed0d8264326f5299403c47deb46859ccde4a85d7 100644
--- a/test/test_run_modules/test_training.py
+++ b/test/test_run_modules/test_training.py
@@ -13,7 +13,7 @@ from mlair.data_handler import DataCollection, KerasIterator, DefaultDataHandler
 from mlair.helpers import PyTestRegex
 from mlair.model_modules.flatten import flatten_tail
 from mlair.model_modules.inception_model import InceptionModelBase
-from mlair.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced, CallbackHandler
+from mlair.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced, CallbackHandler, EpoTimingCallback
 from mlair.run_modules.run_environment import RunEnvironment
 from mlair.run_modules.training import Training
 
@@ -100,6 +100,12 @@ class TestTraining:
         h.model = mock.MagicMock()
         return h
 
+    @pytest.fixture
+    def epo_timing(self):
+        epo_timing = EpoTimingCallback()
+        epo_timing.epoch = [0, 1]
+        epo_timing.epo_timing = {"epo_timing": [0.1, 0.2]}
+
     @pytest.fixture
     def path(self):
         return os.path.join(os.path.dirname(__file__), "TestExperiment")
@@ -144,9 +150,11 @@ class TestTraining:
     def callbacks(self, path):
         clbk = CallbackHandler()
         hist = HistoryAdvanced()
+        epo_timing = EpoTimingCallback()
         clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist")
         lr = LearningRateDecay()
         clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr")
+        clbk.add_callback(epo_timing, os.path.join(path, "epo_timing.pickle"), "epo_timing")
         clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss',
                                      save_best_only=True)
         return clbk, hist, lr
@@ -256,22 +264,22 @@ class TestTraining:
         assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting"))
         assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload..."))
 
-    def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, model_path):
-        init_without_run.save_callbacks_as_json(history, learning_rate)
+    def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, epo_timing, model_path):
+        init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
         assert "history.json" in os.listdir(model_path)
 
-    def test_save_callbacks_lr_created(self, init_without_run, history, learning_rate, model_path):
-        init_without_run.save_callbacks_as_json(history, learning_rate)
+    def test_save_callbacks_lr_created(self, init_without_run, history, learning_rate, epo_timing, model_path):
+        init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
         assert "history_lr.json" in os.listdir(model_path)
 
-    def test_save_callbacks_inspect_history(self, init_without_run, history, learning_rate, model_path):
-        init_without_run.save_callbacks_as_json(history, learning_rate)
+    def test_save_callbacks_inspect_history(self, init_without_run, history, learning_rate, epo_timing, model_path):
+        init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
         with open(os.path.join(model_path, "history.json")) as jfile:
             hist = json.load(jfile)
             assert hist == history.history
 
-    def test_save_callbacks_inspect_lr(self, init_without_run, history, learning_rate, model_path):
-        init_without_run.save_callbacks_as_json(history, learning_rate)
+    def test_save_callbacks_inspect_lr(self, init_without_run, history, learning_rate, epo_timing, model_path):
+        init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
         with open(os.path.join(model_path, "history_lr.json")) as jfile:
             lr = json.load(jfile)
             assert lr == learning_rate.lr