diff --git a/src/model_modules/flatten.py b/src/model_modules/flatten.py
index 2c6920413fe2348f06f7bb3f1b9d3a1696ac03b3..dd1e8e21eeb96f75372add0208b03dc06f5dc25c 100644
--- a/src/model_modules/flatten.py
+++ b/src/model_modules/flatten.py
@@ -1,46 +1,102 @@
 __author__ = "Felix Kleinert, Lukas Leufen"
 __date__ = '2019-12-02'
 
-from typing import Callable
+from typing import Union, Callable
 
 import keras
 
 
-def flatten_tail(input_X: keras.layers, name: str, bound_weight: bool = False, dropout_rate: float = 0.0,
-                 window_lead_time: int = 4, activation: Callable = keras.activations.relu,
-                 reduction_filter: int = 64, first_dense: int = 64):
+def get_activation(input_to_activate: keras.layers, activation: Union[Callable, str], **kwargs):
     """
-    Flatten output of
-
-    :param input_X:
-    :param name:
-    :param bound_weight:
-    :param dropout_rate:
-    :param window_lead_time:
-    :param activation:
-    :param reduction_filter:
-    :param first_dense:
-
-    :return:
+    Apply activation on a given input layer.
+
+    This helper function is able to handle advanced keras activations as well as strings for standard activations.
+
+    :param input_to_activate: keras layer to apply activation on
+    :param activation: activation to apply on `input_to_activate'. Can be a standard keras strings or activation layers
+    :param kwargs: keyword arguments used inside activation layer
+
+    :return: activation
+
+    .. code-block:: python
+
+        input_x = ... # your input data
+        x_in = keras.layer(<without activation>)(input_x)
+
+        # get activation via string
+        x_act_string = get_activation(x_in, 'relu')
+        # or get activation via layer callable
+        x_act_layer = get_activation(x_in, keras.layers.advanced_activations.ELU)
+
+    """
+    if isinstance(activation, str):
+        name = kwargs.pop('name', None)
+        kwargs['name'] = f'{name}_{activation}'
+        act = keras.layers.Activation(activation, **kwargs)(input_to_activate)
+    else:
+        act = activation(**kwargs)(input_to_activate)
+    return act
+
+
+def flatten_tail(input_x: keras.layers, inner_neurons: int, activation: Union[Callable, str],
+                 output_neurons: int, output_activation: Union[Callable, str],
+                 reduction_filter: int = None,
+                 name: str = None,
+                 bound_weight: bool = False,
+                 dropout_rate: float = None,
+                 kernel_regularizer: keras.regularizers = None
+                 ):
     """
-    X_in = keras.layers.Conv2D(reduction_filter, (1, 1), padding='same', name='{}_Conv_1x1'.format(name))(input_X)
+    Flatten output of convolutional layers.
+
+    :param input_x: Multidimensional keras layer (ConvLayer)
+    :param output_neurons: Number of neurons in the last layer (must fit the shape of labels)
+    :param output_activation: final activation function
+    :param name: Name of the flatten tail.
+    :param bound_weight: Use `tanh' as inner activation if set to True, otherwise `activation'
+    :param dropout_rate: Dropout rate to be applied between trainable layers
+    :param activation: activation to after conv and dense layers
+    :param reduction_filter: number of filters used for information compression on `input_x' before flatten()
+    :param inner_neurons: Number of neurons in inner dense layer
+    :param kernel_regularizer: regularizer to apply on conv and dense layers
 
-    X_in = activation(name='{}_conv_act'.format(name))(X_in)
+    :return: flatten branch with size n=output_neurons
 
-    X_in = keras.layers.Flatten(name='{}'.format(name))(X_in)
+    .. code-block:: python
 
-    X_in = keras.layers.Dropout(dropout_rate, name='{}_Dropout_1'.format(name))(X_in)
-    X_in = keras.layers.Dense(first_dense, kernel_regularizer=keras.regularizers.l2(0.01),
-                              name='{}_Dense_1'.format(name))(X_in)
+        input_x = ... # your input data
+        conv_out = Conv2D(*args)(input_x) # your convolution stack
+        out = flatten_tail(conv_out, inner_neurons=64, activation=keras.layers.advanced_activations.ELU,
+                           output_neurons=4
+                           output_activation='linear', reduction_filter=64,
+                           name='Main', bound_weight=False, dropout_rate=.3,
+                           kernel_regularizer=keras.regularizers.l2()
+                           )
+        model = keras.Model(inputs=input_x, outputs=[out])
+
+    """
+    # compression layer
+    if reduction_filter is None:
+        x_in = input_x
+    else:
+        x_in = keras.layers.Conv2D(reduction_filter, (1, 1), name=f'{name}_Conv_1x1',
+                                   kernel_regularizer=kernel_regularizer)(input_x)
+        x_in = get_activation(x_in, activation, name=f'{name}_conv_act')
+
+    x_in = keras.layers.Flatten(name='{}'.format(name))(x_in)
+
+    if dropout_rate is not None:
+        x_in = keras.layers.Dropout(dropout_rate, name=f'{name}_Dropout_1')(x_in)
+    x_in = keras.layers.Dense(inner_neurons, kernel_regularizer=kernel_regularizer,
+                              name=f'{name}_inner_Dense')(x_in)
     if bound_weight:
-        X_in = keras.layers.Activation('tanh')(X_in)
+        x_in = keras.layers.Activation('tanh')(x_in)
     else:
-        try:
-            X_in = activation(name='{}_act'.format(name))(X_in)
-        except:
-            X_in = activation()(X_in)
-
-    X_in = keras.layers.Dropout(dropout_rate, name='{}_Dropout_2'.format(name))(X_in)
-    out = keras.layers.Dense(window_lead_time, activation='linear', kernel_regularizer=keras.regularizers.l2(0.01),
-                             name='{}_Dense_2'.format(name))(X_in)
+        x_in = get_activation(x_in, activation, name=f'{name}_act')
+
+    if dropout_rate is not None:
+        x_in = keras.layers.Dropout(dropout_rate, name='{}_Dropout_2'.format(name))(x_in)
+    out = keras.layers.Dense(output_neurons, kernel_regularizer=kernel_regularizer,
+                             name=f'{name}_out_Dense')(x_in)
+    out = get_activation(out, output_activation, name=f'{name}_final_act')
     return out
diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py
index d516ab77781221d72be0e209133b8b78170259f3..a838c8291983b8c9da5027f19eff0c578b8bdbb8 100644
--- a/src/model_modules/model_class.py
+++ b/src/model_modules/model_class.py
@@ -498,8 +498,14 @@ class MyTowerModel(AbstractModelClass):
                                                batch_normalisation=True)
         #############################################
 
-        out_main = flatten_tail(X_in, 'Main', activation=activation, bound_weight=True, dropout_rate=self.dropout_rate,
-                                reduction_filter=64, first_dense=64, window_lead_time=self.window_lead_time)
+        # out_main = flatten_tail(X_in, 'Main', activation=activation, bound_weight=True, dropout_rate=self.dropout_rate,
+        #                         reduction_filter=64, inner_neurons=64, output_neurons=self.window_lead_time)
+
+        out_main = flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=self.window_lead_time,
+                                output_activation='linear', reduction_filter=64,
+                                name='Main', bound_weight=True, dropout_rate=self.dropout_rate,
+                                kernel_regularizer=self.regularizer
+                                )
 
         self.model = keras.Model(inputs=X_input, outputs=[out_main])
 
@@ -619,8 +625,13 @@ class MyPaperModel(AbstractModelClass):
                                                regularizer=self.regularizer,
                                                batch_normalisation=True,
                                                padding=self.padding)
-        out_minor1 = flatten_tail(X_in, 'minor_1', False, self.dropout_rate, self.window_lead_time,
-                                  self.activation, 32, 64)
+        # out_minor1 = flatten_tail(X_in, 'minor_1', False, self.dropout_rate, self.window_lead_time,
+        #                           self.activation, 32, 64)
+        out_minor1 = flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=self.window_lead_time,
+                                  output_activation='linear', reduction_filter=32,
+                                  name='minor_1', bound_weight=False, dropout_rate=self.dropout_rate,
+                                  kernel_regularizer=self.regularizer
+                                  )
 
         X_in = keras.layers.Dropout(self.dropout_rate)(X_in)
 
@@ -634,8 +645,11 @@ class MyPaperModel(AbstractModelClass):
         #                                        batch_normalisation=True)
         #############################################
 
-        out_main = flatten_tail(X_in, 'Main', activation=activation, bound_weight=False, dropout_rate=self.dropout_rate,
-                                reduction_filter=64 * 2, first_dense=64 * 2, window_lead_time=self.window_lead_time)
+        out_main = flatten_tail(X_in, inner_neurons=64 * 2, activation=activation, output_neurons=self.window_lead_time,
+                                output_activation='linear',  reduction_filter=64 * 2,
+                                name='Main', bound_weight=False, dropout_rate=self.dropout_rate,
+                                kernel_regularizer=self.regularizer
+                                )
 
         self.model = keras.Model(inputs=X_input, outputs=[out_minor1, out_main])
 
diff --git a/test/test_model_modules/test_flatten_tail.py b/test/test_model_modules/test_flatten_tail.py
new file mode 100644
index 0000000000000000000000000000000000000000..0de138ec2323aea3409d5deadfb26c9741b89f50
--- /dev/null
+++ b/test/test_model_modules/test_flatten_tail.py
@@ -0,0 +1,119 @@
+import keras
+import pytest
+from src.model_modules.flatten import flatten_tail, get_activation
+
+
+class TestGetActivation:
+
+    @pytest.fixture()
+    def model_input(self):
+        input_x = keras.layers.Input(shape=(7, 1, 2))
+        return input_x
+
+    def test_string_act(self, model_input):
+        x_in = get_activation(model_input, activation='relu', name='String')
+        act = x_in._keras_history[0]
+        assert act.name == 'String_relu'
+
+    def test_sting_act_unknown(self, model_input):
+        with pytest.raises(ValueError) as einfo:
+            get_activation(model_input, activation='invalid_activation', name='String')
+        assert 'Unknown activation function:invalid_activation' in str(einfo.value)
+
+    def test_layer_act(self, model_input):
+        x_in = get_activation(model_input, activation=keras.layers.advanced_activations.ELU, name='adv_layer')
+        act = x_in._keras_history[0]
+        assert act.name == 'adv_layer'
+
+    def test_layer_act_invalid(self, model_input):
+        with pytest.raises(TypeError) as einfo:
+            get_activation(model_input, activation=keras.layers.Conv2D, name='adv_layer')
+
+
+class TestFlattenTail:
+
+    @pytest.fixture()
+    def model_input(self):
+        input_x = keras.layers.Input(shape=(7, 1, 2))
+        return input_x
+
+    @staticmethod
+    def step_in(element, depth=1):
+        for _ in range(depth):
+            element = element.input._keras_history[0]
+        return element
+
+    def test_flatten_tail_no_bound_no_regul_no_drop(self, model_input):
+        tail = flatten_tail(input_x=model_input, inner_neurons=64, activation=keras.layers.advanced_activations.ELU,
+                            output_neurons=2, output_activation='linear',
+                            reduction_filter=None,
+                            name='Main_tail',
+                            bound_weight=False,
+                            dropout_rate=None,
+                            kernel_regularizer=None)
+        final_act = tail._keras_history[0]
+        assert final_act.name == 'Main_tail_final_act_linear'
+        final_dense = self.step_in(final_act)
+        assert final_act.name == 'Main_tail_final_act_linear'
+        assert final_dense.units == 2
+        assert final_dense.kernel_regularizer is None
+        inner_act = self.step_in(final_dense)
+        assert inner_act.name == 'Main_tail_act'
+        assert inner_act.__class__.__name__ == 'ELU'
+        inner_dense = self.step_in(inner_act)
+        assert inner_dense.name == 'Main_tail_inner_Dense'
+        assert inner_dense.units == 64
+        assert inner_dense.kernel_regularizer is None
+        flatten = self.step_in(inner_dense)
+        assert flatten.name == 'Main_tail'
+        input_layer = self.step_in(flatten)
+        assert input_layer.input_shape == (None, 7, 1, 2)
+
+    def test_flatten_tail_all_settings(self, model_input):
+        tail = flatten_tail(input_x=model_input, inner_neurons=64, activation=keras.layers.advanced_activations.ELU,
+                            output_neurons=3, output_activation='linear',
+                            reduction_filter=32,
+                            name='Main_tail_all',
+                            bound_weight=True,
+                            dropout_rate=.35,
+                            kernel_regularizer=keras.regularizers.l2())
+
+        final_act = tail._keras_history[0]
+        assert final_act.name == 'Main_tail_all_final_act_linear'
+
+        final_dense = self.step_in(final_act)
+        assert final_dense.name == 'Main_tail_all_out_Dense'
+        assert final_dense.units == 3
+        assert isinstance(final_dense.kernel_regularizer, keras.regularizers.L1L2)
+
+        final_dropout = self.step_in(final_dense)
+        assert final_dropout.name == 'Main_tail_all_Dropout_2'
+        assert final_dropout.rate == 0.35
+
+        inner_act = self.step_in(final_dropout)
+        assert inner_act.get_config() == {'name': 'activation_1', 'trainable': True, 'activation': 'tanh'}
+
+        inner_dense = self.step_in(inner_act)
+        assert inner_dense.units == 64
+        assert isinstance(inner_dense.kernel_regularizer, keras.regularizers.L1L2)
+
+        inner_dropout = self.step_in(inner_dense)
+        assert inner_dropout.get_config() == {'name': 'Main_tail_all_Dropout_1', 'trainable': True, 'rate': 0.35,
+                                              'noise_shape': None, 'seed': None}
+
+        flatten = self.step_in(inner_dropout)
+        assert flatten.get_config() == {'name': 'Main_tail_all', 'trainable': True, 'data_format': 'channels_last'}
+
+        reduc_act = self.step_in(flatten)
+        assert reduc_act.get_config() == {'name': 'Main_tail_all_conv_act', 'trainable': True, 'alpha': 1.0}
+
+        reduc_conv = self.step_in(reduc_act)
+
+        assert reduc_conv.kernel_size == (1, 1)
+        assert reduc_conv.name == 'Main_tail_all_Conv_1x1'
+        assert reduc_conv.filters == 32
+        assert isinstance(reduc_conv.kernel_regularizer, keras.regularizers.L1L2)
+
+        input_layer = self.step_in(reduc_conv)
+        assert input_layer.input_shape == (None, 7, 1, 2)
+
diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py
index bb319fe96d7e610aa3115489f4e19c87b9a8f349..83b109940ca74475e7d865eaf690e1a757075815 100644
--- a/test/test_modules/test_training.py
+++ b/test/test_modules/test_training.py
@@ -28,11 +28,19 @@ def my_test_model(activation, window_history_size, channels, dropout_rate, add_m
     X_input = keras.layers.Input(shape=(window_history_size + 1, 1, channels))
     X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1)
     if add_minor_branch:
-        out = [flatten_tail(X_in, 'Minor_1', activation=activation)]
+        # out = [flatten_tail(X_in, 'Minor_1', activation=activation)]
+        out = [flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=4,
+                            output_activation='linear', reduction_filter=64,
+                            name='Minor_1', dropout_rate=dropout_rate,
+                            )]
     else:
         out = []
     X_in = keras.layers.Dropout(dropout_rate)(X_in)
-    out.append(flatten_tail(X_in, 'Main', activation=activation))
+    # out.append(flatten_tail(X_in, 'Main', activation=activation))
+    out.append(flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=4,
+                            output_activation='linear', reduction_filter=64,
+                            name='Main', dropout_rate=dropout_rate,
+                            ))
     return keras.Model(inputs=X_input, outputs=out)