diff --git a/src/plotting/training_monitoring.py b/src/plotting/training_monitoring.py
index dee36166abacf275c79213305ac15e1918e1957c..617ff3135734056e51746ae2924a123a3eb34f8f 100644
--- a/src/plotting/training_monitoring.py
+++ b/src/plotting/training_monitoring.py
@@ -25,7 +25,7 @@ class PlotModelHistory:
     metrics). The plot is saved locally. For a proper saving behaviour, the parameter filename must include the absolute
     path for the plot.
     """
-    def __init__(self, filename: str, history: history_object, plot_metric: str = "loss"):
+    def __init__(self, filename: str, history: history_object, plot_metric: str = "loss", main_branch: bool = True):
         """
         Sets attributes and create plot
         :param filename: saving name of the plot to create (preferably absolute path if possible), the filename needs a
@@ -35,10 +35,22 @@ class PlotModelHistory:
         if isinstance(history, keras.callbacks.History):
             history = history.history
         self._data = pd.DataFrame.from_dict(history)
-        self._plot_metric = plot_metric
+        self._plot_metric = self._get_plot_metric(history, plot_metric, main_branch)
         self._additional_columns = self._filter_columns(history)
         self._plot(filename)
 
+    @staticmethod
+    def _get_plot_metric(history, plot_metric, main_branch):
+        if plot_metric.lower() == "mse":
+            plot_metric = "mean_squared_error"
+        elif plot_metric.lower() == "mae":
+            plot_metric = "mean_absolute_error"
+        available_keys = [k for k in history.keys() if plot_metric in k and ("main" in k if main_branch else True)]
+        print(available_keys)
+        available_keys.sort(key=len)
+        print(available_keys)
+        return available_keys[0]
+
     def _filter_columns(self, history: Dict) -> List[str]:
         """
         Select only columns named like %<plot_metric>%. The default metrics '<plot_metric>' and 'val_<plot_metric>' are
diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py
index 0f3ff6d436b8a65528626f5f80508af222a1e68f..a7722018c52275b390a10199cb30b7b936ed37a3 100644
--- a/src/run_modules/model_setup.py
+++ b/src/run_modules/model_setup.py
@@ -15,8 +15,8 @@ from src.run_modules.run_environment import RunEnvironment
 from src.helpers import l_p_loss, LearningRateDecay
 from src.model_modules.inception_model import InceptionModelBase
 from src.model_modules.flatten import flatten_tail
-# from src.model_modules.model_class import MyBranchedModel as MyModel
-from src.model_modules.model_class import MyLittleModel as MyModel
+from src.model_modules.model_class import MyBranchedModel as MyModel
+# from src.model_modules.model_class import MyLittleModel as MyModel
 
 
 class ModelSetup(RunEnvironment):
diff --git a/src/run_modules/training.py b/src/run_modules/training.py
index 272609a31a3e3c91d6857ed841d5dd2783c66f35..d15060daeb2596d048ff44f5f9948c0002069762 100644
--- a/src/run_modules/training.py
+++ b/src/run_modules/training.py
@@ -134,5 +134,10 @@ class Training(RunEnvironment):
         """
         path = self.data_store.get("plot_path", "general")
         name = self.data_store.get("experiment_name", "general")
-        PlotModelHistory(filename=os.path.join(path, f"{name}_history_loss_val_loss.pdf"), history=history)
+        filename = os.path.join(path, f"{name}_history_loss.pdf")
+        PlotModelHistory(filename=filename, history=history, main_branch=False)
+        filename = os.path.join(path, f"{name}_history_main_loss.pdf")
+        PlotModelHistory(filename=filename, history=history)
+        filename = os.path.join(path, f"{name}_history_main_mse.pdf")
+        PlotModelHistory(filename=filename, history=history, plot_metric="mse")
         PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc)