diff --git a/src/plotting/training_monitoring.py b/src/plotting/training_monitoring.py
index b18cce7a8993899621295644016f1e126d0dfac8..dee36166abacf275c79213305ac15e1918e1957c 100644
--- a/src/plotting/training_monitoring.py
+++ b/src/plotting/training_monitoring.py
@@ -19,12 +19,13 @@ lr_object = Union[Dict, LearningRateDecay]
 
 class PlotModelHistory:
     """
-    Plots history of all losses for a training event. For default loss and val_loss are plotted. If further losses are
-    provided (name must somehow include the word `loss`), this additional information is added to the plot with an
-    separate y-axis scale on the right side (shared for all additional losses). The plot is saved locally. For a proper
-    saving behaviour, the parameter filename must include the absolute path for the plot.
+    Plots history of all plot_metrics (default: loss) for a training event. For default plot_metric and val_plot_metric
+    are plotted. If further metrics are provided (name must somehow include the word `<plot_metric>`), this additional
+    information is added to the plot with an separate y-axis scale on the right side (shared for all additional
+    metrics). The plot is saved locally. For a proper saving behaviour, the parameter filename must include the absolute
+    path for the plot.
     """
-    def __init__(self, filename: str, history: history_object):
+    def __init__(self, filename: str, history: history_object, plot_metric: str = "loss"):
         """
         Sets attributes and create plot
         :param filename: saving name of the plot to create (preferably absolute path if possible), the filename needs a
@@ -34,31 +35,34 @@ class PlotModelHistory:
         if isinstance(history, keras.callbacks.History):
             history = history.history
         self._data = pd.DataFrame.from_dict(history)
+        self._plot_metric = plot_metric
         self._additional_columns = self._filter_columns(history)
         self._plot(filename)
 
-    @staticmethod
-    def _filter_columns(history: Dict) -> List[str]:
+    def _filter_columns(self, history: Dict) -> List[str]:
         """
-        Select only columns named like %loss%. The default losses 'loss' and 'val_loss' are also removed.
-        :param history: a dict with at least 'loss' and 'val_loss' as keys (can be derived from keras History.history)
-        :return: filtered columns including all loss variations except loss and val_loss.
+        Select only columns named like %<plot_metric>%. The default metrics '<plot_metric>' and 'val_<plot_metric>' are
+        also removed.
+        :param history: a dict with at least '<plot_metric>' and 'val_<plot_metric>' as keys (can be derived from keras
+            History.history)
+        :return: filtered columns including all plot_metric variations except <plot_metric> and val_<plot_metric>.
         """
-        cols = list(filter(lambda x: "loss" in x, history.keys()))
-        cols.remove("val_loss")
-        cols.remove("loss")
+        cols = list(filter(lambda x: self._plot_metric in x, history.keys()))
+        cols.remove(f"val_{self._plot_metric}")
+        cols.remove(self._plot_metric)
         return cols
 
     def _plot(self, filename: str) -> None:
         """
-        Actual plot routine. Plots loss and val_loss as default. If more losses are provided, they will be added with
-        an additional yaxis on the right side. The plot is saved in filename.
+        Actual plot routine. Plots <plot_metric> and val_<plot_metric> as default. If more plot_metrics are provided,
+        they will be added with an additional yaxis on the right side. The plot is saved in filename.
         :param filename: name (including total path) of the plot to save.
         """
-        ax = self._data[["loss", "val_loss"]].plot(linewidth=0.7)
+        ax = self._data[[self._plot_metric, f"val_{self._plot_metric}"]].plot(linewidth=0.7)
         if len(self._additional_columns) > 0:
             self._data[self._additional_columns].plot(linewidth=0.7, secondary_y=True, ax=ax)
-        ax.set(xlabel="epoch", ylabel="loss", title=f"Model loss: best = {self._data[['val_loss']].min().values}")
+        title = f"Model {self._plot_metric}: best = {self._data[[f'val_{self._plot_metric}']].min().values}"
+        ax.set(xlabel="epoch", ylabel=self._plot_metric, title=title)
         ax.axhline(y=0, color="gray", linewidth=0.5)
         plt.tight_layout()
         plt.savefig(filename)