diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index f4d042f003042319b3867857b756665a2aa3ddfc..eacbe3e26323e0a0bf1579cba53e2e12ecfd27c0 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -42,7 +42,7 @@ tests (from scratch):
     - ./CI/update_badge.sh > /dev/null
   script:
     - pip install --upgrade pip
-    - pip install numpy wheel six
+    - pip install numpy wheel six==1.15.0
     - zypper --non-interactive install binutils libproj-devel gdal-devel
     - zypper --non-interactive install proj geos-devel
     #    - cat requirements.txt | cut -f1 -d"#" | sed '/^\s*$/d' | xargs -L 1 pip install
diff --git a/HPC_setup/requirements_HDFML_additionals.txt b/HPC_setup/requirements_HDFML_additionals.txt
index 12e09ccdd620c0c81c78ae6d4781d4feb5b94baf..b2a29fbfb353f24d8c99d8429693022ea1fd406f 100644
--- a/HPC_setup/requirements_HDFML_additionals.txt
+++ b/HPC_setup/requirements_HDFML_additionals.txt
@@ -2,6 +2,7 @@ absl-py==0.11.0
 appdirs==1.4.4
 astor==0.8.1
 attrs==20.3.0
+bottleneck==1.3.2
 cached-property==1.5.2
 certifi==2020.12.5
 cftime==1.4.1
@@ -9,6 +10,7 @@ chardet==4.0.0
 coverage==5.4
 cycler==0.10.0
 dask==2021.2.0
+dill==0.3.3
 fsspec==0.8.5
 gast==0.4.0
 grpcio==1.35.0
diff --git a/HPC_setup/requirements_JUWELS_additionals.txt b/HPC_setup/requirements_JUWELS_additionals.txt
index 12e09ccdd620c0c81c78ae6d4781d4feb5b94baf..b2a29fbfb353f24d8c99d8429693022ea1fd406f 100644
--- a/HPC_setup/requirements_JUWELS_additionals.txt
+++ b/HPC_setup/requirements_JUWELS_additionals.txt
@@ -2,6 +2,7 @@ absl-py==0.11.0
 appdirs==1.4.4
 astor==0.8.1
 attrs==20.3.0
+bottleneck==1.3.2
 cached-property==1.5.2
 certifi==2020.12.5
 cftime==1.4.1
@@ -9,6 +10,7 @@ chardet==4.0.0
 coverage==5.4
 cycler==0.10.0
 dask==2021.2.0
+dill==0.3.3
 fsspec==0.8.5
 gast==0.4.0
 grpcio==1.35.0
diff --git a/mlair/data_handler/abstract_data_handler.py b/mlair/data_handler/abstract_data_handler.py
index f085d18bb8d33839a0e3b5f6f3d5ada92134e7f6..419db059a58beeb4ed7e3e198e41b565f8dc7d25 100644
--- a/mlair/data_handler/abstract_data_handler.py
+++ b/mlair/data_handler/abstract_data_handler.py
@@ -55,3 +55,6 @@ class AbstractDataHandler:
     def get_coordinates(self) -> Union[None, Dict]:
         """Return coordinates as dictionary with keys `lon` and `lat`."""
         return None
+
+    def _hash_list(self):
+        return []
diff --git a/mlair/data_handler/data_handler_kz_filter.py b/mlair/data_handler/data_handler_kz_filter.py
index 78638a13b4ea50cd073ca4599a291342fad849d4..539712b39e51c32203e1c55e28ce2eff24069479 100644
--- a/mlair/data_handler/data_handler_kz_filter.py
+++ b/mlair/data_handler/data_handler_kz_filter.py
@@ -7,7 +7,8 @@ import inspect
 import numpy as np
 import pandas as pd
 import xarray as xr
-from typing import List, Union
+from typing import List, Union, Tuple, Optional
+from functools import partial
 
 from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation
 from mlair.data_handler import DefaultDataHandler
@@ -22,6 +23,7 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation):
     """Data handler for a single station to be used by a superior data handler. Inputs are kz filtered."""
 
     _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"])
+    _hash = DataHandlerSingleStation._hash + ["kz_filter_length", "kz_filter_iter", "filter_dim"]
 
     DEFAULT_FILTER_DIM = "filter"
 
@@ -35,13 +37,23 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation):
         self.cutoff_period_days = None
         super().__init__(*args, **kwargs)
 
+    def setup_transformation(self, transformation: Union[None, dict, Tuple]) -> Tuple[Optional[dict], Optional[dict]]:
+        """
+        Adjust setup of transformation because kfz filtered data will have negative values which is not compatible with
+        the log transformation. Therefore, replace all log transformation methods by a default standardization. This is
+        only applied on input side.
+        """
+        transformation = super(__class__, self).setup_transformation(transformation)
+        if transformation[0] is not None:
+            for k, v in transformation[0].items():
+                if v["method"] == "log":
+                    transformation[0][k]["method"] = "standardise"
+        return transformation
+
     def _check_sampling(self, **kwargs):
         assert kwargs.get("sampling") == "hourly"  # This data handler requires hourly data resolution
 
-    def setup_samples(self):
-        """
-        Setup samples. This method prepares and creates samples X, and labels Y.
-        """
+    def make_input_target(self):
         data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling,
                                          self.station_type, self.network, self.store_data_locally, self.data_origin)
         self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
@@ -54,9 +66,6 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation):
         # import matplotlib.pyplot as plt
         # self.input_data.sel(filter="74d", variables="temp", Stations="DEBW107").plot()
         # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter")
-        if self.do_transformation is True:
-            self.call_transform()
-        self.make_samples()
 
     @TimeTrackingWrapper
     def apply_kz_filter(self):
@@ -88,6 +97,15 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation):
         return self.history.transpose(self.time_dim, self.window_dim, self.iter_dim, self.target_dim,
                                       self.filter_dim).copy()
 
+    def _create_lazy_data(self):
+        return [self._data, self.meta, self.input_data, self.target_data, self.cutoff_period, self.cutoff_period_days]
+
+    def _extract_lazy(self, lazy_data):
+        _data, self.meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days = lazy_data
+        f_prep = partial(self._slice_prep, start=self.start, end=self.end)
+        self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data]))
+
+
 class DataHandlerKzFilter(DefaultDataHandler):
     """Data handler using kz filtered data."""
 
diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py
index caaa7a62d1b772808dcaf58abdfa5483e80861e7..75e9e64506231f32406934b67e65454d87a43f61 100644
--- a/mlair/data_handler/data_handler_mixed_sampling.py
+++ b/mlair/data_handler/data_handler_mixed_sampling.py
@@ -12,6 +12,7 @@ import inspect
 from typing import Callable
 import datetime as dt
 from typing import Any
+from functools import partial
 
 import numpy as np
 import pandas as pd
@@ -54,15 +55,9 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation):
         assert len(parameter) == 2  # (inputs, targets)
         kwargs.update({parameter_name: parameter})
 
-    def setup_samples(self):
-        """
-        Setup samples. This method prepares and creates samples X, and labels Y.
-        """
+    def make_input_target(self):
         self._data = list(map(self.load_and_interpolate, [0, 1]))  # load input (0) and target (1) data
         self.set_inputs_and_targets()
-        if self.do_transformation is True:
-            self.call_transform()
-        self.make_samples()
 
     def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]:
         vars = [self.variables, self.target_var]
@@ -83,6 +78,12 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation):
         assert len(sampling) == 2
         return list(map(lambda x: super(__class__, self).setup_data_path(data_path, x), sampling))
 
+    def _extract_lazy(self, lazy_data):
+        _data, self.meta, _input_data, _target_data = lazy_data
+        f_prep = partial(self._slice_prep, start=self.start, end=self.end)
+        self._data = f_prep(_data[0]), f_prep(_data[1])
+        self.input_data, self.target_data = list(map(f_prep, [_input_data, _target_data]))
+
 
 class DataHandlerMixedSampling(DefaultDataHandler):
     """Data handler using mixed sampling for input and target."""
@@ -104,19 +105,14 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi
     def _check_sampling(self, **kwargs):
         assert kwargs.get("sampling") == ("hourly", "daily")
 
-    def setup_samples(self):
+    def make_input_target(self):
         """
-        Setup samples. This method prepares and creates samples X, and labels Y.
-
         A KZ filter is applied on the input data that has hourly resolution. Lables Y are provided as aggregated values
         with daily resolution.
         """
         self._data = list(map(self.load_and_interpolate, [0, 1]))  # load input (0) and target (1) data
         self.set_inputs_and_targets()
         self.apply_kz_filter()
-        if self.do_transformation is True:
-            self.call_transform()
-        self.make_samples()
 
     def estimate_filter_width(self):
         """
@@ -130,14 +126,24 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi
         new_date = dt.datetime.strptime(date, "%Y-%m-%d") + dt.timedelta(hours=delta)
         return new_date.strftime("%Y-%m-%d")
 
-    def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]:
-
+    def update_start_end(self, ind):
         if ind == 0:  # for inputs
             estimated_filter_width = self.estimate_filter_width()
             start = self._add_time_delta(self.start, -estimated_filter_width)
             end = self._add_time_delta(self.end, estimated_filter_width)
         else:  # target
             start, end = self.start, self.end
+        return start, end
+
+    def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]:
+
+        start, end = self.update_start_end(ind)
+        # if ind == 0:  # for inputs
+        #     estimated_filter_width = self.estimate_filter_width()
+        #     start = self._add_time_delta(self.start, -estimated_filter_width)
+        #     end = self._add_time_delta(self.end, estimated_filter_width)
+        # else:  # target
+        #     start, end = self.start, self.end
 
         vars = [self.variables, self.target_var]
         stats_per_var = helpers.select_from_dict(self.statistics_per_var, vars[ind])
@@ -149,6 +155,13 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi
                                 limit=self.interpolation_limit[ind])
         return data
 
+    def _extract_lazy(self, lazy_data):
+        _data, self.meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days = lazy_data
+        start_inp, end_inp = self.update_start_end(0)
+        self._data = list(map(lambda x: self._slice_prep(_data[x], *self.update_start_end(x)), [0, 1]))
+        self.input_data = self._slice_prep(_input_data, start_inp, end_inp)
+        self.target_data = self._slice_prep(_target_data, self.start, self.end)
+
 
 class DataHandlerMixedSamplingWithFilter(DefaultDataHandler):
     """Data handler using mixed sampling for input and target. Inputs are temporal filtered."""
@@ -169,6 +182,7 @@ class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithFil
     """
 
     _requirements = DataHandlerMixedSamplingWithFilterSingleStation.requirements()
+    _hash = DataHandlerMixedSamplingWithFilterSingleStation._hash + ["time_delta"]
 
     def __init__(self, *args, time_delta=np.sqrt, **kwargs):
         assert isinstance(time_delta, Callable)
@@ -204,7 +218,7 @@ class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithFil
         time_deltas = np.round(self.time_delta(self.cutoff_period)).astype(int)
         start, end = window, 1
         res = []
-        window_array = self.create_index_array(self.window_dim.range(start, end), squeeze_dim=self.target_dim)
+        window_array = self.create_index_array(self.window_dim, range(start, end), squeeze_dim=self.target_dim)
         for delta, filter_name in zip(np.append(time_deltas, 1), data.coords["filter"]):
             res_filter = []
             data_filter = data.sel({"filter": filter_name})
@@ -212,7 +226,7 @@ class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithFil
                 res_filter.append(data_filter.shift({dim: -w * delta}))
             res_filter = xr.concat(res_filter, dim=window_array).chunk()
             res.append(res_filter)
-        res = xr.concat(res, dim="filter")
+        res = xr.concat(res, dim="filter").compute()
         return res
 
     def estimate_filter_width(self):
diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py
index a894c635282b5879d79426168eb96d64ff5fa2a2..e9db27a9ff88efa2cc800723ac99279ec66d6cbb 100644
--- a/mlair/data_handler/data_handler_single_station.py
+++ b/mlair/data_handler/data_handler_single_station.py
@@ -5,9 +5,11 @@ __date__ = '2020-07-20'
 
 import copy
 import datetime as dt
+import dill
+import hashlib
 import logging
 import os
-from functools import reduce
+from functools import reduce, partial
 from typing import Union, List, Iterable, Tuple, Dict, Optional
 
 import numpy as np
@@ -45,6 +47,10 @@ class DataHandlerSingleStation(AbstractDataHandler):
     DEFAULT_INTERPOLATION_LIMIT = 0
     DEFAULT_INTERPOLATION_METHOD = "linear"
 
+    _hash = ["station", "statistics_per_var", "data_origin", "station_type", "network", "sampling", "target_dim",
+             "target_var", "time_dim", "iter_dim", "window_dim", "window_history_size", "window_history_offset",
+             "window_lead_time", "interpolation_limit", "interpolation_method"]
+
     def __init__(self, station, data_path, statistics_per_var, station_type=DEFAULT_STATION_TYPE,
                  network=DEFAULT_NETWORK, sampling: Union[str, Tuple[str]] = DEFAULT_SAMPLING,
                  target_dim=DEFAULT_TARGET_DIM, target_var=DEFAULT_TARGET_VAR, time_dim=DEFAULT_TIME_DIM,
@@ -54,10 +60,16 @@ class DataHandlerSingleStation(AbstractDataHandler):
                  interpolation_limit: Union[int, Tuple[int]] = DEFAULT_INTERPOLATION_LIMIT,
                  interpolation_method: Union[str, Tuple[str]] = DEFAULT_INTERPOLATION_METHOD,
                  overwrite_local_data: bool = False, transformation=None, store_data_locally: bool = True,
-                 min_length: int = 0, start=None, end=None, variables=None, data_origin: Dict = None, **kwargs):
+                 min_length: int = 0, start=None, end=None, variables=None, data_origin: Dict = None,
+                 lazy_preprocessing: bool = False, **kwargs):
         super().__init__()
         self.station = helpers.to_list(station)
         self.path = self.setup_data_path(data_path, sampling)
+        self.lazy = lazy_preprocessing
+        self.lazy_path = None
+        if self.lazy is True:
+            self.lazy_path = os.path.join(data_path, "lazy_data", self.__class__.__name__)
+            check_path_and_create(self.lazy_path)
         self.statistics_per_var = statistics_per_var
         self.data_origin = data_origin
         self.do_transformation = transformation is not None
@@ -183,7 +195,17 @@ class DataHandlerSingleStation(AbstractDataHandler):
             else:
                 raise NotImplementedError
 
-        def f_apply(data, method, mean=None, std=None, min=None, max=None):
+        def f_apply(data, method, **kwargs):
+            for k, v in kwargs.items():
+                if not (isinstance(v, xr.DataArray) or v is None):
+                    _, opts = statistics.min_max(data, dim)
+                    helper = xr.ones_like(opts['min'])
+                    kwargs[k] = helper * v
+            mean = kwargs.pop('mean', None)
+            std = kwargs.pop('std', None)
+            min = kwargs.pop('min', None)
+            max = kwargs.pop('max', None)
+
             if method == "standardise":
                 return statistics.standardise_apply(data, mean, std), {"mean": mean, "std": std, "method": method}
             elif method == "centre":
@@ -215,15 +237,48 @@ class DataHandlerSingleStation(AbstractDataHandler):
         """
         Setup samples. This method prepares and creates samples X, and labels Y.
         """
+        if self.lazy is False:
+            self.make_input_target()
+        else:
+            self.load_lazy()
+            self.store_lazy()
+        if self.do_transformation is True:
+            self.call_transform()
+        self.make_samples()
+
+    def store_lazy(self):
+        hash = self._get_hash()
+        filename = os.path.join(self.lazy_path, hash + ".pickle")
+        if not os.path.exists(filename):
+            dill.dump(self._create_lazy_data(), file=open(filename, "wb"))
+
+    def _create_lazy_data(self):
+        return [self._data, self.meta, self.input_data, self.target_data]
+
+    def load_lazy(self):
+        hash = self._get_hash()
+        filename = os.path.join(self.lazy_path, hash + ".pickle")
+        try:
+            with open(filename, "rb") as pickle_file:
+                lazy_data = dill.load(pickle_file)
+            self._extract_lazy(lazy_data)
+            logging.debug(f"{self.station[0]}: used lazy data")
+        except FileNotFoundError:
+            logging.debug(f"{self.station[0]}: could not use lazy data")
+            self.make_input_target()
+
+    def _extract_lazy(self, lazy_data):
+        _data, self.meta, _input_data, _target_data = lazy_data
+        f_prep = partial(self._slice_prep, start=self.start, end=self.end)
+        self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data]))
+
+    def make_input_target(self):
         data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling,
                                          self.station_type, self.network, self.store_data_locally, self.data_origin,
                                          self.start, self.end)
         self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
                                       limit=self.interpolation_limit)
         self.set_inputs_and_targets()
-        if self.do_transformation is True:
-            self.call_transform()
-        self.make_samples()
 
     def set_inputs_and_targets(self):
         inputs = self._data.sel({self.target_dim: helpers.to_list(self.variables)})
@@ -551,8 +606,7 @@ class DataHandlerSingleStation(AbstractDataHandler):
         """
         return data.loc[{coord: slice(str(start), str(end))}]
 
-    @staticmethod
-    def setup_transformation(transformation: Union[None, dict, Tuple]) -> Tuple[Optional[dict], Optional[dict]]:
+    def setup_transformation(self, transformation: Union[None, dict, Tuple]) -> Tuple[Optional[dict], Optional[dict]]:
         """
         Set up transformation by extracting all relevant information.
 
@@ -658,6 +712,13 @@ class DataHandlerSingleStation(AbstractDataHandler):
         return self.transform(data, dim=dim, opts=self._transformation[pos], inverse=inverse,
                               transformation_dim=self.target_dim)
 
+    def _hash_list(self):
+        return sorted(list(set(self._hash)))
+
+    def _get_hash(self):
+        hash = "".join([str(self.__getattribute__(e)) for e in self._hash_list()]).encode()
+        return hashlib.md5(hash).hexdigest()
+
 
 if __name__ == "__main__":
     # dp = AbstractDataPrep('data/', 'dummy', 'DEBW107', ['o3', 'temp'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py
index 704b07cfb83baaa0539f134058d6096adc5554a9..ff1f13aed4b8d829edc653c6e99b6cff82287476 100644
--- a/mlair/data_handler/default_data_handler.py
+++ b/mlair/data_handler/default_data_handler.py
@@ -8,6 +8,7 @@ import gc
 import logging
 import os
 import pickle
+import dill
 import shutil
 from functools import reduce
 from typing import Tuple, Union, List
@@ -86,7 +87,7 @@ class DefaultDataHandler(AbstractDataHandler):
             data = {"X": self._X, "Y": self._Y, "X_extreme": self._X_extreme, "Y_extreme": self._Y_extreme}
             data = self._force_dask_computation(data)
             with open(self._save_file, "wb") as f:
-                pickle.dump(data, f)
+                dill.dump(data, f)
             logging.debug(f"save pickle data to {self._save_file}")
             self._reset_data()
 
@@ -101,7 +102,7 @@ class DefaultDataHandler(AbstractDataHandler):
     def _load(self):
         try:
             with open(self._save_file, "rb") as f:
-                data = pickle.load(f)
+                data = dill.load(f)
             logging.debug(f"load pickle data from {self._save_file}")
             self._X, self._Y = data["X"], data["Y"]
             self._X_extreme, self._Y_extreme = data["X_extreme"], data["Y_extreme"]
@@ -240,6 +241,8 @@ class DefaultDataHandler(AbstractDataHandler):
 
         * standardise (default, if method is not given)
         * centre
+        * min_max
+        * log
 
         ### mean and std estimation
 
@@ -255,14 +258,16 @@ class DefaultDataHandler(AbstractDataHandler):
 
         If mean and std are not None, the default data handler expects this parameters to match the data and applies
         this values to the data. Make sure that all dimensions and/or coordinates are in agreement.
+
+        ### min and max given
+        If min and max are not None, the default data handler expects this parameters to match the data and applies
+        this values to the data. Make sure that all dimensions and/or coordinates are in agreement.
         """
 
         sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs}
-        transformation_dict = sp_keys.get("transformation", None)
-        if transformation_dict is None:
+        if "transformation" not in sp_keys.keys():
             return
-        if isinstance(transformation_dict, dict):  # tuple for (input, target) transformation
-            transformation_dict = copy.deepcopy(transformation_dict), copy.deepcopy(transformation_dict)
+        transformation_dict = ({}, {})
 
         def _inner():
             """Inner method that is performed in both serial and parallel approach."""
@@ -272,7 +277,9 @@ class DefaultDataHandler(AbstractDataHandler):
                         if var not in transformation_dict[i].keys():
                             transformation_dict[i][var] = {}
                         opts = transformation[var]
-                        assert transformation_dict[i][var].get("method", opts["method"]) == opts["method"]
+                        if not transformation_dict[i][var].get("method", opts["method"]) == opts["method"]:
+                            # data handlers with filters are allowed to change transformation method to standardise
+                            assert hasattr(dh, "filter_dim") and opts["method"] == "standardise"
                         transformation_dict[i][var]["method"] = opts["method"]
                         for k in ["mean", "std", "min", "max"]:
                             old = transformation_dict[i][var].get(k, None)
diff --git a/mlair/data_handler/iterator.py b/mlair/data_handler/iterator.py
index 30c45417a64e949b0c0535a96a20c933641fdcbb..564bf3bfd6e4f5b814c9d090733cfbfbf26a850b 100644
--- a/mlair/data_handler/iterator.py
+++ b/mlair/data_handler/iterator.py
@@ -9,6 +9,7 @@ import math
 import os
 import shutil
 import pickle
+import dill
 from typing import Tuple, List
 
 
@@ -109,7 +110,7 @@ class KerasIterator(keras.utils.Sequence):
         """Load pickle data from disk."""
         file = self._path % index
         with open(file, "rb") as f:
-            data = pickle.load(f)
+            data = dill.load(f)
         return data["X"], data["Y"]
 
     @staticmethod
@@ -167,7 +168,7 @@ class KerasIterator(keras.utils.Sequence):
         data = {"X": X, "Y": Y}
         file = self._path % index
         with open(file, "wb") as f:
-            pickle.dump(data, f)
+            dill.dump(data, f)
 
     def _get_number_of_mini_batches(self, number_of_samples: int) -> int:
         """Return number of mini batches as the floored ration of number of samples to batch size."""
diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py
index ee727ef59ff35334be0a52a4d78dbae814d6c205..b57b733b08c4635a16d7fd18e99538a991521fd8 100644
--- a/mlair/helpers/helpers.py
+++ b/mlair/helpers/helpers.py
@@ -103,7 +103,7 @@ def remove_items(obj: Union[List, Dict], items: Any):
         raise TypeError(f"{inspect.stack()[0][3]} does not support type {type(obj)}.")
 
 
-def select_from_dict(dict_obj: dict, sel_list: Any):
+def select_from_dict(dict_obj: dict, sel_list: Any, remove_none=False):
     """
     Extract all key values pairs whose key is contained in the sel_list.
 
@@ -113,6 +113,7 @@ def select_from_dict(dict_obj: dict, sel_list: Any):
     sel_list = to_list(sel_list)
     assert isinstance(dict_obj, dict)
     sel_dict = {k: v for k, v in dict_obj.items() if k in sel_list}
+    sel_dict = sel_dict if not remove_none else {k: v for k, v in sel_dict.items() if v is not None}
     return sel_dict
 
 
diff --git a/mlair/helpers/statistics.py b/mlair/helpers/statistics.py
index 3631597aedb90b3411163a42490e9c023bad706a..3e99357c36d556f093701325964500bf8d46c698 100644
--- a/mlair/helpers/statistics.py
+++ b/mlair/helpers/statistics.py
@@ -11,8 +11,10 @@ import pandas as pd
 from typing import Union, Tuple, Dict, List
 from matplotlib import pyplot as plt
 import itertools
+import gc
+import warnings
 
-from mlair.helpers import to_list
+from mlair.helpers import to_list, TimeTracking, TimeTrackingWrapper
 
 Data = Union[xr.DataArray, pd.DataFrame]
 
@@ -438,7 +440,7 @@ class SkillScores:
         """Calculate CASE IV."""
         AI, BI, CI, data, suffix = self.skill_score_pre_calculations(internal_data, observation_name, forecast_name)
         monthly_mean_external = self.create_monthly_mean_from_daily_data(external_data, index=data.index)
-        data = xr.concat([data, monthly_mean_external], dim="type")
+        data = xr.concat([data, monthly_mean_external], dim="type").dropna(dim="index")
         mean, sigma = suffix["mean"], suffix["sigma"]
         mean_external = monthly_mean_external.mean()
         sigma_external = np.sqrt(monthly_mean_external.var())
@@ -608,6 +610,48 @@ class KolmogorovZurbenkoFilterMovingWindow(KolmogorovZurbenkoBaseClass):
         else:
             return None
 
+    @TimeTrackingWrapper
+    def kz_filter_new(self, df, wl, itr):
+        """
+        It passes the low frequency time series.
+
+        If filter method is from mean, max, min this method will call construct and rechunk before the actual
+        calculation to improve performance. If filter method is either median or percentile this approach is not
+        applicable and depending on the data and window size, this method can become slow.
+
+        Args:
+             wl(int): a window length
+             itr(int): a number of iteration
+        """
+        warnings.filterwarnings("ignore")
+        df_itr = df.__deepcopy__()
+        try:
+            kwargs = {"min_periods": int(0.7 * wl),
+                      "center": True,
+                      self.filter_dim: wl}
+            for i in np.arange(0, itr):
+                print(i)
+                rolling = df_itr.chunk().rolling(**kwargs)
+                if self.method not in ["percentile", "median"]:
+                    rolling = rolling.construct("construct").chunk("auto")
+                if self.method == "median":
+                    df_mv_avg_tmp = rolling.median()
+                elif self.method == "percentile":
+                    df_mv_avg_tmp = rolling.quantile(self.percentile)
+                elif self.method == "max":
+                    df_mv_avg_tmp = rolling.max("construct")
+                elif self.method == "min":
+                    df_mv_avg_tmp = rolling.min("construct")
+                else:
+                    df_mv_avg_tmp = rolling.mean("construct")
+                df_itr = df_mv_avg_tmp.compute()
+                del df_mv_avg_tmp, rolling
+                gc.collect()
+            return df_itr
+        except ValueError:
+            raise ValueError
+
+    @TimeTrackingWrapper
     def kz_filter(self, df, wl, itr):
         """
         It passes the low frequency time series.
@@ -616,15 +660,18 @@ class KolmogorovZurbenkoFilterMovingWindow(KolmogorovZurbenkoBaseClass):
              wl(int): a window length
              itr(int): a number of iteration
         """
+        import warnings
+        warnings.filterwarnings("ignore")
         df_itr = df.__deepcopy__()
         try:
-            kwargs = {"min_periods": 1,
+            kwargs = {"min_periods": int(0.7 * wl),
                       "center": True,
                       self.filter_dim: wl}
             iter_vars = df_itr.coords["variables"].values
             for var in iter_vars:
-                df_itr_var = df_itr.sel(variables=[var]).chunk()
+                df_itr_var = df_itr.sel(variables=[var])
                 for _ in np.arange(0, itr):
+                    df_itr_var = df_itr_var.chunk()
                     rolling = df_itr_var.rolling(**kwargs)
                     if self.method == "median":
                         df_mv_avg_tmp = rolling.median()
@@ -637,7 +684,7 @@ class KolmogorovZurbenkoFilterMovingWindow(KolmogorovZurbenkoBaseClass):
                     else:
                         df_mv_avg_tmp = rolling.mean()
                     df_itr_var = df_mv_avg_tmp.compute()
-                df_itr = df_itr.drop_sel(variables=var).combine_first(df_itr_var)
+                df_itr.loc[{"variables": [var]}] = df_itr_var
             return df_itr
         except ValueError:
             raise ValueError
diff --git a/mlair/model_modules/abstract_model_class.py b/mlair/model_modules/abstract_model_class.py
index 894ff7ac4e787a8b31f75ff932f60bec8c561094..989f4578f78e6566dfca5a63f671ced8120491d8 100644
--- a/mlair/model_modules/abstract_model_class.py
+++ b/mlair/model_modules/abstract_model_class.py
@@ -82,7 +82,7 @@ class AbstractModelClass(ABC):
         self.__custom_objects = value
 
     @property
-    def compile_options(self) -> Callable:
+    def compile_options(self) -> Dict:
         """
         The compile options property allows the user to use all keras.compile() arguments. They can ether be passed as
         dictionary (1), as attribute, without setting compile_options (2) or as mixture (partly defined as instance
@@ -116,7 +116,7 @@ class AbstractModelClass(ABC):
             def set_compile_options(self):
                 self.optimizer = keras.optimizers.SGD()
                 self.loss = keras.losses.mean_squared_error
-                self.compile_options = {"optimizer" = keras.optimizers.Adam(), "metrics": ["mse", "mae"]}
+                self.compile_options = {"optimizer": keras.optimizers.Adam(), "metrics": ["mse", "mae"]}
 
         Note:
         * As long as the attribute and the dict value have exactly the same values, the setter method will not raise
diff --git a/mlair/model_modules/fully_connected_networks.py b/mlair/model_modules/fully_connected_networks.py
index dbcd3a9f41ca1b9a7435be95b93eb40c2b37c5a0..9fb08cdf6efacab12c2828ed221966586bce1d08 100644
--- a/mlair/model_modules/fully_connected_networks.py
+++ b/mlair/model_modules/fully_connected_networks.py
@@ -5,57 +5,11 @@ from functools import reduce, partial
 
 from mlair.model_modules import AbstractModelClass
 from mlair.helpers import select_from_dict
+from mlair.model_modules.loss import var_loss, custom_loss
 
 import keras
 
 
-class FCN_64_32_16(AbstractModelClass):
-    """
-    A customised model 4 Dense layers (64, 32, 16, window_lead_time), where the last layer is the output layer depending
-    on the window_lead_time parameter.
-    """
-
-    def __init__(self, input_shape: list, output_shape: list):
-        """
-        Sets model and loss depending on the given arguments.
-
-        :param input_shape: list of input shapes (expect len=1 with shape=(window_hist, station, variables))
-        :param output_shape: list of output shapes (expect len=1 with shape=(window_forecast))
-        """
-
-        assert len(input_shape) == 1
-        assert len(output_shape) == 1
-        super().__init__(input_shape[0], output_shape[0])
-
-        # settings
-        self.activation = keras.layers.PReLU
-
-        # apply to model
-        self.set_model()
-        self.set_compile_options()
-        self.set_custom_objects(loss=self.compile_options['loss'])
-
-    def set_model(self):
-        """
-        Build the model.
-        """
-        x_input = keras.layers.Input(shape=self._input_shape)
-        x_in = keras.layers.Flatten()(x_input)
-        x_in = keras.layers.Dense(64, name="Dense_64")(x_in)
-        x_in = self.activation()(x_in)
-        x_in = keras.layers.Dense(32, name="Dense_32")(x_in)
-        x_in = self.activation()(x_in)
-        x_in = keras.layers.Dense(16, name="Dense_16")(x_in)
-        x_in = self.activation()(x_in)
-        x_in = keras.layers.Dense(self._output_shape, name="Dense_output")(x_in)
-        out_main = self.activation()(x_in)
-        self.model = keras.Model(inputs=x_input, outputs=[out_main])
-
-    def set_compile_options(self):
-        self.optimizer = keras.optimizers.adam(lr=1e-2)
-        self.compile_options = {"loss": [keras.losses.mean_squared_error], "metrics": ["mse", "mae"]}
-
-
 class FCN(AbstractModelClass):
     """
     A customisable fully connected network (64, 32, 16, window_lead_time), where the last layer is the output layer depending
@@ -64,12 +18,20 @@ class FCN(AbstractModelClass):
 
     _activation = {"relu": keras.layers.ReLU, "tanh": partial(keras.layers.Activation, "tanh"),
                    "sigmoid": partial(keras.layers.Activation, "sigmoid"),
-                   "linear": partial(keras.layers.Activation, "linear")}
+                   "linear": partial(keras.layers.Activation, "linear"),
+                   "selu": partial(keras.layers.Activation, "selu"),
+                   "prelu": partial(keras.layers.PReLU, alpha_initializer=keras.initializers.constant(value=0.25))}
+    _initializer = {"tanh": "glorot_uniform", "sigmoid": "glorot_uniform", "linear": "glorot_uniform",
+                    "relu": keras.initializers.he_normal(), "selu": keras.initializers.lecun_normal(),
+                    "prelu": keras.initializers.he_normal()}
     _optimizer = {"adam": keras.optimizers.adam, "sgd": keras.optimizers.SGD}
-    _requirements = ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad", "momentum", "nesterov"]
+    _regularizer = {"l1": keras.regularizers.l1, "l2": keras.regularizers.l2, "l1_l2": keras.regularizers.l1_l2}
+    _requirements = ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad", "momentum", "nesterov", "l1", "l2"]
+    _dropout = {"selu": keras.layers.AlphaDropout}
 
     def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear",
-                 optimizer="adam", n_layer=1, n_hidden=10, **kwargs):
+                 optimizer="adam", n_layer=1, n_hidden=10, regularizer=None, dropout=None, layer_configuration=None,
+                 **kwargs):
         """
         Sets model and loss depending on the given arguments.
 
@@ -83,15 +45,20 @@ class FCN(AbstractModelClass):
 
         # settings
         self.activation = self._set_activation(activation)
+        self.activation_name = activation
         self.activation_output = self._set_activation(activation_output)
+        self.activation_output_name = activation_output
         self.optimizer = self._set_optimizer(optimizer, **kwargs)
-        self.layer_configuration = (n_layer, n_hidden)
+        self.layer_configuration = (n_layer, n_hidden) if layer_configuration is None else layer_configuration
         self._update_model_name()
+        self.kernel_initializer = self._initializer.get(activation, "glorot_uniform")
+        self.kernel_regularizer = self._set_regularizer(regularizer, **kwargs)
+        self.dropout, self.dropout_rate = self._set_dropout(activation, dropout)
 
         # apply to model
         self.set_model()
         self.set_compile_options()
-        # self.set_custom_objects(loss=self.compile_options['loss'])
+        self.set_custom_objects(loss=self.compile_options["loss"][0], var_loss=var_loss)
 
     def _set_activation(self, activation):
         try:
@@ -112,11 +79,37 @@ class FCN(AbstractModelClass):
         except KeyError:
             raise AttributeError(f"Given optimizer {optimizer} is not supported in this model class.")
 
+    def _set_regularizer(self, regularizer, **kwargs):
+        if regularizer is None or (isinstance(regularizer, str) and regularizer.lower() == "none"):
+            return None
+        try:
+            reg_name = regularizer.lower()
+            reg = self._regularizer.get(reg_name)
+            reg_kwargs = {}
+            if reg_name in ["l1", "l2"]:
+                reg_kwargs = select_from_dict(kwargs, reg_name, remove_none=True)
+                if reg_name in reg_kwargs:
+                    reg_kwargs["l"] = reg_kwargs.pop(reg_name)
+            elif reg_name == "l1_l2":
+                reg_kwargs = select_from_dict(kwargs, ["l1", "l2"], remove_none=True)
+            return reg(**reg_kwargs)
+        except KeyError:
+            raise AttributeError(f"Given regularizer {regularizer} is not supported in this model class.")
+
+    def _set_dropout(self, activation, dropout_rate):
+        if dropout_rate is None:
+            return None, None
+        assert 0 <= dropout_rate < 1
+        return self._dropout.get(activation, keras.layers.Dropout), dropout_rate
+
     def _update_model_name(self):
-        n_layer, n_hidden = self.layer_configuration
         n_input = str(reduce(lambda x, y: x * y, self._input_shape))
         n_output = str(self._output_shape)
-        self.model_name += "_".join(["", n_input, *[f"{n_hidden}" for _ in range(n_layer)], n_output])
+        if isinstance(self.layer_configuration, tuple) and len(self.layer_configuration) == 2:
+            n_layer, n_hidden = self.layer_configuration
+            self.model_name += "_".join(["", n_input, *[f"{n_hidden}" for _ in range(n_layer)], n_output])
+        else:
+            self.model_name += "_".join(["", n_input, *[f"{n}" for n in self.layer_configuration], n_output])
 
     def set_model(self):
         """
@@ -124,13 +117,53 @@ class FCN(AbstractModelClass):
         """
         x_input = keras.layers.Input(shape=self._input_shape)
         x_in = keras.layers.Flatten()(x_input)
-        n_layer, n_hidden = self.layer_configuration
-        for layer in range(n_layer):
-            x_in = keras.layers.Dense(n_hidden)(x_in)
-            x_in = self.activation()(x_in)
+        if isinstance(self.layer_configuration, tuple) is True:
+            n_layer, n_hidden = self.layer_configuration
+            for layer in range(n_layer):
+                x_in = keras.layers.Dense(n_hidden, kernel_initializer=self.kernel_initializer,
+                                          kernel_regularizer=self.kernel_regularizer)(x_in)
+                x_in = self.activation(name=f"{self.activation_name}_{layer + 1}")(x_in)
+                if self.dropout is not None:
+                    x_in = self.dropout(self.dropout_rate)(x_in)
+        else:
+            assert isinstance(self.layer_configuration, list) is True
+            for layer, n_hidden in enumerate(self.layer_configuration):
+                x_in = keras.layers.Dense(n_hidden, kernel_initializer=self.kernel_initializer,
+                                          kernel_regularizer=self.kernel_regularizer)(x_in)
+                x_in = self.activation(name=f"{self.activation_name}_{layer + 1}")(x_in)
+                if self.dropout is not None:
+                    x_in = self.dropout(self.dropout_rate)(x_in)
         x_in = keras.layers.Dense(self._output_shape)(x_in)
-        out = self.activation_output()(x_in)
+        out = self.activation_output(name=f"{self.activation_output_name}_output")(x_in)
         self.model = keras.Model(inputs=x_input, outputs=[out])
 
+    def set_compile_options(self):
+        self.compile_options = {"loss": [custom_loss([keras.losses.mean_squared_error, var_loss])],
+                                "metrics": ["mse", "mae", var_loss]}
+
+
+class FCN_64_32_16(FCN):
+    """
+    A customised model 4 Dense layers (64, 32, 16, window_lead_time), where the last layer is the output layer depending
+    on the window_lead_time parameter.
+    """
+
+    _requirements = ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad"]
+
+    def __init__(self, input_shape: list, output_shape: list, **kwargs):
+        """
+        Sets model and loss depending on the given arguments.
+
+        :param input_shape: list of input shapes (expect len=1 with shape=(window_hist, station, variables))
+        :param output_shape: list of output shapes (expect len=1 with shape=(window_forecast))
+        """
+        lr = kwargs.pop("lr", 1e-2)
+        super().__init__(input_shape, output_shape, activation="prelu", activation_output="linear",
+                         layer_configuration=[64, 32, 16], optimizer="adam", lr=lr, **kwargs)
+
     def set_compile_options(self):
         self.compile_options = {"loss": [keras.losses.mean_squared_error], "metrics": ["mse", "mae"]}
+
+    def _update_model_name(self):
+        self.model_name = "FCN"
+        super()._update_model_name()
diff --git a/mlair/model_modules/loss.py b/mlair/model_modules/loss.py
index bcb85282d0fa15f18ebd65a89e4020c2a0170224..ba871e983ecfa1e91676d53b834ebd622c00fe49 100644
--- a/mlair/model_modules/loss.py
+++ b/mlair/model_modules/loss.py
@@ -20,3 +20,21 @@ def l_p_loss(power: int) -> Callable:
         return K.mean(K.pow(K.abs(y_pred - y_true), power), axis=-1)
 
     return loss
+
+
+def var_loss(y_true, y_pred) -> Callable:
+    return K.mean(K.square(K.var(y_true) - K.var(y_pred)))
+
+
+def custom_loss(loss_list, loss_weights=None) -> Callable:
+    n = len(loss_list)
+    if loss_weights is None:
+        loss_weights = [1. / n for _ in range(n)]
+    else:
+        assert len(loss_weights) == n
+        loss_weights = [w / sum(loss_weights) for w in loss_weights]
+
+    def loss(y_true, y_pred):
+        return sum([loss_weights[i] * loss_list[i](y_true, y_pred) for i in range(n)])
+
+    return loss
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index 3b9b563426a80816f7cf1ea9e114a8395d9fbba0..73aebb008ebf1f61eb2878293fc160cf549d19cb 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -306,7 +306,7 @@ class PostProcessing(RunEnvironment):
         try:
             if ("filter" in self.test_data[0].get_X(as_numpy=False)[0].coords) and (
                     "PlotSeparationOfScales" in plot_list):
-                filter_dim = self.data_store.get("filter_dim", None)
+                filter_dim = self.data_store.get_default("filter_dim", None)
                 PlotSeparationOfScales(self.test_data, plot_folder=self.plot_path, time_dim=time_dim,
                                        window_dim=window_dim, target_dim=target_dim, **{"filter_dim": filter_dim})
         except Exception as e:
diff --git a/requirements.txt b/requirements.txt
index b0a6e7f59896fd0edf08977ee553c803f6c2e960..85655e237f8e10e98f77c379be6acd0a7bb65d46 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,6 +2,7 @@ absl-py==0.11.0
 appdirs==1.4.4
 astor==0.8.1
 attrs==20.3.0
+bottleneck==1.3.2
 cached-property==1.5.2
 certifi==2020.12.5
 cftime==1.4.1
@@ -9,6 +10,7 @@ chardet==4.0.0
 coverage==5.4
 cycler==0.10.0
 dask==2021.2.0
+dill==0.3.3
 fsspec==0.8.5
 gast==0.4.0
 grpcio==1.35.0
diff --git a/requirements_gpu.txt b/requirements_gpu.txt
index 35fe0d5ee2a03f01737bc185d2a5bbaf26383806..cc189496bdf4e1e1ee86902a1953c2058d58c8e4 100644
--- a/requirements_gpu.txt
+++ b/requirements_gpu.txt
@@ -2,6 +2,7 @@ absl-py==0.11.0
 appdirs==1.4.4
 astor==0.8.1
 attrs==20.3.0
+bottleneck==1.3.2
 cached-property==1.5.2
 certifi==2020.12.5
 cftime==1.4.1
@@ -9,6 +10,7 @@ chardet==4.0.0
 coverage==5.4
 cycler==0.10.0
 dask==2021.2.0
+dill==0.3.3
 fsspec==0.8.5
 gast==0.4.0
 grpcio==1.35.0
diff --git a/test/test_data_handler/test_data_handler_mixed_sampling.py b/test/test_data_handler/test_data_handler_mixed_sampling.py
index d2f9ce00224a61815c89e44b7c37a667d239b2f5..2a6553b7f495bb4eb8aeddf7c39f2f2517edc967 100644
--- a/test/test_data_handler/test_data_handler_mixed_sampling.py
+++ b/test/test_data_handler/test_data_handler_mixed_sampling.py
@@ -37,7 +37,7 @@ class TestDataHandlerMixedSamplingSingleStation:
         req = object.__new__(DataHandlerSingleStation)
         assert sorted(obj._requirements) == sorted(remove_items(req.requirements(), "station"))
 
-    @mock.patch("mlair.data_handler.data_handler_mixed_sampling.DataHandlerMixedSamplingSingleStation.setup_samples")
+    @mock.patch("mlair.data_handler.data_handler_single_station.DataHandlerSingleStation.setup_samples")
     def test_init(self, mock_super_init):
         obj = DataHandlerMixedSamplingSingleStation("first_arg", "second", {}, test=23, sampling="hourly",
                                                     interpolation_limit=(1, 10))
diff --git a/test/test_helpers/test_helpers.py b/test/test_helpers/test_helpers.py
index f2e2b341afa424ce351c0253f41c75e362b77eba..91f2278ae7668b623f8d2434ebac7e959dc9c805 100644
--- a/test/test_helpers/test_helpers.py
+++ b/test/test_helpers/test_helpers.py
@@ -175,7 +175,7 @@ class TestSelectFromDict:
 
     @pytest.fixture
     def dictionary(self):
-        return {"a": 1, "b": 23, "c": "last"}
+        return {"a": 1, "b": 23, "c": "last", "e": None}
 
     def test_select(self, dictionary):
         assert select_from_dict(dictionary, "c") == {"c": "last"}
@@ -186,6 +186,10 @@ class TestSelectFromDict:
         with pytest.raises(AssertionError):
             select_from_dict(["we"], "now")
 
+    def test_select_remove_none(self, dictionary):
+        assert select_from_dict(dictionary, ["a", "e"]) == {"a": 1, "e": None}
+        assert select_from_dict(dictionary, ["a", "e"], remove_none=True) == {"a": 1}
+
 
 class TestRemoveItems:
 
diff --git a/test/test_model_modules/test_loss.py b/test/test_model_modules/test_loss.py
index e54e0b00de4a71d241f30e0b6b0c1a2e8fa1a19c..c993830c5290c9beeec392dfd806354ca02eb490 100644
--- a/test/test_model_modules/test_loss.py
+++ b/test/test_model_modules/test_loss.py
@@ -1,10 +1,12 @@
 import keras
 import numpy as np
 
-from mlair.model_modules.loss import l_p_loss
+from mlair.model_modules.loss import l_p_loss, var_loss, custom_loss
 
+import pytest
 
-class TestLoss:
+
+class TestLPLoss:
 
     def test_l_p_loss(self):
         model = keras.Sequential()
@@ -14,4 +16,42 @@ class TestLoss:
         assert hist.history['loss'][0] == 1.25
         model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(3))
         hist = model.fit(np.array([1, 0, -2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=1)
-        assert hist.history['loss'][0] == 2.25
\ No newline at end of file
+        assert hist.history['loss'][0] == 2.25
+
+
+class TestVarLoss:
+
+    def test_var_loss(self):
+        model = keras.Sequential()
+        model.add(keras.layers.Lambda(lambda x: x, input_shape=(None,)))
+        model.compile(optimizer=keras.optimizers.Adam(), loss=var_loss)
+        hist = model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=1)
+        assert hist.history['loss'][0] == 0.140625
+
+
+class TestCustomLoss:
+
+    def test_custom_loss_no_weights(self):
+        cust_loss = custom_loss([l_p_loss(2), var_loss])
+        model = keras.Sequential()
+        model.add(keras.layers.Lambda(lambda x: x, input_shape=(None,)))
+        model.compile(optimizer=keras.optimizers.Adam(), loss=cust_loss)
+        hist = model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=1)
+        assert hist.history['loss'][0] == (0.5 * 0.140625 + 0.5 * 1.25)
+
+    @pytest.mark.parametrize("weights", [[0.3, 0.7], [0.5, 0.5], [1, 1], [4, 1]])
+    def test_custom_loss_with_weights(self, weights):
+        cust_loss = custom_loss([l_p_loss(2), var_loss], weights)
+        model = keras.Sequential()
+        model.add(keras.layers.Lambda(lambda x: x, input_shape=(None,)))
+        model.compile(optimizer=keras.optimizers.Adam(), loss=cust_loss)
+        hist = model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=1)
+        weights_adjusted = list(map(lambda x: x / sum(weights), weights))
+        expected = (weights_adjusted[0] * 1.25 + weights_adjusted[1] * 0.140625)
+        assert np.testing.assert_almost_equal(hist.history['loss'][0], expected, decimal=6) is None
+
+    def test_custom_loss_invalid_weights(self):
+        with pytest.raises(AssertionError):
+            custom_loss([l_p_loss(2), var_loss], [0.3])
+        with pytest.raises(AssertionError):
+            custom_loss([l_p_loss(2), var_loss], [0.4, 3, 1])