diff --git a/CHANGELOG.md b/CHANGELOG.md
index 266cb33ec8666099ffcb638ff85d814d7e2cf184..988e3e5a7863868cead1a2fec7c7b6d1c750b8d8 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,6 +1,29 @@
 # Changelog
 All notable changes to this project will be documented in this file.
 
+## v2.1.0 -  2022-06-07  - new evaluation metrics and improved training
+
+### general:
+* new evaluation metrics, IOA and MNMB
+* advanced train options for early stopping
+* reduced execution time by refactoring
+
+### new features:
+* uncertainty estimation of MSE is now applied for each season separately (#374)
+* added different configurations of early stopping to use either last trained or best epoch (#378)
+* train monitoring plots now add a star for best epoch when using early stopping (#367)
+* new evaluation metric index of agreement, IOA (#376)
+* new evaluation metric modified normalised mean bias, MNMB (#380)
+* new plot available that shows temporal evolution of MSE for each station (#381)
+
+### technical:
+* reduced loading of forecast path from data store (#328)
+* bug fix for not catched error during transformation (#385)
+* bug fix for data handler with climate and fir filter leading to calculate transformation always with fir filter (#387)
+* improved duration for latex report creation at end of preprocessing (#388)
+* enhanced speed for make prediction in postprocessing (#389)
+* fix to always create version badge from version and not from tag name (#382)
+
 ## v2.0.0 -  2022-04-08  - tf2 usage, new model classes, and improved uncertainty estimate
 
 ### general:
diff --git a/CI/create_version_badge.sh b/CI/create_version_badge.sh
index c01bf913073c80f7b838c1a8afe93aa0bcd77fed..c7a85af2b89eccb48679601dffe6a31396739cfc 100644
--- a/CI/create_version_badge.sh
+++ b/CI/create_version_badge.sh
@@ -1,6 +1,6 @@
 #!/bin/bash
 
-VERSION="$(git describe --tags $(git rev-list --tags --max-count=1))"
+VERSION="$(git describe master)"
 COLOR="blue"
 BADGE_NAME="version"
 
diff --git a/README.md b/README.md
index 8decf00b29f91e0a3a014bbf57e92aff12c5e035..792c6d4a06564eb050467f271f660761ec4d3d71 100644
--- a/README.md
+++ b/README.md
@@ -34,7 +34,7 @@ HPC systems, see [here](#special-instructions-for-installation-on-jülich-hpc-sy
 * Installation of **MLAir**:
     * Either clone MLAir from the [gitlab repository](https://gitlab.jsc.fz-juelich.de/esde/machine-learning/mlair.git) 
       and use it without installation (beside the requirements) 
-    * or download the distribution file ([current version](https://gitlab.jsc.fz-juelich.de/esde/machine-learning/mlair/-/blob/master/dist/mlair-2.0.0-py3-none-any.whl)) 
+    * or download the distribution file ([current version](https://gitlab.jsc.fz-juelich.de/esde/machine-learning/mlair/-/blob/master/dist/mlair-2.1.0-py3-none-any.whl)) 
       and install it via `pip install <dist_file>.whl`. In this case, you can simply import MLAir in any python script 
       inside your virtual environment using `import mlair`.
 
diff --git a/dist/mlair-2.1.0-py3-none-any.whl b/dist/mlair-2.1.0-py3-none-any.whl
new file mode 100644
index 0000000000000000000000000000000000000000..b5069f2ae900ff7bf43428d3adba8a50be742588
Binary files /dev/null and b/dist/mlair-2.1.0-py3-none-any.whl differ
diff --git a/docs/_source/installation.rst b/docs/_source/installation.rst
index 6ac4937e6a729c12e54007aa32f0e59635289fdd..6cbf8c424bdd29470c23eb95a9b5d3a5071cf39f 100644
--- a/docs/_source/installation.rst
+++ b/docs/_source/installation.rst
@@ -27,7 +27,7 @@ Installation of MLAir
 * Install all requirements from `requirements.txt <https://gitlab.jsc.fz-juelich.de/esde/machine-learning/mlair/-/blob/master/requirements.txt>`_
   preferably in a virtual environment
 * Either clone MLAir from the `gitlab repository <https://gitlab.jsc.fz-juelich.de/esde/machine-learning/mlair.git>`_
-* or download the distribution file (`current version <https://gitlab.jsc.fz-juelich.de/esde/machine-learning/mlair/-/blob/master/dist/mlair-2.0.0-py3-none-any.whl>`_)
+* or download the distribution file (`current version <https://gitlab.jsc.fz-juelich.de/esde/machine-learning/mlair/-/blob/master/dist/mlair-2.1.0-py3-none-any.whl>`_)
   and install it via :py:`pip install <dist_file>.whl`. In this case, you can simply
   import MLAir in any python script inside your virtual environment using :py:`import mlair`.
 
diff --git a/mlair/__init__.py b/mlair/__init__.py
index 2ca5c3ab96fb3f96fa2343efab02860d465db870..901947e5313a183e3909687b1fea0096075f836c 100644
--- a/mlair/__init__.py
+++ b/mlair/__init__.py
@@ -1,6 +1,6 @@
 __version_info__ = {
     'major': 2,
-    'minor': 0,
+    'minor': 1,
     'micro': 0,
 }
 
diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py
index ca569720dc41d95621d0613a2170cc4d9d46c082..b630261dbf58d7402f8c3cacaee153347ad4f1e3 100644
--- a/mlair/configuration/defaults.py
+++ b/mlair/configuration/defaults.py
@@ -2,6 +2,9 @@ __author__ = "Lukas Leufen"
 __date__ = '2020-06-25'
 
 
+import numpy as np
+
+
 DEFAULT_STATIONS = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
 DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values',
                         'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values',
@@ -24,6 +27,8 @@ DEFAULT_EXTREMES_ON_RIGHT_TAIL_ONLY = False
 DEFAULT_PERMUTE_DATA = False
 DEFAULT_BATCH_SIZE = int(256 * 2)
 DEFAULT_EPOCHS = 20
+DEFAULT_EARLY_STOPPING_EPOCHS = np.inf
+DEFAULT_RESTORE_BEST_MODEL_WEIGHTS = True
 DEFAULT_TARGET_VAR = "o3"
 DEFAULT_TARGET_DIM = "variables"
 DEFAULT_WINDOW_LEAD_TIME = 3
diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py
index 0bdd9b216073bd6d045233afb3fd945718117a98..84596ad081b922a92a91b3df0513a4e730b8eb53 100644
--- a/mlair/data_handler/data_handler_mixed_sampling.py
+++ b/mlair/data_handler/data_handler_mixed_sampling.py
@@ -316,10 +316,17 @@ class DataHandlerMixedSamplingWithClimateAndFirFilter(DataHandlerMixedSamplingWi
 
     @classmethod
     def _split_chem_and_meteo_variables(cls, **kwargs):
+        """
+        Select all used variables and split them into categories chem and other.
+
+        Chemical variables are indicated by `cls.data_handler_climate_fir.chem_vars`. To indicate used variables, this
+        method uses 1) parameter `variables`, 2) keys from `statistics_per_var`, 3) keys from
+        `cls.data_handler_climate_fir.DEFAULT_VAR_ALL_DICT`. Option 3) is also applied if 1) or 2) are given but None.
+        """
         if "variables" in kwargs:
             variables = kwargs.get("variables")
         elif "statistics_per_var" in kwargs:
-            variables = kwargs.get("statistics_per_var")
+            variables = kwargs.get("statistics_per_var").keys()
         else:
             variables = None
         if variables is None:
@@ -348,14 +355,7 @@ class DataHandlerMixedSamplingWithClimateAndFirFilter(DataHandlerMixedSamplingWi
                 cls.prepare_build(sp_keys, chem_vars, cls.chem_indicator)
                 sp_chem_unfiltered = cls.data_handler_unfiltered(station, **sp_keys)
         if len(meteo_vars) > 0:
-            if cls.data_handler_fir_pos is None:
-                if "extend_length_opts" in kwargs:
-                    if isinstance(kwargs["extend_length_opts"], dict) and cls.meteo_indicator not in kwargs["extend_length_opts"].keys():
-                        cls.data_handler_fir_pos = 0  # use faster fir version without climate estimate
-                    else:
-                        cls.data_handler_fir_pos = 1  # use slower fir version with climate estimate
-                else:
-                    cls.data_handler_fir_pos = 0  # use faster fir version without climate estimate
+            cls.set_data_handler_fir_pos(**kwargs)
             sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler_fir[cls.data_handler_fir_pos].requirements() if k in kwargs}
             sp_keys = cls.build_update_transformation(sp_keys, dh_type="filtered_meteo")
             cls.prepare_build(sp_keys, meteo_vars, cls.meteo_indicator)
@@ -369,8 +369,36 @@ class DataHandlerMixedSamplingWithClimateAndFirFilter(DataHandlerMixedSamplingWi
         dp_args = {k: copy.deepcopy(kwargs[k]) for k in cls.own_args("id_class") if k in kwargs}
         return cls(sp_chem, sp_meteo, sp_chem_unfiltered, sp_meteo_unfiltered, chem_vars, meteo_vars, **dp_args)
 
+    @classmethod
+    def set_data_handler_fir_pos(cls, **kwargs):
+        """
+        Set position of fir data handler to use either faster FIR version or slower climate FIR.
+
+        This method will set data handler indicator to 0 if either no parameter "extend_length_opts" is given or the
+        parameter is of type dict but has no entry for the meteo_indicator. In all other cases, indicator is set to 1.
+        """
+        p_name = "extend_length_opts"
+        if cls.data_handler_fir_pos is None:
+            if p_name in kwargs:
+                if isinstance(kwargs[p_name], dict) and cls.meteo_indicator not in kwargs[p_name].keys():
+                    cls.data_handler_fir_pos = 0  # use faster fir version without climate estimate
+                else:
+                    cls.data_handler_fir_pos = 1  # use slower fir version with climate estimate
+            else:
+                cls.data_handler_fir_pos = 0  # use faster fir version without climate estimate
+
     @classmethod
     def prepare_build(cls, kwargs, var_list, var_type):
+        """
+        Prepares for build of class.
+
+        `variables` parameter is updated by `var_list`, which should only include variables of a specific type (e.g.
+        only chemical variables) indicated by `var_type`. Furthermore, this method cleans the `kwargs` dictionary as
+        follows: For all parameters provided as dict to separate between chem and meteo options (dict must have keys
+        from `cls.chem_indicator` and/or `cls.meteo_indicator`), this parameter is removed from kwargs and its value
+        related to `var_type` added again. In case there is no value for given `var_type`, the parameter is not added
+        at all (as this parameter is assumed to affect only other types of variables).
+        """
         kwargs.update({"variables": var_list})
         for k in list(kwargs.keys()):
             v = kwargs[k]
@@ -382,17 +410,6 @@ class DataHandlerMixedSamplingWithClimateAndFirFilter(DataHandlerMixedSamplingWi
                     except KeyError:
                         pass
 
-    @staticmethod
-    def adjust_window_opts(key: str, parameter_name: str, kwargs: dict):
-        try:
-            if parameter_name in kwargs:
-                window_opt = kwargs.pop(parameter_name)
-                if isinstance(window_opt, dict):
-                    window_opt = window_opt[key]
-                kwargs[parameter_name] = window_opt
-        except KeyError:
-            pass
-
     def _create_collection(self):
         collection = super()._create_collection()
         if self.id_class_other is not None:
@@ -419,9 +436,10 @@ class DataHandlerMixedSamplingWithClimateAndFirFilter(DataHandlerMixedSamplingWi
 
         # meteo transformation
         if len(meteo_vars) > 0:
+            cls.set_data_handler_fir_pos(**kwargs)
             kwargs_meteo = copy.deepcopy(kwargs)
             cls.prepare_build(kwargs_meteo, meteo_vars, cls.meteo_indicator)
-            dh_transformation = (cls.data_handler_fir[cls.data_handler_fir_pos or 0], cls.data_handler_unfiltered)
+            dh_transformation = (cls.data_handler_fir[cls.data_handler_fir_pos], cls.data_handler_unfiltered)
             transformation_meteo = super().transformation(set_stations, tmp_path=tmp_path,
                                                           dh_transformation=dh_transformation, **kwargs_meteo)
 
diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py
index d158726e5f433d40cfa272e6a9c7f808057f88e4..300e0435c4e8441e299675319e2c72604ebb3200 100644
--- a/mlair/data_handler/default_data_handler.py
+++ b/mlair/data_handler/default_data_handler.py
@@ -125,8 +125,9 @@ class DefaultDataHandler(AbstractDataHandler):
 
     def get_data(self, upsampling=False, as_numpy=True):
         self._load()
-        X = self.get_X(upsampling, as_numpy)
-        Y = self.get_Y(upsampling, as_numpy)
+        as_numpy_X, as_numpy_Y = as_numpy if isinstance(as_numpy, tuple) else (as_numpy, as_numpy)
+        X = self.get_X(upsampling, as_numpy_X)
+        Y = self.get_Y(upsampling, as_numpy_Y)
         self._reset_data()
         return X, Y
 
@@ -378,7 +379,7 @@ def f_proc(data_handler, station, return_strategy="", tmp_path=None, **sp_keys):
     assert return_strategy in ["result", "reference"]
     try:
         res = data_handler(station, **sp_keys)
-    except (AttributeError, EmptyQueryResult, KeyError, ValueError) as e:
+    except (AttributeError, EmptyQueryResult, KeyError, ValueError, IndexError) as e:
         logging.info(f"remove station {station} because it raised an error: {e}")
         res = None
     if return_strategy == "result":
diff --git a/mlair/data_handler/iterator.py b/mlair/data_handler/iterator.py
index e353f84d85a0871b00964899efb2a79bf555aefc..3fc25a90f861c65d38aa6b7019095210035d4c2d 100644
--- a/mlair/data_handler/iterator.py
+++ b/mlair/data_handler/iterator.py
@@ -144,8 +144,8 @@ class KerasIterator(keras.utils.Sequence):
         mod_rank = self._get_model_rank()
         for data in self._collection:
             logging.debug(f"prepare batches for {str(data)}")
-            X = data.get_X(upsampling=self.upsampling)
-            Y = [data.get_Y(upsampling=self.upsampling)[0] for _ in range(mod_rank)]
+            X, _Y = data.get_data(upsampling=self.upsampling)
+            Y = [_Y[0] for _ in range(mod_rank)]
             if self.upsampling:
                 X, Y = self._permute_data(X, Y)
             if remaining is not None:
diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py
index 8104c7c50517e05be14b05aaa9cea8d0e5ba32f4..b583cf7dc473db96181f88b0ab26e60ee225240d 100644
--- a/mlair/helpers/helpers.py
+++ b/mlair/helpers/helpers.py
@@ -122,6 +122,21 @@ def float_round(number: float, decimals: int = 0, round_type: Callable = math.ce
     return round_type(number * multiplier) / multiplier
 
 
+def relative_round(x: float, sig: int) -> float:
+    """
+    Round small numbers according to given "significance".
+
+    Example: relative_round(0.03112, 2) -> 0.031, relative_round(0.03112, 1) -> 0.03
+
+    :params x: number to round
+    :params sig: "significance" to determine number of decimals
+
+    :return: rounded number
+    """
+    assert sig >= 1
+    return round(x, sig-int(np.floor(np.log10(abs(x))))-1)
+
+
 def remove_items(obj: Union[List, Dict, Tuple], items: Any):
     """
     Remove item(s) from either list, tuple or dictionary.
diff --git a/mlair/helpers/statistics.py b/mlair/helpers/statistics.py
index 57143c9aed0e730c81adbed33f7ba62fea39b298..7633a2a9c1842219d7af7b9c7b2b4f23a034cbdf 100644
--- a/mlair/helpers/statistics.py
+++ b/mlair/helpers/statistics.py
@@ -11,6 +11,8 @@ import pandas as pd
 from typing import Union, Tuple, Dict, List
 import itertools
 from collections import OrderedDict
+from mlair.helpers import to_list
+
 
 Data = Union[xr.DataArray, pd.DataFrame]
 
@@ -211,13 +213,42 @@ def mean_absolute_error(a, b, dim=None):
     return np.abs(a - b).mean(dim)
 
 
+def index_of_agreement(a, b, dim=None):
+    """Calculate index of agreement (IOA) where a is the forecast and b the reference (e.g. observation)."""
+    num = (np.square(b - a)).sum(dim)
+    b_mean = (b * np.ones(1)).mean(dim)
+    den = (np.square(np.abs(b - b_mean) + np.abs(a - b_mean))).sum(dim)
+    frac = num / den
+    # issue with 0/0 division for exactly equal arrays
+    if isinstance(frac, (int, float)):
+        frac = 0 if num == 0 else frac
+    else:
+        frac[num == 0] = 0
+    return 1 - frac
+
+
+def modified_normalized_mean_bias(a, b, dim=None):
+    """Calculate modified normalized mean bias (MNMB) where a is the forecast and b the reference (e.g. observation)."""
+    N = np.count_nonzero(a) if len(a.shape) == 1 else a.notnull().sum(dim)
+    return 2 * ((a - b) / (a + b)).sum(dim) / N
+
+
 def calculate_error_metrics(a, b, dim):
-    """Calculate MSE, RMSE, and MAE. Additionally return number of used values for calculation."""
+    """Calculate MSE, RMSE, MAE, IOA, and MNMB. Additionally, return number of used values for calculation.
+
+    :param a: forecast data to calculate metrics for
+    :param b: reference (e.g. observation)
+    :param dim: dimension to calculate metrics along
+
+    :returns: dict with results for all metrics indicated by lowercase metric short name
+    """
     mse = mean_squared_error(a, b, dim)
     rmse = np.sqrt(mse)
     mae = mean_absolute_error(a, b, dim)
+    ioa = index_of_agreement(a, b, dim)
+    mnmb = modified_normalized_mean_bias(a, b, dim)
     n = (a - b).notnull().sum(dim)
-    return {"mse": mse, "rmse": rmse, "mae": mae, "n": n}
+    return {"mse": mse, "rmse": rmse, "mae": mae, "ioa": ioa, "mnmb": mnmb, "n": n}
 
 
 def mann_whitney_u_test(data: pd.DataFrame, reference_col_name: str, **kwargs):
@@ -540,7 +571,6 @@ def create_single_bootstrap_realization(data: xr.DataArray, dim_name_time: str)
     :param dim_name_time: name of time dimension
     :return: bootstrapped realization of data
     """
-
     num_of_blocks = data.coords[dim_name_time].shape[0]
     boot_idx = np.random.choice(num_of_blocks, size=num_of_blocks, replace=True)
     return data.isel({dim_name_time: boot_idx})
@@ -556,7 +586,7 @@ def calculate_average(data: xr.DataArray, **kwargs) -> xr.DataArray:
 
 
 def create_n_bootstrap_realizations(data: xr.DataArray, dim_name_time: str, dim_name_model: str, n_boots: int = 1000,
-                                    dim_name_boots: str = 'boots') -> xr.DataArray:
+                                    dim_name_boots: str = 'boots', seasons: List = None) -> Dict[str, xr.DataArray]:
     """
     Create n bootstrap realizations and calculate averages across realizations
 
@@ -565,26 +595,23 @@ def create_n_bootstrap_realizations(data: xr.DataArray, dim_name_time: str, dim_
     :param dim_name_model: name of model dimension
     :param n_boots: number of bootstap realizations
     :param dim_name_boots: name of bootstap dimension
+    :param seasons: calculate errors for given seasons in addition (default None)
     :return:
     """
+    seasons = [] if seasons is None else to_list(seasons)  # assure seasons to be empty list if None
     res_dims = [dim_name_boots]
     dims = list(data.dims)
     other_dims = [v for v in dims if v in set(dims).difference([dim_name_time])]
     coords = {dim_name_boots: range(n_boots), **{dim_name: data.coords[dim_name] for dim_name in other_dims}}
     if len(dims) > 1:
         res_dims = res_dims + other_dims
-    res = xr.DataArray(np.nan, dims=res_dims, coords=coords)
+    realizations = {k: xr.DataArray(np.nan, dims=res_dims, coords=coords) for k in seasons + [""]}
     for boot in range(n_boots):
-        res[boot] = (calculate_average(
-            create_single_bootstrap_realization(data, dim_name_time=dim_name_time),
-            dim=dim_name_time, skipna=True))
-    return res
-
-
-
-
-
-
-
-
-
+        shuffled = create_single_bootstrap_realization(data, dim_name_time=dim_name_time)
+        realizations[""][boot] = calculate_average(shuffled, dim=dim_name_time, skipna=True)
+        for season in seasons:
+            assert season in ["DJF", "MAM", "JJA", "SON"]
+            sel = shuffled[dim_name_time].dt.season == season
+            realizations[season][boot] = calculate_average(shuffled.sel({dim_name_time: sel}),
+                                                           dim=dim_name_time, skipna=True)
+    return realizations
diff --git a/mlair/helpers/testing.py b/mlair/helpers/testing.py
index 9820b4956dac09e213df3b9addc317a00ee381b8..eb8982ae3625cfccedf894717eebf299faffb3ee 100644
--- a/mlair/helpers/testing.py
+++ b/mlair/helpers/testing.py
@@ -105,7 +105,10 @@ def get_all_args(*args, remove=None, add=None):
     return res
 
 
-def check_nested_equality(obj1, obj2):
+def check_nested_equality(obj1, obj2, precision=None):
+    """Check for equality in nested structures. Use precision to indicate number of decimals to check for consistency"""
+
+    assert precision is None or isinstance(precision, int)
 
     try:
         print(f"check type {type(obj1)} and {type(obj2)}")
@@ -116,22 +119,38 @@ def check_nested_equality(obj1, obj2):
             assert len(obj1) == len(obj2)
             for pos in range(len(obj1)):
                 print(f"check pos {obj1[pos]} and {obj2[pos]}")
-                assert check_nested_equality(obj1[pos], obj2[pos]) is True
+                assert check_nested_equality(obj1[pos], obj2[pos], precision) is True
         elif isinstance(obj1, dict):
             print(f"check keys {obj1.keys()} and {obj2.keys()}")
             assert sorted(obj1.keys()) == sorted(obj2.keys())
             for k in obj1.keys():
                 print(f"check pos {obj1[k]} and {obj2[k]}")
-                assert check_nested_equality(obj1[k], obj2[k]) is True
+                assert check_nested_equality(obj1[k], obj2[k], precision) is True
         elif isinstance(obj1, xr.DataArray):
-            print(f"check xr {obj1} and {obj2}")
-            assert xr.testing.assert_equal(obj1, obj2) is None
+            if precision is None:
+                print(f"check xr {obj1} and {obj2}")
+                assert xr.testing.assert_equal(obj1, obj2) is None
+            else:
+                print(f"check xr {obj1} and {obj2} with precision {precision}")
+                assert xr.testing.assert_allclose(obj1, obj2, atol=10**(-precision)) is None
         elif isinstance(obj1, np.ndarray):
-            print(f"check np {obj1} and {obj2}")
-            assert np.testing.assert_array_equal(obj1, obj2) is None
+            if precision is None:
+                print(f"check np {obj1} and {obj2}")
+                assert np.testing.assert_array_equal(obj1, obj2) is None
+            else:
+                print(f"check np {obj1} and {obj2} with precision {precision}")
+                assert np.testing.assert_array_almost_equal(obj1, obj2, decimal=precision) is None
         else:
-            print(f"check equal {obj1} and {obj2}")
-            assert obj1 == obj2
+            if isinstance(obj1, (int, float)) and isinstance(obj2, (int, float)):
+                if precision is None:
+                    print(f"check number equal {obj1} and {obj2}")
+                    assert np.testing.assert_equal(obj1, obj2) is None
+                else:
+                    print(f"check number equal {obj1} and {obj2} with precision {precision}")
+                    assert np.testing.assert_almost_equal(obj1, obj2, decimal=precision) is None
+            else:
+                print(f"check equal {obj1} and {obj2}")
+                assert obj1 == obj2
     except AssertionError:
         return False
     return True
diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py
index 8b99acd0f5723d3b00ec1bd0098712753da21b52..39b0da5b49f470d11ea64b9ddd344b9ad11e2b7f 100644
--- a/mlair/model_modules/keras_extensions.py
+++ b/mlair/model_modules/keras_extensions.py
@@ -163,6 +163,8 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
     def __init__(self, *args, **kwargs):
         """Initialise ModelCheckpointAdvanced and set callbacks attribute."""
         self.callbacks = kwargs.pop("callbacks")
+        self.epoch_best = None
+        self.restore_best_weights = kwargs.pop("restore_best_weights", True)
         super().__init__(*args, **kwargs)
 
     def update_best(self, hist):
@@ -176,7 +178,19 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
 
         :param hist: The History object from the previous (interrupted) training.
         """
-        self.best = hist.history.get(self.monitor)[-1]
+        if self.restore_best_weights:
+            f = np.min if self.monitor_op.__name__ == "less" else np.max
+            f_loc = lambda x: np.where(x == f(x))[0][-1]
+            _d = hist.history.get(self.monitor)
+            loc = f_loc(_d)
+            assert f(_d) == _d[loc]
+            self.epoch_best = loc
+            self.best = _d[loc]
+            logging.info(f"Set best epoch {self.epoch_best + 1} with {self.monitor}={self.best}")
+        else:
+            _d = hist.history.get(self.monitor)[-1]
+            self.best = _d
+            logging.info(f"Set only best result ({self.monitor}={self.best}) without best epoch")
 
     def update_callbacks(self, callbacks):
         """
@@ -197,6 +211,8 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
                 if self.save_best_only:
                     current = logs.get(self.monitor)
                     if current == self.best:
+                        if self.restore_best_weights:
+                            self.epoch_best = epoch
                         if self.verbose > 0:  # pragma: no branch
                             print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
                         with open(file_path, "wb") as f:
diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py
index bd2012c3fe9f53e7e07bfb4bfc4cde096c2dc891..c7647ef5bf5b5b6c46eae9318c0fd99b294292c6 100644
--- a/mlair/plotting/postprocessing_plotting.py
+++ b/mlair/plotting/postprocessing_plotting.py
@@ -25,6 +25,7 @@ from mlair.helpers import TimeTrackingWrapper
 from mlair.plotting.abstract_plot_class import AbstractPlotClass
 from mlair.helpers.statistics import mann_whitney_u_test, represent_p_values_as_asteriks
 
+
 logging.getLogger('matplotlib').setLevel(logging.WARNING)
 
 
@@ -1095,7 +1096,7 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass):  # pragma: no cover
     def __init__(self, data: xr.DataArray, plot_folder: str = ".", model_type_dim: str = "type",
                  error_measure: str = "mse", error_unit: str = None, dim_name_boots: str = 'boots',
                  block_length: str = None, model_name: str = "NN", model_indicator: str = "nn",
-                 ahead_dim: str = "ahead", sampling: Union[str, Tuple[str]] = ""):
+                 ahead_dim: str = "ahead", sampling: Union[str, Tuple[str]] = "", season_annotation: str = None):
         super().__init__(plot_folder, "sample_uncertainty_from_bootstrap")
         self.default_plot_name = self.plot_name
         self.model_type_dim = model_type_dim
@@ -1105,6 +1106,7 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass):  # pragma: no cover
         self.error_unit = error_unit
         self.block_length = block_length
         self.model_name = model_name
+        _season = season_annotation or ""
         self.sampling = {"daily": "d", "hourly": "H"}.get(sampling[1] if isinstance(sampling, tuple) else sampling, "")
         data = self.rename_model_indicator(data, model_name, model_indicator)
         self.prepare_data(data)
@@ -1114,12 +1116,12 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass):  # pragma: no cover
 
         # plot raw metric (mse)
         for orientation, utest, agg_type in variants:
-            self._plot(orientation=orientation, apply_u_test=utest, agg_type=agg_type)
+            self._plot(orientation=orientation, apply_u_test=utest, agg_type=agg_type, season=_season)
 
         # plot root of metric (rmse)
         self._apply_root()
         for orientation, utest, agg_type in variants:
-            self._plot(orientation=orientation, apply_u_test=utest, agg_type=agg_type, tag="_sqrt")
+            self._plot(orientation=orientation, apply_u_test=utest, agg_type=agg_type, tag="_sqrt", season=_season)
 
         self._data_table = None
         self._n_boots = None
@@ -1148,9 +1150,10 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass):  # pragma: no cover
         self.error_measure = f"root {self.error_measure}"
         self.error_unit = self.error_unit.replace("$^2$", "")
 
-    def _plot(self, orientation: str = "v", apply_u_test: bool = False, agg_type="single", tag=""):
+    def _plot(self, orientation: str = "v", apply_u_test: bool = False, agg_type="single", tag="", season=""):
         self.plot_name = self.default_plot_name + {"v": "_vertical", "h": "_horizontal"}[orientation] + \
-                         {True: "_u_test", False: ""}[apply_u_test] + "_" + agg_type + tag
+                         {True: "_u_test", False: ""}[apply_u_test] + "_" + agg_type + tag + \
+                         {"": ""}.get(season, f"_{season}")
         if apply_u_test is True and agg_type == "multi":
             return  # not implemented
         data_table = self._data_table
@@ -1198,10 +1201,13 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass):  # pragma: no cover
             ax.set_xlabel(f"{self.error_measure} (in {self.error_unit})")
             xlims = list(ax.get_xlim())
             ax.set_xlim([xlims[0], xlims[1] * 1.015])
-
         else:
             raise ValueError(f"orientation must be `v' or `h' but is: {orientation}")
-        text = f"n={n_boots}" if self.block_length is None else f"{self.block_length}, n={n_boots}"
+        text = f"n={n_boots}"
+        if self.block_length is not None:
+            text = f"{self.block_length}, {text}"
+        if len(season) > 0:
+            text = f"{season}, {text}"
         loc = "lower left"
         text_box = AnchoredText(text, frameon=True, loc=loc, pad=0.5, bbox_to_anchor=(0., 1.0),
                                 bbox_transform=ax.transAxes)
@@ -1234,6 +1240,85 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass):  # pragma: no cover
         return ax
 
 
+@TimeTrackingWrapper
+class PlotTimeEvolutionMetric(AbstractPlotClass):
+
+    def __init__(self, data: xr.DataArray, ahead_dim="ahead", model_type_dim="type", plot_folder=".",
+                 error_measure: str = "mse", error_unit: str = None, model_name: str = "NN",
+                 model_indicator: str = "nn", time_dim="index"):
+        super().__init__(plot_folder, "time_evolution_mse")
+        self.title = error_measure + f" (in {error_unit})" if error_unit is not None else ""
+        plot_name = self.plot_name
+        vmin = int(data.quantile(0.05))
+        vmax = int(data.quantile(0.95))
+        data = self._prepare_data(data, time_dim, model_type_dim, model_indicator, model_name)
+
+        for t in data[model_type_dim]:
+            # note: could be expanded to create plot per ahead step
+            plot_data = data.sel({model_type_dim: t}).mean(ahead_dim).to_pandas()
+            years = plot_data.columns.strftime("%Y").to_list()
+            months = plot_data.columns.strftime("%b").to_list()
+            plot_data.columns = plot_data.columns.strftime("%b %Y")
+            self.plot_name = f"{plot_name}_{t.values}"
+            self._plot(plot_data, years, months, vmin, vmax, str(t.values))
+
+    @staticmethod
+    def _find_nan_edge(data, time_dim):
+        coll = []
+        for i in data:
+            if bool(i) is False:
+                break
+            else:
+                coll.append(i[time_dim].values)
+        return coll
+
+    def _prepare_data(self, data, time_dim, model_type_dim, model_indicator, model_name):
+        # remove nans at begin and end
+        nan_locs = data.isnull().all(helpers.remove_items(data.dims, time_dim))
+        nans_at_end = self._find_nan_edge(reversed(nan_locs), time_dim)
+        nans_at_begin = self._find_nan_edge(nan_locs, time_dim)
+        data = data.drop(nans_at_begin + nans_at_end, time_dim)
+        # rename nn model
+        data[model_type_dim] = [v if v != model_indicator else model_name for v in data[model_type_dim].data.tolist()]
+        return data
+
+    @staticmethod
+    def _set_ticks(ax, years, months):
+        from matplotlib.ticker import IndexLocator
+        ax.xaxis.set_major_locator(IndexLocator(1, 0.5))
+        locs = ax.get_xticks(minor=False).tolist()[:len(months)]
+        ax.set_xticks(locs, minor=True)
+        ax.set_xticklabels([m[0] for m in months], minor=True, rotation=0)
+        locs_major = []
+        labels_major = []
+        for l, major, minor in zip(locs, years, months):
+            if minor == "Jan":
+                locs_major.append(l + 0.001)
+                labels_major.append(major)
+        if len(locs_major) == 0:  # in case there is less than a year and no Jan included
+            locs_major = locs[0] + 0.001
+            labels_major = years[0]
+        ax.set_xticks(locs_major)
+        ax.set_xticklabels(labels_major, minor=False, rotation=0)
+        ax.tick_params(axis="x", which="major", pad=15)
+
+    @staticmethod
+    def _aspect_cbar(val):
+        return min(max(1.25 * val + 7.5, 10), 30)
+
+    def _plot(self, data, years, months, vmin=None, vmax=None, subtitle=None):
+        fig, ax = plt.subplots(figsize=(max(data.shape[1] / 6, 12), max(data.shape[0] / 3.5, 2)))
+        data.sort_index(inplace=True)
+        sns.heatmap(data, linewidths=1, cmap="coolwarm", ax=ax, vmin=vmin, vmax=vmax,
+                    cbar_kws={"aspect": self._aspect_cbar(data.shape[0])})
+        # or cmap="Spectral_r", cmap="RdYlBu_r", cmap="coolwarm",
+        # square=True
+        self._set_ticks(ax, years, months)
+        ax.set(xlabel=None, ylabel=None, title=self.title if subtitle is None else f"{subtitle}\n{self.title}")
+        plt.tight_layout()
+        self._save()
+
+
 if __name__ == "__main__":
     stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
     path = "../../testrun_network/forecasts"
diff --git a/mlair/plotting/training_monitoring.py b/mlair/plotting/training_monitoring.py
index 39dd80651226519463d7b503fb612e43983d73cf..e651078fe66a5d95e6e01d5384865eee621ba6f4 100644
--- a/mlair/plotting/training_monitoring.py
+++ b/mlair/plotting/training_monitoring.py
@@ -11,6 +11,7 @@ import matplotlib.pyplot as plt
 import pandas as pd
 
 from mlair.model_modules.keras_extensions import LearningRateDecay
+from mlair.helpers.helpers import relative_round
 
 # matplotlib.use('Agg')
 history_object = Union[Dict, keras.callbacks.History]
@@ -27,7 +28,8 @@ class PlotModelHistory:
     parameter filename must include the absolute path for the plot.
     """
 
-    def __init__(self, filename: str, history: history_object, plot_metric: str = "loss", main_branch: bool = False):
+    def __init__(self, filename: str, history: history_object, plot_metric: str = "loss", main_branch: bool = False,
+                 epoch_best: int = None):
         """
         Set attributes and create plot.
 
@@ -37,12 +39,15 @@ class PlotModelHistory:
         :param plot_metric: the metric to plot (e.b. mean_squared_error, mse, mean_absolute_error, loss, default: loss)
         :param main_branch: switch between only looking for metrics that go with 'main' or for all occurrences (default:
             False -> look for losses from all branches, not only from main)
+        :param epoch_best: indicator at which epoch the best train result was achieved (should start counting at 0)
         """
         if isinstance(history, keras.callbacks.History):
             history = history.history
         self._data = pd.DataFrame.from_dict(history)
+        self._data.index += 1
         self._plot_metric = self._get_plot_metric(history, plot_metric, main_branch)
         self._additional_columns = self._filter_columns(history)
+        self._epoch_best = epoch_best
         self._plot(filename)
 
     def _get_plot_metric(self, history, plot_metric, main_branch, correct_names=True):
@@ -88,10 +93,19 @@ class PlotModelHistory:
         :param filename: name (including total path) of the plot to save.
         """
         ax = self._data[[self._plot_metric, f"val_{self._plot_metric}"]].plot(linewidth=0.7)
+        if self._epoch_best is not None:
+            ax.scatter(self._epoch_best+1, self._data[[f"val_{self._plot_metric}"]].iloc[self._epoch_best],
+                       s=100, marker="*", c="black")
         ax.set_yscale('log')
         if len(self._additional_columns) > 0:
             self._data[self._additional_columns].plot(linewidth=0.7, secondary_y=True, ax=ax, logy=True)
-        title = f"Model {self._plot_metric}: best = {self._data[[f'val_{self._plot_metric}']].min().values}"
+        if self._epoch_best is not None:
+            final_res = self._data[[f'val_{self._plot_metric}']].min().values[0]
+            annotation = f"best epoch {self._epoch_best}"
+        else:
+            final_res = self._data[[f'val_{self._plot_metric}']].values[-1][0]
+            annotation = "final"
+        title = f"Model {self._plot_metric} (val, {annotation}): {relative_round(final_res, 5)}"
         ax.set(xlabel="epoch", ylabel=self._plot_metric, title=title)
         ax.axhline(y=0, color="gray", linewidth=0.5)
         plt.tight_layout()
diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py
index df797ffc23370bf4f45bb2b4f76e5f71e9bd030f..d807db14c96a4a30fde791e54c8b1b32e519fb9c 100644
--- a/mlair/run_modules/experiment_setup.py
+++ b/mlair/run_modules/experiment_setup.py
@@ -23,7 +23,8 @@ from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT,
     DEFAULT_USE_MULTIPROCESSING, DEFAULT_USE_MULTIPROCESSING_ON_DEBUG, DEFAULT_MAX_NUMBER_MULTIPROCESSING, \
     DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_TYPE, DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_METHOD, DEFAULT_OVERWRITE_LAZY_DATA, \
     DEFAULT_UNCERTAINTY_ESTIMATE_BLOCK_LENGTH, DEFAULT_UNCERTAINTY_ESTIMATE_EVALUATE_COMPETITORS, \
-    DEFAULT_UNCERTAINTY_ESTIMATE_N_BOOTS, DEFAULT_DO_UNCERTAINTY_ESTIMATE
+    DEFAULT_UNCERTAINTY_ESTIMATE_N_BOOTS, DEFAULT_DO_UNCERTAINTY_ESTIMATE, DEFAULT_EARLY_STOPPING_EPOCHS, \
+    DEFAULT_RESTORE_BEST_MODEL_WEIGHTS
 from mlair.data_handler import DefaultDataHandler
 from mlair.run_modules.run_environment import RunEnvironment
 from mlair.model_modules.fully_connected_networks import FCN_64_32_16 as VanillaModel
@@ -178,6 +179,11 @@ class ExperimentSetup(RunEnvironment):
         (partly) trained model is lower than this parameter, training is continue. In case this number is higher than
         the given epochs parameter, no training is resumed. Epochs is set to 20 per default, but this value is just a
         placeholder that should be adjusted for a meaningful training.
+    :param early_stopping_epochs: number of consecutive epochs with no improvement on val loss to stop training. When
+        set to `np.inf` or not providing at all, training is not stopped before reaching `epochs`.
+    :param restore_best_model_weights: indicates whether to use model state with best val loss (if True) or model state
+        on ending of training (if False). The later depends on the parameters `epochs` and `early_stopping_epochs` which
+        trigger stopping of training.
     :param data_handler:
     :param data_origin:
     :param competitors: Provide names of reference models trained by MLAir that can be found in the `competitor_path`.
@@ -221,7 +227,9 @@ class ExperimentSetup(RunEnvironment):
                  feature_importance_n_boots: int = None, feature_importance_create_new_bootstraps: bool = None,
                  feature_importance_bootstrap_method=None, feature_importance_bootstrap_type=None,
                  data_path: str = None, batch_path: str = None, login_nodes=None,
-                 hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None,
+                 hpc_hosts=None, model=None, batch_size=None, epochs=None,
+                 early_stopping_epochs: int = None, restore_best_model_weights: bool = None,
+                 data_handler=None,
                  data_origin: Dict = None, competitors: list = None, competitor_path: str = None,
                  use_multiprocessing: bool = None, use_multiprocessing_on_debug: bool = None,
                  max_number_multiprocessing: int = None, start_script: Union[Callable, str] = None,
@@ -255,6 +263,9 @@ class ExperimentSetup(RunEnvironment):
         self._set_param("permute_data", permute_data or upsampling, scope="train")
         self._set_param("batch_size", batch_size, default=DEFAULT_BATCH_SIZE)
         self._set_param("epochs", epochs, default=DEFAULT_EPOCHS)
+        self._set_param("early_stopping_epochs", early_stopping_epochs, default=DEFAULT_EARLY_STOPPING_EPOCHS)
+        self._set_param("restore_best_model_weights", restore_best_model_weights,
+                        default=DEFAULT_RESTORE_BEST_MODEL_WEIGHTS)
 
         # set experiment name
         sampling = self._set_param("sampling", sampling, default=DEFAULT_SAMPLING)  # always related to output sampling
diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py
index 4e9f8fa4439e9885a6c16c2b2eccfee2c97fd936..eab8012b983a0676620bbc66f65ff79b31165aeb 100644
--- a/mlair/run_modules/model_setup.py
+++ b/mlair/run_modules/model_setup.py
@@ -11,6 +11,7 @@ from dill.source import getsource
 import tensorflow.keras as keras
 import pandas as pd
 import tensorflow as tf
+import numpy as np
 
 from mlair.model_modules.keras_extensions import HistoryAdvanced, EpoTimingCallback, CallbackHandler
 from mlair.run_modules.run_environment import RunEnvironment
@@ -117,18 +118,36 @@ class ModelSetup(RunEnvironment):
 
         Add all callbacks with the .add_callback statement. Finally, the advanced model checkpoint is added.
         """
-        lr = self.data_store.get_default("lr_decay", scope=self.scope, default=None)
-        hist = HistoryAdvanced()
-        epo_timing = EpoTimingCallback()
-        self.data_store.set("hist", hist, scope="model")
-        self.data_store.set("epo_timing", epo_timing, scope="model")
+        # create callback handler
         callbacks = CallbackHandler()
+
+        # add callback: learning rate
+        lr = self.data_store.get_default("lr_decay", scope=self.scope, default=None)
         if lr is not None:
             callbacks.add_callback(lr, self.callbacks_name % "lr", "lr")
+
+        # add callback: advanced history
+        hist = HistoryAdvanced()
+        self.data_store.set("hist", hist, scope="model")
         callbacks.add_callback(hist, self.callbacks_name % "hist", "hist")
+
+        # add callback: epo timing
+        epo_timing = EpoTimingCallback()
+        self.data_store.set("epo_timing", epo_timing, scope="model")
         callbacks.add_callback(epo_timing, self.callbacks_name % "epo_timing", "epo_timing")
+
+        # add callback: early stopping
+        patience = self.data_store.get_default("early_stopping_epochs", default=np.inf)
+        restore_best_weights = self.data_store.get_default("restore_best_model_weights", default=True)
+        assert bool(isinstance(patience, int) or np.isinf(patience)) is True
+        cb = tf.keras.callbacks.EarlyStopping(patience=patience, restore_best_weights=restore_best_weights)
+        callbacks.add_callback(cb, self.callbacks_name % "early_stopping", "early_stopping")
+
+        # create model checkpoint
         callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
-                                          save_best_only=True, mode='auto')
+                                          save_best_only=True, mode='auto', restore_best_weights=restore_best_weights)
+
+        # store callbacks
         self.data_store.set("callbacks", callbacks, self.scope)
 
     def load_model(self):
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index 8f6bf05d29b2534e4918d72aa59f91aace0ec982..00d82f3c6f48c3560e31d62b5bed4ddbd2bc49be 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -10,20 +10,19 @@ import sys
 import traceback
 from typing import Dict, Tuple, Union, List, Callable
 
-import tensorflow.keras as keras
 import numpy as np
 import pandas as pd
 import xarray as xr
 
 from mlair.configuration import path_config
 from mlair.data_handler import Bootstraps, KerasIterator
-from mlair.helpers.datastore import NameNotFoundInDataStore
+from mlair.helpers.datastore import NameNotFoundInDataStore, NameNotFoundInScope
 from mlair.helpers import TimeTracking, TimeTrackingWrapper, statistics, extract_value, remove_items, to_list, tables
 from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel
 from mlair.model_modules import AbstractModelClass
 from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotClimatologicalSkillScore, \
     PlotCompetitiveSkillScore, PlotTimeSeries, PlotFeatureImportanceSkillScore, PlotConditionalQuantiles, \
-    PlotSeparationOfScales, PlotSampleUncertaintyFromBootstrap
+    PlotSeparationOfScales, PlotSampleUncertaintyFromBootstrap, PlotTimeEvolutionMetric
 from mlair.plotting.data_insight_plotting import PlotStationMap, PlotAvailability, PlotAvailabilityHistogram, \
     PlotPeriodogram, PlotDataHistogram
 from mlair.run_modules.run_environment import RunEnvironment
@@ -41,7 +40,7 @@ class PostProcessing(RunEnvironment):
         #. create plots
 
     Required objects [scope] from data store:
-        * `best_model` [.] or locally saved model plus `model_name` [model] and `model` [model]
+        * `model` [.] or locally saved model plus `model_name` [model] and `model` [model]
         * `generator` [train, val, test, train_val]
         * `forecast_path` [.]
         * `plot_path` [postprocessing]
@@ -79,6 +78,7 @@ class PostProcessing(RunEnvironment):
         self.train_data = self.data_store.get("data_collection", "train")
         self.val_data = self.data_store.get("data_collection", "val")
         self.train_val_data = self.data_store.get("data_collection", "train_val")
+        self.forecast_path = self.data_store.get("forecast_path")
         self.plot_path: str = self.data_store.get("plot_path")
         self.target_var = self.data_store.get("target_var")
         self._sampling = self.data_store.get("sampling")
@@ -86,6 +86,8 @@ class PostProcessing(RunEnvironment):
         self.skill_scores = None
         self.feature_importance_skill_scores = None
         self.uncertainty_estimate = None
+        self.uncertainty_estimate_seasons = {}
+        self.block_mse_per_station = None
         self.competitor_path = self.data_store.get("competitor_path")
         self.competitors = to_list(self.data_store.get_default("competitors", default=[]))
         self.forecast_indicator = "nn"
@@ -148,16 +150,20 @@ class PostProcessing(RunEnvironment):
         block_length = self.data_store.get_default("block_length", default="1m", scope="uncertainty_estimate")
         evaluate_competitors = self.data_store.get_default("evaluate_competitors", default=True,
                                                            scope="uncertainty_estimate")
-        block_mse = self.calculate_block_mse(evaluate_competitors=evaluate_competitors, separate_ahead=separate_ahead,
-                                             block_length=block_length)
-        self.uncertainty_estimate = statistics.create_n_bootstrap_realizations(
+        block_mse, block_mse_per_station = self.calculate_block_mse(evaluate_competitors=evaluate_competitors,
+                                                                    separate_ahead=separate_ahead,
+                                                                    block_length=block_length)
+        self.block_mse_per_station = block_mse_per_station
+        estimate = statistics.create_n_bootstrap_realizations(
             block_mse, dim_name_time=self.index_dim, dim_name_model=self.model_type_dim,
-            dim_name_boots=self.uncertainty_estimate_boot_dim, n_boots=n_boots)
+            dim_name_boots=self.uncertainty_estimate_boot_dim, n_boots=n_boots, seasons=["DJF", "MAM", "JJA", "SON"])
+        self.uncertainty_estimate = estimate.pop("")
+        self.uncertainty_estimate_seasons = estimate
         self.report_sample_uncertainty()
 
     def report_sample_uncertainty(self, percentiles: list = None):
         """
-        Store raw results of uncertainty estimate and calculate aggregate statistcs and store as raw data but also as
+        Store raw results of uncertainty estimate and calculate aggregate statistics and store as raw data but also as
         markdown and latex.
         """
         report_path = os.path.join(self.data_store.get("experiment_path"), "latex_report")
@@ -166,31 +172,37 @@ class PostProcessing(RunEnvironment):
         # store raw results as nc
         file_name = os.path.join(report_path, "uncertainty_estimate_raw_results.nc")
         self.uncertainty_estimate.to_netcdf(path=file_name)
+        for season in self.uncertainty_estimate_seasons.keys():
+            file_name = os.path.join(report_path, f"uncertainty_estimate_raw_results_{season}.nc")
+            self.uncertainty_estimate_seasons[season].to_netcdf(path=file_name)
 
         # store statistics
         if percentiles is None:
             percentiles = [.05, .1, .25, .5, .75, .9, .95]
 
-        for ahead_steps in ["single", "multi"]:
-            if ahead_steps == "single":
-                try:
-                    df_descr = self.uncertainty_estimate.to_pandas().describe(percentiles=percentiles).astype("float32")
-                except ValueError:
-                    df_descr = self.uncertainty_estimate.mean(self.ahead_dim).to_pandas().describe(percentiles=percentiles).astype("float32")
-            else:
-                if self.ahead_dim not in self.uncertainty_estimate.dims:
-                    continue
-                df_descr = self.uncertainty_estimate.to_dataframe(self.model_type_dim).unstack().groupby(level=self.ahead_dim).describe(
-                    percentiles=percentiles).astype("float32")
-                df_descr = df_descr.stack(-1)
-                df_descr = df_descr.reorder_levels(df_descr.index.names[::-1])
-                df_sorter = ["count", "mean", "std", "min", *[f"{round(p * 100)}%" for p in percentiles], "max"]
-                df_descr = df_descr.loc[df_sorter]
-            column_format = tables.create_column_format_for_tex(df_descr)
-            file_name = os.path.join(report_path, f"uncertainty_estimate_statistics_{ahead_steps}.%s")
-            tables.save_to_tex(report_path, file_name % "tex", column_format=column_format, df=df_descr)
-            tables.save_to_md(report_path, file_name % "md", df=df_descr)
-            df_descr.to_csv(file_name % "csv", sep=";")
+        for season in [None] + list(self.uncertainty_estimate_seasons.keys()):
+            estimate = self.uncertainty_estimate if season is None else self.uncertainty_estimate_seasons[season]
+            affix = "" if season is None else f"_{season}"
+            for ahead_steps in ["single", "multi"]:
+                if ahead_steps == "single":
+                    try:
+                        df_descr = estimate.to_pandas().describe(percentiles=percentiles).astype("float32")
+                    except ValueError:
+                        df_descr = estimate.mean(self.ahead_dim).to_pandas().describe(percentiles=percentiles).astype("float32")
+                else:
+                    if self.ahead_dim not in estimate.dims:
+                        continue
+                    df_descr = estimate.to_dataframe(self.model_type_dim).unstack().groupby(level=self.ahead_dim).describe(
+                        percentiles=percentiles).astype("float32")
+                    df_descr = df_descr.stack(-1)
+                    df_descr = df_descr.reorder_levels(df_descr.index.names[::-1])
+                    df_sorter = ["count", "mean", "std", "min", *[f"{round(p * 100)}%" for p in percentiles], "max"]
+                    df_descr = df_descr.loc[df_sorter]
+                column_format = tables.create_column_format_for_tex(df_descr)
+                file_name = os.path.join(report_path, f"uncertainty_estimate_statistics_{ahead_steps}{affix}.%s")
+                tables.save_to_tex(report_path, file_name % "tex", column_format=column_format, df=df_descr)
+                tables.save_to_md(report_path, file_name % "md", df=df_descr)
+                df_descr.to_csv(file_name % "csv", sep=";")
 
     def calculate_block_mse(self, evaluate_competitors=True, separate_ahead=False, block_length="1m"):
         """
@@ -199,7 +211,6 @@ class PostProcessing(RunEnvironment):
         station or actual data contained. This is intended to analyze not only the robustness against the time but also
         against the number of observations and diversity ot stations.
         """
-        path = self.data_store.get("forecast_path")
         all_stations = self.data_store.get("stations", "test")
         start = self.data_store.get("start", "test")
         end = self.data_store.get("end", "test")
@@ -208,7 +219,7 @@ class PostProcessing(RunEnvironment):
         collector = []
         for station in all_stations:
             # test data
-            external_data = self._get_external_data(station, path)
+            external_data = self._get_external_data(station, self.forecast_path)
             if external_data is not None:
                 pass
             # competitors
@@ -227,12 +238,16 @@ class PostProcessing(RunEnvironment):
                 # calc mse for each block (single station)
                 mse = errors.resample(indexer={index_dim: block_length}).mean(skipna=True)
                 collector.append(mse.assign_coords({coll_dim: station}))
+
+        # combine all mse blocks
+        mse_blocks_per_station = xr.concat(collector, dim=coll_dim)
         # calc mse for each block (average over all stations)
-        mse_blocks = xr.concat(collector, dim=coll_dim).mean(dim=coll_dim, skipna=True)
+        mse_blocks = mse_blocks_per_station.mean(dim=coll_dim, skipna=True)
         # average also on ahead steps
         if separate_ahead is False:
             mse_blocks = mse_blocks.mean(dim=self.ahead_dim, skipna=True)
-        return mse_blocks
+            mse_blocks_per_station = mse_blocks_per_station.mean(dim=self.ahead_dim, skipna=True)
+        return mse_blocks, mse_blocks_per_station
 
     def create_error_array(self, data):
         """Calculate squared error of all given time series in relation to observation."""
@@ -328,7 +343,6 @@ class PostProcessing(RunEnvironment):
         # forecast
         with TimeTracking(name=f"{inspect.stack()[0].function} ({bootstrap_type}, {bootstrap_method})"):
             # extract all requirements from data store
-            forecast_path = self.data_store.get("forecast_path")
             number_of_bootstraps = self.data_store.get("n_boots", "feature_importance")
             dims = [self.uncertainty_estimate_boot_dim, self.index_dim, self.ahead_dim, self.model_type_dim]
             for station in self.test_data:
@@ -348,13 +362,13 @@ class PostProcessing(RunEnvironment):
                     coords = (range(number_of_bootstraps), range(shape[0]), range(1, shape[1] + 1))
                     var = f"{index}_{dimension}" if index is not None else str(dimension)
                     tmp = xr.DataArray(bootstrap_predictions, coords=(*coords, [var]), dims=dims)
-                    file_name = os.path.join(forecast_path,
+                    file_name = os.path.join(self.forecast_path,
                                              f"bootstraps_{station}_{var}_{bootstrap_type}_{bootstrap_method}.nc")
                     tmp.to_netcdf(file_name)
                 else:
                     # store also true labels for each station
                     labels = np.expand_dims(Y[..., 0], axis=-1)
-                    file_name = os.path.join(forecast_path, f"bootstraps_{station}_{bootstrap_method}_labels.nc")
+                    file_name = os.path.join(self.forecast_path, f"bootstraps_{station}_{bootstrap_method}_labels.nc")
                     labels = xr.DataArray(labels, coords=(*coords[1:], [self.observation_indicator]), dims=dims[1:])
                     labels.to_netcdf(file_name)
 
@@ -370,7 +384,6 @@ class PostProcessing(RunEnvironment):
         """
         with TimeTracking(name=f"{inspect.stack()[0].function} ({bootstrap_type}, {bootstrap_method})"):
             # extract all requirements from data store
-            forecast_path = self.data_store.get("forecast_path")
             number_of_bootstraps = self.data_store.get("n_boots", "feature_importance")
             forecast_file = f"forecasts_norm_%s_test.nc"
             reference_name = "orig"
@@ -386,19 +399,20 @@ class PostProcessing(RunEnvironment):
             score = {}
             for station in self.test_data:
                 # get station labels
-                file_name = os.path.join(forecast_path, f"bootstraps_{str(station)}_{bootstrap_method}_labels.nc")
+                file_name = os.path.join(self.forecast_path, f"bootstraps_{str(station)}_{bootstrap_method}_labels.nc")
                 with xr.open_dataarray(file_name) as da:
                     labels = da.load()
 
                 # get original forecasts
-                orig = self.get_orig_prediction(forecast_path, forecast_file % str(station), reference_name=reference_name)
+                orig = self.get_orig_prediction(self.forecast_path, forecast_file % str(station),
+                                                reference_name=reference_name)
                 orig.coords[self.index_dim] = labels.coords[self.index_dim]
 
                 # calculate skill scores for each variable
                 skill = []
                 for boot_set in bootstrap_iter:
                     boot_var = f"{boot_set[0]}_{boot_set[1]}" if isinstance(boot_set, tuple) else str(boot_set)
-                    file_name = os.path.join(forecast_path,
+                    file_name = os.path.join(self.forecast_path,
                                              f"bootstraps_{station}_{boot_var}_{bootstrap_type}_{bootstrap_method}.nc")
                     with xr.open_dataarray(file_name) as da:
                         boot_data = da.load()
@@ -473,8 +487,8 @@ class PostProcessing(RunEnvironment):
         :return: the model
         """
         try:  # is only available if a model was trained in training stage
-            model = self.data_store.get("best_model")
-        except NameNotFoundInDataStore:
+            model = self.data_store.get("model")
+        except (NameNotFoundInDataStore, NameNotFoundInScope):
             logging.info("No model was saved in data store. Try to load model from experiment path.")
             model_name = self.data_store.get("model_name", "model")
             model: AbstractModelClass = self.data_store.get("model", "model")
@@ -501,7 +515,6 @@ class PostProcessing(RunEnvironment):
 
         """
         logging.info("Run plotting routines...")
-        path = self.data_store.get("forecast_path")
         use_multiprocessing = self.data_store.get("use_multiprocessing")
 
         plot_list = self.data_store.get("plot_list", "postprocessing")
@@ -540,8 +553,8 @@ class PostProcessing(RunEnvironment):
 
         try:
             if "PlotConditionalQuantiles" in plot_list:
-                PlotConditionalQuantiles(self.test_data.keys(), data_pred_path=path, plot_folder=self.plot_path,
-                                         forecast_indicator=self.forecast_indicator,
+                PlotConditionalQuantiles(self.test_data.keys(), data_pred_path=self.forecast_path,
+                                         plot_folder=self.plot_path, forecast_indicator=self.forecast_indicator,
                                          obs_indicator=self.observation_indicator)
         except Exception as e:
             logging.error(f"Could not create plot PlotConditionalQuantiles due to the following error:"
@@ -549,7 +562,7 @@ class PostProcessing(RunEnvironment):
 
         try:
             if "PlotMonthlySummary" in plot_list:
-                PlotMonthlySummary(self.test_data.keys(), path, r"forecasts_%s_test.nc", self.target_var,
+                PlotMonthlySummary(self.test_data.keys(), self.forecast_path, r"forecasts_%s_test.nc", self.target_var,
                                    plot_folder=self.plot_path)
         except Exception as e:
             logging.error(f"Could not create plot PlotMonthlySummary due to the following error:"
@@ -575,8 +588,8 @@ class PostProcessing(RunEnvironment):
 
         try:
             if "PlotTimeSeries" in plot_list:
-                PlotTimeSeries(self.test_data.keys(), path, r"forecasts_%s_test.nc", plot_folder=self.plot_path,
-                               sampling=self._sampling, ahead_dim=self.ahead_dim)
+                PlotTimeSeries(self.test_data.keys(), self.forecast_path, r"forecasts_%s_test.nc",
+                               plot_folder=self.plot_path, sampling=self._sampling, ahead_dim=self.ahead_dim)
         except Exception as e:
             logging.error(f"Could not create plot PlotTimeSeries due to the following error:\n{sys.exc_info()[0]}\n"
                           f"{sys.exc_info()[1]}\n{sys.exc_info()[2]}\n{traceback.format_exc()}")
@@ -584,11 +597,13 @@ class PostProcessing(RunEnvironment):
         try:
             if "PlotSampleUncertaintyFromBootstrap" in plot_list and self.uncertainty_estimate is not None:
                 block_length = self.data_store.get_default("block_length", default="1m", scope="uncertainty_estimate")
-                PlotSampleUncertaintyFromBootstrap(
-                    data=self.uncertainty_estimate, plot_folder=self.plot_path, model_type_dim=self.model_type_dim,
-                    dim_name_boots=self.uncertainty_estimate_boot_dim, error_measure="mean squared error",
-                    error_unit=r"ppb$^2$", block_length=block_length, model_name=self.model_display_name,
-                    model_indicator=self.forecast_indicator, sampling=self._sampling)
+                for season in [None] + list(self.uncertainty_estimate_seasons.keys()):
+                    estimate = self.uncertainty_estimate if season is None else self.uncertainty_estimate_seasons[season]
+                    PlotSampleUncertaintyFromBootstrap(
+                        data=estimate, plot_folder=self.plot_path, model_type_dim=self.model_type_dim,
+                        dim_name_boots=self.uncertainty_estimate_boot_dim, error_measure="mean squared error",
+                        error_unit=r"ppb$^2$", block_length=block_length, model_name=self.model_display_name,
+                        model_indicator=self.forecast_indicator, sampling=self._sampling, season_annotation=season)
         except Exception as e:
             logging.error(f"Could not create plot PlotSampleUncertaintyFromBootstrap due to the following error: {e}"
                           f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}")
@@ -647,7 +662,18 @@ class PostProcessing(RunEnvironment):
         except Exception as e:
             logging.error(f"Could not create plot PlotPeriodogram due to the following error: {e}"
                           f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}")
-        
+
+        try:
+            if "PlotTimeEvolutionMetric" in plot_list:
+                PlotTimeEvolutionMetric(self.block_mse_per_station, plot_folder=self.plot_path,
+                                        model_type_dim=self.model_type_dim, ahead_dim=self.ahead_dim,
+                                        error_measure="Mean Squared Error", error_unit=r"ppb$^2$",
+                                        model_indicator=self.forecast_indicator, model_name=self.model_display_name,
+                                        time_dim=self.index_dim)
+        except Exception as e:
+            logging.error(f"Could not create plot PlotTimeEvolutionMetric due to the following error: {e}"
+                          f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}")
+
     @TimeTrackingWrapper
     def calculate_test_score(self):
         """Evaluate test score of model and save locally."""
@@ -668,6 +694,7 @@ class PostProcessing(RunEnvironment):
         logging.info(f"start train_ols_model on train data")
         self.ols_model = OrdinaryLeastSquaredModel(self.train_data)
 
+    @TimeTrackingWrapper
     def make_prediction(self, subset):
         """
         Create predictions for NN, OLS, and persistence and add true observation as reference.
@@ -680,8 +707,7 @@ class PostProcessing(RunEnvironment):
         logging.info(f"start make_prediction for {subset_type}")
         time_dimension = self.data_store.get("time_dim")
         window_dim = self.data_store.get("window_dim")
-        path = self.data_store.get("forecast_path")
-        subset_type = subset.name
+
         for i, data in enumerate(subset):
             input_data = data.get_X()
             target_data = data.get_Y(as_numpy=False)
@@ -721,7 +747,7 @@ class PostProcessing(RunEnvironment):
 
                 # save all forecasts locally
                 prefix = "forecasts_norm" if normalised is True else "forecasts"
-                file = os.path.join(path, f"{prefix}_{str(data)}_{subset_type}.nc")
+                file = os.path.join(self.forecast_path, f"{prefix}_{str(data)}_{subset_type}.nc")
                 all_predictions.to_netcdf(file)
 
     def _get_frequency(self) -> str:
@@ -949,14 +975,13 @@ class PostProcessing(RunEnvironment):
 
         :return: competitive and climatological skill scores, error metrics
         """
-        path = self.data_store.get("forecast_path")
         all_stations = self.data_store.get("stations")
         skill_score_competitive = {}
         skill_score_competitive_count = {}
         skill_score_climatological = {}
         errors = {}
         for station in all_stations:
-            external_data = self._get_external_data(station, path)  # test data
+            external_data = self._get_external_data(station, self.forecast_path)  # test data
 
             # test errors
             if external_data is not None:
@@ -995,7 +1020,7 @@ class PostProcessing(RunEnvironment):
             if external_data is not None:
                 skill_score_competitive[station], skill_score_competitive_count[station] = skill_score.skill_scores()
 
-            internal_data = self._get_internal_data(station, path)
+            internal_data = self._get_internal_data(station, self.forecast_path)
             if internal_data is not None:
                 skill_score_climatological[station] = skill_score.climatological_skill_scores(
                     internal_data, forecast_name=self.forecast_indicator)
diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py
index 8443b10d4d16b71819b795c0579b4d61cb739b70..0e416acbca4d66d5844e1179c7653ac5a9934f28 100644
--- a/mlair/run_modules/pre_processing.py
+++ b/mlair/run_modules/pre_processing.py
@@ -114,17 +114,17 @@ class PreProcessing(RunEnvironment):
         +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+
 
         """
-        meta_data = ['station_name', 'station_lon', 'station_lat', 'station_alt']
+        meta_cols = ['station_name', 'station_lon', 'station_lat', 'station_alt']
         meta_round = ["station_lon", "station_lat", "station_alt"]
         precision = 4
         path = os.path.join(self.data_store.get("experiment_path"), "latex_report")
         path_config.check_path_and_create(path)
         names_of_set = ["train", "val", "test"]
-        df = self.create_info_df(meta_data, meta_round, names_of_set, precision)
+        df = self.create_info_df(meta_cols, meta_round, names_of_set, precision)
         column_format = tables.create_column_format_for_tex(df)
         tables.save_to_tex(path=path, filename="station_sample_size.tex", column_format=column_format, df=df)
         tables.save_to_md(path=path, filename="station_sample_size.md", df=df)
-        df_nometa = df.drop(meta_data, axis=1)
+        df_nometa = df.drop(meta_cols, axis=1)
         column_format = tables.create_column_format_for_tex(df)
         tables.save_to_tex(path=path, filename="station_sample_size_short.tex", column_format=column_format,
                            df=df_nometa)
@@ -150,15 +150,35 @@ class PreProcessing(RunEnvironment):
         df_descr = df_descr[df_descr_colnames]
         return df_descr
 
-    def create_info_df(self, meta_data, meta_round, names_of_set, precision):
-        df = pd.DataFrame(columns=meta_data + names_of_set)
+    def create_info_df(self, meta_cols, meta_round, names_of_set, precision):
+        use_multiprocessing = self.data_store.get("use_multiprocessing")
+        max_process = self.data_store.get("max_number_multiprocessing")
+        df = pd.DataFrame(columns=meta_cols + names_of_set)
         for set_name in names_of_set:
             data = self.data_store.get("data_collection", set_name)
-            for station in data:
-                station_name = str(station.id_class)
-                df.loc[station_name, set_name] = station.get_Y()[0].shape[0]
-                if df.loc[station_name, meta_data].isnull().any():
-                    df.loc[station_name, meta_data] = station.id_class.meta.loc[meta_data].values.flatten()
+            n_process = min([psutil.cpu_count(logical=False), len(data), max_process])  # use only physical cpus
+            if n_process > 1 and use_multiprocessing is True:  # parallel solution
+                logging.info(f"use parallel create_info_df ({set_name})")
+                pool = multiprocessing.Pool(n_process)
+                logging.info(f"running {getattr(pool, '_processes')} processes in parallel")
+                output = [pool.apply_async(f_proc_create_info_df, args=(station, meta_cols)) for station in data]
+                for i, p in enumerate(output):
+                    res = p.get()
+                    station_name, shape, meta = res["station_name"], res["Y_shape"], res["meta"]
+                    df.loc[station_name, set_name] = shape
+                    if df.loc[station_name, meta_cols].isnull().any():
+                        df.loc[station_name, meta_cols] = meta
+                    logging.info(f"...finished: {station_name} ({int((i + 1.) / len(output) * 100)}%)")
+                pool.close()
+                pool.join()
+            else:  # serial solution
+                logging.info(f"use serial create_info_df ({set_name})")
+                for station in data:
+                    res = f_proc_create_info_df(station, meta_cols)
+                    station_name, shape, meta = res["station_name"], res["Y_shape"], res["meta"]
+                    df.loc[station_name, set_name] = shape
+                    if df.loc[station_name, meta_cols].isnull().any():
+                        df.loc[station_name, meta_cols] = meta
             df.loc["# Samples", set_name] = df.loc[:, set_name].sum()
             assert len(data) == df.loc[:, set_name].count() - 1
             df.loc["# Stations", set_name] = len(data)
@@ -380,6 +400,13 @@ def f_proc(data_handler, station, name_affix, store, return_strategy="", tmp_pat
         return _tmp_file, station
 
 
+def f_proc_create_info_df(data, meta_cols):
+        station_name = str(data.id_class)
+        res = {"station_name": station_name, "Y_shape": data.get_Y()[0].shape[0],
+               "meta": data.id_class.meta.loc[meta_cols].values.flatten()}
+        return res
+
+
 def f_inspect_error(formatted):
     for i in range(len(formatted) - 1, -1, -1):
         if "mlair/mlair" not in formatted[i]:
diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py
index a38837dce041295d37fae1ea86ef2a215d51dc89..5ce906122ef184d6dcad5527e923e44f04028fe5 100644
--- a/mlair/run_modules/training.py
+++ b/mlair/run_modules/training.py
@@ -54,7 +54,7 @@ class Training(RunEnvironment):
         * `upsampling` [train, val, test]
 
     Sets
-        * `best_model` [.]
+        * `model` [.]
 
     Creates
         * `<exp_name>_model-best.h5`
@@ -165,6 +165,9 @@ class Training(RunEnvironment):
                                initial_epoch=initial_epoch,
                                workers=psutil.cpu_count(logical=False))
             history = hist
+        epoch_best = checkpoint.epoch_best
+        if epoch_best is not None:
+            logging.info(f"best epoch: {epoch_best + 1}")
         try:
             lr = self.callbacks.get_callback_by_name("lr")
         except IndexError:
@@ -174,29 +177,15 @@ class Training(RunEnvironment):
         except IndexError:
             epo_timing = None
         self.save_callbacks_as_json(history, lr, epo_timing)
-        self.load_best_model(checkpoint.filepath)
-        self.create_monitoring_plots(history, lr)
+        self.create_monitoring_plots(history, lr, epoch_best)
 
     def save_model(self) -> None:
         """Save model in local experiment directory. Model is named as `<experiment_name>_<custom_model_name>.h5`."""
         model_name = self.data_store.get("model_name", "model")
-        logging.debug(f"save best model to {model_name}")
-        self.model.save(model_name, save_format='h5')
-        self.model.save(model_name)
-        self.data_store.set("best_model", self.model)
-
-    def load_best_model(self, name: str) -> None:
-        """
-        Load model weights for model with name. Skip if no weights are available.
-
-        :param name: name of the model to load weights for
-        """
-        logging.debug(f"load best model: {name}")
-        try:
-            self.model.load_model(name, compile=True)
-            logging.info('reload model...')
-        except OSError:
-            logging.info('no weights to reload...')
+        logging.debug(f"save model to {model_name}")
+        self.model.save(model_name, save_format="h5")
+        self.model.save(model_name, save_format="tf")
+        self.data_store.set("model", self.model)
 
     def save_callbacks_as_json(self, history: Callback, lr_sc: Callback, epo_timing: Callback) -> None:
         """
@@ -219,7 +208,7 @@ class Training(RunEnvironment):
             with open(os.path.join(path, "epo_timing.json"), "w") as f:
                 json.dump(epo_timing.epo_timing, f)
 
-    def create_monitoring_plots(self, history: Callback, lr_sc: Callback) -> None:
+    def create_monitoring_plots(self, history: Callback, lr_sc: Callback, epoch_best: int = None) -> None:
         """
         Create plot of history and learning rate in dependence of the number of epochs.
 
@@ -228,22 +217,23 @@ class Training(RunEnvironment):
 
         :param history: keras history object with losses to plot (must at least include `loss` and `val_loss`)
         :param lr_sc:  learning rate decay object with 'lr' attribute
+        :param epoch_best: number of best epoch (starts counting as 0)
         """
         path = self.data_store.get("plot_path")
         name = self.data_store.get("experiment_name")
 
         # plot history of loss and mse (if available)
         filename = os.path.join(path, f"{name}_history_loss.pdf")
-        PlotModelHistory(filename=filename, history=history)
+        PlotModelHistory(filename=filename, history=history, epoch_best=epoch_best)
         multiple_branches_used = len(history.model.output_names) > 1  # means that there are multiple output branches
         if multiple_branches_used:
             filename = os.path.join(path, f"{name}_history_main_loss.pdf")
-            PlotModelHistory(filename=filename, history=history, main_branch=True)
+            PlotModelHistory(filename=filename, history=history, main_branch=True, epoch_best=epoch_best)
         mse_indicator = list(set(history.model.metrics_names).intersection(["mean_squared_error", "mse"]))
         if len(mse_indicator) > 0:
             filename = os.path.join(path, f"{name}_history_main_mse.pdf")
             PlotModelHistory(filename=filename, history=history, plot_metric=mse_indicator[0],
-                             main_branch=multiple_branches_used)
+                             main_branch=multiple_branches_used, epoch_best=epoch_best)
 
         # plot learning rate
         if lr_sc:
diff --git a/test/test_configuration/test_defaults.py b/test/test_configuration/test_defaults.py
index f6bc6d24724c2620083602d3864bcbca0a709681..07a5aa2f543b1992baf10421de4b28133feb0eac 100644
--- a/test/test_configuration/test_defaults.py
+++ b/test/test_configuration/test_defaults.py
@@ -21,6 +21,8 @@ class TestAllDefaults:
         assert DEFAULT_PERMUTE_DATA is False
         assert DEFAULT_BATCH_SIZE == int(256 * 2)
         assert DEFAULT_EPOCHS == 20
+        assert bool(np.isinf(DEFAULT_EARLY_STOPPING_EPOCHS)) is True
+        assert DEFAULT_RESTORE_BEST_MODEL_WEIGHTS is True
 
     def test_data_handler_parameters(self):
         assert DEFAULT_STATIONS == ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
diff --git a/test/test_data_handler/test_iterator.py b/test/test_data_handler/test_iterator.py
index e47d725a4fd78fec98e81a6de9c18869e7b47637..bb8ecb5d216519b3662a5baa4d463780b4c29d8c 100644
--- a/test/test_data_handler/test_iterator.py
+++ b/test/test_data_handler/test_iterator.py
@@ -106,6 +106,9 @@ class DummyData:
         Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 5, 1))  # samples, window, variables
         return [Y1, Y2]
 
+    def get_data(self, upsampling=False, as_numpy=True):
+        return self.get_X(upsampling, as_numpy), self.get_Y(upsampling, as_numpy)
+
 
 class TestKerasIterator:
 
diff --git a/test/test_helpers/test_helpers.py b/test/test_helpers/test_helpers.py
index b850b361b09a8d180c5c70c2257d2d7be27c6cc0..70640be9d56d71e4f68145b3bb68fb835e1e27a5 100644
--- a/test/test_helpers/test_helpers.py
+++ b/test/test_helpers/test_helpers.py
@@ -15,7 +15,7 @@ import string
 from mlair.helpers import to_list, dict_to_xarray, float_round, remove_items, extract_value, select_from_dict, sort_like
 from mlair.helpers import PyTestRegex
 from mlair.helpers import Logger, TimeTracking
-from mlair.helpers.helpers import is_xarray, convert2xrda
+from mlair.helpers.helpers import is_xarray, convert2xrda, relative_round
 
 
 class TestToList:
@@ -171,6 +171,39 @@ class TestFloatRound:
         assert float_round(-34.9221, 0) == -34.
 
 
+class TestRelativeRound:
+
+    def test_relative_round_big_numbers(self):
+        assert relative_round(101, 1) == 100
+        assert relative_round(99, 1) == 100
+        assert relative_round(105, 2) == 100
+        assert relative_round(106, 2) == 110
+        assert relative_round(106, 3) == 106
+
+    def test_relative_round_float_numbers(self):
+        assert relative_round(101.2033, 4) == 101.2
+        assert relative_round(101.2033, 5) == 101.2
+        assert relative_round(101.2033, 6) == 101.203
+
+    def test_relative_round_small_numbers(self):
+        assert relative_round(0.03112, 2) == 0.031
+        assert relative_round(0.03112, 1) == 0.03
+        assert relative_round(0.031126, 4) == 0.03113
+
+    def test_relative_round_negative_numbers(self):
+        assert relative_round(-101.2033, 5) == -101.2
+        assert relative_round(-106, 2) == -110
+        assert relative_round(-0.03112, 2) == -0.031
+        assert relative_round(-0.03112, 1) == -0.03
+        assert relative_round(-0.031126, 4) == -0.03113
+
+    def test_relative_round_wrong_significance(self):
+        with pytest.raises(AssertionError):
+            relative_round(300, -1)
+        with pytest.raises(TypeError):
+            relative_round(300, 1.1)
+
+
 class TestSelectFromDict:
 
     @pytest.fixture
diff --git a/test/test_helpers/test_statistics.py b/test/test_helpers/test_statistics.py
index a3f645937258604c2dbbda07b36a58d83e879065..6f1952a9c1df2b16a2e298b433d732d9e69200d4 100644
--- a/test/test_helpers/test_statistics.py
+++ b/test/test_helpers/test_statistics.py
@@ -6,7 +6,7 @@ import xarray as xr
 from mlair.helpers.statistics import standardise, standardise_inverse, standardise_apply, centre, centre_inverse, \
     centre_apply, apply_inverse_transformation, min_max, min_max_inverse, min_max_apply, log, log_inverse, log_apply, \
     create_single_bootstrap_realization, calculate_average, create_n_bootstrap_realizations, mean_squared_error, \
-    mean_absolute_error, calculate_error_metrics
+    mean_absolute_error, calculate_error_metrics, index_of_agreement, modified_normalized_mean_bias
 from mlair.helpers.testing import check_nested_equality
 
 lazy = pytest.lazy_fixture
@@ -250,13 +250,20 @@ class TestCreateBootstrapRealizations:
     def test_create_n_bootstrap_realizations(self, data):
         boot_data = create_n_bootstrap_realizations(data, dim_name_time='time', dim_name_model='model',
                                                     n_boots=1000, dim_name_boots='boots')
-        assert isinstance(boot_data, xr.DataArray)
-        assert boot_data.shape == (1000, 2)
+        assert isinstance(boot_data, dict)
+        assert "" in boot_data.keys()
+        assert isinstance(boot_data[""], xr.DataArray)
+        assert boot_data[""].shape == (1000, 2)
 
         boot_data = create_n_bootstrap_realizations(data.sel(model='m1').squeeze(), dim_name_time='time',
                                                     dim_name_model='model', n_boots=1000, dim_name_boots='boots')
-        assert isinstance(boot_data, xr.DataArray)
-        assert boot_data.shape == (1000,)
+        assert isinstance(boot_data[""], xr.DataArray)
+        assert boot_data[""].shape == (1000,)
+
+        data["time"] = pd.date_range("2022-01", periods=10, freq="1m")
+        boot_data = create_n_bootstrap_realizations(data, dim_name_time='time', dim_name_model='model',
+                                                    n_boots=100, dim_name_boots='boots', seasons=["JJA", "DJF"])
+        assert sorted(list(boot_data.keys())) == sorted(["", "JJA", "DJF"])
 
 
 class TestMeanSquaredError:
@@ -297,6 +304,67 @@ class TestMeanAbsoluteError:
         assert xr.testing.assert_equal(mean_absolute_error(x_array1, x_array2, "value"), expected) is None
 
 
+class TestIndexOfAgreement:
+
+    def test_index_of_agreement(self):
+        d1 = np.array([1, 2, 3, 4, 5])
+        d2 = np.array([1, 2, 3, 4, 5])
+        assert index_of_agreement(d1, d2) == 1
+        d1 = np.array([1, 2, 3, 4, 7])
+        assert np.testing.assert_almost_equal(index_of_agreement(d1, d2), 0.9333, 3) is None
+        d1 = np.array([3, 4, 5, 6, 7])
+        assert np.testing.assert_almost_equal(index_of_agreement(d1, d2), 0.687, 3) is None
+
+    def test_index_of_agreement_xarray(self):
+        d1 = np.array([np.array([1, 2, 3, 4, 5]), np.array([1, 2, 3, 4, 5]), np.array([1, 2, 3, 4, 5])])
+        d2 = np.array([np.array([2, 4, 3, 4, 6]), np.array([2, 3, 3, 4, 5]), np.array([0, 1, 3, 4, 5])])
+        shape = d1.shape
+        coords = {'index': range(shape[0]), 'value': range(shape[1])}
+        x_array1 = xr.DataArray(d1, coords=coords, dims=coords.keys())
+        x_array2 = xr.DataArray(d2, coords=coords, dims=coords.keys())
+        expected = xr.DataArray(np.array([1, 1, 1]), coords={"index": [0, 1, 2]}, dims=["index"])
+        res = index_of_agreement(x_array1, x_array1, dim="value")
+        assert xr.testing.assert_equal(res, expected) is None
+        expected = xr.DataArray(np.array([0.8478, 0.9333, 0.9629]), coords={"index": [0, 1, 2]}, dims=["index"])
+        res = index_of_agreement(x_array1, x_array2, dim="value")
+        assert xr.testing.assert_allclose(res, expected, atol=10**-2) is None
+
+
+class TestMNMB:
+
+    def test_modified_normalized_mean_bias(self):
+        d1 = np.array([1, 2, 3, 4, 5])
+        d2 = np.array([1, 2, 3, 4, 5])
+        assert modified_normalized_mean_bias(d1, d2) == 0
+        d1 = np.array([1, 2, 3, 4, 7])
+        assert np.testing.assert_almost_equal(modified_normalized_mean_bias(d1, d2), 0.0666, 3) is None
+        d1 = np.array([3, 4, 5, 6, 7])
+        assert np.testing.assert_almost_equal(modified_normalized_mean_bias(d1, d2), 0.58, 3) is None
+        assert np.testing.assert_almost_equal(modified_normalized_mean_bias(d2, d1), -0.58, 3) is None
+
+    def test_modified_normalized_mean_bias_xarray(self):
+        d1 = np.array([np.array([1, 2, 3, 4, 5]), np.array([1, 2, 3, 4, 5]), np.array([1, 2, 3, 4, 5])])
+        d2 = np.array([np.array([2, 4, 3, 4, 6]), np.array([2, 3, 3, 4, 5]), np.array([0, 1, 3, 4, 5])])
+        shape = d1.shape
+        coords = {'index': range(shape[0]), 'value': range(shape[1])}
+        x_array1 = xr.DataArray(d1, coords=coords, dims=coords.keys())
+        x_array2 = xr.DataArray(d2, coords=coords, dims=coords.keys())
+        expected = xr.DataArray(np.array([0, 0, 0]), coords={"index": [0, 1, 2]}, dims=["index"])
+        res = modified_normalized_mean_bias(x_array1, x_array1, dim="value")
+        assert xr.testing.assert_equal(res, expected) is None
+        expected = xr.DataArray(np.array([0, 0, 0, 0, 0]), coords={"value": [0, 1, 2, 3, 4]}, dims=["value"])
+        res = modified_normalized_mean_bias(x_array1, x_array1, dim="index")
+        assert xr.testing.assert_equal(res, expected) is None
+        expected = xr.DataArray(np.array([-0.3030, -0.2133, 0.5333]), coords={"index": [0, 1, 2]}, dims=["index"])
+        res = modified_normalized_mean_bias(x_array1, x_array2, dim="value")
+        assert xr.testing.assert_allclose(res, expected, atol=10**-2) is None
+        res = modified_normalized_mean_bias(x_array2, x_array1, dim="value")
+        assert xr.testing.assert_allclose(res, -expected, atol=10**-2) is None
+        expected = xr.DataArray(np.array([0.2222, -0.1333, 0, 0, -0.0606]), coords={"value": [0, 1, 2, 3, 4]}, dims=["value"])
+        res = modified_normalized_mean_bias(x_array1, x_array2, dim="index")
+        assert xr.testing.assert_allclose(res, expected, atol=10**-2) is None
+
+
 class TestCalculateErrorMetrics:
 
     def test_calculate_error_metrics(self):
@@ -309,20 +377,15 @@ class TestCalculateErrorMetrics:
         expected = {"mse": xr.DataArray(np.array([1, 2, 0, 0, 1./3]), coords={"value": [0, 1, 2, 3, 4]}, dims=["value"]),
                     "rmse": np.sqrt(xr.DataArray(np.array([1, 2, 0, 0, 1./3]), coords={"value": [0, 1, 2, 3, 4]}, dims=["value"])),
                     "mae": xr.DataArray(np.array([1, 4./3, 0, 0, 1./3]), coords={"value": [0, 1, 2, 3, 4]}, dims=["value"]),
+                    "ioa": xr.DataArray(np.array([0.3721, 0.4255, 1, 1, 0.4706]), coords={"value": [0, 1, 2, 3, 4]}, dims=["value"]),
+                    "mnmb": xr.DataArray(np.array([0.2222, -0.1333, 0, 0, -0.0606]), coords={"value": [0, 1, 2, 3, 4]}, dims=["value"]),
                     "n": xr.DataArray(np.array([3, 3, 3, 3, 3]), coords={"value": [0, 1, 2, 3, 4]}, dims=["value"])}
-        assert check_nested_equality(expected, calculate_error_metrics(x_array1, x_array2, "index")) is True
+        assert check_nested_equality(expected, calculate_error_metrics(x_array1, x_array2, "index"), 3) is True
 
         expected = {"mse": xr.DataArray(np.array([1.2, 0.4, 0.4]), coords={"index": [0, 1, 2]}, dims=["index"]),
                     "rmse": np.sqrt(xr.DataArray(np.array([1.2, 0.4, 0.4]), coords={"index": [0, 1, 2]}, dims=["index"])),
                     "mae": xr.DataArray(np.array([0.8, 0.4, 0.4]), coords={"index": [0, 1, 2]}, dims=["index"]),
+                    "ioa": xr.DataArray(np.array([0.8478, 0.9333, 0.9629]), coords={"index": [0, 1, 2]}, dims=["index"]),
+                    "mnmb": xr.DataArray(np.array([-0.3030, -0.2133, 0.5333]), coords={"index": [0, 1, 2]}, dims=["index"]),
                     "n": xr.DataArray(np.array([5, 5, 5]), coords={"index": [0, 1, 2]}, dims=["index"])}
-        assert check_nested_equality(expected, calculate_error_metrics(x_array1, x_array2, "value")) is True
-
-
-
-        # expected = xr.DataArray(np.array([1.2, 0.4, 0.4]), coords={"index": [0, 1, 2]}, dims=["index"])
-        # assert xr.testing.assert_equal(mean_squared_error(x_array1, x_array2, "value"), expected) is None
-        #
-        #
-        # expected = xr.DataArray(np.array([0.8, 0.4, 0.4]), coords={"index": [0, 1, 2]}, dims=["index"])
-        # assert xr.testing.assert_equal(mean_absolute_error(x_array1, x_array2, "value"), expected) is None
\ No newline at end of file
+        assert check_nested_equality(expected, calculate_error_metrics(x_array1, x_array2, "value"), 3) is True
diff --git a/test/test_helpers/test_testing_helpers.py b/test/test_helpers/test_testing_helpers.py
index 9b888a91a7c88a31764bd272632b1aab8e6e170f..8a4bdb92e41f14a8680ea797dcd74db74bd95c9c 100644
--- a/test/test_helpers/test_testing_helpers.py
+++ b/test/test_helpers/test_testing_helpers.py
@@ -58,16 +58,25 @@ class TestNestedEquality:
         assert check_nested_equality("3", 3) is False
         assert check_nested_equality("3", "3") is True
         assert check_nested_equality(None, None) is True
+        assert check_nested_equality(3.92, 3.9, 1) is True
+        assert check_nested_equality(3.92, 3.9, 2) is False
 
     def test_nested_equality_xarray(self):
         obj1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]})
         obj2 = xr.ones_like(obj1) * obj1
         assert check_nested_equality(obj1, obj2) is True
+        obj2 = obj2 * 1.0001
+        assert check_nested_equality(obj1, obj2) is False
+        assert check_nested_equality(obj1, obj2, 3) is True
 
     def test_nested_equality_numpy(self):
         obj1 = np.random.randn(2, 3)
         obj2 = obj1 * 1
         assert check_nested_equality(obj1, obj2) is True
+        obj2 = obj2 * 1.001
+        assert check_nested_equality(obj1, obj2) is False
+        assert check_nested_equality(obj1, obj2, 5) is False
+        assert check_nested_equality(obj1, obj2, 1) is True
 
     def test_nested_equality_list_tuple(self):
         assert check_nested_equality([3, 3], [3, 3]) is True
diff --git a/test/test_run_modules/test_model_setup.py b/test/test_run_modules/test_model_setup.py
index 60b37207ceefc4088b33fa002dac9db7c6c35399..6e8d3ea9ebab40c79b17b2fba386322a630f00e1 100644
--- a/test/test_run_modules/test_model_setup.py
+++ b/test/test_run_modules/test_model_setup.py
@@ -80,7 +80,7 @@ class TestModelSetup:
         setup._set_callbacks()
         assert "general.model" in setup.data_store.search_name("callbacks")
         callbacks = setup.data_store.get("callbacks", "general.model")
-        assert len(callbacks.get_callbacks()) == 4
+        assert len(callbacks.get_callbacks()) == 5
 
     def test_set_callbacks_no_lr_decay(self, setup):
         setup.data_store.set("lr_decay", None, "general.model")
@@ -88,7 +88,7 @@ class TestModelSetup:
         setup.checkpoint_name = "TestName"
         setup._set_callbacks()
         callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.model")
-        assert len(callbacks.get_callbacks()) == 3
+        assert len(callbacks.get_callbacks()) == 4
         with pytest.raises(IndexError):
             callbacks.get_callback_by_name("lr_decay")
 
@@ -150,3 +150,6 @@ class DummyData:
         Y1 = np.random.randint(0, 10, size=(self.number_of_samples, 5))  # samples, window
         Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 3))  # samples, window
         return [Y1, Y2]
+
+    def get_data(self, upsampling=False, as_numpy=True):
+        return self.get_X(upsampling, as_numpy), self.get_Y(upsampling, as_numpy)
diff --git a/test/test_run_modules/test_pre_processing.py b/test/test_run_modules/test_pre_processing.py
index 0f2ee7a10fd2e3190c0b66da558626747d4c03c9..1dafdbd5c4882932e3d57e726e7a06bea22a745d 100644
--- a/test/test_run_modules/test_pre_processing.py
+++ b/test/test_run_modules/test_pre_processing.py
@@ -46,12 +46,14 @@ class TestPreProcessing:
         with PreProcessing():
             assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started')
             assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started (preprocessing)')
-            assert caplog.record_tuples[-3] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 5 '
+            assert caplog.record_tuples[-6] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 5 '
                                                                         r'station\(s\). Found 5/5 valid stations.'))
+            assert caplog.record_tuples[-5] == ('root', 20, "use serial create_info_df (train)")
+            assert caplog.record_tuples[-4] == ('root', 20, "use serial create_info_df (val)")
+            assert caplog.record_tuples[-3] == ('root', 20, "use serial create_info_df (test)")
             assert caplog.record_tuples[-2] == ('root', 20, "Searching for competitors to be prepared for use.")
-            assert caplog.record_tuples[-1] == (
-            'root', 20, "No preparation required because no competitor was provided "
-                        "to the workflow.")
+            assert caplog.record_tuples[-1] == ('root', 20, "No preparation required because no competitor was provided"
+                                                            " to the workflow.")
         RunEnvironment().__del__()
 
     def test_run(self, obj_with_exp_setup):
diff --git a/test/test_run_modules/test_training.py b/test/test_run_modules/test_training.py
index 1b83b3823519d63d5dcbc10f0e31fc3433f98f34..8f1fcd1943f9f203e738053017e00f8c269afef1 100644
--- a/test/test_run_modules/test_training.py
+++ b/test/test_run_modules/test_training.py
@@ -326,16 +326,10 @@ class TestTraining:
         model_name = "test_model.h5"
         assert model_name not in os.listdir(model_path)
         init_without_run.save_model()
-        message = PyTestRegex(f"save best model to {os.path.join(model_path, model_name)}")
+        message = PyTestRegex(f"save model to {os.path.join(model_path, model_name)}")
         assert caplog.record_tuples[1] == ("root", 10, message)
         assert model_name in os.listdir(model_path)
 
-    def test_load_best_model_no_weights(self, init_without_run, caplog):
-        caplog.set_level(logging.DEBUG)
-        init_without_run.load_best_model("notExisting.h5")
-        assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting.h5"))
-        assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload..."))
-
     def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, epo_timing, model_path):
         init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
         assert "history.json" in os.listdir(model_path)
@@ -360,7 +354,7 @@ class TestTraining:
         assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0
         history.model.output_names = mock.MagicMock(return_value=["Main"])
         history.model.metrics_names = mock.MagicMock(return_value=["loss", "mean_squared_error"])
-        init_without_run.create_monitoring_plots(history, learning_rate)
+        init_without_run.create_monitoring_plots(history, learning_rate, epoch_best=1)
         assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2
 
     def test_resume_training1(self, path: str, model_path, batch_path, data_collection, statistics_per_var,