diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py
index ca569720dc41d95621d0613a2170cc4d9d46c082..b630261dbf58d7402f8c3cacaee153347ad4f1e3 100644
--- a/mlair/configuration/defaults.py
+++ b/mlair/configuration/defaults.py
@@ -2,6 +2,9 @@ __author__ = "Lukas Leufen"
 __date__ = '2020-06-25'
 
 
+import numpy as np
+
+
 DEFAULT_STATIONS = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
 DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values',
                         'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values',
@@ -24,6 +27,8 @@ DEFAULT_EXTREMES_ON_RIGHT_TAIL_ONLY = False
 DEFAULT_PERMUTE_DATA = False
 DEFAULT_BATCH_SIZE = int(256 * 2)
 DEFAULT_EPOCHS = 20
+DEFAULT_EARLY_STOPPING_EPOCHS = np.inf
+DEFAULT_RESTORE_BEST_MODEL_WEIGHTS = True
 DEFAULT_TARGET_VAR = "o3"
 DEFAULT_TARGET_DIM = "variables"
 DEFAULT_WINDOW_LEAD_TIME = 3
diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py
index 72f40e453c37bcfe16566336fbfd56eb2734ae9c..39b0da5b49f470d11ea64b9ddd344b9ad11e2b7f 100644
--- a/mlair/model_modules/keras_extensions.py
+++ b/mlair/model_modules/keras_extensions.py
@@ -164,6 +164,7 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
         """Initialise ModelCheckpointAdvanced and set callbacks attribute."""
         self.callbacks = kwargs.pop("callbacks")
         self.epoch_best = None
+        self.restore_best_weights = kwargs.pop("restore_best_weights", True)
         super().__init__(*args, **kwargs)
 
     def update_best(self, hist):
@@ -177,14 +178,19 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
 
         :param hist: The History object from the previous (interrupted) training.
         """
-        f = np.min if self.monitor_op.__name__ == "less" else np.max
-        f_loc = lambda x: np.where(x == f(x))[0][-1]
-        _d = hist.history.get(self.monitor)
-        loc = f_loc(_d)
-        assert f(_d) == _d[loc]
-        self.epoch_best = loc
-        self.best = _d[loc]
-        logging.info(f"Set best epoch {self.epoch_best + 1} with {self.monitor}={self.best}")
+        if self.restore_best_weights:
+            f = np.min if self.monitor_op.__name__ == "less" else np.max
+            f_loc = lambda x: np.where(x == f(x))[0][-1]
+            _d = hist.history.get(self.monitor)
+            loc = f_loc(_d)
+            assert f(_d) == _d[loc]
+            self.epoch_best = loc
+            self.best = _d[loc]
+            logging.info(f"Set best epoch {self.epoch_best + 1} with {self.monitor}={self.best}")
+        else:
+            _d = hist.history.get(self.monitor)[-1]
+            self.best = _d
+            logging.info(f"Set only best result ({self.monitor}={self.best}) without best epoch")
 
     def update_callbacks(self, callbacks):
         """
@@ -205,7 +211,8 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
                 if self.save_best_only:
                     current = logs.get(self.monitor)
                     if current == self.best:
-                        self.epoch_best = epoch
+                        if self.restore_best_weights:
+                            self.epoch_best = epoch
                         if self.verbose > 0:  # pragma: no branch
                             print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
                         with open(file_path, "wb") as f:
diff --git a/mlair/plotting/training_monitoring.py b/mlair/plotting/training_monitoring.py
index 4884dcb81c2b98546da3edce099c02b47aebd7b2..e651078fe66a5d95e6e01d5384865eee621ba6f4 100644
--- a/mlair/plotting/training_monitoring.py
+++ b/mlair/plotting/training_monitoring.py
@@ -11,6 +11,7 @@ import matplotlib.pyplot as plt
 import pandas as pd
 
 from mlair.model_modules.keras_extensions import LearningRateDecay
+from mlair.helpers.helpers import relative_round
 
 # matplotlib.use('Agg')
 history_object = Union[Dict, keras.callbacks.History]
@@ -98,7 +99,13 @@ class PlotModelHistory:
         ax.set_yscale('log')
         if len(self._additional_columns) > 0:
             self._data[self._additional_columns].plot(linewidth=0.7, secondary_y=True, ax=ax, logy=True)
-        title = f"Model {self._plot_metric}: best = {self._data[[f'val_{self._plot_metric}']].min().values}"
+        if self._epoch_best is not None:
+            final_res = self._data[[f'val_{self._plot_metric}']].min().values[0]
+            annotation = f"best epoch {self._epoch_best}"
+        else:
+            final_res = self._data[[f'val_{self._plot_metric}']].values[-1][0]
+            annotation = "final"
+        title = f"Model {self._plot_metric} (val, {annotation}): {relative_round(final_res, 5)}"
         ax.set(xlabel="epoch", ylabel=self._plot_metric, title=title)
         ax.axhline(y=0, color="gray", linewidth=0.5)
         plt.tight_layout()
diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py
index df797ffc23370bf4f45bb2b4f76e5f71e9bd030f..d807db14c96a4a30fde791e54c8b1b32e519fb9c 100644
--- a/mlair/run_modules/experiment_setup.py
+++ b/mlair/run_modules/experiment_setup.py
@@ -23,7 +23,8 @@ from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT,
     DEFAULT_USE_MULTIPROCESSING, DEFAULT_USE_MULTIPROCESSING_ON_DEBUG, DEFAULT_MAX_NUMBER_MULTIPROCESSING, \
     DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_TYPE, DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_METHOD, DEFAULT_OVERWRITE_LAZY_DATA, \
     DEFAULT_UNCERTAINTY_ESTIMATE_BLOCK_LENGTH, DEFAULT_UNCERTAINTY_ESTIMATE_EVALUATE_COMPETITORS, \
-    DEFAULT_UNCERTAINTY_ESTIMATE_N_BOOTS, DEFAULT_DO_UNCERTAINTY_ESTIMATE
+    DEFAULT_UNCERTAINTY_ESTIMATE_N_BOOTS, DEFAULT_DO_UNCERTAINTY_ESTIMATE, DEFAULT_EARLY_STOPPING_EPOCHS, \
+    DEFAULT_RESTORE_BEST_MODEL_WEIGHTS
 from mlair.data_handler import DefaultDataHandler
 from mlair.run_modules.run_environment import RunEnvironment
 from mlair.model_modules.fully_connected_networks import FCN_64_32_16 as VanillaModel
@@ -178,6 +179,11 @@ class ExperimentSetup(RunEnvironment):
         (partly) trained model is lower than this parameter, training is continue. In case this number is higher than
         the given epochs parameter, no training is resumed. Epochs is set to 20 per default, but this value is just a
         placeholder that should be adjusted for a meaningful training.
+    :param early_stopping_epochs: number of consecutive epochs with no improvement on val loss to stop training. When
+        set to `np.inf` or not providing at all, training is not stopped before reaching `epochs`.
+    :param restore_best_model_weights: indicates whether to use model state with best val loss (if True) or model state
+        on ending of training (if False). The later depends on the parameters `epochs` and `early_stopping_epochs` which
+        trigger stopping of training.
     :param data_handler:
     :param data_origin:
     :param competitors: Provide names of reference models trained by MLAir that can be found in the `competitor_path`.
@@ -221,7 +227,9 @@ class ExperimentSetup(RunEnvironment):
                  feature_importance_n_boots: int = None, feature_importance_create_new_bootstraps: bool = None,
                  feature_importance_bootstrap_method=None, feature_importance_bootstrap_type=None,
                  data_path: str = None, batch_path: str = None, login_nodes=None,
-                 hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None,
+                 hpc_hosts=None, model=None, batch_size=None, epochs=None,
+                 early_stopping_epochs: int = None, restore_best_model_weights: bool = None,
+                 data_handler=None,
                  data_origin: Dict = None, competitors: list = None, competitor_path: str = None,
                  use_multiprocessing: bool = None, use_multiprocessing_on_debug: bool = None,
                  max_number_multiprocessing: int = None, start_script: Union[Callable, str] = None,
@@ -255,6 +263,9 @@ class ExperimentSetup(RunEnvironment):
         self._set_param("permute_data", permute_data or upsampling, scope="train")
         self._set_param("batch_size", batch_size, default=DEFAULT_BATCH_SIZE)
         self._set_param("epochs", epochs, default=DEFAULT_EPOCHS)
+        self._set_param("early_stopping_epochs", early_stopping_epochs, default=DEFAULT_EARLY_STOPPING_EPOCHS)
+        self._set_param("restore_best_model_weights", restore_best_model_weights,
+                        default=DEFAULT_RESTORE_BEST_MODEL_WEIGHTS)
 
         # set experiment name
         sampling = self._set_param("sampling", sampling, default=DEFAULT_SAMPLING)  # always related to output sampling
diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py
index 4e9f8fa4439e9885a6c16c2b2eccfee2c97fd936..eab8012b983a0676620bbc66f65ff79b31165aeb 100644
--- a/mlair/run_modules/model_setup.py
+++ b/mlair/run_modules/model_setup.py
@@ -11,6 +11,7 @@ from dill.source import getsource
 import tensorflow.keras as keras
 import pandas as pd
 import tensorflow as tf
+import numpy as np
 
 from mlair.model_modules.keras_extensions import HistoryAdvanced, EpoTimingCallback, CallbackHandler
 from mlair.run_modules.run_environment import RunEnvironment
@@ -117,18 +118,36 @@ class ModelSetup(RunEnvironment):
 
         Add all callbacks with the .add_callback statement. Finally, the advanced model checkpoint is added.
         """
-        lr = self.data_store.get_default("lr_decay", scope=self.scope, default=None)
-        hist = HistoryAdvanced()
-        epo_timing = EpoTimingCallback()
-        self.data_store.set("hist", hist, scope="model")
-        self.data_store.set("epo_timing", epo_timing, scope="model")
+        # create callback handler
         callbacks = CallbackHandler()
+
+        # add callback: learning rate
+        lr = self.data_store.get_default("lr_decay", scope=self.scope, default=None)
         if lr is not None:
             callbacks.add_callback(lr, self.callbacks_name % "lr", "lr")
+
+        # add callback: advanced history
+        hist = HistoryAdvanced()
+        self.data_store.set("hist", hist, scope="model")
         callbacks.add_callback(hist, self.callbacks_name % "hist", "hist")
+
+        # add callback: epo timing
+        epo_timing = EpoTimingCallback()
+        self.data_store.set("epo_timing", epo_timing, scope="model")
         callbacks.add_callback(epo_timing, self.callbacks_name % "epo_timing", "epo_timing")
+
+        # add callback: early stopping
+        patience = self.data_store.get_default("early_stopping_epochs", default=np.inf)
+        restore_best_weights = self.data_store.get_default("restore_best_model_weights", default=True)
+        assert bool(isinstance(patience, int) or np.isinf(patience)) is True
+        cb = tf.keras.callbacks.EarlyStopping(patience=patience, restore_best_weights=restore_best_weights)
+        callbacks.add_callback(cb, self.callbacks_name % "early_stopping", "early_stopping")
+
+        # create model checkpoint
         callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
-                                          save_best_only=True, mode='auto')
+                                          save_best_only=True, mode='auto', restore_best_weights=restore_best_weights)
+
+        # store callbacks
         self.data_store.set("callbacks", callbacks, self.scope)
 
     def load_model(self):
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index 7a31f83fdf09ceed649aa02d4eecb74a9165eba0..8a5aa98b22d3f3c2cc4e1f32b8f816a14146b716 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -41,7 +41,7 @@ class PostProcessing(RunEnvironment):
         #. create plots
 
     Required objects [scope] from data store:
-        * `best_model` [.] or locally saved model plus `model_name` [model] and `model` [model]
+        * `model` [.] or locally saved model plus `model_name` [model] and `model` [model]
         * `generator` [train, val, test, train_val]
         * `forecast_path` [.]
         * `plot_path` [postprocessing]
@@ -479,7 +479,7 @@ class PostProcessing(RunEnvironment):
         :return: the model
         """
         try:  # is only available if a model was trained in training stage
-            model = self.data_store.get("best_model")
+            model = self.data_store.get("model")
         except NameNotFoundInDataStore:
             logging.info("No model was saved in data store. Try to load model from experiment path.")
             model_name = self.data_store.get("model_name", "model")
diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py
index b0233eb4090fdd72b5635109c74dcfe388e44e37..cb9527ff9243c0d35c2fddb3d22368ef918ac2af 100644
--- a/mlair/run_modules/training.py
+++ b/mlair/run_modules/training.py
@@ -54,7 +54,7 @@ class Training(RunEnvironment):
         * `upsampling` [train, val, test]
 
     Sets
-        * `best_model` [.]
+        * `model` [.]
 
     Creates
         * `<exp_name>_model-best.h5`
@@ -177,16 +177,15 @@ class Training(RunEnvironment):
         except IndexError:
             epo_timing = None
         self.save_callbacks_as_json(history, lr, epo_timing)
-        self.load_best_model(checkpoint.filepath)
         self.create_monitoring_plots(history, lr, epoch_best)
 
     def save_model(self) -> None:
         """Save model in local experiment directory. Model is named as `<experiment_name>_<custom_model_name>.h5`."""
         model_name = self.data_store.get("model_name", "model")
-        logging.debug(f"save best model to {model_name}")
-        self.model.save(model_name, save_format='h5')
-        self.model.save(model_name)
-        self.data_store.set("best_model", self.model)
+        logging.debug(f"save model to {model_name}")
+        self.model.save(model_name, save_format="h5")
+        self.model.save(model_name, save_format="tf")
+        self.data_store.set("model", self.model)
 
     def load_best_model(self, name: str) -> None:
         """
diff --git a/test/test_configuration/test_defaults.py b/test/test_configuration/test_defaults.py
index f6bc6d24724c2620083602d3864bcbca0a709681..07a5aa2f543b1992baf10421de4b28133feb0eac 100644
--- a/test/test_configuration/test_defaults.py
+++ b/test/test_configuration/test_defaults.py
@@ -21,6 +21,8 @@ class TestAllDefaults:
         assert DEFAULT_PERMUTE_DATA is False
         assert DEFAULT_BATCH_SIZE == int(256 * 2)
         assert DEFAULT_EPOCHS == 20
+        assert bool(np.isinf(DEFAULT_EARLY_STOPPING_EPOCHS)) is True
+        assert DEFAULT_RESTORE_BEST_MODEL_WEIGHTS is True
 
     def test_data_handler_parameters(self):
         assert DEFAULT_STATIONS == ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']