diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py
index 0e31cd66fd64d19f1cecc7c9906a5c3b9446fe75..6b0fe236ff8ee726c34a721a6be0ed8be91f2bb8 100644
--- a/src/model_modules/model_class.py
+++ b/src/model_modules/model_class.py
@@ -219,15 +219,15 @@ class MyBranchedModel(AbstractModelClass):
         x_in = keras.layers.Dense(64, name='{}_Dense_64'.format("major"))(x_in)
         x_in = self.activation()(x_in)
         out_minor_1 = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("minor_1"))(x_in)
-        out_minor_1 = self.activation()(out_minor_1)
+        out_minor_1 = self.activation(name="minor_1")(out_minor_1)
         x_in = keras.layers.Dense(32, name='{}_Dense_32'.format("major"))(x_in)
         x_in = self.activation()(x_in)
         out_minor_2 = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("minor_2"))(x_in)
-        out_minor_2 = self.activation()(out_minor_2)
+        out_minor_2 = self.activation(name="minor_2")(out_minor_2)
         x_in = keras.layers.Dense(16, name='{}_Dense_16'.format("major"))(x_in)
         x_in = self.activation()(x_in)
         x_in = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("major"))(x_in)
-        out_main = self.activation()(x_in)
+        out_main = self.activation(name="main")(x_in)
         self.model = keras.Model(inputs=x_input, outputs=[out_minor_1, out_minor_2, out_main])
 
     def set_loss(self):
diff --git a/src/plotting/training_monitoring.py b/src/plotting/training_monitoring.py
index b18cce7a8993899621295644016f1e126d0dfac8..87e4071a0ac98c15c22131cd5d3418eb7c1b6976 100644
--- a/src/plotting/training_monitoring.py
+++ b/src/plotting/training_monitoring.py
@@ -19,46 +19,66 @@ lr_object = Union[Dict, LearningRateDecay]
 
 class PlotModelHistory:
     """
-    Plots history of all losses for a training event. For default loss and val_loss are plotted. If further losses are
-    provided (name must somehow include the word `loss`), this additional information is added to the plot with an
-    separate y-axis scale on the right side (shared for all additional losses). The plot is saved locally. For a proper
-    saving behaviour, the parameter filename must include the absolute path for the plot.
+    Plots history of all plot_metrics (default: loss) for a training event. For default plot_metric and val_plot_metric
+    are plotted. If further metrics are provided (name must somehow include the word `<plot_metric>`), this additional
+    information is added to the plot with an separate y-axis scale on the right side (shared for all additional
+    metrics). The plot is saved locally. For a proper saving behaviour, the parameter filename must include the absolute
+    path for the plot.
     """
-    def __init__(self, filename: str, history: history_object):
+    def __init__(self, filename: str, history: history_object, plot_metric: str = "loss", main_branch: bool = False):
         """
         Sets attributes and create plot
         :param filename: saving name of the plot to create (preferably absolute path if possible), the filename needs a
             format ending like .pdf or .png to work.
         :param history: the history object (or a dict with at least 'loss' and 'val_loss' as keys) to plot loss from
+        :param plot_metric: the metric to plot (e.b. mean_squared_error, mse, mean_absolute_error, loss, default: loss)
+        :param main_branch: switch between only looking for metrics that go with 'main' or for all occurrences (default:
+            False -> look for losses from all branches, not only from main)
         """
         if isinstance(history, keras.callbacks.History):
             history = history.history
         self._data = pd.DataFrame.from_dict(history)
+        self._plot_metric = self._get_plot_metric(history, plot_metric, main_branch)
         self._additional_columns = self._filter_columns(history)
         self._plot(filename)
 
     @staticmethod
-    def _filter_columns(history: Dict) -> List[str]:
+    def _get_plot_metric(history, plot_metric, main_branch):
+        if plot_metric.lower() == "mse":
+            plot_metric = "mean_squared_error"
+        elif plot_metric.lower() == "mae":
+            plot_metric = "mean_absolute_error"
+        available_keys = [k for k in history.keys() if plot_metric in k and ("main" in k.lower() if main_branch else True)]
+        available_keys.sort(key=len)
+        return available_keys[0]
+
+    def _filter_columns(self, history: Dict) -> List[str]:
         """
-        Select only columns named like %loss%. The default losses 'loss' and 'val_loss' are also removed.
-        :param history: a dict with at least 'loss' and 'val_loss' as keys (can be derived from keras History.history)
-        :return: filtered columns including all loss variations except loss and val_loss.
+        Select only columns named like %<plot_metric>%. The default metrics '<plot_metric>' and 'val_<plot_metric>' are
+        also removed.
+        :param history: a dict with at least '<plot_metric>' and 'val_<plot_metric>' as keys (can be derived from keras
+            History.history)
+        :return: filtered columns including all plot_metric variations except <plot_metric> and val_<plot_metric>.
         """
-        cols = list(filter(lambda x: "loss" in x, history.keys()))
-        cols.remove("val_loss")
-        cols.remove("loss")
+        cols = list(filter(lambda x: self._plot_metric in x, history.keys()))
+        # heuristic: there is always val_<plot_metric> and <plot_metric> available in cols, because this is generated by
+        # the keras framework. If this metric isn't available the self._get_plot_metric() will fail before (but only
+        # because it is executed before)
+        cols.remove(f"val_{self._plot_metric}")
+        cols.remove(self._plot_metric)
         return cols
 
     def _plot(self, filename: str) -> None:
         """
-        Actual plot routine. Plots loss and val_loss as default. If more losses are provided, they will be added with
-        an additional yaxis on the right side. The plot is saved in filename.
+        Actual plot routine. Plots <plot_metric> and val_<plot_metric> as default. If more plot_metrics are provided,
+        they will be added with an additional yaxis on the right side. The plot is saved in filename.
         :param filename: name (including total path) of the plot to save.
         """
-        ax = self._data[["loss", "val_loss"]].plot(linewidth=0.7)
+        ax = self._data[[self._plot_metric, f"val_{self._plot_metric}"]].plot(linewidth=0.7)
         if len(self._additional_columns) > 0:
             self._data[self._additional_columns].plot(linewidth=0.7, secondary_y=True, ax=ax)
-        ax.set(xlabel="epoch", ylabel="loss", title=f"Model loss: best = {self._data[['val_loss']].min().values}")
+        title = f"Model {self._plot_metric}: best = {self._data[[f'val_{self._plot_metric}']].min().values}"
+        ax.set(xlabel="epoch", ylabel=self._plot_metric, title=title)
         ax.axhline(y=0, color="gray", linewidth=0.5)
         plt.tight_layout()
         plt.savefig(filename)
diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py
index 0f3ff6d436b8a65528626f5f80508af222a1e68f..a7722018c52275b390a10199cb30b7b936ed37a3 100644
--- a/src/run_modules/model_setup.py
+++ b/src/run_modules/model_setup.py
@@ -15,8 +15,8 @@ from src.run_modules.run_environment import RunEnvironment
 from src.helpers import l_p_loss, LearningRateDecay
 from src.model_modules.inception_model import InceptionModelBase
 from src.model_modules.flatten import flatten_tail
-# from src.model_modules.model_class import MyBranchedModel as MyModel
-from src.model_modules.model_class import MyLittleModel as MyModel
+from src.model_modules.model_class import MyBranchedModel as MyModel
+# from src.model_modules.model_class import MyLittleModel as MyModel
 
 
 class ModelSetup(RunEnvironment):
diff --git a/src/run_modules/training.py b/src/run_modules/training.py
index 272609a31a3e3c91d6857ed841d5dd2783c66f35..fee1b38b97d8f4649730f0f7110cd3801ba7db33 100644
--- a/src/run_modules/training.py
+++ b/src/run_modules/training.py
@@ -134,5 +134,17 @@ class Training(RunEnvironment):
         """
         path = self.data_store.get("plot_path", "general")
         name = self.data_store.get("experiment_name", "general")
-        PlotModelHistory(filename=os.path.join(path, f"{name}_history_loss_val_loss.pdf"), history=history)
+
+        # plot history of loss and mse (if available)
+        filename = os.path.join(path, f"{name}_history_loss.pdf")
+        PlotModelHistory(filename=filename, history=history)
+        multiple_branches_used = len(history.model.output_names) > 1  # means that there are multiple output branches
+        if multiple_branches_used:
+            filename = os.path.join(path, f"{name}_history_main_loss.pdf")
+            PlotModelHistory(filename=filename, history=history, main_branch=True)
+        if "mean_squared_error" in history.model.metrics_names:
+            filename = os.path.join(path, f"{name}_history_main_mse.pdf")
+            PlotModelHistory(filename=filename, history=history, plot_metric="mse", main_branch=multiple_branches_used)
+
+        # plot learning rate
         PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc)
diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py
index 3426eb1d355a6690ee57c3ee45e5088d7df9c249..accb32e5e3ec0fc425065ae6199c0418c524b174 100644
--- a/test/test_modules/test_training.py
+++ b/test/test_modules/test_training.py
@@ -37,7 +37,7 @@ def my_test_model(activation, window_history_size, channels, dropout_rate, add_m
 class TestTraining:
 
     @pytest.fixture
-    def init_without_run(self, path, model, checkpoint):
+    def init_without_run(self, path: str, model: keras.Model, checkpoint: ModelCheckpoint):
         obj = object.__new__(Training)
         super(Training, obj).__init__()
         obj.model = model
@@ -82,6 +82,7 @@ class TestTraining:
                      'loss': [0.6795708956961347, 0.45963566494176616],
                      'mean_squared_error': [0.6795708956961347, 0.45963566494176616],
                      'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]}
+        h.model = mock.MagicMock()
         return h
 
     @pytest.fixture
@@ -103,7 +104,7 @@ class TestTraining:
         return ModelCheckpoint(os.path.join(path, "model_checkpoint"), monitor='val_loss', save_best_only=True)
 
     @pytest.fixture
-    def ready_to_train(self, generator, init_without_run):
+    def ready_to_train(self, generator: DataGenerator, init_without_run: Training):
         init_without_run.train_set = Distributor(generator, init_without_run.model, init_without_run.batch_size)
         init_without_run.val_set = Distributor(generator, init_without_run.model, init_without_run.batch_size)
         init_without_run.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
@@ -208,5 +209,7 @@ class TestTraining:
 
     def test_create_monitoring_plots(self, init_without_run, learning_rate, history, path):
         assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0
+        history.model.output_names = mock.MagicMock(return_value=["Main"])
+        history.model.metrics_names = mock.MagicMock(return_value=["loss", "mean_squared_error"])
         init_without_run.create_monitoring_plots(history, learning_rate)
         assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2
diff --git a/test/test_plotting/test_training_monitoring.py b/test/test_plotting/test_training_monitoring.py
index d38fc623b7fd2f5f8c779ea03757d7bf0794089a..e2d17057facee8ae07284109392f3e42e8e2fb6e 100644
--- a/test/test_plotting/test_training_monitoring.py
+++ b/test/test_plotting/test_training_monitoring.py
@@ -17,50 +17,81 @@ def path():
 class TestPlotModelHistory:
 
     @pytest.fixture
-    def history(self):
+    def default_history(self):
         hist = keras.callbacks.History()
         hist.epoch = [0, 1]
         hist.history = {'val_loss': [0.5586272982587484, 0.45712877659670287],
-                     'val_mean_squared_error': [0.5586272982587484, 0.45712877659670287],
-                     'val_mean_absolute_error': [0.595368885413389, 0.530547587585537],
-                     'loss': [0.6795708956961347, 0.45963566494176616],
-                     'mean_squared_error': [0.6795708956961347, 0.45963566494176616],
-                     'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]}
+                        'val_mean_squared_error': [0.5586272982587484, 0.45712877659670287],
+                        'val_mean_absolute_error': [0.595368885413389, 0.530547587585537],
+                        'loss': [0.6795708956961347, 0.45963566494176616],
+                        'mean_squared_error': [0.6795708956961347, 0.45963566494176616],
+                        'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]}
         return hist
 
     @pytest.fixture
-    def history_var(self):
+    def history_additional_loss(self):
         hist = keras.callbacks.History()
         hist.epoch = [0, 1]
         hist.history = {'val_loss': [0.5586272982587484, 0.45712877659670287],
-                     'test_loss': [0.595368885413389, 0.530547587585537],
-                     'loss': [0.6795708956961347, 0.45963566494176616],
-                     'mean_squared_error': [0.6795708956961347, 0.45963566494176616],
-                     'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]}
+                        'test_loss': [0.595368885413389, 0.530547587585537],
+                        'loss': [0.6795708956961347, 0.45963566494176616],
+                        'mean_squared_error': [0.6795708956961347, 0.45963566494176616],
+                        'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]}
         return hist
 
+    @pytest.fixture
+    def history_with_main(self, default_history):
+        default_history.history["main_val_loss"] = [0.5586272982587484, 0.45712877659670287]
+        default_history.history["main_loss"] = [0.6795708956961347, 0.45963566494176616]
+        return default_history
+
     @pytest.fixture
     def no_init(self):
         return object.__new__(PlotModelHistory)
 
-    def test_plot_from_hist_obj(self, history, path):
+    def test_get_plot_metric(self, no_init, default_history):
+        history = default_history.history
+        metric = no_init._get_plot_metric(history, plot_metric="loss", main_branch=False)
+        assert metric == "loss"
+        metric = no_init._get_plot_metric(history, plot_metric="mean_squared_error", main_branch=False)
+        assert metric == "mean_squared_error"
+
+    def test_get_plot_metric_short_metric(self, no_init, default_history):
+        history = default_history.history
+        metric = no_init._get_plot_metric(history, plot_metric="mse", main_branch=False)
+        assert metric == "mean_squared_error"
+        metric = no_init._get_plot_metric(history, plot_metric="mae", main_branch=False)
+        assert metric == "mean_absolute_error"
+
+    def test_get_plot_metric_main_branch(self, no_init, history_with_main):
+        history = history_with_main.history
+        metric = no_init._get_plot_metric(history, plot_metric="loss", main_branch=True)
+        assert metric == "main_loss"
+
+    def test_filter_columns(self, no_init):
+        no_init._plot_metric = "loss"
+        res = no_init._filter_columns({'loss': None, 'another_loss': None, 'val_loss': None, 'wrong': None})
+        assert res == ['another_loss']
+        no_init._plot_metric = "mean_squared_error"
+        res = no_init._filter_columns({'mean_squared_error': None, 'another_loss': None, 'val_mean_squared_error': None,
+                                       'wrong': None})
+        assert res == []
+
+    def test_plot_from_hist_obj(self, default_history, path):
         assert "hist_obj.pdf" not in os.listdir(path)
-        PlotModelHistory(os.path.join(path, "hist_obj.pdf"), history)
+        PlotModelHistory(os.path.join(path, "hist_obj.pdf"), default_history)
         assert "hist_obj.pdf" in os.listdir(path)
 
-    def test_plot_from_hist_dict(self, history, path):
+    def test_plot_from_hist_dict(self, default_history, path):
         assert "hist_dict.pdf" not in os.listdir(path)
-        PlotModelHistory(os.path.join(path, "hist_dict.pdf"), history.history)
+        PlotModelHistory(os.path.join(path, "hist_dict.pdf"), default_history.history)
         assert "hist_dict.pdf" in os.listdir(path)
 
-    def test_plot_additional_loss(self, history_var, path):
+    def test_plot_additional_loss(self, history_additional_loss, path):
         assert "hist_additional.pdf" not in os.listdir(path)
-        PlotModelHistory(os.path.join(path, "hist_additional.pdf"), history_var)
+        PlotModelHistory(os.path.join(path, "hist_additional.pdf"), history_additional_loss)
         assert "hist_additional.pdf" in os.listdir(path)
 
-    def test_filter_list(self, no_init):
-        res = no_init._filter_columns({'loss': None, 'another_loss': None, 'val_loss': None, 'wrong': None})
-        assert res == ['another_loss']
 
 
 class TestPlotModelLearningRate: