diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py
index 1a8f7c4c400eaf75bdd1dc6af2e0993f662eac49..0e31cd66fd64d19f1cecc7c9906a5c3b9446fe75 100644
--- a/src/model_modules/model_class.py
+++ b/src/model_modules/model_class.py
@@ -154,3 +154,88 @@ class MyLittleModel(AbstractModelClass):
         """
 
         self.loss = keras.losses.mean_squared_error
+
+
+class MyBranchedModel(AbstractModelClass):
+
+    """
+    A customised model
+
+
+    with a 1x1 Conv, and 4 Dense layers (64, 32, 16, window_lead_time), where the last layer is the
+    output layer depending on the window_lead_time parameter. Dropout is used between the Convolution and the first
+    Dense layer.
+    """
+
+    def __init__(self, window_history_size, window_lead_time, channels):
+
+        """
+        Sets model and loss depending on the given arguments.
+        :param activation: activation function
+        :param window_history_size: number of historical time steps included in the input data
+        :param channels: number of variables used in input data
+        :param regularizer: <not used here>
+        :param dropout_rate: dropout rate used in the model [0, 1)
+        :param window_lead_time: number of time steps to forecast in the output layer
+        """
+
+        super().__init__()
+
+        # settings
+        self.window_history_size = window_history_size
+        self.window_lead_time = window_lead_time
+        self.channels = channels
+        self.dropout_rate = 0.1
+        self.regularizer = keras.regularizers.l2(0.1)
+        self.initial_lr = 1e-2
+        self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
+        self.lr_decay = helpers.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10)
+        self.epochs = 2
+        self.batch_size = int(256)
+        self.activation = keras.layers.PReLU
+
+        # apply to model
+        self.set_model()
+        self.set_loss()
+
+    def set_model(self):
+
+        """
+        Build the model.
+        :param activation: activation function
+        :param window_history_size: number of historical time steps included in the input data
+        :param channels: number of variables used in input data
+        :param dropout_rate: dropout rate used in the model [0, 1)
+        :param window_lead_time: number of time steps to forecast in the output layer
+        :return: built keras model
+        """
+
+        # add 1 to window_size to include current time step t0
+        x_input = keras.layers.Input(shape=(self.window_history_size + 1, 1, self.channels))
+        x_in = keras.layers.Conv2D(32, (1, 1), padding='same', name='{}_Conv_1x1'.format("major"))(x_input)
+        x_in = self.activation(name='{}_conv_act'.format("major"))(x_in)
+        x_in = keras.layers.Flatten(name='{}'.format("major"))(x_in)
+        x_in = keras.layers.Dropout(self.dropout_rate, name='{}_Dropout_1'.format("major"))(x_in)
+        x_in = keras.layers.Dense(64, name='{}_Dense_64'.format("major"))(x_in)
+        x_in = self.activation()(x_in)
+        out_minor_1 = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("minor_1"))(x_in)
+        out_minor_1 = self.activation()(out_minor_1)
+        x_in = keras.layers.Dense(32, name='{}_Dense_32'.format("major"))(x_in)
+        x_in = self.activation()(x_in)
+        out_minor_2 = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("minor_2"))(x_in)
+        out_minor_2 = self.activation()(out_minor_2)
+        x_in = keras.layers.Dense(16, name='{}_Dense_16'.format("major"))(x_in)
+        x_in = self.activation()(x_in)
+        x_in = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("major"))(x_in)
+        out_main = self.activation()(x_in)
+        self.model = keras.Model(inputs=x_input, outputs=[out_minor_1, out_minor_2, out_main])
+
+    def set_loss(self):
+
+        """
+        Set the loss
+        :return: loss function
+        """
+
+        self.loss = [keras.losses.mean_absolute_error] + [keras.losses.mean_squared_error] + \
+                    [keras.losses.mean_squared_error]
diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py
index a47ef67ad5781ff37ce812aa931dbd195d4513dc..0f3ff6d436b8a65528626f5f80508af222a1e68f 100644
--- a/src/run_modules/model_setup.py
+++ b/src/run_modules/model_setup.py
@@ -15,7 +15,8 @@ from src.run_modules.run_environment import RunEnvironment
 from src.helpers import l_p_loss, LearningRateDecay
 from src.model_modules.inception_model import InceptionModelBase
 from src.model_modules.flatten import flatten_tail
-from src.model_modules.model_class import MyLittleModel
+# from src.model_modules.model_class import MyBranchedModel as MyModel
+from src.model_modules.model_class import MyLittleModel as MyModel
 
 
 class ModelSetup(RunEnvironment):
@@ -76,7 +77,7 @@ class ModelSetup(RunEnvironment):
     def build_model(self):
         args_list = ["window_history_size", "window_lead_time", "channels"]
         args = self.data_store.create_args_dict(args_list, self.scope)
-        self.model = MyLittleModel(**args)
+        self.model = MyModel(**args)
         self.get_model_settings()
 
     def get_model_settings(self):
diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py
index 35d93dcbd932d1c298c0744fcd0205697576bb4c..e5739e5f15e1c2f20758e388b3493c28f577bb9a 100644
--- a/src/run_modules/post_processing.py
+++ b/src/run_modules/post_processing.py
@@ -109,9 +109,25 @@ class PostProcessing(RunEnvironment):
         return persistence_prediction
 
     def _create_nn_forecast(self, input_data, nn_prediction, mean, std, transformation_method):
+        """
+        create the nn forecast for given input data. Inverse transformation is applied to the forecast to get the output
+        in the original space. Furthermore, only the output of the main branch is returned (not all minor branches, if
+        the network has multiple output branches). The main branch is defined to be the last entry of all outputs.
+        :param input_data:
+        :param nn_prediction:
+        :param mean:
+        :param std:
+        :param transformation_method:
+        :return:
+        """
         tmp_nn = self.model.predict(input_data)
         tmp_nn = statistics.apply_inverse_transformation(tmp_nn, mean, std, transformation_method)
-        nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn, axis=1), 2, 0)
+        if tmp_nn.ndim == 3:
+            nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn[-1, ...], axis=1), 2, 0)
+        elif tmp_nn.ndim == 2:
+            nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn, axis=1), 2, 0)
+        else:
+            raise NotImplementedError(f"Number of dimension of model output must be 2 or 3, but not {tmp_nn.dims}.")
         return nn_prediction
 
     @staticmethod