diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py
index 8b99acd0f5723d3b00ec1bd0098712753da21b52..d36e808b1024e597e04d25c38853d79425cd89e7 100644
--- a/mlair/model_modules/keras_extensions.py
+++ b/mlair/model_modules/keras_extensions.py
@@ -163,6 +163,7 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
     def __init__(self, *args, **kwargs):
         """Initialise ModelCheckpointAdvanced and set callbacks attribute."""
         self.callbacks = kwargs.pop("callbacks")
+        self.epoch_best = kwargs.pop("epoch_best", 0)
         super().__init__(*args, **kwargs)
 
     def update_best(self, hist):
@@ -197,6 +198,7 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
                 if self.save_best_only:
                     current = logs.get(self.monitor)
                     if current == self.best:
+                        self.epoch_best = epoch
                         if self.verbose > 0:  # pragma: no branch
                             print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
                         with open(file_path, "wb") as f:
diff --git a/mlair/plotting/training_monitoring.py b/mlair/plotting/training_monitoring.py
index 39dd80651226519463d7b503fb612e43983d73cf..4884dcb81c2b98546da3edce099c02b47aebd7b2 100644
--- a/mlair/plotting/training_monitoring.py
+++ b/mlair/plotting/training_monitoring.py
@@ -27,7 +27,8 @@ class PlotModelHistory:
     parameter filename must include the absolute path for the plot.
     """
 
-    def __init__(self, filename: str, history: history_object, plot_metric: str = "loss", main_branch: bool = False):
+    def __init__(self, filename: str, history: history_object, plot_metric: str = "loss", main_branch: bool = False,
+                 epoch_best: int = None):
         """
         Set attributes and create plot.
 
@@ -37,12 +38,15 @@ class PlotModelHistory:
         :param plot_metric: the metric to plot (e.b. mean_squared_error, mse, mean_absolute_error, loss, default: loss)
         :param main_branch: switch between only looking for metrics that go with 'main' or for all occurrences (default:
             False -> look for losses from all branches, not only from main)
+        :param epoch_best: indicator at which epoch the best train result was achieved (should start counting at 0)
         """
         if isinstance(history, keras.callbacks.History):
             history = history.history
         self._data = pd.DataFrame.from_dict(history)
+        self._data.index += 1
         self._plot_metric = self._get_plot_metric(history, plot_metric, main_branch)
         self._additional_columns = self._filter_columns(history)
+        self._epoch_best = epoch_best
         self._plot(filename)
 
     def _get_plot_metric(self, history, plot_metric, main_branch, correct_names=True):
@@ -88,6 +92,9 @@ class PlotModelHistory:
         :param filename: name (including total path) of the plot to save.
         """
         ax = self._data[[self._plot_metric, f"val_{self._plot_metric}"]].plot(linewidth=0.7)
+        if self._epoch_best is not None:
+            ax.scatter(self._epoch_best+1, self._data[[f"val_{self._plot_metric}"]].iloc[self._epoch_best],
+                       s=100, marker="*", c="black")
         ax.set_yscale('log')
         if len(self._additional_columns) > 0:
             self._data[self._additional_columns].plot(linewidth=0.7, secondary_y=True, ax=ax, logy=True)
diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py
index a38837dce041295d37fae1ea86ef2a215d51dc89..5ddf91ebf6659d08e1163aceee6000a8082f0bef 100644
--- a/mlair/run_modules/training.py
+++ b/mlair/run_modules/training.py
@@ -165,6 +165,8 @@ class Training(RunEnvironment):
                                initial_epoch=initial_epoch,
                                workers=psutil.cpu_count(logical=False))
             history = hist
+        epoch_best = checkpoint.epoch_best
+        logging.info(f"best epoch: {epoch_best + 1}")
         try:
             lr = self.callbacks.get_callback_by_name("lr")
         except IndexError:
@@ -175,7 +177,7 @@ class Training(RunEnvironment):
             epo_timing = None
         self.save_callbacks_as_json(history, lr, epo_timing)
         self.load_best_model(checkpoint.filepath)
-        self.create_monitoring_plots(history, lr)
+        self.create_monitoring_plots(history, lr, epoch_best)
 
     def save_model(self) -> None:
         """Save model in local experiment directory. Model is named as `<experiment_name>_<custom_model_name>.h5`."""
@@ -194,9 +196,9 @@ class Training(RunEnvironment):
         logging.debug(f"load best model: {name}")
         try:
             self.model.load_model(name, compile=True)
-            logging.info('reload model...')
+            logging.info(f"reload model...")
         except OSError:
-            logging.info('no weights to reload...')
+            logging.info("no weights to reload...")
 
     def save_callbacks_as_json(self, history: Callback, lr_sc: Callback, epo_timing: Callback) -> None:
         """
@@ -219,7 +221,7 @@ class Training(RunEnvironment):
             with open(os.path.join(path, "epo_timing.json"), "w") as f:
                 json.dump(epo_timing.epo_timing, f)
 
-    def create_monitoring_plots(self, history: Callback, lr_sc: Callback) -> None:
+    def create_monitoring_plots(self, history: Callback, lr_sc: Callback, epoch_best: int) -> None:
         """
         Create plot of history and learning rate in dependence of the number of epochs.
 
@@ -228,22 +230,23 @@ class Training(RunEnvironment):
 
         :param history: keras history object with losses to plot (must at least include `loss` and `val_loss`)
         :param lr_sc:  learning rate decay object with 'lr' attribute
+        :param epoch_best: number of best epoch (starts counting as 0)
         """
         path = self.data_store.get("plot_path")
         name = self.data_store.get("experiment_name")
 
         # plot history of loss and mse (if available)
         filename = os.path.join(path, f"{name}_history_loss.pdf")
-        PlotModelHistory(filename=filename, history=history)
+        PlotModelHistory(filename=filename, history=history, epoch_best=epoch_best)
         multiple_branches_used = len(history.model.output_names) > 1  # means that there are multiple output branches
         if multiple_branches_used:
             filename = os.path.join(path, f"{name}_history_main_loss.pdf")
-            PlotModelHistory(filename=filename, history=history, main_branch=True)
+            PlotModelHistory(filename=filename, history=history, main_branch=True, epoch_best=epoch_best)
         mse_indicator = list(set(history.model.metrics_names).intersection(["mean_squared_error", "mse"]))
         if len(mse_indicator) > 0:
             filename = os.path.join(path, f"{name}_history_main_mse.pdf")
             PlotModelHistory(filename=filename, history=history, plot_metric=mse_indicator[0],
-                             main_branch=multiple_branches_used)
+                             main_branch=multiple_branches_used, epoch_best=epoch_best)
 
         # plot learning rate
         if lr_sc: