diff --git a/src/plotting/training_monitoring.py b/src/plotting/training_monitoring.py
index 617ff3135734056e51746ae2924a123a3eb34f8f..87e4071a0ac98c15c22131cd5d3418eb7c1b6976 100644
--- a/src/plotting/training_monitoring.py
+++ b/src/plotting/training_monitoring.py
@@ -25,12 +25,15 @@ class PlotModelHistory:
     metrics). The plot is saved locally. For a proper saving behaviour, the parameter filename must include the absolute
     path for the plot.
     """
-    def __init__(self, filename: str, history: history_object, plot_metric: str = "loss", main_branch: bool = True):
+    def __init__(self, filename: str, history: history_object, plot_metric: str = "loss", main_branch: bool = False):
         """
         Sets attributes and create plot
         :param filename: saving name of the plot to create (preferably absolute path if possible), the filename needs a
             format ending like .pdf or .png to work.
         :param history: the history object (or a dict with at least 'loss' and 'val_loss' as keys) to plot loss from
+        :param plot_metric: the metric to plot (e.b. mean_squared_error, mse, mean_absolute_error, loss, default: loss)
+        :param main_branch: switch between only looking for metrics that go with 'main' or for all occurrences (default:
+            False -> look for losses from all branches, not only from main)
         """
         if isinstance(history, keras.callbacks.History):
             history = history.history
@@ -45,10 +48,8 @@ class PlotModelHistory:
             plot_metric = "mean_squared_error"
         elif plot_metric.lower() == "mae":
             plot_metric = "mean_absolute_error"
-        available_keys = [k for k in history.keys() if plot_metric in k and ("main" in k if main_branch else True)]
-        print(available_keys)
+        available_keys = [k for k in history.keys() if plot_metric in k and ("main" in k.lower() if main_branch else True)]
         available_keys.sort(key=len)
-        print(available_keys)
         return available_keys[0]
 
     def _filter_columns(self, history: Dict) -> List[str]:
@@ -60,6 +61,9 @@ class PlotModelHistory:
         :return: filtered columns including all plot_metric variations except <plot_metric> and val_<plot_metric>.
         """
         cols = list(filter(lambda x: self._plot_metric in x, history.keys()))
+        # heuristic: there is always val_<plot_metric> and <plot_metric> available in cols, because this is generated by
+        # the keras framework. If this metric isn't available the self._get_plot_metric() will fail before (but only
+        # because it is executed before)
         cols.remove(f"val_{self._plot_metric}")
         cols.remove(self._plot_metric)
         return cols
diff --git a/src/run_modules/training.py b/src/run_modules/training.py
index d15060daeb2596d048ff44f5f9948c0002069762..fee1b38b97d8f4649730f0f7110cd3801ba7db33 100644
--- a/src/run_modules/training.py
+++ b/src/run_modules/training.py
@@ -134,10 +134,17 @@ class Training(RunEnvironment):
         """
         path = self.data_store.get("plot_path", "general")
         name = self.data_store.get("experiment_name", "general")
+
+        # plot history of loss and mse (if available)
         filename = os.path.join(path, f"{name}_history_loss.pdf")
-        PlotModelHistory(filename=filename, history=history, main_branch=False)
-        filename = os.path.join(path, f"{name}_history_main_loss.pdf")
         PlotModelHistory(filename=filename, history=history)
-        filename = os.path.join(path, f"{name}_history_main_mse.pdf")
-        PlotModelHistory(filename=filename, history=history, plot_metric="mse")
+        multiple_branches_used = len(history.model.output_names) > 1  # means that there are multiple output branches
+        if multiple_branches_used:
+            filename = os.path.join(path, f"{name}_history_main_loss.pdf")
+            PlotModelHistory(filename=filename, history=history, main_branch=True)
+        if "mean_squared_error" in history.model.metrics_names:
+            filename = os.path.join(path, f"{name}_history_main_mse.pdf")
+            PlotModelHistory(filename=filename, history=history, plot_metric="mse", main_branch=multiple_branches_used)
+
+        # plot learning rate
         PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc)
diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py
index 3426eb1d355a6690ee57c3ee45e5088d7df9c249..accb32e5e3ec0fc425065ae6199c0418c524b174 100644
--- a/test/test_modules/test_training.py
+++ b/test/test_modules/test_training.py
@@ -37,7 +37,7 @@ def my_test_model(activation, window_history_size, channels, dropout_rate, add_m
 class TestTraining:
 
     @pytest.fixture
-    def init_without_run(self, path, model, checkpoint):
+    def init_without_run(self, path: str, model: keras.Model, checkpoint: ModelCheckpoint):
         obj = object.__new__(Training)
         super(Training, obj).__init__()
         obj.model = model
@@ -82,6 +82,7 @@ class TestTraining:
                      'loss': [0.6795708956961347, 0.45963566494176616],
                      'mean_squared_error': [0.6795708956961347, 0.45963566494176616],
                      'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]}
+        h.model = mock.MagicMock()
         return h
 
     @pytest.fixture
@@ -103,7 +104,7 @@ class TestTraining:
         return ModelCheckpoint(os.path.join(path, "model_checkpoint"), monitor='val_loss', save_best_only=True)
 
     @pytest.fixture
-    def ready_to_train(self, generator, init_without_run):
+    def ready_to_train(self, generator: DataGenerator, init_without_run: Training):
         init_without_run.train_set = Distributor(generator, init_without_run.model, init_without_run.batch_size)
         init_without_run.val_set = Distributor(generator, init_without_run.model, init_without_run.batch_size)
         init_without_run.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
@@ -208,5 +209,7 @@ class TestTraining:
 
     def test_create_monitoring_plots(self, init_without_run, learning_rate, history, path):
         assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0
+        history.model.output_names = mock.MagicMock(return_value=["Main"])
+        history.model.metrics_names = mock.MagicMock(return_value=["loss", "mean_squared_error"])
         init_without_run.create_monitoring_plots(history, learning_rate)
         assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2
diff --git a/test/test_plotting/test_training_monitoring.py b/test/test_plotting/test_training_monitoring.py
index d38fc623b7fd2f5f8c779ea03757d7bf0794089a..e2d17057facee8ae07284109392f3e42e8e2fb6e 100644
--- a/test/test_plotting/test_training_monitoring.py
+++ b/test/test_plotting/test_training_monitoring.py
@@ -17,50 +17,81 @@ def path():
 class TestPlotModelHistory:
 
     @pytest.fixture
-    def history(self):
+    def default_history(self):
         hist = keras.callbacks.History()
         hist.epoch = [0, 1]
         hist.history = {'val_loss': [0.5586272982587484, 0.45712877659670287],
-                     'val_mean_squared_error': [0.5586272982587484, 0.45712877659670287],
-                     'val_mean_absolute_error': [0.595368885413389, 0.530547587585537],
-                     'loss': [0.6795708956961347, 0.45963566494176616],
-                     'mean_squared_error': [0.6795708956961347, 0.45963566494176616],
-                     'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]}
+                        'val_mean_squared_error': [0.5586272982587484, 0.45712877659670287],
+                        'val_mean_absolute_error': [0.595368885413389, 0.530547587585537],
+                        'loss': [0.6795708956961347, 0.45963566494176616],
+                        'mean_squared_error': [0.6795708956961347, 0.45963566494176616],
+                        'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]}
         return hist
 
     @pytest.fixture
-    def history_var(self):
+    def history_additional_loss(self):
         hist = keras.callbacks.History()
         hist.epoch = [0, 1]
         hist.history = {'val_loss': [0.5586272982587484, 0.45712877659670287],
-                     'test_loss': [0.595368885413389, 0.530547587585537],
-                     'loss': [0.6795708956961347, 0.45963566494176616],
-                     'mean_squared_error': [0.6795708956961347, 0.45963566494176616],
-                     'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]}
+                        'test_loss': [0.595368885413389, 0.530547587585537],
+                        'loss': [0.6795708956961347, 0.45963566494176616],
+                        'mean_squared_error': [0.6795708956961347, 0.45963566494176616],
+                        'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]}
         return hist
 
+    @pytest.fixture
+    def history_with_main(self, default_history):
+        default_history.history["main_val_loss"] = [0.5586272982587484, 0.45712877659670287]
+        default_history.history["main_loss"] = [0.6795708956961347, 0.45963566494176616]
+        return default_history
+
     @pytest.fixture
     def no_init(self):
         return object.__new__(PlotModelHistory)
 
-    def test_plot_from_hist_obj(self, history, path):
+    def test_get_plot_metric(self, no_init, default_history):
+        history = default_history.history
+        metric = no_init._get_plot_metric(history, plot_metric="loss", main_branch=False)
+        assert metric == "loss"
+        metric = no_init._get_plot_metric(history, plot_metric="mean_squared_error", main_branch=False)
+        assert metric == "mean_squared_error"
+
+    def test_get_plot_metric_short_metric(self, no_init, default_history):
+        history = default_history.history
+        metric = no_init._get_plot_metric(history, plot_metric="mse", main_branch=False)
+        assert metric == "mean_squared_error"
+        metric = no_init._get_plot_metric(history, plot_metric="mae", main_branch=False)
+        assert metric == "mean_absolute_error"
+
+    def test_get_plot_metric_main_branch(self, no_init, history_with_main):
+        history = history_with_main.history
+        metric = no_init._get_plot_metric(history, plot_metric="loss", main_branch=True)
+        assert metric == "main_loss"
+
+    def test_filter_columns(self, no_init):
+        no_init._plot_metric = "loss"
+        res = no_init._filter_columns({'loss': None, 'another_loss': None, 'val_loss': None, 'wrong': None})
+        assert res == ['another_loss']
+        no_init._plot_metric = "mean_squared_error"
+        res = no_init._filter_columns({'mean_squared_error': None, 'another_loss': None, 'val_mean_squared_error': None,
+                                       'wrong': None})
+        assert res == []
+
+    def test_plot_from_hist_obj(self, default_history, path):
         assert "hist_obj.pdf" not in os.listdir(path)
-        PlotModelHistory(os.path.join(path, "hist_obj.pdf"), history)
+        PlotModelHistory(os.path.join(path, "hist_obj.pdf"), default_history)
         assert "hist_obj.pdf" in os.listdir(path)
 
-    def test_plot_from_hist_dict(self, history, path):
+    def test_plot_from_hist_dict(self, default_history, path):
         assert "hist_dict.pdf" not in os.listdir(path)
-        PlotModelHistory(os.path.join(path, "hist_dict.pdf"), history.history)
+        PlotModelHistory(os.path.join(path, "hist_dict.pdf"), default_history.history)
         assert "hist_dict.pdf" in os.listdir(path)
 
-    def test_plot_additional_loss(self, history_var, path):
+    def test_plot_additional_loss(self, history_additional_loss, path):
         assert "hist_additional.pdf" not in os.listdir(path)
-        PlotModelHistory(os.path.join(path, "hist_additional.pdf"), history_var)
+        PlotModelHistory(os.path.join(path, "hist_additional.pdf"), history_additional_loss)
         assert "hist_additional.pdf" in os.listdir(path)
 
-    def test_filter_list(self, no_init):
-        res = no_init._filter_columns({'loss': None, 'another_loss': None, 'val_loss': None, 'wrong': None})
-        assert res == ['another_loss']
 
 
 class TestPlotModelLearningRate: