diff --git a/mlair/model_modules/abstract_model_class.py b/mlair/model_modules/abstract_model_class.py
index 4a323f46ff95a7ca66c157f2e4d6d3184f244a4a..d8e275101e7ec1a2388cc52111034d2497c1e82d 100644
--- a/mlair/model_modules/abstract_model_class.py
+++ b/mlair/model_modules/abstract_model_class.py
@@ -253,5 +253,17 @@ class AbstractModelClass(ABC):
     def own_args(cls, *args):
         """Return all arguments (including kwonlyargs)."""
         arg_spec = inspect.getfullargspec(cls)
-        list_of_args = arg_spec.args + arg_spec.kwonlyargs
-        return remove_items(list_of_args, ["self"] + list(args))
+        list_of_args = arg_spec.args + arg_spec.kwonlyargs + cls.super_args()
+        return list(set(remove_items(list_of_args, ["self"] + list(args))))
+
+    @classmethod
+    def super_args(cls):
+        args = []
+        for super_cls in cls.__mro__:
+            if super_cls == cls:
+                continue
+            if hasattr(super_cls, "own_args"):
+                # args.extend(super_cls.own_args())
+                args.extend(getattr(super_cls, "own_args")())
+        return list(set(args))
+
diff --git a/mlair/model_modules/branched_input_networks.py b/mlair/model_modules/branched_input_networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c62c3cafc1537979e4a21bdb3bb6aa798e6e193
--- /dev/null
+++ b/mlair/model_modules/branched_input_networks.py
@@ -0,0 +1,294 @@
+from functools import partial, reduce
+
+from tensorflow import keras as keras
+
+from mlair import AbstractModelClass
+from mlair.helpers import select_from_dict
+from mlair.model_modules.loss import var_loss
+from mlair.model_modules.recurrent_networks import RNN
+
+
+class BranchedInputRNN(RNN):  # pragma: no cover
+    """A recurrent neural network with multiple input branches."""
+
+    def __init__(self, input_shape, output_shape, *args, **kwargs):
+
+        super().__init__([input_shape], output_shape, *args, **kwargs)
+
+        # apply to model
+        # self.set_model()
+        # self.set_compile_options()
+        # self.set_custom_objects(loss=self.compile_options["loss"][0], var_loss=var_loss)
+
+    def set_model(self):
+        """
+        Build the model.
+        """
+        if isinstance(self.layer_configuration, tuple) is True:
+            n_layer, n_hidden = self.layer_configuration
+            conf = [n_hidden for _ in range(n_layer)]
+        else:
+            assert isinstance(self.layer_configuration, list) is True
+            conf = self.layer_configuration
+
+        x_input = []
+        x_in = []
+
+        for branch in range(len(self._input_shape)):
+            shape_b = self._input_shape[branch]
+            x_input_b = keras.layers.Input(shape=shape_b)
+            x_input.append(x_input_b)
+            x_in_b = keras.layers.Reshape((shape_b[0], reduce((lambda x, y: x * y), shape_b[1:])),
+                                          name=f"reshape_branch{branch + 1}")(x_input_b)
+
+            for layer, n_hidden in enumerate(conf):
+                return_sequences = (layer < len(conf) - 1)
+                x_in_b = self.RNN(n_hidden, return_sequences=return_sequences, recurrent_dropout=self.dropout_rnn,
+                                  name=f"{self.RNN.__name__}_branch{branch + 1}_{layer + 1}",
+                                  kernel_regularizer=self.kernel_regularizer)(x_in_b)
+                if self.bn is True:
+                    x_in_b = keras.layers.BatchNormalization()(x_in_b)
+                x_in_b = self.activation_rnn(name=f"{self.activation_rnn_name}_branch{branch + 1}_{layer + 1}")(x_in_b)
+                if self.dropout is not None:
+                    x_in_b = self.dropout(self.dropout_rate)(x_in_b)
+            x_in.append(x_in_b)
+        x_concat = keras.layers.Concatenate()(x_in)
+
+        if self.add_dense_layer is True:
+            if len(self.dense_layer_configuration) == 0:
+                x_concat = keras.layers.Dense(min(self._output_shape ** 2, conf[-1]), name=f"Dense_{len(conf) + 1}",
+                                              kernel_initializer=self.kernel_initializer, )(x_concat)
+                x_concat = self.activation(name=f"{self.activation_name}_{len(conf) + 1}")(x_concat)
+                if self.dropout is not None:
+                    x_concat = self.dropout(self.dropout_rate)(x_concat)
+            else:
+                for layer, n_hidden in enumerate(self.dense_layer_configuration):
+                    if n_hidden < self._output_shape:
+                        break
+                    x_concat = keras.layers.Dense(n_hidden, name=f"Dense_{len(conf) + layer + 1}",
+                                                  kernel_initializer=self.kernel_initializer, )(x_concat)
+                    x_concat = self.activation(name=f"{self.activation_name}_{len(conf) + layer + 1}")(x_concat)
+                    if self.dropout is not None:
+                        x_concat = self.dropout(self.dropout_rate)(x_concat)
+
+        x_concat = keras.layers.Dense(self._output_shape)(x_concat)
+        out = self.activation_output(name=f"{self.activation_output_name}_output")(x_concat)
+        self.model = keras.Model(inputs=x_input, outputs=[out])
+        print(self.model.summary())
+
+    def set_compile_options(self):
+        self.compile_options = {"loss": [keras.losses.mean_squared_error],
+                                "metrics": ["mse", "mae", var_loss]}
+
+    def _update_model_name(self, rnn_type):
+        n_input = f"{len(self._input_shape)}x{self._input_shape[0][0]}x" \
+                  f"{str(reduce(lambda x, y: x * y, self._input_shape[0][1:]))}"
+        n_output = str(self._output_shape)
+        self.model_name = rnn_type.upper()
+        if isinstance(self.layer_configuration, tuple) and len(self.layer_configuration) == 2:
+            n_layer, n_hidden = self.layer_configuration
+            branch = [f"r{n_hidden}" for _ in range(n_layer)]
+        else:
+            branch = [f"r{n}" for n in self.layer_configuration]
+
+        concat = []
+        if self.add_dense_layer is True:
+            if len(self.dense_layer_configuration) == 0:
+                n_hidden = min(self._output_shape ** 2, int(branch[-1]))
+                concat.append(f"1x{n_hidden}")
+            else:
+                for n_hidden in self.dense_layer_configuration:
+                    if n_hidden < self._output_shape:
+                        break
+                    if len(concat) == 0:
+                        concat.append(f"1x{n_hidden}")
+                    else:
+                        concat.append(str(n_hidden))
+        self.model_name += "_".join(["", n_input, *branch, *concat, n_output])
+
+
+class BranchedInputFCN(AbstractModelClass):  # pragma: no cover
+    """
+    A fully connected network that uses multiple input branches that are combined by a concatenate layer.
+    """
+
+    _activation = {"relu": keras.layers.ReLU, "tanh": partial(keras.layers.Activation, "tanh"),
+                   "sigmoid": partial(keras.layers.Activation, "sigmoid"),
+                   "linear": partial(keras.layers.Activation, "linear"),
+                   "selu": partial(keras.layers.Activation, "selu"),
+                   "prelu": partial(keras.layers.PReLU, alpha_initializer=keras.initializers.constant(value=0.25)),
+                   "leakyrelu": partial(keras.layers.LeakyReLU)}
+    _initializer = {"tanh": "glorot_uniform", "sigmoid": "glorot_uniform", "linear": "glorot_uniform",
+                    "relu": keras.initializers.he_normal(), "selu": keras.initializers.lecun_normal(),
+                    "prelu": keras.initializers.he_normal()}
+    _optimizer = {"adam": keras.optimizers.Adam, "sgd": keras.optimizers.SGD}
+    _regularizer = {"l1": keras.regularizers.l1, "l2": keras.regularizers.l2, "l1_l2": keras.regularizers.l1_l2}
+    _requirements = ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad", "momentum", "nesterov", "l1", "l2"]
+    _dropout = {"selu": keras.layers.AlphaDropout}
+
+    def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear",
+                 optimizer="adam", n_layer=1, n_hidden=10, regularizer=None, dropout=None, layer_configuration=None,
+                 batch_normalization=False, **kwargs):
+        """
+        Sets model and loss depending on the given arguments.
+
+        :param input_shape: list of input shapes (expect len=1 with shape=(window_hist, station, variables))
+        :param output_shape: list of output shapes (expect len=1 with shape=(window_forecast))
+
+        Customize this FCN model via the following parameters:
+
+        :param activation: set your desired activation function. Chose from relu, tanh, sigmoid, linear, selu, prelu,
+            leakyrelu. (Default relu)
+        :param activation_output: same as activation parameter but exclusively applied on output layer only. (Default
+            linear)
+        :param optimizer: set optimizer method. Can be either adam or sgd. (Default adam)
+        :param n_layer: define number of hidden layers in the network. Given number of hidden neurons are used in each
+            layer. (Default 1)
+        :param n_hidden: define number of hidden units per layer. This number is used in each hidden layer. (Default 10)
+        :param layer_configuration: alternative formulation of the network's architecture. This will overwrite the
+            settings from n_layer and n_hidden. Provide a list where each element represent the number of units in the
+            hidden layer. The number of hidden layers is equal to the total length of this list.
+        :param dropout: use dropout with given rate. If no value is provided, dropout layers are not added to the
+            network at all. (Default None)
+        :param batch_normalization: use batch normalization layer in the network if enabled. These layers are inserted
+            between the linear part of a layer (the nn part) and the non-linear part (activation function). No BN layer
+            is added if set to false. (Default false)
+        """
+
+        super().__init__(input_shape, output_shape[0])
+
+        # settings
+        self.activation = self._set_activation(activation)
+        self.activation_name = activation
+        self.activation_output = self._set_activation(activation_output)
+        self.activation_output_name = activation_output
+        self.optimizer = self._set_optimizer(optimizer, **kwargs)
+        self.bn = batch_normalization
+        self.layer_configuration = (n_layer, n_hidden) if layer_configuration is None else layer_configuration
+        self._update_model_name()
+        self.kernel_initializer = self._initializer.get(activation, "glorot_uniform")
+        self.kernel_regularizer = self._set_regularizer(regularizer, **kwargs)
+        self.dropout, self.dropout_rate = self._set_dropout(activation, dropout)
+
+        # apply to model
+        self.set_model()
+        self.set_compile_options()
+        self.set_custom_objects(loss=self.compile_options["loss"][0], var_loss=var_loss)
+
+    def _set_activation(self, activation):
+        try:
+            return self._activation.get(activation.lower())
+        except KeyError:
+            raise AttributeError(f"Given activation {activation} is not supported in this model class.")
+
+    def _set_optimizer(self, optimizer, **kwargs):
+        try:
+            opt_name = optimizer.lower()
+            opt = self._optimizer.get(opt_name)
+            opt_kwargs = {}
+            if opt_name == "adam":
+                opt_kwargs = select_from_dict(kwargs, ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad"])
+            elif opt_name == "sgd":
+                opt_kwargs = select_from_dict(kwargs, ["lr", "momentum", "decay", "nesterov"])
+            return opt(**opt_kwargs)
+        except KeyError:
+            raise AttributeError(f"Given optimizer {optimizer} is not supported in this model class.")
+
+    def _set_regularizer(self, regularizer, **kwargs):
+        if regularizer is None or (isinstance(regularizer, str) and regularizer.lower() == "none"):
+            return None
+        try:
+            reg_name = regularizer.lower()
+            reg = self._regularizer.get(reg_name)
+            reg_kwargs = {}
+            if reg_name in ["l1", "l2"]:
+                reg_kwargs = select_from_dict(kwargs, reg_name, remove_none=True)
+                if reg_name in reg_kwargs:
+                    reg_kwargs["l"] = reg_kwargs.pop(reg_name)
+            elif reg_name == "l1_l2":
+                reg_kwargs = select_from_dict(kwargs, ["l1", "l2"], remove_none=True)
+            return reg(**reg_kwargs)
+        except KeyError:
+            raise AttributeError(f"Given regularizer {regularizer} is not supported in this model class.")
+
+    def _set_dropout(self, activation, dropout_rate):
+        if dropout_rate is None:
+            return None, None
+        assert 0 <= dropout_rate < 1
+        return self._dropout.get(activation, keras.layers.Dropout), dropout_rate
+
+    def _update_model_name(self):
+        n_input = f"{len(self._input_shape)}x{str(reduce(lambda x, y: x * y, self._input_shape[0]))}"
+        n_output = str(self._output_shape)
+
+        if isinstance(self.layer_configuration, tuple) and len(self.layer_configuration) == 2:
+            n_layer, n_hidden = self.layer_configuration
+            branch = [f"{n_hidden}" for _ in range(n_layer)]
+        else:
+            branch = [f"{n}" for n in self.layer_configuration]
+
+        concat = []
+        n_neurons_concat = int(branch[-1]) * len(self._input_shape)
+        for exp in reversed(range(2, len(self._input_shape) + 1)):
+            n_neurons = self._output_shape ** exp
+            if n_neurons < n_neurons_concat:
+                if len(concat) == 0:
+                    concat.append(f"1x{n_neurons}")
+                else:
+                    concat.append(str(n_neurons))
+        self.model_name += "_".join(["", n_input, *branch, *concat, n_output])
+
+    def set_model(self):
+        """
+        Build the model.
+        """
+
+        if isinstance(self.layer_configuration, tuple) is True:
+            n_layer, n_hidden = self.layer_configuration
+            conf = [n_hidden for _ in range(n_layer)]
+        else:
+            assert isinstance(self.layer_configuration, list) is True
+            conf = self.layer_configuration
+
+        x_input = []
+        x_in = []
+
+        for branch in range(len(self._input_shape)):
+            x_input_b = keras.layers.Input(shape=self._input_shape[branch])
+            x_input.append(x_input_b)
+            x_in_b = keras.layers.Flatten()(x_input_b)
+
+            for layer, n_hidden in enumerate(conf):
+                x_in_b = keras.layers.Dense(n_hidden, kernel_initializer=self.kernel_initializer,
+                                            kernel_regularizer=self.kernel_regularizer,
+                                            name=f"Dense_branch{branch + 1}_{layer + 1}")(x_in_b)
+                if self.bn is True:
+                    x_in_b = keras.layers.BatchNormalization()(x_in_b)
+                x_in_b = self.activation(name=f"{self.activation_name}_branch{branch + 1}_{layer + 1}")(x_in_b)
+                if self.dropout is not None:
+                    x_in_b = self.dropout(self.dropout_rate)(x_in_b)
+            x_in.append(x_in_b)
+        x_concat = keras.layers.Concatenate()(x_in)
+
+        n_neurons_concat = int(conf[-1]) * len(self._input_shape)
+        layer_concat = 0
+        for exp in reversed(range(2, len(self._input_shape) + 1)):
+            n_neurons = self._output_shape ** exp
+            if n_neurons < n_neurons_concat:
+                layer_concat += 1
+                x_concat = keras.layers.Dense(n_neurons, name=f"Dense_{layer_concat}")(x_concat)
+                if self.bn is True:
+                    x_concat = keras.layers.BatchNormalization()(x_concat)
+                x_concat = self.activation(name=f"{self.activation_name}_{layer_concat}")(x_concat)
+                if self.dropout is not None:
+                    x_concat = self.dropout(self.dropout_rate)(x_concat)
+        x_concat = keras.layers.Dense(self._output_shape)(x_concat)
+        out = self.activation_output(name=f"{self.activation_output_name}_output")(x_concat)
+        self.model = keras.Model(inputs=x_input, outputs=[out])
+        print(self.model.summary())
+
+    def set_compile_options(self):
+        self.compile_options = {"loss": [keras.losses.mean_squared_error],
+                                "metrics": ["mse", "mae", var_loss]}
+        # self.compile_options = {"loss": [custom_loss([keras.losses.mean_squared_error, var_loss], loss_weights=[2, 1])],
+        #                         "metrics": ["mse", "mae", var_loss]}
diff --git a/mlair/model_modules/fully_connected_networks.py b/mlair/model_modules/fully_connected_networks.py
index 372473ee22b5174a3beca91898509a3582391587..6da427e56f36b1af11ec88ea039abc571d69367b 100644
--- a/mlair/model_modules/fully_connected_networks.py
+++ b/mlair/model_modules/fully_connected_networks.py
@@ -190,191 +190,3 @@ class FCN_64_32_16(FCN):
     def _update_model_name(self):
         self.model_name = "FCN"
         super()._update_model_name()
-
-
-class BranchedInputFCN(AbstractModelClass):  # pragma: no cover
-    """
-    A customisable fully connected network (64, 32, 16, window_lead_time), where the last layer is the output layer depending
-    on the window_lead_time parameter.
-    """
-
-    _activation = {"relu": keras.layers.ReLU, "tanh": partial(keras.layers.Activation, "tanh"),
-                   "sigmoid": partial(keras.layers.Activation, "sigmoid"),
-                   "linear": partial(keras.layers.Activation, "linear"),
-                   "selu": partial(keras.layers.Activation, "selu"),
-                   "prelu": partial(keras.layers.PReLU, alpha_initializer=keras.initializers.constant(value=0.25)),
-                   "leakyrelu": partial(keras.layers.LeakyReLU)}
-    _initializer = {"tanh": "glorot_uniform", "sigmoid": "glorot_uniform", "linear": "glorot_uniform",
-                    "relu": keras.initializers.he_normal(), "selu": keras.initializers.lecun_normal(),
-                    "prelu": keras.initializers.he_normal()}
-    _optimizer = {"adam": keras.optimizers.Adam, "sgd": keras.optimizers.SGD}
-    _regularizer = {"l1": keras.regularizers.l1, "l2": keras.regularizers.l2, "l1_l2": keras.regularizers.l1_l2}
-    _requirements = ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad", "momentum", "nesterov", "l1", "l2"]
-    _dropout = {"selu": keras.layers.AlphaDropout}
-
-    def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear",
-                 optimizer="adam", n_layer=1, n_hidden=10, regularizer=None, dropout=None, layer_configuration=None,
-                 batch_normalization=False, **kwargs):
-        """
-        Sets model and loss depending on the given arguments.
-
-        :param input_shape: list of input shapes (expect len=1 with shape=(window_hist, station, variables))
-        :param output_shape: list of output shapes (expect len=1 with shape=(window_forecast))
-
-        Customize this FCN model via the following parameters:
-
-        :param activation: set your desired activation function. Chose from relu, tanh, sigmoid, linear, selu, prelu,
-            leakyrelu. (Default relu)
-        :param activation_output: same as activation parameter but exclusively applied on output layer only. (Default
-            linear)
-        :param optimizer: set optimizer method. Can be either adam or sgd. (Default adam)
-        :param n_layer: define number of hidden layers in the network. Given number of hidden neurons are used in each
-            layer. (Default 1)
-        :param n_hidden: define number of hidden units per layer. This number is used in each hidden layer. (Default 10)
-        :param layer_configuration: alternative formulation of the network's architecture. This will overwrite the
-            settings from n_layer and n_hidden. Provide a list where each element represent the number of units in the
-            hidden layer. The number of hidden layers is equal to the total length of this list.
-        :param dropout: use dropout with given rate. If no value is provided, dropout layers are not added to the
-            network at all. (Default None)
-        :param batch_normalization: use batch normalization layer in the network if enabled. These layers are inserted
-            between the linear part of a layer (the nn part) and the non-linear part (activation function). No BN layer
-            is added if set to false. (Default false)
-        """
-
-        super().__init__(input_shape, output_shape[0])
-
-        # settings
-        self.activation = self._set_activation(activation)
-        self.activation_name = activation
-        self.activation_output = self._set_activation(activation_output)
-        self.activation_output_name = activation_output
-        self.optimizer = self._set_optimizer(optimizer, **kwargs)
-        self.bn = batch_normalization
-        self.layer_configuration = (n_layer, n_hidden) if layer_configuration is None else layer_configuration
-        self._update_model_name()
-        self.kernel_initializer = self._initializer.get(activation, "glorot_uniform")
-        self.kernel_regularizer = self._set_regularizer(regularizer, **kwargs)
-        self.dropout, self.dropout_rate = self._set_dropout(activation, dropout)
-
-        # apply to model
-        self.set_model()
-        self.set_compile_options()
-        self.set_custom_objects(loss=self.compile_options["loss"][0], var_loss=var_loss)
-
-    def _set_activation(self, activation):
-        try:
-            return self._activation.get(activation.lower())
-        except KeyError:
-            raise AttributeError(f"Given activation {activation} is not supported in this model class.")
-
-    def _set_optimizer(self, optimizer, **kwargs):
-        try:
-            opt_name = optimizer.lower()
-            opt = self._optimizer.get(opt_name)
-            opt_kwargs = {}
-            if opt_name == "adam":
-                opt_kwargs = select_from_dict(kwargs, ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad"])
-            elif opt_name == "sgd":
-                opt_kwargs = select_from_dict(kwargs, ["lr", "momentum", "decay", "nesterov"])
-            return opt(**opt_kwargs)
-        except KeyError:
-            raise AttributeError(f"Given optimizer {optimizer} is not supported in this model class.")
-
-    def _set_regularizer(self, regularizer, **kwargs):
-        if regularizer is None or (isinstance(regularizer, str) and regularizer.lower() == "none"):
-            return None
-        try:
-            reg_name = regularizer.lower()
-            reg = self._regularizer.get(reg_name)
-            reg_kwargs = {}
-            if reg_name in ["l1", "l2"]:
-                reg_kwargs = select_from_dict(kwargs, reg_name, remove_none=True)
-                if reg_name in reg_kwargs:
-                    reg_kwargs["l"] = reg_kwargs.pop(reg_name)
-            elif reg_name == "l1_l2":
-                reg_kwargs = select_from_dict(kwargs, ["l1", "l2"], remove_none=True)
-            return reg(**reg_kwargs)
-        except KeyError:
-            raise AttributeError(f"Given regularizer {regularizer} is not supported in this model class.")
-
-    def _set_dropout(self, activation, dropout_rate):
-        if dropout_rate is None:
-            return None, None
-        assert 0 <= dropout_rate < 1
-        return self._dropout.get(activation, keras.layers.Dropout), dropout_rate
-
-    def _update_model_name(self):
-        n_input = f"{len(self._input_shape)}x{str(reduce(lambda x, y: x * y, self._input_shape[0]))}"
-        n_output = str(self._output_shape)
-
-        if isinstance(self.layer_configuration, tuple) and len(self.layer_configuration) == 2:
-            n_layer, n_hidden = self.layer_configuration
-            branch = [f"{n_hidden}" for _ in range(n_layer)]
-        else:
-            branch = [f"{n}" for n in self.layer_configuration]
-
-        concat = []
-        n_neurons_concat = int(branch[-1]) * len(self._input_shape)
-        for exp in reversed(range(2, len(self._input_shape) + 1)):
-            n_neurons = self._output_shape ** exp
-            if n_neurons < n_neurons_concat:
-                if len(concat) == 0:
-                    concat.append(f"1x{n_neurons}")
-                else:
-                    concat.append(str(n_neurons))
-        self.model_name += "_".join(["", n_input, *branch, *concat, n_output])
-
-    def set_model(self):
-        """
-        Build the model.
-        """
-
-        if isinstance(self.layer_configuration, tuple) is True:
-            n_layer, n_hidden = self.layer_configuration
-            conf = [n_hidden for _ in range(n_layer)]
-        else:
-            assert isinstance(self.layer_configuration, list) is True
-            conf = self.layer_configuration
-
-        x_input = []
-        x_in = []
-
-        for branch in range(len(self._input_shape)):
-            x_input_b = keras.layers.Input(shape=self._input_shape[branch])
-            x_input.append(x_input_b)
-            x_in_b = keras.layers.Flatten()(x_input_b)
-
-            for layer, n_hidden in enumerate(conf):
-                x_in_b = keras.layers.Dense(n_hidden, kernel_initializer=self.kernel_initializer,
-                                            kernel_regularizer=self.kernel_regularizer,
-                                            name=f"Dense_branch{branch + 1}_{layer + 1}")(x_in_b)
-                if self.bn is True:
-                    x_in_b = keras.layers.BatchNormalization()(x_in_b)
-                x_in_b = self.activation(name=f"{self.activation_name}_branch{branch + 1}_{layer + 1}")(x_in_b)
-                if self.dropout is not None:
-                    x_in_b = self.dropout(self.dropout_rate)(x_in_b)
-            x_in.append(x_in_b)
-        x_concat = keras.layers.Concatenate()(x_in)
-
-        n_neurons_concat = int(conf[-1]) * len(self._input_shape)
-        layer_concat = 0
-        for exp in reversed(range(2, len(self._input_shape) + 1)):
-            n_neurons = self._output_shape ** exp
-            if n_neurons < n_neurons_concat:
-                layer_concat += 1
-                x_concat = keras.layers.Dense(n_neurons, name=f"Dense_{layer_concat}")(x_concat)
-                if self.bn is True:
-                    x_concat = keras.layers.BatchNormalization()(x_concat)
-                x_concat = self.activation(name=f"{self.activation_name}_{layer_concat}")(x_concat)
-                if self.dropout is not None:
-                    x_concat = self.dropout(self.dropout_rate)(x_concat)
-        x_concat = keras.layers.Dense(self._output_shape)(x_concat)
-        out = self.activation_output(name=f"{self.activation_output_name}_output")(x_concat)
-        self.model = keras.Model(inputs=x_input, outputs=[out])
-        print(self.model.summary())
-
-    def set_compile_options(self):
-        self.compile_options = {"loss": [keras.losses.mean_squared_error],
-                                "metrics": ["mse", "mae", var_loss]}
-        # self.compile_options = {"loss": [custom_loss([keras.losses.mean_squared_error, var_loss], loss_weights=[2, 1])],
-        #                         "metrics": ["mse", "mae", var_loss]}
diff --git a/mlair/model_modules/recurrent_networks.py b/mlair/model_modules/recurrent_networks.py
index e909ae7696bdf90d4e9a95e020b75a97e15dfd50..13e6fbecc7f3936a788dd6b035b9a7abe7b42857 100644
--- a/mlair/model_modules/recurrent_networks.py
+++ b/mlair/model_modules/recurrent_networks.py
@@ -2,6 +2,7 @@ __author__ = "Lukas Leufen"
 __date__ = '2021-05-25'
 
 from functools import reduce, partial
+from typing import Union
 
 from mlair.model_modules import AbstractModelClass
 from mlair.helpers import select_from_dict
@@ -33,7 +34,8 @@ class RNN(AbstractModelClass):  # pragma: no cover
     def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear",
                  activation_rnn="tanh", dropout_rnn=0,
                  optimizer="adam", n_layer=1, n_hidden=10, regularizer=None, dropout=None, layer_configuration=None,
-                 batch_normalization=False, rnn_type="lstm", add_dense_layer=False, dense_layer_configuration=None, **kwargs):
+                 batch_normalization=False, rnn_type="lstm", add_dense_layer=False, dense_layer_configuration=None,
+                 kernel_regularizer=None, **kwargs):
         """
         Sets model and loss depending on the given arguments.
 
@@ -42,10 +44,12 @@ class RNN(AbstractModelClass):  # pragma: no cover
 
         Customize this RNN model via the following parameters:
 
-        :param activation: set your desired activation function for appended dense layers (add_dense_layer=True=. Choose
+        :param activation: set your desired activation function for appended dense layers (add_dense_layer=True). Choose
             from relu, tanh, sigmoid, linear, selu, prelu, leakyrelu. (Default relu)
         :param activation_rnn: set your desired activation function of the rnn output. Choose from relu, tanh, sigmoid,
-            linear, selu, prelu, leakyrelu. (Default tanh)
+            linear, selu, prelu, leakyrelu. To use the fast cuDNN implementation, tensorflow requires to use tanh as
+            activation. Note that this is not the recurrent activation (which is not mutable in this class) but the
+            activation of the cell. (Default tanh)
         :param activation_output: same as activation parameter but exclusively applied on output layer only. (Default
             linear)
         :param optimizer: set optimizer method. Can be either adam or sgd. (Default adam)
@@ -58,7 +62,8 @@ class RNN(AbstractModelClass):  # pragma: no cover
         :param dropout: use dropout with given rate. If no value is provided, dropout layers are not added to the
             network at all. (Default None)
         :param dropout_rnn: use recurrent dropout with given rate. This is applied along the recursion and not after
-            a rnn layer. (Default 0)
+            a rnn layer. Be aware that tensorflow is only able to use the fast cuDNN implementation with no recurrent
+            dropout. (Default 0)
         :param batch_normalization: use batch normalization layer in the network if enabled. These layers are inserted
             between the linear part of a layer (the nn part) and the non-linear part (activation function). No BN layer
             is added if set to false. (Default false)
@@ -94,7 +99,7 @@ class RNN(AbstractModelClass):  # pragma: no cover
         self.RNN = self._rnn.get(rnn_type.lower())
         self._update_model_name(rnn_type)
         self.kernel_initializer = self._initializer.get(activation, "glorot_uniform")
-        # self.kernel_regularizer = self._set_regularizer(regularizer, **kwargs)
+        self.kernel_regularizer, self.kernel_regularizer_opts = self._set_regularizer(kernel_regularizer, **kwargs)
         self.dropout, self.dropout_rate = self._set_dropout(activation, dropout)
         assert 0 <= dropout_rnn <= 1
         self.dropout_rnn = dropout_rnn
@@ -121,7 +126,8 @@ class RNN(AbstractModelClass):  # pragma: no cover
 
         for layer, n_hidden in enumerate(conf):
             return_sequences = (layer < len(conf) - 1)
-            x_in = self.RNN(n_hidden, return_sequences=return_sequences, recurrent_dropout=self.dropout_rnn)(x_in)
+            x_in = self.RNN(n_hidden, return_sequences=return_sequences, recurrent_dropout=self.dropout_rnn,
+                            kernel_regularizer=self.kernel_regularizer)(x_in)
             if self.bn is True:
                 x_in = keras.layers.BatchNormalization()(x_in)
             x_in = self.activation_rnn(name=f"{self.activation_rnn_name}_{layer + 1}")(x_in)
@@ -188,23 +194,23 @@ class RNN(AbstractModelClass):  # pragma: no cover
             return opt(**opt_kwargs)
         except KeyError:
             raise AttributeError(f"Given optimizer {optimizer} is not supported in this model class.")
-    #
-    # def _set_regularizer(self, regularizer, **kwargs):
-    #     if regularizer is None or (isinstance(regularizer, str) and regularizer.lower() == "none"):
-    #         return None
-    #     try:
-    #         reg_name = regularizer.lower()
-    #         reg = self._regularizer.get(reg_name)
-    #         reg_kwargs = {}
-    #         if reg_name in ["l1", "l2"]:
-    #             reg_kwargs = select_from_dict(kwargs, reg_name, remove_none=True)
-    #             if reg_name in reg_kwargs:
-    #                 reg_kwargs["l"] = reg_kwargs.pop(reg_name)
-    #         elif reg_name == "l1_l2":
-    #             reg_kwargs = select_from_dict(kwargs, ["l1", "l2"], remove_none=True)
-    #         return reg(**reg_kwargs)
-    #     except KeyError:
-    #         raise AttributeError(f"Given regularizer {regularizer} is not supported in this model class.")
+
+    def _set_regularizer(self, regularizer: Union[None, str], **kwargs):
+        if regularizer is None or (isinstance(regularizer, str) and regularizer.lower() == "none"):
+            return None, None
+        try:
+            reg_name = regularizer.lower()
+            reg = self._regularizer.get(reg_name)
+            reg_kwargs = {}
+            if reg_name in ["l1", "l2"]:
+                reg_kwargs = select_from_dict(kwargs, reg_name, remove_none=True)
+                if reg_name in reg_kwargs:
+                    reg_kwargs["l"] = reg_kwargs.pop(reg_name)
+            elif reg_name == "l1_l2":
+                reg_kwargs = select_from_dict(kwargs, ["l1", "l2"], remove_none=True)
+            return reg(**reg_kwargs), reg_kwargs
+        except KeyError:
+            raise AttributeError(f"Given regularizer {regularizer} is not supported in this model class.")
 
     def _update_model_name(self, rnn_type):
         n_input = str(reduce(lambda x, y: x * y, self._input_shape))